You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1101 lines
22 KiB
1101 lines
22 KiB
2 years ago
|
package pgdriver
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"context"
|
||
|
"crypto/md5"
|
||
|
"crypto/tls"
|
||
|
"database/sql"
|
||
|
"database/sql/driver"
|
||
|
"encoding/binary"
|
||
|
"encoding/hex"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"math"
|
||
|
"strconv"
|
||
|
"sync"
|
||
|
"time"
|
||
|
"unicode/utf8"
|
||
|
|
||
|
"mellium.im/sasl"
|
||
|
)
|
||
|
|
||
|
// https://www.postgresql.org/docs/current/protocol-message-formats.html
|
||
|
//nolint:deadcode,varcheck,unused
|
||
|
const (
|
||
|
commandCompleteMsg = 'C'
|
||
|
errorResponseMsg = 'E'
|
||
|
noticeResponseMsg = 'N'
|
||
|
parameterStatusMsg = 'S'
|
||
|
authenticationOKMsg = 'R'
|
||
|
backendKeyDataMsg = 'K'
|
||
|
noDataMsg = 'n'
|
||
|
passwordMessageMsg = 'p'
|
||
|
terminateMsg = 'X'
|
||
|
|
||
|
saslInitialResponseMsg = 'p'
|
||
|
authenticationSASLContinueMsg = 'R'
|
||
|
saslResponseMsg = 'p'
|
||
|
authenticationSASLFinalMsg = 'R'
|
||
|
|
||
|
authenticationOK = 0
|
||
|
authenticationCleartextPassword = 3
|
||
|
authenticationMD5Password = 5
|
||
|
authenticationSASL = 10
|
||
|
|
||
|
notificationResponseMsg = 'A'
|
||
|
|
||
|
describeMsg = 'D'
|
||
|
parameterDescriptionMsg = 't'
|
||
|
|
||
|
queryMsg = 'Q'
|
||
|
readyForQueryMsg = 'Z'
|
||
|
emptyQueryResponseMsg = 'I'
|
||
|
rowDescriptionMsg = 'T'
|
||
|
dataRowMsg = 'D'
|
||
|
|
||
|
parseMsg = 'P'
|
||
|
parseCompleteMsg = '1'
|
||
|
|
||
|
bindMsg = 'B'
|
||
|
bindCompleteMsg = '2'
|
||
|
|
||
|
executeMsg = 'E'
|
||
|
|
||
|
syncMsg = 'S'
|
||
|
flushMsg = 'H'
|
||
|
|
||
|
closeMsg = 'C'
|
||
|
closeCompleteMsg = '3'
|
||
|
|
||
|
copyInResponseMsg = 'G'
|
||
|
copyOutResponseMsg = 'H'
|
||
|
copyDataMsg = 'd'
|
||
|
copyDoneMsg = 'c'
|
||
|
)
|
||
|
|
||
|
var errEmptyQuery = errors.New("pgdriver: query is empty")
|
||
|
|
||
|
type reader struct {
|
||
|
*bufio.Reader
|
||
|
buf []byte
|
||
|
}
|
||
|
|
||
|
func newReader(r io.Reader) *reader {
|
||
|
return &reader{
|
||
|
Reader: bufio.NewReader(r),
|
||
|
buf: make([]byte, 128),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *reader) ReadTemp(n int) ([]byte, error) {
|
||
|
if n <= len(r.buf) {
|
||
|
b := r.buf[:n]
|
||
|
_, err := io.ReadFull(r.Reader, b)
|
||
|
return b, err
|
||
|
}
|
||
|
|
||
|
b := make([]byte, n)
|
||
|
_, err := io.ReadFull(r.Reader, b)
|
||
|
return b, err
|
||
|
}
|
||
|
|
||
|
func (r *reader) Discard(n int) error {
|
||
|
_, err := r.ReadTemp(n)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func enableSSL(ctx context.Context, cn *Conn, tlsConf *tls.Config) error {
|
||
|
if err := writeSSLMsg(ctx, cn); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
rd := cn.reader(ctx, -1)
|
||
|
|
||
|
c, err := rd.ReadByte()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if c != 'S' {
|
||
|
return errors.New("pgdriver: SSL is not enabled on the server")
|
||
|
}
|
||
|
|
||
|
tlsCN := tls.Client(cn.netConn, tlsConf)
|
||
|
if err := tlsCN.HandshakeContext(ctx); err != nil {
|
||
|
return fmt.Errorf("pgdriver: TLS handshake failed: %w", err)
|
||
|
}
|
||
|
cn.netConn = tlsCN
|
||
|
rd.Reset(cn.netConn)
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func writeSSLMsg(ctx context.Context, cn *Conn) error {
|
||
|
wb := getWriteBuffer()
|
||
|
defer putWriteBuffer(wb)
|
||
|
|
||
|
wb.StartMessage(0)
|
||
|
wb.WriteInt32(80877103)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
return cn.write(ctx, wb)
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func startup(ctx context.Context, cn *Conn) error {
|
||
|
if err := writeStartup(ctx, cn); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
rd := cn.reader(ctx, -1)
|
||
|
|
||
|
for {
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case backendKeyDataMsg:
|
||
|
processID, err := readInt32(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
secretKey, err := readInt32(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
cn.processID = processID
|
||
|
cn.secretKey = secretKey
|
||
|
case authenticationOKMsg:
|
||
|
if err := auth(ctx, cn, rd); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
case readyForQueryMsg:
|
||
|
return rd.Discard(msgLen)
|
||
|
case parameterStatusMsg, noticeResponseMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return e
|
||
|
default:
|
||
|
return fmt.Errorf("pgdriver: unexpected startup message: %q", c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func writeStartup(ctx context.Context, cn *Conn) error {
|
||
|
wb := getWriteBuffer()
|
||
|
defer putWriteBuffer(wb)
|
||
|
|
||
|
wb.StartMessage(0)
|
||
|
wb.WriteInt32(196608)
|
||
|
wb.WriteString("user")
|
||
|
wb.WriteString(cn.cfg.User)
|
||
|
wb.WriteString("database")
|
||
|
wb.WriteString(cn.cfg.Database)
|
||
|
if cn.cfg.AppName != "" {
|
||
|
wb.WriteString("application_name")
|
||
|
wb.WriteString(cn.cfg.AppName)
|
||
|
}
|
||
|
wb.WriteString("")
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
return cn.write(ctx, wb)
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func auth(ctx context.Context, cn *Conn, rd *reader) error {
|
||
|
num, err := readInt32(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch num {
|
||
|
case authenticationOK:
|
||
|
return nil
|
||
|
case authenticationCleartextPassword:
|
||
|
return authCleartext(ctx, cn, rd)
|
||
|
case authenticationMD5Password:
|
||
|
return authMD5(ctx, cn, rd)
|
||
|
case authenticationSASL:
|
||
|
if err := authSASL(ctx, cn, rd); err != nil {
|
||
|
return fmt.Errorf("pgdriver: SASL: %w", err)
|
||
|
}
|
||
|
return nil
|
||
|
default:
|
||
|
return fmt.Errorf("pgdriver: unknown authentication message: %q", num)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func authCleartext(ctx context.Context, cn *Conn, rd *reader) error {
|
||
|
if err := writePassword(ctx, cn, cn.cfg.Password); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return readAuthOK(cn, rd)
|
||
|
}
|
||
|
|
||
|
func readAuthOK(cn *Conn, rd *reader) error {
|
||
|
c, _, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case authenticationOKMsg:
|
||
|
num, err := readInt32(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if num != 0 {
|
||
|
return fmt.Errorf("pgdriver: unexpected authentication code: %q", num)
|
||
|
}
|
||
|
return nil
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return e
|
||
|
default:
|
||
|
return fmt.Errorf("pgdriver: unknown password message: %q", c)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func authMD5(ctx context.Context, cn *Conn, rd *reader) error {
|
||
|
b, err := rd.ReadTemp(4)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
secret := "md5" + md5s(md5s(cn.cfg.Password+cn.cfg.User)+string(b))
|
||
|
if err := writePassword(ctx, cn, secret); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return readAuthOK(cn, rd)
|
||
|
}
|
||
|
|
||
|
func writePassword(ctx context.Context, cn *Conn, password string) error {
|
||
|
wb := getWriteBuffer()
|
||
|
defer putWriteBuffer(wb)
|
||
|
|
||
|
wb.StartMessage(passwordMessageMsg)
|
||
|
wb.WriteString(password)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
return cn.write(ctx, wb)
|
||
|
}
|
||
|
|
||
|
func md5s(s string) string {
|
||
|
h := md5.Sum([]byte(s))
|
||
|
return hex.EncodeToString(h[:])
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func authSASL(ctx context.Context, cn *Conn, rd *reader) error {
|
||
|
var saslMech sasl.Mechanism
|
||
|
|
||
|
loop:
|
||
|
for {
|
||
|
s, err := readString(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch s {
|
||
|
case "":
|
||
|
break loop
|
||
|
case sasl.ScramSha256.Name:
|
||
|
saslMech = sasl.ScramSha256
|
||
|
case sasl.ScramSha256Plus.Name:
|
||
|
// ignore
|
||
|
default:
|
||
|
return fmt.Errorf("got %q, wanted %q", s, sasl.ScramSha256.Name)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
creds := sasl.Credentials(func() (Username, Password, Identity []byte) {
|
||
|
return []byte(cn.cfg.User), []byte(cn.cfg.Password), nil
|
||
|
})
|
||
|
client := sasl.NewClient(saslMech, creds)
|
||
|
|
||
|
_, resp, err := client.Step(nil)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("client.Step 1 failed: %w", err)
|
||
|
}
|
||
|
|
||
|
if err := saslWriteInitialResponse(ctx, cn, saslMech, resp); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case authenticationSASLContinueMsg:
|
||
|
c11, err := readInt32(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if c11 != 11 {
|
||
|
return fmt.Errorf("got %q, wanted %q", c, 11)
|
||
|
}
|
||
|
|
||
|
b, err := rd.ReadTemp(msgLen - 4)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
_, resp, err = client.Step(b)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("client.Step 2 failed: %w", err)
|
||
|
}
|
||
|
|
||
|
if err := saslWriteResponse(ctx, cn, resp); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
resp, err = saslReadAuthFinal(cn, rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if _, _, err := client.Step(resp); err != nil {
|
||
|
return fmt.Errorf("client.Step 3 failed: %w", err)
|
||
|
}
|
||
|
|
||
|
if client.State() != sasl.ValidServerResponse {
|
||
|
return fmt.Errorf("got state=%q, wanted %q", client.State(), sasl.ValidServerResponse)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return e
|
||
|
default:
|
||
|
return fmt.Errorf("got %q, wanted %q", c, authenticationSASLContinueMsg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func saslWriteInitialResponse(
|
||
|
ctx context.Context, cn *Conn, saslMech sasl.Mechanism, resp []byte,
|
||
|
) error {
|
||
|
wb := getWriteBuffer()
|
||
|
defer putWriteBuffer(wb)
|
||
|
|
||
|
wb.StartMessage(saslInitialResponseMsg)
|
||
|
wb.WriteString(saslMech.Name)
|
||
|
wb.WriteInt32(int32(len(resp)))
|
||
|
if _, err := wb.Write(resp); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
return cn.write(ctx, wb)
|
||
|
}
|
||
|
|
||
|
func saslWriteResponse(ctx context.Context, cn *Conn, resp []byte) error {
|
||
|
wb := getWriteBuffer()
|
||
|
defer putWriteBuffer(wb)
|
||
|
|
||
|
wb.StartMessage(saslResponseMsg)
|
||
|
if _, err := wb.Write(resp); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
return cn.write(ctx, wb)
|
||
|
}
|
||
|
|
||
|
func saslReadAuthFinal(cn *Conn, rd *reader) ([]byte, error) {
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case authenticationSASLFinalMsg:
|
||
|
c12, err := readInt32(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if c12 != 12 {
|
||
|
return nil, fmt.Errorf("got %q, wanted %q", c, 12)
|
||
|
}
|
||
|
|
||
|
resp := make([]byte, msgLen-4)
|
||
|
if _, err := io.ReadFull(rd, resp); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if err := readAuthOK(cn, rd); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return resp, nil
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return nil, e
|
||
|
default:
|
||
|
return nil, fmt.Errorf("got %q, wanted %q", c, authenticationSASLFinalMsg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func writeQuery(ctx context.Context, cn *Conn, query string) error {
|
||
|
wb := getWriteBuffer()
|
||
|
defer putWriteBuffer(wb)
|
||
|
|
||
|
wb.StartMessage(queryMsg)
|
||
|
wb.WriteString(query)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
return cn.write(ctx, wb)
|
||
|
}
|
||
|
|
||
|
func readQuery(ctx context.Context, cn *Conn) (sql.Result, error) {
|
||
|
rd := cn.reader(ctx, -1)
|
||
|
|
||
|
var res driver.Result
|
||
|
var firstErr error
|
||
|
for {
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if firstErr == nil {
|
||
|
firstErr = e
|
||
|
}
|
||
|
case emptyQueryResponseMsg:
|
||
|
if firstErr == nil {
|
||
|
firstErr = errEmptyQuery
|
||
|
}
|
||
|
case commandCompleteMsg:
|
||
|
tmp, err := rd.ReadTemp(msgLen)
|
||
|
if err != nil {
|
||
|
firstErr = err
|
||
|
break
|
||
|
}
|
||
|
|
||
|
r, err := parseResult(tmp)
|
||
|
if err != nil {
|
||
|
firstErr = err
|
||
|
} else {
|
||
|
res = r
|
||
|
}
|
||
|
case describeMsg,
|
||
|
rowDescriptionMsg,
|
||
|
noticeResponseMsg,
|
||
|
parameterStatusMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
case readyForQueryMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return res, firstErr
|
||
|
default:
|
||
|
return nil, fmt.Errorf("pgdriver: Exec: unexpected message %q", c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func readQueryData(ctx context.Context, cn *Conn) (*rows, error) {
|
||
|
rd := cn.reader(ctx, -1)
|
||
|
var firstErr error
|
||
|
for {
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case rowDescriptionMsg:
|
||
|
rowDesc, err := readRowDescription(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return newRows(cn, rowDesc, true), nil
|
||
|
case commandCompleteMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
case readyForQueryMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if firstErr != nil {
|
||
|
return nil, firstErr
|
||
|
}
|
||
|
return &rows{closed: true}, nil
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if firstErr == nil {
|
||
|
firstErr = e
|
||
|
}
|
||
|
case emptyQueryResponseMsg:
|
||
|
if firstErr == nil {
|
||
|
firstErr = errEmptyQuery
|
||
|
}
|
||
|
case noticeResponseMsg, parameterStatusMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
default:
|
||
|
return nil, fmt.Errorf("pgdriver: newRows: unexpected message %q", c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
var rowDescPool sync.Pool
|
||
|
|
||
|
type rowDescription struct {
|
||
|
buf []byte
|
||
|
names []string
|
||
|
types []int32
|
||
|
numInput int16
|
||
|
}
|
||
|
|
||
|
func newRowDescription(numCol int) *rowDescription {
|
||
|
if numCol < 16 {
|
||
|
numCol = 16
|
||
|
}
|
||
|
return &rowDescription{
|
||
|
buf: make([]byte, 0, 16*numCol),
|
||
|
names: make([]string, 0, numCol),
|
||
|
types: make([]int32, 0, numCol),
|
||
|
numInput: -1,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (d *rowDescription) reset(numCol int) {
|
||
|
d.buf = make([]byte, 0, 16*numCol)
|
||
|
d.names = d.names[:0]
|
||
|
d.types = d.types[:0]
|
||
|
d.numInput = -1
|
||
|
}
|
||
|
|
||
|
func (d *rowDescription) addName(name []byte) {
|
||
|
if len(d.buf)+len(name) > cap(d.buf) {
|
||
|
d.buf = make([]byte, 0, cap(d.buf))
|
||
|
}
|
||
|
|
||
|
i := len(d.buf)
|
||
|
d.buf = append(d.buf, name...)
|
||
|
d.names = append(d.names, bytesToString(d.buf[i:]))
|
||
|
}
|
||
|
|
||
|
func (d *rowDescription) addType(dataType int32) {
|
||
|
d.types = append(d.types, dataType)
|
||
|
}
|
||
|
|
||
|
func readRowDescription(rd *reader) (*rowDescription, error) {
|
||
|
numCol, err := readInt16(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
rowDesc, ok := rowDescPool.Get().(*rowDescription)
|
||
|
if !ok {
|
||
|
rowDesc = newRowDescription(int(numCol))
|
||
|
} else {
|
||
|
rowDesc.reset(int(numCol))
|
||
|
}
|
||
|
|
||
|
for i := 0; i < int(numCol); i++ {
|
||
|
name, err := rd.ReadSlice(0)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
rowDesc.addName(name[:len(name)-1])
|
||
|
|
||
|
if _, err := rd.ReadTemp(6); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
dataType, err := readInt32(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
rowDesc.addType(dataType)
|
||
|
|
||
|
if _, err := rd.ReadTemp(8); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return rowDesc, nil
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func readNotification(ctx context.Context, rd *reader) (channel, payload string, err error) {
|
||
|
for {
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case commandCompleteMsg, readyForQueryMsg, noticeResponseMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
return "", "", e
|
||
|
case notificationResponseMsg:
|
||
|
if err := rd.Discard(4); err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
channel, err = readString(rd)
|
||
|
if err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
payload, err = readString(rd)
|
||
|
if err != nil {
|
||
|
return "", "", err
|
||
|
}
|
||
|
return channel, payload, nil
|
||
|
default:
|
||
|
return "", "", fmt.Errorf("pgdriver: readNotification: unexpected message %q", c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func writeParseDescribeSync(ctx context.Context, cn *Conn, name, query string) error {
|
||
|
wb := getWriteBuffer()
|
||
|
defer putWriteBuffer(wb)
|
||
|
|
||
|
wb.StartMessage(parseMsg)
|
||
|
wb.WriteString(name)
|
||
|
wb.WriteString(query)
|
||
|
wb.WriteInt16(0)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
wb.StartMessage(describeMsg)
|
||
|
wb.WriteByte('S')
|
||
|
wb.WriteString(name)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
wb.StartMessage(syncMsg)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
return cn.write(ctx, wb)
|
||
|
}
|
||
|
|
||
|
func readParseDescribeSync(ctx context.Context, cn *Conn) (*rowDescription, error) {
|
||
|
rd := cn.reader(ctx, -1)
|
||
|
var numParam int16
|
||
|
var rowDesc *rowDescription
|
||
|
var firstErr error
|
||
|
for {
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case parseCompleteMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
case rowDescriptionMsg: // response to DESCRIBE message.
|
||
|
rowDesc, err = readRowDescription(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
rowDesc.numInput = numParam
|
||
|
case parameterDescriptionMsg: // response to DESCRIBE message.
|
||
|
numParam, err = readInt16(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
for i := 0; i < int(numParam); i++ {
|
||
|
if _, err := readInt32(rd); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
case noDataMsg: // response to DESCRIBE message.
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
case readyForQueryMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if firstErr != nil {
|
||
|
return nil, firstErr
|
||
|
}
|
||
|
return rowDesc, err
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if firstErr == nil {
|
||
|
firstErr = e
|
||
|
}
|
||
|
case noticeResponseMsg, parameterStatusMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
default:
|
||
|
return nil, fmt.Errorf("pgdriver: readParseDescribeSync: unexpected message %q", c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func writeBindExecute(ctx context.Context, cn *Conn, name string, args []driver.NamedValue) error {
|
||
|
wb := getWriteBuffer()
|
||
|
defer putWriteBuffer(wb)
|
||
|
|
||
|
wb.StartMessage(bindMsg)
|
||
|
wb.WriteString("")
|
||
|
wb.WriteString(name)
|
||
|
wb.WriteInt16(0)
|
||
|
wb.WriteInt16(int16(len(args)))
|
||
|
for i := range args {
|
||
|
wb.StartParam()
|
||
|
bytes, err := appendStmtArg(wb.Bytes, args[i].Value)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if bytes != nil {
|
||
|
wb.Bytes = bytes
|
||
|
wb.FinishParam()
|
||
|
} else {
|
||
|
wb.FinishNullParam()
|
||
|
}
|
||
|
}
|
||
|
wb.WriteInt16(0)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
wb.StartMessage(executeMsg)
|
||
|
wb.WriteString("")
|
||
|
wb.WriteInt32(0)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
wb.StartMessage(syncMsg)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
return cn.write(ctx, wb)
|
||
|
}
|
||
|
|
||
|
func readExtQuery(ctx context.Context, cn *Conn) (driver.Result, error) {
|
||
|
rd := cn.reader(ctx, -1)
|
||
|
var res driver.Result
|
||
|
var firstErr error
|
||
|
for {
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case bindCompleteMsg, dataRowMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
case commandCompleteMsg: // response to EXECUTE message.
|
||
|
tmp, err := rd.ReadTemp(msgLen)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
r, err := parseResult(tmp)
|
||
|
if err != nil {
|
||
|
if firstErr == nil {
|
||
|
firstErr = err
|
||
|
}
|
||
|
} else {
|
||
|
res = r
|
||
|
}
|
||
|
case readyForQueryMsg: // Response to SYNC message.
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if firstErr != nil {
|
||
|
return nil, firstErr
|
||
|
}
|
||
|
return res, nil
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if firstErr == nil {
|
||
|
firstErr = e
|
||
|
}
|
||
|
case emptyQueryResponseMsg:
|
||
|
if firstErr == nil {
|
||
|
firstErr = errEmptyQuery
|
||
|
}
|
||
|
case noticeResponseMsg, parameterStatusMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
default:
|
||
|
return nil, fmt.Errorf("pgdriver: readExtQuery: unexpected message %q", c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func readExtQueryData(ctx context.Context, cn *Conn, rowDesc *rowDescription) (*rows, error) {
|
||
|
rd := cn.reader(ctx, -1)
|
||
|
var firstErr error
|
||
|
for {
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case bindCompleteMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return newRows(cn, rowDesc, false), nil
|
||
|
case commandCompleteMsg: // response to EXECUTE message.
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
case readyForQueryMsg: // Response to SYNC message.
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if firstErr != nil {
|
||
|
return nil, firstErr
|
||
|
}
|
||
|
return &rows{closed: true}, nil
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if firstErr == nil {
|
||
|
firstErr = e
|
||
|
}
|
||
|
case emptyQueryResponseMsg:
|
||
|
if firstErr == nil {
|
||
|
firstErr = errEmptyQuery
|
||
|
}
|
||
|
case noticeResponseMsg, parameterStatusMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
default:
|
||
|
return nil, fmt.Errorf("pgdriver: readExtQueryData: unexpected message %q", c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func writeCloseStmt(ctx context.Context, cn *Conn, name string) error {
|
||
|
wb := getWriteBuffer()
|
||
|
defer putWriteBuffer(wb)
|
||
|
|
||
|
wb.StartMessage(closeMsg)
|
||
|
wb.WriteByte('S') //nolint
|
||
|
wb.WriteString(name)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
wb.StartMessage(flushMsg)
|
||
|
wb.FinishMessage()
|
||
|
|
||
|
return cn.write(ctx, wb)
|
||
|
}
|
||
|
|
||
|
func readCloseStmtComplete(ctx context.Context, cn *Conn) error {
|
||
|
rd := cn.reader(ctx, -1)
|
||
|
for {
|
||
|
c, msgLen, err := readMessageType(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
switch c {
|
||
|
case closeCompleteMsg:
|
||
|
return rd.Discard(msgLen)
|
||
|
case errorResponseMsg:
|
||
|
e, err := readError(rd)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return e
|
||
|
case noticeResponseMsg, parameterStatusMsg:
|
||
|
if err := rd.Discard(msgLen); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
default:
|
||
|
return fmt.Errorf("pgdriver: readCloseCompleteMsg: unexpected message %q", c)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func readMessageType(rd *reader) (byte, int, error) {
|
||
|
c, err := rd.ReadByte()
|
||
|
if err != nil {
|
||
|
return 0, 0, err
|
||
|
}
|
||
|
l, err := readInt32(rd)
|
||
|
if err != nil {
|
||
|
return 0, 0, err
|
||
|
}
|
||
|
return c, int(l) - 4, nil
|
||
|
}
|
||
|
|
||
|
func readInt16(rd *reader) (int16, error) {
|
||
|
b, err := rd.ReadTemp(2)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
return int16(binary.BigEndian.Uint16(b)), nil
|
||
|
}
|
||
|
|
||
|
func readInt32(rd *reader) (int32, error) {
|
||
|
b, err := rd.ReadTemp(4)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
return int32(binary.BigEndian.Uint32(b)), nil
|
||
|
}
|
||
|
|
||
|
func readString(rd *reader) (string, error) {
|
||
|
b, err := rd.ReadSlice(0)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
return string(b[:len(b)-1]), nil
|
||
|
}
|
||
|
|
||
|
func readError(rd *reader) (error, error) {
|
||
|
m := make(map[byte]string)
|
||
|
for {
|
||
|
c, err := rd.ReadByte()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if c == 0 {
|
||
|
break
|
||
|
}
|
||
|
s, err := readString(rd)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
m[c] = s
|
||
|
}
|
||
|
switch err := (Error{m: m}); err.Field('V') {
|
||
|
case "FATAL", "PANIC":
|
||
|
// Return this as an error and stop processing.
|
||
|
return nil, err
|
||
|
default:
|
||
|
// Return this as an error message and continue processing.
|
||
|
return err, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func appendStmtArg(b []byte, v driver.Value) ([]byte, error) {
|
||
|
switch v := v.(type) {
|
||
|
case nil:
|
||
|
return nil, nil
|
||
|
case int64:
|
||
|
return strconv.AppendInt(b, v, 10), nil
|
||
|
case float64:
|
||
|
switch {
|
||
|
case math.IsNaN(v):
|
||
|
return append(b, "NaN"...), nil
|
||
|
case math.IsInf(v, 1):
|
||
|
return append(b, "Infinity"...), nil
|
||
|
case math.IsInf(v, -1):
|
||
|
return append(b, "-Infinity"...), nil
|
||
|
default:
|
||
|
return strconv.AppendFloat(b, v, 'f', -1, 64), nil
|
||
|
}
|
||
|
case bool:
|
||
|
if v {
|
||
|
return append(b, "TRUE"...), nil
|
||
|
}
|
||
|
return append(b, "FALSE"...), nil
|
||
|
case []byte:
|
||
|
if v == nil {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
b = append(b, `\x`...)
|
||
|
|
||
|
s := len(b)
|
||
|
b = append(b, make([]byte, hex.EncodedLen(len(v)))...)
|
||
|
hex.Encode(b[s:], v)
|
||
|
|
||
|
return b, nil
|
||
|
case string:
|
||
|
for _, r := range v {
|
||
|
if r == 0 {
|
||
|
continue
|
||
|
}
|
||
|
if r < utf8.RuneSelf {
|
||
|
b = append(b, byte(r))
|
||
|
continue
|
||
|
}
|
||
|
l := len(b)
|
||
|
if cap(b)-l < utf8.UTFMax {
|
||
|
b = append(b, make([]byte, utf8.UTFMax)...)
|
||
|
}
|
||
|
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
|
||
|
b = b[:l+n]
|
||
|
}
|
||
|
return b, nil
|
||
|
case time.Time:
|
||
|
if v.IsZero() {
|
||
|
return nil, nil
|
||
|
}
|
||
|
return v.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00"), nil
|
||
|
default:
|
||
|
return nil, fmt.Errorf("pgdriver: unexpected arg: %T", v)
|
||
|
}
|
||
|
}
|