This commit is contained in:
Charlie 2024-05-09 16:59:25 +08:00
parent 65b9f3795d
commit 90f807defc
12 changed files with 1283 additions and 0 deletions

49
clauses/merge.go Normal file
View File

@ -0,0 +1,49 @@
package clauses
import (
"gorm.io/gorm/clause"
)
type Merge struct {
Table clause.Table
Using []clause.Interface
On []clause.Expression
}
func (merge Merge) Name() string {
return "MERGE"
}
func MergeDefaultExcludeName() string {
return "exclude"
}
// Build build from clause
func (merge Merge) Build(builder clause.Builder) {
clause.Insert{}.Build(builder)
builder.WriteString(" USING (")
for idx, iface := range merge.Using {
if idx > 0 {
builder.WriteByte(' ')
}
builder.WriteString(iface.Name())
builder.WriteByte(' ')
iface.Build(builder)
}
builder.WriteString(") ")
builder.WriteString(MergeDefaultExcludeName())
builder.WriteString(" ON (")
for idx, on := range merge.On {
if idx > 0 {
builder.WriteString(", ")
}
on.Build(builder)
}
builder.WriteString(")")
}
// MergeClause merge values clauses
func (merge Merge) MergeClause(clause *clause.Clause) {
clause.Name = merge.Name()
clause.Expression = merge
}

10
clauses/returning_into.go Normal file
View File

@ -0,0 +1,10 @@
package clauses
import (
"gorm.io/gorm/clause"
)
type ReturningInto struct {
Variables []clause.Column
Into []*clause.Values
}

39
clauses/when_matched.go Normal file
View File

@ -0,0 +1,39 @@
package clauses
import (
"gorm.io/gorm/clause"
)
type WhenMatched struct {
clause.Set
Where, Delete clause.Where
}
func (w WhenMatched) Name() string {
return "WHEN MATCHED"
}
func (w WhenMatched) Build(builder clause.Builder) {
if len(w.Set) > 0 {
builder.WriteString(" THEN")
builder.WriteString(" UPDATE ")
builder.WriteString(w.Name())
builder.WriteByte(' ')
w.Build(builder)
buildWhere := func(where clause.Where) {
builder.WriteString(where.Name())
builder.WriteByte(' ')
where.Build(builder)
}
if len(w.Where.Exprs) > 0 {
buildWhere(w.Where)
}
if len(w.Delete.Exprs) > 0 {
builder.WriteString(" DELETE ")
buildWhere(w.Delete)
}
}
}

View File

@ -0,0 +1,32 @@
package clauses
import (
"gorm.io/gorm/clause"
)
type WhenNotMatched struct {
clause.Values
Where clause.Where
}
func (w WhenNotMatched) Name() string {
return "WHEN NOT MATCHED"
}
func (w WhenNotMatched) Build(builder clause.Builder) {
if len(w.Columns) > 0 {
if len(w.Values.Values) != 1 {
panic("cannot insert more than one rows due to Oracle SQL language restriction")
}
builder.WriteString(" THEN")
builder.WriteString(" INSERT ")
w.Build(builder)
if len(w.Where.Exprs) > 0 {
builder.WriteString(w.Where.Name())
builder.WriteByte(' ')
w.Where.Build(builder)
}
}
}

255
create.go Normal file
View File

@ -0,0 +1,255 @@
package oracle
import (
"database/sql"
"reflect"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
)
func Create(db *gorm.DB) {
if db.Error != nil {
return
}
stmt := db.Statement
if stmt == nil {
return
}
stmtSchema := stmt.Schema
if stmtSchema == nil {
return
}
if !stmt.Unscoped {
for _, c := range stmtSchema.CreateClauses {
stmt.AddClause(c)
}
}
if stmt.SQL.Len() == 0 {
var (
createValues = callbacks.ConvertToCreateValues(stmt)
onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
)
if hasConflict {
if len(stmtSchema.PrimaryFields) > 0 {
columnsMap := map[string]bool{}
for _, column := range createValues.Columns {
columnsMap[column.Name] = true
}
for _, field := range stmtSchema.PrimaryFields {
if _, ok := columnsMap[field.DBName]; !ok {
hasConflict = false
}
}
} else {
hasConflict = false
}
}
hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0
if hasConflict {
MergeCreate(db, onConflict, createValues)
} else {
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}})
stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}})
if hasDefaultValues {
columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue))
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
columns[idx] = clause.Column{Name: field.DBName}
}
stmt.AddClauseIfNotExists(clause.Returning{Columns: columns})
}
stmt.Build("INSERT", "VALUES", "RETURNING")
if hasDefaultValues {
_, _ = stmt.WriteString(" INTO ")
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
if idx > 0 {
_ = stmt.WriteByte(',')
}
stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()})
}
_, _ = stmt.WriteString(" /*-sql.Out{}-*/")
}
}
if !db.DryRun && db.Error == nil {
if hasConflict {
for i, val := range stmt.Vars {
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
stmt.Vars[i] = convertValue(val)
}
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
// TODO: get merged returning
}
} else {
for idx, values := range createValues.Values {
for i, val := range values {
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
stmt.Vars[i] = convertValue(val)
}
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
if hasDefaultValues {
getDefaultValues(db, idx)
}
}
}
}
}
}
}
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
var dummyTable string
switch d := ptrDereference(db.Dialector).(type) {
case Dialector:
dummyTable = d.DummyTableName()
default:
dummyTable = "DUAL"
}
_, _ = db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table)
_, _ = db.Statement.WriteString(" USING (")
for idx, value := range values.Values {
if idx > 0 {
_, _ = db.Statement.WriteString(" UNION ALL ")
}
_, _ = db.Statement.WriteString("SELECT ")
for i, v := range value {
if i > 0 {
_ = db.Statement.WriteByte(',')
}
column := values.Columns[i]
db.Statement.AddVar(db.Statement, v)
_, _ = db.Statement.WriteString(" AS ")
db.Statement.WriteQuoted(column.Name)
}
_, _ = db.Statement.WriteString(" FROM ")
_, _ = db.Statement.WriteString(dummyTable)
}
_, _ = db.Statement.WriteString(`) `)
db.Statement.WriteQuoted("excluded")
_, _ = db.Statement.WriteString(" ON (")
var where clause.Where
for _, field := range db.Statement.Schema.PrimaryFields {
where.Exprs = append(where.Exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
Value: clause.Column{Table: "excluded", Name: field.DBName},
})
}
where.Build(db.Statement)
_ = db.Statement.WriteByte(')')
if len(onConflict.DoUpdates) > 0 {
_, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
onConflict.DoUpdates.Build(db.Statement)
}
_, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
written := false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(column.Name)
}
}
_, _ = db.Statement.WriteString(") VALUES (")
written = false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(clause.Column{
Table: "excluded",
Name: column.Name,
})
}
}
_, _ = db.Statement.WriteString(")")
}
func convertValue(val interface{}) interface{} {
val = ptrDereference(val)
switch v := val.(type) {
case bool:
if v {
val = 1
} else {
val = 0
}
// case string:
// if len(v) > 2000 {
// val = go_ora.Clob{String: v, Valid: true}
// }
default:
val = convertCustomType(val)
}
return val
}
func getDefaultValues(db *gorm.DB, idx int) {
insertTo := db.Statement.ReflectValue
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
insertTo = insertTo.Index(idx)
default:
}
if insertTo.Kind() == reflect.Pointer {
insertTo = insertTo.Elem()
}
for _, val := range db.Statement.Vars {
switch v := val.(type) {
case sql.Out:
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
for i := insertTo.Len() - 1; i >= 0; i-- {
rv := insertTo.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero {
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest))
}
}
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo)
if isZero {
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest))
}
default:
}
default:
}
}
}

17
go.mod
View File

@ -1,3 +1,20 @@
module git.charlienet.top/go/oracle
go 1.22
require (
github.com/emirpasic/gods v1.18.1
github.com/godror/godror v0.42.2
github.com/sijms/go-ora/v2 v2.8.16
github.com/thoas/go-funk v0.9.3
gorm.io/gorm v1.25.10
)
require (
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/godror/knownpb v0.1.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 // indirect
google.golang.org/protobuf v1.33.0 // indirect
)

46
go.sum Normal file
View File

