update
This commit is contained in:
255
create.go
Normal file
255
create.go
Normal file
@ -0,0 +1,255 @@
|
||||
package oracle
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
func Create(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
stmt := db.Statement
|
||||
if stmt == nil {
|
||||
return
|
||||
}
|
||||
|
||||
stmtSchema := stmt.Schema
|
||||
if stmtSchema == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !stmt.Unscoped {
|
||||
for _, c := range stmtSchema.CreateClauses {
|
||||
stmt.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if stmt.SQL.Len() == 0 {
|
||||
var (
|
||||
createValues = callbacks.ConvertToCreateValues(stmt)
|
||||
onConflict, hasConflict = stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
|
||||
)
|
||||
|
||||
if hasConflict {
|
||||
if len(stmtSchema.PrimaryFields) > 0 {
|
||||
columnsMap := map[string]bool{}
|
||||
for _, column := range createValues.Columns {
|
||||
columnsMap[column.Name] = true
|
||||
}
|
||||
|
||||
for _, field := range stmtSchema.PrimaryFields {
|
||||
if _, ok := columnsMap[field.DBName]; !ok {
|
||||
hasConflict = false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
hasConflict = false
|
||||
}
|
||||
}
|
||||
|
||||
hasDefaultValues := len(stmtSchema.FieldsWithDefaultDBValue) > 0
|
||||
if hasConflict {
|
||||
MergeCreate(db, onConflict, createValues)
|
||||
} else {
|
||||
stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Schema.Table}})
|
||||
stmt.AddClause(clause.Values{Columns: createValues.Columns, Values: [][]interface{}{createValues.Values[0]}})
|
||||
|
||||
if hasDefaultValues {
|
||||
columns := make([]clause.Column, len(stmtSchema.FieldsWithDefaultDBValue))
|
||||
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
|
||||
columns[idx] = clause.Column{Name: field.DBName}
|
||||
}
|
||||
stmt.AddClauseIfNotExists(clause.Returning{Columns: columns})
|
||||
}
|
||||
stmt.Build("INSERT", "VALUES", "RETURNING")
|
||||
|
||||
if hasDefaultValues {
|
||||
_, _ = stmt.WriteString(" INTO ")
|
||||
for idx, field := range stmtSchema.FieldsWithDefaultDBValue {
|
||||
if idx > 0 {
|
||||
_ = stmt.WriteByte(',')
|
||||
}
|
||||
stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()})
|
||||
}
|
||||
_, _ = stmt.WriteString(" /*-sql.Out{}-*/")
|
||||
}
|
||||
}
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if hasConflict {
|
||||
for i, val := range stmt.Vars {
|
||||
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
|
||||
stmt.Vars[i] = convertValue(val)
|
||||
}
|
||||
|
||||
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
// TODO: get merged returning
|
||||
}
|
||||
} else {
|
||||
for idx, values := range createValues.Values {
|
||||
for i, val := range values {
|
||||
// HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
|
||||
stmt.Vars[i] = convertValue(val)
|
||||
}
|
||||
|
||||
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
|
||||
if hasDefaultValues {
|
||||
getDefaultValues(db, idx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MergeCreate(db *gorm.DB, onConflict clause.OnConflict, values clause.Values) {
|
||||
var dummyTable string
|
||||
switch d := ptrDereference(db.Dialector).(type) {
|
||||
case Dialector:
|
||||
dummyTable = d.DummyTableName()
|
||||
default:
|
||||
dummyTable = "DUAL"
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString("MERGE INTO ")
|
||||
db.Statement.WriteQuoted(db.Statement.Table)
|
||||
_, _ = db.Statement.WriteString(" USING (")
|
||||
|
||||
for idx, value := range values.Values {
|
||||
if idx > 0 {
|
||||
_, _ = db.Statement.WriteString(" UNION ALL ")
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString("SELECT ")
|
||||
for i, v := range value {
|
||||
if i > 0 {
|
||||
_ = db.Statement.WriteByte(',')
|
||||
}
|
||||
column := values.Columns[i]
|
||||
db.Statement.AddVar(db.Statement, v)
|
||||
_, _ = db.Statement.WriteString(" AS ")
|
||||
db.Statement.WriteQuoted(column.Name)
|
||||
}
|
||||
_, _ = db.Statement.WriteString(" FROM ")
|
||||
_, _ = db.Statement.WriteString(dummyTable)
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString(`) `)
|
||||
db.Statement.WriteQuoted("excluded")
|
||||
_, _ = db.Statement.WriteString(" ON (")
|
||||
|
||||
var where clause.Where
|
||||
for _, field := range db.Statement.Schema.PrimaryFields {
|
||||
where.Exprs = append(where.Exprs, clause.Eq{
|
||||
Column: clause.Column{Table: db.Statement.Table, Name: field.DBName},
|
||||
Value: clause.Column{Table: "excluded", Name: field.DBName},
|
||||
})
|
||||
}
|
||||
where.Build(db.Statement)
|
||||
_ = db.Statement.WriteByte(')')
|
||||
|
||||
if len(onConflict.DoUpdates) > 0 {
|
||||
_, _ = db.Statement.WriteString(" WHEN MATCHED THEN UPDATE SET ")
|
||||
onConflict.DoUpdates.Build(db.Statement)
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString(" WHEN NOT MATCHED THEN INSERT (")
|
||||
|
||||
written := false
|
||||
for _, column := range values.Columns {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
|
||||
if written {
|
||||
_ = db.Statement.WriteByte(',')
|
||||
}
|
||||
written = true
|
||||
db.Statement.WriteQuoted(column.Name)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = db.Statement.WriteString(") VALUES (")
|
||||
|
||||
written = false
|
||||
for _, column := range values.Columns {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement || db.Statement.Schema.PrioritizedPrimaryField.DBName != column.Name {
|
||||
if written {
|
||||
_ = db.Statement.WriteByte(',')
|
||||
}
|
||||
written = true
|
||||
db.Statement.WriteQuoted(clause.Column{
|
||||
Table: "excluded",
|
||||
Name: column.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
_, _ = db.Statement.WriteString(")")
|
||||
}
|
||||
|
||||
func convertValue(val interface{}) interface{} {
|
||||
val = ptrDereference(val)
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
val = 1
|
||||
} else {
|
||||
val = 0
|
||||
}
|
||||
// case string:
|
||||
// if len(v) > 2000 {
|
||||
// val = go_ora.Clob{String: v, Valid: true}
|
||||
// }
|
||||
default:
|
||||
val = convertCustomType(val)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func getDefaultValues(db *gorm.DB, idx int) {
|
||||
insertTo := db.Statement.ReflectValue
|
||||
switch insertTo.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
insertTo = insertTo.Index(idx)
|
||||
default:
|
||||
}
|
||||
if insertTo.Kind() == reflect.Pointer {
|
||||
insertTo = insertTo.Elem()
|
||||
}
|
||||
|
||||
for _, val := range db.Statement.Vars {
|
||||
switch v := val.(type) {
|
||||
case sql.Out:
|
||||
switch insertTo.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := insertTo.Len() - 1; i >= 0; i-- {
|
||||
rv := insertTo.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, v.Dest))
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, insertTo)
|
||||
if isZero {
|
||||
_ = db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, insertTo, v.Dest))
|
||||
}
|
||||
default:
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user