From 184a8722d8d46091e8a102b9bde3894cc19ac965 Mon Sep 17 00:00:00 2001 From: Charlie <3140647@qq.com> Date: Fri, 10 May 2024 10:27:21 +0800 Subject: [PATCH] update --- .gitignore | 6 ++ License | 25 +++++ README.md | 39 +++++++ create.go | 303 +++++++++++++++++----------------------------------- go.mod | 8 +- go.sum | 10 +- migrator.go | 9 +- namer.go | 39 ++----- oracle.go | 269 ++++++++++++++++------------------------------ 9 files changed, 285 insertions(+), 423 deletions(-) create mode 100644 .gitignore create mode 100644 License create mode 100644 README.md diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f701f50 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.idea +vendor/ +go.sum +CHANGELOG.md +/test_local/ +/go.work* diff --git a/License b/License new file mode 100644 index 0000000..be54fe7 --- /dev/null +++ b/License @@ -0,0 +1,25 @@ +The MIT License (MIT) + +Copyright (c) 2013-NOW + +Jinzhu , +Steve Fan , +CengSin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..eeccbcc --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ +# GORM Oracle Driver + + +## Description + +GORM Oracle driver for connect Oracle DB and Manage Oracle DB, Based on [CengSin/oracle](https://github.com/CengSin/oracle) +,not recommended for use in a production environment + +## Required dependency Install + +- Oracle 12C+ +- Golang 1.13+ +- see [ODPI-C Installation.](https://oracle.github.io/odpi/doc/installation.html) +- gorm 1.24.0+ + +## Quick Start +### how to install +```bash +go get github.com/dzwvip/oracle +``` +### usage + +```go +import ( + "fmt" + "github.com/dzwvip/oracle" + "gorm.io/gorm" + "log" +) + +func main() { + db, err := gorm.Open(oracle.Open("system/oracle@127.0.0.1:1521/XE"), &gorm.Config{}) + if err != nil { + // panic error or log error info + } + + // do somethings +} +``` diff --git a/create.go b/create.go index 31c414d..b5656be 100644 --- a/create.go +++ b/create.go @@ -1,255 +1,152 @@ package oracle import ( + "bytes" "database/sql" "reflect" + "github.com/thoas/go-funk" "gorm.io/gorm" "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" + gormSchema "gorm.io/gorm/schema" + + "git.charlienet.top/go/oracle/clauses" ) func Create(db *gorm.DB) { - if db.Error != nil { - return - } - stmt := db.Statement - if stmt == nil { + schema := stmt.Schema + boundVars := make(map[string]int) + + if stmt == nil || schema == nil { return } - stmtSchema := stmt.Schema - if stmtSchema == nil { - return - } + hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0 if !stmt.Unscoped { - for _, c := range stmtSchema.CreateClauses { + for _, c := range schema.CreateClauses { stmt.AddClause(c) } } - if stmt.SQL.Len() == 0 { - var ( - createValues = callbacks.ConvertToCreateValues(stmt) - onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict) - ) + if stmt.SQL.String() == "" { + values := callbacks.ConvertToCreateValues(stmt) + onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict) + // are all columns in value the primary fields in schema only? + if hasConflict && funk.Contains( + funk.Map(values.Columns, func(c clause.Column) string { return c.Name }), + funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }), + ) { + stmt.AddClauseIfNotExists(clauses.Merge{ + Using: []clause.Interface{ + clause.Select{ + Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column { + // HACK: I can not come up with a better alternative for now + // I want to add a value to the list of variable and then capture the bind variable position as well + buf := bytes.NewBufferString("") + stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)]) + stmt.BindVarTo(buf, stmt, nil) - if hasConflict { - if len(stmtSchema.PrimaryFields) > 0 { - columnsMap := map[string]bool{} - for _, column := range createValues.Columns { - columnsMap[column.Name] = true - } - - for _, field := range stmtSchema.PrimaryFields { - if _, ok := columnsMap[field.DBName]; !ok { - hasConflict = false + column.Alias = column.Name + // then the captured bind var will be the name + column.Name = buf.String() + return column + }).([]clause.Column), + }, + clause.From{ + Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}}, + }, + }, + On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression { + return clause.Eq{ + Column: clause.Column{Table: stmt.Schema.Table, Name: field.DBName}, + Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName}, } - } - } else { - hasConflict = false - } - } + }).([]clause.Expression), + }) + stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates}) + stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values}) - hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0 - if hasConflict { - MergeCreate(db, onConflict, createValues) + stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED") } else { stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}}) - stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}}) - + stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}}) if hasDefaultValues { - columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue)) - for idx, field := range stmtSchema.FieldsWithDefaultDBValue { - columns[idx] = clause.Column{Name: field.DBName} - } - stmt.AddClauseIfNotExists(clause.Returning{Columns: columns}) + stmt.AddClauseIfNotExists(clause.Returning{ + Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column { + return clause.Column{Name: field.DBName} + }).([]clause.Column), + }) } stmt.Build("INSERT", "VALUES", "RETURNING") - if hasDefaultValues { - _, _ = stmt.WriteString(" INTO ") - for idx, field := range stmtSchema.FieldsWithDefaultDBValue { + stmt.WriteString(" INTO ") + for idx, field := range schema.FieldsWithDefaultDBValue { if idx > 0 { - _ = stmt.WriteByte(',') + stmt.WriteByte(',') } + boundVars[field.Name] = len(stmt.Vars) stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()}) } - _, _ = stmt.WriteString(" /*-sql.Out{}-*/") } } - if !db.DryRun && db.Error == nil { - if hasConflict { - for i, val := range stmt.Vars { - // HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned - stmt.Vars[i] = convertValue(val) - } - - result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...) - if db.AddError(err) == nil { - db.RowsAffected, _ = result.RowsAffected() - // TODO: get merged returning - } - } else { - for idx, values := range createValues.Values { - for i, val := range values { - // HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned - stmt.Vars[i] = convertValue(val) - } - - result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...) - if db.AddError(err) == nil { - db.RowsAffected, _ = result.RowsAffected() - - if hasDefaultValues { - getDefaultValues(db, idx) + if !db.DryRun { + for idx, vals := range values.Values { + // HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned + for idx, val := range vals { + switch v := val.(type) { + case bool: + if v { + val = 1 + } else { + val = 0 } } + + stmt.Vars[idx] = val } - } - } - } -} + // and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert) + // we keep track of the index so that the sub-reflected value is also correct -func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) { - var dummyTable string - switch d := ptrDereference(db.Dialector).(type) { - case Dialector: - dummyTable = d.DummyTableName() - default: - dummyTable = "DUAL" - } + // BIG BUG: what if any of the transactions failed? some result might already be inserted that oracle is so + // sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point, + // resulting in dangling row entries, so we might need to delete them if an error happens - _, _ = db.Statement.WriteString("MERGE INTO ") - db.Statement.WriteQuoted(db.Statement.Table) - _, _ = db.Statement.WriteString(" USING (") + switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err { + case nil: // success + db.RowsAffected, _ = result.RowsAffected() - for idx, value := range values.Values { - if idx > 0 { - _, _ = db.Statement.WriteString(" UNION ALL ") - } - - _, _ = db.Statement.WriteString("SELECT ") - for i, v := range value { - if i > 0 { - _ = db.Statement.WriteByte(',') - } - column := values.Columns[i] - db.Statement.AddVar(db.Statement, v) - _, _ = db.Statement.WriteString(" AS ") - db.Statement.WriteQuoted(column.Name) - } - _, _ = db.Statement.WriteString(" FROM ") - _, _ = db.Statement.WriteString(dummyTable) - } - - _, _ = db.Statement.WriteString(`) `) - db.Statement.WriteQuoted("excluded") - _, _ = db.Statement.WriteString(" ON (") - - var where clause.Where - for _, field := range db.Statement.Schema.PrimaryFields { - where.Exprs = append(where.Exprs, clause.Eq{ - Column: clause.Column{Table: db.Statement.Table, Name: field.DBName}, - Value: clause.Column{Table: "excluded", Name: field.DBName}, - }) - } - where.Build(db.Statement) - _ = db.Statement.WriteByte(')') - - if len(onConflict.DoUpdates) > 0 { - _, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ") - onConflict.DoUpdates.Build(db.Statement) - } - - _, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (") - - written := false - for _, column := range values.Columns { - if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name { - if written { - _ = db.Statement.WriteByte(',') - } - written = true - db.Statement.WriteQuoted(column.Name) - } - } - - _, _ = db.Statement.WriteString(") VALUES (") - - written = false - for _, column := range values.Columns { - if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name { - if written { - _ = db.Statement.WriteByte(',') - } - written = true - db.Statement.WriteQuoted(clause.Column{ - Table: "excluded", - Name: column.Name, - }) - } - } - _, _ = db.Statement.WriteString(")") -} - -func convertValue(val interface{}) interface{} { - val = ptrDereference(val) - switch v := val.(type) { - case bool: - if v { - val = 1 - } else { - val = 0 - } - // case string: - // if len(v) > 2000 { - // val = go_ora.Clob{String: v, Valid: true} - // } - default: - val = convertCustomType(val) - } - return val -} - -func getDefaultValues(db *gorm.DB, idx int) { - insertTo := db.Statement.ReflectValue - switch insertTo.Kind() { - case reflect.Slice, reflect.Array: - insertTo = insertTo.Index(idx) - default: - } - if insertTo.Kind() == reflect.Pointer { - insertTo = insertTo.Elem() - } - - for _, val := range db.Statement.Vars { - switch v := val.(type) { - case sql.Out: - switch insertTo.Kind() { - case reflect.Slice, reflect.Array: - for i := insertTo.Len() - 1; i >= 0; i-- { - rv := insertTo.Index(i) - if reflect.Indirect(rv).Kind() != reflect.Struct { - break + insertTo := stmt.ReflectValue + switch insertTo.Kind() { + case reflect.Slice, reflect.Array: + insertTo = insertTo.Index(idx) } - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) - if isZero { - _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest)) + if hasDefaultValues { + // bind returning value back to reflected value in the respective fields + funk.ForEach( + funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool { + return funk.Contains(boundVars, field.Name) + }), + func(field *gormSchema.Field) { + switch insertTo.Kind() { + case reflect.Struct: + if err = field.Set(stmt.Context, insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil { + db.AddError(err) + } + case reflect.Map: + // todo 设置id的值 + } + }, + ) } + default: // failure + db.AddError(err) } - case reflect.Struct: - _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo) - if isZero { - _ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest)) - } - default: } - default: } } } diff --git a/go.mod b/go.mod index 3f1effe..a45d74a 100644 --- a/go.mod +++ b/go.mod @@ -4,10 +4,10 @@ go 1.22 require ( github.com/emirpasic/gods v1.18.1 - github.com/godror/godror v0.42.2 - github.com/sijms/go-ora/v2 v2.8.16 + github.com/godror/godror v0.43.0 github.com/thoas/go-funk v0.9.3 gorm.io/gorm v1.25.10 + ) require ( @@ -15,6 +15,6 @@ require ( github.com/godror/knownpb v0.1.1 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 // indirect - google.golang.org/protobuf v1.33.0 // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect + google.golang.org/protobuf v1.34.1 // indirect ) diff --git a/go.sum b/go.sum index 14d6913..3145790 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/godror/godror v0.42.2 h1:TmOV0fr4jJxwDD6vWtSieQFOZ75LnSbkJaudfuJOsVQ= -github.com/godror/godror v0.42.2/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94= +github.com/godror/godror v0.43.0 h1:qMbQwG0ejJnKma3bBvrJg1rkiyP5b4v6uxvx9zDrKJw= +github.com/godror/godror v0.43.0/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94= github.com/godror/knownpb v0.1.1 h1:A4J7jdx7jWBhJm18NntafzSC//iZDHkDi1+juwQ5pTI= github.com/godror/knownpb v0.1.1/go.mod h1:4nRFbQo1dDuwKnblRXDxrfCFYeT4hjg3GjMqef58eRE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= @@ -22,8 +22,6 @@ github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc= github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/sijms/go-ora/v2 v2.8.16 h1:XiBAqHjoXUnJnGnRoHD/6NJdiYQ2pr0gPciJM8BE+70= -github.com/sijms/go-ora/v2 v2.8.16/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4FWlH6xk= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -31,6 +29,8 @@ github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw= github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 h1:6R2FC06FonbXQ8pK11/PDFY6N6LWlf9KlzibaCapmqc= golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/sync v0.0.0-20220513210516-0976fa681c29 h1:w8s32wxx3sY+OjLlv9qltkLU5yvJzxjjgiHWLjdIcw4= golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= @@ -39,6 +39,8 @@ golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/migrator.go b/migrator.go index 2673cd3..97d73ae 100644 --- a/migrator.go +++ b/migrator.go @@ -3,9 +3,8 @@ package oracle import ( "database/sql" "fmt" - "strings" - "gorm.io/gorm/schema" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -276,11 +275,11 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { // https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { - // ALTER INDEX oldindex RENAME TO newindex; + panic("TODO") return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( - "ALTER INDEX ? RENAME TO ?", - clause.Column{Name: oldName}, clause.Column{Name: newName}, + "ALTER INDEX ?.? RENAME TO ?", // wat + clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } diff --git a/namer.go b/namer.go index c493b83..a7dd323 100644 --- a/namer.go +++ b/namer.go @@ -1,59 +1,38 @@ package oracle import ( - "fmt" - "strings" - "gorm.io/gorm/schema" + "strings" ) -var _ schema.Namer = Namer{} - type Namer struct { - NamingStrategy schema.Namer - DBName string + schema.NamingStrategy } -func (n Namer) ConvertNameToFormat(x string) string { +func ConvertNameToFormat(x string) string { return strings.ToUpper(x) } func (n Namer) TableName(table string) (name string) { - println("TableName", table) - - tableName := n.ConvertNameToFormat(n.NamingStrategy.TableName(table)) - if len(n.DBName) > 0 { - return fmt.Sprintf("%s.%s", n.DBName, tableName) - } - return tableName + return ConvertNameToFormat(n.NamingStrategy.TableName(table)) } func (n Namer) ColumnName(table, column string) (name string) { - return n.ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column)) + return ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column)) } func (n Namer) JoinTableName(table string) (name string) { - return n.ConvertNameToFormat(n.NamingStrategy.JoinTableName(table)) + return ConvertNameToFormat(n.NamingStrategy.JoinTableName(table)) } func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) { - return n.ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship)) + return ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship)) } func (n Namer) CheckerName(table, column string) (name string) { - return n.ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column)) + return ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column)) } func (n Namer) IndexName(table, column string) (name string) { - return n.ConvertNameToFormat(n.NamingStrategy.IndexName(table, column)) -} - -func (n Namer) SchemaName(table string) string { - println("SchemaName", table) - return n.ConvertNameToFormat(n.NamingStrategy.SchemaName(table)) -} - -func (n Namer) UniqueName(table, column string) string { - println("UniqueName", table, column) - return n.ConvertNameToFormat(n.NamingStrategy.UniqueName(table, column)) + return ConvertNameToFormat(n.NamingStrategy.IndexName(table, column)) } diff --git a/oracle.go b/oracle.go index bfc1e68..d747c4c 100644 --- a/oracle.go +++ b/oracle.go @@ -4,10 +4,14 @@ import ( "context" "database/sql" "fmt" + "log" "regexp" "strconv" "strings" + "gorm.io/gorm/utils" + + _ "github.com/godror/godror" "github.com/thoas/go-funk" "gorm.io/gorm" "gorm.io/gorm/callbacks" @@ -15,75 +19,61 @@ import ( "gorm.io/gorm/logger" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" - "gorm.io/gorm/utils" - - _ "github.com/godror/godror" - _ "github.com/sijms/go-ora/v2" ) -const ( - RowNumberAliasForOracle11 = "ROW_NUM" -) - -func Hello() string { - return "GORM Oracle Driver" -} - type Config struct { DriverName string DSN string - Conn *sql.DB + Conn gorm.ConnPool //*sql.DB DefaultStringSize uint + DBName string DBVer string - DBName string // 库名 } type Dialector struct { *Config } -func New(config Config) gorm.Dialector { - return &Dialector{Config: &config} -} - func Open(dsn string) gorm.Dialector { return &Dialector{Config: &Config{DSN: dsn}} } -func (d Dialector) DummyTableName() string { return "DUAL" } +func New(config Config) gorm.Dialector { + return &Dialector{Config: &config} +} -func (d Dialector) Name() string { return "oracle" } +func (d Dialector) DummyTableName() string { + return "DUAL" +} + +func (d Dialector) Name() string { + return "oracle" +} func (d Dialector) Initialize(db *gorm.DB) (err error) { - db.NamingStrategy = Namer{ - NamingStrategy: db.NamingStrategy, - DBName: d.DBName, - } + db.NamingStrategy = Namer{db.NamingStrategy.(schema.NamingStrategy)} d.DefaultStringSize = 1024 // register callbacks - callbackConfig := &callbacks.Config{} - callbacks.RegisterDefaultCallbacks(db, callbackConfig) + //callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{WithReturning: true}) + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, + }) - // d.DriverName = "godror" - d.DriverName = "oracle" + d.DriverName = "godror" if d.Conn != nil { db.ConnPool = d.Conn } else { db.ConnPool, err = sql.Open(d.DriverName, d.DSN) - if err != nil { - return - } } - err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer) if err != nil { return err } - - db.Logger.Info(context.Background(), "DBVer:%s", d.DBVer) - + //log.Println("DBver:" + d.DBVer) if err = db.Callback().Create().Replace("gorm:create", Create); err != nil { return } @@ -91,24 +81,26 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) { for k, v := range d.ClauseBuilders() { db.ClauseBuilders[k] = v } - return } -func (d Dialector) ClauseBuilders() (clauseBuilders map[string]clause.ClauseBuilder) { - clauseBuilders = make(map[string]clause.ClauseBuilder) - if dbVer, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]); dbVer > 11 { - clauseBuilders["LIMIT"] = d.RewriteLimit +func (d Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { + dbver, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]) + if dbver > 0 && dbver < 12 { + return map[string]clause.ClauseBuilder{ + "LIMIT": d.RewriteLimit11, + } + } else { - clauseBuilders["LIMIT"] = d.RewriteLimit11 + return map[string]clause.ClauseBuilder{ + "LIMIT": d.RewriteLimit, + } } - return } func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) { if limit, ok := c.Expression.(clause.Limit); ok { - if stmt, ok := builder.(*gorm.Statement); ok { if _, ok := stmt.Clauses["ORDER BY"]; !ok { s := stmt.Schema @@ -137,100 +129,38 @@ func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) { } } +// Oracle11 Limit func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) { - println("rewrite limit oracle 11g") - limit, ok := c.Expression.(clause.Limit) - if !ok { - return - } - - offsetRows := limit.Offset - hasOffset := offsetRows > 0 - limitRows, hasLimit := d.getLimitRows(limit) - if (!hasOffset && !hasLimit) || limitRows == 1 { - return - } - - var stmt *gorm.Statement - if stmt, ok = builder.(*gorm.Statement); !ok { - return - } - - if hasLimit && hasOffset { - subQuerySQL := fmt.Sprintf( - "SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY %s) AS %s FROM (%s) T) WHERE %s BETWEEN %d AND %d", - d.getOrderByColumns(stmt), - RowNumberAliasForOracle11, - strings.TrimSpace(stmt.SQL.String()), - RowNumberAliasForOracle11, - offsetRows+1, - offsetRows+limitRows, - ) - stmt.SQL.Reset() - stmt.SQL.WriteString(subQuerySQL) - } else if hasLimit { - // only limit - d.rewriteRownumStmt(stmt, builder, " <= ", limitRows) - } else { - // only offset - d.rewriteRownumStmt(stmt, builder, " > ", offsetRows) - } -} - -func (d Dialector) rewriteRownumStmt(stmt *gorm.Statement, builder clause.Builder, operator string, rows int) { - limitSql := strings.Builder{} - if _, ok := stmt.Clauses["WHERE"]; !ok { - limitSql.WriteString(" WHERE ") - } else { - limitSql.WriteString(" AND ") - } - limitSql.WriteString("ROWNUM") - limitSql.WriteString(operator) - limitSql.WriteString(strconv.Itoa(rows)) - - if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy { - _, _ = builder.WriteString(limitSql.String()) - } else { - // "ORDER BY" before insert - sqlTmp := strings.Builder{} - sqlOld := stmt.SQL.String() - orderIndex := strings.Index(sqlOld, "ORDER BY") - 1 - sqlTmp.WriteString(sqlOld[:orderIndex]) - sqlTmp.WriteString(limitSql.String()) - sqlTmp.WriteString(sqlOld[orderIndex:]) - stmt.SQL = sqlTmp - } -} - -func (d Dialector) getOrderByColumns(stmt *gorm.Statement) string { - if orderByClause, ok := stmt.Clauses["ORDER BY"]; ok { - var orderBy clause.OrderBy - - if orderBy, ok = orderByClause.Expression.(clause.OrderBy); ok && len(orderBy.Columns) > 0 { - println("order columns:", orderBy.Columns) - - orderByBuilder := strings.Builder{} - for i, column := range orderBy.Columns { - if i > 0 { - orderByBuilder.WriteString(", ") - } - orderByBuilder.WriteString(column.Column.Name) - if column.Desc { - orderByBuilder.WriteString(" DESC") + if limit, ok := c.Expression.(clause.Limit); ok { + if stmt, ok := builder.(*gorm.Statement); ok { + limitsql := strings.Builder{} + if limit := limit.Limit; *limit > 0 { + if _, ok := stmt.Clauses["WHERE"]; !ok { + limitsql.WriteString(" WHERE ") + } else { + limitsql.WriteString(" AND ") } + limitsql.WriteString("ROWNUM <= ") + limitsql.WriteString(strconv.Itoa(*limit)) + } + if _, ok := stmt.Clauses["ORDER BY"]; !ok { + builder.WriteString(limitsql.String()) + } else { + // "ORDER BY" before insert + sqltmp := strings.Builder{} + sqlold := stmt.SQL.String() + orderindx := strings.Index(sqlold, "ORDER BY") - 1 + sqltmp.WriteString(sqlold[:orderindx]) + sqltmp.WriteString(limitsql.String()) + sqltmp.WriteString(sqlold[orderindx:]) + log.Println(sqltmp.String()) + stmt.SQL = sqltmp } - return orderByBuilder.String() } } - return "NULL" } - -func (d Dialector) getLimitRows(limit clause.Limit) (limitRows int, hasLimit bool) { - if l := limit.Limit; l != nil { - limitRows = *l - hasLimit = limitRows > 0 - } - return +func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression { + return clause.Expr{SQL: "VALUES (DEFAULT)"} } func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator { @@ -245,8 +175,35 @@ func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator { } } +func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteString(":") + writer.WriteString(strconv.Itoa(len(stmt.Vars))) +} + +func (d Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteString(str) +} + +var numericPlaceholder = regexp.MustCompile(`:(\d+)`) + +func (d Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} { + switch v := v.(type) { + case bool: + if v { + return 1 + } + return 0 + default: + return v + } + }).([]interface{})...) +} + func (d Dialector) DataTypeOf(field *schema.Field) string { - delete(field.TagSettings, "RESTRICT") + if _, found := field.TagSettings["RESTRICT"]; found { + delete(field.TagSettings, "RESTRICT") + } var sqlType string @@ -264,7 +221,7 @@ func (d Dialector) DataTypeOf(field *schema.Field) string { if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { sqlType += " GENERATED BY DEFAULT AS IDENTITY" } - case schema.String: + case schema.String, "VARCHAR2": size := field.Size defaultSize := d.DefaultStringSize @@ -288,13 +245,11 @@ func (d Dialector) DataTypeOf(field *schema.Field) string { case schema.Time: sqlType = "TIMESTAMP WITH TIME ZONE" - if field.NotNull || field.PrimaryKey { - sqlType += " NOT NULL" - } + case schema.Bytes: sqlType = "BLOB" default: - sqlType := string(field.DataType) + sqlType = string(field.DataType) if strings.EqualFold(sqlType, "text") { sqlType = "CLOB" @@ -304,51 +259,11 @@ func (d Dialector) DataTypeOf(field *schema.Field) string { panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String())) } - notNull := field.TagSettings["NOT NULL"] - unique := field.TagSettings["UNIQUE"] - additionalType := fmt.Sprintf("%s %s", notNull, unique) - if value, ok := field.TagSettings["DEFAULT"]; ok { - additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string { - if value, ok := field.TagSettings["COMMENT"]; ok { - return " COMMENT " + value - } - return "" - }()) - } - - sqlType = fmt.Sprintf("%v %v", sqlType, additionalType) - return sqlType } return sqlType } -func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression { - return clause.Expr{SQL: "VALUES (DEFAULT)"} -} - -func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v any) {} - -func (d Dialector) QuoteTo(writer clause.Writer, str string) { - writer.WriteString(str) -} - -var numericPlaceholder = regexp.MustCompile(`:(\d+)`) - -func (d Dialector) Explain(sql string, vars ...any) string { - return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v any) any { - switch v := v.(type) { - case bool: - if v { - return 1 - } - return 0 - default: - return v - } - }).([]any)...) -} - func (d Dialector) SavePoint(tx *gorm.DB, name string) error { tx.Exec("SAVEPOINT " + name) return tx.Error