- update vendor
continuous-integration/drone/push Build is passing Details
continuous-integration/drone/tag Build is passing Details

master v1.0.2
李光春 2 years ago
parent 503a1ab383
commit a3664d42b0

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2014 Chris Hines
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,38 @@
[![GoDoc](https://godoc.org/github.com/go-stack/stack?status.svg)](https://godoc.org/github.com/go-stack/stack)
[![Go Report Card](https://goreportcard.com/badge/go-stack/stack)](https://goreportcard.com/report/go-stack/stack)
[![TravisCI](https://travis-ci.org/go-stack/stack.svg?branch=master)](https://travis-ci.org/go-stack/stack)
[![Coverage Status](https://coveralls.io/repos/github/go-stack/stack/badge.svg?branch=master)](https://coveralls.io/github/go-stack/stack?branch=master)
# stack
Package stack implements utilities to capture, manipulate, and format call
stacks. It provides a simpler API than package runtime.
The implementation takes care of the minutia and special cases of interpreting
the program counter (pc) values returned by runtime.Callers.
## Versioning
Package stack publishes releases via [semver](http://semver.org/) compatible Git
tags prefixed with a single 'v'. The master branch always contains the latest
release. The develop branch contains unreleased commits.
## Formatting
Package stack's types implement fmt.Formatter, which provides a simple and
flexible way to declaratively configure formatting when used with logging or
error tracking packages.
```go
func DoTheThing() {
c := stack.Caller(0)
log.Print(c) // "source.go:10"
log.Printf("%+v", c) // "pkg/path/source.go:10"
log.Printf("%n", c) // "DoTheThing"
s := stack.Trace().TrimRuntime()
log.Print(s) // "[source.go:15 caller.go:42 main.go:14]"
}
```
See the docs for all of the supported formatting options.

@ -0,0 +1,400 @@
// +build go1.7
// Package stack implements utilities to capture, manipulate, and format call
// stacks. It provides a simpler API than package runtime.
//
// The implementation takes care of the minutia and special cases of
// interpreting the program counter (pc) values returned by runtime.Callers.
//
// Package stack's types implement fmt.Formatter, which provides a simple and
// flexible way to declaratively configure formatting when used with logging
// or error tracking packages.
package stack
import (
"bytes"
"errors"
"fmt"
"io"
"runtime"
"strconv"
"strings"
)
// Call records a single function invocation from a goroutine stack.
type Call struct {
frame runtime.Frame
}
// Caller returns a Call from the stack of the current goroutine. The argument
// skip is the number of stack frames to ascend, with 0 identifying the
// calling function.
func Caller(skip int) Call {
// As of Go 1.9 we need room for up to three PC entries.
//
// 0. An entry for the stack frame prior to the target to check for
// special handling needed if that prior entry is runtime.sigpanic.
// 1. A possible second entry to hold metadata about skipped inlined
// functions. If inline functions were not skipped the target frame
// PC will be here.
// 2. A third entry for the target frame PC when the second entry
// is used for skipped inline functions.
var pcs [3]uintptr
n := runtime.Callers(skip+1, pcs[:])
frames := runtime.CallersFrames(pcs[:n])
frame, _ := frames.Next()
frame, _ = frames.Next()
return Call{
frame: frame,
}
}
// String implements fmt.Stinger. It is equivalent to fmt.Sprintf("%v", c).
func (c Call) String() string {
return fmt.Sprint(c)
}
// MarshalText implements encoding.TextMarshaler. It formats the Call the same
// as fmt.Sprintf("%v", c).
func (c Call) MarshalText() ([]byte, error) {
if c.frame == (runtime.Frame{}) {
return nil, ErrNoFunc
}
buf := bytes.Buffer{}
fmt.Fprint(&buf, c)
return buf.Bytes(), nil
}
// ErrNoFunc means that the Call has a nil *runtime.Func. The most likely
// cause is a Call with the zero value.
var ErrNoFunc = errors.New("no call stack information")
// Format implements fmt.Formatter with support for the following verbs.
//
// %s source file
// %d line number
// %n function name
// %k last segment of the package path
// %v equivalent to %s:%d
//
// It accepts the '+' and '#' flags for most of the verbs as follows.
//
// %+s path of source file relative to the compile time GOPATH,
// or the module path joined to the path of source file relative
// to module root
// %#s full path of source file
// %+n import path qualified function name
// %+k full package path
// %+v equivalent to %+s:%d
// %#v equivalent to %#s:%d
func (c Call) Format(s fmt.State, verb rune) {
if c.frame == (runtime.Frame{}) {
fmt.Fprintf(s, "%%!%c(NOFUNC)", verb)
return
}
switch verb {
case 's', 'v':
file := c.frame.File
switch {
case s.Flag('#'):
// done
case s.Flag('+'):
file = pkgFilePath(&c.frame)
default:
const sep = "/"
if i := strings.LastIndex(file, sep); i != -1 {
file = file[i+len(sep):]
}
}
io.WriteString(s, file)
if verb == 'v' {
buf := [7]byte{':'}
s.Write(strconv.AppendInt(buf[:1], int64(c.frame.Line), 10))
}
case 'd':
buf := [6]byte{}
s.Write(strconv.AppendInt(buf[:0], int64(c.frame.Line), 10))
case 'k':
name := c.frame.Function
const pathSep = "/"
start, end := 0, len(name)
if i := strings.LastIndex(name, pathSep); i != -1 {
start = i + len(pathSep)
}
const pkgSep = "."
if i := strings.Index(name[start:], pkgSep); i != -1 {
end = start + i
}
if s.Flag('+') {
start = 0
}
io.WriteString(s, name[start:end])
case 'n':
name := c.frame.Function
if !s.Flag('+') {
const pathSep = "/"
if i := strings.LastIndex(name, pathSep); i != -1 {
name = name[i+len(pathSep):]
}
const pkgSep = "."
if i := strings.Index(name, pkgSep); i != -1 {
name = name[i+len(pkgSep):]
}
}
io.WriteString(s, name)
}
}
// Frame returns the call frame infomation for the Call.
func (c Call) Frame() runtime.Frame {
return c.frame
}
// PC returns the program counter for this call frame; multiple frames may
// have the same PC value.
//
// Deprecated: Use Call.Frame instead.
func (c Call) PC() uintptr {
return c.frame.PC
}
// CallStack records a sequence of function invocations from a goroutine
// stack.
type CallStack []Call
// String implements fmt.Stinger. It is equivalent to fmt.Sprintf("%v", cs).
func (cs CallStack) String() string {
return fmt.Sprint(cs)
}
var (
openBracketBytes = []byte("[")
closeBracketBytes = []byte("]")
spaceBytes = []byte(" ")
)
// MarshalText implements encoding.TextMarshaler. It formats the CallStack the
// same as fmt.Sprintf("%v", cs).
func (cs CallStack) MarshalText() ([]byte, error) {
buf := bytes.Buffer{}
buf.Write(openBracketBytes)
for i, pc := range cs {
if i > 0 {
buf.Write(spaceBytes)
}
fmt.Fprint(&buf, pc)
}
buf.Write(closeBracketBytes)
return buf.Bytes(), nil
}
// Format implements fmt.Formatter by printing the CallStack as square brackets
// ([, ]) surrounding a space separated list of Calls each formatted with the
// supplied verb and options.
func (cs CallStack) Format(s fmt.State, verb rune) {
s.Write(openBracketBytes)
for i, pc := range cs {
if i > 0 {
s.Write(spaceBytes)
}
pc.Format(s, verb)
}
s.Write(closeBracketBytes)
}
// Trace returns a CallStack for the current goroutine with element 0
// identifying the calling function.
func Trace() CallStack {
var pcs [512]uintptr
n := runtime.Callers(1, pcs[:])
frames := runtime.CallersFrames(pcs[:n])
cs := make(CallStack, 0, n)
// Skip extra frame retrieved just to make sure the runtime.sigpanic
// special case is handled.
frame, more := frames.Next()
for more {
frame, more = frames.Next()
cs = append(cs, Call{frame: frame})
}
return cs
}
// TrimBelow returns a slice of the CallStack with all entries below c
// removed.
func (cs CallStack) TrimBelow(c Call) CallStack {
for len(cs) > 0 && cs[0] != c {
cs = cs[1:]
}
return cs
}
// TrimAbove returns a slice of the CallStack with all entries above c
// removed.
func (cs CallStack) TrimAbove(c Call) CallStack {
for len(cs) > 0 && cs[len(cs)-1] != c {
cs = cs[:len(cs)-1]
}
return cs
}
// pkgIndex returns the index that results in file[index:] being the path of
// file relative to the compile time GOPATH, and file[:index] being the
// $GOPATH/src/ portion of file. funcName must be the name of a function in
// file as returned by runtime.Func.Name.
func pkgIndex(file, funcName string) int {
// As of Go 1.6.2 there is no direct way to know the compile time GOPATH
// at runtime, but we can infer the number of path segments in the GOPATH.
// We note that runtime.Func.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// file[:idx] == /home/user/src/
// file[idx:] == pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired result for file[idx:]. We count separators from the
// end of the file path until it finds two more than in the function name
// and then move one character forward to preserve the initial path
// segment without a leading separator.
const sep = "/"
i := len(file)
for n := strings.Count(funcName, sep) + 2; n > 0; n-- {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
return i + len(sep)
}
// pkgFilePath returns the frame's filepath relative to the compile-time GOPATH,
// or its module path joined to its path relative to the module root.
//
// As of Go 1.11 there is no direct way to know the compile time GOPATH or
// module paths at runtime, but we can piece together the desired information
// from available information. We note that runtime.Frame.Function contains the
// function name qualified by the package path, which includes the module path
// but not the GOPATH. We can extract the package path from that and append the
// last segments of the file path to arrive at the desired package qualified
// file path. For example, given:
//
// GOPATH /home/user
// import path pkg/sub
// frame.File /home/user/src/pkg/sub/file.go
// frame.Function pkg/sub.Type.Method
// Desired return pkg/sub/file.go
//
// It appears that we simply need to trim ".Type.Method" from frame.Function and
// append "/" + path.Base(file).
//
// But there are other wrinkles. Although it is idiomatic to do so, the internal
// name of a package is not required to match the last segment of its import
// path. In addition, the introduction of modules in Go 1.11 allows working
// without a GOPATH. So we also must make these work right:
//
// GOPATH /home/user
// import path pkg/go-sub
// package name sub
// frame.File /home/user/src/pkg/go-sub/file.go
// frame.Function pkg/sub.Type.Method
// Desired return pkg/go-sub/file.go
//
// Module path pkg/v2
// import path pkg/v2/go-sub
// package name sub
// frame.File /home/user/cloned-pkg/go-sub/file.go
// frame.Function pkg/v2/sub.Type.Method
// Desired return pkg/v2/go-sub/file.go
//
// We can handle all of these situations by using the package path extracted
// from frame.Function up to, but not including, the last segment as the prefix
// and the last two segments of frame.File as the suffix of the returned path.
// This preserves the existing behavior when working in a GOPATH without modules
// and a semantically equivalent behavior when used in module aware project.
func pkgFilePath(frame *runtime.Frame) string {
pre := pkgPrefix(frame.Function)
post := pathSuffix(frame.File)
if pre == "" {
return post
}
return pre + "/" + post
}
// pkgPrefix returns the import path of the function's package with the final
// segment removed.
func pkgPrefix(funcName string) string {
const pathSep = "/"
end := strings.LastIndex(funcName, pathSep)
if end == -1 {
return ""
}
return funcName[:end]
}
// pathSuffix returns the last two segments of path.
func pathSuffix(path string) string {
const pathSep = "/"
lastSep := strings.LastIndex(path, pathSep)
if lastSep == -1 {
return path
}
return path[strings.LastIndex(path[:lastSep], pathSep)+1:]
}
var runtimePath string
func init() {
var pcs [3]uintptr
runtime.Callers(0, pcs[:])
frames := runtime.CallersFrames(pcs[:])
frame, _ := frames.Next()
file := frame.File
idx := pkgIndex(frame.File, frame.Function)
runtimePath = file[:idx]
if runtime.GOOS == "windows" {
runtimePath = strings.ToLower(runtimePath)
}
}
func inGoroot(c Call) bool {
file := c.frame.File
if len(file) == 0 || file[0] == '?' {
return true
}
if runtime.GOOS == "windows" {
file = strings.ToLower(file)
}
return strings.HasPrefix(file, runtimePath) || strings.HasSuffix(file, "/_testmain.go")
}
// TrimRuntime returns a slice of the CallStack with the topmost entries from
// the go runtime removed. It considers any calls originating from unknown
// files, files under GOROOT, or _testmain.go as part of the runtime.
func (cs CallStack) TrimRuntime() CallStack {
for len(cs) > 0 && inGoroot(cs[len(cs)-1]) {
cs = cs[:len(cs)-1]
}
return cs
}

@ -0,0 +1,16 @@
cmd/snappytool/snappytool
testdata/bench
# These explicitly listed benchmark data files are for an obsolete version of
# snappy_test.go.
testdata/alice29.txt
testdata/asyoulik.txt
testdata/fireworks.jpeg
testdata/geo.protodata
testdata/html
testdata/html_x_4
testdata/kppkn.gtb
testdata/lcet10.txt
testdata/paper-100k.pdf
testdata/plrabn12.txt
testdata/urls.10K

@ -0,0 +1,18 @@
# This is the official list of Snappy-Go authors for copyright purposes.
# This file is distinct from the CONTRIBUTORS files.
# See the latter for an explanation.
# Names should be added to this file as
# Name or Organization <email address>
# The email address is not required for organizations.
# Please keep the list sorted.
Amazon.com, Inc
Damian Gryski <dgryski@gmail.com>
Eric Buth <eric@topos.com>
Google Inc.
Jan Mercl <0xjnml@gmail.com>
Klaus Post <klauspost@gmail.com>
Rodolfo Carvalho <rhcarvalho@gmail.com>
Sebastien Binet <seb.binet@gmail.com>

@ -0,0 +1,41 @@
# This is the official list of people who can contribute
# (and typically have contributed) code to the Snappy-Go repository.
# The AUTHORS file lists the copyright holders; this file
# lists people. For example, Google employees are listed here
# but not in AUTHORS, because Google holds the copyright.
#
# The submission process automatically checks to make sure
# that people submitting code are listed in this file (by email address).
#
# Names should be added to this file only after verifying that
# the individual or the individual's organization has agreed to
# the appropriate Contributor License Agreement, found here:
#
# http://code.google.com/legal/individual-cla-v1.0.html
# http://code.google.com/legal/corporate-cla-v1.0.html
#
# The agreement for individuals can be filled out on the web.
#
# When adding J Random Contributor's name to this file,
# either J's name or J's organization's name should be
# added to the AUTHORS file, depending on whether the
# individual or corporate CLA was used.
# Names should be added to this file like so:
# Name <email address>
# Please keep the list sorted.
Alex Legg <alexlegg@google.com>
Damian Gryski <dgryski@gmail.com>
Eric Buth <eric@topos.com>
Jan Mercl <0xjnml@gmail.com>
Jonathan Swinney <jswinney@amazon.com>
Kai Backman <kaib@golang.org>
Klaus Post <klauspost@gmail.com>
Marc-Antoine Ruel <maruel@chromium.org>
Nigel Tao <nigeltao@golang.org>
Rob Pike <r@golang.org>
Rodolfo Carvalho <rhcarvalho@gmail.com>
Russ Cox <rsc@golang.org>
Sebastien Binet <seb.binet@gmail.com>

@ -0,0 +1,27 @@
Copyright (c) 2011 The Snappy-Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,107 @@
The Snappy compression format in the Go programming language.
To download and install from source:
$ go get github.com/golang/snappy
Unless otherwise noted, the Snappy-Go source files are distributed
under the BSD-style license found in the LICENSE file.
Benchmarks.
The golang/snappy benchmarks include compressing (Z) and decompressing (U) ten
or so files, the same set used by the C++ Snappy code (github.com/google/snappy
and note the "google", not "golang"). On an "Intel(R) Core(TM) i7-3770 CPU @
3.40GHz", Go's GOARCH=amd64 numbers as of 2016-05-29:
"go test -test.bench=."
_UFlat0-8 2.19GB/s ± 0% html
_UFlat1-8 1.41GB/s ± 0% urls
_UFlat2-8 23.5GB/s ± 2% jpg
_UFlat3-8 1.91GB/s ± 0% jpg_200
_UFlat4-8 14.0GB/s ± 1% pdf
_UFlat5-8 1.97GB/s ± 0% html4
_UFlat6-8 814MB/s ± 0% txt1
_UFlat7-8 785MB/s ± 0% txt2
_UFlat8-8 857MB/s ± 0% txt3
_UFlat9-8 719MB/s ± 1% txt4
_UFlat10-8 2.84GB/s ± 0% pb
_UFlat11-8 1.05GB/s ± 0% gaviota
_ZFlat0-8 1.04GB/s ± 0% html
_ZFlat1-8 534MB/s ± 0% urls
_ZFlat2-8 15.7GB/s ± 1% jpg
_ZFlat3-8 740MB/s ± 3% jpg_200
_ZFlat4-8 9.20GB/s ± 1% pdf
_ZFlat5-8 991MB/s ± 0% html4
_ZFlat6-8 379MB/s ± 0% txt1
_ZFlat7-8 352MB/s ± 0% txt2
_ZFlat8-8 396MB/s ± 1% txt3
_ZFlat9-8 327MB/s ± 1% txt4
_ZFlat10-8 1.33GB/s ± 1% pb
_ZFlat11-8 605MB/s ± 1% gaviota
"go test -test.bench=. -tags=noasm"
_UFlat0-8 621MB/s ± 2% html
_UFlat1-8 494MB/s ± 1% urls
_UFlat2-8 23.2GB/s ± 1% jpg
_UFlat3-8 1.12GB/s ± 1% jpg_200
_UFlat4-8 4.35GB/s ± 1% pdf
_UFlat5-8 609MB/s ± 0% html4
_UFlat6-8 296MB/s ± 0% txt1
_UFlat7-8 288MB/s ± 0% txt2
_UFlat8-8 309MB/s ± 1% txt3
_UFlat9-8 280MB/s ± 1% txt4
_UFlat10-8 753MB/s ± 0% pb
_UFlat11-8 400MB/s ± 0% gaviota
_ZFlat0-8 409MB/s ± 1% html
_ZFlat1-8 250MB/s ± 1% urls
_ZFlat2-8 12.3GB/s ± 1% jpg
_ZFlat3-8 132MB/s ± 0% jpg_200
_ZFlat4-8 2.92GB/s ± 0% pdf
_ZFlat5-8 405MB/s ± 1% html4
_ZFlat6-8 179MB/s ± 1% txt1
_ZFlat7-8 170MB/s ± 1% txt2
_ZFlat8-8 189MB/s ± 1% txt3
_ZFlat9-8 164MB/s ± 1% txt4
_ZFlat10-8 479MB/s ± 1% pb
_ZFlat11-8 270MB/s ± 1% gaviota
For comparison (Go's encoded output is byte-for-byte identical to C++'s), here
are the numbers from C++ Snappy's
make CXXFLAGS="-O2 -DNDEBUG -g" clean snappy_unittest.log && cat snappy_unittest.log
BM_UFlat/0 2.4GB/s html
BM_UFlat/1 1.4GB/s urls
BM_UFlat/2 21.8GB/s jpg
BM_UFlat/3 1.5GB/s jpg_200
BM_UFlat/4 13.3GB/s pdf
BM_UFlat/5 2.1GB/s html4
BM_UFlat/6 1.0GB/s txt1
BM_UFlat/7 959.4MB/s txt2
BM_UFlat/8 1.0GB/s txt3
BM_UFlat/9 864.5MB/s txt4
BM_UFlat/10 2.9GB/s pb
BM_UFlat/11 1.2GB/s gaviota
BM_ZFlat/0 944.3MB/s html (22.31 %)
BM_ZFlat/1 501.6MB/s urls (47.78 %)
BM_ZFlat/2 14.3GB/s jpg (99.95 %)
BM_ZFlat/3 538.3MB/s jpg_200 (73.00 %)
BM_ZFlat/4 8.3GB/s pdf (83.30 %)
BM_ZFlat/5 903.5MB/s html4 (22.52 %)
BM_ZFlat/6 336.0MB/s txt1 (57.88 %)
BM_ZFlat/7 312.3MB/s txt2 (61.91 %)
BM_ZFlat/8 353.1MB/s txt3 (54.99 %)
BM_ZFlat/9 289.9MB/s txt4 (66.26 %)
BM_ZFlat/10 1.2GB/s pb (19.68 %)
BM_ZFlat/11 527.4MB/s gaviota (37.72 %)

@ -0,0 +1,264 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
"errors"
"io"
)
var (
// ErrCorrupt reports that the input is invalid.
ErrCorrupt = errors.New("snappy: corrupt input")
// ErrTooLarge reports that the uncompressed length is too large.
ErrTooLarge = errors.New("snappy: decoded block is too large")
// ErrUnsupported reports that the input isn't supported.
ErrUnsupported = errors.New("snappy: unsupported input")
errUnsupportedLiteralLength = errors.New("snappy: unsupported literal length")
)
// DecodedLen returns the length of the decoded block.
func DecodedLen(src []byte) (int, error) {
v, _, err := decodedLen(src)
return v, err
}
// decodedLen returns the length of the decoded block and the number of bytes
// that the length header occupied.
func decodedLen(src []byte) (blockLen, headerLen int, err error) {
v, n := binary.Uvarint(src)
if n <= 0 || v > 0xffffffff {
return 0, 0, ErrCorrupt
}
const wordSize = 32 << (^uint(0) >> 32 & 1)
if wordSize == 32 && v > 0x7fffffff {
return 0, 0, ErrTooLarge
}
return int(v), n, nil
}
const (
decodeErrCodeCorrupt = 1
decodeErrCodeUnsupportedLiteralLength = 2
)
// Decode returns the decoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire decoded block.
// Otherwise, a newly allocated slice will be returned.
//
// The dst and src must not overlap. It is valid to pass a nil dst.
//
// Decode handles the Snappy block format, not the Snappy stream format.
func Decode(dst, src []byte) ([]byte, error) {
dLen, s, err := decodedLen(src)
if err != nil {
return nil, err
}
if dLen <= len(dst) {
dst = dst[:dLen]
} else {
dst = make([]byte, dLen)
}
switch decode(dst, src[s:]) {
case 0:
return dst, nil
case decodeErrCodeUnsupportedLiteralLength:
return nil, errUnsupportedLiteralLength
}
return nil, ErrCorrupt
}
// NewReader returns a new Reader that decompresses from r, using the framing
// format described at
// https://github.com/google/snappy/blob/master/framing_format.txt
func NewReader(r io.Reader) *Reader {
return &Reader{
r: r,
decoded: make([]byte, maxBlockSize),
buf: make([]byte, maxEncodedLenOfMaxBlockSize+checksumSize),
}
}
// Reader is an io.Reader that can read Snappy-compressed bytes.
//
// Reader handles the Snappy stream format, not the Snappy block format.
type Reader struct {
r io.Reader
err error
decoded []byte
buf []byte
// decoded[i:j] contains decoded bytes that have not yet been passed on.
i, j int
readHeader bool
}
// Reset discards any buffered data, resets all state, and switches the Snappy
// reader to read from r. This permits reusing a Reader rather than allocating
// a new one.
func (r *Reader) Reset(reader io.Reader) {
r.r = reader
r.err = nil
r.i = 0
r.j = 0
r.readHeader = false
}
func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
if _, r.err = io.ReadFull(r.r, p); r.err != nil {
if r.err == io.ErrUnexpectedEOF || (r.err == io.EOF && !allowEOF) {
r.err = ErrCorrupt
}
return false
}
return true
}
func (r *Reader) fill() error {
for r.i >= r.j {
if !r.readFull(r.buf[:4], true) {
return r.err
}
chunkType := r.buf[0]
if !r.readHeader {
if chunkType != chunkTypeStreamIdentifier {
r.err = ErrCorrupt
return r.err
}
r.readHeader = true
}
chunkLen := int(r.buf[1]) | int(r.buf[2])<<8 | int(r.buf[3])<<16
if chunkLen > len(r.buf) {
r.err = ErrUnsupported
return r.err
}
// The chunk types are specified at
// https://github.com/google/snappy/blob/master/framing_format.txt
switch chunkType {
case chunkTypeCompressedData:
// Section 4.2. Compressed data (chunk type 0x00).
if chunkLen < checksumSize {
r.err = ErrCorrupt
return r.err
}
buf := r.buf[:chunkLen]
if !r.readFull(buf, false) {
return r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
buf = buf[checksumSize:]
n, err := DecodedLen(buf)
if err != nil {
r.err = err
return r.err
}
if n > len(r.decoded) {
r.err = ErrCorrupt
return r.err
}
if _, err := Decode(r.decoded, buf); err != nil {
r.err = err
return r.err
}
if crc(r.decoded[:n]) != checksum {
r.err = ErrCorrupt
return r.err
}
r.i, r.j = 0, n
continue
case chunkTypeUncompressedData:
// Section 4.3. Uncompressed data (chunk type 0x01).
if chunkLen < checksumSize {
r.err = ErrCorrupt
return r.err
}
buf := r.buf[:checksumSize]
if !r.readFull(buf, false) {
return r.err
}
checksum := uint32(buf[0]) | uint32(buf[1])<<8 | uint32(buf[2])<<16 | uint32(buf[3])<<24
// Read directly into r.decoded instead of via r.buf.
n := chunkLen - checksumSize
if n > len(r.decoded) {
r.err = ErrCorrupt
return r.err
}
if !r.readFull(r.decoded[:n], false) {
return r.err
}
if crc(r.decoded[:n]) != checksum {
r.err = ErrCorrupt
return r.err
}
r.i, r.j = 0, n
continue
case chunkTypeStreamIdentifier:
// Section 4.1. Stream identifier (chunk type 0xff).
if chunkLen != len(magicBody) {
r.err = ErrCorrupt
return r.err
}
if !r.readFull(r.buf[:len(magicBody)], false) {
return r.err
}
for i := 0; i < len(magicBody); i++ {
if r.buf[i] != magicBody[i] {
r.err = ErrCorrupt
return r.err
}
}
continue
}
if chunkType <= 0x7f {
// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
r.err = ErrUnsupported
return r.err
}
// Section 4.4 Padding (chunk type 0xfe).
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
if !r.readFull(r.buf[:chunkLen], false) {
return r.err
}
}
return nil
}
// Read satisfies the io.Reader interface.
func (r *Reader) Read(p []byte) (int, error) {
if r.err != nil {
return 0, r.err
}
if err := r.fill(); err != nil {
return 0, err
}
n := copy(p, r.decoded[r.i:r.j])
r.i += n
return n, nil
}
// ReadByte satisfies the io.ByteReader interface.
func (r *Reader) ReadByte() (byte, error) {
if r.err != nil {
return 0, r.err
}
if err := r.fill(); err != nil {
return 0, err
}
c := r.decoded[r.i]
r.i++
return c, nil
}

@ -0,0 +1,490 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
#include "textflag.h"
// The asm code generally follows the pure Go code in decode_other.go, except
// where marked with a "!!!".
// func decode(dst, src []byte) int
//
// All local variables fit into registers. The non-zero stack size is only to
// spill registers and push args when issuing a CALL. The register allocation:
// - AX scratch
// - BX scratch
// - CX length or x
// - DX offset
// - SI &src[s]
// - DI &dst[d]
// + R8 dst_base
// + R9 dst_len
// + R10 dst_base + dst_len
// + R11 src_base
// + R12 src_len
// + R13 src_base + src_len
// - R14 used by doCopy
// - R15 used by doCopy
//
// The registers R8-R13 (marked with a "+") are set at the start of the
// function, and after a CALL returns, and are not otherwise modified.
//
// The d variable is implicitly DI - R8, and len(dst)-d is R10 - DI.
// The s variable is implicitly SI - R11, and len(src)-s is R13 - SI.
TEXT ·decode(SB), NOSPLIT, $48-56
// Initialize SI, DI and R8-R13.
MOVQ dst_base+0(FP), R8
MOVQ dst_len+8(FP), R9
MOVQ R8, DI
MOVQ R8, R10
ADDQ R9, R10
MOVQ src_base+24(FP), R11
MOVQ src_len+32(FP), R12
MOVQ R11, SI
MOVQ R11, R13
ADDQ R12, R13
loop:
// for s < len(src)
CMPQ SI, R13
JEQ end
// CX = uint32(src[s])
//
// switch src[s] & 0x03
MOVBLZX (SI), CX
MOVL CX, BX
ANDL $3, BX
CMPL BX, $1
JAE tagCopy
// ----------------------------------------
// The code below handles literal tags.
// case tagLiteral:
// x := uint32(src[s] >> 2)
// switch
SHRL $2, CX
CMPL CX, $60
JAE tagLit60Plus
// case x < 60:
// s++
INCQ SI
doLit:
// This is the end of the inner "switch", when we have a literal tag.
//
// We assume that CX == x and x fits in a uint32, where x is the variable
// used in the pure Go decode_other.go code.
// length = int(x) + 1
//
// Unlike the pure Go code, we don't need to check if length <= 0 because
// CX can hold 64 bits, so the increment cannot overflow.
INCQ CX
// Prepare to check if copying length bytes will run past the end of dst or
// src.
//
// AX = len(dst) - d
// BX = len(src) - s
MOVQ R10, AX
SUBQ DI, AX
MOVQ R13, BX
SUBQ SI, BX
// !!! Try a faster technique for short (16 or fewer bytes) copies.
//
// if length > 16 || len(dst)-d < 16 || len(src)-s < 16 {
// goto callMemmove // Fall back on calling runtime·memmove.
// }
//
// The C++ snappy code calls this TryFastAppend. It also checks len(src)-s
// against 21 instead of 16, because it cannot assume that all of its input
// is contiguous in memory and so it needs to leave enough source bytes to
// read the next tag without refilling buffers, but Go's Decode assumes
// contiguousness (the src argument is a []byte).
CMPQ CX, $16
JGT callMemmove
CMPQ AX, $16
JLT callMemmove
CMPQ BX, $16
JLT callMemmove
// !!! Implement the copy from src to dst as a 16-byte load and store.
// (Decode's documentation says that dst and src must not overlap.)
//
// This always copies 16 bytes, instead of only length bytes, but that's
// OK. If the input is a valid Snappy encoding then subsequent iterations
// will fix up the overrun. Otherwise, Decode returns a nil []byte (and a
// non-nil error), so the overrun will be ignored.
//
// Note that on amd64, it is legal and cheap to issue unaligned 8-byte or
// 16-byte loads and stores. This technique probably wouldn't be as
// effective on architectures that are fussier about alignment.
MOVOU 0(SI), X0
MOVOU X0, 0(DI)
// d += length
// s += length
ADDQ CX, DI
ADDQ CX, SI
JMP loop
callMemmove:
// if length > len(dst)-d || length > len(src)-s { etc }
CMPQ CX, AX
JGT errCorrupt
CMPQ CX, BX
JGT errCorrupt
// copy(dst[d:], src[s:s+length])
//
// This means calling runtime·memmove(&dst[d], &src[s], length), so we push
// DI, SI and CX as arguments. Coincidentally, we also need to spill those
// three registers to the stack, to save local variables across the CALL.
MOVQ DI, 0(SP)
MOVQ SI, 8(SP)
MOVQ CX, 16(SP)
MOVQ DI, 24(SP)
MOVQ SI, 32(SP)
MOVQ CX, 40(SP)
CALL runtime·memmove(SB)
// Restore local variables: unspill registers from the stack and
// re-calculate R8-R13.
MOVQ 24(SP), DI
MOVQ 32(SP), SI
MOVQ 40(SP), CX
MOVQ dst_base+0(FP), R8
MOVQ dst_len+8(FP), R9
MOVQ R8, R10
ADDQ R9, R10
MOVQ src_base+24(FP), R11
MOVQ src_len+32(FP), R12
MOVQ R11, R13
ADDQ R12, R13
// d += length
// s += length
ADDQ CX, DI
ADDQ CX, SI
JMP loop
tagLit60Plus:
// !!! This fragment does the
//
// s += x - 58; if uint(s) > uint(len(src)) { etc }
//
// checks. In the asm version, we code it once instead of once per switch case.
ADDQ CX, SI
SUBQ $58, SI
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// case x == 60:
CMPL CX, $61
JEQ tagLit61
JA tagLit62Plus
// x = uint32(src[s-1])
MOVBLZX -1(SI), CX
JMP doLit
tagLit61:
// case x == 61:
// x = uint32(src[s-2]) | uint32(src[s-1])<<8
MOVWLZX -2(SI), CX
JMP doLit
tagLit62Plus:
CMPL CX, $62
JA tagLit63
// case x == 62:
// x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
MOVWLZX -3(SI), CX
MOVBLZX -1(SI), BX
SHLL $16, BX
ORL BX, CX
JMP doLit
tagLit63:
// case x == 63:
// x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
MOVL -4(SI), CX
JMP doLit
// The code above handles literal tags.
// ----------------------------------------
// The code below handles copy tags.
tagCopy4:
// case tagCopy4:
// s += 5
ADDQ $5, SI
// if uint(s) > uint(len(src)) { etc }
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// length = 1 + int(src[s-5])>>2
SHRQ $2, CX
INCQ CX
// offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
MOVLQZX -4(SI), DX
JMP doCopy
tagCopy2:
// case tagCopy2:
// s += 3
ADDQ $3, SI
// if uint(s) > uint(len(src)) { etc }
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// length = 1 + int(src[s-3])>>2
SHRQ $2, CX
INCQ CX
// offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
MOVWQZX -2(SI), DX
JMP doCopy
tagCopy:
// We have a copy tag. We assume that:
// - BX == src[s] & 0x03
// - CX == src[s]
CMPQ BX, $2
JEQ tagCopy2
JA tagCopy4
// case tagCopy1:
// s += 2
ADDQ $2, SI
// if uint(s) > uint(len(src)) { etc }
MOVQ SI, BX
SUBQ R11, BX
CMPQ BX, R12
JA errCorrupt
// offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
MOVQ CX, DX
ANDQ $0xe0, DX
SHLQ $3, DX
MOVBQZX -1(SI), BX
ORQ BX, DX
// length = 4 + int(src[s-2])>>2&0x7
SHRQ $2, CX
ANDQ $7, CX
ADDQ $4, CX
doCopy:
// This is the end of the outer "switch", when we have a copy tag.
//
// We assume that:
// - CX == length && CX > 0
// - DX == offset
// if offset <= 0 { etc }
CMPQ DX, $0
JLE errCorrupt
// if d < offset { etc }
MOVQ DI, BX
SUBQ R8, BX
CMPQ BX, DX
JLT errCorrupt
// if length > len(dst)-d { etc }
MOVQ R10, BX
SUBQ DI, BX
CMPQ CX, BX
JGT errCorrupt
// forwardCopy(dst[d:d+length], dst[d-offset:]); d += length
//
// Set:
// - R14 = len(dst)-d
// - R15 = &dst[d-offset]
MOVQ R10, R14
SUBQ DI, R14
MOVQ DI, R15
SUBQ DX, R15
// !!! Try a faster technique for short (16 or fewer bytes) forward copies.
//
// First, try using two 8-byte load/stores, similar to the doLit technique
// above. Even if dst[d:d+length] and dst[d-offset:] can overlap, this is
// still OK if offset >= 8. Note that this has to be two 8-byte load/stores
// and not one 16-byte load/store, and the first store has to be before the
// second load, due to the overlap if offset is in the range [8, 16).
//
// if length > 16 || offset < 8 || len(dst)-d < 16 {
// goto slowForwardCopy
// }
// copy 16 bytes
// d += length
CMPQ CX, $16
JGT slowForwardCopy
CMPQ DX, $8
JLT slowForwardCopy
CMPQ R14, $16
JLT slowForwardCopy
MOVQ 0(R15), AX
MOVQ AX, 0(DI)
MOVQ 8(R15), BX
MOVQ BX, 8(DI)
ADDQ CX, DI
JMP loop
slowForwardCopy:
// !!! If the forward copy is longer than 16 bytes, or if offset < 8, we
// can still try 8-byte load stores, provided we can overrun up to 10 extra
// bytes. As above, the overrun will be fixed up by subsequent iterations
// of the outermost loop.
//
// The C++ snappy code calls this technique IncrementalCopyFastPath. Its
// commentary says:
//
// ----
//
// The main part of this loop is a simple copy of eight bytes at a time
// until we've copied (at least) the requested amount of bytes. However,
// if d and d-offset are less than eight bytes apart (indicating a
// repeating pattern of length < 8), we first need to expand the pattern in
// order to get the correct results. For instance, if the buffer looks like
// this, with the eight-byte <d-offset> and <d> patterns marked as
// intervals:
//
// abxxxxxxxxxxxx
// [------] d-offset
// [------] d
//
// a single eight-byte copy from <d-offset> to <d> will repeat the pattern
// once, after which we can move <d> two bytes without moving <d-offset>:
//
// ababxxxxxxxxxx
// [------] d-offset
// [------] d
//
// and repeat the exercise until the two no longer overlap.
//
// This allows us to do very well in the special case of one single byte
// repeated many times, without taking a big hit for more general cases.
//
// The worst case of extra writing past the end of the match occurs when
// offset == 1 and length == 1; the last copy will read from byte positions
// [0..7] and write to [4..11], whereas it was only supposed to write to
// position 1. Thus, ten excess bytes.
//
// ----
//
// That "10 byte overrun" worst case is confirmed by Go's
// TestSlowForwardCopyOverrun, which also tests the fixUpSlowForwardCopy
// and finishSlowForwardCopy algorithm.
//
// if length > len(dst)-d-10 {
// goto verySlowForwardCopy
// }
SUBQ $10, R14
CMPQ CX, R14
JGT verySlowForwardCopy
makeOffsetAtLeast8:
// !!! As above, expand the pattern so that offset >= 8 and we can use
// 8-byte load/stores.
//
// for offset < 8 {
// copy 8 bytes from dst[d-offset:] to dst[d:]
// length -= offset
// d += offset
// offset += offset
// // The two previous lines together means that d-offset, and therefore
// // R15, is unchanged.
// }
CMPQ DX, $8
JGE fixUpSlowForwardCopy
MOVQ (R15), BX
MOVQ BX, (DI)
SUBQ DX, CX
ADDQ DX, DI
ADDQ DX, DX
JMP makeOffsetAtLeast8
fixUpSlowForwardCopy:
// !!! Add length (which might be negative now) to d (implied by DI being
// &dst[d]) so that d ends up at the right place when we jump back to the
// top of the loop. Before we do that, though, we save DI to AX so that, if
// length is positive, copying the remaining length bytes will write to the
// right place.
MOVQ DI, AX
ADDQ CX, DI
finishSlowForwardCopy:
// !!! Repeat 8-byte load/stores until length <= 0. Ending with a negative
// length means that we overrun, but as above, that will be fixed up by
// subsequent iterations of the outermost loop.
CMPQ CX, $0
JLE loop
MOVQ (R15), BX
MOVQ BX, (AX)
ADDQ $8, R15
ADDQ $8, AX
SUBQ $8, CX
JMP finishSlowForwardCopy
verySlowForwardCopy:
// verySlowForwardCopy is a simple implementation of forward copy. In C
// parlance, this is a do/while loop instead of a while loop, since we know
// that length > 0. In Go syntax:
//
// for {
// dst[d] = dst[d - offset]
// d++
// length--
// if length == 0 {
// break
// }
// }
MOVB (R15), BX
MOVB BX, (DI)
INCQ R15
INCQ DI
DECQ CX
JNZ verySlowForwardCopy
JMP loop
// The code above handles copy tags.
// ----------------------------------------
end:
// This is the end of the "for s < len(src)".
//
// if d != len(dst) { etc }
CMPQ DI, R10
JNE errCorrupt
// return 0
MOVQ $0, ret+48(FP)
RET
errCorrupt:
// return decodeErrCodeCorrupt
MOVQ $1, ret+48(FP)
RET

@ -0,0 +1,494 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
#include "textflag.h"
// The asm code generally follows the pure Go code in decode_other.go, except
// where marked with a "!!!".
// func decode(dst, src []byte) int
//
// All local variables fit into registers. The non-zero stack size is only to
// spill registers and push args when issuing a CALL. The register allocation:
// - R2 scratch
// - R3 scratch
// - R4 length or x
// - R5 offset
// - R6 &src[s]
// - R7 &dst[d]
// + R8 dst_base
// + R9 dst_len
// + R10 dst_base + dst_len
// + R11 src_base
// + R12 src_len
// + R13 src_base + src_len
// - R14 used by doCopy
// - R15 used by doCopy
//
// The registers R8-R13 (marked with a "+") are set at the start of the
// function, and after a CALL returns, and are not otherwise modified.
//
// The d variable is implicitly R7 - R8, and len(dst)-d is R10 - R7.
// The s variable is implicitly R6 - R11, and len(src)-s is R13 - R6.
TEXT ·decode(SB), NOSPLIT, $56-56
// Initialize R6, R7 and R8-R13.
MOVD dst_base+0(FP), R8
MOVD dst_len+8(FP), R9
MOVD R8, R7
MOVD R8, R10
ADD R9, R10, R10
MOVD src_base+24(FP), R11
MOVD src_len+32(FP), R12
MOVD R11, R6
MOVD R11, R13
ADD R12, R13, R13
loop:
// for s < len(src)
CMP R13, R6
BEQ end
// R4 = uint32(src[s])
//
// switch src[s] & 0x03
MOVBU (R6), R4
MOVW R4, R3
ANDW $3, R3
MOVW $1, R1
CMPW R1, R3
BGE tagCopy
// ----------------------------------------
// The code below handles literal tags.
// case tagLiteral:
// x := uint32(src[s] >> 2)
// switch
MOVW $60, R1
LSRW $2, R4, R4
CMPW R4, R1
BLS tagLit60Plus
// case x < 60:
// s++
ADD $1, R6, R6
doLit:
// This is the end of the inner "switch", when we have a literal tag.
//
// We assume that R4 == x and x fits in a uint32, where x is the variable
// used in the pure Go decode_other.go code.
// length = int(x) + 1
//
// Unlike the pure Go code, we don't need to check if length <= 0 because
// R4 can hold 64 bits, so the increment cannot overflow.
ADD $1, R4, R4
// Prepare to check if copying length bytes will run past the end of dst or
// src.
//
// R2 = len(dst) - d
// R3 = len(src) - s
MOVD R10, R2
SUB R7, R2, R2
MOVD R13, R3
SUB R6, R3, R3
// !!! Try a faster technique for short (16 or fewer bytes) copies.
//
// if length > 16 || len(dst)-d < 16 || len(src)-s < 16 {
// goto callMemmove // Fall back on calling runtime·memmove.
// }
//
// The C++ snappy code calls this TryFastAppend. It also checks len(src)-s
// against 21 instead of 16, because it cannot assume that all of its input
// is contiguous in memory and so it needs to leave enough source bytes to
// read the next tag without refilling buffers, but Go's Decode assumes
// contiguousness (the src argument is a []byte).
CMP $16, R4
BGT callMemmove
CMP $16, R2
BLT callMemmove
CMP $16, R3
BLT callMemmove
// !!! Implement the copy from src to dst as a 16-byte load and store.
// (Decode's documentation says that dst and src must not overlap.)
//
// This always copies 16 bytes, instead of only length bytes, but that's
// OK. If the input is a valid Snappy encoding then subsequent iterations
// will fix up the overrun. Otherwise, Decode returns a nil []byte (and a
// non-nil error), so the overrun will be ignored.
//
// Note that on arm64, it is legal and cheap to issue unaligned 8-byte or
// 16-byte loads and stores. This technique probably wouldn't be as
// effective on architectures that are fussier about alignment.
LDP 0(R6), (R14, R15)
STP (R14, R15), 0(R7)
// d += length
// s += length
ADD R4, R7, R7
ADD R4, R6, R6
B loop
callMemmove:
// if length > len(dst)-d || length > len(src)-s { etc }
CMP R2, R4
BGT errCorrupt
CMP R3, R4
BGT errCorrupt
// copy(dst[d:], src[s:s+length])
//
// This means calling runtime·memmove(&dst[d], &src[s], length), so we push
// R7, R6 and R4 as arguments. Coincidentally, we also need to spill those
// three registers to the stack, to save local variables across the CALL.
MOVD R7, 8(RSP)
MOVD R6, 16(RSP)
MOVD R4, 24(RSP)
MOVD R7, 32(RSP)
MOVD R6, 40(RSP)
MOVD R4, 48(RSP)
CALL runtime·memmove(SB)
// Restore local variables: unspill registers from the stack and
// re-calculate R8-R13.
MOVD 32(RSP), R7
MOVD 40(RSP), R6
MOVD 48(RSP), R4
MOVD dst_base+0(FP), R8
MOVD dst_len+8(FP), R9
MOVD R8, R10
ADD R9, R10, R10
MOVD src_base+24(FP), R11
MOVD src_len+32(FP), R12
MOVD R11, R13
ADD R12, R13, R13
// d += length
// s += length
ADD R4, R7, R7
ADD R4, R6, R6
B loop
tagLit60Plus:
// !!! This fragment does the
//
// s += x - 58; if uint(s) > uint(len(src)) { etc }
//
// checks. In the asm version, we code it once instead of once per switch case.
ADD R4, R6, R6
SUB $58, R6, R6
MOVD R6, R3
SUB R11, R3, R3
CMP R12, R3
BGT errCorrupt
// case x == 60:
MOVW $61, R1
CMPW R1, R4
BEQ tagLit61
BGT tagLit62Plus
// x = uint32(src[s-1])
MOVBU -1(R6), R4
B doLit
tagLit61:
// case x == 61:
// x = uint32(src[s-2]) | uint32(src[s-1])<<8
MOVHU -2(R6), R4
B doLit
tagLit62Plus:
CMPW $62, R4
BHI tagLit63
// case x == 62:
// x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
MOVHU -3(R6), R4
MOVBU -1(R6), R3
ORR R3<<16, R4
B doLit
tagLit63:
// case x == 63:
// x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
MOVWU -4(R6), R4
B doLit
// The code above handles literal tags.
// ----------------------------------------
// The code below handles copy tags.
tagCopy4:
// case tagCopy4:
// s += 5
ADD $5, R6, R6
// if uint(s) > uint(len(src)) { etc }
MOVD R6, R3
SUB R11, R3, R3
CMP R12, R3
BGT errCorrupt
// length = 1 + int(src[s-5])>>2
MOVD $1, R1
ADD R4>>2, R1, R4
// offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
MOVWU -4(R6), R5
B doCopy
tagCopy2:
// case tagCopy2:
// s += 3
ADD $3, R6, R6
// if uint(s) > uint(len(src)) { etc }
MOVD R6, R3
SUB R11, R3, R3
CMP R12, R3
BGT errCorrupt
// length = 1 + int(src[s-3])>>2
MOVD $1, R1
ADD R4>>2, R1, R4
// offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
MOVHU -2(R6), R5
B doCopy
tagCopy:
// We have a copy tag. We assume that:
// - R3 == src[s] & 0x03
// - R4 == src[s]
CMP $2, R3
BEQ tagCopy2
BGT tagCopy4
// case tagCopy1:
// s += 2
ADD $2, R6, R6
// if uint(s) > uint(len(src)) { etc }
MOVD R6, R3
SUB R11, R3, R3
CMP R12, R3
BGT errCorrupt
// offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
MOVD R4, R5
AND $0xe0, R5
MOVBU -1(R6), R3
ORR R5<<3, R3, R5
// length = 4 + int(src[s-2])>>2&0x7
MOVD $7, R1
AND R4>>2, R1, R4
ADD $4, R4, R4
doCopy:
// This is the end of the outer "switch", when we have a copy tag.
//
// We assume that:
// - R4 == length && R4 > 0
// - R5 == offset
// if offset <= 0 { etc }
MOVD $0, R1
CMP R1, R5
BLE errCorrupt
// if d < offset { etc }
MOVD R7, R3
SUB R8, R3, R3
CMP R5, R3
BLT errCorrupt
// if length > len(dst)-d { etc }
MOVD R10, R3
SUB R7, R3, R3
CMP R3, R4
BGT errCorrupt
// forwardCopy(dst[d:d+length], dst[d-offset:]); d += length
//
// Set:
// - R14 = len(dst)-d
// - R15 = &dst[d-offset]
MOVD R10, R14
SUB R7, R14, R14
MOVD R7, R15
SUB R5, R15, R15
// !!! Try a faster technique for short (16 or fewer bytes) forward copies.
//
// First, try using two 8-byte load/stores, similar to the doLit technique
// above. Even if dst[d:d+length] and dst[d-offset:] can overlap, this is
// still OK if offset >= 8. Note that this has to be two 8-byte load/stores
// and not one 16-byte load/store, and the first store has to be before the
// second load, due to the overlap if offset is in the range [8, 16).
//
// if length > 16 || offset < 8 || len(dst)-d < 16 {
// goto slowForwardCopy
// }
// copy 16 bytes
// d += length
CMP $16, R4
BGT slowForwardCopy
CMP $8, R5
BLT slowForwardCopy
CMP $16, R14
BLT slowForwardCopy
MOVD 0(R15), R2
MOVD R2, 0(R7)
MOVD 8(R15), R3
MOVD R3, 8(R7)
ADD R4, R7, R7
B loop
slowForwardCopy:
// !!! If the forward copy is longer than 16 bytes, or if offset < 8, we
// can still try 8-byte load stores, provided we can overrun up to 10 extra
// bytes. As above, the overrun will be fixed up by subsequent iterations
// of the outermost loop.
//
// The C++ snappy code calls this technique IncrementalCopyFastPath. Its
// commentary says:
//
// ----
//
// The main part of this loop is a simple copy of eight bytes at a time
// until we've copied (at least) the requested amount of bytes. However,
// if d and d-offset are less than eight bytes apart (indicating a
// repeating pattern of length < 8), we first need to expand the pattern in
// order to get the correct results. For instance, if the buffer looks like
// this, with the eight-byte <d-offset> and <d> patterns marked as
// intervals:
//
// abxxxxxxxxxxxx
// [------] d-offset
// [------] d
//
// a single eight-byte copy from <d-offset> to <d> will repeat the pattern
// once, after which we can move <d> two bytes without moving <d-offset>:
//
// ababxxxxxxxxxx
// [------] d-offset
// [------] d
//
// and repeat the exercise until the two no longer overlap.
//
// This allows us to do very well in the special case of one single byte
// repeated many times, without taking a big hit for more general cases.
//
// The worst case of extra writing past the end of the match occurs when
// offset == 1 and length == 1; the last copy will read from byte positions
// [0..7] and write to [4..11], whereas it was only supposed to write to
// position 1. Thus, ten excess bytes.
//
// ----
//
// That "10 byte overrun" worst case is confirmed by Go's
// TestSlowForwardCopyOverrun, which also tests the fixUpSlowForwardCopy
// and finishSlowForwardCopy algorithm.
//
// if length > len(dst)-d-10 {
// goto verySlowForwardCopy
// }
SUB $10, R14, R14
CMP R14, R4
BGT verySlowForwardCopy
makeOffsetAtLeast8:
// !!! As above, expand the pattern so that offset >= 8 and we can use
// 8-byte load/stores.
//
// for offset < 8 {
// copy 8 bytes from dst[d-offset:] to dst[d:]
// length -= offset
// d += offset
// offset += offset
// // The two previous lines together means that d-offset, and therefore
// // R15, is unchanged.
// }
CMP $8, R5
BGE fixUpSlowForwardCopy
MOVD (R15), R3
MOVD R3, (R7)
SUB R5, R4, R4
ADD R5, R7, R7
ADD R5, R5, R5
B makeOffsetAtLeast8
fixUpSlowForwardCopy:
// !!! Add length (which might be negative now) to d (implied by R7 being
// &dst[d]) so that d ends up at the right place when we jump back to the
// top of the loop. Before we do that, though, we save R7 to R2 so that, if
// length is positive, copying the remaining length bytes will write to the
// right place.
MOVD R7, R2
ADD R4, R7, R7
finishSlowForwardCopy:
// !!! Repeat 8-byte load/stores until length <= 0. Ending with a negative
// length means that we overrun, but as above, that will be fixed up by
// subsequent iterations of the outermost loop.
MOVD $0, R1
CMP R1, R4
BLE loop
MOVD (R15), R3
MOVD R3, (R2)
ADD $8, R15, R15
ADD $8, R2, R2
SUB $8, R4, R4
B finishSlowForwardCopy
verySlowForwardCopy:
// verySlowForwardCopy is a simple implementation of forward copy. In C
// parlance, this is a do/while loop instead of a while loop, since we know
// that length > 0. In Go syntax:
//
// for {
// dst[d] = dst[d - offset]
// d++
// length--
// if length == 0 {
// break
// }
// }
MOVB (R15), R3
MOVB R3, (R7)
ADD $1, R15, R15
ADD $1, R7, R7
SUB $1, R4, R4
CBNZ R4, verySlowForwardCopy
B loop
// The code above handles copy tags.
// ----------------------------------------
end:
// This is the end of the "for s < len(src)".
//
// if d != len(dst) { etc }
CMP R10, R7
BNE errCorrupt
// return 0
MOVD $0, ret+48(FP)
RET
errCorrupt:
// return decodeErrCodeCorrupt
MOVD $1, R2
MOVD R2, ret+48(FP)
RET

@ -0,0 +1,15 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
// +build amd64 arm64
package snappy
// decode has the same semantics as in decode_other.go.
//
//go:noescape
func decode(dst, src []byte) int

@ -0,0 +1,115 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64,!arm64 appengine !gc noasm
package snappy
// decode writes the decoding of src to dst. It assumes that the varint-encoded
// length of the decompressed bytes has already been read, and that len(dst)
// equals that length.
//
// It returns 0 on success or a decodeErrCodeXxx error code on failure.
func decode(dst, src []byte) int {
var d, s, offset, length int
for s < len(src) {
switch src[s] & 0x03 {
case tagLiteral:
x := uint32(src[s] >> 2)
switch {
case x < 60:
s++
case x == 60:
s += 2
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-1])
case x == 61:
s += 3
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-2]) | uint32(src[s-1])<<8
case x == 62:
s += 4
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16
case x == 63:
s += 5
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24
}
length = int(x) + 1
if length <= 0 {
return decodeErrCodeUnsupportedLiteralLength
}
if length > len(dst)-d || length > len(src)-s {
return decodeErrCodeCorrupt
}
copy(dst[d:], src[s:s+length])
d += length
s += length
continue
case tagCopy1:
s += 2
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
length = 4 + int(src[s-2])>>2&0x7
offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1]))
case tagCopy2:
s += 3
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
length = 1 + int(src[s-3])>>2
offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8)
case tagCopy4:
s += 5
if uint(s) > uint(len(src)) { // The uint conversions catch overflow from the previous line.
return decodeErrCodeCorrupt
}
length = 1 + int(src[s-5])>>2
offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24)
}
if offset <= 0 || d < offset || length > len(dst)-d {
return decodeErrCodeCorrupt
}
// Copy from an earlier sub-slice of dst to a later sub-slice.
// If no overlap, use the built-in copy:
if offset >= length {
copy(dst[d:d+length], dst[d-offset:])
d += length
continue
}
// Unlike the built-in copy function, this byte-by-byte copy always runs
// forwards, even if the slices overlap. Conceptually, this is:
//
// d += forwardCopy(dst[d:d+length], dst[d-offset:])
//
// We align the slices into a and b and show the compiler they are the same size.
// This allows the loop to run without bounds checks.
a := dst[d : d+length]
b := dst[d-offset:]
b = b[:len(a)]
for i := range a {
a[i] = b[i]
}
d += length
}
if d != len(dst) {
return decodeErrCodeCorrupt
}
return 0
}

