parent
f3904a1605
commit
bcb246b16f
@ -0,0 +1,9 @@
|
|||||||
|
package dorm
|
||||||
|
|
||||||
|
import "github.com/upper/db/v4"
|
||||||
|
|
||||||
|
// UpperClient
|
||||||
|
// https://upper.io/
|
||||||
|
type UpperClient struct {
|
||||||
|
Db *db.Session // 驱动
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
package dorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/upper/db/v4/adapter/mysql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewUpperMysqlClient(settings mysql.ConnectionURL) (*UpperClient, error) {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c := &UpperClient{}
|
||||||
|
|
||||||
|
sess, err := mysql.Open(settings)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
|
||||||
|
}
|
||||||
|
defer sess.Close()
|
||||||
|
|
||||||
|
c.Db = &sess
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
@ -0,0 +1,23 @@
|
|||||||
|
package dorm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/upper/db/v4/adapter/postgresql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewUpperPostgresqlClient(settings postgresql.ConnectionURL) (*UpperClient, error) {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
c := &UpperClient{}
|
||||||
|
|
||||||
|
sess, err := postgresql.Open(settings)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
|
||||||
|
}
|
||||||
|
defer sess.Close()
|
||||||
|
|
||||||
|
c.Db = &sess
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
@ -0,0 +1,9 @@
|
|||||||
|
language: go
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.x
|
||||||
|
- tip
|
||||||
|
|
||||||
|
matrix:
|
||||||
|
allow_failures:
|
||||||
|
- go: tip
|
@ -0,0 +1,22 @@
|
|||||||
|
Copyright (c) 2019 Jack Christensen
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of this software and associated documentation files (the
|
||||||
|
"Software"), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be
|
||||||
|
included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,8 @@
|
|||||||
|
[![](https://godoc.org/github.com/jackc/chunkreader?status.svg)](https://godoc.org/github.com/jackc/chunkreader)
|
||||||
|
[![Build Status](https://travis-ci.org/jackc/chunkreader.svg)](https://travis-ci.org/jackc/chunkreader)
|
||||||
|
|
||||||
|
# chunkreader
|
||||||
|
|
||||||
|
Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations.
|
||||||
|
|
||||||
|
Extracted from original implementation in https://github.com/jackc/pgx.
|
@ -0,0 +1,104 @@
|
|||||||
|
// Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations.
|
||||||
|
package chunkreader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and
|
||||||
|
// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually
|
||||||
|
// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy.
|
||||||
|
//
|
||||||
|
// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is
|
||||||
|
// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare
|
||||||
|
// cases it would be advantageous to copy the bytes to another slice.
|
||||||
|
type ChunkReader struct {
|
||||||
|
r io.Reader
|
||||||
|
|
||||||
|
buf []byte
|
||||||
|
rp, wp int // buf read position and write position
|
||||||
|
|
||||||
|
config Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config contains configuration parameters for ChunkReader.
|
||||||
|
type Config struct {
|
||||||
|
MinBufLen int // Minimum buffer length
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates and returns a new ChunkReader for r with default configuration.
|
||||||
|
func New(r io.Reader) *ChunkReader {
|
||||||
|
cr, err := NewConfig(r, Config{})
|
||||||
|
if err != nil {
|
||||||
|
panic("default config can't be bad")
|
||||||
|
}
|
||||||
|
|
||||||
|
return cr
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfig creates and a new ChunkReader for r configured by config.
|
||||||
|
func NewConfig(r io.Reader, config Config) (*ChunkReader, error) {
|
||||||
|
if config.MinBufLen == 0 {
|
||||||
|
// By historical reasons Postgres currently has 8KB send buffer inside,
|
||||||
|
// so here we want to have at least the same size buffer.
|
||||||
|
// @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134
|
||||||
|
// @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru
|
||||||
|
config.MinBufLen = 8192
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ChunkReader{
|
||||||
|
r: r,
|
||||||
|
buf: make([]byte, config.MinBufLen),
|
||||||
|
config: config,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy
|
||||||
|
// of buf. If an error occurs, buf will be nil.
|
||||||
|
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
|
||||||
|
// n bytes already in buf
|
||||||
|
if (r.wp - r.rp) >= n {
|
||||||
|
buf = r.buf[r.rp : r.rp+n]
|
||||||
|
r.rp += n
|
||||||
|
return buf, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// available space in buf is less than n
|
||||||
|
if len(r.buf) < n {
|
||||||
|
r.copyBufContents(r.newBuf(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
||||||
|
minReadCount := n - (r.wp - r.rp)
|
||||||
|
if (len(r.buf) - r.wp) < minReadCount {
|
||||||
|
newBuf := r.newBuf(n)
|
||||||
|
r.copyBufContents(newBuf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.appendAtLeast(minReadCount); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = r.buf[r.rp : r.rp+n]
|
||||||
|
r.rp += n
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ChunkReader) appendAtLeast(fillLen int) error {
|
||||||
|
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
|
||||||
|
r.wp += n
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ChunkReader) newBuf(size int) []byte {
|
||||||
|
if size < r.config.MinBufLen {
|
||||||
|
size = r.config.MinBufLen
|
||||||
|
}
|
||||||
|
return make([]byte, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ChunkReader) copyBufContents(dest []byte) {
|
||||||
|
r.wp = copy(dest, r.buf[r.rp:r.wp])
|
||||||
|
r.rp = 0
|
||||||
|
r.buf = dest
|
||||||
|
}
|
@ -0,0 +1,3 @@
|
|||||||
|
.envrc
|
||||||
|
vendor/
|
||||||
|
.vscode
|
@ -0,0 +1,161 @@
|
|||||||
|
# 1.14.0 (February 11, 2023)
|
||||||
|
|
||||||
|
* Fix: each connection attempt to new node gets own timeout (Nathan Giardina)
|
||||||
|
* Set SNI for SSL connections (Stas Kelvich)
|
||||||
|
* Fix: CopyFrom I/O race (Tommy Reilly)
|
||||||
|
* Minor dependency upgrades
|
||||||
|
|
||||||
|
# 1.13.0 (August 6, 2022)
|
||||||
|
|
||||||
|
* Add sslpassword support (Eric McCormack and yun.xu)
|
||||||
|
* Add prefer-standby target_session_attrs support (sergey.bashilov)
|
||||||
|
* Fix GSS ErrorResponse handling (Oliver Tan)
|
||||||
|
|
||||||
|
# 1.12.1 (May 7, 2022)
|
||||||
|
|
||||||
|
* Fix: setting krbspn and krbsrvname in connection string (sireax)
|
||||||
|
* Add support for Unix sockets on Windows (Eno Compton)
|
||||||
|
* Stop ignoring ErrorResponse during SCRAM auth (Rafi Shamim)
|
||||||
|
|
||||||
|
# 1.12.0 (April 21, 2022)
|
||||||
|
|
||||||
|
* Add pluggable GSSAPI support (Oliver Tan)
|
||||||
|
* Fix: Consider any "0A000" error a possible cached plan changed error due to locale
|
||||||
|
* Better match psql fallback behavior with multiple hosts
|
||||||
|
|
||||||
|
# 1.11.0 (February 7, 2022)
|
||||||
|
|
||||||
|
* Support port in ip from LookupFunc to override config (James Hartig)
|
||||||
|
* Fix TLS connection timeout (Blake Embrey)
|
||||||
|
* Add support for read-only, primary, standby, prefer-standby target_session_attributes (Oscar)
|
||||||
|
* Fix connect when receiving NoticeResponse
|
||||||
|
|
||||||
|
# 1.10.1 (November 20, 2021)
|
||||||
|
|
||||||
|
* Close without waiting for response (Kei Kamikawa)
|
||||||
|
* Save waiting for network round-trip in CopyFrom (Rueian)
|
||||||
|
* Fix concurrency issue with ContextWatcher
|
||||||
|
* LRU.Get always checks context for cancellation / expiration (Georges Varouchas)
|
||||||
|
|
||||||
|
# 1.10.0 (July 24, 2021)
|
||||||
|
|
||||||
|
* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned.
|
||||||
|
|
||||||
|
# 1.9.0 (July 10, 2021)
|
||||||
|
|
||||||
|
* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr)
|
||||||
|
* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle)
|
||||||
|
* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard)
|
||||||
|
* Fix default host when parsing URL without host but with port
|
||||||
|
* Allow dbname query parameter in URL conn string
|
||||||
|
* Update underlying dependencies
|
||||||
|
|
||||||
|
# 1.8.1 (March 25, 2021)
|
||||||
|
|
||||||
|
* Better connection string sanitization (ip.novikov)
|
||||||
|
* Use proper pgpass location on Windows (Moshe Katz)
|
||||||
|
* Use errors instead of golang.org/x/xerrors
|
||||||
|
* Resume fallback on server error in Connect (Andrey Borodin)
|
||||||
|
|
||||||
|
# 1.8.0 (December 3, 2020)
|
||||||
|
|
||||||
|
* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes)
|
||||||
|
|
||||||
|
# 1.7.2 (November 3, 2020)
|
||||||
|
|
||||||
|
* Fix data value slices into work buffer with capacities larger than length.
|
||||||
|
|
||||||
|
# 1.7.1 (October 31, 2020)
|
||||||
|
|
||||||
|
* Do not asyncClose after receiving FATAL error from PostgreSQL server
|
||||||
|
|
||||||
|
# 1.7.0 (September 26, 2020)
|
||||||
|
|
||||||
|
* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded
|
||||||
|
* Add ReceiveResults (Sebastiaan Mannem)
|
||||||
|
* Fix parsing DSN connection with bad backslash
|
||||||
|
* Add PgConn.CleanupDone so connection pools can determine when async close is complete
|
||||||
|
|
||||||
|
# 1.6.4 (July 29, 2020)
|
||||||
|
|
||||||
|
* Fix deadlock on error after CommandComplete but before ReadyForQuery
|
||||||
|
* Fix panic on parsing DSN with trailing '='
|
||||||
|
|
||||||
|
# 1.6.3 (July 22, 2020)
|
||||||
|
|
||||||
|
* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo)
|
||||||
|
|
||||||
|
# 1.6.2 (July 14, 2020)
|
||||||
|
|
||||||
|
* Update pgservicefile library
|
||||||
|
|
||||||
|
# 1.6.1 (June 27, 2020)
|
||||||
|
|
||||||
|
* Update golang.org/x/crypto to latest
|
||||||
|
* Update golang.org/x/text to 0.3.3
|
||||||
|
* Fix error handling for bad PGSERVICE definition
|
||||||
|
* Redact passwords in ParseConfig errors (Lukas Vogel)
|
||||||
|
|
||||||
|
# 1.6.0 (June 6, 2020)
|
||||||
|
|
||||||
|
* Fix panic when closing conn during cancellable query
|
||||||
|
* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný)
|
||||||
|
* Fix field descriptions available after command concluded (Tobias Salzmann)
|
||||||
|
* Support connect_timeout (georgysavva)
|
||||||
|
* Handle IPv6 in connection URLs (Lukas Vogel)
|
||||||
|
* Fix ValidateConnect with cancelable context
|
||||||
|
* Improve CopyFrom performance
|
||||||
|
* Add Config.Copy (georgysavva)
|
||||||
|
|
||||||
|
# 1.5.0 (March 30, 2020)
|
||||||
|
|
||||||
|
* Update golang.org/x/crypto for security fix
|
||||||
|
* Implement "verify-ca" SSL mode (Greg Curtis)
|
||||||
|
|
||||||
|
# 1.4.0 (March 7, 2020)
|
||||||
|
|
||||||
|
* Fix ExecParams and ExecPrepared handling of empty query.
|
||||||
|
* Support reading config from PostgreSQL service files.
|
||||||
|
|
||||||
|
# 1.3.2 (February 14, 2020)
|
||||||
|
|
||||||
|
* Update chunkreader to v2.0.1 for optimized default buffer size.
|
||||||
|
|
||||||
|
# 1.3.1 (February 5, 2020)
|
||||||
|
|
||||||
|
* Fix CopyFrom deadlock when multiple NoticeResponse received during copy
|
||||||
|
|
||||||
|
# 1.3.0 (January 23, 2020)
|
||||||
|
|
||||||
|
* Add Hijack and Construct.
|
||||||
|
* Update pgproto3 to v2.0.1.
|
||||||
|
|
||||||
|
# 1.2.1 (January 13, 2020)
|
||||||
|
|
||||||
|
* Fix data race in context cancellation introduced in v1.2.0.
|
||||||
|
|
||||||
|
# 1.2.0 (January 11, 2020)
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag.
|
||||||
|
* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases.
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
* Improve performance when context.Background() is used. (bakape)
|
||||||
|
* CommandTag.RowsAffected is faster and does not allocate.
|
||||||
|
|
||||||
|
## Fixes
|
||||||
|
|
||||||
|
* Try to cancel any in-progress query when a conn is closed by ctx cancel.
|
||||||
|
* Handle NoticeResponse during CopyFrom.
|
||||||
|
* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish.
|
||||||
|
|
||||||
|
# 1.1.0 (October 12, 2019)
|
||||||
|
|
||||||
|
* Add PgConn.IsBusy() method.
|
||||||
|
|
||||||
|
# 1.0.1 (September 19, 2019)
|
||||||
|
|
||||||
|
* Fix statement cache not properly cleaning discarded statements.
|
@ -0,0 +1,22 @@
|
|||||||
|
Copyright (c) 2019-2021 Jack Christensen
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of this software and associated documentation files (the
|
||||||
|
"Software"), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be
|
||||||
|
included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,62 @@
|
|||||||
|
[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn)
|
||||||
|
![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
This version is used with pgx `v4`. In pgx `v5` it is part of the https://github.com/jackc/pgx repository.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# pgconn
|
||||||
|
|
||||||
|
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
|
||||||
|
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
|
||||||
|
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
|
||||||
|
low-level access to PostgreSQL functionality.
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("pgconn failed to connect:", err)
|
||||||
|
}
|
||||||
|
defer pgConn.Close(context.Background())
|
||||||
|
|
||||||
|
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
|
||||||
|
for result.NextRow() {
|
||||||
|
fmt.Println("User 123 has email:", string(result.Values()[0]))
|
||||||
|
}
|
||||||
|
_, err = result.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("failed reading result:", err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
|
||||||
|
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
|
||||||
|
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
|
||||||
|
environment variable handling.
|
||||||
|
|
||||||
|
### Example Test Environment
|
||||||
|
|
||||||
|
Connect to your PostgreSQL server and run:
|
||||||
|
|
||||||
|
```
|
||||||
|
create database pgx_test;
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can run the tests:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Connection and Authentication Tests
|
||||||
|
|
||||||
|
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
|
||||||
|
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
|
||||||
|
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
|
||||||
|
authentication code.
|
@ -0,0 +1,270 @@
|
|||||||
|
// SCRAM-SHA-256 authentication
|
||||||
|
//
|
||||||
|
// Resources:
|
||||||
|
// https://tools.ietf.org/html/rfc5802
|
||||||
|
// https://tools.ietf.org/html/rfc8265
|
||||||
|
// https://www.postgresql.org/docs/current/sasl-authentication.html
|
||||||
|
//
|
||||||
|
// Inspiration drawn from other implementations:
|
||||||
|
// https://github.com/lib/pq/pull/608
|
||||||
|
// https://github.com/lib/pq/pull/788
|
||||||
|
// https://github.com/lib/pq/pull/833
|
||||||
|
|
||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
|
"golang.org/x/crypto/pbkdf2"
|
||||||
|
"golang.org/x/text/secure/precis"
|
||||||
|
)
|
||||||
|
|
||||||
|
const clientNonceLen = 18
|
||||||
|
|
||||||
|
// Perform SCRAM authentication.
|
||||||
|
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||||
|
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send client-first-message in a SASLInitialResponse
|
||||||
|
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
||||||
|
AuthMechanism: "SCRAM-SHA-256",
|
||||||
|
Data: sc.clientFirstMessage(),
|
||||||
|
}
|
||||||
|
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||||
|
saslContinue, err := c.rxSASLContinue()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = sc.recvServerFirstMessage(saslContinue.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send client-final-message in a SASLResponse
|
||||||
|
saslResponse := &pgproto3.SASLResponse{
|
||||||
|
Data: []byte(sc.clientFinalMessage()),
|
||||||
|
}
|
||||||
|
_, err = c.conn.Write(saslResponse.Encode(nil))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||||
|
saslFinal, err := c.rxSASLFinal()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sc.recvServerFinalMessage(saslFinal.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
|
||||||
|
msg, err := c.receiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch m := msg.(type) {
|
||||||
|
case *pgproto3.AuthenticationSASLContinue:
|
||||||
|
return m, nil
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
return nil, ErrorResponseToPgError(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||||
|
msg, err := c.receiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
switch m := msg.(type) {
|
||||||
|
case *pgproto3.AuthenticationSASLFinal:
|
||||||
|
return m, nil
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
return nil, ErrorResponseToPgError(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
type scramClient struct {
|
||||||
|
serverAuthMechanisms []string
|
||||||
|
password []byte
|
||||||
|
clientNonce []byte
|
||||||
|
|
||||||
|
clientFirstMessageBare []byte
|
||||||
|
|
||||||
|
serverFirstMessage []byte
|
||||||
|
clientAndServerNonce []byte
|
||||||
|
salt []byte
|
||||||
|
iterations int
|
||||||
|
|
||||||
|
saltedPassword []byte
|
||||||
|
authMessage []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
|
||||||
|
sc := &scramClient{
|
||||||
|
serverAuthMechanisms: serverAuthMechanisms,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure server supports SCRAM-SHA-256
|
||||||
|
hasScramSHA256 := false
|
||||||
|
for _, mech := range sc.serverAuthMechanisms {
|
||||||
|
if mech == "SCRAM-SHA-256" {
|
||||||
|
hasScramSHA256 = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasScramSHA256 {
|
||||||
|
return nil, errors.New("server does not support SCRAM-SHA-256")
|
||||||
|
}
|
||||||
|
|
||||||
|
// precis.OpaqueString is equivalent to SASLprep for password.
|
||||||
|
var err error
|
||||||
|
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
|
||||||
|
if err != nil {
|
||||||
|
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
|
||||||
|
sc.password = []byte(password)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, clientNonceLen)
|
||||||
|
_, err = rand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
|
||||||
|
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
|
||||||
|
|
||||||
|
return sc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) clientFirstMessage() []byte {
|
||||||
|
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
|
||||||
|
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
|
||||||
|
sc.serverFirstMessage = serverFirstMessage
|
||||||
|
buf := serverFirstMessage
|
||||||
|
if !bytes.HasPrefix(buf, []byte("r=")) {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
|
||||||
|
}
|
||||||
|
buf = buf[2:]
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(buf, ',')
|
||||||
|
if idx == -1 {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||||
|
}
|
||||||
|
sc.clientAndServerNonce = buf[:idx]
|
||||||
|
buf = buf[idx+1:]
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(buf, []byte("s=")) {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||||
|
}
|
||||||
|
buf = buf[2:]
|
||||||
|
|
||||||
|
idx = bytes.IndexByte(buf, ',')
|
||||||
|
if idx == -1 {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||||
|
}
|
||||||
|
saltStr := buf[:idx]
|
||||||
|
buf = buf[idx+1:]
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(buf, []byte("i=")) {
|
||||||
|
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||||
|
}
|
||||||
|
buf = buf[2:]
|
||||||
|
iterationsStr := buf
|
||||||
|
|
||||||
|
var err error
|
||||||
|
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.iterations, err = strconv.Atoi(string(iterationsStr))
|
||||||
|
if err != nil || sc.iterations <= 0 {
|
||||||
|
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
|
||||||
|
return errors.New("invalid SCRAM nonce: did not start with client nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
|
||||||
|
return errors.New("invalid SCRAM nonce: did not include server nonce")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) clientFinalMessage() string {
|
||||||
|
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
|
||||||
|
|
||||||
|
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
|
||||||
|
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
|
||||||
|
|
||||||
|
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
|
||||||
|
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
|
||||||
|
return errors.New("invalid SCRAM server-final-message received from server")
|
||||||
|
}
|
||||||
|
|
||||||
|
serverSignature := serverFinalMessage[2:]
|
||||||
|
|
||||||
|
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
|
||||||
|
return errors.New("invalid SCRAM ServerSignature received from server")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeHMAC(key, msg []byte) []byte {
|
||||||
|
mac := hmac.New(sha256.New, key)
|
||||||
|
mac.Write(msg)
|
||||||
|
return mac.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
||||||
|
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||||
|
storedKey := sha256.Sum256(clientKey)
|
||||||
|
clientSignature := computeHMAC(storedKey[:], authMessage)
|
||||||
|
|
||||||
|
clientProof := make([]byte, len(clientSignature))
|
||||||
|
for i := 0; i < len(clientSignature); i++ {
|
||||||
|
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
|
||||||
|
base64.StdEncoding.Encode(buf, clientProof)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
||||||
|
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||||
|
serverSignature := computeHMAC(serverKey, authMessage)
|
||||||
|
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||||
|
base64.StdEncoding.Encode(buf, serverSignature)
|
||||||
|
return buf
|
||||||
|
}
|
@ -0,0 +1,905 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/chunkreader/v2"
|
||||||
|
"github.com/jackc/pgpassfile"
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
|
"github.com/jackc/pgservicefile"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
type GetSSLPasswordFunc func(ctx context.Context) string
|
||||||
|
|
||||||
|
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
|
||||||
|
// manually initialized Config will cause ConnectConfig to panic.
|
||||||
|
type Config struct {
|
||||||
|
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||||
|
Port uint16
|
||||||
|
Database string
|
||||||
|
User string
|
||||||
|
Password string
|
||||||
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
|
ConnectTimeout time.Duration
|
||||||
|
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||||
|
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||||
|
BuildFrontend BuildFrontendFunc
|
||||||
|
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||||
|
|
||||||
|
KerberosSrvName string
|
||||||
|
KerberosSpn string
|
||||||
|
Fallbacks []*FallbackConfig
|
||||||
|
|
||||||
|
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||||
|
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||||
|
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||||
|
ValidateConnect ValidateConnectFunc
|
||||||
|
|
||||||
|
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
|
||||||
|
// or prepare statements). If this returns an error the connection attempt fails.
|
||||||
|
AfterConnect AfterConnectFunc
|
||||||
|
|
||||||
|
// OnNotice is a callback function called when a notice response is received.
|
||||||
|
OnNotice NoticeHandler
|
||||||
|
|
||||||
|
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||||
|
OnNotification NotificationHandler
|
||||||
|
|
||||||
|
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
|
||||||
|
type ParseConfigOptions struct {
|
||||||
|
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
|
||||||
|
// PQsetSSLKeyPassHook_OpenSSL.
|
||||||
|
GetSSLPassword GetSSLPasswordFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy returns a deep copy of the config that is safe to use and modify.
|
||||||
|
// The only exception is the TLSConfig field:
|
||||||
|
// according to the tls.Config docs it must not be modified after creation.
|
||||||
|
func (c *Config) Copy() *Config {
|
||||||
|
newConf := new(Config)
|
||||||
|
*newConf = *c
|
||||||
|
if newConf.TLSConfig != nil {
|
||||||
|
newConf.TLSConfig = c.TLSConfig.Clone()
|
||||||
|
}
|
||||||
|
if newConf.RuntimeParams != nil {
|
||||||
|
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
|
||||||
|
for k, v := range c.RuntimeParams {
|
||||||
|
newConf.RuntimeParams[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if newConf.Fallbacks != nil {
|
||||||
|
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
|
||||||
|
for i, fallback := range c.Fallbacks {
|
||||||
|
newFallback := new(FallbackConfig)
|
||||||
|
*newFallback = *fallback
|
||||||
|
if newFallback.TLSConfig != nil {
|
||||||
|
newFallback.TLSConfig = fallback.TLSConfig.Clone()
|
||||||
|
}
|
||||||
|
newConf.Fallbacks[i] = newFallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newConf
|
||||||
|
}
|
||||||
|
|
||||||
|
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||||
|
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||||
|
type FallbackConfig struct {
|
||||||
|
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||||
|
Port uint16
|
||||||
|
TLSConfig *tls.Config // nil disables TLS
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAbsolutePath checks if the provided value is an absolute path either
|
||||||
|
// beginning with a forward slash (as on Linux-based systems) or with a capital
|
||||||
|
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
|
||||||
|
func isAbsolutePath(path string) bool {
|
||||||
|
isWindowsPath := func(p string) bool {
|
||||||
|
if len(p) < 3 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
drive := p[0]
|
||||||
|
colon := p[1]
|
||||||
|
backslash := p[2]
|
||||||
|
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.HasPrefix(path, "/") || isWindowsPath(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||||
|
// net.Dial.
|
||||||
|
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||||
|
if isAbsolutePath(host) {
|
||||||
|
network = "unix"
|
||||||
|
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||||
|
} else {
|
||||||
|
network = "tcp"
|
||||||
|
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||||
|
}
|
||||||
|
return network, address
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
|
||||||
|
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
|
||||||
|
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
|
||||||
|
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
|
||||||
|
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||||
|
//
|
||||||
|
// # Example DSN
|
||||||
|
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||||
|
//
|
||||||
|
// # Example URL
|
||||||
|
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||||
|
//
|
||||||
|
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
|
||||||
|
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
|
||||||
|
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
|
||||||
|
// not be modified individually. They should all be modified or all left unchanged.
|
||||||
|
//
|
||||||
|
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
||||||
|
// values that will be tried in order. This can be used as part of a high availability system. See
|
||||||
|
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||||
|
//
|
||||||
|
// # Example URL
|
||||||
|
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||||
|
//
|
||||||
|
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
|
||||||
|
// via database URL or DSN:
|
||||||
|
//
|
||||||
|
// PGHOST
|
||||||
|
// PGPORT
|
||||||
|
// PGDATABASE
|
||||||
|
// PGUSER
|
||||||
|
// PGPASSWORD
|
||||||
|
// PGPASSFILE
|
||||||
|
// PGSERVICE
|
||||||
|
// PGSERVICEFILE
|
||||||
|
// PGSSLMODE
|
||||||
|
// PGSSLCERT
|
||||||
|
// PGSSLKEY
|
||||||
|
// PGSSLROOTCERT
|
||||||
|
// PGSSLPASSWORD
|
||||||
|
// PGAPPNAME
|
||||||
|
// PGCONNECT_TIMEOUT
|
||||||
|
// PGTARGETSESSIONATTRS
|
||||||
|
//
|
||||||
|
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
||||||
|
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
||||||
|
//
|
||||||
|
// Important Security Notes:
|
||||||
|
//
|
||||||
|
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
||||||
|
// not set.
|
||||||
|
//
|
||||||
|
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
||||||
|
// security each sslmode provides.
|
||||||
|
//
|
||||||
|
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
|
||||||
|
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
|
||||||
|
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
|
||||||
|
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
|
||||||
|
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
|
||||||
|
// TLSConfig.
|
||||||
|
//
|
||||||
|
// Other known differences with libpq:
|
||||||
|
//
|
||||||
|
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
|
||||||
|
// does not.
|
||||||
|
//
|
||||||
|
// In addition, ParseConfig accepts the following options:
|
||||||
|
//
|
||||||
|
// min_read_buffer_size
|
||||||
|
// The minimum size of the internal read buffer. Default 8192.
|
||||||
|
// servicefile
|
||||||
|
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
||||||
|
// part of the connection string.
|
||||||
|
func ParseConfig(connString string) (*Config, error) {
|
||||||
|
var parseConfigOptions ParseConfigOptions
|
||||||
|
return ParseConfigWithOptions(connString, parseConfigOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
|
||||||
|
// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
|
||||||
|
// get the SSL password.
|
||||||
|
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
|
||||||
|
defaultSettings := defaultSettings()
|
||||||
|
envSettings := parseEnvSettings()
|
||||||
|
|
||||||
|
connStringSettings := make(map[string]string)
|
||||||
|
if connString != "" {
|
||||||
|
var err error
|
||||||
|
// connString may be a database URL or a DSN
|
||||||
|
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||||
|
connStringSettings, err = parseURLSettings(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
connStringSettings, err = parseDSNSettings(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
|
||||||
|
if service, present := settings["service"]; present {
|
||||||
|
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
|
||||||
|
}
|
||||||
|
|
||||||
|
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
createdByParseConfig: true,
|
||||||
|
Database: settings["database"],
|
||||||
|
User: settings["user"],
|
||||||
|
Password: settings["password"],
|
||||||
|
RuntimeParams: make(map[string]string),
|
||||||
|
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
|
||||||
|
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||||
|
}
|
||||||
|
config.ConnectTimeout = connectTimeout
|
||||||
|
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
|
||||||
|
} else {
|
||||||
|
defaultDialer := makeDefaultDialer()
|
||||||
|
config.DialFunc = defaultDialer.DialContext
|
||||||
|
}
|
||||||
|
|
||||||
|
config.LookupFunc = makeDefaultResolver().LookupHost
|
||||||
|
|
||||||
|
notRuntimeParams := map[string]struct{}{
|
||||||
|
"host": {},
|
||||||
|
"port": {},
|
||||||
|
"database": {},
|
||||||
|
"user": {},
|
||||||
|
"password": {},
|
||||||
|
"passfile": {},
|
||||||
|
"connect_timeout": {},
|
||||||
|
"sslmode": {},
|
||||||
|
"sslkey": {},
|
||||||
|
"sslcert": {},
|
||||||
|
"sslrootcert": {},
|
||||||
|
"sslpassword": {},
|
||||||
|
"sslsni": {},
|
||||||
|
"krbspn": {},
|
||||||
|
"krbsrvname": {},
|
||||||
|
"target_session_attrs": {},
|
||||||
|
"min_read_buffer_size": {},
|
||||||
|
"service": {},
|
||||||
|
"servicefile": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding kerberos configuration
|
||||||
|
if _, present := settings["krbsrvname"]; present {
|
||||||
|
config.KerberosSrvName = settings["krbsrvname"]
|
||||||
|
}
|
||||||
|
if _, present := settings["krbspn"]; present {
|
||||||
|
config.KerberosSpn = settings["krbspn"]
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range settings {
|
||||||
|
if _, present := notRuntimeParams[k]; present {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
config.RuntimeParams[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
fallbacks := []*FallbackConfig{}
|
||||||
|
|
||||||
|
hosts := strings.Split(settings["host"], ",")
|
||||||
|
ports := strings.Split(settings["port"], ",")
|
||||||
|
|
||||||
|
for i, host := range hosts {
|
||||||
|
var portStr string
|
||||||
|
if i < len(ports) {
|
||||||
|
portStr = ports[i]
|
||||||
|
} else {
|
||||||
|
portStr = ports[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := parsePort(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tlsConfigs []*tls.Config
|
||||||
|
|
||||||
|
// Ignore TLS settings if Unix domain socket like libpq
|
||||||
|
if network, _ := NetworkAddress(host, port); network == "unix" {
|
||||||
|
tlsConfigs = append(tlsConfigs, nil)
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
tlsConfigs, err = configTLS(settings, host, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tlsConfig := range tlsConfigs {
|
||||||
|
fallbacks = append(fallbacks, &FallbackConfig{
|
||||||
|
Host: host,
|
||||||
|
Port: port,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Host = fallbacks[0].Host
|
||||||
|
config.Port = fallbacks[0].Port
|
||||||
|
config.TLSConfig = fallbacks[0].TLSConfig
|
||||||
|
config.Fallbacks = fallbacks[1:]
|
||||||
|
|
||||||
|
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||||
|
if err == nil {
|
||||||
|
if config.Password == "" {
|
||||||
|
host := config.Host
|
||||||
|
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||||
|
host = "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tsa := settings["target_session_attrs"]; tsa {
|
||||||
|
case "read-write":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||||
|
case "read-only":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
|
||||||
|
case "primary":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
|
||||||
|
case "standby":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
|
||||||
|
case "prefer-standby":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
|
||||||
|
case "any":
|
||||||
|
// do nothing
|
||||||
|
default:
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeSettings(settingSets ...map[string]string) map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
for _, s2 := range settingSets {
|
||||||
|
for k, v := range s2 {
|
||||||
|
settings[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEnvSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"PGHOST": "host",
|
||||||
|
"PGPORT": "port",
|
||||||
|
"PGDATABASE": "database",
|
||||||
|
"PGUSER": "user",
|
||||||
|
"PGPASSWORD": "password",
|
||||||
|
"PGPASSFILE": "passfile",
|
||||||
|
"PGAPPNAME": "application_name",
|
||||||
|
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||||
|
"PGSSLMODE": "sslmode",
|
||||||
|
"PGSSLKEY": "sslkey",
|
||||||
|
"PGSSLCERT": "sslcert",
|
||||||
|
"PGSSLSNI": "sslsni",
|
||||||
|
"PGSSLROOTCERT": "sslrootcert",
|
||||||
|
"PGSSLPASSWORD": "sslpassword",
|
||||||
|
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||||
|
"PGSERVICE": "service",
|
||||||
|
"PGSERVICEFILE": "servicefile",
|
||||||
|
}
|
||||||
|
|
||||||
|
for envname, realname := range nameMap {
|
||||||
|
value := os.Getenv(envname)
|
||||||
|
if value != "" {
|
||||||
|
settings[realname] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseURLSettings(connString string) (map[string]string, error) {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
url, err := url.Parse(connString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if url.User != nil {
|
||||||
|
settings["user"] = url.User.Username()
|
||||||
|
if password, present := url.User.Password(); present {
|
||||||
|
settings["password"] = password
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||||
|
var hosts []string
|
||||||
|
var ports []string
|
||||||
|
for _, host := range strings.Split(url.Host, ",") {
|
||||||
|
if host == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isIPOnly(host) {
|
||||||
|
hosts = append(hosts, strings.Trim(host, "[]"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h, p, err := net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
|
||||||
|
}
|
||||||
|
if h != "" {
|
||||||
|
hosts = append(hosts, h)
|
||||||
|
}
|
||||||
|
if p != "" {
|
||||||
|
ports = append(ports, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(hosts) > 0 {
|
||||||
|
settings["host"] = strings.Join(hosts, ",")
|
||||||
|
}
|
||||||
|
if len(ports) > 0 {
|
||||||
|
settings["port"] = strings.Join(ports, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
database := strings.TrimLeft(url.Path, "/")
|
||||||
|
if database != "" {
|
||||||
|
settings["database"] = database
|
||||||
|
}
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range url.Query() {
|
||||||
|
if k2, present := nameMap[k]; present {
|
||||||
|
k = k2
|
||||||
|
}
|
||||||
|
|
||||||
|
settings[k] = v[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIPOnly(host string) bool {
|
||||||
|
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
|
||||||
|
}
|
||||||
|
|
||||||
|
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||||
|
|
||||||
|
func parseDSNSettings(s string) (map[string]string, error) {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(s) > 0 {
|
||||||
|
var key, val string
|
||||||
|
eqIdx := strings.IndexRune(s, '=')
|
||||||
|
if eqIdx < 0 {
|
||||||
|
return nil, errors.New("invalid dsn")
|
||||||
|
}
|
||||||
|
|
||||||
|
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
||||||
|
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
|
||||||
|
if len(s) == 0 {
|
||||||
|
} else if s[0] != '\'' {
|
||||||
|
end := 0
|
||||||
|
for ; end < len(s); end++ {
|
||||||
|
if asciiSpace[s[end]] == 1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[end] == '\\' {
|
||||||
|
end++
|
||||||
|
if end == len(s) {
|
||||||
|
return nil, errors.New("invalid backslash")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||||
|
if end == len(s) {
|
||||||
|
s = ""
|
||||||
|
} else {
|
||||||
|
s = s[end+1:]
|
||||||
|
}
|
||||||
|
} else { // quoted string
|
||||||
|
s = s[1:]
|
||||||
|
end := 0
|
||||||
|
for ; end < len(s); end++ {
|
||||||
|
if s[end] == '\'' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if s[end] == '\\' {
|
||||||
|
end++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if end == len(s) {
|
||||||
|
return nil, errors.New("unterminated quoted string in connection info string")
|
||||||
|
}
|
||||||
|
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||||
|
if end == len(s) {
|
||||||
|
s = ""
|
||||||
|
} else {
|
||||||
|
s = s[end+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if k, ok := nameMap[key]; ok {
|
||||||
|
key = k
|
||||||
|
}
|
||||||
|
|
||||||
|
if key == "" {
|
||||||
|
return nil, errors.New("invalid dsn")
|
||||||
|
}
|
||||||
|
|
||||||
|
settings[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
|
||||||
|
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := servicefile.GetService(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to find service: %v", serviceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
nameMap := map[string]string{
|
||||||
|
"dbname": "database",
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := make(map[string]string, len(service.Settings))
|
||||||
|
for k, v := range service.Settings {
|
||||||
|
if k2, present := nameMap[k]; present {
|
||||||
|
k = k2
|
||||||
|
}
|
||||||
|
settings[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return settings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||||
|
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||||
|
// "prefer" allow fallback.
|
||||||
|
func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
|
||||||
|
host := thisHost
|
||||||
|
sslmode := settings["sslmode"]
|
||||||
|
sslrootcert := settings["sslrootcert"]
|
||||||
|
sslcert := settings["sslcert"]
|
||||||
|
sslkey := settings["sslkey"]
|
||||||
|
sslpassword := settings["sslpassword"]
|
||||||
|
sslsni := settings["sslsni"]
|
||||||
|
|
||||||
|
// Match libpq default behavior
|
||||||
|
if sslmode == "" {
|
||||||
|
sslmode = "prefer"
|
||||||
|
}
|
||||||
|
if sslsni == "" {
|
||||||
|
sslsni = "1"
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "disable":
|
||||||
|
return []*tls.Config{nil}, nil
|
||||||
|
case "allow", "prefer":
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
case "require":
|
||||||
|
// According to PostgreSQL documentation, if a root CA file exists,
|
||||||
|
// the behavior of sslmode=require should be the same as that of verify-ca
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/12/libpq-ssl.html
|
||||||
|
if sslrootcert != "" {
|
||||||
|
goto nextCase
|
||||||
|
}
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
break
|
||||||
|
nextCase:
|
||||||
|
fallthrough
|
||||||
|
case "verify-ca":
|
||||||
|
// Don't perform the default certificate verification because it
|
||||||
|
// will verify the hostname. Instead, verify the server's
|
||||||
|
// certificate chain ourselves in VerifyPeerCertificate and
|
||||||
|
// ignore the server name. This emulates libpq's verify-ca
|
||||||
|
// behavior.
|
||||||
|
//
|
||||||
|
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
|
||||||
|
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
|
||||||
|
// for more info.
|
||||||
|
tlsConfig.InsecureSkipVerify = true
|
||||||
|
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
|
||||||
|
certs := make([]*x509.Certificate, len(certificates))
|
||||||
|
for i, asn1Data := range certificates {
|
||||||
|
cert, err := x509.ParseCertificate(asn1Data)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("failed to parse certificate from server: " + err.Error())
|
||||||
|
}
|
||||||
|
certs[i] = cert
|
||||||
|
}
|
||||||
|
|
||||||
|
// Leave DNSName empty to skip hostname verification.
|
||||||
|
opts := x509.VerifyOptions{
|
||||||
|
Roots: tlsConfig.RootCAs,
|
||||||
|
Intermediates: x509.NewCertPool(),
|
||||||
|
}
|
||||||
|
// Skip the first cert because it's the leaf. All others
|
||||||
|
// are intermediates.
|
||||||
|
for _, cert := range certs[1:] {
|
||||||
|
opts.Intermediates.AddCert(cert)
|
||||||
|
}
|
||||||
|
_, err := certs[0].Verify(opts)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "verify-full":
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
default:
|
||||||
|
return nil, errors.New("sslmode is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslrootcert != "" {
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
|
||||||
|
caPath := sslrootcert
|
||||||
|
caCert, err := ioutil.ReadFile(caPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read CA file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||||
|
return nil, errors.New("unable to add CA to cert pool")
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig.RootCAs = caCertPool
|
||||||
|
tlsConfig.ClientCAs = caCertPool
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||||
|
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sslcert != "" && sslkey != "" {
|
||||||
|
buf, err := ioutil.ReadFile(sslkey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
||||||
|
}
|
||||||
|
block, _ := pem.Decode(buf)
|
||||||
|
var pemKey []byte
|
||||||
|
var decryptedKey []byte
|
||||||
|
var decryptedError error
|
||||||
|
// If PEM is encrypted, attempt to decrypt using pass phrase
|
||||||
|
if x509.IsEncryptedPEMBlock(block) {
|
||||||
|
// Attempt decryption with pass phrase
|
||||||
|
// NOTE: only supports RSA (PKCS#1)
|
||||||
|
if sslpassword != "" {
|
||||||
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
|
}
|
||||||
|
//if sslpassword not provided or has decryption error when use it
|
||||||
|
//try to find sslpassword with callback function
|
||||||
|
if sslpassword == "" || decryptedError != nil {
|
||||||
|
if parseConfigOptions.GetSSLPassword != nil {
|
||||||
|
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
||||||
|
}
|
||||||
|
if sslpassword == "" {
|
||||||
|
return nil, fmt.Errorf("unable to find sslpassword")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
|
// Should we also provide warning for PKCS#1 needed?
|
||||||
|
if decryptedError != nil {
|
||||||
|
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pemBytes := pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: decryptedKey,
|
||||||
|
}
|
||||||
|
pemKey = pem.EncodeToMemory(&pemBytes)
|
||||||
|
} else {
|
||||||
|
pemKey = pem.EncodeToMemory(block)
|
||||||
|
}
|
||||||
|
certfile, err := ioutil.ReadFile(sslcert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read cert: %w", err)
|
||||||
|
}
|
||||||
|
cert, err := tls.X509KeyPair(certfile, pemKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to load cert: %w", err)
|
||||||
|
}
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set Server Name Indication (SNI), if enabled by connection parameters.
|
||||||
|
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
|
||||||
|
// or IPv6).
|
||||||
|
if sslsni == "1" && net.ParseIP(host) == nil {
|
||||||
|
tlsConfig.ServerName = host
|
||||||
|
}
|
||||||
|
|
||||||
|
switch sslmode {
|
||||||
|
case "allow":
|
||||||
|
return []*tls.Config{nil, tlsConfig}, nil
|
||||||
|
case "prefer":
|
||||||
|
return []*tls.Config{tlsConfig, nil}, nil
|
||||||
|
case "require", "verify-ca", "verify-full":
|
||||||
|
return []*tls.Config{tlsConfig}, nil
|
||||||
|
default:
|
||||||
|
panic("BUG: bad sslmode should already have been caught")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePort(s string) (uint16, error) {
|
||||||
|
port, err := strconv.ParseUint(s, 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if port < 1 || port > math.MaxUint16 {
|
||||||
|
return 0, errors.New("outside range")
|
||||||
|
}
|
||||||
|
return uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDefaultDialer() *net.Dialer {
|
||||||
|
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDefaultResolver() *net.Resolver {
|
||||||
|
return net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
|
||||||
|
return func(r io.Reader, w io.Writer) Frontend {
|
||||||
|
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
|
||||||
|
}
|
||||||
|
frontend := pgproto3.NewFrontend(cr, w)
|
||||||
|
|
||||||
|
return frontend
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
|
||||||
|
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if timeout < 0 {
|
||||||
|
return 0, errors.New("negative timeout")
|
||||||
|
}
|
||||||
|
return time.Duration(timeout) * time.Second, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
|
||||||
|
d := makeDefaultDialer()
|
||||||
|
d.Timeout = timeout
|
||||||
|
return d.DialContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=read-write.
|
||||||
|
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) == "on" {
|
||||||
|
return errors.New("read only connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=read-only.
|
||||||
|
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) != "on" {
|
||||||
|
return errors.New("connection is not read only")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=standby.
|
||||||
|
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) != "t" {
|
||||||
|
return errors.New("server is not in hot standby mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=primary.
|
||||||
|
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) == "t" {
|
||||||
|
return errors.New("server is in standby mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=prefer-standby.
|
||||||
|
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) != "t" {
|
||||||
|
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,65 @@
|
|||||||
|
//go:build !windows
|
||||||
|
// +build !windows
|
||||||
|
|
||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
settings["host"] = defaultHost()
|
||||||
|
settings["port"] = "5432"
|
||||||
|
|
||||||
|
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||||
|
// OS. The client application will simply have to specify the user in that
|
||||||
|
// case (which they typically will be doing anyway).
|
||||||
|
user, err := user.Current()
|
||||||
|
if err == nil {
|
||||||
|
settings["user"] = user.Username
|
||||||
|
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||||
|
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||||
|
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
|
||||||
|
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
|
||||||
|
if _, err := os.Stat(sslcert); err == nil {
|
||||||
|
if _, err := os.Stat(sslkey); err == nil {
|
||||||
|
// Both the cert and key must be present to use them, or do not use either
|
||||||
|
settings["sslcert"] = sslcert
|
||||||
|
settings["sslkey"] = sslkey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
|
||||||
|
if _, err := os.Stat(sslrootcert); err == nil {
|
||||||
|
settings["sslrootcert"] = sslrootcert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["target_session_attrs"] = "any"
|
||||||
|
|
||||||
|
settings["min_read_buffer_size"] = "8192"
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||||
|
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||||
|
// checks the existence of common locations.
|
||||||
|
func defaultHost() string {
|
||||||
|
candidatePaths := []string{
|
||||||
|
"/var/run/postgresql", // Debian
|
||||||
|
"/private/tmp", // OSX - homebrew
|
||||||
|
"/tmp", // standard PostgreSQL
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range candidatePaths {
|
||||||
|
if _, err := os.Stat(path); err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "localhost"
|
||||||
|
}
|
@ -0,0 +1,59 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func defaultSettings() map[string]string {
|
||||||
|
settings := make(map[string]string)
|
||||||
|
|
||||||
|
settings["host"] = defaultHost()
|
||||||
|
settings["port"] = "5432"
|
||||||
|
|
||||||
|
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||||
|
// OS. The client application will simply have to specify the user in that
|
||||||
|
// case (which they typically will be doing anyway).
|
||||||
|
user, err := user.Current()
|
||||||
|
appData := os.Getenv("APPDATA")
|
||||||
|
if err == nil {
|
||||||
|
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
|
||||||
|
// but the libpq default is just the `user` portion, so we strip off the first part.
|
||||||
|
username := user.Username
|
||||||
|
if strings.Contains(username, "\\") {
|
||||||
|
username = username[strings.LastIndex(username, "\\")+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["user"] = username
|
||||||
|
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
|
||||||
|
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||||
|
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
|
||||||
|
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
|
||||||
|
if _, err := os.Stat(sslcert); err == nil {
|
||||||
|
if _, err := os.Stat(sslkey); err == nil {
|
||||||
|
// Both the cert and key must be present to use them, or do not use either
|
||||||
|
settings["sslcert"] = sslcert
|
||||||
|
settings["sslkey"] = sslkey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
|
||||||
|
if _, err := os.Stat(sslrootcert); err == nil {
|
||||||
|
settings["sslrootcert"] = sslrootcert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
settings["target_session_attrs"] = "any"
|
||||||
|
|
||||||
|
settings["min_read_buffer_size"] = "8192"
|
||||||
|
|
||||||
|
return settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||||
|
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||||
|
// checks the existence of common locations.
|
||||||
|
func defaultHost() string {
|
||||||
|
return "localhost"
|
||||||
|
}
|
@ -0,0 +1,29 @@
|
|||||||
|
// Package pgconn is a low-level PostgreSQL database driver.
|
||||||
|
/*
|
||||||
|
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
|
||||||
|
nearly the same level is the C library libpq.
|
||||||
|
|
||||||
|
Establishing a Connection
|
||||||
|
|
||||||
|
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
|
||||||
|
libpq style environment variables.
|
||||||
|
|
||||||
|
Executing a Query
|
||||||
|
|
||||||
|
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
|
||||||
|
reads all rows into memory.
|
||||||
|
|
||||||
|
Executing Multiple Queries in a Single Round Trip
|
||||||
|
|
||||||
|
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
|
||||||
|
result. The ReadAll method reads all query results into memory.
|
||||||
|
|
||||||
|
Context Support
|
||||||
|
|
||||||
|
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
|
||||||
|
method immediately returns. In most circumstances, this will close the underlying connection.
|
||||||
|
|
||||||
|
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
|
||||||
|
client to abort.
|
||||||
|
*/
|
||||||
|
package pgconn
|
@ -0,0 +1,238 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
|
||||||
|
func SafeToRetry(err error) bool {
|
||||||
|
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
|
||||||
|
return e.SafeToRetry()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
|
||||||
|
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
|
||||||
|
func Timeout(err error) bool {
|
||||||
|
var timeoutErr *errTimeout
|
||||||
|
return errors.As(err, &timeoutErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PgError represents an error reported by the PostgreSQL server. See
|
||||||
|
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||||
|
// detailed field description.
|
||||||
|
type PgError struct {
|
||||||
|
Severity string
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pe *PgError) Error() string {
|
||||||
|
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLState returns the SQLState of the error.
|
||||||
|
func (pe *PgError) SQLState() string {
|
||||||
|
return pe.Code
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectError struct {
|
||||||
|
config *Config
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connectError) Error() string {
|
||||||
|
sb := &strings.Builder{}
|
||||||
|
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
|
||||||
|
if e.err != nil {
|
||||||
|
fmt.Fprintf(sb, " (%s)", e.err.Error())
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connectError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type connLockError struct {
|
||||||
|
status string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connLockError) SafeToRetry() bool {
|
||||||
|
return true // a lock failure by definition happens before the connection is used.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *connLockError) Error() string {
|
||||||
|
return e.status
|
||||||
|
}
|
||||||
|
|
||||||
|
type parseConfigError struct {
|
||||||
|
connString string
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *parseConfigError) Error() string {
|
||||||
|
connString := redactPW(e.connString)
|
||||||
|
if e.err == nil {
|
||||||
|
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *parseConfigError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
|
||||||
|
// true. Otherwise returns err.
|
||||||
|
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
|
||||||
|
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
|
||||||
|
return &errTimeout{err: ctx.Err()}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type pgconnError struct {
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
safeToRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *pgconnError) Error() string {
|
||||||
|
if e.msg == "" {
|
||||||
|
return e.err.Error()
|
||||||
|
}
|
||||||
|
if e.err == nil {
|
||||||
|
return e.msg
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *pgconnError) SafeToRetry() bool {
|
||||||
|
return e.safeToRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *pgconnError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
|
||||||
|
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
|
||||||
|
type errTimeout struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errTimeout) Error() string {
|
||||||
|
return fmt.Sprintf("timeout: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errTimeout) SafeToRetry() bool {
|
||||||
|
return SafeToRetry(e.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *errTimeout) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextAlreadyDoneError struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *contextAlreadyDoneError) Error() string {
|
||||||
|
return fmt.Sprintf("context already done: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *contextAlreadyDoneError) SafeToRetry() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *contextAlreadyDoneError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
|
||||||
|
func newContextAlreadyDoneError(ctx context.Context) (err error) {
|
||||||
|
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
|
||||||
|
}
|
||||||
|
|
||||||
|
type writeError struct {
|
||||||
|
err error
|
||||||
|
safeToRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *writeError) Error() string {
|
||||||
|
return fmt.Sprintf("write failed: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *writeError) SafeToRetry() bool {
|
||||||
|
return e.safeToRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *writeError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func redactPW(connString string) string {
|
||||||
|
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||||
|
if u, err := url.Parse(connString); err == nil {
|
||||||
|
return redactURL(u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
quotedDSN := regexp.MustCompile(`password='[^']*'`)
|
||||||
|
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||||
|
plainDSN := regexp.MustCompile(`password=[^ ]*`)
|
||||||
|
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||||
|
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
|
||||||
|
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
|
||||||
|
return connString
|
||||||
|
}
|
||||||
|
|
||||||
|
func redactURL(u *url.URL) string {
|
||||||
|
if u == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if _, pwSet := u.User.Password(); pwSet {
|
||||||
|
u.User = url.UserPassword(u.User.Username(), "xxxxx")
|
||||||
|
}
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
type NotPreferredError struct {
|
||||||
|
err error
|
||||||
|
safeToRetry bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotPreferredError) Error() string {
|
||||||
|
return fmt.Sprintf("standby server not found: %s", e.err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotPreferredError) SafeToRetry() bool {
|
||||||
|
return e.safeToRetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotPreferredError) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
@ -0,0 +1,73 @@
|
|||||||
|
package ctxwatch
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||||
|
// time.
|
||||||
|
type ContextWatcher struct {
|
||||||
|
onCancel func()
|
||||||
|
onUnwatchAfterCancel func()
|
||||||
|
unwatchChan chan struct{}
|
||||||
|
|
||||||
|
lock sync.Mutex
|
||||||
|
watchInProgress bool
|
||||||
|
onCancelWasCalled bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||||
|
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
||||||
|
// onCancel called.
|
||||||
|
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
|
||||||
|
cw := &ContextWatcher{
|
||||||
|
onCancel: onCancel,
|
||||||
|
onUnwatchAfterCancel: onUnwatchAfterCancel,
|
||||||
|
unwatchChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return cw
|
||||||
|
}
|
||||||
|
|
||||||
|
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
|
||||||
|
func (cw *ContextWatcher) Watch(ctx context.Context) {
|
||||||
|
cw.lock.Lock()
|
||||||
|
defer cw.lock.Unlock()
|
||||||
|
|
||||||
|
if cw.watchInProgress {
|
||||||
|
panic("Watch already in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.onCancelWasCalled = false
|
||||||
|
|
||||||
|
if ctx.Done() != nil {
|
||||||
|
cw.watchInProgress = true
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
cw.onCancel()
|
||||||
|
cw.onCancelWasCalled = true
|
||||||
|
<-cw.unwatchChan
|
||||||
|
case <-cw.unwatchChan:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
cw.watchInProgress = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
|
||||||
|
// called then onUnwatchAfterCancel will also be called.
|
||||||
|
func (cw *ContextWatcher) Unwatch() {
|
||||||
|
cw.lock.Lock()
|
||||||
|
defer cw.lock.Unlock()
|
||||||
|
|
||||||
|
if cw.watchInProgress {
|
||||||
|
cw.unwatchChan <- struct{}{}
|
||||||
|
if cw.onCancelWasCalled {
|
||||||
|
cw.onUnwatchAfterCancel()
|
||||||
|
}
|
||||||
|
cw.watchInProgress = false
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,99 @@
|
|||||||
|
package pgconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgproto3/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewGSSFunc creates a GSS authentication provider, for use with
|
||||||
|
// RegisterGSSProvider.
|
||||||
|
type NewGSSFunc func() (GSS, error)
|
||||||
|
|
||||||
|
var newGSS NewGSSFunc
|
||||||
|
|
||||||
|
// RegisterGSSProvider registers a GSS authentication provider. For example, if
|
||||||
|
// you need to use Kerberos to authenticate with your server, add this to your
|
||||||
|
// main package:
|
||||||
|
//
|
||||||
|
// import "github.com/otan/gopgkrb5"
|
||||||
|
//
|
||||||
|
// func init() {
|
||||||
|
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
|
||||||
|
// }
|
||||||
|
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
|
||||||
|
newGSS = newGSSArg
|
||||||
|
}
|
||||||
|
|
||||||
|
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
||||||
|
type GSS interface {
|
||||||
|
GetInitToken(host string, service string) ([]byte, error)
|
||||||
|
GetInitTokenFromSPN(spn string) ([]byte, error)
|
||||||
|
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) gssAuth() error {
|
||||||
|
if newGSS == nil {
|
||||||
|
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
|
||||||
|
}
|
||||||
|
cli, err := newGSS()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var nextData []byte
|
||||||
|
if c.config.KerberosSpn != "" {
|
||||||
|
// Use the supplied SPN if provided.
|
||||||
|
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
|
||||||
|
} else {
|
||||||
|
// Allow the kerberos service name to be overridden
|
||||||
|
service := "postgres"
|
||||||
|
if c.config.KerberosSrvName != "" {
|
||||||
|
service = c.config.KerberosSrvName
|
||||||
|
}
|
||||||
|
nextData, err = cli.GetInitToken(c.config.Host, service)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
gssResponse := &pgproto3.GSSResponse{
|
||||||
|
Data: nextData,
|
||||||
|
}
|
||||||
|
_, err = c.conn.Write(gssResponse.Encode(nil))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp, err := c.rxGSSContinue()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var done bool
|
||||||
|
done, nextData, err = cli.Continue(resp.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if done {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
|
||||||
|
msg, err := c.receiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch m := msg.(type) {
|
||||||
|
case *pgproto3.AuthenticationGSSContinue:
|
||||||
|
return m, nil
|
||||||
|
case *pgproto3.ErrorResponse:
|
||||||
|
return nil, ErrorResponseToPgError(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,169 @@
|
|||||||
|
package stmtcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
var lruCount uint64
|
||||||
|
|
||||||
|
// LRU implements Cache with a Least Recently Used (LRU) cache.
|
||||||
|
type LRU struct {
|
||||||
|
conn *pgconn.PgConn
|
||||||
|
mode int
|
||||||
|
cap int
|
||||||
|
prepareCount int
|
||||||
|
m map[string]*list.Element
|
||||||
|
l *list.List
|
||||||
|
psNamePrefix string
|
||||||
|
stmtsToClear []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
|
||||||
|
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
|
||||||
|
mustBeValidMode(mode)
|
||||||
|
mustBeValidCap(cap)
|
||||||
|
|
||||||
|
n := atomic.AddUint64(&lruCount, 1)
|
||||||
|
|
||||||
|
return &LRU{
|
||||||
|
conn: conn,
|
||||||
|
mode: mode,
|
||||||
|
cap: cap,
|
||||||
|
m: make(map[string]*list.Element),
|
||||||
|
l: list.New(),
|
||||||
|
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||||
|
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||||
|
if ctx != context.Background() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flush an outstanding bad statements
|
||||||
|
txStatus := c.conn.TxStatus()
|
||||||
|
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
|
||||||
|
for _, stmt := range c.stmtsToClear {
|
||||||
|
err := c.clearStmt(ctx, stmt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if el, ok := c.m[sql]; ok {
|
||||||
|
c.l.MoveToFront(el)
|
||||||
|
return el.Value.(*pgconn.StatementDescription), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.l.Len() == c.cap {
|
||||||
|
err := c.removeOldest(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
psd, err := c.prepare(ctx, sql)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
el := c.l.PushFront(psd)
|
||||||
|
c.m[sql] = el
|
||||||
|
|
||||||
|
return psd, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||||
|
func (c *LRU) Clear(ctx context.Context) error {
|
||||||
|
for c.l.Len() > 0 {
|
||||||
|
err := c.removeOldest(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRU) StatementErrored(sql string, err error) {
|
||||||
|
pgErr, ok := err.(*pgconn.PgError)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgx/issues/1162
|
||||||
|
//
|
||||||
|
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
|
||||||
|
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
|
||||||
|
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
|
||||||
|
// have so it should be safe.
|
||||||
|
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
|
||||||
|
if possibleInvalidCachedPlanError {
|
||||||
|
c.stmtsToClear = append(c.stmtsToClear, sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
|
||||||
|
elem, inMap := c.m[sql]
|
||||||
|
if !inMap {
|
||||||
|
// The statement probably fell off the back of the list. In that case, we've
|
||||||
|
// ensured that it isn't in the cache, so we can declare victory.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.l.Remove(elem)
|
||||||
|
|
||||||
|
psd := elem.Value.(*pgconn.StatementDescription)
|
||||||
|
delete(c.m, psd.SQL)
|
||||||
|
if c.mode == ModePrepare {
|
||||||
|
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
func (c *LRU) Len() int {
|
||||||
|
return c.l.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
|
func (c *LRU) Cap() int {
|
||||||
|
return c.cap
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||||
|
func (c *LRU) Mode() int {
|
||||||
|
return c.mode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||||
|
var name string
|
||||||
|
if c.mode == ModePrepare {
|
||||||
|
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
|
||||||
|
c.prepareCount += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.conn.Prepare(ctx, name, sql, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LRU) removeOldest(ctx context.Context) error {
|
||||||
|
oldest := c.l.Back()
|
||||||
|
c.l.Remove(oldest)
|
||||||
|
psd := oldest.Value.(*pgconn.StatementDescription)
|
||||||
|
delete(c.m, psd.SQL)
|
||||||
|
if c.mode == ModePrepare {
|
||||||
|
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,58 @@
|
|||||||
|
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
|
||||||
|
package stmtcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModePrepare = iota // Cache should prepare named statements.
|
||||||
|
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cache prepares and caches prepared statement descriptions.
|
||||||
|
type Cache interface {
|
||||||
|
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||||
|
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
|
||||||
|
|
||||||
|
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||||
|
Clear(ctx context.Context) error
|
||||||
|
|
||||||
|
// StatementErrored informs the cache that the given statement resulted in an error when it
|
||||||
|
// was last used against the database. In some cases, this will cause the cache to maer that
|
||||||
|
// statement as bad. The bad statement will instead be flushed during the next call to Get
|
||||||
|
// that occurs outside of a failed transaction.
|
||||||
|
StatementErrored(sql string, err error)
|
||||||
|
|
||||||
|
// Len returns the number of cached prepared statement descriptions.
|
||||||
|
Len() int
|
||||||
|
|
||||||
|
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||||
|
Cap() int
|
||||||
|
|
||||||
|
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||||
|
Mode() int
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
|
||||||
|
// the maximum size of the cache.
|
||||||
|
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
|
||||||
|
mustBeValidMode(mode)
|
||||||
|
mustBeValidCap(cap)
|
||||||
|
|
||||||
|
return NewLRU(conn, mode, cap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustBeValidMode(mode int) {
|
||||||
|
if mode != ModePrepare && mode != ModeDescribe {
|
||||||
|
panic("mode must be ModePrepare or ModeDescribe")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustBeValidCap(cap int) {
|
||||||
|
if cap < 1 {
|
||||||
|
panic("cache must have cap of >= 1")
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,9 @@
|
|||||||
|
language: go
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.x
|
||||||
|
- tip
|
||||||
|
|
||||||
|
matrix:
|
||||||
|
allow_failures:
|
||||||
|
- go: tip
|
@ -0,0 +1,22 @@
|
|||||||
|
Copyright (c) 2019 Jack Christensen
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of this software and associated documentation files (the
|
||||||
|
"Software"), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be
|
||||||
|
included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,11 @@
|
|||||||
|
[![](https://godoc.org/github.com/jackc/pgio?status.svg)](https://godoc.org/github.com/jackc/pgio)
|
||||||
|
[![Build Status](https://travis-ci.org/jackc/pgio.svg)](https://travis-ci.org/jackc/pgio)
|
||||||
|
|
||||||
|
# pgio
|
||||||
|
|
||||||
|
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||||
|
|
||||||
|
pgio provides functions for appending integers to a []byte while doing byte
|
||||||
|
order conversion.
|
||||||
|
|
||||||
|
Extracted from original implementation in https://github.com/jackc/pgx.
|
@ -0,0 +1,6 @@
|
|||||||
|
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||||
|
/*
|
||||||
|
pgio provides functions for appending integers to a []byte while doing byte
|
||||||
|
order conversion.
|
||||||
|
*/
|
||||||
|
package pgio
|
@ -0,0 +1,40 @@
|
|||||||
|
package pgio
|
||||||
|
|
||||||
|
import "encoding/binary"
|
||||||
|
|
||||||
|
func AppendUint16(buf []byte, n uint16) []byte {
|
||||||
|
wp := len(buf)
|
||||||
|
buf = append(buf, 0, 0)
|
||||||
|
binary.BigEndian.PutUint16(buf[wp:], n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendUint32(buf []byte, n uint32) []byte {
|
||||||
|
wp := len(buf)
|
||||||
|
buf = append(buf, 0, 0, 0, 0)
|
||||||
|
binary.BigEndian.PutUint32(buf[wp:], n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendUint64(buf []byte, n uint64) []byte {
|
||||||
|
wp := len(buf)
|
||||||
|
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||||
|
binary.BigEndian.PutUint64(buf[wp:], n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendInt16(buf []byte, n int16) []byte {
|
||||||
|
return AppendUint16(buf, uint16(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendInt32(buf []byte, n int32) []byte {
|
||||||
|
return AppendUint32(buf, uint32(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func AppendInt64(buf []byte, n int64) []byte {
|
||||||
|
return AppendUint64(buf, uint64(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetInt32(buf []byte, n int32) {
|
||||||
|
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||||
|
}
|
@ -0,0 +1,9 @@
|
|||||||
|
language: go
|
||||||
|
|
||||||
|
go:
|
||||||
|
- 1.x
|
||||||
|
- tip
|
||||||
|
|
||||||
|
matrix:
|
||||||
|
allow_failures:
|
||||||
|
- go: tip
|
@ -0,0 +1,22 @@
|
|||||||
|
Copyright (c) 2019 Jack Christensen
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of this software and associated documentation files (the
|
||||||
|
"Software"), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be
|
||||||
|
included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,18 @@
|
|||||||
|
[![](https://godoc.org/github.com/jackc/pgproto3?status.svg)](https://godoc.org/github.com/jackc/pgproto3)
|
||||||
|
[![Build Status](https://travis-ci.org/jackc/pgproto3.svg)](https://travis-ci.org/jackc/pgproto3)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
This version is used with pgx `v4`. In pgx `v5` it is part of the https://github.com/jackc/pgx repository.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# pgproto3
|
||||||
|
|
||||||
|
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||||
|
|
||||||
|
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.
|
||||||
|
|
||||||
|
See example/pgfortune for a playful example of a fake PostgreSQL server.
|
||||||
|
|
||||||
|
Extracted from original implementation in https://github.com/jackc/pgx.
|
@ -0,0 +1,52 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
|
||||||
|
type AuthenticationCleartextPassword struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationCleartextPassword) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
||||||
|
if len(src) != 4 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeCleartextPassword {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationCleartextPassword",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,58 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthenticationGSS struct{}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) Backend() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeGSS {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 4)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationGSS",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,67 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthenticationGSSContinue struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) Backend() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeGSSCont {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Data = src[4:]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||||
|
dst = append(dst, a.Data...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationGSSContinue",
|
||||||
|
Data: a.Data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Data = msg.Data
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,77 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
|
||||||
|
type AuthenticationMD5Password struct {
|
||||||
|
Salt [4]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationMD5Password) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationMD5Password) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
||||||
|
if len(src) != 8 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeMD5Password {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(dst.Salt[:], src[4:8])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 12)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
||||||
|
dst = append(dst, src.Salt[:]...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Salt [4]byte
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationMD5Password",
|
||||||
|
Salt: src.Salt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
Salt [4]byte
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Salt = msg.Salt
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,52 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
|
||||||
|
type AuthenticationOk struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationOk) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationOk) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationOk) Decode(src []byte) error {
|
||||||
|
if len(src) != 4 {
|
||||||
|
return errors.New("bad authentication message size")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeOk {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationOk) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationOK",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
|
||||||
|
type AuthenticationSASL struct {
|
||||||
|
AuthMechanisms []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASL) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationSASL) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASL) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASL {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
authMechanisms := src[4:]
|
||||||
|
for len(authMechanisms) > 1 {
|
||||||
|
idx := bytes.IndexByte(authMechanisms, 0)
|
||||||
|
if idx > 0 {
|
||||||
|
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
|
||||||
|
authMechanisms = authMechanisms[idx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||||
|
|
||||||
|
for _, s := range src.AuthMechanisms {
|
||||||
|
dst = append(dst, []byte(s)...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
}
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationSASL) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
AuthMechanisms []string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationSASL",
|
||||||
|
AuthMechanisms: src.AuthMechanisms,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,81 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
|
||||||
|
type AuthenticationSASLContinue struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASLContinue) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASLContinue {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = src[4:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||||
|
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationSASLContinue",
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,81 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
|
||||||
|
type AuthenticationSASLFinal struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*AuthenticationSASLFinal) Backend() {}
|
||||||
|
|
||||||
|
// Backend identifies this message as an authentication response.
|
||||||
|
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
authType := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if authType != AuthTypeSASLFinal {
|
||||||
|
return errors.New("bad auth type")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = src[4:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'R')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||||
|
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "AuthenticationSASLFinal",
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,213 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Backend acts as a server for the PostgreSQL wire protocol version 3.
|
||||||
|
type Backend struct {
|
||||||
|
cr ChunkReader
|
||||||
|
w io.Writer
|
||||||
|
|
||||||
|
// Frontend message flyweights
|
||||||
|
bind Bind
|
||||||
|
cancelRequest CancelRequest
|
||||||
|
_close Close
|
||||||
|
copyFail CopyFail
|
||||||
|
copyData CopyData
|
||||||
|
copyDone CopyDone
|
||||||
|
describe Describe
|
||||||
|
execute Execute
|
||||||
|
flush Flush
|
||||||
|
functionCall FunctionCall
|
||||||
|
gssEncRequest GSSEncRequest
|
||||||
|
parse Parse
|
||||||
|
query Query
|
||||||
|
sslRequest SSLRequest
|
||||||
|
startupMessage StartupMessage
|
||||||
|
sync Sync
|
||||||
|
terminate Terminate
|
||||||
|
|
||||||
|
bodyLen int
|
||||||
|
msgType byte
|
||||||
|
partialMsg bool
|
||||||
|
authType uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
|
||||||
|
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewBackend creates a new Backend.
|
||||||
|
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
|
||||||
|
return &Backend{cr: cr, w: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a message to the frontend.
|
||||||
|
func (b *Backend) Send(msg BackendMessage) error {
|
||||||
|
_, err := b.w.Write(msg.Encode(nil))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
|
||||||
|
// because the initial connection message is "special" and does not include the message type as the first byte. This
|
||||||
|
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
|
||||||
|
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||||
|
buf, err := b.cr.Next(4)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||||
|
|
||||||
|
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
|
||||||
|
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf, err = b.cr.Next(msgSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := binary.BigEndian.Uint32(buf)
|
||||||
|
|
||||||
|
switch code {
|
||||||
|
case ProtocolVersionNumber:
|
||||||
|
err = b.startupMessage.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.startupMessage, nil
|
||||||
|
case sslRequestNumber:
|
||||||
|
err = b.sslRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.sslRequest, nil
|
||||||
|
case cancelRequestCode:
|
||||||
|
err = b.cancelRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.cancelRequest, nil
|
||||||
|
case gssEncReqNumber:
|
||||||
|
err = b.gssEncRequest.Decode(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &b.gssEncRequest, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown startup message code: %d", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
|
||||||
|
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||||
|
if !b.partialMsg {
|
||||||
|
header, err := b.cr.Next(5)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.msgType = header[0]
|
||||||
|
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||||
|
b.partialMsg = true
|
||||||
|
if b.bodyLen < 0 {
|
||||||
|
return nil, errors.New("invalid message with negative body length received")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg FrontendMessage
|
||||||
|
switch b.msgType {
|
||||||
|
case 'B':
|
||||||
|
msg = &b.bind
|
||||||
|
case 'C':
|
||||||
|
msg = &b._close
|
||||||
|
case 'D':
|
||||||
|
msg = &b.describe
|
||||||
|
case 'E':
|
||||||
|
msg = &b.execute
|
||||||
|
case 'F':
|
||||||
|
msg = &b.functionCall
|
||||||
|
case 'f':
|
||||||
|
msg = &b.copyFail
|
||||||
|
case 'd':
|
||||||
|
msg = &b.copyData
|
||||||
|
case 'c':
|
||||||
|
msg = &b.copyDone
|
||||||
|
case 'H':
|
||||||
|
msg = &b.flush
|
||||||
|
case 'P':
|
||||||
|
msg = &b.parse
|
||||||
|
case 'p':
|
||||||
|
switch b.authType {
|
||||||
|
case AuthTypeSASL:
|
||||||
|
msg = &SASLInitialResponse{}
|
||||||
|
case AuthTypeSASLContinue:
|
||||||
|
msg = &SASLResponse{}
|
||||||
|
case AuthTypeSASLFinal:
|
||||||
|
msg = &SASLResponse{}
|
||||||
|
case AuthTypeGSS, AuthTypeGSSCont:
|
||||||
|
msg = &GSSResponse{}
|
||||||
|
case AuthTypeCleartextPassword, AuthTypeMD5Password:
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
// to maintain backwards compatability
|
||||||
|
msg = &PasswordMessage{}
|
||||||
|
}
|
||||||
|
case 'Q':
|
||||||
|
msg = &b.query
|
||||||
|
case 'S':
|
||||||
|
msg = &b.sync
|
||||||
|
case 'X':
|
||||||
|
msg = &b.terminate
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgBody, err := b.cr.Next(b.bodyLen)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.partialMsg = false
|
||||||
|
|
||||||
|
err = msg.Decode(msgBody)
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAuthType sets the authentication type in the backend.
|
||||||
|
// Since multiple message types can start with 'p', SetAuthType allows
|
||||||
|
// contextual identification of FrontendMessages. For example, in the
|
||||||
|
// PG message flow documentation for PasswordMessage:
|
||||||
|
//
|
||||||
|
// Byte1('p')
|
||||||
|
//
|
||||||
|
// Identifies the message as a password response. Note that this is also used for
|
||||||
|
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
|
||||||
|
// the context.
|
||||||
|
//
|
||||||
|
// Since the Frontend does not know about the state of a backend, it is important
|
||||||
|
// to call SetAuthType() after an authentication request is received by the Frontend.
|
||||||
|
func (b *Backend) SetAuthType(authType uint32) error {
|
||||||
|
switch authType {
|
||||||
|
case AuthTypeOk,
|
||||||
|
AuthTypeCleartextPassword,
|
||||||
|
AuthTypeMD5Password,
|
||||||
|
AuthTypeSCMCreds,
|
||||||
|
AuthTypeGSS,
|
||||||
|
AuthTypeGSSCont,
|
||||||
|
AuthTypeSSPI,
|
||||||
|
AuthTypeSASL,
|
||||||
|
AuthTypeSASLContinue,
|
||||||
|
AuthTypeSASLFinal:
|
||||||
|
b.authType = authType
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("authType not recognized: %d", authType)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BackendKeyData struct {
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*BackendKeyData) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *BackendKeyData) Decode(src []byte) error {
|
||||||
|
if len(src) != 8 {
|
||||||
|
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
||||||
|
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'K')
|
||||||
|
dst = pgio.AppendUint32(dst, 12)
|
||||||
|
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||||
|
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src BackendKeyData) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}{
|
||||||
|
Type: "BackendKeyData",
|
||||||
|
ProcessID: src.ProcessID,
|
||||||
|
SecretKey: src.SecretKey,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BigEndianBuf [8]byte
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Int16(n int16) []byte {
|
||||||
|
buf := b[0:2]
|
||||||
|
binary.BigEndian.PutUint16(buf, uint16(n))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Uint16(n uint16) []byte {
|
||||||
|
buf := b[0:2]
|
||||||
|
binary.BigEndian.PutUint16(buf, n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Int32(n int32) []byte {
|
||||||
|
buf := b[0:4]
|
||||||
|
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Uint32(n uint32) []byte {
|
||||||
|
buf := b[0:4]
|
||||||
|
binary.BigEndian.PutUint32(buf, n)
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b BigEndianBuf) Int64(n int64) []byte {
|
||||||
|
buf := b[0:8]
|
||||||
|
binary.BigEndian.PutUint64(buf, uint64(n))
|
||||||
|
return buf
|
||||||
|
}
|
@ -0,0 +1,216 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bind struct {
|
||||||
|
DestinationPortal string
|
||||||
|
PreparedStatement string
|
||||||
|
ParameterFormatCodes []int16
|
||||||
|
Parameters [][]byte
|
||||||
|
ResultFormatCodes []int16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Bind) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Bind) Decode(src []byte) error {
|
||||||
|
*dst = Bind{}
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
dst.DestinationPortal = string(src[:idx])
|
||||||
|
rp := idx + 1
|
||||||
|
|
||||||
|
idx = bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
dst.PreparedStatement = string(src[rp : rp+idx])
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
if len(src[rp:]) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
if parameterFormatCodeCount > 0 {
|
||||||
|
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
|
||||||
|
|
||||||
|
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
for i := 0; i < parameterFormatCodeCount; i++ {
|
||||||
|
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
if parameterCount > 0 {
|
||||||
|
dst.Parameters = make([][]byte, parameterCount)
|
||||||
|
|
||||||
|
for i := 0; i < parameterCount; i++ {
|
||||||
|
if len(src[rp:]) < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
// null
|
||||||
|
if msgSize == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) < msgSize {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Parameters[i] = src[rp : rp+msgSize]
|
||||||
|
rp += msgSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
|
||||||
|
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||||
|
}
|
||||||
|
for i := 0; i < resultFormatCodeCount; i++ {
|
||||||
|
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Bind) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'B')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.DestinationPortal...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, src.PreparedStatement...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||||
|
for _, fc := range src.ParameterFormatCodes {
|
||||||
|
dst = pgio.AppendInt16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||||
|
for _, p := range src.Parameters {
|
||||||
|
if p == nil {
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(p)))
|
||||||
|
dst = append(dst, p...)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||||
|
for _, fc := range src.ResultFormatCodes {
|
||||||
|
dst = pgio.AppendInt16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Bind) MarshalJSON() ([]byte, error) {
|
||||||
|
formattedParameters := make([]map[string]string, len(src.Parameters))
|
||||||
|
for i, p := range src.Parameters {
|
||||||
|
if p == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
textFormat := true
|
||||||
|
if len(src.ParameterFormatCodes) == 1 {
|
||||||
|
textFormat = src.ParameterFormatCodes[0] == 0
|
||||||
|
} else if len(src.ParameterFormatCodes) > 1 {
|
||||||
|
textFormat = src.ParameterFormatCodes[i] == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if textFormat {
|
||||||
|
formattedParameters[i] = map[string]string{"text": string(p)}
|
||||||
|
} else {
|
||||||
|
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
DestinationPortal string
|
||||||
|
PreparedStatement string
|
||||||
|
ParameterFormatCodes []int16
|
||||||
|
Parameters []map[string]string
|
||||||
|
ResultFormatCodes []int16
|
||||||
|
}{
|
||||||
|
Type: "Bind",
|
||||||
|
DestinationPortal: src.DestinationPortal,
|
||||||
|
PreparedStatement: src.PreparedStatement,
|
||||||
|
ParameterFormatCodes: src.ParameterFormatCodes,
|
||||||
|
Parameters: formattedParameters,
|
||||||
|
ResultFormatCodes: src.ResultFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *Bind) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
DestinationPortal string
|
||||||
|
PreparedStatement string
|
||||||
|
ParameterFormatCodes []int16
|
||||||
|
Parameters []map[string]string
|
||||||
|
ResultFormatCodes []int16
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(data, &msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.DestinationPortal = msg.DestinationPortal
|
||||||
|
dst.PreparedStatement = msg.PreparedStatement
|
||||||
|
dst.ParameterFormatCodes = msg.ParameterFormatCodes
|
||||||
|
dst.Parameters = make([][]byte, len(msg.Parameters))
|
||||||
|
dst.ResultFormatCodes = msg.ResultFormatCodes
|
||||||
|
for n, parameter := range msg.Parameters {
|
||||||
|
dst.Parameters[n], err = getValueFromJSON(parameter)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot get param %d: %w", n, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BindComplete struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*BindComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *BindComplete) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, '2', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src BindComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "BindComplete",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,58 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const cancelRequestCode = 80877102
|
||||||
|
|
||||||
|
type CancelRequest struct {
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CancelRequest) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *CancelRequest) Decode(src []byte) error {
|
||||||
|
if len(src) != 12 {
|
||||||
|
return errors.New("bad cancel request size")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCode := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if requestCode != cancelRequestCode {
|
||||||
|
return errors.New("bad cancel request code")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
|
||||||
|
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
|
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||||
|
dst = pgio.AppendInt32(dst, 16)
|
||||||
|
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||||
|
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||||
|
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CancelRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProcessID uint32
|
||||||
|
SecretKey uint32
|
||||||
|
}{
|
||||||
|
Type: "CancelRequest",
|
||||||
|
ProcessID: src.ProcessID,
|
||||||
|
SecretKey: src.SecretKey,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,19 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/jackc/chunkreader/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package.
|
||||||
|
type ChunkReader interface {
|
||||||
|
// Next returns buf filled with the next n bytes. If an error (including a partial read) occurs,
|
||||||
|
// buf must be nil. Next must preserve any partially read data. Next must not reuse buf.
|
||||||
|
Next(n int) (buf []byte, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChunkReader creates and returns a new default ChunkReader.
|
||||||
|
func NewChunkReader(r io.Reader) ChunkReader {
|
||||||
|
return chunkreader.New(r)
|
||||||
|
}
|
@ -0,0 +1,89 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Close struct {
|
||||||
|
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Close) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Close) Decode(src []byte) error {
|
||||||
|
if len(src) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Close"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = src[0]
|
||||||
|
rp := 1
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx != len(src[rp:])-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Close"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Name = string(src[rp : len(src)-1])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Close) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'C')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.ObjectType)
|
||||||
|
dst = append(dst, src.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Close) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}{
|
||||||
|
Type: "Close",
|
||||||
|
ObjectType: string(src.ObjectType),
|
||||||
|
Name: src.Name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *Close) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.ObjectType) != 1 {
|
||||||
|
return errors.New("invalid length for Close.ObjectType")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = byte(msg.ObjectType[0])
|
||||||
|
dst.Name = msg.Name
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CloseComplete struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CloseComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CloseComplete) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, '3', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CloseComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "CloseComplete",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,71 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CommandComplete struct {
|
||||||
|
CommandTag []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CommandComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CommandComplete) Decode(src []byte) error {
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx != len(src)-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CommandComplete"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.CommandTag = src[:idx]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'C')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.CommandTag...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CommandComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
CommandTag string
|
||||||
|
}{
|
||||||
|
Type: "CommandComplete",
|
||||||
|
CommandTag: string(src.CommandTag),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
CommandTag string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.CommandTag = []byte(msg.CommandTag)
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,95 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyBothResponse struct {
|
||||||
|
OverallFormat byte
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyBothResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
if buf.Len() < 3 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
overallFormat := buf.Next(1)[0]
|
||||||
|
|
||||||
|
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||||
|
if buf.Len() != columnCount*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnFormatCodes := make([]uint16, columnCount)
|
||||||
|
for i := 0; i < columnCount; i++ {
|
||||||
|
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'W')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
dst = append(dst, src.OverallFormat)
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}{
|
||||||
|
Type: "CopyBothResponse",
|
||||||
|
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
OverallFormat string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.OverallFormat) != 1 {
|
||||||
|
return errors.New("invalid length for CopyBothResponse.OverallFormat")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.OverallFormat = msg.OverallFormat[0]
|
||||||
|
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,62 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyData struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyData) Backend() {}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CopyData) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyData) Decode(src []byte) error {
|
||||||
|
dst.Data = src
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyData) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'd')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyData) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "CopyData",
|
||||||
|
Data: hex.EncodeToString(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyData) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyDone struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyDone) Backend() {}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CopyDone) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyDone) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyDone) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'c', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyDone) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "CopyDone",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,53 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyFail struct {
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*CopyFail) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyFail) Decode(src []byte) error {
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx != len(src)-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyFail"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Message = string(src[:idx])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyFail) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'f')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.Message...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyFail) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Message string
|
||||||
|
}{
|
||||||
|
Type: "CopyFail",
|
||||||
|
Message: src.Message,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,96 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyInResponse struct {
|
||||||
|
OverallFormat byte
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*CopyInResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyInResponse) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
if buf.Len() < 3 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
overallFormat := buf.Next(1)[0]
|
||||||
|
|
||||||
|
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||||
|
if buf.Len() != columnCount*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnFormatCodes := make([]uint16, columnCount)
|
||||||
|
for i := 0; i < columnCount; i++ {
|
||||||
|
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'G')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.OverallFormat)
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyInResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}{
|
||||||
|
Type: "CopyInResponse",
|
||||||
|
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyInResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
OverallFormat string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.OverallFormat) != 1 {
|
||||||
|
return errors.New("invalid length for CopyInResponse.OverallFormat")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.OverallFormat = msg.OverallFormat[0]
|
||||||
|
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,96 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CopyOutResponse struct {
|
||||||
|
OverallFormat byte
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*CopyOutResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *CopyOutResponse) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
if buf.Len() < 3 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
overallFormat := buf.Next(1)[0]
|
||||||
|
|
||||||
|
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||||
|
if buf.Len() != columnCount*2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
columnFormatCodes := make([]uint16, columnCount)
|
||||||
|
for i := 0; i < columnCount; i++ {
|
||||||
|
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'H')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.OverallFormat)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||||
|
for _, fc := range src.ColumnFormatCodes {
|
||||||
|
dst = pgio.AppendUint16(dst, fc)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}{
|
||||||
|
Type: "CopyOutResponse",
|
||||||
|
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
OverallFormat string
|
||||||
|
ColumnFormatCodes []uint16
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msg.OverallFormat) != 1 {
|
||||||
|
return errors.New("invalid length for CopyOutResponse.OverallFormat")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.OverallFormat = msg.OverallFormat[0]
|
||||||
|
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,142 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DataRow struct {
|
||||||
|
Values [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*DataRow) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *DataRow) Decode(src []byte) error {
|
||||||
|
if len(src) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||||
|
}
|
||||||
|
rp := 0
|
||||||
|
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
// If the capacity of the values slice is too small OR substantially too
|
||||||
|
// large reallocate. This is too avoid one row with many columns from
|
||||||
|
// permanently allocating memory.
|
||||||
|
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
|
||||||
|
newCap := 32
|
||||||
|
if newCap < fieldCount {
|
||||||
|
newCap = fieldCount
|
||||||
|
}
|
||||||
|
dst.Values = make([][]byte, fieldCount, newCap)
|
||||||
|
} else {
|
||||||
|
dst.Values = dst.Values[:fieldCount]
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < fieldCount; i++ {
|
||||||
|
if len(src[rp:]) < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
// null
|
||||||
|
if msgSize == -1 {
|
||||||
|
dst.Values[i] = nil
|
||||||
|
} else {
|
||||||
|
if len(src[rp:]) < msgSize {
|
||||||
|
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Values[i] = src[rp : rp+msgSize : rp+msgSize]
|
||||||
|
rp += msgSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *DataRow) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'D')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
||||||
|
for _, v := range src.Values {
|
||||||
|
if v == nil {
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(v)))
|
||||||
|
dst = append(dst, v...)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src DataRow) MarshalJSON() ([]byte, error) {
|
||||||
|
formattedValues := make([]map[string]string, len(src.Values))
|
||||||
|
for i, v := range src.Values {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var hasNonPrintable bool
|
||||||
|
for _, b := range v {
|
||||||
|
if b < 32 {
|
||||||
|
hasNonPrintable = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasNonPrintable {
|
||||||
|
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
|
||||||
|
} else {
|
||||||
|
formattedValues[i] = map[string]string{"text": string(v)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Values []map[string]string
|
||||||
|
}{
|
||||||
|
Type: "DataRow",
|
||||||
|
Values: formattedValues,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *DataRow) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Values []map[string]string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Values = make([][]byte, len(msg.Values))
|
||||||
|
for n, parameter := range msg.Values {
|
||||||
|
var err error
|
||||||
|
dst.Values[n], err = getValueFromJSON(parameter)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,88 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Describe struct {
|
||||||
|
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Describe) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Describe) Decode(src []byte) error {
|
||||||
|
if len(src) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = src[0]
|
||||||
|
rp := 1
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx != len(src[rp:])-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Name = string(src[rp : len(src)-1])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Describe) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'D')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.ObjectType)
|
||||||
|
dst = append(dst, src.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Describe) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}{
|
||||||
|
Type: "Describe",
|
||||||
|
ObjectType: string(src.ObjectType),
|
||||||
|
Name: src.Name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *Describe) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
ObjectType string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(msg.ObjectType) != 1 {
|
||||||
|
return errors.New("invalid length for Describe.ObjectType")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ObjectType = byte(msg.ObjectType[0])
|
||||||
|
dst.Name = msg.Name
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||||
|
//
|
||||||
|
// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages.
|
||||||
|
package pgproto3
|
@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EmptyQueryResponse struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*EmptyQueryResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'I', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "EmptyQueryResponse",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,334 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Severity string
|
||||||
|
SeverityUnlocalized string // only in 9.6 and greater
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
|
||||||
|
UnknownFields map[byte]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ErrorResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ErrorResponse) Decode(src []byte) error {
|
||||||
|
*dst = ErrorResponse{}
|
||||||
|
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
for {
|
||||||
|
k, err := buf.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if k == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
vb, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
v := string(vb[:len(vb)-1])
|
||||||
|
|
||||||
|
switch k {
|
||||||
|
case 'S':
|
||||||
|
dst.Severity = v
|
||||||
|
case 'V':
|
||||||
|
dst.SeverityUnlocalized = v
|
||||||
|
case 'C':
|
||||||
|
dst.Code = v
|
||||||
|
case 'M':
|
||||||
|
dst.Message = v
|
||||||
|
case 'D':
|
||||||
|
dst.Detail = v
|
||||||
|
case 'H':
|
||||||
|
dst.Hint = v
|
||||||
|
case 'P':
|
||||||
|
s := v
|
||||||
|
n, _ := strconv.ParseInt(s, 10, 32)
|
||||||
|
dst.Position = int32(n)
|
||||||
|
case 'p':
|
||||||
|
s := v
|
||||||
|
n, _ := strconv.ParseInt(s, 10, 32)
|
||||||
|
dst.InternalPosition = int32(n)
|
||||||
|
case 'q':
|
||||||
|
dst.InternalQuery = v
|
||||||
|
case 'W':
|
||||||
|
dst.Where = v
|
||||||
|
case 's':
|
||||||
|
dst.SchemaName = v
|
||||||
|
case 't':
|
||||||
|
dst.TableName = v
|
||||||
|
case 'c':
|
||||||
|
dst.ColumnName = v
|
||||||
|
case 'd':
|
||||||
|
dst.DataTypeName = v
|
||||||
|
case 'n':
|
||||||
|
dst.ConstraintName = v
|
||||||
|
case 'F':
|
||||||
|
dst.File = v
|
||||||
|
case 'L':
|
||||||
|
s := v
|
||||||
|
n, _ := strconv.ParseInt(s, 10, 32)
|
||||||
|
dst.Line = int32(n)
|
||||||
|
case 'R':
|
||||||
|
dst.Routine = v
|
||||||
|
|
||||||
|
default:
|
||||||
|
if dst.UnknownFields == nil {
|
||||||
|
dst.UnknownFields = make(map[byte]string)
|
||||||
|
}
|
||||||
|
dst.UnknownFields[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, src.marshalBinary('E')...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
||||||
|
var bigEndian BigEndianBuf
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
|
||||||
|
buf.WriteByte(typeByte)
|
||||||
|
buf.Write(bigEndian.Uint32(0))
|
||||||
|
|
||||||
|
if src.Severity != "" {
|
||||||
|
buf.WriteByte('S')
|
||||||
|
buf.WriteString(src.Severity)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.SeverityUnlocalized != "" {
|
||||||
|
buf.WriteByte('V')
|
||||||
|
buf.WriteString(src.SeverityUnlocalized)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Code != "" {
|
||||||
|
buf.WriteByte('C')
|
||||||
|
buf.WriteString(src.Code)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Message != "" {
|
||||||
|
buf.WriteByte('M')
|
||||||
|
buf.WriteString(src.Message)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Detail != "" {
|
||||||
|
buf.WriteByte('D')
|
||||||
|
buf.WriteString(src.Detail)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Hint != "" {
|
||||||
|
buf.WriteByte('H')
|
||||||
|
buf.WriteString(src.Hint)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Position != 0 {
|
||||||
|
buf.WriteByte('P')
|
||||||
|
buf.WriteString(strconv.Itoa(int(src.Position)))
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.InternalPosition != 0 {
|
||||||
|
buf.WriteByte('p')
|
||||||
|
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.InternalQuery != "" {
|
||||||
|
buf.WriteByte('q')
|
||||||
|
buf.WriteString(src.InternalQuery)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Where != "" {
|
||||||
|
buf.WriteByte('W')
|
||||||
|
buf.WriteString(src.Where)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.SchemaName != "" {
|
||||||
|
buf.WriteByte('s')
|
||||||
|
buf.WriteString(src.SchemaName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.TableName != "" {
|
||||||
|
buf.WriteByte('t')
|
||||||
|
buf.WriteString(src.TableName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.ColumnName != "" {
|
||||||
|
buf.WriteByte('c')
|
||||||
|
buf.WriteString(src.ColumnName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.DataTypeName != "" {
|
||||||
|
buf.WriteByte('d')
|
||||||
|
buf.WriteString(src.DataTypeName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.ConstraintName != "" {
|
||||||
|
buf.WriteByte('n')
|
||||||
|
buf.WriteString(src.ConstraintName)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.File != "" {
|
||||||
|
buf.WriteByte('F')
|
||||||
|
buf.WriteString(src.File)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Line != 0 {
|
||||||
|
buf.WriteByte('L')
|
||||||
|
buf.WriteString(strconv.Itoa(int(src.Line)))
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
if src.Routine != "" {
|
||||||
|
buf.WriteByte('R')
|
||||||
|
buf.WriteString(src.Routine)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range src.UnknownFields {
|
||||||
|
buf.WriteByte(k)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
buf.WriteString(v)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteByte(0)
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src ErrorResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Severity string
|
||||||
|
SeverityUnlocalized string // only in 9.6 and greater
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
|
||||||
|
UnknownFields map[byte]string
|
||||||
|
}{
|
||||||
|
Type: "ErrorResponse",
|
||||||
|
Severity: src.Severity,
|
||||||
|
SeverityUnlocalized: src.SeverityUnlocalized,
|
||||||
|
Code: src.Code,
|
||||||
|
Message: src.Message,
|
||||||
|
Detail: src.Detail,
|
||||||
|
Hint: src.Hint,
|
||||||
|
Position: src.Position,
|
||||||
|
InternalPosition: src.InternalPosition,
|
||||||
|
InternalQuery: src.InternalQuery,
|
||||||
|
Where: src.Where,
|
||||||
|
SchemaName: src.SchemaName,
|
||||||
|
TableName: src.TableName,
|
||||||
|
ColumnName: src.ColumnName,
|
||||||
|
DataTypeName: src.DataTypeName,
|
||||||
|
ConstraintName: src.ConstraintName,
|
||||||
|
File: src.File,
|
||||||
|
Line: src.Line,
|
||||||
|
Routine: src.Routine,
|
||||||
|
UnknownFields: src.UnknownFields,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Type string
|
||||||
|
Severity string
|
||||||
|
SeverityUnlocalized string // only in 9.6 and greater
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
Detail string
|
||||||
|
Hint string
|
||||||
|
Position int32
|
||||||
|
InternalPosition int32
|
||||||
|
InternalQuery string
|
||||||
|
Where string
|
||||||
|
SchemaName string
|
||||||
|
TableName string
|
||||||
|
ColumnName string
|
||||||
|
DataTypeName string
|
||||||
|
ConstraintName string
|
||||||
|
File string
|
||||||
|
Line int32
|
||||||
|
Routine string
|
||||||
|
|
||||||
|
UnknownFields map[byte]string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Severity = msg.Severity
|
||||||
|
dst.SeverityUnlocalized = msg.SeverityUnlocalized
|
||||||
|
dst.Code = msg.Code
|
||||||
|
dst.Message = msg.Message
|
||||||
|
dst.Detail = msg.Detail
|
||||||
|
dst.Hint = msg.Hint
|
||||||
|
dst.Position = msg.Position
|
||||||
|
dst.InternalPosition = msg.InternalPosition
|
||||||
|
dst.InternalQuery = msg.InternalQuery
|
||||||
|
dst.Where = msg.Where
|
||||||
|
dst.SchemaName = msg.SchemaName
|
||||||
|
dst.TableName = msg.TableName
|
||||||
|
dst.ColumnName = msg.ColumnName
|
||||||
|
dst.DataTypeName = msg.DataTypeName
|
||||||
|
dst.ConstraintName = msg.ConstraintName
|
||||||
|
dst.File = msg.File
|
||||||
|
dst.Line = msg.Line
|
||||||
|
dst.Routine = msg.Routine
|
||||||
|
|
||||||
|
dst.UnknownFields = msg.UnknownFields
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,65 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Execute struct {
|
||||||
|
Portal string
|
||||||
|
MaxRows uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Execute) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Execute) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Portal = string(b[:len(b)-1])
|
||||||
|
|
||||||
|
if buf.Len() < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Execute"}
|
||||||
|
}
|
||||||
|
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Execute) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'E')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.Portal...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint32(dst, src.MaxRows)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Execute) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Portal string
|
||||||
|
MaxRows uint32
|
||||||
|
}{
|
||||||
|
Type: "Execute",
|
||||||
|
Portal: src.Portal,
|
||||||
|
MaxRows: src.MaxRows,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Flush struct{}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Flush) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Flush) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Flush) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'H', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Flush) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "Flush",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,206 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
|
||||||
|
type Frontend struct {
|
||||||
|
cr ChunkReader
|
||||||
|
w io.Writer
|
||||||
|
|
||||||
|
// Backend message flyweights
|
||||||
|
authenticationOk AuthenticationOk
|
||||||
|
authenticationCleartextPassword AuthenticationCleartextPassword
|
||||||
|
authenticationMD5Password AuthenticationMD5Password
|
||||||
|
authenticationGSS AuthenticationGSS
|
||||||
|
authenticationGSSContinue AuthenticationGSSContinue
|
||||||
|
authenticationSASL AuthenticationSASL
|
||||||
|
authenticationSASLContinue AuthenticationSASLContinue
|
||||||
|
authenticationSASLFinal AuthenticationSASLFinal
|
||||||
|
backendKeyData BackendKeyData
|
||||||
|
bindComplete BindComplete
|
||||||
|
closeComplete CloseComplete
|
||||||
|
commandComplete CommandComplete
|
||||||
|
copyBothResponse CopyBothResponse
|
||||||
|
copyData CopyData
|
||||||
|
copyInResponse CopyInResponse
|
||||||
|
copyOutResponse CopyOutResponse
|
||||||
|
copyDone CopyDone
|
||||||
|
dataRow DataRow
|
||||||
|
emptyQueryResponse EmptyQueryResponse
|
||||||
|
errorResponse ErrorResponse
|
||||||
|
functionCallResponse FunctionCallResponse
|
||||||
|
noData NoData
|
||||||
|
noticeResponse NoticeResponse
|
||||||
|
notificationResponse NotificationResponse
|
||||||
|
parameterDescription ParameterDescription
|
||||||
|
parameterStatus ParameterStatus
|
||||||
|
parseComplete ParseComplete
|
||||||
|
readyForQuery ReadyForQuery
|
||||||
|
rowDescription RowDescription
|
||||||
|
portalSuspended PortalSuspended
|
||||||
|
|
||||||
|
bodyLen int
|
||||||
|
msgType byte
|
||||||
|
partialMsg bool
|
||||||
|
authType uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFrontend creates a new Frontend.
|
||||||
|
func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
|
||||||
|
return &Frontend{cr: cr, w: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send sends a message to the backend.
|
||||||
|
func (f *Frontend) Send(msg FrontendMessage) error {
|
||||||
|
_, err := f.w.Write(msg.Encode(nil))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateEOFtoErrUnexpectedEOF(err error) error {
|
||||||
|
if err == io.EOF {
|
||||||
|
return io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
|
||||||
|
func (f *Frontend) Receive() (BackendMessage, error) {
|
||||||
|
if !f.partialMsg {
|
||||||
|
header, err := f.cr.Next(5)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.msgType = header[0]
|
||||||
|
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||||
|
f.partialMsg = true
|
||||||
|
if f.bodyLen < 0 {
|
||||||
|
return nil, errors.New("invalid message with negative body length received")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgBody, err := f.cr.Next(f.bodyLen)
|
||||||
|
if err != nil {
|
||||||
|
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.partialMsg = false
|
||||||
|
|
||||||
|
var msg BackendMessage
|
||||||
|
switch f.msgType {
|
||||||
|
case '1':
|
||||||
|
msg = &f.parseComplete
|
||||||
|
case '2':
|
||||||
|
msg = &f.bindComplete
|
||||||
|
case '3':
|
||||||
|
msg = &f.closeComplete
|
||||||
|
case 'A':
|
||||||
|
msg = &f.notificationResponse
|
||||||
|
case 'c':
|
||||||
|
msg = &f.copyDone
|
||||||
|
case 'C':
|
||||||
|
msg = &f.commandComplete
|
||||||
|
case 'd':
|
||||||
|
msg = &f.copyData
|
||||||
|
case 'D':
|
||||||
|
msg = &f.dataRow
|
||||||
|
case 'E':
|
||||||
|
msg = &f.errorResponse
|
||||||
|
case 'G':
|
||||||
|
msg = &f.copyInResponse
|
||||||
|
case 'H':
|
||||||
|
msg = &f.copyOutResponse
|
||||||
|
case 'I':
|
||||||
|
msg = &f.emptyQueryResponse
|
||||||
|
case 'K':
|
||||||
|
msg = &f.backendKeyData
|
||||||
|
case 'n':
|
||||||
|
msg = &f.noData
|
||||||
|
case 'N':
|
||||||
|
msg = &f.noticeResponse
|
||||||
|
case 'R':
|
||||||
|
var err error
|
||||||
|
msg, err = f.findAuthenticationMessageType(msgBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case 's':
|
||||||
|
msg = &f.portalSuspended
|
||||||
|
case 'S':
|
||||||
|
msg = &f.parameterStatus
|
||||||
|
case 't':
|
||||||
|
msg = &f.parameterDescription
|
||||||
|
case 'T':
|
||||||
|
msg = &f.rowDescription
|
||||||
|
case 'V':
|
||||||
|
msg = &f.functionCallResponse
|
||||||
|
case 'W':
|
||||||
|
msg = &f.copyBothResponse
|
||||||
|
case 'Z':
|
||||||
|
msg = &f.readyForQuery
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = msg.Decode(msgBody)
|
||||||
|
return msg, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authentication message type constants.
|
||||||
|
// See src/include/libpq/pqcomm.h for all
|
||||||
|
// constants.
|
||||||
|
const (
|
||||||
|
AuthTypeOk = 0
|
||||||
|
AuthTypeCleartextPassword = 3
|
||||||
|
AuthTypeMD5Password = 5
|
||||||
|
AuthTypeSCMCreds = 6
|
||||||
|
AuthTypeGSS = 7
|
||||||
|
AuthTypeGSSCont = 8
|
||||||
|
AuthTypeSSPI = 9
|
||||||
|
AuthTypeSASL = 10
|
||||||
|
AuthTypeSASLContinue = 11
|
||||||
|
AuthTypeSASLFinal = 12
|
||||||
|
)
|
||||||
|
|
||||||
|
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return nil, errors.New("authentication message too short")
|
||||||
|
}
|
||||||
|
f.authType = binary.BigEndian.Uint32(src[:4])
|
||||||
|
|
||||||
|
switch f.authType {
|
||||||
|
case AuthTypeOk:
|
||||||
|
return &f.authenticationOk, nil
|
||||||
|
case AuthTypeCleartextPassword:
|
||||||
|
return &f.authenticationCleartextPassword, nil
|
||||||
|
case AuthTypeMD5Password:
|
||||||
|
return &f.authenticationMD5Password, nil
|
||||||
|
case AuthTypeSCMCreds:
|
||||||
|
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
|
||||||
|
case AuthTypeGSS:
|
||||||
|
return &f.authenticationGSS, nil
|
||||||
|
case AuthTypeGSSCont:
|
||||||
|
return &f.authenticationGSSContinue, nil
|
||||||
|
case AuthTypeSSPI:
|
||||||
|
return nil, errors.New("AuthTypeSSPI is unimplemented")
|
||||||
|
case AuthTypeSASL:
|
||||||
|
return &f.authenticationSASL, nil
|
||||||
|
case AuthTypeSASLContinue:
|
||||||
|
return &f.authenticationSASLContinue, nil
|
||||||
|
case AuthTypeSASLFinal:
|
||||||
|
return &f.authenticationSASLFinal, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthType returns the authType used in the current state of the frontend.
|
||||||
|
// See SetAuthType for more information.
|
||||||
|
func (f *Frontend) GetAuthType() uint32 {
|
||||||
|
return f.authType
|
||||||
|
}
|
@ -0,0 +1,94 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Function uint32
|
||||||
|
ArgFormatCodes []uint16
|
||||||
|
Arguments [][]byte
|
||||||
|
ResultFormatCode uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*FunctionCall) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *FunctionCall) Decode(src []byte) error {
|
||||||
|
*dst = FunctionCall{}
|
||||||
|
rp := 0
|
||||||
|
// Specifies the object ID of the function to call.
|
||||||
|
dst.Function = binary.BigEndian.Uint32(src[rp:])
|
||||||
|
rp += 4
|
||||||
|
// The number of argument format codes that follow (denoted C below).
|
||||||
|
// This can be zero to indicate that there are no arguments or that the arguments all use the default format (text);
|
||||||
|
// or one, in which case the specified format code is applied to all arguments;
|
||||||
|
// or it can equal the actual number of arguments.
|
||||||
|
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
argumentCodes := make([]uint16, nArgumentCodes)
|
||||||
|
for i := 0; i < nArgumentCodes; i++ {
|
||||||
|
// The argument format codes. Each must presently be zero (text) or one (binary).
|
||||||
|
ac := binary.BigEndian.Uint16(src[rp:])
|
||||||
|
if ac != 0 && ac != 1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||||
|
}
|
||||||
|
argumentCodes[i] = ac
|
||||||
|
rp += 2
|
||||||
|
}
|
||||||
|
dst.ArgFormatCodes = argumentCodes
|
||||||
|
|
||||||
|
// Specifies the number of arguments being supplied to the function.
|
||||||
|
nArguments := int(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
arguments := make([][]byte, nArguments)
|
||||||
|
for i := 0; i < nArguments; i++ {
|
||||||
|
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
|
||||||
|
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
|
||||||
|
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
if argumentLength == -1 {
|
||||||
|
arguments[i] = nil
|
||||||
|
} else {
|
||||||
|
// The value of the argument, in the format indicated by the associated format code. n is the above length.
|
||||||
|
argumentValue := src[rp : rp+argumentLength]
|
||||||
|
rp += argumentLength
|
||||||
|
arguments[i] = argumentValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst.Arguments = arguments
|
||||||
|
// The format code for the function result. Must presently be zero (text) or one (binary).
|
||||||
|
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
|
||||||
|
if resultFormatCode != 0 && resultFormatCode != 1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||||
|
}
|
||||||
|
dst.ResultFormatCode = resultFormatCode
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *FunctionCall) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'F')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
|
||||||
|
dst = pgio.AppendUint32(dst, src.Function)
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
|
||||||
|
for _, argFormatCode := range src.ArgFormatCodes {
|
||||||
|
dst = pgio.AppendUint16(dst, argFormatCode)
|
||||||
|
}
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
|
||||||
|
for _, argument := range src.Arguments {
|
||||||
|
if argument == nil {
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
} else {
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(argument)))
|
||||||
|
dst = append(dst, argument...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
return dst
|
||||||
|
}
|
@ -0,0 +1,101 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FunctionCallResponse struct {
|
||||||
|
Result []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*FunctionCallResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *FunctionCallResponse) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||||
|
}
|
||||||
|
rp := 0
|
||||||
|
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
if resultSize == -1 {
|
||||||
|
dst.Result = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src[rp:]) != resultSize {
|
||||||
|
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Result = src[rp:]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'V')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
if src.Result == nil {
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
} else {
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
|
||||||
|
dst = append(dst, src.Result...)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src FunctionCallResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
var formattedValue map[string]string
|
||||||
|
var hasNonPrintable bool
|
||||||
|
for _, b := range src.Result {
|
||||||
|
if b < 32 {
|
||||||
|
hasNonPrintable = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasNonPrintable {
|
||||||
|
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
|
||||||
|
} else {
|
||||||
|
formattedValue = map[string]string{"text": string(src.Result)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Result map[string]string
|
||||||
|
}{
|
||||||
|
Type: "FunctionCallResponse",
|
||||||
|
Result: formattedValue,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
Result map[string]string
|
||||||
|
}
|
||||||
|
err := json.Unmarshal(data, &msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Result, err = getValueFromJSON(msg.Result)
|
||||||
|
return err
|
||||||
|
}
|
@ -0,0 +1,49 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const gssEncReqNumber = 80877104
|
||||||
|
|
||||||
|
type GSSEncRequest struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*GSSEncRequest) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *GSSEncRequest) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("gss encoding request too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCode := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if requestCode != gssEncReqNumber {
|
||||||
|
return errors.New("bad gss encoding request code")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
|
func (src *GSSEncRequest) Encode(dst []byte) []byte {
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendInt32(dst, gssEncReqNumber)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src GSSEncRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}{
|
||||||
|
Type: "GSSEncRequest",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,48 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GSSResponse struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (g *GSSResponse) Frontend() {}
|
||||||
|
|
||||||
|
func (g *GSSResponse) Decode(data []byte) error {
|
||||||
|
g.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GSSResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'p')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
|
||||||
|
dst = append(dst, g.Data...)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data []byte
|
||||||
|
}{
|
||||||
|
Type: "GSSResponse",
|
||||||
|
Data: g.Data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
var msg struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
g.Data = msg.Data
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NoData struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*NoData) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *NoData) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *NoData) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'n', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src NoData) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "NoData",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,17 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
type NoticeResponse ErrorResponse
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*NoticeResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *NoticeResponse) Decode(src []byte) error {
|
||||||
|
return (*ErrorResponse)(dst).Decode(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *NoticeResponse) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
|
||||||
|
}
|
@ -0,0 +1,73 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type NotificationResponse struct {
|
||||||
|
PID uint32
|
||||||
|
Channel string
|
||||||
|
Payload string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*NotificationResponse) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *NotificationResponse) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
pid := binary.BigEndian.Uint32(buf.Next(4))
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
channel := string(b[:len(b)-1])
|
||||||
|
|
||||||
|
b, err = buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload := string(b[:len(b)-1])
|
||||||
|
|
||||||
|
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *NotificationResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'A')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint32(dst, src.PID)
|
||||||
|
dst = append(dst, src.Channel...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, src.Payload...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src NotificationResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
PID uint32
|
||||||
|
Channel string
|
||||||
|
Payload string
|
||||||
|
}{
|
||||||
|
Type: "NotificationResponse",
|
||||||
|
PID: src.PID,
|
||||||
|
Channel: src.Channel,
|
||||||
|
Payload: src.Payload,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,66 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ParameterDescription struct {
|
||||||
|
ParameterOIDs []uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ParameterDescription) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ParameterDescription) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
if buf.Len() < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reported parameter count will be incorrect when number of args is greater than uint16
|
||||||
|
buf.Next(2)
|
||||||
|
// Instead infer parameter count by remaining size of message
|
||||||
|
parameterCount := buf.Len() / 4
|
||||||
|
|
||||||
|
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
|
||||||
|
|
||||||
|
for i := 0; i < parameterCount; i++ {
|
||||||
|
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ParameterDescription) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 't')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||||
|
for _, oid := range src.ParameterOIDs {
|
||||||
|
dst = pgio.AppendUint32(dst, oid)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src ParameterDescription) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ParameterOIDs []uint32
|
||||||
|
}{
|
||||||
|
Type: "ParameterDescription",
|
||||||
|
ParameterOIDs: src.ParameterOIDs,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,66 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ParameterStatus struct {
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ParameterStatus) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ParameterStatus) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
name := string(b[:len(b)-1])
|
||||||
|
|
||||||
|
b, err = buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
value := string(b[:len(b)-1])
|
||||||
|
|
||||||
|
*dst = ParameterStatus{Name: name, Value: value}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ParameterStatus) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'S')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, src.Value...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (ps ParameterStatus) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
}{
|
||||||
|
Type: "ParameterStatus",
|
||||||
|
Name: ps.Name,
|
||||||
|
Value: ps.Value,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,88 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Parse struct {
|
||||||
|
Name string
|
||||||
|
Query string
|
||||||
|
ParameterOIDs []uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Parse) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Parse) Decode(src []byte) error {
|
||||||
|
*dst = Parse{}
|
||||||
|
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Name = string(b[:len(b)-1])
|
||||||
|
|
||||||
|
b, err = buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Query = string(b[:len(b)-1])
|
||||||
|
|
||||||
|
if buf.Len() < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||||
|
}
|
||||||
|
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||||
|
|
||||||
|
for i := 0; i < parameterOIDCount; i++ {
|
||||||
|
if buf.Len() < 4 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||||
|
}
|
||||||
|
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Parse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'P')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, src.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, src.Query...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||||
|
for _, oid := range src.ParameterOIDs {
|
||||||
|
dst = pgio.AppendUint32(dst, oid)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Parse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Name string
|
||||||
|
Query string
|
||||||
|
ParameterOIDs []uint32
|
||||||
|
}{
|
||||||
|
Type: "Parse",
|
||||||
|
Name: src.Name,
|
||||||
|
Query: src.Query,
|
||||||
|
ParameterOIDs: src.ParameterOIDs,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ParseComplete struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ParseComplete) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ParseComplete) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ParseComplete) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, '1', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src ParseComplete) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "ParseComplete",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,54 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PasswordMessage struct {
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*PasswordMessage) Frontend() {}
|
||||||
|
|
||||||
|
// Frontend identifies this message as an authentication response.
|
||||||
|
func (*PasswordMessage) InitialResponse() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *PasswordMessage) Decode(src []byte) error {
|
||||||
|
buf := bytes.NewBuffer(src)
|
||||||
|
|
||||||
|
b, err := buf.ReadBytes(0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Password = string(b[:len(b)-1])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *PasswordMessage) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'p')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
|
||||||
|
|
||||||
|
dst = append(dst, src.Password...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src PasswordMessage) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Password string
|
||||||
|
}{
|
||||||
|
Type: "PasswordMessage",
|
||||||
|
Password: src.Password,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,65 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message is the interface implemented by an object that can decode and encode
|
||||||
|
// a particular PostgreSQL message.
|
||||||
|
type Message interface {
|
||||||
|
// Decode is allowed and expected to retain a reference to data after
|
||||||
|
// returning (unlike encoding.BinaryUnmarshaler).
|
||||||
|
Decode(data []byte) error
|
||||||
|
|
||||||
|
// Encode appends itself to dst and returns the new buffer.
|
||||||
|
Encode(dst []byte) []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type FrontendMessage interface {
|
||||||
|
Message
|
||||||
|
Frontend() // no-op method to distinguish frontend from backend methods
|
||||||
|
}
|
||||||
|
|
||||||
|
type BackendMessage interface {
|
||||||
|
Message
|
||||||
|
Backend() // no-op method to distinguish frontend from backend methods
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthenticationResponseMessage interface {
|
||||||
|
BackendMessage
|
||||||
|
AuthenticationResponse() // no-op method to distinguish authentication responses
|
||||||
|
}
|
||||||
|
|
||||||
|
type invalidMessageLenErr struct {
|
||||||
|
messageType string
|
||||||
|
expectedLen int
|
||||||
|
actualLen int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidMessageLenErr) Error() string {
|
||||||
|
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
type invalidMessageFormatErr struct {
|
||||||
|
messageType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidMessageFormatErr) Error() string {
|
||||||
|
return fmt.Sprintf("%s body is invalid", e.messageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getValueFromJSON gets the value from a protocol message representation in JSON.
|
||||||
|
func getValueFromJSON(v map[string]string) ([]byte, error) {
|
||||||
|
if v == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if text, ok := v["text"]; ok {
|
||||||
|
return []byte(text), nil
|
||||||
|
}
|
||||||
|
if binary, ok := v["binary"]; ok {
|
||||||
|
return hex.DecodeString(binary)
|
||||||
|
}
|
||||||
|
return nil, errors.New("unknown protocol representation")
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PortalSuspended struct{}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*PortalSuspended) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *PortalSuspended) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *PortalSuspended) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 's', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src PortalSuspended) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "PortalSuspended",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,50 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Query struct {
|
||||||
|
String string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Query) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Query) Decode(src []byte) error {
|
||||||
|
i := bytes.IndexByte(src, 0)
|
||||||
|
if i != len(src)-1 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "Query"}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.String = string(src[:i])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Query) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'Q')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
|
||||||
|
|
||||||
|
dst = append(dst, src.String...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Query) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
String string
|
||||||
|
}{
|
||||||
|
Type: "Query",
|
||||||
|
String: src.String,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ReadyForQuery struct {
|
||||||
|
TxStatus byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*ReadyForQuery) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *ReadyForQuery) Decode(src []byte) error {
|
||||||
|
if len(src) != 1 {
|
||||||
|
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.TxStatus = src[0]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *ReadyForQuery) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src ReadyForQuery) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
TxStatus string
|
||||||
|
}{
|
||||||
|
Type: "ReadyForQuery",
|
||||||
|
TxStatus: string(src.TxStatus),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
TxStatus string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(msg.TxStatus) != 1 {
|
||||||
|
return errors.New("invalid length for ReadyForQuery.TxStatus")
|
||||||
|
}
|
||||||
|
dst.TxStatus = msg.TxStatus[0]
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,165 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TextFormat = 0
|
||||||
|
BinaryFormat = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
type FieldDescription struct {
|
||||||
|
Name []byte
|
||||||
|
TableOID uint32
|
||||||
|
TableAttributeNumber uint16
|
||||||
|
DataTypeOID uint32
|
||||||
|
DataTypeSize int16
|
||||||
|
TypeModifier int32
|
||||||
|
Format int16
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (fd FieldDescription) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Name string
|
||||||
|
TableOID uint32
|
||||||
|
TableAttributeNumber uint16
|
||||||
|
DataTypeOID uint32
|
||||||
|
DataTypeSize int16
|
||||||
|
TypeModifier int32
|
||||||
|
Format int16
|
||||||
|
}{
|
||||||
|
Name: string(fd.Name),
|
||||||
|
TableOID: fd.TableOID,
|
||||||
|
TableAttributeNumber: fd.TableAttributeNumber,
|
||||||
|
DataTypeOID: fd.DataTypeOID,
|
||||||
|
DataTypeSize: fd.DataTypeSize,
|
||||||
|
TypeModifier: fd.TypeModifier,
|
||||||
|
Format: fd.Format,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type RowDescription struct {
|
||||||
|
Fields []FieldDescription
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
|
func (*RowDescription) Backend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *RowDescription) Decode(src []byte) error {
|
||||||
|
|
||||||
|
if len(src) < 2 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||||
|
}
|
||||||
|
fieldCount := int(binary.BigEndian.Uint16(src))
|
||||||
|
rp := 2
|
||||||
|
|
||||||
|
dst.Fields = dst.Fields[0:0]
|
||||||
|
|
||||||
|
for i := 0; i < fieldCount; i++ {
|
||||||
|
var fd FieldDescription
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||||
|
}
|
||||||
|
fd.Name = src[rp : rp+idx]
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
// Since buf.Next() doesn't return an error if we hit the end of the buffer
|
||||||
|
// check Len ahead of time
|
||||||
|
if len(src[rp:]) < 18 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||||
|
}
|
||||||
|
|
||||||
|
fd.TableOID = binary.BigEndian.Uint32(src[rp:])
|
||||||
|
rp += 4
|
||||||
|
fd.TableAttributeNumber = binary.BigEndian.Uint16(src[rp:])
|
||||||
|
rp += 2
|
||||||
|
fd.DataTypeOID = binary.BigEndian.Uint32(src[rp:])
|
||||||
|
rp += 4
|
||||||
|
fd.DataTypeSize = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
fd.TypeModifier = int32(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
fd.Format = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||||
|
rp += 2
|
||||||
|
|
||||||
|
dst.Fields = append(dst.Fields, fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *RowDescription) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'T')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
|
||||||
|
for _, fd := range src.Fields {
|
||||||
|
dst = append(dst, fd.Name...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint32(dst, fd.TableOID)
|
||||||
|
dst = pgio.AppendUint16(dst, fd.TableAttributeNumber)
|
||||||
|
dst = pgio.AppendUint32(dst, fd.DataTypeOID)
|
||||||
|
dst = pgio.AppendInt16(dst, fd.DataTypeSize)
|
||||||
|
dst = pgio.AppendInt32(dst, fd.TypeModifier)
|
||||||
|
dst = pgio.AppendInt16(dst, fd.Format)
|
||||||
|
}
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src RowDescription) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Fields []FieldDescription
|
||||||
|
}{
|
||||||
|
Type: "RowDescription",
|
||||||
|
Fields: src.Fields,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *RowDescription) UnmarshalJSON(data []byte) error {
|
||||||
|
var msg struct {
|
||||||
|
Fields []struct {
|
||||||
|
Name string
|
||||||
|
TableOID uint32
|
||||||
|
TableAttributeNumber uint16
|
||||||
|
DataTypeOID uint32
|
||||||
|
DataTypeSize int16
|
||||||
|
TypeModifier int32
|
||||||
|
Format int16
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Fields = make([]FieldDescription, len(msg.Fields))
|
||||||
|
for n, field := range msg.Fields {
|
||||||
|
dst.Fields[n] = FieldDescription{
|
||||||
|
Name: []byte(field.Name),
|
||||||
|
TableOID: field.TableOID,
|
||||||
|
TableAttributeNumber: field.TableAttributeNumber,
|
||||||
|
DataTypeOID: field.DataTypeOID,
|
||||||
|
DataTypeSize: field.DataTypeSize,
|
||||||
|
TypeModifier: field.TypeModifier,
|
||||||
|
Format: field.Format,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,87 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SASLInitialResponse struct {
|
||||||
|
AuthMechanism string
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*SASLInitialResponse) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *SASLInitialResponse) Decode(src []byte) error {
|
||||||
|
*dst = SASLInitialResponse{}
|
||||||
|
|
||||||
|
rp := 0
|
||||||
|
|
||||||
|
idx := bytes.IndexByte(src, 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return errors.New("invalid SASLInitialResponse")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.AuthMechanism = string(src[rp:idx])
|
||||||
|
rp = idx + 1
|
||||||
|
|
||||||
|
rp += 4 // The rest of the message is data so we can just skip the size
|
||||||
|
dst.Data = src[rp:]
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *SASLInitialResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'p')
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = append(dst, []byte(src.AuthMechanism)...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src SASLInitialResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
AuthMechanism string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "SASLInitialResponse",
|
||||||
|
AuthMechanism: src.AuthMechanism,
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
// Ignore null, like in the main JSON package.
|
||||||
|
if string(data) == "null" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg struct {
|
||||||
|
AuthMechanism string
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.AuthMechanism = msg.AuthMechanism
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,54 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SASLResponse struct {
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*SASLResponse) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *SASLResponse) Decode(src []byte) error {
|
||||||
|
*dst = SASLResponse{Data: src}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *SASLResponse) Encode(dst []byte) []byte {
|
||||||
|
dst = append(dst, 'p')
|
||||||
|
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||||
|
|
||||||
|
dst = append(dst, src.Data...)
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src SASLResponse) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
Data string
|
||||||
|
}{
|
||||||
|
Type: "SASLResponse",
|
||||||
|
Data: string(src.Data),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||||
|
func (dst *SASLResponse) UnmarshalJSON(data []byte) error {
|
||||||
|
var msg struct {
|
||||||
|
Data string
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dst.Data = []byte(msg.Data)
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,49 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const sslRequestNumber = 80877103
|
||||||
|
|
||||||
|
type SSLRequest struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*SSLRequest) Frontend() {}
|
||||||
|
|
||||||
|
func (dst *SSLRequest) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("ssl request too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
requestCode := binary.BigEndian.Uint32(src)
|
||||||
|
|
||||||
|
if requestCode != sslRequestNumber {
|
||||||
|
return errors.New("bad ssl request code")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||||
|
func (src *SSLRequest) Encode(dst []byte) []byte {
|
||||||
|
dst = pgio.AppendInt32(dst, 8)
|
||||||
|
dst = pgio.AppendInt32(dst, sslRequestNumber)
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src SSLRequest) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}{
|
||||||
|
Type: "SSLRequest",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,96 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ProtocolVersionNumber = 196608 // 3.0
|
||||||
|
|
||||||
|
type StartupMessage struct {
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*StartupMessage) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *StartupMessage) Decode(src []byte) error {
|
||||||
|
if len(src) < 4 {
|
||||||
|
return errors.New("startup message too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.ProtocolVersion = binary.BigEndian.Uint32(src)
|
||||||
|
rp := 4
|
||||||
|
|
||||||
|
if dst.ProtocolVersion != ProtocolVersionNumber {
|
||||||
|
return fmt.Errorf("Bad startup message version number. Expected %d, got %d", ProtocolVersionNumber, dst.ProtocolVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.Parameters = make(map[string]string)
|
||||||
|
for {
|
||||||
|
idx := bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||||
|
}
|
||||||
|
key := string(src[rp : rp+idx])
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
idx = bytes.IndexByte(src[rp:], 0)
|
||||||
|
if idx < 0 {
|
||||||
|
return &invalidMessageFormatErr{messageType: "StartupMesage"}
|
||||||
|
}
|
||||||
|
value := string(src[rp : rp+idx])
|
||||||
|
rp += idx + 1
|
||||||
|
|
||||||
|
dst.Parameters[key] = value
|
||||||
|
|
||||||
|
if len(src[rp:]) == 1 {
|
||||||
|
if src[rp] != 0 {
|
||||||
|
return fmt.Errorf("Bad startup message last byte. Expected 0, got %d", src[rp])
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *StartupMessage) Encode(dst []byte) []byte {
|
||||||
|
sp := len(dst)
|
||||||
|
dst = pgio.AppendInt32(dst, -1)
|
||||||
|
|
||||||
|
dst = pgio.AppendUint32(dst, src.ProtocolVersion)
|
||||||
|
for k, v := range src.Parameters {
|
||||||
|
dst = append(dst, k...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
dst = append(dst, v...)
|
||||||
|
dst = append(dst, 0)
|
||||||
|
}
|
||||||
|
dst = append(dst, 0)
|
||||||
|
|
||||||
|
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||||
|
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src StartupMessage) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
ProtocolVersion uint32
|
||||||
|
Parameters map[string]string
|
||||||
|
}{
|
||||||
|
Type: "StartupMessage",
|
||||||
|
ProtocolVersion: src.ProtocolVersion,
|
||||||
|
Parameters: src.Parameters,
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Sync struct{}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Sync) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Sync) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "Sync", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Sync) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'S', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Sync) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "Sync",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,34 @@
|
|||||||
|
package pgproto3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Terminate struct{}
|
||||||
|
|
||||||
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
|
func (*Terminate) Frontend() {}
|
||||||
|
|
||||||
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
// type identifier and 4 byte message length.
|
||||||
|
func (dst *Terminate) Decode(src []byte) error {
|
||||||
|
if len(src) != 0 {
|
||||||
|
return &invalidMessageLenErr{messageType: "Terminate", expectedLen: 0, actualLen: len(src)}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||||
|
func (src *Terminate) Encode(dst []byte) []byte {
|
||||||
|
return append(dst, 'X', 0, 0, 0, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements encoding/json.Marshaler.
|
||||||
|
func (src Terminate) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
Type string
|
||||||
|
}{
|
||||||
|
Type: "Terminate",
|
||||||
|
})
|
||||||
|
}
|
@ -0,0 +1,164 @@
|
|||||||
|
# 1.14.0 (February 11, 2023)
|
||||||
|
|
||||||
|
* Fix: BC timestamp text format support (jozeflami)
|
||||||
|
* Add Scanner and Valuer interfaces to CIDR (Yurii Popivniak)
|
||||||
|
* Fix crash when nilifying pointer to sql.Scanner
|
||||||
|
|
||||||
|
# 1.13.0 (December 1, 2022)
|
||||||
|
|
||||||
|
* Fix: Reset jsonb before unmarshal (Tomas Odinas)
|
||||||
|
* Fix: return correct zero value when UUID conversion fails (ndrpnt)
|
||||||
|
* Fix: EncodeText for Lseg includes [ and ]
|
||||||
|
* Support sql Value and Scan for custom date type (Hubert Krauze)
|
||||||
|
* Support Ltree binary encoding (AmineChikhaoui)
|
||||||
|
* Fix: dates with "BC" (jozeflami)
|
||||||
|
|
||||||
|
# 1.12.0 (August 6, 2022)
|
||||||
|
|
||||||
|
* Add JSONArray (Jakob Ackermann)
|
||||||
|
* Support Inet from fmt.Stringer and encoding.TextMarshaler (Ville Skyttä)
|
||||||
|
* Support UUID from fmt.Stringer interface (Lasse Hyldahl Jensen)
|
||||||
|
* Fix: shopspring-numeric extension does not panic on NaN
|
||||||
|
* Numeric can be assigned to string
|
||||||
|
* Fix: Do not send IPv4 networks as IPv4-mapped IPv6 (William Storey)
|
||||||
|
* Fix: PlanScan for interface{}(nil) (James Hartig)
|
||||||
|
* Fix: *sql.Scanner for NULL handling (James Hartig)
|
||||||
|
* Timestamp[tz].Set() supports string (Harmen)
|
||||||
|
* Fix: Hstore AssignTo with map of *string (Diego Becciolini)
|
||||||
|
|
||||||
|
# 1.11.0 (April 21, 2022)
|
||||||
|
|
||||||
|
* Add multirange for numeric, int4, and int8 (Vu)
|
||||||
|
* JSONBArray now supports json.RawMessage (Jens Emil Schulz Østergaard)
|
||||||
|
* Add RecordArray (WGH)
|
||||||
|
* Add UnmarshalJSON to pgtype.Int2
|
||||||
|
* Hstore.Set accepts map[string]Text
|
||||||
|
|
||||||
|
# 1.10.0 (February 7, 2022)
|
||||||
|
|
||||||
|
* Normalize UTC timestamps to comply with stdlib (Torkel Rogstad)
|
||||||
|
* Assign Numeric to *big.Rat (Oleg Lomaka)
|
||||||
|
* Fix typo in float8 error message (Pinank Solanki)
|
||||||
|
* Scan type aliases for floating point types (Collin Forsyth)
|
||||||
|
|
||||||
|
# 1.9.1 (November 28, 2021)
|
||||||
|
|
||||||
|
* Fix: binary timestamp is assumed to be in UTC (restored behavior changed in v1.9.0)
|
||||||
|
|
||||||
|
# 1.9.0 (November 20, 2021)
|
||||||
|
|
||||||
|
* Fix binary hstore null decoding
|
||||||
|
* Add shopspring/decimal.NullDecimal support to integration (Eli Treuherz)
|
||||||
|
* Inet.Set supports bare IP address (Carl Dunham)
|
||||||
|
* Add zeronull.Float8
|
||||||
|
* Fix NULL being lost when scanning unknown OID into sql.Scanner
|
||||||
|
* Fix BPChar.AssignTo **rune
|
||||||
|
* Add support for fmt.Stringer and driver.Valuer in String fields encoding (Jan Dubsky)
|
||||||
|
* Fix really big timestamp(tz)s binary format parsing (e.g. year 294276) (Jim Tsao)
|
||||||
|
* Support `map[string]*string` as hstore (Adrian Sieger)
|
||||||
|
* Fix parsing text array with negative bounds
|
||||||
|
* Add infinity support for numeric (Jim Tsao)
|
||||||
|
|
||||||
|
# 1.8.1 (July 24, 2021)
|
||||||
|
|
||||||
|
* Cleaned up Go module dependency chain
|
||||||
|
|
||||||
|
# 1.8.0 (July 10, 2021)
|
||||||
|
|
||||||
|
* Maintain host bits for inet types (Cameron Daniel)
|
||||||
|
* Support pointers of wrapping structs (Ivan Daunis)
|
||||||
|
* Register JSONBArray at NewConnInfo() (Rueian)
|
||||||
|
* CompositeTextScanner handles backslash escapes
|
||||||
|
|
||||||
|
# 1.7.0 (March 25, 2021)
|
||||||
|
|
||||||
|
* Fix scanning int into **sql.Scanner implementor
|
||||||
|
* Add tsrange array type (Vasilii Novikov)
|
||||||
|
* Fix: escaped strings when they start or end with a newline char (Stephane Martin)
|
||||||
|
* Accept nil *time.Time in Time.Set
|
||||||
|
* Fix numeric NaN support
|
||||||
|
* Use Go 1.13 errors instead of xerrors
|
||||||
|
|
||||||
|
# 1.6.2 (December 3, 2020)
|
||||||
|
|
||||||
|
* Fix panic on assigning empty array to non-slice or array
|
||||||
|
* Fix text array parsing disambiguates NULL and "NULL"
|
||||||
|
* Fix Timestamptz.DecodeText with too short text
|
||||||
|
|
||||||
|
# 1.6.1 (October 31, 2020)
|
||||||
|
|
||||||
|
* Fix simple protocol empty array support
|
||||||
|
|
||||||
|
# 1.6.0 (October 24, 2020)
|
||||||
|
|
||||||
|
* Fix AssignTo pointer to pointer to slice and named types.
|
||||||
|
* Fix zero length array assignment (Simo Haasanen)
|
||||||
|
* Add float64, float32 convert to int2, int4, int8 (lqu3j)
|
||||||
|
* Support setting infinite timestamps (Erik Agsjö)
|
||||||
|
* Polygon improvements (duohedron)
|
||||||
|
* Fix Inet.Set with nil (Tomas Volf)
|
||||||
|
|
||||||
|
# 1.5.0 (September 26, 2020)
|
||||||
|
|
||||||
|
* Add slice of slice mapping to multi-dimensional arrays (Simo Haasanen)
|
||||||
|
* Fix JSONBArray
|
||||||
|
* Fix selecting empty array
|
||||||
|
* Text formatted values except bytea can be directly scanned to []byte
|
||||||
|
* Add JSON marshalling for UUID (bakmataliev)
|
||||||
|
* Improve point type conversions (bakmataliev)
|
||||||
|
|
||||||
|
# 1.4.2 (July 22, 2020)
|
||||||
|
|
||||||
|
* Fix encoding of a large composite data type (Yaz Saito)
|
||||||
|
|
||||||
|
# 1.4.1 (July 14, 2020)
|
||||||
|
|
||||||
|
* Fix ArrayType DecodeBinary empty array breaks future reads
|
||||||
|
|
||||||
|
# 1.4.0 (June 27, 2020)
|
||||||
|
|
||||||
|
* Add JSON support to ext/gofrs-uuid
|
||||||
|
* Performance improvements in Scan path
|
||||||
|
* Improved ext/shopspring-numeric binary decoding performance
|
||||||
|
* Add composite type support (Maxim Ivanov and Jack Christensen)
|
||||||
|
* Add better generic enum type support
|
||||||
|
* Add generic array type support
|
||||||
|
* Clarify and normalize Value semantics
|
||||||
|
* Fix hstore with empty string values
|
||||||
|
* Numeric supports NaN values (leighhopcroft)
|
||||||
|
* Add slice of pointer support to array types (megaturbo)
|
||||||
|
* Add jsonb array type (tserakhau)
|
||||||
|
* Allow converting intervals with months and days to duration
|
||||||
|
|
||||||
|
# 1.3.0 (March 30, 2020)
|
||||||
|
|
||||||
|
* Get implemented on T instead of *T
|
||||||
|
* Set will call Get on src if possible
|
||||||
|
* Range types Set method supports its own type, string, and nil
|
||||||
|
* Date.Set parses string
|
||||||
|
* Fix correct format verb for unknown type error (Robert Welin)
|
||||||
|
* Truncate nanoseconds in EncodeText for Timestamptz and Timestamp
|
||||||
|
|
||||||
|
# 1.2.0 (February 5, 2020)
|
||||||
|
|
||||||
|
* Add zeronull package for easier NULL <-> zero conversion
|
||||||
|
* Add JSON marshalling for shopspring-numeric extension
|
||||||
|
* Add JSON marshalling for Bool, Date, JSON/B, Timestamptz (Jeffrey Stiles)
|
||||||
|
* Fix null status in UnmarshalJSON for some types (Jeffrey Stiles)
|
||||||
|
|
||||||
|
# 1.1.0 (January 11, 2020)
|
||||||
|
|
||||||
|
* Add PostgreSQL time type support
|
||||||
|
* Add more automatic conversions of integer arrays of different types (Jean-Philippe Quéméner)
|
||||||
|
|
||||||
|
# 1.0.3 (November 16, 2019)
|
||||||
|
|
||||||
|
* Support initializing Array types from a slice of the value (Alex Gaynor)
|
||||||
|
|
||||||
|
# 1.0.2 (October 22, 2019)
|
||||||
|
|
||||||
|
* Fix scan into null into pointer to pointer implementing Decode* interface. (Jeremy Altavilla)
|
||||||
|
|
||||||
|
# 1.0.1 (September 19, 2019)
|
||||||
|
|
||||||
|
* Fix daterange OID
|
@ -0,0 +1,22 @@
|
|||||||
|
Copyright (c) 2013-2021 Jack Christensen
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of this software and associated documentation files (the
|
||||||
|
"Software"), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be
|
||||||
|
included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,14 @@
|
|||||||
|
[![](https://godoc.org/github.com/jackc/pgtype?status.svg)](https://godoc.org/github.com/jackc/pgtype)
|
||||||
|
![CI](https://github.com/jackc/pgtype/workflows/CI/badge.svg)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
This version is used with pgx `v4`. In pgx `v5` it is part of the https://github.com/jackc/pgx repository.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# pgtype
|
||||||
|
|
||||||
|
pgtype implements Go types for over 70 PostgreSQL types. pgtype is the type system underlying the
|
||||||
|
https://github.com/jackc/pgx PostgreSQL driver. These types support the binary format for enhanced performance with pgx.
|
||||||
|
They also support the database/sql `Scan` and `Value` interfaces and can be used with https://github.com/lib/pq.
|
@ -0,0 +1,138 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ACLItem is used for PostgreSQL's aclitem data type. A sample aclitem
|
||||||
|
// might look like this:
|
||||||
|
//
|
||||||
|
// postgres=arwdDxt/postgres
|
||||||
|
//
|
||||||
|
// Note, however, that because the user/role name part of an aclitem is
|
||||||
|
// an identifier, it follows all the usual formatting rules for SQL
|
||||||
|
// identifiers: if it contains spaces and other special characters,
|
||||||
|
// it should appear in double-quotes:
|
||||||
|
//
|
||||||
|
// postgres=arwdDxt/"role with spaces"
|
||||||
|
//
|
||||||
|
type ACLItem struct {
|
||||||
|
String string
|
||||||
|
Status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ACLItem) Set(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = ACLItem{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := src.(interface{ Get() interface{} }); ok {
|
||||||
|
value2 := value.Get()
|
||||||
|
if value2 != value {
|
||||||
|
return dst.Set(value2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch value := src.(type) {
|
||||||
|
case string:
|
||||||
|
*dst = ACLItem{String: value, Status: Present}
|
||||||
|
case *string:
|
||||||
|
if value == nil {
|
||||||
|
*dst = ACLItem{Status: Null}
|
||||||
|
} else {
|
||||||
|
*dst = ACLItem{String: *value, Status: Present}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if originalSrc, ok := underlyingStringType(src); ok {
|
||||||
|
return dst.Set(originalSrc)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("cannot convert %v to ACLItem", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst ACLItem) Get() interface{} {
|
||||||
|
switch dst.Status {
|
||||||
|
case Present:
|
||||||
|
return dst.String
|
||||||
|
case Null:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return dst.Status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *ACLItem) AssignTo(dst interface{}) error {
|
||||||
|
switch src.Status {
|
||||||
|
case Present:
|
||||||
|
switch v := dst.(type) {
|
||||||
|
case *string:
|
||||||
|
*v = src.String
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||||
|
return src.AssignTo(nextDst)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to assign to %T", dst)
|
||||||
|
}
|
||||||
|
case Null:
|
||||||
|
return NullAssignTo(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot decode %#v into %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ACLItem) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = ACLItem{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = ACLItem{String: string(src), Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src ACLItem) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(buf, src.String...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *ACLItem) Scan(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = ACLItem{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch src := src.(type) {
|
||||||
|
case string:
|
||||||
|
return dst.DecodeText(nil, []byte(src))
|
||||||
|
case []byte:
|
||||||
|
srcCopy := make([]byte, len(src))
|
||||||
|
copy(srcCopy, src)
|
||||||
|
return dst.DecodeText(nil, srcCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src ACLItem) Value() (driver.Value, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Present:
|
||||||
|
return src.String, nil
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
default:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,428 @@
|
|||||||
|
// Code generated by erb. DO NOT EDIT.
|
||||||
|
|
||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ACLItemArray struct {
|
||||||
|
Elements []ACLItem
|
||||||
|
Dimensions []ArrayDimension
|
||||||
|
Status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ACLItemArray) Set(src interface{}) error {
|
||||||
|
// untyped nil and typed nil interfaces are different
|
||||||
|
if src == nil {
|
||||||
|
*dst = ACLItemArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := src.(interface{ Get() interface{} }); ok {
|
||||||
|
value2 := value.Get()
|
||||||
|
if value2 != value {
|
||||||
|
return dst.Set(value2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to match to select common types:
|
||||||
|
switch value := src.(type) {
|
||||||
|
|
||||||
|
case []string:
|
||||||
|
if value == nil {
|
||||||
|
*dst = ACLItemArray{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*dst = ACLItemArray{Status: Present}
|
||||||
|
} else {
|
||||||
|
elements := make([]ACLItem, len(value))
|
||||||
|
for i := range value {
|
||||||
|
if err := elements[i].Set(value[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*dst = ACLItemArray{
|
||||||
|
Elements: elements,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case []*string:
|
||||||
|
if value == nil {
|
||||||
|
*dst = ACLItemArray{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*dst = ACLItemArray{Status: Present}
|
||||||
|
} else {
|
||||||
|
elements := make([]ACLItem, len(value))
|
||||||
|
for i := range value {
|
||||||
|
if err := elements[i].Set(value[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*dst = ACLItemArray{
|
||||||
|
Elements: elements,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case []ACLItem:
|
||||||
|
if value == nil {
|
||||||
|
*dst = ACLItemArray{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*dst = ACLItemArray{Status: Present}
|
||||||
|
} else {
|
||||||
|
*dst = ACLItemArray{
|
||||||
|
Elements: value,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Fallback to reflection if an optimised match was not found.
|
||||||
|
// The reflection is necessary for arrays and multidimensional slices,
|
||||||
|
// but it comes with a 20-50% performance penalty for large arrays/slices
|
||||||
|
reflectedValue := reflect.ValueOf(src)
|
||||||
|
if !reflectedValue.IsValid() || reflectedValue.IsZero() {
|
||||||
|
*dst = ACLItemArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("cannot find dimensions of %v for ACLItemArray", src)
|
||||||
|
}
|
||||||
|
if elementsLength == 0 {
|
||||||
|
*dst = ACLItemArray{Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(dimensions) == 0 {
|
||||||
|
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||||
|
return dst.Set(originalSrc)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("cannot convert %v to ACLItemArray", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = ACLItemArray{
|
||||||
|
Elements: make([]ACLItem, elementsLength),
|
||||||
|
Dimensions: dimensions,
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
elementCount, err := dst.setRecursive(reflectedValue, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
// Maybe the target was one dimension too far, try again:
|
||||||
|
if len(dst.Dimensions) > 1 {
|
||||||
|
dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1]
|
||||||
|
elementsLength = 0
|
||||||
|
for _, dim := range dst.Dimensions {
|
||||||
|
if elementsLength == 0 {
|
||||||
|
elementsLength = int(dim.Length)
|
||||||
|
} else {
|
||||||
|
elementsLength *= int(dim.Length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst.Elements = make([]ACLItem, elementsLength)
|
||||||
|
elementCount, err = dst.setRecursive(reflectedValue, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if elementCount != len(dst.Elements) {
|
||||||
|
return fmt.Errorf("cannot convert %v to ACLItemArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ACLItemArray) setRecursive(value reflect.Value, index, dimension int) (int, error) {
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Array:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Slice:
|
||||||
|
if len(dst.Dimensions) == dimension {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
valueLen := value.Len()
|
||||||
|
if int32(valueLen) != dst.Dimensions[dimension].Length {
|
||||||
|
return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions")
|
||||||
|
}
|
||||||
|
for i := 0; i < valueLen; i++ {
|
||||||
|
var err error
|
||||||
|
index, err = dst.setRecursive(value.Index(i), index, dimension+1)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
if !value.CanInterface() {
|
||||||
|
return 0, fmt.Errorf("cannot convert all values to ACLItemArray")
|
||||||
|
}
|
||||||
|
if err := dst.Elements[index].Set(value.Interface()); err != nil {
|
||||||
|
return 0, fmt.Errorf("%v in ACLItemArray", err)
|
||||||
|
}
|
||||||
|
index++
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst ACLItemArray) Get() interface{} {
|
||||||
|
switch dst.Status {
|
||||||
|
case Present:
|
||||||
|
return dst
|
||||||
|
case Null:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return dst.Status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *ACLItemArray) AssignTo(dst interface{}) error {
|
||||||
|
switch src.Status {
|
||||||
|
case Present:
|
||||||
|
if len(src.Dimensions) <= 1 {
|
||||||
|
// Attempt to match to select common types:
|
||||||
|
switch v := dst.(type) {
|
||||||
|
|
||||||
|
case *[]string:
|
||||||
|
*v = make([]string, len(src.Elements))
|
||||||
|
for i := range src.Elements {
|
||||||
|
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case *[]*string:
|
||||||
|
*v = make([]*string, len(src.Elements))
|
||||||
|
for i := range src.Elements {
|
||||||
|
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to convert to something AssignTo can use directly.
|
||||||
|
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||||
|
return src.AssignTo(nextDst)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to reflection if an optimised match was not found.
|
||||||
|
// The reflection is necessary for arrays and multidimensional slices,
|
||||||
|
// but it comes with a 20-50% performance penalty for large arrays/slices
|
||||||
|
value := reflect.ValueOf(dst)
|
||||||
|
if value.Kind() == reflect.Ptr {
|
||||||
|
value = value.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot assign %T to %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src.Elements) == 0 {
|
||||||
|
if value.Kind() == reflect.Slice {
|
||||||
|
value.Set(reflect.MakeSlice(value.Type(), 0, 0))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elementCount, err := src.assignToRecursive(value, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if elementCount != len(src.Elements) {
|
||||||
|
return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
case Null:
|
||||||
|
return NullAssignTo(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot decode %#v into %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *ACLItemArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) {
|
||||||
|
switch kind := value.Kind(); kind {
|
||||||
|
case reflect.Array:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Slice:
|
||||||
|
if len(src.Dimensions) == dimension {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
length := int(src.Dimensions[dimension].Length)
|
||||||
|
if reflect.Array == kind {
|
||||||
|
typ := value.Type()
|
||||||
|
if typ.Len() != length {
|
||||||
|
return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len())
|
||||||
|
}
|
||||||
|
value.Set(reflect.New(typ).Elem())
|
||||||
|
} else {
|
||||||
|
value.Set(reflect.MakeSlice(value.Type(), length, length))
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
index, err = src.assignToRecursive(value.Index(i), index, dimension+1)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
if len(src.Dimensions) != dimension {
|
||||||
|
return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension)
|
||||||
|
}
|
||||||
|
if !value.CanAddr() {
|
||||||
|
return 0, fmt.Errorf("cannot assign all values from ACLItemArray")
|
||||||
|
}
|
||||||
|
addr := value.Addr()
|
||||||
|
if !addr.CanInterface() {
|
||||||
|
return 0, fmt.Errorf("cannot assign all values from ACLItemArray")
|
||||||
|
}
|
||||||
|
if err := src.Elements[index].AssignTo(addr.Interface()); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
index++
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ACLItemArray) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = ACLItemArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
uta, err := ParseUntypedTextArray(string(src))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var elements []ACLItem
|
||||||
|
|
||||||
|
if len(uta.Elements) > 0 {
|
||||||
|
elements = make([]ACLItem, len(uta.Elements))
|
||||||
|
|
||||||
|
for i, s := range uta.Elements {
|
||||||
|
var elem ACLItem
|
||||||
|
var elemSrc []byte
|
||||||
|
if s != "NULL" || uta.Quoted[i] {
|
||||||
|
elemSrc = []byte(s)
|
||||||
|
}
|
||||||
|
err = elem.DecodeText(ci, elemSrc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
elements[i] = elem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = ACLItemArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src ACLItemArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src.Dimensions) == 0 {
|
||||||
|
return append(buf, '{', '}'), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
|
||||||
|
|
||||||
|
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||||
|
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||||
|
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||||
|
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||||
|
// or '}'.
|
||||||
|
dimElemCounts := make([]int, len(src.Dimensions))
|
||||||
|
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
|
||||||
|
for i := len(src.Dimensions) - 2; i > -1; i-- {
|
||||||
|
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
inElemBuf := make([]byte, 0, 32)
|
||||||
|
for i, elem := range src.Elements {
|
||||||
|
if i > 0 {
|
||||||
|
buf = append(buf, ',')
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if i%dec == 0 {
|
||||||
|
buf = append(buf, '{')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elemBuf, err := elem.EncodeText(ci, inElemBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if elemBuf == nil {
|
||||||
|
buf = append(buf, `NULL`...)
|
||||||
|
} else {
|
||||||
|
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if (i+1)%dec == 0 {
|
||||||
|
buf = append(buf, '}')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *ACLItemArray) Scan(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
return dst.DecodeText(nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch src := src.(type) {
|
||||||
|
case string:
|
||||||
|
return dst.DecodeText(nil, []byte(src))
|
||||||
|
case []byte:
|
||||||
|
srcCopy := make([]byte, len(src))
|
||||||
|
copy(srcCopy, src)
|
||||||
|
return dst.DecodeText(nil, srcCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src ACLItemArray) Value() (driver.Value, error) {
|
||||||
|
buf, err := src.EncodeText(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if buf == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(buf), nil
|
||||||
|
}
|
@ -0,0 +1,381 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Information on the internals of PostgreSQL arrays can be found in
|
||||||
|
// src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of
|
||||||
|
// particular interest is the array_send function.
|
||||||
|
|
||||||
|
type ArrayHeader struct {
|
||||||
|
ContainsNull bool
|
||||||
|
ElementOID int32
|
||||||
|
Dimensions []ArrayDimension
|
||||||
|
}
|
||||||
|
|
||||||
|
type ArrayDimension struct {
|
||||||
|
Length int32
|
||||||
|
LowerBound int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ArrayHeader) DecodeBinary(ci *ConnInfo, src []byte) (int, error) {
|
||||||
|
if len(src) < 12 {
|
||||||
|
return 0, fmt.Errorf("array header too short: %d", len(src))
|
||||||
|
}
|
||||||
|
|
||||||
|
rp := 0
|
||||||
|
|
||||||
|
numDims := int(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
dst.ElementOID = int32(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
if numDims > 0 {
|
||||||
|
dst.Dimensions = make([]ArrayDimension, numDims)
|
||||||
|
}
|
||||||
|
if len(src) < 12+numDims*8 {
|
||||||
|
return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src))
|
||||||
|
}
|
||||||
|
for i := range dst.Dimensions {
|
||||||
|
dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
|
||||||
|
dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:]))
|
||||||
|
rp += 4
|
||||||
|
}
|
||||||
|
|
||||||
|
return rp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src ArrayHeader) EncodeBinary(ci *ConnInfo, buf []byte) []byte {
|
||||||
|
buf = pgio.AppendInt32(buf, int32(len(src.Dimensions)))
|
||||||
|
|
||||||
|
var containsNull int32
|
||||||
|
if src.ContainsNull {
|
||||||
|
containsNull = 1
|
||||||
|
}
|
||||||
|
buf = pgio.AppendInt32(buf, containsNull)
|
||||||
|
|
||||||
|
buf = pgio.AppendInt32(buf, src.ElementOID)
|
||||||
|
|
||||||
|
for i := range src.Dimensions {
|
||||||
|
buf = pgio.AppendInt32(buf, src.Dimensions[i].Length)
|
||||||
|
buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
type UntypedTextArray struct {
|
||||||
|
Elements []string
|
||||||
|
Quoted []bool
|
||||||
|
Dimensions []ArrayDimension
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseUntypedTextArray(src string) (*UntypedTextArray, error) {
|
||||||
|
dst := &UntypedTextArray{}
|
||||||
|
|
||||||
|
buf := bytes.NewBufferString(src)
|
||||||
|
|
||||||
|
skipWhitespace(buf)
|
||||||
|
|
||||||
|
r, _, err := buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var explicitDimensions []ArrayDimension
|
||||||
|
|
||||||
|
// Array has explicit dimensions
|
||||||
|
if r == '[' {
|
||||||
|
buf.UnreadRune()
|
||||||
|
|
||||||
|
for {
|
||||||
|
r, _, err = buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r == '=' {
|
||||||
|
break
|
||||||
|
} else if r != '[' {
|
||||||
|
return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
lower, err := arrayParseInteger(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, _, err = buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r != ':' {
|
||||||
|
return nil, fmt.Errorf("invalid array, expected ':' got %v", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
upper, err := arrayParseInteger(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, _, err = buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r != ']' {
|
||||||
|
return nil, fmt.Errorf("invalid array, expected ']' got %v", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1})
|
||||||
|
}
|
||||||
|
|
||||||
|
r, _, err = buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r != '{' {
|
||||||
|
return nil, fmt.Errorf("invalid array, expected '{': %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}}
|
||||||
|
|
||||||
|
// Consume all initial opening brackets. This provides number of dimensions.
|
||||||
|
for {
|
||||||
|
r, _, err = buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r == '{' {
|
||||||
|
implicitDimensions[len(implicitDimensions)-1].Length = 1
|
||||||
|
implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1})
|
||||||
|
} else {
|
||||||
|
buf.UnreadRune()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
currentDim := len(implicitDimensions) - 1
|
||||||
|
counterDim := currentDim
|
||||||
|
|
||||||
|
for {
|
||||||
|
r, _, err = buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '{':
|
||||||
|
if currentDim == counterDim {
|
||||||
|
implicitDimensions[currentDim].Length++
|
||||||
|
}
|
||||||
|
currentDim++
|
||||||
|
case ',':
|
||||||
|
case '}':
|
||||||
|
currentDim--
|
||||||
|
if currentDim < counterDim {
|
||||||
|
counterDim = currentDim
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
buf.UnreadRune()
|
||||||
|
value, quoted, err := arrayParseValue(buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid array value: %v", err)
|
||||||
|
}
|
||||||
|
if currentDim == counterDim {
|
||||||
|
implicitDimensions[currentDim].Length++
|
||||||
|
}
|
||||||
|
dst.Quoted = append(dst.Quoted, quoted)
|
||||||
|
dst.Elements = append(dst.Elements, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentDim < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
skipWhitespace(buf)
|
||||||
|
|
||||||
|
if buf.Len() > 0 {
|
||||||
|
return nil, fmt.Errorf("unexpected trailing data: %v", buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(dst.Elements) == 0 {
|
||||||
|
dst.Dimensions = nil
|
||||||
|
} else if len(explicitDimensions) > 0 {
|
||||||
|
dst.Dimensions = explicitDimensions
|
||||||
|
} else {
|
||||||
|
dst.Dimensions = implicitDimensions
|
||||||
|
}
|
||||||
|
|
||||||
|
return dst, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func skipWhitespace(buf *bytes.Buffer) {
|
||||||
|
var r rune
|
||||||
|
var err error
|
||||||
|
for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() {
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != io.EOF {
|
||||||
|
buf.UnreadRune()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func arrayParseValue(buf *bytes.Buffer) (string, bool, error) {
|
||||||
|
r, _, err := buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
if r == '"' {
|
||||||
|
return arrayParseQuotedValue(buf)
|
||||||
|
}
|
||||||
|
buf.UnreadRune()
|
||||||
|
|
||||||
|
s := &bytes.Buffer{}
|
||||||
|
|
||||||
|
for {
|
||||||
|
r, _, err := buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case ',', '}':
|
||||||
|
buf.UnreadRune()
|
||||||
|
return s.String(), false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) {
|
||||||
|
s := &bytes.Buffer{}
|
||||||
|
|
||||||
|
for {
|
||||||
|
r, _, err := buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r {
|
||||||
|
case '\\':
|
||||||
|
r, _, err = buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
case '"':
|
||||||
|
r, _, err = buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return "", false, err
|
||||||
|
}
|
||||||
|
buf.UnreadRune()
|
||||||
|
return s.String(), true, nil
|
||||||
|
}
|
||||||
|
s.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func arrayParseInteger(buf *bytes.Buffer) (int32, error) {
|
||||||
|
s := &bytes.Buffer{}
|
||||||
|
|
||||||
|
for {
|
||||||
|
r, _, err := buf.ReadRune()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ('0' <= r && r <= '9') || r == '-' {
|
||||||
|
s.WriteRune(r)
|
||||||
|
} else {
|
||||||
|
buf.UnreadRune()
|
||||||
|
n, err := strconv.ParseInt(s.String(), 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return int32(n), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func EncodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte {
|
||||||
|
var customDimensions bool
|
||||||
|
for _, dim := range dimensions {
|
||||||
|
if dim.LowerBound != 1 {
|
||||||
|
customDimensions = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !customDimensions {
|
||||||
|
return buf
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dim := range dimensions {
|
||||||
|
buf = append(buf, '[')
|
||||||
|
buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...)
|
||||||
|
buf = append(buf, ':')
|
||||||
|
buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...)
|
||||||
|
buf = append(buf, ']')
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(buf, '=')
|
||||||
|
}
|
||||||
|
|
||||||
|
var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
|
||||||
|
|
||||||
|
func quoteArrayElement(src string) string {
|
||||||
|
return `"` + quoteArrayReplacer.Replace(src) + `"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSpace(ch byte) bool {
|
||||||
|
// see https://github.com/postgres/postgres/blob/REL_12_STABLE/src/backend/parser/scansup.c#L224
|
||||||
|
return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\f'
|
||||||
|
}
|
||||||
|
|
||||||
|
func QuoteArrayElementIfNeeded(src string) string {
|
||||||
|
if src == "" || (len(src) == 4 && strings.ToLower(src) == "null") || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) {
|
||||||
|
return quoteArrayElement(src)
|
||||||
|
}
|
||||||
|
return src
|
||||||
|
}
|
||||||
|
|
||||||
|
func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) {
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Array:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Slice:
|
||||||
|
length := value.Len()
|
||||||
|
if 0 == elementsLength {
|
||||||
|
elementsLength = length
|
||||||
|
} else {
|
||||||
|
elementsLength *= length
|
||||||
|
}
|
||||||
|
dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1})
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok {
|
||||||
|
return d, l, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dimensions, elementsLength, true
|
||||||
|
}
|
@ -0,0 +1,353 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ArrayType represents an array type. While it implements Value, this is only in service of its type conversion duties
|
||||||
|
// when registered as a data type in a ConnType. It should not be used directly as a Value. ArrayType is a convenience
|
||||||
|
// type for types that do not have a concrete array type.
|
||||||
|
type ArrayType struct {
|
||||||
|
elements []ValueTranscoder
|
||||||
|
dimensions []ArrayDimension
|
||||||
|
|
||||||
|
typeName string
|
||||||
|
newElement func() ValueTranscoder
|
||||||
|
|
||||||
|
elementOID uint32
|
||||||
|
status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewArrayType(typeName string, elementOID uint32, newElement func() ValueTranscoder) *ArrayType {
|
||||||
|
return &ArrayType{typeName: typeName, elementOID: elementOID, newElement: newElement}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (at *ArrayType) NewTypeValue() Value {
|
||||||
|
return &ArrayType{
|
||||||
|
elements: at.elements,
|
||||||
|
dimensions: at.dimensions,
|
||||||
|
status: at.status,
|
||||||
|
|
||||||
|
typeName: at.typeName,
|
||||||
|
elementOID: at.elementOID,
|
||||||
|
newElement: at.newElement,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (at *ArrayType) TypeName() string {
|
||||||
|
return at.typeName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ArrayType) setNil() {
|
||||||
|
dst.elements = nil
|
||||||
|
dst.dimensions = nil
|
||||||
|
dst.status = Null
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ArrayType) Set(src interface{}) error {
|
||||||
|
// untyped nil and typed nil interfaces are different
|
||||||
|
if src == nil {
|
||||||
|
dst.setNil()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sliceVal := reflect.ValueOf(src)
|
||||||
|
if sliceVal.Kind() != reflect.Slice {
|
||||||
|
return fmt.Errorf("cannot set non-slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
if sliceVal.IsNil() {
|
||||||
|
dst.setNil()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.elements = make([]ValueTranscoder, sliceVal.Len())
|
||||||
|
for i := range dst.elements {
|
||||||
|
v := dst.newElement()
|
||||||
|
err := v.Set(sliceVal.Index(i).Interface())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.elements[i] = v
|
||||||
|
}
|
||||||
|
dst.dimensions = []ArrayDimension{{Length: int32(len(dst.elements)), LowerBound: 1}}
|
||||||
|
dst.status = Present
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst ArrayType) Get() interface{} {
|
||||||
|
switch dst.status {
|
||||||
|
case Present:
|
||||||
|
elementValues := make([]interface{}, len(dst.elements))
|
||||||
|
for i := range dst.elements {
|
||||||
|
elementValues[i] = dst.elements[i].Get()
|
||||||
|
}
|
||||||
|
return elementValues
|
||||||
|
case Null:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return dst.status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *ArrayType) AssignTo(dst interface{}) error {
|
||||||
|
ptrSlice := reflect.ValueOf(dst)
|
||||||
|
if ptrSlice.Kind() != reflect.Ptr {
|
||||||
|
return fmt.Errorf("cannot assign to non-pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
sliceVal := ptrSlice.Elem()
|
||||||
|
sliceType := sliceVal.Type()
|
||||||
|
|
||||||
|
if sliceType.Kind() != reflect.Slice {
|
||||||
|
return fmt.Errorf("cannot assign to pointer to non-slice")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch src.status {
|
||||||
|
case Present:
|
||||||
|
slice := reflect.MakeSlice(sliceType, len(src.elements), len(src.elements))
|
||||||
|
elemType := sliceType.Elem()
|
||||||
|
|
||||||
|
for i := range src.elements {
|
||||||
|
ptrElem := reflect.New(elemType)
|
||||||
|
err := src.elements[i].AssignTo(ptrElem.Interface())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
slice.Index(i).Set(ptrElem.Elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
sliceVal.Set(slice)
|
||||||
|
return nil
|
||||||
|
case Null:
|
||||||
|
sliceVal.Set(reflect.Zero(sliceType))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot decode %#v into %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ArrayType) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
dst.setNil()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
uta, err := ParseUntypedTextArray(string(src))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var elements []ValueTranscoder
|
||||||
|
|
||||||
|
if len(uta.Elements) > 0 {
|
||||||
|
elements = make([]ValueTranscoder, len(uta.Elements))
|
||||||
|
|
||||||
|
for i, s := range uta.Elements {
|
||||||
|
elem := dst.newElement()
|
||||||
|
var elemSrc []byte
|
||||||
|
if s != "NULL" {
|
||||||
|
elemSrc = []byte(s)
|
||||||
|
}
|
||||||
|
err = elem.DecodeText(ci, elemSrc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
elements[i] = elem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.elements = elements
|
||||||
|
dst.dimensions = uta.Dimensions
|
||||||
|
dst.status = Present
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *ArrayType) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
dst.setNil()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var arrayHeader ArrayHeader
|
||||||
|
rp, err := arrayHeader.DecodeBinary(ci, src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var elements []ValueTranscoder
|
||||||
|
|
||||||
|
if len(arrayHeader.Dimensions) == 0 {
|
||||||
|
dst.elements = elements
|
||||||
|
dst.dimensions = arrayHeader.Dimensions
|
||||||
|
dst.status = Present
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
elementCount := arrayHeader.Dimensions[0].Length
|
||||||
|
for _, d := range arrayHeader.Dimensions[1:] {
|
||||||
|
elementCount *= d.Length
|
||||||
|
}
|
||||||
|
|
||||||
|
elements = make([]ValueTranscoder, elementCount)
|
||||||
|
|
||||||
|
for i := range elements {
|
||||||
|
elem := dst.newElement()
|
||||||
|
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||||
|
rp += 4
|
||||||
|
var elemSrc []byte
|
||||||
|
if elemLen >= 0 {
|
||||||
|
elemSrc = src[rp : rp+elemLen]
|
||||||
|
rp += elemLen
|
||||||
|
}
|
||||||
|
err = elem.DecodeBinary(ci, elemSrc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
elements[i] = elem
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.elements = elements
|
||||||
|
dst.dimensions = arrayHeader.Dimensions
|
||||||
|
dst.status = Present
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src ArrayType) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src.dimensions) == 0 {
|
||||||
|
return append(buf, '{', '}'), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = EncodeTextArrayDimensions(buf, src.dimensions)
|
||||||
|
|
||||||
|
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||||
|
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||||
|
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||||
|
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||||
|
// or '}'.
|
||||||
|
dimElemCounts := make([]int, len(src.dimensions))
|
||||||
|
dimElemCounts[len(src.dimensions)-1] = int(src.dimensions[len(src.dimensions)-1].Length)
|
||||||
|
for i := len(src.dimensions) - 2; i > -1; i-- {
|
||||||
|
dimElemCounts[i] = int(src.dimensions[i].Length) * dimElemCounts[i+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
inElemBuf := make([]byte, 0, 32)
|
||||||
|
for i, elem := range src.elements {
|
||||||
|
if i > 0 {
|
||||||
|
buf = append(buf, ',')
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if i%dec == 0 {
|
||||||
|
buf = append(buf, '{')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elemBuf, err := elem.EncodeText(ci, inElemBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if elemBuf == nil {
|
||||||
|
buf = append(buf, `NULL`...)
|
||||||
|
} else {
|
||||||
|
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if (i+1)%dec == 0 {
|
||||||
|
buf = append(buf, '}')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src ArrayType) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayHeader := ArrayHeader{
|
||||||
|
Dimensions: src.dimensions,
|
||||||
|
ElementOID: int32(src.elementOID),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range src.elements {
|
||||||
|
if src.elements[i].Get() == nil {
|
||||||
|
arrayHeader.ContainsNull = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = arrayHeader.EncodeBinary(ci, buf)
|
||||||
|
|
||||||
|
for i := range src.elements {
|
||||||
|
sp := len(buf)
|
||||||
|
buf = pgio.AppendInt32(buf, -1)
|
||||||
|
|
||||||
|
elemBuf, err := src.elements[i].EncodeBinary(ci, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if elemBuf != nil {
|
||||||
|
buf = elemBuf
|
||||||
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *ArrayType) Scan(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
return dst.DecodeText(nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch src := src.(type) {
|
||||||
|
case string:
|
||||||
|
return dst.DecodeText(nil, []byte(src))
|
||||||
|
case []byte:
|
||||||
|
srcCopy := make([]byte, len(src))
|
||||||
|
copy(srcCopy, src)
|
||||||
|
return dst.DecodeText(nil, srcCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src ArrayType) Value() (driver.Value, error) {
|
||||||
|
buf, err := src.EncodeText(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if buf == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(buf), nil
|
||||||
|
}
|
@ -0,0 +1,45 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bit Varbit
|
||||||
|
|
||||||
|
func (dst *Bit) Set(src interface{}) error {
|
||||||
|
return (*Varbit)(dst).Set(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst Bit) Get() interface{} {
|
||||||
|
return (Varbit)(dst).Get()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *Bit) AssignTo(dst interface{}) error {
|
||||||
|
return (*Varbit)(src).AssignTo(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Bit) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||||
|
return (*Varbit)(dst).DecodeBinary(ci, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src Bit) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
return (Varbit)(src).EncodeBinary(ci, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Bit) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
return (*Varbit)(dst).DecodeText(ci, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src Bit) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
return (Varbit)(src).EncodeText(ci, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *Bit) Scan(src interface{}) error {
|
||||||
|
return (*Varbit)(dst).Scan(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src Bit) Value() (driver.Value, error) {
|
||||||
|
return (Varbit)(src).Value()
|
||||||
|
}
|
@ -0,0 +1,217 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Bool struct {
|
||||||
|
Bool bool
|
||||||
|
Status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Bool) Set(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Bool{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := src.(interface{ Get() interface{} }); ok {
|
||||||
|
value2 := value.Get()
|
||||||
|
if value2 != value {
|
||||||
|
return dst.Set(value2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch value := src.(type) {
|
||||||
|
case bool:
|
||||||
|
*dst = Bool{Bool: value, Status: Present}
|
||||||
|
case string:
|
||||||
|
bb, err := strconv.ParseBool(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*dst = Bool{Bool: bb, Status: Present}
|
||||||
|
case *bool:
|
||||||
|
if value == nil {
|
||||||
|
*dst = Bool{Status: Null}
|
||||||
|
} else {
|
||||||
|
return dst.Set(*value)
|
||||||
|
}
|
||||||
|
case *string:
|
||||||
|
if value == nil {
|
||||||
|
*dst = Bool{Status: Null}
|
||||||
|
} else {
|
||||||
|
return dst.Set(*value)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if originalSrc, ok := underlyingBoolType(src); ok {
|
||||||
|
return dst.Set(originalSrc)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("cannot convert %v to Bool", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst Bool) Get() interface{} {
|
||||||
|
switch dst.Status {
|
||||||
|
case Present:
|
||||||
|
return dst.Bool
|
||||||
|
case Null:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return dst.Status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *Bool) AssignTo(dst interface{}) error {
|
||||||
|
switch src.Status {
|
||||||
|
case Present:
|
||||||
|
switch v := dst.(type) {
|
||||||
|
case *bool:
|
||||||
|
*v = src.Bool
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||||
|
return src.AssignTo(nextDst)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to assign to %T", dst)
|
||||||
|
}
|
||||||
|
case Null:
|
||||||
|
return NullAssignTo(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot decode %#v into %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Bool) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Bool{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src) != 1 {
|
||||||
|
return fmt.Errorf("invalid length for bool: %v", len(src))
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = Bool{Bool: src[0] == 't', Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Bool) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Bool{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src) != 1 {
|
||||||
|
return fmt.Errorf("invalid length for bool: %v", len(src))
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = Bool{Bool: src[0] == 1, Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src Bool) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
if src.Bool {
|
||||||
|
buf = append(buf, 't')
|
||||||
|
} else {
|
||||||
|
buf = append(buf, 'f')
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src Bool) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
if src.Bool {
|
||||||
|
buf = append(buf, 1)
|
||||||
|
} else {
|
||||||
|
buf = append(buf, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *Bool) Scan(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Bool{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch src := src.(type) {
|
||||||
|
case bool:
|
||||||
|
*dst = Bool{Bool: src, Status: Present}
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
return dst.DecodeText(nil, []byte(src))
|
||||||
|
case []byte:
|
||||||
|
srcCopy := make([]byte, len(src))
|
||||||
|
copy(srcCopy, src)
|
||||||
|
return dst.DecodeText(nil, srcCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src Bool) Value() (driver.Value, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Present:
|
||||||
|
return src.Bool, nil
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
default:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src Bool) MarshalJSON() ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Present:
|
||||||
|
if src.Bool {
|
||||||
|
return []byte("true"), nil
|
||||||
|
} else {
|
||||||
|
return []byte("false"), nil
|
||||||
|
}
|
||||||
|
case Null:
|
||||||
|
return []byte("null"), nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errBadStatus
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Bool) UnmarshalJSON(b []byte) error {
|
||||||
|
var v *bool
|
||||||
|
err := json.Unmarshal(b, &v)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if v == nil {
|
||||||
|
*dst = Bool{Status: Null}
|
||||||
|
} else {
|
||||||
|
*dst = Bool{Bool: *v, Status: Present}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,517 @@
|
|||||||
|
// Code generated by erb. DO NOT EDIT.
|
||||||
|
|
||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BoolArray struct {
|
||||||
|
Elements []Bool
|
||||||
|
Dimensions []ArrayDimension
|
||||||
|
Status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BoolArray) Set(src interface{}) error {
|
||||||
|
// untyped nil and typed nil interfaces are different
|
||||||
|
if src == nil {
|
||||||
|
*dst = BoolArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := src.(interface{ Get() interface{} }); ok {
|
||||||
|
value2 := value.Get()
|
||||||
|
if value2 != value {
|
||||||
|
return dst.Set(value2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to match to select common types:
|
||||||
|
switch value := src.(type) {
|
||||||
|
|
||||||
|
case []bool:
|
||||||
|
if value == nil {
|
||||||
|
*dst = BoolArray{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*dst = BoolArray{Status: Present}
|
||||||
|
} else {
|
||||||
|
elements := make([]Bool, len(value))
|
||||||
|
for i := range value {
|
||||||
|
if err := elements[i].Set(value[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*dst = BoolArray{
|
||||||
|
Elements: elements,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case []*bool:
|
||||||
|
if value == nil {
|
||||||
|
*dst = BoolArray{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*dst = BoolArray{Status: Present}
|
||||||
|
} else {
|
||||||
|
elements := make([]Bool, len(value))
|
||||||
|
for i := range value {
|
||||||
|
if err := elements[i].Set(value[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*dst = BoolArray{
|
||||||
|
Elements: elements,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case []Bool:
|
||||||
|
if value == nil {
|
||||||
|
*dst = BoolArray{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*dst = BoolArray{Status: Present}
|
||||||
|
} else {
|
||||||
|
*dst = BoolArray{
|
||||||
|
Elements: value,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Fallback to reflection if an optimised match was not found.
|
||||||
|
// The reflection is necessary for arrays and multidimensional slices,
|
||||||
|
// but it comes with a 20-50% performance penalty for large arrays/slices
|
||||||
|
reflectedValue := reflect.ValueOf(src)
|
||||||
|
if !reflectedValue.IsValid() || reflectedValue.IsZero() {
|
||||||
|
*dst = BoolArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("cannot find dimensions of %v for BoolArray", src)
|
||||||
|
}
|
||||||
|
if elementsLength == 0 {
|
||||||
|
*dst = BoolArray{Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(dimensions) == 0 {
|
||||||
|
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||||
|
return dst.Set(originalSrc)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("cannot convert %v to BoolArray", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = BoolArray{
|
||||||
|
Elements: make([]Bool, elementsLength),
|
||||||
|
Dimensions: dimensions,
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
elementCount, err := dst.setRecursive(reflectedValue, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
// Maybe the target was one dimension too far, try again:
|
||||||
|
if len(dst.Dimensions) > 1 {
|
||||||
|
dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1]
|
||||||
|
elementsLength = 0
|
||||||
|
for _, dim := range dst.Dimensions {
|
||||||
|
if elementsLength == 0 {
|
||||||
|
elementsLength = int(dim.Length)
|
||||||
|
} else {
|
||||||
|
elementsLength *= int(dim.Length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst.Elements = make([]Bool, elementsLength)
|
||||||
|
elementCount, err = dst.setRecursive(reflectedValue, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if elementCount != len(dst.Elements) {
|
||||||
|
return fmt.Errorf("cannot convert %v to BoolArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BoolArray) setRecursive(value reflect.Value, index, dimension int) (int, error) {
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Array:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Slice:
|
||||||
|
if len(dst.Dimensions) == dimension {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
valueLen := value.Len()
|
||||||
|
if int32(valueLen) != dst.Dimensions[dimension].Length {
|
||||||
|
return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions")
|
||||||
|
}
|
||||||
|
for i := 0; i < valueLen; i++ {
|
||||||
|
var err error
|
||||||
|
index, err = dst.setRecursive(value.Index(i), index, dimension+1)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
if !value.CanInterface() {
|
||||||
|
return 0, fmt.Errorf("cannot convert all values to BoolArray")
|
||||||
|
}
|
||||||
|
if err := dst.Elements[index].Set(value.Interface()); err != nil {
|
||||||
|
return 0, fmt.Errorf("%v in BoolArray", err)
|
||||||
|
}
|
||||||
|
index++
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst BoolArray) Get() interface{} {
|
||||||
|
switch dst.Status {
|
||||||
|
case Present:
|
||||||
|
return dst
|
||||||
|
case Null:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return dst.Status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *BoolArray) AssignTo(dst interface{}) error {
|
||||||
|
switch src.Status {
|
||||||
|
case Present:
|
||||||
|
if len(src.Dimensions) <= 1 {
|
||||||
|
// Attempt to match to select common types:
|
||||||
|
switch v := dst.(type) {
|
||||||
|
|
||||||
|
case *[]bool:
|
||||||
|
*v = make([]bool, len(src.Elements))
|
||||||
|
for i := range src.Elements {
|
||||||
|
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case *[]*bool:
|
||||||
|
*v = make([]*bool, len(src.Elements))
|
||||||
|
for i := range src.Elements {
|
||||||
|
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to convert to something AssignTo can use directly.
|
||||||
|
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||||
|
return src.AssignTo(nextDst)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to reflection if an optimised match was not found.
|
||||||
|
// The reflection is necessary for arrays and multidimensional slices,
|
||||||
|
// but it comes with a 20-50% performance penalty for large arrays/slices
|
||||||
|
value := reflect.ValueOf(dst)
|
||||||
|
if value.Kind() == reflect.Ptr {
|
||||||
|
value = value.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot assign %T to %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src.Elements) == 0 {
|
||||||
|
if value.Kind() == reflect.Slice {
|
||||||
|
value.Set(reflect.MakeSlice(value.Type(), 0, 0))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elementCount, err := src.assignToRecursive(value, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if elementCount != len(src.Elements) {
|
||||||
|
return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
case Null:
|
||||||
|
return NullAssignTo(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot decode %#v into %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *BoolArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) {
|
||||||
|
switch kind := value.Kind(); kind {
|
||||||
|
case reflect.Array:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Slice:
|
||||||
|
if len(src.Dimensions) == dimension {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
length := int(src.Dimensions[dimension].Length)
|
||||||
|
if reflect.Array == kind {
|
||||||
|
typ := value.Type()
|
||||||
|
if typ.Len() != length {
|
||||||
|
return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len())
|
||||||
|
}
|
||||||
|
value.Set(reflect.New(typ).Elem())
|
||||||
|
} else {
|
||||||
|
value.Set(reflect.MakeSlice(value.Type(), length, length))
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
index, err = src.assignToRecursive(value.Index(i), index, dimension+1)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
if len(src.Dimensions) != dimension {
|
||||||
|
return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension)
|
||||||
|
}
|
||||||
|
if !value.CanAddr() {
|
||||||
|
return 0, fmt.Errorf("cannot assign all values from BoolArray")
|
||||||
|
}
|
||||||
|
addr := value.Addr()
|
||||||
|
if !addr.CanInterface() {
|
||||||
|
return 0, fmt.Errorf("cannot assign all values from BoolArray")
|
||||||
|
}
|
||||||
|
if err := src.Elements[index].AssignTo(addr.Interface()); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
index++
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BoolArray) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = BoolArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
uta, err := ParseUntypedTextArray(string(src))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var elements []Bool
|
||||||
|
|
||||||
|
if len(uta.Elements) > 0 {
|
||||||
|
elements = make([]Bool, len(uta.Elements))
|
||||||
|
|
||||||
|
for i, s := range uta.Elements {
|
||||||
|
var elem Bool
|
||||||
|
var elemSrc []byte
|
||||||
|
if s != "NULL" || uta.Quoted[i] {
|
||||||
|
elemSrc = []byte(s)
|
||||||
|
}
|
||||||
|
err = elem.DecodeText(ci, elemSrc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
elements[i] = elem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = BoolArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BoolArray) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = BoolArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var arrayHeader ArrayHeader
|
||||||
|
rp, err := arrayHeader.DecodeBinary(ci, src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(arrayHeader.Dimensions) == 0 {
|
||||||
|
*dst = BoolArray{Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
elementCount := arrayHeader.Dimensions[0].Length
|
||||||
|
for _, d := range arrayHeader.Dimensions[1:] {
|
||||||
|
elementCount *= d.Length
|
||||||
|
}
|
||||||
|
|
||||||
|
elements := make([]Bool, elementCount)
|
||||||
|
|
||||||
|
for i := range elements {
|
||||||
|
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||||
|
rp += 4
|
||||||
|
var elemSrc []byte
|
||||||
|
if elemLen >= 0 {
|
||||||
|
elemSrc = src[rp : rp+elemLen]
|
||||||
|
rp += elemLen
|
||||||
|
}
|
||||||
|
err = elements[i].DecodeBinary(ci, elemSrc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = BoolArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src BoolArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src.Dimensions) == 0 {
|
||||||
|
return append(buf, '{', '}'), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
|
||||||
|
|
||||||
|
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||||
|
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||||
|
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||||
|
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||||
|
// or '}'.
|
||||||
|
dimElemCounts := make([]int, len(src.Dimensions))
|
||||||
|
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
|
||||||
|
for i := len(src.Dimensions) - 2; i > -1; i-- {
|
||||||
|
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
inElemBuf := make([]byte, 0, 32)
|
||||||
|
for i, elem := range src.Elements {
|
||||||
|
if i > 0 {
|
||||||
|
buf = append(buf, ',')
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if i%dec == 0 {
|
||||||
|
buf = append(buf, '{')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elemBuf, err := elem.EncodeText(ci, inElemBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if elemBuf == nil {
|
||||||
|
buf = append(buf, `NULL`...)
|
||||||
|
} else {
|
||||||
|
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if (i+1)%dec == 0 {
|
||||||
|
buf = append(buf, '}')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src BoolArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayHeader := ArrayHeader{
|
||||||
|
Dimensions: src.Dimensions,
|
||||||
|
}
|
||||||
|
|
||||||
|
if dt, ok := ci.DataTypeForName("bool"); ok {
|
||||||
|
arrayHeader.ElementOID = int32(dt.OID)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unable to find oid for type name %v", "bool")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range src.Elements {
|
||||||
|
if src.Elements[i].Status == Null {
|
||||||
|
arrayHeader.ContainsNull = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = arrayHeader.EncodeBinary(ci, buf)
|
||||||
|
|
||||||
|
for i := range src.Elements {
|
||||||
|
sp := len(buf)
|
||||||
|
buf = pgio.AppendInt32(buf, -1)
|
||||||
|
|
||||||
|
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if elemBuf != nil {
|
||||||
|
buf = elemBuf
|
||||||
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *BoolArray) Scan(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
return dst.DecodeText(nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch src := src.(type) {
|
||||||
|
case string:
|
||||||
|
return dst.DecodeText(nil, []byte(src))
|
||||||
|
case []byte:
|
||||||
|
srcCopy := make([]byte, len(src))
|
||||||
|
copy(srcCopy, src)
|
||||||
|
return dst.DecodeText(nil, srcCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src BoolArray) Value() (driver.Value, error) {
|
||||||
|
buf, err := src.EncodeText(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if buf == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(buf), nil
|
||||||
|
}
|
@ -0,0 +1,165 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Box struct {
|
||||||
|
P [2]Vec2
|
||||||
|
Status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Box) Set(src interface{}) error {
|
||||||
|
return fmt.Errorf("cannot convert %v to Box", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst Box) Get() interface{} {
|
||||||
|
switch dst.Status {
|
||||||
|
case Present:
|
||||||
|
return dst
|
||||||
|
case Null:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return dst.Status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *Box) AssignTo(dst interface{}) error {
|
||||||
|
return fmt.Errorf("cannot assign %v to %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Box) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Box{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src) < 11 {
|
||||||
|
return fmt.Errorf("invalid length for Box: %v", len(src))
|
||||||
|
}
|
||||||
|
|
||||||
|
str := string(src[1:])
|
||||||
|
|
||||||
|
var end int
|
||||||
|
end = strings.IndexByte(str, ',')
|
||||||
|
|
||||||
|
x1, err := strconv.ParseFloat(str[:end], 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
str = str[end+1:]
|
||||||
|
end = strings.IndexByte(str, ')')
|
||||||
|
|
||||||
|
y1, err := strconv.ParseFloat(str[:end], 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
str = str[end+3:]
|
||||||
|
end = strings.IndexByte(str, ',')
|
||||||
|
|
||||||
|
x2, err := strconv.ParseFloat(str[:end], 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
str = str[end+1 : len(str)-1]
|
||||||
|
|
||||||
|
y2, err := strconv.ParseFloat(str, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = Box{P: [2]Vec2{{x1, y1}, {x2, y2}}, Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *Box) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Box{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src) != 32 {
|
||||||
|
return fmt.Errorf("invalid length for Box: %v", len(src))
|
||||||
|
}
|
||||||
|
|
||||||
|
x1 := binary.BigEndian.Uint64(src)
|
||||||
|
y1 := binary.BigEndian.Uint64(src[8:])
|
||||||
|
x2 := binary.BigEndian.Uint64(src[16:])
|
||||||
|
y2 := binary.BigEndian.Uint64(src[24:])
|
||||||
|
|
||||||
|
*dst = Box{
|
||||||
|
P: [2]Vec2{
|
||||||
|
{math.Float64frombits(x1), math.Float64frombits(y1)},
|
||||||
|
{math.Float64frombits(x2), math.Float64frombits(y2)},
|
||||||
|
},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src Box) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = append(buf, fmt.Sprintf(`(%s,%s),(%s,%s)`,
|
||||||
|
strconv.FormatFloat(src.P[0].X, 'f', -1, 64),
|
||||||
|
strconv.FormatFloat(src.P[0].Y, 'f', -1, 64),
|
||||||
|
strconv.FormatFloat(src.P[1].X, 'f', -1, 64),
|
||||||
|
strconv.FormatFloat(src.P[1].Y, 'f', -1, 64),
|
||||||
|
)...)
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src Box) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].X))
|
||||||
|
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[0].Y))
|
||||||
|
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].X))
|
||||||
|
buf = pgio.AppendUint64(buf, math.Float64bits(src.P[1].Y))
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *Box) Scan(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = Box{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch src := src.(type) {
|
||||||
|
case string:
|
||||||
|
return dst.DecodeText(nil, []byte(src))
|
||||||
|
case []byte:
|
||||||
|
srcCopy := make([]byte, len(src))
|
||||||
|
copy(srcCopy, src)
|
||||||
|
return dst.DecodeText(nil, srcCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src Box) Value() (driver.Value, error) {
|
||||||
|
return EncodeValueText(src)
|
||||||
|
}
|
@ -0,0 +1,93 @@
|
|||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BPChar is fixed-length, blank padded char type
|
||||||
|
// character(n), char(n)
|
||||||
|
type BPChar Text
|
||||||
|
|
||||||
|
// Set converts from src to dst.
|
||||||
|
func (dst *BPChar) Set(src interface{}) error {
|
||||||
|
return (*Text)(dst).Set(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns underlying value
|
||||||
|
func (dst BPChar) Get() interface{} {
|
||||||
|
return (Text)(dst).Get()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssignTo assigns from src to dst.
|
||||||
|
func (src *BPChar) AssignTo(dst interface{}) error {
|
||||||
|
switch src.Status {
|
||||||
|
case Present:
|
||||||
|
switch v := dst.(type) {
|
||||||
|
case *rune:
|
||||||
|
runes := []rune(src.String)
|
||||||
|
if len(runes) == 1 {
|
||||||
|
*v = runes[0]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case *string:
|
||||||
|
*v = src.String
|
||||||
|
return nil
|
||||||
|
case *[]byte:
|
||||||
|
*v = make([]byte, len(src.String))
|
||||||
|
copy(*v, src.String)
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||||
|
return src.AssignTo(nextDst)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("unable to assign to %T", dst)
|
||||||
|
}
|
||||||
|
case Null:
|
||||||
|
return NullAssignTo(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot decode %#v into %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (BPChar) PreferredResultFormat() int16 {
|
||||||
|
return TextFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BPChar) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
return (*Text)(dst).DecodeText(ci, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BPChar) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||||
|
return (*Text)(dst).DecodeBinary(ci, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (BPChar) PreferredParamFormat() int16 {
|
||||||
|
return TextFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src BPChar) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
return (Text)(src).EncodeText(ci, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src BPChar) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
return (Text)(src).EncodeBinary(ci, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *BPChar) Scan(src interface{}) error {
|
||||||
|
return (*Text)(dst).Scan(src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src BPChar) Value() (driver.Value, error) {
|
||||||
|
return (Text)(src).Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src BPChar) MarshalJSON() ([]byte, error) {
|
||||||
|
return (Text)(src).MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BPChar) UnmarshalJSON(b []byte) error {
|
||||||
|
return (*Text)(dst).UnmarshalJSON(b)
|
||||||
|
}
|
@ -0,0 +1,517 @@
|
|||||||
|
// Code generated by erb. DO NOT EDIT.
|
||||||
|
|
||||||
|
package pgtype
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/jackc/pgio"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BPCharArray struct {
|
||||||
|
Elements []BPChar
|
||||||
|
Dimensions []ArrayDimension
|
||||||
|
Status Status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BPCharArray) Set(src interface{}) error {
|
||||||
|
// untyped nil and typed nil interfaces are different
|
||||||
|
if src == nil {
|
||||||
|
*dst = BPCharArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := src.(interface{ Get() interface{} }); ok {
|
||||||
|
value2 := value.Get()
|
||||||
|
if value2 != value {
|
||||||
|
return dst.Set(value2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to match to select common types:
|
||||||
|
switch value := src.(type) {
|
||||||
|
|
||||||
|
case []string:
|
||||||
|
if value == nil {
|
||||||
|
*dst = BPCharArray{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*dst = BPCharArray{Status: Present}
|
||||||
|
} else {
|
||||||
|
elements := make([]BPChar, len(value))
|
||||||
|
for i := range value {
|
||||||
|
if err := elements[i].Set(value[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*dst = BPCharArray{
|
||||||
|
Elements: elements,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case []*string:
|
||||||
|
if value == nil {
|
||||||
|
*dst = BPCharArray{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*dst = BPCharArray{Status: Present}
|
||||||
|
} else {
|
||||||
|
elements := make([]BPChar, len(value))
|
||||||
|
for i := range value {
|
||||||
|
if err := elements[i].Set(value[i]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*dst = BPCharArray{
|
||||||
|
Elements: elements,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case []BPChar:
|
||||||
|
if value == nil {
|
||||||
|
*dst = BPCharArray{Status: Null}
|
||||||
|
} else if len(value) == 0 {
|
||||||
|
*dst = BPCharArray{Status: Present}
|
||||||
|
} else {
|
||||||
|
*dst = BPCharArray{
|
||||||
|
Elements: value,
|
||||||
|
Dimensions: []ArrayDimension{{Length: int32(len(value)), LowerBound: 1}},
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Fallback to reflection if an optimised match was not found.
|
||||||
|
// The reflection is necessary for arrays and multidimensional slices,
|
||||||
|
// but it comes with a 20-50% performance penalty for large arrays/slices
|
||||||
|
reflectedValue := reflect.ValueOf(src)
|
||||||
|
if !reflectedValue.IsValid() || reflectedValue.IsZero() {
|
||||||
|
*dst = BPCharArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dimensions, elementsLength, ok := findDimensionsFromValue(reflectedValue, nil, 0)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("cannot find dimensions of %v for BPCharArray", src)
|
||||||
|
}
|
||||||
|
if elementsLength == 0 {
|
||||||
|
*dst = BPCharArray{Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(dimensions) == 0 {
|
||||||
|
if originalSrc, ok := underlyingSliceType(src); ok {
|
||||||
|
return dst.Set(originalSrc)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("cannot convert %v to BPCharArray", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = BPCharArray{
|
||||||
|
Elements: make([]BPChar, elementsLength),
|
||||||
|
Dimensions: dimensions,
|
||||||
|
Status: Present,
|
||||||
|
}
|
||||||
|
elementCount, err := dst.setRecursive(reflectedValue, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
// Maybe the target was one dimension too far, try again:
|
||||||
|
if len(dst.Dimensions) > 1 {
|
||||||
|
dst.Dimensions = dst.Dimensions[:len(dst.Dimensions)-1]
|
||||||
|
elementsLength = 0
|
||||||
|
for _, dim := range dst.Dimensions {
|
||||||
|
if elementsLength == 0 {
|
||||||
|
elementsLength = int(dim.Length)
|
||||||
|
} else {
|
||||||
|
elementsLength *= int(dim.Length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst.Elements = make([]BPChar, elementsLength)
|
||||||
|
elementCount, err = dst.setRecursive(reflectedValue, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if elementCount != len(dst.Elements) {
|
||||||
|
return fmt.Errorf("cannot convert %v to BPCharArray, expected %d dst.Elements, but got %d instead", src, len(dst.Elements), elementCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BPCharArray) setRecursive(value reflect.Value, index, dimension int) (int, error) {
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Array:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Slice:
|
||||||
|
if len(dst.Dimensions) == dimension {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
valueLen := value.Len()
|
||||||
|
if int32(valueLen) != dst.Dimensions[dimension].Length {
|
||||||
|
return 0, fmt.Errorf("multidimensional arrays must have array expressions with matching dimensions")
|
||||||
|
}
|
||||||
|
for i := 0; i < valueLen; i++ {
|
||||||
|
var err error
|
||||||
|
index, err = dst.setRecursive(value.Index(i), index, dimension+1)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
if !value.CanInterface() {
|
||||||
|
return 0, fmt.Errorf("cannot convert all values to BPCharArray")
|
||||||
|
}
|
||||||
|
if err := dst.Elements[index].Set(value.Interface()); err != nil {
|
||||||
|
return 0, fmt.Errorf("%v in BPCharArray", err)
|
||||||
|
}
|
||||||
|
index++
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst BPCharArray) Get() interface{} {
|
||||||
|
switch dst.Status {
|
||||||
|
case Present:
|
||||||
|
return dst
|
||||||
|
case Null:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return dst.Status
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *BPCharArray) AssignTo(dst interface{}) error {
|
||||||
|
switch src.Status {
|
||||||
|
case Present:
|
||||||
|
if len(src.Dimensions) <= 1 {
|
||||||
|
// Attempt to match to select common types:
|
||||||
|
switch v := dst.(type) {
|
||||||
|
|
||||||
|
case *[]string:
|
||||||
|
*v = make([]string, len(src.Elements))
|
||||||
|
for i := range src.Elements {
|
||||||
|
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case *[]*string:
|
||||||
|
*v = make([]*string, len(src.Elements))
|
||||||
|
for i := range src.Elements {
|
||||||
|
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to convert to something AssignTo can use directly.
|
||||||
|
if nextDst, retry := GetAssignToDstType(dst); retry {
|
||||||
|
return src.AssignTo(nextDst)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to reflection if an optimised match was not found.
|
||||||
|
// The reflection is necessary for arrays and multidimensional slices,
|
||||||
|
// but it comes with a 20-50% performance penalty for large arrays/slices
|
||||||
|
value := reflect.ValueOf(dst)
|
||||||
|
if value.Kind() == reflect.Ptr {
|
||||||
|
value = value.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
switch value.Kind() {
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot assign %T to %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src.Elements) == 0 {
|
||||||
|
if value.Kind() == reflect.Slice {
|
||||||
|
value.Set(reflect.MakeSlice(value.Type(), 0, 0))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elementCount, err := src.assignToRecursive(value, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if elementCount != len(src.Elements) {
|
||||||
|
return fmt.Errorf("cannot assign %v, needed to assign %d elements, but only assigned %d", dst, len(src.Elements), elementCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
case Null:
|
||||||
|
return NullAssignTo(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot decode %#v into %T", src, dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *BPCharArray) assignToRecursive(value reflect.Value, index, dimension int) (int, error) {
|
||||||
|
switch kind := value.Kind(); kind {
|
||||||
|
case reflect.Array:
|
||||||
|
fallthrough
|
||||||
|
case reflect.Slice:
|
||||||
|
if len(src.Dimensions) == dimension {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
length := int(src.Dimensions[dimension].Length)
|
||||||
|
if reflect.Array == kind {
|
||||||
|
typ := value.Type()
|
||||||
|
if typ.Len() != length {
|
||||||
|
return 0, fmt.Errorf("expected size %d array, but %s has size %d array", length, typ, typ.Len())
|
||||||
|
}
|
||||||
|
value.Set(reflect.New(typ).Elem())
|
||||||
|
} else {
|
||||||
|
value.Set(reflect.MakeSlice(value.Type(), length, length))
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for i := 0; i < length; i++ {
|
||||||
|
index, err = src.assignToRecursive(value.Index(i), index, dimension+1)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
if len(src.Dimensions) != dimension {
|
||||||
|
return 0, fmt.Errorf("incorrect dimensions, expected %d, found %d", len(src.Dimensions), dimension)
|
||||||
|
}
|
||||||
|
if !value.CanAddr() {
|
||||||
|
return 0, fmt.Errorf("cannot assign all values from BPCharArray")
|
||||||
|
}
|
||||||
|
addr := value.Addr()
|
||||||
|
if !addr.CanInterface() {
|
||||||
|
return 0, fmt.Errorf("cannot assign all values from BPCharArray")
|
||||||
|
}
|
||||||
|
if err := src.Elements[index].AssignTo(addr.Interface()); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
index++
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BPCharArray) DecodeText(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = BPCharArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
uta, err := ParseUntypedTextArray(string(src))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var elements []BPChar
|
||||||
|
|
||||||
|
if len(uta.Elements) > 0 {
|
||||||
|
elements = make([]BPChar, len(uta.Elements))
|
||||||
|
|
||||||
|
for i, s := range uta.Elements {
|
||||||
|
var elem BPChar
|
||||||
|
var elemSrc []byte
|
||||||
|
if s != "NULL" || uta.Quoted[i] {
|
||||||
|
elemSrc = []byte(s)
|
||||||
|
}
|
||||||
|
err = elem.DecodeText(ci, elemSrc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
elements[i] = elem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = BPCharArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dst *BPCharArray) DecodeBinary(ci *ConnInfo, src []byte) error {
|
||||||
|
if src == nil {
|
||||||
|
*dst = BPCharArray{Status: Null}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var arrayHeader ArrayHeader
|
||||||
|
rp, err := arrayHeader.DecodeBinary(ci, src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(arrayHeader.Dimensions) == 0 {
|
||||||
|
*dst = BPCharArray{Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
elementCount := arrayHeader.Dimensions[0].Length
|
||||||
|
for _, d := range arrayHeader.Dimensions[1:] {
|
||||||
|
elementCount *= d.Length
|
||||||
|
}
|
||||||
|
|
||||||
|
elements := make([]BPChar, elementCount)
|
||||||
|
|
||||||
|
for i := range elements {
|
||||||
|
elemLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||||
|
rp += 4
|
||||||
|
var elemSrc []byte
|
||||||
|
if elemLen >= 0 {
|
||||||
|
elemSrc = src[rp : rp+elemLen]
|
||||||
|
rp += elemLen
|
||||||
|
}
|
||||||
|
err = elements[i].DecodeBinary(ci, elemSrc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*dst = BPCharArray{Elements: elements, Dimensions: arrayHeader.Dimensions, Status: Present}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src BPCharArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(src.Dimensions) == 0 {
|
||||||
|
return append(buf, '{', '}'), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = EncodeTextArrayDimensions(buf, src.Dimensions)
|
||||||
|
|
||||||
|
// dimElemCounts is the multiples of elements that each array lies on. For
|
||||||
|
// example, a single dimension array of length 4 would have a dimElemCounts of
|
||||||
|
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
|
||||||
|
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
|
||||||
|
// or '}'.
|
||||||
|
dimElemCounts := make([]int, len(src.Dimensions))
|
||||||
|
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
|
||||||
|
for i := len(src.Dimensions) - 2; i > -1; i-- {
|
||||||
|
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
|
||||||
|
}
|
||||||
|
|
||||||
|
inElemBuf := make([]byte, 0, 32)
|
||||||
|
for i, elem := range src.Elements {
|
||||||
|
if i > 0 {
|
||||||
|
buf = append(buf, ',')
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if i%dec == 0 {
|
||||||
|
buf = append(buf, '{')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elemBuf, err := elem.EncodeText(ci, inElemBuf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if elemBuf == nil {
|
||||||
|
buf = append(buf, `NULL`...)
|
||||||
|
} else {
|
||||||
|
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dec := range dimElemCounts {
|
||||||
|
if (i+1)%dec == 0 {
|
||||||
|
buf = append(buf, '}')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src BPCharArray) EncodeBinary(ci *ConnInfo, buf []byte) ([]byte, error) {
|
||||||
|
switch src.Status {
|
||||||
|
case Null:
|
||||||
|
return nil, nil
|
||||||
|
case Undefined:
|
||||||
|
return nil, errUndefined
|
||||||
|
}
|
||||||
|
|
||||||
|
arrayHeader := ArrayHeader{
|
||||||
|
Dimensions: src.Dimensions,
|
||||||
|
}
|
||||||
|
|
||||||
|
if dt, ok := ci.DataTypeForName("bpchar"); ok {
|
||||||
|
arrayHeader.ElementOID = int32(dt.OID)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unable to find oid for type name %v", "bpchar")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range src.Elements {
|
||||||
|
if src.Elements[i].Status == Null {
|
||||||
|
arrayHeader.ContainsNull = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buf = arrayHeader.EncodeBinary(ci, buf)
|
||||||
|
|
||||||
|
for i := range src.Elements {
|
||||||
|
sp := len(buf)
|
||||||
|
buf = pgio.AppendInt32(buf, -1)
|
||||||
|
|
||||||
|
elemBuf, err := src.Elements[i].EncodeBinary(ci, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if elemBuf != nil {
|
||||||
|
buf = elemBuf
|
||||||
|
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the database/sql Scanner interface.
|
||||||
|
func (dst *BPCharArray) Scan(src interface{}) error {
|
||||||
|
if src == nil {
|
||||||
|
return dst.DecodeText(nil, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch src := src.(type) {
|
||||||
|
case string:
|
||||||
|
return dst.DecodeText(nil, []byte(src))
|
||||||
|
case []byte:
|
||||||
|
srcCopy := make([]byte, len(src))
|
||||||
|
copy(srcCopy, src)
|
||||||
|
return dst.DecodeText(nil, srcCopy)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the database/sql/driver Valuer interface.
|
||||||
|
func (src BPCharArray) Value() (driver.Value, error) {
|
||||||
|
buf, err := src.EncodeText(nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if buf == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(buf), nil
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue