You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
dorm/vendor/github.com/mailru/go-clickhouse/v2/conn.go

370 lines
8.7 KiB

package clickhouse
import (
"compress/gzip"
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/url"
"os"
"sync/atomic"
"time"
"github.com/google/uuid"
)
type key int
const (
// QueryID uses for setting query_id request param for request to Clickhouse
QueryID key = iota
// QuotaKey uses for setting quota_key request param for request to Clickhouse
QuotaKey
quotaKeyParamName = "quota_key"
queryIDParamName = "query_id"
)
// errors
var (
errEmptyQueryID = errors.New("query id is empty")
)
var defaultKillQueryTimeout = time.Duration(time.Second)
// conn implements an interface sql.Conn
type conn struct {
url *url.URL
user *url.Userinfo
location *time.Location
useDBLocation bool
useGzipCompression bool
transport *http.Transport
cancel context.CancelFunc
txCtx context.Context
stmts []*stmt
logger *log.Logger
closed int32
killQueryOnErr bool
killQueryTimeout time.Duration
}
func newConn(cfg *Config) *conn {
var logger *log.Logger
if cfg.Debug {
logger = log.New(os.Stderr, "clickhouse: ", log.LstdFlags)
}
c := &conn{
url: cfg.url(map[string]string{"default_format": "TabSeparatedWithNamesAndTypes"}, false),
location: cfg.Location,
useDBLocation: cfg.UseDBLocation,
useGzipCompression: cfg.GzipCompression,
killQueryOnErr: cfg.KillQueryOnErr,
killQueryTimeout: cfg.KillQueryTimeout,
transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: cfg.Timeout,
KeepAlive: cfg.IdleTimeout,
DualStack: true,
}).DialContext,
MaxIdleConns: 1,
IdleConnTimeout: cfg.IdleTimeout,
ResponseHeaderTimeout: cfg.ReadTimeout,
TLSClientConfig: getTLSConfigClone(cfg.TLSConfig),
},
logger: logger,
}
// store userinfo in separate member, we will handle it manually
c.user = c.url.User
c.url.User = nil
c.log("new connection", c.url.Scheme, c.url.Host, c.url.Path)
return c
}
func (c *conn) log(msg ...interface{}) {
if c.logger != nil {
c.logger.Println(msg...)
}
}
// Prepare returns a prepared statement, bound to this connection.
func (c *conn) Prepare(query string) (driver.Stmt, error) {
return c.prepare(query)
}
// Close invalidates and potentially stops any current
// prepared statements and transactions, marking this
// connection as no longer in use.
func (c *conn) Close() error {
if atomic.CompareAndSwapInt32(&c.closed, 0, 1) {
c.log("close connection", c.url.Scheme, c.url.Host, c.url.Path)
cancel := c.cancel
transport := c.transport
c.transport = nil
c.cancel = nil
if cancel != nil {
cancel()
}
if transport != nil {
transport.CloseIdleConnections()
}
}
return nil
}
// Begin starts and returns a new transaction.
func (c *conn) Begin() (driver.Tx, error) {
return c.beginTx(context.Background())
}
// Commit applies prepared statement if it exists
func (c *conn) Commit() (err error) {
if atomic.LoadInt32(&c.closed) != 0 {
return driver.ErrBadConn
}
if c.txCtx == nil {
return sql.ErrTxDone
}
ctx := c.txCtx
stmts := c.stmts
c.txCtx = nil
c.stmts = stmts[:0]
if len(stmts) == 0 {
return nil
}
for _, stmt := range stmts {
c.log("commit statement: ", stmt.prefix, stmt.pattern)
if err = stmt.commit(ctx); err != nil {
break
}
}
return
}
// Rollback cleans prepared statement
func (c *conn) Rollback() error {
if atomic.LoadInt32(&c.closed) != 0 {
return driver.ErrBadConn
}
if c.txCtx == nil {
return sql.ErrTxDone
}
c.txCtx = nil
stmts := c.stmts
c.stmts = stmts[:0]
if len(stmts) == 0 {
// there is no statements, so nothing to rollback
return sql.ErrTxDone
}
// the statements will be closed by sql.Tx
return nil
}
// Exec implements the driver.Execer
func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
return c.exec(context.Background(), query, args)
}
// Query implements the driver.Queryer
func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return c.query(context.Background(), query, args)
}
func (c *conn) beginTx(ctx context.Context) (driver.Tx, error) {
if atomic.LoadInt32(&c.closed) != 0 {
return nil, driver.ErrBadConn
}
c.txCtx = ctx
return c, nil
}
func (c *conn) killQuery(req *http.Request, args []driver.Value) error {
if !c.killQueryOnErr {
return nil
}
queryID := req.URL.Query().Get(queryIDParamName)
if queryID == "" {
return errEmptyQueryID
}
query := fmt.Sprintf("KILL QUERY WHERE query_id='%s'", queryID)
timeout := c.killQueryTimeout
if timeout == 0 {
timeout = defaultKillQueryTimeout
}
ctx, cancelFunc := context.WithTimeout(context.Background(), timeout)
defer cancelFunc()
req, err := c.buildRequest(ctx, query, args)
if err != nil {
return err
}
body, err := c.doRequest(ctx, req)
if err != nil {
return err
}
if body != nil {
// Drain body to enable connection reuse
_, _ = io.Copy(ioutil.Discard, body)
body.Close()
}
return nil
}
func (c *conn) query(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) {
if atomic.LoadInt32(&c.closed) != 0 {
return nil, driver.ErrBadConn
}
req, err := c.buildRequest(ctx, query, args)
if err != nil {
return nil, err
}
body, err := c.doRequest(ctx, req)
if err != nil {
if _, ok := err.(*Error); !ok && err != driver.ErrBadConn {
killErr := c.killQuery(req, args)
if killErr != nil {
c.log("error from killQuery", killErr)
}
}
return nil, err
}
return newTextRows(c, body, c.location, c.useDBLocation)
}
func (c *conn) exec(ctx context.Context, query string, args []driver.Value) (driver.Result, error) {
if atomic.LoadInt32(&c.closed) != 0 {
return nil, driver.ErrBadConn
}
req, err := c.buildRequest(ctx, query, args)
if err != nil {
return nil, err
}
body, err := c.doRequest(ctx, req)
if body != nil {
// Drain body to enable connection reuse
_, _ = io.Copy(ioutil.Discard, body)
body.Close()
}
return emptyResult, err
}
func (c *conn) doRequest(ctx context.Context, req *http.Request) (io.ReadCloser, error) {
ctx, cancel := context.WithCancel(ctx)
transport := c.transport
c.cancel = cancel
if transport == nil {
c.cancel = nil
return nil, driver.ErrBadConn
}
req = req.WithContext(ctx)
resp, err := transport.RoundTrip(req)
if err != nil {
c.cancel = nil
return nil, fmt.Errorf("doRequest: transport failed to send a request to ClickHouse: %w", err)
}
if resp.StatusCode != 200 {
msg, err := readResponse(resp)
c.cancel = nil
if err != nil {
return nil, fmt.Errorf("doRequest: failed to read the response with the status code %d: %w", resp.StatusCode, err)
}
// we got non-200 response, which means ClickHouse send an error in the
// response
return nil, newError(string(msg))
}
return resp.Body, nil
}
func (c *conn) buildRequest(ctx context.Context, query string, params []driver.Value) (*http.Request, error) {
var err error
if len(params) > 0 {
if query, err = interpolateParams(query, params); err != nil {
return nil, fmt.Errorf("buildRequest: failed to interpolate params: %w", err)
}
}
bodyReader, bodyWriter := io.Pipe()
go func() {
if c.useGzipCompression {
gz := gzip.NewWriter(bodyWriter)
_, _ = gz.Write([]byte(query))
_ = gz.Close()
_ = bodyWriter.Close()
} else {
_, _ = bodyWriter.Write([]byte(query))
_ = bodyWriter.Close()
}
}()
c.log("query: ", query)
req, err := http.NewRequest(http.MethodPost, c.url.String(), bodyReader)
if err != nil {
return nil, fmt.Errorf("buildRequest: failed to create a request: %w", err)
}
// http.Transport ignores url.User argument, handle it here
if c.user != nil {
p, _ := c.user.Password()
req.SetBasicAuth(c.user.Username(), p)
}
var reqQuery url.Values
if ctx != nil {
quotaKey, quotaOk := ctx.Value(QuotaKey).(string)
if quotaOk && quotaKey != "" {
if reqQuery == nil {
reqQuery = req.URL.Query()
}
reqQuery.Add(quotaKeyParamName, quotaKey)
}
queryID, queryOk := ctx.Value(QueryID).(string)
if c.killQueryOnErr && (!queryOk || queryID == "") {
queryID = uuid.New().String()
}
if queryID != "" {
if reqQuery == nil {
reqQuery = req.URL.Query()
}
reqQuery.Add(queryIDParamName, queryID)
}
}
if reqQuery != nil {
req.URL.RawQuery = reqQuery.Encode()
}
if c.useGzipCompression {
req.Header.Set("Content-Encoding", "gzip")
}
return req, nil
}
func (c *conn) prepare(query string) (*stmt, error) {
if atomic.LoadInt32(&c.closed) != 0 {
return nil, driver.ErrBadConn
}
c.log("new statement: ", query)
s := newStmt(query)
s.c = c
if c.txCtx == nil {
s.batchMode = false
}
if s.batchMode {
c.stmts = append(c.stmts, s)
}
return s, nil
}