You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/github.com/upper/db/v4/internal/sqladapter/session.go

1107 lines
25 KiB

package sqladapter
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"math"
"reflect"
"strconv"
"sync"
"sync/atomic"
"time"
db "github.com/upper/db/v4"
"github.com/upper/db/v4/internal/cache"
"github.com/upper/db/v4/internal/sqladapter/compat"
"github.com/upper/db/v4/internal/sqladapter/exql"
"github.com/upper/db/v4/internal/sqlbuilder"
)
var (
lastSessID uint64
lastTxID uint64
)
var (
slowQueryThreshold = time.Millisecond * 200
retryTransactionWaitTime = time.Millisecond * 10
retryTransactionMaxWaitTime = time.Second * 1
)
// hasCleanUp is implemented by structs that have a clean up routine that needs
// to be called before Close().
type hasCleanUp interface {
CleanUp() error
}
// statementExecer allows the adapter to have its own exec statement.
type statementExecer interface {
StatementExec(sess Session, ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
// statementCompiler transforms an internal statement into a format
// database/sql can understand.
type statementCompiler interface {
CompileStatement(sess Session, stmt *exql.Statement, args []interface{}) (string, []interface{}, error)
}
// sessValueConverter converts values before being passed to the underlying driver.
type sessValueConverter interface {
ConvertValue(in interface{}) interface{}
}
// sessValueConverterContext converts values before being passed to the underlying driver.
type sessValueConverterContext interface {
ConvertValueContext(ctx context.Context, in interface{}) interface{}
}
// valueConverter converts values before being passed to the underlying driver.
type valueConverter interface {
ConvertValue(in interface{}) interface {
sql.Scanner
driver.Valuer
}
}
// errorConverter converts an error value from the underlying driver into
// something different.
type errorConverter interface {
Err(errIn error) (errOut error)
}
// AdapterSession defines methods to be implemented by SQL adapters.
type AdapterSession interface {
Template() *exql.Template
NewCollection() CollectionAdapter
// Open opens a new connection
OpenDSN(sess Session, dsn string) (*sql.DB, error)
// Collections returns a list of non-system tables from the database.
Collections(sess Session) ([]string, error)
// TableExists returns an error if the given table does not exist.
TableExists(sess Session, name string) error
// LookupName returns the name of the database.
LookupName(sess Session) (string, error)
// PrimaryKeys returns all primary keys on the table.
PrimaryKeys(sess Session, name string) ([]string, error)
}
// Session satisfies db.Session.
type Session interface {
SQL() db.SQL
// PrimaryKeys returns all primary keys on the table.
PrimaryKeys(tableName string) ([]string, error)
// Collections returns a list of references to all collections in the
// database.
Collections() ([]db.Collection, error)
// Name returns the name of the database.
Name() string
// Close closes the database session
Close() error
// Ping checks if the database server is reachable.
Ping() error
// Reset clears all caches the session is using
Reset()
// Collection returns a new collection.
Collection(string) db.Collection
// ConnectionURL returns the ConnectionURL that was used to create the
// Session.
ConnectionURL() db.ConnectionURL
// Open attempts to establish a connection to the database server.
Open() error
// TableExists returns an error if the table doesn't exists.
TableExists(name string) error
// Driver returns the underlying driver the session is using
Driver() interface{}
Save(db.Record) error
Get(db.Record, interface{}) error
Delete(db.Record) error
// WaitForConnection attempts to run the given connection function a fixed
// number of times before failing.
WaitForConnection(func() error) error
// BindDB sets the *sql.DB the session will use.
BindDB(*sql.DB) error
// Session returns the *sql.DB the session is using.
DB() *sql.DB
// BindTx binds a transaction to the current session.
BindTx(context.Context, *sql.Tx) error
// Returns the current transaction the session is using.
Transaction() *sql.Tx
// NewClone clones the database using the given AdapterSession as base.
NewClone(AdapterSession, bool) (Session, error)
// Context returns the default context the session is using.
Context() context.Context
// SetContext sets the default context for the session.
SetContext(context.Context)
NewTransaction(ctx context.Context, opts *sql.TxOptions) (Session, error)
Tx(fn func(sess db.Session) error) error
TxContext(ctx context.Context, fn func(sess db.Session) error, opts *sql.TxOptions) error
WithContext(context.Context) db.Session
IsTransaction() bool
Commit() error
Rollback() error
db.Settings
}
// NewTx wraps a *sql.Tx and returns a Tx.
func NewTx(adapter AdapterSession, tx *sql.Tx) (Session, error) {
sessTx := &sessionWithContext{
session: &session{
Settings: db.DefaultSettings,
sqlTx: tx,
adapter: adapter,
cachedPKs: cache.NewCache(),
cachedCollections: cache.NewCache(),
cachedStatements: cache.NewCache(),
},
ctx: context.Background(),
}
return sessTx, nil
}
// NewSession creates a new Session.
func NewSession(connURL db.ConnectionURL, adapter AdapterSession) Session {
sess := &sessionWithContext{
session: &session{
Settings: db.DefaultSettings,
connURL: connURL,
adapter: adapter,
cachedPKs: cache.NewCache(),
cachedCollections: cache.NewCache(),
cachedStatements: cache.NewCache(),
},
ctx: context.Background(),
}
return sess
}
type session struct {
db.Settings
adapter AdapterSession
connURL db.ConnectionURL
builder db.SQL
lookupNameOnce sync.Once
name string
mu sync.Mutex // guards ctx, txOptions
txOptions *sql.TxOptions
sqlDBMu sync.Mutex // guards sess, baseTx
sqlDB *sql.DB
sqlTx *sql.Tx
sessID uint64
txID uint64
cacheMu sync.Mutex // guards cachedStatements and cachedCollections
cachedPKs *cache.Cache
cachedStatements *cache.Cache
cachedCollections *cache.Cache
template *exql.Template
}
type sessionWithContext struct {
*session
ctx context.Context
}
func (sess *sessionWithContext) WithContext(ctx context.Context) db.Session {
if ctx == nil {
panic("nil context")
}
newSess := &sessionWithContext{
session: sess.session,
ctx: ctx,
}
return newSess
}
func (sess *sessionWithContext) Tx(fn func(sess db.Session) error) error {
return TxContext(sess.Context(), sess, fn, nil)
}
func (sess *sessionWithContext) TxContext(ctx context.Context, fn func(sess db.Session) error, opts *sql.TxOptions) error {
return TxContext(ctx, sess, fn, opts)
}
func (sess *sessionWithContext) SQL() db.SQL {
return sqlbuilder.WithSession(
sess,
sess.adapter.Template(),
)
}
func (sess *sessionWithContext) Err(errIn error) (errOur error) {
if convertError, ok := sess.adapter.(errorConverter); ok {
return convertError.Err(errIn)
}
return errIn
}
func (sess *sessionWithContext) PrimaryKeys(tableName string) ([]string, error) {
h := cache.NewHashable(hashTypePrimaryKeys, tableName)
cachedPK, ok := sess.cachedPKs.ReadRaw(h)
if ok {
return cachedPK.([]string), nil
}
pk, err := sess.adapter.PrimaryKeys(sess, tableName)
if err != nil {
return nil, err
}
sess.cachedPKs.Write(h, pk)
return pk, nil
}
func (sess *sessionWithContext) TableExists(name string) error {
return sess.adapter.TableExists(sess, name)
}
func (sess *sessionWithContext) NewTransaction(ctx context.Context, opts *sql.TxOptions) (Session, error) {
if ctx == nil {
ctx = context.Background()
}
clone, err := sess.NewClone(sess.adapter, false)
if err != nil {
return nil, err
}
connFn := func() error {
sqlTx, err := compat.BeginTx(clone.DB(), clone.Context(), opts)
if err == nil {
return clone.BindTx(ctx, sqlTx)
}
return err
}
if err := clone.WaitForConnection(connFn); err != nil {
return nil, err
}
return clone, nil
}
func (sess *sessionWithContext) Collections() ([]db.Collection, error) {
names, err := sess.adapter.Collections(sess)
if err != nil {
return nil, err
}
collections := make([]db.Collection, 0, len(names))
for i := range names {
collections = append(collections, sess.Collection(names[i]))
}
return collections, nil
}
func (sess *sessionWithContext) ConnectionURL() db.ConnectionURL {
return sess.connURL
}
func (sess *sessionWithContext) Open() error {
var sqlDB *sql.DB
var err error
connFn := func() error {
sqlDB, err = sess.adapter.OpenDSN(sess, sess.connURL.String())
if err != nil {
return err
}
sqlDB.SetConnMaxLifetime(sess.ConnMaxLifetime())
sqlDB.SetConnMaxIdleTime(sess.ConnMaxIdleTime())
sqlDB.SetMaxIdleConns(sess.MaxIdleConns())
sqlDB.SetMaxOpenConns(sess.MaxOpenConns())
return nil
}
if err := sess.WaitForConnection(connFn); err != nil {
return err
}
return sess.BindDB(sqlDB)
}
func (sess *sessionWithContext) Get(record db.Record, id interface{}) error {
store := record.Store(sess)
if getter, ok := store.(db.StoreGetter); ok {
return getter.Get(record, id)
}
return store.Find(id).One(record)
}
func (sess *sessionWithContext) Save(record db.Record) error {
if record == nil {
return db.ErrNilRecord
}
if reflect.TypeOf(record).Kind() != reflect.Ptr {
return db.ErrExpectingPointerToStruct
}
store := record.Store(sess)
if saver, ok := store.(db.StoreSaver); ok {
return saver.Save(record)
}
id := db.Cond{}
keys, values, err := recordPrimaryKeyFieldValues(store, record)
if err != nil {
return err
}
for i := range values {
if values[i] != reflect.Zero(reflect.TypeOf(values[i])).Interface() {
id[keys[i]] = values[i]
}
}
if len(id) > 0 && len(id) == len(values) {
// check if record exists before updating it
exists, _ := store.Find(id).Count()
if exists > 0 {
return recordUpdate(store, record)
}
}
return recordCreate(store, record)
}
func (sess *sessionWithContext) Delete(record db.Record) error {
if record == nil {
return db.ErrNilRecord
}
if reflect.TypeOf(record).Kind() != reflect.Ptr {
return db.ErrExpectingPointerToStruct
}
store := record.Store(sess)
if hook, ok := record.(db.BeforeDeleteHook); ok {
if err := hook.BeforeDelete(sess); err != nil {
return err
}
}
if deleter, ok := store.(db.StoreDeleter); ok {
if err := deleter.Delete(record); err != nil {
return err
}
} else {
conds, err := recordID(store, record)
if err != nil {
return err
}
if err := store.Find(conds).Delete(); err != nil {
return err
}
}
if hook, ok := record.(db.AfterDeleteHook); ok {
if err := hook.AfterDelete(sess); err != nil {
return err
}
}
return nil
}
func (sess *sessionWithContext) DB() *sql.DB {
return sess.sqlDB
}
func (sess *sessionWithContext) SetContext(ctx context.Context) {
sess.mu.Lock()
sess.ctx = ctx
sess.mu.Unlock()
}
func (sess *sessionWithContext) Context() context.Context {
return sess.ctx
}
func (sess *sessionWithContext) SetTxOptions(txOptions sql.TxOptions) {
sess.mu.Lock()
sess.txOptions = &txOptions
sess.mu.Unlock()
}
func (sess *sessionWithContext) TxOptions() *sql.TxOptions {
sess.mu.Lock()
defer sess.mu.Unlock()
if sess.txOptions == nil {
return nil
}
return sess.txOptions
}
func (sess *sessionWithContext) BindTx(ctx context.Context, tx *sql.Tx) error {
sess.sqlDBMu.Lock()
defer sess.sqlDBMu.Unlock()
sess.sqlTx = tx
sess.SetContext(ctx)
sess.txID = newBaseTxID()
return nil
}
func (sess *sessionWithContext) Commit() error {
if sess.sqlTx != nil {
return sess.sqlTx.Commit()
}
return db.ErrNotWithinTransaction
}
func (sess *sessionWithContext) Rollback() error {
if sess.sqlTx != nil {
return sess.sqlTx.Rollback()
}
return db.ErrNotWithinTransaction
}
func (sess *sessionWithContext) IsTransaction() bool {
return sess.sqlTx != nil
}
func (sess *sessionWithContext) Transaction() *sql.Tx {
return sess.sqlTx
}
func (sess *sessionWithContext) Name() string {
sess.lookupNameOnce.Do(func() {
if sess.name == "" {
sess.name, _ = sess.adapter.LookupName(sess)
}
})
return sess.name
}
func (sess *sessionWithContext) BindDB(sqlDB *sql.DB) error {
sess.sqlDBMu.Lock()
sess.sqlDB = sqlDB
sess.sqlDBMu.Unlock()
if err := sess.Ping(); err != nil {
return err
}
sess.sessID = newSessionID()
name, err := sess.adapter.LookupName(sess)
if err != nil {
return err
}
sess.name = name
return nil
}
func (sess *sessionWithContext) Ping() error {
if sess.sqlDB != nil {
return sess.sqlDB.Ping()
}
return db.ErrNotConnected
}
func (sess *sessionWithContext) SetConnMaxLifetime(t time.Duration) {
sess.Settings.SetConnMaxLifetime(t)
if sessDB := sess.DB(); sessDB != nil {
sessDB.SetConnMaxLifetime(sess.Settings.ConnMaxLifetime())
}
}
func (sess *sessionWithContext) SetConnMaxIdleTime(t time.Duration) {
sess.Settings.SetConnMaxIdleTime(t)
if sessDB := sess.DB(); sessDB != nil {
sessDB.SetConnMaxIdleTime(sess.Settings.ConnMaxIdleTime())
}
}
func (sess *sessionWithContext) SetMaxIdleConns(n int) {
sess.Settings.SetMaxIdleConns(n)
if sessDB := sess.DB(); sessDB != nil {
sessDB.SetMaxIdleConns(sess.Settings.MaxIdleConns())
}
}
func (sess *sessionWithContext) SetMaxOpenConns(n int) {
sess.Settings.SetMaxOpenConns(n)
if sessDB := sess.DB(); sessDB != nil {
sessDB.SetMaxOpenConns(sess.Settings.MaxOpenConns())
}
}
// Reset removes all caches.
func (sess *sessionWithContext) Reset() {
sess.cacheMu.Lock()
defer sess.cacheMu.Unlock()
sess.cachedPKs.Clear()
sess.cachedCollections.Clear()
sess.cachedStatements.Clear()
if sess.template != nil {
sess.template.Cache.Clear()
}
}
func (sess *sessionWithContext) NewClone(adapter AdapterSession, checkConn bool) (Session, error) {
newSess := NewSession(sess.connURL, adapter).(*sessionWithContext)
newSess.name = sess.name
newSess.sqlDB = sess.sqlDB
newSess.cachedPKs = sess.cachedPKs
if checkConn {
if err := newSess.Ping(); err != nil {
// Retry once if ping fails.
return sess.NewClone(adapter, false)
}
}
newSess.sessID = newSessionID()
// New transaction should inherit parent settings
copySettings(sess, newSess)
return newSess, nil
}
func (sess *sessionWithContext) Close() error {
defer func() {
sess.sqlDBMu.Lock()
sess.sqlDB = nil
sess.sqlTx = nil
sess.sqlDBMu.Unlock()
}()
if sess.sqlDB == nil {
return nil
}
sess.cachedCollections.Clear()
sess.cachedStatements.Clear() // Closes prepared statements as well.
if !sess.IsTransaction() {
if cleaner, ok := sess.adapter.(hasCleanUp); ok {
if err := cleaner.CleanUp(); err != nil {
return err
}
}
// Not within a transaction.
return sess.sqlDB.Close()
}
return nil
}
func (sess *sessionWithContext) Collection(name string) db.Collection {
sess.cacheMu.Lock()
defer sess.cacheMu.Unlock()
h := cache.NewHashable(hashTypeCollection, name)
col, ok := sess.cachedCollections.ReadRaw(h)
if !ok {
col = newCollection(name, sess.adapter.NewCollection())
sess.cachedCollections.Write(h, col)
}
return &collectionWithSession{
collection: col.(*collection),
session: sess,
}
}
func queryLog(status *db.QueryStatus) {
diff := status.End.Sub(status.Start)
slowQuery := false
if diff >= slowQueryThreshold {
status.Err = db.ErrWarnSlowQuery
slowQuery = true
}
if status.Err != nil || slowQuery {
db.LC().Warn(status)
return
}
db.LC().Debug(status)
}
func (sess *sessionWithContext) StatementPrepare(ctx context.Context, stmt *exql.Statement) (sqlStmt *sql.Stmt, err error) {
var query string
defer func(start time.Time) {
queryLog(&db.QueryStatus{
TxID: sess.txID,
SessID: sess.sessID,
RawQuery: query,
Err: err,
Start: start,
End: time.Now(),
Context: ctx,
})
}(time.Now())
query, _, err = sess.compileStatement(stmt, nil)
if err != nil {
return nil, err
}
tx := sess.Transaction()
if tx != nil {
sqlStmt, err = compat.PrepareContext(tx, ctx, query)
return
}
sqlStmt, err = compat.PrepareContext(sess.sqlDB, ctx, query)
return
}
func (sess *sessionWithContext) ConvertValue(value interface{}) interface{} {
if scannerValuer, ok := value.(sqlbuilder.ScannerValuer); ok {
return scannerValuer
}
dv := reflect.Indirect(reflect.ValueOf(value))
if dv.IsValid() {
if converter, ok := dv.Interface().(valueConverter); ok {
return converter.ConvertValue(dv.Interface())
}
}
if converter, ok := sess.adapter.(sessValueConverterContext); ok {
return converter.ConvertValueContext(sess.Context(), value)
}
if converter, ok := sess.adapter.(sessValueConverter); ok {
return converter.ConvertValue(value)
}
return value
}
func (sess *sessionWithContext) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) {
var query string
defer func(start time.Time) {
status := db.QueryStatus{
TxID: sess.txID,
SessID: sess.sessID,
RawQuery: query,
Args: args,
Err: err,
Start: start,
End: time.Now(),
Context: ctx,
}
if res != nil {
if rowsAffected, err := res.RowsAffected(); err == nil {
status.RowsAffected = &rowsAffected
}
if lastInsertID, err := res.LastInsertId(); err == nil {
status.LastInsertID = &lastInsertID
}
}
queryLog(&status)
}(time.Now())
if execer, ok := sess.adapter.(statementExecer); ok {
query, args, err = sess.compileStatement(stmt, args)
if err != nil {
return nil, err
}
res, err = execer.StatementExec(sess, ctx, query, args...)
return
}
tx := sess.Transaction()
if sess.Settings.PreparedStatementCacheEnabled() && tx == nil {
var p *Stmt
if p, query, args, err = sess.prepareStatement(ctx, stmt, args); err != nil {
return nil, err
}
defer p.Close()
res, err = compat.PreparedExecContext(p, ctx, args)
return
}
query, args, err = sess.compileStatement(stmt, args)
if err != nil {
return nil, err
}
if tx != nil {
res, err = compat.ExecContext(tx, ctx, query, args)
return
}
res, err = compat.ExecContext(sess.sqlDB, ctx, query, args)
return
}
// StatementQuery compiles and executes a statement that returns rows.
func (sess *sessionWithContext) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) {
var query string
defer func(start time.Time) {
status := db.QueryStatus{
TxID: sess.txID,
SessID: sess.sessID,
RawQuery: query,
Args: args,
Err: err,
Start: start,
End: time.Now(),
Context: ctx,
}
queryLog(&status)
}(time.Now())
tx := sess.Transaction()
if sess.Settings.PreparedStatementCacheEnabled() && tx == nil {
var p *Stmt
if p, query, args, err = sess.prepareStatement(ctx, stmt, args); err != nil {
return nil, err
}
defer p.Close()
rows, err = compat.PreparedQueryContext(p, ctx, args)
return
}
query, args, err = sess.compileStatement(stmt, args)
if err != nil {
return nil, err
}
if tx != nil {
rows, err = compat.QueryContext(tx, ctx, query, args)
return
}
rows, err = compat.QueryContext(sess.sqlDB, ctx, query, args)
return
}
// StatementQueryRow compiles and executes a statement that returns at most one
// row.
func (sess *sessionWithContext) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) {
var query string
defer func(start time.Time) {
status := db.QueryStatus{
TxID: sess.txID,
SessID: sess.sessID,
RawQuery: query,
Args: args,
Err: err,
Start: start,
End: time.Now(),
Context: ctx,
}
queryLog(&status)
}(time.Now())
tx := sess.Transaction()
if sess.Settings.PreparedStatementCacheEnabled() && tx == nil {
var p *Stmt
if p, query, args, err = sess.prepareStatement(ctx, stmt, args); err != nil {
return nil, err
}
defer p.Close()
row = compat.PreparedQueryRowContext(p, ctx, args)
return
}
query, args, err = sess.compileStatement(stmt, args)
if err != nil {
return nil, err
}
if tx != nil {
row = compat.QueryRowContext(tx, ctx, query, args)
return
}
row = compat.QueryRowContext(sess.sqlDB, ctx, query, args)
return
}
// Driver returns the underlying *sql.DB or *sql.Tx instance.
func (sess *sessionWithContext) Driver() interface{} {
if sess.sqlTx != nil {
return sess.sqlTx
}
return sess.sqlDB
}
// compileStatement compiles the given statement into a string.
func (sess *sessionWithContext) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}, error) {
for i := range args {
args[i] = sess.ConvertValue(args[i])
}
if statementCompiler, ok := sess.adapter.(statementCompiler); ok {
return statementCompiler.CompileStatement(sess, stmt, args)
}
compiled, err := stmt.Compile(sess.adapter.Template())
if err != nil {
return "", nil, err
}
query, args := sqlbuilder.Preprocess(compiled, args)
return query, args, nil
}
// prepareStatement compiles a query and tries to use previously generated
// statement.
func (sess *sessionWithContext) prepareStatement(ctx context.Context, stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) {
sess.sqlDBMu.Lock()
defer sess.sqlDBMu.Unlock()
sqlDB, tx := sess.sqlDB, sess.Transaction()
if sqlDB == nil && tx == nil {
return nil, "", nil, db.ErrNotConnected
}
pc, ok := sess.cachedStatements.ReadRaw(stmt)
if ok {
// The statement was cachesess.
ps, err := pc.(*Stmt).Open()
if err == nil {
_, args, err = sess.compileStatement(stmt, args)
if err != nil {
return nil, "", nil, err
}
return ps, ps.query, args, nil
}
}
query, args, err := sess.compileStatement(stmt, args)
if err != nil {
return nil, "", nil, err
}
sqlStmt, err := func(query *string) (*sql.Stmt, error) {
if tx != nil {
return compat.PrepareContext(tx, ctx, *query)
}
return compat.PrepareContext(sess.sqlDB, ctx, *query)
}(&query)
if err != nil {
return nil, "", nil, err
}
p, err := NewStatement(sqlStmt, query).Open()
if err != nil {
return nil, query, args, err
}
sess.cachedStatements.Write(stmt, p)
return p, p.query, args, nil
}
var waitForConnMu sync.Mutex
// WaitForConnection tries to execute the given connectFn function, if
// connectFn returns an error, then WaitForConnection will keep trying until
// connectFn returns nil. Maximum waiting time is 5s after having acquired the
// lock.
func (sess *sessionWithContext) WaitForConnection(connectFn func() error) error {
// This lock ensures first-come, first-served and prevents opening too many
// file descriptors.
waitForConnMu.Lock()
defer waitForConnMu.Unlock()
// Minimum waiting time.
waitTime := time.Millisecond * 10
// Waitig 5 seconds for a successful connection.
for timeStart := time.Now(); time.Since(timeStart) < time.Second*5; {
err := connectFn()
if err == nil {
return nil // Connected!
}
// Only attempt to reconnect if the error is too many clients.
if sess.Err(err) == db.ErrTooManyClients {
// Sleep and try again if, and only if, the server replied with a "too
// many clients" error.
time.Sleep(waitTime)
if waitTime < time.Millisecond*500 {
// Wait a bit more next time.
waitTime = waitTime * 2
}
continue
}
// Return any other error immediately.
return err
}
return db.ErrGivingUpTryingToConnect
}
// ReplaceWithDollarSign turns a SQL statament with '?' placeholders into
// dollar placeholders, like $1, $2, ..., $n
func ReplaceWithDollarSign(buf []byte) []byte {
z := bytes.Count(buf, []byte{'?'})
// the capacity is a quick estimation of the total memory required, this
// reduces reallocations
out := make([]byte, 0, len(buf)+z*3)
var i, k = 0, 1
for i < len(buf) {
if buf[i] == '?' {
out = append(out, buf[:i]...)
buf = buf[i+1:]
i = 0
if len(buf) > 0 && buf[0] == '?' {
out = append(out, '?')
buf = buf[1:]
continue
}
out = append(out, '$')
out = append(out, []byte(strconv.Itoa(k))...)
k = k + 1
continue
}
i = i + 1
}
out = append(out, buf[:len(buf)]...)
buf = nil
return out
}
func copySettings(from Session, into Session) {
into.SetPreparedStatementCache(from.PreparedStatementCacheEnabled())
into.SetConnMaxLifetime(from.ConnMaxLifetime())
into.SetConnMaxIdleTime(from.ConnMaxIdleTime())
into.SetMaxIdleConns(from.MaxIdleConns())
into.SetMaxOpenConns(from.MaxOpenConns())
}
func newSessionID() uint64 {
if atomic.LoadUint64(&lastSessID) == math.MaxUint64 {
atomic.StoreUint64(&lastSessID, 0)
return 0
}
return atomic.AddUint64(&lastSessID, 1)
}
func newBaseTxID() uint64 {
if atomic.LoadUint64(&lastTxID) == math.MaxUint64 {
atomic.StoreUint64(&lastTxID, 0)
return 0
}
return atomic.AddUint64(&lastTxID, 1)
}
// TxContext creates a transaction context and runs fn within it.
func TxContext(ctx context.Context, sess db.Session, fn func(tx db.Session) error, opts *sql.TxOptions) error {
txFn := func(sess db.Session) error {
tx, err := sess.(Session).NewTransaction(ctx, opts)
if err != nil {
return err
}
defer tx.Close()
if err := fn(tx); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return fmt.Errorf("%v: %w", rollbackErr, err)
}
return err
}
return tx.Commit()
}
retryTime := retryTransactionWaitTime
var txErr error
for i := 0; i < sess.MaxTransactionRetries(); i++ {
txErr = sess.(*sessionWithContext).Err(txFn(sess))
if txErr == nil {
return nil
}
if errors.Is(txErr, db.ErrTransactionAborted) {
time.Sleep(retryTime)
retryTime = retryTime * 2
if retryTime > retryTransactionMaxWaitTime {
retryTime = retryTransactionMaxWaitTime
}
continue
}
return txErr
}
return fmt.Errorf("db: giving up trying to commit transaction: %w", txErr)
}
var _ = db.Session(&sessionWithContext{})