master
hzsunshx 9 years ago
parent 4e45962240
commit b9f6943615

57
Godeps/Godeps.json generated

@ -0,0 +1,57 @@
{
"ImportPath": "github.com/codeskyblue/gosuv",
"GoVersion": "go1.4.2",
"Deps": [
{
"ImportPath": "github.com/bradfitz/http2",
"Rev": "f8202bc903bda493ebba4aa54922d78430c2c42f"
},
{
"ImportPath": "github.com/codegangsta/cli",
"Comment": "1.2.0-139-g142e6cd",
"Rev": "142e6cd241a4dfbf7f07a018f1f8225180018da4"
},
{
"ImportPath": "github.com/codeskyblue/kproc",
"Rev": "fcb55eb35ab6b7290f395ad596e80e8355d52d69"
},
{
"ImportPath": "github.com/franela/goreq",
"Rev": "72c51a544272e007ab3da4f7d9ac959b7af7af03"
},
{
"ImportPath": "github.com/golang/protobuf/proto",
"Rev": "1dceb1a2654bdc74ca97ad91f71f500eecc96269"
},
{
"ImportPath": "github.com/lunny/log",
"Rev": "9ae5e9effeefca29cf00d932a7b54f1b2d70fd58"
},
{
"ImportPath": "github.com/lunny/tango",
"Comment": "v0.4.6-25-g652e7bd",
"Rev": "652e7bd578c564f7ae9307ed102d188e564baf54"
},
{
"ImportPath": "github.com/qiniu/log",
"Comment": "v1.0.00-2-ge002bc2",
"Rev": "e002bc2020b19bfa61ed378cc5407383dbd2f346"
},
{
"ImportPath": "golang.org/x/net/context",
"Rev": "db8e4de5b2d6653f66aea53094624468caad15d2"
},
{
"ImportPath": "golang.org/x/net/internal/timeseries",
"Rev": "db8e4de5b2d6653f66aea53094624468caad15d2"
},
{
"ImportPath": "golang.org/x/net/trace",
"Rev": "db8e4de5b2d6653f66aea53094624468caad15d2"
},
{
"ImportPath": "google.golang.org/grpc",
"Rev": "6b6127d8c67b67f9f8509b84279eac053964e868"
}
]
}

5
Godeps/Readme generated

@ -0,0 +1,5 @@
This directory tree is generated automatically by godep.
Please do not edit.
See https://github.com/tools/godep for more information.

2
Godeps/_workspace/.gitignore generated vendored

@ -0,0 +1,2 @@
/pkg
/bin

@ -0,0 +1,20 @@
# This file is like Go's AUTHORS file: it lists Copyright holders.
# The list of humans who have contributd is in the CONTRIBUTORS file.
#
# To contribute to this project, because it will eventually be folded
# back in to Go itself, you need to submit a CLA:
#
# http://golang.org/doc/contribute.html#copyright
#
# Then you get added to CONTRIBUTORS and you or your company get added
# to the AUTHORS file.
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
Google, Inc.
Keith Rarick <kr@xph.us> github=kr
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
Matt Layher <mdlayher@gmail.com> github=mdlayher
Perry Abbott <perry.j.abbott@gmail.com> github=pabbott0
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t

@ -0,0 +1,20 @@
# This file is like Go's CONTRIBUTORS file: it lists humans.
# The list of copyright holders (which may be companies) are in the AUTHORS file.
#
# To contribute to this project, because it will eventually be folded
# back in to Go itself, you need to submit a CLA:
#
# http://golang.org/doc/contribute.html#copyright
#
# Then you get added to CONTRIBUTORS and you or your company get added
# to the AUTHORS file.
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
Brad Fitzpatrick <bradfitz@golang.org> github=bradfitz
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
Keith Rarick <kr@xph.us> github=kr
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
Matt Layher <mdlayher@gmail.com> github=mdlayher
Perry Abbott <perry.j.abbott@gmail.com> github=pabbott0
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t

@ -0,0 +1,44 @@
#
# This Dockerfile builds a recent curl with HTTP/2 client support, using
# a recent nghttp2 build.
#
# See the Makefile for how to tag it. If Docker and that image is found, the
# Go tests use this curl binary for integration tests.
#
FROM ubuntu:trusty
RUN apt-get update && \
apt-get upgrade -y && \
apt-get install -y git-core build-essential wget
RUN apt-get install -y --no-install-recommends \
autotools-dev libtool pkg-config zlib1g-dev \
libcunit1-dev libssl-dev libxml2-dev libevent-dev \
automake autoconf
# Note: setting NGHTTP2_VER before the git clone, so an old git clone isn't cached:
ENV NGHTTP2_VER af24f8394e43f4
RUN cd /root && git clone https://github.com/tatsuhiro-t/nghttp2.git
WORKDIR /root/nghttp2
RUN git reset --hard $NGHTTP2_VER
RUN autoreconf -i
RUN automake
RUN autoconf
RUN ./configure
RUN make
RUN make install
WORKDIR /root
RUN wget http://curl.haxx.se/download/curl-7.40.0.tar.gz
RUN tar -zxvf curl-7.40.0.tar.gz
WORKDIR /root/curl-7.40.0
RUN ./configure --with-ssl --with-nghttp2=/usr/local
RUN make
RUN make install
RUN ldconfig
CMD ["-h"]
ENTRYPOINT ["/usr/local/bin/curl"]

@ -0,0 +1,5 @@
We only accept contributions from users who have gone through Go's
contribution process (signed a CLA).
Please acknowledge whether you have (and use the same email) if
sending a pull request.

@ -0,0 +1,7 @@
Copyright 2014 Google & the Go AUTHORS
Go AUTHORS are:
See https://code.google.com/p/go/source/browse/AUTHORS
Licensed under the terms of Go itself:
https://code.google.com/p/go/source/browse/LICENSE

@ -0,0 +1,3 @@
curlimage:
docker build -t gohttp2/curl .

@ -0,0 +1,17 @@
This is a work-in-progress HTTP/2 implementation for Go.
It will eventually live in the Go standard library and won't require
any changes to your code to use. It will just be automatic.
Status:
* The server support is pretty good. A few things are missing
but are being worked on.
* The client work has just started but shares a lot of code
is coming along much quicker.
Docs are at https://godoc.org/github.com/bradfitz/http2
Demo test server at https://http2.golang.org/
Help & bug reports welcome.

@ -0,0 +1,76 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"errors"
)
// buffer is an io.ReadWriteCloser backed by a fixed size buffer.
// It never allocates, but moves old data as new data is written.
type buffer struct {
buf []byte
r, w int
closed bool
err error // err to return to reader
}
var (
errReadEmpty = errors.New("read from empty buffer")
errWriteClosed = errors.New("write on closed buffer")
errWriteFull = errors.New("write on full buffer")
)
// Read copies bytes from the buffer into p.
// It is an error to read when no data is available.
func (b *buffer) Read(p []byte) (n int, err error) {
n = copy(p, b.buf[b.r:b.w])
b.r += n
if b.closed && b.r == b.w {
err = b.err
} else if b.r == b.w && n == 0 {
err = errReadEmpty
}
return n, err
}
// Len returns the number of bytes of the unread portion of the buffer.
func (b *buffer) Len() int {
return b.w - b.r
}
// Write copies bytes from p into the buffer.
// It is an error to write more data than the buffer can hold.
func (b *buffer) Write(p []byte) (n int, err error) {
if b.closed {
return 0, errWriteClosed
}
// Slide existing data to beginning.
if b.r > 0 && len(p) > len(b.buf)-b.w {
copy(b.buf, b.buf[b.r:b.w])
b.w -= b.r
b.r = 0
}
// Write new data.
n = copy(b.buf[b.w:], p)
b.w += n
if n < len(p) {
err = errWriteFull
}
return n, err
}
// Close marks the buffer as closed. Future calls to Write will
// return an error. Future calls to Read, once the buffer is
// empty, will return err.
func (b *buffer) Close(err error) {
if !b.closed {
b.closed = true
b.err = err
}
}

@ -0,0 +1,154 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"io"
"reflect"
"testing"
)
var bufferReadTests = []struct {
buf buffer
read, wn int
werr error
wp []byte
wbuf buffer
}{
{
buffer{[]byte{'a', 0}, 0, 1, false, nil},
5, 1, nil, []byte{'a'},
buffer{[]byte{'a', 0}, 1, 1, false, nil},
},
{
buffer{[]byte{'a', 0}, 0, 1, true, io.EOF},
5, 1, io.EOF, []byte{'a'},
buffer{[]byte{'a', 0}, 1, 1, true, io.EOF},
},
{
buffer{[]byte{0, 'a'}, 1, 2, false, nil},
5, 1, nil, []byte{'a'},
buffer{[]byte{0, 'a'}, 2, 2, false, nil},
},
{
buffer{[]byte{0, 'a'}, 1, 2, true, io.EOF},
5, 1, io.EOF, []byte{'a'},
buffer{[]byte{0, 'a'}, 2, 2, true, io.EOF},
},
{
buffer{[]byte{}, 0, 0, false, nil},
5, 0, errReadEmpty, []byte{},
buffer{[]byte{}, 0, 0, false, nil},
},
{
buffer{[]byte{}, 0, 0, true, io.EOF},
5, 0, io.EOF, []byte{},
buffer{[]byte{}, 0, 0, true, io.EOF},
},
}
func TestBufferRead(t *testing.T) {
for i, tt := range bufferReadTests {
read := make([]byte, tt.read)
n, err := tt.buf.Read(read)
if n != tt.wn {
t.Errorf("#%d: wn = %d want %d", i, n, tt.wn)
continue
}
if err != tt.werr {
t.Errorf("#%d: werr = %v want %v", i, err, tt.werr)
continue
}
read = read[:n]
if !reflect.DeepEqual(read, tt.wp) {
t.Errorf("#%d: read = %+v want %+v", i, read, tt.wp)
}
if !reflect.DeepEqual(tt.buf, tt.wbuf) {
t.Errorf("#%d: buf = %+v want %+v", i, tt.buf, tt.wbuf)
}
}
}
var bufferWriteTests = []struct {
buf buffer
write, wn int
werr error
wbuf buffer
}{
{
buf: buffer{
buf: []byte{},
},
wbuf: buffer{
buf: []byte{},
},
},
{
buf: buffer{
buf: []byte{1, 'a'},
},
write: 1,
wn: 1,
wbuf: buffer{
buf: []byte{0, 'a'},
w: 1,
},
},
{
buf: buffer{
buf: []byte{'a', 1},
r: 1,
w: 1,
},
write: 2,
wn: 2,
wbuf: buffer{
buf: []byte{0, 0},
w: 2,
},
},
{
buf: buffer{
buf: []byte{},
r: 1,
closed: true,
},
write: 5,
werr: errWriteClosed,
wbuf: buffer{
buf: []byte{},
r: 1,
closed: true,
},
},
{
buf: buffer{
buf: []byte{},
},
write: 5,
werr: errWriteFull,
wbuf: buffer{
buf: []byte{},
},
},
}
func TestBufferWrite(t *testing.T) {
for i, tt := range bufferWriteTests {
n, err := tt.buf.Write(make([]byte, tt.write))
if n != tt.wn {
t.Errorf("#%d: wrote %d bytes; want %d", i, n, tt.wn)
continue
}
if err != tt.werr {
t.Errorf("#%d: error = %v; want %v", i, err, tt.werr)
continue
}
if !reflect.DeepEqual(tt.buf, tt.wbuf) {
t.Errorf("#%d: buf = %+v; want %+v", i, tt.buf, tt.wbuf)
}
}
}

@ -0,0 +1,78 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import "fmt"
// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
type ErrCode uint32
const (
ErrCodeNo ErrCode = 0x0
ErrCodeProtocol ErrCode = 0x1
ErrCodeInternal ErrCode = 0x2
ErrCodeFlowControl ErrCode = 0x3
ErrCodeSettingsTimeout ErrCode = 0x4
ErrCodeStreamClosed ErrCode = 0x5
ErrCodeFrameSize ErrCode = 0x6
ErrCodeRefusedStream ErrCode = 0x7
ErrCodeCancel ErrCode = 0x8
ErrCodeCompression ErrCode = 0x9
ErrCodeConnect ErrCode = 0xa
ErrCodeEnhanceYourCalm ErrCode = 0xb
ErrCodeInadequateSecurity ErrCode = 0xc
ErrCodeHTTP11Required ErrCode = 0xd
)
var errCodeName = map[ErrCode]string{
ErrCodeNo: "NO_ERROR",
ErrCodeProtocol: "PROTOCOL_ERROR",
ErrCodeInternal: "INTERNAL_ERROR",
ErrCodeFlowControl: "FLOW_CONTROL_ERROR",
ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT",
ErrCodeStreamClosed: "STREAM_CLOSED",
ErrCodeFrameSize: "FRAME_SIZE_ERROR",
ErrCodeRefusedStream: "REFUSED_STREAM",
ErrCodeCancel: "CANCEL",
ErrCodeCompression: "COMPRESSION_ERROR",
ErrCodeConnect: "CONNECT_ERROR",
ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM",
ErrCodeInadequateSecurity: "INADEQUATE_SECURITY",
ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED",
}
func (e ErrCode) String() string {
if s, ok := errCodeName[e]; ok {
return s
}
return fmt.Sprintf("unknown error code 0x%x", uint32(e))
}
// ConnectionError is an error that results in the termination of the
// entire connection.
type ConnectionError ErrCode
func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: %s", ErrCode(e)) }
// StreamError is an error that only affects one stream within an
// HTTP/2 connection.
type StreamError struct {
StreamID uint32
Code ErrCode
}
func (e StreamError) Error() string {
return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
}
// 6.9.1 The Flow Control Window
// "If a sender receives a WINDOW_UPDATE that causes a flow control
// window to exceed this maximum it MUST terminate either the stream
// or the connection, as appropriate. For streams, [...]; for the
// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code."
type goAwayFlowError struct{}
func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" }

@ -0,0 +1,27 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import "testing"
func TestErrCodeString(t *testing.T) {
tests := []struct {
err ErrCode
want string
}{
{ErrCodeProtocol, "PROTOCOL_ERROR"},
{0xd, "HTTP_1_1_REQUIRED"},
{0xf, "unknown error code 0xf"},
}
for i, tt := range tests {
got := tt.err.String()
if got != tt.want {
t.Errorf("%d. Error = %q; want %q", i, got, tt.want)
}
}
}

@ -0,0 +1,51 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// Flow control
package http2
// flow is the flow control window's size.
type flow struct {
// n is the number of DATA bytes we're allowed to send.
// A flow is kept both on a conn and a per-stream.
n int32
// conn points to the shared connection-level flow that is
// shared by all streams on that conn. It is nil for the flow
// that's on the conn directly.
conn *flow
}
func (f *flow) setConnFlow(cf *flow) { f.conn = cf }
func (f *flow) available() int32 {
n := f.n
if f.conn != nil && f.conn.n < n {
n = f.conn.n
}
return n
}
func (f *flow) take(n int32) {
if n > f.available() {
panic("internal error: took too much")
}
f.n -= n
if f.conn != nil {
f.conn.n -= n
}
}
// add adds n bytes (positive or negative) to the flow control window.
// It returns false if the sum would exceed 2^31-1.
func (f *flow) add(n int32) bool {
remain := (1<<31 - 1) - f.n
if n > remain {
return false
}
f.n += n
return true
}

@ -0,0 +1,54 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import "testing"
func TestFlow(t *testing.T) {
var st flow
var conn flow
st.add(3)
conn.add(2)
if got, want := st.available(), int32(3); got != want {
t.Errorf("available = %d; want %d", got, want)
}
st.setConnFlow(&conn)
if got, want := st.available(), int32(2); got != want {
t.Errorf("after parent setup, available = %d; want %d", got, want)
}
st.take(2)
if got, want := conn.available(), int32(0); got != want {
t.Errorf("after taking 2, conn = %d; want %d", got, want)
}
if got, want := st.available(), int32(0); got != want {
t.Errorf("after taking 2, stream = %d; want %d", got, want)
}
}
func TestFlowAdd(t *testing.T) {
var f flow
if !f.add(1) {
t.Fatal("failed to add 1")
}
if !f.add(-1) {
t.Fatal("failed to add -1")
}
if got, want := f.available(), int32(0); got != want {
t.Fatalf("size = %d; want %d", got, want)
}
if !f.add(1<<31 - 1) {
t.Fatal("failed to add 2^31-1")
}
if got, want := f.available(), int32(1<<31-1); got != want {
t.Fatalf("size = %d; want %d", got, want)
}
if f.add(1) {
t.Fatal("adding 1 to max shouldn't be allowed")
}
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,597 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"bytes"
"reflect"
"strings"
"testing"
"unsafe"
)
func testFramer() (*Framer, *bytes.Buffer) {
buf := new(bytes.Buffer)
return NewFramer(buf, buf), buf
}
func TestFrameSizes(t *testing.T) {
// Catch people rearranging the FrameHeader fields.
if got, want := int(unsafe.Sizeof(FrameHeader{})), 12; got != want {
t.Errorf("FrameHeader size = %d; want %d", got, want)
}
}
func TestFrameTypeString(t *testing.T) {
tests := []struct {
ft FrameType
want string
}{
{FrameData, "DATA"},
{FramePing, "PING"},
{FrameGoAway, "GOAWAY"},
{0xf, "UNKNOWN_FRAME_TYPE_15"},
}
for i, tt := range tests {
got := tt.ft.String()
if got != tt.want {
t.Errorf("%d. String(FrameType %d) = %q; want %q", i, int(tt.ft), got, tt.want)
}
}
}
func TestWriteRST(t *testing.T) {
fr, buf := testFramer()
var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4
fr.WriteRSTStream(streamID, ErrCode(errCode))
const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
want := &RSTStreamFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x3,
Flags: 0x0,
Length: 0x4,
StreamID: 0x1020304,
},
ErrCode: 0x7060504,
}
if !reflect.DeepEqual(f, want) {
t.Errorf("parsed back %#v; want %#v", f, want)
}
}
func TestWriteData(t *testing.T) {
fr, buf := testFramer()
var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
data := []byte("ABC")
fr.WriteData(streamID, true, data)
const wantEnc = "\x00\x00\x03\x00\x01\x01\x02\x03\x04ABC"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
df, ok := f.(*DataFrame)
if !ok {
t.Fatalf("got %T; want *DataFrame", f)
}
if !bytes.Equal(df.Data(), data) {
t.Errorf("got %q; want %q", df.Data(), data)
}
if f.Header().Flags&1 == 0 {
t.Errorf("didn't see END_STREAM flag")
}
}
func TestWriteHeaders(t *testing.T) {
tests := []struct {
name string
p HeadersFrameParam
wantEnc string
wantFrame *HeadersFrame
}{
{
"basic",
HeadersFrameParam{
StreamID: 42,
BlockFragment: []byte("abc"),
Priority: PriorityParam{},
},
"\x00\x00\x03\x01\x00\x00\x00\x00*abc",
&HeadersFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: 42,
Type: FrameHeaders,
Length: uint32(len("abc")),
},
Priority: PriorityParam{},
headerFragBuf: []byte("abc"),
},
},
{
"basic + end flags",
HeadersFrameParam{
StreamID: 42,
BlockFragment: []byte("abc"),
EndStream: true,
EndHeaders: true,
Priority: PriorityParam{},
},
"\x00\x00\x03\x01\x05\x00\x00\x00*abc",
&HeadersFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: 42,
Type: FrameHeaders,
Flags: FlagHeadersEndStream | FlagHeadersEndHeaders,
Length: uint32(len("abc")),
},
Priority: PriorityParam{},
headerFragBuf: []byte("abc"),
},
},
{
"with padding",
HeadersFrameParam{
StreamID: 42,
BlockFragment: []byte("abc"),
EndStream: true,
EndHeaders: true,
PadLength: 5,
Priority: PriorityParam{},
},
"\x00\x00\t\x01\r\x00\x00\x00*\x05abc\x00\x00\x00\x00\x00",
&HeadersFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: 42,
Type: FrameHeaders,
Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded,
Length: uint32(1 + len("abc") + 5), // pad length + contents + padding
},
Priority: PriorityParam{},
headerFragBuf: []byte("abc"),
},
},
{
"with priority",
HeadersFrameParam{
StreamID: 42,
BlockFragment: []byte("abc"),
EndStream: true,
EndHeaders: true,
PadLength: 2,
Priority: PriorityParam{
StreamDep: 15,
Exclusive: true,
Weight: 127,
},
},
"\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x0f\u007fabc\x00\x00",
&HeadersFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: 42,
Type: FrameHeaders,
Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority,
Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding
},
Priority: PriorityParam{
StreamDep: 15,
Exclusive: true,
Weight: 127,
},
headerFragBuf: []byte("abc"),
},
},
}
for _, tt := range tests {
fr, buf := testFramer()
if err := fr.WriteHeaders(tt.p); err != nil {
t.Errorf("test %q: %v", tt.name, err)
continue
}
if buf.String() != tt.wantEnc {
t.Errorf("test %q: encoded %q; want %q", tt.name, buf.Bytes(), tt.wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
continue
}
if !reflect.DeepEqual(f, tt.wantFrame) {
t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
}
}
}
func TestWriteContinuation(t *testing.T) {
const streamID = 42
tests := []struct {
name string
end bool
frag []byte
wantFrame *ContinuationFrame
}{
{
"not end",
false,
[]byte("abc"),
&ContinuationFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: streamID,
Type: FrameContinuation,
Length: uint32(len("abc")),
},
headerFragBuf: []byte("abc"),
},
},
{
"end",
true,
[]byte("def"),
&ContinuationFrame{
FrameHeader: FrameHeader{
valid: true,
StreamID: streamID,
Type: FrameContinuation,
Flags: FlagContinuationEndHeaders,
Length: uint32(len("def")),
},
headerFragBuf: []byte("def"),
},
},
}
for _, tt := range tests {
fr, _ := testFramer()
if err := fr.WriteContinuation(streamID, tt.end, tt.frag); err != nil {
t.Errorf("test %q: %v", tt.name, err)
continue
}
f, err := fr.ReadFrame()
if err != nil {
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
continue
}
if !reflect.DeepEqual(f, tt.wantFrame) {
t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
}
}
}
func TestWritePriority(t *testing.T) {
const streamID = 42
tests := []struct {
name string
priority PriorityParam
wantFrame *PriorityFrame
}{
{
"not exclusive",
PriorityParam{
StreamDep: 2,
Exclusive: false,
Weight: 127,
},
&PriorityFrame{
FrameHeader{
valid: true,
StreamID: streamID,
Type: FramePriority,
Length: 5,
},
PriorityParam{
StreamDep: 2,
Exclusive: false,
Weight: 127,
},
},
},
{
"exclusive",
PriorityParam{
StreamDep: 3,
Exclusive: true,
Weight: 77,
},
&PriorityFrame{
FrameHeader{
valid: true,
StreamID: streamID,
Type: FramePriority,
Length: 5,
},
PriorityParam{
StreamDep: 3,
Exclusive: true,
Weight: 77,
},
},
},
}
for _, tt := range tests {
fr, _ := testFramer()
if err := fr.WritePriority(streamID, tt.priority); err != nil {
t.Errorf("test %q: %v", tt.name, err)
continue
}
f, err := fr.ReadFrame()
if err != nil {
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
continue
}
if !reflect.DeepEqual(f, tt.wantFrame) {
t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
}
}
}
func TestWriteSettings(t *testing.T) {
fr, buf := testFramer()
settings := []Setting{{1, 2}, {3, 4}}
fr.WriteSettings(settings...)
const wantEnc = "\x00\x00\f\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x03\x00\x00\x00\x04"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
sf, ok := f.(*SettingsFrame)
if !ok {
t.Fatalf("Got a %T; want a SettingsFrame", f)
}
var got []Setting
sf.ForeachSetting(func(s Setting) error {
got = append(got, s)
valBack, ok := sf.Value(s.ID)
if !ok || valBack != s.Val {
t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok, s.Val)
}
return nil
})
if !reflect.DeepEqual(settings, got) {
t.Errorf("Read settings %+v != written settings %+v", got, settings)
}
}
func TestWriteSettingsAck(t *testing.T) {
fr, buf := testFramer()
fr.WriteSettingsAck()
const wantEnc = "\x00\x00\x00\x04\x01\x00\x00\x00\x00"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
}
func TestWriteWindowUpdate(t *testing.T) {
fr, buf := testFramer()
const streamID = 1<<24 + 2<<16 + 3<<8 + 4
const incr = 7<<24 + 6<<16 + 5<<8 + 4
if err := fr.WriteWindowUpdate(streamID, incr); err != nil {
t.Fatal(err)
}
const wantEnc = "\x00\x00\x04\x08\x00\x01\x02\x03\x04\x07\x06\x05\x04"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
want := &WindowUpdateFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x8,
Flags: 0x0,
Length: 0x4,
StreamID: 0x1020304,
},
Increment: 0x7060504,
}
if !reflect.DeepEqual(f, want) {
t.Errorf("parsed back %#v; want %#v", f, want)
}
}
func TestWritePing(t *testing.T) { testWritePing(t, false) }
func TestWritePingAck(t *testing.T) { testWritePing(t, true) }
func testWritePing(t *testing.T, ack bool) {
fr, buf := testFramer()
if err := fr.WritePing(ack, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
t.Fatal(err)
}
var wantFlags Flags
if ack {
wantFlags = FlagPingAck
}
var wantEnc = "\x00\x00\x08\x06" + string(wantFlags) + "\x00\x00\x00\x00" + "\x01\x02\x03\x04\x05\x06\x07\x08"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
want := &PingFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x6,
Flags: wantFlags,
Length: 0x8,
StreamID: 0,
},
Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8},
}
if !reflect.DeepEqual(f, want) {
t.Errorf("parsed back %#v; want %#v", f, want)
}
}
func TestReadFrameHeader(t *testing.T) {
tests := []struct {
in string
want FrameHeader
}{
{in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: FrameHeader{}},
{in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: FrameHeader{
Length: 66051, Type: 4, Flags: 5, StreamID: 101124105,
}},
// Ignore high bit:
{in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: FrameHeader{
Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
{in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: FrameHeader{
Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
}
for i, tt := range tests {
got, err := readFrameHeader(make([]byte, 9), strings.NewReader(tt.in))
if err != nil {
t.Errorf("%d. readFrameHeader(%q) = %v", i, tt.in, err)
continue
}
tt.want.valid = true
if got != tt.want {
t.Errorf("%d. readFrameHeader(%q) = %+v; want %+v", i, tt.in, got, tt.want)
}
}
}
func TestReadWriteFrameHeader(t *testing.T) {
tests := []struct {
len uint32
typ FrameType
flags Flags
streamID uint32
}{
{len: 0, typ: 255, flags: 1, streamID: 0},
{len: 0, typ: 255, flags: 1, streamID: 1},
{len: 0, typ: 255, flags: 1, streamID: 255},
{len: 0, typ: 255, flags: 1, streamID: 256},
{len: 0, typ: 255, flags: 1, streamID: 65535},
{len: 0, typ: 255, flags: 1, streamID: 65536},
{len: 0, typ: 1, flags: 255, streamID: 1},
{len: 255, typ: 1, flags: 255, streamID: 1},
{len: 256, typ: 1, flags: 255, streamID: 1},
{len: 65535, typ: 1, flags: 255, streamID: 1},
{len: 65536, typ: 1, flags: 255, streamID: 1},
{len: 16777215, typ: 1, flags: 255, streamID: 1},
}
for _, tt := range tests {
fr, buf := testFramer()
fr.startWrite(tt.typ, tt.flags, tt.streamID)
fr.writeBytes(make([]byte, tt.len))
fr.endWrite()
fh, err := ReadFrameHeader(buf)
if err != nil {
t.Errorf("ReadFrameHeader(%+v) = %v", tt, err)
continue
}
if fh.Type != tt.typ || fh.Flags != tt.flags || fh.Length != tt.len || fh.StreamID != tt.streamID {
t.Errorf("ReadFrameHeader(%+v) = %+v; mismatch", tt, fh)
}
}
}
func TestWriteTooLargeFrame(t *testing.T) {
fr, _ := testFramer()
fr.startWrite(0, 1, 1)
fr.writeBytes(make([]byte, 1<<24))
err := fr.endWrite()
if err != ErrFrameTooLarge {
t.Errorf("endWrite = %v; want errFrameTooLarge", err)
}
}
func TestWriteGoAway(t *testing.T) {
const debug = "foo"
fr, buf := testFramer()
if err := fr.WriteGoAway(0x01020304, 0x05060708, []byte(debug)); err != nil {
t.Fatal(err)
}
const wantEnc = "\x00\x00\v\a\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08" + debug
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
want := &GoAwayFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x7,
Flags: 0,
Length: uint32(4 + 4 + len(debug)),
StreamID: 0,
},
LastStreamID: 0x01020304,
ErrCode: 0x05060708,
debugData: []byte(debug),
}
if !reflect.DeepEqual(f, want) {
t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
}
if got := string(f.(*GoAwayFrame).DebugData()); got != debug {
t.Errorf("debug data = %q; want %q", got, debug)
}
}
func TestWritePushPromise(t *testing.T) {
pp := PushPromiseParam{
StreamID: 42,
PromiseID: 42,
BlockFragment: []byte("abc"),
}
fr, buf := testFramer()
if err := fr.WritePushPromise(pp); err != nil {
t.Fatal(err)
}
const wantEnc = "\x00\x00\x07\x05\x00\x00\x00\x00*\x00\x00\x00*abc"
if buf.String() != wantEnc {
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
}
f, err := fr.ReadFrame()
if err != nil {
t.Fatal(err)
}
_, ok := f.(*PushPromiseFrame)
if !ok {
t.Fatalf("got %T; want *PushPromiseFrame", f)
}
want := &PushPromiseFrame{
FrameHeader: FrameHeader{
valid: true,
Type: 0x5,
Flags: 0x0,
Length: 0x7,
StreamID: 42,
},
PromiseID: 42,
headerFragBuf: []byte("abc"),
}
if !reflect.DeepEqual(f, want) {
t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
}
}

@ -0,0 +1,173 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// Defensive debug-only utility to track that functions run on the
// goroutine that they're supposed to.
package http2
import (
"bytes"
"errors"
"fmt"
"os"
"runtime"
"strconv"
"sync"
)
var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
type goroutineLock uint64
func newGoroutineLock() goroutineLock {
if !DebugGoroutines {
return 0
}
return goroutineLock(curGoroutineID())
}
func (g goroutineLock) check() {
if !DebugGoroutines {
return
}
if curGoroutineID() != uint64(g) {
panic("running on the wrong goroutine")
}
}
func (g goroutineLock) checkNotOn() {
if !DebugGoroutines {
return
}
if curGoroutineID() == uint64(g) {
panic("running on the wrong goroutine")
}
}
var goroutineSpace = []byte("goroutine ")
func curGoroutineID() uint64 {
bp := littleBuf.Get().(*[]byte)
defer littleBuf.Put(bp)
b := *bp
b = b[:runtime.Stack(b, false)]
// Parse the 4707 out of "goroutine 4707 ["
b = bytes.TrimPrefix(b, goroutineSpace)
i := bytes.IndexByte(b, ' ')
if i < 0 {
panic(fmt.Sprintf("No space found in %q", b))
}
b = b[:i]
n, err := parseUintBytes(b, 10, 64)
if err != nil {
panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
}
return n
}
var littleBuf = sync.Pool{
New: func() interface{} {
buf := make([]byte, 64)
return &buf
},
}
// parseUintBytes is like strconv.ParseUint, but using a []byte.
func parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) {
var cutoff, maxVal uint64
if bitSize == 0 {
bitSize = int(strconv.IntSize)
}
s0 := s
switch {
case len(s) < 1:
err = strconv.ErrSyntax
goto Error
case 2 <= base && base <= 36:
// valid base; nothing to do
case base == 0:
// Look for octal, hex prefix.
switch {
case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
base = 16
s = s[2:]
if len(s) < 1 {
err = strconv.ErrSyntax
goto Error
}
case s[0] == '0':
base = 8
default:
base = 10
}
default:
err = errors.New("invalid base " + strconv.Itoa(base))
goto Error
}
n = 0
cutoff = cutoff64(base)
maxVal = 1<<uint(bitSize) - 1
for i := 0; i < len(s); i++ {
var v byte
d := s[i]
switch {
case '0' <= d && d <= '9':
v = d - '0'
case 'a' <= d && d <= 'z':
v = d - 'a' + 10
case 'A' <= d && d <= 'Z':
v = d - 'A' + 10
default:
n = 0
err = strconv.ErrSyntax
goto Error
}
if int(v) >= base {
n = 0
err = strconv.ErrSyntax
goto Error
}
if n >= cutoff {
// n*base overflows
n = 1<<64 - 1
err = strconv.ErrRange
goto Error
}
n *= uint64(base)
n1 := n + uint64(v)
if n1 < n || n1 > maxVal {
// n+v overflows
n = 1<<64 - 1
err = strconv.ErrRange
goto Error
}
n = n1
}
return n, nil
Error:
return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err}
}
// Return the first number n such that n*base >= 1<<64.
func cutoff64(base int) uint64 {
if base < 2 {
return 0
}
return (1<<64-1)/uint64(base) + 1
}

@ -0,0 +1,33 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"fmt"
"strings"
"testing"
)
func TestGoroutineLock(t *testing.T) {
DebugGoroutines = true
g := newGoroutineLock()
g.check()
sawPanic := make(chan interface{})
go func() {
defer func() { sawPanic <- recover() }()
g.check() // should panic
}()
e := <-sawPanic
if e == nil {
t.Fatal("did not see panic from check in other goroutine")
}
if !strings.Contains(fmt.Sprint(e), "wrong goroutine") {
t.Errorf("expected on see panic about running on the wrong goroutine; got %v", e)
}
}

@ -0,0 +1,5 @@
h2demo
h2demo.linux
client-id.dat
client-secret.dat
token.dat

@ -0,0 +1,5 @@
h2demo.linux: h2demo.go
GOOS=linux go build --tags=h2demo -o h2demo.linux .
upload: h2demo.linux
cat h2demo.linux | go run launch.go --write_object=http2-demo-server-tls/h2demo --write_object_is_public

@ -0,0 +1,16 @@
Client:
-- Firefox nightly with about:config network.http.spdy.enabled.http2draft set true
-- Chrome: go to chrome://flags/#enable-spdy4, save and restart (button at bottom)
Make CA:
$ openssl genrsa -out rootCA.key 2048
$ openssl req -x509 -new -nodes -key rootCA.key -days 1024 -out rootCA.pem
... install that to Firefox
Make cert:
$ openssl genrsa -out server.key 2048
$ openssl req -new -key server.key -out server.csr
$ openssl x509 -req -in server.csr -CA rootCA.pem -CAkey rootCA.key -CAcreateserial -out server.crt -days 500

@ -0,0 +1,426 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// +build h2demo
package main
import (
"bytes"
"crypto/tls"
"flag"
"fmt"
"hash/crc32"
"image"
"image/jpeg"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"os/exec"
"path"
"regexp"
"runtime"
"strconv"
"strings"
"sync"
"time"
"camlistore.org/pkg/googlestorage"
"camlistore.org/pkg/singleflight"
"github.com/bradfitz/http2"
)
var (
openFirefox = flag.Bool("openff", false, "Open Firefox")
addr = flag.String("addr", "localhost:4430", "TLS address to listen on")
httpAddr = flag.String("httpaddr", "", "If non-empty, address to listen for regular HTTP on")
prod = flag.Bool("prod", false, "Whether to configure itself to be the production http2.golang.org server.")
)
func homeOldHTTP(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, `<html>
<body>
<h1>Go + HTTP/2</h1>
<p>Welcome to <a href="https://golang.org/">the Go language</a>'s <a href="https://http2.github.io/">HTTP/2</a> demo & interop server.</p>
<p>Unfortunately, you're <b>not</b> using HTTP/2 right now. To do so:</p>
<ul>
<li>Use Firefox Nightly or go to <b>about:config</b> and enable "network.http.spdy.enabled.http2draft"</li>
<li>Use Google Chrome Canary and/or go to <b>chrome://flags/#enable-spdy4</b> to <i>Enable SPDY/4</i> (Chrome's name for HTTP/2)</li>
</ul>
<p>See code & instructions for connecting at <a href="https://github.com/bradfitz/http2">https://github.com/bradfitz/http2</a>.</p>
</body></html>`)
}
func home(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
io.WriteString(w, `<html>
<body>
<h1>Go + HTTP/2</h1>
<p>Welcome to <a href="https://golang.org/">the Go language</a>'s <a
href="https://http2.github.io/">HTTP/2</a> demo & interop server.</p>
<p>Congratulations, <b>you're using HTTP/2 right now</b>.</p>
<p>This server exists for others in the HTTP/2 community to test their HTTP/2 client implementations and point out flaws in our server.</p>
<p> The code is currently at <a
href="https://github.com/bradfitz/http2">github.com/bradfitz/http2</a>
but will move to the Go standard library at some point in the future
(enabled by default, without users needing to change their code).</p>
<p>Contact info: <i>bradfitz@golang.org</i>, or <a
href="https://github.com/bradfitz/http2/issues">file a bug</a>.</p>
<h2>Handlers for testing</h2>
<ul>
<li>GET <a href="/reqinfo">/reqinfo</a> to dump the request + headers received</li>
<li>GET <a href="/clockstream">/clockstream</a> streams the current time every second</li>
<li>GET <a href="/gophertiles">/gophertiles</a> to see a page with a bunch of images</li>
<li>GET <a href="/file/gopher.png">/file/gopher.png</a> for a small file (does If-Modified-Since, Content-Range, etc)</li>
<li>GET <a href="/file/go.src.tar.gz">/file/go.src.tar.gz</a> for a larger file (~10 MB)</li>
<li>GET <a href="/redirect">/redirect</a> to redirect back to / (this page)</li>
<li>GET <a href="/goroutines">/goroutines</a> to see all active goroutines in this server</li>
<li>PUT something to <a href="/crc32">/crc32</a> to get a count of number of bytes and its CRC-32</li>
</ul>
</body></html>`)
}
func reqInfoHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "Method: %s\n", r.Method)
fmt.Fprintf(w, "Protocol: %s\n", r.Proto)
fmt.Fprintf(w, "Host: %s\n", r.Host)
fmt.Fprintf(w, "RemoteAddr: %s\n", r.RemoteAddr)
fmt.Fprintf(w, "RequestURI: %q\n", r.RequestURI)
fmt.Fprintf(w, "URL: %#v\n", r.URL)
fmt.Fprintf(w, "Body.ContentLength: %d (-1 means unknown)\n", r.ContentLength)
fmt.Fprintf(w, "Close: %v (relevant for HTTP/1 only)\n", r.Close)
fmt.Fprintf(w, "TLS: %#v\n", r.TLS)
fmt.Fprintf(w, "\nHeaders:\n")
r.Header.Write(w)
}
func crcHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
http.Error(w, "PUT required.", 400)
return
}
crc := crc32.NewIEEE()
n, err := io.Copy(crc, r.Body)
if err == nil {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "bytes=%d, CRC32=%x", n, crc.Sum(nil))
}
}
var (
fsGrp singleflight.Group
fsMu sync.Mutex // guards fsCache
fsCache = map[string]http.Handler{}
)
// fileServer returns a file-serving handler that proxies URL.
// It lazily fetches URL on the first access and caches its contents forever.
func fileServer(url string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
hi, err := fsGrp.Do(url, func() (interface{}, error) {
fsMu.Lock()
if h, ok := fsCache[url]; ok {
fsMu.Unlock()
return h, nil
}
fsMu.Unlock()
res, err := http.Get(url)
if err != nil {
return nil, err
}
defer res.Body.Close()
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
modTime := time.Now()
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.ServeContent(w, r, path.Base(url), modTime, bytes.NewReader(slurp))
})
fsMu.Lock()
fsCache[url] = h
fsMu.Unlock()
return h, nil
})
if err != nil {
http.Error(w, err.Error(), 500)
return
}
hi.(http.Handler).ServeHTTP(w, r)
})
}
func clockStreamHandler(w http.ResponseWriter, r *http.Request) {
clientGone := w.(http.CloseNotifier).CloseNotify()
w.Header().Set("Content-Type", "text/plain")
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
fmt.Fprintf(w, "# ~1KB of junk to force browsers to start rendering immediately: \n")
io.WriteString(w, strings.Repeat("# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n", 13))
for {
fmt.Fprintf(w, "%v\n", time.Now())
w.(http.Flusher).Flush()
select {
case <-ticker.C:
case <-clientGone:
log.Printf("Client %v disconnected from the clock", r.RemoteAddr)
return
}
}
}
func registerHandlers() {
tiles := newGopherTilesHandler()
mux2 := http.NewServeMux()
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil {
if r.URL.Path == "/gophertiles" {
tiles.ServeHTTP(w, r)
return
}
http.Redirect(w, r, "https://http2.golang.org/", http.StatusFound)
return
}
if r.ProtoMajor == 1 {
if r.URL.Path == "/reqinfo" {
reqInfoHandler(w, r)
return
}
homeOldHTTP(w, r)
return
}
mux2.ServeHTTP(w, r)
})
mux2.HandleFunc("/", home)
mux2.Handle("/file/gopher.png", fileServer("https://golang.org/doc/gopher/frontpage.png"))
mux2.Handle("/file/go.src.tar.gz", fileServer("https://storage.googleapis.com/golang/go1.4.1.src.tar.gz"))
mux2.HandleFunc("/reqinfo", reqInfoHandler)
mux2.HandleFunc("/crc32", crcHandler)
mux2.HandleFunc("/clockstream", clockStreamHandler)
mux2.Handle("/gophertiles", tiles)
mux2.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/", http.StatusFound)
})
stripHomedir := regexp.MustCompile(`/(Users|home)/\w+`)
mux2.HandleFunc("/goroutines", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
buf := make([]byte, 2<<20)
w.Write(stripHomedir.ReplaceAll(buf[:runtime.Stack(buf, true)], nil))
})
}
func newGopherTilesHandler() http.Handler {
const gopherURL = "https://blog.golang.org/go-programming-language-turns-two_gophers.jpg"
res, err := http.Get(gopherURL)
if err != nil {
log.Fatal(err)
}
if res.StatusCode != 200 {
log.Fatalf("Error fetching %s: %v", gopherURL, res.Status)
}
slurp, err := ioutil.ReadAll(res.Body)
res.Body.Close()
if err != nil {
log.Fatal(err)
}
im, err := jpeg.Decode(bytes.NewReader(slurp))
if err != nil {
if len(slurp) > 1024 {
slurp = slurp[:1024]
}
log.Fatalf("Failed to decode gopher image: %v (got %q)", err, slurp)
}
type subImager interface {
SubImage(image.Rectangle) image.Image
}
const tileSize = 32
xt := im.Bounds().Max.X / tileSize
yt := im.Bounds().Max.Y / tileSize
var tile [][][]byte // y -> x -> jpeg bytes
for yi := 0; yi < yt; yi++ {
var row [][]byte
for xi := 0; xi < xt; xi++ {
si := im.(subImager).SubImage(image.Rectangle{
Min: image.Point{xi * tileSize, yi * tileSize},
Max: image.Point{(xi + 1) * tileSize, (yi + 1) * tileSize},
})
buf := new(bytes.Buffer)
if err := jpeg.Encode(buf, si, &jpeg.Options{Quality: 90}); err != nil {
log.Fatal(err)
}
row = append(row, buf.Bytes())
}
tile = append(tile, row)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ms, _ := strconv.Atoi(r.FormValue("latency"))
const nanosPerMilli = 1e6
if r.FormValue("x") != "" {
x, _ := strconv.Atoi(r.FormValue("x"))
y, _ := strconv.Atoi(r.FormValue("y"))
if ms <= 1000 {
time.Sleep(time.Duration(ms) * nanosPerMilli)
}
if x >= 0 && x < xt && y >= 0 && y < yt {
http.ServeContent(w, r, "", time.Time{}, bytes.NewReader(tile[y][x]))
return
}
}
io.WriteString(w, "<html><body>")
fmt.Fprintf(w, "A grid of %d tiled images is below. Compare:<p>", xt*yt)
for _, ms := range []int{0, 30, 200, 1000} {
d := time.Duration(ms) * nanosPerMilli
fmt.Fprintf(w, "[<a href='https://%s/gophertiles?latency=%d'>HTTP/2, %v latency</a>] [<a href='http://%s/gophertiles?latency=%d'>HTTP/1, %v latency</a>]<br>\n",
httpsHost(), ms, d,
httpHost(), ms, d,
)
}
io.WriteString(w, "<p>\n")
cacheBust := time.Now().UnixNano()
for y := 0; y < yt; y++ {
for x := 0; x < xt; x++ {
fmt.Fprintf(w, "<img width=%d height=%d src='/gophertiles?x=%d&y=%d&cachebust=%d&latency=%d'>",
tileSize, tileSize, x, y, cacheBust, ms)
}
io.WriteString(w, "<br/>\n")
}
io.WriteString(w, "<hr><a href='/'>&lt;&lt Back to Go HTTP/2 demo server</a></body></html>")
})
}
func httpsHost() string {
if *prod {
return "http2.golang.org"
}
if v := *addr; strings.HasPrefix(v, ":") {
return "localhost" + v
} else {
return v
}
}
func httpHost() string {
if *prod {
return "http2.golang.org"
}
if v := *httpAddr; strings.HasPrefix(v, ":") {
return "localhost" + v
} else {
return v
}
}
func serveProdTLS() error {
c, err := googlestorage.NewServiceClient()
if err != nil {
return err
}
slurp := func(key string) ([]byte, error) {
const bucket = "http2-demo-server-tls"
rc, _, err := c.GetObject(&googlestorage.Object{
Bucket: bucket,
Key: key,
})
if err != nil {
return nil, fmt.Errorf("Error fetching GCS object %q in bucket %q: %v", key, bucket, err)
}
defer rc.Close()
return ioutil.ReadAll(rc)
}
certPem, err := slurp("http2.golang.org.chained.pem")
if err != nil {
return err
}
keyPem, err := slurp("http2.golang.org.key")
if err != nil {
return err
}
cert, err := tls.X509KeyPair(certPem, keyPem)
if err != nil {
return err
}
srv := &http.Server{
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
}
http2.ConfigureServer(srv, &http2.Server{})
ln, err := net.Listen("tcp", ":443")
if err != nil {
return err
}
return srv.Serve(tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig))
}
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
tc, err := ln.AcceptTCP()
if err != nil {
return
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
return tc, nil
}
func serveProd() error {
errc := make(chan error, 2)
go func() { errc <- http.ListenAndServe(":80", nil) }()
go func() { errc <- serveProdTLS() }()
return <-errc
}
func main() {
var srv http.Server
flag.BoolVar(&http2.VerboseLogs, "verbose", false, "Verbose HTTP/2 debugging.")
flag.Parse()
srv.Addr = *addr
registerHandlers()
if *prod {
*httpAddr = "http2.golang.org"
log.Fatal(serveProd())
}
url := "https://" + *addr + "/"
log.Printf("Listening on " + url)
http2.ConfigureServer(&srv, &http2.Server{})
if *httpAddr != "" {
go func() { log.Fatal(http.ListenAndServe(*httpAddr, nil)) }()
}
go func() {
log.Fatal(srv.ListenAndServeTLS("server.crt", "server.key"))
}()
if *openFirefox && runtime.GOOS == "darwin" {
time.Sleep(250 * time.Millisecond)
exec.Command("open", "-b", "org.mozilla.nightly", "https://localhost:4430/").Run()
}
select {}
}

@ -0,0 +1,302 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
package main
import (
"bufio"
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"strings"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
compute "google.golang.org/api/compute/v1"
)
var (
proj = flag.String("project", "symbolic-datum-552", "name of Project")
zone = flag.String("zone", "us-central1-a", "GCE zone")
mach = flag.String("machinetype", "n1-standard-1", "Machine type")
instName = flag.String("instance_name", "http2-demo", "Name of VM instance.")
sshPub = flag.String("ssh_public_key", "", "ssh public key file to authorize. Can modify later in Google's web UI anyway.")
staticIP = flag.String("static_ip", "130.211.116.44", "Static IP to use. If empty, automatic.")
writeObject = flag.String("write_object", "", "If non-empty, a VM isn't created and the flag value is Google Cloud Storage bucket/object to write. The contents from stdin.")
publicObject = flag.Bool("write_object_is_public", false, "Whether the object created by --write_object should be public.")
)
func readFile(v string) string {
slurp, err := ioutil.ReadFile(v)
if err != nil {
log.Fatalf("Error reading %s: %v", v, err)
}
return strings.TrimSpace(string(slurp))
}
var config = &oauth2.Config{
// The client-id and secret should be for an "Installed Application" when using
// the CLI. Later we'll use a web application with a callback.
ClientID: readFile("client-id.dat"),
ClientSecret: readFile("client-secret.dat"),
Endpoint: google.Endpoint,
Scopes: []string{
compute.DevstorageFull_controlScope,
compute.ComputeScope,
"https://www.googleapis.com/auth/sqlservice",
"https://www.googleapis.com/auth/sqlservice.admin",
},
RedirectURL: "urn:ietf:wg:oauth:2.0:oob",
}
const baseConfig = `#cloud-config
coreos:
units:
- name: h2demo.service
command: start
content: |
[Unit]
Description=HTTP2 Demo
[Service]
ExecStartPre=/bin/bash -c 'mkdir -p /opt/bin && curl -s -o /opt/bin/h2demo http://storage.googleapis.com/http2-demo-server-tls/h2demo && chmod +x /opt/bin/h2demo'
ExecStart=/opt/bin/h2demo --prod
RestartSec=5s
Restart=always
Type=simple
[Install]
WantedBy=multi-user.target
`
func main() {
flag.Parse()
if *proj == "" {
log.Fatalf("Missing --project flag")
}
prefix := "https://www.googleapis.com/compute/v1/projects/" + *proj
machType := prefix + "/zones/" + *zone + "/machineTypes/" + *mach
const tokenFileName = "token.dat"
tokenFile := tokenCacheFile(tokenFileName)
tokenSource := oauth2.ReuseTokenSource(nil, tokenFile)
token, err := tokenSource.Token()
if err != nil {
if *writeObject != "" {
log.Fatalf("Can't use --write_object without a valid token.dat file already cached.")
}
log.Printf("Error getting token from %s: %v", tokenFileName, err)
log.Printf("Get auth code from %v", config.AuthCodeURL("my-state"))
fmt.Print("\nEnter auth code: ")
sc := bufio.NewScanner(os.Stdin)
sc.Scan()
authCode := strings.TrimSpace(sc.Text())
token, err = config.Exchange(oauth2.NoContext, authCode)
if err != nil {
log.Fatalf("Error exchanging auth code for a token: %v", err)
}
if err := tokenFile.WriteToken(token); err != nil {
log.Fatalf("Error writing to %s: %v", tokenFileName, err)
}
tokenSource = oauth2.ReuseTokenSource(token, nil)
}
oauthClient := oauth2.NewClient(oauth2.NoContext, tokenSource)
if *writeObject != "" {
writeCloudStorageObject(oauthClient)
return
}
computeService, _ := compute.New(oauthClient)
natIP := *staticIP
if natIP == "" {
// Try to find it by name.
aggAddrList, err := computeService.Addresses.AggregatedList(*proj).Do()
if err != nil {
log.Fatal(err)
}
// http://godoc.org/code.google.com/p/google-api-go-client/compute/v1#AddressAggregatedList
IPLoop:
for _, asl := range aggAddrList.Items {
for _, addr := range asl.Addresses {
if addr.Name == *instName+"-ip" && addr.Status == "RESERVED" {
natIP = addr.Address
break IPLoop
}
}
}
}
cloudConfig := baseConfig
if *sshPub != "" {
key := strings.TrimSpace(readFile(*sshPub))
cloudConfig += fmt.Sprintf("\nssh_authorized_keys:\n - %s\n", key)
}
if os.Getenv("USER") == "bradfitz" {
cloudConfig += fmt.Sprintf("\nssh_authorized_keys:\n - %s\n", "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEAwks9dwWKlRC+73gRbvYtVg0vdCwDSuIlyt4z6xa/YU/jTDynM4R4W10hm2tPjy8iR1k8XhDv4/qdxe6m07NjG/By1tkmGpm1mGwho4Pr5kbAAy/Qg+NLCSdAYnnE00FQEcFOC15GFVMOW2AzDGKisReohwH9eIzHPzdYQNPRWXE= bradfitz@papag.bradfitz.com")
}
const maxCloudConfig = 32 << 10 // per compute API docs
if len(cloudConfig) > maxCloudConfig {
log.Fatalf("cloud config length of %d bytes is over %d byte limit", len(cloudConfig), maxCloudConfig)
}
instance := &compute.Instance{
Name: *instName,
Description: "Go Builder",
MachineType: machType,
Disks: []*compute.AttachedDisk{instanceDisk(computeService)},
Tags: &compute.Tags{
Items: []string{"http-server", "https-server"},
},
Metadata: &compute.Metadata{
Items: []*compute.MetadataItems{
{
Key: "user-data",
Value: cloudConfig,
},
},
},
NetworkInterfaces: []*compute.NetworkInterface{
&compute.NetworkInterface{
AccessConfigs: []*compute.AccessConfig{
&compute.AccessConfig{
Type: "ONE_TO_ONE_NAT",
Name: "External NAT",
NatIP: natIP,
},
},
Network: prefix + "/global/networks/default",
},
},
ServiceAccounts: []*compute.ServiceAccount{
{
Email: "default",
Scopes: []string{
compute.DevstorageFull_controlScope,
compute.ComputeScope,
},
},
},
}
log.Printf("Creating instance...")
op, err := computeService.Instances.Insert(*proj, *zone, instance).Do()
if err != nil {
log.Fatalf("Failed to create instance: %v", err)
}
opName := op.Name
log.Printf("Created. Waiting on operation %v", opName)
OpLoop:
for {
time.Sleep(2 * time.Second)
op, err := computeService.ZoneOperations.Get(*proj, *zone, opName).Do()
if err != nil {
log.Fatalf("Failed to get op %s: %v", opName, err)
}
switch op.Status {
case "PENDING", "RUNNING":
log.Printf("Waiting on operation %v", opName)
continue
case "DONE":
if op.Error != nil {
for _, operr := range op.Error.Errors {
log.Printf("Error: %+v", operr)
}
log.Fatalf("Failed to start.")
}
log.Printf("Success. %+v", op)
break OpLoop
default:
log.Fatalf("Unknown status %q: %+v", op.Status, op)
}
}
inst, err := computeService.Instances.Get(*proj, *zone, *instName).Do()
if err != nil {
log.Fatalf("Error getting instance after creation: %v", err)
}
ij, _ := json.MarshalIndent(inst, "", " ")
log.Printf("Instance: %s", ij)
}
func instanceDisk(svc *compute.Service) *compute.AttachedDisk {
const imageURL = "https://www.googleapis.com/compute/v1/projects/coreos-cloud/global/images/coreos-stable-444-5-0-v20141016"
diskName := *instName + "-disk"
return &compute.AttachedDisk{
AutoDelete: true,
Boot: true,
Type: "PERSISTENT",
InitializeParams: &compute.AttachedDiskInitializeParams{
DiskName: diskName,
SourceImage: imageURL,
DiskSizeGb: 50,
},
}
}
func writeCloudStorageObject(httpClient *http.Client) {
content := os.Stdin
const maxSlurp = 1 << 20
var buf bytes.Buffer
n, err := io.CopyN(&buf, content, maxSlurp)
if err != nil && err != io.EOF {
log.Fatalf("Error reading from stdin: %v, %v", n, err)
}
contentType := http.DetectContentType(buf.Bytes())
req, err := http.NewRequest("PUT", "https://storage.googleapis.com/"+*writeObject, io.MultiReader(&buf, content))
if err != nil {
log.Fatal(err)
}
req.Header.Set("x-goog-api-version", "2")
if *publicObject {
req.Header.Set("x-goog-acl", "public-read")
}
req.Header.Set("Content-Type", contentType)
res, err := httpClient.Do(req)
if err != nil {
log.Fatal(err)
}
if res.StatusCode != 200 {
res.Write(os.Stderr)
log.Fatalf("Failed.")
}
log.Printf("Success.")
os.Exit(0)
}
type tokenCacheFile string
func (f tokenCacheFile) Token() (*oauth2.Token, error) {
slurp, err := ioutil.ReadFile(string(f))
if err != nil {
return nil, err
}
t := new(oauth2.Token)
if err := json.Unmarshal(slurp, t); err != nil {
return nil, err
}
return t, nil
}
func (f tokenCacheFile) WriteToken(t *oauth2.Token) error {
jt, err := json.Marshal(t)
if err != nil {
return err
}
return ioutil.WriteFile(string(f), jt, 0600)
}

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAt5fAjp4fTcekWUTfzsp0kyih1OYbsGL0KX1eRbSSR8Od0+9Q
62Hyny+GFwMTb4A/KU8mssoHvcceSAAbwfbxFK/+s51TobqUnORZrOoTZjkUygby
XDSK99YBbcR1Pip8vwMTm4XKuLtCigeBBdjjAQdgUO28LENGlsMnmeYkJfODVGnV
mr5Ltb9ANA8IKyTfsnHJ4iOCS/PlPbUj2q7YnoVLposUBMlgUb/CykX3mOoLb4yJ
JQyA/iST6ZxiIEj36D4yWZ5lg7YJl+UiiBQHGCnPdGyipqV06ex0heYWcaiW8LWZ
SUQ93jQ+WVCH8hT7DQO1dmsvUmXlq/JeAlwQ/QIDAQABAoIBAFFHV7JMAqPWnMYA
nezY6J81v9+XN+7xABNWM2Q8uv4WdksbigGLTXR3/680Z2hXqJ7LMeC5XJACFT/e
/Gr0vmpgOCygnCPfjGehGKpavtfksXV3edikUlnCXsOP1C//c1bFL+sMYmFCVgTx
qYdDK8yKzXNGrKYT6q5YG7IglyRNV1rsQa8lM/5taFYiD1Ck/3tQi3YIq8Lcuser
hrxsMABcQ6mi+EIvG6Xr4mfJug0dGJMHG4RG1UGFQn6RXrQq2+q53fC8ZbVUSi0j
NQ918aKFzktwv+DouKU0ME4I9toks03gM860bAL7zCbKGmwR3hfgX/TqzVCWpG9E
LDVfvekCgYEA8fk9N53jbBRmULUGEf4qWypcLGiZnNU0OeXWpbPV9aa3H0VDytA7
8fCN2dPAVDPqlthMDdVe983NCNwp2Yo8ZimDgowyIAKhdC25s1kejuaiH9OAPj3c
0f8KbriYX4n8zNHxFwK6Ae3pQ6EqOLJVCUsziUaZX9nyKY5aZlyX6xcCgYEAwjws
K62PjC64U5wYddNLp+kNdJ4edx+a7qBb3mEgPvSFT2RO3/xafJyG8kQB30Mfstjd
bRxyUV6N0vtX1zA7VQtRUAvfGCecpMo+VQZzcHXKzoRTnQ7eZg4Lmj5fQ9tOAKAo
QCVBoSW/DI4PZL26CAMDcAba4Pa22ooLapoRIQsCgYA6pIfkkbxLNkpxpt2YwLtt
Kr/590O7UaR9n6k8sW/aQBRDXNsILR1KDl2ifAIxpf9lnXgZJiwE7HiTfCAcW7c1
nzwDCI0hWuHcMTS/NYsFYPnLsstyyjVZI3FY0h4DkYKV9Q9z3zJLQ2hz/nwoD3gy
b2pHC7giFcTts1VPV4Nt8wKBgHeFn4ihHJweg76vZz3Z78w7VNRWGFklUalVdDK7
gaQ7w2y/ROn/146mo0OhJaXFIFRlrpvdzVrU3GDf2YXJYDlM5ZRkObwbZADjksev
WInzcgDy3KDg7WnPasRXbTfMU4t/AkW2p1QKbi3DnSVYuokDkbH2Beo45vxDxhKr
C69RAoGBAIyo3+OJenoZmoNzNJl2WPW5MeBUzSh8T/bgyjFTdqFHF5WiYRD/lfHj
x9Glyw2nutuT4hlOqHvKhgTYdDMsF2oQ72fe3v8Q5FU7FuKndNPEAyvKNXZaShVA
hnlhv5DjXKb0wFWnt5PCCiQLtzG0yyHaITrrEme7FikkIcTxaX/Y
-----END RSA PRIVATE KEY-----

@ -0,0 +1,26 @@
-----BEGIN CERTIFICATE-----
MIIEWjCCA0KgAwIBAgIJALfRlWsI8YQHMA0GCSqGSIb3DQEBBQUAMHsxCzAJBgNV
BAYTAlVTMQswCQYDVQQIEwJDQTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEUMBIG
A1UEChMLQnJhZGZpdHppbmMxEjAQBgNVBAMTCWxvY2FsaG9zdDEdMBsGCSqGSIb3
DQEJARYOYnJhZEBkYW5nYS5jb20wHhcNMTQwNzE1MjA0NjA1WhcNMTcwNTA0MjA0
NjA1WjB7MQswCQYDVQQGEwJVUzELMAkGA1UECBMCQ0ExFjAUBgNVBAcTDVNhbiBG
cmFuY2lzY28xFDASBgNVBAoTC0JyYWRmaXR6aW5jMRIwEAYDVQQDEwlsb2NhbGhv
c3QxHTAbBgkqhkiG9w0BCQEWDmJyYWRAZGFuZ2EuY29tMIIBIjANBgkqhkiG9w0B
AQEFAAOCAQ8AMIIBCgKCAQEAt5fAjp4fTcekWUTfzsp0kyih1OYbsGL0KX1eRbSS
R8Od0+9Q62Hyny+GFwMTb4A/KU8mssoHvcceSAAbwfbxFK/+s51TobqUnORZrOoT
ZjkUygbyXDSK99YBbcR1Pip8vwMTm4XKuLtCigeBBdjjAQdgUO28LENGlsMnmeYk
JfODVGnVmr5Ltb9ANA8IKyTfsnHJ4iOCS/PlPbUj2q7YnoVLposUBMlgUb/CykX3
mOoLb4yJJQyA/iST6ZxiIEj36D4yWZ5lg7YJl+UiiBQHGCnPdGyipqV06ex0heYW
caiW8LWZSUQ93jQ+WVCH8hT7DQO1dmsvUmXlq/JeAlwQ/QIDAQABo4HgMIHdMB0G
A1UdDgQWBBRcAROthS4P4U7vTfjByC569R7E6DCBrQYDVR0jBIGlMIGigBRcAROt
hS4P4U7vTfjByC569R7E6KF/pH0wezELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNB
MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRQwEgYDVQQKEwtCcmFkZml0emluYzES
MBAGA1UEAxMJbG9jYWxob3N0MR0wGwYJKoZIhvcNAQkBFg5icmFkQGRhbmdhLmNv
bYIJALfRlWsI8YQHMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAG6h
U9f9sNH0/6oBbGGy2EVU0UgITUQIrFWo9rFkrW5k/XkDjQm+3lzjT0iGR4IxE/Ao
eU6sQhua7wrWeFEn47GL98lnCsJdD7oZNhFmQ95Tb/LnDUjs5Yj9brP0NWzXfYU4
UK2ZnINJRcJpB8iRCaCxE8DdcUF0XqIEq6pA272snoLmiXLMvNl3kYEdm+je6voD
58SNVEUsztzQyXmJEhCpwVI0A6QCjzXj+qvpmw3ZZHi8JwXei8ZZBLTSFBki8Z7n
sH9BBH38/SzUmAN4QHSPy1gjqm00OAE8NaYDkh/bzE4d7mLGGMWp/WE3KPSu82HF
kPe6XoSbiLm/kxk32T0=
-----END CERTIFICATE-----

@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDPjCCAiYCCQDizia/MoUFnDANBgkqhkiG9w0BAQUFADB7MQswCQYDVQQGEwJV
UzELMAkGA1UECBMCQ0ExFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xFDASBgNVBAoT
C0JyYWRmaXR6aW5jMRIwEAYDVQQDEwlsb2NhbGhvc3QxHTAbBgkqhkiG9w0BCQEW
DmJyYWRAZGFuZ2EuY29tMB4XDTE0MDcxNTIwNTAyN1oXDTE1MTEyNzIwNTAyN1ow
RzELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNBMQswCQYDVQQHEwJTRjEeMBwGA1UE
ChMVYnJhZGZpdHogaHR0cDIgc2VydmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
MIIBCgKCAQEAs1Y9CyLFrdL8VQWN1WaifDqaZFnoqjHhCMlc1TfG2zA+InDifx2l
gZD3o8FeNnAcfM2sPlk3+ZleOYw9P/CklFVDlvqmpCv9ss/BEp/dDaWvy1LmJ4c2
dbQJfmTxn7CV1H3TsVJvKdwFmdoABb41NoBp6+NNO7OtDyhbIMiCI0pL3Nefb3HL
A7hIMo3DYbORTtJLTIH9W8YKrEWL0lwHLrYFx/UdutZnv+HjdmO6vCN4na55mjws
/vjKQUmc7xeY7Xe20xDEG2oDKVkL2eD7FfyrYMS3rO1ExP2KSqlXYG/1S9I/fz88
F0GK7HX55b5WjZCl2J3ERVdnv/0MQv+sYQIDAQABMA0GCSqGSIb3DQEBBQUAA4IB
AQC0zL+n/YpRZOdulSu9tS8FxrstXqGWoxfe+vIUgqfMZ5+0MkjJ/vW0FqlLDl2R
rn4XaR3e7FmWkwdDVbq/UB6lPmoAaFkCgh9/5oapMaclNVNnfF3fjCJfRr+qj/iD
EmJStTIN0ZuUjAlpiACmfnpEU55PafT5Zx+i1yE4FGjw8bJpFoyD4Hnm54nGjX19
KeCuvcYFUPnBm3lcL0FalF2AjqV02WTHYNQk7YF/oeO7NKBoEgvGvKG3x+xaOeBI
dwvdq175ZsGul30h+QjrRlXhH/twcuaT3GSdoysDl9cCYE8f1Mk8PD6gan3uBCJU
90p6/CbU71bGbfpM2PHot2fm
-----END CERTIFICATE-----

@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAs1Y9CyLFrdL8VQWN1WaifDqaZFnoqjHhCMlc1TfG2zA+InDi
fx2lgZD3o8FeNnAcfM2sPlk3+ZleOYw9P/CklFVDlvqmpCv9ss/BEp/dDaWvy1Lm
J4c2dbQJfmTxn7CV1H3TsVJvKdwFmdoABb41NoBp6+NNO7OtDyhbIMiCI0pL3Nef
b3HLA7hIMo3DYbORTtJLTIH9W8YKrEWL0lwHLrYFx/UdutZnv+HjdmO6vCN4na55
mjws/vjKQUmc7xeY7Xe20xDEG2oDKVkL2eD7FfyrYMS3rO1ExP2KSqlXYG/1S9I/
fz88F0GK7HX55b5WjZCl2J3ERVdnv/0MQv+sYQIDAQABAoIBADQ2spUwbY+bcz4p
3M66ECrNQTBggP40gYl2XyHxGGOu2xhZ94f9ELf1hjRWU2DUKWco1rJcdZClV6q3
qwmXvcM2Q/SMS8JW0ImkNVl/0/NqPxGatEnj8zY30d/L8hGFb0orzFu/XYA5gCP4
NbN2WrXgk3ZLeqwcNxHHtSiJWGJ/fPyeDWAu/apy75u9Xf2GlzBZmV6HYD9EfK80
LTlI60f5FO487CrJnboL7ovPJrIHn+k05xRQqwma4orpz932rTXnTjs9Lg6KtbQN
a7PrqfAntIISgr11a66Mng3IYH1lYqJsWJJwX/xHT4WLEy0EH4/0+PfYemJekz2+
Co62drECgYEA6O9zVJZXrLSDsIi54cfxA7nEZWm5CAtkYWeAHa4EJ+IlZ7gIf9sL
W8oFcEfFGpvwVqWZ+AsQ70dsjXAv3zXaG0tmg9FtqWp7pzRSMPidifZcQwWkKeTO
gJnFmnVyed8h6GfjTEu4gxo1/S5U0V+mYSha01z5NTnN6ltKx1Or3b0CgYEAxRgm
S30nZxnyg/V7ys61AZhst1DG2tkZXEMcA7dYhabMoXPJAP/EfhlWwpWYYUs/u0gS
Wwmf5IivX5TlYScgmkvb/NYz0u4ZmOXkLTnLPtdKKFXhjXJcHjUP67jYmOxNlJLp
V4vLRnFxTpffAV+OszzRxsXX6fvruwZBANYJeXUCgYBVouLFsFgfWGYp2rpr9XP4
KK25kvrBqF6JKOIDB1zjxNJ3pUMKrl8oqccCFoCyXa4oTM2kUX0yWxHfleUjrMq4
yimwQKiOZmV7fVLSSjSw6e/VfBd0h3gb82ygcplZkN0IclkwTY5SNKqwn/3y07V5
drqdhkrgdJXtmQ6O5YYECQKBgATERcDToQ1USlI4sKrB/wyv1AlG8dg/IebiVJ4e
ZAyvcQmClFzq0qS+FiQUnB/WQw9TeeYrwGs1hxBHuJh16srwhLyDrbMvQP06qh8R
48F8UXXSRec22dV9MQphaROhu2qZdv1AC0WD3tqov6L33aqmEOi+xi8JgbT/PLk5
c/c1AoGBAI1A/02ryksW6/wc7/6SP2M2rTy4m1sD/GnrTc67EHnRcVBdKO6qH2RY
nqC8YcveC2ZghgPTDsA3VGuzuBXpwY6wTyV99q6jxQJ6/xcrD9/NUG6Uwv/xfCxl
IJLeBYEqQundSSny3VtaAUK8Ul1nxpTvVRNwtcyWTo8RHAAyNPWd
-----END RSA PRIVATE KEY-----

@ -0,0 +1,97 @@
# h2i
**h2i** is an interactive HTTP/2 ("h2") console debugger. Miss the good ol'
days of telnetting to your HTTP/1.n servers? We're bringing you
back.
Features:
- send raw HTTP/2 frames
- PING
- SETTINGS
- HEADERS
- etc
- type in HTTP/1.n and have it auto-HPACK/frame-ify it for HTTP/2
- pretty print all received HTTP/2 frames from the peer (including HPACK decoding)
- tab completion of commands, options
Not yet features, but soon:
- unnecessary CONTINUATION frames on short boundaries, to test peer implementations
- request bodies (DATA frames)
- send invalid frames for testing server implementations (supported by underlying Framer)
Later:
- act like a server
## Installation
```
$ go get github.com/bradfitz/http2/h2i
$ h2i <host>
```
## Demo
```
$ h2i
Usage: h2i <hostname>
-insecure
Whether to skip TLS cert validation
-nextproto string
Comma-separated list of NPN/ALPN protocol names to negotiate. (default "h2,h2-14")
$ h2i google.com
Connecting to google.com:443 ...
Connected to 74.125.224.41:443
Negotiated protocol "h2-14"
[FrameHeader SETTINGS len=18]
[MAX_CONCURRENT_STREAMS = 100]
[INITIAL_WINDOW_SIZE = 1048576]
[MAX_FRAME_SIZE = 16384]
[FrameHeader WINDOW_UPDATE len=4]
Window-Increment = 983041
h2i> PING h2iSayHI
[FrameHeader PING flags=ACK len=8]
Data = "h2iSayHI"
h2i> headers
(as HTTP/1.1)> GET / HTTP/1.1
(as HTTP/1.1)> Host: ip.appspot.com
(as HTTP/1.1)> User-Agent: h2i/brad-n-blake
(as HTTP/1.1)>
Opening Stream-ID 1:
:authority = ip.appspot.com
:method = GET
:path = /
:scheme = https
user-agent = h2i/brad-n-blake
[FrameHeader HEADERS flags=END_HEADERS stream=1 len=77]
:status = "200"
alternate-protocol = "443:quic,p=1"
content-length = "15"
content-type = "text/html"
date = "Fri, 01 May 2015 23:06:56 GMT"
server = "Google Frontend"
[FrameHeader DATA flags=END_STREAM stream=1 len=15]
"173.164.155.78\n"
[FrameHeader PING len=8]
Data = "\x00\x00\x00\x00\x00\x00\x00\x00"
h2i> ping
[FrameHeader PING flags=ACK len=8]
Data = "h2i_ping"
h2i> ping
[FrameHeader PING flags=ACK len=8]
Data = "h2i_ping"
h2i> ping
[FrameHeader GOAWAY len=22]
Last-Stream-ID = 1; Error-Code = PROTOCOL_ERROR (1)
ReadFrame: EOF
```
## Status
Quick few hour hack. So much yet to do. Feel free to file issues for
bugs or wishlist items, but [@bmizerany](https://github.com/bmizerany/)
and I aren't yet accepting pull requests until things settle down.

@ -0,0 +1,489 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
/*
The h2i command is an interactive HTTP/2 console.
Usage:
$ h2i [flags] <hostname>
Interactive commands in the console: (all parts case-insensitive)
ping [data]
settings ack
settings FOO=n BAR=z
headers (open a new stream by typing HTTP/1.1)
*/
package main
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"github.com/bradfitz/http2"
"github.com/bradfitz/http2/hpack"
"golang.org/x/crypto/ssh/terminal"
)
// Flags
var (
flagNextProto = flag.String("nextproto", "h2,h2-14", "Comma-separated list of NPN/ALPN protocol names to negotiate.")
flagInsecure = flag.Bool("insecure", false, "Whether to skip TLS cert validation")
)
type command struct {
run func(*h2i, []string) error // required
// complete optionally specifies tokens (case-insensitive) which are
// valid for this subcommand.
complete func() []string
}
var commands = map[string]command{
"ping": command{run: (*h2i).cmdPing},
"settings": command{
run: (*h2i).cmdSettings,
complete: func() []string {
return []string{
"ACK",
http2.SettingHeaderTableSize.String(),
http2.SettingEnablePush.String(),
http2.SettingMaxConcurrentStreams.String(),
http2.SettingInitialWindowSize.String(),
http2.SettingMaxFrameSize.String(),
http2.SettingMaxHeaderListSize.String(),
}
},
},
"quit": command{run: (*h2i).cmdQuit},
"headers": command{run: (*h2i).cmdHeaders},
}
func usage() {
fmt.Fprintf(os.Stderr, "Usage: h2i <hostname>\n\n")
flag.PrintDefaults()
os.Exit(1)
}
// withPort adds ":443" if another port isn't already present.
func withPort(host string) string {
if _, _, err := net.SplitHostPort(host); err != nil {
return net.JoinHostPort(host, "443")
}
return host
}
// h2i is the app's state.
type h2i struct {
host string
tc *tls.Conn
framer *http2.Framer
term *terminal.Terminal
// owned by the command loop:
streamID uint32
hbuf bytes.Buffer
henc *hpack.Encoder
// owned by the readFrames loop:
peerSetting map[http2.SettingID]uint32
hdec *hpack.Decoder
}
func main() {
flag.Usage = usage
flag.Parse()
if flag.NArg() != 1 {
usage()
}
log.SetFlags(0)
host := flag.Arg(0)
app := &h2i{
host: host,
peerSetting: make(map[http2.SettingID]uint32),
}
app.henc = hpack.NewEncoder(&app.hbuf)
if err := app.Main(); err != nil {
if app.term != nil {
app.logf("%v\n", err)
} else {
fmt.Fprintf(os.Stderr, "%v\n", err)
}
os.Exit(1)
}
fmt.Fprintf(os.Stdout, "\n")
}
func (app *h2i) Main() error {
cfg := &tls.Config{
ServerName: app.host,
NextProtos: strings.Split(*flagNextProto, ","),
InsecureSkipVerify: *flagInsecure,
}
hostAndPort := withPort(app.host)
log.Printf("Connecting to %s ...", hostAndPort)
tc, err := tls.Dial("tcp", hostAndPort, cfg)
if err != nil {
return fmt.Errorf("Error dialing %s: %v", withPort(app.host), err)
}
log.Printf("Connected to %v", tc.RemoteAddr())
defer tc.Close()
if err := tc.Handshake(); err != nil {
return fmt.Errorf("TLS handshake: %v", err)
}
if !*flagInsecure {
if err := tc.VerifyHostname(app.host); err != nil {
return fmt.Errorf("VerifyHostname: %v", err)
}
}
state := tc.ConnectionState()
log.Printf("Negotiated protocol %q", state.NegotiatedProtocol)
if !state.NegotiatedProtocolIsMutual || state.NegotiatedProtocol == "" {
return fmt.Errorf("Could not negotiate protocol mutually")
}
if _, err := io.WriteString(tc, http2.ClientPreface); err != nil {
return err
}
app.framer = http2.NewFramer(tc, tc)
oldState, err := terminal.MakeRaw(0)
if err != nil {
return err
}
defer terminal.Restore(0, oldState)
var screen = struct {
io.Reader
io.Writer
}{os.Stdin, os.Stdout}
app.term = terminal.NewTerminal(screen, "h2i> ")
lastWord := regexp.MustCompile(`.+\W(\w+)$`)
app.term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
if key != '\t' {
return
}
if pos != len(line) {
// TODO: we're being lazy for now, only supporting tab completion at the end.
return
}
// Auto-complete for the command itself.
if !strings.Contains(line, " ") {
var name string
name, _, ok = lookupCommand(line)
if !ok {
return
}
return name, len(name), true
}
_, c, ok := lookupCommand(line[:strings.IndexByte(line, ' ')])
if !ok || c.complete == nil {
return
}
if strings.HasSuffix(line, " ") {
app.logf("%s", strings.Join(c.complete(), " "))
return line, pos, true
}
m := lastWord.FindStringSubmatch(line)
if m == nil {
return line, len(line), true
}
soFar := m[1]
var match []string
for _, cand := range c.complete() {
if len(soFar) > len(cand) || !strings.EqualFold(cand[:len(soFar)], soFar) {
continue
}
match = append(match, cand)
}
if len(match) == 0 {
return
}
if len(match) > 1 {
// TODO: auto-complete any common prefix
app.logf("%s", strings.Join(match, " "))
return line, pos, true
}
newLine = line[:len(line)-len(soFar)] + match[0]
return newLine, len(newLine), true
}
errc := make(chan error, 2)
go func() { errc <- app.readFrames() }()
go func() { errc <- app.readConsole() }()
return <-errc
}
func (app *h2i) logf(format string, args ...interface{}) {
fmt.Fprintf(app.term, format+"\n", args...)
}
func (app *h2i) readConsole() error {
for {
line, err := app.term.ReadLine()
if err == io.EOF {
return nil
}
if err != nil {
return fmt.Errorf("terminal.ReadLine: %v", err)
}
f := strings.Fields(line)
if len(f) == 0 {
continue
}
cmd, args := f[0], f[1:]
if _, c, ok := lookupCommand(cmd); ok {
err = c.run(app, args)
} else {
app.logf("Unknown command %q", line)
}
if err == errExitApp {
return nil
}
if err != nil {
return err
}
}
}
func lookupCommand(prefix string) (name string, c command, ok bool) {
prefix = strings.ToLower(prefix)
if c, ok = commands[prefix]; ok {
return prefix, c, ok
}
for full, candidate := range commands {
if strings.HasPrefix(full, prefix) {
if c.run != nil {
return "", command{}, false // ambiguous
}
c = candidate
name = full
}
}
return name, c, c.run != nil
}
var errExitApp = errors.New("internal sentinel error value to quit the console reading loop")
func (a *h2i) cmdQuit(args []string) error {
if len(args) > 0 {
a.logf("the QUIT command takes no argument")
return nil
}
return errExitApp
}
func (a *h2i) cmdSettings(args []string) error {
if len(args) == 1 && strings.EqualFold(args[0], "ACK") {
return a.framer.WriteSettingsAck()
}
var settings []http2.Setting
for _, arg := range args {
if strings.EqualFold(arg, "ACK") {
a.logf("Error: ACK must be only argument with the SETTINGS command")
return nil
}
eq := strings.Index(arg, "=")
if eq == -1 {
a.logf("Error: invalid argument %q (expected SETTING_NAME=nnnn)", arg)
return nil
}
sid, ok := settingByName(arg[:eq])
if !ok {
a.logf("Error: unknown setting name %q", arg[:eq])
return nil
}
val, err := strconv.ParseUint(arg[eq+1:], 10, 32)
if err != nil {
a.logf("Error: invalid argument %q (expected SETTING_NAME=nnnn)", arg)
return nil
}
settings = append(settings, http2.Setting{
ID: sid,
Val: uint32(val),
})
}
a.logf("Sending: %v", settings)
return a.framer.WriteSettings(settings...)
}
func settingByName(name string) (http2.SettingID, bool) {
for _, sid := range [...]http2.SettingID{
http2.SettingHeaderTableSize,
http2.SettingEnablePush,
http2.SettingMaxConcurrentStreams,
http2.SettingInitialWindowSize,
http2.SettingMaxFrameSize,
http2.SettingMaxHeaderListSize,
} {
if strings.EqualFold(sid.String(), name) {
return sid, true
}
}
return 0, false
}
func (app *h2i) cmdPing(args []string) error {
if len(args) > 1 {
app.logf("invalid PING usage: only accepts 0 or 1 args")
return nil // nil means don't end the program
}
var data [8]byte
if len(args) == 1 {
copy(data[:], args[0])
} else {
copy(data[:], "h2i_ping")
}
return app.framer.WritePing(false, data)
}
func (app *h2i) cmdHeaders(args []string) error {
if len(args) > 0 {
app.logf("Error: HEADERS doesn't yet take arguments.")
// TODO: flags for restricting window size, to force CONTINUATION
// frames.
return nil
}
var h1req bytes.Buffer
app.term.SetPrompt("(as HTTP/1.1)> ")
defer app.term.SetPrompt("h2i> ")
for {
line, err := app.term.ReadLine()
if err != nil {
return err
}
h1req.WriteString(line)
h1req.WriteString("\r\n")
if line == "" {
break
}
}
req, err := http.ReadRequest(bufio.NewReader(&h1req))
if err != nil {
app.logf("Invalid HTTP/1.1 request: %v", err)
return nil
}
if app.streamID == 0 {
app.streamID = 1
} else {
app.streamID += 2
}
app.logf("Opening Stream-ID %d:", app.streamID)
hbf := app.encodeHeaders(req)
if len(hbf) > 16<<10 {
app.logf("TODO: h2i doesn't yet write CONTINUATION frames. Copy it from transport.go")
return nil
}
return app.framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: app.streamID,
BlockFragment: hbf,
EndStream: req.Method == "GET" || req.Method == "HEAD", // good enough for now
EndHeaders: true, // for now
})
}
func (app *h2i) readFrames() error {
for {
f, err := app.framer.ReadFrame()
if err != nil {
return fmt.Errorf("ReadFrame: %v", err)
}
app.logf("%v", f)
switch f := f.(type) {
case *http2.PingFrame:
app.logf(" Data = %q", f.Data)
case *http2.SettingsFrame:
f.ForeachSetting(func(s http2.Setting) error {
app.logf(" %v", s)
app.peerSetting[s.ID] = s.Val
return nil
})
case *http2.WindowUpdateFrame:
app.logf(" Window-Increment = %v\n", f.Increment)
case *http2.GoAwayFrame:
app.logf(" Last-Stream-ID = %d; Error-Code = %v (%d)\n", f.LastStreamID, f.ErrCode, f.ErrCode)
case *http2.DataFrame:
app.logf(" %q", f.Data())
case *http2.HeadersFrame:
if f.HasPriority() {
app.logf(" PRIORITY = %v", f.Priority)
}
if app.hdec == nil {
// TODO: if the user uses h2i to send a SETTINGS frame advertising
// something larger, we'll need to respect SETTINGS_HEADER_TABLE_SIZE
// and stuff here instead of using the 4k default. But for now:
tableSize := uint32(4 << 10)
app.hdec = hpack.NewDecoder(tableSize, app.onNewHeaderField)
}
app.hdec.Write(f.HeaderBlockFragment())
}
}
}
// called from readLoop
func (app *h2i) onNewHeaderField(f hpack.HeaderField) {
if f.Sensitive {
app.logf(" %s = %q (SENSITIVE)", f.Name, f.Value)
}
app.logf(" %s = %q", f.Name, f.Value)
}
func (app *h2i) encodeHeaders(req *http.Request) []byte {
app.hbuf.Reset()
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
host := req.Host
if host == "" {
host = req.URL.Host
}
path := req.URL.Path
if path == "" {
path = "/"
}
app.writeHeader(":authority", host) // probably not right for all sites
app.writeHeader(":method", req.Method)
app.writeHeader(":path", path)
app.writeHeader(":scheme", "https")
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
if lowKey == "host" {
continue
}
for _, v := range vv {
app.writeHeader(lowKey, v)
}
}
return app.hbuf.Bytes()
}
func (app *h2i) writeHeader(name, value string) {
app.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
app.logf(" %s = %s", name, value)
}

@ -0,0 +1,80 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"net/http"
"strings"
)
var (
commonLowerHeader = map[string]string{} // Go-Canonical-Case -> lower-case
commonCanonHeader = map[string]string{} // lower-case -> Go-Canonical-Case
)
func init() {
for _, v := range []string{
"accept",
"accept-charset",
"accept-encoding",
"accept-language",
"accept-ranges",
"age",
"access-control-allow-origin",
"allow",
"authorization",
"cache-control",
"content-disposition",
"content-encoding",
"content-language",
"content-length",
"content-location",
"content-range",
"content-type",
"cookie",
"date",
"etag",
"expect",
"expires",
"from",
"host",
"if-match",
"if-modified-since",
"if-none-match",
"if-unmodified-since",
"last-modified",
"link",
"location",
"max-forwards",
"proxy-authenticate",
"proxy-authorization",
"range",
"referer",
"refresh",
"retry-after",
"server",
"set-cookie",
"strict-transport-security",
"transfer-encoding",
"user-agent",
"vary",
"via",
"www-authenticate",
} {
chk := http.CanonicalHeaderKey(v)
commonLowerHeader[chk] = v
commonCanonHeader[v] = chk
}
}
func lowerHeader(v string) string {
if s, ok := commonLowerHeader[v]; ok {
return s
}
return strings.ToLower(v)
}

@ -0,0 +1,252 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
import (
"io"
)
const (
uint32Max = ^uint32(0)
initialHeaderTableSize = 4096
)
type Encoder struct {
dynTab dynamicTable
// minSize is the minimum table size set by
// SetMaxDynamicTableSize after the previous Header Table Size
// Update.
minSize uint32
// maxSizeLimit is the maximum table size this encoder
// supports. This will protect the encoder from too large
// size.
maxSizeLimit uint32
// tableSizeUpdate indicates whether "Header Table Size
// Update" is required.
tableSizeUpdate bool
w io.Writer
buf []byte
}
// NewEncoder returns a new Encoder which performs HPACK encoding. An
// encoded data is written to w.
func NewEncoder(w io.Writer) *Encoder {
e := &Encoder{
minSize: uint32Max,
maxSizeLimit: initialHeaderTableSize,
tableSizeUpdate: false,
w: w,
}
e.dynTab.setMaxSize(initialHeaderTableSize)
return e
}
// WriteField encodes f into a single Write to e's underlying Writer.
// This function may also produce bytes for "Header Table Size Update"
// if necessary. If produced, it is done before encoding f.
func (e *Encoder) WriteField(f HeaderField) error {
e.buf = e.buf[:0]
if e.tableSizeUpdate {
e.tableSizeUpdate = false
if e.minSize < e.dynTab.maxSize {
e.buf = appendTableSize(e.buf, e.minSize)
}
e.minSize = uint32Max
e.buf = appendTableSize(e.buf, e.dynTab.maxSize)
}
idx, nameValueMatch := e.searchTable(f)
if nameValueMatch {
e.buf = appendIndexed(e.buf, idx)
} else {
indexing := e.shouldIndex(f)
if indexing {
e.dynTab.add(f)
}
if idx == 0 {
e.buf = appendNewName(e.buf, f, indexing)
} else {
e.buf = appendIndexedName(e.buf, f, idx, indexing)
}
}
n, err := e.w.Write(e.buf)
if err == nil && n != len(e.buf) {
err = io.ErrShortWrite
}
return err
}
// searchTable searches f in both stable and dynamic header tables.
// The static header table is searched first. Only when there is no
// exact match for both name and value, the dynamic header table is
// then searched. If there is no match, i is 0. If both name and value
// match, i is the matched index and nameValueMatch becomes true. If
// only name matches, i points to that index and nameValueMatch
// becomes false.
func (e *Encoder) searchTable(f HeaderField) (i uint64, nameValueMatch bool) {
for idx, hf := range staticTable {
if !constantTimeStringCompare(hf.Name, f.Name) {
continue
}
if i == 0 {
i = uint64(idx + 1)
}
if f.Sensitive {
continue
}
if !constantTimeStringCompare(hf.Value, f.Value) {
continue
}
i = uint64(idx + 1)
nameValueMatch = true
return
}
j, nameValueMatch := e.dynTab.search(f)
if nameValueMatch || (i == 0 && j != 0) {
i = j + uint64(len(staticTable))
}
return
}
// SetMaxDynamicTableSize changes the dynamic header table size to v.
// The actual size is bounded by the value passed to
// SetMaxDynamicTableSizeLimit.
func (e *Encoder) SetMaxDynamicTableSize(v uint32) {
if v > e.maxSizeLimit {
v = e.maxSizeLimit
}
if v < e.minSize {
e.minSize = v
}
e.tableSizeUpdate = true
e.dynTab.setMaxSize(v)
}
// SetMaxDynamicTableSizeLimit changes the maximum value that can be
// specified in SetMaxDynamicTableSize to v. By default, it is set to
// 4096, which is the same size of the default dynamic header table
// size described in HPACK specification. If the current maximum
// dynamic header table size is strictly greater than v, "Header Table
// Size Update" will be done in the next WriteField call and the
// maximum dynamic header table size is truncated to v.
func (e *Encoder) SetMaxDynamicTableSizeLimit(v uint32) {
e.maxSizeLimit = v
if e.dynTab.maxSize > v {
e.tableSizeUpdate = true
e.dynTab.setMaxSize(v)
}
}
// shouldIndex reports whether f should be indexed.
func (e *Encoder) shouldIndex(f HeaderField) bool {
return !f.Sensitive && f.size() <= e.dynTab.maxSize
}
// appendIndexed appends index i, as encoded in "Indexed Header Field"
// representation, to dst and returns the extended buffer.
func appendIndexed(dst []byte, i uint64) []byte {
first := len(dst)
dst = appendVarInt(dst, 7, i)
dst[first] |= 0x80
return dst
}
// appendNewName appends f, as encoded in one of "Literal Header field
// - New Name" representation variants, to dst and returns the
// extended buffer.
//
// If f.Sensitive is true, "Never Indexed" representation is used. If
// f.Sensitive is false and indexing is true, "Inremental Indexing"
// representation is used.
func appendNewName(dst []byte, f HeaderField, indexing bool) []byte {
dst = append(dst, encodeTypeByte(indexing, f.Sensitive))
dst = appendHpackString(dst, f.Name)
return appendHpackString(dst, f.Value)
}
// appendIndexedName appends f and index i referring indexed name
// entry, as encoded in one of "Literal Header field - Indexed Name"
// representation variants, to dst and returns the extended buffer.
//
// If f.Sensitive is true, "Never Indexed" representation is used. If
// f.Sensitive is false and indexing is true, "Incremental Indexing"
// representation is used.
func appendIndexedName(dst []byte, f HeaderField, i uint64, indexing bool) []byte {
first := len(dst)
var n byte
if indexing {
n = 6
} else {
n = 4
}
dst = appendVarInt(dst, n, i)
dst[first] |= encodeTypeByte(indexing, f.Sensitive)
return appendHpackString(dst, f.Value)
}
// appendTableSize appends v, as encoded in "Header Table Size Update"
// representation, to dst and returns the extended buffer.
func appendTableSize(dst []byte, v uint32) []byte {
first := len(dst)
dst = appendVarInt(dst, 5, uint64(v))
dst[first] |= 0x20
return dst
}
// appendVarInt appends i, as encoded in variable integer form using n
// bit prefix, to dst and returns the extended buffer.
//
// See
// http://http2.github.io/http2-spec/compression.html#integer.representation
func appendVarInt(dst []byte, n byte, i uint64) []byte {
k := uint64((1 << n) - 1)
if i < k {
return append(dst, byte(i))
}
dst = append(dst, byte(k))
i -= k
for ; i >= 128; i >>= 7 {
dst = append(dst, byte(0x80|(i&0x7f)))
}
return append(dst, byte(i))
}
// appendHpackString appends s, as encoded in "String Literal"
// representation, to dst and returns the the extended buffer.
//
// s will be encoded in Huffman codes only when it produces strictly
// shorter byte string.
func appendHpackString(dst []byte, s string) []byte {
huffmanLength := HuffmanEncodeLength(s)
if huffmanLength < uint64(len(s)) {
first := len(dst)
dst = appendVarInt(dst, 7, huffmanLength)
dst = AppendHuffmanString(dst, s)
dst[first] |= 0x80
} else {
dst = appendVarInt(dst, 7, uint64(len(s)))
dst = append(dst, s...)
}
return dst
}
// encodeTypeByte returns type byte. If sensitive is true, type byte
// for "Never Indexed" representation is returned. If sensitive is
// false and indexing is true, type byte for "Incremental Indexing"
// representation is returned. Otherwise, type byte for "Without
// Indexing" is returned.
func encodeTypeByte(indexing, sensitive bool) byte {
if sensitive {
return 0x10
}
if indexing {
return 0x40
}
return 0
}

@ -0,0 +1,331 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
import (
"bytes"
"encoding/hex"
"reflect"
"strings"
"testing"
)
func TestEncoderTableSizeUpdate(t *testing.T) {
tests := []struct {
size1, size2 uint32
wantHex string
}{
// Should emit 2 table size updates (2048 and 4096)
{2048, 4096, "3fe10f 3fe11f 82"},
// Should emit 1 table size update (2048)
{16384, 2048, "3fe10f 82"},
}
for _, tt := range tests {
var buf bytes.Buffer
e := NewEncoder(&buf)
e.SetMaxDynamicTableSize(tt.size1)
e.SetMaxDynamicTableSize(tt.size2)
if err := e.WriteField(pair(":method", "GET")); err != nil {
t.Fatal(err)
}
want := removeSpace(tt.wantHex)
if got := hex.EncodeToString(buf.Bytes()); got != want {
t.Errorf("e.SetDynamicTableSize %v, %v = %q; want %q", tt.size1, tt.size2, got, want)
}
}
}
func TestEncoderWriteField(t *testing.T) {
var buf bytes.Buffer
e := NewEncoder(&buf)
var got []HeaderField
d := NewDecoder(4<<10, func(f HeaderField) {
got = append(got, f)
})
tests := []struct {
hdrs []HeaderField
}{
{[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
}},
{[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
pair("cache-control", "no-cache"),
}},
{[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "https"),
pair(":path", "/index.html"),
pair(":authority", "www.example.com"),
pair("custom-key", "custom-value"),
}},
}
for i, tt := range tests {
buf.Reset()
got = got[:0]
for _, hf := range tt.hdrs {
if err := e.WriteField(hf); err != nil {
t.Fatal(err)
}
}
_, err := d.Write(buf.Bytes())
if err != nil {
t.Errorf("%d. Decoder Write = %v", i, err)
}
if !reflect.DeepEqual(got, tt.hdrs) {
t.Errorf("%d. Decoded %+v; want %+v", i, got, tt.hdrs)
}
}
}
func TestEncoderSearchTable(t *testing.T) {
e := NewEncoder(nil)
e.dynTab.add(pair("foo", "bar"))
e.dynTab.add(pair("blake", "miz"))
e.dynTab.add(pair(":method", "GET"))
tests := []struct {
hf HeaderField
wantI uint64
wantMatch bool
}{
// Name and Value match
{pair("foo", "bar"), uint64(len(staticTable) + 3), true},
{pair("blake", "miz"), uint64(len(staticTable) + 2), true},
{pair(":method", "GET"), 2, true},
// Only name match because Sensitive == true
{HeaderField{":method", "GET", true}, 2, false},
// Only Name matches
{pair("foo", "..."), uint64(len(staticTable) + 3), false},
{pair("blake", "..."), uint64(len(staticTable) + 2), false},
{pair(":method", "..."), 2, false},
// None match
{pair("foo-", "bar"), 0, false},
}
for _, tt := range tests {
if gotI, gotMatch := e.searchTable(tt.hf); gotI != tt.wantI || gotMatch != tt.wantMatch {
t.Errorf("d.search(%+v) = %v, %v; want %v, %v", tt.hf, gotI, gotMatch, tt.wantI, tt.wantMatch)
}
}
}
func TestAppendVarInt(t *testing.T) {
tests := []struct {
n byte
i uint64
want []byte
}{
// Fits in a byte:
{1, 0, []byte{0}},
{2, 2, []byte{2}},
{3, 6, []byte{6}},
{4, 14, []byte{14}},
{5, 30, []byte{30}},
{6, 62, []byte{62}},
{7, 126, []byte{126}},
{8, 254, []byte{254}},
// Multiple bytes:
{5, 1337, []byte{31, 154, 10}},
}
for _, tt := range tests {
got := appendVarInt(nil, tt.n, tt.i)
if !bytes.Equal(got, tt.want) {
t.Errorf("appendVarInt(nil, %v, %v) = %v; want %v", tt.n, tt.i, got, tt.want)
}
}
}
func TestAppendHpackString(t *testing.T) {
tests := []struct {
s, wantHex string
}{
// Huffman encoded
{"www.example.com", "8c f1e3 c2e5 f23a 6ba0 ab90 f4ff"},
// Not Huffman encoded
{"a", "01 61"},
// zero length
{"", "00"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendHpackString(nil, tt.s)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendHpackString(nil, %q) = %q; want %q", tt.s, got, want)
}
}
}
func TestAppendIndexed(t *testing.T) {
tests := []struct {
i uint64
wantHex string
}{
// 1 byte
{1, "81"},
{126, "fe"},
// 2 bytes
{127, "ff00"},
{128, "ff01"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendIndexed(nil, tt.i)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendIndex(nil, %v) = %q; want %q", tt.i, got, want)
}
}
}
func TestAppendNewName(t *testing.T) {
tests := []struct {
f HeaderField
indexing bool
wantHex string
}{
// Incremental indexing
{HeaderField{"custom-key", "custom-value", false}, true, "40 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
// Without indexing
{HeaderField{"custom-key", "custom-value", false}, false, "00 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
// Never indexed
{HeaderField{"custom-key", "custom-value", true}, true, "10 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
{HeaderField{"custom-key", "custom-value", true}, false, "10 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendNewName(nil, tt.f, tt.indexing)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendNewName(nil, %+v, %v) = %q; want %q", tt.f, tt.indexing, got, want)
}
}
}
func TestAppendIndexedName(t *testing.T) {
tests := []struct {
f HeaderField
i uint64
indexing bool
wantHex string
}{
// Incremental indexing
{HeaderField{":status", "302", false}, 8, true, "48 82 6402"},
// Without indexing
{HeaderField{":status", "302", false}, 8, false, "08 82 6402"},
// Never indexed
{HeaderField{":status", "302", true}, 8, true, "18 82 6402"},
{HeaderField{":status", "302", true}, 8, false, "18 82 6402"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendIndexedName(nil, tt.f, tt.i, tt.indexing)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendIndexedName(nil, %+v, %v) = %q; want %q", tt.f, tt.indexing, got, want)
}
}
}
func TestAppendTableSize(t *testing.T) {
tests := []struct {
i uint32
wantHex string
}{
// Fits into 1 byte
{30, "3e"},
// Extra byte
{31, "3f00"},
{32, "3f01"},
}
for _, tt := range tests {
want := removeSpace(tt.wantHex)
buf := appendTableSize(nil, tt.i)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("appendTableSize(nil, %v) = %q; want %q", tt.i, got, want)
}
}
}
func TestEncoderSetMaxDynamicTableSize(t *testing.T) {
var buf bytes.Buffer
e := NewEncoder(&buf)
tests := []struct {
v uint32
wantUpdate bool
wantMinSize uint32
wantMaxSize uint32
}{
// Set new table size to 2048
{2048, true, 2048, 2048},
// Set new table size to 16384, but still limited to
// 4096
{16384, true, 2048, 4096},
}
for _, tt := range tests {
e.SetMaxDynamicTableSize(tt.v)
if got := e.tableSizeUpdate; tt.wantUpdate != got {
t.Errorf("e.tableSizeUpdate = %v; want %v", got, tt.wantUpdate)
}
if got := e.minSize; tt.wantMinSize != got {
t.Errorf("e.minSize = %v; want %v", got, tt.wantMinSize)
}
if got := e.dynTab.maxSize; tt.wantMaxSize != got {
t.Errorf("e.maxSize = %v; want %v", got, tt.wantMaxSize)
}
}
}
func TestEncoderSetMaxDynamicTableSizeLimit(t *testing.T) {
e := NewEncoder(nil)
// 4095 < initialHeaderTableSize means maxSize is truncated to
// 4095.
e.SetMaxDynamicTableSizeLimit(4095)
if got, want := e.dynTab.maxSize, uint32(4095); got != want {
t.Errorf("e.dynTab.maxSize = %v; want %v", got, want)
}
if got, want := e.maxSizeLimit, uint32(4095); got != want {
t.Errorf("e.maxSizeLimit = %v; want %v", got, want)
}
if got, want := e.tableSizeUpdate, true; got != want {
t.Errorf("e.tableSizeUpdate = %v; want %v", got, want)
}
// maxSize will be truncated to maxSizeLimit
e.SetMaxDynamicTableSize(16384)
if got, want := e.dynTab.maxSize, uint32(4095); got != want {
t.Errorf("e.dynTab.maxSize = %v; want %v", got, want)
}
// 8192 > current maxSizeLimit, so maxSize does not change.
e.SetMaxDynamicTableSizeLimit(8192)
if got, want := e.dynTab.maxSize, uint32(4095); got != want {
t.Errorf("e.dynTab.maxSize = %v; want %v", got, want)
}
if got, want := e.maxSizeLimit, uint32(8192); got != want {
t.Errorf("e.maxSizeLimit = %v; want %v", got, want)
}
}
func removeSpace(s string) string {
return strings.Replace(s, " ", "", -1)
}

@ -0,0 +1,445 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// Package hpack implements HPACK, a compression format for
// efficiently representing HTTP header fields in the context of HTTP/2.
//
// See http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-09
package hpack
import (
"bytes"
"errors"
"fmt"
)
// A DecodingError is something the spec defines as a decoding error.
type DecodingError struct {
Err error
}
func (de DecodingError) Error() string {
return fmt.Sprintf("decoding error: %v", de.Err)
}
// An InvalidIndexError is returned when an encoder references a table
// entry before the static table or after the end of the dynamic table.
type InvalidIndexError int
func (e InvalidIndexError) Error() string {
return fmt.Sprintf("invalid indexed representation index %d", int(e))
}
// A HeaderField is a name-value pair. Both the name and value are
// treated as opaque sequences of octets.
type HeaderField struct {
Name, Value string
// Sensitive means that this header field should never be
// indexed.
Sensitive bool
}
func (hf *HeaderField) size() uint32 {
// http://http2.github.io/http2-spec/compression.html#rfc.section.4.1
// "The size of the dynamic table is the sum of the size of
// its entries. The size of an entry is the sum of its name's
// length in octets (as defined in Section 5.2), its value's
// length in octets (see Section 5.2), plus 32. The size of
// an entry is calculated using the length of the name and
// value without any Huffman encoding applied."
// This can overflow if somebody makes a large HeaderField
// Name and/or Value by hand, but we don't care, because that
// won't happen on the wire because the encoding doesn't allow
// it.
return uint32(len(hf.Name) + len(hf.Value) + 32)
}
// A Decoder is the decoding context for incremental processing of
// header blocks.
type Decoder struct {
dynTab dynamicTable
emit func(f HeaderField)
// buf is the unparsed buffer. It's only written to
// saveBuf if it was truncated in the middle of a header
// block. Because it's usually not owned, we can only
// process it under Write.
buf []byte // usually not owned
saveBuf bytes.Buffer
}
func NewDecoder(maxSize uint32, emitFunc func(f HeaderField)) *Decoder {
d := &Decoder{
emit: emitFunc,
}
d.dynTab.allowedMaxSize = maxSize
d.dynTab.setMaxSize(maxSize)
return d
}
// TODO: add method *Decoder.Reset(maxSize, emitFunc) to let callers re-use Decoders and their
// underlying buffers for garbage reasons.
func (d *Decoder) SetMaxDynamicTableSize(v uint32) {
d.dynTab.setMaxSize(v)
}
// SetAllowedMaxDynamicTableSize sets the upper bound that the encoded
// stream (via dynamic table size updates) may set the maximum size
// to.
func (d *Decoder) SetAllowedMaxDynamicTableSize(v uint32) {
d.dynTab.allowedMaxSize = v
}
type dynamicTable struct {
// ents is the FIFO described at
// http://http2.github.io/http2-spec/compression.html#rfc.section.2.3.2
// The newest (low index) is append at the end, and items are
// evicted from the front.
ents []HeaderField
size uint32
maxSize uint32 // current maxSize
allowedMaxSize uint32 // maxSize may go up to this, inclusive
}
func (dt *dynamicTable) setMaxSize(v uint32) {
dt.maxSize = v
dt.evict()
}
// TODO: change dynamicTable to be a struct with a slice and a size int field,
// per http://http2.github.io/http2-spec/compression.html#rfc.section.4.1:
//
//
// Then make add increment the size. maybe the max size should move from Decoder to
// dynamicTable and add should return an ok bool if there was enough space.
//
// Later we'll need a remove operation on dynamicTable.
func (dt *dynamicTable) add(f HeaderField) {
dt.ents = append(dt.ents, f)
dt.size += f.size()
dt.evict()
}
// If we're too big, evict old stuff (front of the slice)
func (dt *dynamicTable) evict() {
base := dt.ents // keep base pointer of slice
for dt.size > dt.maxSize {
dt.size -= dt.ents[0].size()
dt.ents = dt.ents[1:]
}
// Shift slice contents down if we evicted things.
if len(dt.ents) != len(base) {
copy(base, dt.ents)
dt.ents = base[:len(dt.ents)]
}
}
// constantTimeStringCompare compares string a and b in a constant
// time manner.
func constantTimeStringCompare(a, b string) bool {
if len(a) != len(b) {
return false
}
c := byte(0)
for i := 0; i < len(a); i++ {
c |= a[i] ^ b[i]
}
return c == 0
}
// Search searches f in the table. The return value i is 0 if there is
// no name match. If there is name match or name/value match, i is the
// index of that entry (1-based). If both name and value match,
// nameValueMatch becomes true.
func (dt *dynamicTable) search(f HeaderField) (i uint64, nameValueMatch bool) {
l := len(dt.ents)
for j := l - 1; j >= 0; j-- {
ent := dt.ents[j]
if !constantTimeStringCompare(ent.Name, f.Name) {
continue
}
if i == 0 {
i = uint64(l - j)
}
if f.Sensitive {
continue
}
if !constantTimeStringCompare(ent.Value, f.Value) {
continue
}
i = uint64(l - j)
nameValueMatch = true
return
}
return
}
func (d *Decoder) maxTableIndex() int {
return len(d.dynTab.ents) + len(staticTable)
}
func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) {
if i < 1 {
return
}
if i > uint64(d.maxTableIndex()) {
return
}
if i <= uint64(len(staticTable)) {
return staticTable[i-1], true
}
dents := d.dynTab.ents
return dents[len(dents)-(int(i)-len(staticTable))], true
}
// Decode decodes an entire block.
//
// TODO: remove this method and make it incremental later? This is
// easier for debugging now.
func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) {
var hf []HeaderField
saveFunc := d.emit
defer func() { d.emit = saveFunc }()
d.emit = func(f HeaderField) { hf = append(hf, f) }
if _, err := d.Write(p); err != nil {
return nil, err
}
if err := d.Close(); err != nil {
return nil, err
}
return hf, nil
}
func (d *Decoder) Close() error {
if d.saveBuf.Len() > 0 {
d.saveBuf.Reset()
return DecodingError{errors.New("truncated headers")}
}
return nil
}
func (d *Decoder) Write(p []byte) (n int, err error) {
if len(p) == 0 {
// Prevent state machine CPU attacks (making us redo
// work up to the point of finding out we don't have
// enough data)
return
}
// Only copy the data if we have to. Optimistically assume
// that p will contain a complete header block.
if d.saveBuf.Len() == 0 {
d.buf = p
} else {
d.saveBuf.Write(p)
d.buf = d.saveBuf.Bytes()
d.saveBuf.Reset()
}
for len(d.buf) > 0 {
err = d.parseHeaderFieldRepr()
if err != nil {
if err == errNeedMore {
err = nil
d.saveBuf.Write(d.buf)
}
break
}
}
return len(p), err
}
// errNeedMore is an internal sentinel error value that means the
// buffer is truncated and we need to read more data before we can
// continue parsing.
var errNeedMore = errors.New("need more data")
type indexType int
const (
indexedTrue indexType = iota
indexedFalse
indexedNever
)
func (v indexType) indexed() bool { return v == indexedTrue }
func (v indexType) sensitive() bool { return v == indexedNever }
// returns errNeedMore if there isn't enough data available.
// any other error is fatal.
// consumes d.buf iff it returns nil.
// precondition: must be called with len(d.buf) > 0
func (d *Decoder) parseHeaderFieldRepr() error {
b := d.buf[0]
switch {
case b&128 != 0:
// Indexed representation.
// High bit set?
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.1
return d.parseFieldIndexed()
case b&192 == 64:
// 6.2.1 Literal Header Field with Incremental Indexing
// 0b10xxxxxx: top two bits are 10
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.1
return d.parseFieldLiteral(6, indexedTrue)
case b&240 == 0:
// 6.2.2 Literal Header Field without Indexing
// 0b0000xxxx: top four bits are 0000
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.2
return d.parseFieldLiteral(4, indexedFalse)
case b&240 == 16:
// 6.2.3 Literal Header Field never Indexed
// 0b0001xxxx: top four bits are 0001
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.3
return d.parseFieldLiteral(4, indexedNever)
case b&224 == 32:
// 6.3 Dynamic Table Size Update
// Top three bits are '001'.
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.3
return d.parseDynamicTableSizeUpdate()
}
return DecodingError{errors.New("invalid encoding")}
}
// (same invariants and behavior as parseHeaderFieldRepr)
func (d *Decoder) parseFieldIndexed() error {
buf := d.buf
idx, buf, err := readVarInt(7, buf)
if err != nil {
return err
}
hf, ok := d.at(idx)
if !ok {
return DecodingError{InvalidIndexError(idx)}
}
d.emit(HeaderField{Name: hf.Name, Value: hf.Value})
d.buf = buf
return nil
}
// (same invariants and behavior as parseHeaderFieldRepr)
func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
buf := d.buf
nameIdx, buf, err := readVarInt(n, buf)
if err != nil {
return err
}
var hf HeaderField
if nameIdx > 0 {
ihf, ok := d.at(nameIdx)
if !ok {
return DecodingError{InvalidIndexError(nameIdx)}
}
hf.Name = ihf.Name
} else {
hf.Name, buf, err = readString(buf)
if err != nil {
return err
}
}
hf.Value, buf, err = readString(buf)
if err != nil {
return err
}
d.buf = buf
if it.indexed() {
d.dynTab.add(hf)
}
hf.Sensitive = it.sensitive()
d.emit(hf)
return nil
}
// (same invariants and behavior as parseHeaderFieldRepr)
func (d *Decoder) parseDynamicTableSizeUpdate() error {
buf := d.buf
size, buf, err := readVarInt(5, buf)
if err != nil {
return err
}
if size > uint64(d.dynTab.allowedMaxSize) {
return DecodingError{errors.New("dynamic table size update too large")}
}
d.dynTab.setMaxSize(uint32(size))
d.buf = buf
return nil
}
var errVarintOverflow = DecodingError{errors.New("varint integer overflow")}
// readVarInt reads an unsigned variable length integer off the
// beginning of p. n is the parameter as described in
// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1.
//
// n must always be between 1 and 8.
//
// The returned remain buffer is either a smaller suffix of p, or err != nil.
// The error is errNeedMore if p doesn't contain a complete integer.
func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
if n < 1 || n > 8 {
panic("bad n")
}
if len(p) == 0 {
return 0, p, errNeedMore
}
i = uint64(p[0])
if n < 8 {
i &= (1 << uint64(n)) - 1
}
if i < (1<<uint64(n))-1 {
return i, p[1:], nil
}
origP := p
p = p[1:]
var m uint64
for len(p) > 0 {
b := p[0]
p = p[1:]
i += uint64(b&127) << m
if b&128 == 0 {
return i, p, nil
}
m += 7
if m >= 63 { // TODO: proper overflow check. making this up.
return 0, origP, errVarintOverflow
}
}
return 0, origP, errNeedMore
}
func readString(p []byte) (s string, remain []byte, err error) {
if len(p) == 0 {
return "", p, errNeedMore
}
isHuff := p[0]&128 != 0
strLen, p, err := readVarInt(7, p)
if err != nil {
return "", p, err
}
if uint64(len(p)) < strLen {
return "", p, errNeedMore
}
if !isHuff {
return string(p[:strLen]), p[strLen:], nil
}
// TODO: optimize this garbage:
var buf bytes.Buffer
if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil {
return "", nil, err
}
return buf.String(), p[strLen:], nil
}

@ -0,0 +1,648 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
import (
"bufio"
"bytes"
"encoding/hex"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"testing"
)
func TestStaticTable(t *testing.T) {
fromSpec := `
+-------+-----------------------------+---------------+
| 1 | :authority | |
| 2 | :method | GET |
| 3 | :method | POST |
| 4 | :path | / |
| 5 | :path | /index.html |
| 6 | :scheme | http |
| 7 | :scheme | https |
| 8 | :status | 200 |
| 9 | :status | 204 |
| 10 | :status | 206 |
| 11 | :status | 304 |
| 12 | :status | 400 |
| 13 | :status | 404 |
| 14 | :status | 500 |
| 15 | accept-charset | |
| 16 | accept-encoding | gzip, deflate |
| 17 | accept-language | |
| 18 | accept-ranges | |
| 19 | accept | |
| 20 | access-control-allow-origin | |
| 21 | age | |
| 22 | allow | |
| 23 | authorization | |
| 24 | cache-control | |
| 25 | content-disposition | |
| 26 | content-encoding | |
| 27 | content-language | |
| 28 | content-length | |
| 29 | content-location | |
| 30 | content-range | |
| 31 | content-type | |
| 32 | cookie | |
| 33 | date | |
| 34 | etag | |
| 35 | expect | |
| 36 | expires | |
| 37 | from | |
| 38 | host | |
| 39 | if-match | |
| 40 | if-modified-since | |
| 41 | if-none-match | |
| 42 | if-range | |
| 43 | if-unmodified-since | |
| 44 | last-modified | |
| 45 | link | |
| 46 | location | |
| 47 | max-forwards | |
| 48 | proxy-authenticate | |
| 49 | proxy-authorization | |
| 50 | range | |
| 51 | referer | |
| 52 | refresh | |
| 53 | retry-after | |
| 54 | server | |
| 55 | set-cookie | |
| 56 | strict-transport-security | |
| 57 | transfer-encoding | |
| 58 | user-agent | |
| 59 | vary | |
| 60 | via | |
| 61 | www-authenticate | |
+-------+-----------------------------+---------------+
`
bs := bufio.NewScanner(strings.NewReader(fromSpec))
re := regexp.MustCompile(`\| (\d+)\s+\| (\S+)\s*\| (\S(.*\S)?)?\s+\|`)
for bs.Scan() {
l := bs.Text()
if !strings.Contains(l, "|") {
continue
}
m := re.FindStringSubmatch(l)
if m == nil {
continue
}
i, err := strconv.Atoi(m[1])
if err != nil {
t.Errorf("Bogus integer on line %q", l)
continue
}
if i < 1 || i > len(staticTable) {
t.Errorf("Bogus index %d on line %q", i, l)
continue
}
if got, want := staticTable[i-1].Name, m[2]; got != want {
t.Errorf("header index %d name = %q; want %q", i, got, want)
}
if got, want := staticTable[i-1].Value, m[3]; got != want {
t.Errorf("header index %d value = %q; want %q", i, got, want)
}
}
if err := bs.Err(); err != nil {
t.Error(err)
}
}
func (d *Decoder) mustAt(idx int) HeaderField {
if hf, ok := d.at(uint64(idx)); !ok {
panic(fmt.Sprintf("bogus index %d", idx))
} else {
return hf
}
}
func TestDynamicTableAt(t *testing.T) {
d := NewDecoder(4096, nil)
at := d.mustAt
if got, want := at(2), (pair(":method", "GET")); got != want {
t.Errorf("at(2) = %v; want %v", got, want)
}
d.dynTab.add(pair("foo", "bar"))
d.dynTab.add(pair("blake", "miz"))
if got, want := at(len(staticTable)+1), (pair("blake", "miz")); got != want {
t.Errorf("at(dyn 1) = %v; want %v", got, want)
}
if got, want := at(len(staticTable)+2), (pair("foo", "bar")); got != want {
t.Errorf("at(dyn 2) = %v; want %v", got, want)
}
if got, want := at(3), (pair(":method", "POST")); got != want {
t.Errorf("at(3) = %v; want %v", got, want)
}
}
func TestDynamicTableSearch(t *testing.T) {
dt := dynamicTable{}
dt.setMaxSize(4096)
dt.add(pair("foo", "bar"))
dt.add(pair("blake", "miz"))
dt.add(pair(":method", "GET"))
tests := []struct {
hf HeaderField
wantI uint64
wantMatch bool
}{
// Name and Value match
{pair("foo", "bar"), 3, true},
{pair(":method", "GET"), 1, true},
// Only name match because of Sensitive == true
{HeaderField{"blake", "miz", true}, 2, false},
// Only Name matches
{pair("foo", "..."), 3, false},
{pair("blake", "..."), 2, false},
{pair(":method", "..."), 1, false},
// None match
{pair("foo-", "bar"), 0, false},
}
for _, tt := range tests {
if gotI, gotMatch := dt.search(tt.hf); gotI != tt.wantI || gotMatch != tt.wantMatch {
t.Errorf("d.search(%+v) = %v, %v; want %v, %v", tt.hf, gotI, gotMatch, tt.wantI, tt.wantMatch)
}
}
}
func TestDynamicTableSizeEvict(t *testing.T) {
d := NewDecoder(4096, nil)
if want := uint32(0); d.dynTab.size != want {
t.Fatalf("size = %d; want %d", d.dynTab.size, want)
}
add := d.dynTab.add
add(pair("blake", "eats pizza"))
if want := uint32(15 + 32); d.dynTab.size != want {
t.Fatalf("after pizza, size = %d; want %d", d.dynTab.size, want)
}
add(pair("foo", "bar"))
if want := uint32(15 + 32 + 6 + 32); d.dynTab.size != want {
t.Fatalf("after foo bar, size = %d; want %d", d.dynTab.size, want)
}
d.dynTab.setMaxSize(15 + 32 + 1 /* slop */)
if want := uint32(6 + 32); d.dynTab.size != want {
t.Fatalf("after setMaxSize, size = %d; want %d", d.dynTab.size, want)
}
if got, want := d.mustAt(len(staticTable)+1), (pair("foo", "bar")); got != want {
t.Errorf("at(dyn 1) = %v; want %v", got, want)
}
add(pair("long", strings.Repeat("x", 500)))
if want := uint32(0); d.dynTab.size != want {
t.Fatalf("after big one, size = %d; want %d", d.dynTab.size, want)
}
}
func TestDecoderDecode(t *testing.T) {
tests := []struct {
name string
in []byte
want []HeaderField
wantDynTab []HeaderField // newest entry first
}{
// C.2.1 Literal Header Field with Indexing
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.1
{"C.2.1", dehex("400a 6375 7374 6f6d 2d6b 6579 0d63 7573 746f 6d2d 6865 6164 6572"),
[]HeaderField{pair("custom-key", "custom-header")},
[]HeaderField{pair("custom-key", "custom-header")},
},
// C.2.2 Literal Header Field without Indexing
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.2
{"C.2.2", dehex("040c 2f73 616d 706c 652f 7061 7468"),
[]HeaderField{pair(":path", "/sample/path")},
[]HeaderField{}},
// C.2.3 Literal Header Field never Indexed
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.3
{"C.2.3", dehex("1008 7061 7373 776f 7264 0673 6563 7265 74"),
[]HeaderField{{"password", "secret", true}},
[]HeaderField{}},
// C.2.4 Indexed Header Field
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.4
{"C.2.4", []byte("\x82"),
[]HeaderField{pair(":method", "GET")},
[]HeaderField{}},
}
for _, tt := range tests {
d := NewDecoder(4096, nil)
hf, err := d.DecodeFull(tt.in)
if err != nil {
t.Errorf("%s: %v", tt.name, err)
continue
}
if !reflect.DeepEqual(hf, tt.want) {
t.Errorf("%s: Got %v; want %v", tt.name, hf, tt.want)
}
gotDynTab := d.dynTab.reverseCopy()
if !reflect.DeepEqual(gotDynTab, tt.wantDynTab) {
t.Errorf("%s: dynamic table after = %v; want %v", tt.name, gotDynTab, tt.wantDynTab)
}
}
}
func (dt *dynamicTable) reverseCopy() (hf []HeaderField) {
hf = make([]HeaderField, len(dt.ents))
for i := range hf {
hf[i] = dt.ents[len(dt.ents)-1-i]
}
return
}
type encAndWant struct {
enc []byte
want []HeaderField
wantDynTab []HeaderField
wantDynSize uint32
}
// C.3 Request Examples without Huffman Coding
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.3
func TestDecodeC3_NoHuffman(t *testing.T) {
testDecodeSeries(t, 4096, []encAndWant{
{dehex("8286 8441 0f77 7777 2e65 7861 6d70 6c65 2e63 6f6d"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
},
[]HeaderField{
pair(":authority", "www.example.com"),
},
57,
},
{dehex("8286 84be 5808 6e6f 2d63 6163 6865"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
pair("cache-control", "no-cache"),
},
[]HeaderField{
pair("cache-control", "no-cache"),
pair(":authority", "www.example.com"),
},
110,
},
{dehex("8287 85bf 400a 6375 7374 6f6d 2d6b 6579 0c63 7573 746f 6d2d 7661 6c75 65"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "https"),
pair(":path", "/index.html"),
pair(":authority", "www.example.com"),
pair("custom-key", "custom-value"),
},
[]HeaderField{
pair("custom-key", "custom-value"),
pair("cache-control", "no-cache"),
pair(":authority", "www.example.com"),
},
164,
},
})
}
// C.4 Request Examples with Huffman Coding
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.4
func TestDecodeC4_Huffman(t *testing.T) {
testDecodeSeries(t, 4096, []encAndWant{
{dehex("8286 8441 8cf1 e3c2 e5f2 3a6b a0ab 90f4 ff"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
},
[]HeaderField{
pair(":authority", "www.example.com"),
},
57,
},
{dehex("8286 84be 5886 a8eb 1064 9cbf"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "http"),
pair(":path", "/"),
pair(":authority", "www.example.com"),
pair("cache-control", "no-cache"),
},
[]HeaderField{
pair("cache-control", "no-cache"),
pair(":authority", "www.example.com"),
},
110,
},
{dehex("8287 85bf 4088 25a8 49e9 5ba9 7d7f 8925 a849 e95b b8e8 b4bf"),
[]HeaderField{
pair(":method", "GET"),
pair(":scheme", "https"),
pair(":path", "/index.html"),
pair(":authority", "www.example.com"),
pair("custom-key", "custom-value"),
},
[]HeaderField{
pair("custom-key", "custom-value"),
pair("cache-control", "no-cache"),
pair(":authority", "www.example.com"),
},
164,
},
})
}
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.5
// "This section shows several consecutive header lists, corresponding
// to HTTP responses, on the same connection. The HTTP/2 setting
// parameter SETTINGS_HEADER_TABLE_SIZE is set to the value of 256
// octets, causing some evictions to occur."
func TestDecodeC5_ResponsesNoHuff(t *testing.T) {
testDecodeSeries(t, 256, []encAndWant{
{dehex(`
4803 3330 3258 0770 7269 7661 7465 611d
4d6f 6e2c 2032 3120 4f63 7420 3230 3133
2032 303a 3133 3a32 3120 474d 546e 1768
7474 7073 3a2f 2f77 7777 2e65 7861 6d70
6c65 2e63 6f6d
`),
[]HeaderField{
pair(":status", "302"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("location", "https://www.example.com"),
},
[]HeaderField{
pair("location", "https://www.example.com"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("cache-control", "private"),
pair(":status", "302"),
},
222,
},
{dehex("4803 3330 37c1 c0bf"),
[]HeaderField{
pair(":status", "307"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("location", "https://www.example.com"),
},
[]HeaderField{
pair(":status", "307"),
pair("location", "https://www.example.com"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("cache-control", "private"),
},
222,
},
{dehex(`
88c1 611d 4d6f 6e2c 2032 3120 4f63 7420
3230 3133 2032 303a 3133 3a32 3220 474d
54c0 5a04 677a 6970 7738 666f 6f3d 4153
444a 4b48 514b 425a 584f 5157 454f 5049
5541 5851 5745 4f49 553b 206d 6178 2d61
6765 3d33 3630 303b 2076 6572 7369 6f6e
3d31
`),
[]HeaderField{
pair(":status", "200"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
pair("location", "https://www.example.com"),
pair("content-encoding", "gzip"),
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
},
[]HeaderField{
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
pair("content-encoding", "gzip"),
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
},
215,
},
})
}
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.6
// "This section shows the same examples as the previous section, but
// using Huffman encoding for the literal values. The HTTP/2 setting
// parameter SETTINGS_HEADER_TABLE_SIZE is set to the value of 256
// octets, causing some evictions to occur. The eviction mechanism
// uses the length of the decoded literal values, so the same
// evictions occurs as in the previous section."
func TestDecodeC6_ResponsesHuffman(t *testing.T) {
testDecodeSeries(t, 256, []encAndWant{
{dehex(`
4882 6402 5885 aec3 771a 4b61 96d0 7abe
9410 54d4 44a8 2005 9504 0b81 66e0 82a6
2d1b ff6e 919d 29ad 1718 63c7 8f0b 97c8
e9ae 82ae 43d3
`),
[]HeaderField{
pair(":status", "302"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("location", "https://www.example.com"),
},
[]HeaderField{
pair("location", "https://www.example.com"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("cache-control", "private"),
pair(":status", "302"),
},
222,
},
{dehex("4883 640e ffc1 c0bf"),
[]HeaderField{
pair(":status", "307"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("location", "https://www.example.com"),
},
[]HeaderField{
pair(":status", "307"),
pair("location", "https://www.example.com"),
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
pair("cache-control", "private"),
},
222,
},
{dehex(`
88c1 6196 d07a be94 1054 d444 a820 0595
040b 8166 e084 a62d 1bff c05a 839b d9ab
77ad 94e7 821d d7f2 e6c7 b335 dfdf cd5b
3960 d5af 2708 7f36 72c1 ab27 0fb5 291f
9587 3160 65c0 03ed 4ee5 b106 3d50 07
`),
[]HeaderField{
pair(":status", "200"),
pair("cache-control", "private"),
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
pair("location", "https://www.example.com"),
pair("content-encoding", "gzip"),
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
},
[]HeaderField{
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
pair("content-encoding", "gzip"),
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
},
215,
},
})
}
func testDecodeSeries(t *testing.T, size uint32, steps []encAndWant) {
d := NewDecoder(size, nil)
for i, step := range steps {
hf, err := d.DecodeFull(step.enc)
if err != nil {
t.Fatalf("Error at step index %d: %v", i, err)
}
if !reflect.DeepEqual(hf, step.want) {
t.Fatalf("At step index %d: Got headers %v; want %v", i, hf, step.want)
}
gotDynTab := d.dynTab.reverseCopy()
if !reflect.DeepEqual(gotDynTab, step.wantDynTab) {
t.Errorf("After step index %d, dynamic table = %v; want %v", i, gotDynTab, step.wantDynTab)
}
if d.dynTab.size != step.wantDynSize {
t.Errorf("After step index %d, dynamic table size = %v; want %v", i, d.dynTab.size, step.wantDynSize)
}
}
}
func TestHuffmanDecode(t *testing.T) {
tests := []struct {
inHex, want string
}{
{"f1e3 c2e5 f23a 6ba0 ab90 f4ff", "www.example.com"},
{"a8eb 1064 9cbf", "no-cache"},
{"25a8 49e9 5ba9 7d7f", "custom-key"},
{"25a8 49e9 5bb8 e8b4 bf", "custom-value"},
{"6402", "302"},
{"aec3 771a 4b", "private"},
{"d07a be94 1054 d444 a820 0595 040b 8166 e082 a62d 1bff", "Mon, 21 Oct 2013 20:13:21 GMT"},
{"9d29 ad17 1863 c78f 0b97 c8e9 ae82 ae43 d3", "https://www.example.com"},
{"9bd9 ab", "gzip"},
{"94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 3160 65c0 03ed 4ee5 b106 3d50 07",
"foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"},
}
for i, tt := range tests {
var buf bytes.Buffer
in, err := hex.DecodeString(strings.Replace(tt.inHex, " ", "", -1))
if err != nil {
t.Errorf("%d. hex input error: %v", i, err)
continue
}
if _, err := HuffmanDecode(&buf, in); err != nil {
t.Errorf("%d. decode error: %v", i, err)
continue
}
if got := buf.String(); tt.want != got {
t.Errorf("%d. decode = %q; want %q", i, got, tt.want)
}
}
}
func TestAppendHuffmanString(t *testing.T) {
tests := []struct {
in, want string
}{
{"www.example.com", "f1e3 c2e5 f23a 6ba0 ab90 f4ff"},
{"no-cache", "a8eb 1064 9cbf"},
{"custom-key", "25a8 49e9 5ba9 7d7f"},
{"custom-value", "25a8 49e9 5bb8 e8b4 bf"},
{"302", "6402"},
{"private", "aec3 771a 4b"},
{"Mon, 21 Oct 2013 20:13:21 GMT", "d07a be94 1054 d444 a820 0595 040b 8166 e082 a62d 1bff"},
{"https://www.example.com", "9d29 ad17 1863 c78f 0b97 c8e9 ae82 ae43 d3"},
{"gzip", "9bd9 ab"},
{"foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1",
"94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 3160 65c0 03ed 4ee5 b106 3d50 07"},
}
for i, tt := range tests {
buf := []byte{}
want := strings.Replace(tt.want, " ", "", -1)
buf = AppendHuffmanString(buf, tt.in)
if got := hex.EncodeToString(buf); want != got {
t.Errorf("%d. encode = %q; want %q", i, got, want)
}
}
}
func TestReadVarInt(t *testing.T) {
type res struct {
i uint64
consumed int
err error
}
tests := []struct {
n byte
p []byte
want res
}{
// Fits in a byte:
{1, []byte{0}, res{0, 1, nil}},
{2, []byte{2}, res{2, 1, nil}},
{3, []byte{6}, res{6, 1, nil}},
{4, []byte{14}, res{14, 1, nil}},
{5, []byte{30}, res{30, 1, nil}},
{6, []byte{62}, res{62, 1, nil}},
{7, []byte{126}, res{126, 1, nil}},
{8, []byte{254}, res{254, 1, nil}},
// Doesn't fit in a byte:
{1, []byte{1}, res{0, 0, errNeedMore}},
{2, []byte{3}, res{0, 0, errNeedMore}},
{3, []byte{7}, res{0, 0, errNeedMore}},
{4, []byte{15}, res{0, 0, errNeedMore}},
{5, []byte{31}, res{0, 0, errNeedMore}},
{6, []byte{63}, res{0, 0, errNeedMore}},
{7, []byte{127}, res{0, 0, errNeedMore}},
{8, []byte{255}, res{0, 0, errNeedMore}},
// Ignoring top bits:
{5, []byte{255, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 111
{5, []byte{159, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 100
{5, []byte{191, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 101
// Extra byte:
{5, []byte{191, 154, 10, 2}, res{1337, 3, nil}}, // extra byte
// Short a byte:
{5, []byte{191, 154}, res{0, 0, errNeedMore}},
// integer overflow:
{1, []byte{255, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128}, res{0, 0, errVarintOverflow}},
}
for _, tt := range tests {
i, remain, err := readVarInt(tt.n, tt.p)
consumed := len(tt.p) - len(remain)
got := res{i, consumed, err}
if got != tt.want {
t.Errorf("readVarInt(%d, %v ~ %x) = %+v; want %+v", tt.n, tt.p, tt.p, got, tt.want)
}
}
}
func dehex(s string) []byte {
s = strings.Replace(s, " ", "", -1)
s = strings.Replace(s, "\n", "", -1)
b, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return b
}

@ -0,0 +1,159 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
import (
"bytes"
"io"
"sync"
)
var bufPool = sync.Pool{
New: func() interface{} { return new(bytes.Buffer) },
}
// HuffmanDecode decodes the string in v and writes the expanded
// result to w, returning the number of bytes written to w and the
// Write call's return value. At most one Write call is made.
func HuffmanDecode(w io.Writer, v []byte) (int, error) {
buf := bufPool.Get().(*bytes.Buffer)
buf.Reset()
defer bufPool.Put(buf)
n := rootHuffmanNode
cur, nbits := uint(0), uint8(0)
for _, b := range v {
cur = cur<<8 | uint(b)
nbits += 8
for nbits >= 8 {
n = n.children[byte(cur>>(nbits-8))]
if n.children == nil {
buf.WriteByte(n.sym)
nbits -= n.codeLen
n = rootHuffmanNode
} else {
nbits -= 8
}
}
}
for nbits > 0 {
n = n.children[byte(cur<<(8-nbits))]
if n.children != nil || n.codeLen > nbits {
break
}
buf.WriteByte(n.sym)
nbits -= n.codeLen
n = rootHuffmanNode
}
return w.Write(buf.Bytes())
}
type node struct {
// children is non-nil for internal nodes
children []*node
// The following are only valid if children is nil:
codeLen uint8 // number of bits that led to the output of sym
sym byte // output symbol
}
func newInternalNode() *node {
return &node{children: make([]*node, 256)}
}
var rootHuffmanNode = newInternalNode()
func init() {
for i, code := range huffmanCodes {
if i > 255 {
panic("too many huffman codes")
}
addDecoderNode(byte(i), code, huffmanCodeLen[i])
}
}
func addDecoderNode(sym byte, code uint32, codeLen uint8) {
cur := rootHuffmanNode
for codeLen > 8 {
codeLen -= 8
i := uint8(code >> codeLen)
if cur.children[i] == nil {
cur.children[i] = newInternalNode()
}
cur = cur.children[i]
}
shift := 8 - codeLen
start, end := int(uint8(code<<shift)), int(1<<shift)
for i := start; i < start+end; i++ {
cur.children[i] = &node{sym: sym, codeLen: codeLen}
}
}
// AppendHuffmanString appends s, as encoded in Huffman codes, to dst
// and returns the extended buffer.
func AppendHuffmanString(dst []byte, s string) []byte {
rembits := uint8(8)
for i := 0; i < len(s); i++ {
if rembits == 8 {
dst = append(dst, 0)
}
dst, rembits = appendByteToHuffmanCode(dst, rembits, s[i])
}
if rembits < 8 {
// special EOS symbol
code := uint32(0x3fffffff)
nbits := uint8(30)
t := uint8(code >> (nbits - rembits))
dst[len(dst)-1] |= t
}
return dst
}
// HuffmanEncodeLength returns the number of bytes required to encode
// s in Huffman codes. The result is round up to byte boundary.
func HuffmanEncodeLength(s string) uint64 {
n := uint64(0)
for i := 0; i < len(s); i++ {
n += uint64(huffmanCodeLen[s[i]])
}
return (n + 7) / 8
}
// appendByteToHuffmanCode appends Huffman code for c to dst and
// returns the extended buffer and the remaining bits in the last
// element. The appending is not byte aligned and the remaining bits
// in the last element of dst is given in rembits.
func appendByteToHuffmanCode(dst []byte, rembits uint8, c byte) ([]byte, uint8) {
code := huffmanCodes[c]
nbits := huffmanCodeLen[c]
for {
if rembits > nbits {
t := uint8(code << (rembits - nbits))
dst[len(dst)-1] |= t
rembits -= nbits
break
}
t := uint8(code >> (nbits - rembits))
dst[len(dst)-1] |= t
nbits -= rembits
rembits = 8
if nbits == 0 {
break
}
dst = append(dst, 0)
}
return dst, rembits
}

@ -0,0 +1,353 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package hpack
func pair(name, value string) HeaderField {
return HeaderField{Name: name, Value: value}
}
// http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-07#appendix-B
var staticTable = []HeaderField{
pair(":authority", ""), // index 1 (1-based)
pair(":method", "GET"),
pair(":method", "POST"),
pair(":path", "/"),
pair(":path", "/index.html"),
pair(":scheme", "http"),
pair(":scheme", "https"),
pair(":status", "200"),
pair(":status", "204"),
pair(":status", "206"),
pair(":status", "304"),
pair(":status", "400"),
pair(":status", "404"),
pair(":status", "500"),
pair("accept-charset", ""),
pair("accept-encoding", "gzip, deflate"),
pair("accept-language", ""),
pair("accept-ranges", ""),
pair("accept", ""),
pair("access-control-allow-origin", ""),
pair("age", ""),
pair("allow", ""),
pair("authorization", ""),
pair("cache-control", ""),
pair("content-disposition", ""),
pair("content-encoding", ""),
pair("content-language", ""),
pair("content-length", ""),
pair("content-location", ""),
pair("content-range", ""),
pair("content-type", ""),
pair("cookie", ""),
pair("date", ""),
pair("etag", ""),
pair("expect", ""),
pair("expires", ""),
pair("from", ""),
pair("host", ""),
pair("if-match", ""),
pair("if-modified-since", ""),
pair("if-none-match", ""),
pair("if-range", ""),
pair("if-unmodified-since", ""),
pair("last-modified", ""),
pair("link", ""),
pair("location", ""),
pair("max-forwards", ""),
pair("proxy-authenticate", ""),
pair("proxy-authorization", ""),
pair("range", ""),
pair("referer", ""),
pair("refresh", ""),
pair("retry-after", ""),
pair("server", ""),
pair("set-cookie", ""),
pair("strict-transport-security", ""),
pair("transfer-encoding", ""),
pair("user-agent", ""),
pair("vary", ""),
pair("via", ""),
pair("www-authenticate", ""),
}
var huffmanCodes = []uint32{
0x1ff8,
0x7fffd8,
0xfffffe2,
0xfffffe3,
0xfffffe4,
0xfffffe5,
0xfffffe6,
0xfffffe7,
0xfffffe8,
0xffffea,
0x3ffffffc,
0xfffffe9,
0xfffffea,
0x3ffffffd,
0xfffffeb,
0xfffffec,
0xfffffed,
0xfffffee,
0xfffffef,
0xffffff0,
0xffffff1,
0xffffff2,
0x3ffffffe,
0xffffff3,
0xffffff4,
0xffffff5,
0xffffff6,
0xffffff7,
0xffffff8,
0xffffff9,
0xffffffa,
0xffffffb,
0x14,
0x3f8,
0x3f9,
0xffa,
0x1ff9,
0x15,
0xf8,
0x7fa,
0x3fa,
0x3fb,
0xf9,
0x7fb,
0xfa,
0x16,
0x17,
0x18,
0x0,
0x1,
0x2,
0x19,
0x1a,
0x1b,
0x1c,
0x1d,
0x1e,
0x1f,
0x5c,
0xfb,
0x7ffc,
0x20,
0xffb,
0x3fc,
0x1ffa,
0x21,
0x5d,
0x5e,
0x5f,
0x60,
0x61,
0x62,
0x63,
0x64,
0x65,
0x66,
0x67,
0x68,
0x69,
0x6a,
0x6b,
0x6c,
0x6d,
0x6e,
0x6f,
0x70,
0x71,
0x72,
0xfc,
0x73,
0xfd,
0x1ffb,
0x7fff0,
0x1ffc,
0x3ffc,
0x22,
0x7ffd,
0x3,
0x23,
0x4,
0x24,
0x5,
0x25,
0x26,
0x27,
0x6,
0x74,
0x75,
0x28,
0x29,
0x2a,
0x7,
0x2b,
0x76,
0x2c,
0x8,
0x9,
0x2d,
0x77,
0x78,
0x79,
0x7a,
0x7b,
0x7ffe,
0x7fc,
0x3ffd,
0x1ffd,
0xffffffc,
0xfffe6,
0x3fffd2,
0xfffe7,
0xfffe8,
0x3fffd3,
0x3fffd4,
0x3fffd5,
0x7fffd9,
0x3fffd6,
0x7fffda,
0x7fffdb,
0x7fffdc,
0x7fffdd,
0x7fffde,
0xffffeb,
0x7fffdf,
0xffffec,
0xffffed,
0x3fffd7,
0x7fffe0,
0xffffee,
0x7fffe1,
0x7fffe2,
0x7fffe3,
0x7fffe4,
0x1fffdc,
0x3fffd8,
0x7fffe5,
0x3fffd9,
0x7fffe6,
0x7fffe7,
0xffffef,
0x3fffda,
0x1fffdd,
0xfffe9,
0x3fffdb,
0x3fffdc,
0x7fffe8,
0x7fffe9,
0x1fffde,
0x7fffea,
0x3fffdd,
0x3fffde,
0xfffff0,
0x1fffdf,
0x3fffdf,
0x7fffeb,
0x7fffec,
0x1fffe0,
0x1fffe1,
0x3fffe0,
0x1fffe2,
0x7fffed,
0x3fffe1,
0x7fffee,
0x7fffef,
0xfffea,
0x3fffe2,
0x3fffe3,
0x3fffe4,
0x7ffff0,
0x3fffe5,
0x3fffe6,
0x7ffff1,
0x3ffffe0,
0x3ffffe1,
0xfffeb,
0x7fff1,
0x3fffe7,
0x7ffff2,
0x3fffe8,
0x1ffffec,
0x3ffffe2,
0x3ffffe3,
0x3ffffe4,
0x7ffffde,
0x7ffffdf,
0x3ffffe5,
0xfffff1,
0x1ffffed,
0x7fff2,
0x1fffe3,
0x3ffffe6,
0x7ffffe0,
0x7ffffe1,
0x3ffffe7,
0x7ffffe2,
0xfffff2,
0x1fffe4,
0x1fffe5,
0x3ffffe8,
0x3ffffe9,
0xffffffd,
0x7ffffe3,
0x7ffffe4,
0x7ffffe5,
0xfffec,
0xfffff3,
0xfffed,
0x1fffe6,
0x3fffe9,
0x1fffe7,
0x1fffe8,
0x7ffff3,
0x3fffea,
0x3fffeb,
0x1ffffee,
0x1ffffef,
0xfffff4,
0xfffff5,
0x3ffffea,
0x7ffff4,
0x3ffffeb,
0x7ffffe6,
0x3ffffec,
0x3ffffed,
0x7ffffe7,
0x7ffffe8,
0x7ffffe9,
0x7ffffea,
0x7ffffeb,
0xffffffe,
0x7ffffec,
0x7ffffed,
0x7ffffee,
0x7ffffef,
0x7fffff0,
0x3ffffee,
}
var huffmanCodeLen = []uint8{
13, 23, 28, 28, 28, 28, 28, 28, 28, 24, 30, 28, 28, 30, 28, 28,
28, 28, 28, 28, 28, 28, 30, 28, 28, 28, 28, 28, 28, 28, 28, 28,
6, 10, 10, 12, 13, 6, 8, 11, 10, 10, 8, 11, 8, 6, 6, 6,
5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 8, 15, 6, 12, 10,
13, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 8, 7, 8, 13, 19, 13, 14, 6,
15, 5, 6, 5, 6, 5, 6, 6, 6, 5, 7, 7, 6, 6, 6, 5,
6, 7, 6, 5, 5, 6, 7, 7, 7, 7, 7, 15, 11, 14, 13, 28,
20, 22, 20, 20, 22, 22, 22, 23, 22, 23, 23, 23, 23, 23, 24, 23,
24, 24, 22, 23, 24, 23, 23, 23, 23, 21, 22, 23, 22, 23, 23, 24,
22, 21, 20, 22, 22, 23, 23, 21, 23, 22, 22, 24, 21, 22, 23, 23,
21, 21, 22, 21, 23, 22, 23, 23, 20, 22, 22, 22, 23, 22, 22, 23,
26, 26, 20, 19, 22, 23, 22, 25, 26, 26, 26, 27, 27, 26, 24, 25,
19, 21, 26, 27, 27, 26, 27, 24, 21, 21, 26, 26, 28, 27, 27, 27,
20, 24, 20, 21, 22, 21, 21, 23, 22, 22, 25, 25, 24, 24, 26, 23,
26, 27, 26, 26, 27, 27, 27, 27, 27, 28, 27, 27, 27, 27, 27, 26,
}

@ -0,0 +1,249 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
// Package http2 implements the HTTP/2 protocol.
//
// This is a work in progress. This package is low-level and intended
// to be used directly by very few people. Most users will use it
// indirectly through integration with the net/http package. See
// ConfigureServer. That ConfigureServer call will likely be automatic
// or available via an empty import in the future.
//
// See http://http2.github.io/
package http2
import (
"bufio"
"fmt"
"io"
"net/http"
"strconv"
"sync"
)
var VerboseLogs = false
const (
// ClientPreface is the string that must be sent by new
// connections from clients.
ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
// SETTINGS_MAX_FRAME_SIZE default
// http://http2.github.io/http2-spec/#rfc.section.6.5.2
initialMaxFrameSize = 16384
// NextProtoTLS is the NPN/ALPN protocol negotiated during
// HTTP/2's TLS setup.
NextProtoTLS = "h2"
// http://http2.github.io/http2-spec/#SettingValues
initialHeaderTableSize = 4096
initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size
defaultMaxReadFrameSize = 1 << 20
)
var (
clientPreface = []byte(ClientPreface)
)
type streamState int
const (
stateIdle streamState = iota
stateOpen
stateHalfClosedLocal
stateHalfClosedRemote
stateResvLocal
stateResvRemote
stateClosed
)
var stateName = [...]string{
stateIdle: "Idle",
stateOpen: "Open",
stateHalfClosedLocal: "HalfClosedLocal",
stateHalfClosedRemote: "HalfClosedRemote",
stateResvLocal: "ResvLocal",
stateResvRemote: "ResvRemote",
stateClosed: "Closed",
}
func (st streamState) String() string {
return stateName[st]
}
// Setting is a setting parameter: which setting it is, and its value.
type Setting struct {
// ID is which setting is being set.
// See http://http2.github.io/http2-spec/#SettingValues
ID SettingID
// Val is the value.
Val uint32
}
func (s Setting) String() string {
return fmt.Sprintf("[%v = %d]", s.ID, s.Val)
}
// Valid reports whether the setting is valid.
func (s Setting) Valid() error {
// Limits and error codes from 6.5.2 Defined SETTINGS Parameters
switch s.ID {
case SettingEnablePush:
if s.Val != 1 && s.Val != 0 {
return ConnectionError(ErrCodeProtocol)
}
case SettingInitialWindowSize:
if s.Val > 1<<31-1 {
return ConnectionError(ErrCodeFlowControl)
}
case SettingMaxFrameSize:
if s.Val < 16384 || s.Val > 1<<24-1 {
return ConnectionError(ErrCodeProtocol)
}
}
return nil
}
// A SettingID is an HTTP/2 setting as defined in
// http://http2.github.io/http2-spec/#iana-settings
type SettingID uint16
const (
SettingHeaderTableSize SettingID = 0x1
SettingEnablePush SettingID = 0x2
SettingMaxConcurrentStreams SettingID = 0x3
SettingInitialWindowSize SettingID = 0x4
SettingMaxFrameSize SettingID = 0x5
SettingMaxHeaderListSize SettingID = 0x6
)
var settingName = map[SettingID]string{
SettingHeaderTableSize: "HEADER_TABLE_SIZE",
SettingEnablePush: "ENABLE_PUSH",
SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
SettingMaxFrameSize: "MAX_FRAME_SIZE",
SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
}
func (s SettingID) String() string {
if v, ok := settingName[s]; ok {
return v
}
return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s))
}
func validHeader(v string) bool {
if len(v) == 0 {
return false
}
for _, r := range v {
// "Just as in HTTP/1.x, header field names are
// strings of ASCII characters that are compared in a
// case-insensitive fashion. However, header field
// names MUST be converted to lowercase prior to their
// encoding in HTTP/2. "
if r >= 127 || ('A' <= r && r <= 'Z') {
return false
}
}
return true
}
var httpCodeStringCommon = map[int]string{} // n -> strconv.Itoa(n)
func init() {
for i := 100; i <= 999; i++ {
if v := http.StatusText(i); v != "" {
httpCodeStringCommon[i] = strconv.Itoa(i)
}
}
}
func httpCodeString(code int) string {
if s, ok := httpCodeStringCommon[code]; ok {
return s
}
return strconv.Itoa(code)
}
// from pkg io
type stringWriter interface {
WriteString(s string) (n int, err error)
}
// A gate lets two goroutines coordinate their activities.
type gate chan struct{}
func (g gate) Done() { g <- struct{}{} }
func (g gate) Wait() { <-g }
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
type closeWaiter chan struct{}
// Init makes a closeWaiter usable.
// It exists because so a closeWaiter value can be placed inside a
// larger struct and have the Mutex and Cond's memory in the same
// allocation.
func (cw *closeWaiter) Init() {
*cw = make(chan struct{})
}
// Close marks the closeWaiter as closed and unblocks any waiters.
func (cw closeWaiter) Close() {
close(cw)
}
// Wait waits for the closeWaiter to become closed.
func (cw closeWaiter) Wait() {
<-cw
}
// bufferedWriter is a buffered writer that writes to w.
// Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections.
type bufferedWriter struct {
w io.Writer // immutable
bw *bufio.Writer // non-nil when data is buffered
}
func newBufferedWriter(w io.Writer) *bufferedWriter {
return &bufferedWriter{w: w}
}
var bufWriterPool = sync.Pool{
New: func() interface{} {
// TODO: pick something better? this is a bit under
// (3 x typical 1500 byte MTU) at least.
return bufio.NewWriterSize(nil, 4<<10)
},
}
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer)
bw.Reset(w.w)
w.bw = bw
}
return w.bw.Write(p)
}
func (w *bufferedWriter) Flush() error {
bw := w.bw
if bw == nil {
return nil
}
err := bw.Flush()
bw.Reset(nil)
bufWriterPool.Put(bw)
w.bw = nil
return err
}

@ -0,0 +1,152 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"bytes"
"errors"
"flag"
"fmt"
"net/http"
"os/exec"
"strconv"
"strings"
"testing"
"github.com/bradfitz/http2/hpack"
)
var knownFailing = flag.Bool("known_failing", false, "Run known-failing tests.")
func condSkipFailingTest(t *testing.T) {
if !*knownFailing {
t.Skip("Skipping known-failing test without --known_failing")
}
}
func init() {
DebugGoroutines = true
flag.BoolVar(&VerboseLogs, "verboseh2", false, "Verbose HTTP/2 debug logging")
}
func TestSettingString(t *testing.T) {
tests := []struct {
s Setting
want string
}{
{Setting{SettingMaxFrameSize, 123}, "[MAX_FRAME_SIZE = 123]"},
{Setting{1<<16 - 1, 123}, "[UNKNOWN_SETTING_65535 = 123]"},
}
for i, tt := range tests {
got := fmt.Sprint(tt.s)
if got != tt.want {
t.Errorf("%d. for %#v, string = %q; want %q", i, tt.s, got, tt.want)
}
}
}
type twriter struct {
t testing.TB
st *serverTester // optional
}
func (w twriter) Write(p []byte) (n int, err error) {
if w.st != nil {
ps := string(p)
for _, phrase := range w.st.logFilter {
if strings.Contains(ps, phrase) {
return len(p), nil // no logging
}
}
}
w.t.Logf("%s", p)
return len(p), nil
}
// like encodeHeader, but don't add implicit psuedo headers.
func encodeHeaderNoImplicit(t *testing.T, headers ...string) []byte {
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
for len(headers) > 0 {
k, v := headers[0], headers[1]
headers = headers[2:]
if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
}
}
return buf.Bytes()
}
// Verify that curl has http2.
func requireCurl(t *testing.T) {
out, err := dockerLogs(curl(t, "--version"))
if err != nil {
t.Skipf("failed to determine curl features; skipping test")
}
if !strings.Contains(string(out), "HTTP2") {
t.Skip("curl doesn't support HTTP2; skipping test")
}
}
func curl(t *testing.T, args ...string) (container string) {
out, err := exec.Command("docker", append([]string{"run", "-d", "--net=host", "gohttp2/curl"}, args...)...).CombinedOutput()
if err != nil {
t.Skipf("Failed to run curl in docker: %v, %s", err, out)
}
return strings.TrimSpace(string(out))
}
type puppetCommand struct {
fn func(w http.ResponseWriter, r *http.Request)
done chan<- bool
}
type handlerPuppet struct {
ch chan puppetCommand
}
func newHandlerPuppet() *handlerPuppet {
return &handlerPuppet{
ch: make(chan puppetCommand),
}
}
func (p *handlerPuppet) act(w http.ResponseWriter, r *http.Request) {
for cmd := range p.ch {
cmd.fn(w, r)
cmd.done <- true
}
}
func (p *handlerPuppet) done() { close(p.ch) }
func (p *handlerPuppet) do(fn func(http.ResponseWriter, *http.Request)) {
done := make(chan bool)
p.ch <- puppetCommand{fn, done}
<-done
}
func dockerLogs(container string) ([]byte, error) {
out, err := exec.Command("docker", "wait", container).CombinedOutput()
if err != nil {
return out, err
}
exitStatus, err := strconv.Atoi(strings.TrimSpace(string(out)))
if err != nil {
return out, errors.New("unexpected exit status from docker wait")
}
out, err = exec.Command("docker", "logs", container).CombinedOutput()
exec.Command("docker", "rm", container).Run()
if err == nil && exitStatus != 0 {
err = fmt.Errorf("exit status %d: %s", exitStatus, out)
}
return out, err
}
func kill(container string) {
exec.Command("docker", "kill", container).Run()
exec.Command("docker", "rm", container).Run()
}

@ -0,0 +1,43 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"sync"
)
type pipe struct {
b buffer
c sync.Cond
m sync.Mutex
}
// Read waits until data is available and copies bytes
// from the buffer into p.
func (r *pipe) Read(p []byte) (n int, err error) {
r.c.L.Lock()
defer r.c.L.Unlock()
for r.b.Len() == 0 && !r.b.closed {
r.c.Wait()
}
return r.b.Read(p)
}
// Write copies bytes from p into the buffer and wakes a reader.
// It is an error to write more data than the buffer can hold.
func (w *pipe) Write(p []byte) (n int, err error) {
w.c.L.Lock()
defer w.c.L.Unlock()
defer w.c.Signal()
return w.b.Write(p)
}
func (c *pipe) Close(err error) {
c.c.L.Lock()
defer c.c.L.Unlock()
defer c.c.Signal()
c.b.Close(err)
}

@ -0,0 +1,24 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"errors"
"testing"
)
func TestPipeClose(t *testing.T) {
var p pipe
p.c.L = &p.m
a := errors.New("a")
b := errors.New("b")
p.Close(a)
p.Close(b)
_, err := p.Read(make([]byte, 1))
if err != a {
t.Errorf("err = %v want %v", err, a)
}
}

@ -0,0 +1,121 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"testing"
)
func TestPriority(t *testing.T) {
// A -> B
// move A's parent to B
streams := make(map[uint32]*stream)
a := &stream{
parent: nil,
weight: 16,
}
streams[1] = a
b := &stream{
parent: a,
weight: 16,
}
streams[2] = b
adjustStreamPriority(streams, 1, PriorityParam{
Weight: 20,
StreamDep: 2,
})
if a.parent != b {
t.Errorf("Expected A's parent to be B")
}
if a.weight != 20 {
t.Errorf("Expected A's weight to be 20; got %d", a.weight)
}
if b.parent != nil {
t.Errorf("Expected B to have no parent")
}
if b.weight != 16 {
t.Errorf("Expected B's weight to be 16; got %d", b.weight)
}
}
func TestPriorityExclusiveZero(t *testing.T) {
// A B and C are all children of the 0 stream.
// Exclusive reprioritization to any of the streams
// should bring the rest of the streams under the
// reprioritized stream
streams := make(map[uint32]*stream)
a := &stream{
parent: nil,
weight: 16,
}
streams[1] = a
b := &stream{
parent: nil,
weight: 16,
}
streams[2] = b
c := &stream{
parent: nil,
weight: 16,
}
streams[3] = c
adjustStreamPriority(streams, 3, PriorityParam{
Weight: 20,
StreamDep: 0,
Exclusive: true,
})
if a.parent != c {
t.Errorf("Expected A's parent to be C")
}
if a.weight != 16 {
t.Errorf("Expected A's weight to be 16; got %d", a.weight)
}
if b.parent != c {
t.Errorf("Expected B's parent to be C")
}
if b.weight != 16 {
t.Errorf("Expected B's weight to be 16; got %d", b.weight)
}
if c.parent != nil {
t.Errorf("Expected C to have no parent")
}
if c.weight != 20 {
t.Errorf("Expected C's weight to be 20; got %d", b.weight)
}
}
func TestPriorityOwnParent(t *testing.T) {
streams := make(map[uint32]*stream)
a := &stream{
parent: nil,
weight: 16,
}
streams[1] = a
b := &stream{
parent: a,
weight: 16,
}
streams[2] = b
adjustStreamPriority(streams, 1, PriorityParam{
Weight: 20,
StreamDep: 1,
})
if a.parent != nil {
t.Errorf("Expected A's parent to be nil")
}
if a.weight != 20 {
t.Errorf("Expected A's weight to be 20; got %d", a.weight)
}
if b.parent != a {
t.Errorf("Expected B's parent to be A")
}
if b.weight != 16 {
t.Errorf("Expected B's weight to be 16; got %d", b.weight)
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,553 @@
// Copyright 2015 The Go Authors.
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://go.googlesource.com/go/+/master/LICENSE
package http2
import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"github.com/bradfitz/http2/hpack"
)
type Transport struct {
Fallback http.RoundTripper
// TODO: remove this and make more general with a TLS dial hook, like http
InsecureTLSDial bool
connMu sync.Mutex
conns map[string][]*clientConn // key is host:port
}
type clientConn struct {
t *Transport
tconn *tls.Conn
tlsState *tls.ConnectionState
connKey []string // key(s) this connection is cached in, in t.conns
readerDone chan struct{} // closed on error
readerErr error // set before readerDone is closed
hdec *hpack.Decoder
nextRes *http.Response
mu sync.Mutex
closed bool
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
streams map[uint32]*clientStream
nextStreamID uint32
bw *bufio.Writer
werr error // first write error that has occurred
br *bufio.Reader
fr *Framer
// Settings from peer:
maxFrameSize uint32
maxConcurrentStreams uint32
initialWindowSize uint32
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
}
type clientStream struct {
ID uint32
resc chan resAndError
pw *io.PipeWriter
pr *io.PipeReader
}
type stickyErrWriter struct {
w io.Writer
err *error
}
func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil {
return 0, *sew.err
}
n, err = sew.w.Write(p)
*sew.err = err
return
}
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "https" {
if t.Fallback == nil {
return nil, errors.New("http2: unsupported scheme and no Fallback")
}
return t.Fallback.RoundTrip(req)
}
host, port, err := net.SplitHostPort(req.URL.Host)
if err != nil {
host = req.URL.Host
port = "443"
}
for {
cc, err := t.getClientConn(host, port)
if err != nil {
return nil, err
}
res, err := cc.roundTrip(req)
if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
continue
}
if err != nil {
return nil, err
}
return res, nil
}
}
// CloseIdleConnections closes any connections which were previously
// connected from previous requests but are now sitting idle.
// It does not interrupt any connections currently in use.
func (t *Transport) CloseIdleConnections() {
t.connMu.Lock()
defer t.connMu.Unlock()
for _, vv := range t.conns {
for _, cc := range vv {
cc.closeIfIdle()
}
}
}
var errClientConnClosed = errors.New("http2: client conn is closed")
func shouldRetryRequest(err error) bool {
// TODO: or GOAWAY graceful shutdown stuff
return err == errClientConnClosed
}
func (t *Transport) removeClientConn(cc *clientConn) {
t.connMu.Lock()
defer t.connMu.Unlock()
for _, key := range cc.connKey {
vv, ok := t.conns[key]
if !ok {
continue
}
newList := filterOutClientConn(vv, cc)
if len(newList) > 0 {
t.conns[key] = newList
} else {
delete(t.conns, key)
}
}
}
func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
out := in[:0]
for _, v := range in {
if v != exclude {
out = append(out, v)
}
}
return out
}
func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
t.connMu.Lock()
defer t.connMu.Unlock()
key := net.JoinHostPort(host, port)
for _, cc := range t.conns[key] {
if cc.canTakeNewRequest() {
return cc, nil
}
}
if t.conns == nil {
t.conns = make(map[string][]*clientConn)
}
cc, err := t.newClientConn(host, port, key)
if err != nil {
return nil, err
}
t.conns[key] = append(t.conns[key], cc)
return cc, nil
}
func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
cfg := &tls.Config{
ServerName: host,
NextProtos: []string{NextProtoTLS},
InsecureSkipVerify: t.InsecureTLSDial,
}
tconn, err := tls.Dial("tcp", host+":"+port, cfg)
if err != nil {
return nil, err
}
if err := tconn.Handshake(); err != nil {
return nil, err
}
if !t.InsecureTLSDial {
if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
return nil, err
}
}
state := tconn.ConnectionState()
if p := state.NegotiatedProtocol; p != NextProtoTLS {
// TODO(bradfitz): fall back to Fallback
return nil, fmt.Errorf("bad protocol: %v", p)
}
if !state.NegotiatedProtocolIsMutual {
return nil, errors.New("could not negotiate protocol mutually")
}
if _, err := tconn.Write(clientPreface); err != nil {
return nil, err
}
cc := &clientConn{
t: t,
tconn: tconn,
connKey: []string{key}, // TODO: cert's validated hostnames too
tlsState: &state,
readerDone: make(chan struct{}),
nextStreamID: 1,
maxFrameSize: 16 << 10, // spec default
initialWindowSize: 65535, // spec default
maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
streams: make(map[uint32]*clientStream),
}
cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
cc.br = bufio.NewReader(tconn)
cc.fr = NewFramer(cc.bw, cc.br)
cc.henc = hpack.NewEncoder(&cc.hbuf)
cc.fr.WriteSettings()
// TODO: re-send more conn-level flow control tokens when server uses all these.
cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
cc.bw.Flush()
if cc.werr != nil {
return nil, cc.werr
}
// Read the obligatory SETTINGS frame
f, err := cc.fr.ReadFrame()
if err != nil {
return nil, err
}
sf, ok := f.(*SettingsFrame)
if !ok {
return nil, fmt.Errorf("expected settings frame, got: %T", f)
}
cc.fr.WriteSettingsAck()
cc.bw.Flush()
sf.ForeachSetting(func(s Setting) error {
switch s.ID {
case SettingMaxFrameSize:
cc.maxFrameSize = s.Val
case SettingMaxConcurrentStreams:
cc.maxConcurrentStreams = s.Val
case SettingInitialWindowSize:
cc.initialWindowSize = s.Val
default:
// TODO(bradfitz): handle more
log.Printf("Unhandled Setting: %v", s)
}
return nil
})
// TODO: figure out henc size
cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
go cc.readLoop()
return cc, nil
}
func (cc *clientConn) setGoAway(f *GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.goAway = f
}
func (cc *clientConn) canTakeNewRequest() bool {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.goAway == nil &&
int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
cc.nextStreamID < 2147483647
}
func (cc *clientConn) closeIfIdle() {
cc.mu.Lock()
if len(cc.streams) > 0 {
cc.mu.Unlock()
return
}
cc.closed = true
// TODO: do clients send GOAWAY too? maybe? Just Close:
cc.mu.Unlock()
cc.tconn.Close()
}
func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
cc.mu.Lock()
if cc.closed {
cc.mu.Unlock()
return nil, errClientConnClosed
}
cs := cc.newStream()
hasBody := false // TODO
// we send: HEADERS[+CONTINUATION] + (DATA?)
hdrs := cc.encodeHeaders(req)
first := true
for len(hdrs) > 0 {
chunk := hdrs
if len(chunk) > int(cc.maxFrameSize) {
chunk = chunk[:cc.maxFrameSize]
}
hdrs = hdrs[len(chunk):]
endHeaders := len(hdrs) == 0
if first {
cc.fr.WriteHeaders(HeadersFrameParam{
StreamID: cs.ID,
BlockFragment: chunk,
EndStream: !hasBody,
EndHeaders: endHeaders,
})
first = false
} else {
cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
}
}
cc.bw.Flush()
werr := cc.werr
cc.mu.Unlock()
if hasBody {
// TODO: write data. and it should probably be interleaved:
// go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
}
if werr != nil {
return nil, werr
}
re := <-cs.resc
if re.err != nil {
return nil, re.err
}
res := re.res
res.Request = req
res.TLS = cc.tlsState
return res, nil
}
// requires cc.mu be held.
func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
cc.hbuf.Reset()
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
host := req.Host
if host == "" {
host = req.URL.Host
}
path := req.URL.Path
if path == "" {
path = "/"
}
cc.writeHeader(":authority", host) // probably not right for all sites
cc.writeHeader(":method", req.Method)
cc.writeHeader(":path", path)
cc.writeHeader(":scheme", "https")
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
if lowKey == "host" {
continue
}
for _, v := range vv {
cc.writeHeader(lowKey, v)
}
}
return cc.hbuf.Bytes()
}
func (cc *clientConn) writeHeader(name, value string) {
log.Printf("sending %q = %q", name, value)
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
type resAndError struct {
res *http.Response
err error
}
// requires cc.mu be held.
func (cc *clientConn) newStream() *clientStream {
cs := &clientStream{
ID: cc.nextStreamID,
resc: make(chan resAndError, 1),
}
cc.nextStreamID += 2
cc.streams[cs.ID] = cs
return cs
}
func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
cc.mu.Lock()
defer cc.mu.Unlock()
cs := cc.streams[id]
if andRemove {
delete(cc.streams, id)
}
return cs
}
// runs in its own goroutine.
func (cc *clientConn) readLoop() {
defer cc.t.removeClientConn(cc)
defer close(cc.readerDone)
activeRes := map[uint32]*clientStream{} // keyed by streamID
// Close any response bodies if the server closes prematurely.
// TODO: also do this if we've written the headers but not
// gotten a response yet.
defer func() {
err := cc.readerErr
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
for _, cs := range activeRes {
cs.pw.CloseWithError(err)
}
}()
// continueStreamID is the stream ID we're waiting for
// continuation frames for.
var continueStreamID uint32
for {
f, err := cc.fr.ReadFrame()
if err != nil {
cc.readerErr = err
return
}
log.Printf("Transport received %v: %#v", f.Header(), f)
streamID := f.Header().StreamID
_, isContinue := f.(*ContinuationFrame)
if isContinue {
if streamID != continueStreamID {
log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
cc.readerErr = ConnectionError(ErrCodeProtocol)
return
}
} else if continueStreamID != 0 {
// Continue frames need to be adjacent in the stream
// and we were in the middle of headers.
log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
cc.readerErr = ConnectionError(ErrCodeProtocol)
return
}
if streamID%2 == 0 {
// Ignore streams pushed from the server for now.
// These always have an even stream id.
continue
}
streamEnded := false
if ff, ok := f.(streamEnder); ok {
streamEnded = ff.StreamEnded()
}
cs := cc.streamByID(streamID, streamEnded)
if cs == nil {
log.Printf("Received frame for untracked stream ID %d", streamID)
continue
}
switch f := f.(type) {
case *HeadersFrame:
cc.nextRes = &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: make(http.Header),
}
cs.pr, cs.pw = io.Pipe()
cc.hdec.Write(f.HeaderBlockFragment())
case *ContinuationFrame:
cc.hdec.Write(f.HeaderBlockFragment())
case *DataFrame:
log.Printf("DATA: %q", f.Data())
cs.pw.Write(f.Data())
case *GoAwayFrame:
cc.t.removeClientConn(cc)
if f.ErrCode != 0 {
// TODO: deal with GOAWAY more. particularly the error code
log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
}
cc.setGoAway(f)
default:
log.Printf("Transport: unhandled response frame type %T", f)
}
headersEnded := false
if he, ok := f.(headersEnder); ok {
headersEnded = he.HeadersEnded()
if headersEnded {
continueStreamID = 0
} else {
continueStreamID = streamID
}
}
if streamEnded {
cs.pw.Close()
delete(activeRes, streamID)
}
if headersEnded {
if cs == nil {
panic("couldn't find stream") // TODO be graceful
}
// TODO: set the Body to one which notes the
// Close and also sends the server a
// RST_STREAM
cc.nextRes.Body = cs.pr
res := cc.nextRes
activeRes[streamID] = cs
cs.resc <- resAndError{res: res}
}
}
}
func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
// TODO: verifiy pseudo headers come before non-pseudo headers
// TODO: verifiy the status is set
log.Printf("Header field: %+v", f)
if f.Name == ":status" {
code, err := strconv.Atoi(f.Value)
if err != nil {
panic("TODO: be graceful")
}
cc.nextRes.Status = f.Value + " " + http.StatusText(code)
cc.nextRes.StatusCode = code
return
}
if strings.HasPrefix(f.Name, ":") {
// "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
// TODO: treat as invalid?
return
}
cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
}

@ -0,0 +1,168 @@
// Copyright 2015 The Go Authors.
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://go.googlesource.com/go/+/master/LICENSE
package http2
import (
"flag"
"io"
"io/ioutil"
"net/http"
"os"
"reflect"
"strings"
"testing"
"time"
)
var (
extNet = flag.Bool("extnet", false, "do external network tests")
transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
insecure = flag.Bool("insecure", false, "insecure TLS dials")
)
func TestTransportExternal(t *testing.T) {
if !*extNet {
t.Skip("skipping external network test")
}
req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
rt := &Transport{
InsecureTLSDial: *insecure,
}
res, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("%v", err)
}
res.Write(os.Stdout)
}
func TestTransport(t *testing.T) {
const body = "sup"
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, body)
})
defer st.Close()
tr := &Transport{InsecureTLSDial: true}
defer tr.CloseIdleConnections()
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
t.Logf("Got res: %+v", res)
if g, w := res.StatusCode, 200; g != w {
t.Errorf("StatusCode = %v; want %v", g, w)
}
if g, w := res.Status, "200 OK"; g != w {
t.Errorf("Status = %q; want %q", g, w)
}
wantHeader := http.Header{
"Content-Length": []string{"3"},
"Content-Type": []string{"text/plain; charset=utf-8"},
}
if !reflect.DeepEqual(res.Header, wantHeader) {
t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
}
if res.Request != req {
t.Errorf("Response.Request = %p; want %p", res.Request, req)
}
if res.TLS == nil {
t.Error("Response.TLS = nil; want non-nil")
}
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("Body read: %v", err)
} else if string(slurp) != body {
t.Errorf("Body = %q; want %q", slurp, body)
}
}
func TestTransportReusesConns(t *testing.T) {
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, r.RemoteAddr)
}, optOnlyServer)
defer st.Close()
tr := &Transport{InsecureTLSDial: true}
defer tr.CloseIdleConnections()
get := func() string {
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
slurp, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Fatalf("Body read: %v", err)
}
addr := strings.TrimSpace(string(slurp))
if addr == "" {
t.Fatalf("didn't get an addr in response")
}
return addr
}
first := get()
second := get()
if first != second {
t.Errorf("first and second responses were on different connections: %q vs %q", first, second)
}
}
func TestTransportAbortClosesPipes(t *testing.T) {
shutdown := make(chan struct{})
st := newServerTester(t,
func(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush()
<-shutdown
},
optOnlyServer,
)
defer st.Close()
defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
done := make(chan struct{})
requestMade := make(chan struct{})
go func() {
defer close(done)
tr := &Transport{
InsecureTLSDial: true,
}
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
res, err := tr.RoundTrip(req)
if err != nil {
t.Fatal(err)
}
defer res.Body.Close()
close(requestMade)
_, err = ioutil.ReadAll(res.Body)
if err == nil {
t.Error("expected error from res.Body.Read")
}
}()
<-requestMade
// Now force the serve loop to end, via closing the connection.
st.closeConn()
// deadlock? that's a bug.
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}

@ -0,0 +1,204 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"bytes"
"fmt"
"net/http"
"time"
"github.com/bradfitz/http2/hpack"
)
// writeFramer is implemented by any type that is used to write frames.
type writeFramer interface {
writeFrame(writeContext) error
}
// writeContext is the interface needed by the various frame writer
// types below. All the writeFrame methods below are scheduled via the
// frame writing scheduler (see writeScheduler in writesched.go).
//
// This interface is implemented by *serverConn.
// TODO: use it from the client code too, once it exists.
type writeContext interface {
Framer() *Framer
Flush() error
CloseConn() error
// HeaderEncoder returns an HPACK encoder that writes to the
// returned buffer.
HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
}
// endsStream reports whether the given frame writer w will locally
// close the stream.
func endsStream(w writeFramer) bool {
switch v := w.(type) {
case *writeData:
return v.endStream
case *writeResHeaders:
return v.endStream
}
return false
}
type flushFrameWriter struct{}
func (flushFrameWriter) writeFrame(ctx writeContext) error {
return ctx.Flush()
}
type writeSettings []Setting
func (s writeSettings) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteSettings([]Setting(s)...)
}
type writeGoAway struct {
maxStreamID uint32
code ErrCode
}
func (p *writeGoAway) writeFrame(ctx writeContext) error {
err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil)
if p.code != 0 {
ctx.Flush() // ignore error: we're hanging up on them anyway
time.Sleep(50 * time.Millisecond)
ctx.CloseConn()
}
return err
}
type writeData struct {
streamID uint32
p []byte
endStream bool
}
func (w *writeData) String() string {
return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream)
}
func (w *writeData) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteData(w.streamID, w.endStream, w.p)
}
func (se StreamError) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteRSTStream(se.StreamID, se.Code)
}
type writePingAck struct{ pf *PingFrame }
func (w writePingAck) writeFrame(ctx writeContext) error {
return ctx.Framer().WritePing(true, w.pf.Data)
}
type writeSettingsAck struct{}
func (writeSettingsAck) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteSettingsAck()
}
// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
// for HTTP response headers from a server handler.
type writeResHeaders struct {
streamID uint32
httpResCode int
h http.Header // may be nil
endStream bool
contentType string
contentLength string
}
func (w *writeResHeaders) writeFrame(ctx writeContext) error {
enc, buf := ctx.HeaderEncoder()
buf.Reset()
enc.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(w.httpResCode)})
for k, vv := range w.h {
k = lowerHeader(k)
for _, v := range vv {
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
if k == "transfer-encoding" && v != "trailers" {
continue
}
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
}
}
if w.contentType != "" {
enc.WriteField(hpack.HeaderField{Name: "content-type", Value: w.contentType})
}
if w.contentLength != "" {
enc.WriteField(hpack.HeaderField{Name: "content-length", Value: w.contentLength})
}
headerBlock := buf.Bytes()
if len(headerBlock) == 0 {
panic("unexpected empty hpack")
}
// For now we're lazy and just pick the minimum MAX_FRAME_SIZE
// that all peers must support (16KB). Later we could care
// more and send larger frames if the peer advertised it, but
// there's little point. Most headers are small anyway (so we
// generally won't have CONTINUATION frames), and extra frames
// only waste 9 bytes anyway.
const maxFrameSize = 16384
first := true
for len(headerBlock) > 0 {
frag := headerBlock
if len(frag) > maxFrameSize {
frag = frag[:maxFrameSize]
}
headerBlock = headerBlock[len(frag):]
endHeaders := len(headerBlock) == 0
var err error
if first {
first = false
err = ctx.Framer().WriteHeaders(HeadersFrameParam{
StreamID: w.streamID,
BlockFragment: frag,
EndStream: w.endStream,
EndHeaders: endHeaders,
})
} else {
err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag)
}
if err != nil {
return err
}
}
return nil
}
type write100ContinueHeadersFrame struct {
streamID uint32
}
func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error {
enc, buf := ctx.HeaderEncoder()
buf.Reset()
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
return ctx.Framer().WriteHeaders(HeadersFrameParam{
StreamID: w.streamID,
BlockFragment: buf.Bytes(),
EndStream: false,
EndHeaders: true,
})
}
type writeWindowUpdate struct {
streamID uint32 // or 0 for conn-level
n uint32
}
func (wu writeWindowUpdate) writeFrame(ctx writeContext) error {
return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
}

@ -0,0 +1,286 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import "fmt"
// frameWriteMsg is a request to write a frame.
type frameWriteMsg struct {
// write is the interface value that does the writing, once the
// writeScheduler (below) has decided to select this frame
// to write. The write functions are all defined in write.go.
write writeFramer
stream *stream // used for prioritization. nil for non-stream frames.
// done, if non-nil, must be a buffered channel with space for
// 1 message and is sent the return value from write (or an
// earlier error) when the frame has been written.
done chan error
}
// for debugging only:
func (wm frameWriteMsg) String() string {
var streamID uint32
if wm.stream != nil {
streamID = wm.stream.id
}
var des string
if s, ok := wm.write.(fmt.Stringer); ok {
des = s.String()
} else {
des = fmt.Sprintf("%T", wm.write)
}
return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des)
}
// writeScheduler tracks pending frames to write, priorities, and decides
// the next one to use. It is not thread-safe.
type writeScheduler struct {
// zero are frames not associated with a specific stream.
// They're sent before any stream-specific freams.
zero writeQueue
// maxFrameSize is the maximum size of a DATA frame
// we'll write. Must be non-zero and between 16K-16M.
maxFrameSize uint32
// sq contains the stream-specific queues, keyed by stream ID.
// when a stream is idle, it's deleted from the map.
sq map[uint32]*writeQueue
// canSend is a slice of memory that's reused between frame
// scheduling decisions to hold the list of writeQueues (from sq)
// which have enough flow control data to send. After canSend is
// built, the best is selected.
canSend []*writeQueue
// pool of empty queues for reuse.
queuePool []*writeQueue
}
func (ws *writeScheduler) putEmptyQueue(q *writeQueue) {
if len(q.s) != 0 {
panic("queue must be empty")
}
ws.queuePool = append(ws.queuePool, q)
}
func (ws *writeScheduler) getEmptyQueue() *writeQueue {
ln := len(ws.queuePool)
if ln == 0 {
return new(writeQueue)
}
q := ws.queuePool[ln-1]
ws.queuePool = ws.queuePool[:ln-1]
return q
}
func (ws *writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 }
func (ws *writeScheduler) add(wm frameWriteMsg) {
st := wm.stream
if st == nil {
ws.zero.push(wm)
} else {
ws.streamQueue(st.id).push(wm)
}
}
func (ws *writeScheduler) streamQueue(streamID uint32) *writeQueue {
if q, ok := ws.sq[streamID]; ok {
return q
}
if ws.sq == nil {
ws.sq = make(map[uint32]*writeQueue)
}
q := ws.getEmptyQueue()
ws.sq[streamID] = q
return q
}
// take returns the most important frame to write and removes it from the scheduler.
// It is illegal to call this if the scheduler is empty or if there are no connection-level
// flow control bytes available.
func (ws *writeScheduler) take() (wm frameWriteMsg, ok bool) {
if ws.maxFrameSize == 0 {
panic("internal error: ws.maxFrameSize not initialized or invalid")
}
// If there any frames not associated with streams, prefer those first.
// These are usually SETTINGS, etc.
if !ws.zero.empty() {
return ws.zero.shift(), true
}
if len(ws.sq) == 0 {
return
}
// Next, prioritize frames on streams that aren't DATA frames (no cost).
for id, q := range ws.sq {
if q.firstIsNoCost() {
return ws.takeFrom(id, q)
}
}
// Now, all that remains are DATA frames with non-zero bytes to
// send. So pick the best one.
if len(ws.canSend) != 0 {
panic("should be empty")
}
for _, q := range ws.sq {
if n := ws.streamWritableBytes(q); n > 0 {
ws.canSend = append(ws.canSend, q)
}
}
if len(ws.canSend) == 0 {
return
}
defer ws.zeroCanSend()
// TODO: find the best queue
q := ws.canSend[0]
return ws.takeFrom(q.streamID(), q)
}
// zeroCanSend is defered from take.
func (ws *writeScheduler) zeroCanSend() {
for i := range ws.canSend {
ws.canSend[i] = nil
}
ws.canSend = ws.canSend[:0]
}
// streamWritableBytes returns the number of DATA bytes we could write
// from the given queue's stream, if this stream/queue were
// selected. It is an error to call this if q's head isn't a
// *writeData.
func (ws *writeScheduler) streamWritableBytes(q *writeQueue) int32 {
wm := q.head()
ret := wm.stream.flow.available() // max we can write
if ret == 0 {
return 0
}
if int32(ws.maxFrameSize) < ret {
ret = int32(ws.maxFrameSize)
}
if ret == 0 {
panic("internal error: ws.maxFrameSize not initialized or invalid")
}
wd := wm.write.(*writeData)
if len(wd.p) < int(ret) {
ret = int32(len(wd.p))
}
return ret
}
func (ws *writeScheduler) takeFrom(id uint32, q *writeQueue) (wm frameWriteMsg, ok bool) {
wm = q.head()
// If the first item in this queue costs flow control tokens
// and we don't have enough, write as much as we can.
if wd, ok := wm.write.(*writeData); ok && len(wd.p) > 0 {
allowed := wm.stream.flow.available() // max we can write
if allowed == 0 {
// No quota available. Caller can try the next stream.
return frameWriteMsg{}, false
}
if int32(ws.maxFrameSize) < allowed {
allowed = int32(ws.maxFrameSize)
}
// TODO: further restrict the allowed size, because even if
// the peer says it's okay to write 16MB data frames, we might
// want to write smaller ones to properly weight competing
// streams' priorities.
if len(wd.p) > int(allowed) {
wm.stream.flow.take(allowed)
chunk := wd.p[:allowed]
wd.p = wd.p[allowed:]
// Make up a new write message of a valid size, rather
// than shifting one off the queue.
return frameWriteMsg{
stream: wm.stream,
write: &writeData{
streamID: wd.streamID,
p: chunk,
// even if the original had endStream set, there
// arebytes remaining because len(wd.p) > allowed,
// so we know endStream is false:
endStream: false,
},
// our caller is blocking on the final DATA frame, not
// these intermediates, so no need to wait:
done: nil,
}, true
}
wm.stream.flow.take(int32(len(wd.p)))
}
q.shift()
if q.empty() {
ws.putEmptyQueue(q)
delete(ws.sq, id)
}
return wm, true
}
func (ws *writeScheduler) forgetStream(id uint32) {
q, ok := ws.sq[id]
if !ok {
return
}
delete(ws.sq, id)
// But keep it for others later.
for i := range q.s {
q.s[i] = frameWriteMsg{}
}
q.s = q.s[:0]
ws.putEmptyQueue(q)
}
type writeQueue struct {
s []frameWriteMsg
}
// streamID returns the stream ID for a non-empty stream-specific queue.
func (q *writeQueue) streamID() uint32 { return q.s[0].stream.id }
func (q *writeQueue) empty() bool { return len(q.s) == 0 }
func (q *writeQueue) push(wm frameWriteMsg) {
q.s = append(q.s, wm)
}
// head returns the next item that would be removed by shift.
func (q *writeQueue) head() frameWriteMsg {
if len(q.s) == 0 {
panic("invalid use of queue")
}
return q.s[0]
}
func (q *writeQueue) shift() frameWriteMsg {
if len(q.s) == 0 {
panic("invalid use of queue")
}
wm := q.s[0]
// TODO: less copy-happy queue.
copy(q.s, q.s[1:])
q.s[len(q.s)-1] = frameWriteMsg{}
q.s = q.s[:len(q.s)-1]
return wm
}
func (q *writeQueue) firstIsNoCost() bool {
if df, ok := q.s[0].write.(*writeData); ok {
return len(df.p) == 0
}
return true
}

@ -0,0 +1,357 @@
// Copyright 2014 The Go Authors.
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
// Licensed under the same terms as Go itself:
// https://code.google.com/p/go/source/browse/LICENSE
package http2
import (
"bytes"
"encoding/xml"
"flag"
"fmt"
"io"
"os"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"testing"
)
var coverSpec = flag.Bool("coverspec", false, "Run spec coverage tests")
// The global map of sentence coverage for the http2 spec.
var defaultSpecCoverage specCoverage
var loadSpecOnce sync.Once
func loadSpec() {
if f, err := os.Open("testdata/draft-ietf-httpbis-http2.xml"); err != nil {
panic(err)
} else {
defaultSpecCoverage = readSpecCov(f)
f.Close()
}
}
// covers marks all sentences for section sec in defaultSpecCoverage. Sentences not
// "covered" will be included in report outputed by TestSpecCoverage.
func covers(sec, sentences string) {
loadSpecOnce.Do(loadSpec)
defaultSpecCoverage.cover(sec, sentences)
}
type specPart struct {
section string
sentence string
}
func (ss specPart) Less(oo specPart) bool {
atoi := func(s string) int {
n, err := strconv.Atoi(s)
if err != nil {
panic(err)
}
return n
}
a := strings.Split(ss.section, ".")
b := strings.Split(oo.section, ".")
for len(a) > 0 {
if len(b) == 0 {
return false
}
x, y := atoi(a[0]), atoi(b[0])
if x == y {
a, b = a[1:], b[1:]
continue
}
return x < y
}
if len(b) > 0 {
return true
}
return false
}
type bySpecSection []specPart
func (a bySpecSection) Len() int { return len(a) }
func (a bySpecSection) Less(i, j int) bool { return a[i].Less(a[j]) }
func (a bySpecSection) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
type specCoverage struct {
coverage map[specPart]bool
d *xml.Decoder
}
func joinSection(sec []int) string {
s := fmt.Sprintf("%d", sec[0])
for _, n := range sec[1:] {
s = fmt.Sprintf("%s.%d", s, n)
}
return s
}
func (sc specCoverage) readSection(sec []int) {
var (
buf = new(bytes.Buffer)
sub = 0
)
for {
tk, err := sc.d.Token()
if err != nil {
if err == io.EOF {
return
}
panic(err)
}
switch v := tk.(type) {
case xml.StartElement:
if skipElement(v) {
if err := sc.d.Skip(); err != nil {
panic(err)
}
if v.Name.Local == "section" {
sub++
}
break
}
switch v.Name.Local {
case "section":
sub++
sc.readSection(append(sec, sub))
case "xref":
buf.Write(sc.readXRef(v))
}
case xml.CharData:
if len(sec) == 0 {
break
}
buf.Write(v)
case xml.EndElement:
if v.Name.Local == "section" {
sc.addSentences(joinSection(sec), buf.String())
return
}
}
}
}
func (sc specCoverage) readXRef(se xml.StartElement) []byte {
var b []byte
for {
tk, err := sc.d.Token()
if err != nil {
panic(err)
}
switch v := tk.(type) {
case xml.CharData:
if b != nil {
panic("unexpected CharData")
}
b = []byte(string(v))
case xml.EndElement:
if v.Name.Local != "xref" {
panic("expected </xref>")
}
if b != nil {
return b
}
sig := attrSig(se)
switch sig {
case "target":
return []byte(fmt.Sprintf("[%s]", attrValue(se, "target")))
case "fmt-of,rel,target", "fmt-,,rel,target":
return []byte(fmt.Sprintf("[%s, %s]", attrValue(se, "target"), attrValue(se, "rel")))
case "fmt-of,sec,target", "fmt-,,sec,target":
return []byte(fmt.Sprintf("[section %s of %s]", attrValue(se, "sec"), attrValue(se, "target")))
case "fmt-of,rel,sec,target":
return []byte(fmt.Sprintf("[section %s of %s, %s]", attrValue(se, "sec"), attrValue(se, "target"), attrValue(se, "rel")))
default:
panic(fmt.Sprintf("unknown attribute signature %q in %#v", sig, fmt.Sprintf("%#v", se)))
}
default:
panic(fmt.Sprintf("unexpected tag %q", v))
}
}
}
var skipAnchor = map[string]bool{
"intro": true,
"Overview": true,
}
var skipTitle = map[string]bool{
"Acknowledgements": true,
"Change Log": true,
"Document Organization": true,
"Conventions and Terminology": true,
}
func skipElement(s xml.StartElement) bool {
switch s.Name.Local {
case "artwork":
return true
case "section":
for _, attr := range s.Attr {
switch attr.Name.Local {
case "anchor":
if skipAnchor[attr.Value] || strings.HasPrefix(attr.Value, "changes.since.") {
return true
}
case "title":
if skipTitle[attr.Value] {
return true
}
}
}
}
return false
}
func readSpecCov(r io.Reader) specCoverage {
sc := specCoverage{
coverage: map[specPart]bool{},
d: xml.NewDecoder(r)}
sc.readSection(nil)
return sc
}
func (sc specCoverage) addSentences(sec string, sentence string) {
for _, s := range parseSentences(sentence) {
sc.coverage[specPart{sec, s}] = false
}
}
func (sc specCoverage) cover(sec string, sentence string) {
for _, s := range parseSentences(sentence) {
p := specPart{sec, s}
if _, ok := sc.coverage[p]; !ok {
panic(fmt.Sprintf("Not found in spec: %q, %q", sec, s))
}
sc.coverage[specPart{sec, s}] = true
}
}
var whitespaceRx = regexp.MustCompile(`\s+`)
func parseSentences(sens string) []string {
sens = strings.TrimSpace(sens)
if sens == "" {
return nil
}
ss := strings.Split(whitespaceRx.ReplaceAllString(sens, " "), ". ")
for i, s := range ss {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, ".") {
s += "."
}
ss[i] = s
}
return ss
}
func TestSpecParseSentences(t *testing.T) {
tests := []struct {
ss string
want []string
}{
{"Sentence 1. Sentence 2.",
[]string{
"Sentence 1.",
"Sentence 2.",
}},
{"Sentence 1. \nSentence 2.\tSentence 3.",
[]string{
"Sentence 1.",
"Sentence 2.",
"Sentence 3.",
}},
}
for i, tt := range tests {
got := parseSentences(tt.ss)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("%d: got = %q, want %q", i, got, tt.want)
}
}
}
func TestSpecCoverage(t *testing.T) {
if !*coverSpec {
t.Skip()
}
loadSpecOnce.Do(loadSpec)
var (
list []specPart
cv = defaultSpecCoverage.coverage
total = len(cv)
complete = 0
)
for sp, touched := range defaultSpecCoverage.coverage {
if touched {
complete++
} else {
list = append(list, sp)
}
}
sort.Stable(bySpecSection(list))
if testing.Short() && len(list) > 5 {
list = list[:5]
}
for _, p := range list {
t.Errorf("\tSECTION %s: %s", p.section, p.sentence)
}
t.Logf("%d/%d (%d%%) sentances covered", complete, total, (complete/total)*100)
}
func attrSig(se xml.StartElement) string {
var names []string
for _, attr := range se.Attr {
if attr.Name.Local == "fmt" {
names = append(names, "fmt-"+attr.Value)
} else {
names = append(names, attr.Name.Local)
}
}
sort.Strings(names)
return strings.Join(names, ",")
}
func attrValue(se xml.StartElement, attr string) string {
for _, a := range se.Attr {
if a.Name.Local == attr {
return a.Value
}
}
panic("unknown attribute " + attr)
}
func TestSpecPartLess(t *testing.T) {
tests := []struct {
sec1, sec2 string
want bool
}{
{"6.2.1", "6.2", false},
{"6.2", "6.2.1", true},
{"6.10", "6.10.1", true},
{"6.10", "6.1.1", false}, // 10, not 1
{"6.1", "6.1", false}, // equal, so not less
}
for _, tt := range tests {
got := (specPart{tt.sec1, "foo"}).Less(specPart{tt.sec2, "foo"})
if got != tt.want {
t.Errorf("Less(%q, %q) = %v; want %v", tt.sec1, tt.sec2, got, tt.want)
}
}
}

@ -0,0 +1,13 @@
language: go
sudo: false
go:
- 1.0.3
- 1.1.2
- 1.2.2
- 1.3.3
- 1.4.2
script:
- go vet ./...
- go test -v ./...

@ -0,0 +1,21 @@
Copyright (C) 2013 Jeremy Saenz
All Rights Reserved.
MIT LICENSE
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,308 @@
[![Build Status](https://travis-ci.org/codegangsta/cli.png?branch=master)](https://travis-ci.org/codegangsta/cli)
# cli.go
cli.go is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way.
You can view the API docs here:
http://godoc.org/github.com/codegangsta/cli
## Overview
Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app.
**This is where cli.go comes into play.** cli.go makes command line programming fun, organized, and expressive!
## Installation
Make sure you have a working Go environment (go 1.1+ is *required*). [See the install instructions](http://golang.org/doc/install.html).
To install `cli.go`, simply run:
```
$ go get github.com/codegangsta/cli
```
Make sure your `PATH` includes to the `$GOPATH/bin` directory so your commands can be easily used:
```
export PATH=$PATH:$GOPATH/bin
```
## Getting Started
One of the philosophies behind cli.go is that an API should be playful and full of discovery. So a cli.go app can be as little as one line of code in `main()`.
``` go
package main
import (
"os"
"github.com/codegangsta/cli"
)
func main() {
cli.NewApp().Run(os.Args)
}
```
This app will run and show help text, but is not very useful. Let's give an action to execute and some help documentation:
``` go
package main
import (
"os"
"github.com/codegangsta/cli"
)
func main() {
app := cli.NewApp()
app.Name = "boom"
app.Usage = "make an explosive entrance"
app.Action = func(c *cli.Context) {
println("boom! I say!")
}
app.Run(os.Args)
}
```
Running this already gives you a ton of functionality, plus support for things like subcommands and flags, which are covered below.
## Example
Being a programmer can be a lonely job. Thankfully by the power of automation that is not the case! Let's create a greeter app to fend off our demons of loneliness!
Start by creating a directory named `greet`, and within it, add a file, `greet.go` with the following code in it:
``` go
package main
import (
"os"
"github.com/codegangsta/cli"
)
func main() {
app := cli.NewApp()
app.Name = "greet"
app.Usage = "fight the loneliness!"
app.Action = func(c *cli.Context) {
println("Hello friend!")
}
app.Run(os.Args)
}
```
Install our command to the `$GOPATH/bin` directory:
```
$ go install
```
Finally run our new command:
```
$ greet
Hello friend!
```
cli.go also generates some bitchass help text:
```
$ greet help
NAME:
greet - fight the loneliness!
USAGE:
greet [global options] command [command options] [arguments...]
VERSION:
0.0.0
COMMANDS:
help, h Shows a list of commands or help for one command
GLOBAL OPTIONS
--version Shows version information
```
### Arguments
You can lookup arguments by calling the `Args` function on `cli.Context`.
``` go
...
app.Action = func(c *cli.Context) {
println("Hello", c.Args()[0])
}
...
```
### Flags
Setting and querying flags is simple.
``` go
...
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang",
Value: "english",
Usage: "language for the greeting",
},
}
app.Action = func(c *cli.Context) {
name := "someone"
if len(c.Args()) > 0 {
name = c.Args()[0]
}
if c.String("lang") == "spanish" {
println("Hola", name)
} else {
println("Hello", name)
}
}
...
```
See full list of flags at http://godoc.org/github.com/codegangsta/cli
#### Alternate Names
You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g.
``` go
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang, l",
Value: "english",
Usage: "language for the greeting",
},
}
```
That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error.
#### Values from the Environment
You can also have the default value set from the environment via `EnvVar`. e.g.
``` go
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang, l",
Value: "english",
Usage: "language for the greeting",
EnvVar: "APP_LANG",
},
}
```
The `EnvVar` may also be given as a comma-delimited "cascade", where the first environment variable that resolves is used as the default.
``` go
app.Flags = []cli.Flag {
cli.StringFlag{
Name: "lang, l",
Value: "english",
Usage: "language for the greeting",
EnvVar: "LEGACY_COMPAT_LANG,APP_LANG,LANG",
},
}
```
### Subcommands
Subcommands can be defined for a more git-like command line app.
```go
...
app.Commands = []cli.Command{
{
Name: "add",
Aliases: []string{"a"},
Usage: "add a task to the list",
Action: func(c *cli.Context) {
println("added task: ", c.Args().First())
},
},
{
Name: "complete",
Aliases: []string{"c"},
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
},
},
{
Name: "template",
Aliases: []string{"r"},
Usage: "options for task templates",
Subcommands: []cli.Command{
{
Name: "add",
Usage: "add a new template",
Action: func(c *cli.Context) {
println("new task template: ", c.Args().First())
},
},
{
Name: "remove",
Usage: "remove an existing template",
Action: func(c *cli.Context) {
println("removed task template: ", c.Args().First())
},
},
},
},
}
...
```
### Bash Completion
You can enable completion commands by setting the `EnableBashCompletion`
flag on the `App` object. By default, this setting will only auto-complete to
show an app's subcommands, but you can write your own completion methods for
the App or its subcommands.
```go
...
var tasks = []string{"cook", "clean", "laundry", "eat", "sleep", "code"}
app := cli.NewApp()
app.EnableBashCompletion = true
app.Commands = []cli.Command{
{
Name: "complete",
Aliases: []string{"c"},
Usage: "complete a task on the list",
Action: func(c *cli.Context) {
println("completed task: ", c.Args().First())
},
BashComplete: func(c *cli.Context) {
// This will complete if no args are passed
if len(c.Args()) > 0 {
return
}
for _, t := range tasks {
fmt.Println(t)
}
},
}
}
...
```
#### To Enable
Source the `autocomplete/bash_autocomplete` file in your `.bashrc` file while
setting the `PROG` variable to the name of your program:
`PROG=myprogram source /.../cli/autocomplete/bash_autocomplete`
#### To Distribute
Copy and modify `autocomplete/bash_autocomplete` to use your program name
rather than `$PROG` and have the user copy the file into
`/etc/bash_completion.d/` (or automatically install it there if you are
distributing a package). Alternatively you can just document that users should
source the generic `autocomplete/bash_autocomplete` with `$PROG` set to your
program name in their bash configuration.
## Contribution Guidelines
Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch.
If you have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together.
If you feel like you have contributed to the project but have not yet been added as a collaborator, I probably forgot to add you. Hit @codegangsta up over email and we will get it figured out.

@ -0,0 +1,308 @@
package cli
import (
"fmt"
"io"
"io/ioutil"
"os"
"time"
)
// App is the main structure of a cli application. It is recomended that
// an app be created with the cli.NewApp() function
type App struct {
// The name of the program. Defaults to os.Args[0]
Name string
// Description of the program.
Usage string
// Version of the program
Version string
// List of commands to execute
Commands []Command
// List of flags to parse
Flags []Flag
// Boolean to enable bash completion commands
EnableBashCompletion bool
// Boolean to hide built-in help command
HideHelp bool
// Boolean to hide built-in version flag
HideVersion bool
// An action to execute when the bash-completion flag is set
BashComplete func(context *Context)
// An action to execute before any subcommands are run, but after the context is ready
// If a non-nil error is returned, no subcommands are run
Before func(context *Context) error
// An action to execute after any subcommands are run, but after the subcommand has finished
// It is run even if Action() panics
After func(context *Context) error
// The action to execute when no subcommands are specified
Action func(context *Context)
// Execute this function if the proper command cannot be found
CommandNotFound func(context *Context, command string)
// Compilation date
Compiled time.Time
// List of all authors who contributed
Authors []Author
// Copyright of the binary if any
Copyright string
// Name of Author (Note: Use App.Authors, this is deprecated)
Author string
// Email of Author (Note: Use App.Authors, this is deprecated)
Email string
// Writer writer to write output to
Writer io.Writer
}
// Tries to find out when this binary was compiled.
// Returns the current time if it fails to find it.
func compileTime() time.Time {
info, err := os.Stat(os.Args[0])
if err != nil {
return time.Now()
}
return info.ModTime()
}
// Creates a new cli Application with some reasonable defaults for Name, Usage, Version and Action.
func NewApp() *App {
return &App{
Name: os.Args[0],
Usage: "A new cli application",
Version: "0.0.0",
BashComplete: DefaultAppComplete,
Action: helpCommand.Action,
Compiled: compileTime(),
Writer: os.Stdout,
}
}
// Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination
func (a *App) Run(arguments []string) (err error) {
if a.Author != "" || a.Email != "" {
a.Authors = append(a.Authors, Author{Name: a.Author, Email: a.Email})
}
// append help to commands
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
if (HelpFlag != BoolFlag{}) {
a.appendFlag(HelpFlag)
}
}
//append version/help flags
if a.EnableBashCompletion {
a.appendFlag(BashCompletionFlag)
}
if !a.HideVersion {
a.appendFlag(VersionFlag)
}
// parse flags
set := flagSet(a.Name, a.Flags)
set.SetOutput(ioutil.Discard)
err = set.Parse(arguments[1:])
nerr := normalizeFlags(a.Flags, set)
if nerr != nil {
fmt.Fprintln(a.Writer, nerr)
context := NewContext(a, set, nil)
ShowAppHelp(context)
return nerr
}
context := NewContext(a, set, nil)
if err != nil {
fmt.Fprintln(a.Writer, "Incorrect Usage.")
fmt.Fprintln(a.Writer)
ShowAppHelp(context)
return err
}
if checkCompletions(context) {
return nil
}
if checkHelp(context) {
return nil
}
if checkVersion(context) {
return nil
}
if a.After != nil {
defer func() {
afterErr := a.After(context)
if afterErr != nil {
if err != nil {
err = NewMultiError(err, afterErr)
} else {
err = afterErr
}
}
}()
}
if a.Before != nil {
err := a.Before(context)
if err != nil {
return err
}
}
args := context.Args()
if args.Present() {
name := args.First()
c := a.Command(name)
if c != nil {
return c.Run(context)
}
}
// Run default Action
a.Action(context)
return nil
}
// Another entry point to the cli app, takes care of passing arguments and error handling
func (a *App) RunAndExitOnError() {
if err := a.Run(os.Args); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
// Invokes the subcommand given the context, parses ctx.Args() to generate command-specific flags
func (a *App) RunAsSubcommand(ctx *Context) (err error) {
// append help to commands
if len(a.Commands) > 0 {
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
a.Commands = append(a.Commands, helpCommand)
if (HelpFlag != BoolFlag{}) {
a.appendFlag(HelpFlag)
}
}
}
// append flags
if a.EnableBashCompletion {
a.appendFlag(BashCompletionFlag)
}
// parse flags
set := flagSet(a.Name, a.Flags)
set.SetOutput(ioutil.Discard)
err = set.Parse(ctx.Args().Tail())
nerr := normalizeFlags(a.Flags, set)
context := NewContext(a, set, ctx)
if nerr != nil {
fmt.Fprintln(a.Writer, nerr)
fmt.Fprintln(a.Writer)
if len(a.Commands) > 0 {
ShowSubcommandHelp(context)
} else {
ShowCommandHelp(ctx, context.Args().First())
}
return nerr
}
if err != nil {
fmt.Fprintln(a.Writer, "Incorrect Usage.")
fmt.Fprintln(a.Writer)
ShowSubcommandHelp(context)
return err
}
if checkCompletions(context) {
return nil
}
if len(a.Commands) > 0 {
if checkSubcommandHelp(context) {
return nil
}
} else {
if checkCommandHelp(ctx, context.Args().First()) {
return nil
}
}
if a.After != nil {
defer func() {
afterErr := a.After(context)
if afterErr != nil {
if err != nil {
err = NewMultiError(err, afterErr)
} else {
err = afterErr
}
}
}()
}
if a.Before != nil {
err := a.Before(context)
if err != nil {
return err
}
}
args := context.Args()
if args.Present() {
name := args.First()
c := a.Command(name)
if c != nil {
return c.Run(context)
}
}
// Run default Action
a.Action(context)
return nil
}
// Returns the named command on App. Returns nil if the command does not exist
func (a *App) Command(name string) *Command {
for _, c := range a.Commands {
if c.HasName(name) {
return &c
}
}
return nil
}
func (a *App) hasFlag(flag Flag) bool {
for _, f := range a.Flags {
if flag == f {
return true
}
}
return false
}
func (a *App) appendFlag(flag Flag) {
if !a.hasFlag(flag) {
a.Flags = append(a.Flags, flag)
}
}
// Author represents someone who has contributed to a cli project.
type Author struct {
Name string // The Authors name
Email string // The Authors email
}
// String makes Author comply to the Stringer interface, to allow an easy print in the templating process
func (a Author) String() string {
e := ""
if a.Email != "" {
e = "<" + a.Email + "> "
}
return fmt.Sprintf("%v %v", a.Name, e)
}

@ -0,0 +1,867 @@
package cli
import (
"bytes"
"flag"
"fmt"
"io"
"os"
"strings"
"testing"
)
func ExampleApp() {
// set args for examples sake
os.Args = []string{"greet", "--name", "Jeremy"}
app := NewApp()
app.Name = "greet"
app.Flags = []Flag{
StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
}
app.Action = func(c *Context) {
fmt.Printf("Hello %v\n", c.String("name"))
}
app.Author = "Harrison"
app.Email = "harrison@lolwut.com"
app.Authors = []Author{Author{Name: "Oliver Allen", Email: "oliver@toyshop.com"}}
app.Run(os.Args)
// Output:
// Hello Jeremy
}
func ExampleAppSubcommand() {
// set args for examples sake
os.Args = []string{"say", "hi", "english", "--name", "Jeremy"}
app := NewApp()
app.Name = "say"
app.Commands = []Command{
{
Name: "hello",
Aliases: []string{"hi"},
Usage: "use it to see a description",
Description: "This is how we describe hello the function",
Subcommands: []Command{
{
Name: "english",
Aliases: []string{"en"},
Usage: "sends a greeting in english",
Description: "greets someone in english",
Flags: []Flag{
StringFlag{
Name: "name",
Value: "Bob",
Usage: "Name of the person to greet",
},
},
Action: func(c *Context) {
fmt.Println("Hello,", c.String("name"))
},
},
},
},
}
app.Run(os.Args)
// Output:
// Hello, Jeremy
}
func ExampleAppHelp() {
// set args for examples sake
os.Args = []string{"greet", "h", "describeit"}
app := NewApp()
app.Name = "greet"
app.Flags = []Flag{
StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
}
app.Commands = []Command{
{
Name: "describeit",
Aliases: []string{"d"},
Usage: "use it to see a description",
Description: "This is how we describe describeit the function",
Action: func(c *Context) {
fmt.Printf("i like to describe things")
},
},
}
app.Run(os.Args)
// Output:
// NAME:
// describeit - use it to see a description
//
// USAGE:
// command describeit [arguments...]
//
// DESCRIPTION:
// This is how we describe describeit the function
}
func ExampleAppBashComplete() {
// set args for examples sake
os.Args = []string{"greet", "--generate-bash-completion"}
app := NewApp()
app.Name = "greet"
app.EnableBashCompletion = true
app.Commands = []Command{
{
Name: "describeit",
Aliases: []string{"d"},
Usage: "use it to see a description",
Description: "This is how we describe describeit the function",
Action: func(c *Context) {
fmt.Printf("i like to describe things")
},
}, {
Name: "next",
Usage: "next example",
Description: "more stuff to see when generating bash completion",
Action: func(c *Context) {
fmt.Printf("the next example")
},
},
}
app.Run(os.Args)
// Output:
// describeit
// d
// next
// help
// h
}
func TestApp_Run(t *testing.T) {
s := ""
app := NewApp()
app.Action = func(c *Context) {
s = s + c.Args().First()
}
err := app.Run([]string{"command", "foo"})
expect(t, err, nil)
err = app.Run([]string{"command", "bar"})
expect(t, err, nil)
expect(t, s, "foobar")
}
var commandAppTests = []struct {
name string
expected bool
}{
{"foobar", true},
{"batbaz", true},
{"b", true},
{"f", true},
{"bat", false},
{"nothing", false},
}
func TestApp_Command(t *testing.T) {
app := NewApp()
fooCommand := Command{Name: "foobar", Aliases: []string{"f"}}
batCommand := Command{Name: "batbaz", Aliases: []string{"b"}}
app.Commands = []Command{
fooCommand,
batCommand,
}
for _, test := range commandAppTests {
expect(t, app.Command(test.name) != nil, test.expected)
}
}
func TestApp_CommandWithArgBeforeFlags(t *testing.T) {
var parsedOption, firstArg string
app := NewApp()
command := Command{
Name: "cmd",
Flags: []Flag{
StringFlag{Name: "option", Value: "", Usage: "some option"},
},
Action: func(c *Context) {
parsedOption = c.String("option")
firstArg = c.Args().First()
},
}
app.Commands = []Command{command}
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option"})
expect(t, parsedOption, "my-option")
expect(t, firstArg, "my-arg")
}
func TestApp_RunAsSubcommandParseFlags(t *testing.T) {
var context *Context
a := NewApp()
a.Commands = []Command{
{
Name: "foo",
Action: func(c *Context) {
context = c
},
Flags: []Flag{
StringFlag{
Name: "lang",
Value: "english",
Usage: "language for the greeting",
},
},
Before: func(_ *Context) error { return nil },
},
}
a.Run([]string{"", "foo", "--lang", "spanish", "abcd"})
expect(t, context.Args().Get(0), "abcd")
expect(t, context.String("lang"), "spanish")
}
func TestApp_CommandWithFlagBeforeTerminator(t *testing.T) {
var parsedOption string
var args []string
app := NewApp()
command := Command{
Name: "cmd",
Flags: []Flag{
StringFlag{Name: "option", Value: "", Usage: "some option"},
},
Action: func(c *Context) {
parsedOption = c.String("option")
args = c.Args()
},
}
app.Commands = []Command{command}
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option", "--", "--notARealFlag"})
expect(t, parsedOption, "my-option")
expect(t, args[0], "my-arg")
expect(t, args[1], "--")
expect(t, args[2], "--notARealFlag")
}
func TestApp_CommandWithNoFlagBeforeTerminator(t *testing.T) {
var args []string
app := NewApp()
command := Command{
Name: "cmd",
Action: func(c *Context) {
args = c.Args()
},
}
app.Commands = []Command{command}
app.Run([]string{"", "cmd", "my-arg", "--", "notAFlagAtAll"})
expect(t, args[0], "my-arg")
expect(t, args[1], "--")
expect(t, args[2], "notAFlagAtAll")
}
func TestApp_Float64Flag(t *testing.T) {
var meters float64
app := NewApp()
app.Flags = []Flag{
Float64Flag{Name: "height", Value: 1.5, Usage: "Set the height, in meters"},
}
app.Action = func(c *Context) {
meters = c.Float64("height")
}
app.Run([]string{"", "--height", "1.93"})
expect(t, meters, 1.93)
}
func TestApp_ParseSliceFlags(t *testing.T) {
var parsedOption, firstArg string
var parsedIntSlice []int
var parsedStringSlice []string
app := NewApp()
command := Command{
Name: "cmd",
Flags: []Flag{
IntSliceFlag{Name: "p", Value: &IntSlice{}, Usage: "set one or more ip addr"},
StringSliceFlag{Name: "ip", Value: &StringSlice{}, Usage: "set one or more ports to open"},
},
Action: func(c *Context) {
parsedIntSlice = c.IntSlice("p")
parsedStringSlice = c.StringSlice("ip")
parsedOption = c.String("option")
firstArg = c.Args().First()
},
}
app.Commands = []Command{command}
app.Run([]string{"", "cmd", "my-arg", "-p", "22", "-p", "80", "-ip", "8.8.8.8", "-ip", "8.8.4.4"})
IntsEquals := func(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
StrsEquals := func(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
var expectedIntSlice = []int{22, 80}
var expectedStringSlice = []string{"8.8.8.8", "8.8.4.4"}
if !IntsEquals(parsedIntSlice, expectedIntSlice) {
t.Errorf("%v does not match %v", parsedIntSlice, expectedIntSlice)
}
if !StrsEquals(parsedStringSlice, expectedStringSlice) {
t.Errorf("%v does not match %v", parsedStringSlice, expectedStringSlice)
}
}
func TestApp_ParseSliceFlagsWithMissingValue(t *testing.T) {
var parsedIntSlice []int
var parsedStringSlice []string
app := NewApp()
command := Command{
Name: "cmd",
Flags: []Flag{
IntSliceFlag{Name: "a", Usage: "set numbers"},
StringSliceFlag{Name: "str", Usage: "set strings"},
},
Action: func(c *Context) {
parsedIntSlice = c.IntSlice("a")
parsedStringSlice = c.StringSlice("str")
},
}
app.Commands = []Command{command}
app.Run([]string{"", "cmd", "my-arg", "-a", "2", "-str", "A"})
var expectedIntSlice = []int{2}
var expectedStringSlice = []string{"A"}
if parsedIntSlice[0] != expectedIntSlice[0] {
t.Errorf("%v does not match %v", parsedIntSlice[0], expectedIntSlice[0])
}
if parsedStringSlice[0] != expectedStringSlice[0] {
t.Errorf("%v does not match %v", parsedIntSlice[0], expectedIntSlice[0])
}
}
func TestApp_DefaultStdout(t *testing.T) {
app := NewApp()
if app.Writer != os.Stdout {
t.Error("Default output writer not set.")
}
}
type mockWriter struct {
written []byte
}
func (fw *mockWriter) Write(p []byte) (n int, err error) {
if fw.written == nil {
fw.written = p
} else {
fw.written = append(fw.written, p...)
}
return len(p), nil
}
func (fw *mockWriter) GetWritten() (b []byte) {
return fw.written
}
func TestApp_SetStdout(t *testing.T) {
w := &mockWriter{}
app := NewApp()
app.Name = "test"
app.Writer = w
err := app.Run([]string{"help"})
if err != nil {
t.Fatalf("Run error: %s", err)
}
if len(w.written) == 0 {
t.Error("App did not write output to desired writer.")
}
}
func TestApp_BeforeFunc(t *testing.T) {
beforeRun, subcommandRun := false, false
beforeError := fmt.Errorf("fail")
var err error
app := NewApp()
app.Before = func(c *Context) error {
beforeRun = true
s := c.String("opt")
if s == "fail" {
return beforeError
}
return nil
}
app.Commands = []Command{
Command{
Name: "sub",
Action: func(c *Context) {
subcommandRun = true
},
},
}
app.Flags = []Flag{
StringFlag{Name: "opt"},
}
// run with the Before() func succeeding
err = app.Run([]string{"command", "--opt", "succeed", "sub"})
if err != nil {
t.Fatalf("Run error: %s", err)
}
if beforeRun == false {
t.Errorf("Before() not executed when expected")
}
if subcommandRun == false {
t.Errorf("Subcommand not executed when expected")
}
// reset
beforeRun, subcommandRun = false, false
// run with the Before() func failing
err = app.Run([]string{"command", "--opt", "fail", "sub"})
// should be the same error produced by the Before func
if err != beforeError {
t.Errorf("Run error expected, but not received")
}
if beforeRun == false {
t.Errorf("Before() not executed when expected")
}
if subcommandRun == true {
t.Errorf("Subcommand executed when NOT expected")
}
}
func TestApp_AfterFunc(t *testing.T) {
afterRun, subcommandRun := false, false
afterError := fmt.Errorf("fail")
var err error
app := NewApp()
app.After = func(c *Context) error {
afterRun = true
s := c.String("opt")
if s == "fail" {
return afterError
}
return nil
}
app.Commands = []Command{
Command{
Name: "sub",
Action: func(c *Context) {
subcommandRun = true
},
},
}
app.Flags = []Flag{
StringFlag{Name: "opt"},
}
// run with the After() func succeeding
err = app.Run([]string{"command", "--opt", "succeed", "sub"})
if err != nil {
t.Fatalf("Run error: %s", err)
}
if afterRun == false {
t.Errorf("After() not executed when expected")
}
if subcommandRun == false {
t.Errorf("Subcommand not executed when expected")
}
// reset
afterRun, subcommandRun = false, false
// run with the Before() func failing
err = app.Run([]string{"command", "--opt", "fail", "sub"})
// should be the same error produced by the Before func
if err != afterError {
t.Errorf("Run error expected, but not received")
}
if afterRun == false {
t.Errorf("After() not executed when expected")
}
if subcommandRun == false {
t.Errorf("Subcommand not executed when expected")
}
}
func TestAppNoHelpFlag(t *testing.T) {
oldFlag := HelpFlag
defer func() {
HelpFlag = oldFlag
}()
HelpFlag = BoolFlag{}
app := NewApp()
err := app.Run([]string{"test", "-h"})
if err != flag.ErrHelp {
t.Errorf("expected error about missing help flag, but got: %s (%T)", err, err)
}
}
func TestAppHelpPrinter(t *testing.T) {
oldPrinter := HelpPrinter
defer func() {
HelpPrinter = oldPrinter
}()
var wasCalled = false
HelpPrinter = func(w io.Writer, template string, data interface{}) {
wasCalled = true
}
app := NewApp()
app.Run([]string{"-h"})
if wasCalled == false {
t.Errorf("Help printer expected to be called, but was not")
}
}
func TestAppVersionPrinter(t *testing.T) {
oldPrinter := VersionPrinter
defer func() {
VersionPrinter = oldPrinter
}()
var wasCalled = false
VersionPrinter = func(c *Context) {
wasCalled = true
}
app := NewApp()
ctx := NewContext(app, nil, nil)
ShowVersion(ctx)
if wasCalled == false {
t.Errorf("Version printer expected to be called, but was not")
}
}
func TestAppCommandNotFound(t *testing.T) {
beforeRun, subcommandRun := false, false
app := NewApp()
app.CommandNotFound = func(c *Context, command string) {
beforeRun = true
}
app.Commands = []Command{
Command{
Name: "bar",
Action: func(c *Context) {
subcommandRun = true
},
},
}
app.Run([]string{"command", "foo"})
expect(t, beforeRun, true)
expect(t, subcommandRun, false)
}
func TestGlobalFlag(t *testing.T) {
var globalFlag string
var globalFlagSet bool
app := NewApp()
app.Flags = []Flag{
StringFlag{Name: "global, g", Usage: "global"},
}
app.Action = func(c *Context) {
globalFlag = c.GlobalString("global")
globalFlagSet = c.GlobalIsSet("global")
}
app.Run([]string{"command", "-g", "foo"})
expect(t, globalFlag, "foo")
expect(t, globalFlagSet, true)
}
func TestGlobalFlagsInSubcommands(t *testing.T) {
subcommandRun := false
parentFlag := false
app := NewApp()
app.Flags = []Flag{
BoolFlag{Name: "debug, d", Usage: "Enable debugging"},
}
app.Commands = []Command{
Command{
Name: "foo",
Flags: []Flag{
BoolFlag{Name: "parent, p", Usage: "Parent flag"},
},
Subcommands: []Command{
{
Name: "bar",
Action: func(c *Context) {
if c.GlobalBool("debug") {
subcommandRun = true
}
if c.GlobalBool("parent") {
parentFlag = true
}
},
},
},
},
}
app.Run([]string{"command", "-d", "foo", "-p", "bar"})
expect(t, subcommandRun, true)
expect(t, parentFlag, true)
}
func TestApp_Run_CommandWithSubcommandHasHelpTopic(t *testing.T) {
var subcommandHelpTopics = [][]string{
{"command", "foo", "--help"},
{"command", "foo", "-h"},
{"command", "foo", "help"},
}
for _, flagSet := range subcommandHelpTopics {
t.Logf("==> checking with flags %v", flagSet)
app := NewApp()
buf := new(bytes.Buffer)
app.Writer = buf
subCmdBar := Command{
Name: "bar",
Usage: "does bar things",
}
subCmdBaz := Command{
Name: "baz",
Usage: "does baz things",
}
cmd := Command{
Name: "foo",
Description: "descriptive wall of text about how it does foo things",
Subcommands: []Command{subCmdBar, subCmdBaz},
}
app.Commands = []Command{cmd}
err := app.Run(flagSet)
if err != nil {
t.Error(err)
}
output := buf.String()
t.Logf("output: %q\n", buf.Bytes())
if strings.Contains(output, "No help topic for") {
t.Errorf("expect a help topic, got none: \n%q", output)
}
for _, shouldContain := range []string{
cmd.Name, cmd.Description,
subCmdBar.Name, subCmdBar.Usage,
subCmdBaz.Name, subCmdBaz.Usage,
} {
if !strings.Contains(output, shouldContain) {
t.Errorf("want help to contain %q, did not: \n%q", shouldContain, output)
}
}
}
}
func TestApp_Run_SubcommandFullPath(t *testing.T) {
app := NewApp()
buf := new(bytes.Buffer)
app.Writer = buf
subCmd := Command{
Name: "bar",
Usage: "does bar things",
}
cmd := Command{
Name: "foo",
Description: "foo commands",
Subcommands: []Command{subCmd},
}
app.Commands = []Command{cmd}
err := app.Run([]string{"command", "foo", "bar", "--help"})
if err != nil {
t.Error(err)
}
output := buf.String()
if !strings.Contains(output, "foo bar - does bar things") {
t.Errorf("expected full path to subcommand: %s", output)
}
if !strings.Contains(output, "command foo bar [arguments...]") {
t.Errorf("expected full path to subcommand: %s", output)
}
}
func TestApp_Run_Help(t *testing.T) {
var helpArguments = [][]string{{"boom", "--help"}, {"boom", "-h"}, {"boom", "help"}}
for _, args := range helpArguments {
buf := new(bytes.Buffer)
t.Logf("==> checking with arguments %v", args)
app := NewApp()
app.Name = "boom"
app.Usage = "make an explosive entrance"
app.Writer = buf
app.Action = func(c *Context) {
buf.WriteString("boom I say!")
}
err := app.Run(args)
if err != nil {
t.Error(err)
}
output := buf.String()
t.Logf("output: %q\n", buf.Bytes())
if !strings.Contains(output, "boom - make an explosive entrance") {
t.Errorf("want help to contain %q, did not: \n%q", "boom - make an explosive entrance", output)
}
}
}
func TestApp_Run_Version(t *testing.T) {
var versionArguments = [][]string{{"boom", "--version"}, {"boom", "-v"}}
for _, args := range versionArguments {
buf := new(bytes.Buffer)
t.Logf("==> checking with arguments %v", args)
app := NewApp()
app.Name = "boom"
app.Usage = "make an explosive entrance"
app.Version = "0.1.0"
app.Writer = buf
app.Action = func(c *Context) {
buf.WriteString("boom I say!")
}
err := app.Run(args)
if err != nil {
t.Error(err)
}
output := buf.String()
t.Logf("output: %q\n", buf.Bytes())
if !strings.Contains(output, "0.1.0") {
t.Errorf("want version to contain %q, did not: \n%q", "0.1.0", output)
}
}
}
func TestApp_Run_DoesNotOverwriteErrorFromBefore(t *testing.T) {
app := NewApp()
app.Action = func(c *Context) {}
app.Before = func(c *Context) error { return fmt.Errorf("before error") }
app.After = func(c *Context) error { return fmt.Errorf("after error") }
err := app.Run([]string{"foo"})
if err == nil {
t.Fatalf("expected to recieve error from Run, got none")
}
if !strings.Contains(err.Error(), "before error") {
t.Errorf("expected text of error from Before method, but got none in \"%v\"", err)
}
if !strings.Contains(err.Error(), "after error") {
t.Errorf("expected text of error from After method, but got none in \"%v\"", err)
}
}
func TestApp_Run_SubcommandDoesNotOverwriteErrorFromBefore(t *testing.T) {
app := NewApp()
app.Commands = []Command{
Command{
Name: "bar",
Before: func(c *Context) error { return fmt.Errorf("before error") },
After: func(c *Context) error { return fmt.Errorf("after error") },
},
}
err := app.Run([]string{"foo", "bar"})
if err == nil {
t.Fatalf("expected to recieve error from Run, got none")
}
if !strings.Contains(err.Error(), "before error") {
t.Errorf("expected text of error from Before method, but got none in \"%v\"", err)
}
if !strings.Contains(err.Error(), "after error") {
t.Errorf("expected text of error from After method, but got none in \"%v\"", err)
}
}

@ -0,0 +1,13 @@
#! /bin/bash
_cli_bash_autocomplete() {
local cur prev opts base
COMPREPLY=()
cur="${COMP_WORDS[COMP_CWORD]}"
prev="${COMP_WORDS[COMP_CWORD-1]}"
opts=$( ${COMP_WORDS[@]:0:$COMP_CWORD} --generate-bash-completion )
COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) )
return 0
}
complete -F _cli_bash_autocomplete $PROG

@ -0,0 +1,5 @@
autoload -U compinit && compinit
autoload -U bashcompinit && bashcompinit
script_dir=$(dirname $0)
source ${script_dir}/bash_autocomplete

@ -0,0 +1,40 @@
// Package cli provides a minimal framework for creating and organizing command line
// Go applications. cli is designed to be easy to understand and write, the most simple
// cli application can be written as follows:
// func main() {
// cli.NewApp().Run(os.Args)
// }
//
// Of course this application does not do much, so let's make this an actual application:
// func main() {
// app := cli.NewApp()
// app.Name = "greet"
// app.Usage = "say a greeting"
// app.Action = func(c *cli.Context) {
// println("Greetings")
// }
//
// app.Run(os.Args)
// }
package cli
import (
"strings"
)
type MultiError struct {
Errors []error
}
func NewMultiError(err ...error) MultiError {
return MultiError{Errors: err}
}
func (m MultiError) Error() string {
errs := make([]string, len(m.Errors))
for i, err := range m.Errors {
errs[i] = err.Error()
}
return strings.Join(errs, "\n")
}

@ -0,0 +1,98 @@
package cli
import (
"os"
)
func Example() {
app := NewApp()
app.Name = "todo"
app.Usage = "task list on the command line"
app.Commands = []Command{
{
Name: "add",
Aliases: []string{"a"},
Usage: "add a task to the list",
Action: func(c *Context) {
println("added task: ", c.Args().First())
},
},
{
Name: "complete",
Aliases: []string{"c"},
Usage: "complete a task on the list",
Action: func(c *Context) {
println("completed task: ", c.Args().First())
},
},
}
app.Run(os.Args)
}
func ExampleSubcommand() {
app := NewApp()
app.Name = "say"
app.Commands = []Command{
{
Name: "hello",
Aliases: []string{"hi"},
Usage: "use it to see a description",
Description: "This is how we describe hello the function",
Subcommands: []Command{
{
Name: "english",
Aliases: []string{"en"},
Usage: "sends a greeting in english",
Description: "greets someone in english",
Flags: []Flag{
StringFlag{
Name: "name",
Value: "Bob",
Usage: "Name of the person to greet",
},
},
Action: func(c *Context) {
println("Hello, ", c.String("name"))
},
}, {
Name: "spanish",
Aliases: []string{"sp"},
Usage: "sends a greeting in spanish",
Flags: []Flag{
StringFlag{
Name: "surname",
Value: "Jones",
Usage: "Surname of the person to greet",
},
},
Action: func(c *Context) {
println("Hola, ", c.String("surname"))
},
}, {
Name: "french",
Aliases: []string{"fr"},
Usage: "sends a greeting in french",
Flags: []Flag{
StringFlag{
Name: "nickname",
Value: "Stevie",
Usage: "Nickname of the person to greet",
},
},
Action: func(c *Context) {
println("Bonjour, ", c.String("nickname"))
},
},
},
}, {
Name: "bye",
Usage: "says goodbye",
Action: func(c *Context) {
println("bye")
},
},
}
app.Run(os.Args)
}

@ -0,0 +1,200 @@
package cli
import (
"fmt"
"io/ioutil"
"strings"
)
// Command is a subcommand for a cli.App.
type Command struct {
// The name of the command
Name string
// short name of the command. Typically one character (deprecated, use `Aliases`)
ShortName string
// A list of aliases for the command
Aliases []string
// A short description of the usage of this command
Usage string
// A longer explanation of how the command works
Description string
// The function to call when checking for bash command completions
BashComplete func(context *Context)
// An action to execute before any sub-subcommands are run, but after the context is ready
// If a non-nil error is returned, no sub-subcommands are run
Before func(context *Context) error
// An action to execute after any subcommands are run, but after the subcommand has finished
// It is run even if Action() panics
After func(context *Context) error
// The function to call when this command is invoked
Action func(context *Context)
// List of child commands
Subcommands []Command
// List of flags to parse
Flags []Flag
// Treat all flags as normal arguments if true
SkipFlagParsing bool
// Boolean to hide built-in help command
HideHelp bool
commandNamePath []string
}
// Returns the full name of the command.
// For subcommands this ensures that parent commands are part of the command path
func (c Command) FullName() string {
if c.commandNamePath == nil {
return c.Name
}
return strings.Join(c.commandNamePath, " ")
}
// Invokes the command given the context, parses ctx.Args() to generate command-specific flags
func (c Command) Run(ctx *Context) error {
if len(c.Subcommands) > 0 || c.Before != nil || c.After != nil {
return c.startApp(ctx)
}
if !c.HideHelp && (HelpFlag != BoolFlag{}) {
// append help to flags
c.Flags = append(
c.Flags,
HelpFlag,
)
}
if ctx.App.EnableBashCompletion {
c.Flags = append(c.Flags, BashCompletionFlag)
}
set := flagSet(c.Name, c.Flags)
set.SetOutput(ioutil.Discard)
firstFlagIndex := -1
terminatorIndex := -1
for index, arg := range ctx.Args() {
if arg == "--" {
terminatorIndex = index
break
} else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 {
firstFlagIndex = index
}
}
var err error
if firstFlagIndex > -1 && !c.SkipFlagParsing {
args := ctx.Args()
regularArgs := make([]string, len(args[1:firstFlagIndex]))
copy(regularArgs, args[1:firstFlagIndex])
var flagArgs []string
if terminatorIndex > -1 {
flagArgs = args[firstFlagIndex:terminatorIndex]
regularArgs = append(regularArgs, args[terminatorIndex:]...)
} else {
flagArgs = args[firstFlagIndex:]
}
err = set.Parse(append(flagArgs, regularArgs...))
} else {
err = set.Parse(ctx.Args().Tail())
}
if err != nil {
fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.")
fmt.Fprintln(ctx.App.Writer)
ShowCommandHelp(ctx, c.Name)
return err
}
nerr := normalizeFlags(c.Flags, set)
if nerr != nil {
fmt.Fprintln(ctx.App.Writer, nerr)
fmt.Fprintln(ctx.App.Writer)
ShowCommandHelp(ctx, c.Name)
return nerr
}
context := NewContext(ctx.App, set, ctx)
if checkCommandCompletions(context, c.Name) {
return nil
}
if checkCommandHelp(context, c.Name) {
return nil
}
context.Command = c
c.Action(context)
return nil
}
func (c Command) Names() []string {
names := []string{c.Name}
if c.ShortName != "" {
names = append(names, c.ShortName)
}
return append(names, c.Aliases...)
}
// Returns true if Command.Name or Command.ShortName matches given name
func (c Command) HasName(name string) bool {
for _, n := range c.Names() {
if n == name {
return true
}
}
return false
}
func (c Command) startApp(ctx *Context) error {
app := NewApp()
// set the name and usage
app.Name = fmt.Sprintf("%s %s", ctx.App.Name, c.Name)
if c.Description != "" {
app.Usage = c.Description
} else {
app.Usage = c.Usage
}
// set CommandNotFound
app.CommandNotFound = ctx.App.CommandNotFound
// set the flags and commands
app.Commands = c.Subcommands
app.Flags = c.Flags
app.HideHelp = c.HideHelp
app.Version = ctx.App.Version
app.HideVersion = ctx.App.HideVersion
app.Compiled = ctx.App.Compiled
app.Author = ctx.App.Author
app.Email = ctx.App.Email
app.Writer = ctx.App.Writer
// bash completion
app.EnableBashCompletion = ctx.App.EnableBashCompletion
if c.BashComplete != nil {
app.BashComplete = c.BashComplete
}
// set the actions
app.Before = c.Before
app.After = c.After
if c.Action != nil {
app.Action = c.Action
} else {
app.Action = helpSubcommand.Action
}
var newCmds []Command
for _, cc := range app.Commands {
cc.commandNamePath = []string{c.Name, cc.Name}
newCmds = append(newCmds, cc)
}
app.Commands = newCmds
return app.RunAsSubcommand(ctx)
}

@ -0,0 +1,47 @@
package cli
import (
"flag"
"testing"
)
func TestCommandDoNotIgnoreFlags(t *testing.T) {
app := NewApp()
set := flag.NewFlagSet("test", 0)
test := []string{"blah", "blah", "-break"}
set.Parse(test)
c := NewContext(app, set, nil)
command := Command{
Name: "test-cmd",
Aliases: []string{"tc"},
Usage: "this is for testing",
Description: "testing",
Action: func(_ *Context) {},
}
err := command.Run(c)
expect(t, err.Error(), "flag provided but not defined: -break")
}
func TestCommandIgnoreFlags(t *testing.T) {
app := NewApp()
set := flag.NewFlagSet("test", 0)
test := []string{"blah", "blah"}
set.Parse(test)
c := NewContext(app, set, nil)
command := Command{
Name: "test-cmd",
Aliases: []string{"tc"},
Usage: "this is for testing",
Description: "testing",
Action: func(_ *Context) {},
SkipFlagParsing: true,
}
err := command.Run(c)
expect(t, err, nil)
}

@ -0,0 +1,388 @@
package cli
import (
"errors"
"flag"
"strconv"
"strings"
"time"
)
// Context is a type that is passed through to
// each Handler action in a cli application. Context
// can be used to retrieve context-specific Args and
// parsed command-line options.
type Context struct {
App *App
Command Command
flagSet *flag.FlagSet
setFlags map[string]bool
globalSetFlags map[string]bool
parentContext *Context
}
// Creates a new context. For use in when invoking an App or Command action.
func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context {
return &Context{App: app, flagSet: set, parentContext: parentCtx}
}
// Looks up the value of a local int flag, returns 0 if no int flag exists
func (c *Context) Int(name string) int {
return lookupInt(name, c.flagSet)
}
// Looks up the value of a local time.Duration flag, returns 0 if no time.Duration flag exists
func (c *Context) Duration(name string) time.Duration {
return lookupDuration(name, c.flagSet)
}
// Looks up the value of a local float64 flag, returns 0 if no float64 flag exists
func (c *Context) Float64(name string) float64 {
return lookupFloat64(name, c.flagSet)
}
// Looks up the value of a local bool flag, returns false if no bool flag exists
func (c *Context) Bool(name string) bool {
return lookupBool(name, c.flagSet)
}
// Looks up the value of a local boolT flag, returns false if no bool flag exists
func (c *Context) BoolT(name string) bool {
return lookupBoolT(name, c.flagSet)
}
// Looks up the value of a local string flag, returns "" if no string flag exists
func (c *Context) String(name string) string {
return lookupString(name, c.flagSet)
}
// Looks up the value of a local string slice flag, returns nil if no string slice flag exists
func (c *Context) StringSlice(name string) []string {
return lookupStringSlice(name, c.flagSet)
}
// Looks up the value of a local int slice flag, returns nil if no int slice flag exists
func (c *Context) IntSlice(name string) []int {
return lookupIntSlice(name, c.flagSet)
}
// Looks up the value of a local generic flag, returns nil if no generic flag exists
func (c *Context) Generic(name string) interface{} {
return lookupGeneric(name, c.flagSet)
}
// Looks up the value of a global int flag, returns 0 if no int flag exists
func (c *Context) GlobalInt(name string) int {
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupInt(name, fs)
}
return 0
}
// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists
func (c *Context) GlobalDuration(name string) time.Duration {
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupDuration(name, fs)
}
return 0
}
// Looks up the value of a global bool flag, returns false if no bool flag exists
func (c *Context) GlobalBool(name string) bool {
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupBool(name, fs)
}
return false
}
// Looks up the value of a global string flag, returns "" if no string flag exists
func (c *Context) GlobalString(name string) string {
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupString(name, fs)
}
return ""
}
// Looks up the value of a global string slice flag, returns nil if no string slice flag exists
func (c *Context) GlobalStringSlice(name string) []string {
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupStringSlice(name, fs)
}
return nil
}
// Looks up the value of a global int slice flag, returns nil if no int slice flag exists
func (c *Context) GlobalIntSlice(name string) []int {
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupIntSlice(name, fs)
}
return nil
}
// Looks up the value of a global generic flag, returns nil if no generic flag exists
func (c *Context) GlobalGeneric(name string) interface{} {
if fs := lookupGlobalFlagSet(name, c); fs != nil {
return lookupGeneric(name, fs)
}
return nil
}
// Returns the number of flags set
func (c *Context) NumFlags() int {
return c.flagSet.NFlag()
}
// Determines if the flag was actually set
func (c *Context) IsSet(name string) bool {
if c.setFlags == nil {
c.setFlags = make(map[string]bool)
c.flagSet.Visit(func(f *flag.Flag) {
c.setFlags[f.Name] = true
})
}
return c.setFlags[name] == true
}
// Determines if the global flag was actually set
func (c *Context) GlobalIsSet(name string) bool {
if c.globalSetFlags == nil {
c.globalSetFlags = make(map[string]bool)
ctx := c
if ctx.parentContext != nil {
ctx = ctx.parentContext
}
for ; ctx != nil && c.globalSetFlags[name] == false; ctx = ctx.parentContext {
ctx.flagSet.Visit(func(f *flag.Flag) {
c.globalSetFlags[f.Name] = true
})
}
}
return c.globalSetFlags[name]
}
// Returns a slice of flag names used in this context.
func (c *Context) FlagNames() (names []string) {
for _, flag := range c.Command.Flags {
name := strings.Split(flag.getName(), ",")[0]
if name == "help" {
continue
}
names = append(names, name)
}
return
}
// Returns a slice of global flag names used by the app.
func (c *Context) GlobalFlagNames() (names []string) {
for _, flag := range c.App.Flags {
name := strings.Split(flag.getName(), ",")[0]
if name == "help" || name == "version" {
continue
}
names = append(names, name)
}
return
}
// Returns the parent context, if any
func (c *Context) Parent() *Context {
return c.parentContext
}
type Args []string
// Returns the command line arguments associated with the context.
func (c *Context) Args() Args {
args := Args(c.flagSet.Args())
return args
}
// Returns the nth argument, or else a blank string
func (a Args) Get(n int) string {
if len(a) > n {
return a[n]
}
return ""
}
// Returns the first argument, or else a blank string
func (a Args) First() string {
return a.Get(0)
}
// Return the rest of the arguments (not the first one)
// or else an empty string slice
func (a Args) Tail() []string {
if len(a) >= 2 {
return []string(a)[1:]
}
return []string{}
}
// Checks if there are any arguments present
func (a Args) Present() bool {
return len(a) != 0
}
// Swaps arguments at the given indexes
func (a Args) Swap(from, to int) error {
if from >= len(a) || to >= len(a) {
return errors.New("index out of range")
}
a[from], a[to] = a[to], a[from]
return nil
}
func lookupGlobalFlagSet(name string, ctx *Context) *flag.FlagSet {
if ctx.parentContext != nil {
ctx = ctx.parentContext
}
for ; ctx != nil; ctx = ctx.parentContext {
if f := ctx.flagSet.Lookup(name); f != nil {
return ctx.flagSet
}
}
return nil
}
func lookupInt(name string, set *flag.FlagSet) int {
f := set.Lookup(name)
if f != nil {
val, err := strconv.Atoi(f.Value.String())
if err != nil {
return 0
}
return val
}
return 0
}
func lookupDuration(name string, set *flag.FlagSet) time.Duration {
f := set.Lookup(name)
if f != nil {
val, err := time.ParseDuration(f.Value.String())
if err == nil {
return val
}
}
return 0
}
func lookupFloat64(name string, set *flag.FlagSet) float64 {
f := set.Lookup(name)
if f != nil {
val, err := strconv.ParseFloat(f.Value.String(), 64)
if err != nil {
return 0
}
return val
}
return 0
}
func lookupString(name string, set *flag.FlagSet) string {
f := set.Lookup(name)
if f != nil {
return f.Value.String()
}
return ""
}
func lookupStringSlice(name string, set *flag.FlagSet) []string {
f := set.Lookup(name)
if f != nil {
return (f.Value.(*StringSlice)).Value()
}
return nil
}
func lookupIntSlice(name string, set *flag.FlagSet) []int {
f := set.Lookup(name)
if f != nil {
return (f.Value.(*IntSlice)).Value()
}
return nil
}
func lookupGeneric(name string, set *flag.FlagSet) interface{} {
f := set.Lookup(name)
if f != nil {
return f.Value
}
return nil
}
func lookupBool(name string, set *flag.FlagSet) bool {
f := set.Lookup(name)
if f != nil {
val, err := strconv.ParseBool(f.Value.String())
if err != nil {
return false
}
return val
}
return false
}
func lookupBoolT(name string, set *flag.FlagSet) bool {
f := set.Lookup(name)
if f != nil {
val, err := strconv.ParseBool(f.Value.String())
if err != nil {
return true
}
return val
}
return false
}
func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) {
switch ff.Value.(type) {
case *StringSlice:
default:
set.Set(name, ff.Value.String())
}
}
func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
visited := make(map[string]bool)
set.Visit(func(f *flag.Flag) {
visited[f.Name] = true
})
for _, f := range flags {
parts := strings.Split(f.getName(), ",")
if len(parts) == 1 {
continue
}
var ff *flag.Flag
for _, name := range parts {
name = strings.Trim(name, " ")
if visited[name] {
if ff != nil {
return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name)
}
ff = set.Lookup(name)
}
}
if ff == nil {
continue
}
for _, name := range parts {
name = strings.Trim(name, " ")
if !visited[name] {
copyFlag(name, ff, set)
}
}
}
return nil
}

@ -0,0 +1,113 @@
package cli
import (
"flag"
"testing"
"time"
)
func TestNewContext(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Int("myflag", 12, "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Int("myflag", 42, "doc")
globalCtx := NewContext(nil, globalSet, nil)
command := Command{Name: "mycommand"}
c := NewContext(nil, set, globalCtx)
c.Command = command
expect(t, c.Int("myflag"), 12)
expect(t, c.GlobalInt("myflag"), 42)
expect(t, c.Command.Name, "mycommand")
}
func TestContext_Int(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Int("myflag", 12, "doc")
c := NewContext(nil, set, nil)
expect(t, c.Int("myflag"), 12)
}
func TestContext_Duration(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Duration("myflag", time.Duration(12*time.Second), "doc")
c := NewContext(nil, set, nil)
expect(t, c.Duration("myflag"), time.Duration(12*time.Second))
}
func TestContext_String(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.String("myflag", "hello world", "doc")
c := NewContext(nil, set, nil)
expect(t, c.String("myflag"), "hello world")
}
func TestContext_Bool(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
c := NewContext(nil, set, nil)
expect(t, c.Bool("myflag"), false)
}
func TestContext_BoolT(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", true, "doc")
c := NewContext(nil, set, nil)
expect(t, c.BoolT("myflag"), true)
}
func TestContext_Args(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
c := NewContext(nil, set, nil)
set.Parse([]string{"--myflag", "bat", "baz"})
expect(t, len(c.Args()), 2)
expect(t, c.Bool("myflag"), true)
}
func TestContext_IsSet(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
set.String("otherflag", "hello world", "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Bool("myflagGlobal", true, "doc")
globalCtx := NewContext(nil, globalSet, nil)
c := NewContext(nil, set, globalCtx)
set.Parse([]string{"--myflag", "bat", "baz"})
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"})
expect(t, c.IsSet("myflag"), true)
expect(t, c.IsSet("otherflag"), false)
expect(t, c.IsSet("bogusflag"), false)
expect(t, c.IsSet("myflagGlobal"), false)
}
func TestContext_GlobalIsSet(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
set.String("otherflag", "hello world", "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Bool("myflagGlobal", true, "doc")
globalSet.Bool("myflagGlobalUnset", true, "doc")
globalCtx := NewContext(nil, globalSet, nil)
c := NewContext(nil, set, globalCtx)
set.Parse([]string{"--myflag", "bat", "baz"})
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"})
expect(t, c.GlobalIsSet("myflag"), false)
expect(t, c.GlobalIsSet("otherflag"), false)
expect(t, c.GlobalIsSet("bogusflag"), false)
expect(t, c.GlobalIsSet("myflagGlobal"), true)
expect(t, c.GlobalIsSet("myflagGlobalUnset"), false)
expect(t, c.GlobalIsSet("bogusGlobal"), false)
}
func TestContext_NumFlags(t *testing.T) {
set := flag.NewFlagSet("test", 0)
set.Bool("myflag", false, "doc")
set.String("otherflag", "hello world", "doc")
globalSet := flag.NewFlagSet("test", 0)
globalSet.Bool("myflagGlobal", true, "doc")
globalCtx := NewContext(nil, globalSet, nil)
c := NewContext(nil, set, globalCtx)
set.Parse([]string{"--myflag", "--otherflag=foo"})
globalSet.Parse([]string{"--myflagGlobal"})
expect(t, c.NumFlags(), 2)
}

@ -0,0 +1,497 @@
package cli
import (
"flag"
"fmt"
"os"
"strconv"
"strings"
"time"
)
// This flag enables bash-completion for all commands and subcommands
var BashCompletionFlag = BoolFlag{
Name: "generate-bash-completion",
}
// This flag prints the version for the application
var VersionFlag = BoolFlag{
Name: "version, v",
Usage: "print the version",
}
// This flag prints the help for all commands and subcommands
// Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand
// unless HideHelp is set to true)
var HelpFlag = BoolFlag{
Name: "help, h",
Usage: "show help",
}
// Flag is a common interface related to parsing flags in cli.
// For more advanced flag parsing techniques, it is recomended that
// this interface be implemented.
type Flag interface {
fmt.Stringer
// Apply Flag settings to the given flag set
Apply(*flag.FlagSet)
getName() string
}
func flagSet(name string, flags []Flag) *flag.FlagSet {
set := flag.NewFlagSet(name, flag.ContinueOnError)
for _, f := range flags {
f.Apply(set)
}
return set
}
func eachName(longName string, fn func(string)) {
parts := strings.Split(longName, ",")
for _, name := range parts {
name = strings.Trim(name, " ")
fn(name)
}
}
// Generic is a generic parseable type identified by a specific flag
type Generic interface {
Set(value string) error
String() string
}
// GenericFlag is the flag type for types implementing Generic
type GenericFlag struct {
Name string
Value Generic
Usage string
EnvVar string
}
// String returns the string representation of the generic flag to display the
// help text to the user (uses the String() method of the generic flag to show
// the value)
func (f GenericFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s \"%v\"\t%v", prefixFor(f.Name), f.Name, f.Value, f.Usage))
}
// Apply takes the flagset and calls Set on the generic flag with the value
// provided by the user for parsing by the flag
func (f GenericFlag) Apply(set *flag.FlagSet) {
val := f.Value
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
val.Set(envVal)
break
}
}
}
eachName(f.Name, func(name string) {
set.Var(f.Value, name, f.Usage)
})
}
func (f GenericFlag) getName() string {
return f.Name
}
// StringSlice is an opaque type for []string to satisfy flag.Value
type StringSlice []string
// Set appends the string value to the list of values
func (f *StringSlice) Set(value string) error {
*f = append(*f, value)
return nil
}
// String returns a readable representation of this value (for usage defaults)
func (f *StringSlice) String() string {
return fmt.Sprintf("%s", *f)
}
// Value returns the slice of strings set by this flag
func (f *StringSlice) Value() []string {
return *f
}
// StringSlice is a string flag that can be specified multiple times on the
// command-line
type StringSliceFlag struct {
Name string
Value *StringSlice
Usage string
EnvVar string
}
// String returns the usage
func (f StringSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f StringSliceFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
newVal := &StringSlice{}
for _, s := range strings.Split(envVal, ",") {
s = strings.TrimSpace(s)
newVal.Set(s)
}
f.Value = newVal
break
}
}
}
eachName(f.Name, func(name string) {
if f.Value == nil {
f.Value = &StringSlice{}
}
set.Var(f.Value, name, f.Usage)
})
}
func (f StringSliceFlag) getName() string {
return f.Name
}
// StringSlice is an opaque type for []int to satisfy flag.Value
type IntSlice []int
// Set parses the value into an integer and appends it to the list of values
func (f *IntSlice) Set(value string) error {
tmp, err := strconv.Atoi(value)
if err != nil {
return err
} else {
*f = append(*f, tmp)
}
return nil
}
// String returns a readable representation of this value (for usage defaults)
func (f *IntSlice) String() string {
return fmt.Sprintf("%d", *f)
}
// Value returns the slice of ints set by this flag
func (f *IntSlice) Value() []int {
return *f
}
// IntSliceFlag is an int flag that can be specified multiple times on the
// command-line
type IntSliceFlag struct {
Name string
Value *IntSlice
Usage string
EnvVar string
}
// String returns the usage
func (f IntSliceFlag) String() string {
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
pref := prefixFor(firstName)
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f IntSliceFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
newVal := &IntSlice{}
for _, s := range strings.Split(envVal, ",") {
s = strings.TrimSpace(s)
err := newVal.Set(s)
if err != nil {
fmt.Fprintf(os.Stderr, err.Error())
}
}
f.Value = newVal
break
}
}
}
eachName(f.Name, func(name string) {
if f.Value == nil {
f.Value = &IntSlice{}
}
set.Var(f.Value, name, f.Usage)
})
}
func (f IntSliceFlag) getName() string {
return f.Name
}
// BoolFlag is a switch that defaults to false
type BoolFlag struct {
Name string
Usage string
EnvVar string
}
// String returns a readable representation of this value (for usage defaults)
func (f BoolFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f BoolFlag) Apply(set *flag.FlagSet) {
val := false
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
}
break
}
}
}
eachName(f.Name, func(name string) {
set.Bool(name, val, f.Usage)
})
}
func (f BoolFlag) getName() string {
return f.Name
}
// BoolTFlag this represents a boolean flag that is true by default, but can
// still be set to false by --some-flag=false
type BoolTFlag struct {
Name string
Usage string
EnvVar string
}
// String returns a readable representation of this value (for usage defaults)
func (f BoolTFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f BoolTFlag) Apply(set *flag.FlagSet) {
val := true
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValBool, err := strconv.ParseBool(envVal)
if err == nil {
val = envValBool
break
}
}
}
}
eachName(f.Name, func(name string) {
set.Bool(name, val, f.Usage)
})
}
func (f BoolTFlag) getName() string {
return f.Name
}
// StringFlag represents a flag that takes as string value
type StringFlag struct {
Name string
Value string
Usage string
EnvVar string
}
// String returns the usage
func (f StringFlag) String() string {
var fmtString string
fmtString = "%s %v\t%v"
if len(f.Value) > 0 {
fmtString = "%s \"%v\"\t%v"
} else {
fmtString = "%s %v\t%v"
}
return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f StringFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
f.Value = envVal
break
}
}
}
eachName(f.Name, func(name string) {
set.String(name, f.Value, f.Usage)
})
}
func (f StringFlag) getName() string {
return f.Name
}
// IntFlag is a flag that takes an integer
// Errors if the value provided cannot be parsed
type IntFlag struct {
Name string
Value int
Usage string
EnvVar string
}
// String returns the usage
func (f IntFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f IntFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValInt, err := strconv.ParseInt(envVal, 0, 64)
if err == nil {
f.Value = int(envValInt)
break
}
}
}
}
eachName(f.Name, func(name string) {
set.Int(name, f.Value, f.Usage)
})
}
func (f IntFlag) getName() string {
return f.Name
}
// DurationFlag is a flag that takes a duration specified in Go's duration
// format: https://golang.org/pkg/time/#ParseDuration
type DurationFlag struct {
Name string
Value time.Duration
Usage string
EnvVar string
}
// String returns a readable representation of this value (for usage defaults)
func (f DurationFlag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f DurationFlag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValDuration, err := time.ParseDuration(envVal)
if err == nil {
f.Value = envValDuration
break
}
}
}
}
eachName(f.Name, func(name string) {
set.Duration(name, f.Value, f.Usage)
})
}
func (f DurationFlag) getName() string {
return f.Name
}
// Float64Flag is a flag that takes an float value
// Errors if the value provided cannot be parsed
type Float64Flag struct {
Name string
Value float64
Usage string
EnvVar string
}
// String returns the usage
func (f Float64Flag) String() string {
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
}
// Apply populates the flag given the flag set and environment
func (f Float64Flag) Apply(set *flag.FlagSet) {
if f.EnvVar != "" {
for _, envVar := range strings.Split(f.EnvVar, ",") {
envVar = strings.TrimSpace(envVar)
if envVal := os.Getenv(envVar); envVal != "" {
envValFloat, err := strconv.ParseFloat(envVal, 10)
if err == nil {
f.Value = float64(envValFloat)
}
}
}
}
eachName(f.Name, func(name string) {
set.Float64(name, f.Value, f.Usage)
})
}
func (f Float64Flag) getName() string {
return f.Name
}
func prefixFor(name string) (prefix string) {
if len(name) == 1 {
prefix = "-"
} else {
prefix = "--"
}
return
}
func prefixedNames(fullName string) (prefixed string) {
parts := strings.Split(fullName, ",")
for i, name := range parts {
name = strings.Trim(name, " ")
prefixed += prefixFor(name) + name
if i < len(parts)-1 {
prefixed += ", "
}
}
return
}
func withEnvHint(envVar, str string) string {
envText := ""
if envVar != "" {
envText = fmt.Sprintf(" [$%s]", strings.Join(strings.Split(envVar, ","), ", $"))
}
return str + envText
}

@ -0,0 +1,740 @@
package cli
import (
"fmt"
"os"
"reflect"
"strings"
"testing"
)
var boolFlagTests = []struct {
name string
expected string
}{
{"help", "--help\t"},
{"h", "-h\t"},
}
func TestBoolFlagHelpOutput(t *testing.T) {
for _, test := range boolFlagTests {
flag := BoolFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
var stringFlagTests = []struct {
name string
value string
expected string
}{
{"help", "", "--help \t"},
{"h", "", "-h \t"},
{"h", "", "-h \t"},
{"test", "Something", "--test \"Something\"\t"},
}
func TestStringFlagHelpOutput(t *testing.T) {
for _, test := range stringFlagTests {
flag := StringFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestStringFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_FOO", "derp")
for _, test := range stringFlagTests {
flag := StringFlag{Name: test.name, Value: test.value, EnvVar: "APP_FOO"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_FOO]") {
t.Errorf("%s does not end with [$APP_FOO]", output)
}
}
}
var stringSliceFlagTests = []struct {
name string
value *StringSlice
expected string
}{
{"help", func() *StringSlice {
s := &StringSlice{}
s.Set("")
return s
}(), "--help [--help option --help option]\t"},
{"h", func() *StringSlice {
s := &StringSlice{}
s.Set("")
return s
}(), "-h [-h option -h option]\t"},
{"h", func() *StringSlice {
s := &StringSlice{}
s.Set("")
return s
}(), "-h [-h option -h option]\t"},
{"test", func() *StringSlice {
s := &StringSlice{}
s.Set("Something")
return s
}(), "--test [--test option --test option]\t"},
}
func TestStringSliceFlagHelpOutput(t *testing.T) {
for _, test := range stringSliceFlagTests {
flag := StringSliceFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestStringSliceFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_QWWX", "11,4")
for _, test := range stringSliceFlagTests {
flag := StringSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_QWWX"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_QWWX]") {
t.Errorf("%q does not end with [$APP_QWWX]", output)
}
}
}
var intFlagTests = []struct {
name string
expected string
}{
{"help", "--help \"0\"\t"},
{"h", "-h \"0\"\t"},
}
func TestIntFlagHelpOutput(t *testing.T) {
for _, test := range intFlagTests {
flag := IntFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestIntFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_BAR", "2")
for _, test := range intFlagTests {
flag := IntFlag{Name: test.name, EnvVar: "APP_BAR"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAR]") {
t.Errorf("%s does not end with [$APP_BAR]", output)
}
}
}
var durationFlagTests = []struct {
name string
expected string
}{
{"help", "--help \"0\"\t"},
{"h", "-h \"0\"\t"},
}
func TestDurationFlagHelpOutput(t *testing.T) {
for _, test := range durationFlagTests {
flag := DurationFlag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestDurationFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_BAR", "2h3m6s")
for _, test := range durationFlagTests {
flag := DurationFlag{Name: test.name, EnvVar: "APP_BAR"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAR]") {
t.Errorf("%s does not end with [$APP_BAR]", output)
}
}
}
var intSliceFlagTests = []struct {
name string
value *IntSlice
expected string
}{
{"help", &IntSlice{}, "--help [--help option --help option]\t"},
{"h", &IntSlice{}, "-h [-h option -h option]\t"},
{"h", &IntSlice{}, "-h [-h option -h option]\t"},
{"test", func() *IntSlice {
i := &IntSlice{}
i.Set("9")
return i
}(), "--test [--test option --test option]\t"},
}
func TestIntSliceFlagHelpOutput(t *testing.T) {
for _, test := range intSliceFlagTests {
flag := IntSliceFlag{Name: test.name, Value: test.value}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestIntSliceFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_SMURF", "42,3")
for _, test := range intSliceFlagTests {
flag := IntSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_SMURF"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_SMURF]") {
t.Errorf("%q does not end with [$APP_SMURF]", output)
}
}
}
var float64FlagTests = []struct {
name string
expected string
}{
{"help", "--help \"0\"\t"},
{"h", "-h \"0\"\t"},
}
func TestFloat64FlagHelpOutput(t *testing.T) {
for _, test := range float64FlagTests {
flag := Float64Flag{Name: test.name}
output := flag.String()
if output != test.expected {
t.Errorf("%s does not match %s", output, test.expected)
}
}
}
func TestFloat64FlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_BAZ", "99.4")
for _, test := range float64FlagTests {
flag := Float64Flag{Name: test.name, EnvVar: "APP_BAZ"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_BAZ]") {
t.Errorf("%s does not end with [$APP_BAZ]", output)
}
}
}
var genericFlagTests = []struct {
name string
value Generic
expected string
}{
{"test", &Parser{"abc", "def"}, "--test \"abc,def\"\ttest flag"},
{"t", &Parser{"abc", "def"}, "-t \"abc,def\"\ttest flag"},
}
func TestGenericFlagHelpOutput(t *testing.T) {
for _, test := range genericFlagTests {
flag := GenericFlag{Name: test.name, Value: test.value, Usage: "test flag"}
output := flag.String()
if output != test.expected {
t.Errorf("%q does not match %q", output, test.expected)
}
}
}
func TestGenericFlagWithEnvVarHelpOutput(t *testing.T) {
os.Clearenv()
os.Setenv("APP_ZAP", "3")
for _, test := range genericFlagTests {
flag := GenericFlag{Name: test.name, EnvVar: "APP_ZAP"}
output := flag.String()
if !strings.HasSuffix(output, " [$APP_ZAP]") {
t.Errorf("%s does not end with [$APP_ZAP]", output)
}
}
}
func TestParseMultiString(t *testing.T) {
(&App{
Flags: []Flag{
StringFlag{Name: "serve, s"},
},
Action: func(ctx *Context) {
if ctx.String("serve") != "10" {
t.Errorf("main name not set")
}
if ctx.String("s") != "10" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10"})
}
func TestParseMultiStringFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_COUNT", "20")
(&App{
Flags: []Flag{
StringFlag{Name: "count, c", EnvVar: "APP_COUNT"},
},
Action: func(ctx *Context) {
if ctx.String("count") != "20" {
t.Errorf("main name not set")
}
if ctx.String("c") != "20" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run"})
}
func TestParseMultiStringFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_COUNT", "20")
(&App{
Flags: []Flag{
StringFlag{Name: "count, c", EnvVar: "COMPAT_COUNT,APP_COUNT"},
},
Action: func(ctx *Context) {
if ctx.String("count") != "20" {
t.Errorf("main name not set")
}
if ctx.String("c") != "20" {
t.Errorf("short name not set")
}
},
}).Run([]string{"run"})
}
func TestParseMultiStringSlice(t *testing.T) {
(&App{
Flags: []Flag{
StringSliceFlag{Name: "serve, s", Value: &StringSlice{}},
},
Action: func(ctx *Context) {
if !reflect.DeepEqual(ctx.StringSlice("serve"), []string{"10", "20"}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.StringSlice("s"), []string{"10", "20"}) {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10", "-s", "20"})
}
func TestParseMultiStringSliceFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_INTERVALS", "20,30,40")
(&App{
Flags: []Flag{
StringSliceFlag{Name: "intervals, i", Value: &StringSlice{}, EnvVar: "APP_INTERVALS"},
},
Action: func(ctx *Context) {
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiStringSliceFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_INTERVALS", "20,30,40")
(&App{
Flags: []Flag{
StringSliceFlag{Name: "intervals, i", Value: &StringSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"},
},
Action: func(ctx *Context) {
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiInt(t *testing.T) {
a := App{
Flags: []Flag{
IntFlag{Name: "serve, s"},
},
Action: func(ctx *Context) {
if ctx.Int("serve") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("s") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10"})
}
func TestParseMultiIntFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_TIMEOUT_SECONDS", "10")
a := App{
Flags: []Flag{
IntFlag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *Context) {
if ctx.Int("timeout") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("t") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiIntFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_TIMEOUT_SECONDS", "10")
a := App{
Flags: []Flag{
IntFlag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *Context) {
if ctx.Int("timeout") != 10 {
t.Errorf("main name not set")
}
if ctx.Int("t") != 10 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiIntSlice(t *testing.T) {
(&App{
Flags: []Flag{
IntSliceFlag{Name: "serve, s", Value: &IntSlice{}},
},
Action: func(ctx *Context) {
if !reflect.DeepEqual(ctx.IntSlice("serve"), []int{10, 20}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.IntSlice("s"), []int{10, 20}) {
t.Errorf("short name not set")
}
},
}).Run([]string{"run", "-s", "10", "-s", "20"})
}
func TestParseMultiIntSliceFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_INTERVALS", "20,30,40")
(&App{
Flags: []Flag{
IntSliceFlag{Name: "intervals, i", Value: &IntSlice{}, EnvVar: "APP_INTERVALS"},
},
Action: func(ctx *Context) {
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiIntSliceFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_INTERVALS", "20,30,40")
(&App{
Flags: []Flag{
IntSliceFlag{Name: "intervals, i", Value: &IntSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"},
},
Action: func(ctx *Context) {
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) {
t.Errorf("short name not set from env")
}
},
}).Run([]string{"run"})
}
func TestParseMultiFloat64(t *testing.T) {
a := App{
Flags: []Flag{
Float64Flag{Name: "serve, s"},
},
Action: func(ctx *Context) {
if ctx.Float64("serve") != 10.2 {
t.Errorf("main name not set")
}
if ctx.Float64("s") != 10.2 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10.2"})
}
func TestParseMultiFloat64FromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
a := App{
Flags: []Flag{
Float64Flag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *Context) {
if ctx.Float64("timeout") != 15.5 {
t.Errorf("main name not set")
}
if ctx.Float64("t") != 15.5 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiFloat64FromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
a := App{
Flags: []Flag{
Float64Flag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"},
},
Action: func(ctx *Context) {
if ctx.Float64("timeout") != 15.5 {
t.Errorf("main name not set")
}
if ctx.Float64("t") != 15.5 {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBool(t *testing.T) {
a := App{
Flags: []Flag{
BoolFlag{Name: "serve, s"},
},
Action: func(ctx *Context) {
if ctx.Bool("serve") != true {
t.Errorf("main name not set")
}
if ctx.Bool("s") != true {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "--serve"})
}
func TestParseMultiBoolFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_DEBUG", "1")
a := App{
Flags: []Flag{
BoolFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
},
Action: func(ctx *Context) {
if ctx.Bool("debug") != true {
t.Errorf("main name not set from env")
}
if ctx.Bool("d") != true {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBoolFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_DEBUG", "1")
a := App{
Flags: []Flag{
BoolFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"},
},
Action: func(ctx *Context) {
if ctx.Bool("debug") != true {
t.Errorf("main name not set from env")
}
if ctx.Bool("d") != true {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBoolT(t *testing.T) {
a := App{
Flags: []Flag{
BoolTFlag{Name: "serve, s"},
},
Action: func(ctx *Context) {
if ctx.BoolT("serve") != true {
t.Errorf("main name not set")
}
if ctx.BoolT("s") != true {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "--serve"})
}
func TestParseMultiBoolTFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_DEBUG", "0")
a := App{
Flags: []Flag{
BoolTFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
},
Action: func(ctx *Context) {
if ctx.BoolT("debug") != false {
t.Errorf("main name not set from env")
}
if ctx.BoolT("d") != false {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseMultiBoolTFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_DEBUG", "0")
a := App{
Flags: []Flag{
BoolTFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"},
},
Action: func(ctx *Context) {
if ctx.BoolT("debug") != false {
t.Errorf("main name not set from env")
}
if ctx.BoolT("d") != false {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
type Parser [2]string
func (p *Parser) Set(value string) error {
parts := strings.Split(value, ",")
if len(parts) != 2 {
return fmt.Errorf("invalid format")
}
(*p)[0] = parts[0]
(*p)[1] = parts[1]
return nil
}
func (p *Parser) String() string {
return fmt.Sprintf("%s,%s", p[0], p[1])
}
func TestParseGeneric(t *testing.T) {
a := App{
Flags: []Flag{
GenericFlag{Name: "serve, s", Value: &Parser{}},
},
Action: func(ctx *Context) {
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"10", "20"}) {
t.Errorf("main name not set")
}
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"10", "20"}) {
t.Errorf("short name not set")
}
},
}
a.Run([]string{"run", "-s", "10,20"})
}
func TestParseGenericFromEnv(t *testing.T) {
os.Clearenv()
os.Setenv("APP_SERVE", "20,30")
a := App{
Flags: []Flag{
GenericFlag{Name: "serve, s", Value: &Parser{}, EnvVar: "APP_SERVE"},
},
Action: func(ctx *Context) {
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"20", "30"}) {
t.Errorf("main name not set from env")
}
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"20", "30"}) {
t.Errorf("short name not set from env")
}
},
}
a.Run([]string{"run"})
}
func TestParseGenericFromEnvCascade(t *testing.T) {
os.Clearenv()
os.Setenv("APP_FOO", "99,2000")
a := App{
Flags: []Flag{
GenericFlag{Name: "foos", Value: &Parser{}, EnvVar: "COMPAT_FOO,APP_FOO"},
},
Action: func(ctx *Context) {
if !reflect.DeepEqual(ctx.Generic("foos"), &Parser{"99", "2000"}) {
t.Errorf("value not set from env")
}
},
}
a.Run([]string{"run"})
}

@ -0,0 +1,238 @@
package cli
import (
"fmt"
"io"
"strings"
"text/tabwriter"
"text/template"
)
// The text template for the Default help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var AppHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
USAGE:
{{.Name}} {{if .Flags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} [arguments...]
{{if .Version}}
VERSION:
{{.Version}}
{{end}}{{if len .Authors}}
AUTHOR(S):
{{range .Authors}}{{ . }}{{end}}
{{end}}{{if .Commands}}
COMMANDS:
{{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}}
{{end}}{{end}}{{if .Flags}}
GLOBAL OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{end}}{{if .Copyright }}
COPYRIGHT:
{{.Copyright}}
{{end}}
`
// The text template for the command help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var CommandHelpTemplate = `NAME:
{{.FullName}} - {{.Usage}}
USAGE:
command {{.FullName}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}}
DESCRIPTION:
{{.Description}}{{end}}{{if .Flags}}
OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{ end }}
`
// The text template for the subcommand help topic.
// cli.go uses text/template to render templates. You can
// render custom help text by setting this variable.
var SubcommandHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}
USAGE:
{{.Name}} command{{if .Flags}} [command options]{{end}} [arguments...]
COMMANDS:
{{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}}
{{end}}{{if .Flags}}
OPTIONS:
{{range .Flags}}{{.}}
{{end}}{{end}}
`
var helpCommand = Command{
Name: "help",
Aliases: []string{"h"},
Usage: "Shows a list of commands or help for one command",
Action: func(c *Context) {
args := c.Args()
if args.Present() {
ShowCommandHelp(c, args.First())
} else {
ShowAppHelp(c)
}
},
}
var helpSubcommand = Command{
Name: "help",
Aliases: []string{"h"},
Usage: "Shows a list of commands or help for one command",
Action: func(c *Context) {
args := c.Args()
if args.Present() {
ShowCommandHelp(c, args.First())
} else {
ShowSubcommandHelp(c)
}
},
}
// Prints help for the App or Command
type helpPrinter func(w io.Writer, templ string, data interface{})
var HelpPrinter helpPrinter = printHelp
// Prints version for the App
var VersionPrinter = printVersion
func ShowAppHelp(c *Context) {
HelpPrinter(c.App.Writer, AppHelpTemplate, c.App)
}
// Prints the list of subcommands as the default app completion method
func DefaultAppComplete(c *Context) {
for _, command := range c.App.Commands {
for _, name := range command.Names() {
fmt.Fprintln(c.App.Writer, name)
}
}
}
// Prints help for the given command
func ShowCommandHelp(ctx *Context, command string) {
// show the subcommand help for a command with subcommands
if command == "" {
HelpPrinter(ctx.App.Writer, SubcommandHelpTemplate, ctx.App)
return
}
for _, c := range ctx.App.Commands {
if c.HasName(command) {
HelpPrinter(ctx.App.Writer, CommandHelpTemplate, c)
return
}
}
if ctx.App.CommandNotFound != nil {
ctx.App.CommandNotFound(ctx, command)
} else {
fmt.Fprintf(ctx.App.Writer, "No help topic for '%v'\n", command)
}
}
// Prints help for the given subcommand
func ShowSubcommandHelp(c *Context) {
ShowCommandHelp(c, c.Command.Name)
}
// Prints the version number of the App
func ShowVersion(c *Context) {
VersionPrinter(c)
}
func printVersion(c *Context) {
fmt.Fprintf(c.App.Writer, "%v version %v\n", c.App.Name, c.App.Version)
}
// Prints the lists of commands within a given context
func ShowCompletions(c *Context) {
a := c.App
if a != nil && a.BashComplete != nil {
a.BashComplete(c)
}
}
// Prints the custom completions for a given command
func ShowCommandCompletions(ctx *Context, command string) {
c := ctx.App.Command(command)
if c != nil && c.BashComplete != nil {
c.BashComplete(ctx)
}
}
func printHelp(out io.Writer, templ string, data interface{}) {
funcMap := template.FuncMap{
"join": strings.Join,
}
w := tabwriter.NewWriter(out, 0, 8, 1, '\t', 0)
t := template.Must(template.New("help").Funcs(funcMap).Parse(templ))
err := t.Execute(w, data)
if err != nil {
panic(err)
}
w.Flush()
}
func checkVersion(c *Context) bool {
if c.GlobalBool("version") || c.GlobalBool("v") || c.Bool("version") || c.Bool("v") {
ShowVersion(c)
return true
}
return false
}
func checkHelp(c *Context) bool {
if c.GlobalBool("h") || c.GlobalBool("help") || c.Bool("h") || c.Bool("help") {
ShowAppHelp(c)
return true
}
return false
}
func checkCommandHelp(c *Context, name string) bool {
if c.Bool("h") || c.Bool("help") {
ShowCommandHelp(c, name)
return true
}
return false
}
func checkSubcommandHelp(c *Context) bool {
if c.GlobalBool("h") || c.GlobalBool("help") {
ShowSubcommandHelp(c)
return true
}
return false
}
func checkCompletions(c *Context) bool {
if (c.GlobalBool(BashCompletionFlag.Name) || c.Bool(BashCompletionFlag.Name)) && c.App.EnableBashCompletion {
ShowCompletions(c)
return true
}
return false
}
func checkCommandCompletions(c *Context, name string) bool {
if c.Bool(BashCompletionFlag.Name) && c.App.EnableBashCompletion {
ShowCommandCompletions(c, name)
return true
}
return false
}

@ -0,0 +1,36 @@
package cli
import (
"bytes"
"testing"
)
func Test_ShowAppHelp_NoAuthor(t *testing.T) {
output := new(bytes.Buffer)
app := NewApp()
app.Writer = output
c := NewContext(app, nil, nil)
ShowAppHelp(c)
if bytes.Index(output.Bytes(), []byte("AUTHOR(S):")) != -1 {
t.Errorf("expected\n%snot to include %s", output.String(), "AUTHOR(S):")
}
}
func Test_ShowAppHelp_NoVersion(t *testing.T) {
output := new(bytes.Buffer)
app := NewApp()
app.Writer = output
app.Version = ""
c := NewContext(app, nil, nil)
ShowAppHelp(c)
if bytes.Index(output.Bytes(), []byte("VERSION:")) != -1 {
t.Errorf("expected\n%snot to include %s", output.String(), "VERSION:")
}
}

@ -0,0 +1,19 @@
package cli
import (
"reflect"
"testing"
)
/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}
func refute(t *testing.T, a interface{}, b interface{}) {
if a == b {
t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
}
}

@ -0,0 +1,22 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe

@ -0,0 +1,24 @@
# kproc
[![GoDoc](https://godoc.org/github.com/codeskyblue/kproc?status.svg)](https://godoc.org/github.com/codeskyblue/kproc)
This is a golang lib, offer a better way to kill all child process.
Tested on _windows, linux, darwin._
This lib has been used in [fswatch](https://github.com/codeskyblue/fswatch).
## Usage
go get -v github.com/codeskyblue/kproc
example:
func main() {
p := kproc.ProcString("python flask_main.py")
p.Start()
time.Sleep(3 * time.Second)
err := p.Terminate(syscall.SIGKILL)
if err != nil {
log.Println(err)
}
}

@ -0,0 +1,7 @@
package kproc
import "os/exec"
type Process struct {
*exec.Cmd
}

@ -0,0 +1,45 @@
// +build !windows
package kproc
import (
"log"
"os"
"os/exec"
"syscall"
)
func setupCmd(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{}
cmd.SysProcAttr.Setsid = true
}
func ProcCommand(cmd *exec.Cmd) *Process {
setupCmd(cmd)
return &Process{
Cmd: cmd,
}
}
func ProcString(command string) *Process {
cmd := exec.Command("/bin/bash", "-c", command)
setupCmd(cmd)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return &Process{
Cmd: cmd,
}
}
func (p *Process) Terminate(sig os.Signal) (err error) {
if p.Process == nil {
return
}
// find pgid, ref: http://unix.stackexchange.com/questions/14815/process-descendants
group, err := os.FindProcess(-1 * p.Process.Pid)
log.Println(group)
if err == nil {
err = group.Signal(sig)
}
return err
}

@ -0,0 +1,33 @@
package kproc
import (
"os"
"os/exec"
"strconv"
)
func ProcCommand(cmd *exec.Cmd) *Process {
return &Process{
Cmd: cmd,
}
}
func ProcString(command string) *Process {
cmd := exec.Command("cmd", "/c", command)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return &Process{
Cmd: cmd,
}
}
func (p *Process) Terminate(sig os.Signal) (err error) {
if p.Process == nil {
return nil
}
pid := p.Process.Pid
c := exec.Command("taskkill", "/t", "/f", "/pid", strconv.Itoa(pid))
c.Stdout = os.Stdout
c.Stderr = os.Stderr
return c.Run()
}

@ -0,0 +1 @@
web: python flask_main.py

@ -0,0 +1,11 @@
import flask
app = flask.Flask(__name__)
@app.route('/')
def homepage():
return 'Home'
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0')

@ -0,0 +1,22 @@
package main
import (
"fmt"
"kproc"
"log"
"os/exec"
"syscall"
"time"
)
func main() {
p := kproc.ProcString("python flask_main.py")
p.Start()
time.Sleep(10 * time.Second)
err := p.Terminate(syscall.SIGKILL)
if err != nil {
log.Println(err)
}
out, _ := exec.Command("lsof", "-i:5000").CombinedOutput()
fmt.Println(string(out))
}

Binary file not shown.

@ -0,0 +1,24 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
src

@ -0,0 +1,8 @@
language: go
go:
- 1.2
- tip
notifications:
email:
- ionathan@gmail.com
- marcosnils@gmail.com

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2013 Jonathan Leibiusky and Marcos Lilljedahl
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,3 @@
test:
go get -v -d -t ./...
go test -v

@ -0,0 +1,444 @@
[![Build Status](https://img.shields.io/travis/franela/goreq/master.svg)](https://travis-ci.org/franela/goreq)
[![GoDoc](https://godoc.org/github.com/franela/goreq?status.svg)](https://godoc.org/github.com/franela/goreq)
GoReq
=======
Simple and sane HTTP request library for Go language.
**Table of Contents**
- [Why GoReq?](#user-content-why-goreq)
- [How do I install it?](#user-content-how-do-i-install-it)
- [What can I do with it?](#user-content-what-can-i-do-with-it)
- [Making requests with different methods](#user-content-making-requests-with-different-methods)
- [GET](#user-content-get)
- [Tags](#user-content-tags)
- [POST](#user-content-post)
- [Sending payloads in the Body](#user-content-sending-payloads-in-the-body)
- [Specifiying request headers](#user-content-specifiying-request-headers)
- [Sending Cookies](#cookie-support)
- [Setting timeouts](#user-content-setting-timeouts)
- [Using the Response and Error](#user-content-using-the-response-and-error)
- [Receiving JSON](#user-content-receiving-json)
- [Sending/Receiving Compressed Payloads](#user-content-sendingreceiving-compressed-payloads)
- [Using gzip compression:](#user-content-using-gzip-compression)
- [Using deflate compression:](#user-content-using-deflate-compression)
- [Using compressed responses:](#user-content-using-compressed-responses)
- [Proxy](#proxy)
- [Debugging requests](#debug)
- [Getting raw Request & Response](#getting-raw-request--response)
- [TODO:](#user-content-todo)
Why GoReq?
==========
Go has very nice native libraries that allows you to do lots of cool things. But sometimes those libraries are too low level, which means that to do a simple thing, like an HTTP Request, it takes some time. And if you want to do something as simple as adding a timeout to a request, you will end up writing several lines of code.
This is why we think GoReq is useful. Because you can do all your HTTP requests in a very simple and comprehensive way, while enabling you to do more advanced stuff by giving you access to the native API.
How do I install it?
====================
```bash
go get github.com/franela/goreq
```
What can I do with it?
======================
## Making requests with different methods
#### GET
```go
res, err := goreq.Request{ Uri: "http://www.google.com" }.Do()
```
GoReq default method is GET.
You can also set value to GET method easily
```go
type Item struct {
Limit int
Skip int
Fields string
}
item := Item {
Limit: 3,
Skip: 5,
Fields: "Value",
}
res, err := goreq.Request{
Uri: "http://localhost:3000/",
QueryString: item,
}.Do()
```
The sample above will send `http://localhost:3000/?limit=3&skip=5&fields=Value`
Alternatively the `url` tag can be used in struct fields to customize encoding properties
```go
type Item struct {
TheLimit int `url:"the_limit"`
TheSkip string `url:"the_skip,omitempty"`
TheFields string `url:"-"`
}
item := Item {
TheLimit: 3,
TheSkip: "",
TheFields: "Value",
}
res, err := goreq.Request{
Uri: "http://localhost:3000/",
QueryString: item,
}.Do()
```
The sample above will send `http://localhost:3000/?the_limit=3`
QueryString also support url.Values
```go
item := url.Values{}
item.Set("Limit", 3)
item.Add("Field", "somefield")
item.Add("Field", "someotherfield")
res, err := goreq.Request{
Uri: "http://localhost:3000/",
QueryString: item,
}.Do()
```
The sample above will send `http://localhost:3000/?limit=3&field=somefield&field=someotherfield`
### Tags
Struct field `url` tag is mainly used as the request parameter name.
Tags can be comma separated multiple values, 1st value is for naming and rest has special meanings.
- special tag for 1st value
- `-`: value is ignored if set this
- special tag for rest 2nd value
- `omitempty`: zero-value is ignored if set this
- `squash`: the fields of embedded struct is used for parameter
#### Tag Examples
```go
type Place struct {
Country string `url:"country"`
City string `url:"city"`
ZipCode string `url:"zipcode,omitempty"`
}
type Person struct {
Place `url:",squash"`
FirstName string `url:"first_name"`
LastName string `url:"last_name"`
Age string `url:"age,omitempty"`
Password string `url:"-"`
}
johnbull := Person{
Place: Place{ // squash the embedded struct value
Country: "UK",
City: "London",
ZipCode: "SW1",
},
FirstName: "John",
LastName: "Doe",
Age: "35",
Password: "my-secret", // ignored for parameter
}
goreq.Request{
Uri: "http://localhost/",
QueryString: johnbull,
}.Do()
// => `http://localhost/?first_name=John&last_name=Doe&age=35&country=UK&city=London&zip_code=SW1`
// age and zipcode will be ignored because of `omitempty`
// but firstname isn't.
samurai := Person{
Place: Place{ // squash the embedded struct value
Country: "Japan",
City: "Tokyo",
},
LastName: "Yagyu",
}
goreq.Request{
Uri: "http://localhost/",
QueryString: samurai,
}.Do()
// => `http://localhost/?first_name=&last_name=yagyu&country=Japan&city=Tokyo`
```
#### POST
```go
res, err := goreq.Request{ Method: "POST", Uri: "http://www.google.com" }.Do()
```
## Sending payloads in the Body
You can send ```string```, ```Reader``` or ```interface{}``` in the body. The first two will be sent as text. The last one will be marshalled to JSON, if possible.
```go
type Item struct {
Id int
Name string
}
item := Item{ Id: 1111, Name: "foobar" }
res, err := goreq.Request{
Method: "POST",
Uri: "http://www.google.com",
Body: item,
}.Do()
```
## Specifiying request headers
We think that most of the times the request headers that you use are: ```Host```, ```Content-Type```, ```Accept``` and ```User-Agent```. This is why we decided to make it very easy to set these headers.
```go
res, err := goreq.Request{
Uri: "http://www.google.com",
Host: "foobar.com",
Accept: "application/json",
ContentType: "application/json",
UserAgent: "goreq",
}.Do()
```
But sometimes you need to set other headers. You can still do it.
```go
req := goreq.Request{ Uri: "http://www.google.com" }
req.AddHeader("X-Custom", "somevalue")
req.Do()
```
Alternatively you can use the `WithHeader` function to keep the syntax short
```go
res, err = goreq.Request{ Uri: "http://www.google.com" }.WithHeader("X-Custom", "somevalue").Do()
```
## Cookie support
Cookies can be either set at the request level by sending a [CookieJar](http://golang.org/pkg/net/http/cookiejar/) in the `CookieJar` request field
or you can use goreq's one-liner WithCookie method as shown below
```go
res, err := goreq.Request{
Uri: "http://www.google.com",
}.
WithCookie(&http.Cookie{Name: "c1", Value: "v1"}).
Do()
```
## Setting timeouts
GoReq supports 2 kind of timeouts. A general connection timeout and a request specific one. By default the connection timeout is of 1 second. There is no default for request timeout, which means it will wait forever.
You can change the connection timeout doing:
```go
goreq.SetConnectTimeout(100 * time.Millisecond)
```
And specify the request timeout doing:
```go
res, err := goreq.Request{
Uri: "http://www.google.com",
Timeout: 500 * time.Millisecond,
}.Do()
```
## Using the Response and Error
GoReq will always return 2 values: a ```Response``` and an ```Error```.
If ```Error``` is not ```nil``` it means that an error happened while doing the request and you shouldn't use the ```Response``` in any way.
You can check what happened by getting the error message:
```go
fmt.Println(err.Error())
```
And to make it easy to know if it was a timeout error, you can ask the error or return it:
```go
if serr, ok := err.(*goreq.Error); ok {
if serr.Timeout() {
...
}
}
return err
```
If you don't get an error, you can safely use the ```Response```.
```go
res.Uri // return final URL location of the response (fulfilled after redirect was made)
res.StatusCode // return the status code of the response
res.Body // gives you access to the body
res.Body.ToString() // will return the body as a string
res.Header.Get("Content-Type") // gives you access to all the response headers
```
Remember that you should **always** close `res.Body` if it's not `nil`
## Receiving JSON
GoReq will help you to receive and unmarshal JSON.
```go
type Item struct {
Id int
Name string
}
var item Item
res.Body.FromJsonTo(&item)
```
## Sending/Receiving Compressed Payloads
GoReq supports gzip, deflate and zlib compression of requests' body and transparent decompression of responses provided they have a correct `Content-Encoding` header.
#####Using gzip compression:
```go
res, err := goreq.Request{
Method: "POST",
Uri: "http://www.google.com",
Body: item,
Compression: goreq.Gzip(),
}.Do()
```
#####Using deflate/zlib compression:
```go
res, err := goreq.Request{
Method: "POST",
Uri: "http://www.google.com",
Body: item,
Compression: goreq.Deflate(),
}.Do()
```
#####Using compressed responses:
If servers replies a correct and matching `Content-Encoding` header (gzip requires `Content-Encoding: gzip` and deflate `Content-Encoding: deflate`) goreq transparently decompresses the response so the previous example should always work:
```go
type Item struct {
Id int
Name string
}
res, err := goreq.Request{
Method: "POST",
Uri: "http://www.google.com",
Body: item,
Compression: goreq.Gzip(),
}.Do()
var item Item
res.Body.FromJsonTo(&item)
```
If no `Content-Encoding` header is replied by the server GoReq will return the crude response.
## Proxy
If you need to use a proxy for your requests GoReq supports the standard `http_proxy` env variable as well as manually setting the proxy for each request
```go
res, err := goreq.Request{
Method: "GET",
Proxy: "http://myproxy:myproxyport",
Uri: "http://www.google.com",
}.Do()
```
### Proxy basic auth is also supported
```go
res, err := goreq.Request{
Method: "GET",
Proxy: "http://user:pass@myproxy:myproxyport",
Uri: "http://www.google.com",
}.Do()
```
## Debug
If you need to debug your http requests, it can print the http request detail.
```go
res, err := goreq.Request{
Method: "GET",
Uri: "http://www.google.com",
Compression: goreq.Gzip(),
ShowDebug: true,
}.Do()
fmt.Println(res, err)
```
and it will print the log:
```
GET / HTTP/1.1
Host: www.google.com
Accept:
Accept-Encoding: gzip
Content-Encoding: gzip
Content-Type:
```
### Getting raw Request & Response
To get the Request:
```go
req := goreq.Request{
Host: "foobar.com",
}
//req.Request will return a new instance of an http.Request so you can safely use it for something else
request, _ := req.NewRequest()
```
To get the Response:
```go
res, err := goreq.Request{
Method: "GET",
Uri: "http://www.google.com",
Compression: goreq.Gzip(),
ShowDebug: true,
}.Do()
// res.Response will contain the original http.Response structure
fmt.Println(res.Response, err)
```
TODO:
-----
We do have a couple of [issues](https://github.com/franela/goreq/issues) pending we'll be addressing soon. But feel free to
contribute and send us PRs (with tests please :smile:).

@ -0,0 +1,494 @@
package goreq
import (
"bufio"
"bytes"
"compress/gzip"
"compress/zlib"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"reflect"
"strings"
"time"
)
type Request struct {
headers []headerTuple
cookies []*http.Cookie
Method string
Uri string
Body interface{}
QueryString interface{}
Timeout time.Duration
ContentType string
Accept string
Host string
UserAgent string
Insecure bool
MaxRedirects int
RedirectHeaders bool
Proxy string
Compression *compression
BasicAuthUsername string
BasicAuthPassword string
CookieJar http.CookieJar
ShowDebug bool
OnBeforeRequest func(goreq *Request, httpreq *http.Request)
}
type compression struct {
writer func(buffer io.Writer) (io.WriteCloser, error)
reader func(buffer io.Reader) (io.ReadCloser, error)
ContentEncoding string
}
type Response struct {
*http.Response
Uri string
Body *Body
req *http.Request
}
func (r Response) CancelRequest() {
cancelRequest(r.req)
}
func cancelRequest(r *http.Request) {
if transport, ok := DefaultTransport.(transportRequestCanceler); ok {
transport.CancelRequest(r)
}
}
type headerTuple struct {
name string
value string
}
type Body struct {
reader io.ReadCloser
compressedReader io.ReadCloser
}
type Error struct {
timeout bool
Err error
}
type transportRequestCanceler interface {
CancelRequest(*http.Request)
}
func (e *Error) Timeout() bool {
return e.timeout
}
func (e *Error) Error() string {
return e.Err.Error()
}
func (b *Body) Read(p []byte) (int, error) {
if b.compressedReader != nil {
return b.compressedReader.Read(p)
}
return b.reader.Read(p)
}
func (b *Body) Close() error {
err := b.reader.Close()
if b.compressedReader != nil {
return b.compressedReader.Close()
}
return err
}
func (b *Body) FromJsonTo(o interface{}) error {
return json.NewDecoder(b).Decode(o)
}
func (b *Body) ToString() (string, error) {
body, err := ioutil.ReadAll(b)
if err != nil {
return "", err
}
return string(body), nil
}
func Gzip() *compression {
reader := func(buffer io.Reader) (io.ReadCloser, error) {
return gzip.NewReader(buffer)
}
writer := func(buffer io.Writer) (io.WriteCloser, error) {
return gzip.NewWriter(buffer), nil
}
return &compression{writer: writer, reader: reader, ContentEncoding: "gzip"}
}
func Deflate() *compression {
reader := func(buffer io.Reader) (io.ReadCloser, error) {
return zlib.NewReader(buffer)
}
writer := func(buffer io.Writer) (io.WriteCloser, error) {
return zlib.NewWriter(buffer), nil
}
return &compression{writer: writer, reader: reader, ContentEncoding: "deflate"}
}
func Zlib() *compression {
return Deflate()
}
func paramParse(query interface{}) (string, error) {
switch query.(type) {
case url.Values:
return query.(url.Values).Encode(), nil
case *url.Values:
return query.(*url.Values).Encode(), nil
default:
var v = &url.Values{}
err := paramParseStruct(v, query)
return v.Encode(), err
}
}
func paramParseStruct(v *url.Values, query interface{}) error {
var (
s = reflect.ValueOf(query)
t = reflect.TypeOf(query)
)
for t.Kind() == reflect.Ptr || t.Kind() == reflect.Interface {
s = s.Elem()
t = s.Type()
}
if t.Kind() != reflect.Struct {
return errors.New("Can not parse QueryString.")
}
for i := 0; i < t.NumField(); i++ {
var name string
field := s.Field(i)
typeField := t.Field(i)
if !field.CanInterface() {
continue
}
urlTag := typeField.Tag.Get("url")
if urlTag == "-" {
continue
}
name, opts := parseTag(urlTag)
var omitEmpty, squash bool
omitEmpty = opts.Contains("omitempty")
squash = opts.Contains("squash")
if squash {
err := paramParseStruct(v, field.Interface())
if err != nil {
return err
}
continue
}
if urlTag == "" {
name = strings.ToLower(typeField.Name)
}
if val := fmt.Sprintf("%v", field.Interface()); !(omitEmpty && len(val) == 0) {
v.Add(name, val)
}
}
return nil
}
func prepareRequestBody(b interface{}) (io.Reader, error) {
switch b.(type) {
case string:
// treat is as text
return strings.NewReader(b.(string)), nil
case io.Reader:
// treat is as text
return b.(io.Reader), nil
case []byte:
//treat as byte array
return bytes.NewReader(b.([]byte)), nil
case nil:
return nil, nil
default:
// try to jsonify it
j, err := json.Marshal(b)
if err == nil {
return bytes.NewReader(j), nil
}
return nil, err
}
}
var DefaultDialer = &net.Dialer{Timeout: 1000 * time.Millisecond}
var DefaultTransport http.RoundTripper = &http.Transport{Dial: DefaultDialer.Dial, Proxy: http.ProxyFromEnvironment}
var DefaultClient = &http.Client{Transport: DefaultTransport}
var proxyTransport http.RoundTripper
var proxyClient *http.Client
func SetConnectTimeout(duration time.Duration) {
DefaultDialer.Timeout = duration
}
func (r *Request) AddHeader(name string, value string) {
if r.headers == nil {
r.headers = []headerTuple{}
}
r.headers = append(r.headers, headerTuple{name: name, value: value})
}
func (r Request) WithHeader(name string, value string) Request {
r.AddHeader(name, value)
return r
}
func (r *Request) AddCookie(c *http.Cookie) {
r.cookies = append(r.cookies, c)
}
func (r Request) WithCookie(c *http.Cookie) Request {
r.AddCookie(c)
return r
}
func (r Request) Do() (*Response, error) {
var client = DefaultClient
var transport = DefaultTransport
var resUri string
var redirectFailed bool
r.Method = valueOrDefault(r.Method, "GET")
// use a client with a cookie jar if necessary. We create a new client not
// to modify the default one.
if r.CookieJar != nil {
client = &http.Client{
Transport: transport,
Jar: r.CookieJar,
}
}
if r.Proxy != "" {
proxyUrl, err := url.Parse(r.Proxy)
if err != nil {
// proxy address is in a wrong format
return nil, &Error{Err: err}
}
//If jar is specified new client needs to be built
if proxyTransport == nil || client.Jar != nil {
proxyTransport = &http.Transport{Dial: DefaultDialer.Dial, Proxy: http.ProxyURL(proxyUrl)}
proxyClient = &http.Client{Transport: proxyTransport, Jar: client.Jar}
} else if proxyTransport, ok := proxyTransport.(*http.Transport); ok {
proxyTransport.Proxy = http.ProxyURL(proxyUrl)
}
transport = proxyTransport
client = proxyClient
}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > r.MaxRedirects {
redirectFailed = true
return errors.New("Error redirecting. MaxRedirects reached")
}
resUri = req.URL.String()
//By default Golang will not redirect request headers
// https://code.google.com/p/go/issues/detail?id=4800&q=request%20header
if r.RedirectHeaders {
for key, val := range via[0].Header {
req.Header[key] = val
}
}
return nil
}
if transport, ok := transport.(*http.Transport); ok {
if r.Insecure {
if transport.TLSClientConfig != nil {
transport.TLSClientConfig.InsecureSkipVerify = true
} else {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
} else if transport.TLSClientConfig != nil {
// the default TLS client (when transport.TLSClientConfig==nil) is
// already set to verify, so do nothing in that case
transport.TLSClientConfig.InsecureSkipVerify = false
}
}
req, err := r.NewRequest()
if err != nil {
// we couldn't parse the URL.
return nil, &Error{Err: err}
}
timeout := false
var timer *time.Timer
if r.Timeout > 0 {
timer = time.AfterFunc(r.Timeout, func() {
cancelRequest(req)
timeout = true
})
}
if r.ShowDebug {
dump, err := httputil.DumpRequest(req, true)
if err != nil {
log.Println(err)
}
log.Println(string(dump))
}
if r.OnBeforeRequest != nil {
r.OnBeforeRequest(&r, req)
}
res, err := client.Do(req)
if timer != nil {
timer.Stop()
}
if err != nil {
if !timeout {
switch err := err.(type) {
case *net.OpError:
timeout = err.Timeout()
case *url.Error:
if op, ok := err.Err.(*net.OpError); ok {
timeout = op.Timeout()
}
}
}
var response *Response
//If redirect fails we still want to return response data
if redirectFailed {
if res != nil {
response = &Response{res, resUri, &Body{reader: res.Body}, req}
} else {
response = &Response{res, resUri, nil, req}
}
}
//If redirect fails and we haven't set a redirect count we shouldn't return an error
if redirectFailed && r.MaxRedirects == 0 {
return response, nil
}
return response, &Error{timeout: timeout, Err: err}
}
if r.Compression != nil && strings.Contains(res.Header.Get("Content-Encoding"), r.Compression.ContentEncoding) {
compressedReader, err := r.Compression.reader(res.Body)
if err != nil {
return nil, &Error{Err: err}
}
return &Response{res, resUri, &Body{reader: res.Body, compressedReader: compressedReader}, req}, nil
}
return &Response{res, resUri, &Body{reader: res.Body}, req}, nil
}
func (r Request) addHeaders(headersMap http.Header) {
if len(r.UserAgent) > 0 {
headersMap.Add("User-Agent", r.UserAgent)
}
if r.Accept != "" {
headersMap.Add("Accept", r.Accept)
}
if r.ContentType != "" {
headersMap.Add("Content-Type", r.ContentType)
}
}
func (r Request) NewRequest() (*http.Request, error) {
b, e := prepareRequestBody(r.Body)
if e != nil {
// there was a problem marshaling the body
return nil, &Error{Err: e}
}
if r.QueryString != nil {
param, e := paramParse(r.QueryString)
if e != nil {
return nil, &Error{Err: e}
}
r.Uri = r.Uri + "?" + param
}
var bodyReader io.Reader
if b != nil && r.Compression != nil {
buffer := bytes.NewBuffer([]byte{})
readBuffer := bufio.NewReader(b)
writer, err := r.Compression.writer(buffer)
if err != nil {
return nil, &Error{Err: err}
}
_, e = readBuffer.WriteTo(writer)
writer.Close()
if e != nil {
return nil, &Error{Err: e}
}
bodyReader = buffer
} else {
bodyReader = b
}
req, err := http.NewRequest(r.Method, r.Uri, bodyReader)
if err != nil {
return nil, err
}
// add headers to the request
req.Host = r.Host
r.addHeaders(req.Header)
if r.Compression != nil {
req.Header.Add("Content-Encoding", r.Compression.ContentEncoding)
req.Header.Add("Accept-Encoding", r.Compression.ContentEncoding)
}
if r.headers != nil {
for _, header := range r.headers {
req.Header.Add(header.name, header.value)
}
}
//use basic auth if required
if r.BasicAuthUsername != "" {
req.SetBasicAuth(r.BasicAuthUsername, r.BasicAuthPassword)
}
for _, c := range r.cookies {
req.AddCookie(c)
}
return req, nil
}
// Return value if nonempty, def otherwise.
func valueOrDefault(value, def string) string {
if value != "" {
return value
}
return def
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,64 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package goreq
import (
"strings"
"unicode"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
if idx := strings.Index(tag, ","); idx != -1 {
return tag[:idx], tagOptions(tag[idx+1:])
}
return tag, tagOptions("")
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var next string
i := strings.Index(s, ",")
if i >= 0 {
s, next = s[:i], s[i+1:]
}
if s == optionName {
return true
}
s = next
}
return false
}
func isValidTag(s string) bool {
if s == "" {
return false
}
for _, c := range s {
switch {
case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c):
// Backslash and quote chars are reserved, but
// otherwise any punctuation chars are allowed
// in a tag name.
default:
if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
return false
}
}
}
return true
}

@ -0,0 +1,43 @@
# Go support for Protocol Buffers - Google's data interchange format
#
# Copyright 2010 The Go Authors. All rights reserved.
# https://github.com/golang/protobuf
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
install:
go install
test: install generate-test-pbs
go test
generate-test-pbs:
make install
make -C testdata
protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata:. proto3_proto/proto3.proto
make

File diff suppressed because it is too large Load Diff

@ -0,0 +1,223 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer deep copy and merge.
// TODO: MessageSet and RawMessage.
package proto
import (
"log"
"reflect"
"strings"
)
// Clone returns a deep copy of a protocol buffer.
func Clone(pb Message) Message {
in := reflect.ValueOf(pb)
if in.IsNil() {
return pb
}
out := reflect.New(in.Type().Elem())
// out is empty so a merge is a deep copy.
mergeStruct(out.Elem(), in.Elem())
return out.Interface().(Message)
}
// Merge merges src into dst.
// Required and optional fields that are set in src will be set to that value in dst.
// Elements of repeated fields will be appended.
// Merge panics if src and dst are not the same type, or if dst is nil.
func Merge(dst, src Message) {
in := reflect.ValueOf(src)
out := reflect.ValueOf(dst)
if out.IsNil() {
panic("proto: nil destination")
}
if in.Type() != out.Type() {
// Explicit test prior to mergeStruct so that mistyped nils will fail
panic("proto: type mismatch")
}
if in.IsNil() {
// Merging nil into non-nil is a quiet no-op
return
}
mergeStruct(out.Elem(), in.Elem())
}
func mergeStruct(out, in reflect.Value) {
sprop := GetProperties(in.Type())
for i := 0; i < in.NumField(); i++ {
f := in.Type().Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
continue
}
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
}
if emIn, ok := in.Addr().Interface().(extendableProto); ok {
emOut := out.Addr().Interface().(extendableProto)
mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
}
uf := in.FieldByName("XXX_unrecognized")
if !uf.IsValid() {
return
}
uin := uf.Bytes()
if len(uin) > 0 {
out.FieldByName("XXX_unrecognized").SetBytes(append([]byte(nil), uin...))
}
}
// mergeAny performs a merge between two values of the same type.
// viaPtr indicates whether the values were indirected through a pointer (implying proto2).
// prop is set if this is a struct field (it may be nil).
func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) {
if in.Type() == protoMessageType {
if !in.IsNil() {
if out.IsNil() {
out.Set(reflect.ValueOf(Clone(in.Interface().(Message))))
} else {
Merge(out.Interface().(Message), in.Interface().(Message))
}
}
return
}
switch in.Kind() {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64:
if !viaPtr && isProto3Zero(in) {
return
}
out.Set(in)
case reflect.Interface:
// Probably a oneof field; copy non-nil values.
if in.IsNil() {
return
}
// Allocate destination if it is not set, or set to a different type.
// Otherwise we will merge as normal.
if out.IsNil() || out.Elem().Type() != in.Elem().Type() {
out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T)
}
mergeAny(out.Elem(), in.Elem(), false, nil)
case reflect.Map:
if in.Len() == 0 {
return
}
if out.IsNil() {
out.Set(reflect.MakeMap(in.Type()))
}
// For maps with value types of *T or []byte we need to deep copy each value.
elemKind := in.Type().Elem().Kind()
for _, key := range in.MapKeys() {
var val reflect.Value
switch elemKind {
case reflect.Ptr:
val = reflect.New(in.Type().Elem().Elem())
mergeAny(val, in.MapIndex(key), false, nil)
case reflect.Slice:
val = in.MapIndex(key)
val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
default:
val = in.MapIndex(key)
}
out.SetMapIndex(key, val)
}
case reflect.Ptr:
if in.IsNil() {
return
}
if out.IsNil() {
out.Set(reflect.New(in.Elem().Type()))
}
mergeAny(out.Elem(), in.Elem(), true, nil)
case reflect.Slice:
if in.IsNil() {
return
}
if in.Type().Elem().Kind() == reflect.Uint8 {
// []byte is a scalar bytes field, not a repeated field.
// Edge case: if this is in a proto3 message, a zero length
// bytes field is considered the zero value, and should not
// be merged.
if prop != nil && prop.proto3 && in.Len() == 0 {
return
}
// Make a deep copy.
// Append to []byte{} instead of []byte(nil) so that we never end up
// with a nil result.
out.SetBytes(append([]byte{}, in.Bytes()...))
return
}
n := in.Len()
if out.IsNil() {
out.Set(reflect.MakeSlice(in.Type(), 0, n))
}
switch in.Type().Elem().Kind() {
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
reflect.String, reflect.Uint32, reflect.Uint64:
out.Set(reflect.AppendSlice(out, in))
default:
for i := 0; i < n; i++ {
x := reflect.Indirect(reflect.New(in.Type().Elem()))
mergeAny(x, in.Index(i), false, nil)
out.Set(reflect.Append(out, x))
}
}
case reflect.Struct:
mergeStruct(out, in)
default:
// unknown type, so not a protocol buffer
log.Printf("proto: don't know how to copy %v", in)
}
}
func mergeExtension(out, in map[int32]Extension) {
for extNum, eIn := range in {
eOut := Extension{desc: eIn.desc}
if eIn.value != nil {
v := reflect.New(reflect.TypeOf(eIn.value)).Elem()
mergeAny(v, reflect.ValueOf(eIn.value), false, nil)
eOut.value = v.Interface()
}
if eIn.enc != nil {
eOut.enc = make([]byte, len(eIn.enc))
copy(eOut.enc, eIn.enc)
}
out[extNum] = eOut
}
}

@ -0,0 +1,267 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"testing"
"github.com/golang/protobuf/proto"
proto3pb "github.com/golang/protobuf/proto/proto3_proto"
pb "github.com/golang/protobuf/proto/testdata"
)
var cloneTestMessage = &pb.MyMessage{
Count: proto.Int32(42),
Name: proto.String("Dave"),
Pet: []string{"bunny", "kitty", "horsey"},
Inner: &pb.InnerMessage{
Host: proto.String("niles"),
Port: proto.Int32(9099),
Connected: proto.Bool(true),
},
Others: []*pb.OtherMessage{
{
Value: []byte("some bytes"),
},
},
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: proto.Int32(6),
},
RepBytes: [][]byte{[]byte("sham"), []byte("wow")},
}
func init() {
ext := &pb.Ext{
Data: proto.String("extension"),
}
if err := proto.SetExtension(cloneTestMessage, pb.E_Ext_More, ext); err != nil {
panic("SetExtension: " + err.Error())
}
}
func TestClone(t *testing.T) {
m := proto.Clone(cloneTestMessage).(*pb.MyMessage)
if !proto.Equal(m, cloneTestMessage) {
t.Errorf("Clone(%v) = %v", cloneTestMessage, m)
}
// Verify it was a deep copy.
*m.Inner.Port++
if proto.Equal(m, cloneTestMessage) {
t.Error("Mutating clone changed the original")
}
// Byte fields and repeated fields should be copied.
if &m.Pet[0] == &cloneTestMessage.Pet[0] {
t.Error("Pet: repeated field not copied")
}
if &m.Others[0] == &cloneTestMessage.Others[0] {
t.Error("Others: repeated field not copied")
}
if &m.Others[0].Value[0] == &cloneTestMessage.Others[0].Value[0] {
t.Error("Others[0].Value: bytes field not copied")
}
if &m.RepBytes[0] == &cloneTestMessage.RepBytes[0] {
t.Error("RepBytes: repeated field not copied")
}
if &m.RepBytes[0][0] == &cloneTestMessage.RepBytes[0][0] {
t.Error("RepBytes[0]: bytes field not copied")
}
}
func TestCloneNil(t *testing.T) {
var m *pb.MyMessage
if c := proto.Clone(m); !proto.Equal(m, c) {
t.Errorf("Clone(%v) = %v", m, c)
}
}
var mergeTests = []struct {
src, dst, want proto.Message
}{
{
src: &pb.MyMessage{
Count: proto.Int32(42),
},
dst: &pb.MyMessage{
Name: proto.String("Dave"),
},
want: &pb.MyMessage{
Count: proto.Int32(42),
Name: proto.String("Dave"),
},
},
{
src: &pb.MyMessage{
Inner: &pb.InnerMessage{
Host: proto.String("hey"),
Connected: proto.Bool(true),
},
Pet: []string{"horsey"},
Others: []*pb.OtherMessage{
{
Value: []byte("some bytes"),
},
},
},
dst: &pb.MyMessage{
Inner: &pb.InnerMessage{
Host: proto.String("niles"),
Port: proto.Int32(9099),
},
Pet: []string{"bunny", "kitty"},
Others: []*pb.OtherMessage{
{
Key: proto.Int64(31415926535),
},
{
// Explicitly test a src=nil field
Inner: nil,
},
},
},
want: &pb.MyMessage{
Inner: &pb.InnerMessage{
Host: proto.String("hey"),
Connected: proto.Bool(true),
Port: proto.Int32(9099),
},
Pet: []string{"bunny", "kitty", "horsey"},
Others: []*pb.OtherMessage{
{
Key: proto.Int64(31415926535),
},
{},
{
Value: []byte("some bytes"),
},
},
},
},
{
src: &pb.MyMessage{
RepBytes: [][]byte{[]byte("wow")},
},
dst: &pb.MyMessage{
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: proto.Int32(6),
},
RepBytes: [][]byte{[]byte("sham")},
},
want: &pb.MyMessage{
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: proto.Int32(6),
},
RepBytes: [][]byte{[]byte("sham"), []byte("wow")},
},
},
// Check that a scalar bytes field replaces rather than appends.
{
src: &pb.OtherMessage{Value: []byte("foo")},
dst: &pb.OtherMessage{Value: []byte("bar")},
want: &pb.OtherMessage{Value: []byte("foo")},
},
{
src: &pb.MessageWithMap{
NameMapping: map[int32]string{6: "Nigel"},
MsgMapping: map[int64]*pb.FloatingPoint{
0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
},
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
},
dst: &pb.MessageWithMap{
NameMapping: map[int32]string{
6: "Bruce", // should be overwritten
7: "Andrew",
},
},
want: &pb.MessageWithMap{
NameMapping: map[int32]string{
6: "Nigel",
7: "Andrew",
},
MsgMapping: map[int64]*pb.FloatingPoint{
0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
},
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
},
},
// proto3 shouldn't merge zero values,
// in the same way that proto2 shouldn't merge nils.
{
src: &proto3pb.Message{
Name: "Aaron",
Data: []byte(""), // zero value, but not nil
},
dst: &proto3pb.Message{
HeightInCm: 176,
Data: []byte("texas!"),
},
want: &proto3pb.Message{
Name: "Aaron",
HeightInCm: 176,
Data: []byte("texas!"),
},
},
// Oneof fields should merge by assignment.
{
src: &pb.Communique{
Union: &pb.Communique_Number{41},
},
dst: &pb.Communique{
Union: &pb.Communique_Name{"Bobby Tables"},
},
want: &pb.Communique{
Union: &pb.Communique_Number{41},
},
},
// Oneof nil is the same as not set.
{
src: &pb.Communique{},
dst: &pb.Communique{
Union: &pb.Communique_Name{"Bobby Tables"},
},
want: &pb.Communique{
Union: &pb.Communique_Name{"Bobby Tables"},
},
},
}
func TestMerge(t *testing.T) {
for _, m := range mergeTests {
got := proto.Clone(m.dst)
proto.Merge(got, m.src)
if !proto.Equal(got, m.want) {
t.Errorf("Merge(%v, %v)\n got %v\nwant %v\n", m.dst, m.src, got, m.want)
}
}
}

@ -0,0 +1,863 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
/*
* Routines for decoding protocol buffer data to construct in-memory representations.
*/
import (
"errors"
"fmt"
"io"
"os"
"reflect"
)
// errOverflow is returned when an integer is too large to be represented.
var errOverflow = errors.New("proto: integer overflow")
// ErrInternalBadWireType is returned by generated code when an incorrect
// wire type is encountered. It does not get returned to user code.
var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof")
// The fundamental decoders that interpret bytes on the wire.
// Those that take integer types all return uint64 and are
// therefore of type valueDecoder.
// DecodeVarint reads a varint-encoded integer from the slice.
// It returns the integer and the number of bytes consumed, or
// zero if there is not enough.
// This is the format for the
// int32, int64, uint32, uint64, bool, and enum
// protocol buffer types.
func DecodeVarint(buf []byte) (x uint64, n int) {
// x, n already 0
for shift := uint(0); shift < 64; shift += 7 {
if n >= len(buf) {
return 0, 0
}
b := uint64(buf[n])
n++
x |= (b & 0x7F) << shift
if (b & 0x80) == 0 {
return x, n
}
}
// The number is too large to represent in a 64-bit value.
return 0, 0
}
// DecodeVarint reads a varint-encoded integer from the Buffer.
// This is the format for the
// int32, int64, uint32, uint64, bool, and enum
// protocol buffer types.
func (p *Buffer) DecodeVarint() (x uint64, err error) {
// x, err already 0
i := p.index
l := len(p.buf)
for shift := uint(0); shift < 64; shift += 7 {
if i >= l {
err = io.ErrUnexpectedEOF
return
}
b := p.buf[i]
i++
x |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
p.index = i
return
}
}
// The number is too large to represent in a 64-bit value.
err = errOverflow
return
}
// DecodeFixed64 reads a 64-bit integer from the Buffer.
// This is the format for the
// fixed64, sfixed64, and double protocol buffer types.
func (p *Buffer) DecodeFixed64() (x uint64, err error) {
// x, err already 0
i := p.index + 8
if i < 0 || i > len(p.buf) {
err = io.ErrUnexpectedEOF
return
}
p.index = i
x = uint64(p.buf[i-8])
x |= uint64(p.buf[i-7]) << 8
x |= uint64(p.buf[i-6]) << 16
x |= uint64(p.buf[i-5]) << 24
x |= uint64(p.buf[i-4]) << 32
x |= uint64(p.buf[i-3]) << 40
x |= uint64(p.buf[i-2]) << 48
x |= uint64(p.buf[i-1]) << 56
return
}
// DecodeFixed32 reads a 32-bit integer from the Buffer.
// This is the format for the
// fixed32, sfixed32, and float protocol buffer types.
func (p *Buffer) DecodeFixed32() (x uint64, err error) {
// x, err already 0
i := p.index + 4
if i < 0 || i > len(p.buf) {
err = io.ErrUnexpectedEOF
return
}
p.index = i
x = uint64(p.buf[i-4])
x |= uint64(p.buf[i-3]) << 8
x |= uint64(p.buf[i-2]) << 16
x |= uint64(p.buf[i-1]) << 24
return
}
// DecodeZigzag64 reads a zigzag-encoded 64-bit integer
// from the Buffer.
// This is the format used for the sint64 protocol buffer type.
func (p *Buffer) DecodeZigzag64() (x uint64, err error) {
x, err = p.DecodeVarint()
if err != nil {
return
}
x = (x >> 1) ^ uint64((int64(x&1)<<63)>>63)
return
}
// DecodeZigzag32 reads a zigzag-encoded 32-bit integer
// from the Buffer.
// This is the format used for the sint32 protocol buffer type.
func (p *Buffer) DecodeZigzag32() (x uint64, err error) {
x, err = p.DecodeVarint()
if err != nil {
return
}
x = uint64((uint32(x) >> 1) ^ uint32((int32(x&1)<<31)>>31))
return
}
// These are not ValueDecoders: they produce an array of bytes or a string.
// bytes, embedded messages
// DecodeRawBytes reads a count-delimited byte buffer from the Buffer.
// This is the format used for the bytes protocol buffer
// type and for embedded messages.
func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
n, err := p.DecodeVarint()
if err != nil {
return nil, err
}
nb := int(n)
if nb < 0 {
return nil, fmt.Errorf("proto: bad byte length %d", nb)
}
end := p.index + nb
if end < p.index || end > len(p.buf) {
return nil, io.ErrUnexpectedEOF
}
if !alloc {
// todo: check if can get more uses of alloc=false
buf = p.buf[p.index:end]
p.index += nb
return
}
buf = make([]byte, nb)
copy(buf, p.buf[p.index:])
p.index += nb
return
}
// DecodeStringBytes reads an encoded string from the Buffer.
// This is the format used for the proto2 string type.
func (p *Buffer) DecodeStringBytes() (s string, err error) {
buf, err := p.DecodeRawBytes(false)
if err != nil {
return
}
return string(buf), nil
}
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
// If the protocol buffer has extensions, and the field matches, add it as an extension.
// Otherwise, if the XXX_unrecognized field exists, append the skipped data there.
func (o *Buffer) skipAndSave(t reflect.Type, tag, wire int, base structPointer, unrecField field) error {
oi := o.index
err := o.skip(t, tag, wire)
if err != nil {
return err
}
if !unrecField.IsValid() {
return nil
}
ptr := structPointer_Bytes(base, unrecField)
// Add the skipped field to struct field
obuf := o.buf
o.buf = *ptr
o.EncodeVarint(uint64(tag<<3 | wire))
*ptr = append(o.buf, obuf[oi:o.index]...)
o.buf = obuf
return nil
}
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
func (o *Buffer) skip(t reflect.Type, tag, wire int) error {
var u uint64
var err error
switch wire {
case WireVarint:
_, err = o.DecodeVarint()
case WireFixed64:
_, err = o.DecodeFixed64()
case WireBytes:
_, err = o.DecodeRawBytes(false)
case WireFixed32:
_, err = o.DecodeFixed32()
case WireStartGroup:
for {
u, err = o.DecodeVarint()
if err != nil {
break
}
fwire := int(u & 0x7)
if fwire == WireEndGroup {
break
}
ftag := int(u >> 3)
err = o.skip(t, ftag, fwire)
if err != nil {
break
}
}
default:
err = fmt.Errorf("proto: can't skip unknown wire type %d for %s", wire, t)
}
return err
}
// Unmarshaler is the interface representing objects that can
// unmarshal themselves. The method should reset the receiver before
// decoding starts. The argument points to data that may be
// overwritten, so implementations should not keep references to the
// buffer.
type Unmarshaler interface {
Unmarshal([]byte) error
}
// Unmarshal parses the protocol buffer representation in buf and places the
// decoded result in pb. If the struct underlying pb does not match
// the data in buf, the results can be unpredictable.
//
// Unmarshal resets pb before starting to unmarshal, so any
// existing data in pb is always removed. Use UnmarshalMerge
// to preserve and append to existing data.
func Unmarshal(buf []byte, pb Message) error {
pb.Reset()
return UnmarshalMerge(buf, pb)
}
// UnmarshalMerge parses the protocol buffer representation in buf and
// writes the decoded result to pb. If the struct underlying pb does not match
// the data in buf, the results can be unpredictable.
//
// UnmarshalMerge merges into existing data in pb.
// Most code should use Unmarshal instead.
func UnmarshalMerge(buf []byte, pb Message) error {
// If the object can unmarshal itself, let it.
if u, ok := pb.(Unmarshaler); ok {
return u.Unmarshal(buf)
}
return NewBuffer(buf).Unmarshal(pb)
}
// DecodeMessage reads a count-delimited message from the Buffer.
func (p *Buffer) DecodeMessage(pb Message) error {
enc, err := p.DecodeRawBytes(false)
if err != nil {
return err
}
return NewBuffer(enc).Unmarshal(pb)
}
// DecodeGroup reads a tag-delimited group from the Buffer.
func (p *Buffer) DecodeGroup(pb Message) error {
typ, base, err := getbase(pb)
if err != nil {
return err
}
return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base)
}
// Unmarshal parses the protocol buffer representation in the
// Buffer and places the decoded result in pb. If the struct
// underlying pb does not match the data in the buffer, the results can be
// unpredictable.
func (p *Buffer) Unmarshal(pb Message) error {
// If the object can unmarshal itself, let it.
if u, ok := pb.(Unmarshaler); ok {
err := u.Unmarshal(p.buf[p.index:])
p.index = len(p.buf)
return err
}
typ, base, err := getbase(pb)
if err != nil {
return err
}
err = p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), false, base)
if collectStats {
stats.Decode++
}
return err
}
// unmarshalType does the work of unmarshaling a structure.
func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error {
var state errorState
required, reqFields := prop.reqCount, uint64(0)
var err error
for err == nil && o.index < len(o.buf) {
oi := o.index
var u uint64
u, err = o.DecodeVarint()
if err != nil {
break
}
wire := int(u & 0x7)
if wire == WireEndGroup {
if is_group {
return nil // input is satisfied
}
return fmt.Errorf("proto: %s: wiretype end group for non-group", st)
}
tag := int(u >> 3)
if tag <= 0 {
return fmt.Errorf("proto: %s: illegal tag %d (wire type %d)", st, tag, wire)
}
fieldnum, ok := prop.decoderTags.get(tag)
if !ok {
// Maybe it's an extension?
if prop.extendable {
if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) {
if err = o.skip(st, tag, wire); err == nil {
ext := e.ExtensionMap()[int32(tag)] // may be missing
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
e.ExtensionMap()[int32(tag)] = ext
}
continue
}
}
// Maybe it's a oneof?
if prop.oneofUnmarshaler != nil {
m := structPointer_Interface(base, st).(Message)
// First return value indicates whether tag is a oneof field.
ok, err = prop.oneofUnmarshaler(m, tag, wire, o)
if err == ErrInternalBadWireType {
// Map the error to something more descriptive.
// Do the formatting here to save generated code space.
err = fmt.Errorf("bad wiretype for oneof field in %T", m)
}
if ok {
continue
}
}
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
continue
}
p := prop.Prop[fieldnum]
if p.dec == nil {
fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", st, st.Field(fieldnum).Name)
continue
}
dec := p.dec
if wire != WireStartGroup && wire != p.WireType {
if wire == WireBytes && p.packedDec != nil {
// a packable field
dec = p.packedDec
} else {
err = fmt.Errorf("proto: bad wiretype for field %s.%s: got wiretype %d, want %d", st, st.Field(fieldnum).Name, wire, p.WireType)
continue
}
}
decErr := dec(o, p, base)
if decErr != nil && !state.shouldContinue(decErr, p) {
err = decErr
}
if err == nil && p.Required {
// Successfully decoded a required field.
if tag <= 64 {
// use bitmap for fields 1-64 to catch field reuse.
var mask uint64 = 1 << uint64(tag-1)
if reqFields&mask == 0 {
// new required field
reqFields |= mask
required--
}
} else {
// This is imprecise. It can be fooled by a required field
// with a tag > 64 that is encoded twice; that's very rare.
// A fully correct implementation would require allocating
// a data structure, which we would like to avoid.
required--
}
}
}
if err == nil {
if is_group {
return io.ErrUnexpectedEOF
}
if state.err != nil {
return state.err
}
if required > 0 {
// Not enough information to determine the exact field. If we use extra
// CPU, we could determine the field only if the missing required field
// has a tag <= 64 and we check reqFields.
return &RequiredNotSetError{"{Unknown}"}
}
}
return err
}
// Individual type decoders
// For each,
// u is the decoded value,
// v is a pointer to the field (pointer) in the struct
// Sizes of the pools to allocate inside the Buffer.
// The goal is modest amortization and allocation
// on at least 16-byte boundaries.
const (
boolPoolSize = 16
uint32PoolSize = 8
uint64PoolSize = 4
)
// Decode a bool.
func (o *Buffer) dec_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
if len(o.bools) == 0 {
o.bools = make([]bool, boolPoolSize)
}
o.bools[0] = u != 0
*structPointer_Bool(base, p.field) = &o.bools[0]
o.bools = o.bools[1:]
return nil
}
func (o *Buffer) dec_proto3_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
*structPointer_BoolVal(base, p.field) = u != 0
return nil
}
// Decode an int32.
func (o *Buffer) dec_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word32_Set(structPointer_Word32(base, p.field), o, uint32(u))
return nil
}
func (o *Buffer) dec_proto3_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word32Val_Set(structPointer_Word32Val(base, p.field), uint32(u))
return nil
}
// Decode an int64.
func (o *Buffer) dec_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word64_Set(structPointer_Word64(base, p.field), o, u)
return nil
}
func (o *Buffer) dec_proto3_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
word64Val_Set(structPointer_Word64Val(base, p.field), o, u)
return nil
}
// Decode a string.
func (o *Buffer) dec_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
*structPointer_String(base, p.field) = &s
return nil
}
func (o *Buffer) dec_proto3_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
*structPointer_StringVal(base, p.field) = s
return nil
}
// Decode a slice of bytes ([]byte).
func (o *Buffer) dec_slice_byte(p *Properties, base structPointer) error {
b, err := o.DecodeRawBytes(true)
if err != nil {
return err
}
*structPointer_Bytes(base, p.field) = b
return nil
}
// Decode a slice of bools ([]bool).
func (o *Buffer) dec_slice_bool(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
v := structPointer_BoolSlice(base, p.field)
*v = append(*v, u != 0)
return nil
}
// Decode a slice of bools ([]bool) in packed format.
func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error {
v := structPointer_BoolSlice(base, p.field)
nn, err := o.DecodeVarint()
if err != nil {
return err
}
nb := int(nn) // number of bytes of encoded bools
y := *v
for i := 0; i < nb; i++ {
u, err := p.valDec(o)
if err != nil {
return err
}
y = append(y, u != 0)
}
*v = y
return nil
}
// Decode a slice of int32s ([]int32).
func (o *Buffer) dec_slice_int32(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
structPointer_Word32Slice(base, p.field).Append(uint32(u))
return nil
}
// Decode a slice of int32s ([]int32) in packed format.
func (o *Buffer) dec_slice_packed_int32(p *Properties, base structPointer) error {
v := structPointer_Word32Slice(base, p.field)
nn, err := o.DecodeVarint()
if err != nil {
return err
}
nb := int(nn) // number of bytes of encoded int32s
fin := o.index + nb
if fin < o.index {
return errOverflow
}
for o.index < fin {
u, err := p.valDec(o)
if err != nil {
return err
}
v.Append(uint32(u))
}
return nil
}
// Decode a slice of int64s ([]int64).
func (o *Buffer) dec_slice_int64(p *Properties, base structPointer) error {
u, err := p.valDec(o)
if err != nil {
return err
}
structPointer_Word64Slice(base, p.field).Append(u)
return nil
}
// Decode a slice of int64s ([]int64) in packed format.
func (o *Buffer) dec_slice_packed_int64(p *Properties, base structPointer) error {
v := structPointer_Word64Slice(base, p.field)
nn, err := o.DecodeVarint()
if err != nil {
return err
}
nb := int(nn) // number of bytes of encoded int64s
fin := o.index + nb
if fin < o.index {
return errOverflow
}
for o.index < fin {
u, err := p.valDec(o)
if err != nil {
return err
}
v.Append(u)
}
return nil
}
// Decode a slice of strings ([]string).
func (o *Buffer) dec_slice_string(p *Properties, base structPointer) error {
s, err := o.DecodeStringBytes()
if err != nil {
return err
}
v := structPointer_StringSlice(base, p.field)
*v = append(*v, s)
return nil
}
// Decode a slice of slice of bytes ([][]byte).
func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error {
b, err := o.DecodeRawBytes(true)
if err != nil {
return err
}
v := structPointer_BytesSlice(base, p.field)
*v = append(*v, b)
return nil
}
// Decode a map field.
func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
raw, err := o.DecodeRawBytes(false)
if err != nil {
return err
}
oi := o.index // index at the end of this map entry
o.index -= len(raw) // move buffer back to start of map entry
mptr := structPointer_NewAt(base, p.field, p.mtype) // *map[K]V
if mptr.Elem().IsNil() {
mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem()))
}
v := mptr.Elem() // map[K]V
// Prepare addressable doubly-indirect placeholders for the key and value types.
// See enc_new_map for why.
keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K
keybase := toStructPointer(keyptr.Addr()) // **K
var valbase structPointer
var valptr reflect.Value
switch p.mtype.Elem().Kind() {
case reflect.Slice:
// []byte
var dummy []byte
valptr = reflect.ValueOf(&dummy) // *[]byte
valbase = toStructPointer(valptr) // *[]byte
case reflect.Ptr:
// message; valptr is **Msg; need to allocate the intermediate pointer
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
valptr.Set(reflect.New(valptr.Type().Elem()))
valbase = toStructPointer(valptr)
default:
// everything else
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
valbase = toStructPointer(valptr.Addr()) // **V
}
// Decode.
// This parses a restricted wire format, namely the encoding of a message
// with two fields. See enc_new_map for the format.
for o.index < oi {
// tagcode for key and value properties are always a single byte
// because they have tags 1 and 2.
tagcode := o.buf[o.index]
o.index++
switch tagcode {
case p.mkeyprop.tagcode[0]:
if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil {
return err
}
case p.mvalprop.tagcode[0]:
if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil {
return err
}
default:
// TODO: Should we silently skip this instead?
return fmt.Errorf("proto: bad map data tag %d", raw[0])
}
}
keyelem, valelem := keyptr.Elem(), valptr.Elem()
if !keyelem.IsValid() || !valelem.IsValid() {
// We did not decode the key or the value in the map entry.
// Either way, it's an invalid map entry.
return fmt.Errorf("proto: bad map data: missing key/val")
}
v.SetMapIndex(keyelem, valelem)
return nil
}
// Decode a group.
func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error {
bas := structPointer_GetStructPointer(base, p.field)
if structPointer_IsNil(bas) {
// allocate new nested message
bas = toStructPointer(reflect.New(p.stype))
structPointer_SetStructPointer(base, p.field, bas)
}
return o.unmarshalType(p.stype, p.sprop, true, bas)
}
// Decode an embedded message.
func (o *Buffer) dec_struct_message(p *Properties, base structPointer) (err error) {
raw, e := o.DecodeRawBytes(false)
if e != nil {
return e
}
bas := structPointer_GetStructPointer(base, p.field)
if structPointer_IsNil(bas) {
// allocate new nested message
bas = toStructPointer(reflect.New(p.stype))
structPointer_SetStructPointer(base, p.field, bas)
}
// If the object can unmarshal itself, let it.
if p.isUnmarshaler {
iv := structPointer_Interface(bas, p.stype)
return iv.(Unmarshaler).Unmarshal(raw)
}
obuf := o.buf
oi := o.index
o.buf = raw
o.index = 0
err = o.unmarshalType(p.stype, p.sprop, false, bas)
o.buf = obuf
o.index = oi
return err
}
// Decode a slice of embedded messages.
func (o *Buffer) dec_slice_struct_message(p *Properties, base structPointer) error {
return o.dec_slice_struct(p, false, base)
}
// Decode a slice of embedded groups.
func (o *Buffer) dec_slice_struct_group(p *Properties, base structPointer) error {
return o.dec_slice_struct(p, true, base)
}
// Decode a slice of structs ([]*struct).
func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base structPointer) error {
v := reflect.New(p.stype)
bas := toStructPointer(v)
structPointer_StructPointerSlice(base, p.field).Append(bas)
if is_group {
err := o.unmarshalType(p.stype, p.sprop, is_group, bas)
return err
}
raw, err := o.DecodeRawBytes(false)
if err != nil {
return err
}
// If the object can unmarshal itself, let it.
if p.isUnmarshaler {
iv := v.Interface()
return iv.(Unmarshaler).Unmarshal(raw)
}
obuf := o.buf
oi := o.index
o.buf = raw
o.index = 0
err = o.unmarshalType(p.stype, p.sprop, is_group, bas)
o.buf = obuf
o.index = oi
return err
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,267 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Protocol buffer comparison.
// TODO: MessageSet.
package proto
import (
"bytes"
"log"
"reflect"
"strings"
)
/*
Equal returns true iff protocol buffers a and b are equal.
The arguments must both be pointers to protocol buffer structs.
Equality is defined in this way:
- Two messages are equal iff they are the same type,
corresponding fields are equal, unknown field sets
are equal, and extensions sets are equal.
- Two set scalar fields are equal iff their values are equal.
If the fields are of a floating-point type, remember that
NaN != x for all x, including NaN.
- Two repeated fields are equal iff their lengths are the same,
and their corresponding elements are equal (a "bytes" field,
although represented by []byte, is not a repeated field)
- Two unset fields are equal.
- Two unknown field sets are equal if their current
encoded state is equal.
- Two extension sets are equal iff they have corresponding
elements that are pairwise equal.
- Every other combination of things are not equal.
The return value is undefined if a and b are not protocol buffers.
*/
func Equal(a, b Message) bool {
if a == nil || b == nil {
return a == b
}
v1, v2 := reflect.ValueOf(a), reflect.ValueOf(b)
if v1.Type() != v2.Type() {
return false
}
if v1.Kind() == reflect.Ptr {
if v1.IsNil() {
return v2.IsNil()
}
if v2.IsNil() {
return false
}
v1, v2 = v1.Elem(), v2.Elem()
}
if v1.Kind() != reflect.Struct {
return false
}
return equalStruct(v1, v2)
}
// v1 and v2 are known to have the same type.
func equalStruct(v1, v2 reflect.Value) bool {
for i := 0; i < v1.NumField(); i++ {
f := v1.Type().Field(i)
if strings.HasPrefix(f.Name, "XXX_") {
continue
}
f1, f2 := v1.Field(i), v2.Field(i)
if f.Type.Kind() == reflect.Ptr {
if n1, n2 := f1.IsNil(), f2.IsNil(); n1 && n2 {
// both unset
continue
} else if n1 != n2 {
// set/unset mismatch
return false
}
b1, ok := f1.Interface().(raw)
if ok {
b2 := f2.Interface().(raw)
// RawMessage
if !bytes.Equal(b1.Bytes(), b2.Bytes()) {
return false
}
continue
}
f1, f2 = f1.Elem(), f2.Elem()
}
if !equalAny(f1, f2) {
return false
}
}
if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
em2 := v2.FieldByName("XXX_extensions")
if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
return false
}
}
uf := v1.FieldByName("XXX_unrecognized")
if !uf.IsValid() {
return true
}
u1 := uf.Bytes()
u2 := v2.FieldByName("XXX_unrecognized").Bytes()
if !bytes.Equal(u1, u2) {
return false
}
return true
}
// v1 and v2 are known to have the same type.
func equalAny(v1, v2 reflect.Value) bool {
if v1.Type() == protoMessageType {
m1, _ := v1.Interface().(Message)
m2, _ := v2.Interface().(Message)
return Equal(m1, m2)
}
switch v1.Kind() {
case reflect.Bool:
return v1.Bool() == v2.Bool()
case reflect.Float32, reflect.Float64:
return v1.Float() == v2.Float()
case reflect.Int32, reflect.Int64:
return v1.Int() == v2.Int()
case reflect.Interface:
// Probably a oneof field; compare the inner values.
n1, n2 := v1.IsNil(), v2.IsNil()
if n1 || n2 {
return n1 == n2
}
e1, e2 := v1.Elem(), v2.Elem()
if e1.Type() != e2.Type() {
return false
}
return equalAny(e1, e2)
case reflect.Map:
if v1.Len() != v2.Len() {
return false
}
for _, key := range v1.MapKeys() {
val2 := v2.MapIndex(key)
if !val2.IsValid() {
// This key was not found in the second map.
return false
}
if !equalAny(v1.MapIndex(key), val2) {
return false
}
}
return true
case reflect.Ptr:
return equalAny(v1.Elem(), v2.Elem())
case reflect.Slice:
if v1.Type().Elem().Kind() == reflect.Uint8 {
// short circuit: []byte
if v1.IsNil() != v2.IsNil() {
return false
}
return bytes.Equal(v1.Interface().([]byte), v2.Interface().([]byte))
}
if v1.Len() != v2.Len() {
return false
}
for i := 0; i < v1.Len(); i++ {
if !equalAny(v1.Index(i), v2.Index(i)) {
return false
}
}
return true
case reflect.String:
return v1.Interface().(string) == v2.Interface().(string)
case reflect.Struct:
return equalStruct(v1, v2)
case reflect.Uint32, reflect.Uint64:
return v1.Uint() == v2.Uint()
}
// unknown type, so not a protocol buffer
log.Printf("proto: don't know how to compare %v", v1)
return false
}
// base is the struct type that the extensions are based on.
// em1 and em2 are extension maps.
func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
if len(em1) != len(em2) {
return false
}
for extNum, e1 := range em1 {
e2, ok := em2[extNum]
if !ok {
return false
}
m1, m2 := e1.value, e2.value
if m1 != nil && m2 != nil {
// Both are unencoded.
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
return false
}
continue
}
// At least one is encoded. To do a semantically correct comparison
// we need to unmarshal them first.
var desc *ExtensionDesc
if m := extensionMaps[base]; m != nil {
desc = m[extNum]
}
if desc == nil {
log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
continue
}
var err error
if m1 == nil {
m1, err = decodeExtension(e1.enc, desc)
}
if m2 == nil && err == nil {
m2, err = decodeExtension(e2.enc, desc)
}
if err != nil {
// The encoded form is invalid.
log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
return false
}
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
return false
}
}
return true
}

@ -0,0 +1,209 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2011 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"testing"
. "github.com/golang/protobuf/proto"
pb "github.com/golang/protobuf/proto/testdata"
)
// Four identical base messages.
// The init function adds extensions to some of them.
var messageWithoutExtension = &pb.MyMessage{Count: Int32(7)}
var messageWithExtension1a = &pb.MyMessage{Count: Int32(7)}
var messageWithExtension1b = &pb.MyMessage{Count: Int32(7)}
var messageWithExtension2 = &pb.MyMessage{Count: Int32(7)}
// Two messages with non-message extensions.
var messageWithInt32Extension1 = &pb.MyMessage{Count: Int32(8)}
var messageWithInt32Extension2 = &pb.MyMessage{Count: Int32(8)}
func init() {
ext1 := &pb.Ext{Data: String("Kirk")}
ext2 := &pb.Ext{Data: String("Picard")}
// messageWithExtension1a has ext1, but never marshals it.
if err := SetExtension(messageWithExtension1a, pb.E_Ext_More, ext1); err != nil {
panic("SetExtension on 1a failed: " + err.Error())
}
// messageWithExtension1b is the unmarshaled form of messageWithExtension1a.
if err := SetExtension(messageWithExtension1b, pb.E_Ext_More, ext1); err != nil {
panic("SetExtension on 1b failed: " + err.Error())
}
buf, err := Marshal(messageWithExtension1b)
if err != nil {
panic("Marshal of 1b failed: " + err.Error())
}
messageWithExtension1b.Reset()
if err := Unmarshal(buf, messageWithExtension1b); err != nil {
panic("Unmarshal of 1b failed: " + err.Error())
}
// messageWithExtension2 has ext2.
if err := SetExtension(messageWithExtension2, pb.E_Ext_More, ext2); err != nil {
panic("SetExtension on 2 failed: " + err.Error())
}
if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(23)); err != nil {
panic("SetExtension on Int32-1 failed: " + err.Error())
}
if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(24)); err != nil {
panic("SetExtension on Int32-2 failed: " + err.Error())
}
}
var EqualTests = []struct {
desc string
a, b Message
exp bool
}{
{"different types", &pb.GoEnum{}, &pb.GoTestField{}, false},
{"equal empty", &pb.GoEnum{}, &pb.GoEnum{}, true},
{"nil vs nil", nil, nil, true},
{"typed nil vs typed nil", (*pb.GoEnum)(nil), (*pb.GoEnum)(nil), true},
{"typed nil vs empty", (*pb.GoEnum)(nil), &pb.GoEnum{}, false},
{"different typed nil", (*pb.GoEnum)(nil), (*pb.GoTestField)(nil), false},
{"one set field, one unset field", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{}, false},
{"one set field zero, one unset field", &pb.GoTest{Param: Int32(0)}, &pb.GoTest{}, false},
{"different set fields", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("bar")}, false},
{"equal set", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("foo")}, true},
{"repeated, one set", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{}, false},
{"repeated, different length", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{F_Int32Repeated: []int32{2}}, false},
{"repeated, different value", &pb.GoTest{F_Int32Repeated: []int32{2}}, &pb.GoTest{F_Int32Repeated: []int32{3}}, false},
{"repeated, equal", &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, true},
{"repeated, nil equal nil", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: nil}, true},
{"repeated, nil equal empty", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: []int32{}}, true},
{"repeated, empty equal nil", &pb.GoTest{F_Int32Repeated: []int32{}}, &pb.GoTest{F_Int32Repeated: nil}, true},
{
"nested, different",
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("foo")}},
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("bar")}},
false,
},
{
"nested, equal",
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}},
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}},
true,
},
{"bytes", &pb.OtherMessage{Value: []byte("foo")}, &pb.OtherMessage{Value: []byte("foo")}, true},
{"bytes, empty", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: []byte{}}, true},
{"bytes, empty vs nil", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: nil}, false},
{
"repeated bytes",
&pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}},
&pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}},
true,
},
{"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false},
{"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true},
{"extension vs. different extension", messageWithExtension1a, messageWithExtension2, false},
{"int32 extension vs. itself", messageWithInt32Extension1, messageWithInt32Extension1, true},
{"int32 extension vs. a different int32", messageWithInt32Extension1, messageWithInt32Extension2, false},
{
"message with group",
&pb.MyMessage{
Count: Int32(1),
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: Int32(5),
},
},
&pb.MyMessage{
Count: Int32(1),
Somegroup: &pb.MyMessage_SomeGroup{
GroupField: Int32(5),
},
},
true,
},
{
"map same",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
true,
},
{
"map different entry",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{2: "Rob"}},
false,
},
{
"map different key only",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{2: "Ken"}},
false,
},
{
"map different value only",
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}},
false,
},
{
"oneof same",
&pb.Communique{Union: &pb.Communique_Number{41}},
&pb.Communique{Union: &pb.Communique_Number{41}},
true,
},
{
"oneof one nil",
&pb.Communique{Union: &pb.Communique_Number{41}},
&pb.Communique{},
false,
},
{
"oneof different",
&pb.Communique{Union: &pb.Communique_Number{41}},
&pb.Communique{Union: &pb.Communique_Name{"Bobby Tables"}},
false,
},
}
func TestEqual(t *testing.T) {
for _, tc := range EqualTests {
if res := Equal(tc.a, tc.b); res != tc.exp {
t.Errorf("%v: Equal(%v, %v) = %v, want %v", tc.desc, tc.a, tc.b, res, tc.exp)
}
}
}

@ -0,0 +1,400 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
/*
* Types and routines for supporting protocol buffer extensions.
*/
import (
"errors"
"fmt"
"reflect"
"strconv"
"sync"
)
// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
var ErrMissingExtension = errors.New("proto: missing extension")
// ExtensionRange represents a range of message extensions for a protocol buffer.
// Used in code generated by the protocol compiler.
type ExtensionRange struct {
Start, End int32 // both inclusive
}
// extendableProto is an interface implemented by any protocol buffer that may be extended.
type extendableProto interface {
Message
ExtensionRangeArray() []ExtensionRange
ExtensionMap() map[int32]Extension
}
var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
// ExtensionDesc represents an extension specification.
// Used in generated code from the protocol compiler.
type ExtensionDesc struct {
ExtendedType Message // nil pointer to the type that is being extended
ExtensionType interface{} // nil pointer to the extension type
Field int32 // field number
Name string // fully-qualified name of extension, for text formatting
Tag string // protobuf tag style
}
func (ed *ExtensionDesc) repeated() bool {
t := reflect.TypeOf(ed.ExtensionType)
return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
}
// Extension represents an extension in a message.
type Extension struct {
// When an extension is stored in a message using SetExtension
// only desc and value are set. When the message is marshaled
// enc will be set to the encoded form of the message.
//
// When a message is unmarshaled and contains extensions, each
// extension will have only enc set. When such an extension is
// accessed using GetExtension (or GetExtensions) desc and value
// will be set.
desc *ExtensionDesc
value interface{}
enc []byte
}
// SetRawExtension is for testing only.
func SetRawExtension(base extendableProto, id int32, b []byte) {
base.ExtensionMap()[id] = Extension{enc: b}
}
// isExtensionField returns true iff the given field number is in an extension range.
func isExtensionField(pb extendableProto, field int32) bool {
for _, er := range pb.ExtensionRangeArray() {
if er.Start <= field && field <= er.End {
return true
}
}
return false
}
// checkExtensionTypes checks that the given extension is valid for pb.
func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
// Check the extended type.
if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
}
// Check the range.
if !isExtensionField(pb, extension.Field) {
return errors.New("proto: bad extension number; not in declared ranges")
}
return nil
}
// extPropKey is sufficient to uniquely identify an extension.
type extPropKey struct {
base reflect.Type
field int32
}
var extProp = struct {
sync.RWMutex
m map[extPropKey]*Properties
}{
m: make(map[extPropKey]*Properties),
}
func extensionProperties(ed *ExtensionDesc) *Properties {
key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
extProp.RLock()
if prop, ok := extProp.m[key]; ok {
extProp.RUnlock()
return prop
}
extProp.RUnlock()
extProp.Lock()
defer extProp.Unlock()
// Check again.
if prop, ok := extProp.m[key]; ok {
return prop
}
prop := new(Properties)
prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
extProp.m[key] = prop
return prop
}
// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
func encodeExtensionMap(m map[int32]Extension) error {
for k, e := range m {
if e.value == nil || e.desc == nil {
// Extension is only in its encoded form.
continue
}
// We don't skip extensions that have an encoded form set,
// because the extension value may have been mutated after
// the last time this function was called.
et := reflect.TypeOf(e.desc.ExtensionType)
props := extensionProperties(e.desc)
p := NewBuffer(nil)
// If e.value has type T, the encoder expects a *struct{ X T }.
// Pass a *T with a zero field and hope it all works out.
x := reflect.New(et)
x.Elem().Set(reflect.ValueOf(e.value))
if err := props.enc(p, props, toStructPointer(x)); err != nil {
return err
}
e.enc = p.buf
m[k] = e
}
return nil
}
func sizeExtensionMap(m map[int32]Extension) (n int) {
for _, e := range m {
if e.value == nil || e.desc == nil {
// Extension is only in its encoded form.
n += len(e.enc)
continue
}
// We don't skip extensions that have an encoded form set,
// because the extension value may have been mutated after
// the last time this function was called.
et := reflect.TypeOf(e.desc.ExtensionType)
props := extensionProperties(e.desc)
// If e.value has type T, the encoder expects a *struct{ X T }.
// Pass a *T with a zero field and hope it all works out.
x := reflect.New(et)
x.Elem().Set(reflect.ValueOf(e.value))
n += props.size(props, toStructPointer(x))
}
return
}
// HasExtension returns whether the given extension is present in pb.
func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.?
_, ok := pb.ExtensionMap()[extension.Field]
return ok
}
// ClearExtension removes the given extension from pb.
func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
// TODO: Check types, field numbers, etc.?
delete(pb.ExtensionMap(), extension.Field)
}
// GetExtension parses and returns the given extension of pb.
// If the extension is not present and has no default value it returns ErrMissingExtension.
func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) {
if err := checkExtensionTypes(pb, extension); err != nil {
return nil, err
}
emap := pb.ExtensionMap()
e, ok := emap[extension.Field]
if !ok {
// defaultExtensionValue returns the default value or
// ErrMissingExtension if there is no default.
return defaultExtensionValue(extension)
}
if e.value != nil {
// Already decoded. Check the descriptor, though.
if e.desc != extension {
// This shouldn't happen. If it does, it means that
// GetExtension was called twice with two different
// descriptors with the same field number.
return nil, errors.New("proto: descriptor conflict")
}
return e.value, nil
}
v, err := decodeExtension(e.enc, extension)
if err != nil {
return nil, err
}
// Remember the decoded version and drop the encoded version.
// That way it is safe to mutate what we return.
e.value = v
e.desc = extension
e.enc = nil
emap[extension.Field] = e
return e.value, nil
}
// defaultExtensionValue returns the default value for extension.
// If no default for an extension is defined ErrMissingExtension is returned.
func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
t := reflect.TypeOf(extension.ExtensionType)
props := extensionProperties(extension)
sf, _, err := fieldDefault(t, props)
if err != nil {
return nil, err
}
if sf == nil || sf.value == nil {
// There is no default value.
return nil, ErrMissingExtension
}
if t.Kind() != reflect.Ptr {
// We do not need to return a Ptr, we can directly return sf.value.
return sf.value, nil
}
// We need to return an interface{} that is a pointer to sf.value.
value := reflect.New(t).Elem()
value.Set(reflect.New(value.Type().Elem()))
if sf.kind == reflect.Int32 {
// We may have an int32 or an enum, but the underlying data is int32.
// Since we can't set an int32 into a non int32 reflect.value directly
// set it as a int32.
value.Elem().SetInt(int64(sf.value.(int32)))
} else {
value.Elem().Set(reflect.ValueOf(sf.value))
}
return value.Interface(), nil
}
// decodeExtension decodes an extension encoded in b.
func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
o := NewBuffer(b)
t := reflect.TypeOf(extension.ExtensionType)
rep := extension.repeated()
props := extensionProperties(extension)
// t is a pointer to a struct, pointer to basic type or a slice.
// Allocate a "field" to store the pointer/slice itself; the
// pointer/slice will be stored here. We pass
// the address of this field to props.dec.
// This passes a zero field and a *t and lets props.dec
// interpret it as a *struct{ x t }.
value := reflect.New(t).Elem()
for {
// Discard wire type and field number varint. It isn't needed.
if _, err := o.DecodeVarint(); err != nil {
return nil, err
}
if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
return nil, err
}
if !rep || o.index >= len(o.buf) {
break
}
}
return value.Interface(), nil
}
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
// The returned slice has the same length as es; missing extensions will appear as nil elements.
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
epb, ok := pb.(extendableProto)
if !ok {
err = errors.New("proto: not an extendable proto")
return
}
extensions = make([]interface{}, len(es))
for i, e := range es {
extensions[i], err = GetExtension(epb, e)
if err == ErrMissingExtension {
err = nil
}
if err != nil {
return
}
}
return
}
// SetExtension sets the specified extension of pb to the specified value.
func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error {
if err := checkExtensionTypes(pb, extension); err != nil {
return err
}
typ := reflect.TypeOf(extension.ExtensionType)
if typ != reflect.TypeOf(value) {
return errors.New("proto: bad extension value type")
}
// nil extension values need to be caught early, because the
// encoder can't distinguish an ErrNil due to a nil extension
// from an ErrNil due to a missing field. Extensions are
// always optional, so the encoder would just swallow the error
// and drop all the extensions from the encoded message.
if reflect.ValueOf(value).IsNil() {
return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
}
pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
return nil
}
// A global registry of extensions.
// The generated code will register the generated descriptors by calling RegisterExtension.
var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
// RegisterExtension is called from the generated code.
func RegisterExtension(desc *ExtensionDesc) {
st := reflect.TypeOf(desc.ExtendedType).Elem()
m := extensionMaps[st]
if m == nil {
m = make(map[int32]*ExtensionDesc)
extensionMaps[st] = m
}
if _, ok := m[desc.Field]; ok {
panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
}
m[desc.Field] = desc
}
// RegisteredExtensions returns a map of the registered extensions of a
// protocol buffer struct, indexed by the extension number.
// The argument pb should be a nil pointer to the struct type.
func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
return extensionMaps[reflect.TypeOf(pb).Elem()]
}

@ -0,0 +1,292 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"fmt"
"reflect"
"testing"
"github.com/golang/protobuf/proto"
pb "github.com/golang/protobuf/proto/testdata"
)
func TestGetExtensionsWithMissingExtensions(t *testing.T) {
msg := &pb.MyMessage{}
ext1 := &pb.Ext{}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
t.Fatalf("Could not set ext1: %s", ext1)
}
exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
pb.E_Ext_More,
pb.E_Ext_Text,
})
if err != nil {
t.Fatalf("GetExtensions() failed: %s", err)
}
if exts[0] != ext1 {
t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
}
if exts[1] != nil {
t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
}
}
func TestGetExtensionStability(t *testing.T) {
check := func(m *pb.MyMessage) bool {
ext1, err := proto.GetExtension(m, pb.E_Ext_More)
if err != nil {
t.Fatalf("GetExtension() failed: %s", err)
}
ext2, err := proto.GetExtension(m, pb.E_Ext_More)
if err != nil {
t.Fatalf("GetExtension() failed: %s", err)
}
return ext1 == ext2
}
msg := &pb.MyMessage{Count: proto.Int32(4)}
ext0 := &pb.Ext{}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
t.Fatalf("Could not set ext1: %s", ext0)
}
if !check(msg) {
t.Errorf("GetExtension() not stable before marshaling")
}
bb, err := proto.Marshal(msg)
if err != nil {
t.Fatalf("Marshal() failed: %s", err)
}
msg1 := &pb.MyMessage{}
err = proto.Unmarshal(bb, msg1)
if err != nil {
t.Fatalf("Unmarshal() failed: %s", err)
}
if !check(msg1) {
t.Errorf("GetExtension() not stable after unmarshaling")
}
}
func TestGetExtensionDefaults(t *testing.T) {
var setFloat64 float64 = 1
var setFloat32 float32 = 2
var setInt32 int32 = 3
var setInt64 int64 = 4
var setUint32 uint32 = 5
var setUint64 uint64 = 6
var setBool = true
var setBool2 = false
var setString = "Goodnight string"
var setBytes = []byte("Goodnight bytes")
var setEnum = pb.DefaultsMessage_TWO
type testcase struct {
ext *proto.ExtensionDesc // Extension we are testing.
want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
def interface{} // Expected value of extension after ClearExtension().
}
tests := []testcase{
{pb.E_NoDefaultDouble, setFloat64, nil},
{pb.E_NoDefaultFloat, setFloat32, nil},
{pb.E_NoDefaultInt32, setInt32, nil},
{pb.E_NoDefaultInt64, setInt64, nil},
{pb.E_NoDefaultUint32, setUint32, nil},
{pb.E_NoDefaultUint64, setUint64, nil},
{pb.E_NoDefaultSint32, setInt32, nil},
{pb.E_NoDefaultSint64, setInt64, nil},
{pb.E_NoDefaultFixed32, setUint32, nil},
{pb.E_NoDefaultFixed64, setUint64, nil},
{pb.E_NoDefaultSfixed32, setInt32, nil},
{pb.E_NoDefaultSfixed64, setInt64, nil},
{pb.E_NoDefaultBool, setBool, nil},
{pb.E_NoDefaultBool, setBool2, nil},
{pb.E_NoDefaultString, setString, nil},
{pb.E_NoDefaultBytes, setBytes, nil},
{pb.E_NoDefaultEnum, setEnum, nil},
{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
{pb.E_DefaultFloat, setFloat32, float32(3.14)},
{pb.E_DefaultInt32, setInt32, int32(42)},
{pb.E_DefaultInt64, setInt64, int64(43)},
{pb.E_DefaultUint32, setUint32, uint32(44)},
{pb.E_DefaultUint64, setUint64, uint64(45)},
{pb.E_DefaultSint32, setInt32, int32(46)},
{pb.E_DefaultSint64, setInt64, int64(47)},
{pb.E_DefaultFixed32, setUint32, uint32(48)},
{pb.E_DefaultFixed64, setUint64, uint64(49)},
{pb.E_DefaultSfixed32, setInt32, int32(50)},
{pb.E_DefaultSfixed64, setInt64, int64(51)},
{pb.E_DefaultBool, setBool, true},
{pb.E_DefaultBool, setBool2, true},
{pb.E_DefaultString, setString, "Hello, string"},
{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
}
checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
val, err := proto.GetExtension(msg, test.ext)
if err != nil {
if valWant != nil {
return fmt.Errorf("GetExtension(): %s", err)
}
if want := proto.ErrMissingExtension; err != want {
return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
}
return nil
}
// All proto2 extension values are either a pointer to a value or a slice of values.
ty := reflect.TypeOf(val)
tyWant := reflect.TypeOf(test.ext.ExtensionType)
if got, want := ty, tyWant; got != want {
return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
}
tye := ty.Elem()
tyeWant := tyWant.Elem()
if got, want := tye, tyeWant; got != want {
return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
}
// Check the name of the type of the value.
// If it is an enum it will be type int32 with the name of the enum.
if got, want := tye.Name(), tye.Name(); got != want {
return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
}
// Check that value is what we expect.
// If we have a pointer in val, get the value it points to.
valExp := val
if ty.Kind() == reflect.Ptr {
valExp = reflect.ValueOf(val).Elem().Interface()
}
if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
}
return nil
}
setTo := func(test testcase) interface{} {
setTo := reflect.ValueOf(test.want)
if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
setTo = reflect.New(typ).Elem()
setTo.Set(reflect.New(setTo.Type().Elem()))
setTo.Elem().Set(reflect.ValueOf(test.want))
}
return setTo.Interface()
}
for _, test := range tests {
msg := &pb.DefaultsMessage{}
name := test.ext.Name
// Check the initial value.
if err := checkVal(test, msg, test.def); err != nil {
t.Errorf("%s: %v", name, err)
}
// Set the per-type value and check value.
name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
t.Errorf("%s: SetExtension(): %v", name, err)
continue
}
if err := checkVal(test, msg, test.want); err != nil {
t.Errorf("%s: %v", name, err)
continue
}
// Set and check the value.
name += " (cleared)"
proto.ClearExtension(msg, test.ext)
if err := checkVal(test, msg, test.def); err != nil {
t.Errorf("%s: %v", name, err)
}
}
}
func TestExtensionsRoundTrip(t *testing.T) {
msg := &pb.MyMessage{}
ext1 := &pb.Ext{
Data: proto.String("hi"),
}
ext2 := &pb.Ext{
Data: proto.String("there"),
}
exists := proto.HasExtension(msg, pb.E_Ext_More)
if exists {
t.Error("Extension More present unexpectedly")
}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
t.Error(err)
}
if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
t.Error(err)
}
e, err := proto.GetExtension(msg, pb.E_Ext_More)
if err != nil {
t.Error(err)
}
x, ok := e.(*pb.Ext)
if !ok {
t.Errorf("e has type %T, expected testdata.Ext", e)
} else if *x.Data != "there" {
t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
}
proto.ClearExtension(msg, pb.E_Ext_More)
if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
t.Errorf("got %v, expected ErrMissingExtension", e)
}
if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
t.Error("expected bad extension error, got nil")
}
if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
t.Error("expected extension err")
}
if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
t.Error("expected some sort of type mismatch error, got nil")
}
}
func TestNilExtension(t *testing.T) {
msg := &pb.MyMessage{
Count: proto.Int32(1),
}
if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
t.Fatal(err)
}
if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
t.Error("expected SetExtension to fail due to a nil extension")
} else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want {
t.Errorf("expected error %v, got %v", want, err)
}
// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
}

@ -0,0 +1,883 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/*
Package proto converts data structures to and from the wire format of
protocol buffers. It works in concert with the Go source code generated
for .proto files by the protocol compiler.
A summary of the properties of the protocol buffer interface
for a protocol buffer variable v:
- Names are turned from camel_case to CamelCase for export.
- There are no methods on v to set fields; just treat
them as structure fields.
- There are getters that return a field's value if set,
and return the field's default value if unset.
The getters work even if the receiver is a nil message.
- The zero value for a struct is its correct initialization state.
All desired fields must be set before marshaling.
- A Reset() method will restore a protobuf struct to its zero state.
- Non-repeated fields are pointers to the values; nil means unset.
That is, optional or required field int32 f becomes F *int32.
- Repeated fields are slices.
- Helper functions are available to aid the setting of fields.
msg.Foo = proto.String("hello") // set field
- Constants are defined to hold the default values of all fields that
have them. They have the form Default_StructName_FieldName.
Because the getter methods handle defaulted values,
direct use of these constants should be rare.
- Enums are given type names and maps from names to values.
Enum values are prefixed by the enclosing message's name, or by the
enum's type name if it is a top-level enum. Enum types have a String
method, and a Enum method to assist in message construction.
- Nested messages, groups and enums have type names prefixed with the name of
the surrounding message type.
- Extensions are given descriptor names that start with E_,
followed by an underscore-delimited list of the nested messages
that contain it (if any) followed by the CamelCased name of the
extension field itself. HasExtension, ClearExtension, GetExtension
and SetExtension are functions for manipulating extensions.
- Oneof field sets are given a single field in their message,
with distinguished wrapper types for each possible field value.
- Marshal and Unmarshal are functions to encode and decode the wire format.
The simplest way to describe this is to see an example.
Given file test.proto, containing
package example;
enum FOO { X = 17; }
message Test {
required string label = 1;
optional int32 type = 2 [default=77];
repeated int64 reps = 3;
optional group OptionalGroup = 4 {
required string RequiredField = 5;
}
oneof union {
int32 number = 6;
string name = 7;
}
}
The resulting file, test.pb.go, is:
package example
import proto "github.com/golang/protobuf/proto"
import math "math"
type FOO int32
const (
FOO_X FOO = 17
)
var FOO_name = map[int32]string{
17: "X",
}
var FOO_value = map[string]int32{
"X": 17,
}
func (x FOO) Enum() *FOO {
p := new(FOO)
*p = x
return p
}
func (x FOO) String() string {
return proto.EnumName(FOO_name, int32(x))
}
func (x *FOO) UnmarshalJSON(data []byte) error {
value, err := proto.UnmarshalJSONEnum(FOO_value, data)
if err != nil {
return err
}
*x = FOO(value)
return nil
}
type Test struct {
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
// Types that are valid to be assigned to Union:
// *Test_Number
// *Test_Name
Union isTest_Union `protobuf_oneof:"union"`
XXX_unrecognized []byte `json:"-"`
}
func (m *Test) Reset() { *m = Test{} }
func (m *Test) String() string { return proto.CompactTextString(m) }
func (*Test) ProtoMessage() {}
type isTest_Union interface {
isTest_Union()
}
type Test_Number struct {
Number int32 `protobuf:"varint,6,opt,name=number"`
}
type Test_Name struct {
Name string `protobuf:"bytes,7,opt,name=name"`
}
func (*Test_Number) isTest_Union() {}
func (*Test_Name) isTest_Union() {}
func (m *Test) GetUnion() isTest_Union {
if m != nil {
return m.Union
}
return nil
}
const Default_Test_Type int32 = 77
func (m *Test) GetLabel() string {
if m != nil && m.Label != nil {
return *m.Label
}
return ""
}
func (m *Test) GetType() int32 {
if m != nil && m.Type != nil {
return *m.Type
}
return Default_Test_Type
}
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
if m != nil {
return m.Optionalgroup
}
return nil
}
type Test_OptionalGroup struct {
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
}
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
func (m *Test_OptionalGroup) GetRequiredField() string {
if m != nil && m.RequiredField != nil {
return *m.RequiredField
}
return ""
}
func (m *Test) GetNumber() int32 {
if x, ok := m.GetUnion().(*Test_Number); ok {
return x.Number
}
return 0
}
func (m *Test) GetName() string {
if x, ok := m.GetUnion().(*Test_Name); ok {
return x.Name
}
return ""
}
func init() {
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
}
To create and play with a Test object:
package main
import (
"log"
"github.com/golang/protobuf/proto"
pb "./example.pb"
)
func main() {
test := &pb.Test{
Label: proto.String("hello"),
Type: proto.Int32(17),
Optionalgroup: &pb.Test_OptionalGroup{
RequiredField: proto.String("good bye"),
},
Union: &pb.Test_Name{"fred"},
}
data, err := proto.Marshal(test)
if err != nil {
log.Fatal("marshaling error: ", err)
}
newTest := &pb.Test{}
err = proto.Unmarshal(data, newTest)
if err != nil {
log.Fatal("unmarshaling error: ", err)
}
// Now test and newTest contain the same data.
if test.GetLabel() != newTest.GetLabel() {
log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel())
}
// Use a type switch to determine which oneof was set.
switch u := test.Union.(type) {
case *pb.Test_Number: // u.Number contains the number.
case *pb.Test_Name: // u.Name contains the string.
}
// etc.
}
*/
package proto
import (
"encoding/json"
"fmt"
"log"
"reflect"
"sort"
"strconv"
"sync"
)
// Message is implemented by generated protocol buffer messages.
type Message interface {
Reset()
String() string
ProtoMessage()
}
// Stats records allocation details about the protocol buffer encoders
// and decoders. Useful for tuning the library itself.
type Stats struct {
Emalloc uint64 // mallocs in encode
Dmalloc uint64 // mallocs in decode
Encode uint64 // number of encodes
Decode uint64 // number of decodes
Chit uint64 // number of cache hits
Cmiss uint64 // number of cache misses
Size uint64 // number of sizes
}
// Set to true to enable stats collection.
const collectStats = false
var stats Stats
// GetStats returns a copy of the global Stats structure.
func GetStats() Stats { return stats }
// A Buffer is a buffer manager for marshaling and unmarshaling
// protocol buffers. It may be reused between invocations to
// reduce memory usage. It is not necessary to use a Buffer;
// the global functions Marshal and Unmarshal create a
// temporary Buffer and are fine for most applications.
type Buffer struct {
buf []byte // encode/decode byte stream
index int // write point
// pools of basic types to amortize allocation.
bools []bool
uint32s []uint32
uint64s []uint64
// extra pools, only used with pointer_reflect.go
int32s []int32
int64s []int64
float32s []float32
float64s []float64
}
// NewBuffer allocates a new Buffer and initializes its internal data to
// the contents of the argument slice.
func NewBuffer(e []byte) *Buffer {
return &Buffer{buf: e}
}
// Reset resets the Buffer, ready for marshaling a new protocol buffer.
func (p *Buffer) Reset() {
p.buf = p.buf[0:0] // for reading/writing
p.index = 0 // for reading
}
// SetBuf replaces the internal buffer with the slice,
// ready for unmarshaling the contents of the slice.
func (p *Buffer) SetBuf(s []byte) {
p.buf = s
p.index = 0
}
// Bytes returns the contents of the Buffer.
func (p *Buffer) Bytes() []byte { return p.buf }
/*
* Helper routines for simplifying the creation of optional fields of basic type.
*/
// Bool is a helper routine that allocates a new bool value
// to store v and returns a pointer to it.
func Bool(v bool) *bool {
return &v
}
// Int32 is a helper routine that allocates a new int32 value
// to store v and returns a pointer to it.
func Int32(v int32) *int32 {
return &v
}
// Int is a helper routine that allocates a new int32 value
// to store v and returns a pointer to it, but unlike Int32
// its argument value is an int.
func Int(v int) *int32 {
p := new(int32)
*p = int32(v)
return p
}
// Int64 is a helper routine that allocates a new int64 value
// to store v and returns a pointer to it.
func Int64(v int64) *int64 {
return &v
}
// Float32 is a helper routine that allocates a new float32 value
// to store v and returns a pointer to it.
func Float32(v float32) *float32 {
return &v
}
// Float64 is a helper routine that allocates a new float64 value
// to store v and returns a pointer to it.
func Float64(v float64) *float64 {
return &v
}
// Uint32 is a helper routine that allocates a new uint32 value
// to store v and returns a pointer to it.
func Uint32(v uint32) *uint32 {
return &v
}
// Uint64 is a helper routine that allocates a new uint64 value
// to store v and returns a pointer to it.
func Uint64(v uint64) *uint64 {
return &v
}
// String is a helper routine that allocates a new string value
// to store v and returns a pointer to it.
func String(v string) *string {
return &v
}
// EnumName is a helper function to simplify printing protocol buffer enums
// by name. Given an enum map and a value, it returns a useful string.
func EnumName(m map[int32]string, v int32) string {
s, ok := m[v]
if ok {
return s
}
return strconv.Itoa(int(v))
}
// UnmarshalJSONEnum is a helper function to simplify recovering enum int values
// from their JSON-encoded representation. Given a map from the enum's symbolic
// names to its int values, and a byte buffer containing the JSON-encoded
// value, it returns an int32 that can be cast to the enum type by the caller.
//
// The function can deal with both JSON representations, numeric and symbolic.
func UnmarshalJSONEnum(m map[string]int32, data []byte, enumName string) (int32, error) {
if data[0] == '"' {
// New style: enums are strings.
var repr string
if err := json.Unmarshal(data, &repr); err != nil {
return -1, err
}
val, ok := m[repr]
if !ok {
return 0, fmt.Errorf("unrecognized enum %s value %q", enumName, repr)
}
return val, nil
}
// Old style: enums are ints.
var val int32
if err := json.Unmarshal(data, &val); err != nil {
return 0, fmt.Errorf("cannot unmarshal %#q into enum %s", data, enumName)
}
return val, nil
}
// DebugPrint dumps the encoded data in b in a debugging format with a header
// including the string s. Used in testing but made available for general debugging.
func (p *Buffer) DebugPrint(s string, b []byte) {
var u uint64
obuf := p.buf
index := p.index
p.buf = b
p.index = 0
depth := 0
fmt.Printf("\n--- %s ---\n", s)
out:
for {
for i := 0; i < depth; i++ {
fmt.Print(" ")
}
index := p.index
if index == len(p.buf) {
break
}
op, err := p.DecodeVarint()
if err != nil {
fmt.Printf("%3d: fetching op err %v\n", index, err)
break out
}
tag := op >> 3
wire := op & 7
switch wire {
default:
fmt.Printf("%3d: t=%3d unknown wire=%d\n",
index, tag, wire)
break out
case WireBytes:
var r []byte
r, err = p.DecodeRawBytes(false)
if err != nil {
break out
}
fmt.Printf("%3d: t=%3d bytes [%d]", index, tag, len(r))
if len(r) <= 6 {
for i := 0; i < len(r); i++ {
fmt.Printf(" %.2x", r[i])
}
} else {
for i := 0; i < 3; i++ {
fmt.Printf(" %.2x", r[i])
}
fmt.Printf(" ..")
for i := len(r) - 3; i < len(r); i++ {
fmt.Printf(" %.2x", r[i])
}
}
fmt.Printf("\n")
case WireFixed32:
u, err = p.DecodeFixed32()
if err != nil {
fmt.Printf("%3d: t=%3d fix32 err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d fix32 %d\n", index, tag, u)
case WireFixed64:
u, err = p.DecodeFixed64()
if err != nil {
fmt.Printf("%3d: t=%3d fix64 err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u)
case WireVarint:
u, err = p.DecodeVarint()
if err != nil {
fmt.Printf("%3d: t=%3d varint err %v\n", index, tag, err)
break out
}
fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u)
case WireStartGroup:
fmt.Printf("%3d: t=%3d start\n", index, tag)
depth++
case WireEndGroup:
depth--
fmt.Printf("%3d: t=%3d end\n", index, tag)
}
}
if depth != 0 {
fmt.Printf("%3d: start-end not balanced %d\n", p.index, depth)
}
fmt.Printf("\n")
p.buf = obuf
p.index = index
}
// SetDefaults sets unset protocol buffer fields to their default values.
// It only modifies fields that are both unset and have defined defaults.
// It recursively sets default values in any non-nil sub-messages.
func SetDefaults(pb Message) {
setDefaults(reflect.ValueOf(pb), true, false)
}
// v is a pointer to a struct.
func setDefaults(v reflect.Value, recur, zeros bool) {
v = v.Elem()
defaultMu.RLock()
dm, ok := defaults[v.Type()]
defaultMu.RUnlock()
if !ok {
dm = buildDefaultMessage(v.Type())
defaultMu.Lock()
defaults[v.Type()] = dm
defaultMu.Unlock()
}
for _, sf := range dm.scalars {
f := v.Field(sf.index)
if !f.IsNil() {
// field already set
continue
}
dv := sf.value
if dv == nil && !zeros {
// no explicit default, and don't want to set zeros
continue
}
fptr := f.Addr().Interface() // **T
// TODO: Consider batching the allocations we do here.
switch sf.kind {
case reflect.Bool:
b := new(bool)
if dv != nil {
*b = dv.(bool)
}
*(fptr.(**bool)) = b
case reflect.Float32:
f := new(float32)
if dv != nil {
*f = dv.(float32)
}
*(fptr.(**float32)) = f
case reflect.Float64:
f := new(float64)
if dv != nil {
*f = dv.(float64)
}
*(fptr.(**float64)) = f
case reflect.Int32:
// might be an enum
if ft := f.Type(); ft != int32PtrType {
// enum
f.Set(reflect.New(ft.Elem()))
if dv != nil {
f.Elem().SetInt(int64(dv.(int32)))
}
} else {
// int32 field
i := new(int32)
if dv != nil {
*i = dv.(int32)
}
*(fptr.(**int32)) = i
}
case reflect.Int64:
i := new(int64)
if dv != nil {
*i = dv.(int64)
}
*(fptr.(**int64)) = i
case reflect.String:
s := new(string)
if dv != nil {
*s = dv.(string)
}
*(fptr.(**string)) = s
case reflect.Uint8:
// exceptional case: []byte
var b []byte
if dv != nil {
db := dv.([]byte)
b = make([]byte, len(db))
copy(b, db)
} else {
b = []byte{}
}
*(fptr.(*[]byte)) = b
case reflect.Uint32:
u := new(uint32)
if dv != nil {
*u = dv.(uint32)
}
*(fptr.(**uint32)) = u
case reflect.Uint64:
u := new(uint64)
if dv != nil {
*u = dv.(uint64)
}
*(fptr.(**uint64)) = u
default:
log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind)
}
}
for _, ni := range dm.nested {
f := v.Field(ni)
// f is *T or []*T or map[T]*T
switch f.Kind() {
case reflect.Ptr:
if f.IsNil() {
continue
}
setDefaults(f, recur, zeros)
case reflect.Slice:
for i := 0; i < f.Len(); i++ {
e := f.Index(i)
if e.IsNil() {
continue
}
setDefaults(e, recur, zeros)
}
case reflect.Map:
for _, k := range f.MapKeys() {
e := f.MapIndex(k)
if e.IsNil() {
continue
}
setDefaults(e, recur, zeros)
}
}
}
}
var (
// defaults maps a protocol buffer struct type to a slice of the fields,
// with its scalar fields set to their proto-declared non-zero default values.
defaultMu sync.RWMutex
defaults = make(map[reflect.Type]defaultMessage)
int32PtrType = reflect.TypeOf((*int32)(nil))
)
// defaultMessage represents information about the default values of a message.
type defaultMessage struct {
scalars []scalarField
nested []int // struct field index of nested messages
}
type scalarField struct {
index int // struct field index
kind reflect.Kind // element type (the T in *T or []T)
value interface{} // the proto-declared default value, or nil
}
// t is a struct type.
func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
sprop := GetProperties(t)
for _, prop := range sprop.Prop {
fi, ok := sprop.decoderTags.get(prop.Tag)
if !ok {
// XXX_unrecognized
continue
}
ft := t.Field(fi).Type
sf, nested, err := fieldDefault(ft, prop)
switch {
case err != nil:
log.Print(err)
case nested:
dm.nested = append(dm.nested, fi)
case sf != nil:
sf.index = fi
dm.scalars = append(dm.scalars, *sf)
}
}
return dm
}
// fieldDefault returns the scalarField for field type ft.
// sf will be nil if the field can not have a default.
// nestedMessage will be true if this is a nested message.
// Note that sf.index is not set on return.
func fieldDefault(ft reflect.Type, prop *Properties) (sf *scalarField, nestedMessage bool, err error) {
var canHaveDefault bool
switch ft.Kind() {
case reflect.Ptr:
if ft.Elem().Kind() == reflect.Struct {
nestedMessage = true
} else {
canHaveDefault = true // proto2 scalar field
}
case reflect.Slice:
switch ft.Elem().Kind() {
case reflect.Ptr:
nestedMessage = true // repeated message
case reflect.Uint8:
canHaveDefault = true // bytes field
}
case reflect.Map:
if ft.Elem().Kind() == reflect.Ptr {
nestedMessage = true // map with message values
}
}
if !canHaveDefault {
if nestedMessage {
return nil, true, nil
}
return nil, false, nil
}
// We now know that ft is a pointer or slice.
sf = &scalarField{kind: ft.Elem().Kind()}
// scalar fields without defaults
if !prop.HasDefault {
return sf, false, nil
}
// a scalar field: either *T or []byte
switch ft.Elem().Kind() {
case reflect.Bool:
x, err := strconv.ParseBool(prop.Default)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default bool %q: %v", prop.Default, err)
}
sf.value = x
case reflect.Float32:
x, err := strconv.ParseFloat(prop.Default, 32)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default float32 %q: %v", prop.Default, err)
}
sf.value = float32(x)
case reflect.Float64:
x, err := strconv.ParseFloat(prop.Default, 64)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default float64 %q: %v", prop.Default, err)
}
sf.value = x
case reflect.Int32:
x, err := strconv.ParseInt(prop.Default, 10, 32)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default int32 %q: %v", prop.Default, err)
}
sf.value = int32(x)
case reflect.Int64:
x, err := strconv.ParseInt(prop.Default, 10, 64)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default int64 %q: %v", prop.Default, err)
}
sf.value = x
case reflect.String:
sf.value = prop.Default
case reflect.Uint8:
// []byte (not *uint8)
sf.value = []byte(prop.Default)
case reflect.Uint32:
x, err := strconv.ParseUint(prop.Default, 10, 32)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default uint32 %q: %v", prop.Default, err)
}
sf.value = uint32(x)
case reflect.Uint64:
x, err := strconv.ParseUint(prop.Default, 10, 64)
if err != nil {
return nil, false, fmt.Errorf("proto: bad default uint64 %q: %v", prop.Default, err)
}
sf.value = x
default:
return nil, false, fmt.Errorf("proto: unhandled def kind %v", ft.Elem().Kind())
}
return sf, false, nil
}
// Map fields may have key types of non-float scalars, strings and enums.
// The easiest way to sort them in some deterministic order is to use fmt.
// If this turns out to be inefficient we can always consider other options,
// such as doing a Schwartzian transform.
func mapKeys(vs []reflect.Value) sort.Interface {
s := mapKeySorter{
vs: vs,
// default Less function: textual comparison
less: func(a, b reflect.Value) bool {
return fmt.Sprint(a.Interface()) < fmt.Sprint(b.Interface())
},
}
// Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps;
// numeric keys are sorted numerically.
if len(vs) == 0 {
return s
}
switch vs[0].Kind() {
case reflect.Int32, reflect.Int64:
s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
case reflect.Uint32, reflect.Uint64:
s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
}
return s
}
type mapKeySorter struct {
vs []reflect.Value
less func(a, b reflect.Value) bool
}
func (s mapKeySorter) Len() int { return len(s.vs) }
func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
func (s mapKeySorter) Less(i, j int) bool {
return s.less(s.vs[i], s.vs[j])
}
// isProto3Zero reports whether v is a zero proto3 value.
func isProto3Zero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool:
return !v.Bool()
case reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint32, reflect.Uint64:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.String:
return v.String() == ""
}
return false
}

@ -0,0 +1,287 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2010 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
/*
* Support for message sets.
*/
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"reflect"
"sort"
)
// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID.
// A message type ID is required for storing a protocol buffer in a message set.
var ErrNoMessageTypeId = errors.New("proto does not have a message type ID")
// The first two types (_MessageSet_Item and MessageSet)
// model what the protocol compiler produces for the following protocol message:
// message MessageSet {
// repeated group Item = 1 {
// required int32 type_id = 2;
// required string message = 3;
// };
// }
// That is the MessageSet wire format. We can't use a proto to generate these
// because that would introduce a circular dependency between it and this package.
//
// When a proto1 proto has a field that looks like:
// optional message<MessageSet> info = 3;
// the protocol compiler produces a field in the generated struct that looks like:
// Info *_proto_.MessageSet `protobuf:"bytes,3,opt,name=info"`
// The package is automatically inserted so there is no need for that proto file to
// import this package.
type _MessageSet_Item struct {
TypeId *int32 `protobuf:"varint,2,req,name=type_id"`
Message []byte `protobuf:"bytes,3,req,name=message"`
}
type MessageSet struct {
Item []*_MessageSet_Item `protobuf:"group,1,rep"`
XXX_unrecognized []byte
// TODO: caching?
}
// Make sure MessageSet is a Message.
var _ Message = (*MessageSet)(nil)
// messageTypeIder is an interface satisfied by a protocol buffer type
// that may be stored in a MessageSet.
type messageTypeIder interface {
MessageTypeId() int32
}
func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
mti, ok := pb.(messageTypeIder)
if !ok {
return nil
}
id := mti.MessageTypeId()
for _, item := range ms.Item {
if *item.TypeId == id {
return item
}
}
return nil
}
func (ms *MessageSet) Has(pb Message) bool {
if ms.find(pb) != nil {
return true
}
return false
}
func (ms *MessageSet) Unmarshal(pb Message) error {
if item := ms.find(pb); item != nil {
return Unmarshal(item.Message, pb)
}
if _, ok := pb.(messageTypeIder); !ok {
return ErrNoMessageTypeId
}
return nil // TODO: return error instead?
}
func (ms *MessageSet) Marshal(pb Message) error {
msg, err := Marshal(pb)
if err != nil {
return err
}
if item := ms.find(pb); item != nil {
// reuse existing item
item.Message = msg
return nil
}
mti, ok := pb.(messageTypeIder)
if !ok {
return ErrNoMessageTypeId
}
mtid := mti.MessageTypeId()
ms.Item = append(ms.Item, &_MessageSet_Item{
TypeId: &mtid,
Message: msg,
})
return nil
}
func (ms *MessageSet) Reset() { *ms = MessageSet{} }
func (ms *MessageSet) String() string { return CompactTextString(ms) }
func (*MessageSet) ProtoMessage() {}
// Support for the message_set_wire_format message option.
func skipVarint(buf []byte) []byte {
i := 0
for ; buf[i]&0x80 != 0; i++ {
}
return buf[i+1:]
}
// MarshalMessageSet encodes the extension map represented by m in the message set wire format.
// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
if err := encodeExtensionMap(m); err != nil {
return nil, err
}
// Sort extension IDs to provide a deterministic encoding.
// See also enc_map in encode.go.
ids := make([]int, 0, len(m))
for id := range m {
ids = append(ids, int(id))
}
sort.Ints(ids)
ms := &MessageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
for _, id := range ids {
e := m[int32(id)]
// Remove the wire type and field number varint, as well as the length varint.
msg := skipVarint(skipVarint(e.enc))
ms.Item = append(ms.Item, &_MessageSet_Item{
TypeId: Int32(int32(id)),
Message: msg,
})
}
return Marshal(ms)
}
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
ms := new(MessageSet)
if err := Unmarshal(buf, ms); err != nil {
return err
}
for _, item := range ms.Item {
id := *item.TypeId
msg := item.Message
// Restore wire type and field number varint, plus length varint.
// Be careful to preserve duplicate items.
b := EncodeVarint(uint64(id)<<3 | WireBytes)
if ext, ok := m[id]; ok {
// Existing data; rip off the tag and length varint
// so we join the new data correctly.
// We can assume that ext.enc is set because we are unmarshaling.
o := ext.enc[len(b):] // skip wire type and field number
_, n := DecodeVarint(o) // calculate length of length varint
o = o[n:] // skip length varint
msg = append(o, msg...) // join old data and new data
}
b = append(b, EncodeVarint(uint64(len(msg)))...)
b = append(b, msg...)
m[id] = Extension{enc: b}
}
return nil
}
// MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
// It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) {
var b bytes.Buffer
b.WriteByte('{')
// Process the map in key order for deterministic output.
ids := make([]int32, 0, len(m))
for id := range m {
ids = append(ids, id)
}
sort.Sort(int32Slice(ids)) // int32Slice defined in text.go
for i, id := range ids {
ext := m[id]
if i > 0 {
b.WriteByte(',')
}
msd, ok := messageSetMap[id]
if !ok {
// Unknown type; we can't render it, so skip it.
continue
}
fmt.Fprintf(&b, `"[%s]":`, msd.name)
x := ext.value
if x == nil {
x = reflect.New(msd.t.Elem()).Interface()
if err := Unmarshal(ext.enc, x.(Message)); err != nil {
return nil, err
}
}
d, err := json.Marshal(x)
if err != nil {
return nil, err
}
b.Write(d)
}
b.WriteByte('}')
return b.Bytes(), nil
}
// UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
// It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error {
// Common-case fast path.
if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
return nil
}
// This is fairly tricky, and it's not clear that it is needed.
return errors.New("TODO: UnmarshalMessageSetJSON not yet implemented")
}
// A global registry of types that can be used in a MessageSet.
var messageSetMap = make(map[int32]messageSetDesc)
type messageSetDesc struct {
t reflect.Type // pointer to struct
name string
}
// RegisterMessageSetType is called from the generated code.
func RegisterMessageSetType(m Message, fieldNum int32, name string) {
messageSetMap[fieldNum] = messageSetDesc{
t: reflect.TypeOf(m),
name: name,
}
}

@ -0,0 +1,66 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2014 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto
import (
"bytes"
"testing"
)
func TestUnmarshalMessageSetWithDuplicate(t *testing.T) {
// Check that a repeated message set entry will be concatenated.
in := &MessageSet{
Item: []*_MessageSet_Item{
{TypeId: Int32(12345), Message: []byte("hoo")},
{TypeId: Int32(12345), Message: []byte("hah")},
},
}
b, err := Marshal(in)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
t.Logf("Marshaled bytes: %q", b)
m := make(map[int32]Extension)
if err := UnmarshalMessageSet(b, m); err != nil {
t.Fatalf("UnmarshalMessageSet: %v", err)
}
ext, ok := m[12345]
if !ok {
t.Fatalf("Didn't retrieve extension 12345; map is %v", m)
}
// Skip wire type/field number and length varints.
got := skipVarint(skipVarint(ext.enc))
if want := []byte("hoohah"); !bytes.Equal(got, want) {
t.Errorf("Combined extension is %q, want %q", got, want)
}
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save