@ -0,0 +1,46 @@
github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4=
github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/godror/godror v0.42.2 h1:TmOV0fr4jJxwDD6vWtSieQFOZ75LnSbkJaudfuJOsVQ=
github.com/godror/godror v0.42.2/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94=
github.com/godror/knownpb v0.1.1 h1:A4J7jdx7jWBhJm18NntafzSC//iZDHkDi1+juwQ5pTI=
github.com/godror/knownpb v0.1.1/go.mod h1:4nRFbQo1dDuwKnblRXDxrfCFYeT4hjg3GjMqef58eRE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sijms/go-ora/v2 v2.8.16 h1:XiBAqHjoXUnJnGnRoHD/6NJdiYQ2pr0gPciJM8BE+70=
github.com/sijms/go-ora/v2 v2.8.16/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4FWlH6xk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw=
github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 h1:6R2FC06FonbXQ8pK11/PDFY6N6LWlf9KlzibaCapmqc=
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29 h1:w8s32wxx3sY+OjLlv9qltkLU5yvJzxjjgiHWLjdIcw4=
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

325
migrator.go Normal file
View File

@ -0,0 +1,325 @@
package oracle
import (
"database/sql"
"fmt"
"strings"
"gorm.io/gorm/schema"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
)
type Migrator struct {
migrator.Migrator
}
func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw(
fmt.Sprintf(`SELECT ORA_DATABASE_NAME as "Current Database" FROM %s`, m.Dialector.(Dialector).DummyTableName()),
).Row().Scan(&name)
return
}
func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range values {
m.TryQuotifyReservedWords(value)
m.TryRemoveOnUpdate(value)
}
return m.Migrator.CreateTable(values...)
}
func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- {
value := values[i]
tx := m.DB.Session(&gorm.Session{})
if m.HasTable(value) {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE ? CASCADE CONSTRAINTS", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
}
return nil
}
func (m Migrator) HasTable(value interface{}) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
ownertable := strings.Split(stmt.Schema.Table, ".")
return m.DB.Raw("SELECT COUNT(*) FROM ALL_TABLES WHERE OWNER = ? and TABLE_NAME = ?", ownertable[0], ownertable[1]).Row().Scan(&count)
} else {
return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", stmt.Table).Row().Scan(&count)
}
})
return count > 0
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Schema.Table).Where("ROWNUM = 1").Rows()
if err != nil {
return err
}
defer func() {
err = rows.Close()
}()
var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, migrator.ColumnType{SQLColumnType: c})
}
return
})
return columnTypes, execErr
}
func (m Migrator) RenameTable(oldName, newName interface{}) (err error) {
resolveTable := func(name interface{}) (result string, err error) {
if v, ok := name.(string); ok {
result = v
} else {
stmt := &gorm.Statement{DB: m.DB}
if err = stmt.Parse(name); err == nil {
result = stmt.Table
}
}
return
}
var oldTable, newTable string
if oldTable, err = resolveTable(oldName); err != nil {
return
}
if newTable, err = resolveTable(newName); err != nil {
return
}
if !m.HasTable(oldTable) {
return
}
return m.DB.Exec("RENAME TABLE ? TO ?",
clause.Table{Name: oldTable},
clause.Table{Name: newTable},
).Error
}
func (m Migrator) AddColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec(
"ALTER TABLE ? ADD ? ?",
clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field),
).Error
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (m Migrator) DropColumn(value interface{}, name string) error {
if !m.HasColumn(value, name) {
return nil
}
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
return m.DB.Exec(
"ALTER TABLE ? DROP ?",
clause.Table{Name: stmt.Schema.Table},
clause.Column{Name: name},
).Error
})
}
func (m Migrator) AlterColumn(value interface{}, field string) error {
if !m.HasColumn(value, field) {
return nil
}
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec(
"ALTER TABLE ? MODIFY ? ?",
clause.Table{Name: stmt.Schema.Table},
clause.Column{Name: field.DBName},
m.AlterDataTypeOf(stmt, field),
).Error
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
ownertable := strings.Split(stmt.Schema.Table, ".")
return m.DB.Raw("SELECT COUNT(*) FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownertable[0], ownertable[1], field).Row().Scan(&count)
} else {
return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field).Row().Scan(&count)
}
}) == nil && count > 0
}
func (m Migrator) AlterDataTypeOf(stmt *gorm.Statement, field *schema.Field) (expr clause.Expr) {
expr.SQL = m.DataTypeOf(field)
var nullable = ""
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
ownertable := strings.Split(stmt.Schema.Table, ".")
m.DB.Raw("SELECT NULLABLE FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownertable[0], ownertable[1], field.DBName).Row().Scan(&nullable)
} else {
m.DB.Raw("SELECT NULLABLE FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field.DBName).Row().Scan(&nullable)
}
if field.NotNull && nullable == "Y" {
expr.SQL += " NOT NULL"
}
if field.Unique {
expr.SQL += " UNIQUE"
}
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
} else if field.DefaultValue != "(-)" {
expr.SQL += " DEFAULT " + field.DefaultValue
}
}
return
}
func (m Migrator) CreateConstraint(value interface{}, name string) error {
m.TryRemoveOnUpdate(value)
return m.Migrator.CreateConstraint(value, name)
}
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
for _, chk := range stmt.Schema.ParseCheckConstraints() {
if chk.Name == name {
return m.DB.Exec(
"ALTER TABLE ? DROP CHECK ?",
clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name},
).Error
}
}
return m.DB.Exec(
"ALTER TABLE ? DROP CONSTRAINT ?",
clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name},
).Error
})
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw(
"SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE TABLE_NAME = ? AND CONSTRAINT_NAME = ?", stmt.Table, name,
).Row().Scan(&count)
}) == nil && count > 0
}
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: stmt.Schema.Table}).Error
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Raw(
"SELECT COUNT(*) FROM USER_INDEXES WHERE TABLE_NAME = ? AND INDEX_NAME = ?",
m.Migrator.DB.NamingStrategy.TableName(stmt.Table),
m.Migrator.DB.NamingStrategy.IndexName(stmt.Table, name),
).Row().Scan(&count)
})
return count > 0
}
// https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
// ALTER INDEX oldindex RENAME TO newindex;
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"ALTER INDEX ? RENAME TO ?",
clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}
func (m Migrator) TryRemoveOnUpdate(values ...interface{}) error {
for _, value := range values {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
for _, rel := range stmt.Schema.Relationships.Relations {
constraint := rel.ParseConstraint()
if constraint != nil {
rel.Field.TagSettings["CONSTRAINT"] = strings.ReplaceAll(rel.Field.TagSettings["CONSTRAINT"], fmt.Sprintf("ON UPDATE %s", constraint.OnUpdate), "")
}
}
return nil
}); err != nil {
return err
}
}
return nil
}
func (m Migrator) TryQuotifyReservedWords(values ...interface{}) error {
for _, value := range values {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
for idx, v := range stmt.Schema.DBNames {
if IsReservedWord(v) {
stmt.Schema.DBNames[idx] = fmt.Sprintf(`"%s"`, v)
}
}
for _, v := range stmt.Schema.Fields {
if IsReservedWord(v.DBName) {
v.DBName = fmt.Sprintf(`"%s"`, v.DBName)
}
}
return nil
}); err != nil {
return err
}
}
return nil
}

