This commit is contained in:
2024-05-10 10:27:21 +08:00
parent 90f807defc
commit 184a8722d8
9 changed files with 285 additions and 423 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
.idea
vendor/
go.sum
CHANGELOG.md
/test_local/
/go.work*

25
License Normal file
View File

@ -0,0 +1,25 @@
The MIT License (MIT)
Copyright (c) 2013-NOW
Jinzhu <wosmvp@gmail.com>,
Steve Fan <stevefan1999-personal@github>,
CengSin <zephone0724@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

39
README.md Normal file
View File

@ -0,0 +1,39 @@
# GORM Oracle Driver
## Description
GORM Oracle driver for connect Oracle DB and Manage Oracle DB, Based on [CengSin/oracle](https://github.com/CengSin/oracle)
not recommended for use in a production environment
## Required dependency Install
- Oracle 12C+
- Golang 1.13+
- see [ODPI-C Installation.](https://oracle.github.io/odpi/doc/installation.html)
- gorm 1.24.0+
## Quick Start
### how to install
```bash
go get github.com/dzwvip/oracle
```
### usage
```go
import (
"fmt"
"github.com/dzwvip/oracle"
"gorm.io/gorm"
"log"
)
func main() {
db, err := gorm.Open(oracle.Open("system/oracle@127.0.0.1:1521/XE"), &gorm.Config{})
if err != nil {
// panic error or log error info
}
// do somethings
}
```

303
create.go
View File

@ -1,255 +1,152 @@
package oracle package oracle
import ( import (
"bytes"
"database/sql" "database/sql"
"reflect" "reflect"
"github.com/thoas/go-funk"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/callbacks" "gorm.io/gorm/callbacks"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
gormSchema "gorm.io/gorm/schema"
"git.charlienet.top/go/oracle/clauses"
) )
func Create(db *gorm.DB) { func Create(db *gorm.DB) {
if db.Error != nil {
return
}
stmt := db.Statement stmt := db.Statement
if stmt == nil { schema := stmt.Schema
boundVars := make(map[string]int)
if stmt == nil || schema == nil {
return return
} }
stmtSchema := stmt.Schema hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0
if stmtSchema == nil {
return
}
if !stmt.Unscoped { if !stmt.Unscoped {
for _, c := range stmtSchema.CreateClauses { for _, c := range schema.CreateClauses {
stmt.AddClause(c) stmt.AddClause(c)
} }
} }
if stmt.SQL.Len() == 0 { if stmt.SQL.String() == "" {
var ( values := callbacks.ConvertToCreateValues(stmt)
createValues = callbacks.ConvertToCreateValues(stmt) onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict) // are all columns in value the primary fields in schema only?
) if hasConflict && funk.Contains(
funk.Map(values.Columns, func(c clause.Column) string { return c.Name }),
funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }),
) {
stmt.AddClauseIfNotExists(clauses.Merge{
Using: []clause.Interface{
clause.Select{
Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column {
// HACK: I can not come up with a better alternative for now
// I want to add a value to the list of variable and then capture the bind variable position as well
buf := bytes.NewBufferString("")
stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)])
stmt.BindVarTo(buf, stmt, nil)
if hasConflict { column.Alias = column.Name
if len(stmtSchema.PrimaryFields) > 0 { // then the captured bind var will be the name
columnsMap := map[string]bool{} column.Name = buf.String()
for _, column := range createValues.Columns { return column
columnsMap[column.Name] = true }).([]clause.Column),
} },
clause.From{
for _, field := range stmtSchema.PrimaryFields { Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}},
if _, ok := columnsMap[field.DBName]; !ok { },
hasConflict = false },
On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression {
return clause.Eq{
Column: clause.Column{Table: stmt.Schema.Table, Name: field.DBName},
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName},
} }
} }).([]clause.Expression),
} else { })
hasConflict = false stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates})
} stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values})
}
hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0 stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED")
if hasConflict {
MergeCreate(db, onConflict, createValues)
} else { } else {
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}}) stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}})
stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}}) stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}})
if hasDefaultValues { if hasDefaultValues {
columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue)) stmt.AddClauseIfNotExists(clause.Returning{
for idx, field := range stmtSchema.FieldsWithDefaultDBValue { Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column {
columns[idx] = clause.Column{Name: field.DBName} return clause.Column{Name: field.DBName}
} }).([]clause.Column),
stmt.AddClauseIfNotExists(clause.Returning{Columns: columns}) })
} }
stmt.Build("INSERT", "VALUES", "RETURNING") stmt.Build("INSERT", "VALUES", "RETURNING")
if hasDefaultValues { if hasDefaultValues {
_, _ = stmt.WriteString(" INTO ") stmt.WriteString(" INTO ")
for idx, field := range stmtSchema.FieldsWithDefaultDBValue { for idx, field := range schema.FieldsWithDefaultDBValue {
if idx > 0 { if idx > 0 {
_ = stmt.WriteByte(',') stmt.WriteByte(',')
} }
boundVars[field.Name] = len(stmt.Vars)
stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()}) stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()})
} }
_, _ = stmt.WriteString(" /*-sql.Out{}-*/")
} }
} }
if !db.DryRun && db.Error == nil { if !db.DryRun {
if hasConflict { for idx, vals := range values.Values {
for i, val := range stmt.Vars { // HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned for idx, val := range vals {
stmt.Vars[i] = convertValue(val) switch v := val.(type) {
} case bool:
if v {
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...) val = 1
if db.AddError(err) == nil { } else {
db.RowsAffected, _ = result.RowsAffected() val = 0
// TODO: get merged returning
}
} else {
for idx, values := range createValues.Values {
for i, val := range values {
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
stmt.Vars[i] = convertValue(val)
}
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
if hasDefaultValues {
getDefaultValues(db, idx)
} }
} }
stmt.Vars[idx] = val
} }
} // and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert)
} // we keep track of the index so that the sub-reflected value is also correct
}
}
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) { // BIG BUG: what if any of the transactions failed? some result might already be inserted that oracle is so
var dummyTable string // sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point,
switch d := ptrDereference(db.Dialector).(type) { // resulting in dangling row entries, so we might need to delete them if an error happens
case Dialector:
dummyTable = d.DummyTableName()
default:
dummyTable = "DUAL"
}
_, _ = db.Statement.WriteString("MERGE INTO ") switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err {
db.Statement.WriteQuoted(db.Statement.Table) case nil: // success
_, _ = db.Statement.WriteString(" USING (") db.RowsAffected, _ = result.RowsAffected()
for idx, value := range values.Values { insertTo := stmt.ReflectValue
if idx > 0 { switch insertTo.Kind() {
_, _ = db.Statement.WriteString(" UNION ALL ") case reflect.Slice, reflect.Array:
} insertTo = insertTo.Index(idx)
_, _ = db.Statement.WriteString("SELECT ")
for i, v := range value {
if i > 0 {
_ = db.Statement.WriteByte(',')
}
column := values.Columns[i]
db.Statement.AddVar(db.Statement, v)
_, _ = db.Statement.WriteString(" AS ")
db.Statement.WriteQuoted(column.Name)
}
_, _ = db.Statement.WriteString(" FROM ")
_, _ = db.Statement.WriteString(dummyTable)
}
_, _ = db.Statement.WriteString(`) `)
db.Statement.WriteQuoted("excluded")
_, _ = db.Statement.WriteString(" ON (")
var where clause.Where
for _, field := range db.Statement.Schema.PrimaryFields {
where.Exprs = append(where.Exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
Value: clause.Column{Table: "excluded", Name: field.DBName},
})
}
where.Build(db.Statement)
_ = db.Statement.WriteByte(')')
if len(onConflict.DoUpdates) > 0 {
_, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
onConflict.DoUpdates.Build(db.Statement)
}
_, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
written := false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(column.Name)
}
}
_, _ = db.Statement.WriteString(") VALUES (")
written = false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(clause.Column{
Table: "excluded",
Name: column.Name,
})
}
}
_, _ = db.Statement.WriteString(")")
}
func convertValue(val interface{}) interface{} {
val = ptrDereference(val)
switch v := val.(type) {
case bool:
if v {
val = 1
} else {
val = 0
}
// case string:
// if len(v) > 2000 {
// val = go_ora.Clob{String: v, Valid: true}
// }
default:
val = convertCustomType(val)
}
return val
}
func getDefaultValues(db *gorm.DB, idx int) {
insertTo := db.Statement.ReflectValue
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
insertTo = insertTo.Index(idx)
default:
}
if insertTo.Kind() == reflect.Pointer {
insertTo = insertTo.Elem()
}
for _, val := range db.Statement.Vars {
switch v := val.(type) {
case sql.Out:
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
for i := insertTo.Len() - 1; i >= 0; i-- {
rv := insertTo.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
} }
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv) if hasDefaultValues {
if isZero { // bind returning value back to reflected value in the respective fields
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest)) funk.ForEach(
funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool {
return funk.Contains(boundVars, field.Name)
}),
func(field *gormSchema.Field) {
switch insertTo.Kind() {
case reflect.Struct:
if err = field.Set(stmt.Context, insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil {
db.AddError(err)
}
case reflect.Map:
// todo 设置id的值
}
},
)
} }
default: // failure
db.AddError(err)
} }
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo)
if isZero {
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest))
}
default:
} }
default:
} }
} }
} }