@ -0,0 +1,289 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
"errors"
"io"
)
// Encode returns the encoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire encoded block.
// Otherwise, a newly allocated slice will be returned.
//
// The dst and src must not overlap. It is valid to pass a nil dst.
//
// Encode handles the Snappy block format, not the Snappy stream format.
func Encode(dst, src []byte) []byte {
if n := MaxEncodedLen(len(src)); n < 0 {
panic(ErrTooLarge)
} else if len(dst) < n {
dst = make([]byte, n)
}
// The block starts with the varint-encoded length of the decompressed bytes.
d := binary.PutUvarint(dst, uint64(len(src)))
for len(src) > 0 {
p := src
src = nil
if len(p) > maxBlockSize {
p, src = p[:maxBlockSize], p[maxBlockSize:]
}
if len(p) < minNonLiteralBlockSize {
d += emitLiteral(dst[d:], p)
} else {
d += encodeBlock(dst[d:], p)
}
}
return dst[:d]
}
// inputMargin is the minimum number of extra input bytes to keep, inside
// encodeBlock's inner loop. On some architectures, this margin lets us
// implement a fast path for emitLiteral, where the copy of short (<= 16 byte)
// literals can be implemented as a single load to and store from a 16-byte
// register. That literal's actual length can be as short as 1 byte, so this
// can copy up to 15 bytes too much, but that's OK as subsequent iterations of
// the encoding loop will fix up the copy overrun, and this inputMargin ensures
// that we don't overrun the dst and src buffers.
const inputMargin = 16 - 1
// minNonLiteralBlockSize is the minimum size of the input to encodeBlock that
// could be encoded with a copy tag. This is the minimum with respect to the
// algorithm used by encodeBlock, not a minimum enforced by the file format.
//
// The encoded output must start with at least a 1 byte literal, as there are
// no previous bytes to copy. A minimal (1 byte) copy after that, generated
// from an emitCopy call in encodeBlock's main loop, would require at least
// another inputMargin bytes, for the reason above: we want any emitLiteral
// calls inside encodeBlock's main loop to use the fast path if possible, which
// requires being able to overrun by inputMargin bytes. Thus,
// minNonLiteralBlockSize equals 1 + 1 + inputMargin.
//
// The C++ code doesn't use this exact threshold, but it could, as discussed at
// https://groups.google.com/d/topic/snappy-compression/oGbhsdIJSJ8/discussion
// The difference between Go (2+inputMargin) and C++ (inputMargin) is purely an
// optimization. It should not affect the encoded form. This is tested by
// TestSameEncodingAsCppShortCopies.
const minNonLiteralBlockSize = 1 + 1 + inputMargin
// MaxEncodedLen returns the maximum length of a snappy block, given its
// uncompressed length.
//
// It will return a negative value if srcLen is too large to encode.
func MaxEncodedLen(srcLen int) int {
n := uint64(srcLen)
if n > 0xffffffff {
return -1
}
// Compressed data can be defined as:
// compressed := item* literal*
// item := literal* copy
//
// The trailing literal sequence has a space blowup of at most 62/60
// since a literal of length 60 needs one tag byte + one extra byte
// for length information.
//
// Item blowup is trickier to measure. Suppose the "copy" op copies
// 4 bytes of data. Because of a special check in the encoding code,
// we produce a 4-byte copy only if the offset is < 65536. Therefore
// the copy op takes 3 bytes to encode, and this type of item leads
// to at most the 62/60 blowup for representing literals.
//
// Suppose the "copy" op copies 5 bytes of data. If the offset is big
// enough, it will take 5 bytes to encode the copy op. Therefore the
// worst case here is a one-byte literal followed by a five-byte copy.
// That is, 6 bytes of input turn into 7 bytes of "compressed" data.
//
// This last factor dominates the blowup, so the final estimate is:
n = 32 + n + n/6
if n > 0xffffffff {
return -1
}
return int(n)
}
var errClosed = errors.New("snappy: Writer is closed")
// NewWriter returns a new Writer that compresses to w.
//
// The Writer returned does not buffer writes. There is no need to Flush or
// Close such a Writer.
//
// Deprecated: the Writer returned is not suitable for many small writes, only
// for few large writes. Use NewBufferedWriter instead, which is efficient
// regardless of the frequency and shape of the writes, and remember to Close
// that Writer when done.
func NewWriter(w io.Writer) *Writer {
return &Writer{
w: w,
obuf: make([]byte, obufLen),
}
}
// NewBufferedWriter returns a new Writer that compresses to w, using the
// framing format described at
// https://github.com/google/snappy/blob/master/framing_format.txt
//
// The Writer returned buffers writes. Users must call Close to guarantee all
// data has been forwarded to the underlying io.Writer. They may also call
// Flush zero or more times before calling Close.
func NewBufferedWriter(w io.Writer) *Writer {
return &Writer{
w: w,
ibuf: make([]byte, 0, maxBlockSize),
obuf: make([]byte, obufLen),
}
}
// Writer is an io.Writer that can write Snappy-compressed bytes.
//
// Writer handles the Snappy stream format, not the Snappy block format.
type Writer struct {
w io.Writer
err error
// ibuf is a buffer for the incoming (uncompressed) bytes.
//
// Its use is optional. For backwards compatibility, Writers created by the
// NewWriter function have ibuf == nil, do not buffer incoming bytes, and
// therefore do not need to be Flush'ed or Close'd.
ibuf []byte
// obuf is a buffer for the outgoing (compressed) bytes.
obuf []byte
// wroteStreamHeader is whether we have written the stream header.
wroteStreamHeader bool
}
// Reset discards the writer's state and switches the Snappy writer to write to
// w. This permits reusing a Writer rather than allocating a new one.
func (w *Writer) Reset(writer io.Writer) {
w.w = writer
w.err = nil
if w.ibuf != nil {
w.ibuf = w.ibuf[:0]
}
w.wroteStreamHeader = false
}
// Write satisfies the io.Writer interface.
func (w *Writer) Write(p []byte) (nRet int, errRet error) {
if w.ibuf == nil {
// Do not buffer incoming bytes. This does not perform or compress well
// if the caller of Writer.Write writes many small slices. This
// behavior is therefore deprecated, but still supported for backwards
// compatibility with code that doesn't explicitly Flush or Close.
return w.write(p)
}
// The remainder of this method is based on bufio.Writer.Write from the
// standard library.
for len(p) > (cap(w.ibuf)-len(w.ibuf)) && w.err == nil {
var n int
if len(w.ibuf) == 0 {
// Large write, empty buffer.
// Write directly from p to avoid copy.
n, _ = w.write(p)
} else {
n = copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
w.ibuf = w.ibuf[:len(w.ibuf)+n]
w.Flush()
}
nRet += n
p = p[n:]
}
if w.err != nil {
return nRet, w.err
}
n := copy(w.ibuf[len(w.ibuf):cap(w.ibuf)], p)
w.ibuf = w.ibuf[:len(w.ibuf)+n]
nRet += n
return nRet, nil
}
func (w *Writer) write(p []byte) (nRet int, errRet error) {
if w.err != nil {
return 0, w.err
}
for len(p) > 0 {
obufStart := len(magicChunk)
if !w.wroteStreamHeader {
w.wroteStreamHeader = true
copy(w.obuf, magicChunk)
obufStart = 0
}
var uncompressed []byte
if len(p) > maxBlockSize {
uncompressed, p = p[:maxBlockSize], p[maxBlockSize:]
} else {
uncompressed, p = p, nil
}
checksum := crc(uncompressed)
// Compress the buffer, discarding the result if the improvement
// isn't at least 12.5%.
compressed := Encode(w.obuf[obufHeaderLen:], uncompressed)
chunkType := uint8(chunkTypeCompressedData)
chunkLen := 4 + len(compressed)
obufEnd := obufHeaderLen + len(compressed)
if len(compressed) >= len(uncompressed)-len(uncompressed)/8 {
chunkType = chunkTypeUncompressedData
chunkLen = 4 + len(uncompressed)
obufEnd = obufHeaderLen
}
// Fill in the per-chunk header that comes before the body.
w.obuf[len(magicChunk)+0] = chunkType
w.obuf[len(magicChunk)+1] = uint8(chunkLen >> 0)
w.obuf[len(magicChunk)+2] = uint8(chunkLen >> 8)
w.obuf[len(magicChunk)+3] = uint8(chunkLen >> 16)
w.obuf[len(magicChunk)+4] = uint8(checksum >> 0)
w.obuf[len(magicChunk)+5] = uint8(checksum >> 8)
w.obuf[len(magicChunk)+6] = uint8(checksum >> 16)
w.obuf[len(magicChunk)+7] = uint8(checksum >> 24)
if _, err := w.w.Write(w.obuf[obufStart:obufEnd]); err != nil {
w.err = err
return nRet, err
}
if chunkType == chunkTypeUncompressedData {
if _, err := w.w.Write(uncompressed); err != nil {
w.err = err
return nRet, err
}
}
nRet += len(uncompressed)
}
return nRet, nil
}
// Flush flushes the Writer to its underlying io.Writer.
func (w *Writer) Flush() error {
if w.err != nil {
return w.err
}
if len(w.ibuf) == 0 {
return nil
}
w.write(w.ibuf)
w.ibuf = w.ibuf[:0]
return w.err
}
// Close calls Flush and then closes the Writer.
func (w *Writer) Close() error {
w.Flush()
ret := w.err
if w.err == nil {
w.err = errClosed
}
return ret
}

