package sqlbuilder import ( "context" "database/sql" "errors" "fmt" "strings" "github.com/upper/db/v4" "github.com/upper/db/v4/internal/adapter" "github.com/upper/db/v4/internal/immutable" "github.com/upper/db/v4/internal/sqladapter/exql" ) type selectorQuery struct { table *exql.Columns tableArgs []interface{} distinct bool where *exql.Where whereArgs []interface{} groupBy *exql.GroupBy groupByArgs []interface{} orderBy *exql.OrderBy orderByArgs []interface{} limit exql.Limit offset exql.Offset columns *exql.Columns columnsArgs []interface{} joins []*exql.Join joinsArgs []interface{} amendFn func(string) string } func (sq *selectorQuery) and(b *sqlBuilder, terms ...interface{}) error { where, whereArgs := b.t.toWhereWithArguments(terms) if sq.where == nil { sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} } sq.where.Append(&where) sq.whereArgs = append(sq.whereArgs, whereArgs...) return nil } func (sq *selectorQuery) arguments() []interface{} { return joinArguments( sq.columnsArgs, sq.tableArgs, sq.joinsArgs, sq.whereArgs, sq.groupByArgs, sq.orderByArgs, ) } func (sq *selectorQuery) statement() *exql.Statement { stmt := &exql.Statement{ Type: exql.Select, Table: sq.table, Columns: sq.columns, Distinct: sq.distinct, Limit: sq.limit, Offset: sq.offset, Where: sq.where, OrderBy: sq.orderBy, GroupBy: sq.groupBy, } if len(sq.joins) > 0 { stmt.Joins = exql.JoinConditions(sq.joins...) } stmt.SetAmendment(sq.amendFn) return stmt } func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error { fragments, args, err := columnFragments(tables) if err != nil { return err } if sq.joins == nil { sq.joins = []*exql.Join{} } sq.joins = append(sq.joins, &exql.Join{ Type: t, Table: exql.JoinColumns(fragments...), }, ) sq.joinsArgs = append(sq.joinsArgs, args...) return nil } type selector struct { builder *sqlBuilder fn func(*selectorQuery) error prev *selector } var _ = immutable.Immutable(&selector{}) func (sel *selector) SQL() *sqlBuilder { if sel.prev == nil { return sel.builder } return sel.prev.SQL() } func (sel *selector) String() string { s, err := sel.Compile() if err != nil { panic(err.Error()) } return prepareQueryForDisplay(s) } func (sel *selector) frame(fn func(*selectorQuery) error) *selector { return &selector{prev: sel, fn: fn} } func (sel *selector) clone() db.Selector { return sel.frame(func(*selectorQuery) error { return nil }) } func (sel *selector) From(tables ...interface{}) db.Selector { return sel.frame( func(sq *selectorQuery) error { fragments, args, err := columnFragments(tables) if err != nil { return err } sq.table = exql.JoinColumns(fragments...) sq.tableArgs = args return nil }, ) } func (sel *selector) setColumns(columns ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { sq.columns = nil return sq.pushColumns(columns...) }) } func (sel *selector) Columns(columns ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { return sq.pushColumns(columns...) }) } func (sq *selectorQuery) pushColumns(columns ...interface{}) error { f, args, err := columnFragments(columns) if err != nil { return err } c := exql.JoinColumns(f...) if sq.columns != nil { sq.columns.Append(c) } else { sq.columns = c } sq.columnsArgs = append(sq.columnsArgs, args...) return nil } func (sel *selector) Distinct(exps ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { sq.distinct = true return sq.pushColumns(exps...) }) } func (sel *selector) Where(terms ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { if len(terms) == 1 && terms[0] == nil { sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} return nil } return sq.and(sel.SQL(), terms...) }) } func (sel *selector) And(terms ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { return sq.and(sel.SQL(), terms...) }) } func (sel *selector) Amend(fn func(string) string) db.Selector { return sel.frame(func(sq *selectorQuery) error { sq.amendFn = fn return nil }) } func (sel *selector) Arguments() []interface{} { sq, err := sel.build() if err != nil { return nil } return sq.arguments() } func (sel *selector) GroupBy(columns ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { fragments, args, err := columnFragments(columns) if err != nil { return err } if fragments != nil { sq.groupBy = exql.GroupByColumns(fragments...) } sq.groupByArgs = args return nil }) } func (sel *selector) OrderBy(columns ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { if len(columns) == 1 && columns[0] == nil { sq.orderBy = nil sq.orderByArgs = nil return nil } var sortColumns exql.SortColumns for i := range columns { var sort *exql.SortColumn switch value := columns[i].(type) { case *adapter.RawExpr: query, args := Preprocess(value.Raw(), value.Arguments()) sort = &exql.SortColumn{ Column: &exql.Raw{Value: query}, } sq.orderByArgs = append(sq.orderByArgs, args...) case *adapter.FuncExpr: fnName, fnArgs := value.Name(), value.Arguments() if len(fnArgs) == 0 { fnName = fnName + "()" } else { fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } fnName, fnArgs = Preprocess(fnName, fnArgs) sort = &exql.SortColumn{ Column: &exql.Raw{Value: fnName}, } sq.orderByArgs = append(sq.orderByArgs, fnArgs...) case string: if strings.HasPrefix(value, "-") { sort = &exql.SortColumn{ Column: exql.ColumnWithName(value[1:]), Order: exql.Order_Descendent, } } else { chunks := strings.SplitN(value, " ", 2) order := exql.Order_Ascendent if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" { order = exql.Order_Descendent } sort = &exql.SortColumn{ Column: exql.ColumnWithName(chunks[0]), Order: order, } } default: return fmt.Errorf("Can't sort by type %T", value) } sortColumns.Columns = append(sortColumns.Columns, sort) } sq.orderBy = &exql.OrderBy{ SortColumns: &sortColumns, } return nil }) } func (sel *selector) Using(columns ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { joins := len(sq.joins) if joins == 0 { return errors.New(`cannot use Using() without a preceding Join() expression`) } lastJoin := sq.joins[joins-1] if lastJoin.On != nil { return errors.New(`cannot use Using() and On() with the same Join() expression`) } fragments, args, err := columnFragments(columns) if err != nil { return err } sq.joinsArgs = append(sq.joinsArgs, args...) lastJoin.Using = exql.UsingColumns(fragments...) return nil }) } func (sel *selector) FullJoin(tables ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { return sq.pushJoin("FULL", tables) }) } func (sel *selector) CrossJoin(tables ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { return sq.pushJoin("CROSS", tables) }) } func (sel *selector) RightJoin(tables ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { return sq.pushJoin("RIGHT", tables) }) } func (sel *selector) LeftJoin(tables ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { return sq.pushJoin("LEFT", tables) }) } func (sel *selector) Join(tables ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { return sq.pushJoin("", tables) }) } func (sel *selector) On(terms ...interface{}) db.Selector { return sel.frame(func(sq *selectorQuery) error { joins := len(sq.joins) if joins == 0 { return errors.New(`cannot use On() without a preceding Join() expression`) } lastJoin := sq.joins[joins-1] if lastJoin.On != nil { return errors.New(`cannot use Using() and On() with the same Join() expression`) } w, a := sel.SQL().t.toWhereWithArguments(terms) o := exql.On(w) lastJoin.On = &o sq.joinsArgs = append(sq.joinsArgs, a...) return nil }) } func (sel *selector) Limit(n int) db.Selector { return sel.frame(func(sq *selectorQuery) error { if n < 0 { n = 0 } sq.limit = exql.Limit(n) return nil }) } func (sel *selector) Offset(n int) db.Selector { return sel.frame(func(sq *selectorQuery) error { if n < 0 { n = 0 } sq.offset = exql.Offset(n) return nil }) } func (sel *selector) template() *exql.Template { return sel.SQL().t.Template } func (sel *selector) As(alias string) db.Selector { return sel.frame(func(sq *selectorQuery) error { if sq.table == nil { return errors.New("Cannot use As() without a preceding From() expression") } last := len(sq.table.Columns) - 1 if raw, ok := sq.table.Columns[last].(*exql.Raw); ok { compiled, err := exql.ColumnWithName(alias).Compile(sel.template()) if err != nil { return err } sq.table.Columns[last] = &exql.Raw{Value: raw.Value + " AS " + compiled} } return nil }) } func (sel *selector) statement() *exql.Statement { sq, _ := sel.build() return sq.statement() } func (sel *selector) QueryRow() (*sql.Row, error) { return sel.QueryRowContext(sel.SQL().sess.Context()) } func (sel *selector) QueryRowContext(ctx context.Context) (*sql.Row, error) { sq, err := sel.build() if err != nil { return nil, err } return sel.SQL().sess.StatementQueryRow(ctx, sq.statement(), sq.arguments()...) } func (sel *selector) Prepare() (*sql.Stmt, error) { return sel.PrepareContext(sel.SQL().sess.Context()) } func (sel *selector) PrepareContext(ctx context.Context) (*sql.Stmt, error) { sq, err := sel.build() if err != nil { return nil, err } return sel.SQL().sess.StatementPrepare(ctx, sq.statement()) } func (sel *selector) Query() (*sql.Rows, error) { return sel.QueryContext(sel.SQL().sess.Context()) } func (sel *selector) QueryContext(ctx context.Context) (*sql.Rows, error) { sq, err := sel.build() if err != nil { return nil, err } return sel.SQL().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) } func (sel *selector) Iterator() db.Iterator { return sel.IteratorContext(sel.SQL().sess.Context()) } func (sel *selector) IteratorContext(ctx context.Context) db.Iterator { sess := sel.SQL().sess sq, err := sel.build() if err != nil { return &iterator{sess, nil, err} } rows, err := sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) return &iterator{sess, rows, err} } func (sel *selector) Paginate(pageSize uint) db.Paginator { return newPaginator(sel.clone(), pageSize) } func (sel *selector) All(destSlice interface{}) error { return sel.Iterator().All(destSlice) } func (sel *selector) One(dest interface{}) error { return sel.Iterator().One(dest) } func (sel *selector) build() (*selectorQuery, error) { sq, err := immutable.FastForward(sel) if err != nil { return nil, err } return sq.(*selectorQuery), nil } func (sel *selector) Compile() (string, error) { return sel.statement().Compile(sel.template()) } func (sel *selector) Prev() immutable.Immutable { if sel == nil { return nil } return sel.prev } func (sel *selector) Fn(in interface{}) error { if sel.fn == nil { return nil } return sel.fn(in.(*selectorQuery)) } func (sel *selector) Base() interface{} { return &selectorQuery{} }