You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1063 lines
28 KiB
1063 lines
28 KiB
package gen
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/callbacks"
|
|
"gorm.io/gorm/clause"
|
|
"gorm.io/gorm/schema"
|
|
|
|
"gorm.io/gen/field"
|
|
"gorm.io/gen/helper"
|
|
)
|
|
|
|
// ResultInfo query/execute info
|
|
type ResultInfo struct {
|
|
RowsAffected int64
|
|
Error error
|
|
}
|
|
|
|
var _ Dao = new(DO)
|
|
|
|
// DO (data object): implement basic query methods
|
|
// the structure embedded with a *gorm.DB, and has a element item "alias" will be used when used as a sub query
|
|
type DO struct {
|
|
*DOConfig
|
|
db *gorm.DB
|
|
alias string // for subquery
|
|
modelType reflect.Type
|
|
tableName string
|
|
|
|
backfillData interface{}
|
|
}
|
|
|
|
func (d DO) getInstance(db *gorm.DB) *DO {
|
|
d.db = db
|
|
return &d
|
|
}
|
|
|
|
type doOptions func(*gorm.DB) *gorm.DB
|
|
|
|
var (
|
|
// Debug use DB in debug mode
|
|
Debug doOptions = func(db *gorm.DB) *gorm.DB { return db.Debug() }
|
|
)
|
|
|
|
// UseDB specify a db connection(*gorm.DB)
|
|
func (d *DO) UseDB(db *gorm.DB, opts ...DOOption) {
|
|
db = db.Session(&gorm.Session{Context: context.Background()})
|
|
d.db = db
|
|
config := &DOConfig{}
|
|
for _, opt := range opts {
|
|
if opt != nil {
|
|
if applyErr := opt.Apply(config); applyErr != nil {
|
|
panic(applyErr)
|
|
}
|
|
}
|
|
}
|
|
d.DOConfig = config
|
|
}
|
|
|
|
// ReplaceDB replace db connection
|
|
func (d *DO) ReplaceDB(db *gorm.DB) {
|
|
d.db = db.Session(&gorm.Session{})
|
|
}
|
|
|
|
// ReplaceConnPool replace db connection pool
|
|
func (d *DO) ReplaceConnPool(pool gorm.ConnPool) {
|
|
d.db = d.db.Session(&gorm.Session{Initialized: true}).Session(&gorm.Session{})
|
|
d.db.Statement.ConnPool = pool
|
|
}
|
|
|
|
// UseModel specify a data model structure as a source for table name
|
|
func (d *DO) UseModel(model interface{}) {
|
|
d.modelType = d.indirect(model)
|
|
|
|
err := d.db.Statement.Parse(model)
|
|
if err != nil {
|
|
panic(fmt.Errorf("Cannot parse model: %+v\n%w", model, err))
|
|
}
|
|
d.tableName = d.db.Statement.Schema.Table
|
|
}
|
|
|
|
func (d *DO) indirect(value interface{}) reflect.Type {
|
|
mt := reflect.TypeOf(value)
|
|
if mt.Kind() == reflect.Ptr {
|
|
mt = mt.Elem()
|
|
}
|
|
return mt
|
|
}
|
|
|
|
// UseTable specify table name
|
|
func (d *DO) UseTable(tableName string) {
|
|
d.db = d.db.Table(tableName).Session(new(gorm.Session))
|
|
//d.db.Statement.Schema.Table=tableName
|
|
d.tableName = tableName
|
|
}
|
|
|
|
// TableName return table name
|
|
func (d DO) TableName() string {
|
|
return d.tableName
|
|
}
|
|
|
|
// Returning backfill data
|
|
func (d DO) Returning(value interface{}, columns ...string) Dao {
|
|
d.backfillData = value
|
|
|
|
var targetCulumns []clause.Column
|
|
for _, column := range columns {
|
|
targetCulumns = append(targetCulumns, clause.Column{Name: column})
|
|
}
|
|
d.db = d.db.Clauses(clause.Returning{Columns: targetCulumns})
|
|
return &d
|
|
}
|
|
|
|
// Session replace db with new session
|
|
func (d *DO) Session(config *gorm.Session) Dao { return d.getInstance(d.db.Session(config)) }
|
|
|
|
// UnderlyingDB return the underlying database connection
|
|
func (d *DO) UnderlyingDB() *gorm.DB { return d.underlyingDB() }
|
|
|
|
// Quote return qutoed data
|
|
func (d *DO) Quote(raw string) string { return d.db.Statement.Quote(raw) }
|
|
|
|
// Build implement the interface of claues.Expression
|
|
// only call WHERE clause's Build
|
|
func (d *DO) Build(builder clause.Builder) {
|
|
for _, e := range d.buildCondition() {
|
|
e.Build(builder)
|
|
}
|
|
}
|
|
|
|
func (d *DO) buildCondition() []clause.Expression {
|
|
return d.db.Statement.BuildCondition(d.db)
|
|
}
|
|
|
|
// underlyingDO return self
|
|
func (d *DO) underlyingDO() *DO { return d }
|
|
|
|
// underlyingDB return self.db
|
|
func (d *DO) underlyingDB() *gorm.DB { return d.db }
|
|
|
|
func (d *DO) withError(err error) *DO {
|
|
if err == nil {
|
|
return d
|
|
}
|
|
|
|
newDB := d.db.Session(new(gorm.Session))
|
|
_ = newDB.AddError(err)
|
|
return d.getInstance(newDB)
|
|
}
|
|
|
|
// BeCond implements Condition
|
|
func (d *DO) BeCond() interface{} { return d.buildCondition() }
|
|
|
|
// CondError implements Condition
|
|
func (d *DO) CondError() error { return nil }
|
|
|
|
// Debug return a DO with db in debug mode
|
|
func (d *DO) Debug() Dao { return d.getInstance(d.db.Debug()) }
|
|
|
|
// WithContext return a DO with db with context
|
|
func (d *DO) WithContext(ctx context.Context) Dao { return d.getInstance(d.db.WithContext(ctx)) }
|
|
|
|
// Clauses specify Clauses
|
|
func (d *DO) Clauses(conds ...clause.Expression) Dao {
|
|
if err := checkConds(conds); err != nil {
|
|
newDB := d.db.Session(new(gorm.Session))
|
|
_ = newDB.AddError(err)
|
|
return d.getInstance(newDB)
|
|
}
|
|
return d.getInstance(d.db.Clauses(conds...))
|
|
}
|
|
|
|
// As alias cannot be heired, As must used on tail
|
|
func (d DO) As(alias string) Dao {
|
|
d.alias = alias
|
|
d.db = d.db.Table(fmt.Sprintf("%s AS %s", d.Quote(d.TableName()), d.Quote(alias)))
|
|
return &d
|
|
}
|
|
|
|
// Alias return alias name
|
|
func (d *DO) Alias() string { return d.alias }
|
|
|
|
// Columns return columns for Subquery
|
|
func (*DO) Columns(cols ...field.Expr) Columns { return cols }
|
|
|
|
// ======================== chainable api ========================
|
|
|
|
// Not ...
|
|
func (d *DO) Not(conds ...Condition) Dao {
|
|
exprs, err := condToExpression(conds)
|
|
if err != nil {
|
|
return d.withError(err)
|
|
}
|
|
if len(exprs) == 0 {
|
|
return d
|
|
}
|
|
return d.getInstance(d.db.Clauses(clause.Where{Exprs: []clause.Expression{clause.Not(exprs...)}}))
|
|
}
|
|
|
|
// Or ...
|
|
func (d *DO) Or(conds ...Condition) Dao {
|
|
exprs, err := condToExpression(conds)
|
|
if err != nil {
|
|
return d.withError(err)
|
|
}
|
|
if len(exprs) == 0 {
|
|
return d
|
|
}
|
|
return d.getInstance(d.db.Clauses(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(exprs...))}}))
|
|
}
|
|
|
|
// Select ...
|
|
func (d *DO) Select(columns ...field.Expr) Dao {
|
|
if len(columns) == 0 {
|
|
return d.getInstance(d.db.Clauses(clause.Select{}))
|
|
}
|
|
query, args := buildExpr4Select(d.db.Statement, columns...)
|
|
return d.getInstance(d.db.Select(query, args...))
|
|
}
|
|
|
|
// Where ...
|
|
func (d *DO) Where(conds ...Condition) Dao {
|
|
exprs, err := condToExpression(conds)
|
|
if err != nil {
|
|
return d.withError(err)
|
|
}
|
|
if len(exprs) == 0 {
|
|
return d
|
|
}
|
|
return d.getInstance(d.db.Clauses(clause.Where{Exprs: exprs}))
|
|
}
|
|
|
|
// Order ...
|
|
func (d *DO) Order(columns ...field.Expr) Dao {
|
|
// lazy build Columns
|
|
// if c, ok := d.db.Statement.Clauses[clause.OrderBy{}.Name()]; ok {
|
|
// if order, ok := c.Expression.(clause.OrderBy); ok {
|
|
// if expr, ok := order.Expression.(clause.CommaExpression); ok {
|
|
// expr.Exprs = append(expr.Exprs, toExpression(columns)...)
|
|
// return d.newInstance(d.db.Clauses(clause.OrderBy{Expression: expr}))
|
|
// }
|
|
// }
|
|
// }
|
|
// return d.newInstance(d.db.Clauses(clause.OrderBy{Expression: clause.CommaExpression{Exprs: toExpression(columns)}}))
|
|
if len(columns) == 0 {
|
|
return d
|
|
}
|
|
return d.getInstance(d.db.Order(d.toOrderValue(columns...)))
|
|
}
|
|
|
|
func (d *DO) toOrderValue(columns ...field.Expr) string {
|
|
// eager build Columns
|
|
stmt := &gorm.Statement{DB: d.db.Statement.DB, Table: d.db.Statement.Table, Schema: d.db.Statement.Schema}
|
|
|
|
for i, c := range columns {
|
|
if i != 0 {
|
|
stmt.WriteByte(',')
|
|
}
|
|
c.Build(stmt)
|
|
}
|
|
|
|
return d.db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
|
}
|
|
|
|
// Distinct ...
|
|
func (d *DO) Distinct(columns ...field.Expr) Dao {
|
|
return d.getInstance(d.db.Distinct(toInterfaceSlice(toColExprFullName(d.db.Statement, columns...))...))
|
|
}
|
|
|
|
// Omit ...
|
|
func (d *DO) Omit(columns ...field.Expr) Dao {
|
|
if len(columns) == 0 {
|
|
return d
|
|
}
|
|
return d.getInstance(d.db.Omit(getColumnName(columns...)...))
|
|
}
|
|
|
|
// Group ...
|
|
func (d *DO) Group(columns ...field.Expr) Dao {
|
|
if len(columns) == 0 {
|
|
return d
|
|
}
|
|
|
|
stmt := &gorm.Statement{DB: d.db.Statement.DB, Table: d.db.Statement.Table, Schema: d.db.Statement.Schema}
|
|
|
|
for i, c := range columns {
|
|
if i != 0 {
|
|
stmt.WriteByte(',')
|
|
}
|
|
c.Build(stmt)
|
|
}
|
|
|
|
return d.getInstance(d.db.Group(stmt.SQL.String()))
|
|
}
|
|
|
|
// Having ...
|
|
func (d *DO) Having(conds ...Condition) Dao {
|
|
exprs, err := condToExpression(conds)
|
|
if err != nil {
|
|
return d.withError(err)
|
|
}
|
|
if len(exprs) == 0 {
|
|
return d
|
|
}
|
|
return d.getInstance(d.db.Clauses(clause.GroupBy{Having: exprs}))
|
|
}
|
|
|
|
// Limit ...
|
|
func (d *DO) Limit(limit int) Dao {
|
|
return d.getInstance(d.db.Limit(limit))
|
|
}
|
|
|
|
// Offset ...
|
|
func (d *DO) Offset(offset int) Dao {
|
|
return d.getInstance(d.db.Offset(offset))
|
|
}
|
|
|
|
// Scopes ...
|
|
func (d *DO) Scopes(funcs ...func(Dao) Dao) Dao {
|
|
fcs := make([]func(*gorm.DB) *gorm.DB, len(funcs))
|
|
for i, f := range funcs {
|
|
sf := f
|
|
fcs[i] = func(tx *gorm.DB) *gorm.DB { return sf(d.getInstance(tx)).(*DO).db }
|
|
}
|
|
return d.getInstance(d.db.Scopes(fcs...))
|
|
}
|
|
|
|
// Unscoped ...
|
|
func (d *DO) Unscoped() Dao {
|
|
return d.getInstance(d.db.Unscoped())
|
|
}
|
|
|
|
// Join ...
|
|
func (d *DO) Join(table schema.Tabler, conds ...field.Expr) Dao {
|
|
return d.join(table, clause.InnerJoin, conds)
|
|
}
|
|
|
|
// LeftJoin ...
|
|
func (d *DO) LeftJoin(table schema.Tabler, conds ...field.Expr) Dao {
|
|
return d.join(table, clause.LeftJoin, conds)
|
|
}
|
|
|
|
// RightJoin ...
|
|
func (d *DO) RightJoin(table schema.Tabler, conds ...field.Expr) Dao {
|
|
return d.join(table, clause.RightJoin, conds)
|
|
}
|
|
|
|
func (d *DO) join(table schema.Tabler, joinType clause.JoinType, conds []field.Expr) Dao {
|
|
if len(conds) == 0 {
|
|
return d.withError(ErrEmptyCondition)
|
|
}
|
|
|
|
join := clause.Join{
|
|
Type: joinType,
|
|
Table: clause.Table{Name: table.TableName()},
|
|
ON: clause.Where{Exprs: toExpression(conds...)},
|
|
}
|
|
if do, ok := table.(Dao); ok {
|
|
join.Expression = helper.NewJoinTblExpr(join, Table(do).underlyingDB().Statement.TableExpr)
|
|
}
|
|
if al, ok := table.(interface{ Alias() string }); ok {
|
|
join.Table.Alias = al.Alias()
|
|
}
|
|
|
|
from := getFromClause(d.db)
|
|
from.Joins = append(from.Joins, join)
|
|
return d.getInstance(d.db.Clauses(from))
|
|
}
|
|
|
|
// Attrs ...
|
|
func (d *DO) Attrs(attrs ...field.AssignExpr) Dao {
|
|
if len(attrs) == 0 {
|
|
return d
|
|
}
|
|
return d.getInstance(d.db.Attrs(d.attrsValue(attrs)...))
|
|
}
|
|
|
|
// Assign ...
|
|
func (d *DO) Assign(attrs ...field.AssignExpr) Dao {
|
|
if len(attrs) == 0 {
|
|
return d
|
|
}
|
|
return d.getInstance(d.db.Assign(d.attrsValue(attrs)...))
|
|
}
|
|
|
|
func (d *DO) attrsValue(attrs []field.AssignExpr) []interface{} {
|
|
values := make([]interface{}, 0, len(attrs))
|
|
for _, attr := range attrs {
|
|
if expr, ok := attr.AssignExpr().(field.IValues); ok {
|
|
values = append(values, expr.Values())
|
|
} else if expr, ok := attr.AssignExpr().(clause.Eq); ok {
|
|
values = append(values, expr)
|
|
}
|
|
}
|
|
return values
|
|
}
|
|
|
|
// Joins ...
|
|
func (d *DO) Joins(field field.RelationField) Dao {
|
|
var args []interface{}
|
|
|
|
if conds := field.GetConds(); len(conds) > 0 {
|
|
var exprs []clause.Expression
|
|
for _, oe := range toExpression(conds...) {
|
|
switch e := oe.(type) {
|
|
case clause.Eq:
|
|
if c, ok := e.Column.(clause.Column); ok {
|
|
c.Table = field.Name()
|
|
e.Column = c
|
|
}
|
|
exprs = append(exprs, e)
|
|
case clause.Neq:
|
|
if c, ok := e.Column.(clause.Column); ok {
|
|
c.Table = field.Name()
|
|
e.Column = c
|
|
}
|
|
exprs = append(exprs, e)
|
|
case clause.Gt:
|
|
if c, ok := e.Column.(clause.Column); ok {
|
|
c.Table = field.Name()
|
|
e.Column = c
|
|
}
|
|
exprs = append(exprs, e)
|
|
case clause.Gte:
|
|
if c, ok := e.Column.(clause.Column); ok {
|
|
c.Table = field.Name()
|
|
e.Column = c
|
|
}
|
|
exprs = append(exprs, e)
|
|
case clause.Lt:
|
|
if c, ok := e.Column.(clause.Column); ok {
|
|
c.Table = field.Name()
|
|
e.Column = c
|
|
}
|
|
exprs = append(exprs, e)
|
|
case clause.Lte:
|
|
if c, ok := e.Column.(clause.Column); ok {
|
|
c.Table = field.Name()
|
|
e.Column = c
|
|
}
|
|
exprs = append(exprs, e)
|
|
case clause.Like:
|
|
if c, ok := e.Column.(clause.Column); ok {
|
|
c.Table = field.Name()
|
|
e.Column = c
|
|
}
|
|
exprs = append(exprs, e)
|
|
}
|
|
}
|
|
|
|
args = append(args, d.db.Clauses(clause.Where{
|
|
Exprs: exprs,
|
|
}))
|
|
}
|
|
if columns := field.GetSelects(); len(columns) > 0 {
|
|
colNames := make([]string, len(columns))
|
|
for i, c := range columns {
|
|
colNames[i] = string(c.ColumnName())
|
|
}
|
|
args = append(args, func(db *gorm.DB) *gorm.DB {
|
|
return db.Select(colNames)
|
|
})
|
|
}
|
|
if columns := field.GetOrderCol(); len(columns) > 0 {
|
|
var os []string
|
|
for _, oe := range columns {
|
|
switch e := oe.RawExpr().(type) {
|
|
case clause.Expr:
|
|
vs := []interface{}{}
|
|
for _, v := range e.Vars {
|
|
if c, ok := v.(clause.Column); ok {
|
|
vs = append(vs, clause.Column{
|
|
Table: field.Name(),
|
|
Name: c.Name,
|
|
Alias: c.Alias,
|
|
Raw: c.Raw,
|
|
})
|
|
}
|
|
}
|
|
e.Vars = vs
|
|
newStmt := &gorm.Statement{DB: d.db.Statement.DB, Table: d.db.Statement.Table, Schema: d.db.Statement.Schema}
|
|
e.Build(newStmt)
|
|
os = append(os, newStmt.SQL.String())
|
|
}
|
|
}
|
|
args = append(args, d.db.Order(strings.Join(os, ",")))
|
|
}
|
|
if clauses := field.GetClauses(); len(clauses) > 0 {
|
|
args = append(args, func(db *gorm.DB) *gorm.DB {
|
|
return db.Clauses(clauses...)
|
|
})
|
|
}
|
|
if funcs := field.GetScopes(); len(funcs) > 0 {
|
|
for _, f := range funcs {
|
|
args = append(args, (func(*gorm.DB) *gorm.DB)(f))
|
|
}
|
|
}
|
|
if offset, limit := field.GetPage(); offset|limit != 0 {
|
|
args = append(args, func(db *gorm.DB) *gorm.DB {
|
|
return db.Offset(offset).Limit(limit)
|
|
})
|
|
}
|
|
|
|
return d.getInstance(d.db.Joins(field.Path(), args...))
|
|
}
|
|
|
|
// Preload ...
|
|
func (d *DO) Preload(field field.RelationField) Dao {
|
|
var args []interface{}
|
|
if conds := field.GetConds(); len(conds) > 0 {
|
|
args = append(args, toExpressionInterface(conds...)...)
|
|
}
|
|
if columns := field.GetSelects(); len(columns) > 0 {
|
|
colNames := make([]string, len(columns))
|
|
for i, c := range columns {
|
|
colNames[i] = string(c.ColumnName())
|
|
}
|
|
args = append(args, func(db *gorm.DB) *gorm.DB {
|
|
return db.Select(colNames)
|
|
})
|
|
}
|
|
if columns := field.GetOrderCol(); len(columns) > 0 {
|
|
args = append(args, func(db *gorm.DB) *gorm.DB {
|
|
return db.Order(d.toOrderValue(columns...))
|
|
})
|
|
}
|
|
if clauses := field.GetClauses(); len(clauses) > 0 {
|
|
args = append(args, func(db *gorm.DB) *gorm.DB {
|
|
return db.Clauses(clauses...)
|
|
})
|
|
}
|
|
if funcs := field.GetScopes(); len(funcs) > 0 {
|
|
for _, f := range funcs {
|
|
args = append(args, (func(*gorm.DB) *gorm.DB)(f))
|
|
}
|
|
}
|
|
if offset, limit := field.GetPage(); offset|limit != 0 {
|
|
args = append(args, func(db *gorm.DB) *gorm.DB {
|
|
return db.Offset(offset).Limit(limit)
|
|
})
|
|
}
|
|
return d.getInstance(d.db.Preload(field.Path(), args...))
|
|
}
|
|
|
|
// UpdateFrom specify update sub query
|
|
func (d *DO) UpdateFrom(q SubQuery) Dao {
|
|
var tableName strings.Builder
|
|
d.db.Statement.QuoteTo(&tableName, d.TableName())
|
|
if d.alias != "" {
|
|
tableName.WriteString(" AS ")
|
|
d.db.Statement.QuoteTo(&tableName, d.alias)
|
|
}
|
|
|
|
tableName.WriteByte(',')
|
|
if _, ok := q.underlyingDB().Statement.Clauses["SELECT"]; ok || len(q.underlyingDB().Statement.Selects) > 0 {
|
|
tableName.WriteString("(" + q.underlyingDB().ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Table(q.underlyingDO().TableName()).Find(nil) }) + ")")
|
|
} else {
|
|
d.db.Statement.QuoteTo(&tableName, q.underlyingDO().TableName())
|
|
}
|
|
if alias := q.underlyingDO().alias; alias != "" {
|
|
tableName.WriteString(" AS ")
|
|
d.db.Statement.QuoteTo(&tableName, alias)
|
|
}
|
|
|
|
return d.getInstance(d.db.Clauses(clause.Update{Table: clause.Table{Name: tableName.String(), Raw: true}}))
|
|
}
|
|
|
|
func getFromClause(db *gorm.DB) *clause.From {
|
|
if db == nil || db.Statement == nil {
|
|
return &clause.From{}
|
|
}
|
|
c, ok := db.Statement.Clauses[clause.From{}.Name()]
|
|
if !ok || c.Expression == nil {
|
|
return &clause.From{}
|
|
}
|
|
from, ok := c.Expression.(clause.From)
|
|
if !ok {
|
|
return &clause.From{}
|
|
}
|
|
return &from
|
|
}
|
|
|
|
// ======================== finisher api ========================
|
|
|
|
// Create ...
|
|
func (d *DO) Create(value interface{}) error {
|
|
return d.db.Create(value).Error
|
|
}
|
|
|
|
// CreateInBatches ...
|
|
func (d *DO) CreateInBatches(value interface{}, batchSize int) error {
|
|
return d.db.CreateInBatches(value, batchSize).Error
|
|
}
|
|
|
|
// Save ...
|
|
func (d *DO) Save(value interface{}) error {
|
|
return d.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(value).Error
|
|
}
|
|
|
|
// First ...
|
|
func (d *DO) First() (result interface{}, err error) {
|
|
return d.singleQuery(d.db.First)
|
|
}
|
|
|
|
// Take ...
|
|
func (d *DO) Take() (result interface{}, err error) {
|
|
return d.singleQuery(d.db.Take)
|
|
}
|
|
|
|
// Last ...
|
|
func (d *DO) Last() (result interface{}, err error) {
|
|
return d.singleQuery(d.db.Last)
|
|
}
|
|
|
|
func (d *DO) singleQuery(query func(dest interface{}, conds ...interface{}) *gorm.DB) (result interface{}, err error) {
|
|
if d.modelType == nil {
|
|
return d.singleScan()
|
|
}
|
|
|
|
result = d.newResultPointer()
|
|
if err := query(result).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (d *DO) singleScan() (result interface{}, err error) {
|
|
result = map[string]interface{}{}
|
|
err = d.db.Scan(result).Error
|
|
return
|
|
}
|
|
|
|
// Find ...
|
|
func (d *DO) Find() (results interface{}, err error) {
|
|
return d.multiQuery(d.db.Find)
|
|
}
|
|
|
|
func (d *DO) multiQuery(query func(dest interface{}, conds ...interface{}) *gorm.DB) (results interface{}, err error) {
|
|
if d.modelType == nil {
|
|
return d.findToMap()
|
|
}
|
|
|
|
resultsPtr := d.newResultSlicePointer()
|
|
err = query(resultsPtr).Error
|
|
return reflect.Indirect(reflect.ValueOf(resultsPtr)).Interface(), err
|
|
}
|
|
|
|
func (d *DO) findToMap() (interface{}, error) {
|
|
var results []map[string]interface{}
|
|
err := d.db.Find(&results).Error
|
|
return results, err
|
|
}
|
|
|
|
// FindInBatches ...
|
|
func (d *DO) FindInBatches(dest interface{}, batchSize int, fc func(tx Dao, batch int) error) error {
|
|
return d.db.FindInBatches(dest, batchSize, func(tx *gorm.DB, batch int) error { return fc(d.getInstance(tx), batch) }).Error
|
|
}
|
|
|
|
// FirstOrInit ...
|
|
func (d *DO) FirstOrInit() (result interface{}, err error) {
|
|
return d.singleQuery(d.db.FirstOrInit)
|
|
}
|
|
|
|
// FirstOrCreate ...
|
|
func (d *DO) FirstOrCreate() (result interface{}, err error) {
|
|
return d.singleQuery(d.db.FirstOrCreate)
|
|
}
|
|
|
|
// Update ...
|
|
func (d *DO) Update(column field.Expr, value interface{}) (info ResultInfo, err error) {
|
|
tx := d.db.Model(d.newResultPointer())
|
|
columnStr := column.BuildColumn(d.db.Statement, field.WithoutQuote).String()
|
|
|
|
var result *gorm.DB
|
|
switch value := value.(type) {
|
|
case field.AssignExpr:
|
|
result = tx.Update(columnStr, value.AssignExpr())
|
|
case SubQuery:
|
|
result = tx.Update(columnStr, value.underlyingDB())
|
|
default:
|
|
result = tx.Update(columnStr, value)
|
|
}
|
|
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
|
|
}
|
|
|
|
// UpdateSimple ...
|
|
func (d *DO) UpdateSimple(columns ...field.AssignExpr) (info ResultInfo, err error) {
|
|
if len(columns) == 0 {
|
|
return
|
|
}
|
|
|
|
result := d.db.Model(d.newResultPointer()).Clauses(d.assignSet(columns)).Omit("*").Updates(map[string]interface{}{})
|
|
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
|
|
}
|
|
|
|
// Updates ...
|
|
func (d *DO) Updates(value interface{}) (info ResultInfo, err error) {
|
|
var rawTyp, valTyp reflect.Type
|
|
|
|
rawTyp = reflect.TypeOf(value)
|
|
if rawTyp.Kind() == reflect.Ptr {
|
|
valTyp = rawTyp.Elem()
|
|
} else {
|
|
valTyp = rawTyp
|
|
}
|
|
|
|
tx := d.db
|
|
if d.backfillData != nil {
|
|
tx = tx.Model(d.backfillData)
|
|
}
|
|
switch {
|
|
case valTyp != d.modelType: // different type with model
|
|
if d.backfillData == nil {
|
|
tx = tx.Model(d.newResultPointer())
|
|
}
|
|
case rawTyp.Kind() == reflect.Ptr: // ignore ptr value
|
|
default: // for fixing "reflect.Value.Addr of unaddressable value" panic
|
|
ptr := reflect.New(d.modelType)
|
|
ptr.Elem().Set(reflect.ValueOf(value))
|
|
value = ptr.Interface()
|
|
}
|
|
result := tx.Updates(value)
|
|
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
|
|
}
|
|
|
|
// UpdateColumn ...
|
|
func (d *DO) UpdateColumn(column field.Expr, value interface{}) (info ResultInfo, err error) {
|
|
tx := d.db.Model(d.newResultPointer())
|
|
columnStr := column.BuildColumn(d.db.Statement, field.WithoutQuote).String()
|
|
|
|
var result *gorm.DB
|
|
switch value := value.(type) {
|
|
case field.Expr:
|
|
result = tx.UpdateColumn(columnStr, value.RawExpr())
|
|
case SubQuery:
|
|
result = d.db.UpdateColumn(columnStr, value.underlyingDB())
|
|
default:
|
|
result = d.db.UpdateColumn(columnStr, value)
|
|
}
|
|
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
|
|
}
|
|
|
|
// UpdateColumnSimple ...
|
|
func (d *DO) UpdateColumnSimple(columns ...field.AssignExpr) (info ResultInfo, err error) {
|
|
if len(columns) == 0 {
|
|
return
|
|
}
|
|
|
|
result := d.db.Model(d.newResultPointer()).Clauses(d.assignSet(columns)).Omit("*").UpdateColumns(map[string]interface{}{})
|
|
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
|
|
}
|
|
|
|
// UpdateColumns ...
|
|
func (d *DO) UpdateColumns(value interface{}) (info ResultInfo, err error) {
|
|
result := d.db.Model(d.newResultPointer()).UpdateColumns(value)
|
|
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
|
|
}
|
|
|
|
// assignSet fetch all set
|
|
func (d *DO) assignSet(exprs []field.AssignExpr) (set clause.Set) {
|
|
for _, expr := range exprs {
|
|
column := clause.Column{Table: d.alias, Name: string(expr.ColumnName())}
|
|
switch e := expr.AssignExpr().(type) {
|
|
case clause.Expr:
|
|
set = append(set, clause.Assignment{Column: column, Value: e})
|
|
case clause.Eq:
|
|
set = append(set, clause.Assignment{Column: column, Value: e.Value})
|
|
case clause.Set:
|
|
set = append(set, e...)
|
|
}
|
|
}
|
|
|
|
stmt := d.db.Session(&gorm.Session{}).Statement
|
|
stmt.Dest = map[string]interface{}{}
|
|
return append(set, callbacks.ConvertToAssignments(stmt)...)
|
|
}
|
|
|
|
// Delete ...
|
|
func (d *DO) Delete(models ...interface{}) (info ResultInfo, err error) {
|
|
var result *gorm.DB
|
|
if len(models) == 0 || reflect.ValueOf(models[0]).Len() == 0 {
|
|
result = d.db.Model(d.newResultPointer()).Delete(reflect.New(d.modelType).Interface())
|
|
} else {
|
|
targets := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(d.modelType)), 0, len(models))
|
|
value := reflect.ValueOf(models[0])
|
|
for i := 0; i < value.Len(); i++ {
|
|
targets = reflect.Append(targets, value.Index(i))
|
|
}
|
|
result = d.db.Delete(targets.Interface())
|
|
}
|
|
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
|
|
}
|
|
|
|
// Count ...
|
|
func (d *DO) Count() (count int64, err error) {
|
|
return count, d.db.Session(&gorm.Session{}).Model(d.newResultPointer()).Count(&count).Error
|
|
}
|
|
|
|
// Row ...
|
|
func (d *DO) Row() *sql.Row {
|
|
return d.db.Model(d.newResultPointer()).Row()
|
|
}
|
|
|
|
// Rows ...
|
|
func (d *DO) Rows() (*sql.Rows, error) {
|
|
return d.db.Model(d.newResultPointer()).Rows()
|
|
}
|
|
|
|
// Scan ...
|
|
func (d *DO) Scan(dest interface{}) error {
|
|
return d.db.Model(d.newResultPointer()).Scan(dest).Error
|
|
}
|
|
|
|
// Pluck ...
|
|
func (d *DO) Pluck(column field.Expr, dest interface{}) error {
|
|
return d.db.Model(d.newResultPointer()).Pluck(column.ColumnName().String(), dest).Error
|
|
}
|
|
|
|
// ScanRows ...
|
|
func (d *DO) ScanRows(rows *sql.Rows, dest interface{}) error {
|
|
return d.db.Model(d.newResultPointer()).ScanRows(rows, dest)
|
|
}
|
|
|
|
// WithResult ...
|
|
func (d DO) WithResult(fc func(tx Dao)) ResultInfo {
|
|
d.db = d.db.Set("", "")
|
|
fc(&d)
|
|
return ResultInfo{RowsAffected: d.db.RowsAffected, Error: d.db.Error}
|
|
}
|
|
|
|
func (d *DO) newResultPointer() interface{} {
|
|
if d.backfillData != nil {
|
|
return d.backfillData
|
|
}
|
|
if d.modelType == nil {
|
|
return nil
|
|
}
|
|
return reflect.New(d.modelType).Interface()
|
|
}
|
|
|
|
func (d *DO) newResultSlicePointer() interface{} {
|
|
return reflect.New(reflect.SliceOf(reflect.PtrTo(d.modelType))).Interface()
|
|
}
|
|
|
|
func (d *DO) AddError(err error) error {
|
|
return d.underlyingDB().AddError(err)
|
|
}
|
|
|
|
func toColExprFullName(stmt *gorm.Statement, columns ...field.Expr) []string {
|
|
return buildColExpr(stmt, columns, field.WithAll)
|
|
}
|
|
|
|
func getColumnName(columns ...field.Expr) (result []string) {
|
|
for _, c := range columns {
|
|
result = append(result, c.ColumnName().String())
|
|
}
|
|
return result
|
|
}
|
|
|
|
func buildColExpr(stmt *gorm.Statement, cols []field.Expr, opts ...field.BuildOpt) []string {
|
|
results := make([]string, len(cols))
|
|
for i, c := range cols {
|
|
switch c.RawExpr().(type) {
|
|
case clause.Column:
|
|
results[i] = c.BuildColumn(stmt, opts...).String()
|
|
case clause.Expression:
|
|
sql, args := c.BuildWithArgs(stmt)
|
|
results[i] = stmt.Dialector.Explain(sql.String(), args...)
|
|
}
|
|
}
|
|
return results
|
|
}
|
|
|
|
func buildExpr4Select(stmt *gorm.Statement, exprs ...field.Expr) (query string, args []interface{}) {
|
|
if len(exprs) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
var queryItems []string
|
|
for _, e := range exprs {
|
|
sql, vars := e.BuildWithArgs(stmt)
|
|
queryItems = append(queryItems, sql.String())
|
|
args = append(args, vars...)
|
|
}
|
|
if len(args) == 0 {
|
|
return queryItems[0], toInterfaceSlice(queryItems[1:])
|
|
}
|
|
return strings.Join(queryItems, ","), args
|
|
}
|
|
|
|
func toExpression(exprs ...field.Expr) []clause.Expression {
|
|
result := make([]clause.Expression, len(exprs))
|
|
for i, e := range exprs {
|
|
result[i] = singleExpr(e)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func toExpressionInterface(exprs ...field.Expr) []interface{} {
|
|
result := make([]interface{}, len(exprs))
|
|
for i, e := range exprs {
|
|
result[i] = singleExpr(e)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func singleExpr(e field.Expr) clause.Expression {
|
|
switch v := e.RawExpr().(type) {
|
|
case clause.Expression:
|
|
return v
|
|
case clause.Column:
|
|
return clause.NamedExpr{SQL: "?", Vars: []interface{}{v}}
|
|
default:
|
|
return clause.Expr{}
|
|
}
|
|
}
|
|
|
|
func toInterfaceSlice(value interface{}) []interface{} {
|
|
switch v := value.(type) {
|
|
case string:
|
|
return []interface{}{v}
|
|
case []string:
|
|
res := make([]interface{}, len(v))
|
|
for i, item := range v {
|
|
res[i] = item
|
|
}
|
|
return res
|
|
case []clause.Column:
|
|
res := make([]interface{}, len(v))
|
|
for i, item := range v {
|
|
res[i] = item
|
|
}
|
|
return res
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// ======================== New Table ========================
|
|
|
|
// Table return a new table produced by subquery,
|
|
// the return value has to be used as root node
|
|
//
|
|
// Table(u.Select(u.ID, u.Name).Where(u.Age.Gt(18))).Select()
|
|
//
|
|
// the above usage is equivalent to SQL statement:
|
|
//
|
|
// SELECT * FROM (SELECT `id`, `name` FROM `users_info` WHERE `age` > ?)"
|
|
func Table(subQueries ...SubQuery) Dao {
|
|
if len(subQueries) == 0 {
|
|
return &DO{}
|
|
}
|
|
|
|
tablePlaceholder := make([]string, len(subQueries))
|
|
tableExprs := make([]interface{}, len(subQueries))
|
|
for i, query := range subQueries {
|
|
tablePlaceholder[i] = "(?)"
|
|
|
|
do := query.underlyingDO()
|
|
// ignore alias, or will misuse with sub query alias
|
|
tableExprs[i] = do.db.Table(do.TableName())
|
|
if do.alias != "" {
|
|
tablePlaceholder[i] += " AS " + do.Quote(do.alias)
|
|
}
|
|
}
|
|
|
|
return &DO{
|
|
db: subQueries[0].underlyingDO().db.Session(&gorm.Session{NewDB: true}).
|
|
Table(strings.Join(tablePlaceholder, ", "), tableExprs...),
|
|
}
|
|
}
|
|
|
|
// Exists EXISTS expression
|
|
// SELECT * FROM table WHERE EXISTS (SELECT NAME FROM users WHERE id = 1)
|
|
func Exists(subQuery SubQuery) Condition {
|
|
return field.CompareSubQuery(field.ExistsOp, nil, subQuery.underlyingDB())
|
|
}
|
|
|
|
// ======================== sub query method ========================
|
|
|
|
// Columns columns array
|
|
type Columns []field.Expr
|
|
|
|
// Set assign value by subquery
|
|
func (cs Columns) Set(query SubQuery) field.AssignExpr {
|
|
return field.AssignSubQuery(cs, query.underlyingDB())
|
|
}
|
|
|
|
// In accept query or value
|
|
func (cs Columns) In(queryOrValue Condition) field.Expr {
|
|
if len(cs) == 0 {
|
|
return field.EmptyExpr()
|
|
}
|
|
|
|
switch query := queryOrValue.(type) {
|
|
case field.Value:
|
|
return field.ContainsValue(cs, query)
|
|
case SubQuery:
|
|
return field.ContainsSubQuery(cs, query.underlyingDB())
|
|
default:
|
|
return field.EmptyExpr()
|
|
}
|
|
}
|
|
|
|
// NotIn ...
|
|
func (cs Columns) NotIn(queryOrValue Condition) field.Expr {
|
|
return field.Not(cs.In(queryOrValue))
|
|
}
|
|
|
|
// Eq ...
|
|
func (cs Columns) Eq(query SubQuery) field.Expr {
|
|
if len(cs) == 0 {
|
|
return field.EmptyExpr()
|
|
}
|
|
return field.CompareSubQuery(field.EqOp, cs[0], query.underlyingDB())
|
|
}
|
|
|
|
// Neq ...
|
|
func (cs Columns) Neq(query SubQuery) field.Expr {
|
|
if len(cs) == 0 {
|
|
return field.EmptyExpr()
|
|
}
|
|
return field.CompareSubQuery(field.NeqOp, cs[0], query.underlyingDB())
|
|
}
|
|
|
|
// Gt ...
|
|
func (cs Columns) Gt(query SubQuery) field.Expr {
|
|
if len(cs) == 0 {
|
|
return field.EmptyExpr()
|
|
}
|
|
return field.CompareSubQuery(field.GtOp, cs[0], query.underlyingDB())
|
|
}
|
|
|
|
// Gte ...
|
|
func (cs Columns) Gte(query SubQuery) field.Expr {
|
|
if len(cs) == 0 {
|
|
return field.EmptyExpr()
|
|
}
|
|
return field.CompareSubQuery(field.GteOp, cs[0], query.underlyingDB())
|
|
}
|
|
|
|
// Lt ...
|
|
func (cs Columns) Lt(query SubQuery) field.Expr {
|
|
if len(cs) == 0 {
|
|
return field.EmptyExpr()
|
|
}
|
|
return field.CompareSubQuery(field.LtOp, cs[0], query.underlyingDB())
|
|
}
|
|
|
|
// Lte ...
|
|
func (cs Columns) Lte(query SubQuery) field.Expr {
|
|
if len(cs) == 0 {
|
|
return field.EmptyExpr()
|
|
}
|
|
return field.CompareSubQuery(field.LteOp, cs[0], query.underlyingDB())
|
|
}
|