You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/gorm.io/plugin/dbresolver/dbresolver.go

203 lines
4.3 KiB

package dbresolver
import (
"errors"
"sync"
"gorm.io/gorm"
)
const (
Write Operation = "write"
Read Operation = "read"
)
type DBResolver struct {
*gorm.DB
configs []Config
resolvers map[string]*resolver
global *resolver
prepareStmtStore map[gorm.ConnPool]*gorm.PreparedStmtDB
compileCallbacks []func(gorm.ConnPool) error
}
type Config struct {
Sources []gorm.Dialector
Replicas []gorm.Dialector
Policy Policy
datas []interface{}
TraceResolverMode bool
}
func Register(config Config, datas ...interface{}) *DBResolver {
return (&DBResolver{}).Register(config, datas...)
}
func (dr *DBResolver) Register(config Config, datas ...interface{}) *DBResolver {
if dr.prepareStmtStore == nil {
dr.prepareStmtStore = map[gorm.ConnPool]*gorm.PreparedStmtDB{}
}
if dr.resolvers == nil {
dr.resolvers = map[string]*resolver{}
}
if config.Policy == nil {
config.Policy = RandomPolicy{}
}
config.datas = datas
dr.configs = append(dr.configs, config)
if dr.DB != nil {
dr.compileConfig(config)
}
return dr
}
func (dr *DBResolver) Name() string {
return "gorm:db_resolver"
}
func (dr *DBResolver) Initialize(db *gorm.DB) error {
dr.DB = db
dr.registerCallbacks(db)
return dr.compile()
}
func (dr *DBResolver) compile() error {
for _, config := range dr.configs {
if err := dr.compileConfig(config); err != nil {
return err
}
}
return nil
}
func (dr *DBResolver) compileConfig(config Config) (err error) {
var (
connPool = dr.DB.Config.ConnPool
r = resolver{
dbResolver: dr,
policy: config.Policy,
traceResolverMode: config.TraceResolverMode,
}
)
if preparedStmtDB, ok := connPool.(*gorm.PreparedStmtDB); ok {
connPool = preparedStmtDB.ConnPool
}
if len(config.Sources) == 0 {
r.sources = []gorm.ConnPool{connPool}
} else if r.sources, err = dr.convertToConnPool(config.Sources); err != nil {
return err
}
if len(config.Replicas) == 0 {
r.replicas = r.sources
} else if r.replicas, err = dr.convertToConnPool(config.Replicas); err != nil {
return err
}
if len(config.datas) > 0 {
for _, data := range config.datas {
if t, ok := data.(string); ok {
dr.resolvers[t] = &r
} else {
stmt := &gorm.Statement{DB: dr.DB}
if err := stmt.Parse(data); err == nil {
dr.resolvers[stmt.Table] = &r
} else {
return err
}
}
}
} else if dr.global == nil {
dr.global = &r
} else {
return errors.New("conflicted global resolver")
}
for _, fc := range dr.compileCallbacks {
if err = r.call(fc); err != nil {
return err
}
}
if config.TraceResolverMode {
dr.Logger = NewResolverModeLogger(dr.Logger)
}
return nil
}
func (dr *DBResolver) convertToConnPool(dialectors []gorm.Dialector) (connPools []gorm.ConnPool, err error) {
config := *dr.DB.Config
for _, dialector := range dialectors {
if db, err := gorm.Open(dialector, &config); err == nil {
connPool := db.Config.ConnPool
if preparedStmtDB, ok := connPool.(*gorm.PreparedStmtDB); ok {
connPool = preparedStmtDB.ConnPool
}
dr.prepareStmtStore[connPool] = &gorm.PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Stmts: map[string]*gorm.Stmt{},
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
connPools = append(connPools, connPool)
} else {
return nil, err
}
}
return connPools, err
}
func (dr *DBResolver) resolve(stmt *gorm.Statement, op Operation) gorm.ConnPool {
if r := dr.getResolver(stmt); r != nil {
return r.resolve(stmt, op)
}
return stmt.ConnPool
}
func (dr *DBResolver) getResolver(stmt *gorm.Statement) *resolver {
if len(dr.resolvers) > 0 {
if u, ok := stmt.Clauses[usingName].Expression.(using); ok && u.Use != "" {
if r, ok := dr.resolvers[u.Use]; ok {
return r
}
}
if stmt.Table != "" {
if r, ok := dr.resolvers[stmt.Table]; ok {
return r
}
}
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err == nil {
if r, ok := dr.resolvers[stmt.Table]; ok {
return r
}
}
}
if stmt.Schema != nil {
if r, ok := dr.resolvers[stmt.Schema.Table]; ok {
return r
}
}
if rawSQL := stmt.SQL.String(); rawSQL != "" {
if r, ok := dr.resolvers[getTableFromRawSQL(rawSQL)]; ok {
return r
}
}
}
return dr.global
}