@ -0,0 +1,730 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
#include "textflag.h"
// The XXX lines assemble on Go 1.4, 1.5 and 1.7, but not 1.6, due to a
// Go toolchain regression. See https://github.com/golang/go/issues/15426 and
// https://github.com/golang/snappy/issues/29
//
// As a workaround, the package was built with a known good assembler, and
// those instructions were disassembled by "objdump -d" to yield the
// 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
// style comments, in AT&T asm syntax. Note that rsp here is a physical
// register, not Go/asm's SP pseudo-register (see https://golang.org/doc/asm).
// The instructions were then encoded as "BYTE $0x.." sequences, which assemble
// fine on Go 1.6.
// The asm code generally follows the pure Go code in encode_other.go, except
// where marked with a "!!!".
// ----------------------------------------------------------------------------
// func emitLiteral(dst, lit []byte) int
//
// All local variables fit into registers. The register allocation:
// - AX len(lit)
// - BX n
// - DX return value
// - DI &dst[i]
// - R10 &lit[0]
//
// The 24 bytes of stack space is to call runtime·memmove.
//
// The unusual register allocation of local variables, such as R10 for the
// source pointer, matches the allocation used at the call site in encodeBlock,
// which makes it easier to manually inline this function.
TEXT ·emitLiteral(SB), NOSPLIT, $24-56
MOVQ dst_base+0(FP), DI
MOVQ lit_base+24(FP), R10
MOVQ lit_len+32(FP), AX
MOVQ AX, DX
MOVL AX, BX
SUBL $1, BX
CMPL BX, $60
JLT oneByte
CMPL BX, $256
JLT twoBytes
threeBytes:
MOVB $0xf4, 0(DI)
MOVW BX, 1(DI)
ADDQ $3, DI
ADDQ $3, DX
JMP memmove
twoBytes:
MOVB $0xf0, 0(DI)
MOVB BX, 1(DI)
ADDQ $2, DI
ADDQ $2, DX
JMP memmove
oneByte:
SHLB $2, BX
MOVB BX, 0(DI)
ADDQ $1, DI
ADDQ $1, DX
memmove:
MOVQ DX, ret+48(FP)
// copy(dst[i:], lit)
//
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
// DI, R10 and AX as arguments.
MOVQ DI, 0(SP)
MOVQ R10, 8(SP)
MOVQ AX, 16(SP)
CALL runtime·memmove(SB)
RET
// ----------------------------------------------------------------------------
// func emitCopy(dst []byte, offset, length int) int
//
// All local variables fit into registers. The register allocation:
// - AX length
// - SI &dst[0]
// - DI &dst[i]
// - R11 offset
//
// The unusual register allocation of local variables, such as R11 for the
// offset, matches the allocation used at the call site in encodeBlock, which
// makes it easier to manually inline this function.
TEXT ·emitCopy(SB), NOSPLIT, $0-48
MOVQ dst_base+0(FP), DI
MOVQ DI, SI
MOVQ offset+24(FP), R11
MOVQ length+32(FP), AX
loop0:
// for length >= 68 { etc }
CMPL AX, $68
JLT step1
// Emit a length 64 copy, encoded as 3 bytes.
MOVB $0xfe, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $64, AX
JMP loop0
step1:
// if length > 64 { etc }
CMPL AX, $64
JLE step2
// Emit a length 60 copy, encoded as 3 bytes.
MOVB $0xee, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $60, AX
step2:
// if length >= 12 || offset >= 2048 { goto step3 }
CMPL AX, $12
JGE step3
CMPL R11, $2048
JGE step3
// Emit the remaining copy, encoded as 2 bytes.
MOVB R11, 1(DI)
SHRL $8, R11
SHLB $5, R11
SUBB $4, AX
SHLB $2, AX
ORB AX, R11
ORB $1, R11
MOVB R11, 0(DI)
ADDQ $2, DI
// Return the number of bytes written.
SUBQ SI, DI
MOVQ DI, ret+40(FP)
RET
step3:
// Emit the remaining copy, encoded as 3 bytes.
SUBL $1, AX
SHLB $2, AX
ORB $2, AX
MOVB AX, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
// Return the number of bytes written.
SUBQ SI, DI
MOVQ DI, ret+40(FP)
RET
// ----------------------------------------------------------------------------
// func extendMatch(src []byte, i, j int) int
//
// All local variables fit into registers. The register allocation:
// - DX &src[0]
// - SI &src[j]
// - R13 &src[len(src) - 8]
// - R14 &src[len(src)]
// - R15 &src[i]
//
// The unusual register allocation of local variables, such as R15 for a source
// pointer, matches the allocation used at the call site in encodeBlock, which
// makes it easier to manually inline this function.
TEXT ·extendMatch(SB), NOSPLIT, $0-48
MOVQ src_base+0(FP), DX
MOVQ src_len+8(FP), R14
MOVQ i+24(FP), R15
MOVQ j+32(FP), SI
ADDQ DX, R14
ADDQ DX, R15
ADDQ DX, SI
MOVQ R14, R13
SUBQ $8, R13
cmp8:
// As long as we are 8 or more bytes before the end of src, we can load and
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
CMPQ SI, R13
JA cmp1
MOVQ (R15), AX
MOVQ (SI), BX
CMPQ AX, BX
JNE bsf
ADDQ $8, R15
ADDQ $8, SI
JMP cmp8
bsf:
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
// the index of the first byte that differs. The BSF instruction finds the
// least significant 1 bit, the amd64 architecture is little-endian, and
// the shift by 3 converts a bit index to a byte index.
XORQ AX, BX
BSFQ BX, BX
SHRQ $3, BX
ADDQ BX, SI
// Convert from &src[ret] to ret.
SUBQ DX, SI
MOVQ SI, ret+40(FP)
RET
cmp1:
// In src's tail, compare 1 byte at a time.
CMPQ SI, R14
JAE extendMatchEnd
MOVB (R15), AX
MOVB (SI), BX
CMPB AX, BX
JNE extendMatchEnd
ADDQ $1, R15
ADDQ $1, SI
JMP cmp1
extendMatchEnd:
// Convert from &src[ret] to ret.
SUBQ DX, SI
MOVQ SI, ret+40(FP)
RET
// ----------------------------------------------------------------------------
// func encodeBlock(dst, src []byte) (d int)
//
// All local variables fit into registers, other than "var table". The register
// allocation:
// - AX . .
// - BX . .
// - CX 56 shift (note that amd64 shifts by non-immediates must use CX).
// - DX 64 &src[0], tableSize
// - SI 72 &src[s]
// - DI 80 &dst[d]
// - R9 88 sLimit
// - R10 . &src[nextEmit]
// - R11 96 prevHash, currHash, nextHash, offset
// - R12 104 &src[base], skip
// - R13 . &src[nextS], &src[len(src) - 8]
// - R14 . len(src), bytesBetweenHashLookups, &src[len(src)], x
// - R15 112 candidate
//
// The second column (56, 64, etc) is the stack offset to spill the registers
// when calling other functions. We could pack this slightly tighter, but it's
// simpler to have a dedicated spill map independent of the function called.
//
// "var table [maxTableSize]uint16" takes up 32768 bytes of stack space. An
// extra 56 bytes, to call other functions, and an extra 64 bytes, to spill
// local variables (registers) during calls gives 32768 + 56 + 64 = 32888.
TEXT ·encodeBlock(SB), 0, $32888-56
MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI
MOVQ src_len+32(FP), R14
// shift, tableSize := uint32(32-8), 1<<8
MOVQ $24, CX
MOVQ $256, DX
calcShift:
// for ; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 {
// shift--
// }
CMPQ DX, $16384
JGE varTable
CMPQ DX, R14
JGE varTable
SUBQ $1, CX
SHLQ $1, DX
JMP calcShift
varTable:
// var table [maxTableSize]uint16
//
// In the asm code, unlike the Go code, we can zero-initialize only the
// first tableSize elements. Each uint16 element is 2 bytes and each MOVOU
// writes 16 bytes, so we can do only tableSize/8 writes instead of the
// 2048 writes that would zero-initialize all of table's 32768 bytes.
SHRQ $3, DX
LEAQ table-32768(SP), BX
PXOR X0, X0
memclr:
MOVOU X0, 0(BX)
ADDQ $16, BX
SUBQ $1, DX
JNZ memclr
// !!! DX = &src[0]
MOVQ SI, DX
// sLimit := len(src) - inputMargin
MOVQ R14, R9
SUBQ $15, R9
// !!! Pre-emptively spill CX, DX and R9 to the stack. Their values don't
// change for the rest of the function.
MOVQ CX, 56(SP)
MOVQ DX, 64(SP)
MOVQ R9, 88(SP)
// nextEmit := 0
MOVQ DX, R10
// s := 1
ADDQ $1, SI
// nextHash := hash(load32(src, s), shift)
MOVL 0(SI), R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
outer:
// for { etc }
// skip := 32
MOVQ $32, R12
// nextS := s
MOVQ SI, R13
// candidate := 0
MOVQ $0, R15
inner0:
// for { etc }
// s := nextS
MOVQ R13, SI
// bytesBetweenHashLookups := skip >> 5
MOVQ R12, R14
SHRQ $5, R14
// nextS = s + bytesBetweenHashLookups
ADDQ R14, R13
// skip += bytesBetweenHashLookups
ADDQ R14, R12
// if nextS > sLimit { goto emitRemainder }
MOVQ R13, AX
SUBQ DX, AX
CMPQ AX, R9
JA emitRemainder
// candidate = int(table[nextHash])
// XXX: MOVWQZX table-32768(SP)(R11*2), R15
// XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
BYTE $0x4e
BYTE $0x0f
BYTE $0xb7
BYTE $0x7c
BYTE $0x5c
BYTE $0x78
// table[nextHash] = uint16(s)
MOVQ SI, AX
SUBQ DX, AX
// XXX: MOVW AX, table-32768(SP)(R11*2)
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
BYTE $0x66
BYTE $0x42
BYTE $0x89
BYTE $0x44
BYTE $0x5c
BYTE $0x78
// nextHash = hash(load32(src, nextS), shift)
MOVL 0(R13), R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// if load32(src, s) != load32(src, candidate) { continue } break
MOVL 0(SI), AX
MOVL (DX)(R15*1), BX
CMPL AX, BX
JNE inner0
fourByteMatch:
// As per the encode_other.go code:
//
// A 4-byte match has been found. We'll later see etc.
// !!! Jump to a fast path for short (<= 16 byte) literals. See the comment
// on inputMargin in encode.go.
MOVQ SI, AX
SUBQ R10, AX
CMPQ AX, $16
JLE emitLiteralFastPath
// ----------------------------------------
// Begin inline of the emitLiteral call.
//
// d += emitLiteral(dst[d:], src[nextEmit:s])
MOVL AX, BX
SUBL $1, BX
CMPL BX, $60
JLT inlineEmitLiteralOneByte
CMPL BX, $256
JLT inlineEmitLiteralTwoBytes
inlineEmitLiteralThreeBytes:
MOVB $0xf4, 0(DI)
MOVW BX, 1(DI)
ADDQ $3, DI
JMP inlineEmitLiteralMemmove
inlineEmitLiteralTwoBytes:
MOVB $0xf0, 0(DI)
MOVB BX, 1(DI)
ADDQ $2, DI
JMP inlineEmitLiteralMemmove
inlineEmitLiteralOneByte:
SHLB $2, BX
MOVB BX, 0(DI)
ADDQ $1, DI
inlineEmitLiteralMemmove:
// Spill local variables (registers) onto the stack; call; unspill.
//
// copy(dst[i:], lit)
//
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
// DI, R10 and AX as arguments.
MOVQ DI, 0(SP)
MOVQ R10, 8(SP)
MOVQ AX, 16(SP)
ADDQ AX, DI // Finish the "d +=" part of "d += emitLiteral(etc)".
MOVQ SI, 72(SP)
MOVQ DI, 80(SP)
MOVQ R15, 112(SP)
CALL runtime·memmove(SB)
MOVQ 56(SP), CX
MOVQ 64(SP), DX
MOVQ 72(SP), SI
MOVQ 80(SP), DI
MOVQ 88(SP), R9
MOVQ 112(SP), R15
JMP inner1
inlineEmitLiteralEnd:
// End inline of the emitLiteral call.
// ----------------------------------------
emitLiteralFastPath:
// !!! Emit the 1-byte encoding "uint8(len(lit)-1)<<2".
MOVB AX, BX
SUBB $1, BX
SHLB $2, BX
MOVB BX, (DI)
ADDQ $1, DI
// !!! Implement the copy from lit to dst as a 16-byte load and store.
// (Encode's documentation says that dst and src must not overlap.)
//
// This always copies 16 bytes, instead of only len(lit) bytes, but that's
// OK. Subsequent iterations will fix up the overrun.
//
// Note that on amd64, it is legal and cheap to issue unaligned 8-byte or
// 16-byte loads and stores. This technique probably wouldn't be as
// effective on architectures that are fussier about alignment.
MOVOU 0(R10), X0
MOVOU X0, 0(DI)
ADDQ AX, DI
inner1:
// for { etc }
// base := s
MOVQ SI, R12
// !!! offset := base - candidate
MOVQ R12, R11
SUBQ R15, R11
SUBQ DX, R11
// ----------------------------------------
// Begin inline of the extendMatch call.
//
// s = extendMatch(src, candidate+4, s+4)
// !!! R14 = &src[len(src)]
MOVQ src_len+32(FP), R14
ADDQ DX, R14
// !!! R13 = &src[len(src) - 8]
MOVQ R14, R13
SUBQ $8, R13
// !!! R15 = &src[candidate + 4]
ADDQ $4, R15
ADDQ DX, R15
// !!! s += 4
ADDQ $4, SI
inlineExtendMatchCmp8:
// As long as we are 8 or more bytes before the end of src, we can load and
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
CMPQ SI, R13
JA inlineExtendMatchCmp1
MOVQ (R15), AX
MOVQ (SI), BX
CMPQ AX, BX
JNE inlineExtendMatchBSF
ADDQ $8, R15
ADDQ $8, SI
JMP inlineExtendMatchCmp8
inlineExtendMatchBSF:
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
// the index of the first byte that differs. The BSF instruction finds the
// least significant 1 bit, the amd64 architecture is little-endian, and
// the shift by 3 converts a bit index to a byte index.
XORQ AX, BX
BSFQ BX, BX
SHRQ $3, BX
ADDQ BX, SI
JMP inlineExtendMatchEnd
inlineExtendMatchCmp1:
// In src's tail, compare 1 byte at a time.
CMPQ SI, R14
JAE inlineExtendMatchEnd
MOVB (R15), AX
MOVB (SI), BX
CMPB AX, BX
JNE inlineExtendMatchEnd
ADDQ $1, R15
ADDQ $1, SI
JMP inlineExtendMatchCmp1
inlineExtendMatchEnd:
// End inline of the extendMatch call.
// ----------------------------------------
// ----------------------------------------
// Begin inline of the emitCopy call.
//
// d += emitCopy(dst[d:], base-candidate, s-base)
// !!! length := s - base
MOVQ SI, AX
SUBQ R12, AX
inlineEmitCopyLoop0:
// for length >= 68 { etc }
CMPL AX, $68
JLT inlineEmitCopyStep1
// Emit a length 64 copy, encoded as 3 bytes.
MOVB $0xfe, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $64, AX
JMP inlineEmitCopyLoop0
inlineEmitCopyStep1:
// if length > 64 { etc }
CMPL AX, $64
JLE inlineEmitCopyStep2
// Emit a length 60 copy, encoded as 3 bytes.
MOVB $0xee, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
SUBL $60, AX
inlineEmitCopyStep2:
// if length >= 12 || offset >= 2048 { goto inlineEmitCopyStep3 }
CMPL AX, $12
JGE inlineEmitCopyStep3
CMPL R11, $2048
JGE inlineEmitCopyStep3
// Emit the remaining copy, encoded as 2 bytes.
MOVB R11, 1(DI)
SHRL $8, R11
SHLB $5, R11
SUBB $4, AX
SHLB $2, AX
ORB AX, R11
ORB $1, R11
MOVB R11, 0(DI)
ADDQ $2, DI
JMP inlineEmitCopyEnd
inlineEmitCopyStep3:
// Emit the remaining copy, encoded as 3 bytes.
SUBL $1, AX
SHLB $2, AX
ORB $2, AX
MOVB AX, 0(DI)
MOVW R11, 1(DI)
ADDQ $3, DI
inlineEmitCopyEnd:
// End inline of the emitCopy call.
// ----------------------------------------
// nextEmit = s
MOVQ SI, R10
// if s >= sLimit { goto emitRemainder }
MOVQ SI, AX
SUBQ DX, AX
CMPQ AX, R9
JAE emitRemainder
// As per the encode_other.go code:
//
// We could immediately etc.
// x := load64(src, s-1)
MOVQ -1(SI), R14
// prevHash := hash(uint32(x>>0), shift)
MOVL R14, R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// table[prevHash] = uint16(s-1)
MOVQ SI, AX
SUBQ DX, AX
SUBQ $1, AX
// XXX: MOVW AX, table-32768(SP)(R11*2)
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
BYTE $0x66
BYTE $0x42
BYTE $0x89
BYTE $0x44
BYTE $0x5c
BYTE $0x78
// currHash := hash(uint32(x>>8), shift)
SHRQ $8, R14
MOVL R14, R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// candidate = int(table[currHash])
// XXX: MOVWQZX table-32768(SP)(R11*2), R15
// XXX: 4e 0f b7 7c 5c 78 movzwq 0x78(%rsp,%r11,2),%r15
BYTE $0x4e
BYTE $0x0f
BYTE $0xb7
BYTE $0x7c
BYTE $0x5c
BYTE $0x78
// table[currHash] = uint16(s)
ADDQ $1, AX
// XXX: MOVW AX, table-32768(SP)(R11*2)
// XXX: 66 42 89 44 5c 78 mov %ax,0x78(%rsp,%r11,2)
BYTE $0x66
BYTE $0x42
BYTE $0x89
BYTE $0x44
BYTE $0x5c
BYTE $0x78
// if uint32(x>>8) == load32(src, candidate) { continue }
MOVL (DX)(R15*1), BX
CMPL R14, BX
JEQ inner1
// nextHash = hash(uint32(x>>16), shift)
SHRQ $8, R14
MOVL R14, R11
IMULL $0x1e35a7bd, R11
SHRL CX, R11
// s++
ADDQ $1, SI
// break out of the inner1 for loop, i.e. continue the outer loop.
JMP outer
emitRemainder:
// if nextEmit < len(src) { etc }
MOVQ src_len+32(FP), AX
ADDQ DX, AX
CMPQ R10, AX
JEQ encodeBlockEnd
// d += emitLiteral(dst[d:], src[nextEmit:])
//
// Push args.
MOVQ DI, 0(SP)
MOVQ $0, 8(SP) // Unnecessary, as the callee ignores it, but conservative.
MOVQ $0, 16(SP) // Unnecessary, as the callee ignores it, but conservative.
MOVQ R10, 24(SP)
SUBQ R10, AX
MOVQ AX, 32(SP)
MOVQ AX, 40(SP) // Unnecessary, as the callee ignores it, but conservative.
// Spill local variables (registers) onto the stack; call; unspill.
MOVQ DI, 80(SP)
CALL ·emitLiteral(SB)
MOVQ 80(SP), DI
// Finish the "d +=" part of "d += emitLiteral(etc)".
ADDQ 48(SP), DI
encodeBlockEnd:
MOVQ dst_base+0(FP), AX
SUBQ AX, DI
MOVQ DI, d+48(FP)
RET

