This commit is contained in:
2024-05-10 10:27:21 +08:00
parent 90f807defc
commit 184a8722d8
9 changed files with 285 additions and 423 deletions

303
create.go
View File

@ -1,255 +1,152 @@
package oracle
import (
"bytes"
"database/sql"
"reflect"
"github.com/thoas/go-funk"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
gormSchema "gorm.io/gorm/schema"
"git.charlienet.top/go/oracle/clauses"
)
func Create(db *gorm.DB) {
if db.Error != nil {
return
}
stmt := db.Statement
if stmt == nil {
schema := stmt.Schema
boundVars := make(map[string]int)
if stmt == nil || schema == nil {
return
}
stmtSchema := stmt.Schema
if stmtSchema == nil {
return
}
hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0
if !stmt.Unscoped {
for _, c := range stmtSchema.CreateClauses {
for _, c := range schema.CreateClauses {
stmt.AddClause(c)
}
}
if stmt.SQL.Len() == 0 {
var (
createValues = callbacks.ConvertToCreateValues(stmt)
onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
)
if stmt.SQL.String() == "" {
values := callbacks.ConvertToCreateValues(stmt)
onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
// are all columns in value the primary fields in schema only?
if hasConflict && funk.Contains(
funk.Map(values.Columns, func(c clause.Column) string { return c.Name }),
funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) string { return field.DBName }),
) {
stmt.AddClauseIfNotExists(clauses.Merge{
Using: []clause.Interface{
clause.Select{
Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column {
// HACK: I can not come up with a better alternative for now
// I want to add a value to the list of variable and then capture the bind variable position as well
buf := bytes.NewBufferString("")
stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)])
stmt.BindVarTo(buf, stmt, nil)
if hasConflict {
if len(stmtSchema.PrimaryFields) > 0 {
columnsMap := map[string]bool{}
for _, column := range createValues.Columns {
columnsMap[column.Name] = true
}
for _, field := range stmtSchema.PrimaryFields {
if _, ok := columnsMap[field.DBName]; !ok {
hasConflict = false
column.Alias = column.Name
// then the captured bind var will be the name
column.Name = buf.String()
return column
}).([]clause.Column),
},
clause.From{
Tables: []clause.Table{{Name: db.Dialector.(Dialector).DummyTableName()}},
},
},
On: funk.Map(schema.PrimaryFields, func(field *gormSchema.Field) clause.Expression {
return clause.Eq{
Column: clause.Column{Table: stmt.Schema.Table, Name: field.DBName},
Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.DBName},
}
}
} else {
hasConflict = false
}
}
}).([]clause.Expression),
})
stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates})
stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values})
hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0
if hasConflict {
MergeCreate(db, onConflict, createValues)
stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED")
} else {
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}})
stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}})
stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}})
if hasDefaultValues {
columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue))
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
columns[idx] = clause.Column{Name: field.DBName}
}
stmt.AddClauseIfNotExists(clause.Returning{Columns: columns})
stmt.AddClauseIfNotExists(clause.Returning{
Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column {
return clause.Column{Name: field.DBName}
}).([]clause.Column),
})
}
stmt.Build("INSERT", "VALUES", "RETURNING")
if hasDefaultValues {
_, _ = stmt.WriteString(" INTO ")
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
stmt.WriteString(" INTO ")
for idx, field := range schema.FieldsWithDefaultDBValue {
if idx > 0 {
_ = stmt.WriteByte(',')
stmt.WriteByte(',')
}
boundVars[field.Name] = len(stmt.Vars)
stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()})
}
_, _ = stmt.WriteString(" /*-sql.Out{}-*/")
}
}
if !db.DryRun && db.Error == nil {
if hasConflict {
for i, val := range stmt.Vars {
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
stmt.Vars[i] = convertValue(val)
}
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
// TODO: get merged returning
}
} else {
for idx, values := range createValues.Values {
for i, val := range values {
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
stmt.Vars[i] = convertValue(val)
}
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
if hasDefaultValues {
getDefaultValues(db, idx)
if !db.DryRun {
for idx, vals := range values.Values {
// HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
for idx, val := range vals {
switch v := val.(type) {
case bool:
if v {
val = 1
} else {
val = 0
}
}
stmt.Vars[idx] = val
}
}
}
}
}
// and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert)
// we keep track of the index so that the sub-reflected value is also correct
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
var dummyTable string
switch d := ptrDereference(db.Dialector).(type) {
case Dialector:
dummyTable = d.DummyTableName()
default:
dummyTable = "DUAL"
}
// BIG BUG: what if any of the transactions failed? some result might already be inserted that oracle is so
// sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point,
// resulting in dangling row entries, so we might need to delete them if an error happens
_, _ = db.Statement.WriteString("MERGE INTO ")
db.Statement.WriteQuoted(db.Statement.Table)
_, _ = db.Statement.WriteString(" USING (")
switch result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err {
case nil: // success
db.RowsAffected, _ = result.RowsAffected()
for idx, value := range values.Values {
if idx > 0 {
_, _ = db.Statement.WriteString(" UNION ALL ")
}
_, _ = db.Statement.WriteString("SELECT ")
for i, v := range value {
if i > 0 {
_ = db.Statement.WriteByte(',')
}
column := values.Columns[i]
db.Statement.AddVar(db.Statement, v)
_, _ = db.Statement.WriteString(" AS ")
db.Statement.WriteQuoted(column.Name)
}
_, _ = db.Statement.WriteString(" FROM ")
_, _ = db.Statement.WriteString(dummyTable)
}
_, _ = db.Statement.WriteString(`) `)
db.Statement.WriteQuoted("excluded")
_, _ = db.Statement.WriteString(" ON (")
var where clause.Where
for _, field := range db.Statement.Schema.PrimaryFields {
where.Exprs = append(where.Exprs, clause.Eq{
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
Value: clause.Column{Table: "excluded", Name: field.DBName},
})
}
where.Build(db.Statement)
_ = db.Statement.WriteByte(')')
if len(onConflict.DoUpdates) > 0 {
_, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
onConflict.DoUpdates.Build(db.Statement)
}
_, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
written := false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(column.Name)
}
}
_, _ = db.Statement.WriteString(") VALUES (")
written = false
for _, column := range values.Columns {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
if written {
_ = db.Statement.WriteByte(',')
}
written = true
db.Statement.WriteQuoted(clause.Column{
Table: "excluded",
Name: column.Name,
})
}
}
_, _ = db.Statement.WriteString(")")
}
func convertValue(val interface{}) interface{} {
val = ptrDereference(val)
switch v := val.(type) {
case bool:
if v {
val = 1
} else {
val = 0
}
// case string:
// if len(v) > 2000 {
// val = go_ora.Clob{String: v, Valid: true}
// }
default:
val = convertCustomType(val)
}
return val
}
func getDefaultValues(db *gorm.DB, idx int) {
insertTo := db.Statement.ReflectValue
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
insertTo = insertTo.Index(idx)
default:
}
if insertTo.Kind() == reflect.Pointer {
insertTo = insertTo.Elem()
}
for _, val := range db.Statement.Vars {
switch v := val.(type) {
case sql.Out:
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
for i := insertTo.Len() - 1; i >= 0; i-- {
rv := insertTo.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
insertTo := stmt.ReflectValue
switch insertTo.Kind() {
case reflect.Slice, reflect.Array:
insertTo = insertTo.Index(idx)
}
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero {
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest))
if hasDefaultValues {
// bind returning value back to reflected value in the respective fields
funk.ForEach(
funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool {
return funk.Contains(boundVars, field.Name)
}),
func(field *gormSchema.Field) {
switch insertTo.Kind() {
case reflect.Struct:
if err = field.Set(stmt.Context, insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil {
db.AddError(err)
}
case reflect.Map:
// todo 设置id的值
}
},
)
}
default: // failure
db.AddError(err)
}
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo)
if isZero {
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest))
}
default:
}
default:
}
}
}