8
go.mod
View File

@ -4,10 +4,10 @@ go 1.22
require ( require (
github.com/emirpasic/gods v1.18.1 github.com/emirpasic/gods v1.18.1
github.com/godror/godror v0.42.2 github.com/godror/godror v0.43.0
github.com/sijms/go-ora/v2 v2.8.16
github.com/thoas/go-funk v0.9.3 github.com/thoas/go-funk v0.9.3
gorm.io/gorm v1.25.10 gorm.io/gorm v1.25.10
) )
require ( require (
@ -15,6 +15,6 @@ require (
github.com/godror/knownpb v0.1.1 // indirect github.com/godror/knownpb v0.1.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 // indirect golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
google.golang.org/protobuf v1.33.0 // indirect google.golang.org/protobuf v1.34.1 // indirect
) )

10
go.sum
View File

@ -8,8 +8,8 @@ github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/godror/godror v0.42.2 h1:TmOV0fr4jJxwDD6vWtSieQFOZ75LnSbkJaudfuJOsVQ= github.com/godror/godror v0.43.0 h1:qMbQwG0ejJnKma3bBvrJg1rkiyP5b4v6uxvx9zDrKJw=
github.com/godror/godror v0.42.2/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94= github.com/godror/godror v0.43.0/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94=
github.com/godror/knownpb v0.1.1 h1:A4J7jdx7jWBhJm18NntafzSC//iZDHkDi1+juwQ5pTI= github.com/godror/knownpb v0.1.1 h1:A4J7jdx7jWBhJm18NntafzSC//iZDHkDi1+juwQ5pTI=
github.com/godror/knownpb v0.1.1/go.mod h1:4nRFbQo1dDuwKnblRXDxrfCFYeT4hjg3GjMqef58eRE= github.com/godror/knownpb v0.1.1/go.mod h1:4nRFbQo1dDuwKnblRXDxrfCFYeT4hjg3GjMqef58eRE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
@ -22,8 +22,6 @@ github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68= github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sijms/go-ora/v2 v2.8.16 h1:XiBAqHjoXUnJnGnRoHD/6NJdiYQ2pr0gPciJM8BE+70=
github.com/sijms/go-ora/v2 v2.8.16/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4FWlH6xk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
@ -31,6 +29,8 @@ github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw=
github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 h1:6R2FC06FonbXQ8pK11/PDFY6N6LWlf9KlzibaCapmqc= golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 h1:6R2FC06FonbXQ8pK11/PDFY6N6LWlf9KlzibaCapmqc=
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29 h1:w8s32wxx3sY+OjLlv9qltkLU5yvJzxjjgiHWLjdIcw4= golang.org/x/sync v0.0.0-20220513210516-0976fa681c29 h1:w8s32wxx3sY+OjLlv9qltkLU5yvJzxjjgiHWLjdIcw4=
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
@ -39,6 +39,8 @@ golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@ -3,9 +3,8 @@ package oracle
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"strings"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -276,11 +275,11 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
// https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm // https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
// ALTER INDEX oldindex RENAME TO newindex; panic("TODO")
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec( return m.DB.Exec(
"ALTER INDEX ? RENAME TO ?", "ALTER INDEX ?.? RENAME TO ?", // wat
clause.Column{Name: oldName}, clause.Column{Name: newName}, clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error ).Error
}) })
} }