59
namer.go Normal file
View File

@ -0,0 +1,59 @@
package oracle
import (
"fmt"
"strings"
"gorm.io/gorm/schema"
)
var _ schema.Namer = Namer{}
type Namer struct {
NamingStrategy schema.Namer
DBName string
}
func (n Namer) ConvertNameToFormat(x string) string {
return strings.ToUpper(x)
}
func (n Namer) TableName(table string) (name string) {
println("TableName", table)
tableName := n.ConvertNameToFormat(n.NamingStrategy.TableName(table))
if len(n.DBName) > 0 {
return fmt.Sprintf("%s.%s", n.DBName, tableName)
}
return tableName
}
func (n Namer) ColumnName(table, column string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column))
}
func (n Namer) JoinTableName(table string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.JoinTableName(table))
}
func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship))
}
func (n Namer) CheckerName(table, column string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column))
}
func (n Namer) IndexName(table, column string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.IndexName(table, column))
}
func (n Namer) SchemaName(table string) string {
println("SchemaName", table)
return n.ConvertNameToFormat(n.NamingStrategy.SchemaName(table))
}
func (n Namer) UniqueName(table, column string) string {
println("UniqueName", table, column)
return n.ConvertNameToFormat(n.NamingStrategy.UniqueName(table, column))
}

355
oracle.go
View File