@ -0,0 +1,722 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
#include "textflag.h"
// The asm code generally follows the pure Go code in encode_other.go, except
// where marked with a "!!!".
// ----------------------------------------------------------------------------
// func emitLiteral(dst, lit []byte) int
//
// All local variables fit into registers. The register allocation:
// - R3 len(lit)
// - R4 n
// - R6 return value
// - R8 &dst[i]
// - R10 &lit[0]
//
// The 32 bytes of stack space is to call runtime·memmove.
//
// The unusual register allocation of local variables, such as R10 for the
// source pointer, matches the allocation used at the call site in encodeBlock,
// which makes it easier to manually inline this function.
TEXT ·emitLiteral(SB), NOSPLIT, $32-56
MOVD dst_base+0(FP), R8
MOVD lit_base+24(FP), R10
MOVD lit_len+32(FP), R3
MOVD R3, R6
MOVW R3, R4
SUBW $1, R4, R4
CMPW $60, R4
BLT oneByte
CMPW $256, R4
BLT twoBytes
threeBytes:
MOVD $0xf4, R2
MOVB R2, 0(R8)
MOVW R4, 1(R8)
ADD $3, R8, R8
ADD $3, R6, R6
B memmove
twoBytes:
MOVD $0xf0, R2
MOVB R2, 0(R8)
MOVB R4, 1(R8)
ADD $2, R8, R8
ADD $2, R6, R6
B memmove
oneByte:
LSLW $2, R4, R4
MOVB R4, 0(R8)
ADD $1, R8, R8
ADD $1, R6, R6
memmove:
MOVD R6, ret+48(FP)
// copy(dst[i:], lit)
//
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
// R8, R10 and R3 as arguments.
MOVD R8, 8(RSP)
MOVD R10, 16(RSP)
MOVD R3, 24(RSP)
CALL runtime·memmove(SB)
RET
// ----------------------------------------------------------------------------
// func emitCopy(dst []byte, offset, length int) int
//
// All local variables fit into registers. The register allocation:
// - R3 length
// - R7 &dst[0]
// - R8 &dst[i]
// - R11 offset
//
// The unusual register allocation of local variables, such as R11 for the
// offset, matches the allocation used at the call site in encodeBlock, which
// makes it easier to manually inline this function.
TEXT ·emitCopy(SB), NOSPLIT, $0-48
MOVD dst_base+0(FP), R8
MOVD R8, R7
MOVD offset+24(FP), R11
MOVD length+32(FP), R3
loop0:
// for length >= 68 { etc }
CMPW $68, R3
BLT step1
// Emit a length 64 copy, encoded as 3 bytes.
MOVD $0xfe, R2
MOVB R2, 0(R8)
MOVW R11, 1(R8)
ADD $3, R8, R8
SUB $64, R3, R3
B loop0
step1:
// if length > 64 { etc }
CMP $64, R3
BLE step2
// Emit a length 60 copy, encoded as 3 bytes.
MOVD $0xee, R2
MOVB R2, 0(R8)
MOVW R11, 1(R8)
ADD $3, R8, R8
SUB $60, R3, R3
step2:
// if length >= 12 || offset >= 2048 { goto step3 }
CMP $12, R3
BGE step3
CMPW $2048, R11
BGE step3
// Emit the remaining copy, encoded as 2 bytes.
MOVB R11, 1(R8)
LSRW $3, R11, R11
AND $0xe0, R11, R11
SUB $4, R3, R3
LSLW $2, R3
AND $0xff, R3, R3
ORRW R3, R11, R11
ORRW $1, R11, R11
MOVB R11, 0(R8)
ADD $2, R8, R8
// Return the number of bytes written.
SUB R7, R8, R8
MOVD R8, ret+40(FP)
RET
step3:
// Emit the remaining copy, encoded as 3 bytes.
SUB $1, R3, R3
AND $0xff, R3, R3
LSLW $2, R3, R3
ORRW $2, R3, R3
MOVB R3, 0(R8)
MOVW R11, 1(R8)
ADD $3, R8, R8
// Return the number of bytes written.
SUB R7, R8, R8
MOVD R8, ret+40(FP)
RET
// ----------------------------------------------------------------------------
// func extendMatch(src []byte, i, j int) int
//
// All local variables fit into registers. The register allocation:
// - R6 &src[0]
// - R7 &src[j]
// - R13 &src[len(src) - 8]
// - R14 &src[len(src)]
// - R15 &src[i]
//
// The unusual register allocation of local variables, such as R15 for a source
// pointer, matches the allocation used at the call site in encodeBlock, which
// makes it easier to manually inline this function.
TEXT ·extendMatch(SB), NOSPLIT, $0-48
MOVD src_base+0(FP), R6
MOVD src_len+8(FP), R14
MOVD i+24(FP), R15
MOVD j+32(FP), R7
ADD R6, R14, R14
ADD R6, R15, R15
ADD R6, R7, R7
MOVD R14, R13
SUB $8, R13, R13
cmp8:
// As long as we are 8 or more bytes before the end of src, we can load and
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
CMP R13, R7
BHI cmp1
MOVD (R15), R3
MOVD (R7), R4
CMP R4, R3
BNE bsf
ADD $8, R15, R15
ADD $8, R7, R7
B cmp8
bsf:
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
// the index of the first byte that differs.
// RBIT reverses the bit order, then CLZ counts the leading zeros, the
// combination of which finds the least significant bit which is set.
// The arm64 architecture is little-endian, and the shift by 3 converts
// a bit index to a byte index.
EOR R3, R4, R4
RBIT R4, R4
CLZ R4, R4
ADD R4>>3, R7, R7
// Convert from &src[ret] to ret.
SUB R6, R7, R7
MOVD R7, ret+40(FP)
RET
cmp1:
// In src's tail, compare 1 byte at a time.
CMP R7, R14
BLS extendMatchEnd
MOVB (R15), R3
MOVB (R7), R4
CMP R4, R3
BNE extendMatchEnd
ADD $1, R15, R15
ADD $1, R7, R7
B cmp1
extendMatchEnd:
// Convert from &src[ret] to ret.
SUB R6, R7, R7
MOVD R7, ret+40(FP)
RET
// ----------------------------------------------------------------------------
// func encodeBlock(dst, src []byte) (d int)
//
// All local variables fit into registers, other than "var table". The register
// allocation:
// - R3 . .
// - R4 . .
// - R5 64 shift
// - R6 72 &src[0], tableSize
// - R7 80 &src[s]
// - R8 88 &dst[d]
// - R9 96 sLimit
// - R10 . &src[nextEmit]
// - R11 104 prevHash, currHash, nextHash, offset
// - R12 112 &src[base], skip
// - R13 . &src[nextS], &src[len(src) - 8]
// - R14 . len(src), bytesBetweenHashLookups, &src[len(src)], x
// - R15 120 candidate
// - R16 . hash constant, 0x1e35a7bd
// - R17 . &table
// - . 128 table
//
// The second column (64, 72, etc) is the stack offset to spill the registers
// when calling other functions. We could pack this slightly tighter, but it's
// simpler to have a dedicated spill map independent of the function called.
//
// "var table [maxTableSize]uint16" takes up 32768 bytes of stack space. An
// extra 64 bytes, to call other functions, and an extra 64 bytes, to spill
// local variables (registers) during calls gives 32768 + 64 + 64 = 32896.
TEXT ·encodeBlock(SB), 0, $32896-56
MOVD dst_base+0(FP), R8
MOVD src_base+24(FP), R7
MOVD src_len+32(FP), R14
// shift, tableSize := uint32(32-8), 1<<8
MOVD $24, R5
MOVD $256, R6
MOVW $0xa7bd, R16
MOVKW $(0x1e35<<16), R16
calcShift:
// for ; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 {
// shift--
// }
MOVD $16384, R2
CMP R2, R6
BGE varTable
CMP R14, R6
BGE varTable
SUB $1, R5, R5
LSL $1, R6, R6
B calcShift
varTable:
// var table [maxTableSize]uint16
//
// In the asm code, unlike the Go code, we can zero-initialize only the
// first tableSize elements. Each uint16 element is 2 bytes and each
// iterations writes 64 bytes, so we can do only tableSize/32 writes
// instead of the 2048 writes that would zero-initialize all of table's
// 32768 bytes. This clear could overrun the first tableSize elements, but
// it won't overrun the allocated stack size.
ADD $128, RSP, R17
MOVD R17, R4
// !!! R6 = &src[tableSize]
ADD R6<<1, R17, R6
memclr:
STP.P (ZR, ZR), 64(R4)
STP (ZR, ZR), -48(R4)
STP (ZR, ZR), -32(R4)
STP (ZR, ZR), -16(R4)
CMP R4, R6
BHI memclr
// !!! R6 = &src[0]
MOVD R7, R6
// sLimit := len(src) - inputMargin
MOVD R14, R9
SUB $15, R9, R9
// !!! Pre-emptively spill R5, R6 and R9 to the stack. Their values don't
// change for the rest of the function.
MOVD R5, 64(RSP)
MOVD R6, 72(RSP)
MOVD R9, 96(RSP)
// nextEmit := 0
MOVD R6, R10
// s := 1
ADD $1, R7, R7
// nextHash := hash(load32(src, s), shift)
MOVW 0(R7), R11
MULW R16, R11, R11
LSRW R5, R11, R11
outer:
// for { etc }
// skip := 32
MOVD $32, R12
// nextS := s
MOVD R7, R13
// candidate := 0
MOVD $0, R15
inner0:
// for { etc }
// s := nextS
MOVD R13, R7
// bytesBetweenHashLookups := skip >> 5
MOVD R12, R14
LSR $5, R14, R14
// nextS = s + bytesBetweenHashLookups
ADD R14, R13, R13
// skip += bytesBetweenHashLookups
ADD R14, R12, R12
// if nextS > sLimit { goto emitRemainder }
MOVD R13, R3
SUB R6, R3, R3
CMP R9, R3
BHI emitRemainder
// candidate = int(table[nextHash])
MOVHU 0(R17)(R11<<1), R15
// table[nextHash] = uint16(s)
MOVD R7, R3
SUB R6, R3, R3
MOVH R3, 0(R17)(R11<<1)
// nextHash = hash(load32(src, nextS), shift)
MOVW 0(R13), R11
MULW R16, R11
LSRW R5, R11, R11
// if load32(src, s) != load32(src, candidate) { continue } break
MOVW 0(R7), R3
MOVW (R6)(R15), R4
CMPW R4, R3
BNE inner0
fourByteMatch:
// As per the encode_other.go code:
//
// A 4-byte match has been found. We'll later see etc.
// !!! Jump to a fast path for short (<= 16 byte) literals. See the comment
// on inputMargin in encode.go.
MOVD R7, R3
SUB R10, R3, R3
CMP $16, R3
BLE emitLiteralFastPath
// ----------------------------------------
// Begin inline of the emitLiteral call.
//
// d += emitLiteral(dst[d:], src[nextEmit:s])
MOVW R3, R4
SUBW $1, R4, R4
MOVW $60, R2
CMPW R2, R4
BLT inlineEmitLiteralOneByte
MOVW $256, R2
CMPW R2, R4
BLT inlineEmitLiteralTwoBytes
inlineEmitLiteralThreeBytes:
MOVD $0xf4, R1
MOVB R1, 0(R8)
MOVW R4, 1(R8)
ADD $3, R8, R8
B inlineEmitLiteralMemmove
inlineEmitLiteralTwoBytes:
MOVD $0xf0, R1
MOVB R1, 0(R8)
MOVB R4, 1(R8)
ADD $2, R8, R8
B inlineEmitLiteralMemmove
inlineEmitLiteralOneByte:
LSLW $2, R4, R4
MOVB R4, 0(R8)
ADD $1, R8, R8
inlineEmitLiteralMemmove:
// Spill local variables (registers) onto the stack; call; unspill.
//
// copy(dst[i:], lit)
//
// This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push
// R8, R10 and R3 as arguments.
MOVD R8, 8(RSP)
MOVD R10, 16(RSP)
MOVD R3, 24(RSP)
// Finish the "d +=" part of "d += emitLiteral(etc)".
ADD R3, R8, R8
MOVD R7, 80(RSP)
MOVD R8, 88(RSP)
MOVD R15, 120(RSP)
CALL runtime·memmove(SB)
MOVD 64(RSP), R5
MOVD 72(RSP), R6
MOVD 80(RSP), R7
MOVD 88(RSP), R8
MOVD 96(RSP), R9
MOVD 120(RSP), R15
ADD $128, RSP, R17
MOVW $0xa7bd, R16
MOVKW $(0x1e35<<16), R16
B inner1
inlineEmitLiteralEnd:
// End inline of the emitLiteral call.
// ----------------------------------------
emitLiteralFastPath:
// !!! Emit the 1-byte encoding "uint8(len(lit)-1)<<2".
MOVB R3, R4
SUBW $1, R4, R4
AND $0xff, R4, R4
LSLW $2, R4, R4
MOVB R4, (R8)
ADD $1, R8, R8
// !!! Implement the copy from lit to dst as a 16-byte load and store.
// (Encode's documentation says that dst and src must not overlap.)
//
// This always copies 16 bytes, instead of only len(lit) bytes, but that's
// OK. Subsequent iterations will fix up the overrun.
//
// Note that on arm64, it is legal and cheap to issue unaligned 8-byte or
// 16-byte loads and stores. This technique probably wouldn't be as
// effective on architectures that are fussier about alignment.
LDP 0(R10), (R0, R1)
STP (R0, R1), 0(R8)
ADD R3, R8, R8
inner1:
// for { etc }
// base := s
MOVD R7, R12
// !!! offset := base - candidate
MOVD R12, R11
SUB R15, R11, R11
SUB R6, R11, R11
// ----------------------------------------
// Begin inline of the extendMatch call.
//
// s = extendMatch(src, candidate+4, s+4)
// !!! R14 = &src[len(src)]
MOVD src_len+32(FP), R14
ADD R6, R14, R14
// !!! R13 = &src[len(src) - 8]
MOVD R14, R13
SUB $8, R13, R13
// !!! R15 = &src[candidate + 4]
ADD $4, R15, R15
ADD R6, R15, R15
// !!! s += 4
ADD $4, R7, R7
inlineExtendMatchCmp8:
// As long as we are 8 or more bytes before the end of src, we can load and
// compare 8 bytes at a time. If those 8 bytes are equal, repeat.
CMP R13, R7
BHI inlineExtendMatchCmp1
MOVD (R15), R3
MOVD (R7), R4
CMP R4, R3
BNE inlineExtendMatchBSF
ADD $8, R15, R15
ADD $8, R7, R7
B inlineExtendMatchCmp8
inlineExtendMatchBSF:
// If those 8 bytes were not equal, XOR the two 8 byte values, and return
// the index of the first byte that differs.
// RBIT reverses the bit order, then CLZ counts the leading zeros, the
// combination of which finds the least significant bit which is set.
// The arm64 architecture is little-endian, and the shift by 3 converts
// a bit index to a byte index.
EOR R3, R4, R4
RBIT R4, R4
CLZ R4, R4
ADD R4>>3, R7, R7
B inlineExtendMatchEnd
inlineExtendMatchCmp1:
// In src's tail, compare 1 byte at a time.
CMP R7, R14
BLS inlineExtendMatchEnd
MOVB (R15), R3
MOVB (R7), R4
CMP R4, R3
BNE inlineExtendMatchEnd
ADD $1, R15, R15
ADD $1, R7, R7
B inlineExtendMatchCmp1
inlineExtendMatchEnd:
// End inline of the extendMatch call.
// ----------------------------------------
// ----------------------------------------
// Begin inline of the emitCopy call.
//
// d += emitCopy(dst[d:], base-candidate, s-base)
// !!! length := s - base
MOVD R7, R3
SUB R12, R3, R3
inlineEmitCopyLoop0:
// for length >= 68 { etc }
MOVW $68, R2
CMPW R2, R3
BLT inlineEmitCopyStep1
// Emit a length 64 copy, encoded as 3 bytes.
MOVD $0xfe, R1
MOVB R1, 0(R8)
MOVW R11, 1(R8)
ADD $3, R8, R8
SUBW $64, R3, R3
B inlineEmitCopyLoop0
inlineEmitCopyStep1:
// if length > 64 { etc }
MOVW $64, R2
CMPW R2, R3
BLE inlineEmitCopyStep2
// Emit a length 60 copy, encoded as 3 bytes.
MOVD $0xee, R1
MOVB R1, 0(R8)
MOVW R11, 1(R8)
ADD $3, R8, R8
SUBW $60, R3, R3
inlineEmitCopyStep2:
// if length >= 12 || offset >= 2048 { goto inlineEmitCopyStep3 }
MOVW $12, R2
CMPW R2, R3
BGE inlineEmitCopyStep3
MOVW $2048, R2
CMPW R2, R11
BGE inlineEmitCopyStep3
// Emit the remaining copy, encoded as 2 bytes.
MOVB R11, 1(R8)
LSRW $8, R11, R11
LSLW $5, R11, R11
SUBW $4, R3, R3
AND $0xff, R3, R3
LSLW $2, R3, R3
ORRW R3, R11, R11
ORRW $1, R11, R11
MOVB R11, 0(R8)
ADD $2, R8, R8
B inlineEmitCopyEnd
inlineEmitCopyStep3:
// Emit the remaining copy, encoded as 3 bytes.
SUBW $1, R3, R3
LSLW $2, R3, R3
ORRW $2, R3, R3
MOVB R3, 0(R8)
MOVW R11, 1(R8)
ADD $3, R8, R8
inlineEmitCopyEnd:
// End inline of the emitCopy call.
// ----------------------------------------
// nextEmit = s
MOVD R7, R10
// if s >= sLimit { goto emitRemainder }
MOVD R7, R3
SUB R6, R3, R3
CMP R3, R9
BLS emitRemainder
// As per the encode_other.go code:
//
// We could immediately etc.
// x := load64(src, s-1)
MOVD -1(R7), R14
// prevHash := hash(uint32(x>>0), shift)
MOVW R14, R11
MULW R16, R11, R11
LSRW R5, R11, R11
// table[prevHash] = uint16(s-1)
MOVD R7, R3
SUB R6, R3, R3
SUB $1, R3, R3
MOVHU R3, 0(R17)(R11<<1)
// currHash := hash(uint32(x>>8), shift)
LSR $8, R14, R14
MOVW R14, R11
MULW R16, R11, R11
LSRW R5, R11, R11
// candidate = int(table[currHash])
MOVHU 0(R17)(R11<<1), R15
// table[currHash] = uint16(s)
ADD $1, R3, R3
MOVHU R3, 0(R17)(R11<<1)
// if uint32(x>>8) == load32(src, candidate) { continue }
MOVW (R6)(R15), R4
CMPW R4, R14
BEQ inner1
// nextHash = hash(uint32(x>>16), shift)
LSR $8, R14, R14
MOVW R14, R11
MULW R16, R11, R11
LSRW R5, R11, R11
// s++
ADD $1, R7, R7
// break out of the inner1 for loop, i.e. continue the outer loop.
B outer
emitRemainder:
// if nextEmit < len(src) { etc }
MOVD src_len+32(FP), R3
ADD R6, R3, R3
CMP R3, R10
BEQ encodeBlockEnd
// d += emitLiteral(dst[d:], src[nextEmit:])
//
// Push args.
MOVD R8, 8(RSP)
MOVD $0, 16(RSP) // Unnecessary, as the callee ignores it, but conservative.
MOVD $0, 24(RSP) // Unnecessary, as the callee ignores it, but conservative.
MOVD R10, 32(RSP)
SUB R10, R3, R3
MOVD R3, 40(RSP)
MOVD R3, 48(RSP) // Unnecessary, as the callee ignores it, but conservative.
// Spill local variables (registers) onto the stack; call; unspill.
MOVD R8, 88(RSP)
CALL ·emitLiteral(SB)
MOVD 88(RSP), R8
// Finish the "d +=" part of "d += emitLiteral(etc)".
MOVD 56(RSP), R1
ADD R1, R8, R8
encodeBlockEnd:
MOVD dst_base+0(FP), R3
SUB R3, R8, R8
MOVD R8, d+48(FP)
RET

