You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

420 lines
9.8 KiB

package pgdriver
import (
type Config struct {
// Network type, either tcp or unix.
// Default is tcp.
Network string
// TCP host:port or Unix socket depending on Network.
Addr string
// Dial timeout for establishing new connections.
// Default is 5 seconds.
DialTimeout time.Duration
// Dialer creates new network connection and has priority over
// Network and Addr options.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// TLS config for secure connections.
TLSConfig *tls.Config
User string
Password string
Database string
AppName string
// PostgreSQL session parameters updated with `SET` command when a connection is created.
ConnParams map[string]interface{}
// Timeout for socket reads. If reached, commands fail with a timeout instead of blocking.
ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands fail with a timeout instead of blocking.
WriteTimeout time.Duration
// ResetSessionFunc is called prior to executing a query on a connection that has been used before.
ResetSessionFunc func(context.Context, *Conn) error
func newDefaultConfig() *Config {
host := env("PGHOST", "localhost")
port := env("PGPORT", "5432")
cfg := &Config{
Network: "tcp",
Addr: net.JoinHostPort(host, port),
DialTimeout: 5 * time.Second,
TLSConfig: &tls.Config{InsecureSkipVerify: true},
User: env("PGUSER", "postgres"),
Database: env("PGDATABASE", "postgres"),
ReadTimeout: 10 * time.Second,
WriteTimeout: 5 * time.Second,
cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: cfg.DialTimeout,
KeepAlive: 5 * time.Minute,
return netDialer.DialContext(ctx, network, addr)
return cfg
type Option func(cfg *Config)
// Deprecated. Use Option instead.
type DriverOption = Option
func WithNetwork(network string) Option {
if network == "" {
panic("network is empty")
return func(cfg *Config) {
cfg.Network = network
func WithAddr(addr string) Option {
if addr == "" {
panic("addr is empty")
return func(cfg *Config) {
cfg.Addr = addr
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *Config) {
cfg.TLSConfig = tlsConfig
func WithInsecure(on bool) Option {
return func(cfg *Config) {
if on {
cfg.TLSConfig = nil
} else {
cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true}
func WithUser(user string) Option {
if user == "" {
panic("user is empty")
return func(cfg *Config) {
cfg.User = user
func WithPassword(password string) Option {
return func(cfg *Config) {
cfg.Password = password
func WithDatabase(database string) Option {
if database == "" {
panic("database is empty")
return func(cfg *Config) {
cfg.Database = database
func WithApplicationName(appName string) Option {
return func(cfg *Config) {
cfg.AppName = appName
func WithConnParams(params map[string]interface{}) Option {
return func(cfg *Config) {
cfg.ConnParams = params
func WithTimeout(timeout time.Duration) Option {
return func(cfg *Config) {
cfg.DialTimeout = timeout
cfg.ReadTimeout = timeout
cfg.WriteTimeout = timeout
func WithDialTimeout(dialTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.DialTimeout = dialTimeout
func WithReadTimeout(readTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.ReadTimeout = readTimeout
func WithWriteTimeout(writeTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.WriteTimeout = writeTimeout
// WithResetSessionFunc configures a function that is called prior to executing
// a query on a connection that has been used before.
// If the func returns driver.ErrBadConn, the connection is discarded.
func WithResetSessionFunc(fn func(context.Context, *Conn) error) Option {
return func(cfg *Config) {
cfg.ResetSessionFunc = fn
func WithDSN(dsn string) Option {
return func(cfg *Config) {
opts, err := parseDSN(dsn)
if err != nil {
for _, opt := range opts {
func env(key, defValue string) string {
if s := os.Getenv(key); s != "" {
return s
return defValue
func parseDSN(dsn string) ([]Option, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, err
q := queryOptions{q: u.Query()}
var opts []Option
switch u.Scheme {
case "postgres", "postgresql":
if u.Host != "" {
addr := u.Host
if !strings.Contains(addr, ":") {
addr += ":5432"
opts = append(opts, WithAddr(addr))
if len(u.Path) > 1 {
opts = append(opts, WithDatabase(u.Path[1:]))
if host := q.string("host"); host != "" {
opts = append(opts, WithAddr(host))
if host[0] == '/' {
opts = append(opts, WithNetwork("unix"))
case "unix":
if len(u.Path) == 0 {
return nil, fmt.Errorf("unix socket DSN requires a path: %s", dsn)
opts = append(opts, WithNetwork("unix"))
if u.Host != "" {
opts = append(opts, WithDatabase(u.Host))
opts = append(opts, WithAddr(u.Path))
return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
if u.User != nil {
opts = append(opts, WithUser(u.User.Username()))
if password, ok := u.User.Password(); ok {
opts = append(opts, WithPassword(password))
if appName := q.string("application_name"); appName != "" {
opts = append(opts, WithApplicationName(appName))
if sslMode, sslRootCert := q.string("sslmode"), q.string("sslrootcert"); sslMode != "" || sslRootCert != "" {
tlsConfig := &tls.Config{}
switch sslMode {
case "disable":
tlsConfig = nil
case "allow", "prefer", "":
tlsConfig.InsecureSkipVerify = true
case "require":
if sslRootCert == "" {
tlsConfig.InsecureSkipVerify = true
// For backwards compatibility reasons, in the presence of `sslrootcert`,
// `sslmode` = `require` must act as if `sslmode` = `verify-ca`. See the note at
// .
case "verify-ca":
// The default certificate verification will also verify the host name
// which is not the behavior of `verify-ca`. As such, we need to manually
// check the certificate chain.
// At the time of writing, tls.Config has no option for this behavior
// (verify chain, but skip server name).
// See .
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, 0, len(rawCerts))
for _, rawCert := range rawCerts {
cert, err := x509.ParseCertificate(rawCert)
if err != nil {
return fmt.Errorf("pgdriver: failed to parse certificate: %w", err)
certs = append(certs, cert)
intermediates := x509.NewCertPool()
for _, cert := range certs[1:] {
_, err := certs[0].Verify(x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
Intermediates: intermediates,
return err
case "verify-full":
tlsConfig.ServerName = u.Host
if host, _, err := net.SplitHostPort(u.Host); err == nil {
tlsConfig.ServerName = host
return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode)
if tlsConfig != nil && sslRootCert != "" {
rawCA, err := ioutil.ReadFile(sslRootCert)
if err != nil {
return nil, fmt.Errorf("pgdriver: failed to read root CA: %w", err)
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(rawCA) {
return nil, fmt.Errorf("pgdriver: failed to append root CA")
tlsConfig.RootCAs = certPool
opts = append(opts, WithTLSConfig(tlsConfig))
if d := q.duration("timeout"); d != 0 {
opts = append(opts, WithTimeout(d))
if d := q.duration("dial_timeout"); d != 0 {
opts = append(opts, WithDialTimeout(d))
if d := q.duration("connect_timeout"); d != 0 {
opts = append(opts, WithDialTimeout(d))
if d := q.duration("read_timeout"); d != 0 {
opts = append(opts, WithReadTimeout(d))
if d := q.duration("write_timeout"); d != 0 {
opts = append(opts, WithWriteTimeout(d))
rem, err := q.remaining()
if err != nil {
return nil, q.err
if len(rem) > 0 {
params := make(map[string]interface{}, len(rem))
for k, v := range rem {
params[k] = v
opts = append(opts, WithConnParams(params))
return opts, nil
// verify is a method to make sure if the config is legitimate
// in the case it detects any errors, it returns with a non-nil error
// it can be extended to check other parameters
func (c *Config) verify() error {
if c.User == "" {
return errors.New("pgdriver: User option is empty (to configure, use WithUser).")
return nil
type queryOptions struct {
q url.Values
err error
func (o *queryOptions) string(name string) string {
vs := o.q[name]
if len(vs) == 0 {
return ""
delete(o.q, name) // enable detection of unknown parameters
return vs[len(vs)-1]
func (o *queryOptions) duration(name string) time.Duration {
s := o.string(name)
if s == "" {
return 0
// try plain number first
if i, err := strconv.Atoi(s); err == nil {
if i <= 0 {
// disable timeouts
return -1
return time.Duration(i) * time.Second
dur, err := time.ParseDuration(s)
if err == nil {
return dur
if o.err == nil {
o.err = fmt.Errorf("pgdriver: invalid %s duration: %w", name, err)
return 0
func (o *queryOptions) remaining() (map[string]string, error) {
if o.err != nil {
return nil, o.err
if len(o.q) == 0 {
return nil, nil
m := make(map[string]string, len(o.q))
for k, ss := range o.q {
m[k] = ss[len(ss)-1]
return m, nil