@ -1,5 +1,360 @@
package oracle
import (
"context"
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/thoas/go-funk"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
_ "github.com/godror/godror"
_ "github.com/sijms/go-ora/v2"
)
const (
RowNumberAliasForOracle11 = "ROW_NUM"
)
func Hello() string {
return "GORM Oracle Driver"
}
type Config struct {
DriverName string
DSN string
Conn *sql.DB
DefaultStringSize uint
DBVer string
DBName string // 库名
}
type Dialector struct {
*Config
}
func New(config Config) gorm.Dialector {
return &Dialector{Config: &config}
}
func Open(dsn string) gorm.Dialector {
return &Dialector{Config: &Config{DSN: dsn}}
}
func (d Dialector) DummyTableName() string { return "DUAL" }
func (d Dialector) Name() string { return "oracle" }
func (d Dialector) Initialize(db *gorm.DB) (err error) {
db.NamingStrategy = Namer{
NamingStrategy: db.NamingStrategy,
DBName: d.DBName,
}
d.DefaultStringSize = 1024
// register callbacks
callbackConfig := &callbacks.Config{}
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
// d.DriverName = "godror"
d.DriverName = "oracle"
if d.Conn != nil {
db.ConnPool = d.Conn
} else {
db.ConnPool, err = sql.Open(d.DriverName, d.DSN)
if err != nil {
return
}
}
err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer)
if err != nil {
return err
}
db.Logger.Info(context.Background(), "DBVer:%s", d.DBVer)
if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
return
}
for k, v := range d.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
func (d Dialector) ClauseBuilders() (clauseBuilders map[string]clause.ClauseBuilder) {
clauseBuilders = make(map[string]clause.ClauseBuilder)
if dbVer, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]); dbVer > 11 {
clauseBuilders["LIMIT"] = d.RewriteLimit
} else {
clauseBuilders["LIMIT"] = d.RewriteLimit11
}
return
}
func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok {
if stmt, ok := builder.(*gorm.Statement); ok {
if _, ok := stmt.Clauses["ORDER BY"]; !ok {
s := stmt.Schema
builder.WriteString("ORDER BY ")
if s != nil && s.PrioritizedPrimaryField != nil {
builder.WriteQuoted(s.PrioritizedPrimaryField.DBName)
builder.WriteByte(' ')
} else {
builder.WriteString("(SELECT NULL FROM ")
builder.WriteString(d.DummyTableName())
builder.WriteString(")")
}
}
}
if offset := limit.Offset; offset > 0 {
builder.WriteString(" OFFSET ")
builder.WriteString(strconv.Itoa(offset))
builder.WriteString(" ROWS")
}
if limit := limit.Limit; *limit > 0 {
builder.WriteString(" FETCH NEXT ")
builder.WriteString(strconv.Itoa(*limit))
builder.WriteString(" ROWS ONLY")
}
}
}
func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) {
println("rewrite limit oracle 11g")
limit, ok := c.Expression.(clause.Limit)
if !ok {
return
}
offsetRows := limit.Offset
hasOffset := offsetRows > 0
limitRows, hasLimit := d.getLimitRows(limit)
if (!hasOffset && !hasLimit) || limitRows == 1 {
return
}
var stmt *gorm.Statement
if stmt, ok = builder.(*gorm.Statement); !ok {
return
}
if hasLimit && hasOffset {
subQuerySQL := fmt.Sprintf(
"SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY %s) AS %s FROM (%s) T) WHERE %s BETWEEN %d AND %d",
d.getOrderByColumns(stmt),
RowNumberAliasForOracle11,
strings.TrimSpace(stmt.SQL.String()),
RowNumberAliasForOracle11,
offsetRows+1,
offsetRows+limitRows,
)
stmt.SQL.Reset()
stmt.SQL.WriteString(subQuerySQL)
} else if hasLimit {
// only limit
d.rewriteRownumStmt(stmt, builder, " <= ", limitRows)
} else {
// only offset
d.rewriteRownumStmt(stmt, builder, " > ", offsetRows)
}
}
func (d Dialector) rewriteRownumStmt(stmt *gorm.Statement, builder clause.Builder, operator string, rows int) {
limitSql := strings.Builder{}
if _, ok := stmt.Clauses["WHERE"]; !ok {
limitSql.WriteString(" WHERE ")
} else {
limitSql.WriteString(" AND ")
}
limitSql.WriteString("ROWNUM")
limitSql.WriteString(operator)
limitSql.WriteString(strconv.Itoa(rows))
if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy {
_, _ = builder.WriteString(limitSql.String())
} else {
// "ORDER BY" before insert
sqlTmp := strings.Builder{}
sqlOld := stmt.SQL.String()
orderIndex := strings.Index(sqlOld, "ORDER BY") - 1
sqlTmp.WriteString(sqlOld[:orderIndex])
sqlTmp.WriteString(limitSql.String())
sqlTmp.WriteString(sqlOld[orderIndex:])
stmt.SQL = sqlTmp
}
}
func (d Dialector) getOrderByColumns(stmt *gorm.Statement) string {
if orderByClause, ok := stmt.Clauses["ORDER BY"]; ok {
var orderBy clause.OrderBy
if orderBy, ok = orderByClause.Expression.(clause.OrderBy); ok && len(orderBy.Columns) > 0 {
println("order columns:", orderBy.Columns)
orderByBuilder := strings.Builder{}
for i, column := range orderBy.Columns {
if i > 0 {
orderByBuilder.WriteString(", ")
}
orderByBuilder.WriteString(column.Column.Name)
if column.Desc {
orderByBuilder.WriteString(" DESC")
}
}
return orderByBuilder.String()
}
}
return "NULL"
}
func (d Dialector) getLimitRows(limit clause.Limit) (limitRows int, hasLimit bool) {
if l := limit.Limit; l != nil {
limitRows = *l
hasLimit = limitRows > 0
}
return
}
func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{
Migrator: migrator.Migrator{
Config: migrator.Config{
DB: db,
Dialector: d,
CreateIndexAfterCreateTable: true,
},
},
}
}
func (d Dialector) DataTypeOf(field *schema.Field) string {
delete(field.TagSettings, "RESTRICT")
var sqlType string
switch field.DataType {
case schema.Bool, schema.Int, schema.Uint, schema.Float:
sqlType = "INTEGER"
switch {
case field.DataType == schema.Float:
sqlType = "FLOAT"
case field.Size <= 8:
sqlType = "SMALLINT"
}
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
sqlType += " GENERATED BY DEFAULT AS IDENTITY"
}
case schema.String:
size := field.Size
defaultSize := d.DefaultStringSize
if size == 0 {
if defaultSize > 0 {
size = int(defaultSize)
} else {
hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
// TEXT, GEOMETRY or JSON column can't have a default value
if field.PrimaryKey || field.HasDefaultValue || hasIndex {
size = 191 // utf8mb4
}
}
}
if size >= 2000 {
sqlType = "CLOB"
} else {
sqlType = fmt.Sprintf("VARCHAR2(%d)", size)
}
case schema.Time:
sqlType = "TIMESTAMP WITH TIME ZONE"
if field.NotNull || field.PrimaryKey {
sqlType += " NOT NULL"
}
case schema.Bytes:
sqlType = "BLOB"
default:
sqlType := string(field.DataType)
if strings.EqualFold(sqlType, "text") {
sqlType = "CLOB"
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String()))
}
notNull := field.TagSettings["NOT NULL"]
unique := field.TagSettings["UNIQUE"]
additionalType := fmt.Sprintf("%s %s", notNull, unique)
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string {
if value, ok := field.TagSettings["COMMENT"]; ok {
return " COMMENT " + value
}
return ""
}())
}
sqlType = fmt.Sprintf("%v %v", sqlType, additionalType)
return sqlType
}
return sqlType
}
func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
return clause.Expr{SQL: "VALUES (DEFAULT)"}
}
func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v any) {}
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteString(str)
}
var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
func (d Dialector) Explain(sql string, vars ...any) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v any) any {
switch v := v.(type) {
case bool:
if v {
return 1
}
return 0
default:
return v
}
}).([]any)...)
}
func (d Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return tx.Error
}
func (d Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return tx.Error
}