@ -0,0 +1,30 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// +build gc
// +build !noasm
// +build amd64 arm64
package snappy
// emitLiteral has the same semantics as in encode_other.go.
//
//go:noescape
func emitLiteral(dst, lit []byte) int
// emitCopy has the same semantics as in encode_other.go.
//
//go:noescape
func emitCopy(dst []byte, offset, length int) int
// extendMatch has the same semantics as in encode_other.go.
//
//go:noescape
func extendMatch(src []byte, i, j int) int
// encodeBlock has the same semantics as in encode_other.go.
//
//go:noescape
func encodeBlock(dst, src []byte) (d int)

@ -0,0 +1,238 @@
// Copyright 2016 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !amd64,!arm64 appengine !gc noasm
package snappy
func load32(b []byte, i int) uint32 {
b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
}
func load64(b []byte, i int) uint64 {
b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line.
return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 |
uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56
}
// emitLiteral writes a literal chunk and returns the number of bytes written.
//
// It assumes that:
// dst is long enough to hold the encoded bytes
// 1 <= len(lit) && len(lit) <= 65536
func emitLiteral(dst, lit []byte) int {
i, n := 0, uint(len(lit)-1)
switch {
case n < 60:
dst[0] = uint8(n)<<2 | tagLiteral
i = 1
case n < 1<<8:
dst[0] = 60<<2 | tagLiteral
dst[1] = uint8(n)
i = 2
default:
dst[0] = 61<<2 | tagLiteral
dst[1] = uint8(n)
dst[2] = uint8(n >> 8)
i = 3
}
return i + copy(dst[i:], lit)
}
// emitCopy writes a copy chunk and returns the number of bytes written.
//
// It assumes that:
// dst is long enough to hold the encoded bytes
// 1 <= offset && offset <= 65535
// 4 <= length && length <= 65535
func emitCopy(dst []byte, offset, length int) int {
i := 0
// The maximum length for a single tagCopy1 or tagCopy2 op is 64 bytes. The
// threshold for this loop is a little higher (at 68 = 64 + 4), and the
// length emitted down below is is a little lower (at 60 = 64 - 4), because
// it's shorter to encode a length 67 copy as a length 60 tagCopy2 followed
// by a length 7 tagCopy1 (which encodes as 3+2 bytes) than to encode it as
// a length 64 tagCopy2 followed by a length 3 tagCopy2 (which encodes as
// 3+3 bytes). The magic 4 in the 64±4 is because the minimum length for a
// tagCopy1 op is 4 bytes, which is why a length 3 copy has to be an
// encodes-as-3-bytes tagCopy2 instead of an encodes-as-2-bytes tagCopy1.
for length >= 68 {
// Emit a length 64 copy, encoded as 3 bytes.
dst[i+0] = 63<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
i += 3
length -= 64
}
if length > 64 {
// Emit a length 60 copy, encoded as 3 bytes.
dst[i+0] = 59<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
i += 3
length -= 60
}
if length >= 12 || offset >= 2048 {
// Emit the remaining copy, encoded as 3 bytes.
dst[i+0] = uint8(length-1)<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
return i + 3
}
// Emit the remaining copy, encoded as 2 bytes.
dst[i+0] = uint8(offset>>8)<<5 | uint8(length-4)<<2 | tagCopy1
dst[i+1] = uint8(offset)
return i + 2
}
// extendMatch returns the largest k such that k <= len(src) and that
// src[i:i+k-j] and src[j:k] have the same contents.
//
// It assumes that:
// 0 <= i && i < j && j <= len(src)
func extendMatch(src []byte, i, j int) int {
for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 {
}
return j
}
func hash(u, shift uint32) uint32 {
return (u * 0x1e35a7bd) >> shift
}
// encodeBlock encodes a non-empty src to a guaranteed-large-enough dst. It
// assumes that the varint-encoded length of the decompressed bytes has already
// been written.
//
// It also assumes that:
// len(dst) >= MaxEncodedLen(len(src)) &&
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
func encodeBlock(dst, src []byte) (d int) {
// Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive.
// The table element type is uint16, as s < sLimit and sLimit < len(src)
// and len(src) <= maxBlockSize and maxBlockSize == 65536.
const (
maxTableSize = 1 << 14
// tableMask is redundant, but helps the compiler eliminate bounds
// checks.
tableMask = maxTableSize - 1
)
shift := uint32(32 - 8)
for tableSize := 1 << 8; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 {
shift--
}
// In Go, all array elements are zero-initialized, so there is no advantage
// to a smaller tableSize per se. However, it matches the C++ algorithm,
// and in the asm versions of this code, we can get away with zeroing only
// the first tableSize elements.
var table [maxTableSize]uint16
// sLimit is when to stop looking for offset/length copies. The inputMargin
// lets us use a fast path for emitLiteral in the main loop, while we are
// looking for copies.
sLimit := len(src) - inputMargin
// nextEmit is where in src the next emitLiteral should start from.
nextEmit := 0
// The encoded form must start with a literal, as there are no previous
// bytes to copy, so we start looking for hash matches at s == 1.
s := 1
nextHash := hash(load32(src, s), shift)
for {
// Copied from the C++ snappy implementation:
//
// Heuristic match skipping: If 32 bytes are scanned with no matches
// found, start looking only at every other byte. If 32 more bytes are
// scanned (or skipped), look at every third byte, etc.. When a match
// is found, immediately go back to looking at every byte. This is a
// small loss (~5% performance, ~0.1% density) for compressible data
// due to more bookkeeping, but for non-compressible data (such as
// JPEG) it's a huge win since the compressor quickly "realizes" the
// data is incompressible and doesn't bother looking for matches
// everywhere.
//
// The "skip" variable keeps track of how many bytes there are since
// the last match; dividing it by 32 (ie. right-shifting by five) gives
// the number of bytes to move ahead for each iteration.
skip := 32
nextS := s
candidate := 0
for {
s = nextS
bytesBetweenHashLookups := skip >> 5
nextS = s + bytesBetweenHashLookups
skip += bytesBetweenHashLookups
if nextS > sLimit {
goto emitRemainder
}
candidate = int(table[nextHash&tableMask])
table[nextHash&tableMask] = uint16(s)
nextHash = hash(load32(src, nextS), shift)
if load32(src, s) == load32(src, candidate) {
break
}
}
// A 4-byte match has been found. We'll later see if more than 4 bytes
// match. But, prior to the match, src[nextEmit:s] are unmatched. Emit
// them as literal bytes.
d += emitLiteral(dst[d:], src[nextEmit:s])
// Call emitCopy, and then see if another emitCopy could be our next
// move. Repeat until we find no match for the input immediately after
// what was consumed by the last emitCopy call.
//
// If we exit this loop normally then we need to call emitLiteral next,
// though we don't yet know how big the literal will be. We handle that
// by proceeding to the next iteration of the main loop. We also can
// exit this loop via goto if we get close to exhausting the input.
for {
// Invariant: we have a 4-byte match at s, and no need to emit any
// literal bytes prior to s.
base := s
// Extend the 4-byte match as long as possible.
//
// This is an inlined version of:
// s = extendMatch(src, candidate+4, s+4)
s += 4
for i := candidate + 4; s < len(src) && src[i] == src[s]; i, s = i+1, s+1 {
}
d += emitCopy(dst[d:], base-candidate, s-base)
nextEmit = s
if s >= sLimit {
goto emitRemainder
}
// We could immediately start working at s now, but to improve
// compression we first update the hash table at s-1 and at s. If
// another emitCopy is not our next move, also calculate nextHash
// at s+1. At least on GOARCH=amd64, these three hash calculations
// are faster as one load64 call (with some shifts) instead of
// three load32 calls.
x := load64(src, s-1)
prevHash := hash(uint32(x>>0), shift)
table[prevHash&tableMask] = uint16(s - 1)
currHash := hash(uint32(x>>8), shift)
candidate = int(table[currHash&tableMask])
table[currHash&tableMask] = uint16(s)
if uint32(x>>8) != load32(src, candidate) {
nextHash = hash(uint32(x>>16), shift)
s++
break
}
}
}
emitRemainder:
if nextEmit < len(src) {
d += emitLiteral(dst[d:], src[nextEmit:])
}
return d
}

@ -0,0 +1,98 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package snappy implements the Snappy compression format. It aims for very
// high speeds and reasonable compression.
//
// There are actually two Snappy formats: block and stream. They are related,
// but different: trying to decompress block-compressed data as a Snappy stream
// will fail, and vice versa. The block format is the Decode and Encode
// functions and the stream format is the Reader and Writer types.
//
// The block format, the more common case, is used when the complete size (the
// number of bytes) of the original data is known upfront, at the time
// compression starts. The stream format, also known as the framing format, is
// for when that isn't always true.
//
// The canonical, C++ implementation is at https://github.com/google/snappy and
// it only implements the block format.
package snappy // import "github.com/golang/snappy"
import (
"hash/crc32"
)
/*
Each encoded block begins with the varint-encoded length of the decoded data,
followed by a sequence of chunks. Chunks begin and end on byte boundaries. The
first byte of each chunk is broken into its 2 least and 6 most significant bits
called l and m: l ranges in [0, 4) and m ranges in [0, 64). l is the chunk tag.
Zero means a literal tag. All other values mean a copy tag.
For literal tags:
- If m < 60, the next 1 + m bytes are literal bytes.
- Otherwise, let n be the little-endian unsigned integer denoted by the next
m - 59 bytes. The next 1 + n bytes after that are literal bytes.
For copy tags, length bytes are copied from offset bytes ago, in the style of
Lempel-Ziv compression algorithms. In particular:
- For l == 1, the offset ranges in [0, 1<<11) and the length in [4, 12).
The length is 4 + the low 3 bits of m. The high 3 bits of m form bits 8-10
of the offset. The next byte is bits 0-7 of the offset.
- For l == 2, the offset ranges in [0, 1<<16) and the length in [1, 65).
The length is 1 + m. The offset is the little-endian unsigned integer
denoted by the next 2 bytes.
- For l == 3, this tag is a legacy format that is no longer issued by most
encoders. Nonetheless, the offset ranges in [0, 1<<32) and the length in
[1, 65). The length is 1 + m. The offset is the little-endian unsigned
integer denoted by the next 4 bytes.
*/
const (
tagLiteral = 0x00
tagCopy1 = 0x01
tagCopy2 = 0x02
tagCopy4 = 0x03
)
const (
checksumSize = 4
chunkHeaderSize = 4
magicChunk = "\xff\x06\x00\x00" + magicBody
magicBody = "sNaPpY"
// maxBlockSize is the maximum size of the input to encodeBlock. It is not
// part of the wire format per se, but some parts of the encoder assume
// that an offset fits into a uint16.
//
// Also, for the framing format (Writer type instead of Encode function),
// https://github.com/google/snappy/blob/master/framing_format.txt says
// that "the uncompressed data in a chunk must be no longer than 65536
// bytes".
maxBlockSize = 65536
// maxEncodedLenOfMaxBlockSize equals MaxEncodedLen(maxBlockSize), but is
// hard coded to be a const instead of a variable, so that obufLen can also
// be a const. Their equivalence is confirmed by
// TestMaxEncodedLenOfMaxBlockSize.
maxEncodedLenOfMaxBlockSize = 76490
obufHeaderLen = len(magicChunk) + checksumSize + chunkHeaderSize
obufLen = obufHeaderLen + maxEncodedLenOfMaxBlockSize
)
const (
chunkTypeCompressedData = 0x00
chunkTypeUncompressedData = 0x01
chunkTypePadding = 0xfe
chunkTypeStreamIdentifier = 0xff
)
var crcTable = crc32.MakeTable(crc32.Castagnoli)
// crc implements the checksum specified in section 3 of
// https://github.com/google/snappy/blob/master/framing_format.txt
func crc(b []byte) uint32 {
c := crc32.Update(0, crcTable, b)
return uint32(c>>15|c<<17) + 0xa282ead8
}

@ -0,0 +1,9 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

@ -0,0 +1,22 @@
Copyright (c) 2019 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,8 @@
[![](https://godoc.org/github.com/jackc/chunkreader?status.svg)](https://godoc.org/github.com/jackc/chunkreader)
[![Build Status](https://travis-ci.org/jackc/chunkreader.svg)](https://travis-ci.org/jackc/chunkreader)
# chunkreader
Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations.
Extracted from original implementation in https://github.com/jackc/pgx.

@ -0,0 +1,104 @@
// Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations.
package chunkreader
import (
"io"
)
// ChunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and
// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually
// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy.
//
// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is
// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare
// cases it would be advantageous to copy the bytes to another slice.
type ChunkReader struct {
r io.Reader
buf []byte
rp, wp int // buf read position and write position
config Config
}
// Config contains configuration parameters for ChunkReader.
type Config struct {
MinBufLen int // Minimum buffer length
}
// New creates and returns a new ChunkReader for r with default configuration.
func New(r io.Reader) *ChunkReader {
cr, err := NewConfig(r, Config{})
if err != nil {
panic("default config can't be bad")
}
return cr
}
// NewConfig creates and a new ChunkReader for r configured by config.
func NewConfig(r io.Reader, config Config) (*ChunkReader, error) {
if config.MinBufLen == 0 {
// By historical reasons Postgres currently has 8KB send buffer inside,
// so here we want to have at least the same size buffer.
// @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134
// @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru
config.MinBufLen = 8192
}
return &ChunkReader{
r: r,
buf: make([]byte, config.MinBufLen),
config: config,
}, nil
}
// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy
// of buf. If an error occurs, buf will be nil.
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
// n bytes already in buf
if (r.wp - r.rp) >= n {
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, err
}
// available space in buf is less than n
if len(r.buf) < n {
r.copyBufContents(r.newBuf(n))
}
// buf is large enough, but need to shift filled area to start to make enough contiguous space
minReadCount := n - (r.wp - r.rp)
if (len(r.buf) - r.wp) < minReadCount {
newBuf := r.newBuf(n)
r.copyBufContents(newBuf)
}
if err := r.appendAtLeast(minReadCount); err != nil {
return nil, err
}
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, nil
}
func (r *ChunkReader) appendAtLeast(fillLen int) error {
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
r.wp += n
return err
}
func (r *ChunkReader) newBuf(size int) []byte {
if size < r.config.MinBufLen {
size = r.config.MinBufLen
}
return make([]byte, size)
}
func (r *ChunkReader) copyBufContents(dest []byte) {
r.wp = copy(dest, r.buf[r.rp:r.wp])
r.rp = 0
r.buf = dest
}

@ -0,0 +1,3 @@
.envrc
vendor/
.vscode

@ -0,0 +1,148 @@
# 1.12.1 (May 7, 2022)
* Fix: setting krbspn and krbsrvname in connection string (sireax)
* Add support for Unix sockets on Windows (Eno Compton)
* Stop ignoring ErrorResponse during SCRAM auth (Rafi Shamim)
# 1.12.0 (April 21, 2022)
* Add pluggable GSSAPI support (Oliver Tan)
* Fix: Consider any "0A000" error a possible cached plan changed error due to locale
* Better match psql fallback behavior with multiple hosts
# 1.11.0 (February 7, 2022)
* Support port in ip from LookupFunc to override config (James Hartig)
* Fix TLS connection timeout (Blake Embrey)
* Add support for read-only, primary, standby, prefer-standby target_session_attributes (Oscar)
* Fix connect when receiving NoticeResponse
# 1.10.1 (November 20, 2021)
* Close without waiting for response (Kei Kamikawa)
* Save waiting for network round-trip in CopyFrom (Rueian)
* Fix concurrency issue with ContextWatcher
* LRU.Get always checks context for cancellation / expiration (Georges Varouchas)
# 1.10.0 (July 24, 2021)
* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned.
# 1.9.0 (July 10, 2021)
* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr)
* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle)
* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard)
* Fix default host when parsing URL without host but with port
* Allow dbname query parameter in URL conn string
* Update underlying dependencies
# 1.8.1 (March 25, 2021)
* Better connection string sanitization (ip.novikov)
* Use proper pgpass location on Windows (Moshe Katz)
* Use errors instead of golang.org/x/xerrors
* Resume fallback on server error in Connect (Andrey Borodin)
# 1.8.0 (December 3, 2020)
* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes)
# 1.7.2 (November 3, 2020)
* Fix data value slices into work buffer with capacities larger than length.
# 1.7.1 (October 31, 2020)
* Do not asyncClose after receiving FATAL error from PostgreSQL server
# 1.7.0 (September 26, 2020)
* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded
* Add ReceiveResults (Sebastiaan Mannem)
* Fix parsing DSN connection with bad backslash
* Add PgConn.CleanupDone so connection pools can determine when async close is complete
# 1.6.4 (July 29, 2020)
* Fix deadlock on error after CommandComplete but before ReadyForQuery
* Fix panic on parsing DSN with trailing '='
# 1.6.3 (July 22, 2020)
* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo)
# 1.6.2 (July 14, 2020)
* Update pgservicefile library
# 1.6.1 (June 27, 2020)
* Update golang.org/x/crypto to latest
* Update golang.org/x/text to 0.3.3
* Fix error handling for bad PGSERVICE definition
* Redact passwords in ParseConfig errors (Lukas Vogel)
# 1.6.0 (June 6, 2020)
* Fix panic when closing conn during cancellable query
* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný)
* Fix field descriptions available after command concluded (Tobias Salzmann)
* Support connect_timeout (georgysavva)
* Handle IPv6 in connection URLs (Lukas Vogel)
* Fix ValidateConnect with cancelable context
* Improve CopyFrom performance
* Add Config.Copy (georgysavva)
# 1.5.0 (March 30, 2020)
* Update golang.org/x/crypto for security fix
* Implement "verify-ca" SSL mode (Greg Curtis)
# 1.4.0 (March 7, 2020)
* Fix ExecParams and ExecPrepared handling of empty query.
* Support reading config from PostgreSQL service files.
# 1.3.2 (February 14, 2020)
* Update chunkreader to v2.0.1 for optimized default buffer size.
# 1.3.1 (February 5, 2020)
* Fix CopyFrom deadlock when multiple NoticeResponse received during copy
# 1.3.0 (January 23, 2020)
* Add Hijack and Construct.
* Update pgproto3 to v2.0.1.
# 1.2.1 (January 13, 2020)
* Fix data race in context cancellation introduced in v1.2.0.
# 1.2.0 (January 11, 2020)
## Features
* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag.
* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases.
## Performance
* Improve performance when context.Background() is used. (bakape)
* CommandTag.RowsAffected is faster and does not allocate.
## Fixes
* Try to cancel any in-progress query when a conn is closed by ctx cancel.
* Handle NoticeResponse during CopyFrom.
* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish.
# 1.1.0 (October 12, 2019)
* Add PgConn.IsBusy() method.
# 1.0.1 (September 19, 2019)
* Fix statement cache not properly cleaning discarded statements.

