package postgres import ( "database/sql" "fmt" "github.com/jackc/pgx/v5" "regexp" "strconv" "strings" "github.com/jackc/pgx/v5/stdlib" "gorm.io/gorm" "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" ) type Dialector struct { *Config } type Config struct { DriverName string DSN string PreferSimpleProtocol bool WithoutReturning bool Conn gorm.ConnPool } func Open(dsn string) gorm.Dialector { return &Dialector{&Config{DSN: dsn}} } func New(config Config) gorm.Dialector { return &Dialector{Config: &config} } func (dialector Dialector) Name() string { return "postgres" } var timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )") func (dialector Dialector) Initialize(db *gorm.DB) (err error) { callbackConfig := &callbacks.Config{ CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"}, UpdateClauses: []string{"UPDATE", "SET", "WHERE"}, DeleteClauses: []string{"DELETE", "FROM", "WHERE"}, } // register callbacks if !dialector.WithoutReturning { callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING") callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING") callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING") } callbacks.RegisterDefaultCallbacks(db, callbackConfig) if dialector.Conn != nil { db.ConnPool = dialector.Conn } else if dialector.DriverName != "" { db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN) } else { var config *pgx.ConnConfig config, err = pgx.ParseConfig(dialector.Config.DSN) if err != nil { return } if dialector.Config.PreferSimpleProtocol { config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol } result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN) if len(result) > 2 { config.RuntimeParams["timezone"] = result[2] } db.ConnPool = stdlib.OpenDB(*config) } return } func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { return Migrator{migrator.Migrator{Config: migrator.Config{ DB: db, Dialector: dialector, CreateIndexAfterCreateTable: true, }}} } func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { return clause.Expr{SQL: "DEFAULT"} } func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { writer.WriteByte('$') writer.WriteString(strconv.Itoa(len(stmt.Vars))) } func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { var ( underQuoted, selfQuoted bool continuousBacktick int8 shiftDelimiter int8 ) for _, v := range []byte(str) { switch v { case '"': continuousBacktick++ if continuousBacktick == 2 { writer.WriteString(`""`) continuousBacktick = 0 } case '.': if continuousBacktick > 0 || !selfQuoted { shiftDelimiter = 0 underQuoted = false continuousBacktick = 0 writer.WriteByte('"') } writer.WriteByte(v) continue default: if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { writer.WriteByte('"') underQuoted = true if selfQuoted = continuousBacktick > 0; selfQuoted { continuousBacktick -= 1 } } for ; continuousBacktick > 0; continuousBacktick -= 1 { writer.WriteString(`""`) } writer.WriteByte(v) } shiftDelimiter++ } if continuousBacktick > 0 && !selfQuoted { writer.WriteString(`""`) } writer.WriteByte('"') } var numericPlaceholder = regexp.MustCompile(`\$(\d+)`) func (dialector Dialector) Explain(sql string, vars ...interface{}) string { return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...) } func (dialector Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: return "boolean" case schema.Int, schema.Uint: size := field.Size if field.DataType == schema.Uint { size++ } if field.AutoIncrement { switch { case size <= 16: return "smallserial" case size <= 32: return "serial" default: return "bigserial" } } else { switch { case size <= 16: return "smallint" case size <= 32: return "integer" default: return "bigint" } } case schema.Float: if field.Precision > 0 { if field.Scale > 0 { return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale) } return fmt.Sprintf("numeric(%d)", field.Precision) } return "decimal" case schema.String: if field.Size > 0 { return fmt.Sprintf("varchar(%d)", field.Size) } return "text" case schema.Time: if field.Precision > 0 { return fmt.Sprintf("timestamptz(%d)", field.Precision) } return "timestamptz" case schema.Bytes: return "bytea" default: return dialector.getSchemaCustomType(field) } } func (dialector Dialector) getSchemaCustomType(field *schema.Field) string { sqlType := string(field.DataType) if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") { size := field.Size if field.GORMDataType == schema.Uint { size++ } switch { case size <= 16: sqlType = "smallserial" case size <= 32: sqlType = "serial" default: sqlType = "bigserial" } } return sqlType } func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error { tx.Exec("SAVEPOINT " + name) return nil } func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error { tx.Exec("ROLLBACK TO SAVEPOINT " + name) return nil } func getSerialDatabaseType(s string) (dbType string, ok bool) { switch s { case "smallserial": return "smallint", true case "serial": return "integer", true case "bigserial": return "bigint", true default: return "", false } }