package postgres import ( "database/sql" "fmt" "regexp" "strings" "github.com/jackc/pgx/v4" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" ) const indexSql = ` select t.relname as table_name, i.relname as index_name, a.attname as column_name, ix.indisunique as non_unique, ix.indisprimary as primary from pg_class t, pg_class i, pg_index ix, pg_attribute a where t.oid = ix.indrelid and i.oid = ix.indexrelid and a.attrelid = t.oid and a.attnum = ANY(ix.indkey) and t.relkind = 'r' and t.relname = ? ` type Migrator struct { migrator.Migrator } func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name) return } func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) if opt.Expression != "" { str = opt.Expression } if opt.Collate != "" { str += " COLLATE " + opt.Collate } if opt.Sort != "" { str += " " + opt.Sort } results = append(results, clause.Expr{SQL: str}) } return } func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) return m.DB.Raw( "SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema, ).Scan(&count).Error }) return count > 0 } func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} createIndexSQL := "CREATE " if idx.Class != "" { createIndexSQL += idx.Class + " " } createIndexSQL += "INDEX " if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" { createIndexSQL += "CONCURRENTLY " } createIndexSQL += "IF NOT EXISTS ? ON ?" if idx.Type != "" { createIndexSQL += " USING " + idx.Type + "(?)" } else { createIndexSQL += " ?" } if idx.Where != "" { createIndexSQL += " WHERE " + idx.Where } return m.DB.Exec(createIndexSQL, values...).Error } return fmt.Errorf("failed to create index with name %v", name) }) } func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER INDEX ? RENAME TO ?", clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error }) } func (m Migrator) GetTables() (tableList []string, err error) { currentSchema, _ := m.CurrentSchema(m.DB.Statement, "") return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error } func (m Migrator) CreateTable(values ...interface{}) (err error) { if err = m.Migrator.CreateTable(values...); err != nil { return } for _, value := range m.ReorderModels(values, false) { if err = m.RunWithValue(value, func(stmt *gorm.Statement) error { for _, field := range stmt.Schema.FieldsByDBName { if field.Comment != "" { if err := m.DB.Exec( "COMMENT ON COLUMN ?.? IS ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), ).Error; err != nil { return err } } } return nil }); err != nil { return } } return } func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error }) return count > 0 } func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) tx := m.DB.Session(&gorm.Session{}) for i := len(values) - 1; i >= 0; i-- { if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error }); err != nil { return err } } return nil } func (m Migrator) AddColumn(value interface{}, field string) error { if err := m.Migrator.AddColumn(value, field); err != nil { return err } return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { if field.Comment != "" { if err := m.DB.Exec( "COMMENT ON COLUMN ?.? IS ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), ).Error; err != nil { return err } } } return nil }) } func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { name := field if stmt.Schema != nil { if field := stmt.Schema.LookUpField(field); field != nil { name = field.DBName } } currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentSchema, curTable, name, ).Scan(&count).Error }) return count > 0 } func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // skip primary field if !field.PrimaryKey { if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil { return err } } return m.RunWithValue(value, func(stmt *gorm.Statement) error { var description string currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table) values := []interface{}{currentSchema, curTable, field.DBName, stmt.Table, currentSchema} checkSQL := "SELECT description FROM pg_catalog.pg_description " checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) " checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = " checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))" m.DB.Raw(checkSQL, values...).Scan(&description) comment := field.Comment if comment != "" { comment = comment[1 : len(comment)-1] } if field.Comment != "" && comment != description { if err := m.DB.Exec( "COMMENT ON COLUMN ?.? IS ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)), ).Error; err != nil { return err } } return nil }) } // AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if field := stmt.Schema.LookUpField(field); field != nil { var ( columnTypes, _ = m.DB.Migrator().ColumnTypes(value) fieldColumnType *migrator.ColumnType ) for _, columnType := range columnTypes { if columnType.Name() == field.DBName { fieldColumnType, _ = columnType.(*migrator.ColumnType) } } fileType := clause.Expr{SQL: m.DataTypeOf(field)} if fieldColumnType.DatabaseTypeName() != fileType.SQL { filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement() if field.AutoIncrement && filedColumnAutoIncrement { // update serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType { if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { return err } } } else if field.AutoIncrement && !filedColumnAutoIncrement { // create serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL) if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil { return err } } else if !field.AutoIncrement && filedColumnAutoIncrement { // delete if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil { return err } } else { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil { return err } } } if null, _ := fieldColumnType.Nullable(); null == field.NotNull { if field.NotNull { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } else { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil { return err } } } if uniq, _ := fieldColumnType.Unique(); uniq != field.Unique { idxName := clause.Column{Name: m.DB.Config.NamingStrategy.IndexName(stmt.Table, field.DBName)} if err := m.DB.Exec("ALTER TABLE ? ADD CONSTRAINT ? UNIQUE(?)", m.CurrentTable(stmt), idxName, clause.Column{Name: field.DBName}).Error; err != nil { return err } } if v, _ := fieldColumnType.DefaultValue(); v != field.DefaultValue { if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil { return err } } else if field.DefaultValue != "(-)" { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { return err } } else { if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil { return err } } } } return nil } return fmt.Errorf("failed to look up field with name: %s", field) }) } func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) currentSchema, curTable := m.CurrentSchema(stmt, table) if constraint != nil { name = constraint.Name } else if chk != nil { name = chk.Name } return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?", currentSchema, curTable, name, ).Scan(&count).Error }) return count > 0 } func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { columnTypes = make([]gorm.ColumnType, 0) err = m.RunWithValue(value, func(stmt *gorm.Statement) error { var ( currentDatabase = m.DB.Migrator().CurrentDatabase() currentSchema, table = m.CurrentSchema(stmt, stmt.Table) columns, err = m.DB.Raw( "SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?", currentDatabase, currentSchema, table).Rows() ) if err != nil { return err } for columns.Next() { var ( column = &migrator.ColumnType{ PrimaryKeyValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, } datetimePrecision sql.NullInt64 radixValue sql.NullInt64 typeLenValue sql.NullInt64 ) err = columns.Scan( &column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue, &radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue, ) if err != nil { return err } if typeLenValue.Valid && typeLenValue.Int64 > 0 { column.LengthValue = typeLenValue } if strings.HasPrefix(column.DefaultValueValue.String, "nextval('") && strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)") { column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true} column.DefaultValueValue = sql.NullString{} } if column.DefaultValueValue.Valid { column.DefaultValueValue.String = regexp.MustCompile(`'(.*)'::[\w]+$`).ReplaceAllString(column.DefaultValueValue.String, "$1") } if datetimePrecision.Valid { column.DecimalSizeValue = datetimePrecision } columnTypes = append(columnTypes, column) } columns.Close() // assign sql column type { rows, rowsErr := m.GetRows(currentSchema, table) if rowsErr != nil { return rowsErr } rawColumnTypes, err := rows.ColumnTypes() if err != nil { return err } for _, columnType := range columnTypes { for _, c := range rawColumnTypes { if c.Name() == columnType.Name() { columnType.(*migrator.ColumnType).SQLColumnType = c break } } } rows.Close() } // check primary, unique field { columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows() if err != nil { return err } uniqueContraints := map[string]int{} for columnTypeRows.Next() { var constraintName string columnTypeRows.Scan(&constraintName) uniqueContraints[constraintName]++ } columnTypeRows.Close() columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows() if err != nil { return err } for columnTypeRows.Next() { var name, constraintName, columnType string columnTypeRows.Scan(&name, &constraintName, &columnType) for _, c := range columnTypes { mc := c.(*migrator.ColumnType) if mc.NameValue.String == name { switch columnType { case "PRIMARY KEY": mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} case "UNIQUE": if uniqueContraints[constraintName] == 1 { mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} } } break } } } columnTypeRows.Close() } // check column type { dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?) WHERE a.attnum > 0 -- hide internal columns AND NOT a.attisdropped -- hide deleted columns AND b.relname = ?`, currentSchema, table).Rows() if err != nil { return err } for dataTypeRows.Next() { var name, dataType string dataTypeRows.Scan(&name, &dataType) for _, c := range columnTypes { mc := c.(*migrator.ColumnType) if mc.NameValue.String == name { mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true} // Handle array type: _text -> text[] , _int4 -> integer[] // Not support array size limits and array size limits because: // https://www.postgresql.org/docs/current/arrays.html#ARRAYS-DECLARATION if strings.HasPrefix(mc.DataTypeValue.String, "_") { mc.DataTypeValue = sql.NullString{String: dataType, Valid: true} } break } } } dataTypeRows.Close() } return err }) return } func (m Migrator) GetRows(currentSchema interface{}, table interface{}) (*sql.Rows, error) { name := table.(string) if _, ok := currentSchema.(string); ok { name = fmt.Sprintf("%v.%v", currentSchema, table) } return m.DB.Session(&gorm.Session{}).Table(name).Limit(1).Scopes(func(d *gorm.DB) *gorm.DB { // use simple protocol if !m.DB.PrepareStmt { d.Statement.Vars = append(d.Statement.Vars, pgx.QuerySimpleProtocol(true)) } return d }).Rows() } func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{}, interface{}) { if strings.Contains(table, ".") { if tables := strings.Split(table, `.`); len(tables) == 2 { return tables[0], tables[1] } } if stmt.TableExpr != nil { if tables := strings.Split(stmt.TableExpr.SQL, `"."`); len(tables) == 2 { return strings.TrimPrefix(tables[0], `"`), table } } return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table } func (m Migrator) CreateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, serialDatabaseType string) (err error) { _, table := m.CurrentSchema(stmt, stmt.Table) tableName := table.(string) sequenceName := strings.Join([]string{tableName, field.DBName, "seq"}, "_") if err = tx.Exec(`CREATE SEQUENCE IF NOT EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { return err } if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT nextval('?')", clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}, clause.Expr{SQL: sequenceName}).Error; err != nil { return err } if err := tx.Exec("ALTER SEQUENCE ? OWNED BY ?.?", clause.Expr{SQL: sequenceName}, clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}).Error; err != nil { return err } return } func (m Migrator) UpdateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, serialDatabaseType string) (err error) { sequenceName, err := m.getColumnSequenceName(tx, stmt, field) if err != nil { return err } if err = tx.Exec(`ALTER SEQUENCE IF EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { return err } if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil { return err } return } func (m Migrator) DeleteSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field, fileType clause.Expr) (err error) { sequenceName, err := m.getColumnSequenceName(tx, stmt, field) if err != nil { return err } if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil { return err } if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}).Error; err != nil { return err } if err = tx.Exec(`DROP SEQUENCE IF EXISTS ?`, clause.Expr{SQL: sequenceName}).Error; err != nil { return err } return } func (m Migrator) getColumnSequenceName(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field) ( sequenceName string, err error) { _, table := m.CurrentSchema(stmt, stmt.Table) // DefaultValueValue is reset by ColumnTypes, search again. var columnDefault string err = tx.Raw( `SELECT column_default FROM information_schema.columns WHERE table_name = ? AND column_name = ?`, table, field.DBName).Scan(&columnDefault).Error if err != nil { return } sequenceName = strings.TrimSuffix( strings.TrimPrefix(columnDefault, `nextval('`), `'::regclass)`, ) return } func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { indexes := make([]gorm.Index, 0) err := m.RunWithValue(value, func(stmt *gorm.Statement) error { result := make([]*Index, 0) scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error if scanErr != nil { return scanErr } indexMap := groupByIndexName(result) for _, idx := range indexMap { tempIdx := &migrator.Index{ TableName: idx[0].TableName, NameValue: idx[0].IndexName, PrimaryKeyValue: sql.NullBool{ Bool: idx[0].Primary, Valid: true, }, UniqueValue: sql.NullBool{ Bool: idx[0].NonUnique, Valid: true, }, } for _, x := range idx { tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName) } indexes = append(indexes, tempIdx) } return nil }) return indexes, err } // Index table index info type Index struct { TableName string `gorm:"column:table_name"` ColumnName string `gorm:"column:column_name"` IndexName string `gorm:"column:index_name"` NonUnique bool `gorm:"column:non_unique"` Primary bool `gorm:"column:primary"` } func groupByIndexName(indexList []*Index) map[string][]*Index { columnIndexMap := make(map[string][]*Index, len(indexList)) for _, idx := range indexList { columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx) } return columnIndexMap }