@ -0,0 +1,22 @@
Copyright (c) 2019-2021 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,56 @@
[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn)
![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg)
# pgconn
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
low-level access to PostgreSQL functionality.
## Example Usage
```go
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatalln("pgconn failed to connect:", err)
}
defer pgConn.Close(context.Background())
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
for result.NextRow() {
fmt.Println("User 123 has email:", string(result.Values()[0]))
}
_, err = result.Close()
if err != nil {
log.Fatalln("failed reading result:", err)
}
```
## Testing
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
environment variable handling.
### Example Test Environment
Connect to your PostgreSQL server and run:
```
create database pgx_test;
```
Now you can run the tests:
```bash
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
```
### Connection and Authentication Tests
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
authentication code.

@ -0,0 +1,270 @@
// SCRAM-SHA-256 authentication
//
// Resources:
// https://tools.ietf.org/html/rfc5802
// https://tools.ietf.org/html/rfc8265
// https://www.postgresql.org/docs/current/sasl-authentication.html
//
// Inspiration drawn from other implementations:
// https://github.com/lib/pq/pull/608
// https://github.com/lib/pq/pull/788
// https://github.com/lib/pq/pull/833
package pgconn
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"strconv"
"github.com/jackc/pgproto3/v2"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/text/secure/precis"
)
const clientNonceLen = 18
// Perform SCRAM authentication.
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
if err != nil {
return err
}
// Send client-first-message in a SASLInitialResponse
saslInitialResponse := &pgproto3.SASLInitialResponse{
AuthMechanism: "SCRAM-SHA-256",
Data: sc.clientFirstMessage(),
}
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
if err != nil {
return err
}
// Receive server-first-message payload in a AuthenticationSASLContinue.
saslContinue, err := c.rxSASLContinue()
if err != nil {
return err
}
err = sc.recvServerFirstMessage(saslContinue.Data)
if err != nil {
return err
}
// Send client-final-message in a SASLResponse
saslResponse := &pgproto3.SASLResponse{
Data: []byte(sc.clientFinalMessage()),
}
_, err = c.conn.Write(saslResponse.Encode(nil))
if err != nil {
return err
}
// Receive server-final-message payload in a AuthenticationSASLFinal.
saslFinal, err := c.rxSASLFinal()
if err != nil {
return err
}
return sc.recvServerFinalMessage(saslFinal.Data)
}
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLFinal:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
}
type scramClient struct {
serverAuthMechanisms []string
password []byte
clientNonce []byte
clientFirstMessageBare []byte
serverFirstMessage []byte
clientAndServerNonce []byte
salt []byte
iterations int
saltedPassword []byte
authMessage []byte
}
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
sc := &scramClient{
serverAuthMechanisms: serverAuthMechanisms,
}
// Ensure server supports SCRAM-SHA-256
hasScramSHA256 := false
for _, mech := range sc.serverAuthMechanisms {
if mech == "SCRAM-SHA-256" {
hasScramSHA256 = true
break
}
}
if !hasScramSHA256 {
return nil, errors.New("server does not support SCRAM-SHA-256")
}
// precis.OpaqueString is equivalent to SASLprep for password.
var err error
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
if err != nil {
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
sc.password = []byte(password)
}
buf := make([]byte, clientNonceLen)
_, err = rand.Read(buf)
if err != nil {
return nil, err
}
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
return sc, nil
}
func (sc *scramClient) clientFirstMessage() []byte {
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
}
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
sc.serverFirstMessage = serverFirstMessage
buf := serverFirstMessage
if !bytes.HasPrefix(buf, []byte("r=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
}
buf = buf[2:]
idx := bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
sc.clientAndServerNonce = buf[:idx]
buf = buf[idx+1:]
if !bytes.HasPrefix(buf, []byte("s=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
buf = buf[2:]
idx = bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
saltStr := buf[:idx]
buf = buf[idx+1:]
if !bytes.HasPrefix(buf, []byte("i=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
buf = buf[2:]
iterationsStr := buf
var err error
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
if err != nil {
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
}
sc.iterations, err = strconv.Atoi(string(iterationsStr))
if err != nil || sc.iterations <= 0 {
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
}
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not start with client nonce")
}
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not include server nonce")
}
return nil
}
func (sc *scramClient) clientFinalMessage() string {
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
}
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
return errors.New("invalid SCRAM server-final-message received from server")
}
serverSignature := serverFinalMessage[2:]
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
return errors.New("invalid SCRAM ServerSignature received from server")
}
return nil
}
func computeHMAC(key, msg []byte) []byte {
mac := hmac.New(sha256.New, key)
mac.Write(msg)
return mac.Sum(nil)
}
func computeClientProof(saltedPassword, authMessage []byte) []byte {
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
storedKey := sha256.Sum256(clientKey)
clientSignature := computeHMAC(storedKey[:], authMessage)
clientProof := make([]byte, len(clientSignature))
for i := 0; i < len(clientSignature); i++ {
clientProof[i] = clientKey[i] ^ clientSignature[i]
}
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
base64.StdEncoding.Encode(buf, clientProof)
return buf
}
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
serverSignature := computeHMAC(serverKey, authMessage)
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
base64.StdEncoding.Encode(buf, serverSignature)
return buf
}

@ -0,0 +1,812 @@
package pgconn
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/jackc/chunkreader/v2"
"github.com/jackc/pgpassfile"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgservicefile"
)
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
// manually initialized Config will cause ConnectConfig to panic.
type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
ConnectTimeout time.Duration
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
KerberosSrvName string
KerberosSpn string
Fallbacks []*FallbackConfig
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
ValidateConnect ValidateConnectFunc
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
// or prepare statements). If this returns an error the connection attempt fails.
AfterConnect AfterConnectFunc
// OnNotice is a callback function called when a notice response is received.
OnNotice NoticeHandler
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the TLSConfig field:
// according to the tls.Config docs it must not be modified after creation.
func (c *Config) Copy() *Config {
newConf := new(Config)
*newConf = *c
if newConf.TLSConfig != nil {
newConf.TLSConfig = c.TLSConfig.Clone()
}
if newConf.RuntimeParams != nil {
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
for k, v := range c.RuntimeParams {
newConf.RuntimeParams[k] = v
}
}
if newConf.Fallbacks != nil {
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
for i, fallback := range c.Fallbacks {
newFallback := new(FallbackConfig)
*newFallback = *fallback
if newFallback.TLSConfig != nil {
newFallback.TLSConfig = fallback.TLSConfig.Clone()
}
newConf.Fallbacks[i] = newFallback
}
}
return newConf
}
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
type FallbackConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16
TLSConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
func isAbsolutePath(path string) bool {
isWindowsPath := func(p string) bool {
if len(p) < 3 {
return false
}
drive := p[0]
colon := p[1]
backslash := p[2]
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
return true
}
return false
}
return strings.HasPrefix(path, "/") || isWindowsPath(path)
}
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
// net.Dial.
func NetworkAddress(host string, port uint16) (network, address string) {
if isAbsolutePath(host) {
network = "unix"
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
} else {
network = "tcp"
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
}
return network, address
}
// ParseConfig builds a *Config with similar behavior to the PostgreSQL standard C library libpq. It uses the same
// defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely matches
// the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
//
// # Example DSN
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
//
// # Example URL
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
//
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
// not be modified individually. They should all be modified or all left unchanged.
//
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
// values that will be tried in order. This can be used as part of a high availability system. See
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
//
// # Example URL
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
//
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or DSN:
//
// PGHOST
// PGPORT
// PGDATABASE
// PGUSER
// PGPASSWORD
// PGPASSFILE
// PGSERVICE
// PGSERVICEFILE
// PGSSLMODE
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
//
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
//
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
// usually but not always the environment variable name downcased and without the "PG" prefix.
//
// Important Security Notes:
//
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
// not set.
//
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
// security each sslmode provides.
//
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
// TLCConfig.
//
// Other known differences with libpq:
//
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
// does not.
//
// In addition, ParseConfig accepts the following options:
//
// min_read_buffer_size
// The minimum size of the internal read buffer. Default 8192.
// servicefile
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
func ParseConfig(connString string) (*Config, error) {
defaultSettings := defaultSettings()
envSettings := parseEnvSettings()
connStringSettings := make(map[string]string)
if connString != "" {
var err error
// connString may be a database URL or a DSN
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseDSNSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
}
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
}
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
}
config := &Config{
createdByParseConfig: true,
Database: settings["database"],
User: settings["user"],
Password: settings["password"],
RuntimeParams: make(map[string]string),
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
}
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
} else {
defaultDialer := makeDefaultDialer()
config.DialFunc = defaultDialer.DialContext
}
config.LookupFunc = makeDefaultResolver().LookupHost
notRuntimeParams := map[string]struct{}{
"host": {},
"port": {},
"database": {},
"user": {},
"password": {},
"passfile": {},
"connect_timeout": {},
"sslmode": {},
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
"min_read_buffer_size": {},
"service": {},
"servicefile": {},
}
// Adding kerberos configuration
if _, present := settings["krbsrvname"]; present {
config.KerberosSrvName = settings["krbsrvname"]
}
if _, present := settings["krbspn"]; present {
config.KerberosSpn = settings["krbspn"]
}
for k, v := range settings {
if _, present := notRuntimeParams[k]; present {
continue
}
config.RuntimeParams[k] = v
}
fallbacks := []*FallbackConfig{}
hosts := strings.Split(settings["host"], ",")
ports := strings.Split(settings["port"], ",")
for i, host := range hosts {
var portStr string
if i < len(ports) {
portStr = ports[i]
} else {
portStr = ports[0]
}
port, err := parsePort(portStr)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
// Ignore TLS settings if Unix domain socket like libpq
if network, _ := NetworkAddress(host, port); network == "unix" {
tlsConfigs = append(tlsConfigs, nil)
} else {
var err error
tlsConfigs, err = configTLS(settings, host)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
}
}
for _, tlsConfig := range tlsConfigs {
fallbacks = append(fallbacks, &FallbackConfig{
Host: host,
Port: port,
TLSConfig: tlsConfig,
})
}
}
config.Host = fallbacks[0].Host
config.Port = fallbacks[0].Port
config.TLSConfig = fallbacks[0].TLSConfig
config.Fallbacks = fallbacks[1:]
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
if err == nil {
if config.Password == "" {
host := config.Host
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
host = "localhost"
}
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
}
}
switch tsa := settings["target_session_attrs"]; tsa {
case "read-write":
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
case "read-only":
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
case "primary":
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
case "standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
case "any", "prefer-standby":
// do nothing
default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}
return config, nil
}
func mergeSettings(settingSets ...map[string]string) map[string]string {
settings := make(map[string]string)
for _, s2 := range settingSets {
for k, v := range s2 {
settings[k] = v
}
}
return settings
}
func parseEnvSettings() map[string]string {
settings := make(map[string]string)
nameMap := map[string]string{
"PGHOST": "host",
"PGPORT": "port",
"PGDATABASE": "database",
"PGUSER": "user",
"PGPASSWORD": "password",
"PGPASSFILE": "passfile",
"PGAPPNAME": "application_name",
"PGCONNECT_TIMEOUT": "connect_timeout",
"PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLROOTCERT": "sslrootcert",
"PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service",
"PGSERVICEFILE": "servicefile",
}
for envname, realname := range nameMap {
value := os.Getenv(envname)
if value != "" {
settings[realname] = value
}
}
return settings
}
func parseURLSettings(connString string) (map[string]string, error) {
settings := make(map[string]string)
url, err := url.Parse(connString)
if err != nil {
return nil, err
}
if url.User != nil {
settings["user"] = url.User.Username()
if password, present := url.User.Password(); present {
settings["password"] = password
}
}
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
var hosts []string
var ports []string
for _, host := range strings.Split(url.Host, ",") {
if host == "" {
continue
}
if isIPOnly(host) {
hosts = append(hosts, strings.Trim(host, "[]"))
continue
}
h, p, err := net.SplitHostPort(host)
if err != nil {
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
}
if h != "" {
hosts = append(hosts, h)
}
if p != "" {
ports = append(ports, p)
}
}
if len(hosts) > 0 {
settings["host"] = strings.Join(hosts, ",")
}
if len(ports) > 0 {
settings["port"] = strings.Join(ports, ",")
}
database := strings.TrimLeft(url.Path, "/")
if database != "" {
settings["database"] = database
}
nameMap := map[string]string{
"dbname": "database",
}
for k, v := range url.Query() {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v[0]
}
return settings, nil
}
func isIPOnly(host string) bool {
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
}
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseDSNSettings(s string) (map[string]string, error) {
settings := make(map[string]string)
nameMap := map[string]string{
"dbname": "database",
}
for len(s) > 0 {
var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return nil, errors.New("invalid dsn")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
if len(s) == 0 {
} else if s[0] != '\'' {
end := 0
for ; end < len(s); end++ {
if asciiSpace[s[end]] == 1 {
break
}
if s[end] == '\\' {
end++
if end == len(s) {
return nil, errors.New("invalid backslash")
}
}
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
} else { // quoted string
s = s[1:]
end := 0
for ; end < len(s); end++ {
if s[end] == '\'' {
break
}
if s[end] == '\\' {
end++
}
}
if end == len(s) {
return nil, errors.New("unterminated quoted string in connection info string")
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
}
if k, ok := nameMap[key]; ok {
key = k
}
if key == "" {
return nil, errors.New("invalid dsn")
}
settings[key] = val
}
return settings, nil
}
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
if err != nil {
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
}
service, err := servicefile.GetService(serviceName)
if err != nil {
return nil, fmt.Errorf("unable to find service: %v", serviceName)
}
nameMap := map[string]string{
"dbname": "database",
}
settings := make(map[string]string, len(service.Settings))
for k, v := range service.Settings {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v
}
return settings, nil
}
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
// necessary to allow returning multiple TLS configs as sslmode "allow" and
// "prefer" allow fallback.
func configTLS(settings map[string]string, thisHost string) ([]*tls.Config, error) {
host := thisHost
sslmode := settings["sslmode"]
sslrootcert := settings["sslrootcert"]
sslcert := settings["sslcert"]
sslkey := settings["sslkey"]
// Match libpq default behavior
if sslmode == "" {
sslmode = "prefer"
}
tlsConfig := &tls.Config{}
switch sslmode {
case "disable":
return []*tls.Config{nil}, nil
case "allow", "prefer":
tlsConfig.InsecureSkipVerify = true
case "require":
// According to PostgreSQL documentation, if a root CA file exists,
// the behavior of sslmode=require should be the same as that of verify-ca
//
// See https://www.postgresql.org/docs/12/libpq-ssl.html
if sslrootcert != "" {
goto nextCase
}
tlsConfig.InsecureSkipVerify = true
break
nextCase:
fallthrough
case "verify-ca":
// Don't perform the default certificate verification because it
// will verify the hostname. Instead, verify the server's
// certificate chain ourselves in VerifyPeerCertificate and
// ignore the server name. This emulates libpq's verify-ca
// behavior.
//
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
// for more info.
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, len(certificates))
for i, asn1Data := range certificates {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return errors.New("failed to parse certificate from server: " + err.Error())
}
certs[i] = cert
}
// Leave DNSName empty to skip hostname verification.
opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
Intermediates: x509.NewCertPool(),
}
// Skip the first cert because it's the leaf. All others
// are intermediates.
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
return err
}
case "verify-full":
tlsConfig.ServerName = host
default:
return nil, errors.New("sslmode is invalid")
}
if sslrootcert != "" {
caCertPool := x509.NewCertPool()
caPath := sslrootcert
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
}
if sslcert != "" && sslkey != "" {
cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
switch sslmode {
case "allow":
return []*tls.Config{nil, tlsConfig}, nil
case "prefer":
return []*tls.Config{tlsConfig, nil}, nil
case "require", "verify-ca", "verify-full":
return []*tls.Config{tlsConfig}, nil
default:
panic("BUG: bad sslmode should already have been caught")
}
}
func parsePort(s string) (uint16, error) {
port, err := strconv.ParseUint(s, 10, 16)
if err != nil {
return 0, err
}
if port < 1 || port > math.MaxUint16 {
return 0, errors.New("outside range")
}
return uint16(port), nil
}
func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func makeDefaultResolver() *net.Resolver {
return net.DefaultResolver
}
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
return func(r io.Reader, w io.Writer) Frontend {
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
if err != nil {
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
}
frontend := pgproto3.NewFrontend(cr, w)
return frontend
}
}
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, err
}
if timeout < 0 {
return 0, errors.New("negative timeout")
}
return time.Duration(timeout) * time.Second, nil
}
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
d := makeDefaultDialer()
d.Timeout = timeout
return d.DialContext
}
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) == "on" {
return errors.New("read only connection")
}
return nil
}
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-only.
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "on" {
return errors.New("connection is not read only")
}
return nil
}
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=standby.
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "t" {
return errors.New("server is not in hot standby mode")
}
return nil
}
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=primary.
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) == "t" {
return errors.New("server is in standby mode")
}
return nil
}

@ -0,0 +1,64 @@
// +build !windows
package pgconn
import (
"os"
"os/user"
"path/filepath"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
if err == nil {
settings["user"] = user.Username
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
settings["min_read_buffer_size"] = "8192"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
candidatePaths := []string{
"/var/run/postgresql", // Debian
"/private/tmp", // OSX - homebrew
"/tmp", // standard PostgreSQL
}
for _, path := range candidatePaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return "localhost"
}

@ -0,0 +1,59 @@
package pgconn
import (
"os"
"os/user"
"path/filepath"
"strings"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
appData := os.Getenv("APPDATA")
if err == nil {
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
// but the libpq default is just the `user` portion, so we strip off the first part.
username := user.Username
if strings.Contains(username, "\\") {
username = username[strings.LastIndex(username, "\\")+1:]
}
settings["user"] = username
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
settings["min_read_buffer_size"] = "8192"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
return "localhost"
}

@ -0,0 +1,29 @@
// Package pgconn is a low-level PostgreSQL database driver.
/*
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
nearly the same level is the C library libpq.
Establishing a Connection
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
libpq style environment variables.
Executing a Query
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
reads all rows into memory.
Executing Multiple Queries in a Single Round Trip
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
result. The ReadAll method reads all query results into memory.
Context Support
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
method immediately returns. In most circumstances, this will close the underlying connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.
*/
package pgconn

@ -0,0 +1,221 @@
package pgconn
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"regexp"
"strings"
)
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool {
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
return e.SafeToRetry()
}
return false
}
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
var timeoutErr *errTimeout
return errors.As(err, &timeoutErr)
}
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// SQLState returns the SQLState of the error.
func (pe *PgError) SQLState() string {
return pe.Code
}
type connectError struct {
config *Config
msg string
err error
}
func (e *connectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
return sb.String()
}
func (e *connectError) Unwrap() error {
return e.err
}
type connLockError struct {
status string
}
func (e *connLockError) SafeToRetry() bool {
return true // a lock failure by definition happens before the connection is used.
}
func (e *connLockError) Error() string {
return e.status
}
type parseConfigError struct {
connString string
msg string
err error
}
func (e *parseConfigError) Error() string {
connString := redactPW(e.connString)
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
}
func (e *parseConfigError) Unwrap() error {
return e.err
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return &errTimeout{err: ctx.Err()}
}
return err
}
type pgconnError struct {
msg string
err error
safeToRetry bool
}
func (e *pgconnError) Error() string {
if e.msg == "" {
return e.err.Error()
}
if e.err == nil {
return e.msg
}
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
}
func (e *pgconnError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *pgconnError) Unwrap() error {
return e.err
}
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
type errTimeout struct {
err error
}
func (e *errTimeout) Error() string {
return fmt.Sprintf("timeout: %s", e.err.Error())
}
func (e *errTimeout) SafeToRetry() bool {
return SafeToRetry(e.err)
}
func (e *errTimeout) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
func (e *contextAlreadyDoneError) Error() string {
return fmt.Sprintf("context already done: %s", e.err.Error())
}
func (e *contextAlreadyDoneError) SafeToRetry() bool {
return true
}
func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
func newContextAlreadyDoneError(ctx context.Context) (err error) {
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
}
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
}
func redactPW(connString string) string {
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
if u, err := url.Parse(connString); err == nil {
return redactURL(u)
}
}
quotedDSN := regexp.MustCompile(`password='[^']*'`)
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
plainDSN := regexp.MustCompile(`password=[^ ]*`)
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
return connString
}
func redactURL(u *url.URL) string {
if u == nil {
return ""
}
if _, pwSet := u.User.Password(); pwSet {
u.User = url.UserPassword(u.User.Username(), "xxxxx")
}
return u.String()
}

@ -0,0 +1,73 @@
package ctxwatch
import (
"context"
"sync"
)
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
onCancel func()
onUnwatchAfterCancel func()
unwatchChan chan struct{}
lock sync.Mutex
watchInProgress bool
onCancelWasCalled bool
}
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
cw := &ContextWatcher{
onCancel: onCancel,
onUnwatchAfterCancel: onUnwatchAfterCancel,
unwatchChan: make(chan struct{}),
}
return cw
}
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
func (cw *ContextWatcher) Watch(ctx context.Context) {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
panic("Watch already in progress")
}
cw.onCancelWasCalled = false
if ctx.Done() != nil {
cw.watchInProgress = true
go func() {
select {
case <-ctx.Done():
cw.onCancel()
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
}
}()
} else {
cw.watchInProgress = false
}
}
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
// called then onUnwatchAfterCancel will also be called.
func (cw *ContextWatcher) Unwatch() {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
cw.onUnwatchAfterCancel()
}
cw.watchInProgress = false
}
}