28
reserved.go Normal file
View File

@ -0,0 +1,28 @@
package oracle
import (
"github.com/emirpasic/gods/sets/hashset"
"github.com/thoas/go-funk"
)
var ReservedWords = hashset.New(funk.Map(ReservedWordsList, func(s string) interface{} { return s }).([]interface{})...)
func IsReservedWord(v string) bool {
return ReservedWords.Contains(v)
}
var ReservedWordsList = []string{
"AGGREGATE", "AGGREGATES", "ALL", "ALLOW", "ANALYZE", "ANCESTOR", "AND", "ANY", "AS", "ASC", "AT", "AVG", "BETWEEN",
"BINARY_DOUBLE", "BINARY_FLOAT", "BLOB", "BRANCH", "BUILD", "BY", "BYTE", "CASE", "CAST", "CHAR", "CHILD", "CLEAR",
"CLOB", "COMMIT", "COMPILE", "CONSIDER", "COUNT", "DATATYPE", "DATE", "DATE_MEASURE", "DAY", "DECIMAL", "DELETE",
"DESC", "DESCENDANT", "DIMENSION", "DISALLOW", "DIVISION", "DML", "ELSE", "END", "ESCAPE", "EXECUTE", "FIRST",
"FLOAT", "FOR", "FROM", "HIERARCHIES", "HIERARCHY", "HOUR", "IGNORE", "IN", "INFINITE", "INSERT", "INTEGER",
"INTERVAL", "INTO", "IS", "LAST", "LEAF_DESCENDANT", "LEAVES", "LEVEL", "LIKE", "LIKEC", "LIKE2", "LIKE4", "LOAD",
"LOCAL", "LOG_SPEC", "LONG", "MAINTAIN", "MAX", "MEASURE", "MEASURES", "MEMBER", "MEMBERS", "MERGE", "MLSLABEL",
"MIN", "MINUTE", "MODEL", "MONTH", "NAN", "NCHAR", "NCLOB", "NO", "NONE", "NOT", "NULL", "NULLS", "NUMBER",
"NVARCHAR2", "OF", "OLAP", "OLAP_DML_EXPRESSION", "ON", "ONLY", "OPERATOR", "OR", "ORDER", "OVER", "OVERFLOW",
"PARALLEL", "PARENT", "PLSQL", "PRUNE", "RAW", "RELATIVE", "ROOT_ANCESTOR", "ROWID", "SCN", "SECOND", "SELF",
"SERIAL", "SET", "SOLVE", "SOME", "SORT", "SPEC", "SUM", "SYNCH", "TEXT_MEASURE", "THEN", "TIME", "TIMESTAMP",
"TO", "UNBRANCH", "UPDATE", "USING", "VALIDATE", "VALUES", "VARCHAR2", "WHEN", "WHERE", "WITHIN", "WITH", "YEAR",
"ZERO", "ZONE",
}

