You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/github.com/upper/db/v4/internal/sqlbuilder/builder.go

633 lines
15 KiB

// Copyright (c) 2012-present The upper.io/db authors. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// Package sqlbuilder provides tools for building custom SQL queries.
package sqlbuilder
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
db "github.com/upper/db/v4"
"github.com/upper/db/v4/internal/adapter"
"github.com/upper/db/v4/internal/reflectx"
"github.com/upper/db/v4/internal/sqladapter/compat"
"github.com/upper/db/v4/internal/sqladapter/exql"
)
// MapOptions represents options for the mapper.
type MapOptions struct {
IncludeZeroed bool
IncludeNil bool
}
var defaultMapOptions = MapOptions{
IncludeZeroed: false,
IncludeNil: false,
}
type hasPaginator interface {
Paginator() (db.Paginator, error)
}
type isCompilable interface {
Compile() (string, error)
Arguments() []interface{}
}
type hasIsZero interface {
IsZero() bool
}
type iterator struct {
sess exprDB
cursor *sql.Rows // This is the main query cursor. It starts as a nil value.
err error
}
type fieldValue struct {
fields []string
values []interface{}
}
var (
sqlPlaceholder = &exql.Raw{Value: `?`}
)
var (
errDeprecatedJSONBTag = errors.New(`Tag "jsonb" is deprecated. See "PostgreSQL: jsonb tag" at https://github.com/upper/db/releases/tag/v3.4.0`)
)
type exprDB interface {
StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error)
StatementPrepare(ctx context.Context, stmt *exql.Statement) (*sql.Stmt, error)
StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error)
StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error)
Context() context.Context
}
type sqlBuilder struct {
sess exprDB
t *templateWithUtils
}
// WithSession returns a query builder that is bound to the given database session.
func WithSession(sess interface{}, t *exql.Template) db.SQL {
if sqlDB, ok := sess.(*sql.DB); ok {
sess = sqlDB
}
return &sqlBuilder{
sess: sess.(exprDB), // Let it panic, it will show the developer an informative error.
t: newTemplateWithUtils(t),
}
}
// WithTemplate returns a builder that is based on the given template.
func WithTemplate(t *exql.Template) db.SQL {
return &sqlBuilder{
t: newTemplateWithUtils(t),
}
}
func (b *sqlBuilder) NewIteratorContext(ctx context.Context, rows *sql.Rows) db.Iterator {
return &iterator{b.sess, rows, nil}
}
func (b *sqlBuilder) NewIterator(rows *sql.Rows) db.Iterator {
return b.NewIteratorContext(b.sess.Context(), rows)
}
func (b *sqlBuilder) Iterator(query interface{}, args ...interface{}) db.Iterator {
return b.IteratorContext(b.sess.Context(), query, args...)
}
func (b *sqlBuilder) IteratorContext(ctx context.Context, query interface{}, args ...interface{}) db.Iterator {
rows, err := b.QueryContext(ctx, query, args...)
return &iterator{b.sess, rows, err}
}
func (b *sqlBuilder) Prepare(query interface{}) (*sql.Stmt, error) {
return b.PrepareContext(b.sess.Context(), query)
}
func (b *sqlBuilder) PrepareContext(ctx context.Context, query interface{}) (*sql.Stmt, error) {
switch q := query.(type) {
case *exql.Statement:
return b.sess.StatementPrepare(ctx, q)
case string:
return b.sess.StatementPrepare(ctx, exql.RawSQL(q))
case *adapter.RawExpr:
return b.PrepareContext(ctx, q.Raw())
default:
return nil, fmt.Errorf("unsupported query type %T", query)
}
}
func (b *sqlBuilder) Exec(query interface{}, args ...interface{}) (sql.Result, error) {
return b.ExecContext(b.sess.Context(), query, args...)
}
func (b *sqlBuilder) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) {
switch q := query.(type) {
case *exql.Statement:
return b.sess.StatementExec(ctx, q, args...)
case string:
return b.sess.StatementExec(ctx, exql.RawSQL(q), args...)
case *adapter.RawExpr:
return b.ExecContext(ctx, q.Raw(), q.Arguments()...)
default:
return nil, fmt.Errorf("unsupported query type %T", query)
}
}
func (b *sqlBuilder) Query(query interface{}, args ...interface{}) (*sql.Rows, error) {
return b.QueryContext(b.sess.Context(), query, args...)
}
func (b *sqlBuilder) QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error) {
switch q := query.(type) {
case *exql.Statement:
return b.sess.StatementQuery(ctx, q, args...)
case string:
return b.sess.StatementQuery(ctx, exql.RawSQL(q), args...)
case *adapter.RawExpr:
return b.QueryContext(ctx, q.Raw(), q.Arguments()...)
default:
return nil, fmt.Errorf("unsupported query type %T", query)
}
}
func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, error) {
return b.QueryRowContext(b.sess.Context(), query, args...)
}
func (b *sqlBuilder) QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error) {
switch q := query.(type) {
case *exql.Statement:
return b.sess.StatementQueryRow(ctx, q, args...)
case string:
return b.sess.StatementQueryRow(ctx, exql.RawSQL(q), args...)
case *adapter.RawExpr:
return b.QueryRowContext(ctx, q.Raw(), q.Arguments()...)
default:
return nil, fmt.Errorf("unsupported query type %T", query)
}
}
func (b *sqlBuilder) SelectFrom(table ...interface{}) db.Selector {
qs := &selector{
builder: b,
}
return qs.From(table...)
}
func (b *sqlBuilder) Select(columns ...interface{}) db.Selector {
qs := &selector{
builder: b,
}
return qs.Columns(columns...)
}
func (b *sqlBuilder) InsertInto(table string) db.Inserter {
qi := &inserter{
builder: b,
}
return qi.Into(table)
}
func (b *sqlBuilder) DeleteFrom(table string) db.Deleter {
qd := &deleter{
builder: b,
}
return qd.setTable(table)
}
func (b *sqlBuilder) Update(table string) db.Updater {
qu := &updater{
builder: b,
}
return qu.setTable(table)
}
// Map receives a pointer to map or struct and maps it to columns and values.
func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) {
var fv fieldValue
if options == nil {
options = &defaultMapOptions
}
itemV := reflect.ValueOf(item)
if !itemV.IsValid() {
return nil, nil, nil
}
itemT := itemV.Type()
if itemT.Kind() == reflect.Ptr {
// Single dereference. Just in case the user passes a pointer to struct
// instead of a struct.
item = itemV.Elem().Interface()
itemV = reflect.ValueOf(item)
itemT = itemV.Type()
}
switch itemT.Kind() {
case reflect.Struct:
fieldMap := Mapper.TypeMap(itemT).Names
nfields := len(fieldMap)
fv.values = make([]interface{}, 0, nfields)
fv.fields = make([]string, 0, nfields)
for _, fi := range fieldMap {
// Check for deprecated JSONB tag
if _, hasJSONBTag := fi.Options["jsonb"]; hasJSONBTag {
return nil, nil, errDeprecatedJSONBTag
}
// Field options
_, tagOmitEmpty := fi.Options["omitempty"]
fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index)
if fld.Kind() == reflect.Ptr && fld.IsNil() {
if tagOmitEmpty && !options.IncludeNil {
continue
}
fv.fields = append(fv.fields, fi.Name)
if tagOmitEmpty {
fv.values = append(fv.values, sqlDefault)
} else {
fv.values = append(fv.values, nil)
}
continue
}
value := fld.Interface()
isZero := false
if t, ok := fld.Interface().(hasIsZero); ok {
if t.IsZero() {
isZero = true
}
} else if fld.Kind() == reflect.Array || fld.Kind() == reflect.Slice {
if fld.Len() == 0 {
isZero = true
}
} else if reflect.DeepEqual(fi.Zero.Interface(), value) {
isZero = true
}
if isZero && tagOmitEmpty && !options.IncludeZeroed {
continue
}
fv.fields = append(fv.fields, fi.Name)
v, err := marshal(value)
if err != nil {
return nil, nil, err
}
if isZero && tagOmitEmpty {
v = sqlDefault
}
fv.values = append(fv.values, v)
}
case reflect.Map:
nfields := itemV.Len()
fv.values = make([]interface{}, nfields)
fv.fields = make([]string, nfields)
mkeys := itemV.MapKeys()
for i, keyV := range mkeys {
valv := itemV.MapIndex(keyV)
fv.fields[i] = fmt.Sprintf("%v", keyV.Interface())
v, err := marshal(valv.Interface())
if err != nil {
return nil, nil, err
}
fv.values[i] = v
}
default:
return nil, nil, ErrExpectingPointerToEitherMapOrStruct
}
sort.Sort(&fv)
return fv.fields, fv.values, nil
}
func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, error) {
f := make([]exql.Fragment, len(columns))
args := []interface{}{}
for i := range columns {
switch v := columns[i].(type) {
case hasPaginator:
p, err := v.Paginator()
if err != nil {
return nil, nil, err
}
q, a := Preprocess(p.String(), p.Arguments())
f[i] = &exql.Raw{Value: "(" + q + ")"}
args = append(args, a...)
case isCompilable:
c, err := v.Compile()
if err != nil {
return nil, nil, err
}
q, a := Preprocess(c, v.Arguments())
if _, ok := v.(db.Selector); ok {
q = "(" + q + ")"
}
f[i] = &exql.Raw{Value: q}
args = append(args, a...)
case *adapter.FuncExpr:
fnName, fnArgs := v.Name(), v.Arguments()
if len(fnArgs) == 0 {
fnName = fnName + "()"
} else {
fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
}
fnName, fnArgs = Preprocess(fnName, fnArgs)
f[i] = &exql.Raw{Value: fnName}
args = append(args, fnArgs...)
case *adapter.RawExpr:
q, a := Preprocess(v.Raw(), v.Arguments())
f[i] = &exql.Raw{Value: q}
args = append(args, a...)
case exql.Fragment:
f[i] = v
case string:
f[i] = exql.ColumnWithName(v)
case fmt.Stringer:
f[i] = exql.ColumnWithName(v.String())
default:
var err error
f[i], err = exql.NewRawValue(columns[i])
if err != nil {
return nil, nil, fmt.Errorf("unexpected argument type %T for Select() argument: %w", v, err)
}
}
}
return f, args, nil
}
func prepareQueryForDisplay(in string) string {
out := make([]byte, 0, len(in))
offset := 0
whitespace := true
placeholders := 1
for i := 0; i < len(in); i++ {
if in[i] == ' ' || in[i] == '\r' || in[i] == '\n' || in[i] == '\t' {
if whitespace {
offset = i
} else {
whitespace = true
out = append(out, in[offset:i]...)
offset = i
}
continue
}
if whitespace {
whitespace = false
if len(out) > 0 {
out = append(out, ' ')
}
offset = i
}
if in[i] == '?' {
out = append(out, in[offset:i]...)
offset = i + 1
out = append(out, '$')
out = append(out, strconv.Itoa(placeholders)...)
placeholders++
}
}
if !whitespace {
out = append(out, in[offset:len(in)]...)
}
return string(out)
}
func (iter *iterator) NextScan(dst ...interface{}) error {
if ok := iter.Next(); ok {
return iter.Scan(dst...)
}
if err := iter.Err(); err != nil {
return err
}
return db.ErrNoMoreRows
}
func (iter *iterator) ScanOne(dst ...interface{}) error {
defer iter.Close()
return iter.NextScan(dst...)
}
func (iter *iterator) Scan(dst ...interface{}) error {
if err := iter.Err(); err != nil {
return err
}
return iter.cursor.Scan(dst...)
}
func (iter *iterator) setErr(err error) error {
iter.err = err
return iter.err
}
func (iter *iterator) One(dst interface{}) error {
if err := iter.Err(); err != nil {
return err
}
defer iter.Close()
return iter.setErr(iter.next(dst))
}
func (iter *iterator) All(dst interface{}) error {
if err := iter.Err(); err != nil {
return err
}
defer iter.Close()
// Fetching all results within the cursor.
if err := fetchRows(iter, dst); err != nil {
return iter.setErr(err)
}
return nil
}
func (iter *iterator) Err() (err error) {
return iter.err
}
func (iter *iterator) Next(dst ...interface{}) bool {
if err := iter.Err(); err != nil {
return false
}
if err := iter.next(dst...); err != nil {
// ignore db.ErrNoMoreRows, just break.
if !errors.Is(err, db.ErrNoMoreRows) {
_ = iter.setErr(err)
}
return false
}
return true
}
func (iter *iterator) next(dst ...interface{}) error {
if iter.cursor == nil {
return iter.setErr(db.ErrNoMoreRows)
}
switch len(dst) {
case 0:
if ok := iter.cursor.Next(); !ok {
defer iter.Close()
err := iter.cursor.Err()
if err == nil {
err = db.ErrNoMoreRows
}
return err
}
return nil
case 1:
if err := fetchRow(iter, dst[0]); err != nil {
defer iter.Close()
return err
}
return nil
}
return errors.New("Next does not currently supports more than one parameters")
}
func (iter *iterator) Close() (err error) {
if iter.cursor != nil {
err = iter.cursor.Close()
iter.cursor = nil
}
return err
}
func marshal(v interface{}) (interface{}, error) {
if m, isMarshaler := v.(db.Marshaler); isMarshaler {
var err error
if v, err = m.MarshalDB(); err != nil {
return nil, err
}
}
return v, nil
}
func (fv *fieldValue) Len() int {
return len(fv.fields)
}
func (fv *fieldValue) Swap(i, j int) {
fv.fields[i], fv.fields[j] = fv.fields[j], fv.fields[i]
fv.values[i], fv.values[j] = fv.values[j], fv.values[i]
}
func (fv *fieldValue) Less(i, j int) bool {
return fv.fields[i] < fv.fields[j]
}
type exprProxy struct {
db *sql.DB
t *exql.Template
}
func (p *exprProxy) Context() context.Context {
return context.Background()
}
func (p *exprProxy) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error) {
s, err := stmt.Compile(p.t)
if err != nil {
return nil, err
}
return compat.ExecContext(p.db, ctx, s, args)
}
func (p *exprProxy) StatementPrepare(ctx context.Context, stmt *exql.Statement) (*sql.Stmt, error) {
s, err := stmt.Compile(p.t)
if err != nil {
return nil, err
}
return compat.PrepareContext(p.db, ctx, s)
}
func (p *exprProxy) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) {
s, err := stmt.Compile(p.t)
if err != nil {
return nil, err
}
return compat.QueryContext(p.db, ctx, s, args)
}
func (p *exprProxy) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error) {
s, err := stmt.Compile(p.t)
if err != nil {
return nil, err
}
return compat.QueryRowContext(p.db, ctx, s, args), nil
}
var (
_ = db.SQL(&sqlBuilder{})
_ = exprDB(&exprProxy{})
)
func joinArguments(args ...[]interface{}) []interface{} {
total := 0
for i := range args {
total += len(args[i])
}
if total == 0 {
return nil
}
flatten := make([]interface{}, 0, total)
for i := range args {
flatten = append(flatten, args[i]...)
}
return flatten
}