@ -0,0 +1,94 @@
package pgconn
import (
"errors"
"github.com/jackc/pgproto3/v2"
)
// NewGSSFunc creates a GSS authentication provider, for use with
// RegisterGSSProvider.
type NewGSSFunc func() (GSS, error)
var newGSS NewGSSFunc
// RegisterGSSProvider registers a GSS authentication provider. For example, if
// you need to use Kerberos to authenticate with your server, add this to your
// main package:
//
// import "github.com/otan/gopgkrb5"
//
// func init() {
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
// }
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
newGSS = newGSSArg
}
// GSS provides GSSAPI authentication (e.g., Kerberos).
type GSS interface {
GetInitToken(host string, service string) ([]byte, error)
GetInitTokenFromSPN(spn string) ([]byte, error)
Continue(inToken []byte) (done bool, outToken []byte, err error)
}
func (c *PgConn) gssAuth() error {
if newGSS == nil {
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
}
cli, err := newGSS()
if err != nil {
return err
}
var nextData []byte
if c.config.KerberosSpn != "" {
// Use the supplied SPN if provided.
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
} else {
// Allow the kerberos service name to be overridden
service := "postgres"
if c.config.KerberosSrvName != "" {
service = c.config.KerberosSrvName
}
nextData, err = cli.GetInitToken(c.config.Host, service)
}
if err != nil {
return err
}
for {
gssResponse := &pgproto3.GSSResponse{
Data: nextData,
}
_, err = c.conn.Write(gssResponse.Encode(nil))
if err != nil {
return err
}
resp, err := c.rxGSSContinue()
if err != nil {
return err
}
var done bool
done, nextData, err = cli.Continue(resp.Data)
if err != nil {
return err
}
if done {
break
}
}
return nil
}
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
gssContinue, ok := msg.(*pgproto3.AuthenticationGSSContinue)
if ok {
return gssContinue, nil
}
return nil, errors.New("expected AuthenticationGSSContinue message but received unexpected message")
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,169 @@
package stmtcache
import (
"container/list"
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgconn"
)
var lruCount uint64
// LRU implements Cache with a Least Recently Used (LRU) cache.
type LRU struct {
conn *pgconn.PgConn
mode int
cap int
prepareCount int
m map[string]*list.Element
l *list.List
psNamePrefix string
stmtsToClear []string
}
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
mustBeValidMode(mode)
mustBeValidCap(cap)
n := atomic.AddUint64(&lruCount, 1)
return &LRU{
conn: conn,
mode: mode,
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
}
}
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// flush an outstanding bad statements
txStatus := c.conn.TxStatus()
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
for _, stmt := range c.stmtsToClear {
err := c.clearStmt(ctx, stmt)
if err != nil {
return nil, err
}
}
}
if el, ok := c.m[sql]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription), nil
}
if c.l.Len() == c.cap {
err := c.removeOldest(ctx)
if err != nil {
return nil, err
}
}
psd, err := c.prepare(ctx, sql)
if err != nil {
return nil, err
}
el := c.l.PushFront(psd)
c.m[sql] = el
return psd, nil
}
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
func (c *LRU) Clear(ctx context.Context) error {
for c.l.Len() > 0 {
err := c.removeOldest(ctx)
if err != nil {
return err
}
}
return nil
}
func (c *LRU) StatementErrored(sql string, err error) {
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return
}
// https://github.com/jackc/pgx/issues/1162
//
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
if possibleInvalidCachedPlanError {
c.stmtsToClear = append(c.stmtsToClear, sql)
}
}
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
elem, inMap := c.m[sql]
if !inMap {
// The statement probably fell off the back of the list. In that case, we've
// ensured that it isn't in the cache, so we can declare victory.
return nil
}
c.l.Remove(elem)
psd := elem.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRU) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRU) Cap() int {
return c.cap
}
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
func (c *LRU) Mode() int {
return c.mode
}
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
var name string
if c.mode == ModePrepare {
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
c.prepareCount += 1
}
return c.conn.Prepare(ctx, name, sql, nil)
}
func (c *LRU) removeOldest(ctx context.Context) error {
oldest := c.l.Back()
c.l.Remove(oldest)
psd := oldest.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}

@ -0,0 +1,58 @@
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
package stmtcache
import (
"context"
"github.com/jackc/pgconn"
)
const (
ModePrepare = iota // Cache should prepare named statements.
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
)
// Cache prepares and caches prepared statement descriptions.
type Cache interface {
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
Clear(ctx context.Context) error
// StatementErrored informs the cache that the given statement resulted in an error when it
// was last used against the database. In some cases, this will cause the cache to maer that
// statement as bad. The bad statement will instead be flushed during the next call to Get
// that occurs outside of a failed transaction.
StatementErrored(sql string, err error)
// Len returns the number of cached prepared statement descriptions.
Len() int
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
Mode() int
}
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
// the maximum size of the cache.
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
mustBeValidMode(mode)
mustBeValidCap(cap)
return NewLRU(conn, mode, cap)
}
func mustBeValidMode(mode int) {
if mode != ModePrepare && mode != ModeDescribe {
panic("mode must be ModePrepare or ModeDescribe")
}
}
func mustBeValidCap(cap int) {
if cap < 1 {
panic("cache must have cap of >= 1")
}
}

@ -0,0 +1,9 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

