You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
679 lines
12 KiB
679 lines
12 KiB
package pq
|
|
|
|
import (
|
|
"bufio"
|
|
"crypto/md5"
|
|
"crypto/tls"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/user"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
var (
|
|
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
|
|
ErrNotSupported = errors.New("pq: invalid command")
|
|
)
|
|
|
|
type drv struct{}
|
|
|
|
func (d *drv) Open(name string) (driver.Conn, error) {
|
|
return Open(name)
|
|
}
|
|
|
|
func init() {
|
|
sql.Register("postgres", &drv{})
|
|
}
|
|
|
|
type conn struct {
|
|
c net.Conn
|
|
buf *bufio.Reader
|
|
namei int
|
|
}
|
|
|
|
func Open(name string) (_ driver.Conn, err error) {
|
|
defer errRecover(&err)
|
|
defer errRecoverWithPGReason(&err)
|
|
|
|
o := make(Values)
|
|
|
|
// A number of defaults are applied here, in this order:
|
|
//
|
|
// * Very low precedence defaults applied in every situation
|
|
// * Environment variables
|
|
// * Explicitly passed connection information
|
|
o.Set("host", "localhost")
|
|
o.Set("port", "5432")
|
|
|
|
// Default the username, but ignore errors, because a user
|
|
// passed in via environment variable or connection string
|
|
// would be okay. This can result in connections failing
|
|
// *sometimes* if the client relies on being able to determine
|
|
// the current username and there are intermittent problems.
|
|
u, err := user.Current()
|
|
if err == nil {
|
|
o.Set("user", u.Username)
|
|
}
|
|
|
|
for k, v := range parseEnviron(os.Environ()) {
|
|
o.Set(k, v)
|
|
}
|
|
|
|
parseOpts(name, o)
|
|
|
|
c, err := net.Dial(network(o))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cn := &conn{c: c}
|
|
cn.ssl(o)
|
|
cn.buf = bufio.NewReader(cn.c)
|
|
cn.startup(o)
|
|
return cn, nil
|
|
}
|
|
|
|
func network(o Values) (string, string) {
|
|
host := o.Get("host")
|
|
|
|
if strings.HasPrefix(host, "/") {
|
|
sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
|
|
return "unix", sockPath
|
|
}
|
|
|
|
return "tcp", host + ":" + o.Get("port")
|
|
}
|
|
|
|
type Values map[string]string
|
|
|
|
func (vs Values) Set(k, v string) {
|
|
vs[k] = v
|
|
}
|
|
|
|
func (vs Values) Get(k string) (v string) {
|
|
v, _ = vs[k]
|
|
return
|
|
}
|
|
|
|
func parseOpts(name string, o Values) {
|
|
if len(name) == 0 {
|
|
return
|
|
}
|
|
|
|
ps := strings.Split(name, " ")
|
|
for _, p := range ps {
|
|
kv := strings.Split(p, "=")
|
|
if len(kv) < 2 {
|
|
errorf("invalid option: %q", p)
|
|
}
|
|
o.Set(kv[0], kv[1])
|
|
}
|
|
}
|
|
|
|
func (cn *conn) Begin() (driver.Tx, error) {
|
|
_, err := cn.Exec("BEGIN", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return cn, err
|
|
}
|
|
|
|
func (cn *conn) Commit() error {
|
|
_, err := cn.Exec("COMMIT", nil)
|
|
return err
|
|
}
|
|
|
|
func (cn *conn) Rollback() error {
|
|
_, err := cn.Exec("ROLLBACK", nil)
|
|
return err
|
|
}
|
|
|
|
func (cn *conn) gname() string {
|
|
cn.namei++
|
|
return strconv.FormatInt(int64(cn.namei), 10)
|
|
}
|
|
|
|
func (cn *conn) simpleQuery(q string) (res driver.Result, err error) {
|
|
defer errRecover(&err)
|
|
|
|
b := newWriteBuf('Q')
|
|
b.string(q)
|
|
cn.send(b)
|
|
|
|
for {
|
|
t, r := cn.recv1()
|
|
switch t {
|
|
case 'C':
|
|
res = parseComplete(r.string())
|
|
case 'Z':
|
|
// done
|
|
return
|
|
case 'E':
|
|
err = parseError(r)
|
|
case 'T', 'N', 'S':
|
|
// ignore
|
|
default:
|
|
errorf("unknown response for simple query: %q", t)
|
|
}
|
|
}
|
|
panic("not reached")
|
|
}
|
|
|
|
func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) {
|
|
defer errRecover(&err)
|
|
|
|
st := &stmt{cn: cn, name: stmtName, query: q}
|
|
|
|
b := newWriteBuf('P')
|
|
b.string(st.name)
|
|
b.string(q)
|
|
b.int16(0)
|
|
cn.send(b)
|
|
|
|
b = newWriteBuf('D')
|
|
b.byte('S')
|
|
b.string(st.name)
|
|
cn.send(b)
|
|
|
|
cn.send(newWriteBuf('S'))
|
|
|
|
for {
|
|
t, r := cn.recv1()
|
|
switch t {
|
|
case '1', '2', 'N':
|
|
case 't':
|
|
st.nparams = int(r.int16())
|
|
st.paramTyps = make([]oid, st.nparams, st.nparams)
|
|
|
|
for i := 0; i < st.nparams; i += 1 {
|
|
st.paramTyps[i] = r.oid()
|
|
}
|
|
case 'T':
|
|
n := r.int16()
|
|
st.cols = make([]string, n)
|
|
st.rowTyps = make([]oid, n)
|
|
for i := range st.cols {
|
|
st.cols[i] = r.string()
|
|
r.next(6)
|
|
st.rowTyps[i] = r.oid()
|
|
r.next(8)
|
|
}
|
|
case 'n':
|
|
// no data
|
|
case 'Z':
|
|
return st, err
|
|
case 'E':
|
|
err = parseError(r)
|
|
default:
|
|
errorf("unexpected describe rows response: %q", t)
|
|
}
|
|
}
|
|
|
|
panic("not reached")
|
|
}
|
|
|
|
func (cn *conn) Prepare(q string) (driver.Stmt, error) {
|
|
return cn.prepareTo(q, cn.gname())
|
|
}
|
|
|
|
func (cn *conn) Close() (err error) {
|
|
defer errRecover(&err)
|
|
cn.send(newWriteBuf('X'))
|
|
|
|
return cn.c.Close()
|
|
}
|
|
|
|
// Implement the optional "Execer" interface for one-shot queries
|
|
func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) {
|
|
defer errRecover(&err)
|
|
|
|
// Check to see if we can use the "simpleQuery" interface, which is
|
|
// *much* faster than going through prepare/exec
|
|
if len(args) == 0 {
|
|
return cn.simpleQuery(query)
|
|
}
|
|
|
|
// Use the unnamed statement to defer planning until bind
|
|
// time, or else value-based selectivity estimates cannot be
|
|
// used.
|
|
st, err := cn.prepareTo(query, "")
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
r, err := st.Exec(args)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return r, err
|
|
}
|
|
|
|
// Assumes len(*m) is > 5
|
|
func (cn *conn) send(m *writeBuf) {
|
|
b := (*m)[1:]
|
|
binary.BigEndian.PutUint32(b, uint32(len(b)))
|
|
|
|
if (*m)[0] == 0 {
|
|
*m = b
|
|
}
|
|
|
|
_, err := cn.c.Write(*m)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func (cn *conn) recv() (t byte, r *readBuf) {
|
|
for {
|
|
t, r = cn.recv1()
|
|
switch t {
|
|
case 'E':
|
|
panic(parseError(r))
|
|
case 'N':
|
|
// ignore
|
|
default:
|
|
return
|
|
}
|
|
}
|
|
|
|
panic("not reached")
|
|
}
|
|
|
|
func (cn *conn) recv1() (byte, *readBuf) {
|
|
x := make([]byte, 5)
|
|
_, err := io.ReadFull(cn.buf, x)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
b := readBuf(x[1:])
|
|
y := make([]byte, b.int32()-4)
|
|
_, err = io.ReadFull(cn.buf, y)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
return x[0], (*readBuf)(&y)
|
|
}
|
|
|
|
func (cn *conn) ssl(o Values) {
|
|
tlsConf := tls.Config{}
|
|
switch mode := o.Get("sslmode"); mode {
|
|
case "require", "":
|
|
tlsConf.InsecureSkipVerify = true
|
|
case "verify-full":
|
|
// fall out
|
|
case "disable":
|
|
return
|
|
default:
|
|
errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode)
|
|
}
|
|
|
|
w := newWriteBuf(0)
|
|
w.int32(80877103)
|
|
cn.send(w)
|
|
|
|
b := make([]byte, 1)
|
|
_, err := io.ReadFull(cn.c, b)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if b[0] != 'S' {
|
|
panic(ErrSSLNotSupported)
|
|
}
|
|
|
|
cn.c = tls.Client(cn.c, &tlsConf)
|
|
}
|
|
|
|
func (cn *conn) startup(o Values) {
|
|
w := newWriteBuf(0)
|
|
w.int32(196608)
|
|
w.string("user")
|
|
w.string(o.Get("user"))
|
|
w.string("database")
|
|
w.string(o.Get("dbname"))
|
|
w.string("")
|
|
cn.send(w)
|
|
|
|
for {
|
|
t, r := cn.recv()
|
|
switch t {
|
|
case 'K', 'S':
|
|
case 'R':
|
|
cn.auth(r, o)
|
|
case 'Z':
|
|
return
|
|
default:
|
|
errorf("unknown response for startup: %q", t)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (cn *conn) auth(r *readBuf, o Values) {
|
|
switch code := r.int32(); code {
|
|
case 0:
|
|
// OK
|
|
case 3:
|
|
w := newWriteBuf('p')
|
|
w.string(o.Get("password"))
|
|
cn.send(w)
|
|
|
|
t, r := cn.recv()
|
|
if t != 'R' {
|
|
errorf("unexpected password response: %q", t)
|
|
}
|
|
|
|
if r.int32() != 0 {
|
|
errorf("unexpected authentication response: %q", t)
|
|
}
|
|
case 5:
|
|
s := string(r.next(4))
|
|
w := newWriteBuf('p')
|
|
w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
|
|
cn.send(w)
|
|
|
|
t, r := cn.recv()
|
|
if t != 'R' {
|
|
errorf("unexpected password response: %q", t)
|
|
}
|
|
|
|
if r.int32() != 0 {
|
|
errorf("unexpected authentication resoonse: %q", t)
|
|
}
|
|
default:
|
|
errorf("unknown authentication response: %d", code)
|
|
}
|
|
}
|
|
|
|
type stmt struct {
|
|
cn *conn
|
|
name string
|
|
query string
|
|
cols []string
|
|
nparams int
|
|
rowTyps []oid
|
|
paramTyps []oid
|
|
closed bool
|
|
}
|
|
|
|
func (st *stmt) Close() (err error) {
|
|
if st.closed {
|
|
return nil
|
|
}
|
|
|
|
defer errRecover(&err)
|
|
|
|
w := newWriteBuf('C')
|
|
w.byte('S')
|
|
w.string(st.name)
|
|
st.cn.send(w)
|
|
|
|
st.cn.send(newWriteBuf('S'))
|
|
|
|
t, _ := st.cn.recv()
|
|
if t != '3' {
|
|
errorf("unexpected close response: %q", t)
|
|
}
|
|
st.closed = true
|
|
|
|
t, _ = st.cn.recv()
|
|
if t != 'Z' {
|
|
errorf("expected ready for query, but got: %q", t)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (st *stmt) Query(v []driver.Value) (_ driver.Rows, err error) {
|
|
defer errRecover(&err)
|
|
st.exec(v)
|
|
return &rows{st: st}, nil
|
|
}
|
|
|
|
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
|
|
defer errRecover(&err)
|
|
|
|
if len(v) == 0 {
|
|
return st.cn.simpleQuery(st.query)
|
|
}
|
|
st.exec(v)
|
|
|
|
for {
|
|
t, r := st.cn.recv1()
|
|
switch t {
|
|
case 'E':
|
|
err = parseError(r)
|
|
case 'C':
|
|
res = parseComplete(r.string())
|
|
case 'Z':
|
|
// done
|
|
return
|
|
case 'D':
|
|
errorf("unexpected data row returned in Exec; check your query")
|
|
case 'S', 'N':
|
|
// Ignore
|
|
default:
|
|
errorf("unknown exec response: %q", t)
|
|
}
|
|
}
|
|
|
|
panic("not reached")
|
|
}
|
|
|
|
func (st *stmt) exec(v []driver.Value) {
|
|
w := newWriteBuf('B')
|
|
w.string("")
|
|
w.string(st.name)
|
|
w.int16(0)
|
|
w.int16(len(v))
|
|
for i, x := range v {
|
|
if x == nil {
|
|
w.int32(-1)
|
|
} else {
|
|
b := encode(x, st.paramTyps[i])
|
|
w.int32(len(b))
|
|
w.bytes(b)
|
|
}
|
|
}
|
|
w.int16(0)
|
|
st.cn.send(w)
|
|
|
|
w = newWriteBuf('E')
|
|
w.string("")
|
|
w.int32(0)
|
|
st.cn.send(w)
|
|
|
|
st.cn.send(newWriteBuf('S'))
|
|
|
|
var err error
|
|
for {
|
|
t, r := st.cn.recv1()
|
|
switch t {
|
|
case 'E':
|
|
err = parseError(r)
|
|
case '2':
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
case 'Z':
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return
|
|
case 'N':
|
|
// ignore
|
|
default:
|
|
errorf("unexpected bind response: %q", t)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (st *stmt) NumInput() int {
|
|
return st.nparams
|
|
}
|
|
|
|
type result int64
|
|
|
|
func (i result) RowsAffected() (int64, error) {
|
|
return int64(i), nil
|
|
}
|
|
|
|
func (i result) LastInsertId() (int64, error) {
|
|
return 0, ErrNotSupported
|
|
}
|
|
|
|
func parseComplete(s string) driver.Result {
|
|
parts := strings.Split(s, " ")
|
|
n, _ := strconv.ParseInt(parts[len(parts)-1], 10, 64)
|
|
return result(n)
|
|
}
|
|
|
|
type rows struct {
|
|
st *stmt
|
|
done bool
|
|
}
|
|
|
|
func (rs *rows) Close() error {
|
|
for {
|
|
err := rs.Next(nil)
|
|
switch err {
|
|
case nil:
|
|
case io.EOF:
|
|
return nil
|
|
default:
|
|
return err
|
|
}
|
|
}
|
|
panic("not reached")
|
|
}
|
|
|
|
func (rs *rows) Columns() []string {
|
|
return rs.st.cols
|
|
}
|
|
|
|
func (rs *rows) Next(dest []driver.Value) (err error) {
|
|
if rs.done {
|
|
return io.EOF
|
|
}
|
|
|
|
defer errRecover(&err)
|
|
|
|
for {
|
|
t, r := rs.st.cn.recv1()
|
|
switch t {
|
|
case 'E':
|
|
err = parseError(r)
|
|
case 'C', 'S', 'N':
|
|
continue
|
|
case 'Z':
|
|
rs.done = true
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return io.EOF
|
|
case 'D':
|
|
n := r.int16()
|
|
for i := 0; i < len(dest) && i < n; i++ {
|
|
l := r.int32()
|
|
if l == -1 {
|
|
dest[i] = nil
|
|
continue
|
|
}
|
|
dest[i] = decode(r.next(l), rs.st.rowTyps[i])
|
|
}
|
|
return
|
|
default:
|
|
errorf("unexpected message after execute: %q", t)
|
|
}
|
|
}
|
|
|
|
panic("not reached")
|
|
}
|
|
|
|
func md5s(s string) string {
|
|
h := md5.New()
|
|
h.Write([]byte(s))
|
|
return fmt.Sprintf("%x", h.Sum(nil))
|
|
}
|
|
|
|
// parseEnviron tries to mimic some of libpq's environment handling
|
|
//
|
|
// To ease testing, it does not directly reference os.Environ, but is
|
|
// designed to accept its output.
|
|
//
|
|
// Environment-set connection information is intended to have a higher
|
|
// precedence than a library default but lower than any explicitly
|
|
// passed information (such as in the URL or connection string).
|
|
func parseEnviron(env []string) (out map[string]string) {
|
|
out = make(map[string]string)
|
|
|
|
for _, v := range env {
|
|
parts := strings.SplitN(v, "=", 2)
|
|
|
|
accrue := func(keyname string) {
|
|
out[keyname] = parts[1]
|
|
}
|
|
|
|
// The order of these is the same as is seen in the
|
|
// PostgreSQL 9.1 manual, with omissions briefly
|
|
// noted.
|
|
switch parts[0] {
|
|
case "PGHOST":
|
|
accrue("host")
|
|
case "PGHOSTADDR":
|
|
accrue("hostaddr")
|
|
case "PGPORT":
|
|
accrue("port")
|
|
case "PGDATABASE":
|
|
accrue("dbname")
|
|
case "PGUSER":
|
|
accrue("user")
|
|
case "PGPASSWORD":
|
|
accrue("password")
|
|
// skip PGPASSFILE, PGSERVICE, PGSERVICEFILE,
|
|
// PGREALM
|
|
case "PGOPTIONS":
|
|
accrue("options")
|
|
case "PGAPPNAME":
|
|
accrue("application_name")
|
|
case "PGSSLMODE":
|
|
accrue("sslmode")
|
|
case "PGREQUIRESSL":
|
|
accrue("requiressl")
|
|
case "PGSSLCERT":
|
|
accrue("sslcert")
|
|
case "PGSSLKEY":
|
|
accrue("sslkey")
|
|
case "PGSSLROOTCERT":
|
|
accrue("sslrootcert")
|
|
case "PGSSLCRL":
|
|
accrue("sslcrl")
|
|
case "PGREQUIREPEER":
|
|
accrue("requirepeer")
|
|
case "PGKRBSRVNAME":
|
|
accrue("krbsrvname")
|
|
case "PGGSSLIB":
|
|
accrue("gsslib")
|
|
case "PGCONNECT_TIMEOUT":
|
|
accrue("connect_timeout")
|
|
case "PGCLIENTENCODING":
|
|
accrue("client_encoding")
|
|
// skip PGDATESTYLE, PGTZ, PGGEQO, PGSYSCONFDIR,
|
|
// PGLOCALEDIR
|
|
}
|
|
}
|
|
|
|
return out
|
|
}
|