package oracle import ( "database/sql" "fmt" "strings" "gorm.io/gorm/schema" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" ) type Migrator struct { migrator.Migrator } func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw( fmt.Sprintf(`SELECT ORA_DATABASE_NAME as "Current Database" FROM %s`, m.Dialector.(Dialector).DummyTableName()), ).Row().Scan(&name) return } func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range values { m.TryQuotifyReservedWords(value) m.TryRemoveOnUpdate(value) } return m.Migrator.CreateTable(values...) } func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { value := values[i] tx := m.DB.Session(&gorm.Session{}) if m.HasTable(value) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE ? CASCADE CONSTRAINTS", clause.Table{Name: stmt.Table}).Error }); err != nil { return err } } } return nil } func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") { ownertable := strings.Split(stmt.Schema.Table, ".") return m.DB.Raw("SELECT COUNT(*) FROM ALL_TABLES WHERE OWNER = ? and TABLE_NAME = ?", ownertable[0], ownertable[1]).Row().Scan(&count) } else { return m.DB.Raw("SELECT COUNT(*) FROM USER_TABLES WHERE TABLE_NAME = ?", stmt.Table).Row().Scan(&count) } }) return count > 0 } // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Schema.Table).Where("ROWNUM = 1").Rows() if err != nil { return err } defer func() { err = rows.Close() }() var rawColumnTypes []*sql.ColumnType rawColumnTypes, err = rows.ColumnTypes() if err != nil { return err } for _, c := range rawColumnTypes { columnTypes = append(columnTypes, migrator.ColumnType{SQLColumnType: c}) } return }) return columnTypes, execErr } func (m Migrator) RenameTable(oldName, newName interface{}) (err error) { resolveTable := func(name interface{}) (result string, err error) { if v, ok := name.(string); ok { result = v } else { stmt := &gorm.Statement{DB: m.DB} if err = stmt.Parse(name); err == nil { result = stmt.Table } } return } var oldTable, newTable string if oldTable, err = resolveTable(oldName); err != nil { return } if newTable, err = resolveTable(newName); err != nil { return } if !m.HasTable(oldTable) { return } return m.DB.Exec("RENAME TABLE ? TO ?", clause.Table{Name: oldTable}, clause.Table{Name: newTable}, ).Error } func (m Migrator) AddColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: field.DBName}, m.DB.Migrator().FullDataTypeOf(field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) } func (m Migrator) DropColumn(value interface{}, name string) error { if !m.HasColumn(value, name) { return nil } return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(name); field != nil { name = field.DBName } return m.DB.Exec( "ALTER TABLE ? DROP ?", clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name}, ).Error }) } func (m Migrator) AlterColumn(value interface{}, field string) error { if !m.HasColumn(value, field) { return nil } return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { return m.DB.Exec( "ALTER TABLE ? MODIFY ? ?", clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: field.DBName}, m.AlterDataTypeOf(stmt, field), ).Error } return fmt.Errorf("failed to look up field with name: %s", field) }) } func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") { ownertable := strings.Split(stmt.Schema.Table, ".") return m.DB.Raw("SELECT COUNT(*) FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownertable[0], ownertable[1], field).Row().Scan(&count) } else { return m.DB.Raw("SELECT COUNT(*) FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field).Row().Scan(&count) } }) == nil && count > 0 } func (m Migrator) AlterDataTypeOf(stmt *gorm.Statement, field *schema.Field) (expr clause.Expr) { expr.SQL = m.DataTypeOf(field) var nullable = "" if stmt.Schema != nil && strings.Contains(stmt.Schema.Table, ".") { ownertable := strings.Split(stmt.Schema.Table, ".") m.DB.Raw("SELECT NULLABLE FROM ALL_TAB_COLUMNS WHERE OWNER = ? and TABLE_NAME = ? AND COLUMN_NAME = ?", ownertable[0], ownertable[1], field.DBName).Row().Scan(&nullable) } else { m.DB.Raw("SELECT NULLABLE FROM USER_TAB_COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = ?", stmt.Table, field.DBName).Row().Scan(&nullable) } if field.NotNull && nullable == "Y" { expr.SQL += " NOT NULL" } if field.Unique { expr.SQL += " UNIQUE" } if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) } else if field.DefaultValue != "(-)" { expr.SQL += " DEFAULT " + field.DefaultValue } } return } func (m Migrator) CreateConstraint(value interface{}, name string) error { m.TryRemoveOnUpdate(value) return m.Migrator.CreateConstraint(value, name) } func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { for _, chk := range stmt.Schema.ParseCheckConstraints() { if chk.Name == name { return m.DB.Exec( "ALTER TABLE ? DROP CHECK ?", clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name}, ).Error } } return m.DB.Exec( "ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: stmt.Schema.Table}, clause.Column{Name: name}, ).Error }) } func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Raw( "SELECT COUNT(*) FROM USER_CONSTRAINTS WHERE TABLE_NAME = ? AND CONSTRAINT_NAME = ?", stmt.Table, name, ).Row().Scan(&count) }) == nil && count > 0 } func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}, clause.Table{Name: stmt.Schema.Table}).Error }) } func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } return m.DB.Raw( "SELECT COUNT(*) FROM USER_INDEXES WHERE TABLE_NAME = ? AND INDEX_NAME = ?", m.Migrator.DB.NamingStrategy.TableName(stmt.Table), m.Migrator.DB.NamingStrategy.IndexName(stmt.Table, name), ).Row().Scan(&count) }) return count > 0 } // https://docs.oracle.com/database/121/SPATL/alter-index-rename.htm func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { // ALTER INDEX oldindex RENAME TO newindex; return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER INDEX ? RENAME TO ?", clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } func (m Migrator) TryRemoveOnUpdate(values ...interface{}) error { for _, value := range values { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { for _, rel := range stmt.Schema.Relationships.Relations { constraint := rel.ParseConstraint() if constraint != nil { rel.Field.TagSettings["CONSTRAINT"] = strings.ReplaceAll(rel.Field.TagSettings["CONSTRAINT"], fmt.Sprintf("ON UPDATE %s", constraint.OnUpdate), "") } } return nil }); err != nil { return err } } return nil } func (m Migrator) TryQuotifyReservedWords(values ...interface{}) error { for _, value := range values { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { for idx, v := range stmt.Schema.DBNames { if IsReservedWord(v) { stmt.Schema.DBNames[idx] = fmt.Sprintf(`"%s"`, v) } } for _, v := range stmt.Schema.Fields { if IsReservedWord(v.DBName) { v.DBName = fmt.Sprintf(`"%s"`, v.DBName) } } return nil }); err != nil { return err } } return nil }