You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/github.com/uptrace/bun/driver/pgdriver/config.go

408 lines
9.3 KiB

package pgdriver
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"net/url"
"os"
"strconv"
"strings"
"time"
)
type Config struct {
// Network type, either tcp or unix.
// Default is tcp.
Network string
// TCP host:port or Unix socket depending on Network.
Addr string
// Dial timeout for establishing new connections.
// Default is 5 seconds.
DialTimeout time.Duration
// Dialer creates new network connection and has priority over
// Network and Addr options.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)
// TLS config for secure connections.
TLSConfig *tls.Config
User string
Password string
Database string
AppName string
// PostgreSQL session parameters updated with `SET` command when a connection is created.
ConnParams map[string]interface{}
// Timeout for socket reads. If reached, commands fail with a timeout instead of blocking.
ReadTimeout time.Duration
// Timeout for socket writes. If reached, commands fail with a timeout instead of blocking.
WriteTimeout time.Duration
}
func newDefaultConfig() *Config {
host := env("PGHOST", "localhost")
port := env("PGPORT", "5432")
cfg := &Config{
Network: "tcp",
Addr: net.JoinHostPort(host, port),
DialTimeout: 5 * time.Second,
TLSConfig: &tls.Config{InsecureSkipVerify: true},
User: env("PGUSER", "postgres"),
Database: env("PGDATABASE", "postgres"),
ReadTimeout: 10 * time.Second,
WriteTimeout: 5 * time.Second,
}
cfg.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialer := &net.Dialer{
Timeout: cfg.DialTimeout,
KeepAlive: 5 * time.Minute,
}
return netDialer.DialContext(ctx, network, addr)
}
return cfg
}
type Option func(cfg *Config)
// Deprecated. Use Option instead.
type DriverOption = Option
func WithNetwork(network string) Option {
if network == "" {
panic("network is empty")
}
return func(cfg *Config) {
cfg.Network = network
}
}
func WithAddr(addr string) Option {
if addr == "" {
panic("addr is empty")
}
return func(cfg *Config) {
cfg.Addr = addr
}
}
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *Config) {
cfg.TLSConfig = tlsConfig
}
}
func WithInsecure(on bool) Option {
return func(cfg *Config) {
if on {
cfg.TLSConfig = nil
} else {
cfg.TLSConfig = &tls.Config{InsecureSkipVerify: true}
}
}
}
func WithUser(user string) Option {
if user == "" {
panic("user is empty")
}
return func(cfg *Config) {
cfg.User = user
}
}
func WithPassword(password string) Option {
return func(cfg *Config) {
cfg.Password = password
}
}
func WithDatabase(database string) Option {
if database == "" {
panic("database is empty")
}
return func(cfg *Config) {
cfg.Database = database
}
}
func WithApplicationName(appName string) Option {
return func(cfg *Config) {
cfg.AppName = appName
}
}
func WithConnParams(params map[string]interface{}) Option {
return func(cfg *Config) {
cfg.ConnParams = params
}
}
func WithTimeout(timeout time.Duration) Option {
return func(cfg *Config) {
cfg.DialTimeout = timeout
cfg.ReadTimeout = timeout
cfg.WriteTimeout = timeout
}
}
func WithDialTimeout(dialTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.DialTimeout = dialTimeout
}
}
func WithReadTimeout(readTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.ReadTimeout = readTimeout
}
}
func WithWriteTimeout(writeTimeout time.Duration) Option {
return func(cfg *Config) {
cfg.WriteTimeout = writeTimeout
}
}
func WithDSN(dsn string) Option {
return func(cfg *Config) {
opts, err := parseDSN(dsn)
if err != nil {
panic(err)
}
for _, opt := range opts {
opt(cfg)
}
}
}
func env(key, defValue string) string {
if s := os.Getenv(key); s != "" {
return s
}
return defValue
}
//------------------------------------------------------------------------------
func parseDSN(dsn string) ([]Option, error) {
u, err := url.Parse(dsn)
if err != nil {
return nil, err
}
q := queryOptions{q: u.Query()}
var opts []Option
switch u.Scheme {
case "postgres", "postgresql":
if u.Host != "" {
addr := u.Host
if !strings.Contains(addr, ":") {
addr += ":5432"
}
opts = append(opts, WithAddr(addr))
}
if len(u.Path) > 1 {
opts = append(opts, WithDatabase(u.Path[1:]))
}
if host := q.string("host"); host != "" {
opts = append(opts, WithAddr(host))
if host[0] == '/' {
opts = append(opts, WithNetwork("unix"))
}
}
case "unix":
if len(u.Path) == 0 {
return nil, fmt.Errorf("unix socket DSN requires a path: %s", dsn)
}
opts = append(opts, WithNetwork("unix"))
if u.Host != "" {
opts = append(opts, WithDatabase(u.Host))
}
opts = append(opts, WithAddr(u.Path))
default:
return nil, errors.New("pgdriver: invalid scheme: " + u.Scheme)
}
if u.User != nil {
opts = append(opts, WithUser(u.User.Username()))
if password, ok := u.User.Password(); ok {
opts = append(opts, WithPassword(password))
}
}
if appName := q.string("application_name"); appName != "" {
opts = append(opts, WithApplicationName(appName))
}
if sslMode, sslRootCert := q.string("sslmode"), q.string("sslrootcert"); sslMode != "" || sslRootCert != "" {
tlsConfig := &tls.Config{}
switch sslMode {
case "disable":
tlsConfig = nil
case "allow", "prefer", "":
tlsConfig.InsecureSkipVerify = true
case "require":
if sslRootCert == "" {
tlsConfig.InsecureSkipVerify = true
break
}
// For backwards compatibility reasons, in the presence of `sslrootcert`,
// `sslmode` = `require` must act as if `sslmode` = `verify-ca`. See the note at
// https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES .
fallthrough
case "verify-ca":
// The default certificate verification will also verify the host name
// which is not the behavior of `verify-ca`. As such, we need to manually
// check the certificate chain.
// At the time of writing, tls.Config has no option for this behavior
// (verify chain, but skip server name).
// See https://github.com/golang/go/issues/21971 .
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, 0, len(rawCerts))
for _, rawCert := range rawCerts {
cert, err := x509.ParseCertificate(rawCert)
if err != nil {
return fmt.Errorf("pgdriver: failed to parse certificate: %w", err)
}
certs = append(certs, cert)
}
intermediates := x509.NewCertPool()
for _, cert := range certs[1:] {
intermediates.AddCert(cert)
}
_, err := certs[0].Verify(x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
Intermediates: intermediates,
})
return err
}
case "verify-full":
tlsConfig.ServerName = u.Host
if host, _, err := net.SplitHostPort(u.Host); err == nil {
tlsConfig.ServerName = host
}
default:
return nil, fmt.Errorf("pgdriver: sslmode '%s' is not supported", sslMode)
}
if tlsConfig != nil && sslRootCert != "" {
rawCA, err := ioutil.ReadFile(sslRootCert)
if err != nil {
return nil, fmt.Errorf("pgdriver: failed to read root CA: %w", err)
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(rawCA) {
return nil, fmt.Errorf("pgdriver: failed to append root CA")
}
tlsConfig.RootCAs = certPool
}
opts = append(opts, WithTLSConfig(tlsConfig))
}
if d := q.duration("timeout"); d != 0 {
opts = append(opts, WithTimeout(d))
}
if d := q.duration("dial_timeout"); d != 0 {
opts = append(opts, WithDialTimeout(d))
}
if d := q.duration("connect_timeout"); d != 0 {
opts = append(opts, WithDialTimeout(d))
}
if d := q.duration("read_timeout"); d != 0 {
opts = append(opts, WithReadTimeout(d))
}
if d := q.duration("write_timeout"); d != 0 {
opts = append(opts, WithWriteTimeout(d))
}
rem, err := q.remaining()
if err != nil {
return nil, q.err
}
if len(rem) > 0 {
params := make(map[string]interface{}, len(rem))
for k, v := range rem {
params[k] = v
}
opts = append(opts, WithConnParams(params))
}
return opts, nil
}
// verify is a method to make sure if the config is legitimate
// in the case it detects any errors, it returns with a non-nil error
// it can be extended to check other parameters
func (c *Config) verify() error {
if c.User == "" {
return errors.New("pgdriver: User option is empty (to configure, use WithUser).")
}
return nil
}
type queryOptions struct {
q url.Values
err error
}
func (o *queryOptions) string(name string) string {
vs := o.q[name]
if len(vs) == 0 {
return ""
}
delete(o.q, name) // enable detection of unknown parameters
return vs[len(vs)-1]
}
func (o *queryOptions) duration(name string) time.Duration {
s := o.string(name)
if s == "" {
return 0
}
// try plain number first
if i, err := strconv.Atoi(s); err == nil {
if i <= 0 {
// disable timeouts
return -1
}
return time.Duration(i) * time.Second
}
dur, err := time.ParseDuration(s)
if err == nil {
return dur
}
if o.err == nil {
o.err = fmt.Errorf("pgdriver: invalid %s duration: %w", name, err)
}
return 0
}
func (o *queryOptions) remaining() (map[string]string, error) {
if o.err != nil {
return nil, o.err
}
if len(o.q) == 0 {
return nil, nil
}
m := make(map[string]string, len(o.q))
for k, ss := range o.q {
m[k] = ss[len(ss)-1]
}
return m, nil
}