View File

@ -1,59 +1,38 @@
package oracle package oracle
import ( import (
"fmt"
"strings"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"strings"
) )
var _ schema.Namer = Namer{}
type Namer struct { type Namer struct {
NamingStrategy schema.Namer schema.NamingStrategy
DBName string
} }
func (n Namer) ConvertNameToFormat(x string) string { func ConvertNameToFormat(x string) string {
return strings.ToUpper(x) return strings.ToUpper(x)
} }
func (n Namer) TableName(table string) (name string) { func (n Namer) TableName(table string) (name string) {
println("TableName", table) return ConvertNameToFormat(n.NamingStrategy.TableName(table))
tableName := n.ConvertNameToFormat(n.NamingStrategy.TableName(table))
if len(n.DBName) > 0 {
return fmt.Sprintf("%s.%s", n.DBName, tableName)
}
return tableName
} }
func (n Namer) ColumnName(table, column string) (name string) { func (n Namer) ColumnName(table, column string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column)) return ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column))
} }
func (n Namer) JoinTableName(table string) (name string) { func (n Namer) JoinTableName(table string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.JoinTableName(table)) return ConvertNameToFormat(n.NamingStrategy.JoinTableName(table))
} }
func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) { func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship)) return ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship))
} }
func (n Namer) CheckerName(table, column string) (name string) { func (n Namer) CheckerName(table, column string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column)) return ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column))
} }
func (n Namer) IndexName(table, column string) (name string) { func (n Namer) IndexName(table, column string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.IndexName(table, column)) return ConvertNameToFormat(n.NamingStrategy.IndexName(table, column))
}
func (n Namer) SchemaName(table string) string {
println("SchemaName", table)
return n.ConvertNameToFormat(n.NamingStrategy.SchemaName(table))
}
func (n Namer) UniqueName(table, column string) string {
println("UniqueName", table, column)
return n.ConvertNameToFormat(n.NamingStrategy.UniqueName(table, column))
} }

