update
This commit is contained in:
parent
65b9f3795d
commit
90f807defc
49
clauses/merge.go
Normal file
49
clauses/merge.go
Normal file
@ -0,0 +1,49 @@
|
||||
package clauses
|
||||
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type Merge struct {
|
||||
Table clause.Table
|
||||
Using []clause.Interface
|
||||
On []clause.Expression
|
||||
}
|
||||
|
||||
func (merge Merge) Name() string {
|
||||
return "MERGE"
|
||||
}
|
||||
|
||||
func MergeDefaultExcludeName() string {
|
||||
return "exclude"
|
||||
}
|
||||
|
||||
// Build build from clause
|
||||
func (merge Merge) Build(builder clause.Builder) {
|
||||
clause.Insert{}.Build(builder)
|
||||
builder.WriteString(" USING (")
|
||||
for idx, iface := range merge.Using {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
builder.WriteString(iface.Name())
|
||||
builder.WriteByte(' ')
|
||||
iface.Build(builder)
|
||||
}
|
||||
builder.WriteString(") ")
|
||||
builder.WriteString(MergeDefaultExcludeName())
|
||||
builder.WriteString(" ON (")
|
||||
for idx, on := range merge.On {
|
||||
if idx > 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
on.Build(builder)
|
||||
}
|
||||
builder.WriteString(")")
|
||||
}
|
||||
|
||||
// MergeClause merge values clauses
|
||||
func (merge Merge) MergeClause(clause *clause.Clause) {
|
||||
clause.Name = merge.Name()
|
||||
clause.Expression = merge
|
||||
}
|
10
clauses/returning_into.go
Normal file
10
clauses/returning_into.go
Normal file
@ -0,0 +1,10 @@
|
||||
package clauses
|
||||
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type ReturningInto struct {
|
||||
Variables []clause.Column
|
||||
Into []*clause.Values
|
||||
}
|
39
clauses/when_matched.go
Normal file
39
clauses/when_matched.go
Normal file
@ -0,0 +1,39 @@
|
||||
package clauses
|
||||
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type WhenMatched struct {
|
||||
clause.Set
|
||||
Where, Delete clause.Where
|
||||
}
|
||||
|
||||
func (w WhenMatched) Name() string {
|
||||
return "WHEN MATCHED"
|
||||
}
|
||||
|
||||
func (w WhenMatched) Build(builder clause.Builder) {
|
||||
if len(w.Set) > 0 {
|
||||
builder.WriteString(" THEN")
|
||||
builder.WriteString(" UPDATE ")
|
||||
builder.WriteString(w.Name())
|
||||
builder.WriteByte(' ')
|
||||
w.Build(builder)
|
||||
|
||||
buildWhere := func(where clause.Where) {
|
||||
builder.WriteString(where.Name())
|
||||
builder.WriteByte(' ')
|
||||
where.Build(builder)
|
||||
}
|
||||
|
||||
if len(w.Where.Exprs) > 0 {
|
||||
buildWhere(w.Where)
|
||||
}
|
||||
|
||||
if len(w.Delete.Exprs) > 0 {
|
||||
builder.WriteString(" DELETE ")
|
||||
buildWhere(w.Delete)
|
||||
}
|
||||
}
|
||||
}
|
32
clauses/when_not_matched.go
Normal file
32
clauses/when_not_matched.go
Normal file
@ -0,0 +1,32 @@
|
||||
package clauses
|
||||
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
type WhenNotMatched struct {
|
||||
clause.Values
|
||||
Where clause.Where
|
||||
}
|
||||
|
||||
func (w WhenNotMatched) Name() string {
|
||||
return "WHEN NOT MATCHED"
|
||||
}
|
||||
|
||||
func (w WhenNotMatched) Build(builder clause.Builder) {
|
||||
if len(w.Columns) > 0 {
|
||||
if len(w.Values.Values) != 1 {
|
||||
panic("cannot insert more than one rows due to Oracle SQL language restriction")
|
||||
}
|
||||
|
||||
builder.WriteString(" THEN")
|
||||
builder.WriteString(" INSERT ")
|
||||
w.Build(builder)
|
||||
|
||||
if len(w.Where.Exprs) > 0 {
|
||||
builder.WriteString(w.Where.Name())
|
||||
builder.WriteByte(' ')
|
||||
w.Where.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
255
create.go
Normal file
255
create.go
Normal file
@ -0,0 +1,255 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func Create(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stmt := db.Statement
|
||||
if stmt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
stmtSchema := stmt.Schema
|
||||
if stmtSchema == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !stmt.Unscoped {
|
||||
for _, c := range stmtSchema.CreateClauses {
|
||||
stmt.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.SQL.Len() == 0 {
|
||||
var (
|
||||
createValues = callbacks.ConvertToCreateValues(stmt)
|
||||
onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
|
||||
)
|
||||
|
||||
if hasConflict {
|
||||
if len(stmtSchema.PrimaryFields) > 0 {
|
||||
columnsMap := map[string]bool{}
|
||||
for _, column := range createValues.Columns {
|
||||
columnsMap[column.Name] = true
|
||||
}
|
||||
|
||||
for _, field := range stmtSchema.PrimaryFields {
|
||||
if _, ok := columnsMap[field.DBName]; !ok {
|
||||
hasConflict = false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
hasConflict = false
|
||||
}
|
||||
}
|
||||
|
||||
hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0
|
||||
if hasConflict {
|
||||
MergeCreate(db, onConflict, createValues)
|
||||
} else {
|
||||
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}})
|
||||
stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}})
|
||||
|
||||
if hasDefaultValues {
|
||||
columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue))
|
||||
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
|
||||
columns[idx] = clause.Column{Name: field.DBName}
|
||||
}
|
||||
stmt.AddClauseIfNotExists(clause.Returning{Columns: columns})
|
||||
}
|
||||
stmt.Build("INSERT", "VALUES", "RETURNING")
|
||||
|
||||
if hasDefaultValues {
|
||||
_, _ = stmt.WriteString(" INTO ")
|
||||
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
|
||||
if idx > 0 {
|
||||
_ = stmt.WriteByte(',')
|
||||
}
|
||||
stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()})
|
||||
}
|
||||
_, _ = stmt.WriteString(" /*-sql.Out{}-*/")
|
||||
}
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if hasConflict {
|
||||
for i, val := range stmt.Vars {
|
||||
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
|
||||
stmt.Vars[i] = convertValue(val)
|
||||
}
|
||||
|
||||
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
// TODO: get merged returning
|
||||
}
|
||||
} else {
|
||||
for idx, values := range createValues.Values {
|
||||
for i, val := range values {
|
||||
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
|
||||
stmt.Vars[i] = convertValue(val)
|
||||
}
|
||||
|
||||
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if hasDefaultValues {
|
||||
getDefaultValues(db, idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
|
||||
var dummyTable string
|
||||
switch d := ptrDereference(db.Dialector).(type) {
|
||||
case Dialector:
|
||||
dummyTable = d.DummyTableName()
|
||||
default:
|
||||
dummyTable = "DUAL"
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString("MERGE INTO ")
|
||||
db.Statement.WriteQuoted(db.Statement.Table)
|
||||
_, _ = db.Statement.WriteString(" USING (")
|
||||
|
||||
for idx, value := range values.Values {
|
||||
if idx > 0 {
|
||||
_, _ = db.Statement.WriteString(" UNION ALL ")
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString("SELECT ")
|
||||
for i, v := range value {
|
||||
if i > 0 {
|
||||
_ = db.Statement.WriteByte(',')
|
||||
}
|
||||
column := values.Columns[i]
|
||||
db.Statement.AddVar(db.Statement, v)
|
||||
_, _ = db.Statement.WriteString(" AS ")
|
||||
db.Statement.WriteQuoted(column.Name)
|
||||
}
|
||||
_, _ = db.Statement.WriteString(" FROM ")
|
||||
_, _ = db.Statement.WriteString(dummyTable)
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString(`) `)
|
||||
db.Statement.WriteQuoted("excluded")
|
||||
_, _ = db.Statement.WriteString(" ON (")
|
||||
|
||||
var where clause.Where
|
||||
for _, field := range db.Statement.Schema.PrimaryFields {
|
||||
where.Exprs = append(where.Exprs, clause.Eq{
|
||||
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
|
||||
Value: clause.Column{Table: "excluded", Name: field.DBName},
|
||||
})
|
||||
}
|
||||
where.Build(db.Statement)
|
||||
_ = db.Statement.WriteByte(')')
|
||||
|
||||
if len(onConflict.DoUpdates) > 0 {
|
||||
_, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
|
||||
onConflict.DoUpdates.Build(db.Statement)
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
|
||||
|
||||
written := false
|
||||
for _, column := range values.Columns {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
|
||||
if written {
|
||||
_ = db.Statement.WriteByte(',')
|
||||
}
|
||||
written = true
|
||||
db.Statement.WriteQuoted(column.Name)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString(") VALUES (")
|
||||
|
||||
written = false
|
||||
for _, column := range values.Columns {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
|
||||
if written {
|
||||
_ = db.Statement.WriteByte(',')
|
||||
}
|
||||
written = true
|
||||
db.Statement.WriteQuoted(clause.Column{
|
||||
Table: "excluded",
|
||||
Name: column.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
_, _ = db.Statement.WriteString(")")
|
||||
}
|
||||
|
||||
func convertValue(val interface{}) interface{} {
|
||||
val = ptrDereference(val)
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
val = 1
|
||||
} else {
|
||||
val = 0
|
||||
}
|
||||
// case string:
|
||||
// if len(v) > 2000 {
|
||||
// val = go_ora.Clob{String: v, Valid: true}
|
||||
// }
|
||||
default:
|
||||
val = convertCustomType(val)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func getDefaultValues(db *gorm.DB, idx int) {
|
||||
insertTo := db.Statement.ReflectValue
|
||||
switch insertTo.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
insertTo = insertTo.Index(idx)
|
||||
default:
|
||||
}
|
||||
if insertTo.Kind() == reflect.Pointer {
|
||||
insertTo = insertTo.Elem()
|
||||
}
|
||||
|
||||
for _, val := range db.Statement.Vars {
|
||||
switch v := val.(type) {
|
||||
case sql.Out:
|
||||
switch insertTo.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := insertTo.Len() - 1; i >= 0; i-- {
|
||||
rv := insertTo.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest))
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo)
|
||||
if isZero {
|
||||
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest))
|
||||
}
|
||||
default:
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
17
go.mod
17
go.mod
@ -1,3 +1,20 @@
|
||||
module git.charlienet.top/go/oracle
|
||||
|
||||
go 1.22
|
||||
|
||||
require (
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/godror/godror v0.42.2
|
||||
github.com/sijms/go-ora/v2 v2.8.16
|
||||
github.com/thoas/go-funk v0.9.3
|
||||
gorm.io/gorm v1.25.10
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
||||
github.com/godror/knownpb v0.1.1 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 // indirect
|
||||
google.golang.org/protobuf v1.33.0 // indirect
|
||||
)
|
||||
|
46
go.sum
Normal file
46
go.sum
Normal file
@ -0,0 +1,46 @@
|
||||
github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4=
|
||||
github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
|
||||
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
|
||||
github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/godror/godror v0.42.2 h1:TmOV0fr4jJxwDD6vWtSieQFOZ75LnSbkJaudfuJOsVQ=
|
||||
github.com/godror/godror v0.42.2/go.mod h1:82Uc/HdjsFVnzR5c9Yf6IkTBalK80jzm/U6xojbTo94=
|
||||
github.com/godror/knownpb v0.1.1 h1:A4J7jdx7jWBhJm18NntafzSC//iZDHkDi1+juwQ5pTI=
|
||||
github.com/godror/knownpb v0.1.1/go.mod h1:4nRFbQo1dDuwKnblRXDxrfCFYeT4hjg3GjMqef58eRE=
|
||||
github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg=
|
||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
|
||||
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sijms/go-ora/v2 v2.8.16 h1:XiBAqHjoXUnJnGnRoHD/6NJdiYQ2pr0gPciJM8BE+70=
|
||||
github.com/sijms/go-ora/v2 v2.8.16/go.mod h1:EHxlY6x7y9HAsdfumurRfTd+v8NrEOTR3Xl4FWlH6xk=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw=
|
||||
github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q=
|
||||
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81 h1:6R2FC06FonbXQ8pK11/PDFY6N6LWlf9KlzibaCapmqc=
|
||||
golang.org/x/exp v0.0.0-20240318143956-a85f2c67cd81/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
|
||||
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29 h1:w8s32wxx3sY+OjLlv9qltkLU5yvJzxjjgiHWLjdIcw4=
|
||||
golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c=
|
||||
golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o=
|
||||
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
|
||||
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
|
||||
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
325
migrator.go
Normal file
325
migrator.go
Normal file
@ -0,0 +1,325 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
m.DB.Raw(
|
||||
fmt.Sprintf(`SELECT ORA_DATABASE_NAME as "Current Database" FROM %s`, m.Dialector.(Dialector).DummyTableName()),
|
||||
).Row().Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
for _, value := range values {
|
||||
m.TryQuotifyReservedWords(value)
|
||||
m.TryRemoveOnUpdate(value)
|
||||
}
|
||||
return m.Migrator.CreateTable(values...)
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
values = m.ReorderModels(values, false)
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
value := values[i]
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if m.HasTable(value) {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE ? CASCADE CONSTRAINTS", clause.Table{Name: stmt.Table}).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int64
|
||||
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
|
||||
ownertable := strings.Split(stmt.Schema.Table, ".")
|
||||
return m.DB.Raw("SELECT COUNT(*) FROM ALL_TABLES WHERE OWNER = ? and TABLE_NAME = ?", ownertable[0], ownertable[1]).Row().Scan(&count)
|
||||
} else {
|
||||
return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", stmt.Table).Row().Scan(&count)
|
||||
}
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Schema.Table).Where("ROWNUM = 1").Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = rows.Close()
|
||||
}()
|
||||
|
||||
var rawColumnTypes []*sql.ColumnType
|
||||
rawColumnTypes, err = rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, c := range rawColumnTypes {
|
||||
columnTypes = append(columnTypes, migrator.ColumnType{SQLColumnType: c})
|
||||
}
|
||||
|
||||
return
|
||||
})
|
||||
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
func (m Migrator) RenameTable(oldName, newName interface{}) (err error) {
|
||||
resolveTable := func(name interface{}) (result string, err error) {
|
||||
if v, ok := name.(string); ok {
|
||||
result = v
|
||||
} else {
|
||||
stmt := &gorm.Statement{DB: m.DB}
|
||||
if err = stmt.Parse(name); err == nil {
|
||||
result = stmt.Table
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var oldTable, newTable string
|
||||
|
||||
if oldTable, err = resolveTable(oldName); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if newTable, err = resolveTable(newName); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !m.HasTable(oldTable) {
|
||||
return
|
||||
}
|
||||
|
||||
return m.DB.Exec("RENAME TABLE ? TO ?",
|
||||
clause.Table{Name: oldTable},
|
||||
clause.Table{Name: newTable},
|
||||
).Error
|
||||
}
|
||||
|
||||
func (m Migrator) AddColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ADD ? ?",
|
||||
clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field),
|
||||
).Error
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
if !m.HasColumn(value, name) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? DROP ?",
|
||||
clause.Table{Name: stmt.Schema.Table},
|
||||
clause.Column{Name: name},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
if !m.HasColumn(value, field) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? MODIFY ? ?",
|
||||
clause.Table{Name: stmt.Schema.Table},
|
||||
clause.Column{Name: field.DBName},
|
||||
m.AlterDataTypeOf(stmt, field),
|
||||
).Error
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
var count int64
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
|
||||
ownertable := strings.Split(stmt.Schema.Table, ".")
|
||||
return m.DB.Raw("SELECT COUNT(*) FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownertable[0], ownertable[1], field).Row().Scan(&count)
|
||||
} else {
|
||||
return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field).Row().Scan(&count)
|
||||
}
|
||||
|
||||
}) == nil && count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterDataTypeOf(stmt *gorm.Statement, field *schema.Field) (expr clause.Expr) {
|
||||
expr.SQL = m.DataTypeOf(field)
|
||||
|
||||
var nullable = ""
|
||||
if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") {
|
||||
ownertable := strings.Split(stmt.Schema.Table, ".")
|
||||
m.DB.Raw("SELECT NULLABLE FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownertable[0], ownertable[1], field.DBName).Row().Scan(&nullable)
|
||||
} else {
|
||||
m.DB.Raw("SELECT NULLABLE FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field.DBName).Row().Scan(&nullable)
|
||||
}
|
||||
if field.NotNull && nullable == "Y" {
|
||||
expr.SQL += " NOT NULL"
|
||||
}
|
||||
|
||||
if field.Unique {
|
||||
expr.SQL += " UNIQUE"
|
||||
}
|
||||
|
||||
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
|
||||
if field.DefaultValueInterface != nil {
|
||||
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
||||
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
|
||||
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
|
||||
} else if field.DefaultValue != "(-)" {
|
||||
expr.SQL += " DEFAULT " + field.DefaultValue
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
m.TryRemoveOnUpdate(value)
|
||||
return m.Migrator.CreateConstraint(value, name)
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
if chk.Name == name {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? DROP CHECK ?",
|
||||
clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name},
|
||||
).Error
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? DROP CONSTRAINT ?",
|
||||
clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw(
|
||||
"SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE TABLE_NAME = ? AND CONSTRAINT_NAME = ?", stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
}) == nil && count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: stmt.Schema.Table}).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT COUNT(*) FROM USER_INDEXES WHERE TABLE_NAME = ? AND INDEX_NAME = ?",
|
||||
m.Migrator.DB.NamingStrategy.TableName(stmt.Table),
|
||||
m.Migrator.DB.NamingStrategy.IndexName(stmt.Table, name),
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
// ALTER INDEX oldindex RENAME TO newindex;
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Exec(
|
||||
"ALTER INDEX ? RENAME TO ?",
|
||||
clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) TryRemoveOnUpdate(values ...interface{}) error {
|
||||
for _, value := range values {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
constraint := rel.ParseConstraint()
|
||||
if constraint != nil {
|
||||
rel.Field.TagSettings["CONSTRAINT"] = strings.ReplaceAll(rel.Field.TagSettings["CONSTRAINT"], fmt.Sprintf("ON UPDATE %s", constraint.OnUpdate), "")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) TryQuotifyReservedWords(values ...interface{}) error {
|
||||
for _, value := range values {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
for idx, v := range stmt.Schema.DBNames {
|
||||
if IsReservedWord(v) {
|
||||
stmt.Schema.DBNames[idx] = fmt.Sprintf(`"%s"`, v)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range stmt.Schema.Fields {
|
||||
if IsReservedWord(v.DBName) {
|
||||
v.DBName = fmt.Sprintf(`"%s"`, v.DBName)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
59
namer.go
Normal file
59
namer.go
Normal file
@ -0,0 +1,59 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var _ schema.Namer = Namer{}
|
||||
|
||||
type Namer struct {
|
||||
NamingStrategy schema.Namer
|
||||
DBName string
|
||||
}
|
||||
|
||||
func (n Namer) ConvertNameToFormat(x string) string {
|
||||
return strings.ToUpper(x)
|
||||
}
|
||||
|
||||
func (n Namer) TableName(table string) (name string) {
|
||||
println("TableName", table)
|
||||
|
||||
tableName := n.ConvertNameToFormat(n.NamingStrategy.TableName(table))
|
||||
if len(n.DBName) > 0 {
|
||||
return fmt.Sprintf("%s.%s", n.DBName, tableName)
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
func (n Namer) ColumnName(table, column string) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.ColumnName(table, column))
|
||||
}
|
||||
|
||||
func (n Namer) JoinTableName(table string) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.JoinTableName(table))
|
||||
}
|
||||
|
||||
func (n Namer) RelationshipFKName(relationship schema.Relationship) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.RelationshipFKName(relationship))
|
||||
}
|
||||
|
||||
func (n Namer) CheckerName(table, column string) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.CheckerName(table, column))
|
||||
}
|
||||
|
||||
func (n Namer) IndexName(table, column string) (name string) {
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.IndexName(table, column))
|
||||
}
|
||||
|
||||
func (n Namer) SchemaName(table string) string {
|
||||
println("SchemaName", table)
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.SchemaName(table))
|
||||
}
|
||||
|
||||
func (n Namer) UniqueName(table, column string) string {
|
||||
println("UniqueName", table, column)
|
||||
return n.ConvertNameToFormat(n.NamingStrategy.UniqueName(table, column))
|
||||
}
|
355
oracle.go
355
oracle.go
@ -1,5 +1,360 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/thoas/go-funk"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
|
||||
_ "github.com/godror/godror"
|
||||
_ "github.com/sijms/go-ora/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
RowNumberAliasForOracle11 = "ROW_NUM"
|
||||
)
|
||||
|
||||
func Hello() string {
|
||||
return "GORM Oracle Driver"
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
DriverName string
|
||||
DSN string
|
||||
Conn *sql.DB
|
||||
DefaultStringSize uint
|
||||
DBVer string
|
||||
DBName string // 库名
|
||||
}
|
||||
|
||||
type Dialector struct {
|
||||
*Config
|
||||
}
|
||||
|
||||
func New(config Config) gorm.Dialector {
|
||||
return &Dialector{Config: &config}
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{Config: &Config{DSN: dsn}}
|
||||
}
|
||||
|
||||
func (d Dialector) DummyTableName() string { return "DUAL" }
|
||||
|
||||
func (d Dialector) Name() string { return "oracle" }
|
||||
|
||||
func (d Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
db.NamingStrategy = Namer{
|
||||
NamingStrategy: db.NamingStrategy,
|
||||
DBName: d.DBName,
|
||||
}
|
||||
d.DefaultStringSize = 1024
|
||||
|
||||
// register callbacks
|
||||
callbackConfig := &callbacks.Config{}
|
||||
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
|
||||
|
||||
// d.DriverName = "godror"
|
||||
d.DriverName = "oracle"
|
||||
|
||||
if d.Conn != nil {
|
||||
db.ConnPool = d.Conn
|
||||
} else {
|
||||
db.ConnPool, err = sql.Open(d.DriverName, d.DSN)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = db.ConnPool.QueryRowContext(context.Background(), "select version from product_component_version where rownum = 1").Scan(&d.DBVer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.Logger.Info(context.Background(), "DBVer:%s", d.DBVer)
|
||||
|
||||
if err = db.Callback().Create().Replace("gorm:create", Create); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range d.ClauseBuilders() {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (d Dialector) ClauseBuilders() (clauseBuilders map[string]clause.ClauseBuilder) {
|
||||
clauseBuilders = make(map[string]clause.ClauseBuilder)
|
||||
if dbVer, _ := strconv.Atoi(strings.Split(d.DBVer, ".")[0]); dbVer > 11 {
|
||||
clauseBuilders["LIMIT"] = d.RewriteLimit
|
||||
} else {
|
||||
clauseBuilders["LIMIT"] = d.RewriteLimit11
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (d Dialector) RewriteLimit(c clause.Clause, builder clause.Builder) {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
|
||||
if stmt, ok := builder.(*gorm.Statement); ok {
|
||||
if _, ok := stmt.Clauses["ORDER BY"]; !ok {
|
||||
s := stmt.Schema
|
||||
builder.WriteString("ORDER BY ")
|
||||
if s != nil && s.PrioritizedPrimaryField != nil {
|
||||
builder.WriteQuoted(s.PrioritizedPrimaryField.DBName)
|
||||
builder.WriteByte(' ')
|
||||
} else {
|
||||
builder.WriteString("(SELECT NULL FROM ")
|
||||
builder.WriteString(d.DummyTableName())
|
||||
builder.WriteString(")")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if offset := limit.Offset; offset > 0 {
|
||||
builder.WriteString(" OFFSET ")
|
||||
builder.WriteString(strconv.Itoa(offset))
|
||||
builder.WriteString(" ROWS")
|
||||
}
|
||||
if limit := limit.Limit; *limit > 0 {
|
||||
builder.WriteString(" FETCH NEXT ")
|
||||
builder.WriteString(strconv.Itoa(*limit))
|
||||
builder.WriteString(" ROWS ONLY")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d Dialector) RewriteLimit11(c clause.Clause, builder clause.Builder) {
|
||||
println("rewrite limit oracle 11g")
|
||||
limit, ok := c.Expression.(clause.Limit)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
offsetRows := limit.Offset
|
||||
hasOffset := offsetRows > 0
|
||||
limitRows, hasLimit := d.getLimitRows(limit)
|
||||
if (!hasOffset && !hasLimit) || limitRows == 1 {
|
||||
return
|
||||
}
|
||||
|
||||
var stmt *gorm.Statement
|
||||
if stmt, ok = builder.(*gorm.Statement); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if hasLimit && hasOffset {
|
||||
subQuerySQL := fmt.Sprintf(
|
||||
"SELECT * FROM (SELECT T.*, ROW_NUMBER() OVER (ORDER BY %s) AS %s FROM (%s) T) WHERE %s BETWEEN %d AND %d",
|
||||
d.getOrderByColumns(stmt),
|
||||
RowNumberAliasForOracle11,
|
||||
strings.TrimSpace(stmt.SQL.String()),
|
||||
RowNumberAliasForOracle11,
|
||||
offsetRows+1,
|
||||
offsetRows+limitRows,
|
||||
)
|
||||
stmt.SQL.Reset()
|
||||
stmt.SQL.WriteString(subQuerySQL)
|
||||
} else if hasLimit {
|
||||
// only limit
|
||||
d.rewriteRownumStmt(stmt, builder, " <= ", limitRows)
|
||||
} else {
|
||||
// only offset
|
||||
d.rewriteRownumStmt(stmt, builder, " > ", offsetRows)
|
||||
}
|
||||
}
|
||||
|
||||
func (d Dialector) rewriteRownumStmt(stmt *gorm.Statement, builder clause.Builder, operator string, rows int) {
|
||||
limitSql := strings.Builder{}
|
||||
if _, ok := stmt.Clauses["WHERE"]; !ok {
|
||||
limitSql.WriteString(" WHERE ")
|
||||
} else {
|
||||
limitSql.WriteString(" AND ")
|
||||
}
|
||||
limitSql.WriteString("ROWNUM")
|
||||
limitSql.WriteString(operator)
|
||||
limitSql.WriteString(strconv.Itoa(rows))
|
||||
|
||||
if _, hasOrderBy := stmt.Clauses["ORDER BY"]; !hasOrderBy {
|
||||
_, _ = builder.WriteString(limitSql.String())
|
||||
} else {
|
||||
// "ORDER BY" before insert
|
||||
sqlTmp := strings.Builder{}
|
||||
sqlOld := stmt.SQL.String()
|
||||
orderIndex := strings.Index(sqlOld, "ORDER BY") - 1
|
||||
sqlTmp.WriteString(sqlOld[:orderIndex])
|
||||
sqlTmp.WriteString(limitSql.String())
|
||||
sqlTmp.WriteString(sqlOld[orderIndex:])
|
||||
stmt.SQL = sqlTmp
|
||||
}
|
||||
}
|
||||
|
||||
func (d Dialector) getOrderByColumns(stmt *gorm.Statement) string {
|
||||
if orderByClause, ok := stmt.Clauses["ORDER BY"]; ok {
|
||||
var orderBy clause.OrderBy
|
||||
|
||||
if orderBy, ok = orderByClause.Expression.(clause.OrderBy); ok && len(orderBy.Columns) > 0 {
|
||||
println("order columns:", orderBy.Columns)
|
||||
|
||||
orderByBuilder := strings.Builder{}
|
||||
for i, column := range orderBy.Columns {
|
||||
if i > 0 {
|
||||
orderByBuilder.WriteString(", ")
|
||||
}
|
||||
orderByBuilder.WriteString(column.Column.Name)
|
||||
if column.Desc {
|
||||
orderByBuilder.WriteString(" DESC")
|
||||
}
|
||||
}
|
||||
return orderByBuilder.String()
|
||||
}
|
||||
}
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
func (d Dialector) getLimitRows(limit clause.Limit) (limitRows int, hasLimit bool) {
|
||||
if l := limit.Limit; l != nil {
|
||||
limitRows = *l
|
||||
hasLimit = limitRows > 0
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{
|
||||
Migrator: migrator.Migrator{
|
||||
Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: d,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (d Dialector) DataTypeOf(field *schema.Field) string {
|
||||
delete(field.TagSettings, "RESTRICT")
|
||||
|
||||
var sqlType string
|
||||
|
||||
switch field.DataType {
|
||||
case schema.Bool, schema.Int, schema.Uint, schema.Float:
|
||||
sqlType = "INTEGER"
|
||||
|
||||
switch {
|
||||
case field.DataType == schema.Float:
|
||||
sqlType = "FLOAT"
|
||||
case field.Size <= 8:
|
||||
sqlType = "SMALLINT"
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["AUTOINCREMENT"]; ok && utils.CheckTruth(val) {
|
||||
sqlType += " GENERATED BY DEFAULT AS IDENTITY"
|
||||
}
|
||||
case schema.String:
|
||||
size := field.Size
|
||||
defaultSize := d.DefaultStringSize
|
||||
|
||||
if size == 0 {
|
||||
if defaultSize > 0 {
|
||||
size = int(defaultSize)
|
||||
} else {
|
||||
hasIndex := field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUE"] != ""
|
||||
// TEXT, GEOMETRY or JSON column can't have a default value
|
||||
if field.PrimaryKey || field.HasDefaultValue || hasIndex {
|
||||
size = 191 // utf8mb4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if size >= 2000 {
|
||||
sqlType = "CLOB"
|
||||
} else {
|
||||
sqlType = fmt.Sprintf("VARCHAR2(%d)", size)
|
||||
}
|
||||
|
||||
case schema.Time:
|
||||
sqlType = "TIMESTAMP WITH TIME ZONE"
|
||||
if field.NotNull || field.PrimaryKey {
|
||||
sqlType += " NOT NULL"
|
||||
}
|
||||
case schema.Bytes:
|
||||
sqlType = "BLOB"
|
||||
default:
|
||||
sqlType := string(field.DataType)
|
||||
|
||||
if strings.EqualFold(sqlType, "text") {
|
||||
sqlType = "CLOB"
|
||||
}
|
||||
|
||||
if sqlType == "" {
|
||||
panic(fmt.Sprintf("invalid sql type %s (%s) for oracle", field.FieldType.Name(), field.FieldType.String()))
|
||||
}
|
||||
|
||||
notNull := field.TagSettings["NOT NULL"]
|
||||
unique := field.TagSettings["UNIQUE"]
|
||||
additionalType := fmt.Sprintf("%s %s", notNull, unique)
|
||||
if value, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
additionalType = fmt.Sprintf("%s %s %s%s", "DEFAULT", value, additionalType, func() string {
|
||||
if value, ok := field.TagSettings["COMMENT"]; ok {
|
||||
return " COMMENT " + value
|
||||
}
|
||||
return ""
|
||||
}())
|
||||
}
|
||||
|
||||
sqlType = fmt.Sprintf("%v %v", sqlType, additionalType)
|
||||
return sqlType
|
||||
}
|
||||
|
||||
return sqlType
|
||||
}
|
||||
|
||||
func (d Dialector) DefaultValueOf(*schema.Field) clause.Expression {
|
||||
return clause.Expr{SQL: "VALUES (DEFAULT)"}
|
||||
}
|
||||
|
||||
func (d Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v any) {}
|
||||
|
||||
func (d Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteString(str)
|
||||
}
|
||||
|
||||
var numericPlaceholder = regexp.MustCompile(`:(\d+)`)
|
||||
|
||||
func (d Dialector) Explain(sql string, vars ...any) string {
|
||||
return logger.ExplainSQL(sql, numericPlaceholder, `'`, funk.Map(vars, func(v any) any {
|
||||
switch v := v.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}).([]any)...)
|
||||
}
|
||||
|
||||
func (d Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
tx.Exec("SAVEPOINT " + name)
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
func (d Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
||||
return tx.Error
|
||||
}
|
||||
|
28
reserved.go
Normal file
28
reserved.go
Normal file
@ -0,0 +1,28 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"github.com/emirpasic/gods/sets/hashset"
|
||||
"github.com/thoas/go-funk"
|
||||
)
|
||||
|
||||
var ReservedWords = hashset.New(funk.Map(ReservedWordsList, func(s string) interface{} { return s }).([]interface{})...)
|
||||
|
||||
func IsReservedWord(v string) bool {
|
||||
return ReservedWords.Contains(v)
|
||||
}
|
||||
|
||||
var ReservedWordsList = []string{
|
||||
"AGGREGATE", "AGGREGATES", "ALL", "ALLOW", "ANALYZE", "ANCESTOR", "AND", "ANY", "AS", "ASC", "AT", "AVG", "BETWEEN",
|
||||
"BINARY_DOUBLE", "BINARY_FLOAT", "BLOB", "BRANCH", "BUILD", "BY", "BYTE", "CASE", "CAST", "CHAR", "CHILD", "CLEAR",
|
||||
"CLOB", "COMMIT", "COMPILE", "CONSIDER", "COUNT", "DATATYPE", "DATE", "DATE_MEASURE", "DAY", "DECIMAL", "DELETE",
|
||||
"DESC", "DESCENDANT", "DIMENSION", "DISALLOW", "DIVISION", "DML", "ELSE", "END", "ESCAPE", "EXECUTE", "FIRST",
|
||||
"FLOAT", "FOR", "FROM", "HIERARCHIES", "HIERARCHY", "HOUR", "IGNORE", "IN", "INFINITE", "INSERT", "INTEGER",
|
||||
"INTERVAL", "INTO", "IS", "LAST", "LEAF_DESCENDANT", "LEAVES", "LEVEL", "LIKE", "LIKEC", "LIKE2", "LIKE4", "LOAD",
|
||||
"LOCAL", "LOG_SPEC", "LONG", "MAINTAIN", "MAX", "MEASURE", "MEASURES", "MEMBER", "MEMBERS", "MERGE", "MLSLABEL",
|
||||
"MIN", "MINUTE", "MODEL", "MONTH", "NAN", "NCHAR", "NCLOB", "NO", "NONE", "NOT", "NULL", "NULLS", "NUMBER",
|
||||
"NVARCHAR2", "OF", "OLAP", "OLAP_DML_EXPRESSION", "ON", "ONLY", "OPERATOR", "OR", "ORDER", "OVER", "OVERFLOW",
|
||||
"PARALLEL", "PARENT", "PLSQL", "PRUNE", "RAW", "RELATIVE", "ROOT_ANCESTOR", "ROWID", "SCN", "SECOND", "SELF",
|
||||
"SERIAL", "SET", "SOLVE", "SOME", "SORT", "SPEC", "SUM", "SYNCH", "TEXT_MEASURE", "THEN", "TIME", "TIMESTAMP",
|
||||
"TO", "UNBRANCH", "UPDATE", "USING", "VALIDATE", "VALUES", "VARCHAR2", "WHEN", "WHERE", "WITHIN", "WITH", "YEAR",
|
||||
"ZERO", "ZONE",
|
||||
}
|
68
utils.go
Normal file
68
utils.go
Normal file
@ -0,0 +1,68 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func convertCustomType(val interface{}) interface{} {
|
||||
rv := reflect.ValueOf(val)
|
||||
ri := rv.Interface()
|
||||
typeName := reflect.TypeOf(ri).Name()
|
||||
if reflect.TypeOf(val).Kind() == reflect.Ptr {
|
||||
if rv.IsNil() {
|
||||
typeName = rv.Type().Elem().Name()
|
||||
} else {
|
||||
for rv.Kind() == reflect.Ptr {
|
||||
rv = rv.Elem()
|
||||
}
|
||||
ri = rv.Interface()
|
||||
typeName = reflect.TypeOf(ri).Name()
|
||||
}
|
||||
}
|
||||
if typeName == "DeletedAt" {
|
||||
// gorm.DeletedAt
|
||||
if rv.IsZero() {
|
||||
val = sql.NullTime{}
|
||||
} else {
|
||||
val = getTimeValue(ri.(gorm.DeletedAt).Time)
|
||||
}
|
||||
} else if m := rv.MethodByName("Time"); m.IsValid() && m.Type().NumIn() == 0 {
|
||||
// custom time type
|
||||
for _, result := range m.Call([]reflect.Value{}) {
|
||||
if reflect.TypeOf(result.Interface()).Name() == "Time" {
|
||||
val = getTimeValue(result.Interface().(time.Time))
|
||||
}
|
||||
}
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func ptrDereference(obj interface{}) (value interface{}) {
|
||||
if obj == nil {
|
||||
return obj
|
||||
}
|
||||
if t := reflect.TypeOf(obj); t.Kind() != reflect.Ptr {
|
||||
return obj
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(obj)
|
||||
for v.Kind() == reflect.Ptr && !v.IsNil() {
|
||||
v = v.Elem()
|
||||
}
|
||||
if !v.IsValid() || v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
return obj
|
||||
}
|
||||
value = v.Interface()
|
||||
return
|
||||
}
|
||||
|
||||
func getTimeValue(t time.Time) interface{} {
|
||||
if t.IsZero() {
|
||||
return sql.NullTime{}
|
||||
}
|
||||
return t
|
||||
}
|
Loading…
Reference in New Issue
Block a user