update
This commit is contained in:
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
.idea
|
||||
vendor/
|
||||
go.sum
|
||||
CHANGELOG.md
|
||||
/test_local/
|
||||
/go.work*
|
25
License
Normal file
25
License
Normal file
@ -0,0 +1,25 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-NOW
|
||||
|
||||
Jinzhu <wosmvp@gmail.com>,
|
||||
Steve Fan <stevefan1999-personal@github>,
|
||||
CengSin <zephone0724@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
39
README.md
Normal file
39
README.md
Normal file
@ -0,0 +1,39 @@
|
||||
# GORM Oracle Driver
|
||||
|
||||
|
||||
## Description
|
||||
|
||||
GORM Oracle driver for connect Oracle DB and Manage Oracle DB, Based on [CengSin/oracle](https://github.com/CengSin/oracle)
|
||||
,not recommended for use in a production environment
|
||||
|
||||
## Required dependency Install
|
||||
|
||||
- Oracle 12C+
|
||||
- Golang 1.13+
|
||||
- see [ODPI-C Installation.](https://oracle.github.io/odpi/doc/installation.html)
|
||||
- gorm 1.24.0+
|
||||
|
||||
## Quick Start
|
||||
### how to install
|
||||
```bash
|
||||
go get github.com/dzwvip/oracle
|
||||
```
|
||||
### usage
|
||||
|
||||
```go
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/dzwvip/oracle"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := gorm.Open(oracle.Open("system/oracle@127.0.0.1:1521/XE"), &gorm.Config{})
|
||||
if err != nil {
|
||||
// panic error or log error info
|
||||
}
|
||||
|
||||
// do somethings
|
||||
}
|
||||
```
|
303
create.go
303
create.go
@ -1,255 +1,152 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"github.com/thoas/go-funk"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
gormSchema "gorm.io/gorm/schema"
|
||||
|
||||
"git.charlienet.top/go/oracle/clauses"
|
||||
)
|
||||
|
||||
func Create(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stmt := db.Statement
|
||||
if stmt == nil {
|
||||
schema := stmt.Schema
|
||||
boundVars := make(map[string]int)
|
||||
|
||||
if stmt == nil || schema == nil {
|
||||
return
|
||||
}
|
||||
|
||||
stmtSchema := stmt.Schema
|
||||
if stmtSchema == nil {
|
||||
return
|
||||
}
|
||||
hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0
|
||||
|
||||
if !stmt.Unscoped {
|
||||
for _, c := range stmtSchema.CreateClauses {
|
||||
for _, c := range schema.CreateClauses {
|
||||
stmt.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.SQL.Len() == 0 {
|
||||
var (
|
||||
createValues = callbacks.ConvertToCreateValues(stmt)
|
||||
onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
|
||||
)
|
||||
if stmt.SQL.String() == "" {
|
||||
values := callbacks.ConvertToCreateValues(stmt)
|
||||
onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
|
||||
// are all columns in value the primary fields in schema only?
|
||||
if hasConflict && funk.Contains(
|
||||
funk.Map(values.Columns, func(c clause.Column) string { return c.Name }),
|
||||
funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }),
|
||||
) {
|
||||
stmt.AddClauseIfNotExists(clauses.Merge{
|
||||
Using: []clause.Interface{
|
||||
clause.Select{
|
||||
Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column {
|
||||
// HACK: I can not come up with a better alternative for now
|
||||
// I want to add a value to the list of variable and then capture the bind variable position as well
|
||||
buf := bytes.NewBufferString("")
|
||||
stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)])
|
||||
stmt.BindVarTo(buf, stmt, nil)
|
||||
|
||||
if hasConflict {
|
||||
if len(stmtSchema.PrimaryFields) > 0 {
|
||||
columnsMap := map[string]bool{}
|
||||
for _, column := range createValues.Columns {
|
||||
columnsMap[column.Name] = true
|
||||
}
|
||||
|
||||
for _, field := range stmtSchema.PrimaryFields {
|
||||
if _, ok := columnsMap[field.DBName]; !ok {
|
||||
hasConflict = false
|
||||
column.Alias = column.Name
|
||||
// then the captured bind var will be the name
|
||||
column.Name = buf.String()
|
||||
return column
|
||||
}).([]clause.Column),
|
||||
},
|
||||
clause.From{
|
||||
Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}},
|
||||
},
|
||||
},
|
||||
On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression {
|
||||
return clause.Eq{
|
||||
Column: clause.Column{Table: stmt.Schema.Table, Name: field.DBName},
|
||||
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName},
|
||||
}
|
||||
}
|
||||
} else {
|
||||
hasConflict = false
|
||||
}
|
||||
}
|
||||
}).([]clause.Expression),
|
||||
})
|
||||
stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates})
|
||||
stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values})
|
||||
|
||||
hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0
|
||||
if hasConflict {
|
||||
MergeCreate(db, onConflict, createValues)
|
||||
stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED")
|
||||
} else {
|
||||
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}})
|
||||
stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}})
|
||||
|
||||
stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}})
|
||||
if hasDefaultValues {
|
||||
columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue))
|
||||
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
|
||||
columns[idx] = clause.Column{Name: field.DBName}
|
||||
}
|
||||
stmt.AddClauseIfNotExists(clause.Returning{Columns: columns})
|
||||
stmt.AddClauseIfNotExists(clause.Returning{
|
||||
Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column {
|
||||
return clause.Column{Name: field.DBName}
|
||||
}).([]clause.Column),
|
||||
})
|
||||
}
|
||||
stmt.Build("INSERT", "VALUES", "RETURNING")
|
||||
|
||||
if hasDefaultValues {
|
||||
_, _ = stmt.WriteString(" INTO ")
|
||||
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
|
||||
stmt.WriteString(" INTO ")
|
||||
for idx, field := range schema.FieldsWithDefaultDBValue {
|
||||
if idx > 0 {
|
||||
_ = stmt.WriteByte(',')
|
||||
stmt.WriteByte(',')
|
||||
}
|
||||
boundVars[field.Name] = len(stmt.Vars)
|
||||
stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()})
|
||||
}
|
||||
_, _ = stmt.WriteString(" /*-sql.Out{}-*/")
|
||||
}
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if hasConflict {
|
||||
for i, val := range stmt.Vars {
|
||||
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
|
||||
stmt.Vars[i] = convertValue(val)
|
||||
}
|
||||
|
||||
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
// TODO: get merged returning
|
||||
}
|
||||
} else {
|
||||
for idx, values := range createValues.Values {
|
||||
for i, val := range values {
|
||||
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
|
||||
stmt.Vars[i] = convertValue(val)
|
||||
}
|
||||
|
||||
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if hasDefaultValues {
|
||||
getDefaultValues(db, idx)
|
||||
if !db.DryRun {
|
||||
for idx, vals := range values.Values {
|
||||
// HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
|
||||
for idx, val := range vals {
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
val = 1
|
||||
} else {
|
||||
val = 0
|
||||
}
|
||||
}
|
||||
|
||||
stmt.Vars[idx] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert)
|
||||
// we keep track of the index so that the sub-reflected value is also correct
|
||||
|
||||
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
|
||||
var dummyTable string
|
||||
switch d := ptrDereference(db.Dialector).(type) {
|
||||
case Dialector:
|
||||
dummyTable = d.DummyTableName()
|
||||
default:
|
||||
dummyTable = "DUAL"
|
||||
}
|
||||
// BIG BUG: what if any of the transactions failed? some result might already be inserted that oracle is so
|
||||
// sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point,
|
||||
// resulting in dangling row entries, so we might need to delete them if an error happens
|
||||
|
||||
_, _ = db.Statement.WriteString("MERGE INTO ")
|
||||
db.Statement.WriteQuoted(db.Statement.Table)
|
||||
_, _ = db.Statement.WriteString(" USING (")
|
||||
switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err {
|
||||
case nil: // success
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
for idx, value := range values.Values {
|
||||
if idx > 0 {
|
||||
_, _ = db.Statement.WriteString(" UNION ALL ")
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString("SELECT ")
|
||||
for i, v := range value {
|
||||
if i > 0 {
|
||||
_ = db.Statement.WriteByte(',')
|
||||
}
|
||||
column := values.Columns[i]
|
||||
db.Statement.AddVar(db.Statement, v)
|
||||
_, _ = db.Statement.WriteString(" AS ")
|
||||
db.Statement.WriteQuoted(column.Name)
|
||||
}
|
||||
_, _ = db.Statement.WriteString(" FROM ")
|
||||
_, _ = db.Statement.WriteString(dummyTable)
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString(`) `)
|
||||
db.Statement.WriteQuoted("excluded")
|
||||
_, _ = db.Statement.WriteString(" ON (")
|
||||
|
||||
var where clause.Where
|
||||
for _, field := range db.Statement.Schema.PrimaryFields {
|
||||
where.Exprs = append(where.Exprs, clause.Eq{
|
||||
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
|
||||
Value: clause.Column{Table: "excluded", Name: field.DBName},
|
||||
})
|
||||
}
|
||||
where.Build(db.Statement)
|
||||
_ = db.Statement.WriteByte(')')
|
||||
|
||||
if len(onConflict.DoUpdates) > 0 {
|
||||
_, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
|
||||
onConflict.DoUpdates.Build(db.Statement)
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
|
||||
|
||||
written := false
|
||||
for _, column := range values.Columns {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
|
||||
if written {
|
||||
_ = db.Statement.WriteByte(',')
|
||||
}
|
||||
written = true
|
||||
db.Statement.WriteQuoted(column.Name)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString(") VALUES (")
|
||||
|
||||
written = false
|
||||
for _, column := range values.Columns {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
|
||||
if written {
|
||||
_ = db.Statement.WriteByte(',')
|
||||
}
|
||||
written = true
|
||||
db.Statement.WriteQuoted(clause.Column{
|
||||
Table: "excluded",
|
||||
Name: column.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
_, _ = db.Statement.WriteString(")")
|
||||
}
|
||||
|
||||
func convertValue(val interface{}) interface{} {
|
||||
val = ptrDereference(val)
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
val = 1
|
||||
} else {
|
||||
val = 0
|
||||
}
|
||||
// case string:
|
||||
// if len(v) > 2000 {
|
||||
// val = go_ora.Clob{String: v, Valid: true}
|
||||
// }
|
||||
default:
|
||||
val = convertCustomType(val)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func getDefaultValues(db *gorm.DB, idx int) {
|
||||
insertTo := db.Statement.ReflectValue
|
||||
switch insertTo.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
insertTo = insertTo.Index(idx)
|
||||
default:
|
||||
}
|
||||
if insertTo.Kind() == reflect.Pointer {
|
||||
insertTo = insertTo.Elem()
|
||||
}
|
||||
|
||||
for _, val := range db.Statement.Vars {
|
||||
switch v := val.(type) {
|
||||
case sql.Out:
|
||||
switch insertTo.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := insertTo.Len() - 1; i >= 0; i-- {
|
||||
rv := insertTo.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
insertTo := stmt.ReflectValue
|
||||
switch insertTo.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
insertTo = insertTo.Index(idx)
|
||||
}
|
||||
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest))
|
||||
if hasDefaultValues {
|
||||
// bind returning value back to reflected value in the respective fields
|
||||
funk.ForEach(
|
||||
funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool {
|
||||
return funk.Contains(boundVars, field.Name)
|
||||
}),
|
||||
func(field *gormSchema.Field) {
|
||||
switch insertTo.Kind() {
|
||||
case reflect.Struct:
|
||||
if err = field.Set(stmt.Context, insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil {
|
||||
db.AddError(err)
|
||||
}
|
||||
case reflect.Map:
|
||||
// todo 设置id的值
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
default: // failure
|
||||
db.AddError(err)
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo)
|
||||
if isZero {
|
||||
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest))
|
||||
}
|
||||
default:
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
8
go.mod
8
go.mod
@ -4,10 +4,10 @@ go 1.22
|
||||
|
||||
require (
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/godror/godror v0.42.2
|
||||
github.com/sijms/go-ora/v2 v2.8.16
|
||||
github.com/godror/godror v0.43.0
|
||||
github.com/thoas/go-funk v0.9.3
|
||||
gorm.io/gorm v1.25.10
|
||||
|
||||
)
|
||||
|
||||
require (
|
||||
@ -15,6 +15,6 @@ require (
|
||||
github.com/godror/knownpb v0.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 // indirect
|
||||
google.golang.org/protobuf v1.33.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
10
go.sum
10
go.sum
@ -8,8 +8,8 @@ github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi
|
||||
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/godror/godror v0.42.2 h1:TmOV0fr4jJxwDD6vWtSieQFOZ75LnSbkJaudfuJOsVQ=
|
||||
github.com/godror/godror v0.42.2/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94=
|
||||
github.com/godror/godror v0.43.0 h1:qMbQwG0ejJnKma3bBvrJg1rkiyP5b4v6uxvx9zDrKJw=
|
||||
github.com/godror/godror v0.43.0/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94=
|
||||
github.com/godror/knownpb v0.1.1 h1:A4J7jdx7jWBhJm18NntafzSC//iZDHkDi1+juwQ5pTI=
|
||||
github.com/godror/knownpb v0.1.1/go.mod h1:4nRFbQo1dDuwKnblRXDxrfCFYeT4hjg3GjMqef58eRE=
|
||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
||||
@ -22,8 +22,6 @@ github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
|
||||
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sijms/go-ora/v2 v2.8.16 h1:XiBAqHjoXUnJnGnRoHD/6NJdiYQ2pr0gPciJM8BE+70=
|
||||
github.com/sijms/go-ora/v2 v2.8.16/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4FWlH6xk=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
@ -31,6 +29,8 @@ github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw=
|
||||
github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
|
||||
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 h1:6R2FC06FonbXQ8pK11/PDFY6N6LWlf9KlzibaCapmqc=
|
||||
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29 h1:w8s32wxx3sY+OjLlv9qltkLU5yvJzxjjgiHWLjdIcw4=
|
||||
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
@ -39,6 +39,8 @@ golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
|
||||
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
@ -3,9 +3,8 @@ package oracle
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
@ -276,11 +275,11 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
|
||||
// https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
// ALTER INDEX oldindex RENAME TO newindex;
|
||||
panic("TODO")
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Exec(
|
||||
"ALTER INDEX ? RENAME TO ?",
|
||||
clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||||
"ALTER INDEX ?.? RENAME TO ?", // wat
|
||||
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
39
namer.go
39
namer.go
@ -1,59 +1,38 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var _ schema.Namer = Namer{}
|
||||
|
||||
type Namer struct {
|
||||
NamingStrategy schema.Namer
|
||||
DBName string
|
||||
schema.NamingStrategy
|
||||
}
|
||||
|
||||
func (n Namer) ConvertNameToFormat(x string) string {
|
||||
func ConvertNameToFormat(x string) string {
|
||||
return strings.ToUpper(x)
|
||||
}
|
||||
|
||||
func (n Namer) TableName(table string) (name string) {
|
||||
println("TableName", table)
|
||||
|
||||
tableName := n.ConvertNameToFormat(n.NamingStrategy.TableName(table))
|
||||
if len(n.DBName) > 0 {
|
||||
return fmt.Sprintf("%s.%s", n.DBName, tableName)
|
||||
}
|
||||
return tableName
|
||||
return ConvertNameToFormat(n.NamingStrategy.TableName(table))
|
||||
}
|
||||
|
||||
func (n Namer) ColumnName(table, column string) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column))
|
||||
return ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column))
|
||||
}
|
||||
|
||||
func (n Namer) JoinTableName(table string) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.JoinTableName(table))
|
||||
return ConvertNameToFormat(n.NamingStrategy.JoinTableName(table))
|
||||
}
|
||||
|
||||
func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship))
|
||||
return ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship))
|
||||
}
|
||||
|
||||
func (n Namer) CheckerName(table, column string) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column))
|
||||
return ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column))
|
||||
}
|
||||
|
||||
func (n Namer) IndexName(table, column string) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.IndexName(table, column))
|
||||
}
|
||||
|
||||
func (n Namer) SchemaName(table string) string {
|
||||
println("SchemaName", table)
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.SchemaName(table))
|
||||
}
|
||||
|
||||
func (n Namer) UniqueName(table, column string) string {
|
||||
println("UniqueName", table, column)
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.UniqueName(table, column))
|
||||
return ConvertNameToFormat(n.NamingStrategy.IndexName(table, column))
|
||||
}
|
||||
|
269
oracle.go
269
oracle.go
@ -4,10 +4,14 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/utils"
|
||||
|
||||
_ "github.com/godror/godror"
|
||||
"github.com/thoas/go-funk"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
@ -15,75 +19,61 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
|
||||
_ "github.com/godror/godror"
|
||||
_ "github.com/sijms/go-ora/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
RowNumberAliasForOracle11 = "ROW_NUM"
|
||||
)
|
||||
|
||||
func Hello() string {
|
||||
return "GORM Oracle Driver"
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
DriverName string
|
||||
DSN string
|
||||
Conn *sql.DB
|
||||
Conn gorm.ConnPool //*sql.DB
|
||||
DefaultStringSize uint
|
||||
DBName string
|
||||
DBVer string
|
||||
DBName string // 库名
|
||||
}
|
||||
|
||||
type Dialector struct {
|
||||
*Config
|
||||
}
|
||||
|
||||
func New(config Config) gorm.Dialector {
|
||||
return &Dialector{Config: &config}
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{Config: &Config{DSN: dsn}}
|
||||
}
|
||||
|
||||
func (d Dialector) DummyTableName() string { return "DUAL" }
|
||||
func New(config Config) gorm.Dialector {
|
||||
return &Dialector{Config: &config}
|
||||
}
|
||||
|
||||
func (d Dialector) Name() string { return "oracle" }
|
||||
func (d Dialector) DummyTableName() string {
|
||||
return "DUAL"
|
||||
}
|
||||
|
||||
func (d Dialector) Name() string {
|
||||
return "oracle"
|
||||
}
|
||||
|
||||
func (d Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
db.NamingStrategy = Namer{
|
||||
NamingStrategy: db.NamingStrategy,
|
||||
DBName: d.DBName,
|
||||
}
|
||||
db.NamingStrategy = Namer{db.NamingStrategy.(schema.NamingStrategy)}
|
||||
d.DefaultStringSize = 1024
|
||||
|
||||
// register callbacks
|
||||
callbackConfig := &callbacks.Config{}
|
||||
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
|
||||
//callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{WithReturning: true})
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
})
|
||||
|
||||
// d.DriverName = "godror"
|
||||
d.DriverName = "oracle"
|
||||
d.DriverName = "godror"
|
||||
|
||||
if d.Conn != nil {
|
||||
db.ConnPool = d.Conn
|
||||
} else {
|
||||
db.ConnPool, err = sql.Open(d.DriverName, d.DSN)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Info(context.Background(), "DBVer:%s", d.DBVer)
|
||||
|
||||
//log.Println("DBver:" + d.DBVer)
|
||||
if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
|
||||
return
|
||||
}
|
||||
@ -91,24 +81,26 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
for k, v := range d.ClauseBuilders() {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (d Dialector) ClauseBuilders() (clauseBuilders map[string]clause.ClauseBuilder) {
|
||||
clauseBuilders = make(map[string]clause.ClauseBuilder)
|
||||
if dbVer, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]); dbVer > 11 {
|
||||
clauseBuilders["LIMIT"] = d.RewriteLimit
|
||||
func (d Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
dbver, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0])
|
||||
if dbver > 0 && dbver < 12 {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"LIMIT": d.RewriteLimit11,
|
||||
}
|
||||
|
||||
} else {
|
||||
clauseBuilders["LIMIT"] = d.RewriteLimit11
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"LIMIT": d.RewriteLimit,
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
|
||||
if stmt, ok := builder.(*gorm.Statement); ok {
|
||||
if _, ok := stmt.Clauses["ORDER BY"]; !ok {
|
||||
s := stmt.Schema
|
||||
@ -137,100 +129,38 @@ func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
|
||||
}
|
||||
}
|
||||
|
||||
// Oracle11 Limit
|
||||
func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) {
|
||||
println("rewrite limit oracle 11g")
|
||||
limit, ok := c.Expression.(clause.Limit)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
offsetRows := limit.Offset
|
||||
hasOffset := offsetRows > 0
|
||||
limitRows, hasLimit := d.getLimitRows(limit)
|
||||
if (!hasOffset && !hasLimit) || limitRows == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
var stmt *gorm.Statement
|
||||
if stmt, ok = builder.(*gorm.Statement); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if hasLimit && hasOffset {
|
||||
subQuerySQL := fmt.Sprintf(
|
||||
"SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY %s) AS %s FROM (%s) T) WHERE %s BETWEEN %d AND %d",
|
||||
d.getOrderByColumns(stmt),
|
||||
RowNumberAliasForOracle11,
|
||||
strings.TrimSpace(stmt.SQL.String()),
|
||||
RowNumberAliasForOracle11,
|
||||
offsetRows+1,
|
||||
offsetRows+limitRows,
|
||||
)
|
||||
stmt.SQL.Reset()
|
||||
stmt.SQL.WriteString(subQuerySQL)
|
||||
} else if hasLimit {
|
||||
// only limit
|
||||
d.rewriteRownumStmt(stmt, builder, " <= ", limitRows)
|
||||
} else {
|
||||
// only offset
|
||||
d.rewriteRownumStmt(stmt, builder, " > ", offsetRows)
|
||||
}
|
||||
}
|
||||
|
||||
func (d Dialector) rewriteRownumStmt(stmt *gorm.Statement, builder clause.Builder, operator string, rows int) {
|
||||
limitSql := strings.Builder{}
|
||||
if _, ok := stmt.Clauses["WHERE"]; !ok {
|
||||
limitSql.WriteString(" WHERE ")
|
||||
} else {
|
||||
limitSql.WriteString(" AND ")
|
||||
}
|
||||
limitSql.WriteString("ROWNUM")
|
||||
limitSql.WriteString(operator)
|
||||
limitSql.WriteString(strconv.Itoa(rows))
|
||||
|
||||
if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy {
|
||||
_, _ = builder.WriteString(limitSql.String())
|
||||
} else {
|
||||
// "ORDER BY" before insert
|
||||
sqlTmp := strings.Builder{}
|
||||
sqlOld := stmt.SQL.String()
|
||||
orderIndex := strings.Index(sqlOld, "ORDER BY") - 1
|
||||
sqlTmp.WriteString(sqlOld[:orderIndex])
|
||||
sqlTmp.WriteString(limitSql.String())
|
||||
sqlTmp.WriteString(sqlOld[orderIndex:])
|
||||
stmt.SQL = sqlTmp
|
||||
}
|
||||
}
|
||||
|
||||
func (d Dialector) getOrderByColumns(stmt *gorm.Statement) string {
|
||||
if orderByClause, ok := stmt.Clauses["ORDER BY"]; ok {
|
||||
var orderBy clause.OrderBy
|
||||
|
||||
if orderBy, ok = orderByClause.Expression.(clause.OrderBy); ok && len(orderBy.Columns) > 0 {
|
||||
println("order columns:", orderBy.Columns)
|
||||
|
||||
orderByBuilder := strings.Builder{}
|
||||
for i, column := range orderBy.Columns {
|
||||
if i > 0 {
|
||||
orderByBuilder.WriteString(", ")
|
||||
}
|
||||
orderByBuilder.WriteString(column.Column.Name)
|
||||
if column.Desc {
|
||||
orderByBuilder.WriteString(" DESC")
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
if stmt, ok := builder.(*gorm.Statement); ok {
|
||||
limitsql := strings.Builder{}
|
||||
if limit := limit.Limit; *limit > 0 {
|
||||
if _, ok := stmt.Clauses["WHERE"]; !ok {
|
||||
limitsql.WriteString(" WHERE ")
|
||||
} else {
|
||||
limitsql.WriteString(" AND ")
|
||||
}
|
||||
limitsql.WriteString("ROWNUM <= ")
|
||||
limitsql.WriteString(strconv.Itoa(*limit))
|
||||
}
|
||||
if _, ok := stmt.Clauses["ORDER BY"]; !ok {
|
||||
builder.WriteString(limitsql.String())
|
||||
} else {
|
||||
// "ORDER BY" before insert
|
||||
sqltmp := strings.Builder{}
|
||||
sqlold := stmt.SQL.String()
|
||||
orderindx := strings.Index(sqlold, "ORDER BY") - 1
|
||||
sqltmp.WriteString(sqlold[:orderindx])
|
||||
sqltmp.WriteString(limitsql.String())
|
||||
sqltmp.WriteString(sqlold[orderindx:])
|
||||
log.Println(sqltmp.String())
|
||||
stmt.SQL = sqltmp
|
||||
}
|
||||
return orderByBuilder.String()
|
||||
}
|
||||
}
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
func (d Dialector) getLimitRows(limit clause.Limit) (limitRows int, hasLimit bool) {
|
||||
if l := limit.Limit; l != nil {
|
||||
limitRows = *l
|
||||
hasLimit = limitRows > 0
|
||||
}
|
||||
return
|
||||
func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
|
||||
return clause.Expr{SQL: "VALUES (DEFAULT)"}
|
||||
}
|
||||
|
||||
func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
@ -245,8 +175,35 @@ func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
}
|
||||
}
|
||||
|
||||
func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteString(":")
|
||||
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
|
||||
}
|
||||
|
||||
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteString(str)
|
||||
}
|
||||
|
||||
var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
|
||||
|
||||
func (d Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} {
|
||||
switch v := v.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}).([]interface{})...)
|
||||
}
|
||||
|
||||
func (d Dialector) DataTypeOf(field *schema.Field) string {
|
||||
delete(field.TagSettings, "RESTRICT")
|
||||
if _, found := field.TagSettings["RESTRICT"]; found {
|
||||
delete(field.TagSettings, "RESTRICT")
|
||||
}
|
||||
|
||||
var sqlType string
|
||||
|
||||
@ -264,7 +221,7 @@ func (d Dialector) DataTypeOf(field *schema.Field) string {
|
||||
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
|
||||
sqlType += " GENERATED BY DEFAULT AS IDENTITY"
|
||||
}
|
||||
case schema.String:
|
||||
case schema.String, "VARCHAR2":
|
||||
size := field.Size
|
||||
defaultSize := d.DefaultStringSize
|
||||
|
||||
@ -288,13 +245,11 @@ func (d Dialector) DataTypeOf(field *schema.Field) string {
|
||||
|
||||
case schema.Time:
|
||||
sqlType = "TIMESTAMP WITH TIME ZONE"
|
||||
if field.NotNull || field.PrimaryKey {
|
||||
sqlType += " NOT NULL"
|
||||
}
|
||||
|
||||
case schema.Bytes:
|
||||
sqlType = "BLOB"
|
||||
default:
|
||||
sqlType := string(field.DataType)
|
||||
sqlType = string(field.DataType)
|
||||
|
||||
if strings.EqualFold(sqlType, "text") {
|
||||
sqlType = "CLOB"
|
||||
@ -304,51 +259,11 @@ func (d Dialector) DataTypeOf(field *schema.Field) string {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String()))
|
||||
}
|
||||
|
||||
notNull := field.TagSettings["NOT NULL"]
|
||||
unique := field.TagSettings["UNIQUE"]
|
||||
additionalType := fmt.Sprintf("%s %s", notNull, unique)
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string {
|
||||
if value, ok := field.TagSettings["COMMENT"]; ok {
|
||||
return " COMMENT " + value
|
||||
}
|
||||
return ""
|
||||
}())
|
||||
}
|
||||
|
||||
sqlType = fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
return sqlType
|
||||
}
|
||||
|
||||
return sqlType
|
||||
}
|
||||
|
||||
func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
|
||||
return clause.Expr{SQL: "VALUES (DEFAULT)"}
|
||||
}
|
||||
|
||||
func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v any) {}
|
||||
|
||||
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteString(str)
|
||||
}
|
||||
|
||||
var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
|
||||
|
||||
func (d Dialector) Explain(sql string, vars ...any) string {
|
||||
return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v any) any {
|
||||
switch v := v.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}).([]any)...)
|
||||
}
|
||||
|
||||
func (d Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
tx.Exec("SAVEPOINT " + name)
|
||||
return tx.Error
|
||||
|
Reference in New Issue
Block a user