parent
4e45962240
commit
b9f6943615
@ -0,0 +1,57 @@
|
||||
{
|
||||
"ImportPath": "github.com/codeskyblue/gosuv",
|
||||
"GoVersion": "go1.4.2",
|
||||
"Deps": [
|
||||
{
|
||||
"ImportPath": "github.com/bradfitz/http2",
|
||||
"Rev": "f8202bc903bda493ebba4aa54922d78430c2c42f"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/codegangsta/cli",
|
||||
"Comment": "1.2.0-139-g142e6cd",
|
||||
"Rev": "142e6cd241a4dfbf7f07a018f1f8225180018da4"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/codeskyblue/kproc",
|
||||
"Rev": "fcb55eb35ab6b7290f395ad596e80e8355d52d69"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/franela/goreq",
|
||||
"Rev": "72c51a544272e007ab3da4f7d9ac959b7af7af03"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/golang/protobuf/proto",
|
||||
"Rev": "1dceb1a2654bdc74ca97ad91f71f500eecc96269"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/lunny/log",
|
||||
"Rev": "9ae5e9effeefca29cf00d932a7b54f1b2d70fd58"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/lunny/tango",
|
||||
"Comment": "v0.4.6-25-g652e7bd",
|
||||
"Rev": "652e7bd578c564f7ae9307ed102d188e564baf54"
|
||||
},
|
||||
{
|
||||
"ImportPath": "github.com/qiniu/log",
|
||||
"Comment": "v1.0.00-2-ge002bc2",
|
||||
"Rev": "e002bc2020b19bfa61ed378cc5407383dbd2f346"
|
||||
},
|
||||
{
|
||||
"ImportPath": "golang.org/x/net/context",
|
||||
"Rev": "db8e4de5b2d6653f66aea53094624468caad15d2"
|
||||
},
|
||||
{
|
||||
"ImportPath": "golang.org/x/net/internal/timeseries",
|
||||
"Rev": "db8e4de5b2d6653f66aea53094624468caad15d2"
|
||||
},
|
||||
{
|
||||
"ImportPath": "golang.org/x/net/trace",
|
||||
"Rev": "db8e4de5b2d6653f66aea53094624468caad15d2"
|
||||
},
|
||||
{
|
||||
"ImportPath": "google.golang.org/grpc",
|
||||
"Rev": "6b6127d8c67b67f9f8509b84279eac053964e868"
|
||||
}
|
||||
]
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
This directory tree is generated automatically by godep.
|
||||
|
||||
Please do not edit.
|
||||
|
||||
See https://github.com/tools/godep for more information.
|
@ -0,0 +1,2 @@
|
||||
/pkg
|
||||
/bin
|
@ -0,0 +1,2 @@
|
||||
*~
|
||||
h2i/h2i
|
@ -0,0 +1,20 @@
|
||||
# This file is like Go's AUTHORS file: it lists Copyright holders.
|
||||
# The list of humans who have contributd is in the CONTRIBUTORS file.
|
||||
#
|
||||
# To contribute to this project, because it will eventually be folded
|
||||
# back in to Go itself, you need to submit a CLA:
|
||||
#
|
||||
# http://golang.org/doc/contribute.html#copyright
|
||||
#
|
||||
# Then you get added to CONTRIBUTORS and you or your company get added
|
||||
# to the AUTHORS file.
|
||||
|
||||
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
|
||||
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
|
||||
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
|
||||
Google, Inc.
|
||||
Keith Rarick <kr@xph.us> github=kr
|
||||
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
|
||||
Matt Layher <mdlayher@gmail.com> github=mdlayher
|
||||
Perry Abbott <perry.j.abbott@gmail.com> github=pabbott0
|
||||
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t
|
@ -0,0 +1,20 @@
|
||||
# This file is like Go's CONTRIBUTORS file: it lists humans.
|
||||
# The list of copyright holders (which may be companies) are in the AUTHORS file.
|
||||
#
|
||||
# To contribute to this project, because it will eventually be folded
|
||||
# back in to Go itself, you need to submit a CLA:
|
||||
#
|
||||
# http://golang.org/doc/contribute.html#copyright
|
||||
#
|
||||
# Then you get added to CONTRIBUTORS and you or your company get added
|
||||
# to the AUTHORS file.
|
||||
|
||||
Blake Mizerany <blake.mizerany@gmail.com> github=bmizerany
|
||||
Brad Fitzpatrick <bradfitz@golang.org> github=bradfitz
|
||||
Daniel Morsing <daniel.morsing@gmail.com> github=DanielMorsing
|
||||
Gabriel Aszalos <gabriel.aszalos@gmail.com> github=gbbr
|
||||
Keith Rarick <kr@xph.us> github=kr
|
||||
Matthew Keenan <tank.en.mate@gmail.com> <github@mattkeenan.net> github=mattkeenan
|
||||
Matt Layher <mdlayher@gmail.com> github=mdlayher
|
||||
Perry Abbott <perry.j.abbott@gmail.com> github=pabbott0
|
||||
Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> github=tatsuhiro-t
|
@ -0,0 +1,44 @@
|
||||
#
|
||||
# This Dockerfile builds a recent curl with HTTP/2 client support, using
|
||||
# a recent nghttp2 build.
|
||||
#
|
||||
# See the Makefile for how to tag it. If Docker and that image is found, the
|
||||
# Go tests use this curl binary for integration tests.
|
||||
#
|
||||
|
||||
FROM ubuntu:trusty
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get upgrade -y && \
|
||||
apt-get install -y git-core build-essential wget
|
||||
|
||||
RUN apt-get install -y --no-install-recommends \
|
||||
autotools-dev libtool pkg-config zlib1g-dev \
|
||||
libcunit1-dev libssl-dev libxml2-dev libevent-dev \
|
||||
automake autoconf
|
||||
|
||||
# Note: setting NGHTTP2_VER before the git clone, so an old git clone isn't cached:
|
||||
ENV NGHTTP2_VER af24f8394e43f4
|
||||
RUN cd /root && git clone https://github.com/tatsuhiro-t/nghttp2.git
|
||||
|
||||
WORKDIR /root/nghttp2
|
||||
RUN git reset --hard $NGHTTP2_VER
|
||||
RUN autoreconf -i
|
||||
RUN automake
|
||||
RUN autoconf
|
||||
RUN ./configure
|
||||
RUN make
|
||||
RUN make install
|
||||
|
||||
WORKDIR /root
|
||||
RUN wget http://curl.haxx.se/download/curl-7.40.0.tar.gz
|
||||
RUN tar -zxvf curl-7.40.0.tar.gz
|
||||
WORKDIR /root/curl-7.40.0
|
||||
RUN ./configure --with-ssl --with-nghttp2=/usr/local
|
||||
RUN make
|
||||
RUN make install
|
||||
RUN ldconfig
|
||||
|
||||
CMD ["-h"]
|
||||
ENTRYPOINT ["/usr/local/bin/curl"]
|
||||
|
@ -0,0 +1,5 @@
|
||||
We only accept contributions from users who have gone through Go's
|
||||
contribution process (signed a CLA).
|
||||
|
||||
Please acknowledge whether you have (and use the same email) if
|
||||
sending a pull request.
|
@ -0,0 +1,7 @@
|
||||
Copyright 2014 Google & the Go AUTHORS
|
||||
|
||||
Go AUTHORS are:
|
||||
See https://code.google.com/p/go/source/browse/AUTHORS
|
||||
|
||||
Licensed under the terms of Go itself:
|
||||
https://code.google.com/p/go/source/browse/LICENSE
|
@ -0,0 +1,3 @@
|
||||
curlimage:
|
||||
docker build -t gohttp2/curl .
|
||||
|
@ -0,0 +1,17 @@
|
||||
This is a work-in-progress HTTP/2 implementation for Go.
|
||||
|
||||
It will eventually live in the Go standard library and won't require
|
||||
any changes to your code to use. It will just be automatic.
|
||||
|
||||
Status:
|
||||
|
||||
* The server support is pretty good. A few things are missing
|
||||
but are being worked on.
|
||||
* The client work has just started but shares a lot of code
|
||||
is coming along much quicker.
|
||||
|
||||
Docs are at https://godoc.org/github.com/bradfitz/http2
|
||||
|
||||
Demo test server at https://http2.golang.org/
|
||||
|
||||
Help & bug reports welcome.
|
@ -0,0 +1,76 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// buffer is an io.ReadWriteCloser backed by a fixed size buffer.
|
||||
// It never allocates, but moves old data as new data is written.
|
||||
type buffer struct {
|
||||
buf []byte
|
||||
r, w int
|
||||
closed bool
|
||||
err error // err to return to reader
|
||||
}
|
||||
|
||||
var (
|
||||
errReadEmpty = errors.New("read from empty buffer")
|
||||
errWriteClosed = errors.New("write on closed buffer")
|
||||
errWriteFull = errors.New("write on full buffer")
|
||||
)
|
||||
|
||||
// Read copies bytes from the buffer into p.
|
||||
// It is an error to read when no data is available.
|
||||
func (b *buffer) Read(p []byte) (n int, err error) {
|
||||
n = copy(p, b.buf[b.r:b.w])
|
||||
b.r += n
|
||||
if b.closed && b.r == b.w {
|
||||
err = b.err
|
||||
} else if b.r == b.w && n == 0 {
|
||||
err = errReadEmpty
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Len returns the number of bytes of the unread portion of the buffer.
|
||||
func (b *buffer) Len() int {
|
||||
return b.w - b.r
|
||||
}
|
||||
|
||||
// Write copies bytes from p into the buffer.
|
||||
// It is an error to write more data than the buffer can hold.
|
||||
func (b *buffer) Write(p []byte) (n int, err error) {
|
||||
if b.closed {
|
||||
return 0, errWriteClosed
|
||||
}
|
||||
|
||||
// Slide existing data to beginning.
|
||||
if b.r > 0 && len(p) > len(b.buf)-b.w {
|
||||
copy(b.buf, b.buf[b.r:b.w])
|
||||
b.w -= b.r
|
||||
b.r = 0
|
||||
}
|
||||
|
||||
// Write new data.
|
||||
n = copy(b.buf[b.w:], p)
|
||||
b.w += n
|
||||
if n < len(p) {
|
||||
err = errWriteFull
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close marks the buffer as closed. Future calls to Write will
|
||||
// return an error. Future calls to Read, once the buffer is
|
||||
// empty, will return err.
|
||||
func (b *buffer) Close(err error) {
|
||||
if !b.closed {
|
||||
b.closed = true
|
||||
b.err = err
|
||||
}
|
||||
}
|
@ -0,0 +1,154 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var bufferReadTests = []struct {
|
||||
buf buffer
|
||||
read, wn int
|
||||
werr error
|
||||
wp []byte
|
||||
wbuf buffer
|
||||
}{
|
||||
{
|
||||
buffer{[]byte{'a', 0}, 0, 1, false, nil},
|
||||
5, 1, nil, []byte{'a'},
|
||||
buffer{[]byte{'a', 0}, 1, 1, false, nil},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{'a', 0}, 0, 1, true, io.EOF},
|
||||
5, 1, io.EOF, []byte{'a'},
|
||||
buffer{[]byte{'a', 0}, 1, 1, true, io.EOF},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{0, 'a'}, 1, 2, false, nil},
|
||||
5, 1, nil, []byte{'a'},
|
||||
buffer{[]byte{0, 'a'}, 2, 2, false, nil},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{0, 'a'}, 1, 2, true, io.EOF},
|
||||
5, 1, io.EOF, []byte{'a'},
|
||||
buffer{[]byte{0, 'a'}, 2, 2, true, io.EOF},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{}, 0, 0, false, nil},
|
||||
5, 0, errReadEmpty, []byte{},
|
||||
buffer{[]byte{}, 0, 0, false, nil},
|
||||
},
|
||||
{
|
||||
buffer{[]byte{}, 0, 0, true, io.EOF},
|
||||
5, 0, io.EOF, []byte{},
|
||||
buffer{[]byte{}, 0, 0, true, io.EOF},
|
||||
},
|
||||
}
|
||||
|
||||
func TestBufferRead(t *testing.T) {
|
||||
for i, tt := range bufferReadTests {
|
||||
read := make([]byte, tt.read)
|
||||
n, err := tt.buf.Read(read)
|
||||
if n != tt.wn {
|
||||
t.Errorf("#%d: wn = %d want %d", i, n, tt.wn)
|
||||
continue
|
||||
}
|
||||
if err != tt.werr {
|
||||
t.Errorf("#%d: werr = %v want %v", i, err, tt.werr)
|
||||
continue
|
||||
}
|
||||
read = read[:n]
|
||||
if !reflect.DeepEqual(read, tt.wp) {
|
||||
t.Errorf("#%d: read = %+v want %+v", i, read, tt.wp)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.buf, tt.wbuf) {
|
||||
t.Errorf("#%d: buf = %+v want %+v", i, tt.buf, tt.wbuf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var bufferWriteTests = []struct {
|
||||
buf buffer
|
||||
write, wn int
|
||||
werr error
|
||||
wbuf buffer
|
||||
}{
|
||||
{
|
||||
buf: buffer{
|
||||
buf: []byte{},
|
||||
},
|
||||
wbuf: buffer{
|
||||
buf: []byte{},
|
||||
},
|
||||
},
|
||||
{
|
||||
buf: buffer{
|
||||
buf: []byte{1, 'a'},
|
||||
},
|
||||
write: 1,
|
||||
wn: 1,
|
||||
wbuf: buffer{
|
||||
buf: []byte{0, 'a'},
|
||||
w: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
buf: buffer{
|
||||
buf: []byte{'a', 1},
|
||||
r: 1,
|
||||
w: 1,
|
||||
},
|
||||
write: 2,
|
||||
wn: 2,
|
||||
wbuf: buffer{
|
||||
buf: []byte{0, 0},
|
||||
w: 2,
|
||||
},
|
||||
},
|
||||
{
|
||||
buf: buffer{
|
||||
buf: []byte{},
|
||||
r: 1,
|
||||
closed: true,
|
||||
},
|
||||
write: 5,
|
||||
werr: errWriteClosed,
|
||||
wbuf: buffer{
|
||||
buf: []byte{},
|
||||
r: 1,
|
||||
closed: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
buf: buffer{
|
||||
buf: []byte{},
|
||||
},
|
||||
write: 5,
|
||||
werr: errWriteFull,
|
||||
wbuf: buffer{
|
||||
buf: []byte{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestBufferWrite(t *testing.T) {
|
||||
for i, tt := range bufferWriteTests {
|
||||
n, err := tt.buf.Write(make([]byte, tt.write))
|
||||
if n != tt.wn {
|
||||
t.Errorf("#%d: wrote %d bytes; want %d", i, n, tt.wn)
|
||||
continue
|
||||
}
|
||||
if err != tt.werr {
|
||||
t.Errorf("#%d: error = %v; want %v", i, err, tt.werr)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(tt.buf, tt.wbuf) {
|
||||
t.Errorf("#%d: buf = %+v; want %+v", i, tt.buf, tt.wbuf)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,78 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import "fmt"
|
||||
|
||||
// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
|
||||
type ErrCode uint32
|
||||
|
||||
const (
|
||||
ErrCodeNo ErrCode = 0x0
|
||||
ErrCodeProtocol ErrCode = 0x1
|
||||
ErrCodeInternal ErrCode = 0x2
|
||||
ErrCodeFlowControl ErrCode = 0x3
|
||||
ErrCodeSettingsTimeout ErrCode = 0x4
|
||||
ErrCodeStreamClosed ErrCode = 0x5
|
||||
ErrCodeFrameSize ErrCode = 0x6
|
||||
ErrCodeRefusedStream ErrCode = 0x7
|
||||
ErrCodeCancel ErrCode = 0x8
|
||||
ErrCodeCompression ErrCode = 0x9
|
||||
ErrCodeConnect ErrCode = 0xa
|
||||
ErrCodeEnhanceYourCalm ErrCode = 0xb
|
||||
ErrCodeInadequateSecurity ErrCode = 0xc
|
||||
ErrCodeHTTP11Required ErrCode = 0xd
|
||||
)
|
||||
|
||||
var errCodeName = map[ErrCode]string{
|
||||
ErrCodeNo: "NO_ERROR",
|
||||
ErrCodeProtocol: "PROTOCOL_ERROR",
|
||||
ErrCodeInternal: "INTERNAL_ERROR",
|
||||
ErrCodeFlowControl: "FLOW_CONTROL_ERROR",
|
||||
ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT",
|
||||
ErrCodeStreamClosed: "STREAM_CLOSED",
|
||||
ErrCodeFrameSize: "FRAME_SIZE_ERROR",
|
||||
ErrCodeRefusedStream: "REFUSED_STREAM",
|
||||
ErrCodeCancel: "CANCEL",
|
||||
ErrCodeCompression: "COMPRESSION_ERROR",
|
||||
ErrCodeConnect: "CONNECT_ERROR",
|
||||
ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM",
|
||||
ErrCodeInadequateSecurity: "INADEQUATE_SECURITY",
|
||||
ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED",
|
||||
}
|
||||
|
||||
func (e ErrCode) String() string {
|
||||
if s, ok := errCodeName[e]; ok {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("unknown error code 0x%x", uint32(e))
|
||||
}
|
||||
|
||||
// ConnectionError is an error that results in the termination of the
|
||||
// entire connection.
|
||||
type ConnectionError ErrCode
|
||||
|
||||
func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: %s", ErrCode(e)) }
|
||||
|
||||
// StreamError is an error that only affects one stream within an
|
||||
// HTTP/2 connection.
|
||||
type StreamError struct {
|
||||
StreamID uint32
|
||||
Code ErrCode
|
||||
}
|
||||
|
||||
func (e StreamError) Error() string {
|
||||
return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
|
||||
}
|
||||
|
||||
// 6.9.1 The Flow Control Window
|
||||
// "If a sender receives a WINDOW_UPDATE that causes a flow control
|
||||
// window to exceed this maximum it MUST terminate either the stream
|
||||
// or the connection, as appropriate. For streams, [...]; for the
|
||||
// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code."
|
||||
type goAwayFlowError struct{}
|
||||
|
||||
func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" }
|
@ -0,0 +1,27 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestErrCodeString(t *testing.T) {
|
||||
tests := []struct {
|
||||
err ErrCode
|
||||
want string
|
||||
}{
|
||||
{ErrCodeProtocol, "PROTOCOL_ERROR"},
|
||||
{0xd, "HTTP_1_1_REQUIRED"},
|
||||
{0xf, "unknown error code 0xf"},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
got := tt.err.String()
|
||||
if got != tt.want {
|
||||
t.Errorf("%d. Error = %q; want %q", i, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
// Flow control
|
||||
|
||||
package http2
|
||||
|
||||
// flow is the flow control window's size.
|
||||
type flow struct {
|
||||
// n is the number of DATA bytes we're allowed to send.
|
||||
// A flow is kept both on a conn and a per-stream.
|
||||
n int32
|
||||
|
||||
// conn points to the shared connection-level flow that is
|
||||
// shared by all streams on that conn. It is nil for the flow
|
||||
// that's on the conn directly.
|
||||
conn *flow
|
||||
}
|
||||
|
||||
func (f *flow) setConnFlow(cf *flow) { f.conn = cf }
|
||||
|
||||
func (f *flow) available() int32 {
|
||||
n := f.n
|
||||
if f.conn != nil && f.conn.n < n {
|
||||
n = f.conn.n
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (f *flow) take(n int32) {
|
||||
if n > f.available() {
|
||||
panic("internal error: took too much")
|
||||
}
|
||||
f.n -= n
|
||||
if f.conn != nil {
|
||||
f.conn.n -= n
|
||||
}
|
||||
}
|
||||
|
||||
// add adds n bytes (positive or negative) to the flow control window.
|
||||
// It returns false if the sum would exceed 2^31-1.
|
||||
func (f *flow) add(n int32) bool {
|
||||
remain := (1<<31 - 1) - f.n
|
||||
if n > remain {
|
||||
return false
|
||||
}
|
||||
f.n += n
|
||||
return true
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFlow(t *testing.T) {
|
||||
var st flow
|
||||
var conn flow
|
||||
st.add(3)
|
||||
conn.add(2)
|
||||
|
||||
if got, want := st.available(), int32(3); got != want {
|
||||
t.Errorf("available = %d; want %d", got, want)
|
||||
}
|
||||
st.setConnFlow(&conn)
|
||||
if got, want := st.available(), int32(2); got != want {
|
||||
t.Errorf("after parent setup, available = %d; want %d", got, want)
|
||||
}
|
||||
|
||||
st.take(2)
|
||||
if got, want := conn.available(), int32(0); got != want {
|
||||
t.Errorf("after taking 2, conn = %d; want %d", got, want)
|
||||
}
|
||||
if got, want := st.available(), int32(0); got != want {
|
||||
t.Errorf("after taking 2, stream = %d; want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlowAdd(t *testing.T) {
|
||||
var f flow
|
||||
if !f.add(1) {
|
||||
t.Fatal("failed to add 1")
|
||||
}
|
||||
if !f.add(-1) {
|
||||
t.Fatal("failed to add -1")
|
||||
}
|
||||
if got, want := f.available(), int32(0); got != want {
|
||||
t.Fatalf("size = %d; want %d", got, want)
|
||||
}
|
||||
if !f.add(1<<31 - 1) {
|
||||
t.Fatal("failed to add 2^31-1")
|
||||
}
|
||||
if got, want := f.available(), int32(1<<31-1); got != want {
|
||||
t.Fatalf("size = %d; want %d", got, want)
|
||||
}
|
||||
if f.add(1) {
|
||||
t.Fatal("adding 1 to max shouldn't be allowed")
|
||||
}
|
||||
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,597 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func testFramer() (*Framer, *bytes.Buffer) {
|
||||
buf := new(bytes.Buffer)
|
||||
return NewFramer(buf, buf), buf
|
||||
}
|
||||
|
||||
func TestFrameSizes(t *testing.T) {
|
||||
// Catch people rearranging the FrameHeader fields.
|
||||
if got, want := int(unsafe.Sizeof(FrameHeader{})), 12; got != want {
|
||||
t.Errorf("FrameHeader size = %d; want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrameTypeString(t *testing.T) {
|
||||
tests := []struct {
|
||||
ft FrameType
|
||||
want string
|
||||
}{
|
||||
{FrameData, "DATA"},
|
||||
{FramePing, "PING"},
|
||||
{FrameGoAway, "GOAWAY"},
|
||||
{0xf, "UNKNOWN_FRAME_TYPE_15"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := tt.ft.String()
|
||||
if got != tt.want {
|
||||
t.Errorf("%d. String(FrameType %d) = %q; want %q", i, int(tt.ft), got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteRST(t *testing.T) {
|
||||
fr, buf := testFramer()
|
||||
var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
|
||||
var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4
|
||||
fr.WriteRSTStream(streamID, ErrCode(errCode))
|
||||
const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04"
|
||||
if buf.String() != wantEnc {
|
||||
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
|
||||
}
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := &RSTStreamFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
Type: 0x3,
|
||||
Flags: 0x0,
|
||||
Length: 0x4,
|
||||
StreamID: 0x1020304,
|
||||
},
|
||||
ErrCode: 0x7060504,
|
||||
}
|
||||
if !reflect.DeepEqual(f, want) {
|
||||
t.Errorf("parsed back %#v; want %#v", f, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteData(t *testing.T) {
|
||||
fr, buf := testFramer()
|
||||
var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4
|
||||
data := []byte("ABC")
|
||||
fr.WriteData(streamID, true, data)
|
||||
const wantEnc = "\x00\x00\x03\x00\x01\x01\x02\x03\x04ABC"
|
||||
if buf.String() != wantEnc {
|
||||
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
|
||||
}
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
df, ok := f.(*DataFrame)
|
||||
if !ok {
|
||||
t.Fatalf("got %T; want *DataFrame", f)
|
||||
}
|
||||
if !bytes.Equal(df.Data(), data) {
|
||||
t.Errorf("got %q; want %q", df.Data(), data)
|
||||
}
|
||||
if f.Header().Flags&1 == 0 {
|
||||
t.Errorf("didn't see END_STREAM flag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
p HeadersFrameParam
|
||||
wantEnc string
|
||||
wantFrame *HeadersFrame
|
||||
}{
|
||||
{
|
||||
"basic",
|
||||
HeadersFrameParam{
|
||||
StreamID: 42,
|
||||
BlockFragment: []byte("abc"),
|
||||
Priority: PriorityParam{},
|
||||
},
|
||||
"\x00\x00\x03\x01\x00\x00\x00\x00*abc",
|
||||
&HeadersFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
StreamID: 42,
|
||||
Type: FrameHeaders,
|
||||
Length: uint32(len("abc")),
|
||||
},
|
||||
Priority: PriorityParam{},
|
||||
headerFragBuf: []byte("abc"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"basic + end flags",
|
||||
HeadersFrameParam{
|
||||
StreamID: 42,
|
||||
BlockFragment: []byte("abc"),
|
||||
EndStream: true,
|
||||
EndHeaders: true,
|
||||
Priority: PriorityParam{},
|
||||
},
|
||||
"\x00\x00\x03\x01\x05\x00\x00\x00*abc",
|
||||
&HeadersFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
StreamID: 42,
|
||||
Type: FrameHeaders,
|
||||
Flags: FlagHeadersEndStream | FlagHeadersEndHeaders,
|
||||
Length: uint32(len("abc")),
|
||||
},
|
||||
Priority: PriorityParam{},
|
||||
headerFragBuf: []byte("abc"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"with padding",
|
||||
HeadersFrameParam{
|
||||
StreamID: 42,
|
||||
BlockFragment: []byte("abc"),
|
||||
EndStream: true,
|
||||
EndHeaders: true,
|
||||
PadLength: 5,
|
||||
Priority: PriorityParam{},
|
||||
},
|
||||
"\x00\x00\t\x01\r\x00\x00\x00*\x05abc\x00\x00\x00\x00\x00",
|
||||
&HeadersFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
StreamID: 42,
|
||||
Type: FrameHeaders,
|
||||
Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded,
|
||||
Length: uint32(1 + len("abc") + 5), // pad length + contents + padding
|
||||
},
|
||||
Priority: PriorityParam{},
|
||||
headerFragBuf: []byte("abc"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"with priority",
|
||||
HeadersFrameParam{
|
||||
StreamID: 42,
|
||||
BlockFragment: []byte("abc"),
|
||||
EndStream: true,
|
||||
EndHeaders: true,
|
||||
PadLength: 2,
|
||||
Priority: PriorityParam{
|
||||
StreamDep: 15,
|
||||
Exclusive: true,
|
||||
Weight: 127,
|
||||
},
|
||||
},
|
||||
"\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x0f\u007fabc\x00\x00",
|
||||
&HeadersFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
StreamID: 42,
|
||||
Type: FrameHeaders,
|
||||
Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority,
|
||||
Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding
|
||||
},
|
||||
Priority: PriorityParam{
|
||||
StreamDep: 15,
|
||||
Exclusive: true,
|
||||
Weight: 127,
|
||||
},
|
||||
headerFragBuf: []byte("abc"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
fr, buf := testFramer()
|
||||
if err := fr.WriteHeaders(tt.p); err != nil {
|
||||
t.Errorf("test %q: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
if buf.String() != tt.wantEnc {
|
||||
t.Errorf("test %q: encoded %q; want %q", tt.name, buf.Bytes(), tt.wantEnc)
|
||||
}
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(f, tt.wantFrame) {
|
||||
t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteContinuation(t *testing.T) {
|
||||
const streamID = 42
|
||||
tests := []struct {
|
||||
name string
|
||||
end bool
|
||||
frag []byte
|
||||
|
||||
wantFrame *ContinuationFrame
|
||||
}{
|
||||
{
|
||||
"not end",
|
||||
false,
|
||||
[]byte("abc"),
|
||||
&ContinuationFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
StreamID: streamID,
|
||||
Type: FrameContinuation,
|
||||
Length: uint32(len("abc")),
|
||||
},
|
||||
headerFragBuf: []byte("abc"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"end",
|
||||
true,
|
||||
[]byte("def"),
|
||||
&ContinuationFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
StreamID: streamID,
|
||||
Type: FrameContinuation,
|
||||
Flags: FlagContinuationEndHeaders,
|
||||
Length: uint32(len("def")),
|
||||
},
|
||||
headerFragBuf: []byte("def"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
fr, _ := testFramer()
|
||||
if err := fr.WriteContinuation(streamID, tt.end, tt.frag); err != nil {
|
||||
t.Errorf("test %q: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(f, tt.wantFrame) {
|
||||
t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritePriority(t *testing.T) {
|
||||
const streamID = 42
|
||||
tests := []struct {
|
||||
name string
|
||||
priority PriorityParam
|
||||
wantFrame *PriorityFrame
|
||||
}{
|
||||
{
|
||||
"not exclusive",
|
||||
PriorityParam{
|
||||
StreamDep: 2,
|
||||
Exclusive: false,
|
||||
Weight: 127,
|
||||
},
|
||||
&PriorityFrame{
|
||||
FrameHeader{
|
||||
valid: true,
|
||||
StreamID: streamID,
|
||||
Type: FramePriority,
|
||||
Length: 5,
|
||||
},
|
||||
PriorityParam{
|
||||
StreamDep: 2,
|
||||
Exclusive: false,
|
||||
Weight: 127,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
"exclusive",
|
||||
PriorityParam{
|
||||
StreamDep: 3,
|
||||
Exclusive: true,
|
||||
Weight: 77,
|
||||
},
|
||||
&PriorityFrame{
|
||||
FrameHeader{
|
||||
valid: true,
|
||||
StreamID: streamID,
|
||||
Type: FramePriority,
|
||||
Length: 5,
|
||||
},
|
||||
PriorityParam{
|
||||
StreamDep: 3,
|
||||
Exclusive: true,
|
||||
Weight: 77,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
fr, _ := testFramer()
|
||||
if err := fr.WritePriority(streamID, tt.priority); err != nil {
|
||||
t.Errorf("test %q: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Errorf("test %q: failed to read the frame back: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(f, tt.wantFrame) {
|
||||
t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteSettings(t *testing.T) {
|
||||
fr, buf := testFramer()
|
||||
settings := []Setting{{1, 2}, {3, 4}}
|
||||
fr.WriteSettings(settings...)
|
||||
const wantEnc = "\x00\x00\f\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x03\x00\x00\x00\x04"
|
||||
if buf.String() != wantEnc {
|
||||
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
|
||||
}
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sf, ok := f.(*SettingsFrame)
|
||||
if !ok {
|
||||
t.Fatalf("Got a %T; want a SettingsFrame", f)
|
||||
}
|
||||
var got []Setting
|
||||
sf.ForeachSetting(func(s Setting) error {
|
||||
got = append(got, s)
|
||||
valBack, ok := sf.Value(s.ID)
|
||||
if !ok || valBack != s.Val {
|
||||
t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok, s.Val)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if !reflect.DeepEqual(settings, got) {
|
||||
t.Errorf("Read settings %+v != written settings %+v", got, settings)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteSettingsAck(t *testing.T) {
|
||||
fr, buf := testFramer()
|
||||
fr.WriteSettingsAck()
|
||||
const wantEnc = "\x00\x00\x00\x04\x01\x00\x00\x00\x00"
|
||||
if buf.String() != wantEnc {
|
||||
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteWindowUpdate(t *testing.T) {
|
||||
fr, buf := testFramer()
|
||||
const streamID = 1<<24 + 2<<16 + 3<<8 + 4
|
||||
const incr = 7<<24 + 6<<16 + 5<<8 + 4
|
||||
if err := fr.WriteWindowUpdate(streamID, incr); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
const wantEnc = "\x00\x00\x04\x08\x00\x01\x02\x03\x04\x07\x06\x05\x04"
|
||||
if buf.String() != wantEnc {
|
||||
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
|
||||
}
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := &WindowUpdateFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
Type: 0x8,
|
||||
Flags: 0x0,
|
||||
Length: 0x4,
|
||||
StreamID: 0x1020304,
|
||||
},
|
||||
Increment: 0x7060504,
|
||||
}
|
||||
if !reflect.DeepEqual(f, want) {
|
||||
t.Errorf("parsed back %#v; want %#v", f, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritePing(t *testing.T) { testWritePing(t, false) }
|
||||
func TestWritePingAck(t *testing.T) { testWritePing(t, true) }
|
||||
|
||||
func testWritePing(t *testing.T, ack bool) {
|
||||
fr, buf := testFramer()
|
||||
if err := fr.WritePing(ack, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var wantFlags Flags
|
||||
if ack {
|
||||
wantFlags = FlagPingAck
|
||||
}
|
||||
var wantEnc = "\x00\x00\x08\x06" + string(wantFlags) + "\x00\x00\x00\x00" + "\x01\x02\x03\x04\x05\x06\x07\x08"
|
||||
if buf.String() != wantEnc {
|
||||
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
|
||||
}
|
||||
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := &PingFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
Type: 0x6,
|
||||
Flags: wantFlags,
|
||||
Length: 0x8,
|
||||
StreamID: 0,
|
||||
},
|
||||
Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
}
|
||||
if !reflect.DeepEqual(f, want) {
|
||||
t.Errorf("parsed back %#v; want %#v", f, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrameHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
in string
|
||||
want FrameHeader
|
||||
}{
|
||||
{in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: FrameHeader{}},
|
||||
{in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: FrameHeader{
|
||||
Length: 66051, Type: 4, Flags: 5, StreamID: 101124105,
|
||||
}},
|
||||
// Ignore high bit:
|
||||
{in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: FrameHeader{
|
||||
Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
|
||||
{in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: FrameHeader{
|
||||
Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
got, err := readFrameHeader(make([]byte, 9), strings.NewReader(tt.in))
|
||||
if err != nil {
|
||||
t.Errorf("%d. readFrameHeader(%q) = %v", i, tt.in, err)
|
||||
continue
|
||||
}
|
||||
tt.want.valid = true
|
||||
if got != tt.want {
|
||||
t.Errorf("%d. readFrameHeader(%q) = %+v; want %+v", i, tt.in, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadWriteFrameHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
len uint32
|
||||
typ FrameType
|
||||
flags Flags
|
||||
streamID uint32
|
||||
}{
|
||||
{len: 0, typ: 255, flags: 1, streamID: 0},
|
||||
{len: 0, typ: 255, flags: 1, streamID: 1},
|
||||
{len: 0, typ: 255, flags: 1, streamID: 255},
|
||||
{len: 0, typ: 255, flags: 1, streamID: 256},
|
||||
{len: 0, typ: 255, flags: 1, streamID: 65535},
|
||||
{len: 0, typ: 255, flags: 1, streamID: 65536},
|
||||
|
||||
{len: 0, typ: 1, flags: 255, streamID: 1},
|
||||
{len: 255, typ: 1, flags: 255, streamID: 1},
|
||||
{len: 256, typ: 1, flags: 255, streamID: 1},
|
||||
{len: 65535, typ: 1, flags: 255, streamID: 1},
|
||||
{len: 65536, typ: 1, flags: 255, streamID: 1},
|
||||
{len: 16777215, typ: 1, flags: 255, streamID: 1},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
fr, buf := testFramer()
|
||||
fr.startWrite(tt.typ, tt.flags, tt.streamID)
|
||||
fr.writeBytes(make([]byte, tt.len))
|
||||
fr.endWrite()
|
||||
fh, err := ReadFrameHeader(buf)
|
||||
if err != nil {
|
||||
t.Errorf("ReadFrameHeader(%+v) = %v", tt, err)
|
||||
continue
|
||||
}
|
||||
if fh.Type != tt.typ || fh.Flags != tt.flags || fh.Length != tt.len || fh.StreamID != tt.streamID {
|
||||
t.Errorf("ReadFrameHeader(%+v) = %+v; mismatch", tt, fh)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestWriteTooLargeFrame(t *testing.T) {
|
||||
fr, _ := testFramer()
|
||||
fr.startWrite(0, 1, 1)
|
||||
fr.writeBytes(make([]byte, 1<<24))
|
||||
err := fr.endWrite()
|
||||
if err != ErrFrameTooLarge {
|
||||
t.Errorf("endWrite = %v; want errFrameTooLarge", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteGoAway(t *testing.T) {
|
||||
const debug = "foo"
|
||||
fr, buf := testFramer()
|
||||
if err := fr.WriteGoAway(0x01020304, 0x05060708, []byte(debug)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
const wantEnc = "\x00\x00\v\a\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08" + debug
|
||||
if buf.String() != wantEnc {
|
||||
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
|
||||
}
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := &GoAwayFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
Type: 0x7,
|
||||
Flags: 0,
|
||||
Length: uint32(4 + 4 + len(debug)),
|
||||
StreamID: 0,
|
||||
},
|
||||
LastStreamID: 0x01020304,
|
||||
ErrCode: 0x05060708,
|
||||
debugData: []byte(debug),
|
||||
}
|
||||
if !reflect.DeepEqual(f, want) {
|
||||
t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
|
||||
}
|
||||
if got := string(f.(*GoAwayFrame).DebugData()); got != debug {
|
||||
t.Errorf("debug data = %q; want %q", got, debug)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritePushPromise(t *testing.T) {
|
||||
pp := PushPromiseParam{
|
||||
StreamID: 42,
|
||||
PromiseID: 42,
|
||||
BlockFragment: []byte("abc"),
|
||||
}
|
||||
fr, buf := testFramer()
|
||||
if err := fr.WritePushPromise(pp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
const wantEnc = "\x00\x00\x07\x05\x00\x00\x00\x00*\x00\x00\x00*abc"
|
||||
if buf.String() != wantEnc {
|
||||
t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc)
|
||||
}
|
||||
f, err := fr.ReadFrame()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, ok := f.(*PushPromiseFrame)
|
||||
if !ok {
|
||||
t.Fatalf("got %T; want *PushPromiseFrame", f)
|
||||
}
|
||||
want := &PushPromiseFrame{
|
||||
FrameHeader: FrameHeader{
|
||||
valid: true,
|
||||
Type: 0x5,
|
||||
Flags: 0x0,
|
||||
Length: 0x7,
|
||||
StreamID: 42,
|
||||
},
|
||||
PromiseID: 42,
|
||||
headerFragBuf: []byte("abc"),
|
||||
}
|
||||
if !reflect.DeepEqual(f, want) {
|
||||
t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want)
|
||||
}
|
||||
}
|
@ -0,0 +1,173 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
// Defensive debug-only utility to track that functions run on the
|
||||
// goroutine that they're supposed to.
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
|
||||
|
||||
type goroutineLock uint64
|
||||
|
||||
func newGoroutineLock() goroutineLock {
|
||||
if !DebugGoroutines {
|
||||
return 0
|
||||
}
|
||||
return goroutineLock(curGoroutineID())
|
||||
}
|
||||
|
||||
func (g goroutineLock) check() {
|
||||
if !DebugGoroutines {
|
||||
return
|
||||
}
|
||||
if curGoroutineID() != uint64(g) {
|
||||
panic("running on the wrong goroutine")
|
||||
}
|
||||
}
|
||||
|
||||
func (g goroutineLock) checkNotOn() {
|
||||
if !DebugGoroutines {
|
||||
return
|
||||
}
|
||||
if curGoroutineID() == uint64(g) {
|
||||
panic("running on the wrong goroutine")
|
||||
}
|
||||
}
|
||||
|
||||
var goroutineSpace = []byte("goroutine ")
|
||||
|
||||
func curGoroutineID() uint64 {
|
||||
bp := littleBuf.Get().(*[]byte)
|
||||
defer littleBuf.Put(bp)
|
||||
b := *bp
|
||||
b = b[:runtime.Stack(b, false)]
|
||||
// Parse the 4707 out of "goroutine 4707 ["
|
||||
b = bytes.TrimPrefix(b, goroutineSpace)
|
||||
i := bytes.IndexByte(b, ' ')
|
||||
if i < 0 {
|
||||
panic(fmt.Sprintf("No space found in %q", b))
|
||||
}
|
||||
b = b[:i]
|
||||
n, err := parseUintBytes(b, 10, 64)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
var littleBuf = sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, 64)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// parseUintBytes is like strconv.ParseUint, but using a []byte.
|
||||
func parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) {
|
||||
var cutoff, maxVal uint64
|
||||
|
||||
if bitSize == 0 {
|
||||
bitSize = int(strconv.IntSize)
|
||||
}
|
||||
|
||||
s0 := s
|
||||
switch {
|
||||
case len(s) < 1:
|
||||
err = strconv.ErrSyntax
|
||||
goto Error
|
||||
|
||||
case 2 <= base && base <= 36:
|
||||
// valid base; nothing to do
|
||||
|
||||
case base == 0:
|
||||
// Look for octal, hex prefix.
|
||||
switch {
|
||||
case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
|
||||
base = 16
|
||||
s = s[2:]
|
||||
if len(s) < 1 {
|
||||
err = strconv.ErrSyntax
|
||||
goto Error
|
||||
}
|
||||
case s[0] == '0':
|
||||
base = 8
|
||||
default:
|
||||
base = 10
|
||||
}
|
||||
|
||||
default:
|
||||
err = errors.New("invalid base " + strconv.Itoa(base))
|
||||
goto Error
|
||||
}
|
||||
|
||||
n = 0
|
||||
cutoff = cutoff64(base)
|
||||
maxVal = 1<<uint(bitSize) - 1
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
var v byte
|
||||
d := s[i]
|
||||
switch {
|
||||
case '0' <= d && d <= '9':
|
||||
v = d - '0'
|
||||
case 'a' <= d && d <= 'z':
|
||||
v = d - 'a' + 10
|
||||
case 'A' <= d && d <= 'Z':
|
||||
v = d - 'A' + 10
|
||||
default:
|
||||
n = 0
|
||||
err = strconv.ErrSyntax
|
||||
goto Error
|
||||
}
|
||||
if int(v) >= base {
|
||||
n = 0
|
||||
err = strconv.ErrSyntax
|
||||
goto Error
|
||||
}
|
||||
|
||||
if n >= cutoff {
|
||||
// n*base overflows
|
||||
n = 1<<64 - 1
|
||||
err = strconv.ErrRange
|
||||
goto Error
|
||||
}
|
||||
n *= uint64(base)
|
||||
|
||||
n1 := n + uint64(v)
|
||||
if n1 < n || n1 > maxVal {
|
||||
// n+v overflows
|
||||
n = 1<<64 - 1
|
||||
err = strconv.ErrRange
|
||||
goto Error
|
||||
}
|
||||
n = n1
|
||||
}
|
||||
|
||||
return n, nil
|
||||
|
||||
Error:
|
||||
return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err}
|
||||
}
|
||||
|
||||
// Return the first number n such that n*base >= 1<<64.
|
||||
func cutoff64(base int) uint64 {
|
||||
if base < 2 {
|
||||
return 0
|
||||
}
|
||||
return (1<<64-1)/uint64(base) + 1
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGoroutineLock(t *testing.T) {
|
||||
DebugGoroutines = true
|
||||
g := newGoroutineLock()
|
||||
g.check()
|
||||
|
||||
sawPanic := make(chan interface{})
|
||||
go func() {
|
||||
defer func() { sawPanic <- recover() }()
|
||||
g.check() // should panic
|
||||
}()
|
||||
e := <-sawPanic
|
||||
if e == nil {
|
||||
t.Fatal("did not see panic from check in other goroutine")
|
||||
}
|
||||
if !strings.Contains(fmt.Sprint(e), "wrong goroutine") {
|
||||
t.Errorf("expected on see panic about running on the wrong goroutine; got %v", e)
|
||||
}
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
h2demo
|
||||
h2demo.linux
|
||||
client-id.dat
|
||||
client-secret.dat
|
||||
token.dat
|
@ -0,0 +1,5 @@
|
||||
h2demo.linux: h2demo.go
|
||||
GOOS=linux go build --tags=h2demo -o h2demo.linux .
|
||||
|
||||
upload: h2demo.linux
|
||||
cat h2demo.linux | go run launch.go --write_object=http2-demo-server-tls/h2demo --write_object_is_public
|
@ -0,0 +1,16 @@
|
||||
|
||||
Client:
|
||||
-- Firefox nightly with about:config network.http.spdy.enabled.http2draft set true
|
||||
-- Chrome: go to chrome://flags/#enable-spdy4, save and restart (button at bottom)
|
||||
|
||||
Make CA:
|
||||
$ openssl genrsa -out rootCA.key 2048
|
||||
$ openssl req -x509 -new -nodes -key rootCA.key -days 1024 -out rootCA.pem
|
||||
... install that to Firefox
|
||||
|
||||
Make cert:
|
||||
$ openssl genrsa -out server.key 2048
|
||||
$ openssl req -new -key server.key -out server.csr
|
||||
$ openssl x509 -req -in server.csr -CA rootCA.pem -CAkey rootCA.key -CAcreateserial -out server.crt -days 500
|
||||
|
||||
|
@ -0,0 +1,426 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
// +build h2demo
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"path"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"camlistore.org/pkg/googlestorage"
|
||||
"camlistore.org/pkg/singleflight"
|
||||
"github.com/bradfitz/http2"
|
||||
)
|
||||
|
||||
var (
|
||||
openFirefox = flag.Bool("openff", false, "Open Firefox")
|
||||
addr = flag.String("addr", "localhost:4430", "TLS address to listen on")
|
||||
httpAddr = flag.String("httpaddr", "", "If non-empty, address to listen for regular HTTP on")
|
||||
prod = flag.Bool("prod", false, "Whether to configure itself to be the production http2.golang.org server.")
|
||||
)
|
||||
|
||||
func homeOldHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, `<html>
|
||||
<body>
|
||||
<h1>Go + HTTP/2</h1>
|
||||
<p>Welcome to <a href="https://golang.org/">the Go language</a>'s <a href="https://http2.github.io/">HTTP/2</a> demo & interop server.</p>
|
||||
<p>Unfortunately, you're <b>not</b> using HTTP/2 right now. To do so:</p>
|
||||
<ul>
|
||||
<li>Use Firefox Nightly or go to <b>about:config</b> and enable "network.http.spdy.enabled.http2draft"</li>
|
||||
<li>Use Google Chrome Canary and/or go to <b>chrome://flags/#enable-spdy4</b> to <i>Enable SPDY/4</i> (Chrome's name for HTTP/2)</li>
|
||||
</ul>
|
||||
<p>See code & instructions for connecting at <a href="https://github.com/bradfitz/http2">https://github.com/bradfitz/http2</a>.</p>
|
||||
|
||||
</body></html>`)
|
||||
}
|
||||
|
||||
func home(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
io.WriteString(w, `<html>
|
||||
<body>
|
||||
<h1>Go + HTTP/2</h1>
|
||||
|
||||
<p>Welcome to <a href="https://golang.org/">the Go language</a>'s <a
|
||||
href="https://http2.github.io/">HTTP/2</a> demo & interop server.</p>
|
||||
|
||||
<p>Congratulations, <b>you're using HTTP/2 right now</b>.</p>
|
||||
|
||||
<p>This server exists for others in the HTTP/2 community to test their HTTP/2 client implementations and point out flaws in our server.</p>
|
||||
|
||||
<p> The code is currently at <a
|
||||
href="https://github.com/bradfitz/http2">github.com/bradfitz/http2</a>
|
||||
but will move to the Go standard library at some point in the future
|
||||
(enabled by default, without users needing to change their code).</p>
|
||||
|
||||
<p>Contact info: <i>bradfitz@golang.org</i>, or <a
|
||||
href="https://github.com/bradfitz/http2/issues">file a bug</a>.</p>
|
||||
|
||||
<h2>Handlers for testing</h2>
|
||||
<ul>
|
||||
<li>GET <a href="/reqinfo">/reqinfo</a> to dump the request + headers received</li>
|
||||
<li>GET <a href="/clockstream">/clockstream</a> streams the current time every second</li>
|
||||
<li>GET <a href="/gophertiles">/gophertiles</a> to see a page with a bunch of images</li>
|
||||
<li>GET <a href="/file/gopher.png">/file/gopher.png</a> for a small file (does If-Modified-Since, Content-Range, etc)</li>
|
||||
<li>GET <a href="/file/go.src.tar.gz">/file/go.src.tar.gz</a> for a larger file (~10 MB)</li>
|
||||
<li>GET <a href="/redirect">/redirect</a> to redirect back to / (this page)</li>
|
||||
<li>GET <a href="/goroutines">/goroutines</a> to see all active goroutines in this server</li>
|
||||
<li>PUT something to <a href="/crc32">/crc32</a> to get a count of number of bytes and its CRC-32</li>
|
||||
</ul>
|
||||
|
||||
</body></html>`)
|
||||
}
|
||||
|
||||
func reqInfoHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprintf(w, "Method: %s\n", r.Method)
|
||||
fmt.Fprintf(w, "Protocol: %s\n", r.Proto)
|
||||
fmt.Fprintf(w, "Host: %s\n", r.Host)
|
||||
fmt.Fprintf(w, "RemoteAddr: %s\n", r.RemoteAddr)
|
||||
fmt.Fprintf(w, "RequestURI: %q\n", r.RequestURI)
|
||||
fmt.Fprintf(w, "URL: %#v\n", r.URL)
|
||||
fmt.Fprintf(w, "Body.ContentLength: %d (-1 means unknown)\n", r.ContentLength)
|
||||
fmt.Fprintf(w, "Close: %v (relevant for HTTP/1 only)\n", r.Close)
|
||||
fmt.Fprintf(w, "TLS: %#v\n", r.TLS)
|
||||
fmt.Fprintf(w, "\nHeaders:\n")
|
||||
r.Header.Write(w)
|
||||
}
|
||||
|
||||
func crcHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PUT" {
|
||||
http.Error(w, "PUT required.", 400)
|
||||
return
|
||||
}
|
||||
crc := crc32.NewIEEE()
|
||||
n, err := io.Copy(crc, r.Body)
|
||||
if err == nil {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
fmt.Fprintf(w, "bytes=%d, CRC32=%x", n, crc.Sum(nil))
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
fsGrp singleflight.Group
|
||||
fsMu sync.Mutex // guards fsCache
|
||||
fsCache = map[string]http.Handler{}
|
||||
)
|
||||
|
||||
// fileServer returns a file-serving handler that proxies URL.
|
||||
// It lazily fetches URL on the first access and caches its contents forever.
|
||||
func fileServer(url string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
hi, err := fsGrp.Do(url, func() (interface{}, error) {
|
||||
fsMu.Lock()
|
||||
if h, ok := fsCache[url]; ok {
|
||||
fsMu.Unlock()
|
||||
return h, nil
|
||||
}
|
||||
fsMu.Unlock()
|
||||
|
||||
res, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
slurp, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modTime := time.Now()
|
||||
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.ServeContent(w, r, path.Base(url), modTime, bytes.NewReader(slurp))
|
||||
})
|
||||
fsMu.Lock()
|
||||
fsCache[url] = h
|
||||
fsMu.Unlock()
|
||||
return h, nil
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
hi.(http.Handler).ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func clockStreamHandler(w http.ResponseWriter, r *http.Request) {
|
||||
clientGone := w.(http.CloseNotifier).CloseNotify()
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
fmt.Fprintf(w, "# ~1KB of junk to force browsers to start rendering immediately: \n")
|
||||
io.WriteString(w, strings.Repeat("# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n", 13))
|
||||
|
||||
for {
|
||||
fmt.Fprintf(w, "%v\n", time.Now())
|
||||
w.(http.Flusher).Flush()
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-clientGone:
|
||||
log.Printf("Client %v disconnected from the clock", r.RemoteAddr)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func registerHandlers() {
|
||||
tiles := newGopherTilesHandler()
|
||||
|
||||
mux2 := http.NewServeMux()
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS == nil {
|
||||
if r.URL.Path == "/gophertiles" {
|
||||
tiles.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, "https://http2.golang.org/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
if r.ProtoMajor == 1 {
|
||||
if r.URL.Path == "/reqinfo" {
|
||||
reqInfoHandler(w, r)
|
||||
return
|
||||
}
|
||||
homeOldHTTP(w, r)
|
||||
return
|
||||
}
|
||||
mux2.ServeHTTP(w, r)
|
||||
})
|
||||
mux2.HandleFunc("/", home)
|
||||
mux2.Handle("/file/gopher.png", fileServer("https://golang.org/doc/gopher/frontpage.png"))
|
||||
mux2.Handle("/file/go.src.tar.gz", fileServer("https://storage.googleapis.com/golang/go1.4.1.src.tar.gz"))
|
||||
mux2.HandleFunc("/reqinfo", reqInfoHandler)
|
||||
mux2.HandleFunc("/crc32", crcHandler)
|
||||
mux2.HandleFunc("/clockstream", clockStreamHandler)
|
||||
mux2.Handle("/gophertiles", tiles)
|
||||
mux2.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
})
|
||||
stripHomedir := regexp.MustCompile(`/(Users|home)/\w+`)
|
||||
mux2.HandleFunc("/goroutines", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
buf := make([]byte, 2<<20)
|
||||
w.Write(stripHomedir.ReplaceAll(buf[:runtime.Stack(buf, true)], nil))
|
||||
})
|
||||
}
|
||||
|
||||
func newGopherTilesHandler() http.Handler {
|
||||
const gopherURL = "https://blog.golang.org/go-programming-language-turns-two_gophers.jpg"
|
||||
res, err := http.Get(gopherURL)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
log.Fatalf("Error fetching %s: %v", gopherURL, res.Status)
|
||||
}
|
||||
slurp, err := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
im, err := jpeg.Decode(bytes.NewReader(slurp))
|
||||
if err != nil {
|
||||
if len(slurp) > 1024 {
|
||||
slurp = slurp[:1024]
|
||||
}
|
||||
log.Fatalf("Failed to decode gopher image: %v (got %q)", err, slurp)
|
||||
}
|
||||
|
||||
type subImager interface {
|
||||
SubImage(image.Rectangle) image.Image
|
||||
}
|
||||
const tileSize = 32
|
||||
xt := im.Bounds().Max.X / tileSize
|
||||
yt := im.Bounds().Max.Y / tileSize
|
||||
var tile [][][]byte // y -> x -> jpeg bytes
|
||||
for yi := 0; yi < yt; yi++ {
|
||||
var row [][]byte
|
||||
for xi := 0; xi < xt; xi++ {
|
||||
si := im.(subImager).SubImage(image.Rectangle{
|
||||
Min: image.Point{xi * tileSize, yi * tileSize},
|
||||
Max: image.Point{(xi + 1) * tileSize, (yi + 1) * tileSize},
|
||||
})
|
||||
buf := new(bytes.Buffer)
|
||||
if err := jpeg.Encode(buf, si, &jpeg.Options{Quality: 90}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
row = append(row, buf.Bytes())
|
||||
}
|
||||
tile = append(tile, row)
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ms, _ := strconv.Atoi(r.FormValue("latency"))
|
||||
const nanosPerMilli = 1e6
|
||||
if r.FormValue("x") != "" {
|
||||
x, _ := strconv.Atoi(r.FormValue("x"))
|
||||
y, _ := strconv.Atoi(r.FormValue("y"))
|
||||
if ms <= 1000 {
|
||||
time.Sleep(time.Duration(ms) * nanosPerMilli)
|
||||
}
|
||||
if x >= 0 && x < xt && y >= 0 && y < yt {
|
||||
http.ServeContent(w, r, "", time.Time{}, bytes.NewReader(tile[y][x]))
|
||||
return
|
||||
}
|
||||
}
|
||||
io.WriteString(w, "<html><body>")
|
||||
fmt.Fprintf(w, "A grid of %d tiled images is below. Compare:<p>", xt*yt)
|
||||
for _, ms := range []int{0, 30, 200, 1000} {
|
||||
d := time.Duration(ms) * nanosPerMilli
|
||||
fmt.Fprintf(w, "[<a href='https://%s/gophertiles?latency=%d'>HTTP/2, %v latency</a>] [<a href='http://%s/gophertiles?latency=%d'>HTTP/1, %v latency</a>]<br>\n",
|
||||
httpsHost(), ms, d,
|
||||
httpHost(), ms, d,
|
||||
)
|
||||
}
|
||||
io.WriteString(w, "<p>\n")
|
||||
cacheBust := time.Now().UnixNano()
|
||||
for y := 0; y < yt; y++ {
|
||||
for x := 0; x < xt; x++ {
|
||||
fmt.Fprintf(w, "<img width=%d height=%d src='/gophertiles?x=%d&y=%d&cachebust=%d&latency=%d'>",
|
||||
tileSize, tileSize, x, y, cacheBust, ms)
|
||||
}
|
||||
io.WriteString(w, "<br/>\n")
|
||||
}
|
||||
io.WriteString(w, "<hr><a href='/'><< Back to Go HTTP/2 demo server</a></body></html>")
|
||||
})
|
||||
}
|
||||
|
||||
func httpsHost() string {
|
||||
if *prod {
|
||||
return "http2.golang.org"
|
||||
}
|
||||
if v := *addr; strings.HasPrefix(v, ":") {
|
||||
return "localhost" + v
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func httpHost() string {
|
||||
if *prod {
|
||||
return "http2.golang.org"
|
||||
}
|
||||
if v := *httpAddr; strings.HasPrefix(v, ":") {
|
||||
return "localhost" + v
|
||||
} else {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
func serveProdTLS() error {
|
||||
c, err := googlestorage.NewServiceClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slurp := func(key string) ([]byte, error) {
|
||||
const bucket = "http2-demo-server-tls"
|
||||
rc, _, err := c.GetObject(&googlestorage.Object{
|
||||
Bucket: bucket,
|
||||
Key: key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error fetching GCS object %q in bucket %q: %v", key, bucket, err)
|
||||
}
|
||||
defer rc.Close()
|
||||
return ioutil.ReadAll(rc)
|
||||
}
|
||||
certPem, err := slurp("http2.golang.org.chained.pem")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
keyPem, err := slurp("http2.golang.org.key")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cert, err := tls.X509KeyPair(certPem, keyPem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv := &http.Server{
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
},
|
||||
}
|
||||
http2.ConfigureServer(srv, &http2.Server{})
|
||||
ln, err := net.Listen("tcp", ":443")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.Serve(tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig))
|
||||
}
|
||||
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tc.SetKeepAlive(true)
|
||||
tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
func serveProd() error {
|
||||
errc := make(chan error, 2)
|
||||
go func() { errc <- http.ListenAndServe(":80", nil) }()
|
||||
go func() { errc <- serveProdTLS() }()
|
||||
return <-errc
|
||||
}
|
||||
|
||||
func main() {
|
||||
var srv http.Server
|
||||
flag.BoolVar(&http2.VerboseLogs, "verbose", false, "Verbose HTTP/2 debugging.")
|
||||
flag.Parse()
|
||||
srv.Addr = *addr
|
||||
|
||||
registerHandlers()
|
||||
|
||||
if *prod {
|
||||
*httpAddr = "http2.golang.org"
|
||||
log.Fatal(serveProd())
|
||||
}
|
||||
|
||||
url := "https://" + *addr + "/"
|
||||
log.Printf("Listening on " + url)
|
||||
http2.ConfigureServer(&srv, &http2.Server{})
|
||||
|
||||
if *httpAddr != "" {
|
||||
go func() { log.Fatal(http.ListenAndServe(*httpAddr, nil)) }()
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Fatal(srv.ListenAndServeTLS("server.crt", "server.key"))
|
||||
}()
|
||||
if *openFirefox && runtime.GOOS == "darwin" {
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
exec.Command("open", "-b", "org.mozilla.nightly", "https://localhost:4430/").Run()
|
||||
}
|
||||
select {}
|
||||
}
|
@ -0,0 +1,302 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
compute "google.golang.org/api/compute/v1"
|
||||
)
|
||||
|
||||
var (
|
||||
proj = flag.String("project", "symbolic-datum-552", "name of Project")
|
||||
zone = flag.String("zone", "us-central1-a", "GCE zone")
|
||||
mach = flag.String("machinetype", "n1-standard-1", "Machine type")
|
||||
instName = flag.String("instance_name", "http2-demo", "Name of VM instance.")
|
||||
sshPub = flag.String("ssh_public_key", "", "ssh public key file to authorize. Can modify later in Google's web UI anyway.")
|
||||
staticIP = flag.String("static_ip", "130.211.116.44", "Static IP to use. If empty, automatic.")
|
||||
|
||||
writeObject = flag.String("write_object", "", "If non-empty, a VM isn't created and the flag value is Google Cloud Storage bucket/object to write. The contents from stdin.")
|
||||
publicObject = flag.Bool("write_object_is_public", false, "Whether the object created by --write_object should be public.")
|
||||
)
|
||||
|
||||
func readFile(v string) string {
|
||||
slurp, err := ioutil.ReadFile(v)
|
||||
if err != nil {
|
||||
log.Fatalf("Error reading %s: %v", v, err)
|
||||
}
|
||||
return strings.TrimSpace(string(slurp))
|
||||
}
|
||||
|
||||
var config = &oauth2.Config{
|
||||
// The client-id and secret should be for an "Installed Application" when using
|
||||
// the CLI. Later we'll use a web application with a callback.
|
||||
ClientID: readFile("client-id.dat"),
|
||||
ClientSecret: readFile("client-secret.dat"),
|
||||
Endpoint: google.Endpoint,
|
||||
Scopes: []string{
|
||||
compute.DevstorageFull_controlScope,
|
||||
compute.ComputeScope,
|
||||
"https://www.googleapis.com/auth/sqlservice",
|
||||
"https://www.googleapis.com/auth/sqlservice.admin",
|
||||
},
|
||||
RedirectURL: "urn:ietf:wg:oauth:2.0:oob",
|
||||
}
|
||||
|
||||
const baseConfig = `#cloud-config
|
||||
coreos:
|
||||
units:
|
||||
- name: h2demo.service
|
||||
command: start
|
||||
content: |
|
||||
[Unit]
|
||||
Description=HTTP2 Demo
|
||||
|
||||
[Service]
|
||||
ExecStartPre=/bin/bash -c 'mkdir -p /opt/bin && curl -s -o /opt/bin/h2demo http://storage.googleapis.com/http2-demo-server-tls/h2demo && chmod +x /opt/bin/h2demo'
|
||||
ExecStart=/opt/bin/h2demo --prod
|
||||
RestartSec=5s
|
||||
Restart=always
|
||||
Type=simple
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
`
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
if *proj == "" {
|
||||
log.Fatalf("Missing --project flag")
|
||||
}
|
||||
prefix := "https://www.googleapis.com/compute/v1/projects/" + *proj
|
||||
machType := prefix + "/zones/" + *zone + "/machineTypes/" + *mach
|
||||
|
||||
const tokenFileName = "token.dat"
|
||||
tokenFile := tokenCacheFile(tokenFileName)
|
||||
tokenSource := oauth2.ReuseTokenSource(nil, tokenFile)
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
if *writeObject != "" {
|
||||
log.Fatalf("Can't use --write_object without a valid token.dat file already cached.")
|
||||
}
|
||||
log.Printf("Error getting token from %s: %v", tokenFileName, err)
|
||||
log.Printf("Get auth code from %v", config.AuthCodeURL("my-state"))
|
||||
fmt.Print("\nEnter auth code: ")
|
||||
sc := bufio.NewScanner(os.Stdin)
|
||||
sc.Scan()
|
||||
authCode := strings.TrimSpace(sc.Text())
|
||||
token, err = config.Exchange(oauth2.NoContext, authCode)
|
||||
if err != nil {
|
||||
log.Fatalf("Error exchanging auth code for a token: %v", err)
|
||||
}
|
||||
if err := tokenFile.WriteToken(token); err != nil {
|
||||
log.Fatalf("Error writing to %s: %v", tokenFileName, err)
|
||||
}
|
||||
tokenSource = oauth2.ReuseTokenSource(token, nil)
|
||||
}
|
||||
|
||||
oauthClient := oauth2.NewClient(oauth2.NoContext, tokenSource)
|
||||
|
||||
if *writeObject != "" {
|
||||
writeCloudStorageObject(oauthClient)
|
||||
return
|
||||
}
|
||||
|
||||
computeService, _ := compute.New(oauthClient)
|
||||
|
||||
natIP := *staticIP
|
||||
if natIP == "" {
|
||||
// Try to find it by name.
|
||||
aggAddrList, err := computeService.Addresses.AggregatedList(*proj).Do()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// http://godoc.org/code.google.com/p/google-api-go-client/compute/v1#AddressAggregatedList
|
||||
IPLoop:
|
||||
for _, asl := range aggAddrList.Items {
|
||||
for _, addr := range asl.Addresses {
|
||||
if addr.Name == *instName+"-ip" && addr.Status == "RESERVED" {
|
||||
natIP = addr.Address
|
||||
break IPLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cloudConfig := baseConfig
|
||||
if *sshPub != "" {
|
||||
key := strings.TrimSpace(readFile(*sshPub))
|
||||
cloudConfig += fmt.Sprintf("\nssh_authorized_keys:\n - %s\n", key)
|
||||
}
|
||||
if os.Getenv("USER") == "bradfitz" {
|
||||
cloudConfig += fmt.Sprintf("\nssh_authorized_keys:\n - %s\n", "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEAwks9dwWKlRC+73gRbvYtVg0vdCwDSuIlyt4z6xa/YU/jTDynM4R4W10hm2tPjy8iR1k8XhDv4/qdxe6m07NjG/By1tkmGpm1mGwho4Pr5kbAAy/Qg+NLCSdAYnnE00FQEcFOC15GFVMOW2AzDGKisReohwH9eIzHPzdYQNPRWXE= bradfitz@papag.bradfitz.com")
|
||||
}
|
||||
const maxCloudConfig = 32 << 10 // per compute API docs
|
||||
if len(cloudConfig) > maxCloudConfig {
|
||||
log.Fatalf("cloud config length of %d bytes is over %d byte limit", len(cloudConfig), maxCloudConfig)
|
||||
}
|
||||
|
||||
instance := &compute.Instance{
|
||||
Name: *instName,
|
||||
Description: "Go Builder",
|
||||
MachineType: machType,
|
||||
Disks: []*compute.AttachedDisk{instanceDisk(computeService)},
|
||||
Tags: &compute.Tags{
|
||||
Items: []string{"http-server", "https-server"},
|
||||
},
|
||||
Metadata: &compute.Metadata{
|
||||
Items: []*compute.MetadataItems{
|
||||
{
|
||||
Key: "user-data",
|
||||
Value: cloudConfig,
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkInterfaces: []*compute.NetworkInterface{
|
||||
&compute.NetworkInterface{
|
||||
AccessConfigs: []*compute.AccessConfig{
|
||||
&compute.AccessConfig{
|
||||
Type: "ONE_TO_ONE_NAT",
|
||||
Name: "External NAT",
|
||||
NatIP: natIP,
|
||||
},
|
||||
},
|
||||
Network: prefix + "/global/networks/default",
|
||||
},
|
||||
},
|
||||
ServiceAccounts: []*compute.ServiceAccount{
|
||||
{
|
||||
Email: "default",
|
||||
Scopes: []string{
|
||||
compute.DevstorageFull_controlScope,
|
||||
compute.ComputeScope,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
log.Printf("Creating instance...")
|
||||
op, err := computeService.Instances.Insert(*proj, *zone, instance).Do()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create instance: %v", err)
|
||||
}
|
||||
opName := op.Name
|
||||
log.Printf("Created. Waiting on operation %v", opName)
|
||||
OpLoop:
|
||||
for {
|
||||
time.Sleep(2 * time.Second)
|
||||
op, err := computeService.ZoneOperations.Get(*proj, *zone, opName).Do()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get op %s: %v", opName, err)
|
||||
}
|
||||
switch op.Status {
|
||||
case "PENDING", "RUNNING":
|
||||
log.Printf("Waiting on operation %v", opName)
|
||||
continue
|
||||
case "DONE":
|
||||
if op.Error != nil {
|
||||
for _, operr := range op.Error.Errors {
|
||||
log.Printf("Error: %+v", operr)
|
||||
}
|
||||
log.Fatalf("Failed to start.")
|
||||
}
|
||||
log.Printf("Success. %+v", op)
|
||||
break OpLoop
|
||||
default:
|
||||
log.Fatalf("Unknown status %q: %+v", op.Status, op)
|
||||
}
|
||||
}
|
||||
|
||||
inst, err := computeService.Instances.Get(*proj, *zone, *instName).Do()
|
||||
if err != nil {
|
||||
log.Fatalf("Error getting instance after creation: %v", err)
|
||||
}
|
||||
ij, _ := json.MarshalIndent(inst, "", " ")
|
||||
log.Printf("Instance: %s", ij)
|
||||
}
|
||||
|
||||
func instanceDisk(svc *compute.Service) *compute.AttachedDisk {
|
||||
const imageURL = "https://www.googleapis.com/compute/v1/projects/coreos-cloud/global/images/coreos-stable-444-5-0-v20141016"
|
||||
diskName := *instName + "-disk"
|
||||
|
||||
return &compute.AttachedDisk{
|
||||
AutoDelete: true,
|
||||
Boot: true,
|
||||
Type: "PERSISTENT",
|
||||
InitializeParams: &compute.AttachedDiskInitializeParams{
|
||||
DiskName: diskName,
|
||||
SourceImage: imageURL,
|
||||
DiskSizeGb: 50,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func writeCloudStorageObject(httpClient *http.Client) {
|
||||
content := os.Stdin
|
||||
const maxSlurp = 1 << 20
|
||||
var buf bytes.Buffer
|
||||
n, err := io.CopyN(&buf, content, maxSlurp)
|
||||
if err != nil && err != io.EOF {
|
||||
log.Fatalf("Error reading from stdin: %v, %v", n, err)
|
||||
}
|
||||
contentType := http.DetectContentType(buf.Bytes())
|
||||
|
||||
req, err := http.NewRequest("PUT", "https://storage.googleapis.com/"+*writeObject, io.MultiReader(&buf, content))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
req.Header.Set("x-goog-api-version", "2")
|
||||
if *publicObject {
|
||||
req.Header.Set("x-goog-acl", "public-read")
|
||||
}
|
||||
req.Header.Set("Content-Type", contentType)
|
||||
res, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
res.Write(os.Stderr)
|
||||
log.Fatalf("Failed.")
|
||||
}
|
||||
log.Printf("Success.")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
type tokenCacheFile string
|
||||
|
||||
func (f tokenCacheFile) Token() (*oauth2.Token, error) {
|
||||
slurp, err := ioutil.ReadFile(string(f))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := new(oauth2.Token)
|
||||
if err := json.Unmarshal(slurp, t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (f tokenCacheFile) WriteToken(t *oauth2.Token) error {
|
||||
jt, err := json.Marshal(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ioutil.WriteFile(string(f), jt, 0600)
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEowIBAAKCAQEAt5fAjp4fTcekWUTfzsp0kyih1OYbsGL0KX1eRbSSR8Od0+9Q
|
||||
62Hyny+GFwMTb4A/KU8mssoHvcceSAAbwfbxFK/+s51TobqUnORZrOoTZjkUygby
|
||||
XDSK99YBbcR1Pip8vwMTm4XKuLtCigeBBdjjAQdgUO28LENGlsMnmeYkJfODVGnV
|
||||
mr5Ltb9ANA8IKyTfsnHJ4iOCS/PlPbUj2q7YnoVLposUBMlgUb/CykX3mOoLb4yJ
|
||||
JQyA/iST6ZxiIEj36D4yWZ5lg7YJl+UiiBQHGCnPdGyipqV06ex0heYWcaiW8LWZ
|
||||
SUQ93jQ+WVCH8hT7DQO1dmsvUmXlq/JeAlwQ/QIDAQABAoIBAFFHV7JMAqPWnMYA
|
||||
nezY6J81v9+XN+7xABNWM2Q8uv4WdksbigGLTXR3/680Z2hXqJ7LMeC5XJACFT/e
|
||||
/Gr0vmpgOCygnCPfjGehGKpavtfksXV3edikUlnCXsOP1C//c1bFL+sMYmFCVgTx
|
||||
qYdDK8yKzXNGrKYT6q5YG7IglyRNV1rsQa8lM/5taFYiD1Ck/3tQi3YIq8Lcuser
|
||||
hrxsMABcQ6mi+EIvG6Xr4mfJug0dGJMHG4RG1UGFQn6RXrQq2+q53fC8ZbVUSi0j
|
||||
NQ918aKFzktwv+DouKU0ME4I9toks03gM860bAL7zCbKGmwR3hfgX/TqzVCWpG9E
|
||||
LDVfvekCgYEA8fk9N53jbBRmULUGEf4qWypcLGiZnNU0OeXWpbPV9aa3H0VDytA7
|
||||
8fCN2dPAVDPqlthMDdVe983NCNwp2Yo8ZimDgowyIAKhdC25s1kejuaiH9OAPj3c
|
||||
0f8KbriYX4n8zNHxFwK6Ae3pQ6EqOLJVCUsziUaZX9nyKY5aZlyX6xcCgYEAwjws
|
||||
K62PjC64U5wYddNLp+kNdJ4edx+a7qBb3mEgPvSFT2RO3/xafJyG8kQB30Mfstjd
|
||||
bRxyUV6N0vtX1zA7VQtRUAvfGCecpMo+VQZzcHXKzoRTnQ7eZg4Lmj5fQ9tOAKAo
|
||||
QCVBoSW/DI4PZL26CAMDcAba4Pa22ooLapoRIQsCgYA6pIfkkbxLNkpxpt2YwLtt
|
||||
Kr/590O7UaR9n6k8sW/aQBRDXNsILR1KDl2ifAIxpf9lnXgZJiwE7HiTfCAcW7c1
|
||||
nzwDCI0hWuHcMTS/NYsFYPnLsstyyjVZI3FY0h4DkYKV9Q9z3zJLQ2hz/nwoD3gy
|
||||
b2pHC7giFcTts1VPV4Nt8wKBgHeFn4ihHJweg76vZz3Z78w7VNRWGFklUalVdDK7
|
||||
gaQ7w2y/ROn/146mo0OhJaXFIFRlrpvdzVrU3GDf2YXJYDlM5ZRkObwbZADjksev
|
||||
WInzcgDy3KDg7WnPasRXbTfMU4t/AkW2p1QKbi3DnSVYuokDkbH2Beo45vxDxhKr
|
||||
C69RAoGBAIyo3+OJenoZmoNzNJl2WPW5MeBUzSh8T/bgyjFTdqFHF5WiYRD/lfHj
|
||||
x9Glyw2nutuT4hlOqHvKhgTYdDMsF2oQ72fe3v8Q5FU7FuKndNPEAyvKNXZaShVA
|
||||
hnlhv5DjXKb0wFWnt5PCCiQLtzG0yyHaITrrEme7FikkIcTxaX/Y
|
||||
-----END RSA PRIVATE KEY-----
|
@ -0,0 +1,26 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIEWjCCA0KgAwIBAgIJALfRlWsI8YQHMA0GCSqGSIb3DQEBBQUAMHsxCzAJBgNV
|
||||
BAYTAlVTMQswCQYDVQQIEwJDQTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEUMBIG
|
||||
A1UEChMLQnJhZGZpdHppbmMxEjAQBgNVBAMTCWxvY2FsaG9zdDEdMBsGCSqGSIb3
|
||||
DQEJARYOYnJhZEBkYW5nYS5jb20wHhcNMTQwNzE1MjA0NjA1WhcNMTcwNTA0MjA0
|
||||
NjA1WjB7MQswCQYDVQQGEwJVUzELMAkGA1UECBMCQ0ExFjAUBgNVBAcTDVNhbiBG
|
||||
cmFuY2lzY28xFDASBgNVBAoTC0JyYWRmaXR6aW5jMRIwEAYDVQQDEwlsb2NhbGhv
|
||||
c3QxHTAbBgkqhkiG9w0BCQEWDmJyYWRAZGFuZ2EuY29tMIIBIjANBgkqhkiG9w0B
|
||||
AQEFAAOCAQ8AMIIBCgKCAQEAt5fAjp4fTcekWUTfzsp0kyih1OYbsGL0KX1eRbSS
|
||||
R8Od0+9Q62Hyny+GFwMTb4A/KU8mssoHvcceSAAbwfbxFK/+s51TobqUnORZrOoT
|
||||
ZjkUygbyXDSK99YBbcR1Pip8vwMTm4XKuLtCigeBBdjjAQdgUO28LENGlsMnmeYk
|
||||
JfODVGnVmr5Ltb9ANA8IKyTfsnHJ4iOCS/PlPbUj2q7YnoVLposUBMlgUb/CykX3
|
||||
mOoLb4yJJQyA/iST6ZxiIEj36D4yWZ5lg7YJl+UiiBQHGCnPdGyipqV06ex0heYW
|
||||
caiW8LWZSUQ93jQ+WVCH8hT7DQO1dmsvUmXlq/JeAlwQ/QIDAQABo4HgMIHdMB0G
|
||||
A1UdDgQWBBRcAROthS4P4U7vTfjByC569R7E6DCBrQYDVR0jBIGlMIGigBRcAROt
|
||||
hS4P4U7vTfjByC569R7E6KF/pH0wezELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNB
|
||||
MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRQwEgYDVQQKEwtCcmFkZml0emluYzES
|
||||
MBAGA1UEAxMJbG9jYWxob3N0MR0wGwYJKoZIhvcNAQkBFg5icmFkQGRhbmdhLmNv
|
||||
bYIJALfRlWsI8YQHMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAG6h
|
||||
U9f9sNH0/6oBbGGy2EVU0UgITUQIrFWo9rFkrW5k/XkDjQm+3lzjT0iGR4IxE/Ao
|
||||
eU6sQhua7wrWeFEn47GL98lnCsJdD7oZNhFmQ95Tb/LnDUjs5Yj9brP0NWzXfYU4
|
||||
UK2ZnINJRcJpB8iRCaCxE8DdcUF0XqIEq6pA272snoLmiXLMvNl3kYEdm+je6voD
|
||||
58SNVEUsztzQyXmJEhCpwVI0A6QCjzXj+qvpmw3ZZHi8JwXei8ZZBLTSFBki8Z7n
|
||||
sH9BBH38/SzUmAN4QHSPy1gjqm00OAE8NaYDkh/bzE4d7mLGGMWp/WE3KPSu82HF
|
||||
kPe6XoSbiLm/kxk32T0=
|
||||
-----END CERTIFICATE-----
|
@ -0,0 +1 @@
|
||||
E2CE26BF3285059C
|
@ -0,0 +1,20 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDPjCCAiYCCQDizia/MoUFnDANBgkqhkiG9w0BAQUFADB7MQswCQYDVQQGEwJV
|
||||
UzELMAkGA1UECBMCQ0ExFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xFDASBgNVBAoT
|
||||
C0JyYWRmaXR6aW5jMRIwEAYDVQQDEwlsb2NhbGhvc3QxHTAbBgkqhkiG9w0BCQEW
|
||||
DmJyYWRAZGFuZ2EuY29tMB4XDTE0MDcxNTIwNTAyN1oXDTE1MTEyNzIwNTAyN1ow
|
||||
RzELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNBMQswCQYDVQQHEwJTRjEeMBwGA1UE
|
||||
ChMVYnJhZGZpdHogaHR0cDIgc2VydmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A
|
||||
MIIBCgKCAQEAs1Y9CyLFrdL8VQWN1WaifDqaZFnoqjHhCMlc1TfG2zA+InDifx2l
|
||||
gZD3o8FeNnAcfM2sPlk3+ZleOYw9P/CklFVDlvqmpCv9ss/BEp/dDaWvy1LmJ4c2
|
||||
dbQJfmTxn7CV1H3TsVJvKdwFmdoABb41NoBp6+NNO7OtDyhbIMiCI0pL3Nefb3HL
|
||||
A7hIMo3DYbORTtJLTIH9W8YKrEWL0lwHLrYFx/UdutZnv+HjdmO6vCN4na55mjws
|
||||
/vjKQUmc7xeY7Xe20xDEG2oDKVkL2eD7FfyrYMS3rO1ExP2KSqlXYG/1S9I/fz88
|
||||
F0GK7HX55b5WjZCl2J3ERVdnv/0MQv+sYQIDAQABMA0GCSqGSIb3DQEBBQUAA4IB
|
||||
AQC0zL+n/YpRZOdulSu9tS8FxrstXqGWoxfe+vIUgqfMZ5+0MkjJ/vW0FqlLDl2R
|
||||
rn4XaR3e7FmWkwdDVbq/UB6lPmoAaFkCgh9/5oapMaclNVNnfF3fjCJfRr+qj/iD
|
||||
EmJStTIN0ZuUjAlpiACmfnpEU55PafT5Zx+i1yE4FGjw8bJpFoyD4Hnm54nGjX19
|
||||
KeCuvcYFUPnBm3lcL0FalF2AjqV02WTHYNQk7YF/oeO7NKBoEgvGvKG3x+xaOeBI
|
||||
dwvdq175ZsGul30h+QjrRlXhH/twcuaT3GSdoysDl9cCYE8f1Mk8PD6gan3uBCJU
|
||||
90p6/CbU71bGbfpM2PHot2fm
|
||||
-----END CERTIFICATE-----
|
@ -0,0 +1,27 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEowIBAAKCAQEAs1Y9CyLFrdL8VQWN1WaifDqaZFnoqjHhCMlc1TfG2zA+InDi
|
||||
fx2lgZD3o8FeNnAcfM2sPlk3+ZleOYw9P/CklFVDlvqmpCv9ss/BEp/dDaWvy1Lm
|
||||
J4c2dbQJfmTxn7CV1H3TsVJvKdwFmdoABb41NoBp6+NNO7OtDyhbIMiCI0pL3Nef
|
||||
b3HLA7hIMo3DYbORTtJLTIH9W8YKrEWL0lwHLrYFx/UdutZnv+HjdmO6vCN4na55
|
||||
mjws/vjKQUmc7xeY7Xe20xDEG2oDKVkL2eD7FfyrYMS3rO1ExP2KSqlXYG/1S9I/
|
||||
fz88F0GK7HX55b5WjZCl2J3ERVdnv/0MQv+sYQIDAQABAoIBADQ2spUwbY+bcz4p
|
||||
3M66ECrNQTBggP40gYl2XyHxGGOu2xhZ94f9ELf1hjRWU2DUKWco1rJcdZClV6q3
|
||||
qwmXvcM2Q/SMS8JW0ImkNVl/0/NqPxGatEnj8zY30d/L8hGFb0orzFu/XYA5gCP4
|
||||
NbN2WrXgk3ZLeqwcNxHHtSiJWGJ/fPyeDWAu/apy75u9Xf2GlzBZmV6HYD9EfK80
|
||||
LTlI60f5FO487CrJnboL7ovPJrIHn+k05xRQqwma4orpz932rTXnTjs9Lg6KtbQN
|
||||
a7PrqfAntIISgr11a66Mng3IYH1lYqJsWJJwX/xHT4WLEy0EH4/0+PfYemJekz2+
|
||||
Co62drECgYEA6O9zVJZXrLSDsIi54cfxA7nEZWm5CAtkYWeAHa4EJ+IlZ7gIf9sL
|
||||
W8oFcEfFGpvwVqWZ+AsQ70dsjXAv3zXaG0tmg9FtqWp7pzRSMPidifZcQwWkKeTO
|
||||
gJnFmnVyed8h6GfjTEu4gxo1/S5U0V+mYSha01z5NTnN6ltKx1Or3b0CgYEAxRgm
|
||||
S30nZxnyg/V7ys61AZhst1DG2tkZXEMcA7dYhabMoXPJAP/EfhlWwpWYYUs/u0gS
|
||||
Wwmf5IivX5TlYScgmkvb/NYz0u4ZmOXkLTnLPtdKKFXhjXJcHjUP67jYmOxNlJLp
|
||||
V4vLRnFxTpffAV+OszzRxsXX6fvruwZBANYJeXUCgYBVouLFsFgfWGYp2rpr9XP4
|
||||
KK25kvrBqF6JKOIDB1zjxNJ3pUMKrl8oqccCFoCyXa4oTM2kUX0yWxHfleUjrMq4
|
||||
yimwQKiOZmV7fVLSSjSw6e/VfBd0h3gb82ygcplZkN0IclkwTY5SNKqwn/3y07V5
|
||||
drqdhkrgdJXtmQ6O5YYECQKBgATERcDToQ1USlI4sKrB/wyv1AlG8dg/IebiVJ4e
|
||||
ZAyvcQmClFzq0qS+FiQUnB/WQw9TeeYrwGs1hxBHuJh16srwhLyDrbMvQP06qh8R
|
||||
48F8UXXSRec22dV9MQphaROhu2qZdv1AC0WD3tqov6L33aqmEOi+xi8JgbT/PLk5
|
||||
c/c1AoGBAI1A/02ryksW6/wc7/6SP2M2rTy4m1sD/GnrTc67EHnRcVBdKO6qH2RY
|
||||
nqC8YcveC2ZghgPTDsA3VGuzuBXpwY6wTyV99q6jxQJ6/xcrD9/NUG6Uwv/xfCxl
|
||||
IJLeBYEqQundSSny3VtaAUK8Ul1nxpTvVRNwtcyWTo8RHAAyNPWd
|
||||
-----END RSA PRIVATE KEY-----
|
@ -0,0 +1,97 @@
|
||||
# h2i
|
||||
|
||||
**h2i** is an interactive HTTP/2 ("h2") console debugger. Miss the good ol'
|
||||
days of telnetting to your HTTP/1.n servers? We're bringing you
|
||||
back.
|
||||
|
||||
Features:
|
||||
- send raw HTTP/2 frames
|
||||
- PING
|
||||
- SETTINGS
|
||||
- HEADERS
|
||||
- etc
|
||||
- type in HTTP/1.n and have it auto-HPACK/frame-ify it for HTTP/2
|
||||
- pretty print all received HTTP/2 frames from the peer (including HPACK decoding)
|
||||
- tab completion of commands, options
|
||||
|
||||
Not yet features, but soon:
|
||||
- unnecessary CONTINUATION frames on short boundaries, to test peer implementations
|
||||
- request bodies (DATA frames)
|
||||
- send invalid frames for testing server implementations (supported by underlying Framer)
|
||||
|
||||
Later:
|
||||
- act like a server
|
||||
|
||||
## Installation
|
||||
|
||||
```
|
||||
$ go get github.com/bradfitz/http2/h2i
|
||||
$ h2i <host>
|
||||
```
|
||||
|
||||
## Demo
|
||||
|
||||
```
|
||||
$ h2i
|
||||
Usage: h2i <hostname>
|
||||
|
||||
-insecure
|
||||
Whether to skip TLS cert validation
|
||||
-nextproto string
|
||||
Comma-separated list of NPN/ALPN protocol names to negotiate. (default "h2,h2-14")
|
||||
|
||||
$ h2i google.com
|
||||
Connecting to google.com:443 ...
|
||||
Connected to 74.125.224.41:443
|
||||
Negotiated protocol "h2-14"
|
||||
[FrameHeader SETTINGS len=18]
|
||||
[MAX_CONCURRENT_STREAMS = 100]
|
||||
[INITIAL_WINDOW_SIZE = 1048576]
|
||||
[MAX_FRAME_SIZE = 16384]
|
||||
[FrameHeader WINDOW_UPDATE len=4]
|
||||
Window-Increment = 983041
|
||||
|
||||
h2i> PING h2iSayHI
|
||||
[FrameHeader PING flags=ACK len=8]
|
||||
Data = "h2iSayHI"
|
||||
h2i> headers
|
||||
(as HTTP/1.1)> GET / HTTP/1.1
|
||||
(as HTTP/1.1)> Host: ip.appspot.com
|
||||
(as HTTP/1.1)> User-Agent: h2i/brad-n-blake
|
||||
(as HTTP/1.1)>
|
||||
Opening Stream-ID 1:
|
||||
:authority = ip.appspot.com
|
||||
:method = GET
|
||||
:path = /
|
||||
:scheme = https
|
||||
user-agent = h2i/brad-n-blake
|
||||
[FrameHeader HEADERS flags=END_HEADERS stream=1 len=77]
|
||||
:status = "200"
|
||||
alternate-protocol = "443:quic,p=1"
|
||||
content-length = "15"
|
||||
content-type = "text/html"
|
||||
date = "Fri, 01 May 2015 23:06:56 GMT"
|
||||
server = "Google Frontend"
|
||||
[FrameHeader DATA flags=END_STREAM stream=1 len=15]
|
||||
"173.164.155.78\n"
|
||||
[FrameHeader PING len=8]
|
||||
Data = "\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
h2i> ping
|
||||
[FrameHeader PING flags=ACK len=8]
|
||||
Data = "h2i_ping"
|
||||
h2i> ping
|
||||
[FrameHeader PING flags=ACK len=8]
|
||||
Data = "h2i_ping"
|
||||
h2i> ping
|
||||
[FrameHeader GOAWAY len=22]
|
||||
Last-Stream-ID = 1; Error-Code = PROTOCOL_ERROR (1)
|
||||
|
||||
ReadFrame: EOF
|
||||
```
|
||||
|
||||
## Status
|
||||
|
||||
Quick few hour hack. So much yet to do. Feel free to file issues for
|
||||
bugs or wishlist items, but [@bmizerany](https://github.com/bmizerany/)
|
||||
and I aren't yet accepting pull requests until things settle down.
|
||||
|
@ -0,0 +1,489 @@
|
||||
// Copyright 2015 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
/*
|
||||
The h2i command is an interactive HTTP/2 console.
|
||||
|
||||
Usage:
|
||||
$ h2i [flags] <hostname>
|
||||
|
||||
Interactive commands in the console: (all parts case-insensitive)
|
||||
|
||||
ping [data]
|
||||
settings ack
|
||||
settings FOO=n BAR=z
|
||||
headers (open a new stream by typing HTTP/1.1)
|
||||
*/
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bradfitz/http2"
|
||||
"github.com/bradfitz/http2/hpack"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
// Flags
|
||||
var (
|
||||
flagNextProto = flag.String("nextproto", "h2,h2-14", "Comma-separated list of NPN/ALPN protocol names to negotiate.")
|
||||
flagInsecure = flag.Bool("insecure", false, "Whether to skip TLS cert validation")
|
||||
)
|
||||
|
||||
type command struct {
|
||||
run func(*h2i, []string) error // required
|
||||
|
||||
// complete optionally specifies tokens (case-insensitive) which are
|
||||
// valid for this subcommand.
|
||||
complete func() []string
|
||||
}
|
||||
|
||||
var commands = map[string]command{
|
||||
"ping": command{run: (*h2i).cmdPing},
|
||||
"settings": command{
|
||||
run: (*h2i).cmdSettings,
|
||||
complete: func() []string {
|
||||
return []string{
|
||||
"ACK",
|
||||
http2.SettingHeaderTableSize.String(),
|
||||
http2.SettingEnablePush.String(),
|
||||
http2.SettingMaxConcurrentStreams.String(),
|
||||
http2.SettingInitialWindowSize.String(),
|
||||
http2.SettingMaxFrameSize.String(),
|
||||
http2.SettingMaxHeaderListSize.String(),
|
||||
}
|
||||
},
|
||||
},
|
||||
"quit": command{run: (*h2i).cmdQuit},
|
||||
"headers": command{run: (*h2i).cmdHeaders},
|
||||
}
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: h2i <hostname>\n\n")
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// withPort adds ":443" if another port isn't already present.
|
||||
func withPort(host string) string {
|
||||
if _, _, err := net.SplitHostPort(host); err != nil {
|
||||
return net.JoinHostPort(host, "443")
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// h2i is the app's state.
|
||||
type h2i struct {
|
||||
host string
|
||||
tc *tls.Conn
|
||||
framer *http2.Framer
|
||||
term *terminal.Terminal
|
||||
|
||||
// owned by the command loop:
|
||||
streamID uint32
|
||||
hbuf bytes.Buffer
|
||||
henc *hpack.Encoder
|
||||
|
||||
// owned by the readFrames loop:
|
||||
peerSetting map[http2.SettingID]uint32
|
||||
hdec *hpack.Decoder
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = usage
|
||||
flag.Parse()
|
||||
if flag.NArg() != 1 {
|
||||
usage()
|
||||
}
|
||||
log.SetFlags(0)
|
||||
|
||||
host := flag.Arg(0)
|
||||
app := &h2i{
|
||||
host: host,
|
||||
peerSetting: make(map[http2.SettingID]uint32),
|
||||
}
|
||||
app.henc = hpack.NewEncoder(&app.hbuf)
|
||||
|
||||
if err := app.Main(); err != nil {
|
||||
if app.term != nil {
|
||||
app.logf("%v\n", err)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Fprintf(os.Stdout, "\n")
|
||||
}
|
||||
|
||||
func (app *h2i) Main() error {
|
||||
cfg := &tls.Config{
|
||||
ServerName: app.host,
|
||||
NextProtos: strings.Split(*flagNextProto, ","),
|
||||
InsecureSkipVerify: *flagInsecure,
|
||||
}
|
||||
|
||||
hostAndPort := withPort(app.host)
|
||||
log.Printf("Connecting to %s ...", hostAndPort)
|
||||
tc, err := tls.Dial("tcp", hostAndPort, cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error dialing %s: %v", withPort(app.host), err)
|
||||
}
|
||||
log.Printf("Connected to %v", tc.RemoteAddr())
|
||||
defer tc.Close()
|
||||
|
||||
if err := tc.Handshake(); err != nil {
|
||||
return fmt.Errorf("TLS handshake: %v", err)
|
||||
}
|
||||
if !*flagInsecure {
|
||||
if err := tc.VerifyHostname(app.host); err != nil {
|
||||
return fmt.Errorf("VerifyHostname: %v", err)
|
||||
}
|
||||
}
|
||||
state := tc.ConnectionState()
|
||||
log.Printf("Negotiated protocol %q", state.NegotiatedProtocol)
|
||||
if !state.NegotiatedProtocolIsMutual || state.NegotiatedProtocol == "" {
|
||||
return fmt.Errorf("Could not negotiate protocol mutually")
|
||||
}
|
||||
|
||||
if _, err := io.WriteString(tc, http2.ClientPreface); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
app.framer = http2.NewFramer(tc, tc)
|
||||
|
||||
oldState, err := terminal.MakeRaw(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer terminal.Restore(0, oldState)
|
||||
|
||||
var screen = struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{os.Stdin, os.Stdout}
|
||||
|
||||
app.term = terminal.NewTerminal(screen, "h2i> ")
|
||||
lastWord := regexp.MustCompile(`.+\W(\w+)$`)
|
||||
app.term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) {
|
||||
if key != '\t' {
|
||||
return
|
||||
}
|
||||
if pos != len(line) {
|
||||
// TODO: we're being lazy for now, only supporting tab completion at the end.
|
||||
return
|
||||
}
|
||||
// Auto-complete for the command itself.
|
||||
if !strings.Contains(line, " ") {
|
||||
var name string
|
||||
name, _, ok = lookupCommand(line)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
return name, len(name), true
|
||||
}
|
||||
_, c, ok := lookupCommand(line[:strings.IndexByte(line, ' ')])
|
||||
if !ok || c.complete == nil {
|
||||
return
|
||||
}
|
||||
if strings.HasSuffix(line, " ") {
|
||||
app.logf("%s", strings.Join(c.complete(), " "))
|
||||
return line, pos, true
|
||||
}
|
||||
m := lastWord.FindStringSubmatch(line)
|
||||
if m == nil {
|
||||
return line, len(line), true
|
||||
}
|
||||
soFar := m[1]
|
||||
var match []string
|
||||
for _, cand := range c.complete() {
|
||||
if len(soFar) > len(cand) || !strings.EqualFold(cand[:len(soFar)], soFar) {
|
||||
continue
|
||||
}
|
||||
match = append(match, cand)
|
||||
}
|
||||
if len(match) == 0 {
|
||||
return
|
||||
}
|
||||
if len(match) > 1 {
|
||||
// TODO: auto-complete any common prefix
|
||||
app.logf("%s", strings.Join(match, " "))
|
||||
return line, pos, true
|
||||
}
|
||||
newLine = line[:len(line)-len(soFar)] + match[0]
|
||||
return newLine, len(newLine), true
|
||||
|
||||
}
|
||||
|
||||
errc := make(chan error, 2)
|
||||
go func() { errc <- app.readFrames() }()
|
||||
go func() { errc <- app.readConsole() }()
|
||||
return <-errc
|
||||
}
|
||||
|
||||
func (app *h2i) logf(format string, args ...interface{}) {
|
||||
fmt.Fprintf(app.term, format+"\n", args...)
|
||||
}
|
||||
|
||||
func (app *h2i) readConsole() error {
|
||||
for {
|
||||
line, err := app.term.ReadLine()
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("terminal.ReadLine: %v", err)
|
||||
}
|
||||
f := strings.Fields(line)
|
||||
if len(f) == 0 {
|
||||
continue
|
||||
}
|
||||
cmd, args := f[0], f[1:]
|
||||
if _, c, ok := lookupCommand(cmd); ok {
|
||||
err = c.run(app, args)
|
||||
} else {
|
||||
app.logf("Unknown command %q", line)
|
||||
}
|
||||
if err == errExitApp {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func lookupCommand(prefix string) (name string, c command, ok bool) {
|
||||
prefix = strings.ToLower(prefix)
|
||||
if c, ok = commands[prefix]; ok {
|
||||
return prefix, c, ok
|
||||
}
|
||||
|
||||
for full, candidate := range commands {
|
||||
if strings.HasPrefix(full, prefix) {
|
||||
if c.run != nil {
|
||||
return "", command{}, false // ambiguous
|
||||
}
|
||||
c = candidate
|
||||
name = full
|
||||
}
|
||||
}
|
||||
return name, c, c.run != nil
|
||||
}
|
||||
|
||||
var errExitApp = errors.New("internal sentinel error value to quit the console reading loop")
|
||||
|
||||
func (a *h2i) cmdQuit(args []string) error {
|
||||
if len(args) > 0 {
|
||||
a.logf("the QUIT command takes no argument")
|
||||
return nil
|
||||
}
|
||||
return errExitApp
|
||||
}
|
||||
|
||||
func (a *h2i) cmdSettings(args []string) error {
|
||||
if len(args) == 1 && strings.EqualFold(args[0], "ACK") {
|
||||
return a.framer.WriteSettingsAck()
|
||||
}
|
||||
var settings []http2.Setting
|
||||
for _, arg := range args {
|
||||
if strings.EqualFold(arg, "ACK") {
|
||||
a.logf("Error: ACK must be only argument with the SETTINGS command")
|
||||
return nil
|
||||
}
|
||||
eq := strings.Index(arg, "=")
|
||||
if eq == -1 {
|
||||
a.logf("Error: invalid argument %q (expected SETTING_NAME=nnnn)", arg)
|
||||
return nil
|
||||
}
|
||||
sid, ok := settingByName(arg[:eq])
|
||||
if !ok {
|
||||
a.logf("Error: unknown setting name %q", arg[:eq])
|
||||
return nil
|
||||
}
|
||||
val, err := strconv.ParseUint(arg[eq+1:], 10, 32)
|
||||
if err != nil {
|
||||
a.logf("Error: invalid argument %q (expected SETTING_NAME=nnnn)", arg)
|
||||
return nil
|
||||
}
|
||||
settings = append(settings, http2.Setting{
|
||||
ID: sid,
|
||||
Val: uint32(val),
|
||||
})
|
||||
}
|
||||
a.logf("Sending: %v", settings)
|
||||
return a.framer.WriteSettings(settings...)
|
||||
}
|
||||
|
||||
func settingByName(name string) (http2.SettingID, bool) {
|
||||
for _, sid := range [...]http2.SettingID{
|
||||
http2.SettingHeaderTableSize,
|
||||
http2.SettingEnablePush,
|
||||
http2.SettingMaxConcurrentStreams,
|
||||
http2.SettingInitialWindowSize,
|
||||
http2.SettingMaxFrameSize,
|
||||
http2.SettingMaxHeaderListSize,
|
||||
} {
|
||||
if strings.EqualFold(sid.String(), name) {
|
||||
return sid, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (app *h2i) cmdPing(args []string) error {
|
||||
if len(args) > 1 {
|
||||
app.logf("invalid PING usage: only accepts 0 or 1 args")
|
||||
return nil // nil means don't end the program
|
||||
}
|
||||
var data [8]byte
|
||||
if len(args) == 1 {
|
||||
copy(data[:], args[0])
|
||||
} else {
|
||||
copy(data[:], "h2i_ping")
|
||||
}
|
||||
return app.framer.WritePing(false, data)
|
||||
}
|
||||
|
||||
func (app *h2i) cmdHeaders(args []string) error {
|
||||
if len(args) > 0 {
|
||||
app.logf("Error: HEADERS doesn't yet take arguments.")
|
||||
// TODO: flags for restricting window size, to force CONTINUATION
|
||||
// frames.
|
||||
return nil
|
||||
}
|
||||
var h1req bytes.Buffer
|
||||
app.term.SetPrompt("(as HTTP/1.1)> ")
|
||||
defer app.term.SetPrompt("h2i> ")
|
||||
for {
|
||||
line, err := app.term.ReadLine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h1req.WriteString(line)
|
||||
h1req.WriteString("\r\n")
|
||||
if line == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
req, err := http.ReadRequest(bufio.NewReader(&h1req))
|
||||
if err != nil {
|
||||
app.logf("Invalid HTTP/1.1 request: %v", err)
|
||||
return nil
|
||||
}
|
||||
if app.streamID == 0 {
|
||||
app.streamID = 1
|
||||
} else {
|
||||
app.streamID += 2
|
||||
}
|
||||
app.logf("Opening Stream-ID %d:", app.streamID)
|
||||
hbf := app.encodeHeaders(req)
|
||||
if len(hbf) > 16<<10 {
|
||||
app.logf("TODO: h2i doesn't yet write CONTINUATION frames. Copy it from transport.go")
|
||||
return nil
|
||||
}
|
||||
return app.framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: app.streamID,
|
||||
BlockFragment: hbf,
|
||||
EndStream: req.Method == "GET" || req.Method == "HEAD", // good enough for now
|
||||
EndHeaders: true, // for now
|
||||
})
|
||||
}
|
||||
|
||||
func (app *h2i) readFrames() error {
|
||||
for {
|
||||
f, err := app.framer.ReadFrame()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ReadFrame: %v", err)
|
||||
}
|
||||
app.logf("%v", f)
|
||||
switch f := f.(type) {
|
||||
case *http2.PingFrame:
|
||||
app.logf(" Data = %q", f.Data)
|
||||
case *http2.SettingsFrame:
|
||||
f.ForeachSetting(func(s http2.Setting) error {
|
||||
app.logf(" %v", s)
|
||||
app.peerSetting[s.ID] = s.Val
|
||||
return nil
|
||||
})
|
||||
case *http2.WindowUpdateFrame:
|
||||
app.logf(" Window-Increment = %v\n", f.Increment)
|
||||
case *http2.GoAwayFrame:
|
||||
app.logf(" Last-Stream-ID = %d; Error-Code = %v (%d)\n", f.LastStreamID, f.ErrCode, f.ErrCode)
|
||||
case *http2.DataFrame:
|
||||
app.logf(" %q", f.Data())
|
||||
case *http2.HeadersFrame:
|
||||
if f.HasPriority() {
|
||||
app.logf(" PRIORITY = %v", f.Priority)
|
||||
}
|
||||
if app.hdec == nil {
|
||||
// TODO: if the user uses h2i to send a SETTINGS frame advertising
|
||||
// something larger, we'll need to respect SETTINGS_HEADER_TABLE_SIZE
|
||||
// and stuff here instead of using the 4k default. But for now:
|
||||
tableSize := uint32(4 << 10)
|
||||
app.hdec = hpack.NewDecoder(tableSize, app.onNewHeaderField)
|
||||
}
|
||||
app.hdec.Write(f.HeaderBlockFragment())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// called from readLoop
|
||||
func (app *h2i) onNewHeaderField(f hpack.HeaderField) {
|
||||
if f.Sensitive {
|
||||
app.logf(" %s = %q (SENSITIVE)", f.Name, f.Value)
|
||||
}
|
||||
app.logf(" %s = %q", f.Name, f.Value)
|
||||
}
|
||||
|
||||
func (app *h2i) encodeHeaders(req *http.Request) []byte {
|
||||
app.hbuf.Reset()
|
||||
|
||||
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
|
||||
path := req.URL.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
app.writeHeader(":authority", host) // probably not right for all sites
|
||||
app.writeHeader(":method", req.Method)
|
||||
app.writeHeader(":path", path)
|
||||
app.writeHeader(":scheme", "https")
|
||||
|
||||
for k, vv := range req.Header {
|
||||
lowKey := strings.ToLower(k)
|
||||
if lowKey == "host" {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
app.writeHeader(lowKey, v)
|
||||
}
|
||||
}
|
||||
return app.hbuf.Bytes()
|
||||
}
|
||||
|
||||
func (app *h2i) writeHeader(name, value string) {
|
||||
app.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
||||
app.logf(" %s = %s", name, value)
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
commonLowerHeader = map[string]string{} // Go-Canonical-Case -> lower-case
|
||||
commonCanonHeader = map[string]string{} // lower-case -> Go-Canonical-Case
|
||||
)
|
||||
|
||||
func init() {
|
||||
for _, v := range []string{
|
||||
"accept",
|
||||
"accept-charset",
|
||||
"accept-encoding",
|
||||
"accept-language",
|
||||
"accept-ranges",
|
||||
"age",
|
||||
"access-control-allow-origin",
|
||||
"allow",
|
||||
"authorization",
|
||||
"cache-control",
|
||||
"content-disposition",
|
||||
"content-encoding",
|
||||
"content-language",
|
||||
"content-length",
|
||||
"content-location",
|
||||
"content-range",
|
||||
"content-type",
|
||||
"cookie",
|
||||
"date",
|
||||
"etag",
|
||||
"expect",
|
||||
"expires",
|
||||
"from",
|
||||
"host",
|
||||
"if-match",
|
||||
"if-modified-since",
|
||||
"if-none-match",
|
||||
"if-unmodified-since",
|
||||
"last-modified",
|
||||
"link",
|
||||
"location",
|
||||
"max-forwards",
|
||||
"proxy-authenticate",
|
||||
"proxy-authorization",
|
||||
"range",
|
||||
"referer",
|
||||
"refresh",
|
||||
"retry-after",
|
||||
"server",
|
||||
"set-cookie",
|
||||
"strict-transport-security",
|
||||
"transfer-encoding",
|
||||
"user-agent",
|
||||
"vary",
|
||||
"via",
|
||||
"www-authenticate",
|
||||
} {
|
||||
chk := http.CanonicalHeaderKey(v)
|
||||
commonLowerHeader[chk] = v
|
||||
commonCanonHeader[v] = chk
|
||||
}
|
||||
}
|
||||
|
||||
func lowerHeader(v string) string {
|
||||
if s, ok := commonLowerHeader[v]; ok {
|
||||
return s
|
||||
}
|
||||
return strings.ToLower(v)
|
||||
}
|
@ -0,0 +1,252 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
uint32Max = ^uint32(0)
|
||||
initialHeaderTableSize = 4096
|
||||
)
|
||||
|
||||
type Encoder struct {
|
||||
dynTab dynamicTable
|
||||
// minSize is the minimum table size set by
|
||||
// SetMaxDynamicTableSize after the previous Header Table Size
|
||||
// Update.
|
||||
minSize uint32
|
||||
// maxSizeLimit is the maximum table size this encoder
|
||||
// supports. This will protect the encoder from too large
|
||||
// size.
|
||||
maxSizeLimit uint32
|
||||
// tableSizeUpdate indicates whether "Header Table Size
|
||||
// Update" is required.
|
||||
tableSizeUpdate bool
|
||||
w io.Writer
|
||||
buf []byte
|
||||
}
|
||||
|
||||
// NewEncoder returns a new Encoder which performs HPACK encoding. An
|
||||
// encoded data is written to w.
|
||||
func NewEncoder(w io.Writer) *Encoder {
|
||||
e := &Encoder{
|
||||
minSize: uint32Max,
|
||||
maxSizeLimit: initialHeaderTableSize,
|
||||
tableSizeUpdate: false,
|
||||
w: w,
|
||||
}
|
||||
e.dynTab.setMaxSize(initialHeaderTableSize)
|
||||
return e
|
||||
}
|
||||
|
||||
// WriteField encodes f into a single Write to e's underlying Writer.
|
||||
// This function may also produce bytes for "Header Table Size Update"
|
||||
// if necessary. If produced, it is done before encoding f.
|
||||
func (e *Encoder) WriteField(f HeaderField) error {
|
||||
e.buf = e.buf[:0]
|
||||
|
||||
if e.tableSizeUpdate {
|
||||
e.tableSizeUpdate = false
|
||||
if e.minSize < e.dynTab.maxSize {
|
||||
e.buf = appendTableSize(e.buf, e.minSize)
|
||||
}
|
||||
e.minSize = uint32Max
|
||||
e.buf = appendTableSize(e.buf, e.dynTab.maxSize)
|
||||
}
|
||||
|
||||
idx, nameValueMatch := e.searchTable(f)
|
||||
if nameValueMatch {
|
||||
e.buf = appendIndexed(e.buf, idx)
|
||||
} else {
|
||||
indexing := e.shouldIndex(f)
|
||||
if indexing {
|
||||
e.dynTab.add(f)
|
||||
}
|
||||
|
||||
if idx == 0 {
|
||||
e.buf = appendNewName(e.buf, f, indexing)
|
||||
} else {
|
||||
e.buf = appendIndexedName(e.buf, f, idx, indexing)
|
||||
}
|
||||
}
|
||||
n, err := e.w.Write(e.buf)
|
||||
if err == nil && n != len(e.buf) {
|
||||
err = io.ErrShortWrite
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// searchTable searches f in both stable and dynamic header tables.
|
||||
// The static header table is searched first. Only when there is no
|
||||
// exact match for both name and value, the dynamic header table is
|
||||
// then searched. If there is no match, i is 0. If both name and value
|
||||
// match, i is the matched index and nameValueMatch becomes true. If
|
||||
// only name matches, i points to that index and nameValueMatch
|
||||
// becomes false.
|
||||
func (e *Encoder) searchTable(f HeaderField) (i uint64, nameValueMatch bool) {
|
||||
for idx, hf := range staticTable {
|
||||
if !constantTimeStringCompare(hf.Name, f.Name) {
|
||||
continue
|
||||
}
|
||||
if i == 0 {
|
||||
i = uint64(idx + 1)
|
||||
}
|
||||
if f.Sensitive {
|
||||
continue
|
||||
}
|
||||
if !constantTimeStringCompare(hf.Value, f.Value) {
|
||||
continue
|
||||
}
|
||||
i = uint64(idx + 1)
|
||||
nameValueMatch = true
|
||||
return
|
||||
}
|
||||
|
||||
j, nameValueMatch := e.dynTab.search(f)
|
||||
if nameValueMatch || (i == 0 && j != 0) {
|
||||
i = j + uint64(len(staticTable))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetMaxDynamicTableSize changes the dynamic header table size to v.
|
||||
// The actual size is bounded by the value passed to
|
||||
// SetMaxDynamicTableSizeLimit.
|
||||
func (e *Encoder) SetMaxDynamicTableSize(v uint32) {
|
||||
if v > e.maxSizeLimit {
|
||||
v = e.maxSizeLimit
|
||||
}
|
||||
if v < e.minSize {
|
||||
e.minSize = v
|
||||
}
|
||||
e.tableSizeUpdate = true
|
||||
e.dynTab.setMaxSize(v)
|
||||
}
|
||||
|
||||
// SetMaxDynamicTableSizeLimit changes the maximum value that can be
|
||||
// specified in SetMaxDynamicTableSize to v. By default, it is set to
|
||||
// 4096, which is the same size of the default dynamic header table
|
||||
// size described in HPACK specification. If the current maximum
|
||||
// dynamic header table size is strictly greater than v, "Header Table
|
||||
// Size Update" will be done in the next WriteField call and the
|
||||
// maximum dynamic header table size is truncated to v.
|
||||
func (e *Encoder) SetMaxDynamicTableSizeLimit(v uint32) {
|
||||
e.maxSizeLimit = v
|
||||
if e.dynTab.maxSize > v {
|
||||
e.tableSizeUpdate = true
|
||||
e.dynTab.setMaxSize(v)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldIndex reports whether f should be indexed.
|
||||
func (e *Encoder) shouldIndex(f HeaderField) bool {
|
||||
return !f.Sensitive && f.size() <= e.dynTab.maxSize
|
||||
}
|
||||
|
||||
// appendIndexed appends index i, as encoded in "Indexed Header Field"
|
||||
// representation, to dst and returns the extended buffer.
|
||||
func appendIndexed(dst []byte, i uint64) []byte {
|
||||
first := len(dst)
|
||||
dst = appendVarInt(dst, 7, i)
|
||||
dst[first] |= 0x80
|
||||
return dst
|
||||
}
|
||||
|
||||
// appendNewName appends f, as encoded in one of "Literal Header field
|
||||
// - New Name" representation variants, to dst and returns the
|
||||
// extended buffer.
|
||||
//
|
||||
// If f.Sensitive is true, "Never Indexed" representation is used. If
|
||||
// f.Sensitive is false and indexing is true, "Inremental Indexing"
|
||||
// representation is used.
|
||||
func appendNewName(dst []byte, f HeaderField, indexing bool) []byte {
|
||||
dst = append(dst, encodeTypeByte(indexing, f.Sensitive))
|
||||
dst = appendHpackString(dst, f.Name)
|
||||
return appendHpackString(dst, f.Value)
|
||||
}
|
||||
|
||||
// appendIndexedName appends f and index i referring indexed name
|
||||
// entry, as encoded in one of "Literal Header field - Indexed Name"
|
||||
// representation variants, to dst and returns the extended buffer.
|
||||
//
|
||||
// If f.Sensitive is true, "Never Indexed" representation is used. If
|
||||
// f.Sensitive is false and indexing is true, "Incremental Indexing"
|
||||
// representation is used.
|
||||
func appendIndexedName(dst []byte, f HeaderField, i uint64, indexing bool) []byte {
|
||||
first := len(dst)
|
||||
var n byte
|
||||
if indexing {
|
||||
n = 6
|
||||
} else {
|
||||
n = 4
|
||||
}
|
||||
dst = appendVarInt(dst, n, i)
|
||||
dst[first] |= encodeTypeByte(indexing, f.Sensitive)
|
||||
return appendHpackString(dst, f.Value)
|
||||
}
|
||||
|
||||
// appendTableSize appends v, as encoded in "Header Table Size Update"
|
||||
// representation, to dst and returns the extended buffer.
|
||||
func appendTableSize(dst []byte, v uint32) []byte {
|
||||
first := len(dst)
|
||||
dst = appendVarInt(dst, 5, uint64(v))
|
||||
dst[first] |= 0x20
|
||||
return dst
|
||||
}
|
||||
|
||||
// appendVarInt appends i, as encoded in variable integer form using n
|
||||
// bit prefix, to dst and returns the extended buffer.
|
||||
//
|
||||
// See
|
||||
// http://http2.github.io/http2-spec/compression.html#integer.representation
|
||||
func appendVarInt(dst []byte, n byte, i uint64) []byte {
|
||||
k := uint64((1 << n) - 1)
|
||||
if i < k {
|
||||
return append(dst, byte(i))
|
||||
}
|
||||
dst = append(dst, byte(k))
|
||||
i -= k
|
||||
for ; i >= 128; i >>= 7 {
|
||||
dst = append(dst, byte(0x80|(i&0x7f)))
|
||||
}
|
||||
return append(dst, byte(i))
|
||||
}
|
||||
|
||||
// appendHpackString appends s, as encoded in "String Literal"
|
||||
// representation, to dst and returns the the extended buffer.
|
||||
//
|
||||
// s will be encoded in Huffman codes only when it produces strictly
|
||||
// shorter byte string.
|
||||
func appendHpackString(dst []byte, s string) []byte {
|
||||
huffmanLength := HuffmanEncodeLength(s)
|
||||
if huffmanLength < uint64(len(s)) {
|
||||
first := len(dst)
|
||||
dst = appendVarInt(dst, 7, huffmanLength)
|
||||
dst = AppendHuffmanString(dst, s)
|
||||
dst[first] |= 0x80
|
||||
} else {
|
||||
dst = appendVarInt(dst, 7, uint64(len(s)))
|
||||
dst = append(dst, s...)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// encodeTypeByte returns type byte. If sensitive is true, type byte
|
||||
// for "Never Indexed" representation is returned. If sensitive is
|
||||
// false and indexing is true, type byte for "Incremental Indexing"
|
||||
// representation is returned. Otherwise, type byte for "Without
|
||||
// Indexing" is returned.
|
||||
func encodeTypeByte(indexing, sensitive bool) byte {
|
||||
if sensitive {
|
||||
return 0x10
|
||||
}
|
||||
if indexing {
|
||||
return 0x40
|
||||
}
|
||||
return 0
|
||||
}
|
@ -0,0 +1,331 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncoderTableSizeUpdate(t *testing.T) {
|
||||
tests := []struct {
|
||||
size1, size2 uint32
|
||||
wantHex string
|
||||
}{
|
||||
// Should emit 2 table size updates (2048 and 4096)
|
||||
{2048, 4096, "3fe10f 3fe11f 82"},
|
||||
|
||||
// Should emit 1 table size update (2048)
|
||||
{16384, 2048, "3fe10f 82"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
var buf bytes.Buffer
|
||||
e := NewEncoder(&buf)
|
||||
e.SetMaxDynamicTableSize(tt.size1)
|
||||
e.SetMaxDynamicTableSize(tt.size2)
|
||||
if err := e.WriteField(pair(":method", "GET")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := removeSpace(tt.wantHex)
|
||||
if got := hex.EncodeToString(buf.Bytes()); got != want {
|
||||
t.Errorf("e.SetDynamicTableSize %v, %v = %q; want %q", tt.size1, tt.size2, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderWriteField(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
e := NewEncoder(&buf)
|
||||
var got []HeaderField
|
||||
d := NewDecoder(4<<10, func(f HeaderField) {
|
||||
got = append(got, f)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
hdrs []HeaderField
|
||||
}{
|
||||
{[]HeaderField{
|
||||
pair(":method", "GET"),
|
||||
pair(":scheme", "http"),
|
||||
pair(":path", "/"),
|
||||
pair(":authority", "www.example.com"),
|
||||
}},
|
||||
{[]HeaderField{
|
||||
pair(":method", "GET"),
|
||||
pair(":scheme", "http"),
|
||||
pair(":path", "/"),
|
||||
pair(":authority", "www.example.com"),
|
||||
pair("cache-control", "no-cache"),
|
||||
}},
|
||||
{[]HeaderField{
|
||||
pair(":method", "GET"),
|
||||
pair(":scheme", "https"),
|
||||
pair(":path", "/index.html"),
|
||||
pair(":authority", "www.example.com"),
|
||||
pair("custom-key", "custom-value"),
|
||||
}},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
buf.Reset()
|
||||
got = got[:0]
|
||||
for _, hf := range tt.hdrs {
|
||||
if err := e.WriteField(hf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
_, err := d.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
t.Errorf("%d. Decoder Write = %v", i, err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.hdrs) {
|
||||
t.Errorf("%d. Decoded %+v; want %+v", i, got, tt.hdrs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderSearchTable(t *testing.T) {
|
||||
e := NewEncoder(nil)
|
||||
|
||||
e.dynTab.add(pair("foo", "bar"))
|
||||
e.dynTab.add(pair("blake", "miz"))
|
||||
e.dynTab.add(pair(":method", "GET"))
|
||||
|
||||
tests := []struct {
|
||||
hf HeaderField
|
||||
wantI uint64
|
||||
wantMatch bool
|
||||
}{
|
||||
// Name and Value match
|
||||
{pair("foo", "bar"), uint64(len(staticTable) + 3), true},
|
||||
{pair("blake", "miz"), uint64(len(staticTable) + 2), true},
|
||||
{pair(":method", "GET"), 2, true},
|
||||
|
||||
// Only name match because Sensitive == true
|
||||
{HeaderField{":method", "GET", true}, 2, false},
|
||||
|
||||
// Only Name matches
|
||||
{pair("foo", "..."), uint64(len(staticTable) + 3), false},
|
||||
{pair("blake", "..."), uint64(len(staticTable) + 2), false},
|
||||
{pair(":method", "..."), 2, false},
|
||||
|
||||
// None match
|
||||
{pair("foo-", "bar"), 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if gotI, gotMatch := e.searchTable(tt.hf); gotI != tt.wantI || gotMatch != tt.wantMatch {
|
||||
t.Errorf("d.search(%+v) = %v, %v; want %v, %v", tt.hf, gotI, gotMatch, tt.wantI, tt.wantMatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendVarInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
n byte
|
||||
i uint64
|
||||
want []byte
|
||||
}{
|
||||
// Fits in a byte:
|
||||
{1, 0, []byte{0}},
|
||||
{2, 2, []byte{2}},
|
||||
{3, 6, []byte{6}},
|
||||
{4, 14, []byte{14}},
|
||||
{5, 30, []byte{30}},
|
||||
{6, 62, []byte{62}},
|
||||
{7, 126, []byte{126}},
|
||||
{8, 254, []byte{254}},
|
||||
|
||||
// Multiple bytes:
|
||||
{5, 1337, []byte{31, 154, 10}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := appendVarInt(nil, tt.n, tt.i)
|
||||
if !bytes.Equal(got, tt.want) {
|
||||
t.Errorf("appendVarInt(nil, %v, %v) = %v; want %v", tt.n, tt.i, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendHpackString(t *testing.T) {
|
||||
tests := []struct {
|
||||
s, wantHex string
|
||||
}{
|
||||
// Huffman encoded
|
||||
{"www.example.com", "8c f1e3 c2e5 f23a 6ba0 ab90 f4ff"},
|
||||
|
||||
// Not Huffman encoded
|
||||
{"a", "01 61"},
|
||||
|
||||
// zero length
|
||||
{"", "00"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
want := removeSpace(tt.wantHex)
|
||||
buf := appendHpackString(nil, tt.s)
|
||||
if got := hex.EncodeToString(buf); want != got {
|
||||
t.Errorf("appendHpackString(nil, %q) = %q; want %q", tt.s, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendIndexed(t *testing.T) {
|
||||
tests := []struct {
|
||||
i uint64
|
||||
wantHex string
|
||||
}{
|
||||
// 1 byte
|
||||
{1, "81"},
|
||||
{126, "fe"},
|
||||
|
||||
// 2 bytes
|
||||
{127, "ff00"},
|
||||
{128, "ff01"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
want := removeSpace(tt.wantHex)
|
||||
buf := appendIndexed(nil, tt.i)
|
||||
if got := hex.EncodeToString(buf); want != got {
|
||||
t.Errorf("appendIndex(nil, %v) = %q; want %q", tt.i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendNewName(t *testing.T) {
|
||||
tests := []struct {
|
||||
f HeaderField
|
||||
indexing bool
|
||||
wantHex string
|
||||
}{
|
||||
// Incremental indexing
|
||||
{HeaderField{"custom-key", "custom-value", false}, true, "40 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
|
||||
|
||||
// Without indexing
|
||||
{HeaderField{"custom-key", "custom-value", false}, false, "00 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
|
||||
|
||||
// Never indexed
|
||||
{HeaderField{"custom-key", "custom-value", true}, true, "10 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
|
||||
{HeaderField{"custom-key", "custom-value", true}, false, "10 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
want := removeSpace(tt.wantHex)
|
||||
buf := appendNewName(nil, tt.f, tt.indexing)
|
||||
if got := hex.EncodeToString(buf); want != got {
|
||||
t.Errorf("appendNewName(nil, %+v, %v) = %q; want %q", tt.f, tt.indexing, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendIndexedName(t *testing.T) {
|
||||
tests := []struct {
|
||||
f HeaderField
|
||||
i uint64
|
||||
indexing bool
|
||||
wantHex string
|
||||
}{
|
||||
// Incremental indexing
|
||||
{HeaderField{":status", "302", false}, 8, true, "48 82 6402"},
|
||||
|
||||
// Without indexing
|
||||
{HeaderField{":status", "302", false}, 8, false, "08 82 6402"},
|
||||
|
||||
// Never indexed
|
||||
{HeaderField{":status", "302", true}, 8, true, "18 82 6402"},
|
||||
{HeaderField{":status", "302", true}, 8, false, "18 82 6402"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
want := removeSpace(tt.wantHex)
|
||||
buf := appendIndexedName(nil, tt.f, tt.i, tt.indexing)
|
||||
if got := hex.EncodeToString(buf); want != got {
|
||||
t.Errorf("appendIndexedName(nil, %+v, %v) = %q; want %q", tt.f, tt.indexing, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendTableSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
i uint32
|
||||
wantHex string
|
||||
}{
|
||||
// Fits into 1 byte
|
||||
{30, "3e"},
|
||||
|
||||
// Extra byte
|
||||
{31, "3f00"},
|
||||
{32, "3f01"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
want := removeSpace(tt.wantHex)
|
||||
buf := appendTableSize(nil, tt.i)
|
||||
if got := hex.EncodeToString(buf); want != got {
|
||||
t.Errorf("appendTableSize(nil, %v) = %q; want %q", tt.i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderSetMaxDynamicTableSize(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
e := NewEncoder(&buf)
|
||||
tests := []struct {
|
||||
v uint32
|
||||
wantUpdate bool
|
||||
wantMinSize uint32
|
||||
wantMaxSize uint32
|
||||
}{
|
||||
// Set new table size to 2048
|
||||
{2048, true, 2048, 2048},
|
||||
|
||||
// Set new table size to 16384, but still limited to
|
||||
// 4096
|
||||
{16384, true, 2048, 4096},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
e.SetMaxDynamicTableSize(tt.v)
|
||||
if got := e.tableSizeUpdate; tt.wantUpdate != got {
|
||||
t.Errorf("e.tableSizeUpdate = %v; want %v", got, tt.wantUpdate)
|
||||
}
|
||||
if got := e.minSize; tt.wantMinSize != got {
|
||||
t.Errorf("e.minSize = %v; want %v", got, tt.wantMinSize)
|
||||
}
|
||||
if got := e.dynTab.maxSize; tt.wantMaxSize != got {
|
||||
t.Errorf("e.maxSize = %v; want %v", got, tt.wantMaxSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderSetMaxDynamicTableSizeLimit(t *testing.T) {
|
||||
e := NewEncoder(nil)
|
||||
// 4095 < initialHeaderTableSize means maxSize is truncated to
|
||||
// 4095.
|
||||
e.SetMaxDynamicTableSizeLimit(4095)
|
||||
if got, want := e.dynTab.maxSize, uint32(4095); got != want {
|
||||
t.Errorf("e.dynTab.maxSize = %v; want %v", got, want)
|
||||
}
|
||||
if got, want := e.maxSizeLimit, uint32(4095); got != want {
|
||||
t.Errorf("e.maxSizeLimit = %v; want %v", got, want)
|
||||
}
|
||||
if got, want := e.tableSizeUpdate, true; got != want {
|
||||
t.Errorf("e.tableSizeUpdate = %v; want %v", got, want)
|
||||
}
|
||||
// maxSize will be truncated to maxSizeLimit
|
||||
e.SetMaxDynamicTableSize(16384)
|
||||
if got, want := e.dynTab.maxSize, uint32(4095); got != want {
|
||||
t.Errorf("e.dynTab.maxSize = %v; want %v", got, want)
|
||||
}
|
||||
// 8192 > current maxSizeLimit, so maxSize does not change.
|
||||
e.SetMaxDynamicTableSizeLimit(8192)
|
||||
if got, want := e.dynTab.maxSize, uint32(4095); got != want {
|
||||
t.Errorf("e.dynTab.maxSize = %v; want %v", got, want)
|
||||
}
|
||||
if got, want := e.maxSizeLimit, uint32(8192); got != want {
|
||||
t.Errorf("e.maxSizeLimit = %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func removeSpace(s string) string {
|
||||
return strings.Replace(s, " ", "", -1)
|
||||
}
|
@ -0,0 +1,445 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
// Package hpack implements HPACK, a compression format for
|
||||
// efficiently representing HTTP header fields in the context of HTTP/2.
|
||||
//
|
||||
// See http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-09
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// A DecodingError is something the spec defines as a decoding error.
|
||||
type DecodingError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (de DecodingError) Error() string {
|
||||
return fmt.Sprintf("decoding error: %v", de.Err)
|
||||
}
|
||||
|
||||
// An InvalidIndexError is returned when an encoder references a table
|
||||
// entry before the static table or after the end of the dynamic table.
|
||||
type InvalidIndexError int
|
||||
|
||||
func (e InvalidIndexError) Error() string {
|
||||
return fmt.Sprintf("invalid indexed representation index %d", int(e))
|
||||
}
|
||||
|
||||
// A HeaderField is a name-value pair. Both the name and value are
|
||||
// treated as opaque sequences of octets.
|
||||
type HeaderField struct {
|
||||
Name, Value string
|
||||
|
||||
// Sensitive means that this header field should never be
|
||||
// indexed.
|
||||
Sensitive bool
|
||||
}
|
||||
|
||||
func (hf *HeaderField) size() uint32 {
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.4.1
|
||||
// "The size of the dynamic table is the sum of the size of
|
||||
// its entries. The size of an entry is the sum of its name's
|
||||
// length in octets (as defined in Section 5.2), its value's
|
||||
// length in octets (see Section 5.2), plus 32. The size of
|
||||
// an entry is calculated using the length of the name and
|
||||
// value without any Huffman encoding applied."
|
||||
|
||||
// This can overflow if somebody makes a large HeaderField
|
||||
// Name and/or Value by hand, but we don't care, because that
|
||||
// won't happen on the wire because the encoding doesn't allow
|
||||
// it.
|
||||
return uint32(len(hf.Name) + len(hf.Value) + 32)
|
||||
}
|
||||
|
||||
// A Decoder is the decoding context for incremental processing of
|
||||
// header blocks.
|
||||
type Decoder struct {
|
||||
dynTab dynamicTable
|
||||
emit func(f HeaderField)
|
||||
|
||||
// buf is the unparsed buffer. It's only written to
|
||||
// saveBuf if it was truncated in the middle of a header
|
||||
// block. Because it's usually not owned, we can only
|
||||
// process it under Write.
|
||||
buf []byte // usually not owned
|
||||
saveBuf bytes.Buffer
|
||||
}
|
||||
|
||||
func NewDecoder(maxSize uint32, emitFunc func(f HeaderField)) *Decoder {
|
||||
d := &Decoder{
|
||||
emit: emitFunc,
|
||||
}
|
||||
d.dynTab.allowedMaxSize = maxSize
|
||||
d.dynTab.setMaxSize(maxSize)
|
||||
return d
|
||||
}
|
||||
|
||||
// TODO: add method *Decoder.Reset(maxSize, emitFunc) to let callers re-use Decoders and their
|
||||
// underlying buffers for garbage reasons.
|
||||
|
||||
func (d *Decoder) SetMaxDynamicTableSize(v uint32) {
|
||||
d.dynTab.setMaxSize(v)
|
||||
}
|
||||
|
||||
// SetAllowedMaxDynamicTableSize sets the upper bound that the encoded
|
||||
// stream (via dynamic table size updates) may set the maximum size
|
||||
// to.
|
||||
func (d *Decoder) SetAllowedMaxDynamicTableSize(v uint32) {
|
||||
d.dynTab.allowedMaxSize = v
|
||||
}
|
||||
|
||||
type dynamicTable struct {
|
||||
// ents is the FIFO described at
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.2.3.2
|
||||
// The newest (low index) is append at the end, and items are
|
||||
// evicted from the front.
|
||||
ents []HeaderField
|
||||
size uint32
|
||||
maxSize uint32 // current maxSize
|
||||
allowedMaxSize uint32 // maxSize may go up to this, inclusive
|
||||
}
|
||||
|
||||
func (dt *dynamicTable) setMaxSize(v uint32) {
|
||||
dt.maxSize = v
|
||||
dt.evict()
|
||||
}
|
||||
|
||||
// TODO: change dynamicTable to be a struct with a slice and a size int field,
|
||||
// per http://http2.github.io/http2-spec/compression.html#rfc.section.4.1:
|
||||
//
|
||||
//
|
||||
// Then make add increment the size. maybe the max size should move from Decoder to
|
||||
// dynamicTable and add should return an ok bool if there was enough space.
|
||||
//
|
||||
// Later we'll need a remove operation on dynamicTable.
|
||||
|
||||
func (dt *dynamicTable) add(f HeaderField) {
|
||||
dt.ents = append(dt.ents, f)
|
||||
dt.size += f.size()
|
||||
dt.evict()
|
||||
}
|
||||
|
||||
// If we're too big, evict old stuff (front of the slice)
|
||||
func (dt *dynamicTable) evict() {
|
||||
base := dt.ents // keep base pointer of slice
|
||||
for dt.size > dt.maxSize {
|
||||
dt.size -= dt.ents[0].size()
|
||||
dt.ents = dt.ents[1:]
|
||||
}
|
||||
|
||||
// Shift slice contents down if we evicted things.
|
||||
if len(dt.ents) != len(base) {
|
||||
copy(base, dt.ents)
|
||||
dt.ents = base[:len(dt.ents)]
|
||||
}
|
||||
}
|
||||
|
||||
// constantTimeStringCompare compares string a and b in a constant
|
||||
// time manner.
|
||||
func constantTimeStringCompare(a, b string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
c := byte(0)
|
||||
|
||||
for i := 0; i < len(a); i++ {
|
||||
c |= a[i] ^ b[i]
|
||||
}
|
||||
|
||||
return c == 0
|
||||
}
|
||||
|
||||
// Search searches f in the table. The return value i is 0 if there is
|
||||
// no name match. If there is name match or name/value match, i is the
|
||||
// index of that entry (1-based). If both name and value match,
|
||||
// nameValueMatch becomes true.
|
||||
func (dt *dynamicTable) search(f HeaderField) (i uint64, nameValueMatch bool) {
|
||||
l := len(dt.ents)
|
||||
for j := l - 1; j >= 0; j-- {
|
||||
ent := dt.ents[j]
|
||||
if !constantTimeStringCompare(ent.Name, f.Name) {
|
||||
continue
|
||||
}
|
||||
if i == 0 {
|
||||
i = uint64(l - j)
|
||||
}
|
||||
if f.Sensitive {
|
||||
continue
|
||||
}
|
||||
if !constantTimeStringCompare(ent.Value, f.Value) {
|
||||
continue
|
||||
}
|
||||
i = uint64(l - j)
|
||||
nameValueMatch = true
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Decoder) maxTableIndex() int {
|
||||
return len(d.dynTab.ents) + len(staticTable)
|
||||
}
|
||||
|
||||
func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) {
|
||||
if i < 1 {
|
||||
return
|
||||
}
|
||||
if i > uint64(d.maxTableIndex()) {
|
||||
return
|
||||
}
|
||||
if i <= uint64(len(staticTable)) {
|
||||
return staticTable[i-1], true
|
||||
}
|
||||
dents := d.dynTab.ents
|
||||
return dents[len(dents)-(int(i)-len(staticTable))], true
|
||||
}
|
||||
|
||||
// Decode decodes an entire block.
|
||||
//
|
||||
// TODO: remove this method and make it incremental later? This is
|
||||
// easier for debugging now.
|
||||
func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) {
|
||||
var hf []HeaderField
|
||||
saveFunc := d.emit
|
||||
defer func() { d.emit = saveFunc }()
|
||||
d.emit = func(f HeaderField) { hf = append(hf, f) }
|
||||
if _, err := d.Write(p); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := d.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hf, nil
|
||||
}
|
||||
|
||||
func (d *Decoder) Close() error {
|
||||
if d.saveBuf.Len() > 0 {
|
||||
d.saveBuf.Reset()
|
||||
return DecodingError{errors.New("truncated headers")}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Decoder) Write(p []byte) (n int, err error) {
|
||||
if len(p) == 0 {
|
||||
// Prevent state machine CPU attacks (making us redo
|
||||
// work up to the point of finding out we don't have
|
||||
// enough data)
|
||||
return
|
||||
}
|
||||
// Only copy the data if we have to. Optimistically assume
|
||||
// that p will contain a complete header block.
|
||||
if d.saveBuf.Len() == 0 {
|
||||
d.buf = p
|
||||
} else {
|
||||
d.saveBuf.Write(p)
|
||||
d.buf = d.saveBuf.Bytes()
|
||||
d.saveBuf.Reset()
|
||||
}
|
||||
|
||||
for len(d.buf) > 0 {
|
||||
err = d.parseHeaderFieldRepr()
|
||||
if err != nil {
|
||||
if err == errNeedMore {
|
||||
err = nil
|
||||
d.saveBuf.Write(d.buf)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return len(p), err
|
||||
}
|
||||
|
||||
// errNeedMore is an internal sentinel error value that means the
|
||||
// buffer is truncated and we need to read more data before we can
|
||||
// continue parsing.
|
||||
var errNeedMore = errors.New("need more data")
|
||||
|
||||
type indexType int
|
||||
|
||||
const (
|
||||
indexedTrue indexType = iota
|
||||
indexedFalse
|
||||
indexedNever
|
||||
)
|
||||
|
||||
func (v indexType) indexed() bool { return v == indexedTrue }
|
||||
func (v indexType) sensitive() bool { return v == indexedNever }
|
||||
|
||||
// returns errNeedMore if there isn't enough data available.
|
||||
// any other error is fatal.
|
||||
// consumes d.buf iff it returns nil.
|
||||
// precondition: must be called with len(d.buf) > 0
|
||||
func (d *Decoder) parseHeaderFieldRepr() error {
|
||||
b := d.buf[0]
|
||||
switch {
|
||||
case b&128 != 0:
|
||||
// Indexed representation.
|
||||
// High bit set?
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.1
|
||||
return d.parseFieldIndexed()
|
||||
case b&192 == 64:
|
||||
// 6.2.1 Literal Header Field with Incremental Indexing
|
||||
// 0b10xxxxxx: top two bits are 10
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.1
|
||||
return d.parseFieldLiteral(6, indexedTrue)
|
||||
case b&240 == 0:
|
||||
// 6.2.2 Literal Header Field without Indexing
|
||||
// 0b0000xxxx: top four bits are 0000
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.2
|
||||
return d.parseFieldLiteral(4, indexedFalse)
|
||||
case b&240 == 16:
|
||||
// 6.2.3 Literal Header Field never Indexed
|
||||
// 0b0001xxxx: top four bits are 0001
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.3
|
||||
return d.parseFieldLiteral(4, indexedNever)
|
||||
case b&224 == 32:
|
||||
// 6.3 Dynamic Table Size Update
|
||||
// Top three bits are '001'.
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.6.3
|
||||
return d.parseDynamicTableSizeUpdate()
|
||||
}
|
||||
|
||||
return DecodingError{errors.New("invalid encoding")}
|
||||
}
|
||||
|
||||
// (same invariants and behavior as parseHeaderFieldRepr)
|
||||
func (d *Decoder) parseFieldIndexed() error {
|
||||
buf := d.buf
|
||||
idx, buf, err := readVarInt(7, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hf, ok := d.at(idx)
|
||||
if !ok {
|
||||
return DecodingError{InvalidIndexError(idx)}
|
||||
}
|
||||
d.emit(HeaderField{Name: hf.Name, Value: hf.Value})
|
||||
d.buf = buf
|
||||
return nil
|
||||
}
|
||||
|
||||
// (same invariants and behavior as parseHeaderFieldRepr)
|
||||
func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error {
|
||||
buf := d.buf
|
||||
nameIdx, buf, err := readVarInt(n, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var hf HeaderField
|
||||
if nameIdx > 0 {
|
||||
ihf, ok := d.at(nameIdx)
|
||||
if !ok {
|
||||
return DecodingError{InvalidIndexError(nameIdx)}
|
||||
}
|
||||
hf.Name = ihf.Name
|
||||
} else {
|
||||
hf.Name, buf, err = readString(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
hf.Value, buf, err = readString(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.buf = buf
|
||||
if it.indexed() {
|
||||
d.dynTab.add(hf)
|
||||
}
|
||||
hf.Sensitive = it.sensitive()
|
||||
d.emit(hf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// (same invariants and behavior as parseHeaderFieldRepr)
|
||||
func (d *Decoder) parseDynamicTableSizeUpdate() error {
|
||||
buf := d.buf
|
||||
size, buf, err := readVarInt(5, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if size > uint64(d.dynTab.allowedMaxSize) {
|
||||
return DecodingError{errors.New("dynamic table size update too large")}
|
||||
}
|
||||
d.dynTab.setMaxSize(uint32(size))
|
||||
d.buf = buf
|
||||
return nil
|
||||
}
|
||||
|
||||
var errVarintOverflow = DecodingError{errors.New("varint integer overflow")}
|
||||
|
||||
// readVarInt reads an unsigned variable length integer off the
|
||||
// beginning of p. n is the parameter as described in
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1.
|
||||
//
|
||||
// n must always be between 1 and 8.
|
||||
//
|
||||
// The returned remain buffer is either a smaller suffix of p, or err != nil.
|
||||
// The error is errNeedMore if p doesn't contain a complete integer.
|
||||
func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) {
|
||||
if n < 1 || n > 8 {
|
||||
panic("bad n")
|
||||
}
|
||||
if len(p) == 0 {
|
||||
return 0, p, errNeedMore
|
||||
}
|
||||
i = uint64(p[0])
|
||||
if n < 8 {
|
||||
i &= (1 << uint64(n)) - 1
|
||||
}
|
||||
if i < (1<<uint64(n))-1 {
|
||||
return i, p[1:], nil
|
||||
}
|
||||
|
||||
origP := p
|
||||
p = p[1:]
|
||||
var m uint64
|
||||
for len(p) > 0 {
|
||||
b := p[0]
|
||||
p = p[1:]
|
||||
i += uint64(b&127) << m
|
||||
if b&128 == 0 {
|
||||
return i, p, nil
|
||||
}
|
||||
m += 7
|
||||
if m >= 63 { // TODO: proper overflow check. making this up.
|
||||
return 0, origP, errVarintOverflow
|
||||
}
|
||||
}
|
||||
return 0, origP, errNeedMore
|
||||
}
|
||||
|
||||
func readString(p []byte) (s string, remain []byte, err error) {
|
||||
if len(p) == 0 {
|
||||
return "", p, errNeedMore
|
||||
}
|
||||
isHuff := p[0]&128 != 0
|
||||
strLen, p, err := readVarInt(7, p)
|
||||
if err != nil {
|
||||
return "", p, err
|
||||
}
|
||||
if uint64(len(p)) < strLen {
|
||||
return "", p, errNeedMore
|
||||
}
|
||||
if !isHuff {
|
||||
return string(p[:strLen]), p[strLen:], nil
|
||||
}
|
||||
|
||||
// TODO: optimize this garbage:
|
||||
var buf bytes.Buffer
|
||||
if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return buf.String(), p[strLen:], nil
|
||||
}
|
@ -0,0 +1,648 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStaticTable(t *testing.T) {
|
||||
fromSpec := `
|
||||
+-------+-----------------------------+---------------+
|
||||
| 1 | :authority | |
|
||||
| 2 | :method | GET |
|
||||
| 3 | :method | POST |
|
||||
| 4 | :path | / |
|
||||
| 5 | :path | /index.html |
|
||||
| 6 | :scheme | http |
|
||||
| 7 | :scheme | https |
|
||||
| 8 | :status | 200 |
|
||||
| 9 | :status | 204 |
|
||||
| 10 | :status | 206 |
|
||||
| 11 | :status | 304 |
|
||||
| 12 | :status | 400 |
|
||||
| 13 | :status | 404 |
|
||||
| 14 | :status | 500 |
|
||||
| 15 | accept-charset | |
|
||||
| 16 | accept-encoding | gzip, deflate |
|
||||
| 17 | accept-language | |
|
||||
| 18 | accept-ranges | |
|
||||
| 19 | accept | |
|
||||
| 20 | access-control-allow-origin | |
|
||||
| 21 | age | |
|
||||
| 22 | allow | |
|
||||
| 23 | authorization | |
|
||||
| 24 | cache-control | |
|
||||
| 25 | content-disposition | |
|
||||
| 26 | content-encoding | |
|
||||
| 27 | content-language | |
|
||||
| 28 | content-length | |
|
||||
| 29 | content-location | |
|
||||
| 30 | content-range | |
|
||||
| 31 | content-type | |
|
||||
| 32 | cookie | |
|
||||
| 33 | date | |
|
||||
| 34 | etag | |
|
||||
| 35 | expect | |
|
||||
| 36 | expires | |
|
||||
| 37 | from | |
|
||||
| 38 | host | |
|
||||
| 39 | if-match | |
|
||||
| 40 | if-modified-since | |
|
||||
| 41 | if-none-match | |
|
||||
| 42 | if-range | |
|
||||
| 43 | if-unmodified-since | |
|
||||
| 44 | last-modified | |
|
||||
| 45 | link | |
|
||||
| 46 | location | |
|
||||
| 47 | max-forwards | |
|
||||
| 48 | proxy-authenticate | |
|
||||
| 49 | proxy-authorization | |
|
||||
| 50 | range | |
|
||||
| 51 | referer | |
|
||||
| 52 | refresh | |
|
||||
| 53 | retry-after | |
|
||||
| 54 | server | |
|
||||
| 55 | set-cookie | |
|
||||
| 56 | strict-transport-security | |
|
||||
| 57 | transfer-encoding | |
|
||||
| 58 | user-agent | |
|
||||
| 59 | vary | |
|
||||
| 60 | via | |
|
||||
| 61 | www-authenticate | |
|
||||
+-------+-----------------------------+---------------+
|
||||
`
|
||||
bs := bufio.NewScanner(strings.NewReader(fromSpec))
|
||||
re := regexp.MustCompile(`\| (\d+)\s+\| (\S+)\s*\| (\S(.*\S)?)?\s+\|`)
|
||||
for bs.Scan() {
|
||||
l := bs.Text()
|
||||
if !strings.Contains(l, "|") {
|
||||
continue
|
||||
}
|
||||
m := re.FindStringSubmatch(l)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
i, err := strconv.Atoi(m[1])
|
||||
if err != nil {
|
||||
t.Errorf("Bogus integer on line %q", l)
|
||||
continue
|
||||
}
|
||||
if i < 1 || i > len(staticTable) {
|
||||
t.Errorf("Bogus index %d on line %q", i, l)
|
||||
continue
|
||||
}
|
||||
if got, want := staticTable[i-1].Name, m[2]; got != want {
|
||||
t.Errorf("header index %d name = %q; want %q", i, got, want)
|
||||
}
|
||||
if got, want := staticTable[i-1].Value, m[3]; got != want {
|
||||
t.Errorf("header index %d value = %q; want %q", i, got, want)
|
||||
}
|
||||
}
|
||||
if err := bs.Err(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Decoder) mustAt(idx int) HeaderField {
|
||||
if hf, ok := d.at(uint64(idx)); !ok {
|
||||
panic(fmt.Sprintf("bogus index %d", idx))
|
||||
} else {
|
||||
return hf
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicTableAt(t *testing.T) {
|
||||
d := NewDecoder(4096, nil)
|
||||
at := d.mustAt
|
||||
if got, want := at(2), (pair(":method", "GET")); got != want {
|
||||
t.Errorf("at(2) = %v; want %v", got, want)
|
||||
}
|
||||
d.dynTab.add(pair("foo", "bar"))
|
||||
d.dynTab.add(pair("blake", "miz"))
|
||||
if got, want := at(len(staticTable)+1), (pair("blake", "miz")); got != want {
|
||||
t.Errorf("at(dyn 1) = %v; want %v", got, want)
|
||||
}
|
||||
if got, want := at(len(staticTable)+2), (pair("foo", "bar")); got != want {
|
||||
t.Errorf("at(dyn 2) = %v; want %v", got, want)
|
||||
}
|
||||
if got, want := at(3), (pair(":method", "POST")); got != want {
|
||||
t.Errorf("at(3) = %v; want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicTableSearch(t *testing.T) {
|
||||
dt := dynamicTable{}
|
||||
dt.setMaxSize(4096)
|
||||
|
||||
dt.add(pair("foo", "bar"))
|
||||
dt.add(pair("blake", "miz"))
|
||||
dt.add(pair(":method", "GET"))
|
||||
|
||||
tests := []struct {
|
||||
hf HeaderField
|
||||
wantI uint64
|
||||
wantMatch bool
|
||||
}{
|
||||
// Name and Value match
|
||||
{pair("foo", "bar"), 3, true},
|
||||
{pair(":method", "GET"), 1, true},
|
||||
|
||||
// Only name match because of Sensitive == true
|
||||
{HeaderField{"blake", "miz", true}, 2, false},
|
||||
|
||||
// Only Name matches
|
||||
{pair("foo", "..."), 3, false},
|
||||
{pair("blake", "..."), 2, false},
|
||||
{pair(":method", "..."), 1, false},
|
||||
|
||||
// None match
|
||||
{pair("foo-", "bar"), 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if gotI, gotMatch := dt.search(tt.hf); gotI != tt.wantI || gotMatch != tt.wantMatch {
|
||||
t.Errorf("d.search(%+v) = %v, %v; want %v, %v", tt.hf, gotI, gotMatch, tt.wantI, tt.wantMatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicTableSizeEvict(t *testing.T) {
|
||||
d := NewDecoder(4096, nil)
|
||||
if want := uint32(0); d.dynTab.size != want {
|
||||
t.Fatalf("size = %d; want %d", d.dynTab.size, want)
|
||||
}
|
||||
add := d.dynTab.add
|
||||
add(pair("blake", "eats pizza"))
|
||||
if want := uint32(15 + 32); d.dynTab.size != want {
|
||||
t.Fatalf("after pizza, size = %d; want %d", d.dynTab.size, want)
|
||||
}
|
||||
add(pair("foo", "bar"))
|
||||
if want := uint32(15 + 32 + 6 + 32); d.dynTab.size != want {
|
||||
t.Fatalf("after foo bar, size = %d; want %d", d.dynTab.size, want)
|
||||
}
|
||||
d.dynTab.setMaxSize(15 + 32 + 1 /* slop */)
|
||||
if want := uint32(6 + 32); d.dynTab.size != want {
|
||||
t.Fatalf("after setMaxSize, size = %d; want %d", d.dynTab.size, want)
|
||||
}
|
||||
if got, want := d.mustAt(len(staticTable)+1), (pair("foo", "bar")); got != want {
|
||||
t.Errorf("at(dyn 1) = %v; want %v", got, want)
|
||||
}
|
||||
add(pair("long", strings.Repeat("x", 500)))
|
||||
if want := uint32(0); d.dynTab.size != want {
|
||||
t.Fatalf("after big one, size = %d; want %d", d.dynTab.size, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecoderDecode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in []byte
|
||||
want []HeaderField
|
||||
wantDynTab []HeaderField // newest entry first
|
||||
}{
|
||||
// C.2.1 Literal Header Field with Indexing
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.1
|
||||
{"C.2.1", dehex("400a 6375 7374 6f6d 2d6b 6579 0d63 7573 746f 6d2d 6865 6164 6572"),
|
||||
[]HeaderField{pair("custom-key", "custom-header")},
|
||||
[]HeaderField{pair("custom-key", "custom-header")},
|
||||
},
|
||||
|
||||
// C.2.2 Literal Header Field without Indexing
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.2
|
||||
{"C.2.2", dehex("040c 2f73 616d 706c 652f 7061 7468"),
|
||||
[]HeaderField{pair(":path", "/sample/path")},
|
||||
[]HeaderField{}},
|
||||
|
||||
// C.2.3 Literal Header Field never Indexed
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.3
|
||||
{"C.2.3", dehex("1008 7061 7373 776f 7264 0673 6563 7265 74"),
|
||||
[]HeaderField{{"password", "secret", true}},
|
||||
[]HeaderField{}},
|
||||
|
||||
// C.2.4 Indexed Header Field
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.4
|
||||
{"C.2.4", []byte("\x82"),
|
||||
[]HeaderField{pair(":method", "GET")},
|
||||
[]HeaderField{}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
d := NewDecoder(4096, nil)
|
||||
hf, err := d.DecodeFull(tt.in)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", tt.name, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(hf, tt.want) {
|
||||
t.Errorf("%s: Got %v; want %v", tt.name, hf, tt.want)
|
||||
}
|
||||
gotDynTab := d.dynTab.reverseCopy()
|
||||
if !reflect.DeepEqual(gotDynTab, tt.wantDynTab) {
|
||||
t.Errorf("%s: dynamic table after = %v; want %v", tt.name, gotDynTab, tt.wantDynTab)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dt *dynamicTable) reverseCopy() (hf []HeaderField) {
|
||||
hf = make([]HeaderField, len(dt.ents))
|
||||
for i := range hf {
|
||||
hf[i] = dt.ents[len(dt.ents)-1-i]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type encAndWant struct {
|
||||
enc []byte
|
||||
want []HeaderField
|
||||
wantDynTab []HeaderField
|
||||
wantDynSize uint32
|
||||
}
|
||||
|
||||
// C.3 Request Examples without Huffman Coding
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.3
|
||||
func TestDecodeC3_NoHuffman(t *testing.T) {
|
||||
testDecodeSeries(t, 4096, []encAndWant{
|
||||
{dehex("8286 8441 0f77 7777 2e65 7861 6d70 6c65 2e63 6f6d"),
|
||||
[]HeaderField{
|
||||
pair(":method", "GET"),
|
||||
pair(":scheme", "http"),
|
||||
pair(":path", "/"),
|
||||
pair(":authority", "www.example.com"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair(":authority", "www.example.com"),
|
||||
},
|
||||
57,
|
||||
},
|
||||
{dehex("8286 84be 5808 6e6f 2d63 6163 6865"),
|
||||
[]HeaderField{
|
||||
pair(":method", "GET"),
|
||||
pair(":scheme", "http"),
|
||||
pair(":path", "/"),
|
||||
pair(":authority", "www.example.com"),
|
||||
pair("cache-control", "no-cache"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair("cache-control", "no-cache"),
|
||||
pair(":authority", "www.example.com"),
|
||||
},
|
||||
110,
|
||||
},
|
||||
{dehex("8287 85bf 400a 6375 7374 6f6d 2d6b 6579 0c63 7573 746f 6d2d 7661 6c75 65"),
|
||||
[]HeaderField{
|
||||
pair(":method", "GET"),
|
||||
pair(":scheme", "https"),
|
||||
pair(":path", "/index.html"),
|
||||
pair(":authority", "www.example.com"),
|
||||
pair("custom-key", "custom-value"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair("custom-key", "custom-value"),
|
||||
pair("cache-control", "no-cache"),
|
||||
pair(":authority", "www.example.com"),
|
||||
},
|
||||
164,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// C.4 Request Examples with Huffman Coding
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.4
|
||||
func TestDecodeC4_Huffman(t *testing.T) {
|
||||
testDecodeSeries(t, 4096, []encAndWant{
|
||||
{dehex("8286 8441 8cf1 e3c2 e5f2 3a6b a0ab 90f4 ff"),
|
||||
[]HeaderField{
|
||||
pair(":method", "GET"),
|
||||
pair(":scheme", "http"),
|
||||
pair(":path", "/"),
|
||||
pair(":authority", "www.example.com"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair(":authority", "www.example.com"),
|
||||
},
|
||||
57,
|
||||
},
|
||||
{dehex("8286 84be 5886 a8eb 1064 9cbf"),
|
||||
[]HeaderField{
|
||||
pair(":method", "GET"),
|
||||
pair(":scheme", "http"),
|
||||
pair(":path", "/"),
|
||||
pair(":authority", "www.example.com"),
|
||||
pair("cache-control", "no-cache"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair("cache-control", "no-cache"),
|
||||
pair(":authority", "www.example.com"),
|
||||
},
|
||||
110,
|
||||
},
|
||||
{dehex("8287 85bf 4088 25a8 49e9 5ba9 7d7f 8925 a849 e95b b8e8 b4bf"),
|
||||
[]HeaderField{
|
||||
pair(":method", "GET"),
|
||||
pair(":scheme", "https"),
|
||||
pair(":path", "/index.html"),
|
||||
pair(":authority", "www.example.com"),
|
||||
pair("custom-key", "custom-value"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair("custom-key", "custom-value"),
|
||||
pair("cache-control", "no-cache"),
|
||||
pair(":authority", "www.example.com"),
|
||||
},
|
||||
164,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.5
|
||||
// "This section shows several consecutive header lists, corresponding
|
||||
// to HTTP responses, on the same connection. The HTTP/2 setting
|
||||
// parameter SETTINGS_HEADER_TABLE_SIZE is set to the value of 256
|
||||
// octets, causing some evictions to occur."
|
||||
func TestDecodeC5_ResponsesNoHuff(t *testing.T) {
|
||||
testDecodeSeries(t, 256, []encAndWant{
|
||||
{dehex(`
|
||||
4803 3330 3258 0770 7269 7661 7465 611d
|
||||
4d6f 6e2c 2032 3120 4f63 7420 3230 3133
|
||||
2032 303a 3133 3a32 3120 474d 546e 1768
|
||||
7474 7073 3a2f 2f77 7777 2e65 7861 6d70
|
||||
6c65 2e63 6f6d
|
||||
`),
|
||||
[]HeaderField{
|
||||
pair(":status", "302"),
|
||||
pair("cache-control", "private"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
|
||||
pair("location", "https://www.example.com"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair("location", "https://www.example.com"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
|
||||
pair("cache-control", "private"),
|
||||
pair(":status", "302"),
|
||||
},
|
||||
222,
|
||||
},
|
||||
{dehex("4803 3330 37c1 c0bf"),
|
||||
[]HeaderField{
|
||||
pair(":status", "307"),
|
||||
pair("cache-control", "private"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
|
||||
pair("location", "https://www.example.com"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair(":status", "307"),
|
||||
pair("location", "https://www.example.com"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
|
||||
pair("cache-control", "private"),
|
||||
},
|
||||
222,
|
||||
},
|
||||
{dehex(`
|
||||
88c1 611d 4d6f 6e2c 2032 3120 4f63 7420
|
||||
3230 3133 2032 303a 3133 3a32 3220 474d
|
||||
54c0 5a04 677a 6970 7738 666f 6f3d 4153
|
||||
444a 4b48 514b 425a 584f 5157 454f 5049
|
||||
5541 5851 5745 4f49 553b 206d 6178 2d61
|
||||
6765 3d33 3630 303b 2076 6572 7369 6f6e
|
||||
3d31
|
||||
`),
|
||||
[]HeaderField{
|
||||
pair(":status", "200"),
|
||||
pair("cache-control", "private"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
|
||||
pair("location", "https://www.example.com"),
|
||||
pair("content-encoding", "gzip"),
|
||||
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
|
||||
pair("content-encoding", "gzip"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
|
||||
},
|
||||
215,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// http://http2.github.io/http2-spec/compression.html#rfc.section.C.6
|
||||
// "This section shows the same examples as the previous section, but
|
||||
// using Huffman encoding for the literal values. The HTTP/2 setting
|
||||
// parameter SETTINGS_HEADER_TABLE_SIZE is set to the value of 256
|
||||
// octets, causing some evictions to occur. The eviction mechanism
|
||||
// uses the length of the decoded literal values, so the same
|
||||
// evictions occurs as in the previous section."
|
||||
func TestDecodeC6_ResponsesHuffman(t *testing.T) {
|
||||
testDecodeSeries(t, 256, []encAndWant{
|
||||
{dehex(`
|
||||
4882 6402 5885 aec3 771a 4b61 96d0 7abe
|
||||
9410 54d4 44a8 2005 9504 0b81 66e0 82a6
|
||||
2d1b ff6e 919d 29ad 1718 63c7 8f0b 97c8
|
||||
e9ae 82ae 43d3
|
||||
`),
|
||||
[]HeaderField{
|
||||
pair(":status", "302"),
|
||||
pair("cache-control", "private"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
|
||||
pair("location", "https://www.example.com"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair("location", "https://www.example.com"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
|
||||
pair("cache-control", "private"),
|
||||
pair(":status", "302"),
|
||||
},
|
||||
222,
|
||||
},
|
||||
{dehex("4883 640e ffc1 c0bf"),
|
||||
[]HeaderField{
|
||||
pair(":status", "307"),
|
||||
pair("cache-control", "private"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
|
||||
pair("location", "https://www.example.com"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair(":status", "307"),
|
||||
pair("location", "https://www.example.com"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"),
|
||||
pair("cache-control", "private"),
|
||||
},
|
||||
222,
|
||||
},
|
||||
{dehex(`
|
||||
88c1 6196 d07a be94 1054 d444 a820 0595
|
||||
040b 8166 e084 a62d 1bff c05a 839b d9ab
|
||||
77ad 94e7 821d d7f2 e6c7 b335 dfdf cd5b
|
||||
3960 d5af 2708 7f36 72c1 ab27 0fb5 291f
|
||||
9587 3160 65c0 03ed 4ee5 b106 3d50 07
|
||||
`),
|
||||
[]HeaderField{
|
||||
pair(":status", "200"),
|
||||
pair("cache-control", "private"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
|
||||
pair("location", "https://www.example.com"),
|
||||
pair("content-encoding", "gzip"),
|
||||
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
|
||||
},
|
||||
[]HeaderField{
|
||||
pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"),
|
||||
pair("content-encoding", "gzip"),
|
||||
pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"),
|
||||
},
|
||||
215,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func testDecodeSeries(t *testing.T, size uint32, steps []encAndWant) {
|
||||
d := NewDecoder(size, nil)
|
||||
for i, step := range steps {
|
||||
hf, err := d.DecodeFull(step.enc)
|
||||
if err != nil {
|
||||
t.Fatalf("Error at step index %d: %v", i, err)
|
||||
}
|
||||
if !reflect.DeepEqual(hf, step.want) {
|
||||
t.Fatalf("At step index %d: Got headers %v; want %v", i, hf, step.want)
|
||||
}
|
||||
gotDynTab := d.dynTab.reverseCopy()
|
||||
if !reflect.DeepEqual(gotDynTab, step.wantDynTab) {
|
||||
t.Errorf("After step index %d, dynamic table = %v; want %v", i, gotDynTab, step.wantDynTab)
|
||||
}
|
||||
if d.dynTab.size != step.wantDynSize {
|
||||
t.Errorf("After step index %d, dynamic table size = %v; want %v", i, d.dynTab.size, step.wantDynSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHuffmanDecode(t *testing.T) {
|
||||
tests := []struct {
|
||||
inHex, want string
|
||||
}{
|
||||
{"f1e3 c2e5 f23a 6ba0 ab90 f4ff", "www.example.com"},
|
||||
{"a8eb 1064 9cbf", "no-cache"},
|
||||
{"25a8 49e9 5ba9 7d7f", "custom-key"},
|
||||
{"25a8 49e9 5bb8 e8b4 bf", "custom-value"},
|
||||
{"6402", "302"},
|
||||
{"aec3 771a 4b", "private"},
|
||||
{"d07a be94 1054 d444 a820 0595 040b 8166 e082 a62d 1bff", "Mon, 21 Oct 2013 20:13:21 GMT"},
|
||||
{"9d29 ad17 1863 c78f 0b97 c8e9 ae82 ae43 d3", "https://www.example.com"},
|
||||
{"9bd9 ab", "gzip"},
|
||||
{"94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 3160 65c0 03ed 4ee5 b106 3d50 07",
|
||||
"foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
var buf bytes.Buffer
|
||||
in, err := hex.DecodeString(strings.Replace(tt.inHex, " ", "", -1))
|
||||
if err != nil {
|
||||
t.Errorf("%d. hex input error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if _, err := HuffmanDecode(&buf, in); err != nil {
|
||||
t.Errorf("%d. decode error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if got := buf.String(); tt.want != got {
|
||||
t.Errorf("%d. decode = %q; want %q", i, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendHuffmanString(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, want string
|
||||
}{
|
||||
{"www.example.com", "f1e3 c2e5 f23a 6ba0 ab90 f4ff"},
|
||||
{"no-cache", "a8eb 1064 9cbf"},
|
||||
{"custom-key", "25a8 49e9 5ba9 7d7f"},
|
||||
{"custom-value", "25a8 49e9 5bb8 e8b4 bf"},
|
||||
{"302", "6402"},
|
||||
{"private", "aec3 771a 4b"},
|
||||
{"Mon, 21 Oct 2013 20:13:21 GMT", "d07a be94 1054 d444 a820 0595 040b 8166 e082 a62d 1bff"},
|
||||
{"https://www.example.com", "9d29 ad17 1863 c78f 0b97 c8e9 ae82 ae43 d3"},
|
||||
{"gzip", "9bd9 ab"},
|
||||
{"foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1",
|
||||
"94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 3160 65c0 03ed 4ee5 b106 3d50 07"},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
buf := []byte{}
|
||||
want := strings.Replace(tt.want, " ", "", -1)
|
||||
buf = AppendHuffmanString(buf, tt.in)
|
||||
if got := hex.EncodeToString(buf); want != got {
|
||||
t.Errorf("%d. encode = %q; want %q", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadVarInt(t *testing.T) {
|
||||
type res struct {
|
||||
i uint64
|
||||
consumed int
|
||||
err error
|
||||
}
|
||||
tests := []struct {
|
||||
n byte
|
||||
p []byte
|
||||
want res
|
||||
}{
|
||||
// Fits in a byte:
|
||||
{1, []byte{0}, res{0, 1, nil}},
|
||||
{2, []byte{2}, res{2, 1, nil}},
|
||||
{3, []byte{6}, res{6, 1, nil}},
|
||||
{4, []byte{14}, res{14, 1, nil}},
|
||||
{5, []byte{30}, res{30, 1, nil}},
|
||||
{6, []byte{62}, res{62, 1, nil}},
|
||||
{7, []byte{126}, res{126, 1, nil}},
|
||||
{8, []byte{254}, res{254, 1, nil}},
|
||||
|
||||
// Doesn't fit in a byte:
|
||||
{1, []byte{1}, res{0, 0, errNeedMore}},
|
||||
{2, []byte{3}, res{0, 0, errNeedMore}},
|
||||
{3, []byte{7}, res{0, 0, errNeedMore}},
|
||||
{4, []byte{15}, res{0, 0, errNeedMore}},
|
||||
{5, []byte{31}, res{0, 0, errNeedMore}},
|
||||
{6, []byte{63}, res{0, 0, errNeedMore}},
|
||||
{7, []byte{127}, res{0, 0, errNeedMore}},
|
||||
{8, []byte{255}, res{0, 0, errNeedMore}},
|
||||
|
||||
// Ignoring top bits:
|
||||
{5, []byte{255, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 111
|
||||
{5, []byte{159, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 100
|
||||
{5, []byte{191, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 101
|
||||
|
||||
// Extra byte:
|
||||
{5, []byte{191, 154, 10, 2}, res{1337, 3, nil}}, // extra byte
|
||||
|
||||
// Short a byte:
|
||||
{5, []byte{191, 154}, res{0, 0, errNeedMore}},
|
||||
|
||||
// integer overflow:
|
||||
{1, []byte{255, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128}, res{0, 0, errVarintOverflow}},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
i, remain, err := readVarInt(tt.n, tt.p)
|
||||
consumed := len(tt.p) - len(remain)
|
||||
got := res{i, consumed, err}
|
||||
if got != tt.want {
|
||||
t.Errorf("readVarInt(%d, %v ~ %x) = %+v; want %+v", tt.n, tt.p, tt.p, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dehex(s string) []byte {
|
||||
s = strings.Replace(s, " ", "", -1)
|
||||
s = strings.Replace(s, "\n", "", -1)
|
||||
b, err := hex.DecodeString(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
@ -0,0 +1,159 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package hpack
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} { return new(bytes.Buffer) },
|
||||
}
|
||||
|
||||
// HuffmanDecode decodes the string in v and writes the expanded
|
||||
// result to w, returning the number of bytes written to w and the
|
||||
// Write call's return value. At most one Write call is made.
|
||||
func HuffmanDecode(w io.Writer, v []byte) (int, error) {
|
||||
buf := bufPool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
defer bufPool.Put(buf)
|
||||
|
||||
n := rootHuffmanNode
|
||||
cur, nbits := uint(0), uint8(0)
|
||||
for _, b := range v {
|
||||
cur = cur<<8 | uint(b)
|
||||
nbits += 8
|
||||
for nbits >= 8 {
|
||||
n = n.children[byte(cur>>(nbits-8))]
|
||||
if n.children == nil {
|
||||
buf.WriteByte(n.sym)
|
||||
nbits -= n.codeLen
|
||||
n = rootHuffmanNode
|
||||
} else {
|
||||
nbits -= 8
|
||||
}
|
||||
}
|
||||
}
|
||||
for nbits > 0 {
|
||||
n = n.children[byte(cur<<(8-nbits))]
|
||||
if n.children != nil || n.codeLen > nbits {
|
||||
break
|
||||
}
|
||||
buf.WriteByte(n.sym)
|
||||
nbits -= n.codeLen
|
||||
n = rootHuffmanNode
|
||||
}
|
||||
return w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
type node struct {
|
||||
// children is non-nil for internal nodes
|
||||
children []*node
|
||||
|
||||
// The following are only valid if children is nil:
|
||||
codeLen uint8 // number of bits that led to the output of sym
|
||||
sym byte // output symbol
|
||||
}
|
||||
|
||||
func newInternalNode() *node {
|
||||
return &node{children: make([]*node, 256)}
|
||||
}
|
||||
|
||||
var rootHuffmanNode = newInternalNode()
|
||||
|
||||
func init() {
|
||||
for i, code := range huffmanCodes {
|
||||
if i > 255 {
|
||||
panic("too many huffman codes")
|
||||
}
|
||||
addDecoderNode(byte(i), code, huffmanCodeLen[i])
|
||||
}
|
||||
}
|
||||
|
||||
func addDecoderNode(sym byte, code uint32, codeLen uint8) {
|
||||
cur := rootHuffmanNode
|
||||
for codeLen > 8 {
|
||||
codeLen -= 8
|
||||
i := uint8(code >> codeLen)
|
||||
if cur.children[i] == nil {
|
||||
cur.children[i] = newInternalNode()
|
||||
}
|
||||
cur = cur.children[i]
|
||||
}
|
||||
shift := 8 - codeLen
|
||||
start, end := int(uint8(code<<shift)), int(1<<shift)
|
||||
for i := start; i < start+end; i++ {
|
||||
cur.children[i] = &node{sym: sym, codeLen: codeLen}
|
||||
}
|
||||
}
|
||||
|
||||
// AppendHuffmanString appends s, as encoded in Huffman codes, to dst
|
||||
// and returns the extended buffer.
|
||||
func AppendHuffmanString(dst []byte, s string) []byte {
|
||||
rembits := uint8(8)
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
if rembits == 8 {
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
dst, rembits = appendByteToHuffmanCode(dst, rembits, s[i])
|
||||
}
|
||||
|
||||
if rembits < 8 {
|
||||
// special EOS symbol
|
||||
code := uint32(0x3fffffff)
|
||||
nbits := uint8(30)
|
||||
|
||||
t := uint8(code >> (nbits - rembits))
|
||||
dst[len(dst)-1] |= t
|
||||
}
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// HuffmanEncodeLength returns the number of bytes required to encode
|
||||
// s in Huffman codes. The result is round up to byte boundary.
|
||||
func HuffmanEncodeLength(s string) uint64 {
|
||||
n := uint64(0)
|
||||
for i := 0; i < len(s); i++ {
|
||||
n += uint64(huffmanCodeLen[s[i]])
|
||||
}
|
||||
return (n + 7) / 8
|
||||
}
|
||||
|
||||
// appendByteToHuffmanCode appends Huffman code for c to dst and
|
||||
// returns the extended buffer and the remaining bits in the last
|
||||
// element. The appending is not byte aligned and the remaining bits
|
||||
// in the last element of dst is given in rembits.
|
||||
func appendByteToHuffmanCode(dst []byte, rembits uint8, c byte) ([]byte, uint8) {
|
||||
code := huffmanCodes[c]
|
||||
nbits := huffmanCodeLen[c]
|
||||
|
||||
for {
|
||||
if rembits > nbits {
|
||||
t := uint8(code << (rembits - nbits))
|
||||
dst[len(dst)-1] |= t
|
||||
rembits -= nbits
|
||||
break
|
||||
}
|
||||
|
||||
t := uint8(code >> (nbits - rembits))
|
||||
dst[len(dst)-1] |= t
|
||||
|
||||
nbits -= rembits
|
||||
rembits = 8
|
||||
|
||||
if nbits == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
|
||||
return dst, rembits
|
||||
}
|
@ -0,0 +1,353 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package hpack
|
||||
|
||||
func pair(name, value string) HeaderField {
|
||||
return HeaderField{Name: name, Value: value}
|
||||
}
|
||||
|
||||
// http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-07#appendix-B
|
||||
var staticTable = []HeaderField{
|
||||
pair(":authority", ""), // index 1 (1-based)
|
||||
pair(":method", "GET"),
|
||||
pair(":method", "POST"),
|
||||
pair(":path", "/"),
|
||||
pair(":path", "/index.html"),
|
||||
pair(":scheme", "http"),
|
||||
pair(":scheme", "https"),
|
||||
pair(":status", "200"),
|
||||
pair(":status", "204"),
|
||||
pair(":status", "206"),
|
||||
pair(":status", "304"),
|
||||
pair(":status", "400"),
|
||||
pair(":status", "404"),
|
||||
pair(":status", "500"),
|
||||
pair("accept-charset", ""),
|
||||
pair("accept-encoding", "gzip, deflate"),
|
||||
pair("accept-language", ""),
|
||||
pair("accept-ranges", ""),
|
||||
pair("accept", ""),
|
||||
pair("access-control-allow-origin", ""),
|
||||
pair("age", ""),
|
||||
pair("allow", ""),
|
||||
pair("authorization", ""),
|
||||
pair("cache-control", ""),
|
||||
pair("content-disposition", ""),
|
||||
pair("content-encoding", ""),
|
||||
pair("content-language", ""),
|
||||
pair("content-length", ""),
|
||||
pair("content-location", ""),
|
||||
pair("content-range", ""),
|
||||
pair("content-type", ""),
|
||||
pair("cookie", ""),
|
||||
pair("date", ""),
|
||||
pair("etag", ""),
|
||||
pair("expect", ""),
|
||||
pair("expires", ""),
|
||||
pair("from", ""),
|
||||
pair("host", ""),
|
||||
pair("if-match", ""),
|
||||
pair("if-modified-since", ""),
|
||||
pair("if-none-match", ""),
|
||||
pair("if-range", ""),
|
||||
pair("if-unmodified-since", ""),
|
||||
pair("last-modified", ""),
|
||||
pair("link", ""),
|
||||
pair("location", ""),
|
||||
pair("max-forwards", ""),
|
||||
pair("proxy-authenticate", ""),
|
||||
pair("proxy-authorization", ""),
|
||||
pair("range", ""),
|
||||
pair("referer", ""),
|
||||
pair("refresh", ""),
|
||||
pair("retry-after", ""),
|
||||
pair("server", ""),
|
||||
pair("set-cookie", ""),
|
||||
pair("strict-transport-security", ""),
|
||||
pair("transfer-encoding", ""),
|
||||
pair("user-agent", ""),
|
||||
pair("vary", ""),
|
||||
pair("via", ""),
|
||||
pair("www-authenticate", ""),
|
||||
}
|
||||
|
||||
var huffmanCodes = []uint32{
|
||||
0x1ff8,
|
||||
0x7fffd8,
|
||||
0xfffffe2,
|
||||
0xfffffe3,
|
||||
0xfffffe4,
|
||||
0xfffffe5,
|
||||
0xfffffe6,
|
||||
0xfffffe7,
|
||||
0xfffffe8,
|
||||
0xffffea,
|
||||
0x3ffffffc,
|
||||
0xfffffe9,
|
||||
0xfffffea,
|
||||
0x3ffffffd,
|
||||
0xfffffeb,
|
||||
0xfffffec,
|
||||
0xfffffed,
|
||||
0xfffffee,
|
||||
0xfffffef,
|
||||
0xffffff0,
|
||||
0xffffff1,
|
||||
0xffffff2,
|
||||
0x3ffffffe,
|
||||
0xffffff3,
|
||||
0xffffff4,
|
||||
0xffffff5,
|
||||
0xffffff6,
|
||||
0xffffff7,
|
||||
0xffffff8,
|
||||
0xffffff9,
|
||||
0xffffffa,
|
||||
0xffffffb,
|
||||
0x14,
|
||||
0x3f8,
|
||||
0x3f9,
|
||||
0xffa,
|
||||
0x1ff9,
|
||||
0x15,
|
||||
0xf8,
|
||||
0x7fa,
|
||||
0x3fa,
|
||||
0x3fb,
|
||||
0xf9,
|
||||
0x7fb,
|
||||
0xfa,
|
||||
0x16,
|
||||
0x17,
|
||||
0x18,
|
||||
0x0,
|
||||
0x1,
|
||||
0x2,
|
||||
0x19,
|
||||
0x1a,
|
||||
0x1b,
|
||||
0x1c,
|
||||
0x1d,
|
||||
0x1e,
|
||||
0x1f,
|
||||
0x5c,
|
||||
0xfb,
|
||||
0x7ffc,
|
||||
0x20,
|
||||
0xffb,
|
||||
0x3fc,
|
||||
0x1ffa,
|
||||
0x21,
|
||||
0x5d,
|
||||
0x5e,
|
||||
0x5f,
|
||||
0x60,
|
||||
0x61,
|
||||
0x62,
|
||||
0x63,
|
||||
0x64,
|
||||
0x65,
|
||||
0x66,
|
||||
0x67,
|
||||
0x68,
|
||||
0x69,
|
||||
0x6a,
|
||||
0x6b,
|
||||
0x6c,
|
||||
0x6d,
|
||||
0x6e,
|
||||
0x6f,
|
||||
0x70,
|
||||
0x71,
|
||||
0x72,
|
||||
0xfc,
|
||||
0x73,
|
||||
0xfd,
|
||||
0x1ffb,
|
||||
0x7fff0,
|
||||
0x1ffc,
|
||||
0x3ffc,
|
||||
0x22,
|
||||
0x7ffd,
|
||||
0x3,
|
||||
0x23,
|
||||
0x4,
|
||||
0x24,
|
||||
0x5,
|
||||
0x25,
|
||||
0x26,
|
||||
0x27,
|
||||
0x6,
|
||||
0x74,
|
||||
0x75,
|
||||
0x28,
|
||||
0x29,
|
||||
0x2a,
|
||||
0x7,
|
||||
0x2b,
|
||||
0x76,
|
||||
0x2c,
|
||||
0x8,
|
||||
0x9,
|
||||
0x2d,
|
||||
0x77,
|
||||
0x78,
|
||||
0x79,
|
||||
0x7a,
|
||||
0x7b,
|
||||
0x7ffe,
|
||||
0x7fc,
|
||||
0x3ffd,
|
||||
0x1ffd,
|
||||
0xffffffc,
|
||||
0xfffe6,
|
||||
0x3fffd2,
|
||||
0xfffe7,
|
||||
0xfffe8,
|
||||
0x3fffd3,
|
||||
0x3fffd4,
|
||||
0x3fffd5,
|
||||
0x7fffd9,
|
||||
0x3fffd6,
|
||||
0x7fffda,
|
||||
0x7fffdb,
|
||||
0x7fffdc,
|
||||
0x7fffdd,
|
||||
0x7fffde,
|
||||
0xffffeb,
|
||||
0x7fffdf,
|
||||
0xffffec,
|
||||
0xffffed,
|
||||
0x3fffd7,
|
||||
0x7fffe0,
|
||||
0xffffee,
|
||||
0x7fffe1,
|
||||
0x7fffe2,
|
||||
0x7fffe3,
|
||||
0x7fffe4,
|
||||
0x1fffdc,
|
||||
0x3fffd8,
|
||||
0x7fffe5,
|
||||
0x3fffd9,
|
||||
0x7fffe6,
|
||||
0x7fffe7,
|
||||
0xffffef,
|
||||
0x3fffda,
|
||||
0x1fffdd,
|
||||
0xfffe9,
|
||||
0x3fffdb,
|
||||
0x3fffdc,
|
||||
0x7fffe8,
|
||||
0x7fffe9,
|
||||
0x1fffde,
|
||||
0x7fffea,
|
||||
0x3fffdd,
|
||||
0x3fffde,
|
||||
0xfffff0,
|
||||
0x1fffdf,
|
||||
0x3fffdf,
|
||||
0x7fffeb,
|
||||
0x7fffec,
|
||||
0x1fffe0,
|
||||
0x1fffe1,
|
||||
0x3fffe0,
|
||||
0x1fffe2,
|
||||
0x7fffed,
|
||||
0x3fffe1,
|
||||
0x7fffee,
|
||||
0x7fffef,
|
||||
0xfffea,
|
||||
0x3fffe2,
|
||||
0x3fffe3,
|
||||
0x3fffe4,
|
||||
0x7ffff0,
|
||||
0x3fffe5,
|
||||
0x3fffe6,
|
||||
0x7ffff1,
|
||||
0x3ffffe0,
|
||||
0x3ffffe1,
|
||||
0xfffeb,
|
||||
0x7fff1,
|
||||
0x3fffe7,
|
||||
0x7ffff2,
|
||||
0x3fffe8,
|
||||
0x1ffffec,
|
||||
0x3ffffe2,
|
||||
0x3ffffe3,
|
||||
0x3ffffe4,
|
||||
0x7ffffde,
|
||||
0x7ffffdf,
|
||||
0x3ffffe5,
|
||||
0xfffff1,
|
||||
0x1ffffed,
|
||||
0x7fff2,
|
||||
0x1fffe3,
|
||||
0x3ffffe6,
|
||||
0x7ffffe0,
|
||||
0x7ffffe1,
|
||||
0x3ffffe7,
|
||||
0x7ffffe2,
|
||||
0xfffff2,
|
||||
0x1fffe4,
|
||||
0x1fffe5,
|
||||
0x3ffffe8,
|
||||
0x3ffffe9,
|
||||
0xffffffd,
|
||||
0x7ffffe3,
|
||||
0x7ffffe4,
|
||||
0x7ffffe5,
|
||||
0xfffec,
|
||||
0xfffff3,
|
||||
0xfffed,
|
||||
0x1fffe6,
|
||||
0x3fffe9,
|
||||
0x1fffe7,
|
||||
0x1fffe8,
|
||||
0x7ffff3,
|
||||
0x3fffea,
|
||||
0x3fffeb,
|
||||
0x1ffffee,
|
||||
0x1ffffef,
|
||||
0xfffff4,
|
||||
0xfffff5,
|
||||
0x3ffffea,
|
||||
0x7ffff4,
|
||||
0x3ffffeb,
|
||||
0x7ffffe6,
|
||||
0x3ffffec,
|
||||
0x3ffffed,
|
||||
0x7ffffe7,
|
||||
0x7ffffe8,
|
||||
0x7ffffe9,
|
||||
0x7ffffea,
|
||||
0x7ffffeb,
|
||||
0xffffffe,
|
||||
0x7ffffec,
|
||||
0x7ffffed,
|
||||
0x7ffffee,
|
||||
0x7ffffef,
|
||||
0x7fffff0,
|
||||
0x3ffffee,
|
||||
}
|
||||
|
||||
var huffmanCodeLen = []uint8{
|
||||
13, 23, 28, 28, 28, 28, 28, 28, 28, 24, 30, 28, 28, 30, 28, 28,
|
||||
28, 28, 28, 28, 28, 28, 30, 28, 28, 28, 28, 28, 28, 28, 28, 28,
|
||||
6, 10, 10, 12, 13, 6, 8, 11, 10, 10, 8, 11, 8, 6, 6, 6,
|
||||
5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 8, 15, 6, 12, 10,
|
||||
13, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
|
||||
7, 7, 7, 7, 7, 7, 7, 7, 8, 7, 8, 13, 19, 13, 14, 6,
|
||||
15, 5, 6, 5, 6, 5, 6, 6, 6, 5, 7, 7, 6, 6, 6, 5,
|
||||
6, 7, 6, 5, 5, 6, 7, 7, 7, 7, 7, 15, 11, 14, 13, 28,
|
||||
20, 22, 20, 20, 22, 22, 22, 23, 22, 23, 23, 23, 23, 23, 24, 23,
|
||||
24, 24, 22, 23, 24, 23, 23, 23, 23, 21, 22, 23, 22, 23, 23, 24,
|
||||
22, 21, 20, 22, 22, 23, 23, 21, 23, 22, 22, 24, 21, 22, 23, 23,
|
||||
21, 21, 22, 21, 23, 22, 23, 23, 20, 22, 22, 22, 23, 22, 22, 23,
|
||||
26, 26, 20, 19, 22, 23, 22, 25, 26, 26, 26, 27, 27, 26, 24, 25,
|
||||
19, 21, 26, 27, 27, 26, 27, 24, 21, 21, 26, 26, 28, 27, 27, 27,
|
||||
20, 24, 20, 21, 22, 21, 21, 23, 22, 22, 25, 25, 24, 24, 26, 23,
|
||||
26, 27, 26, 26, 27, 27, 27, 27, 27, 28, 27, 27, 27, 27, 27, 26,
|
||||
}
|
@ -0,0 +1,249 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
// Package http2 implements the HTTP/2 protocol.
|
||||
//
|
||||
// This is a work in progress. This package is low-level and intended
|
||||
// to be used directly by very few people. Most users will use it
|
||||
// indirectly through integration with the net/http package. See
|
||||
// ConfigureServer. That ConfigureServer call will likely be automatic
|
||||
// or available via an empty import in the future.
|
||||
//
|
||||
// See http://http2.github.io/
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var VerboseLogs = false
|
||||
|
||||
const (
|
||||
// ClientPreface is the string that must be sent by new
|
||||
// connections from clients.
|
||||
ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
|
||||
|
||||
// SETTINGS_MAX_FRAME_SIZE default
|
||||
// http://http2.github.io/http2-spec/#rfc.section.6.5.2
|
||||
initialMaxFrameSize = 16384
|
||||
|
||||
// NextProtoTLS is the NPN/ALPN protocol negotiated during
|
||||
// HTTP/2's TLS setup.
|
||||
NextProtoTLS = "h2"
|
||||
|
||||
// http://http2.github.io/http2-spec/#SettingValues
|
||||
initialHeaderTableSize = 4096
|
||||
|
||||
initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size
|
||||
|
||||
defaultMaxReadFrameSize = 1 << 20
|
||||
)
|
||||
|
||||
var (
|
||||
clientPreface = []byte(ClientPreface)
|
||||
)
|
||||
|
||||
type streamState int
|
||||
|
||||
const (
|
||||
stateIdle streamState = iota
|
||||
stateOpen
|
||||
stateHalfClosedLocal
|
||||
stateHalfClosedRemote
|
||||
stateResvLocal
|
||||
stateResvRemote
|
||||
stateClosed
|
||||
)
|
||||
|
||||
var stateName = [...]string{
|
||||
stateIdle: "Idle",
|
||||
stateOpen: "Open",
|
||||
stateHalfClosedLocal: "HalfClosedLocal",
|
||||
stateHalfClosedRemote: "HalfClosedRemote",
|
||||
stateResvLocal: "ResvLocal",
|
||||
stateResvRemote: "ResvRemote",
|
||||
stateClosed: "Closed",
|
||||
}
|
||||
|
||||
func (st streamState) String() string {
|
||||
return stateName[st]
|
||||
}
|
||||
|
||||
// Setting is a setting parameter: which setting it is, and its value.
|
||||
type Setting struct {
|
||||
// ID is which setting is being set.
|
||||
// See http://http2.github.io/http2-spec/#SettingValues
|
||||
ID SettingID
|
||||
|
||||
// Val is the value.
|
||||
Val uint32
|
||||
}
|
||||
|
||||
func (s Setting) String() string {
|
||||
return fmt.Sprintf("[%v = %d]", s.ID, s.Val)
|
||||
}
|
||||
|
||||
// Valid reports whether the setting is valid.
|
||||
func (s Setting) Valid() error {
|
||||
// Limits and error codes from 6.5.2 Defined SETTINGS Parameters
|
||||
switch s.ID {
|
||||
case SettingEnablePush:
|
||||
if s.Val != 1 && s.Val != 0 {
|
||||
return ConnectionError(ErrCodeProtocol)
|
||||
}
|
||||
case SettingInitialWindowSize:
|
||||
if s.Val > 1<<31-1 {
|
||||
return ConnectionError(ErrCodeFlowControl)
|
||||
}
|
||||
case SettingMaxFrameSize:
|
||||
if s.Val < 16384 || s.Val > 1<<24-1 {
|
||||
return ConnectionError(ErrCodeProtocol)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A SettingID is an HTTP/2 setting as defined in
|
||||
// http://http2.github.io/http2-spec/#iana-settings
|
||||
type SettingID uint16
|
||||
|
||||
const (
|
||||
SettingHeaderTableSize SettingID = 0x1
|
||||
SettingEnablePush SettingID = 0x2
|
||||
SettingMaxConcurrentStreams SettingID = 0x3
|
||||
SettingInitialWindowSize SettingID = 0x4
|
||||
SettingMaxFrameSize SettingID = 0x5
|
||||
SettingMaxHeaderListSize SettingID = 0x6
|
||||
)
|
||||
|
||||
var settingName = map[SettingID]string{
|
||||
SettingHeaderTableSize: "HEADER_TABLE_SIZE",
|
||||
SettingEnablePush: "ENABLE_PUSH",
|
||||
SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
|
||||
SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
|
||||
SettingMaxFrameSize: "MAX_FRAME_SIZE",
|
||||
SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
|
||||
}
|
||||
|
||||
func (s SettingID) String() string {
|
||||
if v, ok := settingName[s]; ok {
|
||||
return v
|
||||
}
|
||||
return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s))
|
||||
}
|
||||
|
||||
func validHeader(v string) bool {
|
||||
if len(v) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range v {
|
||||
// "Just as in HTTP/1.x, header field names are
|
||||
// strings of ASCII characters that are compared in a
|
||||
// case-insensitive fashion. However, header field
|
||||
// names MUST be converted to lowercase prior to their
|
||||
// encoding in HTTP/2. "
|
||||
if r >= 127 || ('A' <= r && r <= 'Z') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var httpCodeStringCommon = map[int]string{} // n -> strconv.Itoa(n)
|
||||
|
||||
func init() {
|
||||
for i := 100; i <= 999; i++ {
|
||||
if v := http.StatusText(i); v != "" {
|
||||
httpCodeStringCommon[i] = strconv.Itoa(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func httpCodeString(code int) string {
|
||||
if s, ok := httpCodeStringCommon[code]; ok {
|
||||
return s
|
||||
}
|
||||
return strconv.Itoa(code)
|
||||
}
|
||||
|
||||
// from pkg io
|
||||
type stringWriter interface {
|
||||
WriteString(s string) (n int, err error)
|
||||
}
|
||||
|
||||
// A gate lets two goroutines coordinate their activities.
|
||||
type gate chan struct{}
|
||||
|
||||
func (g gate) Done() { g <- struct{}{} }
|
||||
func (g gate) Wait() { <-g }
|
||||
|
||||
// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
|
||||
type closeWaiter chan struct{}
|
||||
|
||||
// Init makes a closeWaiter usable.
|
||||
// It exists because so a closeWaiter value can be placed inside a
|
||||
// larger struct and have the Mutex and Cond's memory in the same
|
||||
// allocation.
|
||||
func (cw *closeWaiter) Init() {
|
||||
*cw = make(chan struct{})
|
||||
}
|
||||
|
||||
// Close marks the closeWaiter as closed and unblocks any waiters.
|
||||
func (cw closeWaiter) Close() {
|
||||
close(cw)
|
||||
}
|
||||
|
||||
// Wait waits for the closeWaiter to become closed.
|
||||
func (cw closeWaiter) Wait() {
|
||||
<-cw
|
||||
}
|
||||
|
||||
// bufferedWriter is a buffered writer that writes to w.
|
||||
// Its buffered writer is lazily allocated as needed, to minimize
|
||||
// idle memory usage with many connections.
|
||||
type bufferedWriter struct {
|
||||
w io.Writer // immutable
|
||||
bw *bufio.Writer // non-nil when data is buffered
|
||||
}
|
||||
|
||||
func newBufferedWriter(w io.Writer) *bufferedWriter {
|
||||
return &bufferedWriter{w: w}
|
||||
}
|
||||
|
||||
var bufWriterPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
// TODO: pick something better? this is a bit under
|
||||
// (3 x typical 1500 byte MTU) at least.
|
||||
return bufio.NewWriterSize(nil, 4<<10)
|
||||
},
|
||||
}
|
||||
|
||||
func (w *bufferedWriter) Write(p []byte) (n int, err error) {
|
||||
if w.bw == nil {
|
||||
bw := bufWriterPool.Get().(*bufio.Writer)
|
||||
bw.Reset(w.w)
|
||||
w.bw = bw
|
||||
}
|
||||
return w.bw.Write(p)
|
||||
}
|
||||
|
||||
func (w *bufferedWriter) Flush() error {
|
||||
bw := w.bw
|
||||
if bw == nil {
|
||||
return nil
|
||||
}
|
||||
err := bw.Flush()
|
||||
bw.Reset(nil)
|
||||
bufWriterPool.Put(bw)
|
||||
w.bw = nil
|
||||
return err
|
||||
}
|
@ -0,0 +1,152 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bradfitz/http2/hpack"
|
||||
)
|
||||
|
||||
var knownFailing = flag.Bool("known_failing", false, "Run known-failing tests.")
|
||||
|
||||
func condSkipFailingTest(t *testing.T) {
|
||||
if !*knownFailing {
|
||||
t.Skip("Skipping known-failing test without --known_failing")
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
DebugGoroutines = true
|
||||
flag.BoolVar(&VerboseLogs, "verboseh2", false, "Verbose HTTP/2 debug logging")
|
||||
}
|
||||
|
||||
func TestSettingString(t *testing.T) {
|
||||
tests := []struct {
|
||||
s Setting
|
||||
want string
|
||||
}{
|
||||
{Setting{SettingMaxFrameSize, 123}, "[MAX_FRAME_SIZE = 123]"},
|
||||
{Setting{1<<16 - 1, 123}, "[UNKNOWN_SETTING_65535 = 123]"},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
got := fmt.Sprint(tt.s)
|
||||
if got != tt.want {
|
||||
t.Errorf("%d. for %#v, string = %q; want %q", i, tt.s, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type twriter struct {
|
||||
t testing.TB
|
||||
st *serverTester // optional
|
||||
}
|
||||
|
||||
func (w twriter) Write(p []byte) (n int, err error) {
|
||||
if w.st != nil {
|
||||
ps := string(p)
|
||||
for _, phrase := range w.st.logFilter {
|
||||
if strings.Contains(ps, phrase) {
|
||||
return len(p), nil // no logging
|
||||
}
|
||||
}
|
||||
}
|
||||
w.t.Logf("%s", p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// like encodeHeader, but don't add implicit psuedo headers.
|
||||
func encodeHeaderNoImplicit(t *testing.T, headers ...string) []byte {
|
||||
var buf bytes.Buffer
|
||||
enc := hpack.NewEncoder(&buf)
|
||||
for len(headers) > 0 {
|
||||
k, v := headers[0], headers[1]
|
||||
headers = headers[2:]
|
||||
if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil {
|
||||
t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
|
||||
}
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// Verify that curl has http2.
|
||||
func requireCurl(t *testing.T) {
|
||||
out, err := dockerLogs(curl(t, "--version"))
|
||||
if err != nil {
|
||||
t.Skipf("failed to determine curl features; skipping test")
|
||||
}
|
||||
if !strings.Contains(string(out), "HTTP2") {
|
||||
t.Skip("curl doesn't support HTTP2; skipping test")
|
||||
}
|
||||
}
|
||||
|
||||
func curl(t *testing.T, args ...string) (container string) {
|
||||
out, err := exec.Command("docker", append([]string{"run", "-d", "--net=host", "gohttp2/curl"}, args...)...).CombinedOutput()
|
||||
if err != nil {
|
||||
t.Skipf("Failed to run curl in docker: %v, %s", err, out)
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
}
|
||||
|
||||
type puppetCommand struct {
|
||||
fn func(w http.ResponseWriter, r *http.Request)
|
||||
done chan<- bool
|
||||
}
|
||||
|
||||
type handlerPuppet struct {
|
||||
ch chan puppetCommand
|
||||
}
|
||||
|
||||
func newHandlerPuppet() *handlerPuppet {
|
||||
return &handlerPuppet{
|
||||
ch: make(chan puppetCommand),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *handlerPuppet) act(w http.ResponseWriter, r *http.Request) {
|
||||
for cmd := range p.ch {
|
||||
cmd.fn(w, r)
|
||||
cmd.done <- true
|
||||
}
|
||||
}
|
||||
|
||||
func (p *handlerPuppet) done() { close(p.ch) }
|
||||
func (p *handlerPuppet) do(fn func(http.ResponseWriter, *http.Request)) {
|
||||
done := make(chan bool)
|
||||
p.ch <- puppetCommand{fn, done}
|
||||
<-done
|
||||
}
|
||||
func dockerLogs(container string) ([]byte, error) {
|
||||
out, err := exec.Command("docker", "wait", container).CombinedOutput()
|
||||
if err != nil {
|
||||
return out, err
|
||||
}
|
||||
exitStatus, err := strconv.Atoi(strings.TrimSpace(string(out)))
|
||||
if err != nil {
|
||||
return out, errors.New("unexpected exit status from docker wait")
|
||||
}
|
||||
out, err = exec.Command("docker", "logs", container).CombinedOutput()
|
||||
exec.Command("docker", "rm", container).Run()
|
||||
if err == nil && exitStatus != 0 {
|
||||
err = fmt.Errorf("exit status %d: %s", exitStatus, out)
|
||||
}
|
||||
return out, err
|
||||
}
|
||||
|
||||
func kill(container string) {
|
||||
exec.Command("docker", "kill", container).Run()
|
||||
exec.Command("docker", "rm", container).Run()
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type pipe struct {
|
||||
b buffer
|
||||
c sync.Cond
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
// Read waits until data is available and copies bytes
|
||||
// from the buffer into p.
|
||||
func (r *pipe) Read(p []byte) (n int, err error) {
|
||||
r.c.L.Lock()
|
||||
defer r.c.L.Unlock()
|
||||
for r.b.Len() == 0 && !r.b.closed {
|
||||
r.c.Wait()
|
||||
}
|
||||
return r.b.Read(p)
|
||||
}
|
||||
|
||||
// Write copies bytes from p into the buffer and wakes a reader.
|
||||
// It is an error to write more data than the buffer can hold.
|
||||
func (w *pipe) Write(p []byte) (n int, err error) {
|
||||
w.c.L.Lock()
|
||||
defer w.c.L.Unlock()
|
||||
defer w.c.Signal()
|
||||
return w.b.Write(p)
|
||||
}
|
||||
|
||||
func (c *pipe) Close(err error) {
|
||||
c.c.L.Lock()
|
||||
defer c.c.L.Unlock()
|
||||
defer c.c.Signal()
|
||||
c.b.Close(err)
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPipeClose(t *testing.T) {
|
||||
var p pipe
|
||||
p.c.L = &p.m
|
||||
a := errors.New("a")
|
||||
b := errors.New("b")
|
||||
p.Close(a)
|
||||
p.Close(b)
|
||||
_, err := p.Read(make([]byte, 1))
|
||||
if err != a {
|
||||
t.Errorf("err = %v want %v", err, a)
|
||||
}
|
||||
}
|
@ -0,0 +1,121 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPriority(t *testing.T) {
|
||||
// A -> B
|
||||
// move A's parent to B
|
||||
streams := make(map[uint32]*stream)
|
||||
a := &stream{
|
||||
parent: nil,
|
||||
weight: 16,
|
||||
}
|
||||
streams[1] = a
|
||||
b := &stream{
|
||||
parent: a,
|
||||
weight: 16,
|
||||
}
|
||||
streams[2] = b
|
||||
adjustStreamPriority(streams, 1, PriorityParam{
|
||||
Weight: 20,
|
||||
StreamDep: 2,
|
||||
})
|
||||
if a.parent != b {
|
||||
t.Errorf("Expected A's parent to be B")
|
||||
}
|
||||
if a.weight != 20 {
|
||||
t.Errorf("Expected A's weight to be 20; got %d", a.weight)
|
||||
}
|
||||
if b.parent != nil {
|
||||
t.Errorf("Expected B to have no parent")
|
||||
}
|
||||
if b.weight != 16 {
|
||||
t.Errorf("Expected B's weight to be 16; got %d", b.weight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPriorityExclusiveZero(t *testing.T) {
|
||||
// A B and C are all children of the 0 stream.
|
||||
// Exclusive reprioritization to any of the streams
|
||||
// should bring the rest of the streams under the
|
||||
// reprioritized stream
|
||||
streams := make(map[uint32]*stream)
|
||||
a := &stream{
|
||||
parent: nil,
|
||||
weight: 16,
|
||||
}
|
||||
streams[1] = a
|
||||
b := &stream{
|
||||
parent: nil,
|
||||
weight: 16,
|
||||
}
|
||||
streams[2] = b
|
||||
c := &stream{
|
||||
parent: nil,
|
||||
weight: 16,
|
||||
}
|
||||
streams[3] = c
|
||||
adjustStreamPriority(streams, 3, PriorityParam{
|
||||
Weight: 20,
|
||||
StreamDep: 0,
|
||||
Exclusive: true,
|
||||
})
|
||||
if a.parent != c {
|
||||
t.Errorf("Expected A's parent to be C")
|
||||
}
|
||||
if a.weight != 16 {
|
||||
t.Errorf("Expected A's weight to be 16; got %d", a.weight)
|
||||
}
|
||||
if b.parent != c {
|
||||
t.Errorf("Expected B's parent to be C")
|
||||
}
|
||||
if b.weight != 16 {
|
||||
t.Errorf("Expected B's weight to be 16; got %d", b.weight)
|
||||
}
|
||||
if c.parent != nil {
|
||||
t.Errorf("Expected C to have no parent")
|
||||
}
|
||||
if c.weight != 20 {
|
||||
t.Errorf("Expected C's weight to be 20; got %d", b.weight)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPriorityOwnParent(t *testing.T) {
|
||||
streams := make(map[uint32]*stream)
|
||||
a := &stream{
|
||||
parent: nil,
|
||||
weight: 16,
|
||||
}
|
||||
streams[1] = a
|
||||
b := &stream{
|
||||
parent: a,
|
||||
weight: 16,
|
||||
}
|
||||
streams[2] = b
|
||||
adjustStreamPriority(streams, 1, PriorityParam{
|
||||
Weight: 20,
|
||||
StreamDep: 1,
|
||||
})
|
||||
if a.parent != nil {
|
||||
t.Errorf("Expected A's parent to be nil")
|
||||
}
|
||||
if a.weight != 20 {
|
||||
t.Errorf("Expected A's weight to be 20; got %d", a.weight)
|
||||
}
|
||||
if b.parent != a {
|
||||
t.Errorf("Expected B's parent to be A")
|
||||
}
|
||||
if b.weight != 16 {
|
||||
t.Errorf("Expected B's weight to be 16; got %d", b.weight)
|
||||
}
|
||||
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,553 @@
|
||||
// Copyright 2015 The Go Authors.
|
||||
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://go.googlesource.com/go/+/master/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/bradfitz/http2/hpack"
|
||||
)
|
||||
|
||||
type Transport struct {
|
||||
Fallback http.RoundTripper
|
||||
|
||||
// TODO: remove this and make more general with a TLS dial hook, like http
|
||||
InsecureTLSDial bool
|
||||
|
||||
connMu sync.Mutex
|
||||
conns map[string][]*clientConn // key is host:port
|
||||
}
|
||||
|
||||
type clientConn struct {
|
||||
t *Transport
|
||||
tconn *tls.Conn
|
||||
tlsState *tls.ConnectionState
|
||||
connKey []string // key(s) this connection is cached in, in t.conns
|
||||
|
||||
readerDone chan struct{} // closed on error
|
||||
readerErr error // set before readerDone is closed
|
||||
hdec *hpack.Decoder
|
||||
nextRes *http.Response
|
||||
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
|
||||
streams map[uint32]*clientStream
|
||||
nextStreamID uint32
|
||||
bw *bufio.Writer
|
||||
werr error // first write error that has occurred
|
||||
br *bufio.Reader
|
||||
fr *Framer
|
||||
// Settings from peer:
|
||||
maxFrameSize uint32
|
||||
maxConcurrentStreams uint32
|
||||
initialWindowSize uint32
|
||||
hbuf bytes.Buffer // HPACK encoder writes into this
|
||||
henc *hpack.Encoder
|
||||
}
|
||||
|
||||
type clientStream struct {
|
||||
ID uint32
|
||||
resc chan resAndError
|
||||
pw *io.PipeWriter
|
||||
pr *io.PipeReader
|
||||
}
|
||||
|
||||
type stickyErrWriter struct {
|
||||
w io.Writer
|
||||
err *error
|
||||
}
|
||||
|
||||
func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
|
||||
if *sew.err != nil {
|
||||
return 0, *sew.err
|
||||
}
|
||||
n, err = sew.w.Write(p)
|
||||
*sew.err = err
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Scheme != "https" {
|
||||
if t.Fallback == nil {
|
||||
return nil, errors.New("http2: unsupported scheme and no Fallback")
|
||||
}
|
||||
return t.Fallback.RoundTrip(req)
|
||||
}
|
||||
|
||||
host, port, err := net.SplitHostPort(req.URL.Host)
|
||||
if err != nil {
|
||||
host = req.URL.Host
|
||||
port = "443"
|
||||
}
|
||||
|
||||
for {
|
||||
cc, err := t.getClientConn(host, port)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := cc.roundTrip(req)
|
||||
if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)?
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
// CloseIdleConnections closes any connections which were previously
|
||||
// connected from previous requests but are now sitting idle.
|
||||
// It does not interrupt any connections currently in use.
|
||||
func (t *Transport) CloseIdleConnections() {
|
||||
t.connMu.Lock()
|
||||
defer t.connMu.Unlock()
|
||||
for _, vv := range t.conns {
|
||||
for _, cc := range vv {
|
||||
cc.closeIfIdle()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var errClientConnClosed = errors.New("http2: client conn is closed")
|
||||
|
||||
func shouldRetryRequest(err error) bool {
|
||||
// TODO: or GOAWAY graceful shutdown stuff
|
||||
return err == errClientConnClosed
|
||||
}
|
||||
|
||||
func (t *Transport) removeClientConn(cc *clientConn) {
|
||||
t.connMu.Lock()
|
||||
defer t.connMu.Unlock()
|
||||
for _, key := range cc.connKey {
|
||||
vv, ok := t.conns[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
newList := filterOutClientConn(vv, cc)
|
||||
if len(newList) > 0 {
|
||||
t.conns[key] = newList
|
||||
} else {
|
||||
delete(t.conns, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn {
|
||||
out := in[:0]
|
||||
for _, v := range in {
|
||||
if v != exclude {
|
||||
out = append(out, v)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Transport) getClientConn(host, port string) (*clientConn, error) {
|
||||
t.connMu.Lock()
|
||||
defer t.connMu.Unlock()
|
||||
|
||||
key := net.JoinHostPort(host, port)
|
||||
|
||||
for _, cc := range t.conns[key] {
|
||||
if cc.canTakeNewRequest() {
|
||||
return cc, nil
|
||||
}
|
||||
}
|
||||
if t.conns == nil {
|
||||
t.conns = make(map[string][]*clientConn)
|
||||
}
|
||||
cc, err := t.newClientConn(host, port, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.conns[key] = append(t.conns[key], cc)
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) {
|
||||
cfg := &tls.Config{
|
||||
ServerName: host,
|
||||
NextProtos: []string{NextProtoTLS},
|
||||
InsecureSkipVerify: t.InsecureTLSDial,
|
||||
}
|
||||
tconn, err := tls.Dial("tcp", host+":"+port, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tconn.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !t.InsecureTLSDial {
|
||||
if err := tconn.VerifyHostname(cfg.ServerName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
state := tconn.ConnectionState()
|
||||
if p := state.NegotiatedProtocol; p != NextProtoTLS {
|
||||
// TODO(bradfitz): fall back to Fallback
|
||||
return nil, fmt.Errorf("bad protocol: %v", p)
|
||||
}
|
||||
if !state.NegotiatedProtocolIsMutual {
|
||||
return nil, errors.New("could not negotiate protocol mutually")
|
||||
}
|
||||
if _, err := tconn.Write(clientPreface); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cc := &clientConn{
|
||||
t: t,
|
||||
tconn: tconn,
|
||||
connKey: []string{key}, // TODO: cert's validated hostnames too
|
||||
tlsState: &state,
|
||||
readerDone: make(chan struct{}),
|
||||
nextStreamID: 1,
|
||||
maxFrameSize: 16 << 10, // spec default
|
||||
initialWindowSize: 65535, // spec default
|
||||
maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough.
|
||||
streams: make(map[uint32]*clientStream),
|
||||
}
|
||||
cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr})
|
||||
cc.br = bufio.NewReader(tconn)
|
||||
cc.fr = NewFramer(cc.bw, cc.br)
|
||||
cc.henc = hpack.NewEncoder(&cc.hbuf)
|
||||
|
||||
cc.fr.WriteSettings()
|
||||
// TODO: re-send more conn-level flow control tokens when server uses all these.
|
||||
cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs?
|
||||
cc.bw.Flush()
|
||||
if cc.werr != nil {
|
||||
return nil, cc.werr
|
||||
}
|
||||
|
||||
// Read the obligatory SETTINGS frame
|
||||
f, err := cc.fr.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sf, ok := f.(*SettingsFrame)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected settings frame, got: %T", f)
|
||||
}
|
||||
cc.fr.WriteSettingsAck()
|
||||
cc.bw.Flush()
|
||||
|
||||
sf.ForeachSetting(func(s Setting) error {
|
||||
switch s.ID {
|
||||
case SettingMaxFrameSize:
|
||||
cc.maxFrameSize = s.Val
|
||||
case SettingMaxConcurrentStreams:
|
||||
cc.maxConcurrentStreams = s.Val
|
||||
case SettingInitialWindowSize:
|
||||
cc.initialWindowSize = s.Val
|
||||
default:
|
||||
// TODO(bradfitz): handle more
|
||||
log.Printf("Unhandled Setting: %v", s)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// TODO: figure out henc size
|
||||
cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField)
|
||||
|
||||
go cc.readLoop()
|
||||
return cc, nil
|
||||
}
|
||||
|
||||
func (cc *clientConn) setGoAway(f *GoAwayFrame) {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
cc.goAway = f
|
||||
}
|
||||
|
||||
func (cc *clientConn) canTakeNewRequest() bool {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
return cc.goAway == nil &&
|
||||
int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
|
||||
cc.nextStreamID < 2147483647
|
||||
}
|
||||
|
||||
func (cc *clientConn) closeIfIdle() {
|
||||
cc.mu.Lock()
|
||||
if len(cc.streams) > 0 {
|
||||
cc.mu.Unlock()
|
||||
return
|
||||
}
|
||||
cc.closed = true
|
||||
// TODO: do clients send GOAWAY too? maybe? Just Close:
|
||||
cc.mu.Unlock()
|
||||
|
||||
cc.tconn.Close()
|
||||
}
|
||||
|
||||
func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) {
|
||||
cc.mu.Lock()
|
||||
|
||||
if cc.closed {
|
||||
cc.mu.Unlock()
|
||||
return nil, errClientConnClosed
|
||||
}
|
||||
|
||||
cs := cc.newStream()
|
||||
hasBody := false // TODO
|
||||
|
||||
// we send: HEADERS[+CONTINUATION] + (DATA?)
|
||||
hdrs := cc.encodeHeaders(req)
|
||||
first := true
|
||||
for len(hdrs) > 0 {
|
||||
chunk := hdrs
|
||||
if len(chunk) > int(cc.maxFrameSize) {
|
||||
chunk = chunk[:cc.maxFrameSize]
|
||||
}
|
||||
hdrs = hdrs[len(chunk):]
|
||||
endHeaders := len(hdrs) == 0
|
||||
if first {
|
||||
cc.fr.WriteHeaders(HeadersFrameParam{
|
||||
StreamID: cs.ID,
|
||||
BlockFragment: chunk,
|
||||
EndStream: !hasBody,
|
||||
EndHeaders: endHeaders,
|
||||
})
|
||||
first = false
|
||||
} else {
|
||||
cc.fr.WriteContinuation(cs.ID, endHeaders, chunk)
|
||||
}
|
||||
}
|
||||
cc.bw.Flush()
|
||||
werr := cc.werr
|
||||
cc.mu.Unlock()
|
||||
|
||||
if hasBody {
|
||||
// TODO: write data. and it should probably be interleaved:
|
||||
// go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc
|
||||
}
|
||||
|
||||
if werr != nil {
|
||||
return nil, werr
|
||||
}
|
||||
|
||||
re := <-cs.resc
|
||||
if re.err != nil {
|
||||
return nil, re.err
|
||||
}
|
||||
res := re.res
|
||||
res.Request = req
|
||||
res.TLS = cc.tlsState
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// requires cc.mu be held.
|
||||
func (cc *clientConn) encodeHeaders(req *http.Request) []byte {
|
||||
cc.hbuf.Reset()
|
||||
|
||||
// TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
host = req.URL.Host
|
||||
}
|
||||
|
||||
path := req.URL.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
cc.writeHeader(":authority", host) // probably not right for all sites
|
||||
cc.writeHeader(":method", req.Method)
|
||||
cc.writeHeader(":path", path)
|
||||
cc.writeHeader(":scheme", "https")
|
||||
|
||||
for k, vv := range req.Header {
|
||||
lowKey := strings.ToLower(k)
|
||||
if lowKey == "host" {
|
||||
continue
|
||||
}
|
||||
for _, v := range vv {
|
||||
cc.writeHeader(lowKey, v)
|
||||
}
|
||||
}
|
||||
return cc.hbuf.Bytes()
|
||||
}
|
||||
|
||||
func (cc *clientConn) writeHeader(name, value string) {
|
||||
log.Printf("sending %q = %q", name, value)
|
||||
cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
|
||||
}
|
||||
|
||||
type resAndError struct {
|
||||
res *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
// requires cc.mu be held.
|
||||
func (cc *clientConn) newStream() *clientStream {
|
||||
cs := &clientStream{
|
||||
ID: cc.nextStreamID,
|
||||
resc: make(chan resAndError, 1),
|
||||
}
|
||||
cc.nextStreamID += 2
|
||||
cc.streams[cs.ID] = cs
|
||||
return cs
|
||||
}
|
||||
|
||||
func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream {
|
||||
cc.mu.Lock()
|
||||
defer cc.mu.Unlock()
|
||||
cs := cc.streams[id]
|
||||
if andRemove {
|
||||
delete(cc.streams, id)
|
||||
}
|
||||
return cs
|
||||
}
|
||||
|
||||
// runs in its own goroutine.
|
||||
func (cc *clientConn) readLoop() {
|
||||
defer cc.t.removeClientConn(cc)
|
||||
defer close(cc.readerDone)
|
||||
|
||||
activeRes := map[uint32]*clientStream{} // keyed by streamID
|
||||
// Close any response bodies if the server closes prematurely.
|
||||
// TODO: also do this if we've written the headers but not
|
||||
// gotten a response yet.
|
||||
defer func() {
|
||||
err := cc.readerErr
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
for _, cs := range activeRes {
|
||||
cs.pw.CloseWithError(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// continueStreamID is the stream ID we're waiting for
|
||||
// continuation frames for.
|
||||
var continueStreamID uint32
|
||||
|
||||
for {
|
||||
f, err := cc.fr.ReadFrame()
|
||||
if err != nil {
|
||||
cc.readerErr = err
|
||||
return
|
||||
}
|
||||
log.Printf("Transport received %v: %#v", f.Header(), f)
|
||||
|
||||
streamID := f.Header().StreamID
|
||||
|
||||
_, isContinue := f.(*ContinuationFrame)
|
||||
if isContinue {
|
||||
if streamID != continueStreamID {
|
||||
log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID)
|
||||
cc.readerErr = ConnectionError(ErrCodeProtocol)
|
||||
return
|
||||
}
|
||||
} else if continueStreamID != 0 {
|
||||
// Continue frames need to be adjacent in the stream
|
||||
// and we were in the middle of headers.
|
||||
log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID)
|
||||
cc.readerErr = ConnectionError(ErrCodeProtocol)
|
||||
return
|
||||
}
|
||||
|
||||
if streamID%2 == 0 {
|
||||
// Ignore streams pushed from the server for now.
|
||||
// These always have an even stream id.
|
||||
continue
|
||||
}
|
||||
streamEnded := false
|
||||
if ff, ok := f.(streamEnder); ok {
|
||||
streamEnded = ff.StreamEnded()
|
||||
}
|
||||
|
||||
cs := cc.streamByID(streamID, streamEnded)
|
||||
if cs == nil {
|
||||
log.Printf("Received frame for untracked stream ID %d", streamID)
|
||||
continue
|
||||
}
|
||||
|
||||
switch f := f.(type) {
|
||||
case *HeadersFrame:
|
||||
cc.nextRes = &http.Response{
|
||||
Proto: "HTTP/2.0",
|
||||
ProtoMajor: 2,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
cs.pr, cs.pw = io.Pipe()
|
||||
cc.hdec.Write(f.HeaderBlockFragment())
|
||||
case *ContinuationFrame:
|
||||
cc.hdec.Write(f.HeaderBlockFragment())
|
||||
case *DataFrame:
|
||||
log.Printf("DATA: %q", f.Data())
|
||||
cs.pw.Write(f.Data())
|
||||
case *GoAwayFrame:
|
||||
cc.t.removeClientConn(cc)
|
||||
if f.ErrCode != 0 {
|
||||
// TODO: deal with GOAWAY more. particularly the error code
|
||||
log.Printf("transport got GOAWAY with error code = %v", f.ErrCode)
|
||||
}
|
||||
cc.setGoAway(f)
|
||||
default:
|
||||
log.Printf("Transport: unhandled response frame type %T", f)
|
||||
}
|
||||
headersEnded := false
|
||||
if he, ok := f.(headersEnder); ok {
|
||||
headersEnded = he.HeadersEnded()
|
||||
if headersEnded {
|
||||
continueStreamID = 0
|
||||
} else {
|
||||
continueStreamID = streamID
|
||||
}
|
||||
}
|
||||
|
||||
if streamEnded {
|
||||
cs.pw.Close()
|
||||
delete(activeRes, streamID)
|
||||
}
|
||||
if headersEnded {
|
||||
if cs == nil {
|
||||
panic("couldn't find stream") // TODO be graceful
|
||||
}
|
||||
// TODO: set the Body to one which notes the
|
||||
// Close and also sends the server a
|
||||
// RST_STREAM
|
||||
cc.nextRes.Body = cs.pr
|
||||
res := cc.nextRes
|
||||
activeRes[streamID] = cs
|
||||
cs.resc <- resAndError{res: res}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) {
|
||||
// TODO: verifiy pseudo headers come before non-pseudo headers
|
||||
// TODO: verifiy the status is set
|
||||
log.Printf("Header field: %+v", f)
|
||||
if f.Name == ":status" {
|
||||
code, err := strconv.Atoi(f.Value)
|
||||
if err != nil {
|
||||
panic("TODO: be graceful")
|
||||
}
|
||||
cc.nextRes.Status = f.Value + " " + http.StatusText(code)
|
||||
cc.nextRes.StatusCode = code
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(f.Name, ":") {
|
||||
// "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document."
|
||||
// TODO: treat as invalid?
|
||||
return
|
||||
}
|
||||
cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value)
|
||||
}
|
@ -0,0 +1,168 @@
|
||||
// Copyright 2015 The Go Authors.
|
||||
// See https://go.googlesource.com/go/+/master/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://go.googlesource.com/go/+/master/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
extNet = flag.Bool("extnet", false, "do external network tests")
|
||||
transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
|
||||
insecure = flag.Bool("insecure", false, "insecure TLS dials")
|
||||
)
|
||||
|
||||
func TestTransportExternal(t *testing.T) {
|
||||
if !*extNet {
|
||||
t.Skip("skipping external network test")
|
||||
}
|
||||
req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
|
||||
rt := &Transport{
|
||||
InsecureTLSDial: *insecure,
|
||||
}
|
||||
res, err := rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
res.Write(os.Stdout)
|
||||
}
|
||||
|
||||
func TestTransport(t *testing.T) {
|
||||
const body = "sup"
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, body)
|
||||
})
|
||||
defer st.Close()
|
||||
|
||||
tr := &Transport{InsecureTLSDial: true}
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
req, err := http.NewRequest("GET", st.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
t.Logf("Got res: %+v", res)
|
||||
if g, w := res.StatusCode, 200; g != w {
|
||||
t.Errorf("StatusCode = %v; want %v", g, w)
|
||||
}
|
||||
if g, w := res.Status, "200 OK"; g != w {
|
||||
t.Errorf("Status = %q; want %q", g, w)
|
||||
}
|
||||
wantHeader := http.Header{
|
||||
"Content-Length": []string{"3"},
|
||||
"Content-Type": []string{"text/plain; charset=utf-8"},
|
||||
}
|
||||
if !reflect.DeepEqual(res.Header, wantHeader) {
|
||||
t.Errorf("res Header = %v; want %v", res.Header, wantHeader)
|
||||
}
|
||||
if res.Request != req {
|
||||
t.Errorf("Response.Request = %p; want %p", res.Request, req)
|
||||
}
|
||||
if res.TLS == nil {
|
||||
t.Error("Response.TLS = nil; want non-nil")
|
||||
}
|
||||
slurp, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Body read: %v", err)
|
||||
} else if string(slurp) != body {
|
||||
t.Errorf("Body = %q; want %q", slurp, body)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestTransportReusesConns(t *testing.T) {
|
||||
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, r.RemoteAddr)
|
||||
}, optOnlyServer)
|
||||
defer st.Close()
|
||||
tr := &Transport{InsecureTLSDial: true}
|
||||
defer tr.CloseIdleConnections()
|
||||
get := func() string {
|
||||
req, err := http.NewRequest("GET", st.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
slurp, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Body read: %v", err)
|
||||
}
|
||||
addr := strings.TrimSpace(string(slurp))
|
||||
if addr == "" {
|
||||
t.Fatalf("didn't get an addr in response")
|
||||
}
|
||||
return addr
|
||||
}
|
||||
first := get()
|
||||
second := get()
|
||||
if first != second {
|
||||
t.Errorf("first and second responses were on different connections: %q vs %q", first, second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportAbortClosesPipes(t *testing.T) {
|
||||
shutdown := make(chan struct{})
|
||||
st := newServerTester(t,
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
w.(http.Flusher).Flush()
|
||||
<-shutdown
|
||||
},
|
||||
optOnlyServer,
|
||||
)
|
||||
defer st.Close()
|
||||
defer close(shutdown) // we must shutdown before st.Close() to avoid hanging
|
||||
|
||||
done := make(chan struct{})
|
||||
requestMade := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
tr := &Transport{
|
||||
InsecureTLSDial: true,
|
||||
}
|
||||
req, err := http.NewRequest("GET", st.ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
close(requestMade)
|
||||
_, err = ioutil.ReadAll(res.Body)
|
||||
if err == nil {
|
||||
t.Error("expected error from res.Body.Read")
|
||||
}
|
||||
}()
|
||||
|
||||
<-requestMade
|
||||
// Now force the serve loop to end, via closing the connection.
|
||||
st.closeConn()
|
||||
// deadlock? that's a bug.
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
@ -0,0 +1,204 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bradfitz/http2/hpack"
|
||||
)
|
||||
|
||||
// writeFramer is implemented by any type that is used to write frames.
|
||||
type writeFramer interface {
|
||||
writeFrame(writeContext) error
|
||||
}
|
||||
|
||||
// writeContext is the interface needed by the various frame writer
|
||||
// types below. All the writeFrame methods below are scheduled via the
|
||||
// frame writing scheduler (see writeScheduler in writesched.go).
|
||||
//
|
||||
// This interface is implemented by *serverConn.
|
||||
// TODO: use it from the client code too, once it exists.
|
||||
type writeContext interface {
|
||||
Framer() *Framer
|
||||
Flush() error
|
||||
CloseConn() error
|
||||
// HeaderEncoder returns an HPACK encoder that writes to the
|
||||
// returned buffer.
|
||||
HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
|
||||
}
|
||||
|
||||
// endsStream reports whether the given frame writer w will locally
|
||||
// close the stream.
|
||||
func endsStream(w writeFramer) bool {
|
||||
switch v := w.(type) {
|
||||
case *writeData:
|
||||
return v.endStream
|
||||
case *writeResHeaders:
|
||||
return v.endStream
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type flushFrameWriter struct{}
|
||||
|
||||
func (flushFrameWriter) writeFrame(ctx writeContext) error {
|
||||
return ctx.Flush()
|
||||
}
|
||||
|
||||
type writeSettings []Setting
|
||||
|
||||
func (s writeSettings) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteSettings([]Setting(s)...)
|
||||
}
|
||||
|
||||
type writeGoAway struct {
|
||||
maxStreamID uint32
|
||||
code ErrCode
|
||||
}
|
||||
|
||||
func (p *writeGoAway) writeFrame(ctx writeContext) error {
|
||||
err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil)
|
||||
if p.code != 0 {
|
||||
ctx.Flush() // ignore error: we're hanging up on them anyway
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
ctx.CloseConn()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type writeData struct {
|
||||
streamID uint32
|
||||
p []byte
|
||||
endStream bool
|
||||
}
|
||||
|
||||
func (w *writeData) String() string {
|
||||
return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream)
|
||||
}
|
||||
|
||||
func (w *writeData) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteData(w.streamID, w.endStream, w.p)
|
||||
}
|
||||
|
||||
func (se StreamError) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteRSTStream(se.StreamID, se.Code)
|
||||
}
|
||||
|
||||
type writePingAck struct{ pf *PingFrame }
|
||||
|
||||
func (w writePingAck) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WritePing(true, w.pf.Data)
|
||||
}
|
||||
|
||||
type writeSettingsAck struct{}
|
||||
|
||||
func (writeSettingsAck) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteSettingsAck()
|
||||
}
|
||||
|
||||
// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
|
||||
// for HTTP response headers from a server handler.
|
||||
type writeResHeaders struct {
|
||||
streamID uint32
|
||||
httpResCode int
|
||||
h http.Header // may be nil
|
||||
endStream bool
|
||||
|
||||
contentType string
|
||||
contentLength string
|
||||
}
|
||||
|
||||
func (w *writeResHeaders) writeFrame(ctx writeContext) error {
|
||||
enc, buf := ctx.HeaderEncoder()
|
||||
buf.Reset()
|
||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(w.httpResCode)})
|
||||
for k, vv := range w.h {
|
||||
k = lowerHeader(k)
|
||||
for _, v := range vv {
|
||||
// TODO: more of "8.1.2.2 Connection-Specific Header Fields"
|
||||
if k == "transfer-encoding" && v != "trailers" {
|
||||
continue
|
||||
}
|
||||
enc.WriteField(hpack.HeaderField{Name: k, Value: v})
|
||||
}
|
||||
}
|
||||
if w.contentType != "" {
|
||||
enc.WriteField(hpack.HeaderField{Name: "content-type", Value: w.contentType})
|
||||
}
|
||||
if w.contentLength != "" {
|
||||
enc.WriteField(hpack.HeaderField{Name: "content-length", Value: w.contentLength})
|
||||
}
|
||||
|
||||
headerBlock := buf.Bytes()
|
||||
if len(headerBlock) == 0 {
|
||||
panic("unexpected empty hpack")
|
||||
}
|
||||
|
||||
// For now we're lazy and just pick the minimum MAX_FRAME_SIZE
|
||||
// that all peers must support (16KB). Later we could care
|
||||
// more and send larger frames if the peer advertised it, but
|
||||
// there's little point. Most headers are small anyway (so we
|
||||
// generally won't have CONTINUATION frames), and extra frames
|
||||
// only waste 9 bytes anyway.
|
||||
const maxFrameSize = 16384
|
||||
|
||||
first := true
|
||||
for len(headerBlock) > 0 {
|
||||
frag := headerBlock
|
||||
if len(frag) > maxFrameSize {
|
||||
frag = frag[:maxFrameSize]
|
||||
}
|
||||
headerBlock = headerBlock[len(frag):]
|
||||
endHeaders := len(headerBlock) == 0
|
||||
var err error
|
||||
if first {
|
||||
first = false
|
||||
err = ctx.Framer().WriteHeaders(HeadersFrameParam{
|
||||
StreamID: w.streamID,
|
||||
BlockFragment: frag,
|
||||
EndStream: w.endStream,
|
||||
EndHeaders: endHeaders,
|
||||
})
|
||||
} else {
|
||||
err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type write100ContinueHeadersFrame struct {
|
||||
streamID uint32
|
||||
}
|
||||
|
||||
func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error {
|
||||
enc, buf := ctx.HeaderEncoder()
|
||||
buf.Reset()
|
||||
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"})
|
||||
return ctx.Framer().WriteHeaders(HeadersFrameParam{
|
||||
StreamID: w.streamID,
|
||||
BlockFragment: buf.Bytes(),
|
||||
EndStream: false,
|
||||
EndHeaders: true,
|
||||
})
|
||||
}
|
||||
|
||||
type writeWindowUpdate struct {
|
||||
streamID uint32 // or 0 for conn-level
|
||||
n uint32
|
||||
}
|
||||
|
||||
func (wu writeWindowUpdate) writeFrame(ctx writeContext) error {
|
||||
return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
|
||||
}
|
@ -0,0 +1,286 @@
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import "fmt"
|
||||
|
||||
// frameWriteMsg is a request to write a frame.
|
||||
type frameWriteMsg struct {
|
||||
// write is the interface value that does the writing, once the
|
||||
// writeScheduler (below) has decided to select this frame
|
||||
// to write. The write functions are all defined in write.go.
|
||||
write writeFramer
|
||||
|
||||
stream *stream // used for prioritization. nil for non-stream frames.
|
||||
|
||||
// done, if non-nil, must be a buffered channel with space for
|
||||
// 1 message and is sent the return value from write (or an
|
||||
// earlier error) when the frame has been written.
|
||||
done chan error
|
||||
}
|
||||
|
||||
// for debugging only:
|
||||
func (wm frameWriteMsg) String() string {
|
||||
var streamID uint32
|
||||
if wm.stream != nil {
|
||||
streamID = wm.stream.id
|
||||
}
|
||||
var des string
|
||||
if s, ok := wm.write.(fmt.Stringer); ok {
|
||||
des = s.String()
|
||||
} else {
|
||||
des = fmt.Sprintf("%T", wm.write)
|
||||
}
|
||||
return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des)
|
||||
}
|
||||
|
||||
// writeScheduler tracks pending frames to write, priorities, and decides
|
||||
// the next one to use. It is not thread-safe.
|
||||
type writeScheduler struct {
|
||||
// zero are frames not associated with a specific stream.
|
||||
// They're sent before any stream-specific freams.
|
||||
zero writeQueue
|
||||
|
||||
// maxFrameSize is the maximum size of a DATA frame
|
||||
// we'll write. Must be non-zero and between 16K-16M.
|
||||
maxFrameSize uint32
|
||||
|
||||
// sq contains the stream-specific queues, keyed by stream ID.
|
||||
// when a stream is idle, it's deleted from the map.
|
||||
sq map[uint32]*writeQueue
|
||||
|
||||
// canSend is a slice of memory that's reused between frame
|
||||
// scheduling decisions to hold the list of writeQueues (from sq)
|
||||
// which have enough flow control data to send. After canSend is
|
||||
// built, the best is selected.
|
||||
canSend []*writeQueue
|
||||
|
||||
// pool of empty queues for reuse.
|
||||
queuePool []*writeQueue
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) putEmptyQueue(q *writeQueue) {
|
||||
if len(q.s) != 0 {
|
||||
panic("queue must be empty")
|
||||
}
|
||||
ws.queuePool = append(ws.queuePool, q)
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) getEmptyQueue() *writeQueue {
|
||||
ln := len(ws.queuePool)
|
||||
if ln == 0 {
|
||||
return new(writeQueue)
|
||||
}
|
||||
q := ws.queuePool[ln-1]
|
||||
ws.queuePool = ws.queuePool[:ln-1]
|
||||
return q
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 }
|
||||
|
||||
func (ws *writeScheduler) add(wm frameWriteMsg) {
|
||||
st := wm.stream
|
||||
if st == nil {
|
||||
ws.zero.push(wm)
|
||||
} else {
|
||||
ws.streamQueue(st.id).push(wm)
|
||||
}
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) streamQueue(streamID uint32) *writeQueue {
|
||||
if q, ok := ws.sq[streamID]; ok {
|
||||
return q
|
||||
}
|
||||
if ws.sq == nil {
|
||||
ws.sq = make(map[uint32]*writeQueue)
|
||||
}
|
||||
q := ws.getEmptyQueue()
|
||||
ws.sq[streamID] = q
|
||||
return q
|
||||
}
|
||||
|
||||
// take returns the most important frame to write and removes it from the scheduler.
|
||||
// It is illegal to call this if the scheduler is empty or if there are no connection-level
|
||||
// flow control bytes available.
|
||||
func (ws *writeScheduler) take() (wm frameWriteMsg, ok bool) {
|
||||
if ws.maxFrameSize == 0 {
|
||||
panic("internal error: ws.maxFrameSize not initialized or invalid")
|
||||
}
|
||||
|
||||
// If there any frames not associated with streams, prefer those first.
|
||||
// These are usually SETTINGS, etc.
|
||||
if !ws.zero.empty() {
|
||||
return ws.zero.shift(), true
|
||||
}
|
||||
if len(ws.sq) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Next, prioritize frames on streams that aren't DATA frames (no cost).
|
||||
for id, q := range ws.sq {
|
||||
if q.firstIsNoCost() {
|
||||
return ws.takeFrom(id, q)
|
||||
}
|
||||
}
|
||||
|
||||
// Now, all that remains are DATA frames with non-zero bytes to
|
||||
// send. So pick the best one.
|
||||
if len(ws.canSend) != 0 {
|
||||
panic("should be empty")
|
||||
}
|
||||
for _, q := range ws.sq {
|
||||
if n := ws.streamWritableBytes(q); n > 0 {
|
||||
ws.canSend = append(ws.canSend, q)
|
||||
}
|
||||
}
|
||||
if len(ws.canSend) == 0 {
|
||||
return
|
||||
}
|
||||
defer ws.zeroCanSend()
|
||||
|
||||
// TODO: find the best queue
|
||||
q := ws.canSend[0]
|
||||
|
||||
return ws.takeFrom(q.streamID(), q)
|
||||
}
|
||||
|
||||
// zeroCanSend is defered from take.
|
||||
func (ws *writeScheduler) zeroCanSend() {
|
||||
for i := range ws.canSend {
|
||||
ws.canSend[i] = nil
|
||||
}
|
||||
ws.canSend = ws.canSend[:0]
|
||||
}
|
||||
|
||||
// streamWritableBytes returns the number of DATA bytes we could write
|
||||
// from the given queue's stream, if this stream/queue were
|
||||
// selected. It is an error to call this if q's head isn't a
|
||||
// *writeData.
|
||||
func (ws *writeScheduler) streamWritableBytes(q *writeQueue) int32 {
|
||||
wm := q.head()
|
||||
ret := wm.stream.flow.available() // max we can write
|
||||
if ret == 0 {
|
||||
return 0
|
||||
}
|
||||
if int32(ws.maxFrameSize) < ret {
|
||||
ret = int32(ws.maxFrameSize)
|
||||
}
|
||||
if ret == 0 {
|
||||
panic("internal error: ws.maxFrameSize not initialized or invalid")
|
||||
}
|
||||
wd := wm.write.(*writeData)
|
||||
if len(wd.p) < int(ret) {
|
||||
ret = int32(len(wd.p))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) takeFrom(id uint32, q *writeQueue) (wm frameWriteMsg, ok bool) {
|
||||
wm = q.head()
|
||||
// If the first item in this queue costs flow control tokens
|
||||
// and we don't have enough, write as much as we can.
|
||||
if wd, ok := wm.write.(*writeData); ok && len(wd.p) > 0 {
|
||||
allowed := wm.stream.flow.available() // max we can write
|
||||
if allowed == 0 {
|
||||
// No quota available. Caller can try the next stream.
|
||||
return frameWriteMsg{}, false
|
||||
}
|
||||
if int32(ws.maxFrameSize) < allowed {
|
||||
allowed = int32(ws.maxFrameSize)
|
||||
}
|
||||
// TODO: further restrict the allowed size, because even if
|
||||
// the peer says it's okay to write 16MB data frames, we might
|
||||
// want to write smaller ones to properly weight competing
|
||||
// streams' priorities.
|
||||
|
||||
if len(wd.p) > int(allowed) {
|
||||
wm.stream.flow.take(allowed)
|
||||
chunk := wd.p[:allowed]
|
||||
wd.p = wd.p[allowed:]
|
||||
// Make up a new write message of a valid size, rather
|
||||
// than shifting one off the queue.
|
||||
return frameWriteMsg{
|
||||
stream: wm.stream,
|
||||
write: &writeData{
|
||||
streamID: wd.streamID,
|
||||
p: chunk,
|
||||
// even if the original had endStream set, there
|
||||
// arebytes remaining because len(wd.p) > allowed,
|
||||
// so we know endStream is false:
|
||||
endStream: false,
|
||||
},
|
||||
// our caller is blocking on the final DATA frame, not
|
||||
// these intermediates, so no need to wait:
|
||||
done: nil,
|
||||
}, true
|
||||
}
|
||||
wm.stream.flow.take(int32(len(wd.p)))
|
||||
}
|
||||
|
||||
q.shift()
|
||||
if q.empty() {
|
||||
ws.putEmptyQueue(q)
|
||||
delete(ws.sq, id)
|
||||
}
|
||||
return wm, true
|
||||
}
|
||||
|
||||
func (ws *writeScheduler) forgetStream(id uint32) {
|
||||
q, ok := ws.sq[id]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(ws.sq, id)
|
||||
|
||||
// But keep it for others later.
|
||||
for i := range q.s {
|
||||
q.s[i] = frameWriteMsg{}
|
||||
}
|
||||
q.s = q.s[:0]
|
||||
ws.putEmptyQueue(q)
|
||||
}
|
||||
|
||||
type writeQueue struct {
|
||||
s []frameWriteMsg
|
||||
}
|
||||
|
||||
// streamID returns the stream ID for a non-empty stream-specific queue.
|
||||
func (q *writeQueue) streamID() uint32 { return q.s[0].stream.id }
|
||||
|
||||
func (q *writeQueue) empty() bool { return len(q.s) == 0 }
|
||||
|
||||
func (q *writeQueue) push(wm frameWriteMsg) {
|
||||
q.s = append(q.s, wm)
|
||||
}
|
||||
|
||||
// head returns the next item that would be removed by shift.
|
||||
func (q *writeQueue) head() frameWriteMsg {
|
||||
if len(q.s) == 0 {
|
||||
panic("invalid use of queue")
|
||||
}
|
||||
return q.s[0]
|
||||
}
|
||||
|
||||
func (q *writeQueue) shift() frameWriteMsg {
|
||||
if len(q.s) == 0 {
|
||||
panic("invalid use of queue")
|
||||
}
|
||||
wm := q.s[0]
|
||||
// TODO: less copy-happy queue.
|
||||
copy(q.s, q.s[1:])
|
||||
q.s[len(q.s)-1] = frameWriteMsg{}
|
||||
q.s = q.s[:len(q.s)-1]
|
||||
return wm
|
||||
}
|
||||
|
||||
func (q *writeQueue) firstIsNoCost() bool {
|
||||
if df, ok := q.s[0].write.(*writeData); ok {
|
||||
return len(df.p) == 0
|
||||
}
|
||||
return true
|
||||
}
|
@ -0,0 +1,357 @@
|
||||
// Copyright 2014 The Go Authors.
|
||||
// See https://code.google.com/p/go/source/browse/CONTRIBUTORS
|
||||
// Licensed under the same terms as Go itself:
|
||||
// https://code.google.com/p/go/source/browse/LICENSE
|
||||
|
||||
package http2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/xml"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var coverSpec = flag.Bool("coverspec", false, "Run spec coverage tests")
|
||||
|
||||
// The global map of sentence coverage for the http2 spec.
|
||||
var defaultSpecCoverage specCoverage
|
||||
|
||||
var loadSpecOnce sync.Once
|
||||
|
||||
func loadSpec() {
|
||||
if f, err := os.Open("testdata/draft-ietf-httpbis-http2.xml"); err != nil {
|
||||
panic(err)
|
||||
} else {
|
||||
defaultSpecCoverage = readSpecCov(f)
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// covers marks all sentences for section sec in defaultSpecCoverage. Sentences not
|
||||
// "covered" will be included in report outputed by TestSpecCoverage.
|
||||
func covers(sec, sentences string) {
|
||||
loadSpecOnce.Do(loadSpec)
|
||||
defaultSpecCoverage.cover(sec, sentences)
|
||||
}
|
||||
|
||||
type specPart struct {
|
||||
section string
|
||||
sentence string
|
||||
}
|
||||
|
||||
func (ss specPart) Less(oo specPart) bool {
|
||||
atoi := func(s string) int {
|
||||
n, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
a := strings.Split(ss.section, ".")
|
||||
b := strings.Split(oo.section, ".")
|
||||
for len(a) > 0 {
|
||||
if len(b) == 0 {
|
||||
return false
|
||||
}
|
||||
x, y := atoi(a[0]), atoi(b[0])
|
||||
if x == y {
|
||||
a, b = a[1:], b[1:]
|
||||
continue
|
||||
}
|
||||
return x < y
|
||||
}
|
||||
if len(b) > 0 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type bySpecSection []specPart
|
||||
|
||||
func (a bySpecSection) Len() int { return len(a) }
|
||||
func (a bySpecSection) Less(i, j int) bool { return a[i].Less(a[j]) }
|
||||
func (a bySpecSection) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
|
||||
type specCoverage struct {
|
||||
coverage map[specPart]bool
|
||||
d *xml.Decoder
|
||||
}
|
||||
|
||||
func joinSection(sec []int) string {
|
||||
s := fmt.Sprintf("%d", sec[0])
|
||||
for _, n := range sec[1:] {
|
||||
s = fmt.Sprintf("%s.%d", s, n)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (sc specCoverage) readSection(sec []int) {
|
||||
var (
|
||||
buf = new(bytes.Buffer)
|
||||
sub = 0
|
||||
)
|
||||
for {
|
||||
tk, err := sc.d.Token()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
panic(err)
|
||||
}
|
||||
switch v := tk.(type) {
|
||||
case xml.StartElement:
|
||||
if skipElement(v) {
|
||||
if err := sc.d.Skip(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if v.Name.Local == "section" {
|
||||
sub++
|
||||
}
|
||||
break
|
||||
}
|
||||
switch v.Name.Local {
|
||||
case "section":
|
||||
sub++
|
||||
sc.readSection(append(sec, sub))
|
||||
case "xref":
|
||||
buf.Write(sc.readXRef(v))
|
||||
}
|
||||
case xml.CharData:
|
||||
if len(sec) == 0 {
|
||||
break
|
||||
}
|
||||
buf.Write(v)
|
||||
case xml.EndElement:
|
||||
if v.Name.Local == "section" {
|
||||
sc.addSentences(joinSection(sec), buf.String())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sc specCoverage) readXRef(se xml.StartElement) []byte {
|
||||
var b []byte
|
||||
for {
|
||||
tk, err := sc.d.Token()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
switch v := tk.(type) {
|
||||
case xml.CharData:
|
||||
if b != nil {
|
||||
panic("unexpected CharData")
|
||||
}
|
||||
b = []byte(string(v))
|
||||
case xml.EndElement:
|
||||
if v.Name.Local != "xref" {
|
||||
panic("expected </xref>")
|
||||
}
|
||||
if b != nil {
|
||||
return b
|
||||
}
|
||||
sig := attrSig(se)
|
||||
switch sig {
|
||||
case "target":
|
||||
return []byte(fmt.Sprintf("[%s]", attrValue(se, "target")))
|
||||
case "fmt-of,rel,target", "fmt-,,rel,target":
|
||||
return []byte(fmt.Sprintf("[%s, %s]", attrValue(se, "target"), attrValue(se, "rel")))
|
||||
case "fmt-of,sec,target", "fmt-,,sec,target":
|
||||
return []byte(fmt.Sprintf("[section %s of %s]", attrValue(se, "sec"), attrValue(se, "target")))
|
||||
case "fmt-of,rel,sec,target":
|
||||
return []byte(fmt.Sprintf("[section %s of %s, %s]", attrValue(se, "sec"), attrValue(se, "target"), attrValue(se, "rel")))
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown attribute signature %q in %#v", sig, fmt.Sprintf("%#v", se)))
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected tag %q", v))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var skipAnchor = map[string]bool{
|
||||
"intro": true,
|
||||
"Overview": true,
|
||||
}
|
||||
|
||||
var skipTitle = map[string]bool{
|
||||
"Acknowledgements": true,
|
||||
"Change Log": true,
|
||||
"Document Organization": true,
|
||||
"Conventions and Terminology": true,
|
||||
}
|
||||
|
||||
func skipElement(s xml.StartElement) bool {
|
||||
switch s.Name.Local {
|
||||
case "artwork":
|
||||
return true
|
||||
case "section":
|
||||
for _, attr := range s.Attr {
|
||||
switch attr.Name.Local {
|
||||
case "anchor":
|
||||
if skipAnchor[attr.Value] || strings.HasPrefix(attr.Value, "changes.since.") {
|
||||
return true
|
||||
}
|
||||
case "title":
|
||||
if skipTitle[attr.Value] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func readSpecCov(r io.Reader) specCoverage {
|
||||
sc := specCoverage{
|
||||
coverage: map[specPart]bool{},
|
||||
d: xml.NewDecoder(r)}
|
||||
sc.readSection(nil)
|
||||
return sc
|
||||
}
|
||||
|
||||
func (sc specCoverage) addSentences(sec string, sentence string) {
|
||||
for _, s := range parseSentences(sentence) {
|
||||
sc.coverage[specPart{sec, s}] = false
|
||||
}
|
||||
}
|
||||
|
||||
func (sc specCoverage) cover(sec string, sentence string) {
|
||||
for _, s := range parseSentences(sentence) {
|
||||
p := specPart{sec, s}
|
||||
if _, ok := sc.coverage[p]; !ok {
|
||||
panic(fmt.Sprintf("Not found in spec: %q, %q", sec, s))
|
||||
}
|
||||
sc.coverage[specPart{sec, s}] = true
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
var whitespaceRx = regexp.MustCompile(`\s+`)
|
||||
|
||||
func parseSentences(sens string) []string {
|
||||
sens = strings.TrimSpace(sens)
|
||||
if sens == "" {
|
||||
return nil
|
||||
}
|
||||
ss := strings.Split(whitespaceRx.ReplaceAllString(sens, " "), ". ")
|
||||
for i, s := range ss {
|
||||
s = strings.TrimSpace(s)
|
||||
if !strings.HasSuffix(s, ".") {
|
||||
s += "."
|
||||
}
|
||||
ss[i] = s
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
func TestSpecParseSentences(t *testing.T) {
|
||||
tests := []struct {
|
||||
ss string
|
||||
want []string
|
||||
}{
|
||||
{"Sentence 1. Sentence 2.",
|
||||
[]string{
|
||||
"Sentence 1.",
|
||||
"Sentence 2.",
|
||||
}},
|
||||
{"Sentence 1. \nSentence 2.\tSentence 3.",
|
||||
[]string{
|
||||
"Sentence 1.",
|
||||
"Sentence 2.",
|
||||
"Sentence 3.",
|
||||
}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := parseSentences(tt.ss)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("%d: got = %q, want %q", i, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpecCoverage(t *testing.T) {
|
||||
if !*coverSpec {
|
||||
t.Skip()
|
||||
}
|
||||
|
||||
loadSpecOnce.Do(loadSpec)
|
||||
|
||||
var (
|
||||
list []specPart
|
||||
cv = defaultSpecCoverage.coverage
|
||||
total = len(cv)
|
||||
complete = 0
|
||||
)
|
||||
|
||||
for sp, touched := range defaultSpecCoverage.coverage {
|
||||
if touched {
|
||||
complete++
|
||||
} else {
|
||||
list = append(list, sp)
|
||||
}
|
||||
}
|
||||
sort.Stable(bySpecSection(list))
|
||||
|
||||
if testing.Short() && len(list) > 5 {
|
||||
list = list[:5]
|
||||
}
|
||||
|
||||
for _, p := range list {
|
||||
t.Errorf("\tSECTION %s: %s", p.section, p.sentence)
|
||||
}
|
||||
|
||||
t.Logf("%d/%d (%d%%) sentances covered", complete, total, (complete/total)*100)
|
||||
}
|
||||
|
||||
func attrSig(se xml.StartElement) string {
|
||||
var names []string
|
||||
for _, attr := range se.Attr {
|
||||
if attr.Name.Local == "fmt" {
|
||||
names = append(names, "fmt-"+attr.Value)
|
||||
} else {
|
||||
names = append(names, attr.Name.Local)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
return strings.Join(names, ",")
|
||||
}
|
||||
|
||||
func attrValue(se xml.StartElement, attr string) string {
|
||||
for _, a := range se.Attr {
|
||||
if a.Name.Local == attr {
|
||||
return a.Value
|
||||
}
|
||||
}
|
||||
panic("unknown attribute " + attr)
|
||||
}
|
||||
|
||||
func TestSpecPartLess(t *testing.T) {
|
||||
tests := []struct {
|
||||
sec1, sec2 string
|
||||
want bool
|
||||
}{
|
||||
{"6.2.1", "6.2", false},
|
||||
{"6.2", "6.2.1", true},
|
||||
{"6.10", "6.10.1", true},
|
||||
{"6.10", "6.1.1", false}, // 10, not 1
|
||||
{"6.1", "6.1", false}, // equal, so not less
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := (specPart{tt.sec1, "foo"}).Less(specPart{tt.sec2, "foo"})
|
||||
if got != tt.want {
|
||||
t.Errorf("Less(%q, %q) = %v; want %v", tt.sec1, tt.sec2, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
language: go
|
||||
sudo: false
|
||||
|
||||
go:
|
||||
- 1.0.3
|
||||
- 1.1.2
|
||||
- 1.2.2
|
||||
- 1.3.3
|
||||
- 1.4.2
|
||||
|
||||
script:
|
||||
- go vet ./...
|
||||
- go test -v ./...
|
@ -0,0 +1,21 @@
|
||||
Copyright (C) 2013 Jeremy Saenz
|
||||
All Rights Reserved.
|
||||
|
||||
MIT LICENSE
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,308 @@
|
||||
[![Build Status](https://travis-ci.org/codegangsta/cli.png?branch=master)](https://travis-ci.org/codegangsta/cli)
|
||||
|
||||
# cli.go
|
||||
cli.go is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way.
|
||||
|
||||
You can view the API docs here:
|
||||
http://godoc.org/github.com/codegangsta/cli
|
||||
|
||||
## Overview
|
||||
Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app.
|
||||
|
||||
**This is where cli.go comes into play.** cli.go makes command line programming fun, organized, and expressive!
|
||||
|
||||
## Installation
|
||||
Make sure you have a working Go environment (go 1.1+ is *required*). [See the install instructions](http://golang.org/doc/install.html).
|
||||
|
||||
To install `cli.go`, simply run:
|
||||
```
|
||||
$ go get github.com/codegangsta/cli
|
||||
```
|
||||
|
||||
Make sure your `PATH` includes to the `$GOPATH/bin` directory so your commands can be easily used:
|
||||
```
|
||||
export PATH=$PATH:$GOPATH/bin
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
One of the philosophies behind cli.go is that an API should be playful and full of discovery. So a cli.go app can be as little as one line of code in `main()`.
|
||||
|
||||
``` go
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"github.com/codegangsta/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cli.NewApp().Run(os.Args)
|
||||
}
|
||||
```
|
||||
|
||||
This app will run and show help text, but is not very useful. Let's give an action to execute and some help documentation:
|
||||
|
||||
``` go
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"github.com/codegangsta/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := cli.NewApp()
|
||||
app.Name = "boom"
|
||||
app.Usage = "make an explosive entrance"
|
||||
app.Action = func(c *cli.Context) {
|
||||
println("boom! I say!")
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
}
|
||||
```
|
||||
|
||||
Running this already gives you a ton of functionality, plus support for things like subcommands and flags, which are covered below.
|
||||
|
||||
## Example
|
||||
|
||||
Being a programmer can be a lonely job. Thankfully by the power of automation that is not the case! Let's create a greeter app to fend off our demons of loneliness!
|
||||
|
||||
Start by creating a directory named `greet`, and within it, add a file, `greet.go` with the following code in it:
|
||||
|
||||
``` go
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"github.com/codegangsta/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
app := cli.NewApp()
|
||||
app.Name = "greet"
|
||||
app.Usage = "fight the loneliness!"
|
||||
app.Action = func(c *cli.Context) {
|
||||
println("Hello friend!")
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
}
|
||||
```
|
||||
|
||||
Install our command to the `$GOPATH/bin` directory:
|
||||
|
||||
```
|
||||
$ go install
|
||||
```
|
||||
|
||||
Finally run our new command:
|
||||
|
||||
```
|
||||
$ greet
|
||||
Hello friend!
|
||||
```
|
||||
|
||||
cli.go also generates some bitchass help text:
|
||||
```
|
||||
$ greet help
|
||||
NAME:
|
||||
greet - fight the loneliness!
|
||||
|
||||
USAGE:
|
||||
greet [global options] command [command options] [arguments...]
|
||||
|
||||
VERSION:
|
||||
0.0.0
|
||||
|
||||
COMMANDS:
|
||||
help, h Shows a list of commands or help for one command
|
||||
|
||||
GLOBAL OPTIONS
|
||||
--version Shows version information
|
||||
```
|
||||
|
||||
### Arguments
|
||||
You can lookup arguments by calling the `Args` function on `cli.Context`.
|
||||
|
||||
``` go
|
||||
...
|
||||
app.Action = func(c *cli.Context) {
|
||||
println("Hello", c.Args()[0])
|
||||
}
|
||||
...
|
||||
```
|
||||
|
||||
### Flags
|
||||
Setting and querying flags is simple.
|
||||
``` go
|
||||
...
|
||||
app.Flags = []cli.Flag {
|
||||
cli.StringFlag{
|
||||
Name: "lang",
|
||||
Value: "english",
|
||||
Usage: "language for the greeting",
|
||||
},
|
||||
}
|
||||
app.Action = func(c *cli.Context) {
|
||||
name := "someone"
|
||||
if len(c.Args()) > 0 {
|
||||
name = c.Args()[0]
|
||||
}
|
||||
if c.String("lang") == "spanish" {
|
||||
println("Hola", name)
|
||||
} else {
|
||||
println("Hello", name)
|
||||
}
|
||||
}
|
||||
...
|
||||
```
|
||||
|
||||
See full list of flags at http://godoc.org/github.com/codegangsta/cli
|
||||
|
||||
#### Alternate Names
|
||||
|
||||
You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g.
|
||||
|
||||
``` go
|
||||
app.Flags = []cli.Flag {
|
||||
cli.StringFlag{
|
||||
Name: "lang, l",
|
||||
Value: "english",
|
||||
Usage: "language for the greeting",
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error.
|
||||
|
||||
#### Values from the Environment
|
||||
|
||||
You can also have the default value set from the environment via `EnvVar`. e.g.
|
||||
|
||||
``` go
|
||||
app.Flags = []cli.Flag {
|
||||
cli.StringFlag{
|
||||
Name: "lang, l",
|
||||
Value: "english",
|
||||
Usage: "language for the greeting",
|
||||
EnvVar: "APP_LANG",
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
The `EnvVar` may also be given as a comma-delimited "cascade", where the first environment variable that resolves is used as the default.
|
||||
|
||||
``` go
|
||||
app.Flags = []cli.Flag {
|
||||
cli.StringFlag{
|
||||
Name: "lang, l",
|
||||
Value: "english",
|
||||
Usage: "language for the greeting",
|
||||
EnvVar: "LEGACY_COMPAT_LANG,APP_LANG,LANG",
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### Subcommands
|
||||
|
||||
Subcommands can be defined for a more git-like command line app.
|
||||
```go
|
||||
...
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "add",
|
||||
Aliases: []string{"a"},
|
||||
Usage: "add a task to the list",
|
||||
Action: func(c *cli.Context) {
|
||||
println("added task: ", c.Args().First())
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "complete",
|
||||
Aliases: []string{"c"},
|
||||
Usage: "complete a task on the list",
|
||||
Action: func(c *cli.Context) {
|
||||
println("completed task: ", c.Args().First())
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "template",
|
||||
Aliases: []string{"r"},
|
||||
Usage: "options for task templates",
|
||||
Subcommands: []cli.Command{
|
||||
{
|
||||
Name: "add",
|
||||
Usage: "add a new template",
|
||||
Action: func(c *cli.Context) {
|
||||
println("new task template: ", c.Args().First())
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "remove",
|
||||
Usage: "remove an existing template",
|
||||
Action: func(c *cli.Context) {
|
||||
println("removed task template: ", c.Args().First())
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
...
|
||||
```
|
||||
|
||||
### Bash Completion
|
||||
|
||||
You can enable completion commands by setting the `EnableBashCompletion`
|
||||
flag on the `App` object. By default, this setting will only auto-complete to
|
||||
show an app's subcommands, but you can write your own completion methods for
|
||||
the App or its subcommands.
|
||||
```go
|
||||
...
|
||||
var tasks = []string{"cook", "clean", "laundry", "eat", "sleep", "code"}
|
||||
app := cli.NewApp()
|
||||
app.EnableBashCompletion = true
|
||||
app.Commands = []cli.Command{
|
||||
{
|
||||
Name: "complete",
|
||||
Aliases: []string{"c"},
|
||||
Usage: "complete a task on the list",
|
||||
Action: func(c *cli.Context) {
|
||||
println("completed task: ", c.Args().First())
|
||||
},
|
||||
BashComplete: func(c *cli.Context) {
|
||||
// This will complete if no args are passed
|
||||
if len(c.Args()) > 0 {
|
||||
return
|
||||
}
|
||||
for _, t := range tasks {
|
||||
fmt.Println(t)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
...
|
||||
```
|
||||
|
||||
#### To Enable
|
||||
|
||||
Source the `autocomplete/bash_autocomplete` file in your `.bashrc` file while
|
||||
setting the `PROG` variable to the name of your program:
|
||||
|
||||
`PROG=myprogram source /.../cli/autocomplete/bash_autocomplete`
|
||||
|
||||
#### To Distribute
|
||||
|
||||
Copy and modify `autocomplete/bash_autocomplete` to use your program name
|
||||
rather than `$PROG` and have the user copy the file into
|
||||
`/etc/bash_completion.d/` (or automatically install it there if you are
|
||||
distributing a package). Alternatively you can just document that users should
|
||||
source the generic `autocomplete/bash_autocomplete` with `$PROG` set to your
|
||||
program name in their bash configuration.
|
||||
|
||||
## Contribution Guidelines
|
||||
Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch.
|
||||
|
||||
If you have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together.
|
||||
|
||||
If you feel like you have contributed to the project but have not yet been added as a collaborator, I probably forgot to add you. Hit @codegangsta up over email and we will get it figured out.
|
@ -0,0 +1,308 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// App is the main structure of a cli application. It is recomended that
|
||||
// an app be created with the cli.NewApp() function
|
||||
type App struct {
|
||||
// The name of the program. Defaults to os.Args[0]
|
||||
Name string
|
||||
// Description of the program.
|
||||
Usage string
|
||||
// Version of the program
|
||||
Version string
|
||||
// List of commands to execute
|
||||
Commands []Command
|
||||
// List of flags to parse
|
||||
Flags []Flag
|
||||
// Boolean to enable bash completion commands
|
||||
EnableBashCompletion bool
|
||||
// Boolean to hide built-in help command
|
||||
HideHelp bool
|
||||
// Boolean to hide built-in version flag
|
||||
HideVersion bool
|
||||
// An action to execute when the bash-completion flag is set
|
||||
BashComplete func(context *Context)
|
||||
// An action to execute before any subcommands are run, but after the context is ready
|
||||
// If a non-nil error is returned, no subcommands are run
|
||||
Before func(context *Context) error
|
||||
// An action to execute after any subcommands are run, but after the subcommand has finished
|
||||
// It is run even if Action() panics
|
||||
After func(context *Context) error
|
||||
// The action to execute when no subcommands are specified
|
||||
Action func(context *Context)
|
||||
// Execute this function if the proper command cannot be found
|
||||
CommandNotFound func(context *Context, command string)
|
||||
// Compilation date
|
||||
Compiled time.Time
|
||||
// List of all authors who contributed
|
||||
Authors []Author
|
||||
// Copyright of the binary if any
|
||||
Copyright string
|
||||
// Name of Author (Note: Use App.Authors, this is deprecated)
|
||||
Author string
|
||||
// Email of Author (Note: Use App.Authors, this is deprecated)
|
||||
Email string
|
||||
// Writer writer to write output to
|
||||
Writer io.Writer
|
||||
}
|
||||
|
||||
// Tries to find out when this binary was compiled.
|
||||
// Returns the current time if it fails to find it.
|
||||
func compileTime() time.Time {
|
||||
info, err := os.Stat(os.Args[0])
|
||||
if err != nil {
|
||||
return time.Now()
|
||||
}
|
||||
return info.ModTime()
|
||||
}
|
||||
|
||||
// Creates a new cli Application with some reasonable defaults for Name, Usage, Version and Action.
|
||||
func NewApp() *App {
|
||||
return &App{
|
||||
Name: os.Args[0],
|
||||
Usage: "A new cli application",
|
||||
Version: "0.0.0",
|
||||
BashComplete: DefaultAppComplete,
|
||||
Action: helpCommand.Action,
|
||||
Compiled: compileTime(),
|
||||
Writer: os.Stdout,
|
||||
}
|
||||
}
|
||||
|
||||
// Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination
|
||||
func (a *App) Run(arguments []string) (err error) {
|
||||
if a.Author != "" || a.Email != "" {
|
||||
a.Authors = append(a.Authors, Author{Name: a.Author, Email: a.Email})
|
||||
}
|
||||
|
||||
// append help to commands
|
||||
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
|
||||
a.Commands = append(a.Commands, helpCommand)
|
||||
if (HelpFlag != BoolFlag{}) {
|
||||
a.appendFlag(HelpFlag)
|
||||
}
|
||||
}
|
||||
|
||||
//append version/help flags
|
||||
if a.EnableBashCompletion {
|
||||
a.appendFlag(BashCompletionFlag)
|
||||
}
|
||||
|
||||
if !a.HideVersion {
|
||||
a.appendFlag(VersionFlag)
|
||||
}
|
||||
|
||||
// parse flags
|
||||
set := flagSet(a.Name, a.Flags)
|
||||
set.SetOutput(ioutil.Discard)
|
||||
err = set.Parse(arguments[1:])
|
||||
nerr := normalizeFlags(a.Flags, set)
|
||||
if nerr != nil {
|
||||
fmt.Fprintln(a.Writer, nerr)
|
||||
context := NewContext(a, set, nil)
|
||||
ShowAppHelp(context)
|
||||
return nerr
|
||||
}
|
||||
context := NewContext(a, set, nil)
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintln(a.Writer, "Incorrect Usage.")
|
||||
fmt.Fprintln(a.Writer)
|
||||
ShowAppHelp(context)
|
||||
return err
|
||||
}
|
||||
|
||||
if checkCompletions(context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if checkHelp(context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if checkVersion(context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if a.After != nil {
|
||||
defer func() {
|
||||
afterErr := a.After(context)
|
||||
if afterErr != nil {
|
||||
if err != nil {
|
||||
err = NewMultiError(err, afterErr)
|
||||
} else {
|
||||
err = afterErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if a.Before != nil {
|
||||
err := a.Before(context)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
args := context.Args()
|
||||
if args.Present() {
|
||||
name := args.First()
|
||||
c := a.Command(name)
|
||||
if c != nil {
|
||||
return c.Run(context)
|
||||
}
|
||||
}
|
||||
|
||||
// Run default Action
|
||||
a.Action(context)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Another entry point to the cli app, takes care of passing arguments and error handling
|
||||
func (a *App) RunAndExitOnError() {
|
||||
if err := a.Run(os.Args); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Invokes the subcommand given the context, parses ctx.Args() to generate command-specific flags
|
||||
func (a *App) RunAsSubcommand(ctx *Context) (err error) {
|
||||
// append help to commands
|
||||
if len(a.Commands) > 0 {
|
||||
if a.Command(helpCommand.Name) == nil && !a.HideHelp {
|
||||
a.Commands = append(a.Commands, helpCommand)
|
||||
if (HelpFlag != BoolFlag{}) {
|
||||
a.appendFlag(HelpFlag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// append flags
|
||||
if a.EnableBashCompletion {
|
||||
a.appendFlag(BashCompletionFlag)
|
||||
}
|
||||
|
||||
// parse flags
|
||||
set := flagSet(a.Name, a.Flags)
|
||||
set.SetOutput(ioutil.Discard)
|
||||
err = set.Parse(ctx.Args().Tail())
|
||||
nerr := normalizeFlags(a.Flags, set)
|
||||
context := NewContext(a, set, ctx)
|
||||
|
||||
if nerr != nil {
|
||||
fmt.Fprintln(a.Writer, nerr)
|
||||
fmt.Fprintln(a.Writer)
|
||||
if len(a.Commands) > 0 {
|
||||
ShowSubcommandHelp(context)
|
||||
} else {
|
||||
ShowCommandHelp(ctx, context.Args().First())
|
||||
}
|
||||
return nerr
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintln(a.Writer, "Incorrect Usage.")
|
||||
fmt.Fprintln(a.Writer)
|
||||
ShowSubcommandHelp(context)
|
||||
return err
|
||||
}
|
||||
|
||||
if checkCompletions(context) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(a.Commands) > 0 {
|
||||
if checkSubcommandHelp(context) {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
if checkCommandHelp(ctx, context.Args().First()) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if a.After != nil {
|
||||
defer func() {
|
||||
afterErr := a.After(context)
|
||||
if afterErr != nil {
|
||||
if err != nil {
|
||||
err = NewMultiError(err, afterErr)
|
||||
} else {
|
||||
err = afterErr
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if a.Before != nil {
|
||||
err := a.Before(context)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
args := context.Args()
|
||||
if args.Present() {
|
||||
name := args.First()
|
||||
c := a.Command(name)
|
||||
if c != nil {
|
||||
return c.Run(context)
|
||||
}
|
||||
}
|
||||
|
||||
// Run default Action
|
||||
a.Action(context)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the named command on App. Returns nil if the command does not exist
|
||||
func (a *App) Command(name string) *Command {
|
||||
for _, c := range a.Commands {
|
||||
if c.HasName(name) {
|
||||
return &c
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *App) hasFlag(flag Flag) bool {
|
||||
for _, f := range a.Flags {
|
||||
if flag == f {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *App) appendFlag(flag Flag) {
|
||||
if !a.hasFlag(flag) {
|
||||
a.Flags = append(a.Flags, flag)
|
||||
}
|
||||
}
|
||||
|
||||
// Author represents someone who has contributed to a cli project.
|
||||
type Author struct {
|
||||
Name string // The Authors name
|
||||
Email string // The Authors email
|
||||
}
|
||||
|
||||
// String makes Author comply to the Stringer interface, to allow an easy print in the templating process
|
||||
func (a Author) String() string {
|
||||
e := ""
|
||||
if a.Email != "" {
|
||||
e = "<" + a.Email + "> "
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v %v", a.Name, e)
|
||||
}
|
@ -0,0 +1,867 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ExampleApp() {
|
||||
// set args for examples sake
|
||||
os.Args = []string{"greet", "--name", "Jeremy"}
|
||||
|
||||
app := NewApp()
|
||||
app.Name = "greet"
|
||||
app.Flags = []Flag{
|
||||
StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
|
||||
}
|
||||
app.Action = func(c *Context) {
|
||||
fmt.Printf("Hello %v\n", c.String("name"))
|
||||
}
|
||||
app.Author = "Harrison"
|
||||
app.Email = "harrison@lolwut.com"
|
||||
app.Authors = []Author{Author{Name: "Oliver Allen", Email: "oliver@toyshop.com"}}
|
||||
app.Run(os.Args)
|
||||
// Output:
|
||||
// Hello Jeremy
|
||||
}
|
||||
|
||||
func ExampleAppSubcommand() {
|
||||
// set args for examples sake
|
||||
os.Args = []string{"say", "hi", "english", "--name", "Jeremy"}
|
||||
app := NewApp()
|
||||
app.Name = "say"
|
||||
app.Commands = []Command{
|
||||
{
|
||||
Name: "hello",
|
||||
Aliases: []string{"hi"},
|
||||
Usage: "use it to see a description",
|
||||
Description: "This is how we describe hello the function",
|
||||
Subcommands: []Command{
|
||||
{
|
||||
Name: "english",
|
||||
Aliases: []string{"en"},
|
||||
Usage: "sends a greeting in english",
|
||||
Description: "greets someone in english",
|
||||
Flags: []Flag{
|
||||
StringFlag{
|
||||
Name: "name",
|
||||
Value: "Bob",
|
||||
Usage: "Name of the person to greet",
|
||||
},
|
||||
},
|
||||
Action: func(c *Context) {
|
||||
fmt.Println("Hello,", c.String("name"))
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
// Output:
|
||||
// Hello, Jeremy
|
||||
}
|
||||
|
||||
func ExampleAppHelp() {
|
||||
// set args for examples sake
|
||||
os.Args = []string{"greet", "h", "describeit"}
|
||||
|
||||
app := NewApp()
|
||||
app.Name = "greet"
|
||||
app.Flags = []Flag{
|
||||
StringFlag{Name: "name", Value: "bob", Usage: "a name to say"},
|
||||
}
|
||||
app.Commands = []Command{
|
||||
{
|
||||
Name: "describeit",
|
||||
Aliases: []string{"d"},
|
||||
Usage: "use it to see a description",
|
||||
Description: "This is how we describe describeit the function",
|
||||
Action: func(c *Context) {
|
||||
fmt.Printf("i like to describe things")
|
||||
},
|
||||
},
|
||||
}
|
||||
app.Run(os.Args)
|
||||
// Output:
|
||||
// NAME:
|
||||
// describeit - use it to see a description
|
||||
//
|
||||
// USAGE:
|
||||
// command describeit [arguments...]
|
||||
//
|
||||
// DESCRIPTION:
|
||||
// This is how we describe describeit the function
|
||||
}
|
||||
|
||||
func ExampleAppBashComplete() {
|
||||
// set args for examples sake
|
||||
os.Args = []string{"greet", "--generate-bash-completion"}
|
||||
|
||||
app := NewApp()
|
||||
app.Name = "greet"
|
||||
app.EnableBashCompletion = true
|
||||
app.Commands = []Command{
|
||||
{
|
||||
Name: "describeit",
|
||||
Aliases: []string{"d"},
|
||||
Usage: "use it to see a description",
|
||||
Description: "This is how we describe describeit the function",
|
||||
Action: func(c *Context) {
|
||||
fmt.Printf("i like to describe things")
|
||||
},
|
||||
}, {
|
||||
Name: "next",
|
||||
Usage: "next example",
|
||||
Description: "more stuff to see when generating bash completion",
|
||||
Action: func(c *Context) {
|
||||
fmt.Printf("the next example")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
// Output:
|
||||
// describeit
|
||||
// d
|
||||
// next
|
||||
// help
|
||||
// h
|
||||
}
|
||||
|
||||
func TestApp_Run(t *testing.T) {
|
||||
s := ""
|
||||
|
||||
app := NewApp()
|
||||
app.Action = func(c *Context) {
|
||||
s = s + c.Args().First()
|
||||
}
|
||||
|
||||
err := app.Run([]string{"command", "foo"})
|
||||
expect(t, err, nil)
|
||||
err = app.Run([]string{"command", "bar"})
|
||||
expect(t, err, nil)
|
||||
expect(t, s, "foobar")
|
||||
}
|
||||
|
||||
var commandAppTests = []struct {
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{"foobar", true},
|
||||
{"batbaz", true},
|
||||
{"b", true},
|
||||
{"f", true},
|
||||
{"bat", false},
|
||||
{"nothing", false},
|
||||
}
|
||||
|
||||
func TestApp_Command(t *testing.T) {
|
||||
app := NewApp()
|
||||
fooCommand := Command{Name: "foobar", Aliases: []string{"f"}}
|
||||
batCommand := Command{Name: "batbaz", Aliases: []string{"b"}}
|
||||
app.Commands = []Command{
|
||||
fooCommand,
|
||||
batCommand,
|
||||
}
|
||||
|
||||
for _, test := range commandAppTests {
|
||||
expect(t, app.Command(test.name) != nil, test.expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_CommandWithArgBeforeFlags(t *testing.T) {
|
||||
var parsedOption, firstArg string
|
||||
|
||||
app := NewApp()
|
||||
command := Command{
|
||||
Name: "cmd",
|
||||
Flags: []Flag{
|
||||
StringFlag{Name: "option", Value: "", Usage: "some option"},
|
||||
},
|
||||
Action: func(c *Context) {
|
||||
parsedOption = c.String("option")
|
||||
firstArg = c.Args().First()
|
||||
},
|
||||
}
|
||||
app.Commands = []Command{command}
|
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option"})
|
||||
|
||||
expect(t, parsedOption, "my-option")
|
||||
expect(t, firstArg, "my-arg")
|
||||
}
|
||||
|
||||
func TestApp_RunAsSubcommandParseFlags(t *testing.T) {
|
||||
var context *Context
|
||||
|
||||
a := NewApp()
|
||||
a.Commands = []Command{
|
||||
{
|
||||
Name: "foo",
|
||||
Action: func(c *Context) {
|
||||
context = c
|
||||
},
|
||||
Flags: []Flag{
|
||||
StringFlag{
|
||||
Name: "lang",
|
||||
Value: "english",
|
||||
Usage: "language for the greeting",
|
||||
},
|
||||
},
|
||||
Before: func(_ *Context) error { return nil },
|
||||
},
|
||||
}
|
||||
a.Run([]string{"", "foo", "--lang", "spanish", "abcd"})
|
||||
|
||||
expect(t, context.Args().Get(0), "abcd")
|
||||
expect(t, context.String("lang"), "spanish")
|
||||
}
|
||||
|
||||
func TestApp_CommandWithFlagBeforeTerminator(t *testing.T) {
|
||||
var parsedOption string
|
||||
var args []string
|
||||
|
||||
app := NewApp()
|
||||
command := Command{
|
||||
Name: "cmd",
|
||||
Flags: []Flag{
|
||||
StringFlag{Name: "option", Value: "", Usage: "some option"},
|
||||
},
|
||||
Action: func(c *Context) {
|
||||
parsedOption = c.String("option")
|
||||
args = c.Args()
|
||||
},
|
||||
}
|
||||
app.Commands = []Command{command}
|
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "--option", "my-option", "--", "--notARealFlag"})
|
||||
|
||||
expect(t, parsedOption, "my-option")
|
||||
expect(t, args[0], "my-arg")
|
||||
expect(t, args[1], "--")
|
||||
expect(t, args[2], "--notARealFlag")
|
||||
}
|
||||
|
||||
func TestApp_CommandWithNoFlagBeforeTerminator(t *testing.T) {
|
||||
var args []string
|
||||
|
||||
app := NewApp()
|
||||
command := Command{
|
||||
Name: "cmd",
|
||||
Action: func(c *Context) {
|
||||
args = c.Args()
|
||||
},
|
||||
}
|
||||
app.Commands = []Command{command}
|
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "--", "notAFlagAtAll"})
|
||||
|
||||
expect(t, args[0], "my-arg")
|
||||
expect(t, args[1], "--")
|
||||
expect(t, args[2], "notAFlagAtAll")
|
||||
}
|
||||
|
||||
func TestApp_Float64Flag(t *testing.T) {
|
||||
var meters float64
|
||||
|
||||
app := NewApp()
|
||||
app.Flags = []Flag{
|
||||
Float64Flag{Name: "height", Value: 1.5, Usage: "Set the height, in meters"},
|
||||
}
|
||||
app.Action = func(c *Context) {
|
||||
meters = c.Float64("height")
|
||||
}
|
||||
|
||||
app.Run([]string{"", "--height", "1.93"})
|
||||
expect(t, meters, 1.93)
|
||||
}
|
||||
|
||||
func TestApp_ParseSliceFlags(t *testing.T) {
|
||||
var parsedOption, firstArg string
|
||||
var parsedIntSlice []int
|
||||
var parsedStringSlice []string
|
||||
|
||||
app := NewApp()
|
||||
command := Command{
|
||||
Name: "cmd",
|
||||
Flags: []Flag{
|
||||
IntSliceFlag{Name: "p", Value: &IntSlice{}, Usage: "set one or more ip addr"},
|
||||
StringSliceFlag{Name: "ip", Value: &StringSlice{}, Usage: "set one or more ports to open"},
|
||||
},
|
||||
Action: func(c *Context) {
|
||||
parsedIntSlice = c.IntSlice("p")
|
||||
parsedStringSlice = c.StringSlice("ip")
|
||||
parsedOption = c.String("option")
|
||||
firstArg = c.Args().First()
|
||||
},
|
||||
}
|
||||
app.Commands = []Command{command}
|
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "-p", "22", "-p", "80", "-ip", "8.8.8.8", "-ip", "8.8.4.4"})
|
||||
|
||||
IntsEquals := func(a, b []int) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
StrsEquals := func(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
var expectedIntSlice = []int{22, 80}
|
||||
var expectedStringSlice = []string{"8.8.8.8", "8.8.4.4"}
|
||||
|
||||
if !IntsEquals(parsedIntSlice, expectedIntSlice) {
|
||||
t.Errorf("%v does not match %v", parsedIntSlice, expectedIntSlice)
|
||||
}
|
||||
|
||||
if !StrsEquals(parsedStringSlice, expectedStringSlice) {
|
||||
t.Errorf("%v does not match %v", parsedStringSlice, expectedStringSlice)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_ParseSliceFlagsWithMissingValue(t *testing.T) {
|
||||
var parsedIntSlice []int
|
||||
var parsedStringSlice []string
|
||||
|
||||
app := NewApp()
|
||||
command := Command{
|
||||
Name: "cmd",
|
||||
Flags: []Flag{
|
||||
IntSliceFlag{Name: "a", Usage: "set numbers"},
|
||||
StringSliceFlag{Name: "str", Usage: "set strings"},
|
||||
},
|
||||
Action: func(c *Context) {
|
||||
parsedIntSlice = c.IntSlice("a")
|
||||
parsedStringSlice = c.StringSlice("str")
|
||||
},
|
||||
}
|
||||
app.Commands = []Command{command}
|
||||
|
||||
app.Run([]string{"", "cmd", "my-arg", "-a", "2", "-str", "A"})
|
||||
|
||||
var expectedIntSlice = []int{2}
|
||||
var expectedStringSlice = []string{"A"}
|
||||
|
||||
if parsedIntSlice[0] != expectedIntSlice[0] {
|
||||
t.Errorf("%v does not match %v", parsedIntSlice[0], expectedIntSlice[0])
|
||||
}
|
||||
|
||||
if parsedStringSlice[0] != expectedStringSlice[0] {
|
||||
t.Errorf("%v does not match %v", parsedIntSlice[0], expectedIntSlice[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_DefaultStdout(t *testing.T) {
|
||||
app := NewApp()
|
||||
|
||||
if app.Writer != os.Stdout {
|
||||
t.Error("Default output writer not set.")
|
||||
}
|
||||
}
|
||||
|
||||
type mockWriter struct {
|
||||
written []byte
|
||||
}
|
||||
|
||||
func (fw *mockWriter) Write(p []byte) (n int, err error) {
|
||||
if fw.written == nil {
|
||||
fw.written = p
|
||||
} else {
|
||||
fw.written = append(fw.written, p...)
|
||||
}
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (fw *mockWriter) GetWritten() (b []byte) {
|
||||
return fw.written
|
||||
}
|
||||
|
||||
func TestApp_SetStdout(t *testing.T) {
|
||||
w := &mockWriter{}
|
||||
|
||||
app := NewApp()
|
||||
app.Name = "test"
|
||||
app.Writer = w
|
||||
|
||||
err := app.Run([]string{"help"})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Run error: %s", err)
|
||||
}
|
||||
|
||||
if len(w.written) == 0 {
|
||||
t.Error("App did not write output to desired writer.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_BeforeFunc(t *testing.T) {
|
||||
beforeRun, subcommandRun := false, false
|
||||
beforeError := fmt.Errorf("fail")
|
||||
var err error
|
||||
|
||||
app := NewApp()
|
||||
|
||||
app.Before = func(c *Context) error {
|
||||
beforeRun = true
|
||||
s := c.String("opt")
|
||||
if s == "fail" {
|
||||
return beforeError
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
app.Commands = []Command{
|
||||
Command{
|
||||
Name: "sub",
|
||||
Action: func(c *Context) {
|
||||
subcommandRun = true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Flags = []Flag{
|
||||
StringFlag{Name: "opt"},
|
||||
}
|
||||
|
||||
// run with the Before() func succeeding
|
||||
err = app.Run([]string{"command", "--opt", "succeed", "sub"})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Run error: %s", err)
|
||||
}
|
||||
|
||||
if beforeRun == false {
|
||||
t.Errorf("Before() not executed when expected")
|
||||
}
|
||||
|
||||
if subcommandRun == false {
|
||||
t.Errorf("Subcommand not executed when expected")
|
||||
}
|
||||
|
||||
// reset
|
||||
beforeRun, subcommandRun = false, false
|
||||
|
||||
// run with the Before() func failing
|
||||
err = app.Run([]string{"command", "--opt", "fail", "sub"})
|
||||
|
||||
// should be the same error produced by the Before func
|
||||
if err != beforeError {
|
||||
t.Errorf("Run error expected, but not received")
|
||||
}
|
||||
|
||||
if beforeRun == false {
|
||||
t.Errorf("Before() not executed when expected")
|
||||
}
|
||||
|
||||
if subcommandRun == true {
|
||||
t.Errorf("Subcommand executed when NOT expected")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestApp_AfterFunc(t *testing.T) {
|
||||
afterRun, subcommandRun := false, false
|
||||
afterError := fmt.Errorf("fail")
|
||||
var err error
|
||||
|
||||
app := NewApp()
|
||||
|
||||
app.After = func(c *Context) error {
|
||||
afterRun = true
|
||||
s := c.String("opt")
|
||||
if s == "fail" {
|
||||
return afterError
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
app.Commands = []Command{
|
||||
Command{
|
||||
Name: "sub",
|
||||
Action: func(c *Context) {
|
||||
subcommandRun = true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Flags = []Flag{
|
||||
StringFlag{Name: "opt"},
|
||||
}
|
||||
|
||||
// run with the After() func succeeding
|
||||
err = app.Run([]string{"command", "--opt", "succeed", "sub"})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Run error: %s", err)
|
||||
}
|
||||
|
||||
if afterRun == false {
|
||||
t.Errorf("After() not executed when expected")
|
||||
}
|
||||
|
||||
if subcommandRun == false {
|
||||
t.Errorf("Subcommand not executed when expected")
|
||||
}
|
||||
|
||||
// reset
|
||||
afterRun, subcommandRun = false, false
|
||||
|
||||
// run with the Before() func failing
|
||||
err = app.Run([]string{"command", "--opt", "fail", "sub"})
|
||||
|
||||
// should be the same error produced by the Before func
|
||||
if err != afterError {
|
||||
t.Errorf("Run error expected, but not received")
|
||||
}
|
||||
|
||||
if afterRun == false {
|
||||
t.Errorf("After() not executed when expected")
|
||||
}
|
||||
|
||||
if subcommandRun == false {
|
||||
t.Errorf("Subcommand not executed when expected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppNoHelpFlag(t *testing.T) {
|
||||
oldFlag := HelpFlag
|
||||
defer func() {
|
||||
HelpFlag = oldFlag
|
||||
}()
|
||||
|
||||
HelpFlag = BoolFlag{}
|
||||
|
||||
app := NewApp()
|
||||
err := app.Run([]string{"test", "-h"})
|
||||
|
||||
if err != flag.ErrHelp {
|
||||
t.Errorf("expected error about missing help flag, but got: %s (%T)", err, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppHelpPrinter(t *testing.T) {
|
||||
oldPrinter := HelpPrinter
|
||||
defer func() {
|
||||
HelpPrinter = oldPrinter
|
||||
}()
|
||||
|
||||
var wasCalled = false
|
||||
HelpPrinter = func(w io.Writer, template string, data interface{}) {
|
||||
wasCalled = true
|
||||
}
|
||||
|
||||
app := NewApp()
|
||||
app.Run([]string{"-h"})
|
||||
|
||||
if wasCalled == false {
|
||||
t.Errorf("Help printer expected to be called, but was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppVersionPrinter(t *testing.T) {
|
||||
oldPrinter := VersionPrinter
|
||||
defer func() {
|
||||
VersionPrinter = oldPrinter
|
||||
}()
|
||||
|
||||
var wasCalled = false
|
||||
VersionPrinter = func(c *Context) {
|
||||
wasCalled = true
|
||||
}
|
||||
|
||||
app := NewApp()
|
||||
ctx := NewContext(app, nil, nil)
|
||||
ShowVersion(ctx)
|
||||
|
||||
if wasCalled == false {
|
||||
t.Errorf("Version printer expected to be called, but was not")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppCommandNotFound(t *testing.T) {
|
||||
beforeRun, subcommandRun := false, false
|
||||
app := NewApp()
|
||||
|
||||
app.CommandNotFound = func(c *Context, command string) {
|
||||
beforeRun = true
|
||||
}
|
||||
|
||||
app.Commands = []Command{
|
||||
Command{
|
||||
Name: "bar",
|
||||
Action: func(c *Context) {
|
||||
subcommandRun = true
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run([]string{"command", "foo"})
|
||||
|
||||
expect(t, beforeRun, true)
|
||||
expect(t, subcommandRun, false)
|
||||
}
|
||||
|
||||
func TestGlobalFlag(t *testing.T) {
|
||||
var globalFlag string
|
||||
var globalFlagSet bool
|
||||
app := NewApp()
|
||||
app.Flags = []Flag{
|
||||
StringFlag{Name: "global, g", Usage: "global"},
|
||||
}
|
||||
app.Action = func(c *Context) {
|
||||
globalFlag = c.GlobalString("global")
|
||||
globalFlagSet = c.GlobalIsSet("global")
|
||||
}
|
||||
app.Run([]string{"command", "-g", "foo"})
|
||||
expect(t, globalFlag, "foo")
|
||||
expect(t, globalFlagSet, true)
|
||||
|
||||
}
|
||||
|
||||
func TestGlobalFlagsInSubcommands(t *testing.T) {
|
||||
subcommandRun := false
|
||||
parentFlag := false
|
||||
app := NewApp()
|
||||
|
||||
app.Flags = []Flag{
|
||||
BoolFlag{Name: "debug, d", Usage: "Enable debugging"},
|
||||
}
|
||||
|
||||
app.Commands = []Command{
|
||||
Command{
|
||||
Name: "foo",
|
||||
Flags: []Flag{
|
||||
BoolFlag{Name: "parent, p", Usage: "Parent flag"},
|
||||
},
|
||||
Subcommands: []Command{
|
||||
{
|
||||
Name: "bar",
|
||||
Action: func(c *Context) {
|
||||
if c.GlobalBool("debug") {
|
||||
subcommandRun = true
|
||||
}
|
||||
if c.GlobalBool("parent") {
|
||||
parentFlag = true
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run([]string{"command", "-d", "foo", "-p", "bar"})
|
||||
|
||||
expect(t, subcommandRun, true)
|
||||
expect(t, parentFlag, true)
|
||||
}
|
||||
|
||||
func TestApp_Run_CommandWithSubcommandHasHelpTopic(t *testing.T) {
|
||||
var subcommandHelpTopics = [][]string{
|
||||
{"command", "foo", "--help"},
|
||||
{"command", "foo", "-h"},
|
||||
{"command", "foo", "help"},
|
||||
}
|
||||
|
||||
for _, flagSet := range subcommandHelpTopics {
|
||||
t.Logf("==> checking with flags %v", flagSet)
|
||||
|
||||
app := NewApp()
|
||||
buf := new(bytes.Buffer)
|
||||
app.Writer = buf
|
||||
|
||||
subCmdBar := Command{
|
||||
Name: "bar",
|
||||
Usage: "does bar things",
|
||||
}
|
||||
subCmdBaz := Command{
|
||||
Name: "baz",
|
||||
Usage: "does baz things",
|
||||
}
|
||||
cmd := Command{
|
||||
Name: "foo",
|
||||
Description: "descriptive wall of text about how it does foo things",
|
||||
Subcommands: []Command{subCmdBar, subCmdBaz},
|
||||
}
|
||||
|
||||
app.Commands = []Command{cmd}
|
||||
err := app.Run(flagSet)
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
t.Logf("output: %q\n", buf.Bytes())
|
||||
|
||||
if strings.Contains(output, "No help topic for") {
|
||||
t.Errorf("expect a help topic, got none: \n%q", output)
|
||||
}
|
||||
|
||||
for _, shouldContain := range []string{
|
||||
cmd.Name, cmd.Description,
|
||||
subCmdBar.Name, subCmdBar.Usage,
|
||||
subCmdBaz.Name, subCmdBaz.Usage,
|
||||
} {
|
||||
if !strings.Contains(output, shouldContain) {
|
||||
t.Errorf("want help to contain %q, did not: \n%q", shouldContain, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_Run_SubcommandFullPath(t *testing.T) {
|
||||
app := NewApp()
|
||||
buf := new(bytes.Buffer)
|
||||
app.Writer = buf
|
||||
|
||||
subCmd := Command{
|
||||
Name: "bar",
|
||||
Usage: "does bar things",
|
||||
}
|
||||
cmd := Command{
|
||||
Name: "foo",
|
||||
Description: "foo commands",
|
||||
Subcommands: []Command{subCmd},
|
||||
}
|
||||
app.Commands = []Command{cmd}
|
||||
|
||||
err := app.Run([]string{"command", "foo", "bar", "--help"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
if !strings.Contains(output, "foo bar - does bar things") {
|
||||
t.Errorf("expected full path to subcommand: %s", output)
|
||||
}
|
||||
if !strings.Contains(output, "command foo bar [arguments...]") {
|
||||
t.Errorf("expected full path to subcommand: %s", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_Run_Help(t *testing.T) {
|
||||
var helpArguments = [][]string{{"boom", "--help"}, {"boom", "-h"}, {"boom", "help"}}
|
||||
|
||||
for _, args := range helpArguments {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
t.Logf("==> checking with arguments %v", args)
|
||||
|
||||
app := NewApp()
|
||||
app.Name = "boom"
|
||||
app.Usage = "make an explosive entrance"
|
||||
app.Writer = buf
|
||||
app.Action = func(c *Context) {
|
||||
buf.WriteString("boom I say!")
|
||||
}
|
||||
|
||||
err := app.Run(args)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
t.Logf("output: %q\n", buf.Bytes())
|
||||
|
||||
if !strings.Contains(output, "boom - make an explosive entrance") {
|
||||
t.Errorf("want help to contain %q, did not: \n%q", "boom - make an explosive entrance", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_Run_Version(t *testing.T) {
|
||||
var versionArguments = [][]string{{"boom", "--version"}, {"boom", "-v"}}
|
||||
|
||||
for _, args := range versionArguments {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
t.Logf("==> checking with arguments %v", args)
|
||||
|
||||
app := NewApp()
|
||||
app.Name = "boom"
|
||||
app.Usage = "make an explosive entrance"
|
||||
app.Version = "0.1.0"
|
||||
app.Writer = buf
|
||||
app.Action = func(c *Context) {
|
||||
buf.WriteString("boom I say!")
|
||||
}
|
||||
|
||||
err := app.Run(args)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
output := buf.String()
|
||||
t.Logf("output: %q\n", buf.Bytes())
|
||||
|
||||
if !strings.Contains(output, "0.1.0") {
|
||||
t.Errorf("want version to contain %q, did not: \n%q", "0.1.0", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_Run_DoesNotOverwriteErrorFromBefore(t *testing.T) {
|
||||
app := NewApp()
|
||||
app.Action = func(c *Context) {}
|
||||
app.Before = func(c *Context) error { return fmt.Errorf("before error") }
|
||||
app.After = func(c *Context) error { return fmt.Errorf("after error") }
|
||||
|
||||
err := app.Run([]string{"foo"})
|
||||
if err == nil {
|
||||
t.Fatalf("expected to recieve error from Run, got none")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "before error") {
|
||||
t.Errorf("expected text of error from Before method, but got none in \"%v\"", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "after error") {
|
||||
t.Errorf("expected text of error from After method, but got none in \"%v\"", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApp_Run_SubcommandDoesNotOverwriteErrorFromBefore(t *testing.T) {
|
||||
app := NewApp()
|
||||
app.Commands = []Command{
|
||||
Command{
|
||||
Name: "bar",
|
||||
Before: func(c *Context) error { return fmt.Errorf("before error") },
|
||||
After: func(c *Context) error { return fmt.Errorf("after error") },
|
||||
},
|
||||
}
|
||||
|
||||
err := app.Run([]string{"foo", "bar"})
|
||||
if err == nil {
|
||||
t.Fatalf("expected to recieve error from Run, got none")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "before error") {
|
||||
t.Errorf("expected text of error from Before method, but got none in \"%v\"", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "after error") {
|
||||
t.Errorf("expected text of error from After method, but got none in \"%v\"", err)
|
||||
}
|
||||
}
|
13
Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/bash_autocomplete
generated
vendored
13
Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/bash_autocomplete
generated
vendored
@ -0,0 +1,13 @@
|
||||
#! /bin/bash
|
||||
|
||||
_cli_bash_autocomplete() {
|
||||
local cur prev opts base
|
||||
COMPREPLY=()
|
||||
cur="${COMP_WORDS[COMP_CWORD]}"
|
||||
prev="${COMP_WORDS[COMP_CWORD-1]}"
|
||||
opts=$( ${COMP_WORDS[@]:0:$COMP_CWORD} --generate-bash-completion )
|
||||
COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) )
|
||||
return 0
|
||||
}
|
||||
|
||||
complete -F _cli_bash_autocomplete $PROG
|
@ -0,0 +1,5 @@
|
||||
autoload -U compinit && compinit
|
||||
autoload -U bashcompinit && bashcompinit
|
||||
|
||||
script_dir=$(dirname $0)
|
||||
source ${script_dir}/bash_autocomplete
|
@ -0,0 +1,40 @@
|
||||
// Package cli provides a minimal framework for creating and organizing command line
|
||||
// Go applications. cli is designed to be easy to understand and write, the most simple
|
||||
// cli application can be written as follows:
|
||||
// func main() {
|
||||
// cli.NewApp().Run(os.Args)
|
||||
// }
|
||||
//
|
||||
// Of course this application does not do much, so let's make this an actual application:
|
||||
// func main() {
|
||||
// app := cli.NewApp()
|
||||
// app.Name = "greet"
|
||||
// app.Usage = "say a greeting"
|
||||
// app.Action = func(c *cli.Context) {
|
||||
// println("Greetings")
|
||||
// }
|
||||
//
|
||||
// app.Run(os.Args)
|
||||
// }
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MultiError struct {
|
||||
Errors []error
|
||||
}
|
||||
|
||||
func NewMultiError(err ...error) MultiError {
|
||||
return MultiError{Errors: err}
|
||||
}
|
||||
|
||||
func (m MultiError) Error() string {
|
||||
errs := make([]string, len(m.Errors))
|
||||
for i, err := range m.Errors {
|
||||
errs[i] = err.Error()
|
||||
}
|
||||
|
||||
return strings.Join(errs, "\n")
|
||||
}
|
@ -0,0 +1,98 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
app := NewApp()
|
||||
app.Name = "todo"
|
||||
app.Usage = "task list on the command line"
|
||||
app.Commands = []Command{
|
||||
{
|
||||
Name: "add",
|
||||
Aliases: []string{"a"},
|
||||
Usage: "add a task to the list",
|
||||
Action: func(c *Context) {
|
||||
println("added task: ", c.Args().First())
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "complete",
|
||||
Aliases: []string{"c"},
|
||||
Usage: "complete a task on the list",
|
||||
Action: func(c *Context) {
|
||||
println("completed task: ", c.Args().First())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
}
|
||||
|
||||
func ExampleSubcommand() {
|
||||
app := NewApp()
|
||||
app.Name = "say"
|
||||
app.Commands = []Command{
|
||||
{
|
||||
Name: "hello",
|
||||
Aliases: []string{"hi"},
|
||||
Usage: "use it to see a description",
|
||||
Description: "This is how we describe hello the function",
|
||||
Subcommands: []Command{
|
||||
{
|
||||
Name: "english",
|
||||
Aliases: []string{"en"},
|
||||
Usage: "sends a greeting in english",
|
||||
Description: "greets someone in english",
|
||||
Flags: []Flag{
|
||||
StringFlag{
|
||||
Name: "name",
|
||||
Value: "Bob",
|
||||
Usage: "Name of the person to greet",
|
||||
},
|
||||
},
|
||||
Action: func(c *Context) {
|
||||
println("Hello, ", c.String("name"))
|
||||
},
|
||||
}, {
|
||||
Name: "spanish",
|
||||
Aliases: []string{"sp"},
|
||||
Usage: "sends a greeting in spanish",
|
||||
Flags: []Flag{
|
||||
StringFlag{
|
||||
Name: "surname",
|
||||
Value: "Jones",
|
||||
Usage: "Surname of the person to greet",
|
||||
},
|
||||
},
|
||||
Action: func(c *Context) {
|
||||
println("Hola, ", c.String("surname"))
|
||||
},
|
||||
}, {
|
||||
Name: "french",
|
||||
Aliases: []string{"fr"},
|
||||
Usage: "sends a greeting in french",
|
||||
Flags: []Flag{
|
||||
StringFlag{
|
||||
Name: "nickname",
|
||||
Value: "Stevie",
|
||||
Usage: "Nickname of the person to greet",
|
||||
},
|
||||
},
|
||||
Action: func(c *Context) {
|
||||
println("Bonjour, ", c.String("nickname"))
|
||||
},
|
||||
},
|
||||
},
|
||||
}, {
|
||||
Name: "bye",
|
||||
Usage: "says goodbye",
|
||||
Action: func(c *Context) {
|
||||
println("bye")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
app.Run(os.Args)
|
||||
}
|
@ -0,0 +1,200 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Command is a subcommand for a cli.App.
|
||||
type Command struct {
|
||||
// The name of the command
|
||||
Name string
|
||||
// short name of the command. Typically one character (deprecated, use `Aliases`)
|
||||
ShortName string
|
||||
// A list of aliases for the command
|
||||
Aliases []string
|
||||
// A short description of the usage of this command
|
||||
Usage string
|
||||
// A longer explanation of how the command works
|
||||
Description string
|
||||
// The function to call when checking for bash command completions
|
||||
BashComplete func(context *Context)
|
||||
// An action to execute before any sub-subcommands are run, but after the context is ready
|
||||
// If a non-nil error is returned, no sub-subcommands are run
|
||||
Before func(context *Context) error
|
||||
// An action to execute after any subcommands are run, but after the subcommand has finished
|
||||
// It is run even if Action() panics
|
||||
After func(context *Context) error
|
||||
// The function to call when this command is invoked
|
||||
Action func(context *Context)
|
||||
// List of child commands
|
||||
Subcommands []Command
|
||||
// List of flags to parse
|
||||
Flags []Flag
|
||||
// Treat all flags as normal arguments if true
|
||||
SkipFlagParsing bool
|
||||
// Boolean to hide built-in help command
|
||||
HideHelp bool
|
||||
|
||||
commandNamePath []string
|
||||
}
|
||||
|
||||
// Returns the full name of the command.
|
||||
// For subcommands this ensures that parent commands are part of the command path
|
||||
func (c Command) FullName() string {
|
||||
if c.commandNamePath == nil {
|
||||
return c.Name
|
||||
}
|
||||
return strings.Join(c.commandNamePath, " ")
|
||||
}
|
||||
|
||||
// Invokes the command given the context, parses ctx.Args() to generate command-specific flags
|
||||
func (c Command) Run(ctx *Context) error {
|
||||
if len(c.Subcommands) > 0 || c.Before != nil || c.After != nil {
|
||||
return c.startApp(ctx)
|
||||
}
|
||||
|
||||
if !c.HideHelp && (HelpFlag != BoolFlag{}) {
|
||||
// append help to flags
|
||||
c.Flags = append(
|
||||
c.Flags,
|
||||
HelpFlag,
|
||||
)
|
||||
}
|
||||
|
||||
if ctx.App.EnableBashCompletion {
|
||||
c.Flags = append(c.Flags, BashCompletionFlag)
|
||||
}
|
||||
|
||||
set := flagSet(c.Name, c.Flags)
|
||||
set.SetOutput(ioutil.Discard)
|
||||
|
||||
firstFlagIndex := -1
|
||||
terminatorIndex := -1
|
||||
for index, arg := range ctx.Args() {
|
||||
if arg == "--" {
|
||||
terminatorIndex = index
|
||||
break
|
||||
} else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 {
|
||||
firstFlagIndex = index
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if firstFlagIndex > -1 && !c.SkipFlagParsing {
|
||||
args := ctx.Args()
|
||||
regularArgs := make([]string, len(args[1:firstFlagIndex]))
|
||||
copy(regularArgs, args[1:firstFlagIndex])
|
||||
|
||||
var flagArgs []string
|
||||
if terminatorIndex > -1 {
|
||||
flagArgs = args[firstFlagIndex:terminatorIndex]
|
||||
regularArgs = append(regularArgs, args[terminatorIndex:]...)
|
||||
} else {
|
||||
flagArgs = args[firstFlagIndex:]
|
||||
}
|
||||
|
||||
err = set.Parse(append(flagArgs, regularArgs...))
|
||||
} else {
|
||||
err = set.Parse(ctx.Args().Tail())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.")
|
||||
fmt.Fprintln(ctx.App.Writer)
|
||||
ShowCommandHelp(ctx, c.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
nerr := normalizeFlags(c.Flags, set)
|
||||
if nerr != nil {
|
||||
fmt.Fprintln(ctx.App.Writer, nerr)
|
||||
fmt.Fprintln(ctx.App.Writer)
|
||||
ShowCommandHelp(ctx, c.Name)
|
||||
return nerr
|
||||
}
|
||||
context := NewContext(ctx.App, set, ctx)
|
||||
|
||||
if checkCommandCompletions(context, c.Name) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if checkCommandHelp(context, c.Name) {
|
||||
return nil
|
||||
}
|
||||
context.Command = c
|
||||
c.Action(context)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Command) Names() []string {
|
||||
names := []string{c.Name}
|
||||
|
||||
if c.ShortName != "" {
|
||||
names = append(names, c.ShortName)
|
||||
}
|
||||
|
||||
return append(names, c.Aliases...)
|
||||
}
|
||||
|
||||
// Returns true if Command.Name or Command.ShortName matches given name
|
||||
func (c Command) HasName(name string) bool {
|
||||
for _, n := range c.Names() {
|
||||
if n == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c Command) startApp(ctx *Context) error {
|
||||
app := NewApp()
|
||||
|
||||
// set the name and usage
|
||||
app.Name = fmt.Sprintf("%s %s", ctx.App.Name, c.Name)
|
||||
if c.Description != "" {
|
||||
app.Usage = c.Description
|
||||
} else {
|
||||
app.Usage = c.Usage
|
||||
}
|
||||
|
||||
// set CommandNotFound
|
||||
app.CommandNotFound = ctx.App.CommandNotFound
|
||||
|
||||
// set the flags and commands
|
||||
app.Commands = c.Subcommands
|
||||
app.Flags = c.Flags
|
||||
app.HideHelp = c.HideHelp
|
||||
|
||||
app.Version = ctx.App.Version
|
||||
app.HideVersion = ctx.App.HideVersion
|
||||
app.Compiled = ctx.App.Compiled
|
||||
app.Author = ctx.App.Author
|
||||
app.Email = ctx.App.Email
|
||||
app.Writer = ctx.App.Writer
|
||||
|
||||
// bash completion
|
||||
app.EnableBashCompletion = ctx.App.EnableBashCompletion
|
||||
if c.BashComplete != nil {
|
||||
app.BashComplete = c.BashComplete
|
||||
}
|
||||
|
||||
// set the actions
|
||||
app.Before = c.Before
|
||||
app.After = c.After
|
||||
if c.Action != nil {
|
||||
app.Action = c.Action
|
||||
} else {
|
||||
app.Action = helpSubcommand.Action
|
||||
}
|
||||
|
||||
var newCmds []Command
|
||||
for _, cc := range app.Commands {
|
||||
cc.commandNamePath = []string{c.Name, cc.Name}
|
||||
newCmds = append(newCmds, cc)
|
||||
}
|
||||
app.Commands = newCmds
|
||||
|
||||
return app.RunAsSubcommand(ctx)
|
||||
}
|
@ -0,0 +1,47 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCommandDoNotIgnoreFlags(t *testing.T) {
|
||||
app := NewApp()
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
test := []string{"blah", "blah", "-break"}
|
||||
set.Parse(test)
|
||||
|
||||
c := NewContext(app, set, nil)
|
||||
|
||||
command := Command{
|
||||
Name: "test-cmd",
|
||||
Aliases: []string{"tc"},
|
||||
Usage: "this is for testing",
|
||||
Description: "testing",
|
||||
Action: func(_ *Context) {},
|
||||
}
|
||||
err := command.Run(c)
|
||||
|
||||
expect(t, err.Error(), "flag provided but not defined: -break")
|
||||
}
|
||||
|
||||
func TestCommandIgnoreFlags(t *testing.T) {
|
||||
app := NewApp()
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
test := []string{"blah", "blah"}
|
||||
set.Parse(test)
|
||||
|
||||
c := NewContext(app, set, nil)
|
||||
|
||||
command := Command{
|
||||
Name: "test-cmd",
|
||||
Aliases: []string{"tc"},
|
||||
Usage: "this is for testing",
|
||||
Description: "testing",
|
||||
Action: func(_ *Context) {},
|
||||
SkipFlagParsing: true,
|
||||
}
|
||||
err := command.Run(c)
|
||||
|
||||
expect(t, err, nil)
|
||||
}
|
@ -0,0 +1,388 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Context is a type that is passed through to
|
||||
// each Handler action in a cli application. Context
|
||||
// can be used to retrieve context-specific Args and
|
||||
// parsed command-line options.
|
||||
type Context struct {
|
||||
App *App
|
||||
Command Command
|
||||
flagSet *flag.FlagSet
|
||||
setFlags map[string]bool
|
||||
globalSetFlags map[string]bool
|
||||
parentContext *Context
|
||||
}
|
||||
|
||||
// Creates a new context. For use in when invoking an App or Command action.
|
||||
func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context {
|
||||
return &Context{App: app, flagSet: set, parentContext: parentCtx}
|
||||
}
|
||||
|
||||
// Looks up the value of a local int flag, returns 0 if no int flag exists
|
||||
func (c *Context) Int(name string) int {
|
||||
return lookupInt(name, c.flagSet)
|
||||
}
|
||||
|
||||
// Looks up the value of a local time.Duration flag, returns 0 if no time.Duration flag exists
|
||||
func (c *Context) Duration(name string) time.Duration {
|
||||
return lookupDuration(name, c.flagSet)
|
||||
}
|
||||
|
||||
// Looks up the value of a local float64 flag, returns 0 if no float64 flag exists
|
||||
func (c *Context) Float64(name string) float64 {
|
||||
return lookupFloat64(name, c.flagSet)
|
||||
}
|
||||
|
||||
// Looks up the value of a local bool flag, returns false if no bool flag exists
|
||||
func (c *Context) Bool(name string) bool {
|
||||
return lookupBool(name, c.flagSet)
|
||||
}
|
||||
|
||||
// Looks up the value of a local boolT flag, returns false if no bool flag exists
|
||||
func (c *Context) BoolT(name string) bool {
|
||||
return lookupBoolT(name, c.flagSet)
|
||||
}
|
||||
|
||||
// Looks up the value of a local string flag, returns "" if no string flag exists
|
||||
func (c *Context) String(name string) string {
|
||||
return lookupString(name, c.flagSet)
|
||||
}
|
||||
|
||||
// Looks up the value of a local string slice flag, returns nil if no string slice flag exists
|
||||
func (c *Context) StringSlice(name string) []string {
|
||||
return lookupStringSlice(name, c.flagSet)
|
||||
}
|
||||
|
||||
// Looks up the value of a local int slice flag, returns nil if no int slice flag exists
|
||||
func (c *Context) IntSlice(name string) []int {
|
||||
return lookupIntSlice(name, c.flagSet)
|
||||
}
|
||||
|
||||
// Looks up the value of a local generic flag, returns nil if no generic flag exists
|
||||
func (c *Context) Generic(name string) interface{} {
|
||||
return lookupGeneric(name, c.flagSet)
|
||||
}
|
||||
|
||||
// Looks up the value of a global int flag, returns 0 if no int flag exists
|
||||
func (c *Context) GlobalInt(name string) int {
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupInt(name, fs)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists
|
||||
func (c *Context) GlobalDuration(name string) time.Duration {
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupDuration(name, fs)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Looks up the value of a global bool flag, returns false if no bool flag exists
|
||||
func (c *Context) GlobalBool(name string) bool {
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupBool(name, fs)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Looks up the value of a global string flag, returns "" if no string flag exists
|
||||
func (c *Context) GlobalString(name string) string {
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupString(name, fs)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Looks up the value of a global string slice flag, returns nil if no string slice flag exists
|
||||
func (c *Context) GlobalStringSlice(name string) []string {
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupStringSlice(name, fs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Looks up the value of a global int slice flag, returns nil if no int slice flag exists
|
||||
func (c *Context) GlobalIntSlice(name string) []int {
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupIntSlice(name, fs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Looks up the value of a global generic flag, returns nil if no generic flag exists
|
||||
func (c *Context) GlobalGeneric(name string) interface{} {
|
||||
if fs := lookupGlobalFlagSet(name, c); fs != nil {
|
||||
return lookupGeneric(name, fs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the number of flags set
|
||||
func (c *Context) NumFlags() int {
|
||||
return c.flagSet.NFlag()
|
||||
}
|
||||
|
||||
// Determines if the flag was actually set
|
||||
func (c *Context) IsSet(name string) bool {
|
||||
if c.setFlags == nil {
|
||||
c.setFlags = make(map[string]bool)
|
||||
c.flagSet.Visit(func(f *flag.Flag) {
|
||||
c.setFlags[f.Name] = true
|
||||
})
|
||||
}
|
||||
return c.setFlags[name] == true
|
||||
}
|
||||
|
||||
// Determines if the global flag was actually set
|
||||
func (c *Context) GlobalIsSet(name string) bool {
|
||||
if c.globalSetFlags == nil {
|
||||
c.globalSetFlags = make(map[string]bool)
|
||||
ctx := c
|
||||
if ctx.parentContext != nil {
|
||||
ctx = ctx.parentContext
|
||||
}
|
||||
for ; ctx != nil && c.globalSetFlags[name] == false; ctx = ctx.parentContext {
|
||||
ctx.flagSet.Visit(func(f *flag.Flag) {
|
||||
c.globalSetFlags[f.Name] = true
|
||||
})
|
||||
}
|
||||
}
|
||||
return c.globalSetFlags[name]
|
||||
}
|
||||
|
||||
// Returns a slice of flag names used in this context.
|
||||
func (c *Context) FlagNames() (names []string) {
|
||||
for _, flag := range c.Command.Flags {
|
||||
name := strings.Split(flag.getName(), ",")[0]
|
||||
if name == "help" {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Returns a slice of global flag names used by the app.
|
||||
func (c *Context) GlobalFlagNames() (names []string) {
|
||||
for _, flag := range c.App.Flags {
|
||||
name := strings.Split(flag.getName(), ",")[0]
|
||||
if name == "help" || name == "version" {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Returns the parent context, if any
|
||||
func (c *Context) Parent() *Context {
|
||||
return c.parentContext
|
||||
}
|
||||
|
||||
type Args []string
|
||||
|
||||
// Returns the command line arguments associated with the context.
|
||||
func (c *Context) Args() Args {
|
||||
args := Args(c.flagSet.Args())
|
||||
return args
|
||||
}
|
||||
|
||||
// Returns the nth argument, or else a blank string
|
||||
func (a Args) Get(n int) string {
|
||||
if len(a) > n {
|
||||
return a[n]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Returns the first argument, or else a blank string
|
||||
func (a Args) First() string {
|
||||
return a.Get(0)
|
||||
}
|
||||
|
||||
// Return the rest of the arguments (not the first one)
|
||||
// or else an empty string slice
|
||||
func (a Args) Tail() []string {
|
||||
if len(a) >= 2 {
|
||||
return []string(a)[1:]
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// Checks if there are any arguments present
|
||||
func (a Args) Present() bool {
|
||||
return len(a) != 0
|
||||
}
|
||||
|
||||
// Swaps arguments at the given indexes
|
||||
func (a Args) Swap(from, to int) error {
|
||||
if from >= len(a) || to >= len(a) {
|
||||
return errors.New("index out of range")
|
||||
}
|
||||
a[from], a[to] = a[to], a[from]
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupGlobalFlagSet(name string, ctx *Context) *flag.FlagSet {
|
||||
if ctx.parentContext != nil {
|
||||
ctx = ctx.parentContext
|
||||
}
|
||||
for ; ctx != nil; ctx = ctx.parentContext {
|
||||
if f := ctx.flagSet.Lookup(name); f != nil {
|
||||
return ctx.flagSet
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupInt(name string, set *flag.FlagSet) int {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
val, err := strconv.Atoi(f.Value.String())
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func lookupDuration(name string, set *flag.FlagSet) time.Duration {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
val, err := time.ParseDuration(f.Value.String())
|
||||
if err == nil {
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func lookupFloat64(name string, set *flag.FlagSet) float64 {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
val, err := strconv.ParseFloat(f.Value.String(), 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func lookupString(name string, set *flag.FlagSet) string {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
return f.Value.String()
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func lookupStringSlice(name string, set *flag.FlagSet) []string {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
return (f.Value.(*StringSlice)).Value()
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupIntSlice(name string, set *flag.FlagSet) []int {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
return (f.Value.(*IntSlice)).Value()
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupGeneric(name string, set *flag.FlagSet) interface{} {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
return f.Value
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupBool(name string, set *flag.FlagSet) bool {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
val, err := strconv.ParseBool(f.Value.String())
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func lookupBoolT(name string, set *flag.FlagSet) bool {
|
||||
f := set.Lookup(name)
|
||||
if f != nil {
|
||||
val, err := strconv.ParseBool(f.Value.String())
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) {
|
||||
switch ff.Value.(type) {
|
||||
case *StringSlice:
|
||||
default:
|
||||
set.Set(name, ff.Value.String())
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeFlags(flags []Flag, set *flag.FlagSet) error {
|
||||
visited := make(map[string]bool)
|
||||
set.Visit(func(f *flag.Flag) {
|
||||
visited[f.Name] = true
|
||||
})
|
||||
for _, f := range flags {
|
||||
parts := strings.Split(f.getName(), ",")
|
||||
if len(parts) == 1 {
|
||||
continue
|
||||
}
|
||||
var ff *flag.Flag
|
||||
for _, name := range parts {
|
||||
name = strings.Trim(name, " ")
|
||||
if visited[name] {
|
||||
if ff != nil {
|
||||
return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name)
|
||||
}
|
||||
ff = set.Lookup(name)
|
||||
}
|
||||
}
|
||||
if ff == nil {
|
||||
continue
|
||||
}
|
||||
for _, name := range parts {
|
||||
name = strings.Trim(name, " ")
|
||||
if !visited[name] {
|
||||
copyFlag(name, ff, set)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,113 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Int("myflag", 12, "doc")
|
||||
globalSet := flag.NewFlagSet("test", 0)
|
||||
globalSet.Int("myflag", 42, "doc")
|
||||
globalCtx := NewContext(nil, globalSet, nil)
|
||||
command := Command{Name: "mycommand"}
|
||||
c := NewContext(nil, set, globalCtx)
|
||||
c.Command = command
|
||||
expect(t, c.Int("myflag"), 12)
|
||||
expect(t, c.GlobalInt("myflag"), 42)
|
||||
expect(t, c.Command.Name, "mycommand")
|
||||
}
|
||||
|
||||
func TestContext_Int(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Int("myflag", 12, "doc")
|
||||
c := NewContext(nil, set, nil)
|
||||
expect(t, c.Int("myflag"), 12)
|
||||
}
|
||||
|
||||
func TestContext_Duration(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Duration("myflag", time.Duration(12*time.Second), "doc")
|
||||
c := NewContext(nil, set, nil)
|
||||
expect(t, c.Duration("myflag"), time.Duration(12*time.Second))
|
||||
}
|
||||
|
||||
func TestContext_String(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.String("myflag", "hello world", "doc")
|
||||
c := NewContext(nil, set, nil)
|
||||
expect(t, c.String("myflag"), "hello world")
|
||||
}
|
||||
|
||||
func TestContext_Bool(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", false, "doc")
|
||||
c := NewContext(nil, set, nil)
|
||||
expect(t, c.Bool("myflag"), false)
|
||||
}
|
||||
|
||||
func TestContext_BoolT(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", true, "doc")
|
||||
c := NewContext(nil, set, nil)
|
||||
expect(t, c.BoolT("myflag"), true)
|
||||
}
|
||||
|
||||
func TestContext_Args(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", false, "doc")
|
||||
c := NewContext(nil, set, nil)
|
||||
set.Parse([]string{"--myflag", "bat", "baz"})
|
||||
expect(t, len(c.Args()), 2)
|
||||
expect(t, c.Bool("myflag"), true)
|
||||
}
|
||||
|
||||
func TestContext_IsSet(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", false, "doc")
|
||||
set.String("otherflag", "hello world", "doc")
|
||||
globalSet := flag.NewFlagSet("test", 0)
|
||||
globalSet.Bool("myflagGlobal", true, "doc")
|
||||
globalCtx := NewContext(nil, globalSet, nil)
|
||||
c := NewContext(nil, set, globalCtx)
|
||||
set.Parse([]string{"--myflag", "bat", "baz"})
|
||||
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"})
|
||||
expect(t, c.IsSet("myflag"), true)
|
||||
expect(t, c.IsSet("otherflag"), false)
|
||||
expect(t, c.IsSet("bogusflag"), false)
|
||||
expect(t, c.IsSet("myflagGlobal"), false)
|
||||
}
|
||||
|
||||
func TestContext_GlobalIsSet(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", false, "doc")
|
||||
set.String("otherflag", "hello world", "doc")
|
||||
globalSet := flag.NewFlagSet("test", 0)
|
||||
globalSet.Bool("myflagGlobal", true, "doc")
|
||||
globalSet.Bool("myflagGlobalUnset", true, "doc")
|
||||
globalCtx := NewContext(nil, globalSet, nil)
|
||||
c := NewContext(nil, set, globalCtx)
|
||||
set.Parse([]string{"--myflag", "bat", "baz"})
|
||||
globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"})
|
||||
expect(t, c.GlobalIsSet("myflag"), false)
|
||||
expect(t, c.GlobalIsSet("otherflag"), false)
|
||||
expect(t, c.GlobalIsSet("bogusflag"), false)
|
||||
expect(t, c.GlobalIsSet("myflagGlobal"), true)
|
||||
expect(t, c.GlobalIsSet("myflagGlobalUnset"), false)
|
||||
expect(t, c.GlobalIsSet("bogusGlobal"), false)
|
||||
}
|
||||
|
||||
func TestContext_NumFlags(t *testing.T) {
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.Bool("myflag", false, "doc")
|
||||
set.String("otherflag", "hello world", "doc")
|
||||
globalSet := flag.NewFlagSet("test", 0)
|
||||
globalSet.Bool("myflagGlobal", true, "doc")
|
||||
globalCtx := NewContext(nil, globalSet, nil)
|
||||
c := NewContext(nil, set, globalCtx)
|
||||
set.Parse([]string{"--myflag", "--otherflag=foo"})
|
||||
globalSet.Parse([]string{"--myflagGlobal"})
|
||||
expect(t, c.NumFlags(), 2)
|
||||
}
|
@ -0,0 +1,497 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// This flag enables bash-completion for all commands and subcommands
|
||||
var BashCompletionFlag = BoolFlag{
|
||||
Name: "generate-bash-completion",
|
||||
}
|
||||
|
||||
// This flag prints the version for the application
|
||||
var VersionFlag = BoolFlag{
|
||||
Name: "version, v",
|
||||
Usage: "print the version",
|
||||
}
|
||||
|
||||
// This flag prints the help for all commands and subcommands
|
||||
// Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand
|
||||
// unless HideHelp is set to true)
|
||||
var HelpFlag = BoolFlag{
|
||||
Name: "help, h",
|
||||
Usage: "show help",
|
||||
}
|
||||
|
||||
// Flag is a common interface related to parsing flags in cli.
|
||||
// For more advanced flag parsing techniques, it is recomended that
|
||||
// this interface be implemented.
|
||||
type Flag interface {
|
||||
fmt.Stringer
|
||||
// Apply Flag settings to the given flag set
|
||||
Apply(*flag.FlagSet)
|
||||
getName() string
|
||||
}
|
||||
|
||||
func flagSet(name string, flags []Flag) *flag.FlagSet {
|
||||
set := flag.NewFlagSet(name, flag.ContinueOnError)
|
||||
|
||||
for _, f := range flags {
|
||||
f.Apply(set)
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func eachName(longName string, fn func(string)) {
|
||||
parts := strings.Split(longName, ",")
|
||||
for _, name := range parts {
|
||||
name = strings.Trim(name, " ")
|
||||
fn(name)
|
||||
}
|
||||
}
|
||||
|
||||
// Generic is a generic parseable type identified by a specific flag
|
||||
type Generic interface {
|
||||
Set(value string) error
|
||||
String() string
|
||||
}
|
||||
|
||||
// GenericFlag is the flag type for types implementing Generic
|
||||
type GenericFlag struct {
|
||||
Name string
|
||||
Value Generic
|
||||
Usage string
|
||||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns the string representation of the generic flag to display the
|
||||
// help text to the user (uses the String() method of the generic flag to show
|
||||
// the value)
|
||||
func (f GenericFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s \"%v\"\t%v", prefixFor(f.Name), f.Name, f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply takes the flagset and calls Set on the generic flag with the value
|
||||
// provided by the user for parsing by the flag
|
||||
func (f GenericFlag) Apply(set *flag.FlagSet) {
|
||||
val := f.Value
|
||||
if f.EnvVar != "" {
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
val.Set(envVal)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
set.Var(f.Value, name, f.Usage)
|
||||
})
|
||||
}
|
||||
|
||||
func (f GenericFlag) getName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// StringSlice is an opaque type for []string to satisfy flag.Value
|
||||
type StringSlice []string
|
||||
|
||||
// Set appends the string value to the list of values
|
||||
func (f *StringSlice) Set(value string) error {
|
||||
*f = append(*f, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f *StringSlice) String() string {
|
||||
return fmt.Sprintf("%s", *f)
|
||||
}
|
||||
|
||||
// Value returns the slice of strings set by this flag
|
||||
func (f *StringSlice) Value() []string {
|
||||
return *f
|
||||
}
|
||||
|
||||
// StringSlice is a string flag that can be specified multiple times on the
|
||||
// command-line
|
||||
type StringSliceFlag struct {
|
||||
Name string
|
||||
Value *StringSlice
|
||||
Usage string
|
||||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f StringSliceFlag) String() string {
|
||||
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
|
||||
pref := prefixFor(firstName)
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f StringSliceFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
newVal := &StringSlice{}
|
||||
for _, s := range strings.Split(envVal, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
newVal.Set(s)
|
||||
}
|
||||
f.Value = newVal
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Value == nil {
|
||||
f.Value = &StringSlice{}
|
||||
}
|
||||
set.Var(f.Value, name, f.Usage)
|
||||
})
|
||||
}
|
||||
|
||||
func (f StringSliceFlag) getName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// StringSlice is an opaque type for []int to satisfy flag.Value
|
||||
type IntSlice []int
|
||||
|
||||
// Set parses the value into an integer and appends it to the list of values
|
||||
func (f *IntSlice) Set(value string) error {
|
||||
tmp, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return err
|
||||
} else {
|
||||
*f = append(*f, tmp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f *IntSlice) String() string {
|
||||
return fmt.Sprintf("%d", *f)
|
||||
}
|
||||
|
||||
// Value returns the slice of ints set by this flag
|
||||
func (f *IntSlice) Value() []int {
|
||||
return *f
|
||||
}
|
||||
|
||||
// IntSliceFlag is an int flag that can be specified multiple times on the
|
||||
// command-line
|
||||
type IntSliceFlag struct {
|
||||
Name string
|
||||
Value *IntSlice
|
||||
Usage string
|
||||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f IntSliceFlag) String() string {
|
||||
firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ")
|
||||
pref := prefixFor(firstName)
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f IntSliceFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
newVal := &IntSlice{}
|
||||
for _, s := range strings.Split(envVal, ",") {
|
||||
s = strings.TrimSpace(s)
|
||||
err := newVal.Set(s)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, err.Error())
|
||||
}
|
||||
}
|
||||
f.Value = newVal
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
if f.Value == nil {
|
||||
f.Value = &IntSlice{}
|
||||
}
|
||||
set.Var(f.Value, name, f.Usage)
|
||||
})
|
||||
}
|
||||
|
||||
func (f IntSliceFlag) getName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// BoolFlag is a switch that defaults to false
|
||||
type BoolFlag struct {
|
||||
Name string
|
||||
Usage string
|
||||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f BoolFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f BoolFlag) Apply(set *flag.FlagSet) {
|
||||
val := false
|
||||
if f.EnvVar != "" {
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValBool, err := strconv.ParseBool(envVal)
|
||||
if err == nil {
|
||||
val = envValBool
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
set.Bool(name, val, f.Usage)
|
||||
})
|
||||
}
|
||||
|
||||
func (f BoolFlag) getName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// BoolTFlag this represents a boolean flag that is true by default, but can
|
||||
// still be set to false by --some-flag=false
|
||||
type BoolTFlag struct {
|
||||
Name string
|
||||
Usage string
|
||||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f BoolTFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f BoolTFlag) Apply(set *flag.FlagSet) {
|
||||
val := true
|
||||
if f.EnvVar != "" {
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValBool, err := strconv.ParseBool(envVal)
|
||||
if err == nil {
|
||||
val = envValBool
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
set.Bool(name, val, f.Usage)
|
||||
})
|
||||
}
|
||||
|
||||
func (f BoolTFlag) getName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// StringFlag represents a flag that takes as string value
|
||||
type StringFlag struct {
|
||||
Name string
|
||||
Value string
|
||||
Usage string
|
||||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f StringFlag) String() string {
|
||||
var fmtString string
|
||||
fmtString = "%s %v\t%v"
|
||||
|
||||
if len(f.Value) > 0 {
|
||||
fmtString = "%s \"%v\"\t%v"
|
||||
} else {
|
||||
fmtString = "%s %v\t%v"
|
||||
}
|
||||
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f StringFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
f.Value = envVal
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
set.String(name, f.Value, f.Usage)
|
||||
})
|
||||
}
|
||||
|
||||
func (f StringFlag) getName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// IntFlag is a flag that takes an integer
|
||||
// Errors if the value provided cannot be parsed
|
||||
type IntFlag struct {
|
||||
Name string
|
||||
Value int
|
||||
Usage string
|
||||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f IntFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f IntFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValInt, err := strconv.ParseInt(envVal, 0, 64)
|
||||
if err == nil {
|
||||
f.Value = int(envValInt)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
set.Int(name, f.Value, f.Usage)
|
||||
})
|
||||
}
|
||||
|
||||
func (f IntFlag) getName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// DurationFlag is a flag that takes a duration specified in Go's duration
|
||||
// format: https://golang.org/pkg/time/#ParseDuration
|
||||
type DurationFlag struct {
|
||||
Name string
|
||||
Value time.Duration
|
||||
Usage string
|
||||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns a readable representation of this value (for usage defaults)
|
||||
func (f DurationFlag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f DurationFlag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValDuration, err := time.ParseDuration(envVal)
|
||||
if err == nil {
|
||||
f.Value = envValDuration
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
set.Duration(name, f.Value, f.Usage)
|
||||
})
|
||||
}
|
||||
|
||||
func (f DurationFlag) getName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
// Float64Flag is a flag that takes an float value
|
||||
// Errors if the value provided cannot be parsed
|
||||
type Float64Flag struct {
|
||||
Name string
|
||||
Value float64
|
||||
Usage string
|
||||
EnvVar string
|
||||
}
|
||||
|
||||
// String returns the usage
|
||||
func (f Float64Flag) String() string {
|
||||
return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage))
|
||||
}
|
||||
|
||||
// Apply populates the flag given the flag set and environment
|
||||
func (f Float64Flag) Apply(set *flag.FlagSet) {
|
||||
if f.EnvVar != "" {
|
||||
for _, envVar := range strings.Split(f.EnvVar, ",") {
|
||||
envVar = strings.TrimSpace(envVar)
|
||||
if envVal := os.Getenv(envVar); envVal != "" {
|
||||
envValFloat, err := strconv.ParseFloat(envVal, 10)
|
||||
if err == nil {
|
||||
f.Value = float64(envValFloat)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eachName(f.Name, func(name string) {
|
||||
set.Float64(name, f.Value, f.Usage)
|
||||
})
|
||||
}
|
||||
|
||||
func (f Float64Flag) getName() string {
|
||||
return f.Name
|
||||
}
|
||||
|
||||
func prefixFor(name string) (prefix string) {
|
||||
if len(name) == 1 {
|
||||
prefix = "-"
|
||||
} else {
|
||||
prefix = "--"
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func prefixedNames(fullName string) (prefixed string) {
|
||||
parts := strings.Split(fullName, ",")
|
||||
for i, name := range parts {
|
||||
name = strings.Trim(name, " ")
|
||||
prefixed += prefixFor(name) + name
|
||||
if i < len(parts)-1 {
|
||||
prefixed += ", "
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func withEnvHint(envVar, str string) string {
|
||||
envText := ""
|
||||
if envVar != "" {
|
||||
envText = fmt.Sprintf(" [$%s]", strings.Join(strings.Split(envVar, ","), ", $"))
|
||||
}
|
||||
return str + envText
|
||||
}
|
@ -0,0 +1,740 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var boolFlagTests = []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"help", "--help\t"},
|
||||
{"h", "-h\t"},
|
||||
}
|
||||
|
||||
func TestBoolFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range boolFlagTests {
|
||||
flag := BoolFlag{Name: test.name}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var stringFlagTests = []struct {
|
||||
name string
|
||||
value string
|
||||
expected string
|
||||
}{
|
||||
{"help", "", "--help \t"},
|
||||
{"h", "", "-h \t"},
|
||||
{"h", "", "-h \t"},
|
||||
{"test", "Something", "--test \"Something\"\t"},
|
||||
}
|
||||
|
||||
func TestStringFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range stringFlagTests {
|
||||
flag := StringFlag{Name: test.name, Value: test.value}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_FOO", "derp")
|
||||
for _, test := range stringFlagTests {
|
||||
flag := StringFlag{Name: test.name, Value: test.value, EnvVar: "APP_FOO"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_FOO]") {
|
||||
t.Errorf("%s does not end with [$APP_FOO]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var stringSliceFlagTests = []struct {
|
||||
name string
|
||||
value *StringSlice
|
||||
expected string
|
||||
}{
|
||||
{"help", func() *StringSlice {
|
||||
s := &StringSlice{}
|
||||
s.Set("")
|
||||
return s
|
||||
}(), "--help [--help option --help option]\t"},
|
||||
{"h", func() *StringSlice {
|
||||
s := &StringSlice{}
|
||||
s.Set("")
|
||||
return s
|
||||
}(), "-h [-h option -h option]\t"},
|
||||
{"h", func() *StringSlice {
|
||||
s := &StringSlice{}
|
||||
s.Set("")
|
||||
return s
|
||||
}(), "-h [-h option -h option]\t"},
|
||||
{"test", func() *StringSlice {
|
||||
s := &StringSlice{}
|
||||
s.Set("Something")
|
||||
return s
|
||||
}(), "--test [--test option --test option]\t"},
|
||||
}
|
||||
|
||||
func TestStringSliceFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range stringSliceFlagTests {
|
||||
flag := StringSliceFlag{Name: test.name, Value: test.value}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%q does not match %q", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringSliceFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_QWWX", "11,4")
|
||||
for _, test := range stringSliceFlagTests {
|
||||
flag := StringSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_QWWX"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_QWWX]") {
|
||||
t.Errorf("%q does not end with [$APP_QWWX]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var intFlagTests = []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"help", "--help \"0\"\t"},
|
||||
{"h", "-h \"0\"\t"},
|
||||
}
|
||||
|
||||
func TestIntFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range intFlagTests {
|
||||
flag := IntFlag{Name: test.name}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_BAR", "2")
|
||||
for _, test := range intFlagTests {
|
||||
flag := IntFlag{Name: test.name, EnvVar: "APP_BAR"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_BAR]") {
|
||||
t.Errorf("%s does not end with [$APP_BAR]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var durationFlagTests = []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"help", "--help \"0\"\t"},
|
||||
{"h", "-h \"0\"\t"},
|
||||
}
|
||||
|
||||
func TestDurationFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range durationFlagTests {
|
||||
flag := DurationFlag{Name: test.name}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDurationFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_BAR", "2h3m6s")
|
||||
for _, test := range durationFlagTests {
|
||||
flag := DurationFlag{Name: test.name, EnvVar: "APP_BAR"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_BAR]") {
|
||||
t.Errorf("%s does not end with [$APP_BAR]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var intSliceFlagTests = []struct {
|
||||
name string
|
||||
value *IntSlice
|
||||
expected string
|
||||
}{
|
||||
{"help", &IntSlice{}, "--help [--help option --help option]\t"},
|
||||
{"h", &IntSlice{}, "-h [-h option -h option]\t"},
|
||||
{"h", &IntSlice{}, "-h [-h option -h option]\t"},
|
||||
{"test", func() *IntSlice {
|
||||
i := &IntSlice{}
|
||||
i.Set("9")
|
||||
return i
|
||||
}(), "--test [--test option --test option]\t"},
|
||||
}
|
||||
|
||||
func TestIntSliceFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range intSliceFlagTests {
|
||||
flag := IntSliceFlag{Name: test.name, Value: test.value}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%q does not match %q", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntSliceFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_SMURF", "42,3")
|
||||
for _, test := range intSliceFlagTests {
|
||||
flag := IntSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_SMURF"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_SMURF]") {
|
||||
t.Errorf("%q does not end with [$APP_SMURF]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var float64FlagTests = []struct {
|
||||
name string
|
||||
expected string
|
||||
}{
|
||||
{"help", "--help \"0\"\t"},
|
||||
{"h", "-h \"0\"\t"},
|
||||
}
|
||||
|
||||
func TestFloat64FlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range float64FlagTests {
|
||||
flag := Float64Flag{Name: test.name}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%s does not match %s", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloat64FlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_BAZ", "99.4")
|
||||
for _, test := range float64FlagTests {
|
||||
flag := Float64Flag{Name: test.name, EnvVar: "APP_BAZ"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_BAZ]") {
|
||||
t.Errorf("%s does not end with [$APP_BAZ]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var genericFlagTests = []struct {
|
||||
name string
|
||||
value Generic
|
||||
expected string
|
||||
}{
|
||||
{"test", &Parser{"abc", "def"}, "--test \"abc,def\"\ttest flag"},
|
||||
{"t", &Parser{"abc", "def"}, "-t \"abc,def\"\ttest flag"},
|
||||
}
|
||||
|
||||
func TestGenericFlagHelpOutput(t *testing.T) {
|
||||
|
||||
for _, test := range genericFlagTests {
|
||||
flag := GenericFlag{Name: test.name, Value: test.value, Usage: "test flag"}
|
||||
output := flag.String()
|
||||
|
||||
if output != test.expected {
|
||||
t.Errorf("%q does not match %q", output, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericFlagWithEnvVarHelpOutput(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_ZAP", "3")
|
||||
for _, test := range genericFlagTests {
|
||||
flag := GenericFlag{Name: test.name, EnvVar: "APP_ZAP"}
|
||||
output := flag.String()
|
||||
|
||||
if !strings.HasSuffix(output, " [$APP_ZAP]") {
|
||||
t.Errorf("%s does not end with [$APP_ZAP]", output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMultiString(t *testing.T) {
|
||||
(&App{
|
||||
Flags: []Flag{
|
||||
StringFlag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.String("serve") != "10" {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.String("s") != "10" {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run", "-s", "10"})
|
||||
}
|
||||
|
||||
func TestParseMultiStringFromEnv(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_COUNT", "20")
|
||||
(&App{
|
||||
Flags: []Flag{
|
||||
StringFlag{Name: "count, c", EnvVar: "APP_COUNT"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.String("count") != "20" {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.String("c") != "20" {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiStringFromEnvCascade(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_COUNT", "20")
|
||||
(&App{
|
||||
Flags: []Flag{
|
||||
StringFlag{Name: "count, c", EnvVar: "COMPAT_COUNT,APP_COUNT"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.String("count") != "20" {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.String("c") != "20" {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiStringSlice(t *testing.T) {
|
||||
(&App{
|
||||
Flags: []Flag{
|
||||
StringSliceFlag{Name: "serve, s", Value: &StringSlice{}},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if !reflect.DeepEqual(ctx.StringSlice("serve"), []string{"10", "20"}) {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.StringSlice("s"), []string{"10", "20"}) {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run", "-s", "10", "-s", "20"})
|
||||
}
|
||||
|
||||
func TestParseMultiStringSliceFromEnv(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_INTERVALS", "20,30,40")
|
||||
|
||||
(&App{
|
||||
Flags: []Flag{
|
||||
StringSliceFlag{Name: "intervals, i", Value: &StringSlice{}, EnvVar: "APP_INTERVALS"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiStringSliceFromEnvCascade(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_INTERVALS", "20,30,40")
|
||||
|
||||
(&App{
|
||||
Flags: []Flag{
|
||||
StringSliceFlag{Name: "intervals, i", Value: &StringSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiInt(t *testing.T) {
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
IntFlag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.Int("serve") != 10 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Int("s") != 10 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "-s", "10"})
|
||||
}
|
||||
|
||||
func TestParseMultiIntFromEnv(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_TIMEOUT_SECONDS", "10")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
IntFlag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.Int("timeout") != 10 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Int("t") != 10 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiIntFromEnvCascade(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_TIMEOUT_SECONDS", "10")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
IntFlag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.Int("timeout") != 10 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Int("t") != 10 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiIntSlice(t *testing.T) {
|
||||
(&App{
|
||||
Flags: []Flag{
|
||||
IntSliceFlag{Name: "serve, s", Value: &IntSlice{}},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if !reflect.DeepEqual(ctx.IntSlice("serve"), []int{10, 20}) {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.IntSlice("s"), []int{10, 20}) {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run", "-s", "10", "-s", "20"})
|
||||
}
|
||||
|
||||
func TestParseMultiIntSliceFromEnv(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_INTERVALS", "20,30,40")
|
||||
|
||||
(&App{
|
||||
Flags: []Flag{
|
||||
IntSliceFlag{Name: "intervals, i", Value: &IntSlice{}, EnvVar: "APP_INTERVALS"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiIntSliceFromEnvCascade(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_INTERVALS", "20,30,40")
|
||||
|
||||
(&App{
|
||||
Flags: []Flag{
|
||||
IntSliceFlag{Name: "intervals, i", Value: &IntSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}).Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiFloat64(t *testing.T) {
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
Float64Flag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.Float64("serve") != 10.2 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Float64("s") != 10.2 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "-s", "10.2"})
|
||||
}
|
||||
|
||||
func TestParseMultiFloat64FromEnv(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
Float64Flag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.Float64("timeout") != 15.5 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Float64("t") != 15.5 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiFloat64FromEnvCascade(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_TIMEOUT_SECONDS", "15.5")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
Float64Flag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.Float64("timeout") != 15.5 {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Float64("t") != 15.5 {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiBool(t *testing.T) {
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
BoolFlag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.Bool("serve") != true {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.Bool("s") != true {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "--serve"})
|
||||
}
|
||||
|
||||
func TestParseMultiBoolFromEnv(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_DEBUG", "1")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
BoolFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.Bool("debug") != true {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if ctx.Bool("d") != true {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiBoolFromEnvCascade(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_DEBUG", "1")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
BoolFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.Bool("debug") != true {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if ctx.Bool("d") != true {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiBoolT(t *testing.T) {
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
BoolTFlag{Name: "serve, s"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.BoolT("serve") != true {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if ctx.BoolT("s") != true {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "--serve"})
|
||||
}
|
||||
|
||||
func TestParseMultiBoolTFromEnv(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_DEBUG", "0")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
BoolTFlag{Name: "debug, d", EnvVar: "APP_DEBUG"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.BoolT("debug") != false {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if ctx.BoolT("d") != false {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseMultiBoolTFromEnvCascade(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_DEBUG", "0")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
BoolTFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if ctx.BoolT("debug") != false {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if ctx.BoolT("d") != false {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
type Parser [2]string
|
||||
|
||||
func (p *Parser) Set(value string) error {
|
||||
parts := strings.Split(value, ",")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("invalid format")
|
||||
}
|
||||
|
||||
(*p)[0] = parts[0]
|
||||
(*p)[1] = parts[1]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Parser) String() string {
|
||||
return fmt.Sprintf("%s,%s", p[0], p[1])
|
||||
}
|
||||
|
||||
func TestParseGeneric(t *testing.T) {
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
GenericFlag{Name: "serve, s", Value: &Parser{}},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"10", "20"}) {
|
||||
t.Errorf("main name not set")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"10", "20"}) {
|
||||
t.Errorf("short name not set")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run", "-s", "10,20"})
|
||||
}
|
||||
|
||||
func TestParseGenericFromEnv(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_SERVE", "20,30")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
GenericFlag{Name: "serve, s", Value: &Parser{}, EnvVar: "APP_SERVE"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"20", "30"}) {
|
||||
t.Errorf("main name not set from env")
|
||||
}
|
||||
if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"20", "30"}) {
|
||||
t.Errorf("short name not set from env")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
||||
|
||||
func TestParseGenericFromEnvCascade(t *testing.T) {
|
||||
os.Clearenv()
|
||||
os.Setenv("APP_FOO", "99,2000")
|
||||
a := App{
|
||||
Flags: []Flag{
|
||||
GenericFlag{Name: "foos", Value: &Parser{}, EnvVar: "COMPAT_FOO,APP_FOO"},
|
||||
},
|
||||
Action: func(ctx *Context) {
|
||||
if !reflect.DeepEqual(ctx.Generic("foos"), &Parser{"99", "2000"}) {
|
||||
t.Errorf("value not set from env")
|
||||
}
|
||||
},
|
||||
}
|
||||
a.Run([]string{"run"})
|
||||
}
|
@ -0,0 +1,238 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// The text template for the Default help topic.
|
||||
// cli.go uses text/template to render templates. You can
|
||||
// render custom help text by setting this variable.
|
||||
var AppHelpTemplate = `NAME:
|
||||
{{.Name}} - {{.Usage}}
|
||||
|
||||
USAGE:
|
||||
{{.Name}} {{if .Flags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} [arguments...]
|
||||
{{if .Version}}
|
||||
VERSION:
|
||||
{{.Version}}
|
||||
{{end}}{{if len .Authors}}
|
||||
AUTHOR(S):
|
||||
{{range .Authors}}{{ . }}{{end}}
|
||||
{{end}}{{if .Commands}}
|
||||
COMMANDS:
|
||||
{{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}}
|
||||
{{end}}{{end}}{{if .Flags}}
|
||||
GLOBAL OPTIONS:
|
||||
{{range .Flags}}{{.}}
|
||||
{{end}}{{end}}{{if .Copyright }}
|
||||
COPYRIGHT:
|
||||
{{.Copyright}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
// The text template for the command help topic.
|
||||
// cli.go uses text/template to render templates. You can
|
||||
// render custom help text by setting this variable.
|
||||
var CommandHelpTemplate = `NAME:
|
||||
{{.FullName}} - {{.Usage}}
|
||||
|
||||
USAGE:
|
||||
command {{.FullName}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}}
|
||||
|
||||
DESCRIPTION:
|
||||
{{.Description}}{{end}}{{if .Flags}}
|
||||
|
||||
OPTIONS:
|
||||
{{range .Flags}}{{.}}
|
||||
{{end}}{{ end }}
|
||||
`
|
||||
|
||||
// The text template for the subcommand help topic.
|
||||
// cli.go uses text/template to render templates. You can
|
||||
// render custom help text by setting this variable.
|
||||
var SubcommandHelpTemplate = `NAME:
|
||||
{{.Name}} - {{.Usage}}
|
||||
|
||||
USAGE:
|
||||
{{.Name}} command{{if .Flags}} [command options]{{end}} [arguments...]
|
||||
|
||||
COMMANDS:
|
||||
{{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}}
|
||||
{{end}}{{if .Flags}}
|
||||
OPTIONS:
|
||||
{{range .Flags}}{{.}}
|
||||
{{end}}{{end}}
|
||||
`
|
||||
|
||||
var helpCommand = Command{
|
||||
Name: "help",
|
||||
Aliases: []string{"h"},
|
||||
Usage: "Shows a list of commands or help for one command",
|
||||
Action: func(c *Context) {
|
||||
args := c.Args()
|
||||
if args.Present() {
|
||||
ShowCommandHelp(c, args.First())
|
||||
} else {
|
||||
ShowAppHelp(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var helpSubcommand = Command{
|
||||
Name: "help",
|
||||
Aliases: []string{"h"},
|
||||
Usage: "Shows a list of commands or help for one command",
|
||||
Action: func(c *Context) {
|
||||
args := c.Args()
|
||||
if args.Present() {
|
||||
ShowCommandHelp(c, args.First())
|
||||
} else {
|
||||
ShowSubcommandHelp(c)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Prints help for the App or Command
|
||||
type helpPrinter func(w io.Writer, templ string, data interface{})
|
||||
|
||||
var HelpPrinter helpPrinter = printHelp
|
||||
|
||||
// Prints version for the App
|
||||
var VersionPrinter = printVersion
|
||||
|
||||
func ShowAppHelp(c *Context) {
|
||||
HelpPrinter(c.App.Writer, AppHelpTemplate, c.App)
|
||||
}
|
||||
|
||||
// Prints the list of subcommands as the default app completion method
|
||||
func DefaultAppComplete(c *Context) {
|
||||
for _, command := range c.App.Commands {
|
||||
for _, name := range command.Names() {
|
||||
fmt.Fprintln(c.App.Writer, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prints help for the given command
|
||||
func ShowCommandHelp(ctx *Context, command string) {
|
||||
// show the subcommand help for a command with subcommands
|
||||
if command == "" {
|
||||
HelpPrinter(ctx.App.Writer, SubcommandHelpTemplate, ctx.App)
|
||||
return
|
||||
}
|
||||
|
||||
for _, c := range ctx.App.Commands {
|
||||
if c.HasName(command) {
|
||||
HelpPrinter(ctx.App.Writer, CommandHelpTemplate, c)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.App.CommandNotFound != nil {
|
||||
ctx.App.CommandNotFound(ctx, command)
|
||||
} else {
|
||||
fmt.Fprintf(ctx.App.Writer, "No help topic for '%v'\n", command)
|
||||
}
|
||||
}
|
||||
|
||||
// Prints help for the given subcommand
|
||||
func ShowSubcommandHelp(c *Context) {
|
||||
ShowCommandHelp(c, c.Command.Name)
|
||||
}
|
||||
|
||||
// Prints the version number of the App
|
||||
func ShowVersion(c *Context) {
|
||||
VersionPrinter(c)
|
||||
}
|
||||
|
||||
func printVersion(c *Context) {
|
||||
fmt.Fprintf(c.App.Writer, "%v version %v\n", c.App.Name, c.App.Version)
|
||||
}
|
||||
|
||||
// Prints the lists of commands within a given context
|
||||
func ShowCompletions(c *Context) {
|
||||
a := c.App
|
||||
if a != nil && a.BashComplete != nil {
|
||||
a.BashComplete(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Prints the custom completions for a given command
|
||||
func ShowCommandCompletions(ctx *Context, command string) {
|
||||
c := ctx.App.Command(command)
|
||||
if c != nil && c.BashComplete != nil {
|
||||
c.BashComplete(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func printHelp(out io.Writer, templ string, data interface{}) {
|
||||
funcMap := template.FuncMap{
|
||||
"join": strings.Join,
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(out, 0, 8, 1, '\t', 0)
|
||||
t := template.Must(template.New("help").Funcs(funcMap).Parse(templ))
|
||||
err := t.Execute(w, data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
w.Flush()
|
||||
}
|
||||
|
||||
func checkVersion(c *Context) bool {
|
||||
if c.GlobalBool("version") || c.GlobalBool("v") || c.Bool("version") || c.Bool("v") {
|
||||
ShowVersion(c)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkHelp(c *Context) bool {
|
||||
if c.GlobalBool("h") || c.GlobalBool("help") || c.Bool("h") || c.Bool("help") {
|
||||
ShowAppHelp(c)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkCommandHelp(c *Context, name string) bool {
|
||||
if c.Bool("h") || c.Bool("help") {
|
||||
ShowCommandHelp(c, name)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkSubcommandHelp(c *Context) bool {
|
||||
if c.GlobalBool("h") || c.GlobalBool("help") {
|
||||
ShowSubcommandHelp(c)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkCompletions(c *Context) bool {
|
||||
if (c.GlobalBool(BashCompletionFlag.Name) || c.Bool(BashCompletionFlag.Name)) && c.App.EnableBashCompletion {
|
||||
ShowCompletions(c)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkCommandCompletions(c *Context, name string) bool {
|
||||
if c.Bool(BashCompletionFlag.Name) && c.App.EnableBashCompletion {
|
||||
ShowCommandCompletions(c, name)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_ShowAppHelp_NoAuthor(t *testing.T) {
|
||||
output := new(bytes.Buffer)
|
||||
app := NewApp()
|
||||
app.Writer = output
|
||||
|
||||
c := NewContext(app, nil, nil)
|
||||
|
||||
ShowAppHelp(c)
|
||||
|
||||
if bytes.Index(output.Bytes(), []byte("AUTHOR(S):")) != -1 {
|
||||
t.Errorf("expected\n%snot to include %s", output.String(), "AUTHOR(S):")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ShowAppHelp_NoVersion(t *testing.T) {
|
||||
output := new(bytes.Buffer)
|
||||
app := NewApp()
|
||||
app.Writer = output
|
||||
|
||||
app.Version = ""
|
||||
|
||||
c := NewContext(app, nil, nil)
|
||||
|
||||
ShowAppHelp(c)
|
||||
|
||||
if bytes.Index(output.Bytes(), []byte("VERSION:")) != -1 {
|
||||
t.Errorf("expected\n%snot to include %s", output.String(), "VERSION:")
|
||||
}
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
/* Test Helpers */
|
||||
func expect(t *testing.T, a interface{}, b interface{}) {
|
||||
if a != b {
|
||||
t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
|
||||
}
|
||||
}
|
||||
|
||||
func refute(t *testing.T, a interface{}, b interface{}) {
|
||||
if a == b {
|
||||
t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a))
|
||||
}
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
@ -0,0 +1,24 @@
|
||||
# kproc
|
||||
[![GoDoc](https://godoc.org/github.com/codeskyblue/kproc?status.svg)](https://godoc.org/github.com/codeskyblue/kproc)
|
||||
|
||||
This is a golang lib, offer a better way to kill all child process.
|
||||
|
||||
Tested on _windows, linux, darwin._
|
||||
|
||||
This lib has been used in [fswatch](https://github.com/codeskyblue/fswatch).
|
||||
|
||||
## Usage
|
||||
|
||||
go get -v github.com/codeskyblue/kproc
|
||||
|
||||
example:
|
||||
|
||||
func main() {
|
||||
p := kproc.ProcString("python flask_main.py")
|
||||
p.Start()
|
||||
time.Sleep(3 * time.Second)
|
||||
err := p.Terminate(syscall.SIGKILL)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package kproc
|
||||
|
||||
import "os/exec"
|
||||
|
||||
type Process struct {
|
||||
*exec.Cmd
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
// +build !windows
|
||||
|
||||
package kproc
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setupCmd(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
cmd.SysProcAttr.Setsid = true
|
||||
}
|
||||
|
||||
func ProcCommand(cmd *exec.Cmd) *Process {
|
||||
setupCmd(cmd)
|
||||
return &Process{
|
||||
Cmd: cmd,
|
||||
}
|
||||
}
|
||||
|
||||
func ProcString(command string) *Process {
|
||||
cmd := exec.Command("/bin/bash", "-c", command)
|
||||
setupCmd(cmd)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return &Process{
|
||||
Cmd: cmd,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) Terminate(sig os.Signal) (err error) {
|
||||
if p.Process == nil {
|
||||
return
|
||||
}
|
||||
// find pgid, ref: http://unix.stackexchange.com/questions/14815/process-descendants
|
||||
group, err := os.FindProcess(-1 * p.Process.Pid)
|
||||
log.Println(group)
|
||||
if err == nil {
|
||||
err = group.Signal(sig)
|
||||
}
|
||||
return err
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
package kproc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func ProcCommand(cmd *exec.Cmd) *Process {
|
||||
return &Process{
|
||||
Cmd: cmd,
|
||||
}
|
||||
}
|
||||
|
||||
func ProcString(command string) *Process {
|
||||
cmd := exec.Command("cmd", "/c", command)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return &Process{
|
||||
Cmd: cmd,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) Terminate(sig os.Signal) (err error) {
|
||||
if p.Process == nil {
|
||||
return nil
|
||||
}
|
||||
pid := p.Process.Pid
|
||||
c := exec.Command("taskkill", "/t", "/f", "/pid", strconv.Itoa(pid))
|
||||
c.Stdout = os.Stdout
|
||||
c.Stderr = os.Stderr
|
||||
return c.Run()
|
||||
}
|
@ -0,0 +1 @@
|
||||
web: python flask_main.py
|
@ -0,0 +1,11 @@
|
||||
import flask
|
||||
|
||||
|
||||
app = flask.Flask(__name__)
|
||||
|
||||
@app.route('/')
|
||||
def homepage():
|
||||
return 'Home'
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True, host='0.0.0.0')
|
@ -0,0 +1,22 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"kproc"
|
||||
"log"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
p := kproc.ProcString("python flask_main.py")
|
||||
p.Start()
|
||||
time.Sleep(10 * time.Second)
|
||||
err := p.Terminate(syscall.SIGKILL)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
out, _ := exec.Command("lsof", "-i:5000").CombinedOutput()
|
||||
fmt.Println(string(out))
|
||||
}
|
Binary file not shown.
@ -0,0 +1,24 @@
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
|
||||
src
|
@ -0,0 +1,8 @@
|
||||
language: go
|
||||
go:
|
||||
- 1.2
|
||||
- tip
|
||||
notifications:
|
||||
email:
|
||||
- ionathan@gmail.com
|
||||
- marcosnils@gmail.com
|
@ -0,0 +1,20 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013 Jonathan Leibiusky and Marcos Lilljedahl
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,3 @@
|
||||
test:
|
||||
go get -v -d -t ./...
|
||||
go test -v
|
@ -0,0 +1,444 @@
|
||||
[![Build Status](https://img.shields.io/travis/franela/goreq/master.svg)](https://travis-ci.org/franela/goreq)
|
||||
[![GoDoc](https://godoc.org/github.com/franela/goreq?status.svg)](https://godoc.org/github.com/franela/goreq)
|
||||
|
||||
GoReq
|
||||
=======
|
||||
|
||||
Simple and sane HTTP request library for Go language.
|
||||
|
||||
|
||||
|
||||
**Table of Contents**
|
||||
|
||||
- [Why GoReq?](#user-content-why-goreq)
|
||||
- [How do I install it?](#user-content-how-do-i-install-it)
|
||||
- [What can I do with it?](#user-content-what-can-i-do-with-it)
|
||||
- [Making requests with different methods](#user-content-making-requests-with-different-methods)
|
||||
- [GET](#user-content-get)
|
||||
- [Tags](#user-content-tags)
|
||||
- [POST](#user-content-post)
|
||||
- [Sending payloads in the Body](#user-content-sending-payloads-in-the-body)
|
||||
- [Specifiying request headers](#user-content-specifiying-request-headers)
|
||||
- [Sending Cookies](#cookie-support)
|
||||
- [Setting timeouts](#user-content-setting-timeouts)
|
||||
- [Using the Response and Error](#user-content-using-the-response-and-error)
|
||||
- [Receiving JSON](#user-content-receiving-json)
|
||||
- [Sending/Receiving Compressed Payloads](#user-content-sendingreceiving-compressed-payloads)
|
||||
- [Using gzip compression:](#user-content-using-gzip-compression)
|
||||
- [Using deflate compression:](#user-content-using-deflate-compression)
|
||||
- [Using compressed responses:](#user-content-using-compressed-responses)
|
||||
- [Proxy](#proxy)
|
||||
- [Debugging requests](#debug)
|
||||
- [Getting raw Request & Response](#getting-raw-request--response)
|
||||
- [TODO:](#user-content-todo)
|
||||
|
||||
|
||||
|
||||
Why GoReq?
|
||||
==========
|
||||
|
||||
Go has very nice native libraries that allows you to do lots of cool things. But sometimes those libraries are too low level, which means that to do a simple thing, like an HTTP Request, it takes some time. And if you want to do something as simple as adding a timeout to a request, you will end up writing several lines of code.
|
||||
|
||||
This is why we think GoReq is useful. Because you can do all your HTTP requests in a very simple and comprehensive way, while enabling you to do more advanced stuff by giving you access to the native API.
|
||||
|
||||
How do I install it?
|
||||
====================
|
||||
|
||||
```bash
|
||||
go get github.com/franela/goreq
|
||||
```
|
||||
|
||||
What can I do with it?
|
||||
======================
|
||||
|
||||
## Making requests with different methods
|
||||
|
||||
#### GET
|
||||
```go
|
||||
res, err := goreq.Request{ Uri: "http://www.google.com" }.Do()
|
||||
```
|
||||
|
||||
GoReq default method is GET.
|
||||
|
||||
You can also set value to GET method easily
|
||||
|
||||
```go
|
||||
type Item struct {
|
||||
Limit int
|
||||
Skip int
|
||||
Fields string
|
||||
}
|
||||
|
||||
item := Item {
|
||||
Limit: 3,
|
||||
Skip: 5,
|
||||
Fields: "Value",
|
||||
}
|
||||
|
||||
res, err := goreq.Request{
|
||||
Uri: "http://localhost:3000/",
|
||||
QueryString: item,
|
||||
}.Do()
|
||||
```
|
||||
The sample above will send `http://localhost:3000/?limit=3&skip=5&fields=Value`
|
||||
|
||||
Alternatively the `url` tag can be used in struct fields to customize encoding properties
|
||||
|
||||
```go
|
||||
type Item struct {
|
||||
TheLimit int `url:"the_limit"`
|
||||
TheSkip string `url:"the_skip,omitempty"`
|
||||
TheFields string `url:"-"`
|
||||
}
|
||||
|
||||
item := Item {
|
||||
TheLimit: 3,
|
||||
TheSkip: "",
|
||||
TheFields: "Value",
|
||||
}
|
||||
|
||||
res, err := goreq.Request{
|
||||
Uri: "http://localhost:3000/",
|
||||
QueryString: item,
|
||||
}.Do()
|
||||
```
|
||||
The sample above will send `http://localhost:3000/?the_limit=3`
|
||||
|
||||
|
||||
QueryString also support url.Values
|
||||
|
||||
```go
|
||||
item := url.Values{}
|
||||
item.Set("Limit", 3)
|
||||
item.Add("Field", "somefield")
|
||||
item.Add("Field", "someotherfield")
|
||||
|
||||
res, err := goreq.Request{
|
||||
Uri: "http://localhost:3000/",
|
||||
QueryString: item,
|
||||
}.Do()
|
||||
```
|
||||
|
||||
The sample above will send `http://localhost:3000/?limit=3&field=somefield&field=someotherfield`
|
||||
|
||||
### Tags
|
||||
|
||||
Struct field `url` tag is mainly used as the request parameter name.
|
||||
Tags can be comma separated multiple values, 1st value is for naming and rest has special meanings.
|
||||
|
||||
- special tag for 1st value
|
||||
- `-`: value is ignored if set this
|
||||
|
||||
- special tag for rest 2nd value
|
||||
- `omitempty`: zero-value is ignored if set this
|
||||
- `squash`: the fields of embedded struct is used for parameter
|
||||
|
||||
#### Tag Examples
|
||||
|
||||
```go
|
||||
type Place struct {
|
||||
Country string `url:"country"`
|
||||
City string `url:"city"`
|
||||
ZipCode string `url:"zipcode,omitempty"`
|
||||
}
|
||||
|
||||
type Person struct {
|
||||
Place `url:",squash"`
|
||||
|
||||
FirstName string `url:"first_name"`
|
||||
LastName string `url:"last_name"`
|
||||
Age string `url:"age,omitempty"`
|
||||
Password string `url:"-"`
|
||||
}
|
||||
|
||||
johnbull := Person{
|
||||
Place: Place{ // squash the embedded struct value
|
||||
Country: "UK",
|
||||
City: "London",
|
||||
ZipCode: "SW1",
|
||||
},
|
||||
FirstName: "John",
|
||||
LastName: "Doe",
|
||||
Age: "35",
|
||||
Password: "my-secret", // ignored for parameter
|
||||
}
|
||||
|
||||
goreq.Request{
|
||||
Uri: "http://localhost/",
|
||||
QueryString: johnbull,
|
||||
}.Do()
|
||||
// => `http://localhost/?first_name=John&last_name=Doe&age=35&country=UK&city=London&zip_code=SW1`
|
||||
|
||||
|
||||
// age and zipcode will be ignored because of `omitempty`
|
||||
// but firstname isn't.
|
||||
samurai := Person{
|
||||
Place: Place{ // squash the embedded struct value
|
||||
Country: "Japan",
|
||||
City: "Tokyo",
|
||||
},
|
||||
LastName: "Yagyu",
|
||||
}
|
||||
|
||||
goreq.Request{
|
||||
Uri: "http://localhost/",
|
||||
QueryString: samurai,
|
||||
}.Do()
|
||||
// => `http://localhost/?first_name=&last_name=yagyu&country=Japan&city=Tokyo`
|
||||
```
|
||||
|
||||
|
||||
#### POST
|
||||
|
||||
```go
|
||||
res, err := goreq.Request{ Method: "POST", Uri: "http://www.google.com" }.Do()
|
||||
```
|
||||
|
||||
## Sending payloads in the Body
|
||||
|
||||
You can send ```string```, ```Reader``` or ```interface{}``` in the body. The first two will be sent as text. The last one will be marshalled to JSON, if possible.
|
||||
|
||||
```go
|
||||
type Item struct {
|
||||
Id int
|
||||
Name string
|
||||
}
|
||||
|
||||
item := Item{ Id: 1111, Name: "foobar" }
|
||||
|
||||
res, err := goreq.Request{
|
||||
Method: "POST",
|
||||
Uri: "http://www.google.com",
|
||||
Body: item,
|
||||
}.Do()
|
||||
```
|
||||
|
||||
## Specifiying request headers
|
||||
|
||||
We think that most of the times the request headers that you use are: ```Host```, ```Content-Type```, ```Accept``` and ```User-Agent```. This is why we decided to make it very easy to set these headers.
|
||||
|
||||
```go
|
||||
res, err := goreq.Request{
|
||||
Uri: "http://www.google.com",
|
||||
Host: "foobar.com",
|
||||
Accept: "application/json",
|
||||
ContentType: "application/json",
|
||||
UserAgent: "goreq",
|
||||
}.Do()
|
||||
```
|
||||
|
||||
But sometimes you need to set other headers. You can still do it.
|
||||
|
||||
```go
|
||||
req := goreq.Request{ Uri: "http://www.google.com" }
|
||||
|
||||
req.AddHeader("X-Custom", "somevalue")
|
||||
|
||||
req.Do()
|
||||
```
|
||||
|
||||
Alternatively you can use the `WithHeader` function to keep the syntax short
|
||||
|
||||
```go
|
||||
res, err = goreq.Request{ Uri: "http://www.google.com" }.WithHeader("X-Custom", "somevalue").Do()
|
||||
```
|
||||
|
||||
## Cookie support
|
||||
|
||||
Cookies can be either set at the request level by sending a [CookieJar](http://golang.org/pkg/net/http/cookiejar/) in the `CookieJar` request field
|
||||
or you can use goreq's one-liner WithCookie method as shown below
|
||||
|
||||
```go
|
||||
res, err := goreq.Request{
|
||||
Uri: "http://www.google.com",
|
||||
}.
|
||||
WithCookie(&http.Cookie{Name: "c1", Value: "v1"}).
|
||||
Do()
|
||||
```
|
||||
|
||||
## Setting timeouts
|
||||
|
||||
GoReq supports 2 kind of timeouts. A general connection timeout and a request specific one. By default the connection timeout is of 1 second. There is no default for request timeout, which means it will wait forever.
|
||||
|
||||
You can change the connection timeout doing:
|
||||
|
||||
```go
|
||||
goreq.SetConnectTimeout(100 * time.Millisecond)
|
||||
```
|
||||
|
||||
And specify the request timeout doing:
|
||||
|
||||
```go
|
||||
res, err := goreq.Request{
|
||||
Uri: "http://www.google.com",
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}.Do()
|
||||
```
|
||||
|
||||
## Using the Response and Error
|
||||
|
||||
GoReq will always return 2 values: a ```Response``` and an ```Error```.
|
||||
If ```Error``` is not ```nil``` it means that an error happened while doing the request and you shouldn't use the ```Response``` in any way.
|
||||
You can check what happened by getting the error message:
|
||||
|
||||
```go
|
||||
fmt.Println(err.Error())
|
||||
```
|
||||
And to make it easy to know if it was a timeout error, you can ask the error or return it:
|
||||
|
||||
```go
|
||||
if serr, ok := err.(*goreq.Error); ok {
|
||||
if serr.Timeout() {
|
||||
...
|
||||
}
|
||||
}
|
||||
return err
|
||||
```
|
||||
|
||||
If you don't get an error, you can safely use the ```Response```.
|
||||
|
||||
```go
|
||||
res.Uri // return final URL location of the response (fulfilled after redirect was made)
|
||||
res.StatusCode // return the status code of the response
|
||||
res.Body // gives you access to the body
|
||||
res.Body.ToString() // will return the body as a string
|
||||
res.Header.Get("Content-Type") // gives you access to all the response headers
|
||||
```
|
||||
Remember that you should **always** close `res.Body` if it's not `nil`
|
||||
|
||||
## Receiving JSON
|
||||
|
||||
GoReq will help you to receive and unmarshal JSON.
|
||||
|
||||
```go
|
||||
type Item struct {
|
||||
Id int
|
||||
Name string
|
||||
}
|
||||
|
||||
var item Item
|
||||
|
||||
res.Body.FromJsonTo(&item)
|
||||
```
|
||||
|
||||
## Sending/Receiving Compressed Payloads
|
||||
GoReq supports gzip, deflate and zlib compression of requests' body and transparent decompression of responses provided they have a correct `Content-Encoding` header.
|
||||
|
||||
#####Using gzip compression:
|
||||
```go
|
||||
res, err := goreq.Request{
|
||||
Method: "POST",
|
||||
Uri: "http://www.google.com",
|
||||
Body: item,
|
||||
Compression: goreq.Gzip(),
|
||||
}.Do()
|
||||
```
|
||||
#####Using deflate/zlib compression:
|
||||
```go
|
||||
res, err := goreq.Request{
|
||||
Method: "POST",
|
||||
Uri: "http://www.google.com",
|
||||
Body: item,
|
||||
Compression: goreq.Deflate(),
|
||||
}.Do()
|
||||
```
|
||||
#####Using compressed responses:
|
||||
If servers replies a correct and matching `Content-Encoding` header (gzip requires `Content-Encoding: gzip` and deflate `Content-Encoding: deflate`) goreq transparently decompresses the response so the previous example should always work:
|
||||
```go
|
||||
type Item struct {
|
||||
Id int
|
||||
Name string
|
||||
}
|
||||
res, err := goreq.Request{
|
||||
Method: "POST",
|
||||
Uri: "http://www.google.com",
|
||||
Body: item,
|
||||
Compression: goreq.Gzip(),
|
||||
}.Do()
|
||||
var item Item
|
||||
res.Body.FromJsonTo(&item)
|
||||
```
|
||||
If no `Content-Encoding` header is replied by the server GoReq will return the crude response.
|
||||
|
||||
## Proxy
|
||||
If you need to use a proxy for your requests GoReq supports the standard `http_proxy` env variable as well as manually setting the proxy for each request
|
||||
|
||||
```go
|
||||
res, err := goreq.Request{
|
||||
Method: "GET",
|
||||
Proxy: "http://myproxy:myproxyport",
|
||||
Uri: "http://www.google.com",
|
||||
}.Do()
|
||||
```
|
||||
|
||||
### Proxy basic auth is also supported
|
||||
|
||||
```go
|
||||
res, err := goreq.Request{
|
||||
Method: "GET",
|
||||
Proxy: "http://user:pass@myproxy:myproxyport",
|
||||
Uri: "http://www.google.com",
|
||||
}.Do()
|
||||
```
|
||||
|
||||
## Debug
|
||||
If you need to debug your http requests, it can print the http request detail.
|
||||
|
||||
```go
|
||||
res, err := goreq.Request{
|
||||
Method: "GET",
|
||||
Uri: "http://www.google.com",
|
||||
Compression: goreq.Gzip(),
|
||||
ShowDebug: true,
|
||||
}.Do()
|
||||
fmt.Println(res, err)
|
||||
```
|
||||
|
||||
and it will print the log:
|
||||
```
|
||||
GET / HTTP/1.1
|
||||
Host: www.google.com
|
||||
Accept:
|
||||
Accept-Encoding: gzip
|
||||
Content-Encoding: gzip
|
||||
Content-Type:
|
||||
```
|
||||
|
||||
|
||||
### Getting raw Request & Response
|
||||
|
||||
To get the Request:
|
||||
|
||||
```go
|
||||
req := goreq.Request{
|
||||
Host: "foobar.com",
|
||||
}
|
||||
|
||||
//req.Request will return a new instance of an http.Request so you can safely use it for something else
|
||||
request, _ := req.NewRequest()
|
||||
|
||||
```
|
||||
|
||||
|
||||
To get the Response:
|
||||
|
||||
```go
|
||||
res, err := goreq.Request{
|
||||
Method: "GET",
|
||||
Uri: "http://www.google.com",
|
||||
Compression: goreq.Gzip(),
|
||||
ShowDebug: true,
|
||||
}.Do()
|
||||
|
||||
// res.Response will contain the original http.Response structure
|
||||
fmt.Println(res.Response, err)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
TODO:
|
||||
-----
|
||||
|
||||
We do have a couple of [issues](https://github.com/franela/goreq/issues) pending we'll be addressing soon. But feel free to
|
||||
contribute and send us PRs (with tests please :smile:).
|
@ -0,0 +1,494 @@
|
||||
package goreq
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"compress/zlib"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Request struct {
|
||||
headers []headerTuple
|
||||
cookies []*http.Cookie
|
||||
Method string
|
||||
Uri string
|
||||
Body interface{}
|
||||
QueryString interface{}
|
||||
Timeout time.Duration
|
||||
ContentType string
|
||||
Accept string
|
||||
Host string
|
||||
UserAgent string
|
||||
Insecure bool
|
||||
MaxRedirects int
|
||||
RedirectHeaders bool
|
||||
Proxy string
|
||||
Compression *compression
|
||||
BasicAuthUsername string
|
||||
BasicAuthPassword string
|
||||
CookieJar http.CookieJar
|
||||
ShowDebug bool
|
||||
OnBeforeRequest func(goreq *Request, httpreq *http.Request)
|
||||
}
|
||||
|
||||
type compression struct {
|
||||
writer func(buffer io.Writer) (io.WriteCloser, error)
|
||||
reader func(buffer io.Reader) (io.ReadCloser, error)
|
||||
ContentEncoding string
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
*http.Response
|
||||
Uri string
|
||||
Body *Body
|
||||
req *http.Request
|
||||
}
|
||||
|
||||
func (r Response) CancelRequest() {
|
||||
cancelRequest(r.req)
|
||||
|
||||
}
|
||||
|
||||
func cancelRequest(r *http.Request) {
|
||||
if transport, ok := DefaultTransport.(transportRequestCanceler); ok {
|
||||
transport.CancelRequest(r)
|
||||
}
|
||||
}
|
||||
|
||||
type headerTuple struct {
|
||||
name string
|
||||
value string
|
||||
}
|
||||
|
||||
type Body struct {
|
||||
reader io.ReadCloser
|
||||
compressedReader io.ReadCloser
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
timeout bool
|
||||
Err error
|
||||
}
|
||||
|
||||
type transportRequestCanceler interface {
|
||||
CancelRequest(*http.Request)
|
||||
}
|
||||
|
||||
func (e *Error) Timeout() bool {
|
||||
return e.timeout
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (b *Body) Read(p []byte) (int, error) {
|
||||
if b.compressedReader != nil {
|
||||
return b.compressedReader.Read(p)
|
||||
}
|
||||
return b.reader.Read(p)
|
||||
}
|
||||
|
||||
func (b *Body) Close() error {
|
||||
err := b.reader.Close()
|
||||
if b.compressedReader != nil {
|
||||
return b.compressedReader.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Body) FromJsonTo(o interface{}) error {
|
||||
return json.NewDecoder(b).Decode(o)
|
||||
}
|
||||
|
||||
func (b *Body) ToString() (string, error) {
|
||||
body, err := ioutil.ReadAll(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
func Gzip() *compression {
|
||||
reader := func(buffer io.Reader) (io.ReadCloser, error) {
|
||||
return gzip.NewReader(buffer)
|
||||
}
|
||||
writer := func(buffer io.Writer) (io.WriteCloser, error) {
|
||||
return gzip.NewWriter(buffer), nil
|
||||
}
|
||||
return &compression{writer: writer, reader: reader, ContentEncoding: "gzip"}
|
||||
}
|
||||
|
||||
func Deflate() *compression {
|
||||
reader := func(buffer io.Reader) (io.ReadCloser, error) {
|
||||
return zlib.NewReader(buffer)
|
||||
}
|
||||
writer := func(buffer io.Writer) (io.WriteCloser, error) {
|
||||
return zlib.NewWriter(buffer), nil
|
||||
}
|
||||
return &compression{writer: writer, reader: reader, ContentEncoding: "deflate"}
|
||||
}
|
||||
|
||||
func Zlib() *compression {
|
||||
return Deflate()
|
||||
}
|
||||
|
||||
func paramParse(query interface{}) (string, error) {
|
||||
switch query.(type) {
|
||||
case url.Values:
|
||||
return query.(url.Values).Encode(), nil
|
||||
case *url.Values:
|
||||
return query.(*url.Values).Encode(), nil
|
||||
default:
|
||||
var v = &url.Values{}
|
||||
err := paramParseStruct(v, query)
|
||||
return v.Encode(), err
|
||||
}
|
||||
}
|
||||
|
||||
func paramParseStruct(v *url.Values, query interface{}) error {
|
||||
var (
|
||||
s = reflect.ValueOf(query)
|
||||
t = reflect.TypeOf(query)
|
||||
)
|
||||
for t.Kind() == reflect.Ptr || t.Kind() == reflect.Interface {
|
||||
s = s.Elem()
|
||||
t = s.Type()
|
||||
}
|
||||
if t.Kind() != reflect.Struct {
|
||||
return errors.New("Can not parse QueryString.")
|
||||
}
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
var name string
|
||||
|
||||
field := s.Field(i)
|
||||
typeField := t.Field(i)
|
||||
|
||||
if !field.CanInterface() {
|
||||
continue
|
||||
}
|
||||
|
||||
urlTag := typeField.Tag.Get("url")
|
||||
if urlTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
name, opts := parseTag(urlTag)
|
||||
|
||||
var omitEmpty, squash bool
|
||||
omitEmpty = opts.Contains("omitempty")
|
||||
squash = opts.Contains("squash")
|
||||
|
||||
if squash {
|
||||
err := paramParseStruct(v, field.Interface())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if urlTag == "" {
|
||||
name = strings.ToLower(typeField.Name)
|
||||
}
|
||||
|
||||
if val := fmt.Sprintf("%v", field.Interface()); !(omitEmpty && len(val) == 0) {
|
||||
v.Add(name, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareRequestBody(b interface{}) (io.Reader, error) {
|
||||
switch b.(type) {
|
||||
case string:
|
||||
// treat is as text
|
||||
return strings.NewReader(b.(string)), nil
|
||||
case io.Reader:
|
||||
// treat is as text
|
||||
return b.(io.Reader), nil
|
||||
case []byte:
|
||||
//treat as byte array
|
||||
return bytes.NewReader(b.([]byte)), nil
|
||||
case nil:
|
||||
return nil, nil
|
||||
default:
|
||||
// try to jsonify it
|
||||
j, err := json.Marshal(b)
|
||||
if err == nil {
|
||||
return bytes.NewReader(j), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var DefaultDialer = &net.Dialer{Timeout: 1000 * time.Millisecond}
|
||||
var DefaultTransport http.RoundTripper = &http.Transport{Dial: DefaultDialer.Dial, Proxy: http.ProxyFromEnvironment}
|
||||
var DefaultClient = &http.Client{Transport: DefaultTransport}
|
||||
|
||||
var proxyTransport http.RoundTripper
|
||||
var proxyClient *http.Client
|
||||
|
||||
func SetConnectTimeout(duration time.Duration) {
|
||||
DefaultDialer.Timeout = duration
|
||||
}
|
||||
|
||||
func (r *Request) AddHeader(name string, value string) {
|
||||
if r.headers == nil {
|
||||
r.headers = []headerTuple{}
|
||||
}
|
||||
r.headers = append(r.headers, headerTuple{name: name, value: value})
|
||||
}
|
||||
|
||||
func (r Request) WithHeader(name string, value string) Request {
|
||||
r.AddHeader(name, value)
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Request) AddCookie(c *http.Cookie) {
|
||||
r.cookies = append(r.cookies, c)
|
||||
}
|
||||
|
||||
func (r Request) WithCookie(c *http.Cookie) Request {
|
||||
r.AddCookie(c)
|
||||
return r
|
||||
|
||||
}
|
||||
|
||||
func (r Request) Do() (*Response, error) {
|
||||
var client = DefaultClient
|
||||
var transport = DefaultTransport
|
||||
var resUri string
|
||||
var redirectFailed bool
|
||||
|
||||
r.Method = valueOrDefault(r.Method, "GET")
|
||||
|
||||
// use a client with a cookie jar if necessary. We create a new client not
|
||||
// to modify the default one.
|
||||
if r.CookieJar != nil {
|
||||
client = &http.Client{
|
||||
Transport: transport,
|
||||
Jar: r.CookieJar,
|
||||
}
|
||||
}
|
||||
|
||||
if r.Proxy != "" {
|
||||
proxyUrl, err := url.Parse(r.Proxy)
|
||||
if err != nil {
|
||||
// proxy address is in a wrong format
|
||||
return nil, &Error{Err: err}
|
||||
}
|
||||
|
||||
//If jar is specified new client needs to be built
|
||||
if proxyTransport == nil || client.Jar != nil {
|
||||
proxyTransport = &http.Transport{Dial: DefaultDialer.Dial, Proxy: http.ProxyURL(proxyUrl)}
|
||||
proxyClient = &http.Client{Transport: proxyTransport, Jar: client.Jar}
|
||||
} else if proxyTransport, ok := proxyTransport.(*http.Transport); ok {
|
||||
proxyTransport.Proxy = http.ProxyURL(proxyUrl)
|
||||
}
|
||||
transport = proxyTransport
|
||||
client = proxyClient
|
||||
}
|
||||
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
|
||||
if len(via) > r.MaxRedirects {
|
||||
redirectFailed = true
|
||||
return errors.New("Error redirecting. MaxRedirects reached")
|
||||
}
|
||||
|
||||
resUri = req.URL.String()
|
||||
|
||||
//By default Golang will not redirect request headers
|
||||
// https://code.google.com/p/go/issues/detail?id=4800&q=request%20header
|
||||
if r.RedirectHeaders {
|
||||
for key, val := range via[0].Header {
|
||||
req.Header[key] = val
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if transport, ok := transport.(*http.Transport); ok {
|
||||
if r.Insecure {
|
||||
if transport.TLSClientConfig != nil {
|
||||
transport.TLSClientConfig.InsecureSkipVerify = true
|
||||
} else {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
} else if transport.TLSClientConfig != nil {
|
||||
// the default TLS client (when transport.TLSClientConfig==nil) is
|
||||
// already set to verify, so do nothing in that case
|
||||
transport.TLSClientConfig.InsecureSkipVerify = false
|
||||
}
|
||||
}
|
||||
|
||||
req, err := r.NewRequest()
|
||||
|
||||
if err != nil {
|
||||
// we couldn't parse the URL.
|
||||
return nil, &Error{Err: err}
|
||||
}
|
||||
|
||||
timeout := false
|
||||
var timer *time.Timer
|
||||
if r.Timeout > 0 {
|
||||
timer = time.AfterFunc(r.Timeout, func() {
|
||||
cancelRequest(req)
|
||||
timeout = true
|
||||
})
|
||||
}
|
||||
|
||||
if r.ShowDebug {
|
||||
dump, err := httputil.DumpRequest(req, true)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
log.Println(string(dump))
|
||||
}
|
||||
|
||||
if r.OnBeforeRequest != nil {
|
||||
r.OnBeforeRequest(&r, req)
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !timeout {
|
||||
switch err := err.(type) {
|
||||
case *net.OpError:
|
||||
timeout = err.Timeout()
|
||||
case *url.Error:
|
||||
if op, ok := err.Err.(*net.OpError); ok {
|
||||
timeout = op.Timeout()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var response *Response
|
||||
//If redirect fails we still want to return response data
|
||||
if redirectFailed {
|
||||
if res != nil {
|
||||
response = &Response{res, resUri, &Body{reader: res.Body}, req}
|
||||
} else {
|
||||
response = &Response{res, resUri, nil, req}
|
||||
}
|
||||
}
|
||||
|
||||
//If redirect fails and we haven't set a redirect count we shouldn't return an error
|
||||
if redirectFailed && r.MaxRedirects == 0 {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
return response, &Error{timeout: timeout, Err: err}
|
||||
}
|
||||
|
||||
if r.Compression != nil && strings.Contains(res.Header.Get("Content-Encoding"), r.Compression.ContentEncoding) {
|
||||
compressedReader, err := r.Compression.reader(res.Body)
|
||||
if err != nil {
|
||||
return nil, &Error{Err: err}
|
||||
}
|
||||
return &Response{res, resUri, &Body{reader: res.Body, compressedReader: compressedReader}, req}, nil
|
||||
}
|
||||
|
||||
return &Response{res, resUri, &Body{reader: res.Body}, req}, nil
|
||||
}
|
||||
|
||||
func (r Request) addHeaders(headersMap http.Header) {
|
||||
if len(r.UserAgent) > 0 {
|
||||
headersMap.Add("User-Agent", r.UserAgent)
|
||||
}
|
||||
if r.Accept != "" {
|
||||
headersMap.Add("Accept", r.Accept)
|
||||
}
|
||||
if r.ContentType != "" {
|
||||
headersMap.Add("Content-Type", r.ContentType)
|
||||
}
|
||||
}
|
||||
|
||||
func (r Request) NewRequest() (*http.Request, error) {
|
||||
|
||||
b, e := prepareRequestBody(r.Body)
|
||||
if e != nil {
|
||||
// there was a problem marshaling the body
|
||||
return nil, &Error{Err: e}
|
||||
}
|
||||
|
||||
if r.QueryString != nil {
|
||||
param, e := paramParse(r.QueryString)
|
||||
if e != nil {
|
||||
return nil, &Error{Err: e}
|
||||
}
|
||||
r.Uri = r.Uri + "?" + param
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
if b != nil && r.Compression != nil {
|
||||
buffer := bytes.NewBuffer([]byte{})
|
||||
readBuffer := bufio.NewReader(b)
|
||||
writer, err := r.Compression.writer(buffer)
|
||||
if err != nil {
|
||||
return nil, &Error{Err: err}
|
||||
}
|
||||
_, e = readBuffer.WriteTo(writer)
|
||||
writer.Close()
|
||||
if e != nil {
|
||||
return nil, &Error{Err: e}
|
||||
}
|
||||
bodyReader = buffer
|
||||
} else {
|
||||
bodyReader = b
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(r.Method, r.Uri, bodyReader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// add headers to the request
|
||||
req.Host = r.Host
|
||||
|
||||
r.addHeaders(req.Header)
|
||||
if r.Compression != nil {
|
||||
req.Header.Add("Content-Encoding", r.Compression.ContentEncoding)
|
||||
req.Header.Add("Accept-Encoding", r.Compression.ContentEncoding)
|
||||
}
|
||||
if r.headers != nil {
|
||||
for _, header := range r.headers {
|
||||
req.Header.Add(header.name, header.value)
|
||||
}
|
||||
}
|
||||
|
||||
//use basic auth if required
|
||||
if r.BasicAuthUsername != "" {
|
||||
req.SetBasicAuth(r.BasicAuthUsername, r.BasicAuthPassword)
|
||||
}
|
||||
|
||||
for _, c := range r.cookies {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Return value if nonempty, def otherwise.
|
||||
func valueOrDefault(value, def string) string {
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
return def
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,64 @@
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package goreq
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// tagOptions is the string following a comma in a struct field's "json"
|
||||
// tag, or the empty string. It does not include the leading comma.
|
||||
type tagOptions string
|
||||
|
||||
// parseTag splits a struct field's json tag into its name and
|
||||
// comma-separated options.
|
||||
func parseTag(tag string) (string, tagOptions) {
|
||||
if idx := strings.Index(tag, ","); idx != -1 {
|
||||
return tag[:idx], tagOptions(tag[idx+1:])
|
||||
}
|
||||
return tag, tagOptions("")
|
||||
}
|
||||
|
||||
// Contains reports whether a comma-separated list of options
|
||||
// contains a particular substr flag. substr must be surrounded by a
|
||||
// string boundary or commas.
|
||||
func (o tagOptions) Contains(optionName string) bool {
|
||||
if len(o) == 0 {
|
||||
return false
|
||||
}
|
||||
s := string(o)
|
||||
for s != "" {
|
||||
var next string
|
||||
i := strings.Index(s, ",")
|
||||
if i >= 0 {
|
||||
s, next = s[:i], s[i+1:]
|
||||
}
|
||||
if s == optionName {
|
||||
return true
|
||||
}
|
||||
s = next
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isValidTag(s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, c := range s {
|
||||
switch {
|
||||
case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c):
|
||||
// Backslash and quote chars are reserved, but
|
||||
// otherwise any punctuation chars are allowed
|
||||
// in a tag name.
|
||||
default:
|
||||
if !unicode.IsLetter(c) && !unicode.IsDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
# Go support for Protocol Buffers - Google's data interchange format
|
||||
#
|
||||
# Copyright 2010 The Go Authors. All rights reserved.
|
||||
# https://github.com/golang/protobuf
|
||||
#
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are
|
||||
# met:
|
||||
#
|
||||
# * Redistributions of source code must retain the above copyright
|
||||
# notice, this list of conditions and the following disclaimer.
|
||||
# * Redistributions in binary form must reproduce the above
|
||||
# copyright notice, this list of conditions and the following disclaimer
|
||||
# in the documentation and/or other materials provided with the
|
||||
# distribution.
|
||||
# * Neither the name of Google Inc. nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
#
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
install:
|
||||
go install
|
||||
|
||||
test: install generate-test-pbs
|
||||
go test
|
||||
|
||||
|
||||
generate-test-pbs:
|
||||
make install
|
||||
make -C testdata
|
||||
protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata:. proto3_proto/proto3.proto
|
||||
make
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,223 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// Protocol buffer deep copy and merge.
|
||||
// TODO: MessageSet and RawMessage.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Clone returns a deep copy of a protocol buffer.
|
||||
func Clone(pb Message) Message {
|
||||
in := reflect.ValueOf(pb)
|
||||
if in.IsNil() {
|
||||
return pb
|
||||
}
|
||||
|
||||
out := reflect.New(in.Type().Elem())
|
||||
// out is empty so a merge is a deep copy.
|
||||
mergeStruct(out.Elem(), in.Elem())
|
||||
return out.Interface().(Message)
|
||||
}
|
||||
|
||||
// Merge merges src into dst.
|
||||
// Required and optional fields that are set in src will be set to that value in dst.
|
||||
// Elements of repeated fields will be appended.
|
||||
// Merge panics if src and dst are not the same type, or if dst is nil.
|
||||
func Merge(dst, src Message) {
|
||||
in := reflect.ValueOf(src)
|
||||
out := reflect.ValueOf(dst)
|
||||
if out.IsNil() {
|
||||
panic("proto: nil destination")
|
||||
}
|
||||
if in.Type() != out.Type() {
|
||||
// Explicit test prior to mergeStruct so that mistyped nils will fail
|
||||
panic("proto: type mismatch")
|
||||
}
|
||||
if in.IsNil() {
|
||||
// Merging nil into non-nil is a quiet no-op
|
||||
return
|
||||
}
|
||||
mergeStruct(out.Elem(), in.Elem())
|
||||
}
|
||||
|
||||
func mergeStruct(out, in reflect.Value) {
|
||||
sprop := GetProperties(in.Type())
|
||||
for i := 0; i < in.NumField(); i++ {
|
||||
f := in.Type().Field(i)
|
||||
if strings.HasPrefix(f.Name, "XXX_") {
|
||||
continue
|
||||
}
|
||||
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
|
||||
}
|
||||
|
||||
if emIn, ok := in.Addr().Interface().(extendableProto); ok {
|
||||
emOut := out.Addr().Interface().(extendableProto)
|
||||
mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
|
||||
}
|
||||
|
||||
uf := in.FieldByName("XXX_unrecognized")
|
||||
if !uf.IsValid() {
|
||||
return
|
||||
}
|
||||
uin := uf.Bytes()
|
||||
if len(uin) > 0 {
|
||||
out.FieldByName("XXX_unrecognized").SetBytes(append([]byte(nil), uin...))
|
||||
}
|
||||
}
|
||||
|
||||
// mergeAny performs a merge between two values of the same type.
|
||||
// viaPtr indicates whether the values were indirected through a pointer (implying proto2).
|
||||
// prop is set if this is a struct field (it may be nil).
|
||||
func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) {
|
||||
if in.Type() == protoMessageType {
|
||||
if !in.IsNil() {
|
||||
if out.IsNil() {
|
||||
out.Set(reflect.ValueOf(Clone(in.Interface().(Message))))
|
||||
} else {
|
||||
Merge(out.Interface().(Message), in.Interface().(Message))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
switch in.Kind() {
|
||||
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
|
||||
reflect.String, reflect.Uint32, reflect.Uint64:
|
||||
if !viaPtr && isProto3Zero(in) {
|
||||
return
|
||||
}
|
||||
out.Set(in)
|
||||
case reflect.Interface:
|
||||
// Probably a oneof field; copy non-nil values.
|
||||
if in.IsNil() {
|
||||
return
|
||||
}
|
||||
// Allocate destination if it is not set, or set to a different type.
|
||||
// Otherwise we will merge as normal.
|
||||
if out.IsNil() || out.Elem().Type() != in.Elem().Type() {
|
||||
out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T)
|
||||
}
|
||||
mergeAny(out.Elem(), in.Elem(), false, nil)
|
||||
case reflect.Map:
|
||||
if in.Len() == 0 {
|
||||
return
|
||||
}
|
||||
if out.IsNil() {
|
||||
out.Set(reflect.MakeMap(in.Type()))
|
||||
}
|
||||
// For maps with value types of *T or []byte we need to deep copy each value.
|
||||
elemKind := in.Type().Elem().Kind()
|
||||
for _, key := range in.MapKeys() {
|
||||
var val reflect.Value
|
||||
switch elemKind {
|
||||
case reflect.Ptr:
|
||||
val = reflect.New(in.Type().Elem().Elem())
|
||||
mergeAny(val, in.MapIndex(key), false, nil)
|
||||
case reflect.Slice:
|
||||
val = in.MapIndex(key)
|
||||
val = reflect.ValueOf(append([]byte{}, val.Bytes()...))
|
||||
default:
|
||||
val = in.MapIndex(key)
|
||||
}
|
||||
out.SetMapIndex(key, val)
|
||||
}
|
||||
case reflect.Ptr:
|
||||
if in.IsNil() {
|
||||
return
|
||||
}
|
||||
if out.IsNil() {
|
||||
out.Set(reflect.New(in.Elem().Type()))
|
||||
}
|
||||
mergeAny(out.Elem(), in.Elem(), true, nil)
|
||||
case reflect.Slice:
|
||||
if in.IsNil() {
|
||||
return
|
||||
}
|
||||
if in.Type().Elem().Kind() == reflect.Uint8 {
|
||||
// []byte is a scalar bytes field, not a repeated field.
|
||||
|
||||
// Edge case: if this is in a proto3 message, a zero length
|
||||
// bytes field is considered the zero value, and should not
|
||||
// be merged.
|
||||
if prop != nil && prop.proto3 && in.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Make a deep copy.
|
||||
// Append to []byte{} instead of []byte(nil) so that we never end up
|
||||
// with a nil result.
|
||||
out.SetBytes(append([]byte{}, in.Bytes()...))
|
||||
return
|
||||
}
|
||||
n := in.Len()
|
||||
if out.IsNil() {
|
||||
out.Set(reflect.MakeSlice(in.Type(), 0, n))
|
||||
}
|
||||
switch in.Type().Elem().Kind() {
|
||||
case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64,
|
||||
reflect.String, reflect.Uint32, reflect.Uint64:
|
||||
out.Set(reflect.AppendSlice(out, in))
|
||||
default:
|
||||
for i := 0; i < n; i++ {
|
||||
x := reflect.Indirect(reflect.New(in.Type().Elem()))
|
||||
mergeAny(x, in.Index(i), false, nil)
|
||||
out.Set(reflect.Append(out, x))
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
mergeStruct(out, in)
|
||||
default:
|
||||
// unknown type, so not a protocol buffer
|
||||
log.Printf("proto: don't know how to copy %v", in)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeExtension(out, in map[int32]Extension) {
|
||||
for extNum, eIn := range in {
|
||||
eOut := Extension{desc: eIn.desc}
|
||||
if eIn.value != nil {
|
||||
v := reflect.New(reflect.TypeOf(eIn.value)).Elem()
|
||||
mergeAny(v, reflect.ValueOf(eIn.value), false, nil)
|
||||
eOut.value = v.Interface()
|
||||
}
|
||||
if eIn.enc != nil {
|
||||
eOut.enc = make([]byte, len(eIn.enc))
|
||||
copy(eOut.enc, eIn.enc)
|
||||
}
|
||||
|
||||
out[extNum] = eOut
|
||||
}
|
||||
}
|
@ -0,0 +1,267 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
proto3pb "github.com/golang/protobuf/proto/proto3_proto"
|
||||
pb "github.com/golang/protobuf/proto/testdata"
|
||||
)
|
||||
|
||||
var cloneTestMessage = &pb.MyMessage{
|
||||
Count: proto.Int32(42),
|
||||
Name: proto.String("Dave"),
|
||||
Pet: []string{"bunny", "kitty", "horsey"},
|
||||
Inner: &pb.InnerMessage{
|
||||
Host: proto.String("niles"),
|
||||
Port: proto.Int32(9099),
|
||||
Connected: proto.Bool(true),
|
||||
},
|
||||
Others: []*pb.OtherMessage{
|
||||
{
|
||||
Value: []byte("some bytes"),
|
||||
},
|
||||
},
|
||||
Somegroup: &pb.MyMessage_SomeGroup{
|
||||
GroupField: proto.Int32(6),
|
||||
},
|
||||
RepBytes: [][]byte{[]byte("sham"), []byte("wow")},
|
||||
}
|
||||
|
||||
func init() {
|
||||
ext := &pb.Ext{
|
||||
Data: proto.String("extension"),
|
||||
}
|
||||
if err := proto.SetExtension(cloneTestMessage, pb.E_Ext_More, ext); err != nil {
|
||||
panic("SetExtension: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClone(t *testing.T) {
|
||||
m := proto.Clone(cloneTestMessage).(*pb.MyMessage)
|
||||
if !proto.Equal(m, cloneTestMessage) {
|
||||
t.Errorf("Clone(%v) = %v", cloneTestMessage, m)
|
||||
}
|
||||
|
||||
// Verify it was a deep copy.
|
||||
*m.Inner.Port++
|
||||
if proto.Equal(m, cloneTestMessage) {
|
||||
t.Error("Mutating clone changed the original")
|
||||
}
|
||||
// Byte fields and repeated fields should be copied.
|
||||
if &m.Pet[0] == &cloneTestMessage.Pet[0] {
|
||||
t.Error("Pet: repeated field not copied")
|
||||
}
|
||||
if &m.Others[0] == &cloneTestMessage.Others[0] {
|
||||
t.Error("Others: repeated field not copied")
|
||||
}
|
||||
if &m.Others[0].Value[0] == &cloneTestMessage.Others[0].Value[0] {
|
||||
t.Error("Others[0].Value: bytes field not copied")
|
||||
}
|
||||
if &m.RepBytes[0] == &cloneTestMessage.RepBytes[0] {
|
||||
t.Error("RepBytes: repeated field not copied")
|
||||
}
|
||||
if &m.RepBytes[0][0] == &cloneTestMessage.RepBytes[0][0] {
|
||||
t.Error("RepBytes[0]: bytes field not copied")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloneNil(t *testing.T) {
|
||||
var m *pb.MyMessage
|
||||
if c := proto.Clone(m); !proto.Equal(m, c) {
|
||||
t.Errorf("Clone(%v) = %v", m, c)
|
||||
}
|
||||
}
|
||||
|
||||
var mergeTests = []struct {
|
||||
src, dst, want proto.Message
|
||||
}{
|
||||
{
|
||||
src: &pb.MyMessage{
|
||||
Count: proto.Int32(42),
|
||||
},
|
||||
dst: &pb.MyMessage{
|
||||
Name: proto.String("Dave"),
|
||||
},
|
||||
want: &pb.MyMessage{
|
||||
Count: proto.Int32(42),
|
||||
Name: proto.String("Dave"),
|
||||
},
|
||||
},
|
||||
{
|
||||
src: &pb.MyMessage{
|
||||
Inner: &pb.InnerMessage{
|
||||
Host: proto.String("hey"),
|
||||
Connected: proto.Bool(true),
|
||||
},
|
||||
Pet: []string{"horsey"},
|
||||
Others: []*pb.OtherMessage{
|
||||
{
|
||||
Value: []byte("some bytes"),
|
||||
},
|
||||
},
|
||||
},
|
||||
dst: &pb.MyMessage{
|
||||
Inner: &pb.InnerMessage{
|
||||
Host: proto.String("niles"),
|
||||
Port: proto.Int32(9099),
|
||||
},
|
||||
Pet: []string{"bunny", "kitty"},
|
||||
Others: []*pb.OtherMessage{
|
||||
{
|
||||
Key: proto.Int64(31415926535),
|
||||
},
|
||||
{
|
||||
// Explicitly test a src=nil field
|
||||
Inner: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: &pb.MyMessage{
|
||||
Inner: &pb.InnerMessage{
|
||||
Host: proto.String("hey"),
|
||||
Connected: proto.Bool(true),
|
||||
Port: proto.Int32(9099),
|
||||
},
|
||||
Pet: []string{"bunny", "kitty", "horsey"},
|
||||
Others: []*pb.OtherMessage{
|
||||
{
|
||||
Key: proto.Int64(31415926535),
|
||||
},
|
||||
{},
|
||||
{
|
||||
Value: []byte("some bytes"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
src: &pb.MyMessage{
|
||||
RepBytes: [][]byte{[]byte("wow")},
|
||||
},
|
||||
dst: &pb.MyMessage{
|
||||
Somegroup: &pb.MyMessage_SomeGroup{
|
||||
GroupField: proto.Int32(6),
|
||||
},
|
||||
RepBytes: [][]byte{[]byte("sham")},
|
||||
},
|
||||
want: &pb.MyMessage{
|
||||
Somegroup: &pb.MyMessage_SomeGroup{
|
||||
GroupField: proto.Int32(6),
|
||||
},
|
||||
RepBytes: [][]byte{[]byte("sham"), []byte("wow")},
|
||||
},
|
||||
},
|
||||
// Check that a scalar bytes field replaces rather than appends.
|
||||
{
|
||||
src: &pb.OtherMessage{Value: []byte("foo")},
|
||||
dst: &pb.OtherMessage{Value: []byte("bar")},
|
||||
want: &pb.OtherMessage{Value: []byte("foo")},
|
||||
},
|
||||
{
|
||||
src: &pb.MessageWithMap{
|
||||
NameMapping: map[int32]string{6: "Nigel"},
|
||||
MsgMapping: map[int64]*pb.FloatingPoint{
|
||||
0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
|
||||
},
|
||||
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
|
||||
},
|
||||
dst: &pb.MessageWithMap{
|
||||
NameMapping: map[int32]string{
|
||||
6: "Bruce", // should be overwritten
|
||||
7: "Andrew",
|
||||
},
|
||||
},
|
||||
want: &pb.MessageWithMap{
|
||||
NameMapping: map[int32]string{
|
||||
6: "Nigel",
|
||||
7: "Andrew",
|
||||
},
|
||||
MsgMapping: map[int64]*pb.FloatingPoint{
|
||||
0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)},
|
||||
},
|
||||
ByteMapping: map[bool][]byte{true: []byte("wowsa")},
|
||||
},
|
||||
},
|
||||
// proto3 shouldn't merge zero values,
|
||||
// in the same way that proto2 shouldn't merge nils.
|
||||
{
|
||||
src: &proto3pb.Message{
|
||||
Name: "Aaron",
|
||||
Data: []byte(""), // zero value, but not nil
|
||||
},
|
||||
dst: &proto3pb.Message{
|
||||
HeightInCm: 176,
|
||||
Data: []byte("texas!"),
|
||||
},
|
||||
want: &proto3pb.Message{
|
||||
Name: "Aaron",
|
||||
HeightInCm: 176,
|
||||
Data: []byte("texas!"),
|
||||
},
|
||||
},
|
||||
// Oneof fields should merge by assignment.
|
||||
{
|
||||
src: &pb.Communique{
|
||||
Union: &pb.Communique_Number{41},
|
||||
},
|
||||
dst: &pb.Communique{
|
||||
Union: &pb.Communique_Name{"Bobby Tables"},
|
||||
},
|
||||
want: &pb.Communique{
|
||||
Union: &pb.Communique_Number{41},
|
||||
},
|
||||
},
|
||||
// Oneof nil is the same as not set.
|
||||
{
|
||||
src: &pb.Communique{},
|
||||
dst: &pb.Communique{
|
||||
Union: &pb.Communique_Name{"Bobby Tables"},
|
||||
},
|
||||
want: &pb.Communique{
|
||||
Union: &pb.Communique_Name{"Bobby Tables"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
for _, m := range mergeTests {
|
||||
got := proto.Clone(m.dst)
|
||||
proto.Merge(got, m.src)
|
||||
if !proto.Equal(got, m.want) {
|
||||
t.Errorf("Merge(%v, %v)\n got %v\nwant %v\n", m.dst, m.src, got, m.want)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,863 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
/*
|
||||
* Routines for decoding protocol buffer data to construct in-memory representations.
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// errOverflow is returned when an integer is too large to be represented.
|
||||
var errOverflow = errors.New("proto: integer overflow")
|
||||
|
||||
// ErrInternalBadWireType is returned by generated code when an incorrect
|
||||
// wire type is encountered. It does not get returned to user code.
|
||||
var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof")
|
||||
|
||||
// The fundamental decoders that interpret bytes on the wire.
|
||||
// Those that take integer types all return uint64 and are
|
||||
// therefore of type valueDecoder.
|
||||
|
||||
// DecodeVarint reads a varint-encoded integer from the slice.
|
||||
// It returns the integer and the number of bytes consumed, or
|
||||
// zero if there is not enough.
|
||||
// This is the format for the
|
||||
// int32, int64, uint32, uint64, bool, and enum
|
||||
// protocol buffer types.
|
||||
func DecodeVarint(buf []byte) (x uint64, n int) {
|
||||
// x, n already 0
|
||||
for shift := uint(0); shift < 64; shift += 7 {
|
||||
if n >= len(buf) {
|
||||
return 0, 0
|
||||
}
|
||||
b := uint64(buf[n])
|
||||
n++
|
||||
x |= (b & 0x7F) << shift
|
||||
if (b & 0x80) == 0 {
|
||||
return x, n
|
||||
}
|
||||
}
|
||||
|
||||
// The number is too large to represent in a 64-bit value.
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// DecodeVarint reads a varint-encoded integer from the Buffer.
|
||||
// This is the format for the
|
||||
// int32, int64, uint32, uint64, bool, and enum
|
||||
// protocol buffer types.
|
||||
func (p *Buffer) DecodeVarint() (x uint64, err error) {
|
||||
// x, err already 0
|
||||
|
||||
i := p.index
|
||||
l := len(p.buf)
|
||||
|
||||
for shift := uint(0); shift < 64; shift += 7 {
|
||||
if i >= l {
|
||||
err = io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
b := p.buf[i]
|
||||
i++
|
||||
x |= (uint64(b) & 0x7F) << shift
|
||||
if b < 0x80 {
|
||||
p.index = i
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// The number is too large to represent in a 64-bit value.
|
||||
err = errOverflow
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeFixed64 reads a 64-bit integer from the Buffer.
|
||||
// This is the format for the
|
||||
// fixed64, sfixed64, and double protocol buffer types.
|
||||
func (p *Buffer) DecodeFixed64() (x uint64, err error) {
|
||||
// x, err already 0
|
||||
i := p.index + 8
|
||||
if i < 0 || i > len(p.buf) {
|
||||
err = io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
p.index = i
|
||||
|
||||
x = uint64(p.buf[i-8])
|
||||
x |= uint64(p.buf[i-7]) << 8
|
||||
x |= uint64(p.buf[i-6]) << 16
|
||||
x |= uint64(p.buf[i-5]) << 24
|
||||
x |= uint64(p.buf[i-4]) << 32
|
||||
x |= uint64(p.buf[i-3]) << 40
|
||||
x |= uint64(p.buf[i-2]) << 48
|
||||
x |= uint64(p.buf[i-1]) << 56
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeFixed32 reads a 32-bit integer from the Buffer.
|
||||
// This is the format for the
|
||||
// fixed32, sfixed32, and float protocol buffer types.
|
||||
func (p *Buffer) DecodeFixed32() (x uint64, err error) {
|
||||
// x, err already 0
|
||||
i := p.index + 4
|
||||
if i < 0 || i > len(p.buf) {
|
||||
err = io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
p.index = i
|
||||
|
||||
x = uint64(p.buf[i-4])
|
||||
x |= uint64(p.buf[i-3]) << 8
|
||||
x |= uint64(p.buf[i-2]) << 16
|
||||
x |= uint64(p.buf[i-1]) << 24
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeZigzag64 reads a zigzag-encoded 64-bit integer
|
||||
// from the Buffer.
|
||||
// This is the format used for the sint64 protocol buffer type.
|
||||
func (p *Buffer) DecodeZigzag64() (x uint64, err error) {
|
||||
x, err = p.DecodeVarint()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
x = (x >> 1) ^ uint64((int64(x&1)<<63)>>63)
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeZigzag32 reads a zigzag-encoded 32-bit integer
|
||||
// from the Buffer.
|
||||
// This is the format used for the sint32 protocol buffer type.
|
||||
func (p *Buffer) DecodeZigzag32() (x uint64, err error) {
|
||||
x, err = p.DecodeVarint()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
x = uint64((uint32(x) >> 1) ^ uint32((int32(x&1)<<31)>>31))
|
||||
return
|
||||
}
|
||||
|
||||
// These are not ValueDecoders: they produce an array of bytes or a string.
|
||||
// bytes, embedded messages
|
||||
|
||||
// DecodeRawBytes reads a count-delimited byte buffer from the Buffer.
|
||||
// This is the format used for the bytes protocol buffer
|
||||
// type and for embedded messages.
|
||||
func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) {
|
||||
n, err := p.DecodeVarint()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nb := int(n)
|
||||
if nb < 0 {
|
||||
return nil, fmt.Errorf("proto: bad byte length %d", nb)
|
||||
}
|
||||
end := p.index + nb
|
||||
if end < p.index || end > len(p.buf) {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
if !alloc {
|
||||
// todo: check if can get more uses of alloc=false
|
||||
buf = p.buf[p.index:end]
|
||||
p.index += nb
|
||||
return
|
||||
}
|
||||
|
||||
buf = make([]byte, nb)
|
||||
copy(buf, p.buf[p.index:])
|
||||
p.index += nb
|
||||
return
|
||||
}
|
||||
|
||||
// DecodeStringBytes reads an encoded string from the Buffer.
|
||||
// This is the format used for the proto2 string type.
|
||||
func (p *Buffer) DecodeStringBytes() (s string, err error) {
|
||||
buf, err := p.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
|
||||
// If the protocol buffer has extensions, and the field matches, add it as an extension.
|
||||
// Otherwise, if the XXX_unrecognized field exists, append the skipped data there.
|
||||
func (o *Buffer) skipAndSave(t reflect.Type, tag, wire int, base structPointer, unrecField field) error {
|
||||
oi := o.index
|
||||
|
||||
err := o.skip(t, tag, wire)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !unrecField.IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
ptr := structPointer_Bytes(base, unrecField)
|
||||
|
||||
// Add the skipped field to struct field
|
||||
obuf := o.buf
|
||||
|
||||
o.buf = *ptr
|
||||
o.EncodeVarint(uint64(tag<<3 | wire))
|
||||
*ptr = append(o.buf, obuf[oi:o.index]...)
|
||||
|
||||
o.buf = obuf
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Skip the next item in the buffer. Its wire type is decoded and presented as an argument.
|
||||
func (o *Buffer) skip(t reflect.Type, tag, wire int) error {
|
||||
|
||||
var u uint64
|
||||
var err error
|
||||
|
||||
switch wire {
|
||||
case WireVarint:
|
||||
_, err = o.DecodeVarint()
|
||||
case WireFixed64:
|
||||
_, err = o.DecodeFixed64()
|
||||
case WireBytes:
|
||||
_, err = o.DecodeRawBytes(false)
|
||||
case WireFixed32:
|
||||
_, err = o.DecodeFixed32()
|
||||
case WireStartGroup:
|
||||
for {
|
||||
u, err = o.DecodeVarint()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fwire := int(u & 0x7)
|
||||
if fwire == WireEndGroup {
|
||||
break
|
||||
}
|
||||
ftag := int(u >> 3)
|
||||
err = o.skip(t, ftag, fwire)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("proto: can't skip unknown wire type %d for %s", wire, t)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Unmarshaler is the interface representing objects that can
|
||||
// unmarshal themselves. The method should reset the receiver before
|
||||
// decoding starts. The argument points to data that may be
|
||||
// overwritten, so implementations should not keep references to the
|
||||
// buffer.
|
||||
type Unmarshaler interface {
|
||||
Unmarshal([]byte) error
|
||||
}
|
||||
|
||||
// Unmarshal parses the protocol buffer representation in buf and places the
|
||||
// decoded result in pb. If the struct underlying pb does not match
|
||||
// the data in buf, the results can be unpredictable.
|
||||
//
|
||||
// Unmarshal resets pb before starting to unmarshal, so any
|
||||
// existing data in pb is always removed. Use UnmarshalMerge
|
||||
// to preserve and append to existing data.
|
||||
func Unmarshal(buf []byte, pb Message) error {
|
||||
pb.Reset()
|
||||
return UnmarshalMerge(buf, pb)
|
||||
}
|
||||
|
||||
// UnmarshalMerge parses the protocol buffer representation in buf and
|
||||
// writes the decoded result to pb. If the struct underlying pb does not match
|
||||
// the data in buf, the results can be unpredictable.
|
||||
//
|
||||
// UnmarshalMerge merges into existing data in pb.
|
||||
// Most code should use Unmarshal instead.
|
||||
func UnmarshalMerge(buf []byte, pb Message) error {
|
||||
// If the object can unmarshal itself, let it.
|
||||
if u, ok := pb.(Unmarshaler); ok {
|
||||
return u.Unmarshal(buf)
|
||||
}
|
||||
return NewBuffer(buf).Unmarshal(pb)
|
||||
}
|
||||
|
||||
// DecodeMessage reads a count-delimited message from the Buffer.
|
||||
func (p *Buffer) DecodeMessage(pb Message) error {
|
||||
enc, err := p.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return NewBuffer(enc).Unmarshal(pb)
|
||||
}
|
||||
|
||||
// DecodeGroup reads a tag-delimited group from the Buffer.
|
||||
func (p *Buffer) DecodeGroup(pb Message) error {
|
||||
typ, base, err := getbase(pb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base)
|
||||
}
|
||||
|
||||
// Unmarshal parses the protocol buffer representation in the
|
||||
// Buffer and places the decoded result in pb. If the struct
|
||||
// underlying pb does not match the data in the buffer, the results can be
|
||||
// unpredictable.
|
||||
func (p *Buffer) Unmarshal(pb Message) error {
|
||||
// If the object can unmarshal itself, let it.
|
||||
if u, ok := pb.(Unmarshaler); ok {
|
||||
err := u.Unmarshal(p.buf[p.index:])
|
||||
p.index = len(p.buf)
|
||||
return err
|
||||
}
|
||||
|
||||
typ, base, err := getbase(pb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), false, base)
|
||||
|
||||
if collectStats {
|
||||
stats.Decode++
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// unmarshalType does the work of unmarshaling a structure.
|
||||
func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error {
|
||||
var state errorState
|
||||
required, reqFields := prop.reqCount, uint64(0)
|
||||
|
||||
var err error
|
||||
for err == nil && o.index < len(o.buf) {
|
||||
oi := o.index
|
||||
var u uint64
|
||||
u, err = o.DecodeVarint()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
wire := int(u & 0x7)
|
||||
if wire == WireEndGroup {
|
||||
if is_group {
|
||||
return nil // input is satisfied
|
||||
}
|
||||
return fmt.Errorf("proto: %s: wiretype end group for non-group", st)
|
||||
}
|
||||
tag := int(u >> 3)
|
||||
if tag <= 0 {
|
||||
return fmt.Errorf("proto: %s: illegal tag %d (wire type %d)", st, tag, wire)
|
||||
}
|
||||
fieldnum, ok := prop.decoderTags.get(tag)
|
||||
if !ok {
|
||||
// Maybe it's an extension?
|
||||
if prop.extendable {
|
||||
if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) {
|
||||
if err = o.skip(st, tag, wire); err == nil {
|
||||
ext := e.ExtensionMap()[int32(tag)] // may be missing
|
||||
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
|
||||
e.ExtensionMap()[int32(tag)] = ext
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Maybe it's a oneof?
|
||||
if prop.oneofUnmarshaler != nil {
|
||||
m := structPointer_Interface(base, st).(Message)
|
||||
// First return value indicates whether tag is a oneof field.
|
||||
ok, err = prop.oneofUnmarshaler(m, tag, wire, o)
|
||||
if err == ErrInternalBadWireType {
|
||||
// Map the error to something more descriptive.
|
||||
// Do the formatting here to save generated code space.
|
||||
err = fmt.Errorf("bad wiretype for oneof field in %T", m)
|
||||
}
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
err = o.skipAndSave(st, tag, wire, base, prop.unrecField)
|
||||
continue
|
||||
}
|
||||
p := prop.Prop[fieldnum]
|
||||
|
||||
if p.dec == nil {
|
||||
fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", st, st.Field(fieldnum).Name)
|
||||
continue
|
||||
}
|
||||
dec := p.dec
|
||||
if wire != WireStartGroup && wire != p.WireType {
|
||||
if wire == WireBytes && p.packedDec != nil {
|
||||
// a packable field
|
||||
dec = p.packedDec
|
||||
} else {
|
||||
err = fmt.Errorf("proto: bad wiretype for field %s.%s: got wiretype %d, want %d", st, st.Field(fieldnum).Name, wire, p.WireType)
|
||||
continue
|
||||
}
|
||||
}
|
||||
decErr := dec(o, p, base)
|
||||
if decErr != nil && !state.shouldContinue(decErr, p) {
|
||||
err = decErr
|
||||
}
|
||||
if err == nil && p.Required {
|
||||
// Successfully decoded a required field.
|
||||
if tag <= 64 {
|
||||
// use bitmap for fields 1-64 to catch field reuse.
|
||||
var mask uint64 = 1 << uint64(tag-1)
|
||||
if reqFields&mask == 0 {
|
||||
// new required field
|
||||
reqFields |= mask
|
||||
required--
|
||||
}
|
||||
} else {
|
||||
// This is imprecise. It can be fooled by a required field
|
||||
// with a tag > 64 that is encoded twice; that's very rare.
|
||||
// A fully correct implementation would require allocating
|
||||
// a data structure, which we would like to avoid.
|
||||
required--
|
||||
}
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
if is_group {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
if state.err != nil {
|
||||
return state.err
|
||||
}
|
||||
if required > 0 {
|
||||
// Not enough information to determine the exact field. If we use extra
|
||||
// CPU, we could determine the field only if the missing required field
|
||||
// has a tag <= 64 and we check reqFields.
|
||||
return &RequiredNotSetError{"{Unknown}"}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Individual type decoders
|
||||
// For each,
|
||||
// u is the decoded value,
|
||||
// v is a pointer to the field (pointer) in the struct
|
||||
|
||||
// Sizes of the pools to allocate inside the Buffer.
|
||||
// The goal is modest amortization and allocation
|
||||
// on at least 16-byte boundaries.
|
||||
const (
|
||||
boolPoolSize = 16
|
||||
uint32PoolSize = 8
|
||||
uint64PoolSize = 4
|
||||
)
|
||||
|
||||
// Decode a bool.
|
||||
func (o *Buffer) dec_bool(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(o.bools) == 0 {
|
||||
o.bools = make([]bool, boolPoolSize)
|
||||
}
|
||||
o.bools[0] = u != 0
|
||||
*structPointer_Bool(base, p.field) = &o.bools[0]
|
||||
o.bools = o.bools[1:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Buffer) dec_proto3_bool(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*structPointer_BoolVal(base, p.field) = u != 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode an int32.
|
||||
func (o *Buffer) dec_int32(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
word32_Set(structPointer_Word32(base, p.field), o, uint32(u))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Buffer) dec_proto3_int32(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
word32Val_Set(structPointer_Word32Val(base, p.field), uint32(u))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode an int64.
|
||||
func (o *Buffer) dec_int64(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
word64_Set(structPointer_Word64(base, p.field), o, u)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Buffer) dec_proto3_int64(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
word64Val_Set(structPointer_Word64Val(base, p.field), o, u)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a string.
|
||||
func (o *Buffer) dec_string(p *Properties, base structPointer) error {
|
||||
s, err := o.DecodeStringBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*structPointer_String(base, p.field) = &s
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Buffer) dec_proto3_string(p *Properties, base structPointer) error {
|
||||
s, err := o.DecodeStringBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*structPointer_StringVal(base, p.field) = s
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of bytes ([]byte).
|
||||
func (o *Buffer) dec_slice_byte(p *Properties, base structPointer) error {
|
||||
b, err := o.DecodeRawBytes(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*structPointer_Bytes(base, p.field) = b
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of bools ([]bool).
|
||||
func (o *Buffer) dec_slice_bool(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := structPointer_BoolSlice(base, p.field)
|
||||
*v = append(*v, u != 0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of bools ([]bool) in packed format.
|
||||
func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error {
|
||||
v := structPointer_BoolSlice(base, p.field)
|
||||
|
||||
nn, err := o.DecodeVarint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nb := int(nn) // number of bytes of encoded bools
|
||||
|
||||
y := *v
|
||||
for i := 0; i < nb; i++ {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
y = append(y, u != 0)
|
||||
}
|
||||
|
||||
*v = y
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of int32s ([]int32).
|
||||
func (o *Buffer) dec_slice_int32(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
structPointer_Word32Slice(base, p.field).Append(uint32(u))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of int32s ([]int32) in packed format.
|
||||
func (o *Buffer) dec_slice_packed_int32(p *Properties, base structPointer) error {
|
||||
v := structPointer_Word32Slice(base, p.field)
|
||||
|
||||
nn, err := o.DecodeVarint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nb := int(nn) // number of bytes of encoded int32s
|
||||
|
||||
fin := o.index + nb
|
||||
if fin < o.index {
|
||||
return errOverflow
|
||||
}
|
||||
for o.index < fin {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Append(uint32(u))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of int64s ([]int64).
|
||||
func (o *Buffer) dec_slice_int64(p *Properties, base structPointer) error {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
structPointer_Word64Slice(base, p.field).Append(u)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of int64s ([]int64) in packed format.
|
||||
func (o *Buffer) dec_slice_packed_int64(p *Properties, base structPointer) error {
|
||||
v := structPointer_Word64Slice(base, p.field)
|
||||
|
||||
nn, err := o.DecodeVarint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nb := int(nn) // number of bytes of encoded int64s
|
||||
|
||||
fin := o.index + nb
|
||||
if fin < o.index {
|
||||
return errOverflow
|
||||
}
|
||||
for o.index < fin {
|
||||
u, err := p.valDec(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v.Append(u)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of strings ([]string).
|
||||
func (o *Buffer) dec_slice_string(p *Properties, base structPointer) error {
|
||||
s, err := o.DecodeStringBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := structPointer_StringSlice(base, p.field)
|
||||
*v = append(*v, s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a slice of slice of bytes ([][]byte).
|
||||
func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error {
|
||||
b, err := o.DecodeRawBytes(true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := structPointer_BytesSlice(base, p.field)
|
||||
*v = append(*v, b)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a map field.
|
||||
func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
|
||||
raw, err := o.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
oi := o.index // index at the end of this map entry
|
||||
o.index -= len(raw) // move buffer back to start of map entry
|
||||
|
||||
mptr := structPointer_NewAt(base, p.field, p.mtype) // *map[K]V
|
||||
if mptr.Elem().IsNil() {
|
||||
mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem()))
|
||||
}
|
||||
v := mptr.Elem() // map[K]V
|
||||
|
||||
// Prepare addressable doubly-indirect placeholders for the key and value types.
|
||||
// See enc_new_map for why.
|
||||
keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K
|
||||
keybase := toStructPointer(keyptr.Addr()) // **K
|
||||
|
||||
var valbase structPointer
|
||||
var valptr reflect.Value
|
||||
switch p.mtype.Elem().Kind() {
|
||||
case reflect.Slice:
|
||||
// []byte
|
||||
var dummy []byte
|
||||
valptr = reflect.ValueOf(&dummy) // *[]byte
|
||||
valbase = toStructPointer(valptr) // *[]byte
|
||||
case reflect.Ptr:
|
||||
// message; valptr is **Msg; need to allocate the intermediate pointer
|
||||
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
|
||||
valptr.Set(reflect.New(valptr.Type().Elem()))
|
||||
valbase = toStructPointer(valptr)
|
||||
default:
|
||||
// everything else
|
||||
valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V
|
||||
valbase = toStructPointer(valptr.Addr()) // **V
|
||||
}
|
||||
|
||||
// Decode.
|
||||
// This parses a restricted wire format, namely the encoding of a message
|
||||
// with two fields. See enc_new_map for the format.
|
||||
for o.index < oi {
|
||||
// tagcode for key and value properties are always a single byte
|
||||
// because they have tags 1 and 2.
|
||||
tagcode := o.buf[o.index]
|
||||
o.index++
|
||||
switch tagcode {
|
||||
case p.mkeyprop.tagcode[0]:
|
||||
if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil {
|
||||
return err
|
||||
}
|
||||
case p.mvalprop.tagcode[0]:
|
||||
if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
// TODO: Should we silently skip this instead?
|
||||
return fmt.Errorf("proto: bad map data tag %d", raw[0])
|
||||
}
|
||||
}
|
||||
keyelem, valelem := keyptr.Elem(), valptr.Elem()
|
||||
if !keyelem.IsValid() || !valelem.IsValid() {
|
||||
// We did not decode the key or the value in the map entry.
|
||||
// Either way, it's an invalid map entry.
|
||||
return fmt.Errorf("proto: bad map data: missing key/val")
|
||||
}
|
||||
|
||||
v.SetMapIndex(keyelem, valelem)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode a group.
|
||||
func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error {
|
||||
bas := structPointer_GetStructPointer(base, p.field)
|
||||
if structPointer_IsNil(bas) {
|
||||
// allocate new nested message
|
||||
bas = toStructPointer(reflect.New(p.stype))
|
||||
structPointer_SetStructPointer(base, p.field, bas)
|
||||
}
|
||||
return o.unmarshalType(p.stype, p.sprop, true, bas)
|
||||
}
|
||||
|
||||
// Decode an embedded message.
|
||||
func (o *Buffer) dec_struct_message(p *Properties, base structPointer) (err error) {
|
||||
raw, e := o.DecodeRawBytes(false)
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
|
||||
bas := structPointer_GetStructPointer(base, p.field)
|
||||
if structPointer_IsNil(bas) {
|
||||
// allocate new nested message
|
||||
bas = toStructPointer(reflect.New(p.stype))
|
||||
structPointer_SetStructPointer(base, p.field, bas)
|
||||
}
|
||||
|
||||
// If the object can unmarshal itself, let it.
|
||||
if p.isUnmarshaler {
|
||||
iv := structPointer_Interface(bas, p.stype)
|
||||
return iv.(Unmarshaler).Unmarshal(raw)
|
||||
}
|
||||
|
||||
obuf := o.buf
|
||||
oi := o.index
|
||||
o.buf = raw
|
||||
o.index = 0
|
||||
|
||||
err = o.unmarshalType(p.stype, p.sprop, false, bas)
|
||||
o.buf = obuf
|
||||
o.index = oi
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode a slice of embedded messages.
|
||||
func (o *Buffer) dec_slice_struct_message(p *Properties, base structPointer) error {
|
||||
return o.dec_slice_struct(p, false, base)
|
||||
}
|
||||
|
||||
// Decode a slice of embedded groups.
|
||||
func (o *Buffer) dec_slice_struct_group(p *Properties, base structPointer) error {
|
||||
return o.dec_slice_struct(p, true, base)
|
||||
}
|
||||
|
||||
// Decode a slice of structs ([]*struct).
|
||||
func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base structPointer) error {
|
||||
v := reflect.New(p.stype)
|
||||
bas := toStructPointer(v)
|
||||
structPointer_StructPointerSlice(base, p.field).Append(bas)
|
||||
|
||||
if is_group {
|
||||
err := o.unmarshalType(p.stype, p.sprop, is_group, bas)
|
||||
return err
|
||||
}
|
||||
|
||||
raw, err := o.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the object can unmarshal itself, let it.
|
||||
if p.isUnmarshaler {
|
||||
iv := v.Interface()
|
||||
return iv.(Unmarshaler).Unmarshal(raw)
|
||||
}
|
||||
|
||||
obuf := o.buf
|
||||
oi := o.index
|
||||
o.buf = raw
|
||||
o.index = 0
|
||||
|
||||
err = o.unmarshalType(p.stype, p.sprop, is_group, bas)
|
||||
|
||||
o.buf = obuf
|
||||
o.index = oi
|
||||
|
||||
return err
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,267 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
// Protocol buffer comparison.
|
||||
// TODO: MessageSet.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
/*
|
||||
Equal returns true iff protocol buffers a and b are equal.
|
||||
The arguments must both be pointers to protocol buffer structs.
|
||||
|
||||
Equality is defined in this way:
|
||||
- Two messages are equal iff they are the same type,
|
||||
corresponding fields are equal, unknown field sets
|
||||
are equal, and extensions sets are equal.
|
||||
- Two set scalar fields are equal iff their values are equal.
|
||||
If the fields are of a floating-point type, remember that
|
||||
NaN != x for all x, including NaN.
|
||||
- Two repeated fields are equal iff their lengths are the same,
|
||||
and their corresponding elements are equal (a "bytes" field,
|
||||
although represented by []byte, is not a repeated field)
|
||||
- Two unset fields are equal.
|
||||
- Two unknown field sets are equal if their current
|
||||
encoded state is equal.
|
||||
- Two extension sets are equal iff they have corresponding
|
||||
elements that are pairwise equal.
|
||||
- Every other combination of things are not equal.
|
||||
|
||||
The return value is undefined if a and b are not protocol buffers.
|
||||
*/
|
||||
func Equal(a, b Message) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == b
|
||||
}
|
||||
v1, v2 := reflect.ValueOf(a), reflect.ValueOf(b)
|
||||
if v1.Type() != v2.Type() {
|
||||
return false
|
||||
}
|
||||
if v1.Kind() == reflect.Ptr {
|
||||
if v1.IsNil() {
|
||||
return v2.IsNil()
|
||||
}
|
||||
if v2.IsNil() {
|
||||
return false
|
||||
}
|
||||
v1, v2 = v1.Elem(), v2.Elem()
|
||||
}
|
||||
if v1.Kind() != reflect.Struct {
|
||||
return false
|
||||
}
|
||||
return equalStruct(v1, v2)
|
||||
}
|
||||
|
||||
// v1 and v2 are known to have the same type.
|
||||
func equalStruct(v1, v2 reflect.Value) bool {
|
||||
for i := 0; i < v1.NumField(); i++ {
|
||||
f := v1.Type().Field(i)
|
||||
if strings.HasPrefix(f.Name, "XXX_") {
|
||||
continue
|
||||
}
|
||||
f1, f2 := v1.Field(i), v2.Field(i)
|
||||
if f.Type.Kind() == reflect.Ptr {
|
||||
if n1, n2 := f1.IsNil(), f2.IsNil(); n1 && n2 {
|
||||
// both unset
|
||||
continue
|
||||
} else if n1 != n2 {
|
||||
// set/unset mismatch
|
||||
return false
|
||||
}
|
||||
b1, ok := f1.Interface().(raw)
|
||||
if ok {
|
||||
b2 := f2.Interface().(raw)
|
||||
// RawMessage
|
||||
if !bytes.Equal(b1.Bytes(), b2.Bytes()) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
f1, f2 = f1.Elem(), f2.Elem()
|
||||
}
|
||||
if !equalAny(f1, f2) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
|
||||
em2 := v2.FieldByName("XXX_extensions")
|
||||
if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
uf := v1.FieldByName("XXX_unrecognized")
|
||||
if !uf.IsValid() {
|
||||
return true
|
||||
}
|
||||
|
||||
u1 := uf.Bytes()
|
||||
u2 := v2.FieldByName("XXX_unrecognized").Bytes()
|
||||
if !bytes.Equal(u1, u2) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// v1 and v2 are known to have the same type.
|
||||
func equalAny(v1, v2 reflect.Value) bool {
|
||||
if v1.Type() == protoMessageType {
|
||||
m1, _ := v1.Interface().(Message)
|
||||
m2, _ := v2.Interface().(Message)
|
||||
return Equal(m1, m2)
|
||||
}
|
||||
switch v1.Kind() {
|
||||
case reflect.Bool:
|
||||
return v1.Bool() == v2.Bool()
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v1.Float() == v2.Float()
|
||||
case reflect.Int32, reflect.Int64:
|
||||
return v1.Int() == v2.Int()
|
||||
case reflect.Interface:
|
||||
// Probably a oneof field; compare the inner values.
|
||||
n1, n2 := v1.IsNil(), v2.IsNil()
|
||||
if n1 || n2 {
|
||||
return n1 == n2
|
||||
}
|
||||
e1, e2 := v1.Elem(), v2.Elem()
|
||||
if e1.Type() != e2.Type() {
|
||||
return false
|
||||
}
|
||||
return equalAny(e1, e2)
|
||||
case reflect.Map:
|
||||
if v1.Len() != v2.Len() {
|
||||
return false
|
||||
}
|
||||
for _, key := range v1.MapKeys() {
|
||||
val2 := v2.MapIndex(key)
|
||||
if !val2.IsValid() {
|
||||
// This key was not found in the second map.
|
||||
return false
|
||||
}
|
||||
if !equalAny(v1.MapIndex(key), val2) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case reflect.Ptr:
|
||||
return equalAny(v1.Elem(), v2.Elem())
|
||||
case reflect.Slice:
|
||||
if v1.Type().Elem().Kind() == reflect.Uint8 {
|
||||
// short circuit: []byte
|
||||
if v1.IsNil() != v2.IsNil() {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(v1.Interface().([]byte), v2.Interface().([]byte))
|
||||
}
|
||||
|
||||
if v1.Len() != v2.Len() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < v1.Len(); i++ {
|
||||
if !equalAny(v1.Index(i), v2.Index(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case reflect.String:
|
||||
return v1.Interface().(string) == v2.Interface().(string)
|
||||
case reflect.Struct:
|
||||
return equalStruct(v1, v2)
|
||||
case reflect.Uint32, reflect.Uint64:
|
||||
return v1.Uint() == v2.Uint()
|
||||
}
|
||||
|
||||
// unknown type, so not a protocol buffer
|
||||
log.Printf("proto: don't know how to compare %v", v1)
|
||||
return false
|
||||
}
|
||||
|
||||
// base is the struct type that the extensions are based on.
|
||||
// em1 and em2 are extension maps.
|
||||
func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
|
||||
if len(em1) != len(em2) {
|
||||
return false
|
||||
}
|
||||
|
||||
for extNum, e1 := range em1 {
|
||||
e2, ok := em2[extNum]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
m1, m2 := e1.value, e2.value
|
||||
|
||||
if m1 != nil && m2 != nil {
|
||||
// Both are unencoded.
|
||||
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// At least one is encoded. To do a semantically correct comparison
|
||||
// we need to unmarshal them first.
|
||||
var desc *ExtensionDesc
|
||||
if m := extensionMaps[base]; m != nil {
|
||||
desc = m[extNum]
|
||||
}
|
||||
if desc == nil {
|
||||
log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
|
||||
continue
|
||||
}
|
||||
var err error
|
||||
if m1 == nil {
|
||||
m1, err = decodeExtension(e1.enc, desc)
|
||||
}
|
||||
if m2 == nil && err == nil {
|
||||
m2, err = decodeExtension(e2.enc, desc)
|
||||
}
|
||||
if err != nil {
|
||||
// The encoded form is invalid.
|
||||
log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
|
||||
return false
|
||||
}
|
||||
if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
@ -0,0 +1,209 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/golang/protobuf/proto"
|
||||
pb "github.com/golang/protobuf/proto/testdata"
|
||||
)
|
||||
|
||||
// Four identical base messages.
|
||||
// The init function adds extensions to some of them.
|
||||
var messageWithoutExtension = &pb.MyMessage{Count: Int32(7)}
|
||||
var messageWithExtension1a = &pb.MyMessage{Count: Int32(7)}
|
||||
var messageWithExtension1b = &pb.MyMessage{Count: Int32(7)}
|
||||
var messageWithExtension2 = &pb.MyMessage{Count: Int32(7)}
|
||||
|
||||
// Two messages with non-message extensions.
|
||||
var messageWithInt32Extension1 = &pb.MyMessage{Count: Int32(8)}
|
||||
var messageWithInt32Extension2 = &pb.MyMessage{Count: Int32(8)}
|
||||
|
||||
func init() {
|
||||
ext1 := &pb.Ext{Data: String("Kirk")}
|
||||
ext2 := &pb.Ext{Data: String("Picard")}
|
||||
|
||||
// messageWithExtension1a has ext1, but never marshals it.
|
||||
if err := SetExtension(messageWithExtension1a, pb.E_Ext_More, ext1); err != nil {
|
||||
panic("SetExtension on 1a failed: " + err.Error())
|
||||
}
|
||||
|
||||
// messageWithExtension1b is the unmarshaled form of messageWithExtension1a.
|
||||
if err := SetExtension(messageWithExtension1b, pb.E_Ext_More, ext1); err != nil {
|
||||
panic("SetExtension on 1b failed: " + err.Error())
|
||||
}
|
||||
buf, err := Marshal(messageWithExtension1b)
|
||||
if err != nil {
|
||||
panic("Marshal of 1b failed: " + err.Error())
|
||||
}
|
||||
messageWithExtension1b.Reset()
|
||||
if err := Unmarshal(buf, messageWithExtension1b); err != nil {
|
||||
panic("Unmarshal of 1b failed: " + err.Error())
|
||||
}
|
||||
|
||||
// messageWithExtension2 has ext2.
|
||||
if err := SetExtension(messageWithExtension2, pb.E_Ext_More, ext2); err != nil {
|
||||
panic("SetExtension on 2 failed: " + err.Error())
|
||||
}
|
||||
|
||||
if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(23)); err != nil {
|
||||
panic("SetExtension on Int32-1 failed: " + err.Error())
|
||||
}
|
||||
if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(24)); err != nil {
|
||||
panic("SetExtension on Int32-2 failed: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
var EqualTests = []struct {
|
||||
desc string
|
||||
a, b Message
|
||||
exp bool
|
||||
}{
|
||||
{"different types", &pb.GoEnum{}, &pb.GoTestField{}, false},
|
||||
{"equal empty", &pb.GoEnum{}, &pb.GoEnum{}, true},
|
||||
{"nil vs nil", nil, nil, true},
|
||||
{"typed nil vs typed nil", (*pb.GoEnum)(nil), (*pb.GoEnum)(nil), true},
|
||||
{"typed nil vs empty", (*pb.GoEnum)(nil), &pb.GoEnum{}, false},
|
||||
{"different typed nil", (*pb.GoEnum)(nil), (*pb.GoTestField)(nil), false},
|
||||
|
||||
{"one set field, one unset field", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{}, false},
|
||||
{"one set field zero, one unset field", &pb.GoTest{Param: Int32(0)}, &pb.GoTest{}, false},
|
||||
{"different set fields", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("bar")}, false},
|
||||
{"equal set", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("foo")}, true},
|
||||
|
||||
{"repeated, one set", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{}, false},
|
||||
{"repeated, different length", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{F_Int32Repeated: []int32{2}}, false},
|
||||
{"repeated, different value", &pb.GoTest{F_Int32Repeated: []int32{2}}, &pb.GoTest{F_Int32Repeated: []int32{3}}, false},
|
||||
{"repeated, equal", &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, true},
|
||||
{"repeated, nil equal nil", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: nil}, true},
|
||||
{"repeated, nil equal empty", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: []int32{}}, true},
|
||||
{"repeated, empty equal nil", &pb.GoTest{F_Int32Repeated: []int32{}}, &pb.GoTest{F_Int32Repeated: nil}, true},
|
||||
|
||||
{
|
||||
"nested, different",
|
||||
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("foo")}},
|
||||
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("bar")}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"nested, equal",
|
||||
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}},
|
||||
&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}},
|
||||
true,
|
||||
},
|
||||
|
||||
{"bytes", &pb.OtherMessage{Value: []byte("foo")}, &pb.OtherMessage{Value: []byte("foo")}, true},
|
||||
{"bytes, empty", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: []byte{}}, true},
|
||||
{"bytes, empty vs nil", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: nil}, false},
|
||||
{
|
||||
"repeated bytes",
|
||||
&pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}},
|
||||
&pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}},
|
||||
true,
|
||||
},
|
||||
|
||||
{"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false},
|
||||
{"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true},
|
||||
{"extension vs. different extension", messageWithExtension1a, messageWithExtension2, false},
|
||||
|
||||
{"int32 extension vs. itself", messageWithInt32Extension1, messageWithInt32Extension1, true},
|
||||
{"int32 extension vs. a different int32", messageWithInt32Extension1, messageWithInt32Extension2, false},
|
||||
|
||||
{
|
||||
"message with group",
|
||||
&pb.MyMessage{
|
||||
Count: Int32(1),
|
||||
Somegroup: &pb.MyMessage_SomeGroup{
|
||||
GroupField: Int32(5),
|
||||
},
|
||||
},
|
||||
&pb.MyMessage{
|
||||
Count: Int32(1),
|
||||
Somegroup: &pb.MyMessage_SomeGroup{
|
||||
GroupField: Int32(5),
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
|
||||
{
|
||||
"map same",
|
||||
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
|
||||
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"map different entry",
|
||||
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
|
||||
&pb.MessageWithMap{NameMapping: map[int32]string{2: "Rob"}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"map different key only",
|
||||
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
|
||||
&pb.MessageWithMap{NameMapping: map[int32]string{2: "Ken"}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"map different value only",
|
||||
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}},
|
||||
&pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"oneof same",
|
||||
&pb.Communique{Union: &pb.Communique_Number{41}},
|
||||
&pb.Communique{Union: &pb.Communique_Number{41}},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"oneof one nil",
|
||||
&pb.Communique{Union: &pb.Communique_Number{41}},
|
||||
&pb.Communique{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"oneof different",
|
||||
&pb.Communique{Union: &pb.Communique_Number{41}},
|
||||
&pb.Communique{Union: &pb.Communique_Name{"Bobby Tables"}},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
func TestEqual(t *testing.T) {
|
||||
for _, tc := range EqualTests {
|
||||
if res := Equal(tc.a, tc.b); res != tc.exp {
|
||||
t.Errorf("%v: Equal(%v, %v) = %v, want %v", tc.desc, tc.a, tc.b, res, tc.exp)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,400 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
/*
|
||||
* Types and routines for supporting protocol buffer extensions.
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
|
||||
var ErrMissingExtension = errors.New("proto: missing extension")
|
||||
|
||||
// ExtensionRange represents a range of message extensions for a protocol buffer.
|
||||
// Used in code generated by the protocol compiler.
|
||||
type ExtensionRange struct {
|
||||
Start, End int32 // both inclusive
|
||||
}
|
||||
|
||||
// extendableProto is an interface implemented by any protocol buffer that may be extended.
|
||||
type extendableProto interface {
|
||||
Message
|
||||
ExtensionRangeArray() []ExtensionRange
|
||||
ExtensionMap() map[int32]Extension
|
||||
}
|
||||
|
||||
var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
|
||||
|
||||
// ExtensionDesc represents an extension specification.
|
||||
// Used in generated code from the protocol compiler.
|
||||
type ExtensionDesc struct {
|
||||
ExtendedType Message // nil pointer to the type that is being extended
|
||||
ExtensionType interface{} // nil pointer to the extension type
|
||||
Field int32 // field number
|
||||
Name string // fully-qualified name of extension, for text formatting
|
||||
Tag string // protobuf tag style
|
||||
}
|
||||
|
||||
func (ed *ExtensionDesc) repeated() bool {
|
||||
t := reflect.TypeOf(ed.ExtensionType)
|
||||
return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
|
||||
}
|
||||
|
||||
// Extension represents an extension in a message.
|
||||
type Extension struct {
|
||||
// When an extension is stored in a message using SetExtension
|
||||
// only desc and value are set. When the message is marshaled
|
||||
// enc will be set to the encoded form of the message.
|
||||
//
|
||||
// When a message is unmarshaled and contains extensions, each
|
||||
// extension will have only enc set. When such an extension is
|
||||
// accessed using GetExtension (or GetExtensions) desc and value
|
||||
// will be set.
|
||||
desc *ExtensionDesc
|
||||
value interface{}
|
||||
enc []byte
|
||||
}
|
||||
|
||||
// SetRawExtension is for testing only.
|
||||
func SetRawExtension(base extendableProto, id int32, b []byte) {
|
||||
base.ExtensionMap()[id] = Extension{enc: b}
|
||||
}
|
||||
|
||||
// isExtensionField returns true iff the given field number is in an extension range.
|
||||
func isExtensionField(pb extendableProto, field int32) bool {
|
||||
for _, er := range pb.ExtensionRangeArray() {
|
||||
if er.Start <= field && field <= er.End {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// checkExtensionTypes checks that the given extension is valid for pb.
|
||||
func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
|
||||
// Check the extended type.
|
||||
if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
|
||||
return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
|
||||
}
|
||||
// Check the range.
|
||||
if !isExtensionField(pb, extension.Field) {
|
||||
return errors.New("proto: bad extension number; not in declared ranges")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// extPropKey is sufficient to uniquely identify an extension.
|
||||
type extPropKey struct {
|
||||
base reflect.Type
|
||||
field int32
|
||||
}
|
||||
|
||||
var extProp = struct {
|
||||
sync.RWMutex
|
||||
m map[extPropKey]*Properties
|
||||
}{
|
||||
m: make(map[extPropKey]*Properties),
|
||||
}
|
||||
|
||||
func extensionProperties(ed *ExtensionDesc) *Properties {
|
||||
key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
|
||||
|
||||
extProp.RLock()
|
||||
if prop, ok := extProp.m[key]; ok {
|
||||
extProp.RUnlock()
|
||||
return prop
|
||||
}
|
||||
extProp.RUnlock()
|
||||
|
||||
extProp.Lock()
|
||||
defer extProp.Unlock()
|
||||
// Check again.
|
||||
if prop, ok := extProp.m[key]; ok {
|
||||
return prop
|
||||
}
|
||||
|
||||
prop := new(Properties)
|
||||
prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
|
||||
extProp.m[key] = prop
|
||||
return prop
|
||||
}
|
||||
|
||||
// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
|
||||
func encodeExtensionMap(m map[int32]Extension) error {
|
||||
for k, e := range m {
|
||||
if e.value == nil || e.desc == nil {
|
||||
// Extension is only in its encoded form.
|
||||
continue
|
||||
}
|
||||
|
||||
// We don't skip extensions that have an encoded form set,
|
||||
// because the extension value may have been mutated after
|
||||
// the last time this function was called.
|
||||
|
||||
et := reflect.TypeOf(e.desc.ExtensionType)
|
||||
props := extensionProperties(e.desc)
|
||||
|
||||
p := NewBuffer(nil)
|
||||
// If e.value has type T, the encoder expects a *struct{ X T }.
|
||||
// Pass a *T with a zero field and hope it all works out.
|
||||
x := reflect.New(et)
|
||||
x.Elem().Set(reflect.ValueOf(e.value))
|
||||
if err := props.enc(p, props, toStructPointer(x)); err != nil {
|
||||
return err
|
||||
}
|
||||
e.enc = p.buf
|
||||
m[k] = e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sizeExtensionMap(m map[int32]Extension) (n int) {
|
||||
for _, e := range m {
|
||||
if e.value == nil || e.desc == nil {
|
||||
// Extension is only in its encoded form.
|
||||
n += len(e.enc)
|
||||
continue
|
||||
}
|
||||
|
||||
// We don't skip extensions that have an encoded form set,
|
||||
// because the extension value may have been mutated after
|
||||
// the last time this function was called.
|
||||
|
||||
et := reflect.TypeOf(e.desc.ExtensionType)
|
||||
props := extensionProperties(e.desc)
|
||||
|
||||
// If e.value has type T, the encoder expects a *struct{ X T }.
|
||||
// Pass a *T with a zero field and hope it all works out.
|
||||
x := reflect.New(et)
|
||||
x.Elem().Set(reflect.ValueOf(e.value))
|
||||
n += props.size(props, toStructPointer(x))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// HasExtension returns whether the given extension is present in pb.
|
||||
func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
|
||||
// TODO: Check types, field numbers, etc.?
|
||||
_, ok := pb.ExtensionMap()[extension.Field]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ClearExtension removes the given extension from pb.
|
||||
func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
|
||||
// TODO: Check types, field numbers, etc.?
|
||||
delete(pb.ExtensionMap(), extension.Field)
|
||||
}
|
||||
|
||||
// GetExtension parses and returns the given extension of pb.
|
||||
// If the extension is not present and has no default value it returns ErrMissingExtension.
|
||||
func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) {
|
||||
if err := checkExtensionTypes(pb, extension); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
emap := pb.ExtensionMap()
|
||||
e, ok := emap[extension.Field]
|
||||
if !ok {
|
||||
// defaultExtensionValue returns the default value or
|
||||
// ErrMissingExtension if there is no default.
|
||||
return defaultExtensionValue(extension)
|
||||
}
|
||||
|
||||
if e.value != nil {
|
||||
// Already decoded. Check the descriptor, though.
|
||||
if e.desc != extension {
|
||||
// This shouldn't happen. If it does, it means that
|
||||
// GetExtension was called twice with two different
|
||||
// descriptors with the same field number.
|
||||
return nil, errors.New("proto: descriptor conflict")
|
||||
}
|
||||
return e.value, nil
|
||||
}
|
||||
|
||||
v, err := decodeExtension(e.enc, extension)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Remember the decoded version and drop the encoded version.
|
||||
// That way it is safe to mutate what we return.
|
||||
e.value = v
|
||||
e.desc = extension
|
||||
e.enc = nil
|
||||
emap[extension.Field] = e
|
||||
return e.value, nil
|
||||
}
|
||||
|
||||
// defaultExtensionValue returns the default value for extension.
|
||||
// If no default for an extension is defined ErrMissingExtension is returned.
|
||||
func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
|
||||
t := reflect.TypeOf(extension.ExtensionType)
|
||||
props := extensionProperties(extension)
|
||||
|
||||
sf, _, err := fieldDefault(t, props)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sf == nil || sf.value == nil {
|
||||
// There is no default value.
|
||||
return nil, ErrMissingExtension
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Ptr {
|
||||
// We do not need to return a Ptr, we can directly return sf.value.
|
||||
return sf.value, nil
|
||||
}
|
||||
|
||||
// We need to return an interface{} that is a pointer to sf.value.
|
||||
value := reflect.New(t).Elem()
|
||||
value.Set(reflect.New(value.Type().Elem()))
|
||||
if sf.kind == reflect.Int32 {
|
||||
// We may have an int32 or an enum, but the underlying data is int32.
|
||||
// Since we can't set an int32 into a non int32 reflect.value directly
|
||||
// set it as a int32.
|
||||
value.Elem().SetInt(int64(sf.value.(int32)))
|
||||
} else {
|
||||
value.Elem().Set(reflect.ValueOf(sf.value))
|
||||
}
|
||||
return value.Interface(), nil
|
||||
}
|
||||
|
||||
// decodeExtension decodes an extension encoded in b.
|
||||
func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
|
||||
o := NewBuffer(b)
|
||||
|
||||
t := reflect.TypeOf(extension.ExtensionType)
|
||||
rep := extension.repeated()
|
||||
|
||||
props := extensionProperties(extension)
|
||||
|
||||
// t is a pointer to a struct, pointer to basic type or a slice.
|
||||
// Allocate a "field" to store the pointer/slice itself; the
|
||||
// pointer/slice will be stored here. We pass
|
||||
// the address of this field to props.dec.
|
||||
// This passes a zero field and a *t and lets props.dec
|
||||
// interpret it as a *struct{ x t }.
|
||||
value := reflect.New(t).Elem()
|
||||
|
||||
for {
|
||||
// Discard wire type and field number varint. It isn't needed.
|
||||
if _, err := o.DecodeVarint(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !rep || o.index >= len(o.buf) {
|
||||
break
|
||||
}
|
||||
}
|
||||
return value.Interface(), nil
|
||||
}
|
||||
|
||||
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
|
||||
// The returned slice has the same length as es; missing extensions will appear as nil elements.
|
||||
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
|
||||
epb, ok := pb.(extendableProto)
|
||||
if !ok {
|
||||
err = errors.New("proto: not an extendable proto")
|
||||
return
|
||||
}
|
||||
extensions = make([]interface{}, len(es))
|
||||
for i, e := range es {
|
||||
extensions[i], err = GetExtension(epb, e)
|
||||
if err == ErrMissingExtension {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// SetExtension sets the specified extension of pb to the specified value.
|
||||
func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error {
|
||||
if err := checkExtensionTypes(pb, extension); err != nil {
|
||||
return err
|
||||
}
|
||||
typ := reflect.TypeOf(extension.ExtensionType)
|
||||
if typ != reflect.TypeOf(value) {
|
||||
return errors.New("proto: bad extension value type")
|
||||
}
|
||||
// nil extension values need to be caught early, because the
|
||||
// encoder can't distinguish an ErrNil due to a nil extension
|
||||
// from an ErrNil due to a missing field. Extensions are
|
||||
// always optional, so the encoder would just swallow the error
|
||||
// and drop all the extensions from the encoded message.
|
||||
if reflect.ValueOf(value).IsNil() {
|
||||
return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
|
||||
}
|
||||
|
||||
pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A global registry of extensions.
|
||||
// The generated code will register the generated descriptors by calling RegisterExtension.
|
||||
|
||||
var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
|
||||
|
||||
// RegisterExtension is called from the generated code.
|
||||
func RegisterExtension(desc *ExtensionDesc) {
|
||||
st := reflect.TypeOf(desc.ExtendedType).Elem()
|
||||
m := extensionMaps[st]
|
||||
if m == nil {
|
||||
m = make(map[int32]*ExtensionDesc)
|
||||
extensionMaps[st] = m
|
||||
}
|
||||
if _, ok := m[desc.Field]; ok {
|
||||
panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
|
||||
}
|
||||
m[desc.Field] = desc
|
||||
}
|
||||
|
||||
// RegisteredExtensions returns a map of the registered extensions of a
|
||||
// protocol buffer struct, indexed by the extension number.
|
||||
// The argument pb should be a nil pointer to the struct type.
|
||||
func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
|
||||
return extensionMaps[reflect.TypeOf(pb).Elem()]
|
||||
}
|
@ -0,0 +1,292 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
pb "github.com/golang/protobuf/proto/testdata"
|
||||
)
|
||||
|
||||
func TestGetExtensionsWithMissingExtensions(t *testing.T) {
|
||||
msg := &pb.MyMessage{}
|
||||
ext1 := &pb.Ext{}
|
||||
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
|
||||
t.Fatalf("Could not set ext1: %s", ext1)
|
||||
}
|
||||
exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
|
||||
pb.E_Ext_More,
|
||||
pb.E_Ext_Text,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetExtensions() failed: %s", err)
|
||||
}
|
||||
if exts[0] != ext1 {
|
||||
t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
|
||||
}
|
||||
if exts[1] != nil {
|
||||
t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExtensionStability(t *testing.T) {
|
||||
check := func(m *pb.MyMessage) bool {
|
||||
ext1, err := proto.GetExtension(m, pb.E_Ext_More)
|
||||
if err != nil {
|
||||
t.Fatalf("GetExtension() failed: %s", err)
|
||||
}
|
||||
ext2, err := proto.GetExtension(m, pb.E_Ext_More)
|
||||
if err != nil {
|
||||
t.Fatalf("GetExtension() failed: %s", err)
|
||||
}
|
||||
return ext1 == ext2
|
||||
}
|
||||
msg := &pb.MyMessage{Count: proto.Int32(4)}
|
||||
ext0 := &pb.Ext{}
|
||||
if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
|
||||
t.Fatalf("Could not set ext1: %s", ext0)
|
||||
}
|
||||
if !check(msg) {
|
||||
t.Errorf("GetExtension() not stable before marshaling")
|
||||
}
|
||||
bb, err := proto.Marshal(msg)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal() failed: %s", err)
|
||||
}
|
||||
msg1 := &pb.MyMessage{}
|
||||
err = proto.Unmarshal(bb, msg1)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal() failed: %s", err)
|
||||
}
|
||||
if !check(msg1) {
|
||||
t.Errorf("GetExtension() not stable after unmarshaling")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExtensionDefaults(t *testing.T) {
|
||||
var setFloat64 float64 = 1
|
||||
var setFloat32 float32 = 2
|
||||
var setInt32 int32 = 3
|
||||
var setInt64 int64 = 4
|
||||
var setUint32 uint32 = 5
|
||||
var setUint64 uint64 = 6
|
||||
var setBool = true
|
||||
var setBool2 = false
|
||||
var setString = "Goodnight string"
|
||||
var setBytes = []byte("Goodnight bytes")
|
||||
var setEnum = pb.DefaultsMessage_TWO
|
||||
|
||||
type testcase struct {
|
||||
ext *proto.ExtensionDesc // Extension we are testing.
|
||||
want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail).
|
||||
def interface{} // Expected value of extension after ClearExtension().
|
||||
}
|
||||
tests := []testcase{
|
||||
{pb.E_NoDefaultDouble, setFloat64, nil},
|
||||
{pb.E_NoDefaultFloat, setFloat32, nil},
|
||||
{pb.E_NoDefaultInt32, setInt32, nil},
|
||||
{pb.E_NoDefaultInt64, setInt64, nil},
|
||||
{pb.E_NoDefaultUint32, setUint32, nil},
|
||||
{pb.E_NoDefaultUint64, setUint64, nil},
|
||||
{pb.E_NoDefaultSint32, setInt32, nil},
|
||||
{pb.E_NoDefaultSint64, setInt64, nil},
|
||||
{pb.E_NoDefaultFixed32, setUint32, nil},
|
||||
{pb.E_NoDefaultFixed64, setUint64, nil},
|
||||
{pb.E_NoDefaultSfixed32, setInt32, nil},
|
||||
{pb.E_NoDefaultSfixed64, setInt64, nil},
|
||||
{pb.E_NoDefaultBool, setBool, nil},
|
||||
{pb.E_NoDefaultBool, setBool2, nil},
|
||||
{pb.E_NoDefaultString, setString, nil},
|
||||
{pb.E_NoDefaultBytes, setBytes, nil},
|
||||
{pb.E_NoDefaultEnum, setEnum, nil},
|
||||
{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
|
||||
{pb.E_DefaultFloat, setFloat32, float32(3.14)},
|
||||
{pb.E_DefaultInt32, setInt32, int32(42)},
|
||||
{pb.E_DefaultInt64, setInt64, int64(43)},
|
||||
{pb.E_DefaultUint32, setUint32, uint32(44)},
|
||||
{pb.E_DefaultUint64, setUint64, uint64(45)},
|
||||
{pb.E_DefaultSint32, setInt32, int32(46)},
|
||||
{pb.E_DefaultSint64, setInt64, int64(47)},
|
||||
{pb.E_DefaultFixed32, setUint32, uint32(48)},
|
||||
{pb.E_DefaultFixed64, setUint64, uint64(49)},
|
||||
{pb.E_DefaultSfixed32, setInt32, int32(50)},
|
||||
{pb.E_DefaultSfixed64, setInt64, int64(51)},
|
||||
{pb.E_DefaultBool, setBool, true},
|
||||
{pb.E_DefaultBool, setBool2, true},
|
||||
{pb.E_DefaultString, setString, "Hello, string"},
|
||||
{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
|
||||
{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
|
||||
}
|
||||
|
||||
checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
|
||||
val, err := proto.GetExtension(msg, test.ext)
|
||||
if err != nil {
|
||||
if valWant != nil {
|
||||
return fmt.Errorf("GetExtension(): %s", err)
|
||||
}
|
||||
if want := proto.ErrMissingExtension; err != want {
|
||||
return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// All proto2 extension values are either a pointer to a value or a slice of values.
|
||||
ty := reflect.TypeOf(val)
|
||||
tyWant := reflect.TypeOf(test.ext.ExtensionType)
|
||||
if got, want := ty, tyWant; got != want {
|
||||
return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
|
||||
}
|
||||
tye := ty.Elem()
|
||||
tyeWant := tyWant.Elem()
|
||||
if got, want := tye, tyeWant; got != want {
|
||||
return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
|
||||
}
|
||||
|
||||
// Check the name of the type of the value.
|
||||
// If it is an enum it will be type int32 with the name of the enum.
|
||||
if got, want := tye.Name(), tye.Name(); got != want {
|
||||
return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
|
||||
}
|
||||
|
||||
// Check that value is what we expect.
|
||||
// If we have a pointer in val, get the value it points to.
|
||||
valExp := val
|
||||
if ty.Kind() == reflect.Ptr {
|
||||
valExp = reflect.ValueOf(val).Elem().Interface()
|
||||
}
|
||||
if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
|
||||
return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
setTo := func(test testcase) interface{} {
|
||||
setTo := reflect.ValueOf(test.want)
|
||||
if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
|
||||
setTo = reflect.New(typ).Elem()
|
||||
setTo.Set(reflect.New(setTo.Type().Elem()))
|
||||
setTo.Elem().Set(reflect.ValueOf(test.want))
|
||||
}
|
||||
return setTo.Interface()
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
msg := &pb.DefaultsMessage{}
|
||||
name := test.ext.Name
|
||||
|
||||
// Check the initial value.
|
||||
if err := checkVal(test, msg, test.def); err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
}
|
||||
|
||||
// Set the per-type value and check value.
|
||||
name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
|
||||
if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
|
||||
t.Errorf("%s: SetExtension(): %v", name, err)
|
||||
continue
|
||||
}
|
||||
if err := checkVal(test, msg, test.want); err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Set and check the value.
|
||||
name += " (cleared)"
|
||||
proto.ClearExtension(msg, test.ext)
|
||||
if err := checkVal(test, msg, test.def); err != nil {
|
||||
t.Errorf("%s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionsRoundTrip(t *testing.T) {
|
||||
msg := &pb.MyMessage{}
|
||||
ext1 := &pb.Ext{
|
||||
Data: proto.String("hi"),
|
||||
}
|
||||
ext2 := &pb.Ext{
|
||||
Data: proto.String("there"),
|
||||
}
|
||||
exists := proto.HasExtension(msg, pb.E_Ext_More)
|
||||
if exists {
|
||||
t.Error("Extension More present unexpectedly")
|
||||
}
|
||||
if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
e, err := proto.GetExtension(msg, pb.E_Ext_More)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
x, ok := e.(*pb.Ext)
|
||||
if !ok {
|
||||
t.Errorf("e has type %T, expected testdata.Ext", e)
|
||||
} else if *x.Data != "there" {
|
||||
t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
|
||||
}
|
||||
proto.ClearExtension(msg, pb.E_Ext_More)
|
||||
if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
|
||||
t.Errorf("got %v, expected ErrMissingExtension", e)
|
||||
}
|
||||
if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
|
||||
t.Error("expected bad extension error, got nil")
|
||||
}
|
||||
if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
|
||||
t.Error("expected extension err")
|
||||
}
|
||||
if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
|
||||
t.Error("expected some sort of type mismatch error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilExtension(t *testing.T) {
|
||||
msg := &pb.MyMessage{
|
||||
Count: proto.Int32(1),
|
||||
}
|
||||
if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
|
||||
t.Error("expected SetExtension to fail due to a nil extension")
|
||||
} else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want {
|
||||
t.Errorf("expected error %v, got %v", want, err)
|
||||
}
|
||||
// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
|
||||
// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
|
||||
}
|
@ -0,0 +1,883 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
/*
|
||||
Package proto converts data structures to and from the wire format of
|
||||
protocol buffers. It works in concert with the Go source code generated
|
||||
for .proto files by the protocol compiler.
|
||||
|
||||
A summary of the properties of the protocol buffer interface
|
||||
for a protocol buffer variable v:
|
||||
|
||||
- Names are turned from camel_case to CamelCase for export.
|
||||
- There are no methods on v to set fields; just treat
|
||||
them as structure fields.
|
||||
- There are getters that return a field's value if set,
|
||||
and return the field's default value if unset.
|
||||
The getters work even if the receiver is a nil message.
|
||||
- The zero value for a struct is its correct initialization state.
|
||||
All desired fields must be set before marshaling.
|
||||
- A Reset() method will restore a protobuf struct to its zero state.
|
||||
- Non-repeated fields are pointers to the values; nil means unset.
|
||||
That is, optional or required field int32 f becomes F *int32.
|
||||
- Repeated fields are slices.
|
||||
- Helper functions are available to aid the setting of fields.
|
||||
msg.Foo = proto.String("hello") // set field
|
||||
- Constants are defined to hold the default values of all fields that
|
||||
have them. They have the form Default_StructName_FieldName.
|
||||
Because the getter methods handle defaulted values,
|
||||
direct use of these constants should be rare.
|
||||
- Enums are given type names and maps from names to values.
|
||||
Enum values are prefixed by the enclosing message's name, or by the
|
||||
enum's type name if it is a top-level enum. Enum types have a String
|
||||
method, and a Enum method to assist in message construction.
|
||||
- Nested messages, groups and enums have type names prefixed with the name of
|
||||
the surrounding message type.
|
||||
- Extensions are given descriptor names that start with E_,
|
||||
followed by an underscore-delimited list of the nested messages
|
||||
that contain it (if any) followed by the CamelCased name of the
|
||||
extension field itself. HasExtension, ClearExtension, GetExtension
|
||||
and SetExtension are functions for manipulating extensions.
|
||||
- Oneof field sets are given a single field in their message,
|
||||
with distinguished wrapper types for each possible field value.
|
||||
- Marshal and Unmarshal are functions to encode and decode the wire format.
|
||||
|
||||
The simplest way to describe this is to see an example.
|
||||
Given file test.proto, containing
|
||||
|
||||
package example;
|
||||
|
||||
enum FOO { X = 17; }
|
||||
|
||||
message Test {
|
||||
required string label = 1;
|
||||
optional int32 type = 2 [default=77];
|
||||
repeated int64 reps = 3;
|
||||
optional group OptionalGroup = 4 {
|
||||
required string RequiredField = 5;
|
||||
}
|
||||
oneof union {
|
||||
int32 number = 6;
|
||||
string name = 7;
|
||||
}
|
||||
}
|
||||
|
||||
The resulting file, test.pb.go, is:
|
||||
|
||||
package example
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import math "math"
|
||||
|
||||
type FOO int32
|
||||
const (
|
||||
FOO_X FOO = 17
|
||||
)
|
||||
var FOO_name = map[int32]string{
|
||||
17: "X",
|
||||
}
|
||||
var FOO_value = map[string]int32{
|
||||
"X": 17,
|
||||
}
|
||||
|
||||
func (x FOO) Enum() *FOO {
|
||||
p := new(FOO)
|
||||
*p = x
|
||||
return p
|
||||
}
|
||||
func (x FOO) String() string {
|
||||
return proto.EnumName(FOO_name, int32(x))
|
||||
}
|
||||
func (x *FOO) UnmarshalJSON(data []byte) error {
|
||||
value, err := proto.UnmarshalJSONEnum(FOO_value, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*x = FOO(value)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Test struct {
|
||||
Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"`
|
||||
Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"`
|
||||
Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"`
|
||||
Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"`
|
||||
// Types that are valid to be assigned to Union:
|
||||
// *Test_Number
|
||||
// *Test_Name
|
||||
Union isTest_Union `protobuf_oneof:"union"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
}
|
||||
func (m *Test) Reset() { *m = Test{} }
|
||||
func (m *Test) String() string { return proto.CompactTextString(m) }
|
||||
func (*Test) ProtoMessage() {}
|
||||
|
||||
type isTest_Union interface {
|
||||
isTest_Union()
|
||||
}
|
||||
|
||||
type Test_Number struct {
|
||||
Number int32 `protobuf:"varint,6,opt,name=number"`
|
||||
}
|
||||
type Test_Name struct {
|
||||
Name string `protobuf:"bytes,7,opt,name=name"`
|
||||
}
|
||||
|
||||
func (*Test_Number) isTest_Union() {}
|
||||
func (*Test_Name) isTest_Union() {}
|
||||
|
||||
func (m *Test) GetUnion() isTest_Union {
|
||||
if m != nil {
|
||||
return m.Union
|
||||
}
|
||||
return nil
|
||||
}
|
||||
const Default_Test_Type int32 = 77
|
||||
|
||||
func (m *Test) GetLabel() string {
|
||||
if m != nil && m.Label != nil {
|
||||
return *m.Label
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Test) GetType() int32 {
|
||||
if m != nil && m.Type != nil {
|
||||
return *m.Type
|
||||
}
|
||||
return Default_Test_Type
|
||||
}
|
||||
|
||||
func (m *Test) GetOptionalgroup() *Test_OptionalGroup {
|
||||
if m != nil {
|
||||
return m.Optionalgroup
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Test_OptionalGroup struct {
|
||||
RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"`
|
||||
}
|
||||
func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} }
|
||||
func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) }
|
||||
|
||||
func (m *Test_OptionalGroup) GetRequiredField() string {
|
||||
if m != nil && m.RequiredField != nil {
|
||||
return *m.RequiredField
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Test) GetNumber() int32 {
|
||||
if x, ok := m.GetUnion().(*Test_Number); ok {
|
||||
return x.Number
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *Test) GetName() string {
|
||||
if x, ok := m.GetUnion().(*Test_Name); ok {
|
||||
return x.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterEnum("example.FOO", FOO_name, FOO_value)
|
||||
}
|
||||
|
||||
To create and play with a Test object:
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
pb "./example.pb"
|
||||
)
|
||||
|
||||
func main() {
|
||||
test := &pb.Test{
|
||||
Label: proto.String("hello"),
|
||||
Type: proto.Int32(17),
|
||||
Optionalgroup: &pb.Test_OptionalGroup{
|
||||
RequiredField: proto.String("good bye"),
|
||||
},
|
||||
Union: &pb.Test_Name{"fred"},
|
||||
}
|
||||
data, err := proto.Marshal(test)
|
||||
if err != nil {
|
||||
log.Fatal("marshaling error: ", err)
|
||||
}
|
||||
newTest := &pb.Test{}
|
||||
err = proto.Unmarshal(data, newTest)
|
||||
if err != nil {
|
||||
log.Fatal("unmarshaling error: ", err)
|
||||
}
|
||||
// Now test and newTest contain the same data.
|
||||
if test.GetLabel() != newTest.GetLabel() {
|
||||
log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel())
|
||||
}
|
||||
// Use a type switch to determine which oneof was set.
|
||||
switch u := test.Union.(type) {
|
||||
case *pb.Test_Number: // u.Number contains the number.
|
||||
case *pb.Test_Name: // u.Name contains the string.
|
||||
}
|
||||
// etc.
|
||||
}
|
||||
*/
|
||||
package proto
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Message is implemented by generated protocol buffer messages.
|
||||
type Message interface {
|
||||
Reset()
|
||||
String() string
|
||||
ProtoMessage()
|
||||
}
|
||||
|
||||
// Stats records allocation details about the protocol buffer encoders
|
||||
// and decoders. Useful for tuning the library itself.
|
||||
type Stats struct {
|
||||
Emalloc uint64 // mallocs in encode
|
||||
Dmalloc uint64 // mallocs in decode
|
||||
Encode uint64 // number of encodes
|
||||
Decode uint64 // number of decodes
|
||||
Chit uint64 // number of cache hits
|
||||
Cmiss uint64 // number of cache misses
|
||||
Size uint64 // number of sizes
|
||||
}
|
||||
|
||||
// Set to true to enable stats collection.
|
||||
const collectStats = false
|
||||
|
||||
var stats Stats
|
||||
|
||||
// GetStats returns a copy of the global Stats structure.
|
||||
func GetStats() Stats { return stats }
|
||||
|
||||
// A Buffer is a buffer manager for marshaling and unmarshaling
|
||||
// protocol buffers. It may be reused between invocations to
|
||||
// reduce memory usage. It is not necessary to use a Buffer;
|
||||
// the global functions Marshal and Unmarshal create a
|
||||
// temporary Buffer and are fine for most applications.
|
||||
type Buffer struct {
|
||||
buf []byte // encode/decode byte stream
|
||||
index int // write point
|
||||
|
||||
// pools of basic types to amortize allocation.
|
||||
bools []bool
|
||||
uint32s []uint32
|
||||
uint64s []uint64
|
||||
|
||||
// extra pools, only used with pointer_reflect.go
|
||||
int32s []int32
|
||||
int64s []int64
|
||||
float32s []float32
|
||||
float64s []float64
|
||||
}
|
||||
|
||||
// NewBuffer allocates a new Buffer and initializes its internal data to
|
||||
// the contents of the argument slice.
|
||||
func NewBuffer(e []byte) *Buffer {
|
||||
return &Buffer{buf: e}
|
||||
}
|
||||
|
||||
// Reset resets the Buffer, ready for marshaling a new protocol buffer.
|
||||
func (p *Buffer) Reset() {
|
||||
p.buf = p.buf[0:0] // for reading/writing
|
||||
p.index = 0 // for reading
|
||||
}
|
||||
|
||||
// SetBuf replaces the internal buffer with the slice,
|
||||
// ready for unmarshaling the contents of the slice.
|
||||
func (p *Buffer) SetBuf(s []byte) {
|
||||
p.buf = s
|
||||
p.index = 0
|
||||
}
|
||||
|
||||
// Bytes returns the contents of the Buffer.
|
||||
func (p *Buffer) Bytes() []byte { return p.buf }
|
||||
|
||||
/*
|
||||
* Helper routines for simplifying the creation of optional fields of basic type.
|
||||
*/
|
||||
|
||||
// Bool is a helper routine that allocates a new bool value
|
||||
// to store v and returns a pointer to it.
|
||||
func Bool(v bool) *bool {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Int32 is a helper routine that allocates a new int32 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Int32(v int32) *int32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Int is a helper routine that allocates a new int32 value
|
||||
// to store v and returns a pointer to it, but unlike Int32
|
||||
// its argument value is an int.
|
||||
func Int(v int) *int32 {
|
||||
p := new(int32)
|
||||
*p = int32(v)
|
||||
return p
|
||||
}
|
||||
|
||||
// Int64 is a helper routine that allocates a new int64 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Int64(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Float32 is a helper routine that allocates a new float32 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Float32(v float32) *float32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Float64 is a helper routine that allocates a new float64 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Float64(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Uint32 is a helper routine that allocates a new uint32 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Uint32(v uint32) *uint32 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Uint64 is a helper routine that allocates a new uint64 value
|
||||
// to store v and returns a pointer to it.
|
||||
func Uint64(v uint64) *uint64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
// String is a helper routine that allocates a new string value
|
||||
// to store v and returns a pointer to it.
|
||||
func String(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
// EnumName is a helper function to simplify printing protocol buffer enums
|
||||
// by name. Given an enum map and a value, it returns a useful string.
|
||||
func EnumName(m map[int32]string, v int32) string {
|
||||
s, ok := m[v]
|
||||
if ok {
|
||||
return s
|
||||
}
|
||||
return strconv.Itoa(int(v))
|
||||
}
|
||||
|
||||
// UnmarshalJSONEnum is a helper function to simplify recovering enum int values
|
||||
// from their JSON-encoded representation. Given a map from the enum's symbolic
|
||||
// names to its int values, and a byte buffer containing the JSON-encoded
|
||||
// value, it returns an int32 that can be cast to the enum type by the caller.
|
||||
//
|
||||
// The function can deal with both JSON representations, numeric and symbolic.
|
||||
func UnmarshalJSONEnum(m map[string]int32, data []byte, enumName string) (int32, error) {
|
||||
if data[0] == '"' {
|
||||
// New style: enums are strings.
|
||||
var repr string
|
||||
if err := json.Unmarshal(data, &repr); err != nil {
|
||||
return -1, err
|
||||
}
|
||||
val, ok := m[repr]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unrecognized enum %s value %q", enumName, repr)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
// Old style: enums are ints.
|
||||
var val int32
|
||||
if err := json.Unmarshal(data, &val); err != nil {
|
||||
return 0, fmt.Errorf("cannot unmarshal %#q into enum %s", data, enumName)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// DebugPrint dumps the encoded data in b in a debugging format with a header
|
||||
// including the string s. Used in testing but made available for general debugging.
|
||||
func (p *Buffer) DebugPrint(s string, b []byte) {
|
||||
var u uint64
|
||||
|
||||
obuf := p.buf
|
||||
index := p.index
|
||||
p.buf = b
|
||||
p.index = 0
|
||||
depth := 0
|
||||
|
||||
fmt.Printf("\n--- %s ---\n", s)
|
||||
|
||||
out:
|
||||
for {
|
||||
for i := 0; i < depth; i++ {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
|
||||
index := p.index
|
||||
if index == len(p.buf) {
|
||||
break
|
||||
}
|
||||
|
||||
op, err := p.DecodeVarint()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: fetching op err %v\n", index, err)
|
||||
break out
|
||||
}
|
||||
tag := op >> 3
|
||||
wire := op & 7
|
||||
|
||||
switch wire {
|
||||
default:
|
||||
fmt.Printf("%3d: t=%3d unknown wire=%d\n",
|
||||
index, tag, wire)
|
||||
break out
|
||||
|
||||
case WireBytes:
|
||||
var r []byte
|
||||
|
||||
r, err = p.DecodeRawBytes(false)
|
||||
if err != nil {
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d bytes [%d]", index, tag, len(r))
|
||||
if len(r) <= 6 {
|
||||
for i := 0; i < len(r); i++ {
|
||||
fmt.Printf(" %.2x", r[i])
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < 3; i++ {
|
||||
fmt.Printf(" %.2x", r[i])
|
||||
}
|
||||
fmt.Printf(" ..")
|
||||
for i := len(r) - 3; i < len(r); i++ {
|
||||
fmt.Printf(" %.2x", r[i])
|
||||
}
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
|
||||
case WireFixed32:
|
||||
u, err = p.DecodeFixed32()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d fix32 err %v\n", index, tag, err)
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d fix32 %d\n", index, tag, u)
|
||||
|
||||
case WireFixed64:
|
||||
u, err = p.DecodeFixed64()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d fix64 err %v\n", index, tag, err)
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u)
|
||||
|
||||
case WireVarint:
|
||||
u, err = p.DecodeVarint()
|
||||
if err != nil {
|
||||
fmt.Printf("%3d: t=%3d varint err %v\n", index, tag, err)
|
||||
break out
|
||||
}
|
||||
fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u)
|
||||
|
||||
case WireStartGroup:
|
||||
fmt.Printf("%3d: t=%3d start\n", index, tag)
|
||||
depth++
|
||||
|
||||
case WireEndGroup:
|
||||
depth--
|
||||
fmt.Printf("%3d: t=%3d end\n", index, tag)
|
||||
}
|
||||
}
|
||||
|
||||
if depth != 0 {
|
||||
fmt.Printf("%3d: start-end not balanced %d\n", p.index, depth)
|
||||
}
|
||||
fmt.Printf("\n")
|
||||
|
||||
p.buf = obuf
|
||||
p.index = index
|
||||
}
|
||||
|
||||
// SetDefaults sets unset protocol buffer fields to their default values.
|
||||
// It only modifies fields that are both unset and have defined defaults.
|
||||
// It recursively sets default values in any non-nil sub-messages.
|
||||
func SetDefaults(pb Message) {
|
||||
setDefaults(reflect.ValueOf(pb), true, false)
|
||||
}
|
||||
|
||||
// v is a pointer to a struct.
|
||||
func setDefaults(v reflect.Value, recur, zeros bool) {
|
||||
v = v.Elem()
|
||||
|
||||
defaultMu.RLock()
|
||||
dm, ok := defaults[v.Type()]
|
||||
defaultMu.RUnlock()
|
||||
if !ok {
|
||||
dm = buildDefaultMessage(v.Type())
|
||||
defaultMu.Lock()
|
||||
defaults[v.Type()] = dm
|
||||
defaultMu.Unlock()
|
||||
}
|
||||
|
||||
for _, sf := range dm.scalars {
|
||||
f := v.Field(sf.index)
|
||||
if !f.IsNil() {
|
||||
// field already set
|
||||
continue
|
||||
}
|
||||
dv := sf.value
|
||||
if dv == nil && !zeros {
|
||||
// no explicit default, and don't want to set zeros
|
||||
continue
|
||||
}
|
||||
fptr := f.Addr().Interface() // **T
|
||||
// TODO: Consider batching the allocations we do here.
|
||||
switch sf.kind {
|
||||
case reflect.Bool:
|
||||
b := new(bool)
|
||||
if dv != nil {
|
||||
*b = dv.(bool)
|
||||
}
|
||||
*(fptr.(**bool)) = b
|
||||
case reflect.Float32:
|
||||
f := new(float32)
|
||||
if dv != nil {
|
||||
*f = dv.(float32)
|
||||
}
|
||||
*(fptr.(**float32)) = f
|
||||
case reflect.Float64:
|
||||
f := new(float64)
|
||||
if dv != nil {
|
||||
*f = dv.(float64)
|
||||
}
|
||||
*(fptr.(**float64)) = f
|
||||
case reflect.Int32:
|
||||
// might be an enum
|
||||
if ft := f.Type(); ft != int32PtrType {
|
||||
// enum
|
||||
f.Set(reflect.New(ft.Elem()))
|
||||
if dv != nil {
|
||||
f.Elem().SetInt(int64(dv.(int32)))
|
||||
}
|
||||
} else {
|
||||
// int32 field
|
||||
i := new(int32)
|
||||
if dv != nil {
|
||||
*i = dv.(int32)
|
||||
}
|
||||
*(fptr.(**int32)) = i
|
||||
}
|
||||
case reflect.Int64:
|
||||
i := new(int64)
|
||||
if dv != nil {
|
||||
*i = dv.(int64)
|
||||
}
|
||||
*(fptr.(**int64)) = i
|
||||
case reflect.String:
|
||||
s := new(string)
|
||||
if dv != nil {
|
||||
*s = dv.(string)
|
||||
}
|
||||
*(fptr.(**string)) = s
|
||||
case reflect.Uint8:
|
||||
// exceptional case: []byte
|
||||
var b []byte
|
||||
if dv != nil {
|
||||
db := dv.([]byte)
|
||||
b = make([]byte, len(db))
|
||||
copy(b, db)
|
||||
} else {
|
||||
b = []byte{}
|
||||
}
|
||||
*(fptr.(*[]byte)) = b
|
||||
case reflect.Uint32:
|
||||
u := new(uint32)
|
||||
if dv != nil {
|
||||
*u = dv.(uint32)
|
||||
}
|
||||
*(fptr.(**uint32)) = u
|
||||
case reflect.Uint64:
|
||||
u := new(uint64)
|
||||
if dv != nil {
|
||||
*u = dv.(uint64)
|
||||
}
|
||||
*(fptr.(**uint64)) = u
|
||||
default:
|
||||
log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind)
|
||||
}
|
||||
}
|
||||
|
||||
for _, ni := range dm.nested {
|
||||
f := v.Field(ni)
|
||||
// f is *T or []*T or map[T]*T
|
||||
switch f.Kind() {
|
||||
case reflect.Ptr:
|
||||
if f.IsNil() {
|
||||
continue
|
||||
}
|
||||
setDefaults(f, recur, zeros)
|
||||
|
||||
case reflect.Slice:
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
e := f.Index(i)
|
||||
if e.IsNil() {
|
||||
continue
|
||||
}
|
||||
setDefaults(e, recur, zeros)
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
for _, k := range f.MapKeys() {
|
||||
e := f.MapIndex(k)
|
||||
if e.IsNil() {
|
||||
continue
|
||||
}
|
||||
setDefaults(e, recur, zeros)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// defaults maps a protocol buffer struct type to a slice of the fields,
|
||||
// with its scalar fields set to their proto-declared non-zero default values.
|
||||
defaultMu sync.RWMutex
|
||||
defaults = make(map[reflect.Type]defaultMessage)
|
||||
|
||||
int32PtrType = reflect.TypeOf((*int32)(nil))
|
||||
)
|
||||
|
||||
// defaultMessage represents information about the default values of a message.
|
||||
type defaultMessage struct {
|
||||
scalars []scalarField
|
||||
nested []int // struct field index of nested messages
|
||||
}
|
||||
|
||||
type scalarField struct {
|
||||
index int // struct field index
|
||||
kind reflect.Kind // element type (the T in *T or []T)
|
||||
value interface{} // the proto-declared default value, or nil
|
||||
}
|
||||
|
||||
// t is a struct type.
|
||||
func buildDefaultMessage(t reflect.Type) (dm defaultMessage) {
|
||||
sprop := GetProperties(t)
|
||||
for _, prop := range sprop.Prop {
|
||||
fi, ok := sprop.decoderTags.get(prop.Tag)
|
||||
if !ok {
|
||||
// XXX_unrecognized
|
||||
continue
|
||||
}
|
||||
ft := t.Field(fi).Type
|
||||
|
||||
sf, nested, err := fieldDefault(ft, prop)
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Print(err)
|
||||
case nested:
|
||||
dm.nested = append(dm.nested, fi)
|
||||
case sf != nil:
|
||||
sf.index = fi
|
||||
dm.scalars = append(dm.scalars, *sf)
|
||||
}
|
||||
}
|
||||
|
||||
return dm
|
||||
}
|
||||
|
||||
// fieldDefault returns the scalarField for field type ft.
|
||||
// sf will be nil if the field can not have a default.
|
||||
// nestedMessage will be true if this is a nested message.
|
||||
// Note that sf.index is not set on return.
|
||||
func fieldDefault(ft reflect.Type, prop *Properties) (sf *scalarField, nestedMessage bool, err error) {
|
||||
var canHaveDefault bool
|
||||
switch ft.Kind() {
|
||||
case reflect.Ptr:
|
||||
if ft.Elem().Kind() == reflect.Struct {
|
||||
nestedMessage = true
|
||||
} else {
|
||||
canHaveDefault = true // proto2 scalar field
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
switch ft.Elem().Kind() {
|
||||
case reflect.Ptr:
|
||||
nestedMessage = true // repeated message
|
||||
case reflect.Uint8:
|
||||
canHaveDefault = true // bytes field
|
||||
}
|
||||
|
||||
case reflect.Map:
|
||||
if ft.Elem().Kind() == reflect.Ptr {
|
||||
nestedMessage = true // map with message values
|
||||
}
|
||||
}
|
||||
|
||||
if !canHaveDefault {
|
||||
if nestedMessage {
|
||||
return nil, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// We now know that ft is a pointer or slice.
|
||||
sf = &scalarField{kind: ft.Elem().Kind()}
|
||||
|
||||
// scalar fields without defaults
|
||||
if !prop.HasDefault {
|
||||
return sf, false, nil
|
||||
}
|
||||
|
||||
// a scalar field: either *T or []byte
|
||||
switch ft.Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
x, err := strconv.ParseBool(prop.Default)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default bool %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.Float32:
|
||||
x, err := strconv.ParseFloat(prop.Default, 32)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default float32 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = float32(x)
|
||||
case reflect.Float64:
|
||||
x, err := strconv.ParseFloat(prop.Default, 64)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default float64 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.Int32:
|
||||
x, err := strconv.ParseInt(prop.Default, 10, 32)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default int32 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = int32(x)
|
||||
case reflect.Int64:
|
||||
x, err := strconv.ParseInt(prop.Default, 10, 64)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default int64 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
case reflect.String:
|
||||
sf.value = prop.Default
|
||||
case reflect.Uint8:
|
||||
// []byte (not *uint8)
|
||||
sf.value = []byte(prop.Default)
|
||||
case reflect.Uint32:
|
||||
x, err := strconv.ParseUint(prop.Default, 10, 32)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default uint32 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = uint32(x)
|
||||
case reflect.Uint64:
|
||||
x, err := strconv.ParseUint(prop.Default, 10, 64)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("proto: bad default uint64 %q: %v", prop.Default, err)
|
||||
}
|
||||
sf.value = x
|
||||
default:
|
||||
return nil, false, fmt.Errorf("proto: unhandled def kind %v", ft.Elem().Kind())
|
||||
}
|
||||
|
||||
return sf, false, nil
|
||||
}
|
||||
|
||||
// Map fields may have key types of non-float scalars, strings and enums.
|
||||
// The easiest way to sort them in some deterministic order is to use fmt.
|
||||
// If this turns out to be inefficient we can always consider other options,
|
||||
// such as doing a Schwartzian transform.
|
||||
|
||||
func mapKeys(vs []reflect.Value) sort.Interface {
|
||||
s := mapKeySorter{
|
||||
vs: vs,
|
||||
// default Less function: textual comparison
|
||||
less: func(a, b reflect.Value) bool {
|
||||
return fmt.Sprint(a.Interface()) < fmt.Sprint(b.Interface())
|
||||
},
|
||||
}
|
||||
|
||||
// Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps;
|
||||
// numeric keys are sorted numerically.
|
||||
if len(vs) == 0 {
|
||||
return s
|
||||
}
|
||||
switch vs[0].Kind() {
|
||||
case reflect.Int32, reflect.Int64:
|
||||
s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
|
||||
case reflect.Uint32, reflect.Uint64:
|
||||
s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
type mapKeySorter struct {
|
||||
vs []reflect.Value
|
||||
less func(a, b reflect.Value) bool
|
||||
}
|
||||
|
||||
func (s mapKeySorter) Len() int { return len(s.vs) }
|
||||
func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
|
||||
func (s mapKeySorter) Less(i, j int) bool {
|
||||
return s.less(s.vs[i], s.vs[j])
|
||||
}
|
||||
|
||||
// isProto3Zero reports whether v is a zero proto3 value.
|
||||
func isProto3Zero(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Bool:
|
||||
return !v.Bool()
|
||||
case reflect.Int32, reflect.Int64:
|
||||
return v.Int() == 0
|
||||
case reflect.Uint32, reflect.Uint64:
|
||||
return v.Uint() == 0
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return v.Float() == 0
|
||||
case reflect.String:
|
||||
return v.String() == ""
|
||||
}
|
||||
return false
|
||||
}
|
@ -0,0 +1,287 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2010 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
/*
|
||||
* Support for message sets.
|
||||
*/
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID.
|
||||
// A message type ID is required for storing a protocol buffer in a message set.
|
||||
var ErrNoMessageTypeId = errors.New("proto does not have a message type ID")
|
||||
|
||||
// The first two types (_MessageSet_Item and MessageSet)
|
||||
// model what the protocol compiler produces for the following protocol message:
|
||||
// message MessageSet {
|
||||
// repeated group Item = 1 {
|
||||
// required int32 type_id = 2;
|
||||
// required string message = 3;
|
||||
// };
|
||||
// }
|
||||
// That is the MessageSet wire format. We can't use a proto to generate these
|
||||
// because that would introduce a circular dependency between it and this package.
|
||||
//
|
||||
// When a proto1 proto has a field that looks like:
|
||||
// optional message<MessageSet> info = 3;
|
||||
// the protocol compiler produces a field in the generated struct that looks like:
|
||||
// Info *_proto_.MessageSet `protobuf:"bytes,3,opt,name=info"`
|
||||
// The package is automatically inserted so there is no need for that proto file to
|
||||
// import this package.
|
||||
|
||||
type _MessageSet_Item struct {
|
||||
TypeId *int32 `protobuf:"varint,2,req,name=type_id"`
|
||||
Message []byte `protobuf:"bytes,3,req,name=message"`
|
||||
}
|
||||
|
||||
type MessageSet struct {
|
||||
Item []*_MessageSet_Item `protobuf:"group,1,rep"`
|
||||
XXX_unrecognized []byte
|
||||
// TODO: caching?
|
||||
}
|
||||
|
||||
// Make sure MessageSet is a Message.
|
||||
var _ Message = (*MessageSet)(nil)
|
||||
|
||||
// messageTypeIder is an interface satisfied by a protocol buffer type
|
||||
// that may be stored in a MessageSet.
|
||||
type messageTypeIder interface {
|
||||
MessageTypeId() int32
|
||||
}
|
||||
|
||||
func (ms *MessageSet) find(pb Message) *_MessageSet_Item {
|
||||
mti, ok := pb.(messageTypeIder)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
id := mti.MessageTypeId()
|
||||
for _, item := range ms.Item {
|
||||
if *item.TypeId == id {
|
||||
return item
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *MessageSet) Has(pb Message) bool {
|
||||
if ms.find(pb) != nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (ms *MessageSet) Unmarshal(pb Message) error {
|
||||
if item := ms.find(pb); item != nil {
|
||||
return Unmarshal(item.Message, pb)
|
||||
}
|
||||
if _, ok := pb.(messageTypeIder); !ok {
|
||||
return ErrNoMessageTypeId
|
||||
}
|
||||
return nil // TODO: return error instead?
|
||||
}
|
||||
|
||||
func (ms *MessageSet) Marshal(pb Message) error {
|
||||
msg, err := Marshal(pb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if item := ms.find(pb); item != nil {
|
||||
// reuse existing item
|
||||
item.Message = msg
|
||||
return nil
|
||||
}
|
||||
|
||||
mti, ok := pb.(messageTypeIder)
|
||||
if !ok {
|
||||
return ErrNoMessageTypeId
|
||||
}
|
||||
|
||||
mtid := mti.MessageTypeId()
|
||||
ms.Item = append(ms.Item, &_MessageSet_Item{
|
||||
TypeId: &mtid,
|
||||
Message: msg,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ms *MessageSet) Reset() { *ms = MessageSet{} }
|
||||
func (ms *MessageSet) String() string { return CompactTextString(ms) }
|
||||
func (*MessageSet) ProtoMessage() {}
|
||||
|
||||
// Support for the message_set_wire_format message option.
|
||||
|
||||
func skipVarint(buf []byte) []byte {
|
||||
i := 0
|
||||
for ; buf[i]&0x80 != 0; i++ {
|
||||
}
|
||||
return buf[i+1:]
|
||||
}
|
||||
|
||||
// MarshalMessageSet encodes the extension map represented by m in the message set wire format.
|
||||
// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option.
|
||||
func MarshalMessageSet(m map[int32]Extension) ([]byte, error) {
|
||||
if err := encodeExtensionMap(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sort extension IDs to provide a deterministic encoding.
|
||||
// See also enc_map in encode.go.
|
||||
ids := make([]int, 0, len(m))
|
||||
for id := range m {
|
||||
ids = append(ids, int(id))
|
||||
}
|
||||
sort.Ints(ids)
|
||||
|
||||
ms := &MessageSet{Item: make([]*_MessageSet_Item, 0, len(m))}
|
||||
for _, id := range ids {
|
||||
e := m[int32(id)]
|
||||
// Remove the wire type and field number varint, as well as the length varint.
|
||||
msg := skipVarint(skipVarint(e.enc))
|
||||
|
||||
ms.Item = append(ms.Item, &_MessageSet_Item{
|
||||
TypeId: Int32(int32(id)),
|
||||
Message: msg,
|
||||
})
|
||||
}
|
||||
return Marshal(ms)
|
||||
}
|
||||
|
||||
// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format.
|
||||
// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option.
|
||||
func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error {
|
||||
ms := new(MessageSet)
|
||||
if err := Unmarshal(buf, ms); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, item := range ms.Item {
|
||||
id := *item.TypeId
|
||||
msg := item.Message
|
||||
|
||||
// Restore wire type and field number varint, plus length varint.
|
||||
// Be careful to preserve duplicate items.
|
||||
b := EncodeVarint(uint64(id)<<3 | WireBytes)
|
||||
if ext, ok := m[id]; ok {
|
||||
// Existing data; rip off the tag and length varint
|
||||
// so we join the new data correctly.
|
||||
// We can assume that ext.enc is set because we are unmarshaling.
|
||||
o := ext.enc[len(b):] // skip wire type and field number
|
||||
_, n := DecodeVarint(o) // calculate length of length varint
|
||||
o = o[n:] // skip length varint
|
||||
msg = append(o, msg...) // join old data and new data
|
||||
}
|
||||
b = append(b, EncodeVarint(uint64(len(msg)))...)
|
||||
b = append(b, msg...)
|
||||
|
||||
m[id] = Extension{enc: b}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalMessageSetJSON encodes the extension map represented by m in JSON format.
|
||||
// It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
|
||||
func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
b.WriteByte('{')
|
||||
|
||||
// Process the map in key order for deterministic output.
|
||||
ids := make([]int32, 0, len(m))
|
||||
for id := range m {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
sort.Sort(int32Slice(ids)) // int32Slice defined in text.go
|
||||
|
||||
for i, id := range ids {
|
||||
ext := m[id]
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
|
||||
msd, ok := messageSetMap[id]
|
||||
if !ok {
|
||||
// Unknown type; we can't render it, so skip it.
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&b, `"[%s]":`, msd.name)
|
||||
|
||||
x := ext.value
|
||||
if x == nil {
|
||||
x = reflect.New(msd.t.Elem()).Interface()
|
||||
if err := Unmarshal(ext.enc, x.(Message)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
d, err := json.Marshal(x)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
b.Write(d)
|
||||
}
|
||||
b.WriteByte('}')
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format.
|
||||
// It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option.
|
||||
func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error {
|
||||
// Common-case fast path.
|
||||
if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This is fairly tricky, and it's not clear that it is needed.
|
||||
return errors.New("TODO: UnmarshalMessageSetJSON not yet implemented")
|
||||
}
|
||||
|
||||
// A global registry of types that can be used in a MessageSet.
|
||||
|
||||
var messageSetMap = make(map[int32]messageSetDesc)
|
||||
|
||||
type messageSetDesc struct {
|
||||
t reflect.Type // pointer to struct
|
||||
name string
|
||||
}
|
||||
|
||||
// RegisterMessageSetType is called from the generated code.
|
||||
func RegisterMessageSetType(m Message, fieldNum int32, name string) {
|
||||
messageSetMap[fieldNum] = messageSetDesc{
|
||||
t: reflect.TypeOf(m),
|
||||
name: name,
|
||||
}
|
||||
}
|
@ -0,0 +1,66 @@
|
||||
// Go support for Protocol Buffers - Google's data interchange format
|
||||
//
|
||||
// Copyright 2014 The Go Authors. All rights reserved.
|
||||
// https://github.com/golang/protobuf
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without
|
||||
// modification, are permitted provided that the following conditions are
|
||||
// met:
|
||||
//
|
||||
// * Redistributions of source code must retain the above copyright
|
||||
// notice, this list of conditions and the following disclaimer.
|
||||
// * Redistributions in binary form must reproduce the above
|
||||
// copyright notice, this list of conditions and the following disclaimer
|
||||
// in the documentation and/or other materials provided with the
|
||||
// distribution.
|
||||
// * Neither the name of Google Inc. nor the names of its
|
||||
// contributors may be used to endorse or promote products derived from
|
||||
// this software without specific prior written permission.
|
||||
//
|
||||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUnmarshalMessageSetWithDuplicate(t *testing.T) {
|
||||
// Check that a repeated message set entry will be concatenated.
|
||||
in := &MessageSet{
|
||||
Item: []*_MessageSet_Item{
|
||||
{TypeId: Int32(12345), Message: []byte("hoo")},
|
||||
{TypeId: Int32(12345), Message: []byte("hah")},
|
||||
},
|
||||
}
|
||||
b, err := Marshal(in)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal: %v", err)
|
||||
}
|
||||
t.Logf("Marshaled bytes: %q", b)
|
||||
|
||||
m := make(map[int32]Extension)
|
||||
if err := UnmarshalMessageSet(b, m); err != nil {
|
||||
t.Fatalf("UnmarshalMessageSet: %v", err)
|
||||
}
|
||||
ext, ok := m[12345]
|
||||
if !ok {
|
||||
t.Fatalf("Didn't retrieve extension 12345; map is %v", m)
|
||||
}
|
||||
// Skip wire type/field number and length varints.
|
||||
got := skipVarint(skipVarint(ext.enc))
|
||||
if want := []byte("hoohah"); !bytes.Equal(got, want) {
|
||||
t.Errorf("Combined extension is %q, want %q", got, want)
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue