This commit is contained in:
2024-05-10 10:27:21 +08:00
parent 90f807defc
commit 184a8722d8
9 changed files with 285 additions and 423 deletions

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
.idea
vendor/
go.sum
CHANGELOG.md
/test_local/
/go.work*

25
License Normal file
View File

@ -0,0 +1,25 @@
The MIT License (MIT)
Copyright (c) 2013-NOW
Jinzhu <wosmvp@gmail.com>,
Steve Fan <stevefan1999-personal@github>,
CengSin <zephone0724@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

39
README.md Normal file
View File

@ -0,0 +1,39 @@
# GORM Oracle Driver
## Description
GORM Oracle driver for connect Oracle DB and Manage Oracle DB, Based on [CengSin/oracle](https://github.com/CengSin/oracle)
not recommended for use in a production environment
## Required dependency Install
- Oracle 12C+
- Golang 1.13+
- see [ODPI-C Installation.](https://oracle.github.io/odpi/doc/installation.html)
- gorm 1.24.0+
## Quick Start
### how to install
```bash
go get github.com/dzwvip/oracle
```
### usage
```go
import (
"fmt"
"github.com/dzwvip/oracle"
"gorm.io/gorm"
"log"
)
func main() {
db, err := gorm.Open(oracle.Open("system/oracle@127.0.0.1:1521/XE"), &gorm.Config{})
if err != nil {
// panic error or log error info
}
// do somethings
}
```

289
create.go
View File

@ -1,203 +1,102 @@
package oracle
import (
"bytes"
"database/sql"
"reflect"
"github.com/thoas/go-funk"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
gormSchema "gorm.io/gorm/schema"
"git.charlienet.top/go/oracle/clauses"
)
func Create(db *gorm.DB) {
if db.Error != nil {
return
}
stmt := db.Statement
if stmt == nil {
schema := stmt.Schema
boundVars := make(map[string]int)
if stmt == nil || schema == nil {
return
}
stmtSchema := stmt.Schema
if stmtSchema == nil {
return
}
hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0
if !stmt.Unscoped {
for _, c := range stmtSchema.CreateClauses {
for _, c := range schema.CreateClauses {
stmt.AddClause(c)
}
}
if stmt.SQL.Len() == 0 {
var (
createValues = callbacks.ConvertToCreateValues(stmt)
onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
)
if stmt.SQL.String() == "" {
values := callbacks.ConvertToCreateValues(stmt)
onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
// are all columns in value the primary fields in schema only?
if hasConflict && funk.Contains(
funk.Map(values.Columns, func(c clause.Column) string { return c.Name }),
funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }),
) {
stmt.AddClauseIfNotExists(clauses.Merge{
Using: []clause.Interface{
clause.Select{
Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column {
// HACK: I can not come up with a better alternative for now
// I want to add a value to the list of variable and then capture the bind variable position as well
buf := bytes.NewBufferString("")
stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)])
stmt.BindVarTo(buf, stmt, nil)
if hasConflict {
if len(stmtSchema.PrimaryFields) > 0 {
columnsMap := map[string]bool{}
for _, column := range createValues.Columns {
columnsMap[column.Name] = true
column.Alias = column.Name
// then the captured bind var will be the name
column.Name = buf.String()
return column
}).([]clause.Column),
},
clause.From{
Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}},
},
},
On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression {
return clause.Eq{
Column: clause.Column{Table: stmt.Schema.Table, Name: field.DBName},
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName},
}
}).([]clause.Expression),
})
stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates})
stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values})
for _, field := range stmtSchema.PrimaryFields {
if _, ok := columnsMap[field.DBName]; !ok {
hasConflict = false
}
}
} else {
hasConflict = false
}
}
hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0
if hasConflict {
MergeCreate(db, onConflict, createValues)
stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED")
} else {
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}})
stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}})
stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}})
if hasDefaultValues {
columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue))
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
columns[idx] = clause.Column{Name: field.DBName}
}
stmt.AddClauseIfNotExists(clause.Returning{Columns: columns})
stmt.AddClauseIfNotExists(clause.Returning{
Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column {
return clause.Column{Name: field.DBName}
}).([]clause.Column),
})
}
stmt.Build("INSERT", "VALUES", "RETURNING")
if hasDefaultValues {
_, _ = stmt.WriteString(" INTO ")
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
stmt.WriteString(" INTO ")
for idx, field := range schema.FieldsWithDefaultDBValue {
if idx > 0 {
_ = stmt.WriteByte(',')
stmt.WriteByte(',')
}
boundVars[field.Name] = len(stmt.Vars)
stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()})
}
_, _ = stmt.WriteString(" /*-sql.Out{}-*/")
}
}
if !db.DryRun && db.Error == nil {
if hasConflict {
for i, val := range stmt.Vars {
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
stmt.Vars[i] = convertValue(val)
}
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
// TODO: get merged returning
}
} else {
for idx, values := range createValues.Values {
for i, val := range values {
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
stmt.Vars[i] = convertValue(val)
}
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
if hasDefaultValues {
getDefaultValues(db, idx)
}
}
}
}
}
}
}
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
var dummyTable string
switch d := ptrDereference(db.Dialector).(type) {
case Dialector:
dummyTable = d.DummyTableName()
default:
dummyTable = "DUAL"
}
_, _ = db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table)
_, _ = db.Statement.WriteString(" USING (")
for idx, value := range values.Values {
if idx > 0 {
_, _ = db.Statement.WriteString(" UNION ALL ")
}
_, _ = db.Statement.WriteString("SELECT ")
for i, v := range value {
if i > 0 {
_ = db.Statement.WriteByte(',')
}
column := values.Columns[i]
db.Statement.AddVar(db.Statement, v)
_, _ = db.Statement.WriteString(" AS ")
db.Statement.WriteQuoted(column.Name)
}
_, _ = db.Statement.WriteString(" FROM ")
_, _ = db.Statement.WriteString(dummyTable)
}
_, _ = db.Statement.WriteString(`) `)
db.Statement.WriteQuoted("excluded")
_, _ = db.Statement.WriteString(" ON (")
var where clause.Where
for _, field := range db.Statement.Schema.PrimaryFields {
where.Exprs = append(where.Exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
Value: clause.Column{Table: "excluded", Name: field.DBName},
})
}
where.Build(db.Statement)
_ = db.Statement.WriteByte(')')
if len(onConflict.DoUpdates) > 0 {
_, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
onConflict.DoUpdates.Build(db.Statement)
}
_, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
written := false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(column.Name)
}
}
_, _ = db.Statement.WriteString(") VALUES (")
written = false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(clause.Column{
Table: "excluded",
Name: column.Name,
})
}
}
_, _ = db.Statement.WriteString(")")
}
func convertValue(val interface{}) interface{} {
val = ptrDereference(val)
if !db.DryRun {
for idx, vals := range values.Values {
// HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
for idx, val := range vals {
switch v := val.(type) {
case bool:
if v {
@ -205,51 +104,49 @@ func convertValue(val interface{}) interface{} {
} else {
val = 0
}
// case string:
// if len(v) > 2000 {
// val = go_ora.Clob{String: v, Valid: true}
// }
default:
val = convertCustomType(val)
}
return val
}
func getDefaultValues(db *gorm.DB, idx int) {
insertTo := db.Statement.ReflectValue
stmt.Vars[idx] = val
}
// and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert)
// we keep track of the index so that the sub-reflected value is also correct
// BIG BUG: what if any of the transactions failed? some result might already be inserted that oracle is so
// sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point,
// resulting in dangling row entries, so we might need to delete them if an error happens
switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err {
case nil: // success
db.RowsAffected, _ = result.RowsAffected()
insertTo := stmt.ReflectValue
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
insertTo = insertTo.Index(idx)
default:
}
if insertTo.Kind() == reflect.Pointer {
insertTo = insertTo.Elem()
}
for _, val := range db.Statement.Vars {
switch v := val.(type) {
case sql.Out:
if hasDefaultValues {
// bind returning value back to reflected value in the respective fields
funk.ForEach(
funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool {
return funk.Contains(boundVars, field.Name)
}),
func(field *gormSchema.Field) {
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
for i := insertTo.Len() - 1; i >= 0; i-- {
rv := insertTo.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero {
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest))
}
}
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo)
if isZero {
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest))
if err = field.Set(stmt.Context, insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil {
db.AddError(err)
}
case reflect.Map:
// todo 设置id的值
}
},
)
}
default: // failure
db.AddError(err)
}
default:
}
default:
}
}
}