269
oracle.go
View File

@ -4,10 +4,14 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"log"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"gorm.io/gorm/utils"
_ "github.com/godror/godror"
"github.com/thoas/go-funk" "github.com/thoas/go-funk"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/callbacks" "gorm.io/gorm/callbacks"
@ -15,75 +19,61 @@ import (
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"gorm.io/gorm/migrator" "gorm.io/gorm/migrator"
"gorm.io/gorm/schema" "gorm.io/gorm/schema"
"gorm.io/gorm/utils"
_ "github.com/godror/godror"
_ "github.com/sijms/go-ora/v2"
) )
const (
RowNumberAliasForOracle11 = "ROW_NUM"
)
func Hello() string {
return "GORM Oracle Driver"
}
type Config struct { type Config struct {
DriverName string DriverName string
DSN string DSN string
Conn *sql.DB Conn gorm.ConnPool //*sql.DB
DefaultStringSize uint DefaultStringSize uint
DBName string
DBVer string DBVer string
DBName string // 库名
} }
type Dialector struct { type Dialector struct {
*Config *Config
} }
func New(config Config) gorm.Dialector {
return &Dialector{Config: &config}
}
func Open(dsn string) gorm.Dialector { func Open(dsn string) gorm.Dialector {
return &Dialector{Config: &Config{DSN: dsn}} return &Dialector{Config: &Config{DSN: dsn}}
} }
func (d Dialector) DummyTableName() string { return "DUAL" } func New(config Config) gorm.Dialector {
return &Dialector{Config: &config}
}
func (d Dialector) Name() string { return "oracle" } func (d Dialector) DummyTableName() string {
return "DUAL"
}
func (d Dialector) Name() string {
return "oracle"
}
func (d Dialector) Initialize(db *gorm.DB) (err error) { func (d Dialector) Initialize(db *gorm.DB) (err error) {
db.NamingStrategy = Namer{ db.NamingStrategy = Namer{db.NamingStrategy.(schema.NamingStrategy)}
NamingStrategy: db.NamingStrategy,
DBName: d.DBName,
}
d.DefaultStringSize = 1024 d.DefaultStringSize = 1024
// register callbacks // register callbacks
callbackConfig := &callbacks.Config{} //callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{WithReturning: true})
callbacks.RegisterDefaultCallbacks(db, callbackConfig) callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
})
// d.DriverName = "godror" d.DriverName = "godror"
d.DriverName = "oracle"
if d.Conn != nil { if d.Conn != nil {
db.ConnPool = d.Conn db.ConnPool = d.Conn
} else { } else {
db.ConnPool, err = sql.Open(d.DriverName, d.DSN) db.ConnPool, err = sql.Open(d.DriverName, d.DSN)
if err != nil {
return
}
} }
err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer) err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer)
if err != nil { if err != nil {
return err return err
} }
//log.Println("DBver:" + d.DBVer)
db.Logger.Info(context.Background(), "DBVer:%s", d.DBVer)
if err = db.Callback().Create().Replace("gorm:create", Create); err != nil { if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
return return
} }
@ -91,24 +81,26 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) {
for k, v := range d.ClauseBuilders() { for k, v := range d.ClauseBuilders() {
db.ClauseBuilders[k] = v db.ClauseBuilders[k] = v
} }
return return
} }
func (d Dialector) ClauseBuilders() (clauseBuilders map[string]clause.ClauseBuilder) { func (d Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
clauseBuilders = make(map[string]clause.ClauseBuilder) dbver, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0])
if dbVer, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]); dbVer > 11 { if dbver > 0 && dbver < 12 {
clauseBuilders["LIMIT"] = d.RewriteLimit return map[string]clause.ClauseBuilder{
"LIMIT": d.RewriteLimit11,
}
} else { } else {
clauseBuilders["LIMIT"] = d.RewriteLimit11 return map[string]clause.ClauseBuilder{
"LIMIT": d.RewriteLimit,
}
} }
return
} }
func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) { func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok { if limit, ok := c.Expression.(clause.Limit); ok {
if stmt, ok := builder.(*gorm.Statement); ok { if stmt, ok := builder.(*gorm.Statement); ok {
if _, ok := stmt.Clauses["ORDER BY"]; !ok { if _, ok := stmt.Clauses["ORDER BY"]; !ok {
s := stmt.Schema s := stmt.Schema
@ -137,100 +129,38 @@ func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
} }
} }
// Oracle11 Limit
func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) { func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) {
println("rewrite limit oracle 11g") if limit, ok := c.Expression.(clause.Limit); ok {
limit, ok := c.Expression.(clause.Limit) if stmt, ok := builder.(*gorm.Statement); ok {
if !ok { limitsql := strings.Builder{}
return if limit := limit.Limit; *limit > 0 {
} if _, ok := stmt.Clauses["WHERE"]; !ok {
limitsql.WriteString(" WHERE ")
offsetRows := limit.Offset } else {
hasOffset := offsetRows > 0 limitsql.WriteString(" AND ")
limitRows, hasLimit := d.getLimitRows(limit)
if (!hasOffset && !hasLimit) || limitRows == 1 {
return
}
var stmt *gorm.Statement
if stmt, ok = builder.(*gorm.Statement); !ok {
return
}
if hasLimit && hasOffset {
subQuerySQL := fmt.Sprintf(
"SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY %s) AS %s FROM (%s) T) WHERE %s BETWEEN %d AND %d",
d.getOrderByColumns(stmt),
RowNumberAliasForOracle11,
strings.TrimSpace(stmt.SQL.String()),
RowNumberAliasForOracle11,
offsetRows+1,
offsetRows+limitRows,
)
stmt.SQL.Reset()
stmt.SQL.WriteString(subQuerySQL)
} else if hasLimit {
// only limit
d.rewriteRownumStmt(stmt, builder, " <= ", limitRows)
} else {
// only offset
d.rewriteRownumStmt(stmt, builder, " > ", offsetRows)
}
}
func (d Dialector) rewriteRownumStmt(stmt *gorm.Statement, builder clause.Builder, operator string, rows int) {
limitSql := strings.Builder{}
if _, ok := stmt.Clauses["WHERE"]; !ok {
limitSql.WriteString(" WHERE ")
} else {
limitSql.WriteString(" AND ")
}
limitSql.WriteString("ROWNUM")
limitSql.WriteString(operator)
limitSql.WriteString(strconv.Itoa(rows))
if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy {
_, _ = builder.WriteString(limitSql.String())
} else {
// "ORDER BY" before insert
sqlTmp := strings.Builder{}
sqlOld := stmt.SQL.String()
orderIndex := strings.Index(sqlOld, "ORDER BY") - 1
sqlTmp.WriteString(sqlOld[:orderIndex])
sqlTmp.WriteString(limitSql.String())
sqlTmp.WriteString(sqlOld[orderIndex:])
stmt.SQL = sqlTmp
}
}
func (d Dialector) getOrderByColumns(stmt *gorm.Statement) string {
if orderByClause, ok := stmt.Clauses["ORDER BY"]; ok {
var orderBy clause.OrderBy
if orderBy, ok = orderByClause.Expression.(clause.OrderBy); ok && len(orderBy.Columns) > 0 {
println("order columns:", orderBy.Columns)
orderByBuilder := strings.Builder{}
for i, column := range orderBy.Columns {
if i > 0 {
orderByBuilder.WriteString(", ")
}
orderByBuilder.WriteString(column.Column.Name)
if column.Desc {
orderByBuilder.WriteString(" DESC")
} }
limitsql.WriteString("ROWNUM <= ")
limitsql.WriteString(strconv.Itoa(*limit))
}
if _, ok := stmt.Clauses["ORDER BY"]; !ok {
builder.WriteString(limitsql.String())
} else {
// "ORDER BY" before insert
sqltmp := strings.Builder{}
sqlold := stmt.SQL.String()
orderindx := strings.Index(sqlold, "ORDER BY") - 1
sqltmp.WriteString(sqlold[:orderindx])
sqltmp.WriteString(limitsql.String())
sqltmp.WriteString(sqlold[orderindx:])
log.Println(sqltmp.String())
stmt.SQL = sqltmp
} }
return orderByBuilder.String()
} }
} }
return "NULL"
} }
func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
func (d Dialector) getLimitRows(limit clause.Limit) (limitRows int, hasLimit bool) { return clause.Expr{SQL: "VALUES (DEFAULT)"}
if l := limit.Limit; l != nil {
limitRows = *l
hasLimit = limitRows > 0
}
return
} }
func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator { func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
@ -245,8 +175,35 @@ func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
} }
} }
func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteString(":")
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
}
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteString(str)
}
var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
func (d Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} {
switch v := v.(type) {
case bool:
if v {
return 1
}
return 0
default:
return v
}
}).([]interface{})...)
}
func (d Dialector) DataTypeOf(field *schema.Field) string { func (d Dialector) DataTypeOf(field *schema.Field) string {
delete(field.TagSettings, "RESTRICT") if _, found := field.TagSettings["RESTRICT"]; found {
delete(field.TagSettings, "RESTRICT")
}
var sqlType string var sqlType string
@ -264,7 +221,7 @@ func (d Dialector) DataTypeOf(field *schema.Field) string {
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) { if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
sqlType += " GENERATED BY DEFAULT AS IDENTITY" sqlType += " GENERATED BY DEFAULT AS IDENTITY"
} }
case schema.String: case schema.String, "VARCHAR2":
size := field.Size size := field.Size
defaultSize := d.DefaultStringSize defaultSize := d.DefaultStringSize
@ -288,13 +245,11 @@ func (d Dialector) DataTypeOf(field *schema.Field) string {
case schema.Time: case schema.Time:
sqlType = "TIMESTAMP WITH TIME ZONE" sqlType = "TIMESTAMP WITH TIME ZONE"
if field.NotNull || field.PrimaryKey {
sqlType += " NOT NULL"
}
case schema.Bytes: case schema.Bytes:
sqlType = "BLOB" sqlType = "BLOB"
default: default:
sqlType := string(field.DataType) sqlType = string(field.DataType)
if strings.EqualFold(sqlType, "text") { if strings.EqualFold(sqlType, "text") {
sqlType = "CLOB" sqlType = "CLOB"
@ -304,51 +259,11 @@ func (d Dialector) DataTypeOf(field *schema.Field) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String())) panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String()))
} }
notNull := field.TagSettings["NOT NULL"]
unique := field.TagSettings["UNIQUE"]
additionalType := fmt.Sprintf("%s %s", notNull, unique)
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string {
if value, ok := field.TagSettings["COMMENT"]; ok {
return " COMMENT " + value
}
return ""
}())
}
sqlType = fmt.Sprintf("%v %v", sqlType, additionalType)
return sqlType
} }
return sqlType return sqlType
} }
func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
return clause.Expr{SQL: "VALUES (DEFAULT)"}
}
func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v any) {}
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteString(str)
}
var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
func (d Dialector) Explain(sql string, vars ...any) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v any) any {
switch v := v.(type) {
case bool:
if v {
return 1
}
return 0
default:
return v
}
}).([]any)...)
}
func (d Dialector) SavePoint(tx *gorm.DB, name string) error { func (d Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name) tx.Exec("SAVEPOINT " + name)
return tx.Error return tx.Error