You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
420 lines
9.8 KiB
420 lines
9.8 KiB
2 years ago
|
package pgdriver
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io/ioutil"
|
||
|
"net"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
type Config struct {
|
||
|
// Network type, either tcp or unix.
|
||
|
// Default is tcp.
|
||
|
Network string
|
||
|
// TCP host:port or Unix socket depending on Network.
|
||
|
Addr string
|
||
|
// Dial timeout for establishing new connections.
|
||
|
// Default is 5 seconds.
|
||
|
DialTimeout time.Duration
|
||
|
// Dialer creates new network connection and has priority over
|
||
|
// Network and Addr options.
|
||
|
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||
|
|
||
|
// TLS config for secure connections.
|
||
|
TLSConfig *tls.Config
|
||
|
|
||
|
User string
|
||
|
Password string
|
||
|
Database string
|
||
|
AppName string
|
||
|
// PostgreSQL session parameters updated with `SET` command when a connection is created.
|
||
|
ConnParams map[string]interface{}
|
||
|
|
||
|
// Timeout for socket reads. If reached, commands fail with a timeout instead of blocking.
|
||
|
ReadTimeout time.Duration
|
||
|
// Timeout for socket writes. If reached, commands fail with a timeout instead of blocking.
|
||
|
WriteTimeout time.Duration
|
||
|
|
||
|
// ResetSessionFunc is called prior to executing a query on a connection that has been used before.
|
||
|
ResetSessionFunc func(context.Context, *Conn) error
|
||
|
}
|
||
|
|
||
|
func newDefaultConfig() *Config {
|
||
|
host := env("PGHOST", "localhost")
|
||
|
port := env("PGPORT", "5432")
|
||
|
|
||
|
cfg := &Config{
|
||
|
Network: "tcp",
|
||
|
Addr: net.JoinHostPort(host, port),
|
||
|
DialTimeout: 5 * time.Second,
|
||
|
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||
|
|
||
|
User: env("PGUSER", "postgres"),
|
||
|
Database: env("PGDATABASE", "postgres"),
|
||
|
|
||
|
ReadTimeout: 10 * time.Second,
|
||
|
WriteTimeout: 5 * time.Second,
|
||
|
}
|
||
|
|
||
|
cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||
|
netDialer := &net.Dialer{
|
||
|
Timeout: cfg.DialTimeout,
|
||
|
KeepAlive: 5 * time.Minute,
|
||
|
}
|
||
|
return netDialer.DialContext(ctx, network, addr)
|
||
|
}
|
||
|
|
||
|
return cfg
|
||
|
}
|
||
|
|
||
|
type Option func(cfg *Config)
|
||
|
|
||
|
// Deprecated. Use Option instead.
|
||
|
type DriverOption = Option
|
||
|
|
||
|
func WithNetwork(network string) Option {
|
||
|
if network == "" {
|
||
|
panic("network is empty")
|
||
|
}
|
||
|
return func(cfg *Config) {
|
||
|
cfg.Network = network
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithAddr(addr string) Option {
|
||
|
if addr == "" {
|
||
|
panic("addr is empty")
|
||
|
}
|
||
|
return func(cfg *Config) {
|
||
|
cfg.Addr = addr
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
||
|
return func(cfg *Config) {
|
||
|
cfg.TLSConfig = tlsConfig
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithInsecure(on bool) Option {
|
||
|
return func(cfg *Config) {
|
||
|
if on {
|
||
|
cfg.TLSConfig = nil
|
||
|
} else {
|
||
|
cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithUser(user string) Option {
|
||
|
if user == "" {
|
||
|
panic("user is empty")
|
||
|
}
|
||
|
return func(cfg *Config) {
|
||
|
cfg.User = user
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithPassword(password string) Option {
|
||
|
return func(cfg *Config) {
|
||
|
cfg.Password = password
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithDatabase(database string) Option {
|
||
|
if database == "" {
|
||
|
panic("database is empty")
|
||
|
}
|
||
|
return func(cfg *Config) {
|
||
|
cfg.Database = database
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithApplicationName(appName string) Option {
|
||
|
return func(cfg *Config) {
|
||
|
cfg.AppName = appName
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithConnParams(params map[string]interface{}) Option {
|
||
|
return func(cfg *Config) {
|
||
|
cfg.ConnParams = params
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithTimeout(timeout time.Duration) Option {
|
||
|
return func(cfg *Config) {
|
||
|
cfg.DialTimeout = timeout
|
||
|
cfg.ReadTimeout = timeout
|
||
|
cfg.WriteTimeout = timeout
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithDialTimeout(dialTimeout time.Duration) Option {
|
||
|
return func(cfg *Config) {
|
||
|
cfg.DialTimeout = dialTimeout
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithReadTimeout(readTimeout time.Duration) Option {
|
||
|
return func(cfg *Config) {
|
||
|
cfg.ReadTimeout = readTimeout
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithWriteTimeout(writeTimeout time.Duration) Option {
|
||
|
return func(cfg *Config) {
|
||
|
cfg.WriteTimeout = writeTimeout
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// WithResetSessionFunc configures a function that is called prior to executing
|
||
|
// a query on a connection that has been used before.
|
||
|
// If the func returns driver.ErrBadConn, the connection is discarded.
|
||
|
func WithResetSessionFunc(fn func(context.Context, *Conn) error) Option {
|
||
|
return func(cfg *Config) {
|
||
|
cfg.ResetSessionFunc = fn
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func WithDSN(dsn string) Option {
|
||
|
return func(cfg *Config) {
|
||
|
opts, err := parseDSN(dsn)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
for _, opt := range opts {
|
||
|
opt(cfg)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func env(key, defValue string) string {
|
||
|
if s := os.Getenv(key); s != "" {
|
||
|
return s
|
||
|
}
|
||
|
return defValue
|
||
|
}
|
||
|
|
||
|
//------------------------------------------------------------------------------
|
||
|
|
||
|
func parseDSN(dsn string) ([]Option, error) {
|
||
|
u, err := url.Parse(dsn)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
q := queryOptions{q: u.Query()}
|
||
|
var opts []Option
|
||
|
|
||
|
switch u.Scheme {
|
||
|
case "postgres", "postgresql":
|
||
|
if u.Host != "" {
|
||
|
addr := u.Host
|
||
|
if !strings.Contains(addr, ":") {
|
||
|
addr += ":5432"
|
||
|
}
|
||
|
opts = append(opts, WithAddr(addr))
|
||
|
}
|
||
|
|
||
|
if len(u.Path) > 1 {
|
||
|
opts = append(opts, WithDatabase(u.Path[1:]))
|
||
|
}
|
||
|
|
||
|
if host := q.string("host"); host != "" {
|
||
|
opts = append(opts, WithAddr(host))
|
||
|
if host[0] == '/' {
|
||
|
opts = append(opts, WithNetwork("unix"))
|
||
|
}
|
||
|
}
|
||
|
case "unix":
|
||
|
if len(u.Path) == 0 {
|
||
|
return nil, fmt.Errorf("unix socket DSN requires a path: %s", dsn)
|
||
|
}
|
||
|
|
||
|
opts = append(opts, WithNetwork("unix"))
|
||
|
if u.Host != "" {
|
||
|
opts = append(opts, WithDatabase(u.Host))
|
||
|
}
|
||
|
opts = append(opts, WithAddr(u.Path))
|
||
|
default:
|
||
|
return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
|
||
|
}
|
||
|
|
||
|
if u.User != nil {
|
||
|
opts = append(opts, WithUser(u.User.Username()))
|
||
|
if password, ok := u.User.Password(); ok {
|
||
|
opts = append(opts, WithPassword(password))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if appName := q.string("application_name"); appName != "" {
|
||
|
opts = append(opts, WithApplicationName(appName))
|
||
|
}
|
||
|
|
||
|
if sslMode, sslRootCert := q.string("sslmode"), q.string("sslrootcert"); sslMode != "" || sslRootCert != "" {
|
||
|
tlsConfig := &tls.Config{}
|
||
|
switch sslMode {
|
||
|
case "disable":
|
||
|
tlsConfig = nil
|
||
|
case "allow", "prefer", "":
|
||
|
tlsConfig.InsecureSkipVerify = true
|
||
|
case "require":
|
||
|
if sslRootCert == "" {
|
||
|
tlsConfig.InsecureSkipVerify = true
|
||
|
break
|
||
|
}
|
||
|
// For backwards compatibility reasons, in the presence of `sslrootcert`,
|
||
|
// `sslmode` = `require` must act as if `sslmode` = `verify-ca`. See the note at
|
||
|
// https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES .
|
||
|
fallthrough
|
||
|
case "verify-ca":
|
||
|
// The default certificate verification will also verify the host name
|
||
|
// which is not the behavior of `verify-ca`. As such, we need to manually
|
||
|
// check the certificate chain.
|
||
|
// At the time of writing, tls.Config has no option for this behavior
|
||
|
// (verify chain, but skip server name).
|
||
|
// See https://github.com/golang/go/issues/21971 .
|
||
|
tlsConfig.InsecureSkipVerify = true
|
||
|
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||
|
certs := make([]*x509.Certificate, 0, len(rawCerts))
|
||
|
for _, rawCert := range rawCerts {
|
||
|
cert, err := x509.ParseCertificate(rawCert)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("pgdriver: failed to parse certificate: %w", err)
|
||
|
}
|
||
|
certs = append(certs, cert)
|
||
|
}
|
||
|
intermediates := x509.NewCertPool()
|
||
|
for _, cert := range certs[1:] {
|
||
|
intermediates.AddCert(cert)
|
||
|
}
|
||
|
_, err := certs[0].Verify(x509.VerifyOptions{
|
||
|
Roots: tlsConfig.RootCAs,
|
||
|
Intermediates: intermediates,
|
||
|
})
|
||
|
return err
|
||
|
}
|
||
|
case "verify-full":
|
||
|
tlsConfig.ServerName = u.Host
|
||
|
if host, _, err := net.SplitHostPort(u.Host); err == nil {
|
||
|
tlsConfig.ServerName = host
|
||
|
}
|
||
|
default:
|
||
|
return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode)
|
||
|
}
|
||
|
if tlsConfig != nil && sslRootCert != "" {
|
||
|
rawCA, err := ioutil.ReadFile(sslRootCert)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("pgdriver: failed to read root CA: %w", err)
|
||
|
}
|
||
|
certPool := x509.NewCertPool()
|
||
|
if !certPool.AppendCertsFromPEM(rawCA) {
|
||
|
return nil, fmt.Errorf("pgdriver: failed to append root CA")
|
||
|
}
|
||
|
tlsConfig.RootCAs = certPool
|
||
|
}
|
||
|
opts = append(opts, WithTLSConfig(tlsConfig))
|
||
|
}
|
||
|
|
||
|
if d := q.duration("timeout"); d != 0 {
|
||
|
opts = append(opts, WithTimeout(d))
|
||
|
}
|
||
|
if d := q.duration("dial_timeout"); d != 0 {
|
||
|
opts = append(opts, WithDialTimeout(d))
|
||
|
}
|
||
|
if d := q.duration("connect_timeout"); d != 0 {
|
||
|
opts = append(opts, WithDialTimeout(d))
|
||
|
}
|
||
|
if d := q.duration("read_timeout"); d != 0 {
|
||
|
opts = append(opts, WithReadTimeout(d))
|
||
|
}
|
||
|
if d := q.duration("write_timeout"); d != 0 {
|
||
|
opts = append(opts, WithWriteTimeout(d))
|
||
|
}
|
||
|
|
||
|
rem, err := q.remaining()
|
||
|
if err != nil {
|
||
|
return nil, q.err
|
||
|
}
|
||
|
|
||
|
if len(rem) > 0 {
|
||
|
params := make(map[string]interface{}, len(rem))
|
||
|
for k, v := range rem {
|
||
|
params[k] = v
|
||
|
}
|
||
|
opts = append(opts, WithConnParams(params))
|
||
|
}
|
||
|
|
||
|
return opts, nil
|
||
|
}
|
||
|
|
||
|
// verify is a method to make sure if the config is legitimate
|
||
|
// in the case it detects any errors, it returns with a non-nil error
|
||
|
// it can be extended to check other parameters
|
||
|
func (c *Config) verify() error {
|
||
|
if c.User == "" {
|
||
|
return errors.New("pgdriver: User option is empty (to configure, use WithUser).")
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type queryOptions struct {
|
||
|
q url.Values
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
func (o *queryOptions) string(name string) string {
|
||
|
vs := o.q[name]
|
||
|
if len(vs) == 0 {
|
||
|
return ""
|
||
|
}
|
||
|
delete(o.q, name) // enable detection of unknown parameters
|
||
|
return vs[len(vs)-1]
|
||
|
}
|
||
|
|
||
|
func (o *queryOptions) duration(name string) time.Duration {
|
||
|
s := o.string(name)
|
||
|
if s == "" {
|
||
|
return 0
|
||
|
}
|
||
|
// try plain number first
|
||
|
if i, err := strconv.Atoi(s); err == nil {
|
||
|
if i <= 0 {
|
||
|
// disable timeouts
|
||
|
return -1
|
||
|
}
|
||
|
return time.Duration(i) * time.Second
|
||
|
}
|
||
|
dur, err := time.ParseDuration(s)
|
||
|
if err == nil {
|
||
|
return dur
|
||
|
}
|
||
|
if o.err == nil {
|
||
|
o.err = fmt.Errorf("pgdriver: invalid %s duration: %w", name, err)
|
||
|
}
|
||
|
return 0
|
||
|
}
|
||
|
|
||
|
func (o *queryOptions) remaining() (map[string]string, error) {
|
||
|
if o.err != nil {
|
||
|
return nil, o.err
|
||
|
}
|
||
|
if len(o.q) == 0 {
|
||
|
return nil, nil
|
||
|
}
|
||
|
m := make(map[string]string, len(o.q))
|
||
|
for k, ss := range o.q {
|
||
|
m[k] = ss[len(ss)-1]
|
||
|
}
|
||
|
return m, nil
|
||
|
}
|