@ -0,0 +1,22 @@
Copyright (c) 2019 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,11 @@
[![](https://godoc.org/github.com/jackc/pgio?status.svg)](https://godoc.org/github.com/jackc/pgio)
[![Build Status](https://travis-ci.org/jackc/pgio.svg)](https://travis-ci.org/jackc/pgio)
# pgio
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
pgio provides functions for appending integers to a []byte while doing byte
order conversion.
Extracted from original implementation in https://github.com/jackc/pgx.

@ -0,0 +1,6 @@
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
/*
pgio provides functions for appending integers to a []byte while doing byte
order conversion.
*/
package pgio

@ -0,0 +1,40 @@
package pgio
import "encoding/binary"
func AppendUint16(buf []byte, n uint16) []byte {
wp := len(buf)
buf = append(buf, 0, 0)
binary.BigEndian.PutUint16(buf[wp:], n)
return buf
}
func AppendUint32(buf []byte, n uint32) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0)
binary.BigEndian.PutUint32(buf[wp:], n)
return buf
}
func AppendUint64(buf []byte, n uint64) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
binary.BigEndian.PutUint64(buf[wp:], n)
return buf
}
func AppendInt16(buf []byte, n int16) []byte {
return AppendUint16(buf, uint16(n))
}
func AppendInt32(buf []byte, n int32) []byte {
return AppendUint32(buf, uint32(n))
}
func AppendInt64(buf []byte, n int64) []byte {
return AppendUint64(buf, uint64(n))
}
func SetInt32(buf []byte, n int32) {
binary.BigEndian.PutUint32(buf, uint32(n))
}

@ -0,0 +1,9 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

@ -0,0 +1,22 @@
Copyright (c) 2019 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,8 @@
[![](https://godoc.org/github.com/jackc/pgpassfile?status.svg)](https://godoc.org/github.com/jackc/pgpassfile)
[![Build Status](https://travis-ci.org/jackc/pgpassfile.svg)](https://travis-ci.org/jackc/pgpassfile)
# pgpassfile
Package pgpassfile is a parser PostgreSQL .pgpass files.
Extracted and rewritten from original implementation in https://github.com/jackc/pgx.

@ -0,0 +1,110 @@
// Package pgpassfile is a parser PostgreSQL .pgpass files.
package pgpassfile
import (
"bufio"
"io"
"os"
"regexp"
"strings"
)
// Entry represents a line in a PG passfile.
type Entry struct {
Hostname string
Port string
Database string
Username string
Password string
}
// Passfile is the in memory data structure representing a PG passfile.
type Passfile struct {
Entries []*Entry
}
// ReadPassfile reads the file at path and parses it into a Passfile.
func ReadPassfile(path string) (*Passfile, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
return ParsePassfile(f)
}
// ParsePassfile reads r and parses it into a Passfile.
func ParsePassfile(r io.Reader) (*Passfile, error) {
passfile := &Passfile{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
entry := parseLine(scanner.Text())
if entry != nil {
passfile.Entries = append(passfile.Entries, entry)
}
}
return passfile, scanner.Err()
}
// Match (not colons or escaped colon or escaped backslash)+. Essentially gives a split on unescaped
// colon.
var colonSplitterRegexp = regexp.MustCompile("(([^:]|(\\:)))+")
// var colonSplitterRegexp = regexp.MustCompile("((?:[^:]|(?:\\:)|(?:\\\\))+)")
// parseLine parses a line into an *Entry. It returns nil on comment lines or any other unparsable
// line.
func parseLine(line string) *Entry {
const (
tmpBackslash = "\r"
tmpColon = "\n"
)
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "#") {
return nil
}
line = strings.Replace(line, `\\`, tmpBackslash, -1)
line = strings.Replace(line, `\:`, tmpColon, -1)
parts := strings.Split(line, ":")
if len(parts) != 5 {
return nil
}
// Unescape escaped colons and backslashes
for i := range parts {
parts[i] = strings.Replace(parts[i], tmpBackslash, `\`, -1)
parts[i] = strings.Replace(parts[i], tmpColon, `:`, -1)
}
return &Entry{
Hostname: parts[0],
Port: parts[1],
Database: parts[2],
Username: parts[3],
Password: parts[4],
}
}
// FindPassword finds the password for the provided hostname, port, database, and username. For a
// Unix domain socket hostname must be set to "localhost". An empty string will be returned if no
// match is found.
//
// See https://www.postgresql.org/docs/current/libpq-pgpass.html for more password file information.
func (pf *Passfile) FindPassword(hostname, port, database, username string) (password string) {
for _, e := range pf.Entries {
if (e.Hostname == "*" || e.Hostname == hostname) &&
(e.Port == "*" || e.Port == port) &&
(e.Database == "*" || e.Database == database) &&
(e.Username == "*" || e.Username == username) {
return e.Password
}
}
return ""
}

@ -0,0 +1,9 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

@ -0,0 +1,22 @@
Copyright (c) 2019 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,12 @@
[![](https://godoc.org/github.com/jackc/pgproto3?status.svg)](https://godoc.org/github.com/jackc/pgproto3)
[![Build Status](https://travis-ci.org/jackc/pgproto3.svg)](https://travis-ci.org/jackc/pgproto3)
# pgproto3
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.
See example/pgfortune for a playful example of a fake PostgreSQL server.
Extracted from original implementation in https://github.com/jackc/pgx.

@ -0,0 +1,52 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
type AuthenticationCleartextPassword struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationCleartextPassword) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeCleartextPassword {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "AuthenticationCleartextPassword",
})
}

@ -0,0 +1,58 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type AuthenticationGSS struct{}
func (a *AuthenticationGSS) Backend() {}
func (a *AuthenticationGSS) AuthenticationResponse() {}
func (a *AuthenticationGSS) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSS {
return errors.New("bad auth type")
}
return nil
}
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst
}
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSS",
})
}
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
return nil
}

@ -0,0 +1,67 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type AuthenticationGSSContinue struct {
Data []byte
}
func (a *AuthenticationGSSContinue) Backend() {}
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSSCont {
return errors.New("bad auth type")
}
a.Data = src[4:]
return nil
}
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return dst
}
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSSContinue",
Data: a.Data,
})
}
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
a.Data = msg.Data
return nil
}

@ -0,0 +1,77 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
type AuthenticationMD5Password struct {
Salt [4]byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationMD5Password) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationMD5Password) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
if len(src) != 8 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeMD5Password {
return errors.New("bad auth type")
}
copy(dst.Salt[:], src[4:8])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Salt [4]byte
}{
Type: "AuthenticationMD5Password",
Salt: src.Salt,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Salt [4]byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Salt = msg.Salt
return nil
}

@ -0,0 +1,52 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
type AuthenticationOk struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationOk) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationOk) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationOk) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeOk {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationOk) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "AuthenticationOK",
})
}

@ -0,0 +1,75 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
type AuthenticationSASL struct {
AuthMechanisms []string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASL) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASL) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASL) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASL {
return errors.New("bad auth type")
}
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 {
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationSASL) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
AuthMechanisms []string
}{
Type: "AuthenticationSASL",
AuthMechanisms: src.AuthMechanisms,
})
}

@ -0,0 +1,81 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
type AuthenticationSASLContinue struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLContinue) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLContinue {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "AuthenticationSASLContinue",
Data: string(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

@ -0,0 +1,81 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
type AuthenticationSASLFinal struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLFinal) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLFinal {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Unmarshaler.
func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "AuthenticationSASLFinal",
Data: string(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

@ -0,0 +1,209 @@
package pgproto3
import (
"encoding/binary"
"fmt"
"io"
)
// Backend acts as a server for the PostgreSQL wire protocol version 3.
type Backend struct {
cr ChunkReader
w io.Writer
// Frontend message flyweights
bind Bind
cancelRequest CancelRequest
_close Close
copyFail CopyFail
copyData CopyData
copyDone CopyDone
describe Describe
execute Execute
flush Flush
functionCall FunctionCall
gssEncRequest GSSEncRequest
parse Parse
query Query
sslRequest SSLRequest
startupMessage StartupMessage
sync Sync
terminate Terminate
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
const (
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
)
// NewBackend creates a new Backend.
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
}
// Send sends a message to the frontend.
func (b *Backend) Send(msg BackendMessage) error {
_, err := b.w.Write(msg.Encode(nil))
return err
}
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
// because the initial connection message is "special" and does not include the message type as the first byte. This
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
buf, err := b.cr.Next(4)
if err != nil {
return nil, err
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
}
buf, err = b.cr.Next(msgSize)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
code := binary.BigEndian.Uint32(buf)
switch code {
case ProtocolVersionNumber:
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
}
return &b.startupMessage, nil
case sslRequestNumber:
err = b.sslRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.sslRequest, nil
case cancelRequestCode:
err = b.cancelRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.cancelRequest, nil
case gssEncReqNumber:
err = b.gssEncRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.gssEncRequest, nil
default:
return nil, fmt.Errorf("unknown startup message code: %d", code)
}
}
// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
func (b *Backend) Receive() (FrontendMessage, error) {
if !b.partialMsg {
header, err := b.cr.Next(5)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
}
var msg FrontendMessage
switch b.msgType {
case 'B':
msg = &b.bind
case 'C':
msg = &b._close
case 'D':
msg = &b.describe
case 'E':
msg = &b.execute
case 'F':
msg = &b.functionCall
case 'f':
msg = &b.copyFail
case 'd':
msg = &b.copyData
case 'c':
msg = &b.copyDone
case 'H':
msg = &b.flush
case 'P':
msg = &b.parse
case 'p':
switch b.authType {
case AuthTypeSASL:
msg = &SASLInitialResponse{}
case AuthTypeSASLContinue:
msg = &SASLResponse{}
case AuthTypeSASLFinal:
msg = &SASLResponse{}
case AuthTypeGSS, AuthTypeGSSCont:
msg = &GSSResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:
// to maintain backwards compatability
msg = &PasswordMessage{}
}
case 'Q':
msg = &b.query
case 'S':
msg = &b.sync
case 'X':
msg = &b.terminate
default:
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.partialMsg = false
err = msg.Decode(msgBody)
return msg, err
}
// SetAuthType sets the authentication type in the backend.
// Since multiple message types can start with 'p', SetAuthType allows
// contextual identification of FrontendMessages. For example, in the
// PG message flow documentation for PasswordMessage:
//
// Byte1('p')
//
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
//
// Since the Frontend does not know about the state of a backend, it is important
// to call SetAuthType() after an authentication request is received by the Frontend.
func (b *Backend) SetAuthType(authType uint32) error {
switch authType {
case AuthTypeOk,
AuthTypeCleartextPassword,
AuthTypeMD5Password,
AuthTypeSCMCreds,
AuthTypeGSS,
AuthTypeGSSCont,
AuthTypeSSPI,
AuthTypeSASL,
AuthTypeSASLContinue,
AuthTypeSASLFinal:
b.authType = authType
default:
return fmt.Errorf("authType not recognized: %d", authType)
}
return nil
}

@ -0,0 +1,51 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type BackendKeyData struct {
ProcessID uint32
SecretKey uint32
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*BackendKeyData) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *BackendKeyData) Decode(src []byte) error {
if len(src) != 8 {
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
}
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) []byte {
dst = append(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src BackendKeyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "BackendKeyData",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

@ -0,0 +1,37 @@
package pgproto3
import (
"encoding/binary"
)
type BigEndianBuf [8]byte
func (b BigEndianBuf) Int16(n int16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, uint16(n))
return buf
}
func (b BigEndianBuf) Uint16(n uint16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, n)
return buf
}
func (b BigEndianBuf) Int32(n int32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, uint32(n))
return buf
}
func (b BigEndianBuf) Uint32(n uint32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, n)
return buf
}
func (b BigEndianBuf) Int64(n int64) []byte {
buf := b[0:8]
binary.BigEndian.PutUint64(buf, uint64(n))
return buf
}

@ -0,0 +1,216 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/jackc/pgio"
)
type Bind struct {
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters [][]byte
ResultFormatCodes []int16
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Bind) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Bind) Decode(src []byte) error {
*dst = Bind{}
idx := bytes.IndexByte(src, 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.DestinationPortal = string(src[:idx])
rp := idx + 1
idx = bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.PreparedStatement = string(src[rp : rp+idx])
rp += idx + 1
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if parameterFormatCodeCount > 0 {
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
for i := 0; i < parameterFormatCodeCount; i++ {
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
}
}
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if parameterCount > 0 {
dst.Parameters = make([][]byte, parameterCount)
for i := 0; i < parameterCount; i++ {
if len(src[rp:]) < 4 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
// null
if msgSize == -1 {
continue
}
if len(src[rp:]) < msgSize {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.Parameters[i] = src[rp : rp+msgSize]
rp += msgSize
}
}
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
for i := 0; i < resultFormatCodeCount; i++ {
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
dst = pgio.AppendInt32(dst, -1)
continue
}
dst = pgio.AppendInt32(dst, int32(len(p)))
dst = append(dst, p...)
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Bind) MarshalJSON() ([]byte, error) {
formattedParameters := make([]map[string]string, len(src.Parameters))
for i, p := range src.Parameters {
if p == nil {
continue
}
textFormat := true
if len(src.ParameterFormatCodes) == 1 {
textFormat = src.ParameterFormatCodes[0] == 0
} else if len(src.ParameterFormatCodes) > 1 {
textFormat = src.ParameterFormatCodes[i] == 0
}
if textFormat {
formattedParameters[i] = map[string]string{"text": string(p)}
} else {
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
}
}
return json.Marshal(struct {
Type string
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters []map[string]string
ResultFormatCodes []int16
}{
Type: "Bind",
DestinationPortal: src.DestinationPortal,
PreparedStatement: src.PreparedStatement,
ParameterFormatCodes: src.ParameterFormatCodes,
Parameters: formattedParameters,
ResultFormatCodes: src.ResultFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Bind) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters []map[string]string
ResultFormatCodes []int16
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
dst.DestinationPortal = msg.DestinationPortal
dst.PreparedStatement = msg.PreparedStatement
dst.ParameterFormatCodes = msg.ParameterFormatCodes
dst.Parameters = make([][]byte, len(msg.Parameters))
dst.ResultFormatCodes = msg.ResultFormatCodes
for n, parameter := range msg.Parameters {
dst.Parameters[n], err = getValueFromJSON(parameter)
if err != nil {
return fmt.Errorf("cannot get param %d: %w", n, err)
}
}
return nil
}

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type BindComplete struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*BindComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *BindComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) []byte {
return append(dst, '2', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src BindComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "BindComplete",
})
}

@ -0,0 +1,58 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
const cancelRequestCode = 80877102
type CancelRequest struct {
ProcessID uint32
SecretKey uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CancelRequest) Frontend() {}
func (dst *CancelRequest) Decode(src []byte) error {
if len(src) != 12 {
return errors.New("bad cancel request size")
}
requestCode := binary.BigEndian.Uint32(src)
if requestCode != cancelRequestCode {
return errors.New("bad cancel request code")
}
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
return nil
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CancelRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "CancelRequest",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

@ -0,0 +1,19 @@
package pgproto3
import (
"io"
"github.com/jackc/chunkreader/v2"
)
// ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package.
type ChunkReader interface {
// Next returns buf filled with the next n bytes. If an error (including a partial read) occurs,
// buf must be nil. Next must preserve any partially read data. Next must not reuse buf.
Next(n int) (buf []byte, err error)
}
// NewChunkReader creates and returns a new default ChunkReader.
func NewChunkReader(r io.Reader) ChunkReader {
return chunkreader.New(r)
}

@ -0,0 +1,89 @@
package pgproto3
import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type Close struct {
ObjectType byte // 'S' = prepared statement, 'P' = portal
Name string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Close) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Close) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "Close"}
}
dst.ObjectType = src[0]
rp := 1
idx := bytes.IndexByte(src[rp:], 0)
if idx != len(src[rp:])-1 {
return &invalidMessageFormatErr{messageType: "Close"}
}
dst.Name = string(src[rp : len(src)-1])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Close) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Close) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ObjectType string
Name string
}{
Type: "Close",
ObjectType: string(src.ObjectType),
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Close) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Close.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type CloseComplete struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CloseComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CloseComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CloseComplete) Encode(dst []byte) []byte {
return append(dst, '3', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CloseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "CloseComplete",
})
}

@ -0,0 +1,71 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgio"
)
type CommandComplete struct {
CommandTag []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CommandComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CommandComplete) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0)
if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CommandComplete"}
}
dst.CommandTag = src[:idx]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CommandComplete) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.CommandTag...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CommandComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
CommandTag string
}{
Type: "CommandComplete",
CommandTag: string(src.CommandTag),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
CommandTag string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.CommandTag = []byte(msg.CommandTag)
return nil
}

@ -0,0 +1,95 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type CopyBothResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyBothResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyBothResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyBothResponse) Encode(dst []byte) []byte {
dst = append(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyBothResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyBothResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

@ -0,0 +1,62 @@
package pgproto3
import (
"encoding/hex"
"encoding/json"
"github.com/jackc/pgio"
)
type CopyData struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyData) Backend() {}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CopyData) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyData) Decode(src []byte) error {
dst.Data = src
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyData) Encode(dst []byte) []byte {
dst = append(dst, 'd')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
dst = append(dst, src.Data...)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "CopyData",
Data: hex.EncodeToString(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyData) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

@ -0,0 +1,38 @@
package pgproto3
import (
"encoding/json"
)
type CopyDone struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyDone) Backend() {}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CopyDone) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyDone) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyDone) Encode(dst []byte) []byte {
return append(dst, 'c', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyDone) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "CopyDone",
})
}

@ -0,0 +1,53 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgio"
)
type CopyFail struct {
Message string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CopyFail) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyFail) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0)
if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CopyFail"}
}
dst.Message = string(src[:idx])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyFail) Encode(dst []byte) []byte {
dst = append(dst, 'f')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Message...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyFail) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Message string
}{
Type: "CopyFail",
Message: src.Message,
})
}

@ -0,0 +1,96 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type CopyInResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyInResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyInResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyInResponse) Encode(dst []byte) []byte {
dst = append(dst, 'G')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyInResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyInResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyInResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyInResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

@ -0,0 +1,96 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type CopyOutResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyOutResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyOutResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyOutResponse) Encode(dst []byte) []byte {
dst = append(dst, 'H')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyOutResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyOutResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

@ -0,0 +1,142 @@
package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"github.com/jackc/pgio"
)
type DataRow struct {
Values [][]byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*DataRow) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *DataRow) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
rp := 0
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
// If the capacity of the values slice is too small OR substantially too
// large reallocate. This is too avoid one row with many columns from
// permanently allocating memory.
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
newCap := 32
if newCap < fieldCount {
newCap = fieldCount
}
dst.Values = make([][]byte, fieldCount, newCap)
} else {
dst.Values = dst.Values[:fieldCount]
}
for i := 0; i < fieldCount; i++ {
if len(src[rp:]) < 4 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
// null
if msgSize == -1 {
dst.Values[i] = nil
} else {
if len(src[rp:]) < msgSize {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
dst.Values[i] = src[rp : rp+msgSize : rp+msgSize]
rp += msgSize
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {
dst = pgio.AppendInt32(dst, -1)
continue
}
dst = pgio.AppendInt32(dst, int32(len(v)))
dst = append(dst, v...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src DataRow) MarshalJSON() ([]byte, error) {
formattedValues := make([]map[string]string, len(src.Values))
for i, v := range src.Values {
if v == nil {
continue
}
var hasNonPrintable bool
for _, b := range v {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
} else {
formattedValues[i] = map[string]string{"text": string(v)}
}
}
return json.Marshal(struct {
Type string
Values []map[string]string
}{
Type: "DataRow",
Values: formattedValues,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *DataRow) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Values []map[string]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Values = make([][]byte, len(msg.Values))
for n, parameter := range msg.Values {
var err error
dst.Values[n], err = getValueFromJSON(parameter)
if err != nil {
return err
}
}
return nil
}

@ -0,0 +1,88 @@
package pgproto3
import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type Describe struct {
ObjectType byte // 'S' = prepared statement, 'P' = portal
Name string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Describe) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Describe) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "Describe"}
}
dst.ObjectType = src[0]
rp := 1
idx := bytes.IndexByte(src[rp:], 0)
if idx != len(src[rp:])-1 {
return &invalidMessageFormatErr{messageType: "Describe"}
}
dst.Name = string(src[rp : len(src)-1])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Describe) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Describe) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ObjectType string
Name string
}{
Type: "Describe",
ObjectType: string(src.ObjectType),
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Describe) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Describe.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}

@ -0,0 +1,4 @@
// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
//
// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages.
package pgproto3

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type EmptyQueryResponse struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*EmptyQueryResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *EmptyQueryResponse) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
return append(dst, 'I', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "EmptyQueryResponse",
})
}

@ -0,0 +1,334 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"strconv"
)
type ErrorResponse struct {
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ErrorResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ErrorResponse) Decode(src []byte) error {
*dst = ErrorResponse{}
buf := bytes.NewBuffer(src)
for {
k, err := buf.ReadByte()
if err != nil {
return err
}
if k == 0 {
break
}
vb, err := buf.ReadBytes(0)
if err != nil {
return err
}
v := string(vb[:len(vb)-1])
switch k {
case 'S':
dst.Severity = v
case 'V':
dst.SeverityUnlocalized = v
case 'C':
dst.Code = v
case 'M':
dst.Message = v
case 'D':
dst.Detail = v
case 'H':
dst.Hint = v
case 'P':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Position = int32(n)
case 'p':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.InternalPosition = int32(n)
case 'q':
dst.InternalQuery = v
case 'W':
dst.Where = v
case 's':
dst.SchemaName = v
case 't':
dst.TableName = v
case 'c':
dst.ColumnName = v
case 'd':
dst.DataTypeName = v
case 'n':
dst.ConstraintName = v
case 'F':
dst.File = v
case 'L':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Line = int32(n)
case 'R':
dst.Routine = v
default:
if dst.UnknownFields == nil {
dst.UnknownFields = make(map[byte]string)
}
dst.UnknownFields[k] = v
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ErrorResponse) Encode(dst []byte) []byte {
return append(dst, src.marshalBinary('E')...)
}
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
if src.Severity != "" {
buf.WriteByte('S')
buf.WriteString(src.Severity)
buf.WriteByte(0)
}
if src.SeverityUnlocalized != "" {
buf.WriteByte('V')
buf.WriteString(src.SeverityUnlocalized)
buf.WriteByte(0)
}
if src.Code != "" {
buf.WriteByte('C')
buf.WriteString(src.Code)
buf.WriteByte(0)
}
if src.Message != "" {
buf.WriteByte('M')
buf.WriteString(src.Message)
buf.WriteByte(0)
}
if src.Detail != "" {
buf.WriteByte('D')
buf.WriteString(src.Detail)
buf.WriteByte(0)
}
if src.Hint != "" {
buf.WriteByte('H')
buf.WriteString(src.Hint)
buf.WriteByte(0)
}
if src.Position != 0 {
buf.WriteByte('P')
buf.WriteString(strconv.Itoa(int(src.Position)))
buf.WriteByte(0)
}
if src.InternalPosition != 0 {
buf.WriteByte('p')
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
buf.WriteByte(0)
}
if src.InternalQuery != "" {
buf.WriteByte('q')
buf.WriteString(src.InternalQuery)
buf.WriteByte(0)
}
if src.Where != "" {
buf.WriteByte('W')
buf.WriteString(src.Where)
buf.WriteByte(0)
}
if src.SchemaName != "" {
buf.WriteByte('s')
buf.WriteString(src.SchemaName)
buf.WriteByte(0)
}
if src.TableName != "" {
buf.WriteByte('t')
buf.WriteString(src.TableName)
buf.WriteByte(0)
}
if src.ColumnName != "" {
buf.WriteByte('c')
buf.WriteString(src.ColumnName)
buf.WriteByte(0)
}
if src.DataTypeName != "" {
buf.WriteByte('d')
buf.WriteString(src.DataTypeName)
buf.WriteByte(0)
}
if src.ConstraintName != "" {
buf.WriteByte('n')
buf.WriteString(src.ConstraintName)
buf.WriteByte(0)
}
if src.File != "" {
buf.WriteByte('F')
buf.WriteString(src.File)
buf.WriteByte(0)
}
if src.Line != 0 {
buf.WriteByte('L')
buf.WriteString(strconv.Itoa(int(src.Line)))
buf.WriteByte(0)
}
if src.Routine != "" {
buf.WriteByte('R')
buf.WriteString(src.Routine)
buf.WriteByte(0)
}
for k, v := range src.UnknownFields {
buf.WriteByte(k)
buf.WriteByte(0)
buf.WriteString(v)
buf.WriteByte(0)
}
buf.WriteByte(0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes()
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ErrorResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}{
Type: "ErrorResponse",
Severity: src.Severity,
SeverityUnlocalized: src.SeverityUnlocalized,
Code: src.Code,
Message: src.Message,
Detail: src.Detail,
Hint: src.Hint,
Position: src.Position,
InternalPosition: src.InternalPosition,
InternalQuery: src.InternalQuery,
Where: src.Where,
SchemaName: src.SchemaName,
TableName: src.TableName,
ColumnName: src.ColumnName,
DataTypeName: src.DataTypeName,
ConstraintName: src.ConstraintName,
File: src.File,
Line: src.Line,
Routine: src.Routine,
UnknownFields: src.UnknownFields,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Severity = msg.Severity
dst.SeverityUnlocalized = msg.SeverityUnlocalized
dst.Code = msg.Code
dst.Message = msg.Message
dst.Detail = msg.Detail
dst.Hint = msg.Hint
dst.Position = msg.Position
dst.InternalPosition = msg.InternalPosition
dst.InternalQuery = msg.InternalQuery
dst.Where = msg.Where
dst.SchemaName = msg.SchemaName
dst.TableName = msg.TableName
dst.ColumnName = msg.ColumnName
dst.DataTypeName = msg.DataTypeName
dst.ConstraintName = msg.ConstraintName
dst.File = msg.File
dst.Line = msg.Line
dst.Routine = msg.Routine
dst.UnknownFields = msg.UnknownFields
return nil
}

@ -0,0 +1,65 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type Execute struct {
Portal string
MaxRows uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Execute) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Execute) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Portal = string(b[:len(b)-1])
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "Execute"}
}
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Execute) Encode(dst []byte) []byte {
dst = append(dst, 'E')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Portal...)
dst = append(dst, 0)
dst = pgio.AppendUint32(dst, src.MaxRows)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Execute) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Portal string
MaxRows uint32
}{
Type: "Execute",
Portal: src.Portal,
MaxRows: src.MaxRows,
})
}

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type Flush struct{}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Flush) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Flush) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Flush) Encode(dst []byte) []byte {
return append(dst, 'H', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Flush) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "Flush",
})
}

@ -0,0 +1,203 @@
package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
type Frontend struct {
cr ChunkReader
w io.Writer
// Backend message flyweights
authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
authenticationMD5Password AuthenticationMD5Password
authenticationGSS AuthenticationGSS
authenticationGSSContinue AuthenticationGSSContinue
authenticationSASL AuthenticationSASL
authenticationSASLContinue AuthenticationSASLContinue
authenticationSASLFinal AuthenticationSASLFinal
backendKeyData BackendKeyData
bindComplete BindComplete
closeComplete CloseComplete
commandComplete CommandComplete
copyBothResponse CopyBothResponse
copyData CopyData
copyInResponse CopyInResponse
copyOutResponse CopyOutResponse
copyDone CopyDone
dataRow DataRow
emptyQueryResponse EmptyQueryResponse
errorResponse ErrorResponse
functionCallResponse FunctionCallResponse
noData NoData
noticeResponse NoticeResponse
notificationResponse NotificationResponse
parameterDescription ParameterDescription
parameterStatus ParameterStatus
parseComplete ParseComplete
readyForQuery ReadyForQuery
rowDescription RowDescription
portalSuspended PortalSuspended
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
// NewFrontend creates a new Frontend.
func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
return &Frontend{cr: cr, w: w}
}
// Send sends a message to the backend.
func (f *Frontend) Send(msg FrontendMessage) error {
_, err := f.w.Write(msg.Encode(nil))
return err
}
func translateEOFtoErrUnexpectedEOF(err error) error {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
return err
}
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
func (f *Frontend) Receive() (BackendMessage, error) {
if !f.partialMsg {
header, err := f.cr.Next(5)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
f.msgType = header[0]
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
f.partialMsg = true
}
msgBody, err := f.cr.Next(f.bodyLen)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
f.partialMsg = false
var msg BackendMessage
switch f.msgType {
case '1':
msg = &f.parseComplete
case '2':
msg = &f.bindComplete
case '3':
msg = &f.closeComplete
case 'A':
msg = &f.notificationResponse
case 'c':
msg = &f.copyDone
case 'C':
msg = &f.commandComplete
case 'd':
msg = &f.copyData
case 'D':
msg = &f.dataRow
case 'E':
msg = &f.errorResponse
case 'G':
msg = &f.copyInResponse
case 'H':
msg = &f.copyOutResponse
case 'I':
msg = &f.emptyQueryResponse
case 'K':
msg = &f.backendKeyData
case 'n':
msg = &f.noData
case 'N':
msg = &f.noticeResponse
case 'R':
var err error
msg, err = f.findAuthenticationMessageType(msgBody)
if err != nil {
return nil, err
}
case 's':
msg = &f.portalSuspended
case 'S':
msg = &f.parameterStatus
case 't':
msg = &f.parameterDescription
case 'T':
msg = &f.rowDescription
case 'V':
msg = &f.functionCallResponse
case 'W':
msg = &f.copyBothResponse
case 'Z':
msg = &f.readyForQuery
default:
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
}
err = msg.Decode(msgBody)
return msg, err
}
// Authentication message type constants.
// See src/include/libpq/pqcomm.h for all
// constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSCMCreds = 6
AuthTypeGSS = 7
AuthTypeGSSCont = 8
AuthTypeSSPI = 9
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
if len(src) < 4 {
return nil, errors.New("authentication message too short")
}
f.authType = binary.BigEndian.Uint32(src[:4])
switch f.authType {
case AuthTypeOk:
return &f.authenticationOk, nil
case AuthTypeCleartextPassword:
return &f.authenticationCleartextPassword, nil
case AuthTypeMD5Password:
return &f.authenticationMD5Password, nil
case AuthTypeSCMCreds:
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
case AuthTypeGSS:
return &f.authenticationGSS, nil
case AuthTypeGSSCont:
return &f.authenticationGSSContinue, nil
case AuthTypeSSPI:
return nil, errors.New("AuthTypeSSPI is unimplemented")
case AuthTypeSASL:
return &f.authenticationSASL, nil
case AuthTypeSASLContinue:
return &f.authenticationSASLContinue, nil
case AuthTypeSASLFinal:
return &f.authenticationSASLFinal, nil
default:
return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
}
}
// GetAuthType returns the authType used in the current state of the frontend.
// See SetAuthType for more information.
func (f *Frontend) GetAuthType() uint32 {
return f.authType
}

@ -0,0 +1,94 @@
package pgproto3
import (
"encoding/binary"
"github.com/jackc/pgio"
)
type FunctionCall struct {
Function uint32
ArgFormatCodes []uint16
Arguments [][]byte
ResultFormatCode uint16
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*FunctionCall) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *FunctionCall) Decode(src []byte) error {
*dst = FunctionCall{}
rp := 0
// Specifies the object ID of the function to call.
dst.Function = binary.BigEndian.Uint32(src[rp:])
rp += 4
// The number of argument format codes that follow (denoted C below).
// This can be zero to indicate that there are no arguments or that the arguments all use the default format (text);
// or one, in which case the specified format code is applied to all arguments;
// or it can equal the actual number of arguments.
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
argumentCodes := make([]uint16, nArgumentCodes)
for i := 0; i < nArgumentCodes; i++ {
// The argument format codes. Each must presently be zero (text) or one (binary).
ac := binary.BigEndian.Uint16(src[rp:])
if ac != 0 && ac != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
argumentCodes[i] = ac
rp += 2
}
dst.ArgFormatCodes = argumentCodes
// Specifies the number of arguments being supplied to the function.
nArguments := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
arguments := make([][]byte, nArguments)
for i := 0; i < nArguments; i++ {
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if argumentLength == -1 {
arguments[i] = nil
} else {
// The value of the argument, in the format indicated by the associated format code. n is the above length.
argumentValue := src[rp : rp+argumentLength]
rp += argumentLength
arguments[i] = argumentValue
}
}
dst.Arguments = arguments
// The format code for the function result. Must presently be zero (text) or one (binary).
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
if resultFormatCode != 0 && resultFormatCode != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
dst.ResultFormatCode = resultFormatCode
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCall) Encode(dst []byte) []byte {
dst = append(dst, 'F')
sp := len(dst)
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
dst = pgio.AppendUint32(dst, src.Function)
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode)
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments {
if argument == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(argument)))
dst = append(dst, argument...)
}
}
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}

@ -0,0 +1,101 @@
package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"github.com/jackc/pgio"
)
type FunctionCallResponse struct {
Result []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*FunctionCallResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *FunctionCallResponse) Decode(src []byte) error {
if len(src) < 4 {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
rp := 0
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if resultSize == -1 {
dst.Result = nil
return nil
}
if len(src[rp:]) != resultSize {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
dst.Result = src[rp:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
dst = append(dst, 'V')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
if src.Result == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
dst = append(dst, src.Result...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src FunctionCallResponse) MarshalJSON() ([]byte, error) {
var formattedValue map[string]string
var hasNonPrintable bool
for _, b := range src.Result {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
} else {
formattedValue = map[string]string{"text": string(src.Result)}
}
return json.Marshal(struct {
Type string
Result map[string]string
}{
Type: "FunctionCallResponse",
Result: formattedValue,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Result map[string]string
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
dst.Result, err = getValueFromJSON(msg.Result)
return err
}

@ -0,0 +1,49 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
const gssEncReqNumber = 80877104
type GSSEncRequest struct {
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*GSSEncRequest) Frontend() {}
func (dst *GSSEncRequest) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("gss encoding request too short")
}
requestCode := binary.BigEndian.Uint32(src)
if requestCode != gssEncReqNumber {
return errors.New("bad gss encoding request code")
}
return nil
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *GSSEncRequest) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, gssEncReqNumber)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src GSSEncRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProtocolVersion uint32
Parameters map[string]string
}{
Type: "GSSEncRequest",
})
}

@ -0,0 +1,48 @@
package pgproto3
import (
"encoding/json"
"github.com/jackc/pgio"
)
type GSSResponse struct {
Data []byte
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (g *GSSResponse) Frontend() {}
func (g *GSSResponse) Decode(data []byte) error {
g.Data = data
return nil
}
func (g *GSSResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
dst = append(dst, g.Data...)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "GSSResponse",
Data: g.Data,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
var msg struct {
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
g.Data = msg.Data
return nil
}

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type NoData struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NoData) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NoData) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoData) Encode(dst []byte) []byte {
return append(dst, 'n', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src NoData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "NoData",
})
}

@ -0,0 +1,17 @@
package pgproto3
type NoticeResponse ErrorResponse
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NoticeResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NoticeResponse) Decode(src []byte) error {
return (*ErrorResponse)(dst).Decode(src)
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoticeResponse) Encode(dst []byte) []byte {
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
}

@ -0,0 +1,73 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type NotificationResponse struct {
PID uint32
Channel string
Payload string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NotificationResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NotificationResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
pid := binary.BigEndian.Uint32(buf.Next(4))
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
channel := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
payload := string(b[:len(b)-1])
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NotificationResponse) Encode(dst []byte) []byte {
dst = append(dst, 'A')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, src.PID)
dst = append(dst, src.Channel...)
dst = append(dst, 0)
dst = append(dst, src.Payload...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src NotificationResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
PID uint32
Channel string
Payload string
}{
Type: "NotificationResponse",
PID: src.PID,
Channel: src.Channel,
Payload: src.Payload,
})
}

@ -0,0 +1,66 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type ParameterDescription struct {
ParameterOIDs []uint32
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ParameterDescription) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ParameterDescription) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
}
// Reported parameter count will be incorrect when number of args is greater than uint16
buf.Next(2)
// Instead infer parameter count by remaining size of message
parameterCount := buf.Len() / 4
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
for i := 0; i < parameterCount; i++ {
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterDescription) Encode(dst []byte) []byte {
dst = append(dst, 't')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ParameterDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ParameterOIDs []uint32
}{
Type: "ParameterDescription",
ParameterOIDs: src.ParameterOIDs,
})
}

@ -0,0 +1,66 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgio"
)
type ParameterStatus struct {
Name string
Value string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ParameterStatus) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ParameterStatus) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
name := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
value := string(b[:len(b)-1])
*dst = ParameterStatus{Name: name, Value: value}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterStatus) Encode(dst []byte) []byte {
dst = append(dst, 'S')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Value...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (ps ParameterStatus) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Value string
}{
Type: "ParameterStatus",
Name: ps.Name,
Value: ps.Value,
})
}

@ -0,0 +1,88 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type Parse struct {
Name string
Query string
ParameterOIDs []uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Parse) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Parse) Decode(src []byte) error {
*dst = Parse{}
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Name = string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
dst.Query = string(b[:len(b)-1])
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "Parse"}
}
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
for i := 0; i < parameterOIDCount; i++ {
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "Parse"}
}
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Parse) Encode(dst []byte) []byte {
dst = append(dst, 'P')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Query...)
dst = append(dst, 0)
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Parse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Query string
ParameterOIDs []uint32
}{
Type: "Parse",
Name: src.Name,
Query: src.Query,
ParameterOIDs: src.ParameterOIDs,
})
}

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type ParseComplete struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ParseComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ParseComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParseComplete) Encode(dst []byte) []byte {
return append(dst, '1', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ParseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "ParseComplete",
})
}

@ -0,0 +1,54 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgio"
)
type PasswordMessage struct {
Password string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*PasswordMessage) Frontend() {}
// Frontend identifies this message as an authentication response.
func (*PasswordMessage) InitialResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *PasswordMessage) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Password = string(b[:len(b)-1])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *PasswordMessage) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
dst = append(dst, src.Password...)
dst = append(dst, 0)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src PasswordMessage) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Password string
}{
Type: "PasswordMessage",
Password: src.Password,
})
}

@ -0,0 +1,65 @@
package pgproto3
import (
"encoding/hex"
"errors"
"fmt"
)
// Message is the interface implemented by an object that can decode and encode
// a particular PostgreSQL message.
type Message interface {
// Decode is allowed and expected to retain a reference to data after
// returning (unlike encoding.BinaryUnmarshaler).
Decode(data []byte) error
// Encode appends itself to dst and returns the new buffer.
Encode(dst []byte) []byte
}
type FrontendMessage interface {
Message
Frontend() // no-op method to distinguish frontend from backend methods
}
type BackendMessage interface {
Message
Backend() // no-op method to distinguish frontend from backend methods
}
type AuthenticationResponseMessage interface {
BackendMessage
AuthenticationResponse() // no-op method to distinguish authentication responses
}
type invalidMessageLenErr struct {
messageType string
expectedLen int
actualLen int
}
func (e *invalidMessageLenErr) Error() string {
return fmt.Sprintf("%s body must have length of %d, but it is %d", e.messageType, e.expectedLen, e.actualLen)
}
type invalidMessageFormatErr struct {
messageType string
}
func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType)
}
// getValueFromJSON gets the value from a protocol message representation in JSON.
func getValueFromJSON(v map[string]string) ([]byte, error) {
if v == nil {
return nil, nil
}
if text, ok := v["text"]; ok {
return []byte(text), nil
}
if binary, ok := v["binary"]; ok {
return hex.DecodeString(binary)
}
return nil, errors.New("unknown protocol representation")
}

@ -0,0 +1,34 @@
package pgproto3
import (
"encoding/json"
)
type PortalSuspended struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*PortalSuspended) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *PortalSuspended) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "PortalSuspended", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *PortalSuspended) Encode(dst []byte) []byte {
return append(dst, 's', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src PortalSuspended) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "PortalSuspended",
})
}

@ -0,0 +1,50 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgio"
)
type Query struct {
String string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Query) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Query) Decode(src []byte) error {
i := bytes.IndexByte(src, 0)
if i != len(src)-1 {
return &invalidMessageFormatErr{messageType: "Query"}
}
dst.String = string(src[:i])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Query) Encode(dst []byte) []byte {
dst = append(dst, 'Q')
dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1))
dst = append(dst, src.String...)
dst = append(dst, 0)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Query) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
String string
}{
Type: "Query",
String: src.String,
})
}

@ -0,0 +1,61 @@
package pgproto3
import (
"encoding/json"
"errors"
)
type ReadyForQuery struct {
TxStatus byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ReadyForQuery) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ReadyForQuery) Decode(src []byte) error {
if len(src) != 1 {
return &invalidMessageLenErr{messageType: "ReadyForQuery", expectedLen: 1, actualLen: len(src)}
}
dst.TxStatus = src[0]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ReadyForQuery) Encode(dst []byte) []byte {
return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ReadyForQuery) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
TxStatus string
}{
Type: "ReadyForQuery",
TxStatus: string(src.TxStatus),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *ReadyForQuery) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
TxStatus string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.TxStatus) != 1 {
return errors.New("invalid length for ReadyForQuery.TxStatus")
}
dst.TxStatus = msg.TxStatus[0]
return nil
}

@ -0,0 +1,165 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
const (
TextFormat = 0
BinaryFormat = 1
)
type FieldDescription struct {
Name []byte
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier int32
Format int16
}
// MarshalJSON implements encoding/json.Marshaler.
func (fd FieldDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier int32
Format int16
}{
Name: string(fd.Name),
TableOID: fd.TableOID,
TableAttributeNumber: fd.TableAttributeNumber,
DataTypeOID: fd.DataTypeOID,
DataTypeSize: fd.DataTypeSize,
TypeModifier: fd.TypeModifier,
Format: fd.Format,
})
}
type RowDescription struct {
Fields []FieldDescription
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*RowDescription) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *RowDescription) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fieldCount := int(binary.BigEndian.Uint16(src))
rp := 2
dst.Fields = dst.Fields[0:0]
for i := 0; i < fieldCount; i++ {
var fd FieldDescription
idx := bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fd.Name = src[rp : rp+idx]
rp += idx + 1
// Since buf.Next() doesn't return an error if we hit the end of the buffer
// check Len ahead of time
if len(src[rp:]) < 18 {
return &invalidMessageFormatErr{messageType: "RowDescription"}
}
fd.TableOID = binary.BigEndian.Uint32(src[rp:])
rp += 4
fd.TableAttributeNumber = binary.BigEndian.Uint16(src[rp:])
rp += 2
fd.DataTypeOID = binary.BigEndian.Uint32(src[rp:])
rp += 4
fd.DataTypeSize = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
fd.TypeModifier = int32(binary.BigEndian.Uint32(src[rp:]))
rp += 4
fd.Format = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
dst.Fields = append(dst.Fields, fd)
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *RowDescription) Encode(dst []byte) []byte {
dst = append(dst, 'T')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.Fields)))
for _, fd := range src.Fields {
dst = append(dst, fd.Name...)
dst = append(dst, 0)
dst = pgio.AppendUint32(dst, fd.TableOID)
dst = pgio.AppendUint16(dst, fd.TableAttributeNumber)
dst = pgio.AppendUint32(dst, fd.DataTypeOID)
dst = pgio.AppendInt16(dst, fd.DataTypeSize)
dst = pgio.AppendInt32(dst, fd.TypeModifier)
dst = pgio.AppendInt16(dst, fd.Format)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src RowDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Fields []FieldDescription
}{
Type: "RowDescription",
Fields: src.Fields,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *RowDescription) UnmarshalJSON(data []byte) error {
var msg struct {
Fields []struct {
Name string
TableOID uint32
TableAttributeNumber uint16
DataTypeOID uint32
DataTypeSize int16
TypeModifier int32
Format int16
}
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Fields = make([]FieldDescription, len(msg.Fields))
for n, field := range msg.Fields {
dst.Fields[n] = FieldDescription{
Name: []byte(field.Name),
TableOID: field.TableOID,
TableAttributeNumber: field.TableAttributeNumber,
DataTypeOID: field.DataTypeOID,
DataTypeSize: field.DataTypeSize,
TypeModifier: field.TypeModifier,
Format: field.Format,
}
}
return nil
}

@ -0,0 +1,94 @@
package pgproto3
import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type SASLInitialResponse struct {
AuthMechanism string
Data []byte
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*SASLInitialResponse) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *SASLInitialResponse) Decode(src []byte) error {
*dst = SASLInitialResponse{}
rp := 0
idx := bytes.IndexByte(src, 0)
if idx < 0 {
return errors.New("invalid SASLInitialResponse")
}
dst.AuthMechanism = string(src[rp:idx])
rp = idx + 1
rp += 4 // The rest of the message is data so we can just skip the size
dst.Data = src[rp:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *SASLInitialResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, []byte(src.AuthMechanism)...)
dst = append(dst, 0)
dst = pgio.AppendInt32(dst, int32(len(src.Data)))
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src SASLInitialResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
AuthMechanism string
Data string
}{
Type: "SASLInitialResponse",
AuthMechanism: src.AuthMechanism,
Data: string(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *SASLInitialResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
AuthMechanism string
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.AuthMechanism = msg.AuthMechanism
if msg.Data != "" {
decoded, err := hex.DecodeString(msg.Data)
if err != nil {
return err
}
dst.Data = decoded
}
return nil
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save