You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/github.com/upper/db/v4/internal/sqlbuilder/select.go

525 lines
11 KiB

package sqlbuilder
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"github.com/upper/db/v4"
"github.com/upper/db/v4/internal/adapter"
"github.com/upper/db/v4/internal/immutable"
"github.com/upper/db/v4/internal/sqladapter/exql"
)
type selectorQuery struct {
table *exql.Columns
tableArgs []interface{}
distinct bool
where *exql.Where
whereArgs []interface{}
groupBy *exql.GroupBy
groupByArgs []interface{}
orderBy *exql.OrderBy
orderByArgs []interface{}
limit exql.Limit
offset exql.Offset
columns *exql.Columns
columnsArgs []interface{}
joins []*exql.Join
joinsArgs []interface{}
amendFn func(string) string
}
func (sq *selectorQuery) and(b *sqlBuilder, terms ...interface{}) error {
where, whereArgs := b.t.toWhereWithArguments(terms)
if sq.where == nil {
sq.where, sq.whereArgs = &exql.Where{}, []interface{}{}
}
sq.where.Append(&where)
sq.whereArgs = append(sq.whereArgs, whereArgs...)
return nil
}
func (sq *selectorQuery) arguments() []interface{} {
return joinArguments(
sq.columnsArgs,
sq.tableArgs,
sq.joinsArgs,
sq.whereArgs,
sq.groupByArgs,
sq.orderByArgs,
)
}
func (sq *selectorQuery) statement() *exql.Statement {
stmt := &exql.Statement{
Type: exql.Select,
Table: sq.table,
Columns: sq.columns,
Distinct: sq.distinct,
Limit: sq.limit,
Offset: sq.offset,
Where: sq.where,
OrderBy: sq.orderBy,
GroupBy: sq.groupBy,
}
if len(sq.joins) > 0 {
stmt.Joins = exql.JoinConditions(sq.joins...)
}
stmt.SetAmendment(sq.amendFn)
return stmt
}
func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error {
fragments, args, err := columnFragments(tables)
if err != nil {
return err
}
if sq.joins == nil {
sq.joins = []*exql.Join{}
}
sq.joins = append(sq.joins,
&exql.Join{
Type: t,
Table: exql.JoinColumns(fragments...),
},
)
sq.joinsArgs = append(sq.joinsArgs, args...)
return nil
}
type selector struct {
builder *sqlBuilder
fn func(*selectorQuery) error
prev *selector
}
var _ = immutable.Immutable(&selector{})
func (sel *selector) SQL() *sqlBuilder {
if sel.prev == nil {
return sel.builder
}
return sel.prev.SQL()
}
func (sel *selector) String() string {
s, err := sel.Compile()
if err != nil {
panic(err.Error())
}
return prepareQueryForDisplay(s)
}
func (sel *selector) frame(fn func(*selectorQuery) error) *selector {
return &selector{prev: sel, fn: fn}
}
func (sel *selector) clone() db.Selector {
return sel.frame(func(*selectorQuery) error {
return nil
})
}
func (sel *selector) From(tables ...interface{}) db.Selector {
return sel.frame(
func(sq *selectorQuery) error {
fragments, args, err := columnFragments(tables)
if err != nil {
return err
}
sq.table = exql.JoinColumns(fragments...)
sq.tableArgs = args
return nil
},
)
}
func (sel *selector) setColumns(columns ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
sq.columns = nil
return sq.pushColumns(columns...)
})
}
func (sel *selector) Columns(columns ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
return sq.pushColumns(columns...)
})
}
func (sq *selectorQuery) pushColumns(columns ...interface{}) error {
f, args, err := columnFragments(columns)
if err != nil {
return err
}
c := exql.JoinColumns(f...)
if sq.columns != nil {
sq.columns.Append(c)
} else {
sq.columns = c
}
sq.columnsArgs = append(sq.columnsArgs, args...)
return nil
}
func (sel *selector) Distinct(exps ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
sq.distinct = true
return sq.pushColumns(exps...)
})
}
func (sel *selector) Where(terms ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
if len(terms) == 1 && terms[0] == nil {
sq.where, sq.whereArgs = &exql.Where{}, []interface{}{}
return nil
}
return sq.and(sel.SQL(), terms...)
})
}
func (sel *selector) And(terms ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
return sq.and(sel.SQL(), terms...)
})
}
func (sel *selector) Amend(fn func(string) string) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
sq.amendFn = fn
return nil
})
}
func (sel *selector) Arguments() []interface{} {
sq, err := sel.build()
if err != nil {
return nil
}
return sq.arguments()
}
func (sel *selector) GroupBy(columns ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
fragments, args, err := columnFragments(columns)
if err != nil {
return err
}
if fragments != nil {
sq.groupBy = exql.GroupByColumns(fragments...)
}
sq.groupByArgs = args
return nil
})
}
func (sel *selector) OrderBy(columns ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
if len(columns) == 1 && columns[0] == nil {
sq.orderBy = nil
sq.orderByArgs = nil
return nil
}
var sortColumns exql.SortColumns
for i := range columns {
var sort *exql.SortColumn
switch value := columns[i].(type) {
case *adapter.RawExpr:
query, args := Preprocess(value.Raw(), value.Arguments())
sort = &exql.SortColumn{
Column: &exql.Raw{Value: query},
}
sq.orderByArgs = append(sq.orderByArgs, args...)
case *adapter.FuncExpr:
fnName, fnArgs := value.Name(), value.Arguments()
if len(fnArgs) == 0 {
fnName = fnName + "()"
} else {
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
}
fnName, fnArgs = Preprocess(fnName, fnArgs)
sort = &exql.SortColumn{
Column: &exql.Raw{Value: fnName},
}
sq.orderByArgs = append(sq.orderByArgs, fnArgs...)
case string:
if strings.HasPrefix(value, "-") {
sort = &exql.SortColumn{
Column: exql.ColumnWithName(value[1:]),
Order: exql.Order_Descendent,
}
} else {
chunks := strings.SplitN(value, " ", 2)
order := exql.Order_Ascendent
if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" {
order = exql.Order_Descendent
}
sort = &exql.SortColumn{
Column: exql.ColumnWithName(chunks[0]),
Order: order,
}
}
default:
return fmt.Errorf("Can't sort by type %T", value)
}
sortColumns.Columns = append(sortColumns.Columns, sort)
}
sq.orderBy = &exql.OrderBy{
SortColumns: &sortColumns,
}
return nil
})
}
func (sel *selector) Using(columns ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
joins := len(sq.joins)
if joins == 0 {
return errors.New(`cannot use Using() without a preceding Join() expression`)
}
lastJoin := sq.joins[joins-1]
if lastJoin.On != nil {
return errors.New(`cannot use Using() and On() with the same Join() expression`)
}
fragments, args, err := columnFragments(columns)
if err != nil {
return err
}
sq.joinsArgs = append(sq.joinsArgs, args...)
lastJoin.Using = exql.UsingColumns(fragments...)
return nil
})
}
func (sel *selector) FullJoin(tables ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
return sq.pushJoin("FULL", tables)
})
}
func (sel *selector) CrossJoin(tables ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
return sq.pushJoin("CROSS", tables)
})
}
func (sel *selector) RightJoin(tables ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
return sq.pushJoin("RIGHT", tables)
})
}
func (sel *selector) LeftJoin(tables ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
return sq.pushJoin("LEFT", tables)
})
}
func (sel *selector) Join(tables ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
return sq.pushJoin("", tables)
})
}
func (sel *selector) On(terms ...interface{}) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
joins := len(sq.joins)
if joins == 0 {
return errors.New(`cannot use On() without a preceding Join() expression`)
}
lastJoin := sq.joins[joins-1]
if lastJoin.On != nil {
return errors.New(`cannot use Using() and On() with the same Join() expression`)
}
w, a := sel.SQL().t.toWhereWithArguments(terms)
o := exql.On(w)
lastJoin.On = &o
sq.joinsArgs = append(sq.joinsArgs, a...)
return nil
})
}
func (sel *selector) Limit(n int) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
if n < 0 {
n = 0
}
sq.limit = exql.Limit(n)
return nil
})
}
func (sel *selector) Offset(n int) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
if n < 0 {
n = 0
}
sq.offset = exql.Offset(n)
return nil
})
}
func (sel *selector) template() *exql.Template {
return sel.SQL().t.Template
}
func (sel *selector) As(alias string) db.Selector {
return sel.frame(func(sq *selectorQuery) error {
if sq.table == nil {
return errors.New("Cannot use As() without a preceding From() expression")
}
last := len(sq.table.Columns) - 1
if raw, ok := sq.table.Columns[last].(*exql.Raw); ok {
compiled, err := exql.ColumnWithName(alias).Compile(sel.template())
if err != nil {
return err
}
sq.table.Columns[last] = &exql.Raw{Value: raw.Value + " AS " + compiled}
}
return nil
})
}
func (sel *selector) statement() *exql.Statement {
sq, _ := sel.build()
return sq.statement()
}
func (sel *selector) QueryRow() (*sql.Row, error) {
return sel.QueryRowContext(sel.SQL().sess.Context())
}
func (sel *selector) QueryRowContext(ctx context.Context) (*sql.Row, error) {
sq, err := sel.build()
if err != nil {
return nil, err
}
return sel.SQL().sess.StatementQueryRow(ctx, sq.statement(), sq.arguments()...)
}
func (sel *selector) Prepare() (*sql.Stmt, error) {
return sel.PrepareContext(sel.SQL().sess.Context())
}
func (sel *selector) PrepareContext(ctx context.Context) (*sql.Stmt, error) {
sq, err := sel.build()
if err != nil {
return nil, err
}
return sel.SQL().sess.StatementPrepare(ctx, sq.statement())
}
func (sel *selector) Query() (*sql.Rows, error) {
return sel.QueryContext(sel.SQL().sess.Context())
}
func (sel *selector) QueryContext(ctx context.Context) (*sql.Rows, error) {
sq, err := sel.build()
if err != nil {
return nil, err
}
return sel.SQL().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...)
}
func (sel *selector) Iterator() db.Iterator {
return sel.IteratorContext(sel.SQL().sess.Context())
}
func (sel *selector) IteratorContext(ctx context.Context) db.Iterator {
sess := sel.SQL().sess
sq, err := sel.build()
if err != nil {
return &iterator{sess, nil, err}
}
rows, err := sess.StatementQuery(ctx, sq.statement(), sq.arguments()...)
return &iterator{sess, rows, err}
}
func (sel *selector) Paginate(pageSize uint) db.Paginator {
return newPaginator(sel.clone(), pageSize)
}
func (sel *selector) All(destSlice interface{}) error {
return sel.Iterator().All(destSlice)
}
func (sel *selector) One(dest interface{}) error {
return sel.Iterator().One(dest)
}
func (sel *selector) build() (*selectorQuery, error) {
sq, err := immutable.FastForward(sel)
if err != nil {
return nil, err
}
return sq.(*selectorQuery), nil
}
func (sel *selector) Compile() (string, error) {
return sel.statement().Compile(sel.template())
}
func (sel *selector) Prev() immutable.Immutable {
if sel == nil {
return nil
}
return sel.prev
}
func (sel *selector) Fn(in interface{}) error {
if sel.fn == nil {
return nil
}
return sel.fn(in.(*selectorQuery))
}
func (sel *selector) Base() interface{} {
return &selectorQuery{}
}