8
go.mod
View File

@ -4,10 +4,10 @@ go 1.22
require (
github.com/emirpasic/gods v1.18.1
github.com/godror/godror v0.42.2
github.com/sijms/go-ora/v2 v2.8.16
github.com/godror/godror v0.43.0
github.com/thoas/go-funk v0.9.3
gorm.io/gorm v1.25.10
)
require (
@ -15,6 +15,6 @@ require (
github.com/godror/knownpb v0.1.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 // indirect
google.golang.org/protobuf v1.33.0 // indirect
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)

10
go.sum
View File

@ -8,8 +8,8 @@ github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/godror/godror v0.42.2 h1:TmOV0fr4jJxwDD6vWtSieQFOZ75LnSbkJaudfuJOsVQ=
github.com/godror/godror v0.42.2/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94=
github.com/godror/godror v0.43.0 h1:qMbQwG0ejJnKma3bBvrJg1rkiyP5b4v6uxvx9zDrKJw=
github.com/godror/godror v0.43.0/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94=
github.com/godror/knownpb v0.1.1 h1:A4J7jdx7jWBhJm18NntafzSC//iZDHkDi1+juwQ5pTI=
github.com/godror/knownpb v0.1.1/go.mod h1:4nRFbQo1dDuwKnblRXDxrfCFYeT4hjg3GjMqef58eRE=
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
@ -22,8 +22,6 @@ github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sijms/go-ora/v2 v2.8.16 h1:XiBAqHjoXUnJnGnRoHD/6NJdiYQ2pr0gPciJM8BE+70=
github.com/sijms/go-ora/v2 v2.8.16/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4FWlH6xk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
@ -31,6 +29,8 @@ github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw=
github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 h1:6R2FC06FonbXQ8pK11/PDFY6N6LWlf9KlzibaCapmqc=
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29 h1:w8s32wxx3sY+OjLlv9qltkLU5yvJzxjjgiHWLjdIcw4=
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
@ -39,6 +39,8 @@ golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@ -3,9 +3,8 @@ package oracle
import (
"database/sql"
"fmt"
"strings"
"gorm.io/gorm/schema"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
@ -276,11 +275,11 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
// https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
// ALTER INDEX oldindex RENAME TO newindex;
panic("TODO")
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"ALTER INDEX ? RENAME TO ?",
clause.Column{Name: oldName}, clause.Column{Name: newName},
"ALTER INDEX ?.? RENAME TO ?", // wat
clause.Table{Name: stmt.Table}, clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}

View File

@ -1,59 +1,38 @@
package oracle
import (
"fmt"
"strings"
"gorm.io/gorm/schema"
"strings"
)
var _ schema.Namer = Namer{}
type Namer struct {
NamingStrategy schema.Namer
DBName string
schema.NamingStrategy
}
func (n Namer) ConvertNameToFormat(x string) string {
func ConvertNameToFormat(x string) string {
return strings.ToUpper(x)
}
func (n Namer) TableName(table string) (name string) {
println("TableName", table)
tableName := n.ConvertNameToFormat(n.NamingStrategy.TableName(table))
if len(n.DBName) > 0 {
return fmt.Sprintf("%s.%s", n.DBName, tableName)
}
return tableName
return ConvertNameToFormat(n.NamingStrategy.TableName(table))
}
func (n Namer) ColumnName(table, column string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column))
return ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column))
}
func (n Namer) JoinTableName(table string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.JoinTableName(table))
return ConvertNameToFormat(n.NamingStrategy.JoinTableName(table))
}
func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship))
return ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship))
}
func (n Namer) CheckerName(table, column string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column))
return ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column))
}
func (n Namer) IndexName(table, column string) (name string) {
return n.ConvertNameToFormat(n.NamingStrategy.IndexName(table, column))
}
func (n Namer) SchemaName(table string) string {
println("SchemaName", table)
return n.ConvertNameToFormat(n.NamingStrategy.SchemaName(table))
}
func (n Namer) UniqueName(table, column string) string {
println("UniqueName", table, column)
return n.ConvertNameToFormat(n.NamingStrategy.UniqueName(table, column))
return ConvertNameToFormat(n.NamingStrategy.IndexName(table, column))
}

259
oracle.go
View File

@ -4,10 +4,14 @@ import (
"context"
"database/sql"
"fmt"
"log"
"regexp"
"strconv"
"strings"
"gorm.io/gorm/utils"
_ "github.com/godror/godror"
"github.com/thoas/go-funk"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
@ -15,75 +19,61 @@ import (
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
_ "github.com/godror/godror"
_ "github.com/sijms/go-ora/v2"
)
const (
RowNumberAliasForOracle11 = "ROW_NUM"
)
func Hello() string {
return "GORM Oracle Driver"
}
type Config struct {
DriverName string
DSN string
Conn *sql.DB
Conn gorm.ConnPool //*sql.DB
DefaultStringSize uint
DBName string
DBVer string
DBName string // 库名
}
type Dialector struct {
*Config
}
func New(config Config) gorm.Dialector {
return &Dialector{Config: &config}
}
func Open(dsn string) gorm.Dialector {
return &Dialector{Config: &Config{DSN: dsn}}
}
func (d Dialector) DummyTableName() string { return "DUAL" }
func New(config Config) gorm.Dialector {
return &Dialector{Config: &config}
}
func (d Dialector) Name() string { return "oracle" }
func (d Dialector) DummyTableName() string {
return "DUAL"
}
func (d Dialector) Name() string {
return "oracle"
}
func (d Dialector) Initialize(db *gorm.DB) (err error) {
db.NamingStrategy = Namer{
NamingStrategy: db.NamingStrategy,
DBName: d.DBName,
}
db.NamingStrategy = Namer{db.NamingStrategy.(schema.NamingStrategy)}
d.DefaultStringSize = 1024
// register callbacks
callbackConfig := &callbacks.Config{}
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
//callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{WithReturning: true})
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
})
// d.DriverName = "godror"
d.DriverName = "oracle"
d.DriverName = "godror"
if d.Conn != nil {
db.ConnPool = d.Conn
} else {
db.ConnPool, err = sql.Open(d.DriverName, d.DSN)
if err != nil {
return
}
}
err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer)
if err != nil {
return err
}
db.Logger.Info(context.Background(), "DBVer:%s", d.DBVer)
//log.Println("DBver:" + d.DBVer)
if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
return
}
@ -91,24 +81,26 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) {
for k, v := range d.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
func (d Dialector) ClauseBuilders() (clauseBuilders map[string]clause.ClauseBuilder) {
clauseBuilders = make(map[string]clause.ClauseBuilder)
if dbVer, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]); dbVer > 11 {
clauseBuilders["LIMIT"] = d.RewriteLimit
func (d Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
dbver, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0])
if dbver > 0 && dbver < 12 {
return map[string]clause.ClauseBuilder{
"LIMIT": d.RewriteLimit11,
}
} else {
clauseBuilders["LIMIT"] = d.RewriteLimit11
return map[string]clause.ClauseBuilder{
"LIMIT": d.RewriteLimit,
}
}
return
}
func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok {
if stmt, ok := builder.(*gorm.Statement); ok {
if _, ok := stmt.Clauses["ORDER BY"]; !ok {
s := stmt.Schema
@ -137,100 +129,38 @@ func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
}
}
// Oracle11 Limit
func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) {
println("rewrite limit oracle 11g")
limit, ok := c.Expression.(clause.Limit)
if !ok {
return
}
offsetRows := limit.Offset
hasOffset := offsetRows > 0
limitRows, hasLimit := d.getLimitRows(limit)
if (!hasOffset && !hasLimit) || limitRows == 1 {
return
}
var stmt *gorm.Statement
if stmt, ok = builder.(*gorm.Statement); !ok {
return
}
if hasLimit && hasOffset {
subQuerySQL := fmt.Sprintf(
"SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY %s) AS %s FROM (%s) T) WHERE %s BETWEEN %d AND %d",
d.getOrderByColumns(stmt),
RowNumberAliasForOracle11,
strings.TrimSpace(stmt.SQL.String()),
RowNumberAliasForOracle11,
offsetRows+1,
offsetRows+limitRows,
)
stmt.SQL.Reset()
stmt.SQL.WriteString(subQuerySQL)
} else if hasLimit {
// only limit
d.rewriteRownumStmt(stmt, builder, " <= ", limitRows)
} else {
// only offset
d.rewriteRownumStmt(stmt, builder, " > ", offsetRows)
}
}
func (d Dialector) rewriteRownumStmt(stmt *gorm.Statement, builder clause.Builder, operator string, rows int) {
limitSql := strings.Builder{}
if limit, ok := c.Expression.(clause.Limit); ok {
if stmt, ok := builder.(*gorm.Statement); ok {
limitsql := strings.Builder{}
if limit := limit.Limit; *limit > 0 {
if _, ok := stmt.Clauses["WHERE"]; !ok {
limitSql.WriteString(" WHERE ")
limitsql.WriteString(" WHERE ")
} else {
limitSql.WriteString(" AND ")
limitsql.WriteString(" AND ")
}
limitSql.WriteString("ROWNUM")
limitSql.WriteString(operator)
limitSql.WriteString(strconv.Itoa(rows))
if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy {
_, _ = builder.WriteString(limitSql.String())
limitsql.WriteString("ROWNUM <= ")
limitsql.WriteString(strconv.Itoa(*limit))
}
if _, ok := stmt.Clauses["ORDER BY"]; !ok {
builder.WriteString(limitsql.String())
} else {
// "ORDER BY" before insert
sqlTmp := strings.Builder{}
sqlOld := stmt.SQL.String()
orderIndex := strings.Index(sqlOld, "ORDER BY") - 1
sqlTmp.WriteString(sqlOld[:orderIndex])
sqlTmp.WriteString(limitSql.String())
sqlTmp.WriteString(sqlOld[orderIndex:])
stmt.SQL = sqlTmp
sqltmp := strings.Builder{}
sqlold := stmt.SQL.String()
orderindx := strings.Index(sqlold, "ORDER BY") - 1
sqltmp.WriteString(sqlold[:orderindx])
sqltmp.WriteString(limitsql.String())
sqltmp.WriteString(sqlold[orderindx:])
log.Println(sqltmp.String())
stmt.SQL = sqltmp
}
}
func (d Dialector) getOrderByColumns(stmt *gorm.Statement) string {
if orderByClause, ok := stmt.Clauses["ORDER BY"]; ok {
var orderBy clause.OrderBy
if orderBy, ok = orderByClause.Expression.(clause.OrderBy); ok && len(orderBy.Columns) > 0 {
println("order columns:", orderBy.Columns)
orderByBuilder := strings.Builder{}
for i, column := range orderBy.Columns {
if i > 0 {
orderByBuilder.WriteString(", ")
}
orderByBuilder.WriteString(column.Column.Name)
if column.Desc {
orderByBuilder.WriteString(" DESC")
}
}
return orderByBuilder.String()
}
}
return "NULL"
}
func (d Dialector) getLimitRows(limit clause.Limit) (limitRows int, hasLimit bool) {
if l := limit.Limit; l != nil {
limitRows = *l
hasLimit = limitRows > 0
}
return
func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
return clause.Expr{SQL: "VALUES (DEFAULT)"}
}
func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
@ -245,8 +175,35 @@ func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
}
}
func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteString(":")
writer.WriteString(strconv.Itoa(len(stmt.Vars)))
}
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteString(str)
}
var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
func (d Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v interface{}) interface{} {
switch v := v.(type) {
case bool:
if v {
return 1
}
return 0
default:
return v
}
}).([]interface{})...)
}
func (d Dialector) DataTypeOf(field *schema.Field) string {
if _, found := field.TagSettings["RESTRICT"]; found {
delete(field.TagSettings, "RESTRICT")
}
var sqlType string
@ -264,7 +221,7 @@ func (d Dialector) DataTypeOf(field *schema.Field) string {
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
sqlType += " GENERATED BY DEFAULT AS IDENTITY"
}
case schema.String:
case schema.String, "VARCHAR2":
size := field.Size
defaultSize := d.DefaultStringSize
@ -288,13 +245,11 @@ func (d Dialector) DataTypeOf(field *schema.Field) string {
case schema.Time:
sqlType = "TIMESTAMP WITH TIME ZONE"
if field.NotNull || field.PrimaryKey {
sqlType += " NOT NULL"
}
case schema.Bytes:
sqlType = "BLOB"
default:
sqlType := string(field.DataType)
sqlType = string(field.DataType)
if strings.EqualFold(sqlType, "text") {
sqlType = "CLOB"
@ -304,51 +259,11 @@ func (d Dialector) DataTypeOf(field *schema.Field) string {
panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String()))
}
notNull := field.TagSettings["NOT NULL"]
unique := field.TagSettings["UNIQUE"]
additionalType := fmt.Sprintf("%s %s", notNull, unique)
if value, ok := field.TagSettings["DEFAULT"]; ok {
additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string {
if value, ok := field.TagSettings["COMMENT"]; ok {
return " COMMENT " + value
}
return ""
}())
}
sqlType = fmt.Sprintf("%v %v", sqlType, additionalType)
return sqlType
}
return sqlType
}
func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
return clause.Expr{SQL: "VALUES (DEFAULT)"}
}
func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v any) {}
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteString(str)
}
var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
func (d Dialector) Explain(sql string, vars ...any) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v any) any {
switch v := v.(type) {
case bool:
if v {
return 1
}
return 0
default:
return v
}
}).([]any)...)
}
func (d Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return tx.Error