68
utils.go Normal file
View File

@ -0,0 +1,68 @@
package oracle
import (
"database/sql"
"reflect"
"time"
"gorm.io/gorm"
)
func convertCustomType(val interface{}) interface{} {
rv := reflect.ValueOf(val)
ri := rv.Interface()
typeName := reflect.TypeOf(ri).Name()
if reflect.TypeOf(val).Kind() == reflect.Ptr {
if rv.IsNil() {
typeName = rv.Type().Elem().Name()
} else {
for rv.Kind() == reflect.Ptr {
rv = rv.Elem()
}
ri = rv.Interface()
typeName = reflect.TypeOf(ri).Name()
}
}
if typeName == "DeletedAt" {
// gorm.DeletedAt
if rv.IsZero() {
val = sql.NullTime{}
} else {
val = getTimeValue(ri.(gorm.DeletedAt).Time)
}
} else if m := rv.MethodByName("Time"); m.IsValid() && m.Type().NumIn() == 0 {
// custom time type
for _, result := range m.Call([]reflect.Value{}) {
if reflect.TypeOf(result.Interface()).Name() == "Time" {
val = getTimeValue(result.Interface().(time.Time))
}
}
}
return val
}
func ptrDereference(obj interface{}) (value interface{}) {
if obj == nil {
return obj
}
if t := reflect.TypeOf(obj); t.Kind() != reflect.Ptr {
return obj
}
v := reflect.ValueOf(obj)
for v.Kind() == reflect.Ptr && !v.IsNil() {
v = v.Elem()
}
if !v.IsValid() || v.Kind() == reflect.Ptr && v.IsNil() {
return obj
}
value = v.Interface()
return
}
func getTimeValue(t time.Time) interface{} {
if t.IsZero() {
return sql.NullTime{}
}
return t
}