diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json new file mode 100644 index 0000000..10be2a1 --- /dev/null +++ b/Godeps/Godeps.json @@ -0,0 +1,57 @@ +{ + "ImportPath": "github.com/codeskyblue/gosuv", + "GoVersion": "go1.4.2", + "Deps": [ + { + "ImportPath": "github.com/bradfitz/http2", + "Rev": "f8202bc903bda493ebba4aa54922d78430c2c42f" + }, + { + "ImportPath": "github.com/codegangsta/cli", + "Comment": "1.2.0-139-g142e6cd", + "Rev": "142e6cd241a4dfbf7f07a018f1f8225180018da4" + }, + { + "ImportPath": "github.com/codeskyblue/kproc", + "Rev": "fcb55eb35ab6b7290f395ad596e80e8355d52d69" + }, + { + "ImportPath": "github.com/franela/goreq", + "Rev": "72c51a544272e007ab3da4f7d9ac959b7af7af03" + }, + { + "ImportPath": "github.com/golang/protobuf/proto", + "Rev": "1dceb1a2654bdc74ca97ad91f71f500eecc96269" + }, + { + "ImportPath": "github.com/lunny/log", + "Rev": "9ae5e9effeefca29cf00d932a7b54f1b2d70fd58" + }, + { + "ImportPath": "github.com/lunny/tango", + "Comment": "v0.4.6-25-g652e7bd", + "Rev": "652e7bd578c564f7ae9307ed102d188e564baf54" + }, + { + "ImportPath": "github.com/qiniu/log", + "Comment": "v1.0.00-2-ge002bc2", + "Rev": "e002bc2020b19bfa61ed378cc5407383dbd2f346" + }, + { + "ImportPath": "golang.org/x/net/context", + "Rev": "db8e4de5b2d6653f66aea53094624468caad15d2" + }, + { + "ImportPath": "golang.org/x/net/internal/timeseries", + "Rev": "db8e4de5b2d6653f66aea53094624468caad15d2" + }, + { + "ImportPath": "golang.org/x/net/trace", + "Rev": "db8e4de5b2d6653f66aea53094624468caad15d2" + }, + { + "ImportPath": "google.golang.org/grpc", + "Rev": "6b6127d8c67b67f9f8509b84279eac053964e868" + } + ] +} diff --git a/Godeps/Readme b/Godeps/Readme new file mode 100644 index 0000000..4cdaa53 --- /dev/null +++ b/Godeps/Readme @@ -0,0 +1,5 @@ +This directory tree is generated automatically by godep. + +Please do not edit. + +See https://github.com/tools/godep for more information. diff --git a/Godeps/_workspace/.gitignore b/Godeps/_workspace/.gitignore new file mode 100644 index 0000000..f037d68 --- /dev/null +++ b/Godeps/_workspace/.gitignore @@ -0,0 +1,2 @@ +/pkg +/bin diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/.gitignore b/Godeps/_workspace/src/github.com/bradfitz/http2/.gitignore new file mode 100644 index 0000000..190f122 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/.gitignore @@ -0,0 +1,2 @@ +*~ +h2i/h2i diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/AUTHORS b/Godeps/_workspace/src/github.com/bradfitz/http2/AUTHORS new file mode 100644 index 0000000..973453e --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/AUTHORS @@ -0,0 +1,20 @@ +# This file is like Go's AUTHORS file: it lists Copyright holders. +# The list of humans who have contributd is in the CONTRIBUTORS file. +# +# To contribute to this project, because it will eventually be folded +# back in to Go itself, you need to submit a CLA: +# +# http://golang.org/doc/contribute.html#copyright +# +# Then you get added to CONTRIBUTORS and you or your company get added +# to the AUTHORS file. + +Blake Mizerany github=bmizerany +Daniel Morsing github=DanielMorsing +Gabriel Aszalos github=gbbr +Google, Inc. +Keith Rarick github=kr +Matthew Keenan github=mattkeenan +Matt Layher github=mdlayher +Perry Abbott github=pabbott0 +Tatsuhiro Tsujikawa github=tatsuhiro-t diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/CONTRIBUTORS b/Godeps/_workspace/src/github.com/bradfitz/http2/CONTRIBUTORS new file mode 100644 index 0000000..f0128c1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/CONTRIBUTORS @@ -0,0 +1,20 @@ +# This file is like Go's CONTRIBUTORS file: it lists humans. +# The list of copyright holders (which may be companies) are in the AUTHORS file. +# +# To contribute to this project, because it will eventually be folded +# back in to Go itself, you need to submit a CLA: +# +# http://golang.org/doc/contribute.html#copyright +# +# Then you get added to CONTRIBUTORS and you or your company get added +# to the AUTHORS file. + +Blake Mizerany github=bmizerany +Brad Fitzpatrick github=bradfitz +Daniel Morsing github=DanielMorsing +Gabriel Aszalos github=gbbr +Keith Rarick github=kr +Matthew Keenan github=mattkeenan +Matt Layher github=mdlayher +Perry Abbott github=pabbott0 +Tatsuhiro Tsujikawa github=tatsuhiro-t diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/Dockerfile b/Godeps/_workspace/src/github.com/bradfitz/http2/Dockerfile new file mode 100644 index 0000000..b4e14d5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/Dockerfile @@ -0,0 +1,44 @@ +# +# This Dockerfile builds a recent curl with HTTP/2 client support, using +# a recent nghttp2 build. +# +# See the Makefile for how to tag it. If Docker and that image is found, the +# Go tests use this curl binary for integration tests. +# + +FROM ubuntu:trusty + +RUN apt-get update && \ + apt-get upgrade -y && \ + apt-get install -y git-core build-essential wget + +RUN apt-get install -y --no-install-recommends \ + autotools-dev libtool pkg-config zlib1g-dev \ + libcunit1-dev libssl-dev libxml2-dev libevent-dev \ + automake autoconf + +# Note: setting NGHTTP2_VER before the git clone, so an old git clone isn't cached: +ENV NGHTTP2_VER af24f8394e43f4 +RUN cd /root && git clone https://github.com/tatsuhiro-t/nghttp2.git + +WORKDIR /root/nghttp2 +RUN git reset --hard $NGHTTP2_VER +RUN autoreconf -i +RUN automake +RUN autoconf +RUN ./configure +RUN make +RUN make install + +WORKDIR /root +RUN wget http://curl.haxx.se/download/curl-7.40.0.tar.gz +RUN tar -zxvf curl-7.40.0.tar.gz +WORKDIR /root/curl-7.40.0 +RUN ./configure --with-ssl --with-nghttp2=/usr/local +RUN make +RUN make install +RUN ldconfig + +CMD ["-h"] +ENTRYPOINT ["/usr/local/bin/curl"] + diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/HACKING b/Godeps/_workspace/src/github.com/bradfitz/http2/HACKING new file mode 100644 index 0000000..69aafe4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/HACKING @@ -0,0 +1,5 @@ +We only accept contributions from users who have gone through Go's +contribution process (signed a CLA). + +Please acknowledge whether you have (and use the same email) if +sending a pull request. diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/LICENSE b/Godeps/_workspace/src/github.com/bradfitz/http2/LICENSE new file mode 100644 index 0000000..2dc6853 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/LICENSE @@ -0,0 +1,7 @@ +Copyright 2014 Google & the Go AUTHORS + +Go AUTHORS are: +See https://code.google.com/p/go/source/browse/AUTHORS + +Licensed under the terms of Go itself: +https://code.google.com/p/go/source/browse/LICENSE diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/Makefile b/Godeps/_workspace/src/github.com/bradfitz/http2/Makefile new file mode 100644 index 0000000..55fd826 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/Makefile @@ -0,0 +1,3 @@ +curlimage: + docker build -t gohttp2/curl . + diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/README b/Godeps/_workspace/src/github.com/bradfitz/http2/README new file mode 100644 index 0000000..16cb516 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/README @@ -0,0 +1,17 @@ +This is a work-in-progress HTTP/2 implementation for Go. + +It will eventually live in the Go standard library and won't require +any changes to your code to use. It will just be automatic. + +Status: + +* The server support is pretty good. A few things are missing + but are being worked on. +* The client work has just started but shares a lot of code + is coming along much quicker. + +Docs are at https://godoc.org/github.com/bradfitz/http2 + +Demo test server at https://http2.golang.org/ + +Help & bug reports welcome. diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/buffer.go b/Godeps/_workspace/src/github.com/bradfitz/http2/buffer.go new file mode 100644 index 0000000..c43954c --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/buffer.go @@ -0,0 +1,76 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "errors" +) + +// buffer is an io.ReadWriteCloser backed by a fixed size buffer. +// It never allocates, but moves old data as new data is written. +type buffer struct { + buf []byte + r, w int + closed bool + err error // err to return to reader +} + +var ( + errReadEmpty = errors.New("read from empty buffer") + errWriteClosed = errors.New("write on closed buffer") + errWriteFull = errors.New("write on full buffer") +) + +// Read copies bytes from the buffer into p. +// It is an error to read when no data is available. +func (b *buffer) Read(p []byte) (n int, err error) { + n = copy(p, b.buf[b.r:b.w]) + b.r += n + if b.closed && b.r == b.w { + err = b.err + } else if b.r == b.w && n == 0 { + err = errReadEmpty + } + return n, err +} + +// Len returns the number of bytes of the unread portion of the buffer. +func (b *buffer) Len() int { + return b.w - b.r +} + +// Write copies bytes from p into the buffer. +// It is an error to write more data than the buffer can hold. +func (b *buffer) Write(p []byte) (n int, err error) { + if b.closed { + return 0, errWriteClosed + } + + // Slide existing data to beginning. + if b.r > 0 && len(p) > len(b.buf)-b.w { + copy(b.buf, b.buf[b.r:b.w]) + b.w -= b.r + b.r = 0 + } + + // Write new data. + n = copy(b.buf[b.w:], p) + b.w += n + if n < len(p) { + err = errWriteFull + } + return n, err +} + +// Close marks the buffer as closed. Future calls to Write will +// return an error. Future calls to Read, once the buffer is +// empty, will return err. +func (b *buffer) Close(err error) { + if !b.closed { + b.closed = true + b.err = err + } +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/buffer_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/buffer_test.go new file mode 100644 index 0000000..ea76905 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/buffer_test.go @@ -0,0 +1,154 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "io" + "reflect" + "testing" +) + +var bufferReadTests = []struct { + buf buffer + read, wn int + werr error + wp []byte + wbuf buffer +}{ + { + buffer{[]byte{'a', 0}, 0, 1, false, nil}, + 5, 1, nil, []byte{'a'}, + buffer{[]byte{'a', 0}, 1, 1, false, nil}, + }, + { + buffer{[]byte{'a', 0}, 0, 1, true, io.EOF}, + 5, 1, io.EOF, []byte{'a'}, + buffer{[]byte{'a', 0}, 1, 1, true, io.EOF}, + }, + { + buffer{[]byte{0, 'a'}, 1, 2, false, nil}, + 5, 1, nil, []byte{'a'}, + buffer{[]byte{0, 'a'}, 2, 2, false, nil}, + }, + { + buffer{[]byte{0, 'a'}, 1, 2, true, io.EOF}, + 5, 1, io.EOF, []byte{'a'}, + buffer{[]byte{0, 'a'}, 2, 2, true, io.EOF}, + }, + { + buffer{[]byte{}, 0, 0, false, nil}, + 5, 0, errReadEmpty, []byte{}, + buffer{[]byte{}, 0, 0, false, nil}, + }, + { + buffer{[]byte{}, 0, 0, true, io.EOF}, + 5, 0, io.EOF, []byte{}, + buffer{[]byte{}, 0, 0, true, io.EOF}, + }, +} + +func TestBufferRead(t *testing.T) { + for i, tt := range bufferReadTests { + read := make([]byte, tt.read) + n, err := tt.buf.Read(read) + if n != tt.wn { + t.Errorf("#%d: wn = %d want %d", i, n, tt.wn) + continue + } + if err != tt.werr { + t.Errorf("#%d: werr = %v want %v", i, err, tt.werr) + continue + } + read = read[:n] + if !reflect.DeepEqual(read, tt.wp) { + t.Errorf("#%d: read = %+v want %+v", i, read, tt.wp) + } + if !reflect.DeepEqual(tt.buf, tt.wbuf) { + t.Errorf("#%d: buf = %+v want %+v", i, tt.buf, tt.wbuf) + } + } +} + +var bufferWriteTests = []struct { + buf buffer + write, wn int + werr error + wbuf buffer +}{ + { + buf: buffer{ + buf: []byte{}, + }, + wbuf: buffer{ + buf: []byte{}, + }, + }, + { + buf: buffer{ + buf: []byte{1, 'a'}, + }, + write: 1, + wn: 1, + wbuf: buffer{ + buf: []byte{0, 'a'}, + w: 1, + }, + }, + { + buf: buffer{ + buf: []byte{'a', 1}, + r: 1, + w: 1, + }, + write: 2, + wn: 2, + wbuf: buffer{ + buf: []byte{0, 0}, + w: 2, + }, + }, + { + buf: buffer{ + buf: []byte{}, + r: 1, + closed: true, + }, + write: 5, + werr: errWriteClosed, + wbuf: buffer{ + buf: []byte{}, + r: 1, + closed: true, + }, + }, + { + buf: buffer{ + buf: []byte{}, + }, + write: 5, + werr: errWriteFull, + wbuf: buffer{ + buf: []byte{}, + }, + }, +} + +func TestBufferWrite(t *testing.T) { + for i, tt := range bufferWriteTests { + n, err := tt.buf.Write(make([]byte, tt.write)) + if n != tt.wn { + t.Errorf("#%d: wrote %d bytes; want %d", i, n, tt.wn) + continue + } + if err != tt.werr { + t.Errorf("#%d: error = %v; want %v", i, err, tt.werr) + continue + } + if !reflect.DeepEqual(tt.buf, tt.wbuf) { + t.Errorf("#%d: buf = %+v; want %+v", i, tt.buf, tt.wbuf) + } + } +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/errors.go b/Godeps/_workspace/src/github.com/bradfitz/http2/errors.go new file mode 100644 index 0000000..c885328 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/errors.go @@ -0,0 +1,78 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import "fmt" + +// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec. +type ErrCode uint32 + +const ( + ErrCodeNo ErrCode = 0x0 + ErrCodeProtocol ErrCode = 0x1 + ErrCodeInternal ErrCode = 0x2 + ErrCodeFlowControl ErrCode = 0x3 + ErrCodeSettingsTimeout ErrCode = 0x4 + ErrCodeStreamClosed ErrCode = 0x5 + ErrCodeFrameSize ErrCode = 0x6 + ErrCodeRefusedStream ErrCode = 0x7 + ErrCodeCancel ErrCode = 0x8 + ErrCodeCompression ErrCode = 0x9 + ErrCodeConnect ErrCode = 0xa + ErrCodeEnhanceYourCalm ErrCode = 0xb + ErrCodeInadequateSecurity ErrCode = 0xc + ErrCodeHTTP11Required ErrCode = 0xd +) + +var errCodeName = map[ErrCode]string{ + ErrCodeNo: "NO_ERROR", + ErrCodeProtocol: "PROTOCOL_ERROR", + ErrCodeInternal: "INTERNAL_ERROR", + ErrCodeFlowControl: "FLOW_CONTROL_ERROR", + ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT", + ErrCodeStreamClosed: "STREAM_CLOSED", + ErrCodeFrameSize: "FRAME_SIZE_ERROR", + ErrCodeRefusedStream: "REFUSED_STREAM", + ErrCodeCancel: "CANCEL", + ErrCodeCompression: "COMPRESSION_ERROR", + ErrCodeConnect: "CONNECT_ERROR", + ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM", + ErrCodeInadequateSecurity: "INADEQUATE_SECURITY", + ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED", +} + +func (e ErrCode) String() string { + if s, ok := errCodeName[e]; ok { + return s + } + return fmt.Sprintf("unknown error code 0x%x", uint32(e)) +} + +// ConnectionError is an error that results in the termination of the +// entire connection. +type ConnectionError ErrCode + +func (e ConnectionError) Error() string { return fmt.Sprintf("connection error: %s", ErrCode(e)) } + +// StreamError is an error that only affects one stream within an +// HTTP/2 connection. +type StreamError struct { + StreamID uint32 + Code ErrCode +} + +func (e StreamError) Error() string { + return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) +} + +// 6.9.1 The Flow Control Window +// "If a sender receives a WINDOW_UPDATE that causes a flow control +// window to exceed this maximum it MUST terminate either the stream +// or the connection, as appropriate. For streams, [...]; for the +// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code." +type goAwayFlowError struct{} + +func (goAwayFlowError) Error() string { return "connection exceeded flow control window size" } diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/errors_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/errors_test.go new file mode 100644 index 0000000..86d52ee --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/errors_test.go @@ -0,0 +1,27 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import "testing" + +func TestErrCodeString(t *testing.T) { + tests := []struct { + err ErrCode + want string + }{ + {ErrCodeProtocol, "PROTOCOL_ERROR"}, + {0xd, "HTTP_1_1_REQUIRED"}, + {0xf, "unknown error code 0xf"}, + } + for i, tt := range tests { + got := tt.err.String() + if got != tt.want { + t.Errorf("%d. Error = %q; want %q", i, got, tt.want) + } + } +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/flow.go b/Godeps/_workspace/src/github.com/bradfitz/http2/flow.go new file mode 100644 index 0000000..540fc42 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/flow.go @@ -0,0 +1,51 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +// Flow control + +package http2 + +// flow is the flow control window's size. +type flow struct { + // n is the number of DATA bytes we're allowed to send. + // A flow is kept both on a conn and a per-stream. + n int32 + + // conn points to the shared connection-level flow that is + // shared by all streams on that conn. It is nil for the flow + // that's on the conn directly. + conn *flow +} + +func (f *flow) setConnFlow(cf *flow) { f.conn = cf } + +func (f *flow) available() int32 { + n := f.n + if f.conn != nil && f.conn.n < n { + n = f.conn.n + } + return n +} + +func (f *flow) take(n int32) { + if n > f.available() { + panic("internal error: took too much") + } + f.n -= n + if f.conn != nil { + f.conn.n -= n + } +} + +// add adds n bytes (positive or negative) to the flow control window. +// It returns false if the sum would exceed 2^31-1. +func (f *flow) add(n int32) bool { + remain := (1<<31 - 1) - f.n + if n > remain { + return false + } + f.n += n + return true +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/flow_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/flow_test.go new file mode 100644 index 0000000..dbb656f --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/flow_test.go @@ -0,0 +1,54 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import "testing" + +func TestFlow(t *testing.T) { + var st flow + var conn flow + st.add(3) + conn.add(2) + + if got, want := st.available(), int32(3); got != want { + t.Errorf("available = %d; want %d", got, want) + } + st.setConnFlow(&conn) + if got, want := st.available(), int32(2); got != want { + t.Errorf("after parent setup, available = %d; want %d", got, want) + } + + st.take(2) + if got, want := conn.available(), int32(0); got != want { + t.Errorf("after taking 2, conn = %d; want %d", got, want) + } + if got, want := st.available(), int32(0); got != want { + t.Errorf("after taking 2, stream = %d; want %d", got, want) + } +} + +func TestFlowAdd(t *testing.T) { + var f flow + if !f.add(1) { + t.Fatal("failed to add 1") + } + if !f.add(-1) { + t.Fatal("failed to add -1") + } + if got, want := f.available(), int32(0); got != want { + t.Fatalf("size = %d; want %d", got, want) + } + if !f.add(1<<31 - 1) { + t.Fatal("failed to add 2^31-1") + } + if got, want := f.available(), int32(1<<31-1); got != want { + t.Fatalf("size = %d; want %d", got, want) + } + if f.add(1) { + t.Fatal("adding 1 to max shouldn't be allowed") + } + +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/frame.go b/Godeps/_workspace/src/github.com/bradfitz/http2/frame.go new file mode 100644 index 0000000..e8b872a --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/frame.go @@ -0,0 +1,1113 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "sync" +) + +const frameHeaderLen = 9 + +var padZeros = make([]byte, 255) // zeros for padding + +// A FrameType is a registered frame type as defined in +// http://http2.github.io/http2-spec/#rfc.section.11.2 +type FrameType uint8 + +const ( + FrameData FrameType = 0x0 + FrameHeaders FrameType = 0x1 + FramePriority FrameType = 0x2 + FrameRSTStream FrameType = 0x3 + FrameSettings FrameType = 0x4 + FramePushPromise FrameType = 0x5 + FramePing FrameType = 0x6 + FrameGoAway FrameType = 0x7 + FrameWindowUpdate FrameType = 0x8 + FrameContinuation FrameType = 0x9 +) + +var frameName = map[FrameType]string{ + FrameData: "DATA", + FrameHeaders: "HEADERS", + FramePriority: "PRIORITY", + FrameRSTStream: "RST_STREAM", + FrameSettings: "SETTINGS", + FramePushPromise: "PUSH_PROMISE", + FramePing: "PING", + FrameGoAway: "GOAWAY", + FrameWindowUpdate: "WINDOW_UPDATE", + FrameContinuation: "CONTINUATION", +} + +func (t FrameType) String() string { + if s, ok := frameName[t]; ok { + return s + } + return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t)) +} + +// Flags is a bitmask of HTTP/2 flags. +// The meaning of flags varies depending on the frame type. +type Flags uint8 + +// Has reports whether f contains all (0 or more) flags in v. +func (f Flags) Has(v Flags) bool { + return (f & v) == v +} + +// Frame-specific FrameHeader flag bits. +const ( + // Data Frame + FlagDataEndStream Flags = 0x1 + FlagDataPadded Flags = 0x8 + + // Headers Frame + FlagHeadersEndStream Flags = 0x1 + FlagHeadersEndHeaders Flags = 0x4 + FlagHeadersPadded Flags = 0x8 + FlagHeadersPriority Flags = 0x20 + + // Settings Frame + FlagSettingsAck Flags = 0x1 + + // Ping Frame + FlagPingAck Flags = 0x1 + + // Continuation Frame + FlagContinuationEndHeaders Flags = 0x4 + + FlagPushPromiseEndHeaders Flags = 0x4 + FlagPushPromisePadded Flags = 0x8 +) + +var flagName = map[FrameType]map[Flags]string{ + FrameData: { + FlagDataEndStream: "END_STREAM", + FlagDataPadded: "PADDED", + }, + FrameHeaders: { + FlagHeadersEndStream: "END_STREAM", + FlagHeadersEndHeaders: "END_HEADERS", + FlagHeadersPadded: "PADDED", + FlagHeadersPriority: "PRIORITY", + }, + FrameSettings: { + FlagSettingsAck: "ACK", + }, + FramePing: { + FlagPingAck: "ACK", + }, + FrameContinuation: { + FlagContinuationEndHeaders: "END_HEADERS", + }, + FramePushPromise: { + FlagPushPromiseEndHeaders: "END_HEADERS", + FlagPushPromisePadded: "PADDED", + }, +} + +// a frameParser parses a frame given its FrameHeader and payload +// bytes. The length of payload will always equal fh.Length (which +// might be 0). +type frameParser func(fh FrameHeader, payload []byte) (Frame, error) + +var frameParsers = map[FrameType]frameParser{ + FrameData: parseDataFrame, + FrameHeaders: parseHeadersFrame, + FramePriority: parsePriorityFrame, + FrameRSTStream: parseRSTStreamFrame, + FrameSettings: parseSettingsFrame, + FramePushPromise: parsePushPromise, + FramePing: parsePingFrame, + FrameGoAway: parseGoAwayFrame, + FrameWindowUpdate: parseWindowUpdateFrame, + FrameContinuation: parseContinuationFrame, +} + +func typeFrameParser(t FrameType) frameParser { + if f := frameParsers[t]; f != nil { + return f + } + return parseUnknownFrame +} + +// A FrameHeader is the 9 byte header of all HTTP/2 frames. +// +// See http://http2.github.io/http2-spec/#FrameHeader +type FrameHeader struct { + valid bool // caller can access []byte fields in the Frame + + // Type is the 1 byte frame type. There are ten standard frame + // types, but extension frame types may be written by WriteRawFrame + // and will be returned by ReadFrame (as UnknownFrame). + Type FrameType + + // Flags are the 1 byte of 8 potential bit flags per frame. + // They are specific to the frame type. + Flags Flags + + // Length is the length of the frame, not including the 9 byte header. + // The maximum size is one byte less than 16MB (uint24), but only + // frames up to 16KB are allowed without peer agreement. + Length uint32 + + // StreamID is which stream this frame is for. Certain frames + // are not stream-specific, in which case this field is 0. + StreamID uint32 +} + +// Header returns h. It exists so FrameHeaders can be embedded in other +// specific frame types and implement the Frame interface. +func (h FrameHeader) Header() FrameHeader { return h } + +func (h FrameHeader) String() string { + var buf bytes.Buffer + buf.WriteString("[FrameHeader ") + buf.WriteString(h.Type.String()) + if h.Flags != 0 { + buf.WriteString(" flags=") + set := 0 + for i := uint8(0); i < 8; i++ { + if h.Flags&(1< 1 { + buf.WriteByte('|') + } + name := flagName[h.Type][Flags(1<>24), + byte(streamID>>16), + byte(streamID>>8), + byte(streamID)) +} + +func (f *Framer) endWrite() error { + // Now that we know the final size, fill in the FrameHeader in + // the space previously reserved for it. Abuse append. + length := len(f.wbuf) - frameHeaderLen + if length >= (1 << 24) { + return ErrFrameTooLarge + } + _ = append(f.wbuf[:0], + byte(length>>16), + byte(length>>8), + byte(length)) + n, err := f.w.Write(f.wbuf) + if err == nil && n != len(f.wbuf) { + err = io.ErrShortWrite + } + return err +} + +func (f *Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } +func (f *Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) } +func (f *Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) } +func (f *Framer) writeUint32(v uint32) { + f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +const ( + minMaxFrameSize = 1 << 14 + maxFrameSize = 1<<24 - 1 +) + +// NewFramer returns a Framer that writes frames to w and reads them from r. +func NewFramer(w io.Writer, r io.Reader) *Framer { + fr := &Framer{ + w: w, + r: r, + } + fr.getReadBuf = func(size uint32) []byte { + if cap(fr.readBuf) >= int(size) { + return fr.readBuf[:size] + } + fr.readBuf = make([]byte, size) + return fr.readBuf + } + fr.SetMaxReadFrameSize(maxFrameSize) + return fr +} + +// SetMaxReadFrameSize sets the maximum size of a frame +// that will be read by a subsequent call to ReadFrame. +// It is the caller's responsibility to advertise this +// limit with a SETTINGS frame. +func (fr *Framer) SetMaxReadFrameSize(v uint32) { + if v > maxFrameSize { + v = maxFrameSize + } + fr.maxReadSize = v +} + +// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer +// sends a frame that is larger than declared with SetMaxReadFrameSize. +var ErrFrameTooLarge = errors.New("http2: frame too large") + +// ReadFrame reads a single frame. The returned Frame is only valid +// until the next call to ReadFrame. +// If the frame is larger than previously set with SetMaxReadFrameSize, +// the returned error is ErrFrameTooLarge. +func (fr *Framer) ReadFrame() (Frame, error) { + if fr.lastFrame != nil { + fr.lastFrame.invalidate() + } + fh, err := readFrameHeader(fr.headerBuf[:], fr.r) + if err != nil { + return nil, err + } + if fh.Length > fr.maxReadSize { + return nil, ErrFrameTooLarge + } + payload := fr.getReadBuf(fh.Length) + if _, err := io.ReadFull(fr.r, payload); err != nil { + return nil, err + } + f, err := typeFrameParser(fh.Type)(fh, payload) + if err != nil { + return nil, err + } + fr.lastFrame = f + return f, nil +} + +// A DataFrame conveys arbitrary, variable-length sequences of octets +// associated with a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.1 +type DataFrame struct { + FrameHeader + data []byte +} + +func (f *DataFrame) StreamEnded() bool { + return f.FrameHeader.Flags.Has(FlagDataEndStream) +} + +// Data returns the frame's data octets, not including any padding +// size byte or padding suffix bytes. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *DataFrame) Data() []byte { + f.checkValid() + return f.data +} + +func parseDataFrame(fh FrameHeader, payload []byte) (Frame, error) { + if fh.StreamID == 0 { + // DATA frames MUST be associated with a stream. If a + // DATA frame is received whose stream identifier + // field is 0x0, the recipient MUST respond with a + // connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + return nil, ConnectionError(ErrCodeProtocol) + } + f := &DataFrame{ + FrameHeader: fh, + } + var padSize byte + if fh.Flags.Has(FlagDataPadded) { + var err error + payload, padSize, err = readByte(payload) + if err != nil { + return nil, err + } + } + if int(padSize) > len(payload) { + // If the length of the padding is greater than the + // length of the frame payload, the recipient MUST + // treat this as a connection error. + // Filed: https://github.com/http2/http2-spec/issues/610 + return nil, ConnectionError(ErrCodeProtocol) + } + f.data = payload[:len(payload)-int(padSize)] + return f, nil +} + +var errStreamID = errors.New("invalid streamid") + +func validStreamID(streamID uint32) bool { + return streamID != 0 && streamID&(1<<31) == 0 +} + +// WriteData writes a DATA frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *Framer) WriteData(streamID uint32, endStream bool, data []byte) error { + // TODO: ignoring padding for now. will add when somebody cares. + if !validStreamID(streamID) && !f.AllowIllegalWrites { + return errStreamID + } + var flags Flags + if endStream { + flags |= FlagDataEndStream + } + f.startWrite(FrameData, flags, streamID) + f.wbuf = append(f.wbuf, data...) + return f.endWrite() +} + +// A SettingsFrame conveys configuration parameters that affect how +// endpoints communicate, such as preferences and constraints on peer +// behavior. +// +// See http://http2.github.io/http2-spec/#SETTINGS +type SettingsFrame struct { + FrameHeader + p []byte +} + +func parseSettingsFrame(fh FrameHeader, p []byte) (Frame, error) { + if fh.Flags.Has(FlagSettingsAck) && fh.Length > 0 { + // When this (ACK 0x1) bit is set, the payload of the + // SETTINGS frame MUST be empty. Receipt of a + // SETTINGS frame with the ACK flag set and a length + // field value other than 0 MUST be treated as a + // connection error (Section 5.4.1) of type + // FRAME_SIZE_ERROR. + return nil, ConnectionError(ErrCodeFrameSize) + } + if fh.StreamID != 0 { + // SETTINGS frames always apply to a connection, + // never a single stream. The stream identifier for a + // SETTINGS frame MUST be zero (0x0). If an endpoint + // receives a SETTINGS frame whose stream identifier + // field is anything other than 0x0, the endpoint MUST + // respond with a connection error (Section 5.4.1) of + // type PROTOCOL_ERROR. + return nil, ConnectionError(ErrCodeProtocol) + } + if len(p)%6 != 0 { + // Expecting even number of 6 byte settings. + return nil, ConnectionError(ErrCodeFrameSize) + } + f := &SettingsFrame{FrameHeader: fh, p: p} + if v, ok := f.Value(SettingInitialWindowSize); ok && v > (1<<31)-1 { + // Values above the maximum flow control window size of 2^31 - 1 MUST + // be treated as a connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR. + return nil, ConnectionError(ErrCodeFlowControl) + } + return f, nil +} + +func (f *SettingsFrame) IsAck() bool { + return f.FrameHeader.Flags.Has(FlagSettingsAck) +} + +func (f *SettingsFrame) Value(s SettingID) (v uint32, ok bool) { + f.checkValid() + buf := f.p + for len(buf) > 0 { + settingID := SettingID(binary.BigEndian.Uint16(buf[:2])) + if settingID == s { + return binary.BigEndian.Uint32(buf[2:6]), true + } + buf = buf[6:] + } + return 0, false +} + +// ForeachSetting runs fn for each setting. +// It stops and returns the first error. +func (f *SettingsFrame) ForeachSetting(fn func(Setting) error) error { + f.checkValid() + buf := f.p + for len(buf) > 0 { + if err := fn(Setting{ + SettingID(binary.BigEndian.Uint16(buf[:2])), + binary.BigEndian.Uint32(buf[2:6]), + }); err != nil { + return err + } + buf = buf[6:] + } + return nil +} + +// WriteSettings writes a SETTINGS frame with zero or more settings +// specified and the ACK bit not set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *Framer) WriteSettings(settings ...Setting) error { + f.startWrite(FrameSettings, 0, 0) + for _, s := range settings { + f.writeUint16(uint16(s.ID)) + f.writeUint32(s.Val) + } + return f.endWrite() +} + +// WriteSettings writes an empty SETTINGS frame with the ACK bit set. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *Framer) WriteSettingsAck() error { + f.startWrite(FrameSettings, FlagSettingsAck, 0) + return f.endWrite() +} + +// A PingFrame is a mechanism for measuring a minimal round trip time +// from the sender, as well as determining whether an idle connection +// is still functional. +// See http://http2.github.io/http2-spec/#rfc.section.6.7 +type PingFrame struct { + FrameHeader + Data [8]byte +} + +func parsePingFrame(fh FrameHeader, payload []byte) (Frame, error) { + if len(payload) != 8 { + return nil, ConnectionError(ErrCodeFrameSize) + } + if fh.StreamID != 0 { + return nil, ConnectionError(ErrCodeProtocol) + } + f := &PingFrame{FrameHeader: fh} + copy(f.Data[:], payload) + return f, nil +} + +func (f *Framer) WritePing(ack bool, data [8]byte) error { + var flags Flags + if ack { + flags = FlagPingAck + } + f.startWrite(FramePing, flags, 0) + f.writeBytes(data[:]) + return f.endWrite() +} + +// A GoAwayFrame informs the remote peer to stop creating streams on this connection. +// See http://http2.github.io/http2-spec/#rfc.section.6.8 +type GoAwayFrame struct { + FrameHeader + LastStreamID uint32 + ErrCode ErrCode + debugData []byte +} + +// DebugData returns any debug data in the GOAWAY frame. Its contents +// are not defined. +// The caller must not retain the returned memory past the next +// call to ReadFrame. +func (f *GoAwayFrame) DebugData() []byte { + f.checkValid() + return f.debugData +} + +func parseGoAwayFrame(fh FrameHeader, p []byte) (Frame, error) { + if fh.StreamID != 0 { + return nil, ConnectionError(ErrCodeProtocol) + } + if len(p) < 8 { + return nil, ConnectionError(ErrCodeFrameSize) + } + return &GoAwayFrame{ + FrameHeader: fh, + LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1), + ErrCode: ErrCode(binary.BigEndian.Uint32(p[4:8])), + debugData: p[8:], + }, nil +} + +func (f *Framer) WriteGoAway(maxStreamID uint32, code ErrCode, debugData []byte) error { + f.startWrite(FrameGoAway, 0, 0) + f.writeUint32(maxStreamID & (1<<31 - 1)) + f.writeUint32(uint32(code)) + f.writeBytes(debugData) + return f.endWrite() +} + +// An UnknownFrame is the frame type returned when the frame type is unknown +// or no specific frame type parser exists. +type UnknownFrame struct { + FrameHeader + p []byte +} + +// Payload returns the frame's payload (after the header). It is not +// valid to call this method after a subsequent call to +// Framer.ReadFrame, nor is it valid to retain the returned slice. +// The memory is owned by the Framer and is invalidated when the next +// frame is read. +func (f *UnknownFrame) Payload() []byte { + f.checkValid() + return f.p +} + +func parseUnknownFrame(fh FrameHeader, p []byte) (Frame, error) { + return &UnknownFrame{fh, p}, nil +} + +// A WindowUpdateFrame is used to implement flow control. +// See http://http2.github.io/http2-spec/#rfc.section.6.9 +type WindowUpdateFrame struct { + FrameHeader + Increment uint32 +} + +func parseWindowUpdateFrame(fh FrameHeader, p []byte) (Frame, error) { + if len(p) != 4 { + return nil, ConnectionError(ErrCodeFrameSize) + } + inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff // mask off high reserved bit + if inc == 0 { + // A receiver MUST treat the receipt of a + // WINDOW_UPDATE frame with an flow control window + // increment of 0 as a stream error (Section 5.4.2) of + // type PROTOCOL_ERROR; errors on the connection flow + // control window MUST be treated as a connection + // error (Section 5.4.1). + if fh.StreamID == 0 { + return nil, ConnectionError(ErrCodeProtocol) + } + return nil, StreamError{fh.StreamID, ErrCodeProtocol} + } + return &WindowUpdateFrame{ + FrameHeader: fh, + Increment: inc, + }, nil +} + +// WriteWindowUpdate writes a WINDOW_UPDATE frame. +// The increment value must be between 1 and 2,147,483,647, inclusive. +// If the Stream ID is zero, the window update applies to the +// connection as a whole. +func (f *Framer) WriteWindowUpdate(streamID, incr uint32) error { + // "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets." + if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites { + return errors.New("illegal window increment value") + } + f.startWrite(FrameWindowUpdate, 0, streamID) + f.writeUint32(incr) + return f.endWrite() +} + +// A HeadersFrame is used to open a stream and additionally carries a +// header block fragment. +type HeadersFrame struct { + FrameHeader + + // Priority is set if FlagHeadersPriority is set in the FrameHeader. + Priority PriorityParam + + headerFragBuf []byte // not owned +} + +func (f *HeadersFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *HeadersFrame) HeadersEnded() bool { + return f.FrameHeader.Flags.Has(FlagHeadersEndHeaders) +} + +func (f *HeadersFrame) StreamEnded() bool { + return f.FrameHeader.Flags.Has(FlagHeadersEndStream) +} + +func (f *HeadersFrame) HasPriority() bool { + return f.FrameHeader.Flags.Has(FlagHeadersPriority) +} + +func parseHeadersFrame(fh FrameHeader, p []byte) (_ Frame, err error) { + hf := &HeadersFrame{ + FrameHeader: fh, + } + if fh.StreamID == 0 { + // HEADERS frames MUST be associated with a stream. If a HEADERS frame + // is received whose stream identifier field is 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + return nil, ConnectionError(ErrCodeProtocol) + } + var padLength uint8 + if fh.Flags.Has(FlagHeadersPadded) { + if p, padLength, err = readByte(p); err != nil { + return + } + } + if fh.Flags.Has(FlagHeadersPriority) { + var v uint32 + p, v, err = readUint32(p) + if err != nil { + return nil, err + } + hf.Priority.StreamDep = v & 0x7fffffff + hf.Priority.Exclusive = (v != hf.Priority.StreamDep) // high bit was set + p, hf.Priority.Weight, err = readByte(p) + if err != nil { + return nil, err + } + } + if len(p)-int(padLength) <= 0 { + return nil, StreamError{fh.StreamID, ErrCodeProtocol} + } + hf.headerFragBuf = p[:len(p)-int(padLength)] + return hf, nil +} + +// HeadersFrameParam are the parameters for writing a HEADERS frame. +type HeadersFrameParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndStream indicates that the header block is the last that + // the endpoint will send for the identified stream. Setting + // this flag causes the stream to enter one of "half closed" + // states. + EndStream bool + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 + + // Priority, if non-zero, includes stream priority information + // in the HEADER frame. + Priority PriorityParam +} + +// WriteHeaders writes a single HEADERS frame. +// +// This is a low-level header writing method. Encoding headers and +// splitting them into any necessary CONTINUATION frames is handled +// elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *Framer) WriteHeaders(p HeadersFrameParam) error { + if !validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return errStreamID + } + var flags Flags + if p.PadLength != 0 { + flags |= FlagHeadersPadded + } + if p.EndStream { + flags |= FlagHeadersEndStream + } + if p.EndHeaders { + flags |= FlagHeadersEndHeaders + } + if !p.Priority.IsZero() { + flags |= FlagHeadersPriority + } + f.startWrite(FrameHeaders, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !p.Priority.IsZero() { + v := p.Priority.StreamDep + if !validStreamID(v) && !f.AllowIllegalWrites { + return errors.New("invalid dependent stream id") + } + if p.Priority.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Priority.Weight) + } + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, padZeros[:p.PadLength]...) + return f.endWrite() +} + +// A PriorityFrame specifies the sender-advised priority of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.3 +type PriorityFrame struct { + FrameHeader + PriorityParam +} + +// PriorityParam are the stream prioritzation parameters. +type PriorityParam struct { + // StreamDep is a 31-bit stream identifier for the + // stream that this stream depends on. Zero means no + // dependency. + StreamDep uint32 + + // Exclusive is whether the dependency is exclusive. + Exclusive bool + + // Weight is the stream's zero-indexed weight. It should be + // set together with StreamDep, or neither should be set. Per + // the spec, "Add one to the value to obtain a weight between + // 1 and 256." + Weight uint8 +} + +func (p PriorityParam) IsZero() bool { + return p == PriorityParam{} +} + +func parsePriorityFrame(fh FrameHeader, payload []byte) (Frame, error) { + if fh.StreamID == 0 { + return nil, ConnectionError(ErrCodeProtocol) + } + if len(payload) != 5 { + return nil, ConnectionError(ErrCodeFrameSize) + } + v := binary.BigEndian.Uint32(payload[:4]) + streamID := v & 0x7fffffff // mask off high bit + return &PriorityFrame{ + FrameHeader: fh, + PriorityParam: PriorityParam{ + Weight: payload[4], + StreamDep: streamID, + Exclusive: streamID != v, // was high bit set? + }, + }, nil +} + +// WritePriority writes a PRIORITY frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *Framer) WritePriority(streamID uint32, p PriorityParam) error { + if !validStreamID(streamID) && !f.AllowIllegalWrites { + return errStreamID + } + f.startWrite(FramePriority, 0, streamID) + v := p.StreamDep + if p.Exclusive { + v |= 1 << 31 + } + f.writeUint32(v) + f.writeByte(p.Weight) + return f.endWrite() +} + +// A RSTStreamFrame allows for abnormal termination of a stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.4 +type RSTStreamFrame struct { + FrameHeader + ErrCode ErrCode +} + +func parseRSTStreamFrame(fh FrameHeader, p []byte) (Frame, error) { + if len(p) != 4 { + return nil, ConnectionError(ErrCodeFrameSize) + } + if fh.StreamID == 0 { + return nil, ConnectionError(ErrCodeProtocol) + } + return &RSTStreamFrame{fh, ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil +} + +// WriteRSTStream writes a RST_STREAM frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *Framer) WriteRSTStream(streamID uint32, code ErrCode) error { + if !validStreamID(streamID) && !f.AllowIllegalWrites { + return errStreamID + } + f.startWrite(FrameRSTStream, 0, streamID) + f.writeUint32(uint32(code)) + return f.endWrite() +} + +// A ContinuationFrame is used to continue a sequence of header block fragments. +// See http://http2.github.io/http2-spec/#rfc.section.6.10 +type ContinuationFrame struct { + FrameHeader + headerFragBuf []byte +} + +func parseContinuationFrame(fh FrameHeader, p []byte) (Frame, error) { + return &ContinuationFrame{fh, p}, nil +} + +func (f *ContinuationFrame) StreamEnded() bool { + return f.FrameHeader.Flags.Has(FlagDataEndStream) +} + +func (f *ContinuationFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *ContinuationFrame) HeadersEnded() bool { + return f.FrameHeader.Flags.Has(FlagContinuationEndHeaders) +} + +// WriteContinuation writes a CONTINUATION frame. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error { + if !validStreamID(streamID) && !f.AllowIllegalWrites { + return errStreamID + } + var flags Flags + if endHeaders { + flags |= FlagContinuationEndHeaders + } + f.startWrite(FrameContinuation, flags, streamID) + f.wbuf = append(f.wbuf, headerBlockFragment...) + return f.endWrite() +} + +// A PushPromiseFrame is used to initiate a server stream. +// See http://http2.github.io/http2-spec/#rfc.section.6.6 +type PushPromiseFrame struct { + FrameHeader + PromiseID uint32 + headerFragBuf []byte // not owned +} + +func (f *PushPromiseFrame) HeaderBlockFragment() []byte { + f.checkValid() + return f.headerFragBuf +} + +func (f *PushPromiseFrame) HeadersEnded() bool { + return f.FrameHeader.Flags.Has(FlagPushPromiseEndHeaders) +} + +func parsePushPromise(fh FrameHeader, p []byte) (_ Frame, err error) { + pp := &PushPromiseFrame{ + FrameHeader: fh, + } + if pp.StreamID == 0 { + // PUSH_PROMISE frames MUST be associated with an existing, + // peer-initiated stream. The stream identifier of a + // PUSH_PROMISE frame indicates the stream it is associated + // with. If the stream identifier field specifies the value + // 0x0, a recipient MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + return nil, ConnectionError(ErrCodeProtocol) + } + // The PUSH_PROMISE frame includes optional padding. + // Padding fields and flags are identical to those defined for DATA frames + var padLength uint8 + if fh.Flags.Has(FlagPushPromisePadded) { + if p, padLength, err = readByte(p); err != nil { + return + } + } + + p, pp.PromiseID, err = readUint32(p) + if err != nil { + return + } + pp.PromiseID = pp.PromiseID & (1<<31 - 1) + + if int(padLength) > len(p) { + // like the DATA frame, error out if padding is longer than the body. + return nil, ConnectionError(ErrCodeProtocol) + } + pp.headerFragBuf = p[:len(p)-int(padLength)] + return pp, nil +} + +// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame. +type PushPromiseParam struct { + // StreamID is the required Stream ID to initiate. + StreamID uint32 + + // PromiseID is the required Stream ID which this + // Push Promises + PromiseID uint32 + + // BlockFragment is part (or all) of a Header Block. + BlockFragment []byte + + // EndHeaders indicates that this frame contains an entire + // header block and is not followed by any + // CONTINUATION frames. + EndHeaders bool + + // PadLength is the optional number of bytes of zeros to add + // to this frame. + PadLength uint8 +} + +// WritePushPromise writes a single PushPromise Frame. +// +// As with Header Frames, This is the low level call for writing +// individual frames. Continuation frames are handled elsewhere. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility to not call other Write methods concurrently. +func (f *Framer) WritePushPromise(p PushPromiseParam) error { + if !validStreamID(p.StreamID) && !f.AllowIllegalWrites { + return errStreamID + } + var flags Flags + if p.PadLength != 0 { + flags |= FlagPushPromisePadded + } + if p.EndHeaders { + flags |= FlagPushPromiseEndHeaders + } + f.startWrite(FramePushPromise, flags, p.StreamID) + if p.PadLength != 0 { + f.writeByte(p.PadLength) + } + if !validStreamID(p.PromiseID) && !f.AllowIllegalWrites { + return errStreamID + } + f.writeUint32(p.PromiseID) + f.wbuf = append(f.wbuf, p.BlockFragment...) + f.wbuf = append(f.wbuf, padZeros[:p.PadLength]...) + return f.endWrite() +} + +// WriteRawFrame writes a raw frame. This can be used to write +// extension frames unknown to this package. +func (f *Framer) WriteRawFrame(t FrameType, flags Flags, streamID uint32, payload []byte) error { + f.startWrite(t, flags, streamID) + f.writeBytes(payload) + return f.endWrite() +} + +func readByte(p []byte) (remain []byte, b byte, err error) { + if len(p) == 0 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[1:], p[0], nil +} + +func readUint32(p []byte) (remain []byte, v uint32, err error) { + if len(p) < 4 { + return nil, 0, io.ErrUnexpectedEOF + } + return p[4:], binary.BigEndian.Uint32(p[:4]), nil +} + +type streamEnder interface { + StreamEnded() bool +} + +type headersEnder interface { + HeadersEnded() bool +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/frame_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/frame_test.go new file mode 100644 index 0000000..20027c5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/frame_test.go @@ -0,0 +1,597 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http2 + +import ( + "bytes" + "reflect" + "strings" + "testing" + "unsafe" +) + +func testFramer() (*Framer, *bytes.Buffer) { + buf := new(bytes.Buffer) + return NewFramer(buf, buf), buf +} + +func TestFrameSizes(t *testing.T) { + // Catch people rearranging the FrameHeader fields. + if got, want := int(unsafe.Sizeof(FrameHeader{})), 12; got != want { + t.Errorf("FrameHeader size = %d; want %d", got, want) + } +} + +func TestFrameTypeString(t *testing.T) { + tests := []struct { + ft FrameType + want string + }{ + {FrameData, "DATA"}, + {FramePing, "PING"}, + {FrameGoAway, "GOAWAY"}, + {0xf, "UNKNOWN_FRAME_TYPE_15"}, + } + + for i, tt := range tests { + got := tt.ft.String() + if got != tt.want { + t.Errorf("%d. String(FrameType %d) = %q; want %q", i, int(tt.ft), got, tt.want) + } + } +} + +func TestWriteRST(t *testing.T) { + fr, buf := testFramer() + var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 + var errCode uint32 = 7<<24 + 6<<16 + 5<<8 + 4 + fr.WriteRSTStream(streamID, ErrCode(errCode)) + const wantEnc = "\x00\x00\x04\x03\x00\x01\x02\x03\x04\x07\x06\x05\x04" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + want := &RSTStreamFrame{ + FrameHeader: FrameHeader{ + valid: true, + Type: 0x3, + Flags: 0x0, + Length: 0x4, + StreamID: 0x1020304, + }, + ErrCode: 0x7060504, + } + if !reflect.DeepEqual(f, want) { + t.Errorf("parsed back %#v; want %#v", f, want) + } +} + +func TestWriteData(t *testing.T) { + fr, buf := testFramer() + var streamID uint32 = 1<<24 + 2<<16 + 3<<8 + 4 + data := []byte("ABC") + fr.WriteData(streamID, true, data) + const wantEnc = "\x00\x00\x03\x00\x01\x01\x02\x03\x04ABC" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + df, ok := f.(*DataFrame) + if !ok { + t.Fatalf("got %T; want *DataFrame", f) + } + if !bytes.Equal(df.Data(), data) { + t.Errorf("got %q; want %q", df.Data(), data) + } + if f.Header().Flags&1 == 0 { + t.Errorf("didn't see END_STREAM flag") + } +} + +func TestWriteHeaders(t *testing.T) { + tests := []struct { + name string + p HeadersFrameParam + wantEnc string + wantFrame *HeadersFrame + }{ + { + "basic", + HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + Priority: PriorityParam{}, + }, + "\x00\x00\x03\x01\x00\x00\x00\x00*abc", + &HeadersFrame{ + FrameHeader: FrameHeader{ + valid: true, + StreamID: 42, + Type: FrameHeaders, + Length: uint32(len("abc")), + }, + Priority: PriorityParam{}, + headerFragBuf: []byte("abc"), + }, + }, + { + "basic + end flags", + HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + EndStream: true, + EndHeaders: true, + Priority: PriorityParam{}, + }, + "\x00\x00\x03\x01\x05\x00\x00\x00*abc", + &HeadersFrame{ + FrameHeader: FrameHeader{ + valid: true, + StreamID: 42, + Type: FrameHeaders, + Flags: FlagHeadersEndStream | FlagHeadersEndHeaders, + Length: uint32(len("abc")), + }, + Priority: PriorityParam{}, + headerFragBuf: []byte("abc"), + }, + }, + { + "with padding", + HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + EndStream: true, + EndHeaders: true, + PadLength: 5, + Priority: PriorityParam{}, + }, + "\x00\x00\t\x01\r\x00\x00\x00*\x05abc\x00\x00\x00\x00\x00", + &HeadersFrame{ + FrameHeader: FrameHeader{ + valid: true, + StreamID: 42, + Type: FrameHeaders, + Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded, + Length: uint32(1 + len("abc") + 5), // pad length + contents + padding + }, + Priority: PriorityParam{}, + headerFragBuf: []byte("abc"), + }, + }, + { + "with priority", + HeadersFrameParam{ + StreamID: 42, + BlockFragment: []byte("abc"), + EndStream: true, + EndHeaders: true, + PadLength: 2, + Priority: PriorityParam{ + StreamDep: 15, + Exclusive: true, + Weight: 127, + }, + }, + "\x00\x00\v\x01-\x00\x00\x00*\x02\x80\x00\x00\x0f\u007fabc\x00\x00", + &HeadersFrame{ + FrameHeader: FrameHeader{ + valid: true, + StreamID: 42, + Type: FrameHeaders, + Flags: FlagHeadersEndStream | FlagHeadersEndHeaders | FlagHeadersPadded | FlagHeadersPriority, + Length: uint32(1 + 5 + len("abc") + 2), // pad length + priority + contents + padding + }, + Priority: PriorityParam{ + StreamDep: 15, + Exclusive: true, + Weight: 127, + }, + headerFragBuf: []byte("abc"), + }, + }, + } + for _, tt := range tests { + fr, buf := testFramer() + if err := fr.WriteHeaders(tt.p); err != nil { + t.Errorf("test %q: %v", tt.name, err) + continue + } + if buf.String() != tt.wantEnc { + t.Errorf("test %q: encoded %q; want %q", tt.name, buf.Bytes(), tt.wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Errorf("test %q: failed to read the frame back: %v", tt.name, err) + continue + } + if !reflect.DeepEqual(f, tt.wantFrame) { + t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame) + } + } +} + +func TestWriteContinuation(t *testing.T) { + const streamID = 42 + tests := []struct { + name string + end bool + frag []byte + + wantFrame *ContinuationFrame + }{ + { + "not end", + false, + []byte("abc"), + &ContinuationFrame{ + FrameHeader: FrameHeader{ + valid: true, + StreamID: streamID, + Type: FrameContinuation, + Length: uint32(len("abc")), + }, + headerFragBuf: []byte("abc"), + }, + }, + { + "end", + true, + []byte("def"), + &ContinuationFrame{ + FrameHeader: FrameHeader{ + valid: true, + StreamID: streamID, + Type: FrameContinuation, + Flags: FlagContinuationEndHeaders, + Length: uint32(len("def")), + }, + headerFragBuf: []byte("def"), + }, + }, + } + for _, tt := range tests { + fr, _ := testFramer() + if err := fr.WriteContinuation(streamID, tt.end, tt.frag); err != nil { + t.Errorf("test %q: %v", tt.name, err) + continue + } + f, err := fr.ReadFrame() + if err != nil { + t.Errorf("test %q: failed to read the frame back: %v", tt.name, err) + continue + } + if !reflect.DeepEqual(f, tt.wantFrame) { + t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame) + } + } +} + +func TestWritePriority(t *testing.T) { + const streamID = 42 + tests := []struct { + name string + priority PriorityParam + wantFrame *PriorityFrame + }{ + { + "not exclusive", + PriorityParam{ + StreamDep: 2, + Exclusive: false, + Weight: 127, + }, + &PriorityFrame{ + FrameHeader{ + valid: true, + StreamID: streamID, + Type: FramePriority, + Length: 5, + }, + PriorityParam{ + StreamDep: 2, + Exclusive: false, + Weight: 127, + }, + }, + }, + + { + "exclusive", + PriorityParam{ + StreamDep: 3, + Exclusive: true, + Weight: 77, + }, + &PriorityFrame{ + FrameHeader{ + valid: true, + StreamID: streamID, + Type: FramePriority, + Length: 5, + }, + PriorityParam{ + StreamDep: 3, + Exclusive: true, + Weight: 77, + }, + }, + }, + } + for _, tt := range tests { + fr, _ := testFramer() + if err := fr.WritePriority(streamID, tt.priority); err != nil { + t.Errorf("test %q: %v", tt.name, err) + continue + } + f, err := fr.ReadFrame() + if err != nil { + t.Errorf("test %q: failed to read the frame back: %v", tt.name, err) + continue + } + if !reflect.DeepEqual(f, tt.wantFrame) { + t.Errorf("test %q: mismatch.\n got: %#v\nwant: %#v\n", tt.name, f, tt.wantFrame) + } + } +} + +func TestWriteSettings(t *testing.T) { + fr, buf := testFramer() + settings := []Setting{{1, 2}, {3, 4}} + fr.WriteSettings(settings...) + const wantEnc = "\x00\x00\f\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x03\x00\x00\x00\x04" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + sf, ok := f.(*SettingsFrame) + if !ok { + t.Fatalf("Got a %T; want a SettingsFrame", f) + } + var got []Setting + sf.ForeachSetting(func(s Setting) error { + got = append(got, s) + valBack, ok := sf.Value(s.ID) + if !ok || valBack != s.Val { + t.Errorf("Value(%d) = %v, %v; want %v, true", s.ID, valBack, ok, s.Val) + } + return nil + }) + if !reflect.DeepEqual(settings, got) { + t.Errorf("Read settings %+v != written settings %+v", got, settings) + } +} + +func TestWriteSettingsAck(t *testing.T) { + fr, buf := testFramer() + fr.WriteSettingsAck() + const wantEnc = "\x00\x00\x00\x04\x01\x00\x00\x00\x00" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } +} + +func TestWriteWindowUpdate(t *testing.T) { + fr, buf := testFramer() + const streamID = 1<<24 + 2<<16 + 3<<8 + 4 + const incr = 7<<24 + 6<<16 + 5<<8 + 4 + if err := fr.WriteWindowUpdate(streamID, incr); err != nil { + t.Fatal(err) + } + const wantEnc = "\x00\x00\x04\x08\x00\x01\x02\x03\x04\x07\x06\x05\x04" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + want := &WindowUpdateFrame{ + FrameHeader: FrameHeader{ + valid: true, + Type: 0x8, + Flags: 0x0, + Length: 0x4, + StreamID: 0x1020304, + }, + Increment: 0x7060504, + } + if !reflect.DeepEqual(f, want) { + t.Errorf("parsed back %#v; want %#v", f, want) + } +} + +func TestWritePing(t *testing.T) { testWritePing(t, false) } +func TestWritePingAck(t *testing.T) { testWritePing(t, true) } + +func testWritePing(t *testing.T, ack bool) { + fr, buf := testFramer() + if err := fr.WritePing(ack, [8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { + t.Fatal(err) + } + var wantFlags Flags + if ack { + wantFlags = FlagPingAck + } + var wantEnc = "\x00\x00\x08\x06" + string(wantFlags) + "\x00\x00\x00\x00" + "\x01\x02\x03\x04\x05\x06\x07\x08" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + want := &PingFrame{ + FrameHeader: FrameHeader{ + valid: true, + Type: 0x6, + Flags: wantFlags, + Length: 0x8, + StreamID: 0, + }, + Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}, + } + if !reflect.DeepEqual(f, want) { + t.Errorf("parsed back %#v; want %#v", f, want) + } +} + +func TestReadFrameHeader(t *testing.T) { + tests := []struct { + in string + want FrameHeader + }{ + {in: "\x00\x00\x00" + "\x00" + "\x00" + "\x00\x00\x00\x00", want: FrameHeader{}}, + {in: "\x01\x02\x03" + "\x04" + "\x05" + "\x06\x07\x08\x09", want: FrameHeader{ + Length: 66051, Type: 4, Flags: 5, StreamID: 101124105, + }}, + // Ignore high bit: + {in: "\xff\xff\xff" + "\xff" + "\xff" + "\xff\xff\xff\xff", want: FrameHeader{ + Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}}, + {in: "\xff\xff\xff" + "\xff" + "\xff" + "\x7f\xff\xff\xff", want: FrameHeader{ + Length: 16777215, Type: 255, Flags: 255, StreamID: 2147483647}}, + } + for i, tt := range tests { + got, err := readFrameHeader(make([]byte, 9), strings.NewReader(tt.in)) + if err != nil { + t.Errorf("%d. readFrameHeader(%q) = %v", i, tt.in, err) + continue + } + tt.want.valid = true + if got != tt.want { + t.Errorf("%d. readFrameHeader(%q) = %+v; want %+v", i, tt.in, got, tt.want) + } + } +} + +func TestReadWriteFrameHeader(t *testing.T) { + tests := []struct { + len uint32 + typ FrameType + flags Flags + streamID uint32 + }{ + {len: 0, typ: 255, flags: 1, streamID: 0}, + {len: 0, typ: 255, flags: 1, streamID: 1}, + {len: 0, typ: 255, flags: 1, streamID: 255}, + {len: 0, typ: 255, flags: 1, streamID: 256}, + {len: 0, typ: 255, flags: 1, streamID: 65535}, + {len: 0, typ: 255, flags: 1, streamID: 65536}, + + {len: 0, typ: 1, flags: 255, streamID: 1}, + {len: 255, typ: 1, flags: 255, streamID: 1}, + {len: 256, typ: 1, flags: 255, streamID: 1}, + {len: 65535, typ: 1, flags: 255, streamID: 1}, + {len: 65536, typ: 1, flags: 255, streamID: 1}, + {len: 16777215, typ: 1, flags: 255, streamID: 1}, + } + for _, tt := range tests { + fr, buf := testFramer() + fr.startWrite(tt.typ, tt.flags, tt.streamID) + fr.writeBytes(make([]byte, tt.len)) + fr.endWrite() + fh, err := ReadFrameHeader(buf) + if err != nil { + t.Errorf("ReadFrameHeader(%+v) = %v", tt, err) + continue + } + if fh.Type != tt.typ || fh.Flags != tt.flags || fh.Length != tt.len || fh.StreamID != tt.streamID { + t.Errorf("ReadFrameHeader(%+v) = %+v; mismatch", tt, fh) + } + } + +} + +func TestWriteTooLargeFrame(t *testing.T) { + fr, _ := testFramer() + fr.startWrite(0, 1, 1) + fr.writeBytes(make([]byte, 1<<24)) + err := fr.endWrite() + if err != ErrFrameTooLarge { + t.Errorf("endWrite = %v; want errFrameTooLarge", err) + } +} + +func TestWriteGoAway(t *testing.T) { + const debug = "foo" + fr, buf := testFramer() + if err := fr.WriteGoAway(0x01020304, 0x05060708, []byte(debug)); err != nil { + t.Fatal(err) + } + const wantEnc = "\x00\x00\v\a\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08" + debug + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + want := &GoAwayFrame{ + FrameHeader: FrameHeader{ + valid: true, + Type: 0x7, + Flags: 0, + Length: uint32(4 + 4 + len(debug)), + StreamID: 0, + }, + LastStreamID: 0x01020304, + ErrCode: 0x05060708, + debugData: []byte(debug), + } + if !reflect.DeepEqual(f, want) { + t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want) + } + if got := string(f.(*GoAwayFrame).DebugData()); got != debug { + t.Errorf("debug data = %q; want %q", got, debug) + } +} + +func TestWritePushPromise(t *testing.T) { + pp := PushPromiseParam{ + StreamID: 42, + PromiseID: 42, + BlockFragment: []byte("abc"), + } + fr, buf := testFramer() + if err := fr.WritePushPromise(pp); err != nil { + t.Fatal(err) + } + const wantEnc = "\x00\x00\x07\x05\x00\x00\x00\x00*\x00\x00\x00*abc" + if buf.String() != wantEnc { + t.Errorf("encoded as %q; want %q", buf.Bytes(), wantEnc) + } + f, err := fr.ReadFrame() + if err != nil { + t.Fatal(err) + } + _, ok := f.(*PushPromiseFrame) + if !ok { + t.Fatalf("got %T; want *PushPromiseFrame", f) + } + want := &PushPromiseFrame{ + FrameHeader: FrameHeader{ + valid: true, + Type: 0x5, + Flags: 0x0, + Length: 0x7, + StreamID: 42, + }, + PromiseID: 42, + headerFragBuf: []byte("abc"), + } + if !reflect.DeepEqual(f, want) { + t.Fatalf("parsed back:\n%#v\nwant:\n%#v", f, want) + } +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/gotrack.go b/Godeps/_workspace/src/github.com/bradfitz/http2/gotrack.go new file mode 100644 index 0000000..7dc2ef9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/gotrack.go @@ -0,0 +1,173 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +// Defensive debug-only utility to track that functions run on the +// goroutine that they're supposed to. + +package http2 + +import ( + "bytes" + "errors" + "fmt" + "os" + "runtime" + "strconv" + "sync" +) + +var DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" + +type goroutineLock uint64 + +func newGoroutineLock() goroutineLock { + if !DebugGoroutines { + return 0 + } + return goroutineLock(curGoroutineID()) +} + +func (g goroutineLock) check() { + if !DebugGoroutines { + return + } + if curGoroutineID() != uint64(g) { + panic("running on the wrong goroutine") + } +} + +func (g goroutineLock) checkNotOn() { + if !DebugGoroutines { + return + } + if curGoroutineID() == uint64(g) { + panic("running on the wrong goroutine") + } +} + +var goroutineSpace = []byte("goroutine ") + +func curGoroutineID() uint64 { + bp := littleBuf.Get().(*[]byte) + defer littleBuf.Put(bp) + b := *bp + b = b[:runtime.Stack(b, false)] + // Parse the 4707 out of "goroutine 4707 [" + b = bytes.TrimPrefix(b, goroutineSpace) + i := bytes.IndexByte(b, ' ') + if i < 0 { + panic(fmt.Sprintf("No space found in %q", b)) + } + b = b[:i] + n, err := parseUintBytes(b, 10, 64) + if err != nil { + panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err)) + } + return n +} + +var littleBuf = sync.Pool{ + New: func() interface{} { + buf := make([]byte, 64) + return &buf + }, +} + +// parseUintBytes is like strconv.ParseUint, but using a []byte. +func parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) { + var cutoff, maxVal uint64 + + if bitSize == 0 { + bitSize = int(strconv.IntSize) + } + + s0 := s + switch { + case len(s) < 1: + err = strconv.ErrSyntax + goto Error + + case 2 <= base && base <= 36: + // valid base; nothing to do + + case base == 0: + // Look for octal, hex prefix. + switch { + case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'): + base = 16 + s = s[2:] + if len(s) < 1 { + err = strconv.ErrSyntax + goto Error + } + case s[0] == '0': + base = 8 + default: + base = 10 + } + + default: + err = errors.New("invalid base " + strconv.Itoa(base)) + goto Error + } + + n = 0 + cutoff = cutoff64(base) + maxVal = 1<= base { + n = 0 + err = strconv.ErrSyntax + goto Error + } + + if n >= cutoff { + // n*base overflows + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n *= uint64(base) + + n1 := n + uint64(v) + if n1 < n || n1 > maxVal { + // n+v overflows + n = 1<<64 - 1 + err = strconv.ErrRange + goto Error + } + n = n1 + } + + return n, nil + +Error: + return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err} +} + +// Return the first number n such that n*base >= 1<<64. +func cutoff64(base int) uint64 { + if base < 2 { + return 0 + } + return (1<<64-1)/uint64(base) + 1 +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/gotrack_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/gotrack_test.go new file mode 100644 index 0000000..3ff4957 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/gotrack_test.go @@ -0,0 +1,33 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "fmt" + "strings" + "testing" +) + +func TestGoroutineLock(t *testing.T) { + DebugGoroutines = true + g := newGoroutineLock() + g.check() + + sawPanic := make(chan interface{}) + go func() { + defer func() { sawPanic <- recover() }() + g.check() // should panic + }() + e := <-sawPanic + if e == nil { + t.Fatal("did not see panic from check in other goroutine") + } + if !strings.Contains(fmt.Sprint(e), "wrong goroutine") { + t.Errorf("expected on see panic about running on the wrong goroutine; got %v", e) + } +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/.gitignore b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/.gitignore new file mode 100644 index 0000000..0de86dd --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/.gitignore @@ -0,0 +1,5 @@ +h2demo +h2demo.linux +client-id.dat +client-secret.dat +token.dat diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/Makefile b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/Makefile new file mode 100644 index 0000000..3a77cf0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/Makefile @@ -0,0 +1,5 @@ +h2demo.linux: h2demo.go + GOOS=linux go build --tags=h2demo -o h2demo.linux . + +upload: h2demo.linux + cat h2demo.linux | go run launch.go --write_object=http2-demo-server-tls/h2demo --write_object_is_public diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/README b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/README new file mode 100644 index 0000000..212a96f --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/README @@ -0,0 +1,16 @@ + +Client: + -- Firefox nightly with about:config network.http.spdy.enabled.http2draft set true + -- Chrome: go to chrome://flags/#enable-spdy4, save and restart (button at bottom) + +Make CA: +$ openssl genrsa -out rootCA.key 2048 +$ openssl req -x509 -new -nodes -key rootCA.key -days 1024 -out rootCA.pem +... install that to Firefox + +Make cert: +$ openssl genrsa -out server.key 2048 +$ openssl req -new -key server.key -out server.csr +$ openssl x509 -req -in server.csr -CA rootCA.pem -CAkey rootCA.key -CAcreateserial -out server.crt -days 500 + + diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/h2demo.go b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/h2demo.go new file mode 100644 index 0000000..de6caaa --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/h2demo.go @@ -0,0 +1,426 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +// +build h2demo + +package main + +import ( + "bytes" + "crypto/tls" + "flag" + "fmt" + "hash/crc32" + "image" + "image/jpeg" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "os/exec" + "path" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "camlistore.org/pkg/googlestorage" + "camlistore.org/pkg/singleflight" + "github.com/bradfitz/http2" +) + +var ( + openFirefox = flag.Bool("openff", false, "Open Firefox") + addr = flag.String("addr", "localhost:4430", "TLS address to listen on") + httpAddr = flag.String("httpaddr", "", "If non-empty, address to listen for regular HTTP on") + prod = flag.Bool("prod", false, "Whether to configure itself to be the production http2.golang.org server.") +) + +func homeOldHTTP(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, ` + +

Go + HTTP/2

+

Welcome to the Go language's HTTP/2 demo & interop server.

+

Unfortunately, you're not using HTTP/2 right now. To do so:

+
    +
  • Use Firefox Nightly or go to about:config and enable "network.http.spdy.enabled.http2draft"
  • +
  • Use Google Chrome Canary and/or go to chrome://flags/#enable-spdy4 to Enable SPDY/4 (Chrome's name for HTTP/2)
  • +
+

See code & instructions for connecting at https://github.com/bradfitz/http2.

+ +`) +} + +func home(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + io.WriteString(w, ` + +

Go + HTTP/2

+ +

Welcome to the Go language's HTTP/2 demo & interop server.

+ +

Congratulations, you're using HTTP/2 right now.

+ +

This server exists for others in the HTTP/2 community to test their HTTP/2 client implementations and point out flaws in our server.

+ +

The code is currently at github.com/bradfitz/http2 +but will move to the Go standard library at some point in the future +(enabled by default, without users needing to change their code).

+ +

Contact info: bradfitz@golang.org, or file a bug.

+ +

Handlers for testing

+
    +
  • GET /reqinfo to dump the request + headers received
  • +
  • GET /clockstream streams the current time every second
  • +
  • GET /gophertiles to see a page with a bunch of images
  • +
  • GET /file/gopher.png for a small file (does If-Modified-Since, Content-Range, etc)
  • +
  • GET /file/go.src.tar.gz for a larger file (~10 MB)
  • +
  • GET /redirect to redirect back to / (this page)
  • +
  • GET /goroutines to see all active goroutines in this server
  • +
  • PUT something to /crc32 to get a count of number of bytes and its CRC-32
  • +
+ +`) +} + +func reqInfoHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprintf(w, "Method: %s\n", r.Method) + fmt.Fprintf(w, "Protocol: %s\n", r.Proto) + fmt.Fprintf(w, "Host: %s\n", r.Host) + fmt.Fprintf(w, "RemoteAddr: %s\n", r.RemoteAddr) + fmt.Fprintf(w, "RequestURI: %q\n", r.RequestURI) + fmt.Fprintf(w, "URL: %#v\n", r.URL) + fmt.Fprintf(w, "Body.ContentLength: %d (-1 means unknown)\n", r.ContentLength) + fmt.Fprintf(w, "Close: %v (relevant for HTTP/1 only)\n", r.Close) + fmt.Fprintf(w, "TLS: %#v\n", r.TLS) + fmt.Fprintf(w, "\nHeaders:\n") + r.Header.Write(w) +} + +func crcHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != "PUT" { + http.Error(w, "PUT required.", 400) + return + } + crc := crc32.NewIEEE() + n, err := io.Copy(crc, r.Body) + if err == nil { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprintf(w, "bytes=%d, CRC32=%x", n, crc.Sum(nil)) + } +} + +var ( + fsGrp singleflight.Group + fsMu sync.Mutex // guards fsCache + fsCache = map[string]http.Handler{} +) + +// fileServer returns a file-serving handler that proxies URL. +// It lazily fetches URL on the first access and caches its contents forever. +func fileServer(url string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hi, err := fsGrp.Do(url, func() (interface{}, error) { + fsMu.Lock() + if h, ok := fsCache[url]; ok { + fsMu.Unlock() + return h, nil + } + fsMu.Unlock() + + res, err := http.Get(url) + if err != nil { + return nil, err + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + return nil, err + } + + modTime := time.Now() + var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.ServeContent(w, r, path.Base(url), modTime, bytes.NewReader(slurp)) + }) + fsMu.Lock() + fsCache[url] = h + fsMu.Unlock() + return h, nil + }) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + hi.(http.Handler).ServeHTTP(w, r) + }) +} + +func clockStreamHandler(w http.ResponseWriter, r *http.Request) { + clientGone := w.(http.CloseNotifier).CloseNotify() + w.Header().Set("Content-Type", "text/plain") + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + fmt.Fprintf(w, "# ~1KB of junk to force browsers to start rendering immediately: \n") + io.WriteString(w, strings.Repeat("# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n", 13)) + + for { + fmt.Fprintf(w, "%v\n", time.Now()) + w.(http.Flusher).Flush() + select { + case <-ticker.C: + case <-clientGone: + log.Printf("Client %v disconnected from the clock", r.RemoteAddr) + return + } + } +} + +func registerHandlers() { + tiles := newGopherTilesHandler() + + mux2 := http.NewServeMux() + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil { + if r.URL.Path == "/gophertiles" { + tiles.ServeHTTP(w, r) + return + } + http.Redirect(w, r, "https://http2.golang.org/", http.StatusFound) + return + } + if r.ProtoMajor == 1 { + if r.URL.Path == "/reqinfo" { + reqInfoHandler(w, r) + return + } + homeOldHTTP(w, r) + return + } + mux2.ServeHTTP(w, r) + }) + mux2.HandleFunc("/", home) + mux2.Handle("/file/gopher.png", fileServer("https://golang.org/doc/gopher/frontpage.png")) + mux2.Handle("/file/go.src.tar.gz", fileServer("https://storage.googleapis.com/golang/go1.4.1.src.tar.gz")) + mux2.HandleFunc("/reqinfo", reqInfoHandler) + mux2.HandleFunc("/crc32", crcHandler) + mux2.HandleFunc("/clockstream", clockStreamHandler) + mux2.Handle("/gophertiles", tiles) + mux2.HandleFunc("/redirect", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/", http.StatusFound) + }) + stripHomedir := regexp.MustCompile(`/(Users|home)/\w+`) + mux2.HandleFunc("/goroutines", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + buf := make([]byte, 2<<20) + w.Write(stripHomedir.ReplaceAll(buf[:runtime.Stack(buf, true)], nil)) + }) +} + +func newGopherTilesHandler() http.Handler { + const gopherURL = "https://blog.golang.org/go-programming-language-turns-two_gophers.jpg" + res, err := http.Get(gopherURL) + if err != nil { + log.Fatal(err) + } + if res.StatusCode != 200 { + log.Fatalf("Error fetching %s: %v", gopherURL, res.Status) + } + slurp, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + im, err := jpeg.Decode(bytes.NewReader(slurp)) + if err != nil { + if len(slurp) > 1024 { + slurp = slurp[:1024] + } + log.Fatalf("Failed to decode gopher image: %v (got %q)", err, slurp) + } + + type subImager interface { + SubImage(image.Rectangle) image.Image + } + const tileSize = 32 + xt := im.Bounds().Max.X / tileSize + yt := im.Bounds().Max.Y / tileSize + var tile [][][]byte // y -> x -> jpeg bytes + for yi := 0; yi < yt; yi++ { + var row [][]byte + for xi := 0; xi < xt; xi++ { + si := im.(subImager).SubImage(image.Rectangle{ + Min: image.Point{xi * tileSize, yi * tileSize}, + Max: image.Point{(xi + 1) * tileSize, (yi + 1) * tileSize}, + }) + buf := new(bytes.Buffer) + if err := jpeg.Encode(buf, si, &jpeg.Options{Quality: 90}); err != nil { + log.Fatal(err) + } + row = append(row, buf.Bytes()) + } + tile = append(tile, row) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ms, _ := strconv.Atoi(r.FormValue("latency")) + const nanosPerMilli = 1e6 + if r.FormValue("x") != "" { + x, _ := strconv.Atoi(r.FormValue("x")) + y, _ := strconv.Atoi(r.FormValue("y")) + if ms <= 1000 { + time.Sleep(time.Duration(ms) * nanosPerMilli) + } + if x >= 0 && x < xt && y >= 0 && y < yt { + http.ServeContent(w, r, "", time.Time{}, bytes.NewReader(tile[y][x])) + return + } + } + io.WriteString(w, "") + fmt.Fprintf(w, "A grid of %d tiled images is below. Compare:

", xt*yt) + for _, ms := range []int{0, 30, 200, 1000} { + d := time.Duration(ms) * nanosPerMilli + fmt.Fprintf(w, "[HTTP/2, %v latency] [HTTP/1, %v latency]
\n", + httpsHost(), ms, d, + httpHost(), ms, d, + ) + } + io.WriteString(w, "

\n") + cacheBust := time.Now().UnixNano() + for y := 0; y < yt; y++ { + for x := 0; x < xt; x++ { + fmt.Fprintf(w, "", + tileSize, tileSize, x, y, cacheBust, ms) + } + io.WriteString(w, "
\n") + } + io.WriteString(w, "


<< Back to Go HTTP/2 demo server") + }) +} + +func httpsHost() string { + if *prod { + return "http2.golang.org" + } + if v := *addr; strings.HasPrefix(v, ":") { + return "localhost" + v + } else { + return v + } +} + +func httpHost() string { + if *prod { + return "http2.golang.org" + } + if v := *httpAddr; strings.HasPrefix(v, ":") { + return "localhost" + v + } else { + return v + } +} + +func serveProdTLS() error { + c, err := googlestorage.NewServiceClient() + if err != nil { + return err + } + slurp := func(key string) ([]byte, error) { + const bucket = "http2-demo-server-tls" + rc, _, err := c.GetObject(&googlestorage.Object{ + Bucket: bucket, + Key: key, + }) + if err != nil { + return nil, fmt.Errorf("Error fetching GCS object %q in bucket %q: %v", key, bucket, err) + } + defer rc.Close() + return ioutil.ReadAll(rc) + } + certPem, err := slurp("http2.golang.org.chained.pem") + if err != nil { + return err + } + keyPem, err := slurp("http2.golang.org.key") + if err != nil { + return err + } + cert, err := tls.X509KeyPair(certPem, keyPem) + if err != nil { + return err + } + srv := &http.Server{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + } + http2.ConfigureServer(srv, &http2.Server{}) + ln, err := net.Listen("tcp", ":443") + if err != nil { + return err + } + return srv.Serve(tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, srv.TLSConfig)) +} + +type tcpKeepAliveListener struct { + *net.TCPListener +} + +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +func serveProd() error { + errc := make(chan error, 2) + go func() { errc <- http.ListenAndServe(":80", nil) }() + go func() { errc <- serveProdTLS() }() + return <-errc +} + +func main() { + var srv http.Server + flag.BoolVar(&http2.VerboseLogs, "verbose", false, "Verbose HTTP/2 debugging.") + flag.Parse() + srv.Addr = *addr + + registerHandlers() + + if *prod { + *httpAddr = "http2.golang.org" + log.Fatal(serveProd()) + } + + url := "https://" + *addr + "/" + log.Printf("Listening on " + url) + http2.ConfigureServer(&srv, &http2.Server{}) + + if *httpAddr != "" { + go func() { log.Fatal(http.ListenAndServe(*httpAddr, nil)) }() + } + + go func() { + log.Fatal(srv.ListenAndServeTLS("server.crt", "server.key")) + }() + if *openFirefox && runtime.GOOS == "darwin" { + time.Sleep(250 * time.Millisecond) + exec.Command("open", "-b", "org.mozilla.nightly", "https://localhost:4430/").Run() + } + select {} +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/launch.go b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/launch.go new file mode 100644 index 0000000..7466615 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/launch.go @@ -0,0 +1,302 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package main + +import ( + "bufio" + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "os" + "strings" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + compute "google.golang.org/api/compute/v1" +) + +var ( + proj = flag.String("project", "symbolic-datum-552", "name of Project") + zone = flag.String("zone", "us-central1-a", "GCE zone") + mach = flag.String("machinetype", "n1-standard-1", "Machine type") + instName = flag.String("instance_name", "http2-demo", "Name of VM instance.") + sshPub = flag.String("ssh_public_key", "", "ssh public key file to authorize. Can modify later in Google's web UI anyway.") + staticIP = flag.String("static_ip", "130.211.116.44", "Static IP to use. If empty, automatic.") + + writeObject = flag.String("write_object", "", "If non-empty, a VM isn't created and the flag value is Google Cloud Storage bucket/object to write. The contents from stdin.") + publicObject = flag.Bool("write_object_is_public", false, "Whether the object created by --write_object should be public.") +) + +func readFile(v string) string { + slurp, err := ioutil.ReadFile(v) + if err != nil { + log.Fatalf("Error reading %s: %v", v, err) + } + return strings.TrimSpace(string(slurp)) +} + +var config = &oauth2.Config{ + // The client-id and secret should be for an "Installed Application" when using + // the CLI. Later we'll use a web application with a callback. + ClientID: readFile("client-id.dat"), + ClientSecret: readFile("client-secret.dat"), + Endpoint: google.Endpoint, + Scopes: []string{ + compute.DevstorageFull_controlScope, + compute.ComputeScope, + "https://www.googleapis.com/auth/sqlservice", + "https://www.googleapis.com/auth/sqlservice.admin", + }, + RedirectURL: "urn:ietf:wg:oauth:2.0:oob", +} + +const baseConfig = `#cloud-config +coreos: + units: + - name: h2demo.service + command: start + content: | + [Unit] + Description=HTTP2 Demo + + [Service] + ExecStartPre=/bin/bash -c 'mkdir -p /opt/bin && curl -s -o /opt/bin/h2demo http://storage.googleapis.com/http2-demo-server-tls/h2demo && chmod +x /opt/bin/h2demo' + ExecStart=/opt/bin/h2demo --prod + RestartSec=5s + Restart=always + Type=simple + + [Install] + WantedBy=multi-user.target +` + +func main() { + flag.Parse() + if *proj == "" { + log.Fatalf("Missing --project flag") + } + prefix := "https://www.googleapis.com/compute/v1/projects/" + *proj + machType := prefix + "/zones/" + *zone + "/machineTypes/" + *mach + + const tokenFileName = "token.dat" + tokenFile := tokenCacheFile(tokenFileName) + tokenSource := oauth2.ReuseTokenSource(nil, tokenFile) + token, err := tokenSource.Token() + if err != nil { + if *writeObject != "" { + log.Fatalf("Can't use --write_object without a valid token.dat file already cached.") + } + log.Printf("Error getting token from %s: %v", tokenFileName, err) + log.Printf("Get auth code from %v", config.AuthCodeURL("my-state")) + fmt.Print("\nEnter auth code: ") + sc := bufio.NewScanner(os.Stdin) + sc.Scan() + authCode := strings.TrimSpace(sc.Text()) + token, err = config.Exchange(oauth2.NoContext, authCode) + if err != nil { + log.Fatalf("Error exchanging auth code for a token: %v", err) + } + if err := tokenFile.WriteToken(token); err != nil { + log.Fatalf("Error writing to %s: %v", tokenFileName, err) + } + tokenSource = oauth2.ReuseTokenSource(token, nil) + } + + oauthClient := oauth2.NewClient(oauth2.NoContext, tokenSource) + + if *writeObject != "" { + writeCloudStorageObject(oauthClient) + return + } + + computeService, _ := compute.New(oauthClient) + + natIP := *staticIP + if natIP == "" { + // Try to find it by name. + aggAddrList, err := computeService.Addresses.AggregatedList(*proj).Do() + if err != nil { + log.Fatal(err) + } + // http://godoc.org/code.google.com/p/google-api-go-client/compute/v1#AddressAggregatedList + IPLoop: + for _, asl := range aggAddrList.Items { + for _, addr := range asl.Addresses { + if addr.Name == *instName+"-ip" && addr.Status == "RESERVED" { + natIP = addr.Address + break IPLoop + } + } + } + } + + cloudConfig := baseConfig + if *sshPub != "" { + key := strings.TrimSpace(readFile(*sshPub)) + cloudConfig += fmt.Sprintf("\nssh_authorized_keys:\n - %s\n", key) + } + if os.Getenv("USER") == "bradfitz" { + cloudConfig += fmt.Sprintf("\nssh_authorized_keys:\n - %s\n", "ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAIEAwks9dwWKlRC+73gRbvYtVg0vdCwDSuIlyt4z6xa/YU/jTDynM4R4W10hm2tPjy8iR1k8XhDv4/qdxe6m07NjG/By1tkmGpm1mGwho4Pr5kbAAy/Qg+NLCSdAYnnE00FQEcFOC15GFVMOW2AzDGKisReohwH9eIzHPzdYQNPRWXE= bradfitz@papag.bradfitz.com") + } + const maxCloudConfig = 32 << 10 // per compute API docs + if len(cloudConfig) > maxCloudConfig { + log.Fatalf("cloud config length of %d bytes is over %d byte limit", len(cloudConfig), maxCloudConfig) + } + + instance := &compute.Instance{ + Name: *instName, + Description: "Go Builder", + MachineType: machType, + Disks: []*compute.AttachedDisk{instanceDisk(computeService)}, + Tags: &compute.Tags{ + Items: []string{"http-server", "https-server"}, + }, + Metadata: &compute.Metadata{ + Items: []*compute.MetadataItems{ + { + Key: "user-data", + Value: cloudConfig, + }, + }, + }, + NetworkInterfaces: []*compute.NetworkInterface{ + &compute.NetworkInterface{ + AccessConfigs: []*compute.AccessConfig{ + &compute.AccessConfig{ + Type: "ONE_TO_ONE_NAT", + Name: "External NAT", + NatIP: natIP, + }, + }, + Network: prefix + "/global/networks/default", + }, + }, + ServiceAccounts: []*compute.ServiceAccount{ + { + Email: "default", + Scopes: []string{ + compute.DevstorageFull_controlScope, + compute.ComputeScope, + }, + }, + }, + } + + log.Printf("Creating instance...") + op, err := computeService.Instances.Insert(*proj, *zone, instance).Do() + if err != nil { + log.Fatalf("Failed to create instance: %v", err) + } + opName := op.Name + log.Printf("Created. Waiting on operation %v", opName) +OpLoop: + for { + time.Sleep(2 * time.Second) + op, err := computeService.ZoneOperations.Get(*proj, *zone, opName).Do() + if err != nil { + log.Fatalf("Failed to get op %s: %v", opName, err) + } + switch op.Status { + case "PENDING", "RUNNING": + log.Printf("Waiting on operation %v", opName) + continue + case "DONE": + if op.Error != nil { + for _, operr := range op.Error.Errors { + log.Printf("Error: %+v", operr) + } + log.Fatalf("Failed to start.") + } + log.Printf("Success. %+v", op) + break OpLoop + default: + log.Fatalf("Unknown status %q: %+v", op.Status, op) + } + } + + inst, err := computeService.Instances.Get(*proj, *zone, *instName).Do() + if err != nil { + log.Fatalf("Error getting instance after creation: %v", err) + } + ij, _ := json.MarshalIndent(inst, "", " ") + log.Printf("Instance: %s", ij) +} + +func instanceDisk(svc *compute.Service) *compute.AttachedDisk { + const imageURL = "https://www.googleapis.com/compute/v1/projects/coreos-cloud/global/images/coreos-stable-444-5-0-v20141016" + diskName := *instName + "-disk" + + return &compute.AttachedDisk{ + AutoDelete: true, + Boot: true, + Type: "PERSISTENT", + InitializeParams: &compute.AttachedDiskInitializeParams{ + DiskName: diskName, + SourceImage: imageURL, + DiskSizeGb: 50, + }, + } +} + +func writeCloudStorageObject(httpClient *http.Client) { + content := os.Stdin + const maxSlurp = 1 << 20 + var buf bytes.Buffer + n, err := io.CopyN(&buf, content, maxSlurp) + if err != nil && err != io.EOF { + log.Fatalf("Error reading from stdin: %v, %v", n, err) + } + contentType := http.DetectContentType(buf.Bytes()) + + req, err := http.NewRequest("PUT", "https://storage.googleapis.com/"+*writeObject, io.MultiReader(&buf, content)) + if err != nil { + log.Fatal(err) + } + req.Header.Set("x-goog-api-version", "2") + if *publicObject { + req.Header.Set("x-goog-acl", "public-read") + } + req.Header.Set("Content-Type", contentType) + res, err := httpClient.Do(req) + if err != nil { + log.Fatal(err) + } + if res.StatusCode != 200 { + res.Write(os.Stderr) + log.Fatalf("Failed.") + } + log.Printf("Success.") + os.Exit(0) +} + +type tokenCacheFile string + +func (f tokenCacheFile) Token() (*oauth2.Token, error) { + slurp, err := ioutil.ReadFile(string(f)) + if err != nil { + return nil, err + } + t := new(oauth2.Token) + if err := json.Unmarshal(slurp, t); err != nil { + return nil, err + } + return t, nil +} + +func (f tokenCacheFile) WriteToken(t *oauth2.Token) error { + jt, err := json.Marshal(t) + if err != nil { + return err + } + return ioutil.WriteFile(string(f), jt, 0600) +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/rootCA.key b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/rootCA.key new file mode 100644 index 0000000..a15a6ab --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/rootCA.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAt5fAjp4fTcekWUTfzsp0kyih1OYbsGL0KX1eRbSSR8Od0+9Q +62Hyny+GFwMTb4A/KU8mssoHvcceSAAbwfbxFK/+s51TobqUnORZrOoTZjkUygby +XDSK99YBbcR1Pip8vwMTm4XKuLtCigeBBdjjAQdgUO28LENGlsMnmeYkJfODVGnV +mr5Ltb9ANA8IKyTfsnHJ4iOCS/PlPbUj2q7YnoVLposUBMlgUb/CykX3mOoLb4yJ +JQyA/iST6ZxiIEj36D4yWZ5lg7YJl+UiiBQHGCnPdGyipqV06ex0heYWcaiW8LWZ +SUQ93jQ+WVCH8hT7DQO1dmsvUmXlq/JeAlwQ/QIDAQABAoIBAFFHV7JMAqPWnMYA +nezY6J81v9+XN+7xABNWM2Q8uv4WdksbigGLTXR3/680Z2hXqJ7LMeC5XJACFT/e +/Gr0vmpgOCygnCPfjGehGKpavtfksXV3edikUlnCXsOP1C//c1bFL+sMYmFCVgTx +qYdDK8yKzXNGrKYT6q5YG7IglyRNV1rsQa8lM/5taFYiD1Ck/3tQi3YIq8Lcuser +hrxsMABcQ6mi+EIvG6Xr4mfJug0dGJMHG4RG1UGFQn6RXrQq2+q53fC8ZbVUSi0j +NQ918aKFzktwv+DouKU0ME4I9toks03gM860bAL7zCbKGmwR3hfgX/TqzVCWpG9E +LDVfvekCgYEA8fk9N53jbBRmULUGEf4qWypcLGiZnNU0OeXWpbPV9aa3H0VDytA7 +8fCN2dPAVDPqlthMDdVe983NCNwp2Yo8ZimDgowyIAKhdC25s1kejuaiH9OAPj3c +0f8KbriYX4n8zNHxFwK6Ae3pQ6EqOLJVCUsziUaZX9nyKY5aZlyX6xcCgYEAwjws +K62PjC64U5wYddNLp+kNdJ4edx+a7qBb3mEgPvSFT2RO3/xafJyG8kQB30Mfstjd +bRxyUV6N0vtX1zA7VQtRUAvfGCecpMo+VQZzcHXKzoRTnQ7eZg4Lmj5fQ9tOAKAo +QCVBoSW/DI4PZL26CAMDcAba4Pa22ooLapoRIQsCgYA6pIfkkbxLNkpxpt2YwLtt +Kr/590O7UaR9n6k8sW/aQBRDXNsILR1KDl2ifAIxpf9lnXgZJiwE7HiTfCAcW7c1 +nzwDCI0hWuHcMTS/NYsFYPnLsstyyjVZI3FY0h4DkYKV9Q9z3zJLQ2hz/nwoD3gy +b2pHC7giFcTts1VPV4Nt8wKBgHeFn4ihHJweg76vZz3Z78w7VNRWGFklUalVdDK7 +gaQ7w2y/ROn/146mo0OhJaXFIFRlrpvdzVrU3GDf2YXJYDlM5ZRkObwbZADjksev +WInzcgDy3KDg7WnPasRXbTfMU4t/AkW2p1QKbi3DnSVYuokDkbH2Beo45vxDxhKr +C69RAoGBAIyo3+OJenoZmoNzNJl2WPW5MeBUzSh8T/bgyjFTdqFHF5WiYRD/lfHj +x9Glyw2nutuT4hlOqHvKhgTYdDMsF2oQ72fe3v8Q5FU7FuKndNPEAyvKNXZaShVA +hnlhv5DjXKb0wFWnt5PCCiQLtzG0yyHaITrrEme7FikkIcTxaX/Y +-----END RSA PRIVATE KEY----- diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/rootCA.pem b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/rootCA.pem new file mode 100644 index 0000000..3a323e7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/rootCA.pem @@ -0,0 +1,26 @@ +-----BEGIN CERTIFICATE----- +MIIEWjCCA0KgAwIBAgIJALfRlWsI8YQHMA0GCSqGSIb3DQEBBQUAMHsxCzAJBgNV +BAYTAlVTMQswCQYDVQQIEwJDQTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEUMBIG +A1UEChMLQnJhZGZpdHppbmMxEjAQBgNVBAMTCWxvY2FsaG9zdDEdMBsGCSqGSIb3 +DQEJARYOYnJhZEBkYW5nYS5jb20wHhcNMTQwNzE1MjA0NjA1WhcNMTcwNTA0MjA0 +NjA1WjB7MQswCQYDVQQGEwJVUzELMAkGA1UECBMCQ0ExFjAUBgNVBAcTDVNhbiBG +cmFuY2lzY28xFDASBgNVBAoTC0JyYWRmaXR6aW5jMRIwEAYDVQQDEwlsb2NhbGhv +c3QxHTAbBgkqhkiG9w0BCQEWDmJyYWRAZGFuZ2EuY29tMIIBIjANBgkqhkiG9w0B +AQEFAAOCAQ8AMIIBCgKCAQEAt5fAjp4fTcekWUTfzsp0kyih1OYbsGL0KX1eRbSS +R8Od0+9Q62Hyny+GFwMTb4A/KU8mssoHvcceSAAbwfbxFK/+s51TobqUnORZrOoT +ZjkUygbyXDSK99YBbcR1Pip8vwMTm4XKuLtCigeBBdjjAQdgUO28LENGlsMnmeYk +JfODVGnVmr5Ltb9ANA8IKyTfsnHJ4iOCS/PlPbUj2q7YnoVLposUBMlgUb/CykX3 +mOoLb4yJJQyA/iST6ZxiIEj36D4yWZ5lg7YJl+UiiBQHGCnPdGyipqV06ex0heYW +caiW8LWZSUQ93jQ+WVCH8hT7DQO1dmsvUmXlq/JeAlwQ/QIDAQABo4HgMIHdMB0G +A1UdDgQWBBRcAROthS4P4U7vTfjByC569R7E6DCBrQYDVR0jBIGlMIGigBRcAROt +hS4P4U7vTfjByC569R7E6KF/pH0wezELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNB +MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRQwEgYDVQQKEwtCcmFkZml0emluYzES +MBAGA1UEAxMJbG9jYWxob3N0MR0wGwYJKoZIhvcNAQkBFg5icmFkQGRhbmdhLmNv +bYIJALfRlWsI8YQHMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAG6h +U9f9sNH0/6oBbGGy2EVU0UgITUQIrFWo9rFkrW5k/XkDjQm+3lzjT0iGR4IxE/Ao +eU6sQhua7wrWeFEn47GL98lnCsJdD7oZNhFmQ95Tb/LnDUjs5Yj9brP0NWzXfYU4 +UK2ZnINJRcJpB8iRCaCxE8DdcUF0XqIEq6pA272snoLmiXLMvNl3kYEdm+je6voD +58SNVEUsztzQyXmJEhCpwVI0A6QCjzXj+qvpmw3ZZHi8JwXei8ZZBLTSFBki8Z7n +sH9BBH38/SzUmAN4QHSPy1gjqm00OAE8NaYDkh/bzE4d7mLGGMWp/WE3KPSu82HF +kPe6XoSbiLm/kxk32T0= +-----END CERTIFICATE----- diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/rootCA.srl b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/rootCA.srl new file mode 100644 index 0000000..6db3891 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/rootCA.srl @@ -0,0 +1 @@ +E2CE26BF3285059C diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/server.crt b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/server.crt new file mode 100644 index 0000000..c59059b --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/server.crt @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDPjCCAiYCCQDizia/MoUFnDANBgkqhkiG9w0BAQUFADB7MQswCQYDVQQGEwJV +UzELMAkGA1UECBMCQ0ExFjAUBgNVBAcTDVNhbiBGcmFuY2lzY28xFDASBgNVBAoT +C0JyYWRmaXR6aW5jMRIwEAYDVQQDEwlsb2NhbGhvc3QxHTAbBgkqhkiG9w0BCQEW +DmJyYWRAZGFuZ2EuY29tMB4XDTE0MDcxNTIwNTAyN1oXDTE1MTEyNzIwNTAyN1ow +RzELMAkGA1UEBhMCVVMxCzAJBgNVBAgTAkNBMQswCQYDVQQHEwJTRjEeMBwGA1UE +ChMVYnJhZGZpdHogaHR0cDIgc2VydmVyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAs1Y9CyLFrdL8VQWN1WaifDqaZFnoqjHhCMlc1TfG2zA+InDifx2l +gZD3o8FeNnAcfM2sPlk3+ZleOYw9P/CklFVDlvqmpCv9ss/BEp/dDaWvy1LmJ4c2 +dbQJfmTxn7CV1H3TsVJvKdwFmdoABb41NoBp6+NNO7OtDyhbIMiCI0pL3Nefb3HL +A7hIMo3DYbORTtJLTIH9W8YKrEWL0lwHLrYFx/UdutZnv+HjdmO6vCN4na55mjws +/vjKQUmc7xeY7Xe20xDEG2oDKVkL2eD7FfyrYMS3rO1ExP2KSqlXYG/1S9I/fz88 +F0GK7HX55b5WjZCl2J3ERVdnv/0MQv+sYQIDAQABMA0GCSqGSIb3DQEBBQUAA4IB +AQC0zL+n/YpRZOdulSu9tS8FxrstXqGWoxfe+vIUgqfMZ5+0MkjJ/vW0FqlLDl2R +rn4XaR3e7FmWkwdDVbq/UB6lPmoAaFkCgh9/5oapMaclNVNnfF3fjCJfRr+qj/iD +EmJStTIN0ZuUjAlpiACmfnpEU55PafT5Zx+i1yE4FGjw8bJpFoyD4Hnm54nGjX19 +KeCuvcYFUPnBm3lcL0FalF2AjqV02WTHYNQk7YF/oeO7NKBoEgvGvKG3x+xaOeBI +dwvdq175ZsGul30h+QjrRlXhH/twcuaT3GSdoysDl9cCYE8f1Mk8PD6gan3uBCJU +90p6/CbU71bGbfpM2PHot2fm +-----END CERTIFICATE----- diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/server.key b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/server.key new file mode 100644 index 0000000..f329c14 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2demo/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAs1Y9CyLFrdL8VQWN1WaifDqaZFnoqjHhCMlc1TfG2zA+InDi +fx2lgZD3o8FeNnAcfM2sPlk3+ZleOYw9P/CklFVDlvqmpCv9ss/BEp/dDaWvy1Lm +J4c2dbQJfmTxn7CV1H3TsVJvKdwFmdoABb41NoBp6+NNO7OtDyhbIMiCI0pL3Nef +b3HLA7hIMo3DYbORTtJLTIH9W8YKrEWL0lwHLrYFx/UdutZnv+HjdmO6vCN4na55 +mjws/vjKQUmc7xeY7Xe20xDEG2oDKVkL2eD7FfyrYMS3rO1ExP2KSqlXYG/1S9I/ +fz88F0GK7HX55b5WjZCl2J3ERVdnv/0MQv+sYQIDAQABAoIBADQ2spUwbY+bcz4p +3M66ECrNQTBggP40gYl2XyHxGGOu2xhZ94f9ELf1hjRWU2DUKWco1rJcdZClV6q3 +qwmXvcM2Q/SMS8JW0ImkNVl/0/NqPxGatEnj8zY30d/L8hGFb0orzFu/XYA5gCP4 +NbN2WrXgk3ZLeqwcNxHHtSiJWGJ/fPyeDWAu/apy75u9Xf2GlzBZmV6HYD9EfK80 +LTlI60f5FO487CrJnboL7ovPJrIHn+k05xRQqwma4orpz932rTXnTjs9Lg6KtbQN +a7PrqfAntIISgr11a66Mng3IYH1lYqJsWJJwX/xHT4WLEy0EH4/0+PfYemJekz2+ +Co62drECgYEA6O9zVJZXrLSDsIi54cfxA7nEZWm5CAtkYWeAHa4EJ+IlZ7gIf9sL +W8oFcEfFGpvwVqWZ+AsQ70dsjXAv3zXaG0tmg9FtqWp7pzRSMPidifZcQwWkKeTO +gJnFmnVyed8h6GfjTEu4gxo1/S5U0V+mYSha01z5NTnN6ltKx1Or3b0CgYEAxRgm +S30nZxnyg/V7ys61AZhst1DG2tkZXEMcA7dYhabMoXPJAP/EfhlWwpWYYUs/u0gS +Wwmf5IivX5TlYScgmkvb/NYz0u4ZmOXkLTnLPtdKKFXhjXJcHjUP67jYmOxNlJLp +V4vLRnFxTpffAV+OszzRxsXX6fvruwZBANYJeXUCgYBVouLFsFgfWGYp2rpr9XP4 +KK25kvrBqF6JKOIDB1zjxNJ3pUMKrl8oqccCFoCyXa4oTM2kUX0yWxHfleUjrMq4 +yimwQKiOZmV7fVLSSjSw6e/VfBd0h3gb82ygcplZkN0IclkwTY5SNKqwn/3y07V5 +drqdhkrgdJXtmQ6O5YYECQKBgATERcDToQ1USlI4sKrB/wyv1AlG8dg/IebiVJ4e +ZAyvcQmClFzq0qS+FiQUnB/WQw9TeeYrwGs1hxBHuJh16srwhLyDrbMvQP06qh8R +48F8UXXSRec22dV9MQphaROhu2qZdv1AC0WD3tqov6L33aqmEOi+xi8JgbT/PLk5 +c/c1AoGBAI1A/02ryksW6/wc7/6SP2M2rTy4m1sD/GnrTc67EHnRcVBdKO6qH2RY +nqC8YcveC2ZghgPTDsA3VGuzuBXpwY6wTyV99q6jxQJ6/xcrD9/NUG6Uwv/xfCxl +IJLeBYEqQundSSny3VtaAUK8Ul1nxpTvVRNwtcyWTo8RHAAyNPWd +-----END RSA PRIVATE KEY----- diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2i/README.md b/Godeps/_workspace/src/github.com/bradfitz/http2/h2i/README.md new file mode 100644 index 0000000..3ceecb7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2i/README.md @@ -0,0 +1,97 @@ +# h2i + +**h2i** is an interactive HTTP/2 ("h2") console debugger. Miss the good ol' +days of telnetting to your HTTP/1.n servers? We're bringing you +back. + +Features: +- send raw HTTP/2 frames + - PING + - SETTINGS + - HEADERS + - etc +- type in HTTP/1.n and have it auto-HPACK/frame-ify it for HTTP/2 +- pretty print all received HTTP/2 frames from the peer (including HPACK decoding) +- tab completion of commands, options + +Not yet features, but soon: +- unnecessary CONTINUATION frames on short boundaries, to test peer implementations +- request bodies (DATA frames) +- send invalid frames for testing server implementations (supported by underlying Framer) + +Later: +- act like a server + +## Installation + +``` +$ go get github.com/bradfitz/http2/h2i +$ h2i +``` + +## Demo + +``` +$ h2i +Usage: h2i + + -insecure + Whether to skip TLS cert validation + -nextproto string + Comma-separated list of NPN/ALPN protocol names to negotiate. (default "h2,h2-14") + +$ h2i google.com +Connecting to google.com:443 ... +Connected to 74.125.224.41:443 +Negotiated protocol "h2-14" +[FrameHeader SETTINGS len=18] + [MAX_CONCURRENT_STREAMS = 100] + [INITIAL_WINDOW_SIZE = 1048576] + [MAX_FRAME_SIZE = 16384] +[FrameHeader WINDOW_UPDATE len=4] + Window-Increment = 983041 + +h2i> PING h2iSayHI +[FrameHeader PING flags=ACK len=8] + Data = "h2iSayHI" +h2i> headers +(as HTTP/1.1)> GET / HTTP/1.1 +(as HTTP/1.1)> Host: ip.appspot.com +(as HTTP/1.1)> User-Agent: h2i/brad-n-blake +(as HTTP/1.1)> +Opening Stream-ID 1: + :authority = ip.appspot.com + :method = GET + :path = / + :scheme = https + user-agent = h2i/brad-n-blake +[FrameHeader HEADERS flags=END_HEADERS stream=1 len=77] + :status = "200" + alternate-protocol = "443:quic,p=1" + content-length = "15" + content-type = "text/html" + date = "Fri, 01 May 2015 23:06:56 GMT" + server = "Google Frontend" +[FrameHeader DATA flags=END_STREAM stream=1 len=15] + "173.164.155.78\n" +[FrameHeader PING len=8] + Data = "\x00\x00\x00\x00\x00\x00\x00\x00" +h2i> ping +[FrameHeader PING flags=ACK len=8] + Data = "h2i_ping" +h2i> ping +[FrameHeader PING flags=ACK len=8] + Data = "h2i_ping" +h2i> ping +[FrameHeader GOAWAY len=22] + Last-Stream-ID = 1; Error-Code = PROTOCOL_ERROR (1) + +ReadFrame: EOF +``` + +## Status + +Quick few hour hack. So much yet to do. Feel free to file issues for +bugs or wishlist items, but [@bmizerany](https://github.com/bmizerany/) +and I aren't yet accepting pull requests until things settle down. + diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/h2i/h2i.go b/Godeps/_workspace/src/github.com/bradfitz/http2/h2i/h2i.go new file mode 100644 index 0000000..67b874d --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/h2i/h2i.go @@ -0,0 +1,489 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +/* +The h2i command is an interactive HTTP/2 console. + +Usage: + $ h2i [flags] + +Interactive commands in the console: (all parts case-insensitive) + + ping [data] + settings ack + settings FOO=n BAR=z + headers (open a new stream by typing HTTP/1.1) +*/ +package main + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "flag" + "fmt" + "io" + "log" + "net" + "net/http" + "os" + "regexp" + "strconv" + "strings" + + "github.com/bradfitz/http2" + "github.com/bradfitz/http2/hpack" + "golang.org/x/crypto/ssh/terminal" +) + +// Flags +var ( + flagNextProto = flag.String("nextproto", "h2,h2-14", "Comma-separated list of NPN/ALPN protocol names to negotiate.") + flagInsecure = flag.Bool("insecure", false, "Whether to skip TLS cert validation") +) + +type command struct { + run func(*h2i, []string) error // required + + // complete optionally specifies tokens (case-insensitive) which are + // valid for this subcommand. + complete func() []string +} + +var commands = map[string]command{ + "ping": command{run: (*h2i).cmdPing}, + "settings": command{ + run: (*h2i).cmdSettings, + complete: func() []string { + return []string{ + "ACK", + http2.SettingHeaderTableSize.String(), + http2.SettingEnablePush.String(), + http2.SettingMaxConcurrentStreams.String(), + http2.SettingInitialWindowSize.String(), + http2.SettingMaxFrameSize.String(), + http2.SettingMaxHeaderListSize.String(), + } + }, + }, + "quit": command{run: (*h2i).cmdQuit}, + "headers": command{run: (*h2i).cmdHeaders}, +} + +func usage() { + fmt.Fprintf(os.Stderr, "Usage: h2i \n\n") + flag.PrintDefaults() + os.Exit(1) +} + +// withPort adds ":443" if another port isn't already present. +func withPort(host string) string { + if _, _, err := net.SplitHostPort(host); err != nil { + return net.JoinHostPort(host, "443") + } + return host +} + +// h2i is the app's state. +type h2i struct { + host string + tc *tls.Conn + framer *http2.Framer + term *terminal.Terminal + + // owned by the command loop: + streamID uint32 + hbuf bytes.Buffer + henc *hpack.Encoder + + // owned by the readFrames loop: + peerSetting map[http2.SettingID]uint32 + hdec *hpack.Decoder +} + +func main() { + flag.Usage = usage + flag.Parse() + if flag.NArg() != 1 { + usage() + } + log.SetFlags(0) + + host := flag.Arg(0) + app := &h2i{ + host: host, + peerSetting: make(map[http2.SettingID]uint32), + } + app.henc = hpack.NewEncoder(&app.hbuf) + + if err := app.Main(); err != nil { + if app.term != nil { + app.logf("%v\n", err) + } else { + fmt.Fprintf(os.Stderr, "%v\n", err) + } + os.Exit(1) + } + fmt.Fprintf(os.Stdout, "\n") +} + +func (app *h2i) Main() error { + cfg := &tls.Config{ + ServerName: app.host, + NextProtos: strings.Split(*flagNextProto, ","), + InsecureSkipVerify: *flagInsecure, + } + + hostAndPort := withPort(app.host) + log.Printf("Connecting to %s ...", hostAndPort) + tc, err := tls.Dial("tcp", hostAndPort, cfg) + if err != nil { + return fmt.Errorf("Error dialing %s: %v", withPort(app.host), err) + } + log.Printf("Connected to %v", tc.RemoteAddr()) + defer tc.Close() + + if err := tc.Handshake(); err != nil { + return fmt.Errorf("TLS handshake: %v", err) + } + if !*flagInsecure { + if err := tc.VerifyHostname(app.host); err != nil { + return fmt.Errorf("VerifyHostname: %v", err) + } + } + state := tc.ConnectionState() + log.Printf("Negotiated protocol %q", state.NegotiatedProtocol) + if !state.NegotiatedProtocolIsMutual || state.NegotiatedProtocol == "" { + return fmt.Errorf("Could not negotiate protocol mutually") + } + + if _, err := io.WriteString(tc, http2.ClientPreface); err != nil { + return err + } + + app.framer = http2.NewFramer(tc, tc) + + oldState, err := terminal.MakeRaw(0) + if err != nil { + return err + } + defer terminal.Restore(0, oldState) + + var screen = struct { + io.Reader + io.Writer + }{os.Stdin, os.Stdout} + + app.term = terminal.NewTerminal(screen, "h2i> ") + lastWord := regexp.MustCompile(`.+\W(\w+)$`) + app.term.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { + if key != '\t' { + return + } + if pos != len(line) { + // TODO: we're being lazy for now, only supporting tab completion at the end. + return + } + // Auto-complete for the command itself. + if !strings.Contains(line, " ") { + var name string + name, _, ok = lookupCommand(line) + if !ok { + return + } + return name, len(name), true + } + _, c, ok := lookupCommand(line[:strings.IndexByte(line, ' ')]) + if !ok || c.complete == nil { + return + } + if strings.HasSuffix(line, " ") { + app.logf("%s", strings.Join(c.complete(), " ")) + return line, pos, true + } + m := lastWord.FindStringSubmatch(line) + if m == nil { + return line, len(line), true + } + soFar := m[1] + var match []string + for _, cand := range c.complete() { + if len(soFar) > len(cand) || !strings.EqualFold(cand[:len(soFar)], soFar) { + continue + } + match = append(match, cand) + } + if len(match) == 0 { + return + } + if len(match) > 1 { + // TODO: auto-complete any common prefix + app.logf("%s", strings.Join(match, " ")) + return line, pos, true + } + newLine = line[:len(line)-len(soFar)] + match[0] + return newLine, len(newLine), true + + } + + errc := make(chan error, 2) + go func() { errc <- app.readFrames() }() + go func() { errc <- app.readConsole() }() + return <-errc +} + +func (app *h2i) logf(format string, args ...interface{}) { + fmt.Fprintf(app.term, format+"\n", args...) +} + +func (app *h2i) readConsole() error { + for { + line, err := app.term.ReadLine() + if err == io.EOF { + return nil + } + if err != nil { + return fmt.Errorf("terminal.ReadLine: %v", err) + } + f := strings.Fields(line) + if len(f) == 0 { + continue + } + cmd, args := f[0], f[1:] + if _, c, ok := lookupCommand(cmd); ok { + err = c.run(app, args) + } else { + app.logf("Unknown command %q", line) + } + if err == errExitApp { + return nil + } + if err != nil { + return err + } + } +} + +func lookupCommand(prefix string) (name string, c command, ok bool) { + prefix = strings.ToLower(prefix) + if c, ok = commands[prefix]; ok { + return prefix, c, ok + } + + for full, candidate := range commands { + if strings.HasPrefix(full, prefix) { + if c.run != nil { + return "", command{}, false // ambiguous + } + c = candidate + name = full + } + } + return name, c, c.run != nil +} + +var errExitApp = errors.New("internal sentinel error value to quit the console reading loop") + +func (a *h2i) cmdQuit(args []string) error { + if len(args) > 0 { + a.logf("the QUIT command takes no argument") + return nil + } + return errExitApp +} + +func (a *h2i) cmdSettings(args []string) error { + if len(args) == 1 && strings.EqualFold(args[0], "ACK") { + return a.framer.WriteSettingsAck() + } + var settings []http2.Setting + for _, arg := range args { + if strings.EqualFold(arg, "ACK") { + a.logf("Error: ACK must be only argument with the SETTINGS command") + return nil + } + eq := strings.Index(arg, "=") + if eq == -1 { + a.logf("Error: invalid argument %q (expected SETTING_NAME=nnnn)", arg) + return nil + } + sid, ok := settingByName(arg[:eq]) + if !ok { + a.logf("Error: unknown setting name %q", arg[:eq]) + return nil + } + val, err := strconv.ParseUint(arg[eq+1:], 10, 32) + if err != nil { + a.logf("Error: invalid argument %q (expected SETTING_NAME=nnnn)", arg) + return nil + } + settings = append(settings, http2.Setting{ + ID: sid, + Val: uint32(val), + }) + } + a.logf("Sending: %v", settings) + return a.framer.WriteSettings(settings...) +} + +func settingByName(name string) (http2.SettingID, bool) { + for _, sid := range [...]http2.SettingID{ + http2.SettingHeaderTableSize, + http2.SettingEnablePush, + http2.SettingMaxConcurrentStreams, + http2.SettingInitialWindowSize, + http2.SettingMaxFrameSize, + http2.SettingMaxHeaderListSize, + } { + if strings.EqualFold(sid.String(), name) { + return sid, true + } + } + return 0, false +} + +func (app *h2i) cmdPing(args []string) error { + if len(args) > 1 { + app.logf("invalid PING usage: only accepts 0 or 1 args") + return nil // nil means don't end the program + } + var data [8]byte + if len(args) == 1 { + copy(data[:], args[0]) + } else { + copy(data[:], "h2i_ping") + } + return app.framer.WritePing(false, data) +} + +func (app *h2i) cmdHeaders(args []string) error { + if len(args) > 0 { + app.logf("Error: HEADERS doesn't yet take arguments.") + // TODO: flags for restricting window size, to force CONTINUATION + // frames. + return nil + } + var h1req bytes.Buffer + app.term.SetPrompt("(as HTTP/1.1)> ") + defer app.term.SetPrompt("h2i> ") + for { + line, err := app.term.ReadLine() + if err != nil { + return err + } + h1req.WriteString(line) + h1req.WriteString("\r\n") + if line == "" { + break + } + } + req, err := http.ReadRequest(bufio.NewReader(&h1req)) + if err != nil { + app.logf("Invalid HTTP/1.1 request: %v", err) + return nil + } + if app.streamID == 0 { + app.streamID = 1 + } else { + app.streamID += 2 + } + app.logf("Opening Stream-ID %d:", app.streamID) + hbf := app.encodeHeaders(req) + if len(hbf) > 16<<10 { + app.logf("TODO: h2i doesn't yet write CONTINUATION frames. Copy it from transport.go") + return nil + } + return app.framer.WriteHeaders(http2.HeadersFrameParam{ + StreamID: app.streamID, + BlockFragment: hbf, + EndStream: req.Method == "GET" || req.Method == "HEAD", // good enough for now + EndHeaders: true, // for now + }) +} + +func (app *h2i) readFrames() error { + for { + f, err := app.framer.ReadFrame() + if err != nil { + return fmt.Errorf("ReadFrame: %v", err) + } + app.logf("%v", f) + switch f := f.(type) { + case *http2.PingFrame: + app.logf(" Data = %q", f.Data) + case *http2.SettingsFrame: + f.ForeachSetting(func(s http2.Setting) error { + app.logf(" %v", s) + app.peerSetting[s.ID] = s.Val + return nil + }) + case *http2.WindowUpdateFrame: + app.logf(" Window-Increment = %v\n", f.Increment) + case *http2.GoAwayFrame: + app.logf(" Last-Stream-ID = %d; Error-Code = %v (%d)\n", f.LastStreamID, f.ErrCode, f.ErrCode) + case *http2.DataFrame: + app.logf(" %q", f.Data()) + case *http2.HeadersFrame: + if f.HasPriority() { + app.logf(" PRIORITY = %v", f.Priority) + } + if app.hdec == nil { + // TODO: if the user uses h2i to send a SETTINGS frame advertising + // something larger, we'll need to respect SETTINGS_HEADER_TABLE_SIZE + // and stuff here instead of using the 4k default. But for now: + tableSize := uint32(4 << 10) + app.hdec = hpack.NewDecoder(tableSize, app.onNewHeaderField) + } + app.hdec.Write(f.HeaderBlockFragment()) + } + } +} + +// called from readLoop +func (app *h2i) onNewHeaderField(f hpack.HeaderField) { + if f.Sensitive { + app.logf(" %s = %q (SENSITIVE)", f.Name, f.Value) + } + app.logf(" %s = %q", f.Name, f.Value) +} + +func (app *h2i) encodeHeaders(req *http.Request) []byte { + app.hbuf.Reset() + + // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go + host := req.Host + if host == "" { + host = req.URL.Host + } + + path := req.URL.Path + if path == "" { + path = "/" + } + + app.writeHeader(":authority", host) // probably not right for all sites + app.writeHeader(":method", req.Method) + app.writeHeader(":path", path) + app.writeHeader(":scheme", "https") + + for k, vv := range req.Header { + lowKey := strings.ToLower(k) + if lowKey == "host" { + continue + } + for _, v := range vv { + app.writeHeader(lowKey, v) + } + } + return app.hbuf.Bytes() +} + +func (app *h2i) writeHeader(name, value string) { + app.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) + app.logf(" %s = %s", name, value) +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/headermap.go b/Godeps/_workspace/src/github.com/bradfitz/http2/headermap.go new file mode 100644 index 0000000..67c7c48 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/headermap.go @@ -0,0 +1,80 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "net/http" + "strings" +) + +var ( + commonLowerHeader = map[string]string{} // Go-Canonical-Case -> lower-case + commonCanonHeader = map[string]string{} // lower-case -> Go-Canonical-Case +) + +func init() { + for _, v := range []string{ + "accept", + "accept-charset", + "accept-encoding", + "accept-language", + "accept-ranges", + "age", + "access-control-allow-origin", + "allow", + "authorization", + "cache-control", + "content-disposition", + "content-encoding", + "content-language", + "content-length", + "content-location", + "content-range", + "content-type", + "cookie", + "date", + "etag", + "expect", + "expires", + "from", + "host", + "if-match", + "if-modified-since", + "if-none-match", + "if-unmodified-since", + "last-modified", + "link", + "location", + "max-forwards", + "proxy-authenticate", + "proxy-authorization", + "range", + "referer", + "refresh", + "retry-after", + "server", + "set-cookie", + "strict-transport-security", + "transfer-encoding", + "user-agent", + "vary", + "via", + "www-authenticate", + } { + chk := http.CanonicalHeaderKey(v) + commonLowerHeader[chk] = v + commonCanonHeader[v] = chk + } +} + +func lowerHeader(v string) string { + if s, ok := commonLowerHeader[v]; ok { + return s + } + return strings.ToLower(v) +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/encode.go b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/encode.go new file mode 100644 index 0000000..19bd9f4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/encode.go @@ -0,0 +1,252 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package hpack + +import ( + "io" +) + +const ( + uint32Max = ^uint32(0) + initialHeaderTableSize = 4096 +) + +type Encoder struct { + dynTab dynamicTable + // minSize is the minimum table size set by + // SetMaxDynamicTableSize after the previous Header Table Size + // Update. + minSize uint32 + // maxSizeLimit is the maximum table size this encoder + // supports. This will protect the encoder from too large + // size. + maxSizeLimit uint32 + // tableSizeUpdate indicates whether "Header Table Size + // Update" is required. + tableSizeUpdate bool + w io.Writer + buf []byte +} + +// NewEncoder returns a new Encoder which performs HPACK encoding. An +// encoded data is written to w. +func NewEncoder(w io.Writer) *Encoder { + e := &Encoder{ + minSize: uint32Max, + maxSizeLimit: initialHeaderTableSize, + tableSizeUpdate: false, + w: w, + } + e.dynTab.setMaxSize(initialHeaderTableSize) + return e +} + +// WriteField encodes f into a single Write to e's underlying Writer. +// This function may also produce bytes for "Header Table Size Update" +// if necessary. If produced, it is done before encoding f. +func (e *Encoder) WriteField(f HeaderField) error { + e.buf = e.buf[:0] + + if e.tableSizeUpdate { + e.tableSizeUpdate = false + if e.minSize < e.dynTab.maxSize { + e.buf = appendTableSize(e.buf, e.minSize) + } + e.minSize = uint32Max + e.buf = appendTableSize(e.buf, e.dynTab.maxSize) + } + + idx, nameValueMatch := e.searchTable(f) + if nameValueMatch { + e.buf = appendIndexed(e.buf, idx) + } else { + indexing := e.shouldIndex(f) + if indexing { + e.dynTab.add(f) + } + + if idx == 0 { + e.buf = appendNewName(e.buf, f, indexing) + } else { + e.buf = appendIndexedName(e.buf, f, idx, indexing) + } + } + n, err := e.w.Write(e.buf) + if err == nil && n != len(e.buf) { + err = io.ErrShortWrite + } + return err +} + +// searchTable searches f in both stable and dynamic header tables. +// The static header table is searched first. Only when there is no +// exact match for both name and value, the dynamic header table is +// then searched. If there is no match, i is 0. If both name and value +// match, i is the matched index and nameValueMatch becomes true. If +// only name matches, i points to that index and nameValueMatch +// becomes false. +func (e *Encoder) searchTable(f HeaderField) (i uint64, nameValueMatch bool) { + for idx, hf := range staticTable { + if !constantTimeStringCompare(hf.Name, f.Name) { + continue + } + if i == 0 { + i = uint64(idx + 1) + } + if f.Sensitive { + continue + } + if !constantTimeStringCompare(hf.Value, f.Value) { + continue + } + i = uint64(idx + 1) + nameValueMatch = true + return + } + + j, nameValueMatch := e.dynTab.search(f) + if nameValueMatch || (i == 0 && j != 0) { + i = j + uint64(len(staticTable)) + } + return +} + +// SetMaxDynamicTableSize changes the dynamic header table size to v. +// The actual size is bounded by the value passed to +// SetMaxDynamicTableSizeLimit. +func (e *Encoder) SetMaxDynamicTableSize(v uint32) { + if v > e.maxSizeLimit { + v = e.maxSizeLimit + } + if v < e.minSize { + e.minSize = v + } + e.tableSizeUpdate = true + e.dynTab.setMaxSize(v) +} + +// SetMaxDynamicTableSizeLimit changes the maximum value that can be +// specified in SetMaxDynamicTableSize to v. By default, it is set to +// 4096, which is the same size of the default dynamic header table +// size described in HPACK specification. If the current maximum +// dynamic header table size is strictly greater than v, "Header Table +// Size Update" will be done in the next WriteField call and the +// maximum dynamic header table size is truncated to v. +func (e *Encoder) SetMaxDynamicTableSizeLimit(v uint32) { + e.maxSizeLimit = v + if e.dynTab.maxSize > v { + e.tableSizeUpdate = true + e.dynTab.setMaxSize(v) + } +} + +// shouldIndex reports whether f should be indexed. +func (e *Encoder) shouldIndex(f HeaderField) bool { + return !f.Sensitive && f.size() <= e.dynTab.maxSize +} + +// appendIndexed appends index i, as encoded in "Indexed Header Field" +// representation, to dst and returns the extended buffer. +func appendIndexed(dst []byte, i uint64) []byte { + first := len(dst) + dst = appendVarInt(dst, 7, i) + dst[first] |= 0x80 + return dst +} + +// appendNewName appends f, as encoded in one of "Literal Header field +// - New Name" representation variants, to dst and returns the +// extended buffer. +// +// If f.Sensitive is true, "Never Indexed" representation is used. If +// f.Sensitive is false and indexing is true, "Inremental Indexing" +// representation is used. +func appendNewName(dst []byte, f HeaderField, indexing bool) []byte { + dst = append(dst, encodeTypeByte(indexing, f.Sensitive)) + dst = appendHpackString(dst, f.Name) + return appendHpackString(dst, f.Value) +} + +// appendIndexedName appends f and index i referring indexed name +// entry, as encoded in one of "Literal Header field - Indexed Name" +// representation variants, to dst and returns the extended buffer. +// +// If f.Sensitive is true, "Never Indexed" representation is used. If +// f.Sensitive is false and indexing is true, "Incremental Indexing" +// representation is used. +func appendIndexedName(dst []byte, f HeaderField, i uint64, indexing bool) []byte { + first := len(dst) + var n byte + if indexing { + n = 6 + } else { + n = 4 + } + dst = appendVarInt(dst, n, i) + dst[first] |= encodeTypeByte(indexing, f.Sensitive) + return appendHpackString(dst, f.Value) +} + +// appendTableSize appends v, as encoded in "Header Table Size Update" +// representation, to dst and returns the extended buffer. +func appendTableSize(dst []byte, v uint32) []byte { + first := len(dst) + dst = appendVarInt(dst, 5, uint64(v)) + dst[first] |= 0x20 + return dst +} + +// appendVarInt appends i, as encoded in variable integer form using n +// bit prefix, to dst and returns the extended buffer. +// +// See +// http://http2.github.io/http2-spec/compression.html#integer.representation +func appendVarInt(dst []byte, n byte, i uint64) []byte { + k := uint64((1 << n) - 1) + if i < k { + return append(dst, byte(i)) + } + dst = append(dst, byte(k)) + i -= k + for ; i >= 128; i >>= 7 { + dst = append(dst, byte(0x80|(i&0x7f))) + } + return append(dst, byte(i)) +} + +// appendHpackString appends s, as encoded in "String Literal" +// representation, to dst and returns the the extended buffer. +// +// s will be encoded in Huffman codes only when it produces strictly +// shorter byte string. +func appendHpackString(dst []byte, s string) []byte { + huffmanLength := HuffmanEncodeLength(s) + if huffmanLength < uint64(len(s)) { + first := len(dst) + dst = appendVarInt(dst, 7, huffmanLength) + dst = AppendHuffmanString(dst, s) + dst[first] |= 0x80 + } else { + dst = appendVarInt(dst, 7, uint64(len(s))) + dst = append(dst, s...) + } + return dst +} + +// encodeTypeByte returns type byte. If sensitive is true, type byte +// for "Never Indexed" representation is returned. If sensitive is +// false and indexing is true, type byte for "Incremental Indexing" +// representation is returned. Otherwise, type byte for "Without +// Indexing" is returned. +func encodeTypeByte(indexing, sensitive bool) byte { + if sensitive { + return 0x10 + } + if indexing { + return 0x40 + } + return 0 +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/encode_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/encode_test.go new file mode 100644 index 0000000..dce66c9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/encode_test.go @@ -0,0 +1,331 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package hpack + +import ( + "bytes" + "encoding/hex" + "reflect" + "strings" + "testing" +) + +func TestEncoderTableSizeUpdate(t *testing.T) { + tests := []struct { + size1, size2 uint32 + wantHex string + }{ + // Should emit 2 table size updates (2048 and 4096) + {2048, 4096, "3fe10f 3fe11f 82"}, + + // Should emit 1 table size update (2048) + {16384, 2048, "3fe10f 82"}, + } + for _, tt := range tests { + var buf bytes.Buffer + e := NewEncoder(&buf) + e.SetMaxDynamicTableSize(tt.size1) + e.SetMaxDynamicTableSize(tt.size2) + if err := e.WriteField(pair(":method", "GET")); err != nil { + t.Fatal(err) + } + want := removeSpace(tt.wantHex) + if got := hex.EncodeToString(buf.Bytes()); got != want { + t.Errorf("e.SetDynamicTableSize %v, %v = %q; want %q", tt.size1, tt.size2, got, want) + } + } +} + +func TestEncoderWriteField(t *testing.T) { + var buf bytes.Buffer + e := NewEncoder(&buf) + var got []HeaderField + d := NewDecoder(4<<10, func(f HeaderField) { + got = append(got, f) + }) + + tests := []struct { + hdrs []HeaderField + }{ + {[]HeaderField{ + pair(":method", "GET"), + pair(":scheme", "http"), + pair(":path", "/"), + pair(":authority", "www.example.com"), + }}, + {[]HeaderField{ + pair(":method", "GET"), + pair(":scheme", "http"), + pair(":path", "/"), + pair(":authority", "www.example.com"), + pair("cache-control", "no-cache"), + }}, + {[]HeaderField{ + pair(":method", "GET"), + pair(":scheme", "https"), + pair(":path", "/index.html"), + pair(":authority", "www.example.com"), + pair("custom-key", "custom-value"), + }}, + } + for i, tt := range tests { + buf.Reset() + got = got[:0] + for _, hf := range tt.hdrs { + if err := e.WriteField(hf); err != nil { + t.Fatal(err) + } + } + _, err := d.Write(buf.Bytes()) + if err != nil { + t.Errorf("%d. Decoder Write = %v", i, err) + } + if !reflect.DeepEqual(got, tt.hdrs) { + t.Errorf("%d. Decoded %+v; want %+v", i, got, tt.hdrs) + } + } +} + +func TestEncoderSearchTable(t *testing.T) { + e := NewEncoder(nil) + + e.dynTab.add(pair("foo", "bar")) + e.dynTab.add(pair("blake", "miz")) + e.dynTab.add(pair(":method", "GET")) + + tests := []struct { + hf HeaderField + wantI uint64 + wantMatch bool + }{ + // Name and Value match + {pair("foo", "bar"), uint64(len(staticTable) + 3), true}, + {pair("blake", "miz"), uint64(len(staticTable) + 2), true}, + {pair(":method", "GET"), 2, true}, + + // Only name match because Sensitive == true + {HeaderField{":method", "GET", true}, 2, false}, + + // Only Name matches + {pair("foo", "..."), uint64(len(staticTable) + 3), false}, + {pair("blake", "..."), uint64(len(staticTable) + 2), false}, + {pair(":method", "..."), 2, false}, + + // None match + {pair("foo-", "bar"), 0, false}, + } + for _, tt := range tests { + if gotI, gotMatch := e.searchTable(tt.hf); gotI != tt.wantI || gotMatch != tt.wantMatch { + t.Errorf("d.search(%+v) = %v, %v; want %v, %v", tt.hf, gotI, gotMatch, tt.wantI, tt.wantMatch) + } + } +} + +func TestAppendVarInt(t *testing.T) { + tests := []struct { + n byte + i uint64 + want []byte + }{ + // Fits in a byte: + {1, 0, []byte{0}}, + {2, 2, []byte{2}}, + {3, 6, []byte{6}}, + {4, 14, []byte{14}}, + {5, 30, []byte{30}}, + {6, 62, []byte{62}}, + {7, 126, []byte{126}}, + {8, 254, []byte{254}}, + + // Multiple bytes: + {5, 1337, []byte{31, 154, 10}}, + } + for _, tt := range tests { + got := appendVarInt(nil, tt.n, tt.i) + if !bytes.Equal(got, tt.want) { + t.Errorf("appendVarInt(nil, %v, %v) = %v; want %v", tt.n, tt.i, got, tt.want) + } + } +} + +func TestAppendHpackString(t *testing.T) { + tests := []struct { + s, wantHex string + }{ + // Huffman encoded + {"www.example.com", "8c f1e3 c2e5 f23a 6ba0 ab90 f4ff"}, + + // Not Huffman encoded + {"a", "01 61"}, + + // zero length + {"", "00"}, + } + for _, tt := range tests { + want := removeSpace(tt.wantHex) + buf := appendHpackString(nil, tt.s) + if got := hex.EncodeToString(buf); want != got { + t.Errorf("appendHpackString(nil, %q) = %q; want %q", tt.s, got, want) + } + } +} + +func TestAppendIndexed(t *testing.T) { + tests := []struct { + i uint64 + wantHex string + }{ + // 1 byte + {1, "81"}, + {126, "fe"}, + + // 2 bytes + {127, "ff00"}, + {128, "ff01"}, + } + for _, tt := range tests { + want := removeSpace(tt.wantHex) + buf := appendIndexed(nil, tt.i) + if got := hex.EncodeToString(buf); want != got { + t.Errorf("appendIndex(nil, %v) = %q; want %q", tt.i, got, want) + } + } +} + +func TestAppendNewName(t *testing.T) { + tests := []struct { + f HeaderField + indexing bool + wantHex string + }{ + // Incremental indexing + {HeaderField{"custom-key", "custom-value", false}, true, "40 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"}, + + // Without indexing + {HeaderField{"custom-key", "custom-value", false}, false, "00 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"}, + + // Never indexed + {HeaderField{"custom-key", "custom-value", true}, true, "10 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"}, + {HeaderField{"custom-key", "custom-value", true}, false, "10 88 25a8 49e9 5ba9 7d7f 89 25a8 49e9 5bb8 e8b4 bf"}, + } + for _, tt := range tests { + want := removeSpace(tt.wantHex) + buf := appendNewName(nil, tt.f, tt.indexing) + if got := hex.EncodeToString(buf); want != got { + t.Errorf("appendNewName(nil, %+v, %v) = %q; want %q", tt.f, tt.indexing, got, want) + } + } +} + +func TestAppendIndexedName(t *testing.T) { + tests := []struct { + f HeaderField + i uint64 + indexing bool + wantHex string + }{ + // Incremental indexing + {HeaderField{":status", "302", false}, 8, true, "48 82 6402"}, + + // Without indexing + {HeaderField{":status", "302", false}, 8, false, "08 82 6402"}, + + // Never indexed + {HeaderField{":status", "302", true}, 8, true, "18 82 6402"}, + {HeaderField{":status", "302", true}, 8, false, "18 82 6402"}, + } + for _, tt := range tests { + want := removeSpace(tt.wantHex) + buf := appendIndexedName(nil, tt.f, tt.i, tt.indexing) + if got := hex.EncodeToString(buf); want != got { + t.Errorf("appendIndexedName(nil, %+v, %v) = %q; want %q", tt.f, tt.indexing, got, want) + } + } +} + +func TestAppendTableSize(t *testing.T) { + tests := []struct { + i uint32 + wantHex string + }{ + // Fits into 1 byte + {30, "3e"}, + + // Extra byte + {31, "3f00"}, + {32, "3f01"}, + } + for _, tt := range tests { + want := removeSpace(tt.wantHex) + buf := appendTableSize(nil, tt.i) + if got := hex.EncodeToString(buf); want != got { + t.Errorf("appendTableSize(nil, %v) = %q; want %q", tt.i, got, want) + } + } +} + +func TestEncoderSetMaxDynamicTableSize(t *testing.T) { + var buf bytes.Buffer + e := NewEncoder(&buf) + tests := []struct { + v uint32 + wantUpdate bool + wantMinSize uint32 + wantMaxSize uint32 + }{ + // Set new table size to 2048 + {2048, true, 2048, 2048}, + + // Set new table size to 16384, but still limited to + // 4096 + {16384, true, 2048, 4096}, + } + for _, tt := range tests { + e.SetMaxDynamicTableSize(tt.v) + if got := e.tableSizeUpdate; tt.wantUpdate != got { + t.Errorf("e.tableSizeUpdate = %v; want %v", got, tt.wantUpdate) + } + if got := e.minSize; tt.wantMinSize != got { + t.Errorf("e.minSize = %v; want %v", got, tt.wantMinSize) + } + if got := e.dynTab.maxSize; tt.wantMaxSize != got { + t.Errorf("e.maxSize = %v; want %v", got, tt.wantMaxSize) + } + } +} + +func TestEncoderSetMaxDynamicTableSizeLimit(t *testing.T) { + e := NewEncoder(nil) + // 4095 < initialHeaderTableSize means maxSize is truncated to + // 4095. + e.SetMaxDynamicTableSizeLimit(4095) + if got, want := e.dynTab.maxSize, uint32(4095); got != want { + t.Errorf("e.dynTab.maxSize = %v; want %v", got, want) + } + if got, want := e.maxSizeLimit, uint32(4095); got != want { + t.Errorf("e.maxSizeLimit = %v; want %v", got, want) + } + if got, want := e.tableSizeUpdate, true; got != want { + t.Errorf("e.tableSizeUpdate = %v; want %v", got, want) + } + // maxSize will be truncated to maxSizeLimit + e.SetMaxDynamicTableSize(16384) + if got, want := e.dynTab.maxSize, uint32(4095); got != want { + t.Errorf("e.dynTab.maxSize = %v; want %v", got, want) + } + // 8192 > current maxSizeLimit, so maxSize does not change. + e.SetMaxDynamicTableSizeLimit(8192) + if got, want := e.dynTab.maxSize, uint32(4095); got != want { + t.Errorf("e.dynTab.maxSize = %v; want %v", got, want) + } + if got, want := e.maxSizeLimit, uint32(8192); got != want { + t.Errorf("e.maxSizeLimit = %v; want %v", got, want) + } +} + +func removeSpace(s string) string { + return strings.Replace(s, " ", "", -1) +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/hpack.go b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/hpack.go new file mode 100644 index 0000000..c9e36f7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/hpack.go @@ -0,0 +1,445 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +// Package hpack implements HPACK, a compression format for +// efficiently representing HTTP header fields in the context of HTTP/2. +// +// See http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-09 +package hpack + +import ( + "bytes" + "errors" + "fmt" +) + +// A DecodingError is something the spec defines as a decoding error. +type DecodingError struct { + Err error +} + +func (de DecodingError) Error() string { + return fmt.Sprintf("decoding error: %v", de.Err) +} + +// An InvalidIndexError is returned when an encoder references a table +// entry before the static table or after the end of the dynamic table. +type InvalidIndexError int + +func (e InvalidIndexError) Error() string { + return fmt.Sprintf("invalid indexed representation index %d", int(e)) +} + +// A HeaderField is a name-value pair. Both the name and value are +// treated as opaque sequences of octets. +type HeaderField struct { + Name, Value string + + // Sensitive means that this header field should never be + // indexed. + Sensitive bool +} + +func (hf *HeaderField) size() uint32 { + // http://http2.github.io/http2-spec/compression.html#rfc.section.4.1 + // "The size of the dynamic table is the sum of the size of + // its entries. The size of an entry is the sum of its name's + // length in octets (as defined in Section 5.2), its value's + // length in octets (see Section 5.2), plus 32. The size of + // an entry is calculated using the length of the name and + // value without any Huffman encoding applied." + + // This can overflow if somebody makes a large HeaderField + // Name and/or Value by hand, but we don't care, because that + // won't happen on the wire because the encoding doesn't allow + // it. + return uint32(len(hf.Name) + len(hf.Value) + 32) +} + +// A Decoder is the decoding context for incremental processing of +// header blocks. +type Decoder struct { + dynTab dynamicTable + emit func(f HeaderField) + + // buf is the unparsed buffer. It's only written to + // saveBuf if it was truncated in the middle of a header + // block. Because it's usually not owned, we can only + // process it under Write. + buf []byte // usually not owned + saveBuf bytes.Buffer +} + +func NewDecoder(maxSize uint32, emitFunc func(f HeaderField)) *Decoder { + d := &Decoder{ + emit: emitFunc, + } + d.dynTab.allowedMaxSize = maxSize + d.dynTab.setMaxSize(maxSize) + return d +} + +// TODO: add method *Decoder.Reset(maxSize, emitFunc) to let callers re-use Decoders and their +// underlying buffers for garbage reasons. + +func (d *Decoder) SetMaxDynamicTableSize(v uint32) { + d.dynTab.setMaxSize(v) +} + +// SetAllowedMaxDynamicTableSize sets the upper bound that the encoded +// stream (via dynamic table size updates) may set the maximum size +// to. +func (d *Decoder) SetAllowedMaxDynamicTableSize(v uint32) { + d.dynTab.allowedMaxSize = v +} + +type dynamicTable struct { + // ents is the FIFO described at + // http://http2.github.io/http2-spec/compression.html#rfc.section.2.3.2 + // The newest (low index) is append at the end, and items are + // evicted from the front. + ents []HeaderField + size uint32 + maxSize uint32 // current maxSize + allowedMaxSize uint32 // maxSize may go up to this, inclusive +} + +func (dt *dynamicTable) setMaxSize(v uint32) { + dt.maxSize = v + dt.evict() +} + +// TODO: change dynamicTable to be a struct with a slice and a size int field, +// per http://http2.github.io/http2-spec/compression.html#rfc.section.4.1: +// +// +// Then make add increment the size. maybe the max size should move from Decoder to +// dynamicTable and add should return an ok bool if there was enough space. +// +// Later we'll need a remove operation on dynamicTable. + +func (dt *dynamicTable) add(f HeaderField) { + dt.ents = append(dt.ents, f) + dt.size += f.size() + dt.evict() +} + +// If we're too big, evict old stuff (front of the slice) +func (dt *dynamicTable) evict() { + base := dt.ents // keep base pointer of slice + for dt.size > dt.maxSize { + dt.size -= dt.ents[0].size() + dt.ents = dt.ents[1:] + } + + // Shift slice contents down if we evicted things. + if len(dt.ents) != len(base) { + copy(base, dt.ents) + dt.ents = base[:len(dt.ents)] + } +} + +// constantTimeStringCompare compares string a and b in a constant +// time manner. +func constantTimeStringCompare(a, b string) bool { + if len(a) != len(b) { + return false + } + + c := byte(0) + + for i := 0; i < len(a); i++ { + c |= a[i] ^ b[i] + } + + return c == 0 +} + +// Search searches f in the table. The return value i is 0 if there is +// no name match. If there is name match or name/value match, i is the +// index of that entry (1-based). If both name and value match, +// nameValueMatch becomes true. +func (dt *dynamicTable) search(f HeaderField) (i uint64, nameValueMatch bool) { + l := len(dt.ents) + for j := l - 1; j >= 0; j-- { + ent := dt.ents[j] + if !constantTimeStringCompare(ent.Name, f.Name) { + continue + } + if i == 0 { + i = uint64(l - j) + } + if f.Sensitive { + continue + } + if !constantTimeStringCompare(ent.Value, f.Value) { + continue + } + i = uint64(l - j) + nameValueMatch = true + return + } + return +} + +func (d *Decoder) maxTableIndex() int { + return len(d.dynTab.ents) + len(staticTable) +} + +func (d *Decoder) at(i uint64) (hf HeaderField, ok bool) { + if i < 1 { + return + } + if i > uint64(d.maxTableIndex()) { + return + } + if i <= uint64(len(staticTable)) { + return staticTable[i-1], true + } + dents := d.dynTab.ents + return dents[len(dents)-(int(i)-len(staticTable))], true +} + +// Decode decodes an entire block. +// +// TODO: remove this method and make it incremental later? This is +// easier for debugging now. +func (d *Decoder) DecodeFull(p []byte) ([]HeaderField, error) { + var hf []HeaderField + saveFunc := d.emit + defer func() { d.emit = saveFunc }() + d.emit = func(f HeaderField) { hf = append(hf, f) } + if _, err := d.Write(p); err != nil { + return nil, err + } + if err := d.Close(); err != nil { + return nil, err + } + return hf, nil +} + +func (d *Decoder) Close() error { + if d.saveBuf.Len() > 0 { + d.saveBuf.Reset() + return DecodingError{errors.New("truncated headers")} + } + return nil +} + +func (d *Decoder) Write(p []byte) (n int, err error) { + if len(p) == 0 { + // Prevent state machine CPU attacks (making us redo + // work up to the point of finding out we don't have + // enough data) + return + } + // Only copy the data if we have to. Optimistically assume + // that p will contain a complete header block. + if d.saveBuf.Len() == 0 { + d.buf = p + } else { + d.saveBuf.Write(p) + d.buf = d.saveBuf.Bytes() + d.saveBuf.Reset() + } + + for len(d.buf) > 0 { + err = d.parseHeaderFieldRepr() + if err != nil { + if err == errNeedMore { + err = nil + d.saveBuf.Write(d.buf) + } + break + } + } + + return len(p), err +} + +// errNeedMore is an internal sentinel error value that means the +// buffer is truncated and we need to read more data before we can +// continue parsing. +var errNeedMore = errors.New("need more data") + +type indexType int + +const ( + indexedTrue indexType = iota + indexedFalse + indexedNever +) + +func (v indexType) indexed() bool { return v == indexedTrue } +func (v indexType) sensitive() bool { return v == indexedNever } + +// returns errNeedMore if there isn't enough data available. +// any other error is fatal. +// consumes d.buf iff it returns nil. +// precondition: must be called with len(d.buf) > 0 +func (d *Decoder) parseHeaderFieldRepr() error { + b := d.buf[0] + switch { + case b&128 != 0: + // Indexed representation. + // High bit set? + // http://http2.github.io/http2-spec/compression.html#rfc.section.6.1 + return d.parseFieldIndexed() + case b&192 == 64: + // 6.2.1 Literal Header Field with Incremental Indexing + // 0b10xxxxxx: top two bits are 10 + // http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.1 + return d.parseFieldLiteral(6, indexedTrue) + case b&240 == 0: + // 6.2.2 Literal Header Field without Indexing + // 0b0000xxxx: top four bits are 0000 + // http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.2 + return d.parseFieldLiteral(4, indexedFalse) + case b&240 == 16: + // 6.2.3 Literal Header Field never Indexed + // 0b0001xxxx: top four bits are 0001 + // http://http2.github.io/http2-spec/compression.html#rfc.section.6.2.3 + return d.parseFieldLiteral(4, indexedNever) + case b&224 == 32: + // 6.3 Dynamic Table Size Update + // Top three bits are '001'. + // http://http2.github.io/http2-spec/compression.html#rfc.section.6.3 + return d.parseDynamicTableSizeUpdate() + } + + return DecodingError{errors.New("invalid encoding")} +} + +// (same invariants and behavior as parseHeaderFieldRepr) +func (d *Decoder) parseFieldIndexed() error { + buf := d.buf + idx, buf, err := readVarInt(7, buf) + if err != nil { + return err + } + hf, ok := d.at(idx) + if !ok { + return DecodingError{InvalidIndexError(idx)} + } + d.emit(HeaderField{Name: hf.Name, Value: hf.Value}) + d.buf = buf + return nil +} + +// (same invariants and behavior as parseHeaderFieldRepr) +func (d *Decoder) parseFieldLiteral(n uint8, it indexType) error { + buf := d.buf + nameIdx, buf, err := readVarInt(n, buf) + if err != nil { + return err + } + + var hf HeaderField + if nameIdx > 0 { + ihf, ok := d.at(nameIdx) + if !ok { + return DecodingError{InvalidIndexError(nameIdx)} + } + hf.Name = ihf.Name + } else { + hf.Name, buf, err = readString(buf) + if err != nil { + return err + } + } + hf.Value, buf, err = readString(buf) + if err != nil { + return err + } + d.buf = buf + if it.indexed() { + d.dynTab.add(hf) + } + hf.Sensitive = it.sensitive() + d.emit(hf) + return nil +} + +// (same invariants and behavior as parseHeaderFieldRepr) +func (d *Decoder) parseDynamicTableSizeUpdate() error { + buf := d.buf + size, buf, err := readVarInt(5, buf) + if err != nil { + return err + } + if size > uint64(d.dynTab.allowedMaxSize) { + return DecodingError{errors.New("dynamic table size update too large")} + } + d.dynTab.setMaxSize(uint32(size)) + d.buf = buf + return nil +} + +var errVarintOverflow = DecodingError{errors.New("varint integer overflow")} + +// readVarInt reads an unsigned variable length integer off the +// beginning of p. n is the parameter as described in +// http://http2.github.io/http2-spec/compression.html#rfc.section.5.1. +// +// n must always be between 1 and 8. +// +// The returned remain buffer is either a smaller suffix of p, or err != nil. +// The error is errNeedMore if p doesn't contain a complete integer. +func readVarInt(n byte, p []byte) (i uint64, remain []byte, err error) { + if n < 1 || n > 8 { + panic("bad n") + } + if len(p) == 0 { + return 0, p, errNeedMore + } + i = uint64(p[0]) + if n < 8 { + i &= (1 << uint64(n)) - 1 + } + if i < (1< 0 { + b := p[0] + p = p[1:] + i += uint64(b&127) << m + if b&128 == 0 { + return i, p, nil + } + m += 7 + if m >= 63 { // TODO: proper overflow check. making this up. + return 0, origP, errVarintOverflow + } + } + return 0, origP, errNeedMore +} + +func readString(p []byte) (s string, remain []byte, err error) { + if len(p) == 0 { + return "", p, errNeedMore + } + isHuff := p[0]&128 != 0 + strLen, p, err := readVarInt(7, p) + if err != nil { + return "", p, err + } + if uint64(len(p)) < strLen { + return "", p, errNeedMore + } + if !isHuff { + return string(p[:strLen]), p[strLen:], nil + } + + // TODO: optimize this garbage: + var buf bytes.Buffer + if _, err := HuffmanDecode(&buf, p[:strLen]); err != nil { + return "", nil, err + } + return buf.String(), p[strLen:], nil +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/hpack_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/hpack_test.go new file mode 100644 index 0000000..08bee41 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/hpack_test.go @@ -0,0 +1,648 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package hpack + +import ( + "bufio" + "bytes" + "encoding/hex" + "fmt" + "reflect" + "regexp" + "strconv" + "strings" + "testing" +) + +func TestStaticTable(t *testing.T) { + fromSpec := ` + +-------+-----------------------------+---------------+ + | 1 | :authority | | + | 2 | :method | GET | + | 3 | :method | POST | + | 4 | :path | / | + | 5 | :path | /index.html | + | 6 | :scheme | http | + | 7 | :scheme | https | + | 8 | :status | 200 | + | 9 | :status | 204 | + | 10 | :status | 206 | + | 11 | :status | 304 | + | 12 | :status | 400 | + | 13 | :status | 404 | + | 14 | :status | 500 | + | 15 | accept-charset | | + | 16 | accept-encoding | gzip, deflate | + | 17 | accept-language | | + | 18 | accept-ranges | | + | 19 | accept | | + | 20 | access-control-allow-origin | | + | 21 | age | | + | 22 | allow | | + | 23 | authorization | | + | 24 | cache-control | | + | 25 | content-disposition | | + | 26 | content-encoding | | + | 27 | content-language | | + | 28 | content-length | | + | 29 | content-location | | + | 30 | content-range | | + | 31 | content-type | | + | 32 | cookie | | + | 33 | date | | + | 34 | etag | | + | 35 | expect | | + | 36 | expires | | + | 37 | from | | + | 38 | host | | + | 39 | if-match | | + | 40 | if-modified-since | | + | 41 | if-none-match | | + | 42 | if-range | | + | 43 | if-unmodified-since | | + | 44 | last-modified | | + | 45 | link | | + | 46 | location | | + | 47 | max-forwards | | + | 48 | proxy-authenticate | | + | 49 | proxy-authorization | | + | 50 | range | | + | 51 | referer | | + | 52 | refresh | | + | 53 | retry-after | | + | 54 | server | | + | 55 | set-cookie | | + | 56 | strict-transport-security | | + | 57 | transfer-encoding | | + | 58 | user-agent | | + | 59 | vary | | + | 60 | via | | + | 61 | www-authenticate | | + +-------+-----------------------------+---------------+ +` + bs := bufio.NewScanner(strings.NewReader(fromSpec)) + re := regexp.MustCompile(`\| (\d+)\s+\| (\S+)\s*\| (\S(.*\S)?)?\s+\|`) + for bs.Scan() { + l := bs.Text() + if !strings.Contains(l, "|") { + continue + } + m := re.FindStringSubmatch(l) + if m == nil { + continue + } + i, err := strconv.Atoi(m[1]) + if err != nil { + t.Errorf("Bogus integer on line %q", l) + continue + } + if i < 1 || i > len(staticTable) { + t.Errorf("Bogus index %d on line %q", i, l) + continue + } + if got, want := staticTable[i-1].Name, m[2]; got != want { + t.Errorf("header index %d name = %q; want %q", i, got, want) + } + if got, want := staticTable[i-1].Value, m[3]; got != want { + t.Errorf("header index %d value = %q; want %q", i, got, want) + } + } + if err := bs.Err(); err != nil { + t.Error(err) + } +} + +func (d *Decoder) mustAt(idx int) HeaderField { + if hf, ok := d.at(uint64(idx)); !ok { + panic(fmt.Sprintf("bogus index %d", idx)) + } else { + return hf + } +} + +func TestDynamicTableAt(t *testing.T) { + d := NewDecoder(4096, nil) + at := d.mustAt + if got, want := at(2), (pair(":method", "GET")); got != want { + t.Errorf("at(2) = %v; want %v", got, want) + } + d.dynTab.add(pair("foo", "bar")) + d.dynTab.add(pair("blake", "miz")) + if got, want := at(len(staticTable)+1), (pair("blake", "miz")); got != want { + t.Errorf("at(dyn 1) = %v; want %v", got, want) + } + if got, want := at(len(staticTable)+2), (pair("foo", "bar")); got != want { + t.Errorf("at(dyn 2) = %v; want %v", got, want) + } + if got, want := at(3), (pair(":method", "POST")); got != want { + t.Errorf("at(3) = %v; want %v", got, want) + } +} + +func TestDynamicTableSearch(t *testing.T) { + dt := dynamicTable{} + dt.setMaxSize(4096) + + dt.add(pair("foo", "bar")) + dt.add(pair("blake", "miz")) + dt.add(pair(":method", "GET")) + + tests := []struct { + hf HeaderField + wantI uint64 + wantMatch bool + }{ + // Name and Value match + {pair("foo", "bar"), 3, true}, + {pair(":method", "GET"), 1, true}, + + // Only name match because of Sensitive == true + {HeaderField{"blake", "miz", true}, 2, false}, + + // Only Name matches + {pair("foo", "..."), 3, false}, + {pair("blake", "..."), 2, false}, + {pair(":method", "..."), 1, false}, + + // None match + {pair("foo-", "bar"), 0, false}, + } + for _, tt := range tests { + if gotI, gotMatch := dt.search(tt.hf); gotI != tt.wantI || gotMatch != tt.wantMatch { + t.Errorf("d.search(%+v) = %v, %v; want %v, %v", tt.hf, gotI, gotMatch, tt.wantI, tt.wantMatch) + } + } +} + +func TestDynamicTableSizeEvict(t *testing.T) { + d := NewDecoder(4096, nil) + if want := uint32(0); d.dynTab.size != want { + t.Fatalf("size = %d; want %d", d.dynTab.size, want) + } + add := d.dynTab.add + add(pair("blake", "eats pizza")) + if want := uint32(15 + 32); d.dynTab.size != want { + t.Fatalf("after pizza, size = %d; want %d", d.dynTab.size, want) + } + add(pair("foo", "bar")) + if want := uint32(15 + 32 + 6 + 32); d.dynTab.size != want { + t.Fatalf("after foo bar, size = %d; want %d", d.dynTab.size, want) + } + d.dynTab.setMaxSize(15 + 32 + 1 /* slop */) + if want := uint32(6 + 32); d.dynTab.size != want { + t.Fatalf("after setMaxSize, size = %d; want %d", d.dynTab.size, want) + } + if got, want := d.mustAt(len(staticTable)+1), (pair("foo", "bar")); got != want { + t.Errorf("at(dyn 1) = %v; want %v", got, want) + } + add(pair("long", strings.Repeat("x", 500))) + if want := uint32(0); d.dynTab.size != want { + t.Fatalf("after big one, size = %d; want %d", d.dynTab.size, want) + } +} + +func TestDecoderDecode(t *testing.T) { + tests := []struct { + name string + in []byte + want []HeaderField + wantDynTab []HeaderField // newest entry first + }{ + // C.2.1 Literal Header Field with Indexing + // http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.1 + {"C.2.1", dehex("400a 6375 7374 6f6d 2d6b 6579 0d63 7573 746f 6d2d 6865 6164 6572"), + []HeaderField{pair("custom-key", "custom-header")}, + []HeaderField{pair("custom-key", "custom-header")}, + }, + + // C.2.2 Literal Header Field without Indexing + // http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.2 + {"C.2.2", dehex("040c 2f73 616d 706c 652f 7061 7468"), + []HeaderField{pair(":path", "/sample/path")}, + []HeaderField{}}, + + // C.2.3 Literal Header Field never Indexed + // http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.3 + {"C.2.3", dehex("1008 7061 7373 776f 7264 0673 6563 7265 74"), + []HeaderField{{"password", "secret", true}}, + []HeaderField{}}, + + // C.2.4 Indexed Header Field + // http://http2.github.io/http2-spec/compression.html#rfc.section.C.2.4 + {"C.2.4", []byte("\x82"), + []HeaderField{pair(":method", "GET")}, + []HeaderField{}}, + } + for _, tt := range tests { + d := NewDecoder(4096, nil) + hf, err := d.DecodeFull(tt.in) + if err != nil { + t.Errorf("%s: %v", tt.name, err) + continue + } + if !reflect.DeepEqual(hf, tt.want) { + t.Errorf("%s: Got %v; want %v", tt.name, hf, tt.want) + } + gotDynTab := d.dynTab.reverseCopy() + if !reflect.DeepEqual(gotDynTab, tt.wantDynTab) { + t.Errorf("%s: dynamic table after = %v; want %v", tt.name, gotDynTab, tt.wantDynTab) + } + } +} + +func (dt *dynamicTable) reverseCopy() (hf []HeaderField) { + hf = make([]HeaderField, len(dt.ents)) + for i := range hf { + hf[i] = dt.ents[len(dt.ents)-1-i] + } + return +} + +type encAndWant struct { + enc []byte + want []HeaderField + wantDynTab []HeaderField + wantDynSize uint32 +} + +// C.3 Request Examples without Huffman Coding +// http://http2.github.io/http2-spec/compression.html#rfc.section.C.3 +func TestDecodeC3_NoHuffman(t *testing.T) { + testDecodeSeries(t, 4096, []encAndWant{ + {dehex("8286 8441 0f77 7777 2e65 7861 6d70 6c65 2e63 6f6d"), + []HeaderField{ + pair(":method", "GET"), + pair(":scheme", "http"), + pair(":path", "/"), + pair(":authority", "www.example.com"), + }, + []HeaderField{ + pair(":authority", "www.example.com"), + }, + 57, + }, + {dehex("8286 84be 5808 6e6f 2d63 6163 6865"), + []HeaderField{ + pair(":method", "GET"), + pair(":scheme", "http"), + pair(":path", "/"), + pair(":authority", "www.example.com"), + pair("cache-control", "no-cache"), + }, + []HeaderField{ + pair("cache-control", "no-cache"), + pair(":authority", "www.example.com"), + }, + 110, + }, + {dehex("8287 85bf 400a 6375 7374 6f6d 2d6b 6579 0c63 7573 746f 6d2d 7661 6c75 65"), + []HeaderField{ + pair(":method", "GET"), + pair(":scheme", "https"), + pair(":path", "/index.html"), + pair(":authority", "www.example.com"), + pair("custom-key", "custom-value"), + }, + []HeaderField{ + pair("custom-key", "custom-value"), + pair("cache-control", "no-cache"), + pair(":authority", "www.example.com"), + }, + 164, + }, + }) +} + +// C.4 Request Examples with Huffman Coding +// http://http2.github.io/http2-spec/compression.html#rfc.section.C.4 +func TestDecodeC4_Huffman(t *testing.T) { + testDecodeSeries(t, 4096, []encAndWant{ + {dehex("8286 8441 8cf1 e3c2 e5f2 3a6b a0ab 90f4 ff"), + []HeaderField{ + pair(":method", "GET"), + pair(":scheme", "http"), + pair(":path", "/"), + pair(":authority", "www.example.com"), + }, + []HeaderField{ + pair(":authority", "www.example.com"), + }, + 57, + }, + {dehex("8286 84be 5886 a8eb 1064 9cbf"), + []HeaderField{ + pair(":method", "GET"), + pair(":scheme", "http"), + pair(":path", "/"), + pair(":authority", "www.example.com"), + pair("cache-control", "no-cache"), + }, + []HeaderField{ + pair("cache-control", "no-cache"), + pair(":authority", "www.example.com"), + }, + 110, + }, + {dehex("8287 85bf 4088 25a8 49e9 5ba9 7d7f 8925 a849 e95b b8e8 b4bf"), + []HeaderField{ + pair(":method", "GET"), + pair(":scheme", "https"), + pair(":path", "/index.html"), + pair(":authority", "www.example.com"), + pair("custom-key", "custom-value"), + }, + []HeaderField{ + pair("custom-key", "custom-value"), + pair("cache-control", "no-cache"), + pair(":authority", "www.example.com"), + }, + 164, + }, + }) +} + +// http://http2.github.io/http2-spec/compression.html#rfc.section.C.5 +// "This section shows several consecutive header lists, corresponding +// to HTTP responses, on the same connection. The HTTP/2 setting +// parameter SETTINGS_HEADER_TABLE_SIZE is set to the value of 256 +// octets, causing some evictions to occur." +func TestDecodeC5_ResponsesNoHuff(t *testing.T) { + testDecodeSeries(t, 256, []encAndWant{ + {dehex(` +4803 3330 3258 0770 7269 7661 7465 611d +4d6f 6e2c 2032 3120 4f63 7420 3230 3133 +2032 303a 3133 3a32 3120 474d 546e 1768 +7474 7073 3a2f 2f77 7777 2e65 7861 6d70 +6c65 2e63 6f6d +`), + []HeaderField{ + pair(":status", "302"), + pair("cache-control", "private"), + pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + pair("location", "https://www.example.com"), + }, + []HeaderField{ + pair("location", "https://www.example.com"), + pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + pair("cache-control", "private"), + pair(":status", "302"), + }, + 222, + }, + {dehex("4803 3330 37c1 c0bf"), + []HeaderField{ + pair(":status", "307"), + pair("cache-control", "private"), + pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + pair("location", "https://www.example.com"), + }, + []HeaderField{ + pair(":status", "307"), + pair("location", "https://www.example.com"), + pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + pair("cache-control", "private"), + }, + 222, + }, + {dehex(` +88c1 611d 4d6f 6e2c 2032 3120 4f63 7420 +3230 3133 2032 303a 3133 3a32 3220 474d +54c0 5a04 677a 6970 7738 666f 6f3d 4153 +444a 4b48 514b 425a 584f 5157 454f 5049 +5541 5851 5745 4f49 553b 206d 6178 2d61 +6765 3d33 3630 303b 2076 6572 7369 6f6e +3d31 +`), + []HeaderField{ + pair(":status", "200"), + pair("cache-control", "private"), + pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"), + pair("location", "https://www.example.com"), + pair("content-encoding", "gzip"), + pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"), + }, + []HeaderField{ + pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"), + pair("content-encoding", "gzip"), + pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"), + }, + 215, + }, + }) +} + +// http://http2.github.io/http2-spec/compression.html#rfc.section.C.6 +// "This section shows the same examples as the previous section, but +// using Huffman encoding for the literal values. The HTTP/2 setting +// parameter SETTINGS_HEADER_TABLE_SIZE is set to the value of 256 +// octets, causing some evictions to occur. The eviction mechanism +// uses the length of the decoded literal values, so the same +// evictions occurs as in the previous section." +func TestDecodeC6_ResponsesHuffman(t *testing.T) { + testDecodeSeries(t, 256, []encAndWant{ + {dehex(` +4882 6402 5885 aec3 771a 4b61 96d0 7abe +9410 54d4 44a8 2005 9504 0b81 66e0 82a6 +2d1b ff6e 919d 29ad 1718 63c7 8f0b 97c8 +e9ae 82ae 43d3 +`), + []HeaderField{ + pair(":status", "302"), + pair("cache-control", "private"), + pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + pair("location", "https://www.example.com"), + }, + []HeaderField{ + pair("location", "https://www.example.com"), + pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + pair("cache-control", "private"), + pair(":status", "302"), + }, + 222, + }, + {dehex("4883 640e ffc1 c0bf"), + []HeaderField{ + pair(":status", "307"), + pair("cache-control", "private"), + pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + pair("location", "https://www.example.com"), + }, + []HeaderField{ + pair(":status", "307"), + pair("location", "https://www.example.com"), + pair("date", "Mon, 21 Oct 2013 20:13:21 GMT"), + pair("cache-control", "private"), + }, + 222, + }, + {dehex(` +88c1 6196 d07a be94 1054 d444 a820 0595 +040b 8166 e084 a62d 1bff c05a 839b d9ab +77ad 94e7 821d d7f2 e6c7 b335 dfdf cd5b +3960 d5af 2708 7f36 72c1 ab27 0fb5 291f +9587 3160 65c0 03ed 4ee5 b106 3d50 07 +`), + []HeaderField{ + pair(":status", "200"), + pair("cache-control", "private"), + pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"), + pair("location", "https://www.example.com"), + pair("content-encoding", "gzip"), + pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"), + }, + []HeaderField{ + pair("set-cookie", "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"), + pair("content-encoding", "gzip"), + pair("date", "Mon, 21 Oct 2013 20:13:22 GMT"), + }, + 215, + }, + }) +} + +func testDecodeSeries(t *testing.T, size uint32, steps []encAndWant) { + d := NewDecoder(size, nil) + for i, step := range steps { + hf, err := d.DecodeFull(step.enc) + if err != nil { + t.Fatalf("Error at step index %d: %v", i, err) + } + if !reflect.DeepEqual(hf, step.want) { + t.Fatalf("At step index %d: Got headers %v; want %v", i, hf, step.want) + } + gotDynTab := d.dynTab.reverseCopy() + if !reflect.DeepEqual(gotDynTab, step.wantDynTab) { + t.Errorf("After step index %d, dynamic table = %v; want %v", i, gotDynTab, step.wantDynTab) + } + if d.dynTab.size != step.wantDynSize { + t.Errorf("After step index %d, dynamic table size = %v; want %v", i, d.dynTab.size, step.wantDynSize) + } + } +} + +func TestHuffmanDecode(t *testing.T) { + tests := []struct { + inHex, want string + }{ + {"f1e3 c2e5 f23a 6ba0 ab90 f4ff", "www.example.com"}, + {"a8eb 1064 9cbf", "no-cache"}, + {"25a8 49e9 5ba9 7d7f", "custom-key"}, + {"25a8 49e9 5bb8 e8b4 bf", "custom-value"}, + {"6402", "302"}, + {"aec3 771a 4b", "private"}, + {"d07a be94 1054 d444 a820 0595 040b 8166 e082 a62d 1bff", "Mon, 21 Oct 2013 20:13:21 GMT"}, + {"9d29 ad17 1863 c78f 0b97 c8e9 ae82 ae43 d3", "https://www.example.com"}, + {"9bd9 ab", "gzip"}, + {"94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 3160 65c0 03ed 4ee5 b106 3d50 07", + "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1"}, + } + for i, tt := range tests { + var buf bytes.Buffer + in, err := hex.DecodeString(strings.Replace(tt.inHex, " ", "", -1)) + if err != nil { + t.Errorf("%d. hex input error: %v", i, err) + continue + } + if _, err := HuffmanDecode(&buf, in); err != nil { + t.Errorf("%d. decode error: %v", i, err) + continue + } + if got := buf.String(); tt.want != got { + t.Errorf("%d. decode = %q; want %q", i, got, tt.want) + } + } +} + +func TestAppendHuffmanString(t *testing.T) { + tests := []struct { + in, want string + }{ + {"www.example.com", "f1e3 c2e5 f23a 6ba0 ab90 f4ff"}, + {"no-cache", "a8eb 1064 9cbf"}, + {"custom-key", "25a8 49e9 5ba9 7d7f"}, + {"custom-value", "25a8 49e9 5bb8 e8b4 bf"}, + {"302", "6402"}, + {"private", "aec3 771a 4b"}, + {"Mon, 21 Oct 2013 20:13:21 GMT", "d07a be94 1054 d444 a820 0595 040b 8166 e082 a62d 1bff"}, + {"https://www.example.com", "9d29 ad17 1863 c78f 0b97 c8e9 ae82 ae43 d3"}, + {"gzip", "9bd9 ab"}, + {"foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1", + "94e7 821d d7f2 e6c7 b335 dfdf cd5b 3960 d5af 2708 7f36 72c1 ab27 0fb5 291f 9587 3160 65c0 03ed 4ee5 b106 3d50 07"}, + } + for i, tt := range tests { + buf := []byte{} + want := strings.Replace(tt.want, " ", "", -1) + buf = AppendHuffmanString(buf, tt.in) + if got := hex.EncodeToString(buf); want != got { + t.Errorf("%d. encode = %q; want %q", i, got, want) + } + } +} + +func TestReadVarInt(t *testing.T) { + type res struct { + i uint64 + consumed int + err error + } + tests := []struct { + n byte + p []byte + want res + }{ + // Fits in a byte: + {1, []byte{0}, res{0, 1, nil}}, + {2, []byte{2}, res{2, 1, nil}}, + {3, []byte{6}, res{6, 1, nil}}, + {4, []byte{14}, res{14, 1, nil}}, + {5, []byte{30}, res{30, 1, nil}}, + {6, []byte{62}, res{62, 1, nil}}, + {7, []byte{126}, res{126, 1, nil}}, + {8, []byte{254}, res{254, 1, nil}}, + + // Doesn't fit in a byte: + {1, []byte{1}, res{0, 0, errNeedMore}}, + {2, []byte{3}, res{0, 0, errNeedMore}}, + {3, []byte{7}, res{0, 0, errNeedMore}}, + {4, []byte{15}, res{0, 0, errNeedMore}}, + {5, []byte{31}, res{0, 0, errNeedMore}}, + {6, []byte{63}, res{0, 0, errNeedMore}}, + {7, []byte{127}, res{0, 0, errNeedMore}}, + {8, []byte{255}, res{0, 0, errNeedMore}}, + + // Ignoring top bits: + {5, []byte{255, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 111 + {5, []byte{159, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 100 + {5, []byte{191, 154, 10}, res{1337, 3, nil}}, // high dummy three bits: 101 + + // Extra byte: + {5, []byte{191, 154, 10, 2}, res{1337, 3, nil}}, // extra byte + + // Short a byte: + {5, []byte{191, 154}, res{0, 0, errNeedMore}}, + + // integer overflow: + {1, []byte{255, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128}, res{0, 0, errVarintOverflow}}, + } + for _, tt := range tests { + i, remain, err := readVarInt(tt.n, tt.p) + consumed := len(tt.p) - len(remain) + got := res{i, consumed, err} + if got != tt.want { + t.Errorf("readVarInt(%d, %v ~ %x) = %+v; want %+v", tt.n, tt.p, tt.p, got, tt.want) + } + } +} + +func dehex(s string) []byte { + s = strings.Replace(s, " ", "", -1) + s = strings.Replace(s, "\n", "", -1) + b, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return b +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/huffman.go b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/huffman.go new file mode 100644 index 0000000..9fe76f6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/huffman.go @@ -0,0 +1,159 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package hpack + +import ( + "bytes" + "io" + "sync" +) + +var bufPool = sync.Pool{ + New: func() interface{} { return new(bytes.Buffer) }, +} + +// HuffmanDecode decodes the string in v and writes the expanded +// result to w, returning the number of bytes written to w and the +// Write call's return value. At most one Write call is made. +func HuffmanDecode(w io.Writer, v []byte) (int, error) { + buf := bufPool.Get().(*bytes.Buffer) + buf.Reset() + defer bufPool.Put(buf) + + n := rootHuffmanNode + cur, nbits := uint(0), uint8(0) + for _, b := range v { + cur = cur<<8 | uint(b) + nbits += 8 + for nbits >= 8 { + n = n.children[byte(cur>>(nbits-8))] + if n.children == nil { + buf.WriteByte(n.sym) + nbits -= n.codeLen + n = rootHuffmanNode + } else { + nbits -= 8 + } + } + } + for nbits > 0 { + n = n.children[byte(cur<<(8-nbits))] + if n.children != nil || n.codeLen > nbits { + break + } + buf.WriteByte(n.sym) + nbits -= n.codeLen + n = rootHuffmanNode + } + return w.Write(buf.Bytes()) +} + +type node struct { + // children is non-nil for internal nodes + children []*node + + // The following are only valid if children is nil: + codeLen uint8 // number of bits that led to the output of sym + sym byte // output symbol +} + +func newInternalNode() *node { + return &node{children: make([]*node, 256)} +} + +var rootHuffmanNode = newInternalNode() + +func init() { + for i, code := range huffmanCodes { + if i > 255 { + panic("too many huffman codes") + } + addDecoderNode(byte(i), code, huffmanCodeLen[i]) + } +} + +func addDecoderNode(sym byte, code uint32, codeLen uint8) { + cur := rootHuffmanNode + for codeLen > 8 { + codeLen -= 8 + i := uint8(code >> codeLen) + if cur.children[i] == nil { + cur.children[i] = newInternalNode() + } + cur = cur.children[i] + } + shift := 8 - codeLen + start, end := int(uint8(code<> (nbits - rembits)) + dst[len(dst)-1] |= t + } + + return dst +} + +// HuffmanEncodeLength returns the number of bytes required to encode +// s in Huffman codes. The result is round up to byte boundary. +func HuffmanEncodeLength(s string) uint64 { + n := uint64(0) + for i := 0; i < len(s); i++ { + n += uint64(huffmanCodeLen[s[i]]) + } + return (n + 7) / 8 +} + +// appendByteToHuffmanCode appends Huffman code for c to dst and +// returns the extended buffer and the remaining bits in the last +// element. The appending is not byte aligned and the remaining bits +// in the last element of dst is given in rembits. +func appendByteToHuffmanCode(dst []byte, rembits uint8, c byte) ([]byte, uint8) { + code := huffmanCodes[c] + nbits := huffmanCodeLen[c] + + for { + if rembits > nbits { + t := uint8(code << (rembits - nbits)) + dst[len(dst)-1] |= t + rembits -= nbits + break + } + + t := uint8(code >> (nbits - rembits)) + dst[len(dst)-1] |= t + + nbits -= rembits + rembits = 8 + + if nbits == 0 { + break + } + + dst = append(dst, 0) + } + + return dst, rembits +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/tables.go b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/tables.go new file mode 100644 index 0000000..f898e25 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/hpack/tables.go @@ -0,0 +1,353 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package hpack + +func pair(name, value string) HeaderField { + return HeaderField{Name: name, Value: value} +} + +// http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-07#appendix-B +var staticTable = []HeaderField{ + pair(":authority", ""), // index 1 (1-based) + pair(":method", "GET"), + pair(":method", "POST"), + pair(":path", "/"), + pair(":path", "/index.html"), + pair(":scheme", "http"), + pair(":scheme", "https"), + pair(":status", "200"), + pair(":status", "204"), + pair(":status", "206"), + pair(":status", "304"), + pair(":status", "400"), + pair(":status", "404"), + pair(":status", "500"), + pair("accept-charset", ""), + pair("accept-encoding", "gzip, deflate"), + pair("accept-language", ""), + pair("accept-ranges", ""), + pair("accept", ""), + pair("access-control-allow-origin", ""), + pair("age", ""), + pair("allow", ""), + pair("authorization", ""), + pair("cache-control", ""), + pair("content-disposition", ""), + pair("content-encoding", ""), + pair("content-language", ""), + pair("content-length", ""), + pair("content-location", ""), + pair("content-range", ""), + pair("content-type", ""), + pair("cookie", ""), + pair("date", ""), + pair("etag", ""), + pair("expect", ""), + pair("expires", ""), + pair("from", ""), + pair("host", ""), + pair("if-match", ""), + pair("if-modified-since", ""), + pair("if-none-match", ""), + pair("if-range", ""), + pair("if-unmodified-since", ""), + pair("last-modified", ""), + pair("link", ""), + pair("location", ""), + pair("max-forwards", ""), + pair("proxy-authenticate", ""), + pair("proxy-authorization", ""), + pair("range", ""), + pair("referer", ""), + pair("refresh", ""), + pair("retry-after", ""), + pair("server", ""), + pair("set-cookie", ""), + pair("strict-transport-security", ""), + pair("transfer-encoding", ""), + pair("user-agent", ""), + pair("vary", ""), + pair("via", ""), + pair("www-authenticate", ""), +} + +var huffmanCodes = []uint32{ + 0x1ff8, + 0x7fffd8, + 0xfffffe2, + 0xfffffe3, + 0xfffffe4, + 0xfffffe5, + 0xfffffe6, + 0xfffffe7, + 0xfffffe8, + 0xffffea, + 0x3ffffffc, + 0xfffffe9, + 0xfffffea, + 0x3ffffffd, + 0xfffffeb, + 0xfffffec, + 0xfffffed, + 0xfffffee, + 0xfffffef, + 0xffffff0, + 0xffffff1, + 0xffffff2, + 0x3ffffffe, + 0xffffff3, + 0xffffff4, + 0xffffff5, + 0xffffff6, + 0xffffff7, + 0xffffff8, + 0xffffff9, + 0xffffffa, + 0xffffffb, + 0x14, + 0x3f8, + 0x3f9, + 0xffa, + 0x1ff9, + 0x15, + 0xf8, + 0x7fa, + 0x3fa, + 0x3fb, + 0xf9, + 0x7fb, + 0xfa, + 0x16, + 0x17, + 0x18, + 0x0, + 0x1, + 0x2, + 0x19, + 0x1a, + 0x1b, + 0x1c, + 0x1d, + 0x1e, + 0x1f, + 0x5c, + 0xfb, + 0x7ffc, + 0x20, + 0xffb, + 0x3fc, + 0x1ffa, + 0x21, + 0x5d, + 0x5e, + 0x5f, + 0x60, + 0x61, + 0x62, + 0x63, + 0x64, + 0x65, + 0x66, + 0x67, + 0x68, + 0x69, + 0x6a, + 0x6b, + 0x6c, + 0x6d, + 0x6e, + 0x6f, + 0x70, + 0x71, + 0x72, + 0xfc, + 0x73, + 0xfd, + 0x1ffb, + 0x7fff0, + 0x1ffc, + 0x3ffc, + 0x22, + 0x7ffd, + 0x3, + 0x23, + 0x4, + 0x24, + 0x5, + 0x25, + 0x26, + 0x27, + 0x6, + 0x74, + 0x75, + 0x28, + 0x29, + 0x2a, + 0x7, + 0x2b, + 0x76, + 0x2c, + 0x8, + 0x9, + 0x2d, + 0x77, + 0x78, + 0x79, + 0x7a, + 0x7b, + 0x7ffe, + 0x7fc, + 0x3ffd, + 0x1ffd, + 0xffffffc, + 0xfffe6, + 0x3fffd2, + 0xfffe7, + 0xfffe8, + 0x3fffd3, + 0x3fffd4, + 0x3fffd5, + 0x7fffd9, + 0x3fffd6, + 0x7fffda, + 0x7fffdb, + 0x7fffdc, + 0x7fffdd, + 0x7fffde, + 0xffffeb, + 0x7fffdf, + 0xffffec, + 0xffffed, + 0x3fffd7, + 0x7fffe0, + 0xffffee, + 0x7fffe1, + 0x7fffe2, + 0x7fffe3, + 0x7fffe4, + 0x1fffdc, + 0x3fffd8, + 0x7fffe5, + 0x3fffd9, + 0x7fffe6, + 0x7fffe7, + 0xffffef, + 0x3fffda, + 0x1fffdd, + 0xfffe9, + 0x3fffdb, + 0x3fffdc, + 0x7fffe8, + 0x7fffe9, + 0x1fffde, + 0x7fffea, + 0x3fffdd, + 0x3fffde, + 0xfffff0, + 0x1fffdf, + 0x3fffdf, + 0x7fffeb, + 0x7fffec, + 0x1fffe0, + 0x1fffe1, + 0x3fffe0, + 0x1fffe2, + 0x7fffed, + 0x3fffe1, + 0x7fffee, + 0x7fffef, + 0xfffea, + 0x3fffe2, + 0x3fffe3, + 0x3fffe4, + 0x7ffff0, + 0x3fffe5, + 0x3fffe6, + 0x7ffff1, + 0x3ffffe0, + 0x3ffffe1, + 0xfffeb, + 0x7fff1, + 0x3fffe7, + 0x7ffff2, + 0x3fffe8, + 0x1ffffec, + 0x3ffffe2, + 0x3ffffe3, + 0x3ffffe4, + 0x7ffffde, + 0x7ffffdf, + 0x3ffffe5, + 0xfffff1, + 0x1ffffed, + 0x7fff2, + 0x1fffe3, + 0x3ffffe6, + 0x7ffffe0, + 0x7ffffe1, + 0x3ffffe7, + 0x7ffffe2, + 0xfffff2, + 0x1fffe4, + 0x1fffe5, + 0x3ffffe8, + 0x3ffffe9, + 0xffffffd, + 0x7ffffe3, + 0x7ffffe4, + 0x7ffffe5, + 0xfffec, + 0xfffff3, + 0xfffed, + 0x1fffe6, + 0x3fffe9, + 0x1fffe7, + 0x1fffe8, + 0x7ffff3, + 0x3fffea, + 0x3fffeb, + 0x1ffffee, + 0x1ffffef, + 0xfffff4, + 0xfffff5, + 0x3ffffea, + 0x7ffff4, + 0x3ffffeb, + 0x7ffffe6, + 0x3ffffec, + 0x3ffffed, + 0x7ffffe7, + 0x7ffffe8, + 0x7ffffe9, + 0x7ffffea, + 0x7ffffeb, + 0xffffffe, + 0x7ffffec, + 0x7ffffed, + 0x7ffffee, + 0x7ffffef, + 0x7fffff0, + 0x3ffffee, +} + +var huffmanCodeLen = []uint8{ + 13, 23, 28, 28, 28, 28, 28, 28, 28, 24, 30, 28, 28, 30, 28, 28, + 28, 28, 28, 28, 28, 28, 30, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 6, 10, 10, 12, 13, 6, 8, 11, 10, 10, 8, 11, 8, 6, 6, 6, + 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 8, 15, 6, 12, 10, + 13, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 8, 7, 8, 13, 19, 13, 14, 6, + 15, 5, 6, 5, 6, 5, 6, 6, 6, 5, 7, 7, 6, 6, 6, 5, + 6, 7, 6, 5, 5, 6, 7, 7, 7, 7, 7, 15, 11, 14, 13, 28, + 20, 22, 20, 20, 22, 22, 22, 23, 22, 23, 23, 23, 23, 23, 24, 23, + 24, 24, 22, 23, 24, 23, 23, 23, 23, 21, 22, 23, 22, 23, 23, 24, + 22, 21, 20, 22, 22, 23, 23, 21, 23, 22, 22, 24, 21, 22, 23, 23, + 21, 21, 22, 21, 23, 22, 23, 23, 20, 22, 22, 22, 23, 22, 22, 23, + 26, 26, 20, 19, 22, 23, 22, 25, 26, 26, 26, 27, 27, 26, 24, 25, + 19, 21, 26, 27, 27, 26, 27, 24, 21, 21, 26, 26, 28, 27, 27, 27, + 20, 24, 20, 21, 22, 21, 21, 23, 22, 22, 25, 25, 24, 24, 26, 23, + 26, 27, 26, 26, 27, 27, 27, 27, 27, 28, 27, 27, 27, 27, 27, 26, +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/http2.go b/Godeps/_workspace/src/github.com/bradfitz/http2/http2.go new file mode 100644 index 0000000..35f9b26 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/http2.go @@ -0,0 +1,249 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +// Package http2 implements the HTTP/2 protocol. +// +// This is a work in progress. This package is low-level and intended +// to be used directly by very few people. Most users will use it +// indirectly through integration with the net/http package. See +// ConfigureServer. That ConfigureServer call will likely be automatic +// or available via an empty import in the future. +// +// See http://http2.github.io/ +package http2 + +import ( + "bufio" + "fmt" + "io" + "net/http" + "strconv" + "sync" +) + +var VerboseLogs = false + +const ( + // ClientPreface is the string that must be sent by new + // connections from clients. + ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + // SETTINGS_MAX_FRAME_SIZE default + // http://http2.github.io/http2-spec/#rfc.section.6.5.2 + initialMaxFrameSize = 16384 + + // NextProtoTLS is the NPN/ALPN protocol negotiated during + // HTTP/2's TLS setup. + NextProtoTLS = "h2" + + // http://http2.github.io/http2-spec/#SettingValues + initialHeaderTableSize = 4096 + + initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size + + defaultMaxReadFrameSize = 1 << 20 +) + +var ( + clientPreface = []byte(ClientPreface) +) + +type streamState int + +const ( + stateIdle streamState = iota + stateOpen + stateHalfClosedLocal + stateHalfClosedRemote + stateResvLocal + stateResvRemote + stateClosed +) + +var stateName = [...]string{ + stateIdle: "Idle", + stateOpen: "Open", + stateHalfClosedLocal: "HalfClosedLocal", + stateHalfClosedRemote: "HalfClosedRemote", + stateResvLocal: "ResvLocal", + stateResvRemote: "ResvRemote", + stateClosed: "Closed", +} + +func (st streamState) String() string { + return stateName[st] +} + +// Setting is a setting parameter: which setting it is, and its value. +type Setting struct { + // ID is which setting is being set. + // See http://http2.github.io/http2-spec/#SettingValues + ID SettingID + + // Val is the value. + Val uint32 +} + +func (s Setting) String() string { + return fmt.Sprintf("[%v = %d]", s.ID, s.Val) +} + +// Valid reports whether the setting is valid. +func (s Setting) Valid() error { + // Limits and error codes from 6.5.2 Defined SETTINGS Parameters + switch s.ID { + case SettingEnablePush: + if s.Val != 1 && s.Val != 0 { + return ConnectionError(ErrCodeProtocol) + } + case SettingInitialWindowSize: + if s.Val > 1<<31-1 { + return ConnectionError(ErrCodeFlowControl) + } + case SettingMaxFrameSize: + if s.Val < 16384 || s.Val > 1<<24-1 { + return ConnectionError(ErrCodeProtocol) + } + } + return nil +} + +// A SettingID is an HTTP/2 setting as defined in +// http://http2.github.io/http2-spec/#iana-settings +type SettingID uint16 + +const ( + SettingHeaderTableSize SettingID = 0x1 + SettingEnablePush SettingID = 0x2 + SettingMaxConcurrentStreams SettingID = 0x3 + SettingInitialWindowSize SettingID = 0x4 + SettingMaxFrameSize SettingID = 0x5 + SettingMaxHeaderListSize SettingID = 0x6 +) + +var settingName = map[SettingID]string{ + SettingHeaderTableSize: "HEADER_TABLE_SIZE", + SettingEnablePush: "ENABLE_PUSH", + SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS", + SettingInitialWindowSize: "INITIAL_WINDOW_SIZE", + SettingMaxFrameSize: "MAX_FRAME_SIZE", + SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE", +} + +func (s SettingID) String() string { + if v, ok := settingName[s]; ok { + return v + } + return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s)) +} + +func validHeader(v string) bool { + if len(v) == 0 { + return false + } + for _, r := range v { + // "Just as in HTTP/1.x, header field names are + // strings of ASCII characters that are compared in a + // case-insensitive fashion. However, header field + // names MUST be converted to lowercase prior to their + // encoding in HTTP/2. " + if r >= 127 || ('A' <= r && r <= 'Z') { + return false + } + } + return true +} + +var httpCodeStringCommon = map[int]string{} // n -> strconv.Itoa(n) + +func init() { + for i := 100; i <= 999; i++ { + if v := http.StatusText(i); v != "" { + httpCodeStringCommon[i] = strconv.Itoa(i) + } + } +} + +func httpCodeString(code int) string { + if s, ok := httpCodeStringCommon[code]; ok { + return s + } + return strconv.Itoa(code) +} + +// from pkg io +type stringWriter interface { + WriteString(s string) (n int, err error) +} + +// A gate lets two goroutines coordinate their activities. +type gate chan struct{} + +func (g gate) Done() { g <- struct{}{} } +func (g gate) Wait() { <-g } + +// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed). +type closeWaiter chan struct{} + +// Init makes a closeWaiter usable. +// It exists because so a closeWaiter value can be placed inside a +// larger struct and have the Mutex and Cond's memory in the same +// allocation. +func (cw *closeWaiter) Init() { + *cw = make(chan struct{}) +} + +// Close marks the closeWaiter as closed and unblocks any waiters. +func (cw closeWaiter) Close() { + close(cw) +} + +// Wait waits for the closeWaiter to become closed. +func (cw closeWaiter) Wait() { + <-cw +} + +// bufferedWriter is a buffered writer that writes to w. +// Its buffered writer is lazily allocated as needed, to minimize +// idle memory usage with many connections. +type bufferedWriter struct { + w io.Writer // immutable + bw *bufio.Writer // non-nil when data is buffered +} + +func newBufferedWriter(w io.Writer) *bufferedWriter { + return &bufferedWriter{w: w} +} + +var bufWriterPool = sync.Pool{ + New: func() interface{} { + // TODO: pick something better? this is a bit under + // (3 x typical 1500 byte MTU) at least. + return bufio.NewWriterSize(nil, 4<<10) + }, +} + +func (w *bufferedWriter) Write(p []byte) (n int, err error) { + if w.bw == nil { + bw := bufWriterPool.Get().(*bufio.Writer) + bw.Reset(w.w) + w.bw = bw + } + return w.bw.Write(p) +} + +func (w *bufferedWriter) Flush() error { + bw := w.bw + if bw == nil { + return nil + } + err := bw.Flush() + bw.Reset(nil) + bufWriterPool.Put(bw) + w.bw = nil + return err +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/http2_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/http2_test.go new file mode 100644 index 0000000..6a1b7e8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/http2_test.go @@ -0,0 +1,152 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "bytes" + "errors" + "flag" + "fmt" + "net/http" + "os/exec" + "strconv" + "strings" + "testing" + + "github.com/bradfitz/http2/hpack" +) + +var knownFailing = flag.Bool("known_failing", false, "Run known-failing tests.") + +func condSkipFailingTest(t *testing.T) { + if !*knownFailing { + t.Skip("Skipping known-failing test without --known_failing") + } +} + +func init() { + DebugGoroutines = true + flag.BoolVar(&VerboseLogs, "verboseh2", false, "Verbose HTTP/2 debug logging") +} + +func TestSettingString(t *testing.T) { + tests := []struct { + s Setting + want string + }{ + {Setting{SettingMaxFrameSize, 123}, "[MAX_FRAME_SIZE = 123]"}, + {Setting{1<<16 - 1, 123}, "[UNKNOWN_SETTING_65535 = 123]"}, + } + for i, tt := range tests { + got := fmt.Sprint(tt.s) + if got != tt.want { + t.Errorf("%d. for %#v, string = %q; want %q", i, tt.s, got, tt.want) + } + } +} + +type twriter struct { + t testing.TB + st *serverTester // optional +} + +func (w twriter) Write(p []byte) (n int, err error) { + if w.st != nil { + ps := string(p) + for _, phrase := range w.st.logFilter { + if strings.Contains(ps, phrase) { + return len(p), nil // no logging + } + } + } + w.t.Logf("%s", p) + return len(p), nil +} + +// like encodeHeader, but don't add implicit psuedo headers. +func encodeHeaderNoImplicit(t *testing.T, headers ...string) []byte { + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + for len(headers) > 0 { + k, v := headers[0], headers[1] + headers = headers[2:] + if err := enc.WriteField(hpack.HeaderField{Name: k, Value: v}); err != nil { + t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) + } + } + return buf.Bytes() +} + +// Verify that curl has http2. +func requireCurl(t *testing.T) { + out, err := dockerLogs(curl(t, "--version")) + if err != nil { + t.Skipf("failed to determine curl features; skipping test") + } + if !strings.Contains(string(out), "HTTP2") { + t.Skip("curl doesn't support HTTP2; skipping test") + } +} + +func curl(t *testing.T, args ...string) (container string) { + out, err := exec.Command("docker", append([]string{"run", "-d", "--net=host", "gohttp2/curl"}, args...)...).CombinedOutput() + if err != nil { + t.Skipf("Failed to run curl in docker: %v, %s", err, out) + } + return strings.TrimSpace(string(out)) +} + +type puppetCommand struct { + fn func(w http.ResponseWriter, r *http.Request) + done chan<- bool +} + +type handlerPuppet struct { + ch chan puppetCommand +} + +func newHandlerPuppet() *handlerPuppet { + return &handlerPuppet{ + ch: make(chan puppetCommand), + } +} + +func (p *handlerPuppet) act(w http.ResponseWriter, r *http.Request) { + for cmd := range p.ch { + cmd.fn(w, r) + cmd.done <- true + } +} + +func (p *handlerPuppet) done() { close(p.ch) } +func (p *handlerPuppet) do(fn func(http.ResponseWriter, *http.Request)) { + done := make(chan bool) + p.ch <- puppetCommand{fn, done} + <-done +} +func dockerLogs(container string) ([]byte, error) { + out, err := exec.Command("docker", "wait", container).CombinedOutput() + if err != nil { + return out, err + } + exitStatus, err := strconv.Atoi(strings.TrimSpace(string(out))) + if err != nil { + return out, errors.New("unexpected exit status from docker wait") + } + out, err = exec.Command("docker", "logs", container).CombinedOutput() + exec.Command("docker", "rm", container).Run() + if err == nil && exitStatus != 0 { + err = fmt.Errorf("exit status %d: %s", exitStatus, out) + } + return out, err +} + +func kill(container string) { + exec.Command("docker", "kill", container).Run() + exec.Command("docker", "rm", container).Run() +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/pipe.go b/Godeps/_workspace/src/github.com/bradfitz/http2/pipe.go new file mode 100644 index 0000000..ce9aad5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/pipe.go @@ -0,0 +1,43 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "sync" +) + +type pipe struct { + b buffer + c sync.Cond + m sync.Mutex +} + +// Read waits until data is available and copies bytes +// from the buffer into p. +func (r *pipe) Read(p []byte) (n int, err error) { + r.c.L.Lock() + defer r.c.L.Unlock() + for r.b.Len() == 0 && !r.b.closed { + r.c.Wait() + } + return r.b.Read(p) +} + +// Write copies bytes from p into the buffer and wakes a reader. +// It is an error to write more data than the buffer can hold. +func (w *pipe) Write(p []byte) (n int, err error) { + w.c.L.Lock() + defer w.c.L.Unlock() + defer w.c.Signal() + return w.b.Write(p) +} + +func (c *pipe) Close(err error) { + c.c.L.Lock() + defer c.c.L.Unlock() + defer c.c.Signal() + c.b.Close(err) +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/pipe_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/pipe_test.go new file mode 100644 index 0000000..10c4c32 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/pipe_test.go @@ -0,0 +1,24 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "errors" + "testing" +) + +func TestPipeClose(t *testing.T) { + var p pipe + p.c.L = &p.m + a := errors.New("a") + b := errors.New("b") + p.Close(a) + p.Close(b) + _, err := p.Read(make([]byte, 1)) + if err != a { + t.Errorf("err = %v want %v", err, a) + } +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/priority_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/priority_test.go new file mode 100644 index 0000000..d9648fd --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/priority_test.go @@ -0,0 +1,121 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "testing" +) + +func TestPriority(t *testing.T) { + // A -> B + // move A's parent to B + streams := make(map[uint32]*stream) + a := &stream{ + parent: nil, + weight: 16, + } + streams[1] = a + b := &stream{ + parent: a, + weight: 16, + } + streams[2] = b + adjustStreamPriority(streams, 1, PriorityParam{ + Weight: 20, + StreamDep: 2, + }) + if a.parent != b { + t.Errorf("Expected A's parent to be B") + } + if a.weight != 20 { + t.Errorf("Expected A's weight to be 20; got %d", a.weight) + } + if b.parent != nil { + t.Errorf("Expected B to have no parent") + } + if b.weight != 16 { + t.Errorf("Expected B's weight to be 16; got %d", b.weight) + } +} + +func TestPriorityExclusiveZero(t *testing.T) { + // A B and C are all children of the 0 stream. + // Exclusive reprioritization to any of the streams + // should bring the rest of the streams under the + // reprioritized stream + streams := make(map[uint32]*stream) + a := &stream{ + parent: nil, + weight: 16, + } + streams[1] = a + b := &stream{ + parent: nil, + weight: 16, + } + streams[2] = b + c := &stream{ + parent: nil, + weight: 16, + } + streams[3] = c + adjustStreamPriority(streams, 3, PriorityParam{ + Weight: 20, + StreamDep: 0, + Exclusive: true, + }) + if a.parent != c { + t.Errorf("Expected A's parent to be C") + } + if a.weight != 16 { + t.Errorf("Expected A's weight to be 16; got %d", a.weight) + } + if b.parent != c { + t.Errorf("Expected B's parent to be C") + } + if b.weight != 16 { + t.Errorf("Expected B's weight to be 16; got %d", b.weight) + } + if c.parent != nil { + t.Errorf("Expected C to have no parent") + } + if c.weight != 20 { + t.Errorf("Expected C's weight to be 20; got %d", b.weight) + } +} + +func TestPriorityOwnParent(t *testing.T) { + streams := make(map[uint32]*stream) + a := &stream{ + parent: nil, + weight: 16, + } + streams[1] = a + b := &stream{ + parent: a, + weight: 16, + } + streams[2] = b + adjustStreamPriority(streams, 1, PriorityParam{ + Weight: 20, + StreamDep: 1, + }) + if a.parent != nil { + t.Errorf("Expected A's parent to be nil") + } + if a.weight != 20 { + t.Errorf("Expected A's weight to be 20; got %d", a.weight) + } + if b.parent != a { + t.Errorf("Expected B's parent to be A") + } + if b.weight != 16 { + t.Errorf("Expected B's weight to be 16; got %d", b.weight) + } + +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/server.go b/Godeps/_workspace/src/github.com/bradfitz/http2/server.go new file mode 100644 index 0000000..aa0e7bd --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/server.go @@ -0,0 +1,1780 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +// TODO: replace all <-sc.doneServing with reads from the stream's cw +// instead, and make sure that on close we close all open +// streams. then remove doneServing? + +// TODO: finish GOAWAY support. Consider each incoming frame type and +// whether it should be ignored during a shutdown race. + +// TODO: disconnect idle clients. GFE seems to do 4 minutes. make +// configurable? or maximum number of idle clients and remove the +// oldest? + +// TODO: turn off the serve goroutine when idle, so +// an idle conn only has the readFrames goroutine active. (which could +// also be optimized probably to pin less memory in crypto/tls). This +// would involve tracking when the serve goroutine is active (atomic +// int32 read/CAS probably?) and starting it up when frames arrive, +// and shutting it down when all handlers exit. the occasional PING +// packets could use time.AfterFunc to call sc.wakeStartServeLoop() +// (which is a no-op if already running) and then queue the PING write +// as normal. The serve loop would then exit in most cases (if no +// Handlers running) and not be woken up again until the PING packet +// returns. + +// TODO (maybe): add a mechanism for Handlers to going into +// half-closed-local mode (rw.(io.Closer) test?) but not exit their +// handler, and continue to be able to read from the +// Request.Body. This would be a somewhat semantic change from HTTP/1 +// (or at least what we expose in net/http), so I'd probably want to +// add it there too. For now, this package says that returning from +// the Handler ServeHTTP function means you're both done reading and +// done writing, without a way to stop just one or the other. + +package http2 + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/bradfitz/http2/hpack" +) + +const ( + prefaceTimeout = 10 * time.Second + firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway + handlerChunkWriteSize = 4 << 10 + defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? +) + +var ( + errClientDisconnected = errors.New("client disconnected") + errClosedBody = errors.New("body closed by handler") + errStreamBroken = errors.New("http2: stream broken") +) + +var responseWriterStatePool = sync.Pool{ + New: func() interface{} { + rws := &responseWriterState{} + rws.bw = bufio.NewWriterSize(chunkWriter{rws}, handlerChunkWriteSize) + return rws + }, +} + +// Test hooks. +var ( + testHookOnConn func() + testHookGetServerConn func(*serverConn) + testHookOnPanicMu *sync.Mutex // nil except in tests + testHookOnPanic func(sc *serverConn, panicVal interface{}) (rePanic bool) +) + +// Server is an HTTP/2 server. +type Server struct { + // MaxHandlers limits the number of http.Handler ServeHTTP goroutines + // which may run at a time over all connections. + // Negative or zero no limit. + // TODO: implement + MaxHandlers int + + // MaxConcurrentStreams optionally specifies the number of + // concurrent streams that each client may have open at a + // time. This is unrelated to the number of http.Handler goroutines + // which may be active globally, which is MaxHandlers. + // If zero, MaxConcurrentStreams defaults to at least 100, per + // the HTTP/2 spec's recommendations. + MaxConcurrentStreams uint32 + + // MaxReadFrameSize optionally specifies the largest frame + // this server is willing to read. A valid value is between + // 16k and 16M, inclusive. If zero or otherwise invalid, a + // default value is used. + MaxReadFrameSize uint32 + + // PermitProhibitedCipherSuites, if true, permits the use of + // cipher suites prohibited by the HTTP/2 spec. + PermitProhibitedCipherSuites bool +} + +func (s *Server) maxReadFrameSize() uint32 { + if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize { + return v + } + return defaultMaxReadFrameSize +} + +func (s *Server) maxConcurrentStreams() uint32 { + if v := s.MaxConcurrentStreams; v > 0 { + return v + } + return defaultMaxStreams +} + +// ConfigureServer adds HTTP/2 support to a net/http Server. +// +// The configuration conf may be nil. +// +// ConfigureServer must be called before s begins serving. +func ConfigureServer(s *http.Server, conf *Server) { + if conf == nil { + conf = new(Server) + } + if s.TLSConfig == nil { + s.TLSConfig = new(tls.Config) + } + + // Note: not setting MinVersion to tls.VersionTLS12, + // as we don't want to interfere with HTTP/1.1 traffic + // on the user's server. We enforce TLS 1.2 later once + // we accept a connection. Ideally this should be done + // during next-proto selection, but using TLS <1.2 with + // HTTP/2 is still the client's bug. + + // Be sure we advertise tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + // at least. + // TODO: enable PreferServerCipherSuites? + if s.TLSConfig.CipherSuites != nil { + const requiredCipher = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + haveRequired := false + for _, v := range s.TLSConfig.CipherSuites { + if v == requiredCipher { + haveRequired = true + break + } + } + if !haveRequired { + s.TLSConfig.CipherSuites = append(s.TLSConfig.CipherSuites, requiredCipher) + } + } + + haveNPN := false + for _, p := range s.TLSConfig.NextProtos { + if p == NextProtoTLS { + haveNPN = true + break + } + } + if !haveNPN { + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, NextProtoTLS) + } + // h2-14 is temporary (as of 2015-03-05) while we wait for all browsers + // to switch to "h2". + s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2-14") + + if s.TLSNextProto == nil { + s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){} + } + protoHandler := func(hs *http.Server, c *tls.Conn, h http.Handler) { + if testHookOnConn != nil { + testHookOnConn() + } + conf.handleConn(hs, c, h) + } + s.TLSNextProto[NextProtoTLS] = protoHandler + s.TLSNextProto["h2-14"] = protoHandler // temporary; see above. +} + +func (srv *Server) handleConn(hs *http.Server, c net.Conn, h http.Handler) { + sc := &serverConn{ + srv: srv, + hs: hs, + conn: c, + remoteAddrStr: c.RemoteAddr().String(), + bw: newBufferedWriter(c), + handler: h, + streams: make(map[uint32]*stream), + readFrameCh: make(chan frameAndGate), + readFrameErrCh: make(chan error, 1), // must be buffered for 1 + wantWriteFrameCh: make(chan frameWriteMsg, 8), + wroteFrameCh: make(chan struct{}, 1), // buffered; one send in reading goroutine + bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way + doneServing: make(chan struct{}), + advMaxStreams: srv.maxConcurrentStreams(), + writeSched: writeScheduler{ + maxFrameSize: initialMaxFrameSize, + }, + initialWindowSize: initialWindowSize, + headerTableSize: initialHeaderTableSize, + serveG: newGoroutineLock(), + pushEnabled: true, + } + sc.flow.add(initialWindowSize) + sc.inflow.add(initialWindowSize) + sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) + sc.hpackDecoder = hpack.NewDecoder(initialHeaderTableSize, sc.onNewHeaderField) + + fr := NewFramer(sc.bw, c) + fr.SetMaxReadFrameSize(srv.maxReadFrameSize()) + sc.framer = fr + + if tc, ok := c.(*tls.Conn); ok { + sc.tlsState = new(tls.ConnectionState) + *sc.tlsState = tc.ConnectionState() + // 9.2 Use of TLS Features + // An implementation of HTTP/2 over TLS MUST use TLS + // 1.2 or higher with the restrictions on feature set + // and cipher suite described in this section. Due to + // implementation limitations, it might not be + // possible to fail TLS negotiation. An endpoint MUST + // immediately terminate an HTTP/2 connection that + // does not meet the TLS requirements described in + // this section with a connection error (Section + // 5.4.1) of type INADEQUATE_SECURITY. + if sc.tlsState.Version < tls.VersionTLS12 { + sc.rejectConn(ErrCodeInadequateSecurity, "TLS version too low") + return + } + + if sc.tlsState.ServerName == "" { + // Client must use SNI, but we don't enforce that anymore, + // since it was causing problems when connecting to bare IP + // addresses during development. + // + // TODO: optionally enforce? Or enforce at the time we receive + // a new request, and verify the the ServerName matches the :authority? + // But that precludes proxy situations, perhaps. + // + // So for now, do nothing here again. + } + + if !srv.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) { + // "Endpoints MAY choose to generate a connection error + // (Section 5.4.1) of type INADEQUATE_SECURITY if one of + // the prohibited cipher suites are negotiated." + // + // We choose that. In my opinion, the spec is weak + // here. It also says both parties must support at least + // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 so there's no + // excuses here. If we really must, we could allow an + // "AllowInsecureWeakCiphers" option on the server later. + // Let's see how it plays out first. + sc.rejectConn(ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite)) + return + } + } + + if hook := testHookGetServerConn; hook != nil { + hook(sc) + } + sc.serve() +} + +// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. +func isBadCipher(cipher uint16) bool { + switch cipher { + case tls.TLS_RSA_WITH_RC4_128_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: + // Reject cipher suites from Appendix A. + // "This list includes those cipher suites that do not + // offer an ephemeral key exchange and those that are + // based on the TLS null, stream or block cipher type" + return true + default: + return false + } +} + +func (sc *serverConn) rejectConn(err ErrCode, debug string) { + log.Printf("REJECTING conn: %v, %s", err, debug) + // ignoring errors. hanging up anyway. + sc.framer.WriteGoAway(0, err, []byte(debug)) + sc.bw.Flush() + sc.conn.Close() +} + +// frameAndGates coordinates the readFrames and serve +// goroutines. Because the Framer interface only permits the most +// recently-read Frame from being accessed, the readFrames goroutine +// blocks until it has a frame, passes it to serve, and then waits for +// serve to be done with it before reading the next one. +type frameAndGate struct { + f Frame + g gate +} + +type serverConn struct { + // Immutable: + srv *Server + hs *http.Server + conn net.Conn + bw *bufferedWriter // writing to conn + handler http.Handler + framer *Framer + hpackDecoder *hpack.Decoder + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan frameAndGate // written by serverConn.readFrames + readFrameErrCh chan error + wantWriteFrameCh chan frameWriteMsg // from handlers -> serve + wroteFrameCh chan struct{} // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan bodyReadMsg // from handlers -> serve + testHookCh chan func() // code to run on the serve loop + flow flow // conn-wide (not stream-specific) outbound flow control + inflow flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http + remoteAddrStr string + + // Everything following is owned by the serve loop; use serveG.check(): + serveG goroutineLock // used to verify funcs are on serve() + pushEnabled bool + sawFirstSettings bool // got the initial SETTINGS frame after the preface + needToSendSettingsAck bool + unackedSettings int // how many SETTINGS have we sent without ACKs? + clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) + advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client + curOpenStreams uint32 // client's number of open streams + maxStreamID uint32 // max ever seen + streams map[uint32]*stream + initialWindowSize int32 + headerTableSize uint32 + maxHeaderListSize uint32 // zero means unknown (default) + canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case + req requestParam // non-zero while reading request headers + writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh + needsFrameFlush bool // last frame write wasn't a flush + writeSched writeScheduler + inGoAway bool // we've started to or sent GOAWAY + needToSendGoAway bool // we need to schedule a GOAWAY frame write + goAwayCode ErrCode + shutdownTimerCh <-chan time.Time // nil until used + shutdownTimer *time.Timer // nil until used + + // Owned by the writeFrameAsync goroutine: + headerWriteBuf bytes.Buffer + hpackEncoder *hpack.Encoder +} + +// requestParam is the state of the next request, initialized over +// potentially several frames HEADERS + zero or more CONTINUATION +// frames. +type requestParam struct { + // stream is non-nil if we're reading (HEADER or CONTINUATION) + // frames for a request (but not DATA). + stream *stream + header http.Header + method, path string + scheme, authority string + sawRegularHeader bool // saw a non-pseudo header already + invalidHeader bool // an invalid header was seen +} + +// stream represents a stream. This is the minimal metadata needed by +// the serve goroutine. Most of the actual stream state is owned by +// the http.Handler's goroutine in the responseWriter. Because the +// responseWriter's responseWriterState is recycled at the end of a +// handler, this struct intentionally has no pointer to the +// *responseWriter{,State} itself, as the Handler ending nils out the +// responseWriter's state field. +type stream struct { + // immutable: + id uint32 + body *pipe // non-nil if expecting DATA frames + cw closeWaiter // closed wait stream transitions to closed state + + // owned by serverConn's serve loop: + bodyBytes int64 // body bytes seen so far + declBodyBytes int64 // or -1 if undeclared + flow flow // limits writing from Handler to client + inflow flow // what the client is allowed to POST/etc to us + parent *stream // or nil + weight uint8 + state streamState + sentReset bool // only true once detached from streams map + gotReset bool // only true once detacted from streams map +} + +func (sc *serverConn) Framer() *Framer { return sc.framer } +func (sc *serverConn) CloseConn() error { return sc.conn.Close() } +func (sc *serverConn) Flush() error { return sc.bw.Flush() } +func (sc *serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) { + return sc.hpackEncoder, &sc.headerWriteBuf +} + +func (sc *serverConn) state(streamID uint32) (streamState, *stream) { + sc.serveG.check() + // http://http2.github.io/http2-spec/#rfc.section.5.1 + if st, ok := sc.streams[streamID]; ok { + return st.state, st + } + // "The first use of a new stream identifier implicitly closes all + // streams in the "idle" state that might have been initiated by + // that peer with a lower-valued stream identifier. For example, if + // a client sends a HEADERS frame on stream 7 without ever sending a + // frame on stream 5, then stream 5 transitions to the "closed" + // state when the first frame for stream 7 is sent or received." + if streamID <= sc.maxStreamID { + return stateClosed, nil + } + return stateIdle, nil +} + +func (sc *serverConn) vlogf(format string, args ...interface{}) { + if VerboseLogs { + sc.logf(format, args...) + } +} + +func (sc *serverConn) logf(format string, args ...interface{}) { + if lg := sc.hs.ErrorLog; lg != nil { + lg.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +func (sc *serverConn) condlogf(err error, format string, args ...interface{}) { + if err == nil { + return + } + str := err.Error() + if err == io.EOF || strings.Contains(str, "use of closed network connection") { + // Boring, expected errors. + sc.vlogf(format, args...) + } else { + sc.logf(format, args...) + } +} + +func (sc *serverConn) onNewHeaderField(f hpack.HeaderField) { + sc.serveG.check() + sc.vlogf("got header field %+v", f) + switch { + case !validHeader(f.Name): + sc.req.invalidHeader = true + case strings.HasPrefix(f.Name, ":"): + if sc.req.sawRegularHeader { + sc.logf("pseudo-header after regular header") + sc.req.invalidHeader = true + return + } + var dst *string + switch f.Name { + case ":method": + dst = &sc.req.method + case ":path": + dst = &sc.req.path + case ":scheme": + dst = &sc.req.scheme + case ":authority": + dst = &sc.req.authority + default: + // 8.1.2.1 Pseudo-Header Fields + // "Endpoints MUST treat a request or response + // that contains undefined or invalid + // pseudo-header fields as malformed (Section + // 8.1.2.6)." + sc.logf("invalid pseudo-header %q", f.Name) + sc.req.invalidHeader = true + return + } + if *dst != "" { + sc.logf("duplicate pseudo-header %q sent", f.Name) + sc.req.invalidHeader = true + return + } + *dst = f.Value + case f.Name == "cookie": + sc.req.sawRegularHeader = true + if s, ok := sc.req.header["Cookie"]; ok && len(s) == 1 { + s[0] = s[0] + "; " + f.Value + } else { + sc.req.header.Add("Cookie", f.Value) + } + default: + sc.req.sawRegularHeader = true + sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value) + } +} + +func (sc *serverConn) canonicalHeader(v string) string { + sc.serveG.check() + cv, ok := commonCanonHeader[v] + if ok { + return cv + } + cv, ok = sc.canonHeader[v] + if ok { + return cv + } + if sc.canonHeader == nil { + sc.canonHeader = make(map[string]string) + } + cv = http.CanonicalHeaderKey(v) + sc.canonHeader[v] = cv + return cv +} + +// readFrames is the loop that reads incoming frames. +// It's run on its own goroutine. +func (sc *serverConn) readFrames() { + g := make(gate, 1) + for { + f, err := sc.framer.ReadFrame() + if err != nil { + sc.readFrameErrCh <- err + close(sc.readFrameCh) + return + } + sc.readFrameCh <- frameAndGate{f, g} + // We can't read another frame until this one is + // processed, as the ReadFrame interface doesn't copy + // memory. The Frame accessor methods access the last + // frame's (shared) buffer. So we wait for the + // serve goroutine to tell us it's done: + g.Wait() + } +} + +// writeFrameAsync runs in its own goroutine and writes a single frame +// and then reports when it's done. +// At most one goroutine can be running writeFrameAsync at a time per +// serverConn. +func (sc *serverConn) writeFrameAsync(wm frameWriteMsg) { + err := wm.write.writeFrame(sc) + if ch := wm.done; ch != nil { + select { + case ch <- err: + default: + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write)) + } + } + sc.wroteFrameCh <- struct{}{} // tickle frame selection scheduler +} + +func (sc *serverConn) closeAllStreamsOnConnClose() { + sc.serveG.check() + for _, st := range sc.streams { + sc.closeStream(st, errClientDisconnected) + } +} + +func (sc *serverConn) stopShutdownTimer() { + sc.serveG.check() + if t := sc.shutdownTimer; t != nil { + t.Stop() + } +} + +func (sc *serverConn) notePanic() { + if testHookOnPanicMu != nil { + testHookOnPanicMu.Lock() + defer testHookOnPanicMu.Unlock() + } + if testHookOnPanic != nil { + if e := recover(); e != nil { + if testHookOnPanic(sc, e) { + panic(e) + } + } + } +} + +func (sc *serverConn) serve() { + sc.serveG.check() + defer sc.notePanic() + defer sc.conn.Close() + defer sc.closeAllStreamsOnConnClose() + defer sc.stopShutdownTimer() + defer close(sc.doneServing) // unblocks handlers trying to send + + sc.vlogf("HTTP/2 connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) + + sc.writeFrame(frameWriteMsg{ + write: writeSettings{ + {SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, + {SettingMaxConcurrentStreams, sc.advMaxStreams}, + + // TODO: more actual settings, notably + // SettingInitialWindowSize, but then we also + // want to bump up the conn window size the + // same amount here right after the settings + }, + }) + sc.unackedSettings++ + + if err := sc.readPreface(); err != nil { + sc.condlogf(err, "error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) + return + } + + go sc.readFrames() // closed by defer sc.conn.Close above + + settingsTimer := time.NewTimer(firstSettingsTimeout) + for { + select { + case wm := <-sc.wantWriteFrameCh: + sc.writeFrame(wm) + case <-sc.wroteFrameCh: + if sc.writingFrame != true { + panic("internal error: expected to be already writing a frame") + } + sc.writingFrame = false + sc.scheduleFrameWrite() + case fg, ok := <-sc.readFrameCh: + if !ok { + sc.readFrameCh = nil + } + if !sc.processFrameFromReader(fg, ok) { + return + } + if settingsTimer.C != nil { + settingsTimer.Stop() + settingsTimer.C = nil + } + case m := <-sc.bodyReadCh: + sc.noteBodyRead(m.st, m.n) + case <-settingsTimer.C: + sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) + return + case <-sc.shutdownTimerCh: + sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) + return + case fn := <-sc.testHookCh: + fn() + } + } +} + +// readPreface reads the ClientPreface greeting from the peer +// or returns an error on timeout or an invalid greeting. +func (sc *serverConn) readPreface() error { + errc := make(chan error, 1) + go func() { + // Read the client preface + buf := make([]byte, len(ClientPreface)) + if _, err := io.ReadFull(sc.conn, buf); err != nil { + errc <- err + } else if !bytes.Equal(buf, clientPreface) { + errc <- fmt.Errorf("bogus greeting %q", buf) + } else { + errc <- nil + } + }() + timer := time.NewTimer(prefaceTimeout) // TODO: configurable on *Server? + defer timer.Stop() + select { + case <-timer.C: + return errors.New("timeout waiting for client preface") + case err := <-errc: + if err == nil { + sc.vlogf("client %v said hello", sc.conn.RemoteAddr()) + } + return err + } +} + +// writeDataFromHandler writes the data described in req to stream.id. +// +// The provided ch is used to avoid allocating new channels for each +// write operation. It's expected that the caller reuses writeData and ch +// over time. +// +// The flow control currently happens in the Handler where it waits +// for 1 or more bytes to be available to then write here. So at this +// point we know that we have flow control. But this might have to +// change when priority is implemented, so the serve goroutine knows +// the total amount of bytes waiting to be sent and can can have more +// scheduling decisions available. +func (sc *serverConn) writeDataFromHandler(stream *stream, writeData *writeData, ch chan error) error { + sc.writeFrameFromHandler(frameWriteMsg{ + write: writeData, + stream: stream, + done: ch, + }) + select { + case err := <-ch: + return err + case <-sc.doneServing: + return errClientDisconnected + case <-stream.cw: + return errStreamBroken + } +} + +// writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts +// if the connection has gone away. +// +// This must not be run from the serve goroutine itself, else it might +// deadlock writing to sc.wantWriteFrameCh (which is only mildly +// buffered and is read by serve itself). If you're on the serve +// goroutine, call writeFrame instead. +func (sc *serverConn) writeFrameFromHandler(wm frameWriteMsg) { + sc.serveG.checkNotOn() // NOT + select { + case sc.wantWriteFrameCh <- wm: + case <-sc.doneServing: + // Client has closed their connection to the server. + } +} + +// writeFrame schedules a frame to write and sends it if there's nothing +// already being written. +// +// There is no pushback here (the serve goroutine never blocks). It's +// the http.Handlers that block, waiting for their previous frames to +// make it onto the wire +// +// If you're not on the serve goroutine, use writeFrameFromHandler instead. +func (sc *serverConn) writeFrame(wm frameWriteMsg) { + sc.serveG.check() + sc.writeSched.add(wm) + sc.scheduleFrameWrite() +} + +// startFrameWrite starts a goroutine to write wm (in a separate +// goroutine since that might block on the network), and updates the +// serve goroutine's state about the world, updated from info in wm. +func (sc *serverConn) startFrameWrite(wm frameWriteMsg) { + sc.serveG.check() + if sc.writingFrame { + panic("internal error: can only be writing one frame at a time") + } + sc.writingFrame = true + + st := wm.stream + if st != nil { + switch st.state { + case stateHalfClosedLocal: + panic("internal error: attempt to send frame on half-closed-local stream") + case stateClosed: + if st.sentReset || st.gotReset { + // Skip this frame. But fake the frame write to reschedule: + sc.wroteFrameCh <- struct{}{} + return + } + panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm)) + } + } + + sc.needsFrameFlush = true + if endsStream(wm.write) { + if st == nil { + panic("internal error: expecting non-nil stream") + } + switch st.state { + case stateOpen: + // Here we would go to stateHalfClosedLocal in + // theory, but since our handler is done and + // the net/http package provides no mechanism + // for finishing writing to a ResponseWriter + // while still reading data (see possible TODO + // at top of this file), we go into closed + // state here anyway, after telling the peer + // we're hanging up on them. + st.state = stateHalfClosedLocal // won't last long, but necessary for closeStream via resetStream + errCancel := StreamError{st.id, ErrCodeCancel} + sc.resetStream(errCancel) + case stateHalfClosedRemote: + sc.closeStream(st, nil) + } + } + go sc.writeFrameAsync(wm) +} + +// scheduleFrameWrite tickles the frame writing scheduler. +// +// If a frame is already being written, nothing happens. This will be called again +// when the frame is done being written. +// +// If a frame isn't being written we need to send one, the best frame +// to send is selected, preferring first things that aren't +// stream-specific (e.g. ACKing settings), and then finding the +// highest priority stream. +// +// If a frame isn't being written and there's nothing else to send, we +// flush the write buffer. +func (sc *serverConn) scheduleFrameWrite() { + sc.serveG.check() + if sc.writingFrame { + return + } + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(frameWriteMsg{ + write: &writeGoAway{ + maxStreamID: sc.maxStreamID, + code: sc.goAwayCode, + }, + }) + return + } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(frameWriteMsg{write: writeSettingsAck{}}) + return + } + if !sc.inGoAway { + if wm, ok := sc.writeSched.take(); ok { + sc.startFrameWrite(wm) + return + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(frameWriteMsg{write: flushFrameWriter{}}) + sc.needsFrameFlush = false // after startFrameWrite, since it sets this true + return + } +} + +func (sc *serverConn) goAway(code ErrCode) { + sc.serveG.check() + if sc.inGoAway { + return + } + if code != ErrCodeNo { + sc.shutDownIn(250 * time.Millisecond) + } else { + // TODO: configurable + sc.shutDownIn(1 * time.Second) + } + sc.inGoAway = true + sc.needToSendGoAway = true + sc.goAwayCode = code + sc.scheduleFrameWrite() +} + +func (sc *serverConn) shutDownIn(d time.Duration) { + sc.serveG.check() + sc.shutdownTimer = time.NewTimer(d) + sc.shutdownTimerCh = sc.shutdownTimer.C +} + +func (sc *serverConn) resetStream(se StreamError) { + sc.serveG.check() + sc.writeFrame(frameWriteMsg{write: se}) + if st, ok := sc.streams[se.StreamID]; ok { + st.sentReset = true + sc.closeStream(st, se) + } +} + +// curHeaderStreamID returns the stream ID of the header block we're +// currently in the middle of reading. If this returns non-zero, the +// next frame must be a CONTINUATION with this stream id. +func (sc *serverConn) curHeaderStreamID() uint32 { + sc.serveG.check() + st := sc.req.stream + if st == nil { + return 0 + } + return st.id +} + +// processFrameFromReader processes the serve loop's read from readFrameCh from the +// frame-reading goroutine. +// processFrameFromReader returns whether the connection should be kept open. +func (sc *serverConn) processFrameFromReader(fg frameAndGate, fgValid bool) bool { + sc.serveG.check() + var clientGone bool + var err error + if !fgValid { + err = <-sc.readFrameErrCh + if err == ErrFrameTooLarge { + sc.goAway(ErrCodeFrameSize) + return true // goAway will close the loop + } + clientGone = err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") + if clientGone { + // TODO: could we also get into this state if + // the peer does a half close + // (e.g. CloseWrite) because they're done + // sending frames but they're still wanting + // our open replies? Investigate. + // TODO: add CloseWrite to crypto/tls.Conn first + // so we have a way to test this? I suppose + // just for testing we could have a non-TLS mode. + return false + } + } + + if fgValid { + f := fg.f + sc.vlogf("got %v: %#v", f.Header(), f) + err = sc.processFrame(f) + fg.g.Done() // unblock the readFrames goroutine + if err == nil { + return true + } + } + + switch ev := err.(type) { + case StreamError: + sc.resetStream(ev) + return true + case goAwayFlowError: + sc.goAway(ErrCodeFlowControl) + return true + case ConnectionError: + sc.logf("%v: %v", sc.conn.RemoteAddr(), ev) + sc.goAway(ErrCode(ev)) + return true // goAway will handle shutdown + default: + if !fgValid { + sc.logf("disconnecting; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err) + } else { + sc.logf("disconnection due to other error: %v", err) + } + } + return false +} + +func (sc *serverConn) processFrame(f Frame) error { + sc.serveG.check() + + // First frame received must be SETTINGS. + if !sc.sawFirstSettings { + if _, ok := f.(*SettingsFrame); !ok { + return ConnectionError(ErrCodeProtocol) + } + sc.sawFirstSettings = true + } + + if s := sc.curHeaderStreamID(); s != 0 { + if cf, ok := f.(*ContinuationFrame); !ok { + return ConnectionError(ErrCodeProtocol) + } else if cf.Header().StreamID != s { + return ConnectionError(ErrCodeProtocol) + } + } + + switch f := f.(type) { + case *SettingsFrame: + return sc.processSettings(f) + case *HeadersFrame: + return sc.processHeaders(f) + case *ContinuationFrame: + return sc.processContinuation(f) + case *WindowUpdateFrame: + return sc.processWindowUpdate(f) + case *PingFrame: + return sc.processPing(f) + case *DataFrame: + return sc.processData(f) + case *RSTStreamFrame: + return sc.processResetStream(f) + case *PriorityFrame: + return sc.processPriority(f) + case *PushPromiseFrame: + // A client cannot push. Thus, servers MUST treat the receipt of a PUSH_PROMISE + // frame as a connection error (Section 5.4.1) of type PROTOCOL_ERROR. + return ConnectionError(ErrCodeProtocol) + default: + log.Printf("Ignoring frame: %v", f.Header()) + return nil + } +} + +func (sc *serverConn) processPing(f *PingFrame) error { + sc.serveG.check() + if f.Flags.Has(FlagSettingsAck) { + // 6.7 PING: " An endpoint MUST NOT respond to PING frames + // containing this flag." + return nil + } + if f.StreamID != 0 { + // "PING frames are not associated with any individual + // stream. If a PING frame is received with a stream + // identifier field value other than 0x0, the recipient MUST + // respond with a connection error (Section 5.4.1) of type + // PROTOCOL_ERROR." + return ConnectionError(ErrCodeProtocol) + } + sc.writeFrame(frameWriteMsg{write: writePingAck{f}}) + return nil +} + +func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error { + sc.serveG.check() + switch { + case f.StreamID != 0: // stream-level flow control + st := sc.streams[f.StreamID] + if st == nil { + // "WINDOW_UPDATE can be sent by a peer that has sent a + // frame bearing the END_STREAM flag. This means that a + // receiver could receive a WINDOW_UPDATE frame on a "half + // closed (remote)" or "closed" stream. A receiver MUST + // NOT treat this as an error, see Section 5.1." + return nil + } + if !st.flow.add(int32(f.Increment)) { + return StreamError{f.StreamID, ErrCodeFlowControl} + } + default: // connection-level flow control + if !sc.flow.add(int32(f.Increment)) { + return goAwayFlowError{} + } + } + sc.scheduleFrameWrite() + return nil +} + +func (sc *serverConn) processResetStream(f *RSTStreamFrame) error { + sc.serveG.check() + + state, st := sc.state(f.StreamID) + if state == stateIdle { + // 6.4 "RST_STREAM frames MUST NOT be sent for a + // stream in the "idle" state. If a RST_STREAM frame + // identifying an idle stream is received, the + // recipient MUST treat this as a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + return ConnectionError(ErrCodeProtocol) + } + if st != nil { + st.gotReset = true + sc.closeStream(st, StreamError{f.StreamID, f.ErrCode}) + } + return nil +} + +func (sc *serverConn) closeStream(st *stream, err error) { + sc.serveG.check() + if st.state == stateIdle || st.state == stateClosed { + panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) + } + st.state = stateClosed + sc.curOpenStreams-- + delete(sc.streams, st.id) + if p := st.body; p != nil { + p.Close(err) + } + st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc + sc.writeSched.forgetStream(st.id) +} + +func (sc *serverConn) processSettings(f *SettingsFrame) error { + sc.serveG.check() + if f.IsAck() { + sc.unackedSettings-- + if sc.unackedSettings < 0 { + // Why is the peer ACKing settings we never sent? + // The spec doesn't mention this case, but + // hang up on them anyway. + return ConnectionError(ErrCodeProtocol) + } + return nil + } + if err := f.ForeachSetting(sc.processSetting); err != nil { + return err + } + sc.needToSendSettingsAck = true + sc.scheduleFrameWrite() + return nil +} + +func (sc *serverConn) processSetting(s Setting) error { + sc.serveG.check() + if err := s.Valid(); err != nil { + return err + } + sc.vlogf("processing setting %v", s) + switch s.ID { + case SettingHeaderTableSize: + sc.headerTableSize = s.Val + sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) + case SettingEnablePush: + sc.pushEnabled = s.Val != 0 + case SettingMaxConcurrentStreams: + sc.clientMaxStreams = s.Val + case SettingInitialWindowSize: + return sc.processSettingInitialWindowSize(s.Val) + case SettingMaxFrameSize: + sc.writeSched.maxFrameSize = s.Val + case SettingMaxHeaderListSize: + sc.maxHeaderListSize = s.Val + default: + // Unknown setting: "An endpoint that receives a SETTINGS + // frame with any unknown or unsupported identifier MUST + // ignore that setting." + } + return nil +} + +func (sc *serverConn) processSettingInitialWindowSize(val uint32) error { + sc.serveG.check() + // Note: val already validated to be within range by + // processSetting's Valid call. + + // "A SETTINGS frame can alter the initial flow control window + // size for all current streams. When the value of + // SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST + // adjust the size of all stream flow control windows that it + // maintains by the difference between the new value and the + // old value." + old := sc.initialWindowSize + sc.initialWindowSize = int32(val) + growth := sc.initialWindowSize - old // may be negative + for _, st := range sc.streams { + if !st.flow.add(growth) { + // 6.9.2 Initial Flow Control Window Size + // "An endpoint MUST treat a change to + // SETTINGS_INITIAL_WINDOW_SIZE that causes any flow + // control window to exceed the maximum size as a + // connection error (Section 5.4.1) of type + // FLOW_CONTROL_ERROR." + return ConnectionError(ErrCodeFlowControl) + } + } + return nil +} + +func (sc *serverConn) processData(f *DataFrame) error { + sc.serveG.check() + // "If a DATA frame is received whose stream is not in "open" + // or "half closed (local)" state, the recipient MUST respond + // with a stream error (Section 5.4.2) of type STREAM_CLOSED." + id := f.Header().StreamID + st, ok := sc.streams[id] + if !ok || st.state != stateOpen { + // This includes sending a RST_STREAM if the stream is + // in stateHalfClosedLocal (which currently means that + // the http.Handler returned, so it's done reading & + // done writing). Try to stop the client from sending + // more DATA. + return StreamError{id, ErrCodeStreamClosed} + } + if st.body == nil { + panic("internal error: should have a body in this state") + } + data := f.Data() + + // Sender sending more than they'd declared? + if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { + st.body.Close(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) + return StreamError{id, ErrCodeStreamClosed} + } + if len(data) > 0 { + // Check whether the client has flow control quota. + if int(st.inflow.available()) < len(data) { + return StreamError{id, ErrCodeFlowControl} + } + st.inflow.take(int32(len(data))) + wrote, err := st.body.Write(data) + if err != nil { + return StreamError{id, ErrCodeStreamClosed} + } + if wrote != len(data) { + panic("internal error: bad Writer") + } + st.bodyBytes += int64(len(data)) + } + if f.StreamEnded() { + if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes { + st.body.Close(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes", + st.declBodyBytes, st.bodyBytes)) + } else { + st.body.Close(io.EOF) + } + st.state = stateHalfClosedRemote + } + return nil +} + +func (sc *serverConn) processHeaders(f *HeadersFrame) error { + sc.serveG.check() + id := f.Header().StreamID + if sc.inGoAway { + // Ignore. + return nil + } + // http://http2.github.io/http2-spec/#rfc.section.5.1.1 + if id%2 != 1 || id <= sc.maxStreamID || sc.req.stream != nil { + // Streams initiated by a client MUST use odd-numbered + // stream identifiers. [...] The identifier of a newly + // established stream MUST be numerically greater than all + // streams that the initiating endpoint has opened or + // reserved. [...] An endpoint that receives an unexpected + // stream identifier MUST respond with a connection error + // (Section 5.4.1) of type PROTOCOL_ERROR. + return ConnectionError(ErrCodeProtocol) + } + if id > sc.maxStreamID { + sc.maxStreamID = id + } + st := &stream{ + id: id, + state: stateOpen, + } + if f.StreamEnded() { + st.state = stateHalfClosedRemote + } + st.cw.Init() + + st.flow.conn = &sc.flow // link to conn-level counter + st.flow.add(sc.initialWindowSize) + st.inflow.conn = &sc.inflow // link to conn-level counter + st.inflow.add(initialWindowSize) // TODO: update this when we send a higher initial window size in the initial settings + + sc.streams[id] = st + if f.HasPriority() { + adjustStreamPriority(sc.streams, st.id, f.Priority) + } + sc.curOpenStreams++ + sc.req = requestParam{ + stream: st, + header: make(http.Header), + } + return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) +} + +func (sc *serverConn) processContinuation(f *ContinuationFrame) error { + sc.serveG.check() + st := sc.streams[f.Header().StreamID] + if st == nil || sc.curHeaderStreamID() != st.id { + return ConnectionError(ErrCodeProtocol) + } + return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) +} + +func (sc *serverConn) processHeaderBlockFragment(st *stream, frag []byte, end bool) error { + sc.serveG.check() + if _, err := sc.hpackDecoder.Write(frag); err != nil { + // TODO: convert to stream error I assume? + return err + } + if !end { + return nil + } + if err := sc.hpackDecoder.Close(); err != nil { + // TODO: convert to stream error I assume? + return err + } + defer sc.resetPendingRequest() + if sc.curOpenStreams > sc.advMaxStreams { + // "Endpoints MUST NOT exceed the limit set by their + // peer. An endpoint that receives a HEADERS frame + // that causes their advertised concurrent stream + // limit to be exceeded MUST treat this as a stream + // error (Section 5.4.2) of type PROTOCOL_ERROR or + // REFUSED_STREAM." + if sc.unackedSettings == 0 { + // They should know better. + return StreamError{st.id, ErrCodeProtocol} + } + // Assume it's a network race, where they just haven't + // received our last SETTINGS update. But actually + // this can't happen yet, because we don't yet provide + // a way for users to adjust server parameters at + // runtime. + return StreamError{st.id, ErrCodeRefusedStream} + } + + rw, req, err := sc.newWriterAndRequest() + if err != nil { + return err + } + st.body = req.Body.(*requestBody).pipe // may be nil + st.declBodyBytes = req.ContentLength + go sc.runHandler(rw, req) + return nil +} + +func (sc *serverConn) processPriority(f *PriorityFrame) error { + adjustStreamPriority(sc.streams, f.StreamID, f.PriorityParam) + return nil +} + +func adjustStreamPriority(streams map[uint32]*stream, streamID uint32, priority PriorityParam) { + st, ok := streams[streamID] + if !ok { + // TODO: not quite correct (this streamID might + // already exist in the dep tree, but be closed), but + // close enough for now. + return + } + st.weight = priority.Weight + parent := streams[priority.StreamDep] // might be nil + if parent == st { + // if client tries to set this stream to be the parent of itself + // ignore and keep going + return + } + + // section 5.3.3: If a stream is made dependent on one of its + // own dependencies, the formerly dependent stream is first + // moved to be dependent on the reprioritized stream's previous + // parent. The moved dependency retains its weight. + for piter := parent; piter != nil; piter = piter.parent { + if piter == st { + parent.parent = st.parent + break + } + } + st.parent = parent + if priority.Exclusive && (st.parent != nil || priority.StreamDep == 0) { + for _, openStream := range streams { + if openStream != st && openStream.parent == st.parent { + openStream.parent = st + } + } + } +} + +// resetPendingRequest zeros out all state related to a HEADERS frame +// and its zero or more CONTINUATION frames sent to start a new +// request. +func (sc *serverConn) resetPendingRequest() { + sc.serveG.check() + sc.req = requestParam{} +} + +func (sc *serverConn) newWriterAndRequest() (*responseWriter, *http.Request, error) { + sc.serveG.check() + rp := &sc.req + if rp.invalidHeader || rp.method == "" || rp.path == "" || + (rp.scheme != "https" && rp.scheme != "http") { + // See 8.1.2.6 Malformed Requests and Responses: + // + // Malformed requests or responses that are detected + // MUST be treated as a stream error (Section 5.4.2) + // of type PROTOCOL_ERROR." + // + // 8.1.2.3 Request Pseudo-Header Fields + // "All HTTP/2 requests MUST include exactly one valid + // value for the :method, :scheme, and :path + // pseudo-header fields" + return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol} + } + var tlsState *tls.ConnectionState // nil if not scheme https + if rp.scheme == "https" { + tlsState = sc.tlsState + } + authority := rp.authority + if authority == "" { + authority = rp.header.Get("Host") + } + needsContinue := rp.header.Get("Expect") == "100-continue" + if needsContinue { + rp.header.Del("Expect") + } + bodyOpen := rp.stream.state == stateOpen + body := &requestBody{ + conn: sc, + stream: rp.stream, + needsContinue: needsContinue, + } + // TODO: handle asterisk '*' requests + test + url, err := url.ParseRequestURI(rp.path) + if err != nil { + // TODO: find the right error code? + return nil, nil, StreamError{rp.stream.id, ErrCodeProtocol} + } + req := &http.Request{ + Method: rp.method, + URL: url, + RemoteAddr: sc.remoteAddrStr, + Header: rp.header, + RequestURI: rp.path, + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + TLS: tlsState, + Host: authority, + Body: body, + } + if bodyOpen { + body.pipe = &pipe{ + b: buffer{buf: make([]byte, initialWindowSize)}, // TODO: share/remove XXX + } + body.pipe.c.L = &body.pipe.m + + if vv, ok := rp.header["Content-Length"]; ok { + req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) + } else { + req.ContentLength = -1 + } + } + + rws := responseWriterStatePool.Get().(*responseWriterState) + bwSave := rws.bw + *rws = responseWriterState{} // zero all the fields + rws.conn = sc + rws.bw = bwSave + rws.bw.Reset(chunkWriter{rws}) + rws.stream = rp.stream + rws.req = req + rws.body = body + rws.frameWriteCh = make(chan error, 1) + + rw := &responseWriter{rws: rws} + return rw, req, nil +} + +// Run on its own goroutine. +func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request) { + defer rw.handlerDone() + // TODO: catch panics like net/http.Server + sc.handler.ServeHTTP(rw, req) +} + +// called from handler goroutines. +// h may be nil. +func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders, tempCh chan error) { + sc.serveG.checkNotOn() // NOT on + var errc chan error + if headerData.h != nil { + // If there's a header map (which we don't own), so we have to block on + // waiting for this frame to be written, so an http.Flush mid-handler + // writes out the correct value of keys, before a handler later potentially + // mutates it. + errc = tempCh + } + sc.writeFrameFromHandler(frameWriteMsg{ + write: headerData, + stream: st, + done: errc, + }) + if errc != nil { + select { + case <-errc: + // Ignore. Just for synchronization. + // Any error will be handled in the writing goroutine. + case <-sc.doneServing: + // Client has closed the connection. + } + } +} + +// called from handler goroutines. +func (sc *serverConn) write100ContinueHeaders(st *stream) { + sc.writeFrameFromHandler(frameWriteMsg{ + write: write100ContinueHeadersFrame{st.id}, + stream: st, + }) +} + +// A bodyReadMsg tells the server loop that the http.Handler read n +// bytes of the DATA from the client on the given stream. +type bodyReadMsg struct { + st *stream + n int +} + +// called from handler goroutines. +// Notes that the handler for the given stream ID read n bytes of its body +// and schedules flow control tokens to be sent. +func (sc *serverConn) noteBodyReadFromHandler(st *stream, n int) { + sc.serveG.checkNotOn() // NOT on + sc.bodyReadCh <- bodyReadMsg{st, n} +} + +func (sc *serverConn) noteBodyRead(st *stream, n int) { + sc.serveG.check() + sc.sendWindowUpdate(nil, n) // conn-level + if st.state != stateHalfClosedRemote && st.state != stateClosed { + // Don't send this WINDOW_UPDATE if the stream is closed + // remotely. + sc.sendWindowUpdate(st, n) + } +} + +// st may be nil for conn-level +func (sc *serverConn) sendWindowUpdate(st *stream, n int) { + sc.serveG.check() + // "The legal range for the increment to the flow control + // window is 1 to 2^31-1 (2,147,483,647) octets." + // A Go Read call on 64-bit machines could in theory read + // a larger Read than this. Very unlikely, but we handle it here + // rather than elsewhere for now. + const maxUint31 = 1<<31 - 1 + for n >= maxUint31 { + sc.sendWindowUpdate32(st, maxUint31) + n -= maxUint31 + } + sc.sendWindowUpdate32(st, int32(n)) +} + +// st may be nil for conn-level +func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) { + sc.serveG.check() + if n == 0 { + return + } + if n < 0 { + panic("negative update") + } + var streamID uint32 + if st != nil { + streamID = st.id + } + sc.writeFrame(frameWriteMsg{ + write: writeWindowUpdate{streamID: streamID, n: uint32(n)}, + stream: st, + }) + var ok bool + if st == nil { + ok = sc.inflow.add(n) + } else { + ok = st.inflow.add(n) + } + if !ok { + panic("internal error; sent too many window updates without decrements?") + } +} + +type requestBody struct { + stream *stream + conn *serverConn + closed bool + pipe *pipe // non-nil if we have a HTTP entity message body + needsContinue bool // need to send a 100-continue +} + +func (b *requestBody) Close() error { + if b.pipe != nil { + b.pipe.Close(errClosedBody) + } + b.closed = true + return nil +} + +func (b *requestBody) Read(p []byte) (n int, err error) { + if b.needsContinue { + b.needsContinue = false + b.conn.write100ContinueHeaders(b.stream) + } + if b.pipe == nil { + return 0, io.EOF + } + n, err = b.pipe.Read(p) + if n > 0 { + b.conn.noteBodyReadFromHandler(b.stream, n) + } + return +} + +// responseWriter is the http.ResponseWriter implementation. It's +// intentionally small (1 pointer wide) to minimize garbage. The +// responseWriterState pointer inside is zeroed at the end of a +// request (in handlerDone) and calls on the responseWriter thereafter +// simply crash (caller's mistake), but the much larger responseWriterState +// and buffers are reused between multiple requests. +type responseWriter struct { + rws *responseWriterState +} + +// Optional http.ResponseWriter interfaces implemented. +var ( + _ http.CloseNotifier = (*responseWriter)(nil) + _ http.Flusher = (*responseWriter)(nil) + _ stringWriter = (*responseWriter)(nil) +) + +type responseWriterState struct { + // immutable within a request: + stream *stream + req *http.Request + body *requestBody // to close at end of request, if DATA frames didn't + conn *serverConn + + // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc + bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState} + + // mutated by http.Handler goroutine: + handlerHeader http.Header // nil until called + snapHeader http.Header // snapshot of handlerHeader at WriteHeader time + status int // status code passed to WriteHeader + wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet. + sentHeader bool // have we sent the header frame? + handlerDone bool // handler has finished + curWrite writeData + frameWriteCh chan error // re-used whenever we need to block on a frame being written + + closeNotifierMu sync.Mutex // guards closeNotifierCh + closeNotifierCh chan bool // nil until first used +} + +type chunkWriter struct{ rws *responseWriterState } + +func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } + +// writeChunk writes chunks from the bufio.Writer. But because +// bufio.Writer may bypass its chunking, sometimes p may be +// arbitrarily large. +// +// writeChunk is also responsible (on the first chunk) for sending the +// HEADER response. +func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) { + if !rws.wroteHeader { + rws.writeHeader(200) + } + if !rws.sentHeader { + rws.sentHeader = true + var ctype, clen string // implicit ones, if we can calculate it + if rws.handlerDone && rws.snapHeader.Get("Content-Length") == "" { + clen = strconv.Itoa(len(p)) + } + if rws.snapHeader.Get("Content-Type") == "" { + ctype = http.DetectContentType(p) + } + endStream := rws.handlerDone && len(p) == 0 + rws.conn.writeHeaders(rws.stream, &writeResHeaders{ + streamID: rws.stream.id, + httpResCode: rws.status, + h: rws.snapHeader, + endStream: endStream, + contentType: ctype, + contentLength: clen, + }, rws.frameWriteCh) + if endStream { + return 0, nil + } + } + if len(p) == 0 && !rws.handlerDone { + return 0, nil + } + curWrite := &rws.curWrite + curWrite.streamID = rws.stream.id + curWrite.p = p + curWrite.endStream = rws.handlerDone + if err := rws.conn.writeDataFromHandler(rws.stream, curWrite, rws.frameWriteCh); err != nil { + return 0, err + } + return len(p), nil +} + +func (w *responseWriter) Flush() { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.bw.Buffered() > 0 { + if err := rws.bw.Flush(); err != nil { + // Ignore the error. The frame writer already knows. + return + } + } else { + // The bufio.Writer won't call chunkWriter.Write + // (writeChunk with zero bytes, so we have to do it + // ourselves to force the HTTP response header and/or + // final DATA frame (with END_STREAM) to be sent. + rws.writeChunk(nil) + } +} + +func (w *responseWriter) CloseNotify() <-chan bool { + rws := w.rws + if rws == nil { + panic("CloseNotify called after Handler finished") + } + rws.closeNotifierMu.Lock() + ch := rws.closeNotifierCh + if ch == nil { + ch = make(chan bool, 1) + rws.closeNotifierCh = ch + go func() { + rws.stream.cw.Wait() // wait for close + ch <- true + }() + } + rws.closeNotifierMu.Unlock() + return ch +} + +func (w *responseWriter) Header() http.Header { + rws := w.rws + if rws == nil { + panic("Header called after Handler finished") + } + if rws.handlerHeader == nil { + rws.handlerHeader = make(http.Header) + } + return rws.handlerHeader +} + +func (w *responseWriter) WriteHeader(code int) { + rws := w.rws + if rws == nil { + panic("WriteHeader called after Handler finished") + } + rws.writeHeader(code) +} + +func (rws *responseWriterState) writeHeader(code int) { + if !rws.wroteHeader { + rws.wroteHeader = true + rws.status = code + if len(rws.handlerHeader) > 0 { + rws.snapHeader = cloneHeader(rws.handlerHeader) + } + } +} + +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 +} + +// The Life Of A Write is like this: +// +// * Handler calls w.Write or w.WriteString -> +// * -> rws.bw (*bufio.Writer) -> +// * (Handler migth call Flush) +// * -> chunkWriter{rws} +// * -> responseWriterState.writeChunk(p []byte) +// * -> responseWriterState.writeChunk (most of the magic; see comment there) +func (w *responseWriter) Write(p []byte) (n int, err error) { + return w.write(len(p), p, "") +} + +func (w *responseWriter) WriteString(s string) (n int, err error) { + return w.write(len(s), nil, s) +} + +// either dataB or dataS is non-zero. +func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) { + rws := w.rws + if rws == nil { + panic("Write called after Handler finished") + } + if !rws.wroteHeader { + w.WriteHeader(200) + } + if dataB != nil { + return rws.bw.Write(dataB) + } else { + return rws.bw.WriteString(dataS) + } +} + +func (w *responseWriter) handlerDone() { + rws := w.rws + if rws == nil { + panic("handlerDone called twice") + } + rws.handlerDone = true + w.Flush() + w.rws = nil + responseWriterStatePool.Put(rws) +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/server_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/server_test.go new file mode 100644 index 0000000..1133223 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/server_test.go @@ -0,0 +1,2252 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "bytes" + "crypto/tls" + "errors" + "flag" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "os" + "reflect" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/bradfitz/http2/hpack" +) + +var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered") + +type serverTester struct { + cc net.Conn // client conn + t testing.TB + ts *httptest.Server + fr *Framer + logBuf *bytes.Buffer + logFilter []string // substrings to filter out + scMu sync.Mutex // guards sc + sc *serverConn + + // writing headers: + headerBuf bytes.Buffer + hpackEnc *hpack.Encoder + + // reading frames: + frc chan Frame + frErrc chan error + readTimer *time.Timer +} + +func init() { + testHookOnPanicMu = new(sync.Mutex) +} + +func resetHooks() { + testHookOnPanicMu.Lock() + testHookOnPanic = nil + testHookOnPanicMu.Unlock() +} + +type serverTesterOpt string + +var optOnlyServer = serverTesterOpt("only_server") + +func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester { + resetHooks() + + logBuf := new(bytes.Buffer) + ts := httptest.NewUnstartedServer(handler) + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + // The h2-14 is temporary, until curl is updated. (as used by unit tests + // in Docker) + NextProtos: []string{NextProtoTLS, "h2-14"}, + } + + onlyServer := false + for _, opt := range opts { + switch v := opt.(type) { + case func(*tls.Config): + v(tlsConfig) + case func(*httptest.Server): + v(ts) + case serverTesterOpt: + onlyServer = (v == optOnlyServer) + default: + t.Fatalf("unknown newServerTester option type %T", v) + } + } + + ConfigureServer(ts.Config, &Server{}) + + st := &serverTester{ + t: t, + ts: ts, + logBuf: logBuf, + frc: make(chan Frame, 1), + frErrc: make(chan error, 1), + } + st.hpackEnc = hpack.NewEncoder(&st.headerBuf) + + var stderrv io.Writer = ioutil.Discard + if *stderrVerbose { + stderrv = os.Stderr + } + + ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config + ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv, twriter{t: t, st: st}, logBuf), "", log.LstdFlags) + ts.StartTLS() + + if VerboseLogs { + t.Logf("Running test server at: %s", ts.URL) + } + testHookGetServerConn = func(v *serverConn) { + st.scMu.Lock() + defer st.scMu.Unlock() + st.sc = v + st.sc.testHookCh = make(chan func()) + } + log.SetOutput(io.MultiWriter(stderrv, twriter{t: t, st: st})) + if !onlyServer { + cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig) + if err != nil { + t.Fatal(err) + } + st.cc = cc + st.fr = NewFramer(cc, cc) + } + + return st +} + +func (st *serverTester) closeConn() { + st.scMu.Lock() + defer st.scMu.Unlock() + st.sc.conn.Close() +} + +func (st *serverTester) addLogFilter(phrase string) { + st.logFilter = append(st.logFilter, phrase) +} + +func (st *serverTester) stream(id uint32) *stream { + ch := make(chan *stream, 1) + st.sc.testHookCh <- func() { + ch <- st.sc.streams[id] + } + return <-ch +} + +func (st *serverTester) streamState(id uint32) streamState { + ch := make(chan streamState, 1) + st.sc.testHookCh <- func() { + state, _ := st.sc.state(id) + ch <- state + } + return <-ch +} + +func (st *serverTester) Close() { + st.ts.Close() + if st.cc != nil { + st.cc.Close() + } + log.SetOutput(os.Stderr) +} + +// greet initiates the client's HTTP/2 connection into a state where +// frames may be sent. +func (st *serverTester) greet() { + st.writePreface() + st.writeInitialSettings() + st.wantSettings() + st.writeSettingsAck() + st.wantSettingsAck() +} + +func (st *serverTester) writePreface() { + n, err := st.cc.Write(clientPreface) + if err != nil { + st.t.Fatalf("Error writing client preface: %v", err) + } + if n != len(clientPreface) { + st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface)) + } +} + +func (st *serverTester) writeInitialSettings() { + if err := st.fr.WriteSettings(); err != nil { + st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err) + } +} + +func (st *serverTester) writeSettingsAck() { + if err := st.fr.WriteSettingsAck(); err != nil { + st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err) + } +} + +func (st *serverTester) writeHeaders(p HeadersFrameParam) { + if err := st.fr.WriteHeaders(p); err != nil { + st.t.Fatalf("Error writing HEADERS: %v", err) + } +} + +func (st *serverTester) encodeHeaderField(k, v string) { + err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) + if err != nil { + st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err) + } +} + +// encodeHeader encodes headers and returns their HPACK bytes. headers +// must contain an even number of key/value pairs. There may be +// multiple pairs for keys (e.g. "cookie"). The :method, :path, and +// :scheme headers default to GET, / and https. +func (st *serverTester) encodeHeader(headers ...string) []byte { + if len(headers)%2 == 1 { + panic("odd number of kv args") + } + + st.headerBuf.Reset() + + if len(headers) == 0 { + // Fast path, mostly for benchmarks, so test code doesn't pollute + // profiles when we're looking to improve server allocations. + st.encodeHeaderField(":method", "GET") + st.encodeHeaderField(":path", "/") + st.encodeHeaderField(":scheme", "https") + return st.headerBuf.Bytes() + } + + if len(headers) == 2 && headers[0] == ":method" { + // Another fast path for benchmarks. + st.encodeHeaderField(":method", headers[1]) + st.encodeHeaderField(":path", "/") + st.encodeHeaderField(":scheme", "https") + return st.headerBuf.Bytes() + } + + pseudoCount := map[string]int{} + keys := []string{":method", ":path", ":scheme"} + vals := map[string][]string{ + ":method": {"GET"}, + ":path": {"/"}, + ":scheme": {"https"}, + } + for len(headers) > 0 { + k, v := headers[0], headers[1] + headers = headers[2:] + if _, ok := vals[k]; !ok { + keys = append(keys, k) + } + if strings.HasPrefix(k, ":") { + pseudoCount[k]++ + if pseudoCount[k] == 1 { + vals[k] = []string{v} + } else { + // Allows testing of invalid headers w/ dup pseudo fields. + vals[k] = append(vals[k], v) + } + } else { + vals[k] = append(vals[k], v) + } + } + st.headerBuf.Reset() + for _, k := range keys { + for _, v := range vals[k] { + st.encodeHeaderField(k, v) + } + } + return st.headerBuf.Bytes() +} + +// bodylessReq1 writes a HEADERS frames with StreamID 1 and EndStream and EndHeaders set. +func (st *serverTester) bodylessReq1(headers ...string) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(headers...), + EndStream: true, + EndHeaders: true, + }) +} + +func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) { + if err := st.fr.WriteData(streamID, endStream, data); err != nil { + st.t.Fatalf("Error writing DATA: %v", err) + } +} + +func (st *serverTester) readFrame() (Frame, error) { + go func() { + fr, err := st.fr.ReadFrame() + if err != nil { + st.frErrc <- err + } else { + st.frc <- fr + } + }() + t := st.readTimer + if t == nil { + t = time.NewTimer(2 * time.Second) + st.readTimer = t + } + t.Reset(2 * time.Second) + defer t.Stop() + select { + case f := <-st.frc: + return f, nil + case err := <-st.frErrc: + return nil, err + case <-t.C: + return nil, errors.New("timeout waiting for frame") + } +} + +func (st *serverTester) wantHeaders() *HeadersFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a HEADERS frame: %v", err) + } + hf, ok := f.(*HeadersFrame) + if !ok { + st.t.Fatalf("got a %T; want *HeadersFrame", f) + } + return hf +} + +func (st *serverTester) wantContinuation() *ContinuationFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a CONTINUATION frame: %v", err) + } + cf, ok := f.(*ContinuationFrame) + if !ok { + st.t.Fatalf("got a %T; want *ContinuationFrame", f) + } + return cf +} + +func (st *serverTester) wantData() *DataFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a DATA frame: %v", err) + } + df, ok := f.(*DataFrame) + if !ok { + st.t.Fatalf("got a %T; want *DataFrame", f) + } + return df +} + +func (st *serverTester) wantSettings() *SettingsFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err) + } + sf, ok := f.(*SettingsFrame) + if !ok { + st.t.Fatalf("got a %T; want *SettingsFrame", f) + } + return sf +} + +func (st *serverTester) wantPing() *PingFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a PING frame: %v", err) + } + pf, ok := f.(*PingFrame) + if !ok { + st.t.Fatalf("got a %T; want *PingFrame", f) + } + return pf +} + +func (st *serverTester) wantGoAway() *GoAwayFrame { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a GOAWAY frame: %v", err) + } + gf, ok := f.(*GoAwayFrame) + if !ok { + st.t.Fatalf("got a %T; want *GoAwayFrame", f) + } + return gf +} + +func (st *serverTester) wantRSTStream(streamID uint32, errCode ErrCode) { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting an RSTStream frame: %v", err) + } + rs, ok := f.(*RSTStreamFrame) + if !ok { + st.t.Fatalf("got a %T; want *RSTStreamFrame", f) + } + if rs.FrameHeader.StreamID != streamID { + st.t.Fatalf("RSTStream StreamID = %d; want %d", rs.FrameHeader.StreamID, streamID) + } + if rs.ErrCode != errCode { + st.t.Fatalf("RSTStream ErrCode = %d (%s); want %d (%s)", rs.ErrCode, rs.ErrCode, errCode, errCode) + } +} + +func (st *serverTester) wantWindowUpdate(streamID, incr uint32) { + f, err := st.readFrame() + if err != nil { + st.t.Fatalf("Error while expecting a WINDOW_UPDATE frame: %v", err) + } + wu, ok := f.(*WindowUpdateFrame) + if !ok { + st.t.Fatalf("got a %T; want *WindowUpdateFrame", f) + } + if wu.FrameHeader.StreamID != streamID { + st.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID) + } + if wu.Increment != incr { + st.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr) + } +} + +func (st *serverTester) wantSettingsAck() { + f, err := st.readFrame() + if err != nil { + st.t.Fatal(err) + } + sf, ok := f.(*SettingsFrame) + if !ok { + st.t.Fatalf("Wanting a settings ACK, received a %T", f) + } + if !sf.Header().Flags.Has(FlagSettingsAck) { + st.t.Fatal("Settings Frame didn't have ACK set") + } + +} + +func TestServer(t *testing.T) { + gotReq := make(chan bool, 1) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Foo", "Bar") + gotReq <- true + }) + defer st.Close() + + covers("3.5", ` + The server connection preface consists of a potentially empty + SETTINGS frame ([SETTINGS]) that MUST be the first frame the + server sends in the HTTP/2 connection. + `) + + st.writePreface() + st.writeInitialSettings() + st.wantSettings() + st.writeSettingsAck() + st.wantSettingsAck() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(), + EndStream: true, // no DATA frames + EndHeaders: true, + }) + + select { + case <-gotReq: + case <-time.After(2 * time.Second): + t.Error("timeout waiting for request") + } +} + +func TestServer_Request_Get(t *testing.T) { + testServerRequest(t, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader("foo-bar", "some-value"), + EndStream: true, // no DATA frames + EndHeaders: true, + }) + }, func(r *http.Request) { + if r.Method != "GET" { + t.Errorf("Method = %q; want GET", r.Method) + } + if r.URL.Path != "/" { + t.Errorf("URL.Path = %q; want /", r.URL.Path) + } + if r.ContentLength != 0 { + t.Errorf("ContentLength = %v; want 0", r.ContentLength) + } + if r.Close { + t.Error("Close = true; want false") + } + if !strings.Contains(r.RemoteAddr, ":") { + t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr) + } + if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 { + t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor) + } + wantHeader := http.Header{ + "Foo-Bar": []string{"some-value"}, + } + if !reflect.DeepEqual(r.Header, wantHeader) { + t.Errorf("Header = %#v; want %#v", r.Header, wantHeader) + } + if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 { + t.Errorf("Read = %d, %v; want 0, EOF", n, err) + } + }) +} + +func TestServer_Request_Get_PathSlashes(t *testing.T) { + testServerRequest(t, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(":path", "/%2f/"), + EndStream: true, // no DATA frames + EndHeaders: true, + }) + }, func(r *http.Request) { + if r.RequestURI != "/%2f/" { + t.Errorf("RequestURI = %q; want /%%2f/", r.RequestURI) + } + if r.URL.Path != "///" { + t.Errorf("URL.Path = %q; want ///", r.URL.Path) + } + }) +} + +// TODO: add a test with EndStream=true on the HEADERS but setting a +// Content-Length anyway. Should we just omit it and force it to +// zero? + +func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) { + testServerRequest(t, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(":method", "POST"), + EndStream: true, + EndHeaders: true, + }) + }, func(r *http.Request) { + if r.Method != "POST" { + t.Errorf("Method = %q; want POST", r.Method) + } + if r.ContentLength != 0 { + t.Errorf("ContentLength = %v; want 0", r.ContentLength) + } + if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 { + t.Errorf("Read = %d, %v; want 0, EOF", n, err) + } + }) +} + +func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) { + testBodyContents(t, -1, "", func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(":method", "POST"), + EndStream: false, // to say DATA frames are coming + EndHeaders: true, + }) + st.writeData(1, true, nil) // just kidding. empty body. + }) +} + +func TestServer_Request_Post_Body_OneData(t *testing.T) { + const content = "Some content" + testBodyContents(t, -1, content, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(":method", "POST"), + EndStream: false, // to say DATA frames are coming + EndHeaders: true, + }) + st.writeData(1, true, []byte(content)) + }) +} + +func TestServer_Request_Post_Body_TwoData(t *testing.T) { + const content = "Some content" + testBodyContents(t, -1, content, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(":method", "POST"), + EndStream: false, // to say DATA frames are coming + EndHeaders: true, + }) + st.writeData(1, false, []byte(content[:5])) + st.writeData(1, true, []byte(content[5:])) + }) +} + +func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) { + const content = "Some content" + testBodyContents(t, int64(len(content)), content, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader( + ":method", "POST", + "content-length", strconv.Itoa(len(content)), + ), + EndStream: false, // to say DATA frames are coming + EndHeaders: true, + }) + st.writeData(1, true, []byte(content)) + }) +} + +func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) { + testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes", + func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader( + ":method", "POST", + "content-length", "3", + ), + EndStream: false, // to say DATA frames are coming + EndHeaders: true, + }) + st.writeData(1, true, []byte("12")) + }) +} + +func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) { + testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes", + func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader( + ":method", "POST", + "content-length", "4", + ), + EndStream: false, // to say DATA frames are coming + EndHeaders: true, + }) + st.writeData(1, true, []byte("12345")) + }) +} + +func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) { + testServerRequest(t, write, func(r *http.Request) { + if r.Method != "POST" { + t.Errorf("Method = %q; want POST", r.Method) + } + if r.ContentLength != wantContentLength { + t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength) + } + all, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + if string(all) != wantBody { + t.Errorf("Read = %q; want %q", all, wantBody) + } + if err := r.Body.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + }) +} + +func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) { + testServerRequest(t, write, func(r *http.Request) { + if r.Method != "POST" { + t.Errorf("Method = %q; want POST", r.Method) + } + if r.ContentLength != wantContentLength { + t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength) + } + all, err := ioutil.ReadAll(r.Body) + if err == nil { + t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.", + wantReadError, all) + } + if !strings.Contains(err.Error(), wantReadError) { + t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError) + } + if err := r.Body.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + }) +} + +// Using a Host header, instead of :authority +func TestServer_Request_Get_Host(t *testing.T) { + const host = "example.com" + testServerRequest(t, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader("host", host), + EndStream: true, + EndHeaders: true, + }) + }, func(r *http.Request) { + if r.Host != host { + t.Errorf("Host = %q; want %q", r.Host, host) + } + }) +} + +// Using an :authority pseudo-header, instead of Host +func TestServer_Request_Get_Authority(t *testing.T) { + const host = "example.com" + testServerRequest(t, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(":authority", host), + EndStream: true, + EndHeaders: true, + }) + }, func(r *http.Request) { + if r.Host != host { + t.Errorf("Host = %q; want %q", r.Host, host) + } + }) +} + +func TestServer_Request_WithContinuation(t *testing.T) { + wantHeader := http.Header{ + "Foo-One": []string{"value-one"}, + "Foo-Two": []string{"value-two"}, + "Foo-Three": []string{"value-three"}, + } + testServerRequest(t, func(st *serverTester) { + fullHeaders := st.encodeHeader( + "foo-one", "value-one", + "foo-two", "value-two", + "foo-three", "value-three", + ) + remain := fullHeaders + chunks := 0 + for len(remain) > 0 { + const maxChunkSize = 5 + chunk := remain + if len(chunk) > maxChunkSize { + chunk = chunk[:maxChunkSize] + } + remain = remain[len(chunk):] + + if chunks == 0 { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: chunk, + EndStream: true, // no DATA frames + EndHeaders: false, // we'll have continuation frames + }) + } else { + err := st.fr.WriteContinuation(1, len(remain) == 0, chunk) + if err != nil { + t.Fatal(err) + } + } + chunks++ + } + if chunks < 2 { + t.Fatal("too few chunks") + } + }, func(r *http.Request) { + if !reflect.DeepEqual(r.Header, wantHeader) { + t.Errorf("Header = %#v; want %#v", r.Header, wantHeader) + } + }) +} + +// Concatenated cookie headers. ("8.1.2.5 Compressing the Cookie Header Field") +func TestServer_Request_CookieConcat(t *testing.T) { + const host = "example.com" + testServerRequest(t, func(st *serverTester) { + st.bodylessReq1( + ":authority", host, + "cookie", "a=b", + "cookie", "c=d", + "cookie", "e=f", + ) + }, func(r *http.Request) { + const want = "a=b; c=d; e=f" + if got := r.Header.Get("Cookie"); got != want { + t.Errorf("Cookie = %q; want %q", got, want) + } + }) +} + +func TestServer_Request_Reject_CapitalHeader(t *testing.T) { + testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") }) +} + +func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) { + testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") }) +} + +func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) { + // 8.1.2.3 Request Pseudo-Header Fields + // "All HTTP/2 requests MUST include exactly one valid value" ... + testRejectRequest(t, func(st *serverTester) { + st.addLogFilter("duplicate pseudo-header") + st.bodylessReq1(":method", "GET", ":method", "POST") + }) +} + +func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) { + // 8.1.2.3 Request Pseudo-Header Fields + // "All pseudo-header fields MUST appear in the header block + // before regular header fields. Any request or response that + // contains a pseudo-header field that appears in a header + // block after a regular header field MUST be treated as + // malformed (Section 8.1.2.6)." + testRejectRequest(t, func(st *serverTester) { + st.addLogFilter("pseudo-header after regular header") + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"}) + enc.WriteField(hpack.HeaderField{Name: "regular", Value: "foobar"}) + enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"}) + enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"}) + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: buf.Bytes(), + EndStream: true, + EndHeaders: true, + }) + }) +} + +func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) { + testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") }) +} + +func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) { + testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") }) +} + +func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) { + testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "bogus") }) +} + +func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) { + testRejectRequest(t, func(st *serverTester) { + st.addLogFilter(`invalid pseudo-header ":unknown_thing"`) + st.bodylessReq1(":unknown_thing", "") + }) +} + +func testRejectRequest(t *testing.T, send func(*serverTester)) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + t.Fatal("server request made it to handler; should've been rejected") + }) + defer st.Close() + + st.greet() + send(st) + st.wantRSTStream(1, ErrCodeProtocol) +} + +func TestServer_Ping(t *testing.T) { + st := newServerTester(t, nil) + defer st.Close() + st.greet() + + // Server should ignore this one, since it has ACK set. + ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128} + if err := st.fr.WritePing(true, ackPingData); err != nil { + t.Fatal(err) + } + + // But the server should reply to this one, since ACK is false. + pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + if err := st.fr.WritePing(false, pingData); err != nil { + t.Fatal(err) + } + + pf := st.wantPing() + if !pf.Flags.Has(FlagPingAck) { + t.Error("response ping doesn't have ACK set") + } + if pf.Data != pingData { + t.Errorf("response ping has data %q; want %q", pf.Data, pingData) + } +} + +func TestServer_RejectsLargeFrames(t *testing.T) { + st := newServerTester(t, nil) + defer st.Close() + st.greet() + + // Write too large of a frame (too large by one byte) + // We ignore the return value because it's expected that the server + // will only read the first 9 bytes (the headre) and then disconnect. + st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1)) + + gf := st.wantGoAway() + if gf.ErrCode != ErrCodeFrameSize { + t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFrameSize) + } +} + +func TestServer_Handler_Sends_WindowUpdate(t *testing.T) { + puppet := newHandlerPuppet() + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + puppet.act(w, r) + }) + defer st.Close() + defer puppet.done() + + st.greet() + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(":method", "POST"), + EndStream: false, // data coming + EndHeaders: true, + }) + st.writeData(1, false, []byte("abcdef")) + puppet.do(readBodyHandler(t, "abc")) + st.wantWindowUpdate(0, 3) + st.wantWindowUpdate(1, 3) + + puppet.do(readBodyHandler(t, "def")) + st.wantWindowUpdate(0, 3) + st.wantWindowUpdate(1, 3) + + st.writeData(1, true, []byte("ghijkl")) // END_STREAM here + puppet.do(readBodyHandler(t, "ghi")) + puppet.do(readBodyHandler(t, "jkl")) + st.wantWindowUpdate(0, 3) + st.wantWindowUpdate(0, 3) // no more stream-level, since END_STREAM +} + +func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) { + st := newServerTester(t, nil) + defer st.Close() + st.greet() + if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil { + t.Fatal(err) + } + gf := st.wantGoAway() + if gf.ErrCode != ErrCodeFlowControl { + t.Errorf("GOAWAY err = %v; want %v", gf.ErrCode, ErrCodeFlowControl) + } + if gf.LastStreamID != 0 { + t.Errorf("GOAWAY last stream ID = %v; want %v", gf.LastStreamID, 0) + } +} + +func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) { + inHandler := make(chan bool) + blockHandler := make(chan bool) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + inHandler <- true + <-blockHandler + }) + defer st.Close() + defer close(blockHandler) + st.greet() + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(":method", "POST"), + EndStream: false, // keep it open + EndHeaders: true, + }) + <-inHandler + // Send a bogus window update: + if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil { + t.Fatal(err) + } + st.wantRSTStream(1, ErrCodeFlowControl) +} + +// testServerPostUnblock sends a hanging POST with unsent data to handler, +// then runs fn once in the handler, and verifies that the error returned from +// handler is acceptable. It fails if takes over 5 seconds for handler to exit. +func testServerPostUnblock(t *testing.T, + handler func(http.ResponseWriter, *http.Request) error, + fn func(*serverTester), + checkErr func(error), + otherHeaders ...string) { + inHandler := make(chan bool) + errc := make(chan error, 1) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + inHandler <- true + errc <- handler(w, r) + }) + st.greet() + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(append([]string{":method", "POST"}, otherHeaders...)...), + EndStream: false, // keep it open + EndHeaders: true, + }) + <-inHandler + fn(st) + select { + case err := <-errc: + if checkErr != nil { + checkErr(err) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for Handler to return") + } + st.Close() +} + +func TestServer_RSTStream_Unblocks_Read(t *testing.T) { + testServerPostUnblock(t, + func(w http.ResponseWriter, r *http.Request) (err error) { + _, err = r.Body.Read(make([]byte, 1)) + return + }, + func(st *serverTester) { + if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { + t.Fatal(err) + } + }, + func(err error) { + if err == nil { + t.Error("unexpected nil error from Request.Body.Read") + } + }, + ) +} + +func TestServer_DeadConn_Unblocks_Read(t *testing.T) { + testServerPostUnblock(t, + func(w http.ResponseWriter, r *http.Request) (err error) { + _, err = r.Body.Read(make([]byte, 1)) + return + }, + func(st *serverTester) { st.cc.Close() }, + func(err error) { + if err == nil { + t.Error("unexpected nil error from Request.Body.Read") + } + }, + ) +} + +var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error { + <-w.(http.CloseNotifier).CloseNotify() + return nil +} + +func TestServer_CloseNotify_After_RSTStream(t *testing.T) { + testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { + if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { + t.Fatal(err) + } + }, nil) +} + +func TestServer_CloseNotify_After_ConnClose(t *testing.T) { + testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil) +} + +// that CloseNotify unblocks after a stream error due to the client's +// problem that's unrelated to them explicitly canceling it (which is +// TestServer_CloseNotify_After_RSTStream above) +func TestServer_CloseNotify_After_StreamError(t *testing.T) { + testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { + // data longer than declared Content-Length => stream error + st.writeData(1, true, []byte("1234")) + }, nil, "content-length", "3") +} + +func TestServer_StateTransitions(t *testing.T) { + var st *serverTester + inHandler := make(chan bool) + writeData := make(chan bool) + leaveHandler := make(chan bool) + st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + inHandler <- true + if st.stream(1) == nil { + t.Errorf("nil stream 1 in handler") + } + if got, want := st.streamState(1), stateOpen; got != want { + t.Errorf("in handler, state is %v; want %v", got, want) + } + writeData <- true + if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF { + t.Errorf("body read = %d, %v; want 0, EOF", n, err) + } + if got, want := st.streamState(1), stateHalfClosedRemote; got != want { + t.Errorf("in handler, state is %v; want %v", got, want) + } + + <-leaveHandler + }) + st.greet() + if st.stream(1) != nil { + t.Fatal("stream 1 should be empty") + } + if got := st.streamState(1); got != stateIdle { + t.Fatalf("stream 1 should be idle; got %v", got) + } + + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(":method", "POST"), + EndStream: false, // keep it open + EndHeaders: true, + }) + <-inHandler + <-writeData + st.writeData(1, true, nil) + + leaveHandler <- true + hf := st.wantHeaders() + if !hf.StreamEnded() { + t.Fatal("expected END_STREAM flag") + } + + if got, want := st.streamState(1), stateClosed; got != want { + t.Errorf("at end, state is %v; want %v", got, want) + } + if st.stream(1) != nil { + t.Fatal("at end, stream 1 should be gone") + } +} + +// test HEADERS w/o EndHeaders + another HEADERS (should get rejected) +func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) { + testServerRejects(t, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: false, + }) + st.writeHeaders(HeadersFrameParam{ // Not a continuation. + StreamID: 3, // different stream. + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + }) +} + +// test HEADERS w/o EndHeaders + PING (should get rejected) +func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) { + testServerRejects(t, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: false, + }) + if err := st.fr.WritePing(false, [8]byte{}); err != nil { + t.Fatal(err) + } + }) +} + +// test HEADERS w/ EndHeaders + a continuation HEADERS (should get rejected) +func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) { + testServerRejects(t, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + st.wantHeaders() + if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil { + t.Fatal(err) + } + }) +} + +// test HEADERS w/o EndHeaders + a continuation HEADERS on wrong stream ID +func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) { + testServerRejects(t, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: false, + }) + if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil { + t.Fatal(err) + } + }) +} + +// No HEADERS on stream 0. +func TestServer_Rejects_Headers0(t *testing.T) { + testServerRejects(t, func(st *serverTester) { + st.fr.AllowIllegalWrites = true + st.writeHeaders(HeadersFrameParam{ + StreamID: 0, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + }) +} + +// No CONTINUATION on stream 0. +func TestServer_Rejects_Continuation0(t *testing.T) { + testServerRejects(t, func(st *serverTester) { + st.fr.AllowIllegalWrites = true + if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil { + t.Fatal(err) + } + }) +} + +func TestServer_Rejects_PushPromise(t *testing.T) { + testServerRejects(t, func(st *serverTester) { + pp := PushPromiseParam{ + StreamID: 1, + PromiseID: 3, + } + if err := st.fr.WritePushPromise(pp); err != nil { + t.Fatal(err) + } + }) +} + +// testServerRejects tests that the server hangs up with a GOAWAY +// frame and a server close after the client does something +// deserving a CONNECTION_ERROR. +func testServerRejects(t *testing.T, writeReq func(*serverTester)) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}) + st.addLogFilter("connection error: PROTOCOL_ERROR") + defer st.Close() + st.greet() + writeReq(st) + + st.wantGoAway() + errc := make(chan error, 1) + go func() { + fr, err := st.fr.ReadFrame() + if err == nil { + err = fmt.Errorf("got frame of type %T", fr) + } + errc <- err + }() + select { + case err := <-errc: + if err != io.EOF { + t.Errorf("ReadFrame = %v; want io.EOF", err) + } + case <-time.After(2 * time.Second): + t.Error("timeout waiting for disconnect") + } +} + +// testServerRequest sets up an idle HTTP/2 connection and lets you +// write a single request with writeReq, and then verify that the +// *http.Request is built correctly in checkReq. +func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) { + gotReq := make(chan bool, 1) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + if r.Body == nil { + t.Fatal("nil Body") + } + checkReq(r) + gotReq <- true + }) + defer st.Close() + + st.greet() + writeReq(st) + + select { + case <-gotReq: + case <-time.After(2 * time.Second): + t.Error("timeout waiting for request") + } +} + +func getSlash(st *serverTester) { st.bodylessReq1() } + +func TestServer_Response_NoData(t *testing.T) { + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + // Nothing. + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + if !hf.StreamEnded() { + t.Fatal("want END_STREAM flag") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + }) +} + +func TestServer_Response_NoData_Header_FooBar(t *testing.T) { + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Foo-Bar", "some-value") + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + if !hf.StreamEnded() { + t.Fatal("want END_STREAM flag") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + goth := decodeHeader(t, hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"foo-bar", "some-value"}, + {"content-type", "text/plain; charset=utf-8"}, + {"content-length", "0"}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + }) +} + +func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) { + const msg = "this is HTML." + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "foo/bar") + io.WriteString(w, msg) + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("don't want END_STREAM, expecting data") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + goth := decodeHeader(t, hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"content-type", "foo/bar"}, + {"content-length", strconv.Itoa(len(msg))}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + df := st.wantData() + if !df.StreamEnded() { + t.Error("expected DATA to have END_STREAM flag") + } + if got := string(df.Data()); got != msg { + t.Errorf("got DATA %q; want %q", got, msg) + } + }) +} + +func TestServer_Response_TransferEncoding_chunked(t *testing.T) { + const msg = "hi" + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Transfer-Encoding", "chunked") // should be stripped + io.WriteString(w, msg) + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + goth := decodeHeader(t, hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"content-type", "text/plain; charset=utf-8"}, + {"content-length", strconv.Itoa(len(msg))}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + }) +} + +// Header accessed only after the initial write. +func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) { + const msg = "this is HTML." + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + io.WriteString(w, msg) + w.Header().Set("foo", "should be ignored") + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("unexpected END_STREAM") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + goth := decodeHeader(t, hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"content-type", "text/html; charset=utf-8"}, + {"content-length", strconv.Itoa(len(msg))}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + }) +} + +// Header accessed before the initial write and later mutated. +func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) { + const msg = "this is HTML." + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + w.Header().Set("foo", "proper value") + io.WriteString(w, msg) + w.Header().Set("foo", "should be ignored") + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("unexpected END_STREAM") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + goth := decodeHeader(t, hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"foo", "proper value"}, + {"content-type", "text/html; charset=utf-8"}, + {"content-length", strconv.Itoa(len(msg))}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + }) +} + +func TestServer_Response_Data_SniffLenType(t *testing.T) { + const msg = "this is HTML." + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + io.WriteString(w, msg) + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("don't want END_STREAM, expecting data") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + goth := decodeHeader(t, hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"content-type", "text/html; charset=utf-8"}, + {"content-length", strconv.Itoa(len(msg))}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + df := st.wantData() + if !df.StreamEnded() { + t.Error("expected DATA to have END_STREAM flag") + } + if got := string(df.Data()); got != msg { + t.Errorf("got DATA %q; want %q", got, msg) + } + }) +} + +func TestServer_Response_Header_Flush_MidWrite(t *testing.T) { + const msg = "this is HTML" + const msg2 = ", and this is the next chunk" + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + io.WriteString(w, msg) + w.(http.Flusher).Flush() + io.WriteString(w, msg2) + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("unexpected END_STREAM flag") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + goth := decodeHeader(t, hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"content-type", "text/html; charset=utf-8"}, // sniffed + // and no content-length + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + { + df := st.wantData() + if df.StreamEnded() { + t.Error("unexpected END_STREAM flag") + } + if got := string(df.Data()); got != msg { + t.Errorf("got DATA %q; want %q", got, msg) + } + } + { + df := st.wantData() + if !df.StreamEnded() { + t.Error("wanted END_STREAM flag on last data chunk") + } + if got := string(df.Data()); got != msg2 { + t.Errorf("got DATA %q; want %q", got, msg2) + } + } + }) +} + +func TestServer_Response_LargeWrite(t *testing.T) { + const size = 1 << 20 + const maxFrameSize = 16 << 10 + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + n, err := w.Write(bytes.Repeat([]byte("a"), size)) + if err != nil { + return fmt.Errorf("Write error: %v", err) + } + if n != size { + return fmt.Errorf("wrong size %d from Write", n) + } + return nil + }, func(st *serverTester) { + if err := st.fr.WriteSettings( + Setting{SettingInitialWindowSize, 0}, + Setting{SettingMaxFrameSize, maxFrameSize}, + ); err != nil { + t.Fatal(err) + } + st.wantSettingsAck() + + getSlash(st) // make the single request + + // Give the handler quota to write: + if err := st.fr.WriteWindowUpdate(1, size); err != nil { + t.Fatal(err) + } + // Give the handler quota to write to connection-level + // window as well + if err := st.fr.WriteWindowUpdate(0, size); err != nil { + t.Fatal(err) + } + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("unexpected END_STREAM flag") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + goth := decodeHeader(t, hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "200"}, + {"content-type", "text/plain; charset=utf-8"}, // sniffed + // and no content-length + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + var bytes, frames int + for { + df := st.wantData() + bytes += len(df.Data()) + frames++ + for _, b := range df.Data() { + if b != 'a' { + t.Fatal("non-'a' byte seen in DATA") + } + } + if df.StreamEnded() { + break + } + } + if bytes != size { + t.Errorf("Got %d bytes; want %d", bytes, size) + } + if want := int(size / maxFrameSize); frames < want || frames > want*2 { + t.Errorf("Got %d frames; want %d", frames, size) + } + }) +} + +// Test that the handler can't write more than the client allows +func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) { + const size = 1 << 20 + const maxFrameSize = 16 << 10 + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + w.(http.Flusher).Flush() + n, err := w.Write(bytes.Repeat([]byte("a"), size)) + if err != nil { + return fmt.Errorf("Write error: %v", err) + } + if n != size { + return fmt.Errorf("wrong size %d from Write", n) + } + return nil + }, func(st *serverTester) { + // Set the window size to something explicit for this test. + // It's also how much initial data we expect. + const initWindowSize = 123 + if err := st.fr.WriteSettings( + Setting{SettingInitialWindowSize, initWindowSize}, + Setting{SettingMaxFrameSize, maxFrameSize}, + ); err != nil { + t.Fatal(err) + } + st.wantSettingsAck() + + getSlash(st) // make the single request + defer func() { st.fr.WriteRSTStream(1, ErrCodeCancel) }() + + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("unexpected END_STREAM flag") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + + df := st.wantData() + if got := len(df.Data()); got != initWindowSize { + t.Fatalf("Initial window size = %d but got DATA with %d bytes", initWindowSize, got) + } + + for _, quota := range []int{1, 13, 127} { + if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil { + t.Fatal(err) + } + df := st.wantData() + if int(quota) != len(df.Data()) { + t.Fatalf("read %d bytes after giving %d quota", len(df.Data()), quota) + } + } + + if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { + t.Fatal(err) + } + }) +} + +// Test that the handler blocked in a Write is unblocked if the server sends a RST_STREAM. +func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) { + const size = 1 << 20 + const maxFrameSize = 16 << 10 + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + w.(http.Flusher).Flush() + errc := make(chan error, 1) + go func() { + _, err := w.Write(bytes.Repeat([]byte("a"), size)) + errc <- err + }() + select { + case err := <-errc: + if err == nil { + return errors.New("unexpected nil error from Write in handler") + } + return nil + case <-time.After(2 * time.Second): + return errors.New("timeout waiting for Write in handler") + } + }, func(st *serverTester) { + if err := st.fr.WriteSettings( + Setting{SettingInitialWindowSize, 0}, + Setting{SettingMaxFrameSize, maxFrameSize}, + ); err != nil { + t.Fatal(err) + } + st.wantSettingsAck() + + getSlash(st) // make the single request + + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("unexpected END_STREAM flag") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + + if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil { + t.Fatal(err) + } + }) +} + +func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) { + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + w.(http.Flusher).Flush() + // Nothing; send empty DATA + return nil + }, func(st *serverTester) { + // Handler gets no data quota: + if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, 0}); err != nil { + t.Fatal(err) + } + st.wantSettingsAck() + + getSlash(st) // make the single request + + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("unexpected END_STREAM flag") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + + df := st.wantData() + if got := len(df.Data()); got != 0 { + t.Fatalf("unexpected %d DATA bytes; want 0", got) + } + if !df.StreamEnded() { + t.Fatal("DATA didn't have END_STREAM") + } + }) +} + +func TestServer_Response_Automatic100Continue(t *testing.T) { + const msg = "foo" + const reply = "bar" + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + if v := r.Header.Get("Expect"); v != "" { + t.Errorf("Expect header = %q; want empty", v) + } + buf := make([]byte, len(msg)) + // This read should trigger the 100-continue being sent. + if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg { + return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg) + } + _, err := io.WriteString(w, reply) + return err + }, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-continue"), + EndStream: false, + EndHeaders: true, + }) + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("unexpected END_STREAM flag") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + goth := decodeHeader(t, hf.HeaderBlockFragment()) + wanth := [][2]string{ + {":status", "100"}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Fatalf("Got headers %v; want %v", goth, wanth) + } + + // Okay, they sent status 100, so we can send our + // gigantic and/or sensitive "foo" payload now. + st.writeData(1, true, []byte(msg)) + + st.wantWindowUpdate(0, uint32(len(msg))) + + hf = st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("expected data to follow") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + goth = decodeHeader(t, hf.HeaderBlockFragment()) + wanth = [][2]string{ + {":status", "200"}, + {"content-type", "text/plain; charset=utf-8"}, + {"content-length", strconv.Itoa(len(reply))}, + } + if !reflect.DeepEqual(goth, wanth) { + t.Errorf("Got headers %v; want %v", goth, wanth) + } + + df := st.wantData() + if string(df.Data()) != reply { + t.Errorf("Client read %q; want %q", df.Data(), reply) + } + if !df.StreamEnded() { + t.Errorf("expect data stream end") + } + }) +} + +func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) { + errc := make(chan error, 1) + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + p := []byte("some data.\n") + for { + _, err := w.Write(p) + if err != nil { + errc <- err + return nil + } + } + }, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: false, + EndHeaders: true, + }) + hf := st.wantHeaders() + if hf.StreamEnded() { + t.Fatal("unexpected END_STREAM flag") + } + if !hf.HeadersEnded() { + t.Fatal("want END_HEADERS flag") + } + // Close the connection and wait for the handler to (hopefully) notice. + st.cc.Close() + select { + case <-errc: + case <-time.After(5 * time.Second): + t.Error("timeout") + } + }) +} + +func TestServer_Rejects_Too_Many_Streams(t *testing.T) { + const testPath = "/some/path" + + inHandler := make(chan uint32) + leaveHandler := make(chan bool) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + id := w.(*responseWriter).rws.stream.id + inHandler <- id + if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath { + t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath) + } + <-leaveHandler + }) + defer st.Close() + st.greet() + nextStreamID := uint32(1) + streamID := func() uint32 { + defer func() { nextStreamID += 2 }() + return nextStreamID + } + sendReq := func(id uint32, headers ...string) { + st.writeHeaders(HeadersFrameParam{ + StreamID: id, + BlockFragment: st.encodeHeader(headers...), + EndStream: true, + EndHeaders: true, + }) + } + for i := 0; i < defaultMaxStreams; i++ { + sendReq(streamID()) + <-inHandler + } + defer func() { + for i := 0; i < defaultMaxStreams; i++ { + leaveHandler <- true + } + }() + + // And this one should cross the limit: + // (It's also sent as a CONTINUATION, to verify we still track the decoder context, + // even if we're rejecting it) + rejectID := streamID() + headerBlock := st.encodeHeader(":path", testPath) + frag1, frag2 := headerBlock[:3], headerBlock[3:] + st.writeHeaders(HeadersFrameParam{ + StreamID: rejectID, + BlockFragment: frag1, + EndStream: true, + EndHeaders: false, // CONTINUATION coming + }) + if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil { + t.Fatal(err) + } + st.wantRSTStream(rejectID, ErrCodeProtocol) + + // But let a handler finish: + leaveHandler <- true + st.wantHeaders() + + // And now another stream should be able to start: + goodID := streamID() + sendReq(goodID, ":path", testPath) + select { + case got := <-inHandler: + if got != goodID { + t.Errorf("Got stream %d; want %d", got, goodID) + } + case <-time.After(3 * time.Second): + t.Error("timeout waiting for handler") + } +} + +// So many response headers that the server needs to use CONTINUATION frames: +func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) { + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + h := w.Header() + for i := 0; i < 5000; i++ { + h.Set(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("x-value-%d", i)) + } + return nil + }, func(st *serverTester) { + getSlash(st) + hf := st.wantHeaders() + if hf.HeadersEnded() { + t.Fatal("got unwanted END_HEADERS flag") + } + n := 0 + for { + n++ + cf := st.wantContinuation() + if cf.HeadersEnded() { + break + } + } + if n < 5 { + t.Errorf("Only got %d CONTINUATION frames; expected 5+ (currently 6)", n) + } + }) +} + +// This previously crashed (reported by Mathieu Lonjaret as observed +// while using Camlistore) because we got a DATA frame from the client +// after the handler exited and our logic at the time was wrong, +// keeping a stream in the map in stateClosed, which tickled an +// invariant check later when we tried to remove that stream (via +// defer sc.closeAllStreamsOnConnClose) when the serverConn serve loop +// ended. +func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) { + testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error { + // nothing + return nil + }, func(st *serverTester) { + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, + BlockFragment: st.encodeHeader(), + EndStream: false, // DATA is coming + EndHeaders: true, + }) + hf := st.wantHeaders() + if !hf.HeadersEnded() || !hf.StreamEnded() { + t.Fatalf("want END_HEADERS+END_STREAM, got %v", hf) + } + + // Sent when the a Handler closes while a client has + // indicated it's still sending DATA: + st.wantRSTStream(1, ErrCodeCancel) + + // Now the handler has ended, so it's ended its + // stream, but the client hasn't closed its side + // (stateClosedLocal). So send more data and verify + // it doesn't crash with an internal invariant panic, like + // it did before. + st.writeData(1, true, []byte("foo")) + + // Sent after a peer sends data anyway (admittedly the + // previous RST_STREAM might've still been in-flight), + // but they'll get the more friendly 'cancel' code + // first. + st.wantRSTStream(1, ErrCodeStreamClosed) + + // Set up a bunch of machinery to record the panic we saw + // previously. + var ( + panMu sync.Mutex + panicVal interface{} + ) + + testHookOnPanicMu.Lock() + testHookOnPanic = func(sc *serverConn, pv interface{}) bool { + panMu.Lock() + panicVal = pv + panMu.Unlock() + return true + } + testHookOnPanicMu.Unlock() + + // Now force the serve loop to end, via closing the connection. + st.cc.Close() + select { + case <-st.sc.doneServing: + // Loop has exited. + panMu.Lock() + got := panicVal + panMu.Unlock() + if got != nil { + t.Errorf("Got panic: %v", got) + } + case <-time.After(5 * time.Second): + t.Error("timeout") + } + }) +} + +func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) } +func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) } + +func testRejectTLS(t *testing.T, max uint16) { + st := newServerTester(t, nil, func(c *tls.Config) { + c.MaxVersion = max + }) + defer st.Close() + gf := st.wantGoAway() + if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want { + t.Errorf("Got error code %v; want %v", got, want) + } +} + +func TestServer_Rejects_TLSBadCipher(t *testing.T) { + st := newServerTester(t, nil, func(c *tls.Config) { + // Only list bad ones: + c.CipherSuites = []uint16{ + tls.TLS_RSA_WITH_RC4_128_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + } + }) + defer st.Close() + gf := st.wantGoAway() + if got, want := gf.ErrCode, ErrCodeInadequateSecurity; got != want { + t.Errorf("Got error code %v; want %v", got, want) + } +} + +func TestServer_Advertises_Common_Cipher(t *testing.T) { + const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + st := newServerTester(t, nil, func(c *tls.Config) { + // Have the client only support the one required by the spec. + c.CipherSuites = []uint16{requiredSuite} + }, func(ts *httptest.Server) { + var srv *http.Server = ts.Config + // Have the server configured with one specific cipher suite + // which is banned. This tests that ConfigureServer ends up + // adding the good one to this list. + srv.TLSConfig = &tls.Config{ + CipherSuites: []uint16{tls.TLS_RSA_WITH_AES_128_CBC_SHA}, // just a banned one + } + }) + defer st.Close() + st.greet() +} + +// TODO: move this onto *serverTester, and re-use the same hpack +// decoding context throughout. We're just getting lucky here with +// creating a new decoder each time. +func decodeHeader(t *testing.T, headerBlock []byte) (pairs [][2]string) { + d := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) { + pairs = append(pairs, [2]string{f.Name, f.Value}) + }) + if _, err := d.Write(headerBlock); err != nil { + t.Fatalf("hpack decoding error: %v", err) + } + if err := d.Close(); err != nil { + t.Fatalf("hpack decoding error: %v", err) + } + return +} + +// testServerResponse sets up an idle HTTP/2 connection and lets you +// write a single request with writeReq, and then reply to it in some way with the provided handler, +// and then verify the output with the serverTester again (assuming the handler returns nil) +func testServerResponse(t testing.TB, + handler func(http.ResponseWriter, *http.Request) error, + client func(*serverTester), +) { + errc := make(chan error, 1) + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + if r.Body == nil { + t.Fatal("nil Body") + } + errc <- handler(w, r) + }) + defer st.Close() + + donec := make(chan bool) + go func() { + defer close(donec) + st.greet() + client(st) + }() + + select { + case <-donec: + return + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } + + select { + case err := <-errc: + if err != nil { + t.Fatalf("Error in handler: %v", err) + } + case <-time.After(2 * time.Second): + t.Error("timeout waiting for handler to finish") + } +} + +// readBodyHandler returns an http Handler func that reads len(want) +// bytes from r.Body and fails t if the contents read were not +// the value of want. +func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + buf := make([]byte, len(want)) + _, err := io.ReadFull(r.Body, buf) + if err != nil { + t.Error(err) + return + } + if string(buf) != want { + t.Errorf("read %q; want %q", buf, want) + } + } +} + +// TestServerWithCurl currently fails, hence the LenientCipherSuites test. See: +// https://github.com/tatsuhiro-t/nghttp2/issues/140 & +// http://sourceforge.net/p/curl/bugs/1472/ +func TestServerWithCurl(t *testing.T) { testServerWithCurl(t, false) } +func TestServerWithCurl_LenientCipherSuites(t *testing.T) { testServerWithCurl(t, true) } + +func testServerWithCurl(t *testing.T, permitProhibitedCipherSuites bool) { + if runtime.GOOS != "linux" { + t.Skip("skipping Docker test when not on Linux; requires --net which won't work with boot2docker anyway") + } + requireCurl(t) + const msg = "Hello from curl!\n" + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Foo", "Bar") + w.Header().Set("Client-Proto", r.Proto) + io.WriteString(w, msg) + })) + ConfigureServer(ts.Config, &Server{ + PermitProhibitedCipherSuites: permitProhibitedCipherSuites, + }) + ts.TLS = ts.Config.TLSConfig // the httptest.Server has its own copy of this TLS config + ts.StartTLS() + defer ts.Close() + + var gotConn int32 + testHookOnConn = func() { atomic.StoreInt32(&gotConn, 1) } + + t.Logf("Running test server for curl to hit at: %s", ts.URL) + container := curl(t, "--silent", "--http2", "--insecure", "-v", ts.URL) + defer kill(container) + resc := make(chan interface{}, 1) + go func() { + res, err := dockerLogs(container) + if err != nil { + resc <- err + } else { + resc <- res + } + }() + select { + case res := <-resc: + if err, ok := res.(error); ok { + t.Fatal(err) + } + if !strings.Contains(string(res.([]byte)), "foo: Bar") { + t.Errorf("didn't see foo: Bar header") + t.Logf("Got: %s", res) + } + if !strings.Contains(string(res.([]byte)), "client-proto: HTTP/2") { + t.Errorf("didn't see client-proto: HTTP/2 header") + t.Logf("Got: %s", res) + } + if !strings.Contains(string(res.([]byte)), msg) { + t.Errorf("didn't see %q content", msg) + t.Logf("Got: %s", res) + } + case <-time.After(3 * time.Second): + t.Errorf("timeout waiting for curl") + } + + if atomic.LoadInt32(&gotConn) == 0 { + t.Error("never saw an http2 connection") + } +} + +func BenchmarkServerGets(b *testing.B) { + b.ReportAllocs() + + const msg = "Hello, world" + st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, msg) + }) + defer st.Close() + st.greet() + + // Give the server quota to reply. (plus it has the the 64KB) + if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + id := 1 + uint32(i)*2 + st.writeHeaders(HeadersFrameParam{ + StreamID: id, + BlockFragment: st.encodeHeader(), + EndStream: true, + EndHeaders: true, + }) + st.wantHeaders() + df := st.wantData() + if !df.StreamEnded() { + b.Fatalf("DATA didn't have END_STREAM; got %v", df) + } + } +} + +func BenchmarkServerPosts(b *testing.B) { + b.ReportAllocs() + + const msg = "Hello, world" + st := newServerTester(b, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, msg) + }) + defer st.Close() + st.greet() + + // Give the server quota to reply. (plus it has the the 64KB) + if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + id := 1 + uint32(i)*2 + st.writeHeaders(HeadersFrameParam{ + StreamID: id, + BlockFragment: st.encodeHeader(":method", "POST"), + EndStream: false, + EndHeaders: true, + }) + st.writeData(id, true, nil) + st.wantHeaders() + df := st.wantData() + if !df.StreamEnded() { + b.Fatalf("DATA didn't have END_STREAM; got %v", df) + } + } +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/transport.go b/Godeps/_workspace/src/github.com/bradfitz/http2/transport.go new file mode 100644 index 0000000..ea62188 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/transport.go @@ -0,0 +1,553 @@ +// Copyright 2015 The Go Authors. +// See https://go.googlesource.com/go/+/master/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://go.googlesource.com/go/+/master/LICENSE + +package http2 + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "strconv" + "strings" + "sync" + + "github.com/bradfitz/http2/hpack" +) + +type Transport struct { + Fallback http.RoundTripper + + // TODO: remove this and make more general with a TLS dial hook, like http + InsecureTLSDial bool + + connMu sync.Mutex + conns map[string][]*clientConn // key is host:port +} + +type clientConn struct { + t *Transport + tconn *tls.Conn + tlsState *tls.ConnectionState + connKey []string // key(s) this connection is cached in, in t.conns + + readerDone chan struct{} // closed on error + readerErr error // set before readerDone is closed + hdec *hpack.Decoder + nextRes *http.Response + + mu sync.Mutex + closed bool + goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received + streams map[uint32]*clientStream + nextStreamID uint32 + bw *bufio.Writer + werr error // first write error that has occurred + br *bufio.Reader + fr *Framer + // Settings from peer: + maxFrameSize uint32 + maxConcurrentStreams uint32 + initialWindowSize uint32 + hbuf bytes.Buffer // HPACK encoder writes into this + henc *hpack.Encoder +} + +type clientStream struct { + ID uint32 + resc chan resAndError + pw *io.PipeWriter + pr *io.PipeReader +} + +type stickyErrWriter struct { + w io.Writer + err *error +} + +func (sew stickyErrWriter) Write(p []byte) (n int, err error) { + if *sew.err != nil { + return 0, *sew.err + } + n, err = sew.w.Write(p) + *sew.err = err + return +} + +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + if req.URL.Scheme != "https" { + if t.Fallback == nil { + return nil, errors.New("http2: unsupported scheme and no Fallback") + } + return t.Fallback.RoundTrip(req) + } + + host, port, err := net.SplitHostPort(req.URL.Host) + if err != nil { + host = req.URL.Host + port = "443" + } + + for { + cc, err := t.getClientConn(host, port) + if err != nil { + return nil, err + } + res, err := cc.roundTrip(req) + if shouldRetryRequest(err) { // TODO: or clientconn is overloaded (too many outstanding requests)? + continue + } + if err != nil { + return nil, err + } + return res, nil + } +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle. +// It does not interrupt any connections currently in use. +func (t *Transport) CloseIdleConnections() { + t.connMu.Lock() + defer t.connMu.Unlock() + for _, vv := range t.conns { + for _, cc := range vv { + cc.closeIfIdle() + } + } +} + +var errClientConnClosed = errors.New("http2: client conn is closed") + +func shouldRetryRequest(err error) bool { + // TODO: or GOAWAY graceful shutdown stuff + return err == errClientConnClosed +} + +func (t *Transport) removeClientConn(cc *clientConn) { + t.connMu.Lock() + defer t.connMu.Unlock() + for _, key := range cc.connKey { + vv, ok := t.conns[key] + if !ok { + continue + } + newList := filterOutClientConn(vv, cc) + if len(newList) > 0 { + t.conns[key] = newList + } else { + delete(t.conns, key) + } + } +} + +func filterOutClientConn(in []*clientConn, exclude *clientConn) []*clientConn { + out := in[:0] + for _, v := range in { + if v != exclude { + out = append(out, v) + } + } + return out +} + +func (t *Transport) getClientConn(host, port string) (*clientConn, error) { + t.connMu.Lock() + defer t.connMu.Unlock() + + key := net.JoinHostPort(host, port) + + for _, cc := range t.conns[key] { + if cc.canTakeNewRequest() { + return cc, nil + } + } + if t.conns == nil { + t.conns = make(map[string][]*clientConn) + } + cc, err := t.newClientConn(host, port, key) + if err != nil { + return nil, err + } + t.conns[key] = append(t.conns[key], cc) + return cc, nil +} + +func (t *Transport) newClientConn(host, port, key string) (*clientConn, error) { + cfg := &tls.Config{ + ServerName: host, + NextProtos: []string{NextProtoTLS}, + InsecureSkipVerify: t.InsecureTLSDial, + } + tconn, err := tls.Dial("tcp", host+":"+port, cfg) + if err != nil { + return nil, err + } + if err := tconn.Handshake(); err != nil { + return nil, err + } + if !t.InsecureTLSDial { + if err := tconn.VerifyHostname(cfg.ServerName); err != nil { + return nil, err + } + } + state := tconn.ConnectionState() + if p := state.NegotiatedProtocol; p != NextProtoTLS { + // TODO(bradfitz): fall back to Fallback + return nil, fmt.Errorf("bad protocol: %v", p) + } + if !state.NegotiatedProtocolIsMutual { + return nil, errors.New("could not negotiate protocol mutually") + } + if _, err := tconn.Write(clientPreface); err != nil { + return nil, err + } + + cc := &clientConn{ + t: t, + tconn: tconn, + connKey: []string{key}, // TODO: cert's validated hostnames too + tlsState: &state, + readerDone: make(chan struct{}), + nextStreamID: 1, + maxFrameSize: 16 << 10, // spec default + initialWindowSize: 65535, // spec default + maxConcurrentStreams: 1000, // "infinite", per spec. 1000 seems good enough. + streams: make(map[uint32]*clientStream), + } + cc.bw = bufio.NewWriter(stickyErrWriter{tconn, &cc.werr}) + cc.br = bufio.NewReader(tconn) + cc.fr = NewFramer(cc.bw, cc.br) + cc.henc = hpack.NewEncoder(&cc.hbuf) + + cc.fr.WriteSettings() + // TODO: re-send more conn-level flow control tokens when server uses all these. + cc.fr.WriteWindowUpdate(0, 1<<30) // um, 0x7fffffff doesn't work to Google? it hangs? + cc.bw.Flush() + if cc.werr != nil { + return nil, cc.werr + } + + // Read the obligatory SETTINGS frame + f, err := cc.fr.ReadFrame() + if err != nil { + return nil, err + } + sf, ok := f.(*SettingsFrame) + if !ok { + return nil, fmt.Errorf("expected settings frame, got: %T", f) + } + cc.fr.WriteSettingsAck() + cc.bw.Flush() + + sf.ForeachSetting(func(s Setting) error { + switch s.ID { + case SettingMaxFrameSize: + cc.maxFrameSize = s.Val + case SettingMaxConcurrentStreams: + cc.maxConcurrentStreams = s.Val + case SettingInitialWindowSize: + cc.initialWindowSize = s.Val + default: + // TODO(bradfitz): handle more + log.Printf("Unhandled Setting: %v", s) + } + return nil + }) + // TODO: figure out henc size + cc.hdec = hpack.NewDecoder(initialHeaderTableSize, cc.onNewHeaderField) + + go cc.readLoop() + return cc, nil +} + +func (cc *clientConn) setGoAway(f *GoAwayFrame) { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.goAway = f +} + +func (cc *clientConn) canTakeNewRequest() bool { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.goAway == nil && + int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) && + cc.nextStreamID < 2147483647 +} + +func (cc *clientConn) closeIfIdle() { + cc.mu.Lock() + if len(cc.streams) > 0 { + cc.mu.Unlock() + return + } + cc.closed = true + // TODO: do clients send GOAWAY too? maybe? Just Close: + cc.mu.Unlock() + + cc.tconn.Close() +} + +func (cc *clientConn) roundTrip(req *http.Request) (*http.Response, error) { + cc.mu.Lock() + + if cc.closed { + cc.mu.Unlock() + return nil, errClientConnClosed + } + + cs := cc.newStream() + hasBody := false // TODO + + // we send: HEADERS[+CONTINUATION] + (DATA?) + hdrs := cc.encodeHeaders(req) + first := true + for len(hdrs) > 0 { + chunk := hdrs + if len(chunk) > int(cc.maxFrameSize) { + chunk = chunk[:cc.maxFrameSize] + } + hdrs = hdrs[len(chunk):] + endHeaders := len(hdrs) == 0 + if first { + cc.fr.WriteHeaders(HeadersFrameParam{ + StreamID: cs.ID, + BlockFragment: chunk, + EndStream: !hasBody, + EndHeaders: endHeaders, + }) + first = false + } else { + cc.fr.WriteContinuation(cs.ID, endHeaders, chunk) + } + } + cc.bw.Flush() + werr := cc.werr + cc.mu.Unlock() + + if hasBody { + // TODO: write data. and it should probably be interleaved: + // go ... io.Copy(dataFrameWriter{cc, cs, ...}, req.Body) ... etc + } + + if werr != nil { + return nil, werr + } + + re := <-cs.resc + if re.err != nil { + return nil, re.err + } + res := re.res + res.Request = req + res.TLS = cc.tlsState + return res, nil +} + +// requires cc.mu be held. +func (cc *clientConn) encodeHeaders(req *http.Request) []byte { + cc.hbuf.Reset() + + // TODO(bradfitz): figure out :authority-vs-Host stuff between http2 and Go + host := req.Host + if host == "" { + host = req.URL.Host + } + + path := req.URL.Path + if path == "" { + path = "/" + } + + cc.writeHeader(":authority", host) // probably not right for all sites + cc.writeHeader(":method", req.Method) + cc.writeHeader(":path", path) + cc.writeHeader(":scheme", "https") + + for k, vv := range req.Header { + lowKey := strings.ToLower(k) + if lowKey == "host" { + continue + } + for _, v := range vv { + cc.writeHeader(lowKey, v) + } + } + return cc.hbuf.Bytes() +} + +func (cc *clientConn) writeHeader(name, value string) { + log.Printf("sending %q = %q", name, value) + cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value}) +} + +type resAndError struct { + res *http.Response + err error +} + +// requires cc.mu be held. +func (cc *clientConn) newStream() *clientStream { + cs := &clientStream{ + ID: cc.nextStreamID, + resc: make(chan resAndError, 1), + } + cc.nextStreamID += 2 + cc.streams[cs.ID] = cs + return cs +} + +func (cc *clientConn) streamByID(id uint32, andRemove bool) *clientStream { + cc.mu.Lock() + defer cc.mu.Unlock() + cs := cc.streams[id] + if andRemove { + delete(cc.streams, id) + } + return cs +} + +// runs in its own goroutine. +func (cc *clientConn) readLoop() { + defer cc.t.removeClientConn(cc) + defer close(cc.readerDone) + + activeRes := map[uint32]*clientStream{} // keyed by streamID + // Close any response bodies if the server closes prematurely. + // TODO: also do this if we've written the headers but not + // gotten a response yet. + defer func() { + err := cc.readerErr + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + for _, cs := range activeRes { + cs.pw.CloseWithError(err) + } + }() + + // continueStreamID is the stream ID we're waiting for + // continuation frames for. + var continueStreamID uint32 + + for { + f, err := cc.fr.ReadFrame() + if err != nil { + cc.readerErr = err + return + } + log.Printf("Transport received %v: %#v", f.Header(), f) + + streamID := f.Header().StreamID + + _, isContinue := f.(*ContinuationFrame) + if isContinue { + if streamID != continueStreamID { + log.Printf("Protocol violation: got CONTINUATION with id %d; want %d", streamID, continueStreamID) + cc.readerErr = ConnectionError(ErrCodeProtocol) + return + } + } else if continueStreamID != 0 { + // Continue frames need to be adjacent in the stream + // and we were in the middle of headers. + log.Printf("Protocol violation: got %T for stream %d, want CONTINUATION for %d", f, streamID, continueStreamID) + cc.readerErr = ConnectionError(ErrCodeProtocol) + return + } + + if streamID%2 == 0 { + // Ignore streams pushed from the server for now. + // These always have an even stream id. + continue + } + streamEnded := false + if ff, ok := f.(streamEnder); ok { + streamEnded = ff.StreamEnded() + } + + cs := cc.streamByID(streamID, streamEnded) + if cs == nil { + log.Printf("Received frame for untracked stream ID %d", streamID) + continue + } + + switch f := f.(type) { + case *HeadersFrame: + cc.nextRes = &http.Response{ + Proto: "HTTP/2.0", + ProtoMajor: 2, + Header: make(http.Header), + } + cs.pr, cs.pw = io.Pipe() + cc.hdec.Write(f.HeaderBlockFragment()) + case *ContinuationFrame: + cc.hdec.Write(f.HeaderBlockFragment()) + case *DataFrame: + log.Printf("DATA: %q", f.Data()) + cs.pw.Write(f.Data()) + case *GoAwayFrame: + cc.t.removeClientConn(cc) + if f.ErrCode != 0 { + // TODO: deal with GOAWAY more. particularly the error code + log.Printf("transport got GOAWAY with error code = %v", f.ErrCode) + } + cc.setGoAway(f) + default: + log.Printf("Transport: unhandled response frame type %T", f) + } + headersEnded := false + if he, ok := f.(headersEnder); ok { + headersEnded = he.HeadersEnded() + if headersEnded { + continueStreamID = 0 + } else { + continueStreamID = streamID + } + } + + if streamEnded { + cs.pw.Close() + delete(activeRes, streamID) + } + if headersEnded { + if cs == nil { + panic("couldn't find stream") // TODO be graceful + } + // TODO: set the Body to one which notes the + // Close and also sends the server a + // RST_STREAM + cc.nextRes.Body = cs.pr + res := cc.nextRes + activeRes[streamID] = cs + cs.resc <- resAndError{res: res} + } + } +} + +func (cc *clientConn) onNewHeaderField(f hpack.HeaderField) { + // TODO: verifiy pseudo headers come before non-pseudo headers + // TODO: verifiy the status is set + log.Printf("Header field: %+v", f) + if f.Name == ":status" { + code, err := strconv.Atoi(f.Value) + if err != nil { + panic("TODO: be graceful") + } + cc.nextRes.Status = f.Value + " " + http.StatusText(code) + cc.nextRes.StatusCode = code + return + } + if strings.HasPrefix(f.Name, ":") { + // "Endpoints MUST NOT generate pseudo-header fields other than those defined in this document." + // TODO: treat as invalid? + return + } + cc.nextRes.Header.Add(http.CanonicalHeaderKey(f.Name), f.Value) +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/transport_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/transport_test.go new file mode 100644 index 0000000..d752da1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/transport_test.go @@ -0,0 +1,168 @@ +// Copyright 2015 The Go Authors. +// See https://go.googlesource.com/go/+/master/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://go.googlesource.com/go/+/master/LICENSE + +package http2 + +import ( + "flag" + "io" + "io/ioutil" + "net/http" + "os" + "reflect" + "strings" + "testing" + "time" +) + +var ( + extNet = flag.Bool("extnet", false, "do external network tests") + transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport") + insecure = flag.Bool("insecure", false, "insecure TLS dials") +) + +func TestTransportExternal(t *testing.T) { + if !*extNet { + t.Skip("skipping external network test") + } + req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil) + rt := &Transport{ + InsecureTLSDial: *insecure, + } + res, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("%v", err) + } + res.Write(os.Stdout) +} + +func TestTransport(t *testing.T) { + const body = "sup" + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, body) + }) + defer st.Close() + + tr := &Transport{InsecureTLSDial: true} + defer tr.CloseIdleConnections() + + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + t.Logf("Got res: %+v", res) + if g, w := res.StatusCode, 200; g != w { + t.Errorf("StatusCode = %v; want %v", g, w) + } + if g, w := res.Status, "200 OK"; g != w { + t.Errorf("Status = %q; want %q", g, w) + } + wantHeader := http.Header{ + "Content-Length": []string{"3"}, + "Content-Type": []string{"text/plain; charset=utf-8"}, + } + if !reflect.DeepEqual(res.Header, wantHeader) { + t.Errorf("res Header = %v; want %v", res.Header, wantHeader) + } + if res.Request != req { + t.Errorf("Response.Request = %p; want %p", res.Request, req) + } + if res.TLS == nil { + t.Error("Response.TLS = nil; want non-nil") + } + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("Body read: %v", err) + } else if string(slurp) != body { + t.Errorf("Body = %q; want %q", slurp, body) + } + +} + +func TestTransportReusesConns(t *testing.T) { + st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, r.RemoteAddr) + }, optOnlyServer) + defer st.Close() + tr := &Transport{InsecureTLSDial: true} + defer tr.CloseIdleConnections() + get := func() string { + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("Body read: %v", err) + } + addr := strings.TrimSpace(string(slurp)) + if addr == "" { + t.Fatalf("didn't get an addr in response") + } + return addr + } + first := get() + second := get() + if first != second { + t.Errorf("first and second responses were on different connections: %q vs %q", first, second) + } +} + +func TestTransportAbortClosesPipes(t *testing.T) { + shutdown := make(chan struct{}) + st := newServerTester(t, + func(w http.ResponseWriter, r *http.Request) { + w.(http.Flusher).Flush() + <-shutdown + }, + optOnlyServer, + ) + defer st.Close() + defer close(shutdown) // we must shutdown before st.Close() to avoid hanging + + done := make(chan struct{}) + requestMade := make(chan struct{}) + go func() { + defer close(done) + tr := &Transport{ + InsecureTLSDial: true, + } + req, err := http.NewRequest("GET", st.ts.URL, nil) + if err != nil { + t.Fatal(err) + } + res, err := tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + close(requestMade) + _, err = ioutil.ReadAll(res.Body) + if err == nil { + t.Error("expected error from res.Body.Read") + } + }() + + <-requestMade + // Now force the serve loop to end, via closing the connection. + st.closeConn() + // deadlock? that's a bug. + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("timeout") + } +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/write.go b/Godeps/_workspace/src/github.com/bradfitz/http2/write.go new file mode 100644 index 0000000..7b9bdd3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/write.go @@ -0,0 +1,204 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "bytes" + "fmt" + "net/http" + "time" + + "github.com/bradfitz/http2/hpack" +) + +// writeFramer is implemented by any type that is used to write frames. +type writeFramer interface { + writeFrame(writeContext) error +} + +// writeContext is the interface needed by the various frame writer +// types below. All the writeFrame methods below are scheduled via the +// frame writing scheduler (see writeScheduler in writesched.go). +// +// This interface is implemented by *serverConn. +// TODO: use it from the client code too, once it exists. +type writeContext interface { + Framer() *Framer + Flush() error + CloseConn() error + // HeaderEncoder returns an HPACK encoder that writes to the + // returned buffer. + HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) +} + +// endsStream reports whether the given frame writer w will locally +// close the stream. +func endsStream(w writeFramer) bool { + switch v := w.(type) { + case *writeData: + return v.endStream + case *writeResHeaders: + return v.endStream + } + return false +} + +type flushFrameWriter struct{} + +func (flushFrameWriter) writeFrame(ctx writeContext) error { + return ctx.Flush() +} + +type writeSettings []Setting + +func (s writeSettings) writeFrame(ctx writeContext) error { + return ctx.Framer().WriteSettings([]Setting(s)...) +} + +type writeGoAway struct { + maxStreamID uint32 + code ErrCode +} + +func (p *writeGoAway) writeFrame(ctx writeContext) error { + err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil) + if p.code != 0 { + ctx.Flush() // ignore error: we're hanging up on them anyway + time.Sleep(50 * time.Millisecond) + ctx.CloseConn() + } + return err +} + +type writeData struct { + streamID uint32 + p []byte + endStream bool +} + +func (w *writeData) String() string { + return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream) +} + +func (w *writeData) writeFrame(ctx writeContext) error { + return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) +} + +func (se StreamError) writeFrame(ctx writeContext) error { + return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) +} + +type writePingAck struct{ pf *PingFrame } + +func (w writePingAck) writeFrame(ctx writeContext) error { + return ctx.Framer().WritePing(true, w.pf.Data) +} + +type writeSettingsAck struct{} + +func (writeSettingsAck) writeFrame(ctx writeContext) error { + return ctx.Framer().WriteSettingsAck() +} + +// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames +// for HTTP response headers from a server handler. +type writeResHeaders struct { + streamID uint32 + httpResCode int + h http.Header // may be nil + endStream bool + + contentType string + contentLength string +} + +func (w *writeResHeaders) writeFrame(ctx writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: httpCodeString(w.httpResCode)}) + for k, vv := range w.h { + k = lowerHeader(k) + for _, v := range vv { + // TODO: more of "8.1.2.2 Connection-Specific Header Fields" + if k == "transfer-encoding" && v != "trailers" { + continue + } + enc.WriteField(hpack.HeaderField{Name: k, Value: v}) + } + } + if w.contentType != "" { + enc.WriteField(hpack.HeaderField{Name: "content-type", Value: w.contentType}) + } + if w.contentLength != "" { + enc.WriteField(hpack.HeaderField{Name: "content-length", Value: w.contentLength}) + } + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + endHeaders := len(headerBlock) == 0 + var err error + if first { + first = false + err = ctx.Framer().WriteHeaders(HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: endHeaders, + }) + } else { + err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag) + } + if err != nil { + return err + } + } + return nil +} + +type write100ContinueHeadersFrame struct { + streamID uint32 +} + +func (w write100ContinueHeadersFrame) writeFrame(ctx writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + enc.WriteField(hpack.HeaderField{Name: ":status", Value: "100"}) + return ctx.Framer().WriteHeaders(HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: buf.Bytes(), + EndStream: false, + EndHeaders: true, + }) +} + +type writeWindowUpdate struct { + streamID uint32 // or 0 for conn-level + n uint32 +} + +func (wu writeWindowUpdate) writeFrame(ctx writeContext) error { + return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/writesched.go b/Godeps/_workspace/src/github.com/bradfitz/http2/writesched.go new file mode 100644 index 0000000..0e1b748 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/writesched.go @@ -0,0 +1,286 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import "fmt" + +// frameWriteMsg is a request to write a frame. +type frameWriteMsg struct { + // write is the interface value that does the writing, once the + // writeScheduler (below) has decided to select this frame + // to write. The write functions are all defined in write.go. + write writeFramer + + stream *stream // used for prioritization. nil for non-stream frames. + + // done, if non-nil, must be a buffered channel with space for + // 1 message and is sent the return value from write (or an + // earlier error) when the frame has been written. + done chan error +} + +// for debugging only: +func (wm frameWriteMsg) String() string { + var streamID uint32 + if wm.stream != nil { + streamID = wm.stream.id + } + var des string + if s, ok := wm.write.(fmt.Stringer); ok { + des = s.String() + } else { + des = fmt.Sprintf("%T", wm.write) + } + return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des) +} + +// writeScheduler tracks pending frames to write, priorities, and decides +// the next one to use. It is not thread-safe. +type writeScheduler struct { + // zero are frames not associated with a specific stream. + // They're sent before any stream-specific freams. + zero writeQueue + + // maxFrameSize is the maximum size of a DATA frame + // we'll write. Must be non-zero and between 16K-16M. + maxFrameSize uint32 + + // sq contains the stream-specific queues, keyed by stream ID. + // when a stream is idle, it's deleted from the map. + sq map[uint32]*writeQueue + + // canSend is a slice of memory that's reused between frame + // scheduling decisions to hold the list of writeQueues (from sq) + // which have enough flow control data to send. After canSend is + // built, the best is selected. + canSend []*writeQueue + + // pool of empty queues for reuse. + queuePool []*writeQueue +} + +func (ws *writeScheduler) putEmptyQueue(q *writeQueue) { + if len(q.s) != 0 { + panic("queue must be empty") + } + ws.queuePool = append(ws.queuePool, q) +} + +func (ws *writeScheduler) getEmptyQueue() *writeQueue { + ln := len(ws.queuePool) + if ln == 0 { + return new(writeQueue) + } + q := ws.queuePool[ln-1] + ws.queuePool = ws.queuePool[:ln-1] + return q +} + +func (ws *writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 } + +func (ws *writeScheduler) add(wm frameWriteMsg) { + st := wm.stream + if st == nil { + ws.zero.push(wm) + } else { + ws.streamQueue(st.id).push(wm) + } +} + +func (ws *writeScheduler) streamQueue(streamID uint32) *writeQueue { + if q, ok := ws.sq[streamID]; ok { + return q + } + if ws.sq == nil { + ws.sq = make(map[uint32]*writeQueue) + } + q := ws.getEmptyQueue() + ws.sq[streamID] = q + return q +} + +// take returns the most important frame to write and removes it from the scheduler. +// It is illegal to call this if the scheduler is empty or if there are no connection-level +// flow control bytes available. +func (ws *writeScheduler) take() (wm frameWriteMsg, ok bool) { + if ws.maxFrameSize == 0 { + panic("internal error: ws.maxFrameSize not initialized or invalid") + } + + // If there any frames not associated with streams, prefer those first. + // These are usually SETTINGS, etc. + if !ws.zero.empty() { + return ws.zero.shift(), true + } + if len(ws.sq) == 0 { + return + } + + // Next, prioritize frames on streams that aren't DATA frames (no cost). + for id, q := range ws.sq { + if q.firstIsNoCost() { + return ws.takeFrom(id, q) + } + } + + // Now, all that remains are DATA frames with non-zero bytes to + // send. So pick the best one. + if len(ws.canSend) != 0 { + panic("should be empty") + } + for _, q := range ws.sq { + if n := ws.streamWritableBytes(q); n > 0 { + ws.canSend = append(ws.canSend, q) + } + } + if len(ws.canSend) == 0 { + return + } + defer ws.zeroCanSend() + + // TODO: find the best queue + q := ws.canSend[0] + + return ws.takeFrom(q.streamID(), q) +} + +// zeroCanSend is defered from take. +func (ws *writeScheduler) zeroCanSend() { + for i := range ws.canSend { + ws.canSend[i] = nil + } + ws.canSend = ws.canSend[:0] +} + +// streamWritableBytes returns the number of DATA bytes we could write +// from the given queue's stream, if this stream/queue were +// selected. It is an error to call this if q's head isn't a +// *writeData. +func (ws *writeScheduler) streamWritableBytes(q *writeQueue) int32 { + wm := q.head() + ret := wm.stream.flow.available() // max we can write + if ret == 0 { + return 0 + } + if int32(ws.maxFrameSize) < ret { + ret = int32(ws.maxFrameSize) + } + if ret == 0 { + panic("internal error: ws.maxFrameSize not initialized or invalid") + } + wd := wm.write.(*writeData) + if len(wd.p) < int(ret) { + ret = int32(len(wd.p)) + } + return ret +} + +func (ws *writeScheduler) takeFrom(id uint32, q *writeQueue) (wm frameWriteMsg, ok bool) { + wm = q.head() + // If the first item in this queue costs flow control tokens + // and we don't have enough, write as much as we can. + if wd, ok := wm.write.(*writeData); ok && len(wd.p) > 0 { + allowed := wm.stream.flow.available() // max we can write + if allowed == 0 { + // No quota available. Caller can try the next stream. + return frameWriteMsg{}, false + } + if int32(ws.maxFrameSize) < allowed { + allowed = int32(ws.maxFrameSize) + } + // TODO: further restrict the allowed size, because even if + // the peer says it's okay to write 16MB data frames, we might + // want to write smaller ones to properly weight competing + // streams' priorities. + + if len(wd.p) > int(allowed) { + wm.stream.flow.take(allowed) + chunk := wd.p[:allowed] + wd.p = wd.p[allowed:] + // Make up a new write message of a valid size, rather + // than shifting one off the queue. + return frameWriteMsg{ + stream: wm.stream, + write: &writeData{ + streamID: wd.streamID, + p: chunk, + // even if the original had endStream set, there + // arebytes remaining because len(wd.p) > allowed, + // so we know endStream is false: + endStream: false, + }, + // our caller is blocking on the final DATA frame, not + // these intermediates, so no need to wait: + done: nil, + }, true + } + wm.stream.flow.take(int32(len(wd.p))) + } + + q.shift() + if q.empty() { + ws.putEmptyQueue(q) + delete(ws.sq, id) + } + return wm, true +} + +func (ws *writeScheduler) forgetStream(id uint32) { + q, ok := ws.sq[id] + if !ok { + return + } + delete(ws.sq, id) + + // But keep it for others later. + for i := range q.s { + q.s[i] = frameWriteMsg{} + } + q.s = q.s[:0] + ws.putEmptyQueue(q) +} + +type writeQueue struct { + s []frameWriteMsg +} + +// streamID returns the stream ID for a non-empty stream-specific queue. +func (q *writeQueue) streamID() uint32 { return q.s[0].stream.id } + +func (q *writeQueue) empty() bool { return len(q.s) == 0 } + +func (q *writeQueue) push(wm frameWriteMsg) { + q.s = append(q.s, wm) +} + +// head returns the next item that would be removed by shift. +func (q *writeQueue) head() frameWriteMsg { + if len(q.s) == 0 { + panic("invalid use of queue") + } + return q.s[0] +} + +func (q *writeQueue) shift() frameWriteMsg { + if len(q.s) == 0 { + panic("invalid use of queue") + } + wm := q.s[0] + // TODO: less copy-happy queue. + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = frameWriteMsg{} + q.s = q.s[:len(q.s)-1] + return wm +} + +func (q *writeQueue) firstIsNoCost() bool { + if df, ok := q.s[0].write.(*writeData); ok { + return len(df.p) == 0 + } + return true +} diff --git a/Godeps/_workspace/src/github.com/bradfitz/http2/z_spec_test.go b/Godeps/_workspace/src/github.com/bradfitz/http2/z_spec_test.go new file mode 100644 index 0000000..07b53e3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/bradfitz/http2/z_spec_test.go @@ -0,0 +1,357 @@ +// Copyright 2014 The Go Authors. +// See https://code.google.com/p/go/source/browse/CONTRIBUTORS +// Licensed under the same terms as Go itself: +// https://code.google.com/p/go/source/browse/LICENSE + +package http2 + +import ( + "bytes" + "encoding/xml" + "flag" + "fmt" + "io" + "os" + "reflect" + "regexp" + "sort" + "strconv" + "strings" + "sync" + "testing" +) + +var coverSpec = flag.Bool("coverspec", false, "Run spec coverage tests") + +// The global map of sentence coverage for the http2 spec. +var defaultSpecCoverage specCoverage + +var loadSpecOnce sync.Once + +func loadSpec() { + if f, err := os.Open("testdata/draft-ietf-httpbis-http2.xml"); err != nil { + panic(err) + } else { + defaultSpecCoverage = readSpecCov(f) + f.Close() + } +} + +// covers marks all sentences for section sec in defaultSpecCoverage. Sentences not +// "covered" will be included in report outputed by TestSpecCoverage. +func covers(sec, sentences string) { + loadSpecOnce.Do(loadSpec) + defaultSpecCoverage.cover(sec, sentences) +} + +type specPart struct { + section string + sentence string +} + +func (ss specPart) Less(oo specPart) bool { + atoi := func(s string) int { + n, err := strconv.Atoi(s) + if err != nil { + panic(err) + } + return n + } + a := strings.Split(ss.section, ".") + b := strings.Split(oo.section, ".") + for len(a) > 0 { + if len(b) == 0 { + return false + } + x, y := atoi(a[0]), atoi(b[0]) + if x == y { + a, b = a[1:], b[1:] + continue + } + return x < y + } + if len(b) > 0 { + return true + } + return false +} + +type bySpecSection []specPart + +func (a bySpecSection) Len() int { return len(a) } +func (a bySpecSection) Less(i, j int) bool { return a[i].Less(a[j]) } +func (a bySpecSection) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +type specCoverage struct { + coverage map[specPart]bool + d *xml.Decoder +} + +func joinSection(sec []int) string { + s := fmt.Sprintf("%d", sec[0]) + for _, n := range sec[1:] { + s = fmt.Sprintf("%s.%d", s, n) + } + return s +} + +func (sc specCoverage) readSection(sec []int) { + var ( + buf = new(bytes.Buffer) + sub = 0 + ) + for { + tk, err := sc.d.Token() + if err != nil { + if err == io.EOF { + return + } + panic(err) + } + switch v := tk.(type) { + case xml.StartElement: + if skipElement(v) { + if err := sc.d.Skip(); err != nil { + panic(err) + } + if v.Name.Local == "section" { + sub++ + } + break + } + switch v.Name.Local { + case "section": + sub++ + sc.readSection(append(sec, sub)) + case "xref": + buf.Write(sc.readXRef(v)) + } + case xml.CharData: + if len(sec) == 0 { + break + } + buf.Write(v) + case xml.EndElement: + if v.Name.Local == "section" { + sc.addSentences(joinSection(sec), buf.String()) + return + } + } + } +} + +func (sc specCoverage) readXRef(se xml.StartElement) []byte { + var b []byte + for { + tk, err := sc.d.Token() + if err != nil { + panic(err) + } + switch v := tk.(type) { + case xml.CharData: + if b != nil { + panic("unexpected CharData") + } + b = []byte(string(v)) + case xml.EndElement: + if v.Name.Local != "xref" { + panic("expected ") + } + if b != nil { + return b + } + sig := attrSig(se) + switch sig { + case "target": + return []byte(fmt.Sprintf("[%s]", attrValue(se, "target"))) + case "fmt-of,rel,target", "fmt-,,rel,target": + return []byte(fmt.Sprintf("[%s, %s]", attrValue(se, "target"), attrValue(se, "rel"))) + case "fmt-of,sec,target", "fmt-,,sec,target": + return []byte(fmt.Sprintf("[section %s of %s]", attrValue(se, "sec"), attrValue(se, "target"))) + case "fmt-of,rel,sec,target": + return []byte(fmt.Sprintf("[section %s of %s, %s]", attrValue(se, "sec"), attrValue(se, "target"), attrValue(se, "rel"))) + default: + panic(fmt.Sprintf("unknown attribute signature %q in %#v", sig, fmt.Sprintf("%#v", se))) + } + default: + panic(fmt.Sprintf("unexpected tag %q", v)) + } + } +} + +var skipAnchor = map[string]bool{ + "intro": true, + "Overview": true, +} + +var skipTitle = map[string]bool{ + "Acknowledgements": true, + "Change Log": true, + "Document Organization": true, + "Conventions and Terminology": true, +} + +func skipElement(s xml.StartElement) bool { + switch s.Name.Local { + case "artwork": + return true + case "section": + for _, attr := range s.Attr { + switch attr.Name.Local { + case "anchor": + if skipAnchor[attr.Value] || strings.HasPrefix(attr.Value, "changes.since.") { + return true + } + case "title": + if skipTitle[attr.Value] { + return true + } + } + } + } + return false +} + +func readSpecCov(r io.Reader) specCoverage { + sc := specCoverage{ + coverage: map[specPart]bool{}, + d: xml.NewDecoder(r)} + sc.readSection(nil) + return sc +} + +func (sc specCoverage) addSentences(sec string, sentence string) { + for _, s := range parseSentences(sentence) { + sc.coverage[specPart{sec, s}] = false + } +} + +func (sc specCoverage) cover(sec string, sentence string) { + for _, s := range parseSentences(sentence) { + p := specPart{sec, s} + if _, ok := sc.coverage[p]; !ok { + panic(fmt.Sprintf("Not found in spec: %q, %q", sec, s)) + } + sc.coverage[specPart{sec, s}] = true + } + +} + +var whitespaceRx = regexp.MustCompile(`\s+`) + +func parseSentences(sens string) []string { + sens = strings.TrimSpace(sens) + if sens == "" { + return nil + } + ss := strings.Split(whitespaceRx.ReplaceAllString(sens, " "), ". ") + for i, s := range ss { + s = strings.TrimSpace(s) + if !strings.HasSuffix(s, ".") { + s += "." + } + ss[i] = s + } + return ss +} + +func TestSpecParseSentences(t *testing.T) { + tests := []struct { + ss string + want []string + }{ + {"Sentence 1. Sentence 2.", + []string{ + "Sentence 1.", + "Sentence 2.", + }}, + {"Sentence 1. \nSentence 2.\tSentence 3.", + []string{ + "Sentence 1.", + "Sentence 2.", + "Sentence 3.", + }}, + } + + for i, tt := range tests { + got := parseSentences(tt.ss) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("%d: got = %q, want %q", i, got, tt.want) + } + } +} + +func TestSpecCoverage(t *testing.T) { + if !*coverSpec { + t.Skip() + } + + loadSpecOnce.Do(loadSpec) + + var ( + list []specPart + cv = defaultSpecCoverage.coverage + total = len(cv) + complete = 0 + ) + + for sp, touched := range defaultSpecCoverage.coverage { + if touched { + complete++ + } else { + list = append(list, sp) + } + } + sort.Stable(bySpecSection(list)) + + if testing.Short() && len(list) > 5 { + list = list[:5] + } + + for _, p := range list { + t.Errorf("\tSECTION %s: %s", p.section, p.sentence) + } + + t.Logf("%d/%d (%d%%) sentances covered", complete, total, (complete/total)*100) +} + +func attrSig(se xml.StartElement) string { + var names []string + for _, attr := range se.Attr { + if attr.Name.Local == "fmt" { + names = append(names, "fmt-"+attr.Value) + } else { + names = append(names, attr.Name.Local) + } + } + sort.Strings(names) + return strings.Join(names, ",") +} + +func attrValue(se xml.StartElement, attr string) string { + for _, a := range se.Attr { + if a.Name.Local == attr { + return a.Value + } + } + panic("unknown attribute " + attr) +} + +func TestSpecPartLess(t *testing.T) { + tests := []struct { + sec1, sec2 string + want bool + }{ + {"6.2.1", "6.2", false}, + {"6.2", "6.2.1", true}, + {"6.10", "6.10.1", true}, + {"6.10", "6.1.1", false}, // 10, not 1 + {"6.1", "6.1", false}, // equal, so not less + } + for _, tt := range tests { + got := (specPart{tt.sec1, "foo"}).Less(specPart{tt.sec2, "foo"}) + if got != tt.want { + t.Errorf("Less(%q, %q) = %v; want %v", tt.sec1, tt.sec2, got, tt.want) + } + } +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml b/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml new file mode 100644 index 0000000..34d39c8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/.travis.yml @@ -0,0 +1,13 @@ +language: go +sudo: false + +go: +- 1.0.3 +- 1.1.2 +- 1.2.2 +- 1.3.3 +- 1.4.2 + +script: +- go vet ./... +- go test -v ./... diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/LICENSE b/Godeps/_workspace/src/github.com/codegangsta/cli/LICENSE new file mode 100644 index 0000000..5515ccf --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/LICENSE @@ -0,0 +1,21 @@ +Copyright (C) 2013 Jeremy Saenz +All Rights Reserved. + +MIT LICENSE + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/README.md b/Godeps/_workspace/src/github.com/codegangsta/cli/README.md new file mode 100644 index 0000000..85b9cda --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/README.md @@ -0,0 +1,308 @@ +[![Build Status](https://travis-ci.org/codegangsta/cli.png?branch=master)](https://travis-ci.org/codegangsta/cli) + +# cli.go +cli.go is simple, fast, and fun package for building command line apps in Go. The goal is to enable developers to write fast and distributable command line applications in an expressive way. + +You can view the API docs here: +http://godoc.org/github.com/codegangsta/cli + +## Overview +Command line apps are usually so tiny that there is absolutely no reason why your code should *not* be self-documenting. Things like generating help text and parsing command flags/options should not hinder productivity when writing a command line app. + +**This is where cli.go comes into play.** cli.go makes command line programming fun, organized, and expressive! + +## Installation +Make sure you have a working Go environment (go 1.1+ is *required*). [See the install instructions](http://golang.org/doc/install.html). + +To install `cli.go`, simply run: +``` +$ go get github.com/codegangsta/cli +``` + +Make sure your `PATH` includes to the `$GOPATH/bin` directory so your commands can be easily used: +``` +export PATH=$PATH:$GOPATH/bin +``` + +## Getting Started +One of the philosophies behind cli.go is that an API should be playful and full of discovery. So a cli.go app can be as little as one line of code in `main()`. + +``` go +package main + +import ( + "os" + "github.com/codegangsta/cli" +) + +func main() { + cli.NewApp().Run(os.Args) +} +``` + +This app will run and show help text, but is not very useful. Let's give an action to execute and some help documentation: + +``` go +package main + +import ( + "os" + "github.com/codegangsta/cli" +) + +func main() { + app := cli.NewApp() + app.Name = "boom" + app.Usage = "make an explosive entrance" + app.Action = func(c *cli.Context) { + println("boom! I say!") + } + + app.Run(os.Args) +} +``` + +Running this already gives you a ton of functionality, plus support for things like subcommands and flags, which are covered below. + +## Example + +Being a programmer can be a lonely job. Thankfully by the power of automation that is not the case! Let's create a greeter app to fend off our demons of loneliness! + +Start by creating a directory named `greet`, and within it, add a file, `greet.go` with the following code in it: + +``` go +package main + +import ( + "os" + "github.com/codegangsta/cli" +) + +func main() { + app := cli.NewApp() + app.Name = "greet" + app.Usage = "fight the loneliness!" + app.Action = func(c *cli.Context) { + println("Hello friend!") + } + + app.Run(os.Args) +} +``` + +Install our command to the `$GOPATH/bin` directory: + +``` +$ go install +``` + +Finally run our new command: + +``` +$ greet +Hello friend! +``` + +cli.go also generates some bitchass help text: +``` +$ greet help +NAME: + greet - fight the loneliness! + +USAGE: + greet [global options] command [command options] [arguments...] + +VERSION: + 0.0.0 + +COMMANDS: + help, h Shows a list of commands or help for one command + +GLOBAL OPTIONS + --version Shows version information +``` + +### Arguments +You can lookup arguments by calling the `Args` function on `cli.Context`. + +``` go +... +app.Action = func(c *cli.Context) { + println("Hello", c.Args()[0]) +} +... +``` + +### Flags +Setting and querying flags is simple. +``` go +... +app.Flags = []cli.Flag { + cli.StringFlag{ + Name: "lang", + Value: "english", + Usage: "language for the greeting", + }, +} +app.Action = func(c *cli.Context) { + name := "someone" + if len(c.Args()) > 0 { + name = c.Args()[0] + } + if c.String("lang") == "spanish" { + println("Hola", name) + } else { + println("Hello", name) + } +} +... +``` + +See full list of flags at http://godoc.org/github.com/codegangsta/cli + +#### Alternate Names + +You can set alternate (or short) names for flags by providing a comma-delimited list for the `Name`. e.g. + +``` go +app.Flags = []cli.Flag { + cli.StringFlag{ + Name: "lang, l", + Value: "english", + Usage: "language for the greeting", + }, +} +``` + +That flag can then be set with `--lang spanish` or `-l spanish`. Note that giving two different forms of the same flag in the same command invocation is an error. + +#### Values from the Environment + +You can also have the default value set from the environment via `EnvVar`. e.g. + +``` go +app.Flags = []cli.Flag { + cli.StringFlag{ + Name: "lang, l", + Value: "english", + Usage: "language for the greeting", + EnvVar: "APP_LANG", + }, +} +``` + +The `EnvVar` may also be given as a comma-delimited "cascade", where the first environment variable that resolves is used as the default. + +``` go +app.Flags = []cli.Flag { + cli.StringFlag{ + Name: "lang, l", + Value: "english", + Usage: "language for the greeting", + EnvVar: "LEGACY_COMPAT_LANG,APP_LANG,LANG", + }, +} +``` + +### Subcommands + +Subcommands can be defined for a more git-like command line app. +```go +... +app.Commands = []cli.Command{ + { + Name: "add", + Aliases: []string{"a"}, + Usage: "add a task to the list", + Action: func(c *cli.Context) { + println("added task: ", c.Args().First()) + }, + }, + { + Name: "complete", + Aliases: []string{"c"}, + Usage: "complete a task on the list", + Action: func(c *cli.Context) { + println("completed task: ", c.Args().First()) + }, + }, + { + Name: "template", + Aliases: []string{"r"}, + Usage: "options for task templates", + Subcommands: []cli.Command{ + { + Name: "add", + Usage: "add a new template", + Action: func(c *cli.Context) { + println("new task template: ", c.Args().First()) + }, + }, + { + Name: "remove", + Usage: "remove an existing template", + Action: func(c *cli.Context) { + println("removed task template: ", c.Args().First()) + }, + }, + }, + }, +} +... +``` + +### Bash Completion + +You can enable completion commands by setting the `EnableBashCompletion` +flag on the `App` object. By default, this setting will only auto-complete to +show an app's subcommands, but you can write your own completion methods for +the App or its subcommands. +```go +... +var tasks = []string{"cook", "clean", "laundry", "eat", "sleep", "code"} +app := cli.NewApp() +app.EnableBashCompletion = true +app.Commands = []cli.Command{ + { + Name: "complete", + Aliases: []string{"c"}, + Usage: "complete a task on the list", + Action: func(c *cli.Context) { + println("completed task: ", c.Args().First()) + }, + BashComplete: func(c *cli.Context) { + // This will complete if no args are passed + if len(c.Args()) > 0 { + return + } + for _, t := range tasks { + fmt.Println(t) + } + }, + } +} +... +``` + +#### To Enable + +Source the `autocomplete/bash_autocomplete` file in your `.bashrc` file while +setting the `PROG` variable to the name of your program: + +`PROG=myprogram source /.../cli/autocomplete/bash_autocomplete` + +#### To Distribute + +Copy and modify `autocomplete/bash_autocomplete` to use your program name +rather than `$PROG` and have the user copy the file into +`/etc/bash_completion.d/` (or automatically install it there if you are +distributing a package). Alternatively you can just document that users should +source the generic `autocomplete/bash_autocomplete` with `$PROG` set to your +program name in their bash configuration. + +## Contribution Guidelines +Feel free to put up a pull request to fix a bug or maybe add a feature. I will give it a code review and make sure that it does not break backwards compatibility. If I or any other collaborators agree that it is in line with the vision of the project, we will work with you to get the code into a mergeable state and merge it into the master branch. + +If you have contributed something significant to the project, I will most likely add you as a collaborator. As a collaborator you are given the ability to merge others pull requests. It is very important that new code does not break existing code, so be careful about what code you do choose to merge. If you have any questions feel free to link @codegangsta to the issue in question and we can review it together. + +If you feel like you have contributed to the project but have not yet been added as a collaborator, I probably forgot to add you. Hit @codegangsta up over email and we will get it figured out. diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/app.go b/Godeps/_workspace/src/github.com/codegangsta/cli/app.go new file mode 100644 index 0000000..e7caec9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/app.go @@ -0,0 +1,308 @@ +package cli + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "time" +) + +// App is the main structure of a cli application. It is recomended that +// an app be created with the cli.NewApp() function +type App struct { + // The name of the program. Defaults to os.Args[0] + Name string + // Description of the program. + Usage string + // Version of the program + Version string + // List of commands to execute + Commands []Command + // List of flags to parse + Flags []Flag + // Boolean to enable bash completion commands + EnableBashCompletion bool + // Boolean to hide built-in help command + HideHelp bool + // Boolean to hide built-in version flag + HideVersion bool + // An action to execute when the bash-completion flag is set + BashComplete func(context *Context) + // An action to execute before any subcommands are run, but after the context is ready + // If a non-nil error is returned, no subcommands are run + Before func(context *Context) error + // An action to execute after any subcommands are run, but after the subcommand has finished + // It is run even if Action() panics + After func(context *Context) error + // The action to execute when no subcommands are specified + Action func(context *Context) + // Execute this function if the proper command cannot be found + CommandNotFound func(context *Context, command string) + // Compilation date + Compiled time.Time + // List of all authors who contributed + Authors []Author + // Copyright of the binary if any + Copyright string + // Name of Author (Note: Use App.Authors, this is deprecated) + Author string + // Email of Author (Note: Use App.Authors, this is deprecated) + Email string + // Writer writer to write output to + Writer io.Writer +} + +// Tries to find out when this binary was compiled. +// Returns the current time if it fails to find it. +func compileTime() time.Time { + info, err := os.Stat(os.Args[0]) + if err != nil { + return time.Now() + } + return info.ModTime() +} + +// Creates a new cli Application with some reasonable defaults for Name, Usage, Version and Action. +func NewApp() *App { + return &App{ + Name: os.Args[0], + Usage: "A new cli application", + Version: "0.0.0", + BashComplete: DefaultAppComplete, + Action: helpCommand.Action, + Compiled: compileTime(), + Writer: os.Stdout, + } +} + +// Entry point to the cli app. Parses the arguments slice and routes to the proper flag/args combination +func (a *App) Run(arguments []string) (err error) { + if a.Author != "" || a.Email != "" { + a.Authors = append(a.Authors, Author{Name: a.Author, Email: a.Email}) + } + + // append help to commands + if a.Command(helpCommand.Name) == nil && !a.HideHelp { + a.Commands = append(a.Commands, helpCommand) + if (HelpFlag != BoolFlag{}) { + a.appendFlag(HelpFlag) + } + } + + //append version/help flags + if a.EnableBashCompletion { + a.appendFlag(BashCompletionFlag) + } + + if !a.HideVersion { + a.appendFlag(VersionFlag) + } + + // parse flags + set := flagSet(a.Name, a.Flags) + set.SetOutput(ioutil.Discard) + err = set.Parse(arguments[1:]) + nerr := normalizeFlags(a.Flags, set) + if nerr != nil { + fmt.Fprintln(a.Writer, nerr) + context := NewContext(a, set, nil) + ShowAppHelp(context) + return nerr + } + context := NewContext(a, set, nil) + + if err != nil { + fmt.Fprintln(a.Writer, "Incorrect Usage.") + fmt.Fprintln(a.Writer) + ShowAppHelp(context) + return err + } + + if checkCompletions(context) { + return nil + } + + if checkHelp(context) { + return nil + } + + if checkVersion(context) { + return nil + } + + if a.After != nil { + defer func() { + afterErr := a.After(context) + if afterErr != nil { + if err != nil { + err = NewMultiError(err, afterErr) + } else { + err = afterErr + } + } + }() + } + + if a.Before != nil { + err := a.Before(context) + if err != nil { + return err + } + } + + args := context.Args() + if args.Present() { + name := args.First() + c := a.Command(name) + if c != nil { + return c.Run(context) + } + } + + // Run default Action + a.Action(context) + return nil +} + +// Another entry point to the cli app, takes care of passing arguments and error handling +func (a *App) RunAndExitOnError() { + if err := a.Run(os.Args); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +// Invokes the subcommand given the context, parses ctx.Args() to generate command-specific flags +func (a *App) RunAsSubcommand(ctx *Context) (err error) { + // append help to commands + if len(a.Commands) > 0 { + if a.Command(helpCommand.Name) == nil && !a.HideHelp { + a.Commands = append(a.Commands, helpCommand) + if (HelpFlag != BoolFlag{}) { + a.appendFlag(HelpFlag) + } + } + } + + // append flags + if a.EnableBashCompletion { + a.appendFlag(BashCompletionFlag) + } + + // parse flags + set := flagSet(a.Name, a.Flags) + set.SetOutput(ioutil.Discard) + err = set.Parse(ctx.Args().Tail()) + nerr := normalizeFlags(a.Flags, set) + context := NewContext(a, set, ctx) + + if nerr != nil { + fmt.Fprintln(a.Writer, nerr) + fmt.Fprintln(a.Writer) + if len(a.Commands) > 0 { + ShowSubcommandHelp(context) + } else { + ShowCommandHelp(ctx, context.Args().First()) + } + return nerr + } + + if err != nil { + fmt.Fprintln(a.Writer, "Incorrect Usage.") + fmt.Fprintln(a.Writer) + ShowSubcommandHelp(context) + return err + } + + if checkCompletions(context) { + return nil + } + + if len(a.Commands) > 0 { + if checkSubcommandHelp(context) { + return nil + } + } else { + if checkCommandHelp(ctx, context.Args().First()) { + return nil + } + } + + if a.After != nil { + defer func() { + afterErr := a.After(context) + if afterErr != nil { + if err != nil { + err = NewMultiError(err, afterErr) + } else { + err = afterErr + } + } + }() + } + + if a.Before != nil { + err := a.Before(context) + if err != nil { + return err + } + } + + args := context.Args() + if args.Present() { + name := args.First() + c := a.Command(name) + if c != nil { + return c.Run(context) + } + } + + // Run default Action + a.Action(context) + + return nil +} + +// Returns the named command on App. Returns nil if the command does not exist +func (a *App) Command(name string) *Command { + for _, c := range a.Commands { + if c.HasName(name) { + return &c + } + } + + return nil +} + +func (a *App) hasFlag(flag Flag) bool { + for _, f := range a.Flags { + if flag == f { + return true + } + } + + return false +} + +func (a *App) appendFlag(flag Flag) { + if !a.hasFlag(flag) { + a.Flags = append(a.Flags, flag) + } +} + +// Author represents someone who has contributed to a cli project. +type Author struct { + Name string // The Authors name + Email string // The Authors email +} + +// String makes Author comply to the Stringer interface, to allow an easy print in the templating process +func (a Author) String() string { + e := "" + if a.Email != "" { + e = "<" + a.Email + "> " + } + + return fmt.Sprintf("%v %v", a.Name, e) +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/app_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/app_test.go new file mode 100644 index 0000000..4c6787a --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/app_test.go @@ -0,0 +1,867 @@ +package cli + +import ( + "bytes" + "flag" + "fmt" + "io" + "os" + "strings" + "testing" +) + +func ExampleApp() { + // set args for examples sake + os.Args = []string{"greet", "--name", "Jeremy"} + + app := NewApp() + app.Name = "greet" + app.Flags = []Flag{ + StringFlag{Name: "name", Value: "bob", Usage: "a name to say"}, + } + app.Action = func(c *Context) { + fmt.Printf("Hello %v\n", c.String("name")) + } + app.Author = "Harrison" + app.Email = "harrison@lolwut.com" + app.Authors = []Author{Author{Name: "Oliver Allen", Email: "oliver@toyshop.com"}} + app.Run(os.Args) + // Output: + // Hello Jeremy +} + +func ExampleAppSubcommand() { + // set args for examples sake + os.Args = []string{"say", "hi", "english", "--name", "Jeremy"} + app := NewApp() + app.Name = "say" + app.Commands = []Command{ + { + Name: "hello", + Aliases: []string{"hi"}, + Usage: "use it to see a description", + Description: "This is how we describe hello the function", + Subcommands: []Command{ + { + Name: "english", + Aliases: []string{"en"}, + Usage: "sends a greeting in english", + Description: "greets someone in english", + Flags: []Flag{ + StringFlag{ + Name: "name", + Value: "Bob", + Usage: "Name of the person to greet", + }, + }, + Action: func(c *Context) { + fmt.Println("Hello,", c.String("name")) + }, + }, + }, + }, + } + + app.Run(os.Args) + // Output: + // Hello, Jeremy +} + +func ExampleAppHelp() { + // set args for examples sake + os.Args = []string{"greet", "h", "describeit"} + + app := NewApp() + app.Name = "greet" + app.Flags = []Flag{ + StringFlag{Name: "name", Value: "bob", Usage: "a name to say"}, + } + app.Commands = []Command{ + { + Name: "describeit", + Aliases: []string{"d"}, + Usage: "use it to see a description", + Description: "This is how we describe describeit the function", + Action: func(c *Context) { + fmt.Printf("i like to describe things") + }, + }, + } + app.Run(os.Args) + // Output: + // NAME: + // describeit - use it to see a description + // + // USAGE: + // command describeit [arguments...] + // + // DESCRIPTION: + // This is how we describe describeit the function +} + +func ExampleAppBashComplete() { + // set args for examples sake + os.Args = []string{"greet", "--generate-bash-completion"} + + app := NewApp() + app.Name = "greet" + app.EnableBashCompletion = true + app.Commands = []Command{ + { + Name: "describeit", + Aliases: []string{"d"}, + Usage: "use it to see a description", + Description: "This is how we describe describeit the function", + Action: func(c *Context) { + fmt.Printf("i like to describe things") + }, + }, { + Name: "next", + Usage: "next example", + Description: "more stuff to see when generating bash completion", + Action: func(c *Context) { + fmt.Printf("the next example") + }, + }, + } + + app.Run(os.Args) + // Output: + // describeit + // d + // next + // help + // h +} + +func TestApp_Run(t *testing.T) { + s := "" + + app := NewApp() + app.Action = func(c *Context) { + s = s + c.Args().First() + } + + err := app.Run([]string{"command", "foo"}) + expect(t, err, nil) + err = app.Run([]string{"command", "bar"}) + expect(t, err, nil) + expect(t, s, "foobar") +} + +var commandAppTests = []struct { + name string + expected bool +}{ + {"foobar", true}, + {"batbaz", true}, + {"b", true}, + {"f", true}, + {"bat", false}, + {"nothing", false}, +} + +func TestApp_Command(t *testing.T) { + app := NewApp() + fooCommand := Command{Name: "foobar", Aliases: []string{"f"}} + batCommand := Command{Name: "batbaz", Aliases: []string{"b"}} + app.Commands = []Command{ + fooCommand, + batCommand, + } + + for _, test := range commandAppTests { + expect(t, app.Command(test.name) != nil, test.expected) + } +} + +func TestApp_CommandWithArgBeforeFlags(t *testing.T) { + var parsedOption, firstArg string + + app := NewApp() + command := Command{ + Name: "cmd", + Flags: []Flag{ + StringFlag{Name: "option", Value: "", Usage: "some option"}, + }, + Action: func(c *Context) { + parsedOption = c.String("option") + firstArg = c.Args().First() + }, + } + app.Commands = []Command{command} + + app.Run([]string{"", "cmd", "my-arg", "--option", "my-option"}) + + expect(t, parsedOption, "my-option") + expect(t, firstArg, "my-arg") +} + +func TestApp_RunAsSubcommandParseFlags(t *testing.T) { + var context *Context + + a := NewApp() + a.Commands = []Command{ + { + Name: "foo", + Action: func(c *Context) { + context = c + }, + Flags: []Flag{ + StringFlag{ + Name: "lang", + Value: "english", + Usage: "language for the greeting", + }, + }, + Before: func(_ *Context) error { return nil }, + }, + } + a.Run([]string{"", "foo", "--lang", "spanish", "abcd"}) + + expect(t, context.Args().Get(0), "abcd") + expect(t, context.String("lang"), "spanish") +} + +func TestApp_CommandWithFlagBeforeTerminator(t *testing.T) { + var parsedOption string + var args []string + + app := NewApp() + command := Command{ + Name: "cmd", + Flags: []Flag{ + StringFlag{Name: "option", Value: "", Usage: "some option"}, + }, + Action: func(c *Context) { + parsedOption = c.String("option") + args = c.Args() + }, + } + app.Commands = []Command{command} + + app.Run([]string{"", "cmd", "my-arg", "--option", "my-option", "--", "--notARealFlag"}) + + expect(t, parsedOption, "my-option") + expect(t, args[0], "my-arg") + expect(t, args[1], "--") + expect(t, args[2], "--notARealFlag") +} + +func TestApp_CommandWithNoFlagBeforeTerminator(t *testing.T) { + var args []string + + app := NewApp() + command := Command{ + Name: "cmd", + Action: func(c *Context) { + args = c.Args() + }, + } + app.Commands = []Command{command} + + app.Run([]string{"", "cmd", "my-arg", "--", "notAFlagAtAll"}) + + expect(t, args[0], "my-arg") + expect(t, args[1], "--") + expect(t, args[2], "notAFlagAtAll") +} + +func TestApp_Float64Flag(t *testing.T) { + var meters float64 + + app := NewApp() + app.Flags = []Flag{ + Float64Flag{Name: "height", Value: 1.5, Usage: "Set the height, in meters"}, + } + app.Action = func(c *Context) { + meters = c.Float64("height") + } + + app.Run([]string{"", "--height", "1.93"}) + expect(t, meters, 1.93) +} + +func TestApp_ParseSliceFlags(t *testing.T) { + var parsedOption, firstArg string + var parsedIntSlice []int + var parsedStringSlice []string + + app := NewApp() + command := Command{ + Name: "cmd", + Flags: []Flag{ + IntSliceFlag{Name: "p", Value: &IntSlice{}, Usage: "set one or more ip addr"}, + StringSliceFlag{Name: "ip", Value: &StringSlice{}, Usage: "set one or more ports to open"}, + }, + Action: func(c *Context) { + parsedIntSlice = c.IntSlice("p") + parsedStringSlice = c.StringSlice("ip") + parsedOption = c.String("option") + firstArg = c.Args().First() + }, + } + app.Commands = []Command{command} + + app.Run([]string{"", "cmd", "my-arg", "-p", "22", "-p", "80", "-ip", "8.8.8.8", "-ip", "8.8.4.4"}) + + IntsEquals := func(a, b []int) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true + } + + StrsEquals := func(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true + } + var expectedIntSlice = []int{22, 80} + var expectedStringSlice = []string{"8.8.8.8", "8.8.4.4"} + + if !IntsEquals(parsedIntSlice, expectedIntSlice) { + t.Errorf("%v does not match %v", parsedIntSlice, expectedIntSlice) + } + + if !StrsEquals(parsedStringSlice, expectedStringSlice) { + t.Errorf("%v does not match %v", parsedStringSlice, expectedStringSlice) + } +} + +func TestApp_ParseSliceFlagsWithMissingValue(t *testing.T) { + var parsedIntSlice []int + var parsedStringSlice []string + + app := NewApp() + command := Command{ + Name: "cmd", + Flags: []Flag{ + IntSliceFlag{Name: "a", Usage: "set numbers"}, + StringSliceFlag{Name: "str", Usage: "set strings"}, + }, + Action: func(c *Context) { + parsedIntSlice = c.IntSlice("a") + parsedStringSlice = c.StringSlice("str") + }, + } + app.Commands = []Command{command} + + app.Run([]string{"", "cmd", "my-arg", "-a", "2", "-str", "A"}) + + var expectedIntSlice = []int{2} + var expectedStringSlice = []string{"A"} + + if parsedIntSlice[0] != expectedIntSlice[0] { + t.Errorf("%v does not match %v", parsedIntSlice[0], expectedIntSlice[0]) + } + + if parsedStringSlice[0] != expectedStringSlice[0] { + t.Errorf("%v does not match %v", parsedIntSlice[0], expectedIntSlice[0]) + } +} + +func TestApp_DefaultStdout(t *testing.T) { + app := NewApp() + + if app.Writer != os.Stdout { + t.Error("Default output writer not set.") + } +} + +type mockWriter struct { + written []byte +} + +func (fw *mockWriter) Write(p []byte) (n int, err error) { + if fw.written == nil { + fw.written = p + } else { + fw.written = append(fw.written, p...) + } + + return len(p), nil +} + +func (fw *mockWriter) GetWritten() (b []byte) { + return fw.written +} + +func TestApp_SetStdout(t *testing.T) { + w := &mockWriter{} + + app := NewApp() + app.Name = "test" + app.Writer = w + + err := app.Run([]string{"help"}) + + if err != nil { + t.Fatalf("Run error: %s", err) + } + + if len(w.written) == 0 { + t.Error("App did not write output to desired writer.") + } +} + +func TestApp_BeforeFunc(t *testing.T) { + beforeRun, subcommandRun := false, false + beforeError := fmt.Errorf("fail") + var err error + + app := NewApp() + + app.Before = func(c *Context) error { + beforeRun = true + s := c.String("opt") + if s == "fail" { + return beforeError + } + + return nil + } + + app.Commands = []Command{ + Command{ + Name: "sub", + Action: func(c *Context) { + subcommandRun = true + }, + }, + } + + app.Flags = []Flag{ + StringFlag{Name: "opt"}, + } + + // run with the Before() func succeeding + err = app.Run([]string{"command", "--opt", "succeed", "sub"}) + + if err != nil { + t.Fatalf("Run error: %s", err) + } + + if beforeRun == false { + t.Errorf("Before() not executed when expected") + } + + if subcommandRun == false { + t.Errorf("Subcommand not executed when expected") + } + + // reset + beforeRun, subcommandRun = false, false + + // run with the Before() func failing + err = app.Run([]string{"command", "--opt", "fail", "sub"}) + + // should be the same error produced by the Before func + if err != beforeError { + t.Errorf("Run error expected, but not received") + } + + if beforeRun == false { + t.Errorf("Before() not executed when expected") + } + + if subcommandRun == true { + t.Errorf("Subcommand executed when NOT expected") + } + +} + +func TestApp_AfterFunc(t *testing.T) { + afterRun, subcommandRun := false, false + afterError := fmt.Errorf("fail") + var err error + + app := NewApp() + + app.After = func(c *Context) error { + afterRun = true + s := c.String("opt") + if s == "fail" { + return afterError + } + + return nil + } + + app.Commands = []Command{ + Command{ + Name: "sub", + Action: func(c *Context) { + subcommandRun = true + }, + }, + } + + app.Flags = []Flag{ + StringFlag{Name: "opt"}, + } + + // run with the After() func succeeding + err = app.Run([]string{"command", "--opt", "succeed", "sub"}) + + if err != nil { + t.Fatalf("Run error: %s", err) + } + + if afterRun == false { + t.Errorf("After() not executed when expected") + } + + if subcommandRun == false { + t.Errorf("Subcommand not executed when expected") + } + + // reset + afterRun, subcommandRun = false, false + + // run with the Before() func failing + err = app.Run([]string{"command", "--opt", "fail", "sub"}) + + // should be the same error produced by the Before func + if err != afterError { + t.Errorf("Run error expected, but not received") + } + + if afterRun == false { + t.Errorf("After() not executed when expected") + } + + if subcommandRun == false { + t.Errorf("Subcommand not executed when expected") + } +} + +func TestAppNoHelpFlag(t *testing.T) { + oldFlag := HelpFlag + defer func() { + HelpFlag = oldFlag + }() + + HelpFlag = BoolFlag{} + + app := NewApp() + err := app.Run([]string{"test", "-h"}) + + if err != flag.ErrHelp { + t.Errorf("expected error about missing help flag, but got: %s (%T)", err, err) + } +} + +func TestAppHelpPrinter(t *testing.T) { + oldPrinter := HelpPrinter + defer func() { + HelpPrinter = oldPrinter + }() + + var wasCalled = false + HelpPrinter = func(w io.Writer, template string, data interface{}) { + wasCalled = true + } + + app := NewApp() + app.Run([]string{"-h"}) + + if wasCalled == false { + t.Errorf("Help printer expected to be called, but was not") + } +} + +func TestAppVersionPrinter(t *testing.T) { + oldPrinter := VersionPrinter + defer func() { + VersionPrinter = oldPrinter + }() + + var wasCalled = false + VersionPrinter = func(c *Context) { + wasCalled = true + } + + app := NewApp() + ctx := NewContext(app, nil, nil) + ShowVersion(ctx) + + if wasCalled == false { + t.Errorf("Version printer expected to be called, but was not") + } +} + +func TestAppCommandNotFound(t *testing.T) { + beforeRun, subcommandRun := false, false + app := NewApp() + + app.CommandNotFound = func(c *Context, command string) { + beforeRun = true + } + + app.Commands = []Command{ + Command{ + Name: "bar", + Action: func(c *Context) { + subcommandRun = true + }, + }, + } + + app.Run([]string{"command", "foo"}) + + expect(t, beforeRun, true) + expect(t, subcommandRun, false) +} + +func TestGlobalFlag(t *testing.T) { + var globalFlag string + var globalFlagSet bool + app := NewApp() + app.Flags = []Flag{ + StringFlag{Name: "global, g", Usage: "global"}, + } + app.Action = func(c *Context) { + globalFlag = c.GlobalString("global") + globalFlagSet = c.GlobalIsSet("global") + } + app.Run([]string{"command", "-g", "foo"}) + expect(t, globalFlag, "foo") + expect(t, globalFlagSet, true) + +} + +func TestGlobalFlagsInSubcommands(t *testing.T) { + subcommandRun := false + parentFlag := false + app := NewApp() + + app.Flags = []Flag{ + BoolFlag{Name: "debug, d", Usage: "Enable debugging"}, + } + + app.Commands = []Command{ + Command{ + Name: "foo", + Flags: []Flag{ + BoolFlag{Name: "parent, p", Usage: "Parent flag"}, + }, + Subcommands: []Command{ + { + Name: "bar", + Action: func(c *Context) { + if c.GlobalBool("debug") { + subcommandRun = true + } + if c.GlobalBool("parent") { + parentFlag = true + } + }, + }, + }, + }, + } + + app.Run([]string{"command", "-d", "foo", "-p", "bar"}) + + expect(t, subcommandRun, true) + expect(t, parentFlag, true) +} + +func TestApp_Run_CommandWithSubcommandHasHelpTopic(t *testing.T) { + var subcommandHelpTopics = [][]string{ + {"command", "foo", "--help"}, + {"command", "foo", "-h"}, + {"command", "foo", "help"}, + } + + for _, flagSet := range subcommandHelpTopics { + t.Logf("==> checking with flags %v", flagSet) + + app := NewApp() + buf := new(bytes.Buffer) + app.Writer = buf + + subCmdBar := Command{ + Name: "bar", + Usage: "does bar things", + } + subCmdBaz := Command{ + Name: "baz", + Usage: "does baz things", + } + cmd := Command{ + Name: "foo", + Description: "descriptive wall of text about how it does foo things", + Subcommands: []Command{subCmdBar, subCmdBaz}, + } + + app.Commands = []Command{cmd} + err := app.Run(flagSet) + + if err != nil { + t.Error(err) + } + + output := buf.String() + t.Logf("output: %q\n", buf.Bytes()) + + if strings.Contains(output, "No help topic for") { + t.Errorf("expect a help topic, got none: \n%q", output) + } + + for _, shouldContain := range []string{ + cmd.Name, cmd.Description, + subCmdBar.Name, subCmdBar.Usage, + subCmdBaz.Name, subCmdBaz.Usage, + } { + if !strings.Contains(output, shouldContain) { + t.Errorf("want help to contain %q, did not: \n%q", shouldContain, output) + } + } + } +} + +func TestApp_Run_SubcommandFullPath(t *testing.T) { + app := NewApp() + buf := new(bytes.Buffer) + app.Writer = buf + + subCmd := Command{ + Name: "bar", + Usage: "does bar things", + } + cmd := Command{ + Name: "foo", + Description: "foo commands", + Subcommands: []Command{subCmd}, + } + app.Commands = []Command{cmd} + + err := app.Run([]string{"command", "foo", "bar", "--help"}) + if err != nil { + t.Error(err) + } + + output := buf.String() + if !strings.Contains(output, "foo bar - does bar things") { + t.Errorf("expected full path to subcommand: %s", output) + } + if !strings.Contains(output, "command foo bar [arguments...]") { + t.Errorf("expected full path to subcommand: %s", output) + } +} + +func TestApp_Run_Help(t *testing.T) { + var helpArguments = [][]string{{"boom", "--help"}, {"boom", "-h"}, {"boom", "help"}} + + for _, args := range helpArguments { + buf := new(bytes.Buffer) + + t.Logf("==> checking with arguments %v", args) + + app := NewApp() + app.Name = "boom" + app.Usage = "make an explosive entrance" + app.Writer = buf + app.Action = func(c *Context) { + buf.WriteString("boom I say!") + } + + err := app.Run(args) + if err != nil { + t.Error(err) + } + + output := buf.String() + t.Logf("output: %q\n", buf.Bytes()) + + if !strings.Contains(output, "boom - make an explosive entrance") { + t.Errorf("want help to contain %q, did not: \n%q", "boom - make an explosive entrance", output) + } + } +} + +func TestApp_Run_Version(t *testing.T) { + var versionArguments = [][]string{{"boom", "--version"}, {"boom", "-v"}} + + for _, args := range versionArguments { + buf := new(bytes.Buffer) + + t.Logf("==> checking with arguments %v", args) + + app := NewApp() + app.Name = "boom" + app.Usage = "make an explosive entrance" + app.Version = "0.1.0" + app.Writer = buf + app.Action = func(c *Context) { + buf.WriteString("boom I say!") + } + + err := app.Run(args) + if err != nil { + t.Error(err) + } + + output := buf.String() + t.Logf("output: %q\n", buf.Bytes()) + + if !strings.Contains(output, "0.1.0") { + t.Errorf("want version to contain %q, did not: \n%q", "0.1.0", output) + } + } +} + +func TestApp_Run_DoesNotOverwriteErrorFromBefore(t *testing.T) { + app := NewApp() + app.Action = func(c *Context) {} + app.Before = func(c *Context) error { return fmt.Errorf("before error") } + app.After = func(c *Context) error { return fmt.Errorf("after error") } + + err := app.Run([]string{"foo"}) + if err == nil { + t.Fatalf("expected to recieve error from Run, got none") + } + + if !strings.Contains(err.Error(), "before error") { + t.Errorf("expected text of error from Before method, but got none in \"%v\"", err) + } + if !strings.Contains(err.Error(), "after error") { + t.Errorf("expected text of error from After method, but got none in \"%v\"", err) + } +} + +func TestApp_Run_SubcommandDoesNotOverwriteErrorFromBefore(t *testing.T) { + app := NewApp() + app.Commands = []Command{ + Command{ + Name: "bar", + Before: func(c *Context) error { return fmt.Errorf("before error") }, + After: func(c *Context) error { return fmt.Errorf("after error") }, + }, + } + + err := app.Run([]string{"foo", "bar"}) + if err == nil { + t.Fatalf("expected to recieve error from Run, got none") + } + + if !strings.Contains(err.Error(), "before error") { + t.Errorf("expected text of error from Before method, but got none in \"%v\"", err) + } + if !strings.Contains(err.Error(), "after error") { + t.Errorf("expected text of error from After method, but got none in \"%v\"", err) + } +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/bash_autocomplete b/Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/bash_autocomplete new file mode 100644 index 0000000..9b55dd9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/bash_autocomplete @@ -0,0 +1,13 @@ +#! /bin/bash + +_cli_bash_autocomplete() { + local cur prev opts base + COMPREPLY=() + cur="${COMP_WORDS[COMP_CWORD]}" + prev="${COMP_WORDS[COMP_CWORD-1]}" + opts=$( ${COMP_WORDS[@]:0:$COMP_CWORD} --generate-bash-completion ) + COMPREPLY=( $(compgen -W "${opts}" -- ${cur}) ) + return 0 + } + + complete -F _cli_bash_autocomplete $PROG \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/zsh_autocomplete b/Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/zsh_autocomplete new file mode 100644 index 0000000..5430a18 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/autocomplete/zsh_autocomplete @@ -0,0 +1,5 @@ +autoload -U compinit && compinit +autoload -U bashcompinit && bashcompinit + +script_dir=$(dirname $0) +source ${script_dir}/bash_autocomplete diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go b/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go new file mode 100644 index 0000000..31dc912 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/cli.go @@ -0,0 +1,40 @@ +// Package cli provides a minimal framework for creating and organizing command line +// Go applications. cli is designed to be easy to understand and write, the most simple +// cli application can be written as follows: +// func main() { +// cli.NewApp().Run(os.Args) +// } +// +// Of course this application does not do much, so let's make this an actual application: +// func main() { +// app := cli.NewApp() +// app.Name = "greet" +// app.Usage = "say a greeting" +// app.Action = func(c *cli.Context) { +// println("Greetings") +// } +// +// app.Run(os.Args) +// } +package cli + +import ( + "strings" +) + +type MultiError struct { + Errors []error +} + +func NewMultiError(err ...error) MultiError { + return MultiError{Errors: err} +} + +func (m MultiError) Error() string { + errs := make([]string, len(m.Errors)) + for i, err := range m.Errors { + errs[i] = err.Error() + } + + return strings.Join(errs, "\n") +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/cli_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/cli_test.go new file mode 100644 index 0000000..e54f8e2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/cli_test.go @@ -0,0 +1,98 @@ +package cli + +import ( + "os" +) + +func Example() { + app := NewApp() + app.Name = "todo" + app.Usage = "task list on the command line" + app.Commands = []Command{ + { + Name: "add", + Aliases: []string{"a"}, + Usage: "add a task to the list", + Action: func(c *Context) { + println("added task: ", c.Args().First()) + }, + }, + { + Name: "complete", + Aliases: []string{"c"}, + Usage: "complete a task on the list", + Action: func(c *Context) { + println("completed task: ", c.Args().First()) + }, + }, + } + + app.Run(os.Args) +} + +func ExampleSubcommand() { + app := NewApp() + app.Name = "say" + app.Commands = []Command{ + { + Name: "hello", + Aliases: []string{"hi"}, + Usage: "use it to see a description", + Description: "This is how we describe hello the function", + Subcommands: []Command{ + { + Name: "english", + Aliases: []string{"en"}, + Usage: "sends a greeting in english", + Description: "greets someone in english", + Flags: []Flag{ + StringFlag{ + Name: "name", + Value: "Bob", + Usage: "Name of the person to greet", + }, + }, + Action: func(c *Context) { + println("Hello, ", c.String("name")) + }, + }, { + Name: "spanish", + Aliases: []string{"sp"}, + Usage: "sends a greeting in spanish", + Flags: []Flag{ + StringFlag{ + Name: "surname", + Value: "Jones", + Usage: "Surname of the person to greet", + }, + }, + Action: func(c *Context) { + println("Hola, ", c.String("surname")) + }, + }, { + Name: "french", + Aliases: []string{"fr"}, + Usage: "sends a greeting in french", + Flags: []Flag{ + StringFlag{ + Name: "nickname", + Value: "Stevie", + Usage: "Nickname of the person to greet", + }, + }, + Action: func(c *Context) { + println("Bonjour, ", c.String("nickname")) + }, + }, + }, + }, { + Name: "bye", + Usage: "says goodbye", + Action: func(c *Context) { + println("bye") + }, + }, + } + + app.Run(os.Args) +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/command.go b/Godeps/_workspace/src/github.com/codegangsta/cli/command.go new file mode 100644 index 0000000..54617af --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/command.go @@ -0,0 +1,200 @@ +package cli + +import ( + "fmt" + "io/ioutil" + "strings" +) + +// Command is a subcommand for a cli.App. +type Command struct { + // The name of the command + Name string + // short name of the command. Typically one character (deprecated, use `Aliases`) + ShortName string + // A list of aliases for the command + Aliases []string + // A short description of the usage of this command + Usage string + // A longer explanation of how the command works + Description string + // The function to call when checking for bash command completions + BashComplete func(context *Context) + // An action to execute before any sub-subcommands are run, but after the context is ready + // If a non-nil error is returned, no sub-subcommands are run + Before func(context *Context) error + // An action to execute after any subcommands are run, but after the subcommand has finished + // It is run even if Action() panics + After func(context *Context) error + // The function to call when this command is invoked + Action func(context *Context) + // List of child commands + Subcommands []Command + // List of flags to parse + Flags []Flag + // Treat all flags as normal arguments if true + SkipFlagParsing bool + // Boolean to hide built-in help command + HideHelp bool + + commandNamePath []string +} + +// Returns the full name of the command. +// For subcommands this ensures that parent commands are part of the command path +func (c Command) FullName() string { + if c.commandNamePath == nil { + return c.Name + } + return strings.Join(c.commandNamePath, " ") +} + +// Invokes the command given the context, parses ctx.Args() to generate command-specific flags +func (c Command) Run(ctx *Context) error { + if len(c.Subcommands) > 0 || c.Before != nil || c.After != nil { + return c.startApp(ctx) + } + + if !c.HideHelp && (HelpFlag != BoolFlag{}) { + // append help to flags + c.Flags = append( + c.Flags, + HelpFlag, + ) + } + + if ctx.App.EnableBashCompletion { + c.Flags = append(c.Flags, BashCompletionFlag) + } + + set := flagSet(c.Name, c.Flags) + set.SetOutput(ioutil.Discard) + + firstFlagIndex := -1 + terminatorIndex := -1 + for index, arg := range ctx.Args() { + if arg == "--" { + terminatorIndex = index + break + } else if strings.HasPrefix(arg, "-") && firstFlagIndex == -1 { + firstFlagIndex = index + } + } + + var err error + if firstFlagIndex > -1 && !c.SkipFlagParsing { + args := ctx.Args() + regularArgs := make([]string, len(args[1:firstFlagIndex])) + copy(regularArgs, args[1:firstFlagIndex]) + + var flagArgs []string + if terminatorIndex > -1 { + flagArgs = args[firstFlagIndex:terminatorIndex] + regularArgs = append(regularArgs, args[terminatorIndex:]...) + } else { + flagArgs = args[firstFlagIndex:] + } + + err = set.Parse(append(flagArgs, regularArgs...)) + } else { + err = set.Parse(ctx.Args().Tail()) + } + + if err != nil { + fmt.Fprintln(ctx.App.Writer, "Incorrect Usage.") + fmt.Fprintln(ctx.App.Writer) + ShowCommandHelp(ctx, c.Name) + return err + } + + nerr := normalizeFlags(c.Flags, set) + if nerr != nil { + fmt.Fprintln(ctx.App.Writer, nerr) + fmt.Fprintln(ctx.App.Writer) + ShowCommandHelp(ctx, c.Name) + return nerr + } + context := NewContext(ctx.App, set, ctx) + + if checkCommandCompletions(context, c.Name) { + return nil + } + + if checkCommandHelp(context, c.Name) { + return nil + } + context.Command = c + c.Action(context) + return nil +} + +func (c Command) Names() []string { + names := []string{c.Name} + + if c.ShortName != "" { + names = append(names, c.ShortName) + } + + return append(names, c.Aliases...) +} + +// Returns true if Command.Name or Command.ShortName matches given name +func (c Command) HasName(name string) bool { + for _, n := range c.Names() { + if n == name { + return true + } + } + return false +} + +func (c Command) startApp(ctx *Context) error { + app := NewApp() + + // set the name and usage + app.Name = fmt.Sprintf("%s %s", ctx.App.Name, c.Name) + if c.Description != "" { + app.Usage = c.Description + } else { + app.Usage = c.Usage + } + + // set CommandNotFound + app.CommandNotFound = ctx.App.CommandNotFound + + // set the flags and commands + app.Commands = c.Subcommands + app.Flags = c.Flags + app.HideHelp = c.HideHelp + + app.Version = ctx.App.Version + app.HideVersion = ctx.App.HideVersion + app.Compiled = ctx.App.Compiled + app.Author = ctx.App.Author + app.Email = ctx.App.Email + app.Writer = ctx.App.Writer + + // bash completion + app.EnableBashCompletion = ctx.App.EnableBashCompletion + if c.BashComplete != nil { + app.BashComplete = c.BashComplete + } + + // set the actions + app.Before = c.Before + app.After = c.After + if c.Action != nil { + app.Action = c.Action + } else { + app.Action = helpSubcommand.Action + } + + var newCmds []Command + for _, cc := range app.Commands { + cc.commandNamePath = []string{c.Name, cc.Name} + newCmds = append(newCmds, cc) + } + app.Commands = newCmds + + return app.RunAsSubcommand(ctx) +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/command_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/command_test.go new file mode 100644 index 0000000..688d12c --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/command_test.go @@ -0,0 +1,47 @@ +package cli + +import ( + "flag" + "testing" +) + +func TestCommandDoNotIgnoreFlags(t *testing.T) { + app := NewApp() + set := flag.NewFlagSet("test", 0) + test := []string{"blah", "blah", "-break"} + set.Parse(test) + + c := NewContext(app, set, nil) + + command := Command{ + Name: "test-cmd", + Aliases: []string{"tc"}, + Usage: "this is for testing", + Description: "testing", + Action: func(_ *Context) {}, + } + err := command.Run(c) + + expect(t, err.Error(), "flag provided but not defined: -break") +} + +func TestCommandIgnoreFlags(t *testing.T) { + app := NewApp() + set := flag.NewFlagSet("test", 0) + test := []string{"blah", "blah"} + set.Parse(test) + + c := NewContext(app, set, nil) + + command := Command{ + Name: "test-cmd", + Aliases: []string{"tc"}, + Usage: "this is for testing", + Description: "testing", + Action: func(_ *Context) {}, + SkipFlagParsing: true, + } + err := command.Run(c) + + expect(t, err, nil) +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/context.go b/Godeps/_workspace/src/github.com/codegangsta/cli/context.go new file mode 100644 index 0000000..f541f41 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/context.go @@ -0,0 +1,388 @@ +package cli + +import ( + "errors" + "flag" + "strconv" + "strings" + "time" +) + +// Context is a type that is passed through to +// each Handler action in a cli application. Context +// can be used to retrieve context-specific Args and +// parsed command-line options. +type Context struct { + App *App + Command Command + flagSet *flag.FlagSet + setFlags map[string]bool + globalSetFlags map[string]bool + parentContext *Context +} + +// Creates a new context. For use in when invoking an App or Command action. +func NewContext(app *App, set *flag.FlagSet, parentCtx *Context) *Context { + return &Context{App: app, flagSet: set, parentContext: parentCtx} +} + +// Looks up the value of a local int flag, returns 0 if no int flag exists +func (c *Context) Int(name string) int { + return lookupInt(name, c.flagSet) +} + +// Looks up the value of a local time.Duration flag, returns 0 if no time.Duration flag exists +func (c *Context) Duration(name string) time.Duration { + return lookupDuration(name, c.flagSet) +} + +// Looks up the value of a local float64 flag, returns 0 if no float64 flag exists +func (c *Context) Float64(name string) float64 { + return lookupFloat64(name, c.flagSet) +} + +// Looks up the value of a local bool flag, returns false if no bool flag exists +func (c *Context) Bool(name string) bool { + return lookupBool(name, c.flagSet) +} + +// Looks up the value of a local boolT flag, returns false if no bool flag exists +func (c *Context) BoolT(name string) bool { + return lookupBoolT(name, c.flagSet) +} + +// Looks up the value of a local string flag, returns "" if no string flag exists +func (c *Context) String(name string) string { + return lookupString(name, c.flagSet) +} + +// Looks up the value of a local string slice flag, returns nil if no string slice flag exists +func (c *Context) StringSlice(name string) []string { + return lookupStringSlice(name, c.flagSet) +} + +// Looks up the value of a local int slice flag, returns nil if no int slice flag exists +func (c *Context) IntSlice(name string) []int { + return lookupIntSlice(name, c.flagSet) +} + +// Looks up the value of a local generic flag, returns nil if no generic flag exists +func (c *Context) Generic(name string) interface{} { + return lookupGeneric(name, c.flagSet) +} + +// Looks up the value of a global int flag, returns 0 if no int flag exists +func (c *Context) GlobalInt(name string) int { + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupInt(name, fs) + } + return 0 +} + +// Looks up the value of a global time.Duration flag, returns 0 if no time.Duration flag exists +func (c *Context) GlobalDuration(name string) time.Duration { + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupDuration(name, fs) + } + return 0 +} + +// Looks up the value of a global bool flag, returns false if no bool flag exists +func (c *Context) GlobalBool(name string) bool { + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupBool(name, fs) + } + return false +} + +// Looks up the value of a global string flag, returns "" if no string flag exists +func (c *Context) GlobalString(name string) string { + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupString(name, fs) + } + return "" +} + +// Looks up the value of a global string slice flag, returns nil if no string slice flag exists +func (c *Context) GlobalStringSlice(name string) []string { + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupStringSlice(name, fs) + } + return nil +} + +// Looks up the value of a global int slice flag, returns nil if no int slice flag exists +func (c *Context) GlobalIntSlice(name string) []int { + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupIntSlice(name, fs) + } + return nil +} + +// Looks up the value of a global generic flag, returns nil if no generic flag exists +func (c *Context) GlobalGeneric(name string) interface{} { + if fs := lookupGlobalFlagSet(name, c); fs != nil { + return lookupGeneric(name, fs) + } + return nil +} + +// Returns the number of flags set +func (c *Context) NumFlags() int { + return c.flagSet.NFlag() +} + +// Determines if the flag was actually set +func (c *Context) IsSet(name string) bool { + if c.setFlags == nil { + c.setFlags = make(map[string]bool) + c.flagSet.Visit(func(f *flag.Flag) { + c.setFlags[f.Name] = true + }) + } + return c.setFlags[name] == true +} + +// Determines if the global flag was actually set +func (c *Context) GlobalIsSet(name string) bool { + if c.globalSetFlags == nil { + c.globalSetFlags = make(map[string]bool) + ctx := c + if ctx.parentContext != nil { + ctx = ctx.parentContext + } + for ; ctx != nil && c.globalSetFlags[name] == false; ctx = ctx.parentContext { + ctx.flagSet.Visit(func(f *flag.Flag) { + c.globalSetFlags[f.Name] = true + }) + } + } + return c.globalSetFlags[name] +} + +// Returns a slice of flag names used in this context. +func (c *Context) FlagNames() (names []string) { + for _, flag := range c.Command.Flags { + name := strings.Split(flag.getName(), ",")[0] + if name == "help" { + continue + } + names = append(names, name) + } + return +} + +// Returns a slice of global flag names used by the app. +func (c *Context) GlobalFlagNames() (names []string) { + for _, flag := range c.App.Flags { + name := strings.Split(flag.getName(), ",")[0] + if name == "help" || name == "version" { + continue + } + names = append(names, name) + } + return +} + +// Returns the parent context, if any +func (c *Context) Parent() *Context { + return c.parentContext +} + +type Args []string + +// Returns the command line arguments associated with the context. +func (c *Context) Args() Args { + args := Args(c.flagSet.Args()) + return args +} + +// Returns the nth argument, or else a blank string +func (a Args) Get(n int) string { + if len(a) > n { + return a[n] + } + return "" +} + +// Returns the first argument, or else a blank string +func (a Args) First() string { + return a.Get(0) +} + +// Return the rest of the arguments (not the first one) +// or else an empty string slice +func (a Args) Tail() []string { + if len(a) >= 2 { + return []string(a)[1:] + } + return []string{} +} + +// Checks if there are any arguments present +func (a Args) Present() bool { + return len(a) != 0 +} + +// Swaps arguments at the given indexes +func (a Args) Swap(from, to int) error { + if from >= len(a) || to >= len(a) { + return errors.New("index out of range") + } + a[from], a[to] = a[to], a[from] + return nil +} + +func lookupGlobalFlagSet(name string, ctx *Context) *flag.FlagSet { + if ctx.parentContext != nil { + ctx = ctx.parentContext + } + for ; ctx != nil; ctx = ctx.parentContext { + if f := ctx.flagSet.Lookup(name); f != nil { + return ctx.flagSet + } + } + return nil +} + +func lookupInt(name string, set *flag.FlagSet) int { + f := set.Lookup(name) + if f != nil { + val, err := strconv.Atoi(f.Value.String()) + if err != nil { + return 0 + } + return val + } + + return 0 +} + +func lookupDuration(name string, set *flag.FlagSet) time.Duration { + f := set.Lookup(name) + if f != nil { + val, err := time.ParseDuration(f.Value.String()) + if err == nil { + return val + } + } + + return 0 +} + +func lookupFloat64(name string, set *flag.FlagSet) float64 { + f := set.Lookup(name) + if f != nil { + val, err := strconv.ParseFloat(f.Value.String(), 64) + if err != nil { + return 0 + } + return val + } + + return 0 +} + +func lookupString(name string, set *flag.FlagSet) string { + f := set.Lookup(name) + if f != nil { + return f.Value.String() + } + + return "" +} + +func lookupStringSlice(name string, set *flag.FlagSet) []string { + f := set.Lookup(name) + if f != nil { + return (f.Value.(*StringSlice)).Value() + + } + + return nil +} + +func lookupIntSlice(name string, set *flag.FlagSet) []int { + f := set.Lookup(name) + if f != nil { + return (f.Value.(*IntSlice)).Value() + + } + + return nil +} + +func lookupGeneric(name string, set *flag.FlagSet) interface{} { + f := set.Lookup(name) + if f != nil { + return f.Value + } + return nil +} + +func lookupBool(name string, set *flag.FlagSet) bool { + f := set.Lookup(name) + if f != nil { + val, err := strconv.ParseBool(f.Value.String()) + if err != nil { + return false + } + return val + } + + return false +} + +func lookupBoolT(name string, set *flag.FlagSet) bool { + f := set.Lookup(name) + if f != nil { + val, err := strconv.ParseBool(f.Value.String()) + if err != nil { + return true + } + return val + } + + return false +} + +func copyFlag(name string, ff *flag.Flag, set *flag.FlagSet) { + switch ff.Value.(type) { + case *StringSlice: + default: + set.Set(name, ff.Value.String()) + } +} + +func normalizeFlags(flags []Flag, set *flag.FlagSet) error { + visited := make(map[string]bool) + set.Visit(func(f *flag.Flag) { + visited[f.Name] = true + }) + for _, f := range flags { + parts := strings.Split(f.getName(), ",") + if len(parts) == 1 { + continue + } + var ff *flag.Flag + for _, name := range parts { + name = strings.Trim(name, " ") + if visited[name] { + if ff != nil { + return errors.New("Cannot use two forms of the same flag: " + name + " " + ff.Name) + } + ff = set.Lookup(name) + } + } + if ff == nil { + continue + } + for _, name := range parts { + name = strings.Trim(name, " ") + if !visited[name] { + copyFlag(name, ff, set) + } + } + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/context_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/context_test.go new file mode 100644 index 0000000..7f8e928 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/context_test.go @@ -0,0 +1,113 @@ +package cli + +import ( + "flag" + "testing" + "time" +) + +func TestNewContext(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Int("myflag", 12, "doc") + globalSet := flag.NewFlagSet("test", 0) + globalSet.Int("myflag", 42, "doc") + globalCtx := NewContext(nil, globalSet, nil) + command := Command{Name: "mycommand"} + c := NewContext(nil, set, globalCtx) + c.Command = command + expect(t, c.Int("myflag"), 12) + expect(t, c.GlobalInt("myflag"), 42) + expect(t, c.Command.Name, "mycommand") +} + +func TestContext_Int(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Int("myflag", 12, "doc") + c := NewContext(nil, set, nil) + expect(t, c.Int("myflag"), 12) +} + +func TestContext_Duration(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Duration("myflag", time.Duration(12*time.Second), "doc") + c := NewContext(nil, set, nil) + expect(t, c.Duration("myflag"), time.Duration(12*time.Second)) +} + +func TestContext_String(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.String("myflag", "hello world", "doc") + c := NewContext(nil, set, nil) + expect(t, c.String("myflag"), "hello world") +} + +func TestContext_Bool(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Bool("myflag", false, "doc") + c := NewContext(nil, set, nil) + expect(t, c.Bool("myflag"), false) +} + +func TestContext_BoolT(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Bool("myflag", true, "doc") + c := NewContext(nil, set, nil) + expect(t, c.BoolT("myflag"), true) +} + +func TestContext_Args(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Bool("myflag", false, "doc") + c := NewContext(nil, set, nil) + set.Parse([]string{"--myflag", "bat", "baz"}) + expect(t, len(c.Args()), 2) + expect(t, c.Bool("myflag"), true) +} + +func TestContext_IsSet(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Bool("myflag", false, "doc") + set.String("otherflag", "hello world", "doc") + globalSet := flag.NewFlagSet("test", 0) + globalSet.Bool("myflagGlobal", true, "doc") + globalCtx := NewContext(nil, globalSet, nil) + c := NewContext(nil, set, globalCtx) + set.Parse([]string{"--myflag", "bat", "baz"}) + globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"}) + expect(t, c.IsSet("myflag"), true) + expect(t, c.IsSet("otherflag"), false) + expect(t, c.IsSet("bogusflag"), false) + expect(t, c.IsSet("myflagGlobal"), false) +} + +func TestContext_GlobalIsSet(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Bool("myflag", false, "doc") + set.String("otherflag", "hello world", "doc") + globalSet := flag.NewFlagSet("test", 0) + globalSet.Bool("myflagGlobal", true, "doc") + globalSet.Bool("myflagGlobalUnset", true, "doc") + globalCtx := NewContext(nil, globalSet, nil) + c := NewContext(nil, set, globalCtx) + set.Parse([]string{"--myflag", "bat", "baz"}) + globalSet.Parse([]string{"--myflagGlobal", "bat", "baz"}) + expect(t, c.GlobalIsSet("myflag"), false) + expect(t, c.GlobalIsSet("otherflag"), false) + expect(t, c.GlobalIsSet("bogusflag"), false) + expect(t, c.GlobalIsSet("myflagGlobal"), true) + expect(t, c.GlobalIsSet("myflagGlobalUnset"), false) + expect(t, c.GlobalIsSet("bogusGlobal"), false) +} + +func TestContext_NumFlags(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.Bool("myflag", false, "doc") + set.String("otherflag", "hello world", "doc") + globalSet := flag.NewFlagSet("test", 0) + globalSet.Bool("myflagGlobal", true, "doc") + globalCtx := NewContext(nil, globalSet, nil) + c := NewContext(nil, set, globalCtx) + set.Parse([]string{"--myflag", "--otherflag=foo"}) + globalSet.Parse([]string{"--myflagGlobal"}) + expect(t, c.NumFlags(), 2) +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go b/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go new file mode 100644 index 0000000..531b091 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/flag.go @@ -0,0 +1,497 @@ +package cli + +import ( + "flag" + "fmt" + "os" + "strconv" + "strings" + "time" +) + +// This flag enables bash-completion for all commands and subcommands +var BashCompletionFlag = BoolFlag{ + Name: "generate-bash-completion", +} + +// This flag prints the version for the application +var VersionFlag = BoolFlag{ + Name: "version, v", + Usage: "print the version", +} + +// This flag prints the help for all commands and subcommands +// Set to the zero value (BoolFlag{}) to disable flag -- keeps subcommand +// unless HideHelp is set to true) +var HelpFlag = BoolFlag{ + Name: "help, h", + Usage: "show help", +} + +// Flag is a common interface related to parsing flags in cli. +// For more advanced flag parsing techniques, it is recomended that +// this interface be implemented. +type Flag interface { + fmt.Stringer + // Apply Flag settings to the given flag set + Apply(*flag.FlagSet) + getName() string +} + +func flagSet(name string, flags []Flag) *flag.FlagSet { + set := flag.NewFlagSet(name, flag.ContinueOnError) + + for _, f := range flags { + f.Apply(set) + } + return set +} + +func eachName(longName string, fn func(string)) { + parts := strings.Split(longName, ",") + for _, name := range parts { + name = strings.Trim(name, " ") + fn(name) + } +} + +// Generic is a generic parseable type identified by a specific flag +type Generic interface { + Set(value string) error + String() string +} + +// GenericFlag is the flag type for types implementing Generic +type GenericFlag struct { + Name string + Value Generic + Usage string + EnvVar string +} + +// String returns the string representation of the generic flag to display the +// help text to the user (uses the String() method of the generic flag to show +// the value) +func (f GenericFlag) String() string { + return withEnvHint(f.EnvVar, fmt.Sprintf("%s%s \"%v\"\t%v", prefixFor(f.Name), f.Name, f.Value, f.Usage)) +} + +// Apply takes the flagset and calls Set on the generic flag with the value +// provided by the user for parsing by the flag +func (f GenericFlag) Apply(set *flag.FlagSet) { + val := f.Value + if f.EnvVar != "" { + for _, envVar := range strings.Split(f.EnvVar, ",") { + envVar = strings.TrimSpace(envVar) + if envVal := os.Getenv(envVar); envVal != "" { + val.Set(envVal) + break + } + } + } + + eachName(f.Name, func(name string) { + set.Var(f.Value, name, f.Usage) + }) +} + +func (f GenericFlag) getName() string { + return f.Name +} + +// StringSlice is an opaque type for []string to satisfy flag.Value +type StringSlice []string + +// Set appends the string value to the list of values +func (f *StringSlice) Set(value string) error { + *f = append(*f, value) + return nil +} + +// String returns a readable representation of this value (for usage defaults) +func (f *StringSlice) String() string { + return fmt.Sprintf("%s", *f) +} + +// Value returns the slice of strings set by this flag +func (f *StringSlice) Value() []string { + return *f +} + +// StringSlice is a string flag that can be specified multiple times on the +// command-line +type StringSliceFlag struct { + Name string + Value *StringSlice + Usage string + EnvVar string +} + +// String returns the usage +func (f StringSliceFlag) String() string { + firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") + pref := prefixFor(firstName) + return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) +} + +// Apply populates the flag given the flag set and environment +func (f StringSliceFlag) Apply(set *flag.FlagSet) { + if f.EnvVar != "" { + for _, envVar := range strings.Split(f.EnvVar, ",") { + envVar = strings.TrimSpace(envVar) + if envVal := os.Getenv(envVar); envVal != "" { + newVal := &StringSlice{} + for _, s := range strings.Split(envVal, ",") { + s = strings.TrimSpace(s) + newVal.Set(s) + } + f.Value = newVal + break + } + } + } + + eachName(f.Name, func(name string) { + if f.Value == nil { + f.Value = &StringSlice{} + } + set.Var(f.Value, name, f.Usage) + }) +} + +func (f StringSliceFlag) getName() string { + return f.Name +} + +// StringSlice is an opaque type for []int to satisfy flag.Value +type IntSlice []int + +// Set parses the value into an integer and appends it to the list of values +func (f *IntSlice) Set(value string) error { + tmp, err := strconv.Atoi(value) + if err != nil { + return err + } else { + *f = append(*f, tmp) + } + return nil +} + +// String returns a readable representation of this value (for usage defaults) +func (f *IntSlice) String() string { + return fmt.Sprintf("%d", *f) +} + +// Value returns the slice of ints set by this flag +func (f *IntSlice) Value() []int { + return *f +} + +// IntSliceFlag is an int flag that can be specified multiple times on the +// command-line +type IntSliceFlag struct { + Name string + Value *IntSlice + Usage string + EnvVar string +} + +// String returns the usage +func (f IntSliceFlag) String() string { + firstName := strings.Trim(strings.Split(f.Name, ",")[0], " ") + pref := prefixFor(firstName) + return withEnvHint(f.EnvVar, fmt.Sprintf("%s [%v]\t%v", prefixedNames(f.Name), pref+firstName+" option "+pref+firstName+" option", f.Usage)) +} + +// Apply populates the flag given the flag set and environment +func (f IntSliceFlag) Apply(set *flag.FlagSet) { + if f.EnvVar != "" { + for _, envVar := range strings.Split(f.EnvVar, ",") { + envVar = strings.TrimSpace(envVar) + if envVal := os.Getenv(envVar); envVal != "" { + newVal := &IntSlice{} + for _, s := range strings.Split(envVal, ",") { + s = strings.TrimSpace(s) + err := newVal.Set(s) + if err != nil { + fmt.Fprintf(os.Stderr, err.Error()) + } + } + f.Value = newVal + break + } + } + } + + eachName(f.Name, func(name string) { + if f.Value == nil { + f.Value = &IntSlice{} + } + set.Var(f.Value, name, f.Usage) + }) +} + +func (f IntSliceFlag) getName() string { + return f.Name +} + +// BoolFlag is a switch that defaults to false +type BoolFlag struct { + Name string + Usage string + EnvVar string +} + +// String returns a readable representation of this value (for usage defaults) +func (f BoolFlag) String() string { + return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) +} + +// Apply populates the flag given the flag set and environment +func (f BoolFlag) Apply(set *flag.FlagSet) { + val := false + if f.EnvVar != "" { + for _, envVar := range strings.Split(f.EnvVar, ",") { + envVar = strings.TrimSpace(envVar) + if envVal := os.Getenv(envVar); envVal != "" { + envValBool, err := strconv.ParseBool(envVal) + if err == nil { + val = envValBool + } + break + } + } + } + + eachName(f.Name, func(name string) { + set.Bool(name, val, f.Usage) + }) +} + +func (f BoolFlag) getName() string { + return f.Name +} + +// BoolTFlag this represents a boolean flag that is true by default, but can +// still be set to false by --some-flag=false +type BoolTFlag struct { + Name string + Usage string + EnvVar string +} + +// String returns a readable representation of this value (for usage defaults) +func (f BoolTFlag) String() string { + return withEnvHint(f.EnvVar, fmt.Sprintf("%s\t%v", prefixedNames(f.Name), f.Usage)) +} + +// Apply populates the flag given the flag set and environment +func (f BoolTFlag) Apply(set *flag.FlagSet) { + val := true + if f.EnvVar != "" { + for _, envVar := range strings.Split(f.EnvVar, ",") { + envVar = strings.TrimSpace(envVar) + if envVal := os.Getenv(envVar); envVal != "" { + envValBool, err := strconv.ParseBool(envVal) + if err == nil { + val = envValBool + break + } + } + } + } + + eachName(f.Name, func(name string) { + set.Bool(name, val, f.Usage) + }) +} + +func (f BoolTFlag) getName() string { + return f.Name +} + +// StringFlag represents a flag that takes as string value +type StringFlag struct { + Name string + Value string + Usage string + EnvVar string +} + +// String returns the usage +func (f StringFlag) String() string { + var fmtString string + fmtString = "%s %v\t%v" + + if len(f.Value) > 0 { + fmtString = "%s \"%v\"\t%v" + } else { + fmtString = "%s %v\t%v" + } + + return withEnvHint(f.EnvVar, fmt.Sprintf(fmtString, prefixedNames(f.Name), f.Value, f.Usage)) +} + +// Apply populates the flag given the flag set and environment +func (f StringFlag) Apply(set *flag.FlagSet) { + if f.EnvVar != "" { + for _, envVar := range strings.Split(f.EnvVar, ",") { + envVar = strings.TrimSpace(envVar) + if envVal := os.Getenv(envVar); envVal != "" { + f.Value = envVal + break + } + } + } + + eachName(f.Name, func(name string) { + set.String(name, f.Value, f.Usage) + }) +} + +func (f StringFlag) getName() string { + return f.Name +} + +// IntFlag is a flag that takes an integer +// Errors if the value provided cannot be parsed +type IntFlag struct { + Name string + Value int + Usage string + EnvVar string +} + +// String returns the usage +func (f IntFlag) String() string { + return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) +} + +// Apply populates the flag given the flag set and environment +func (f IntFlag) Apply(set *flag.FlagSet) { + if f.EnvVar != "" { + for _, envVar := range strings.Split(f.EnvVar, ",") { + envVar = strings.TrimSpace(envVar) + if envVal := os.Getenv(envVar); envVal != "" { + envValInt, err := strconv.ParseInt(envVal, 0, 64) + if err == nil { + f.Value = int(envValInt) + break + } + } + } + } + + eachName(f.Name, func(name string) { + set.Int(name, f.Value, f.Usage) + }) +} + +func (f IntFlag) getName() string { + return f.Name +} + +// DurationFlag is a flag that takes a duration specified in Go's duration +// format: https://golang.org/pkg/time/#ParseDuration +type DurationFlag struct { + Name string + Value time.Duration + Usage string + EnvVar string +} + +// String returns a readable representation of this value (for usage defaults) +func (f DurationFlag) String() string { + return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) +} + +// Apply populates the flag given the flag set and environment +func (f DurationFlag) Apply(set *flag.FlagSet) { + if f.EnvVar != "" { + for _, envVar := range strings.Split(f.EnvVar, ",") { + envVar = strings.TrimSpace(envVar) + if envVal := os.Getenv(envVar); envVal != "" { + envValDuration, err := time.ParseDuration(envVal) + if err == nil { + f.Value = envValDuration + break + } + } + } + } + + eachName(f.Name, func(name string) { + set.Duration(name, f.Value, f.Usage) + }) +} + +func (f DurationFlag) getName() string { + return f.Name +} + +// Float64Flag is a flag that takes an float value +// Errors if the value provided cannot be parsed +type Float64Flag struct { + Name string + Value float64 + Usage string + EnvVar string +} + +// String returns the usage +func (f Float64Flag) String() string { + return withEnvHint(f.EnvVar, fmt.Sprintf("%s \"%v\"\t%v", prefixedNames(f.Name), f.Value, f.Usage)) +} + +// Apply populates the flag given the flag set and environment +func (f Float64Flag) Apply(set *flag.FlagSet) { + if f.EnvVar != "" { + for _, envVar := range strings.Split(f.EnvVar, ",") { + envVar = strings.TrimSpace(envVar) + if envVal := os.Getenv(envVar); envVal != "" { + envValFloat, err := strconv.ParseFloat(envVal, 10) + if err == nil { + f.Value = float64(envValFloat) + } + } + } + } + + eachName(f.Name, func(name string) { + set.Float64(name, f.Value, f.Usage) + }) +} + +func (f Float64Flag) getName() string { + return f.Name +} + +func prefixFor(name string) (prefix string) { + if len(name) == 1 { + prefix = "-" + } else { + prefix = "--" + } + + return +} + +func prefixedNames(fullName string) (prefixed string) { + parts := strings.Split(fullName, ",") + for i, name := range parts { + name = strings.Trim(name, " ") + prefixed += prefixFor(name) + name + if i < len(parts)-1 { + prefixed += ", " + } + } + return +} + +func withEnvHint(envVar, str string) string { + envText := "" + if envVar != "" { + envText = fmt.Sprintf(" [$%s]", strings.Join(strings.Split(envVar, ","), ", $")) + } + return str + envText +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/flag_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/flag_test.go new file mode 100644 index 0000000..3606102 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/flag_test.go @@ -0,0 +1,740 @@ +package cli + +import ( + "fmt" + "os" + "reflect" + "strings" + "testing" +) + +var boolFlagTests = []struct { + name string + expected string +}{ + {"help", "--help\t"}, + {"h", "-h\t"}, +} + +func TestBoolFlagHelpOutput(t *testing.T) { + + for _, test := range boolFlagTests { + flag := BoolFlag{Name: test.name} + output := flag.String() + + if output != test.expected { + t.Errorf("%s does not match %s", output, test.expected) + } + } +} + +var stringFlagTests = []struct { + name string + value string + expected string +}{ + {"help", "", "--help \t"}, + {"h", "", "-h \t"}, + {"h", "", "-h \t"}, + {"test", "Something", "--test \"Something\"\t"}, +} + +func TestStringFlagHelpOutput(t *testing.T) { + + for _, test := range stringFlagTests { + flag := StringFlag{Name: test.name, Value: test.value} + output := flag.String() + + if output != test.expected { + t.Errorf("%s does not match %s", output, test.expected) + } + } +} + +func TestStringFlagWithEnvVarHelpOutput(t *testing.T) { + os.Clearenv() + os.Setenv("APP_FOO", "derp") + for _, test := range stringFlagTests { + flag := StringFlag{Name: test.name, Value: test.value, EnvVar: "APP_FOO"} + output := flag.String() + + if !strings.HasSuffix(output, " [$APP_FOO]") { + t.Errorf("%s does not end with [$APP_FOO]", output) + } + } +} + +var stringSliceFlagTests = []struct { + name string + value *StringSlice + expected string +}{ + {"help", func() *StringSlice { + s := &StringSlice{} + s.Set("") + return s + }(), "--help [--help option --help option]\t"}, + {"h", func() *StringSlice { + s := &StringSlice{} + s.Set("") + return s + }(), "-h [-h option -h option]\t"}, + {"h", func() *StringSlice { + s := &StringSlice{} + s.Set("") + return s + }(), "-h [-h option -h option]\t"}, + {"test", func() *StringSlice { + s := &StringSlice{} + s.Set("Something") + return s + }(), "--test [--test option --test option]\t"}, +} + +func TestStringSliceFlagHelpOutput(t *testing.T) { + + for _, test := range stringSliceFlagTests { + flag := StringSliceFlag{Name: test.name, Value: test.value} + output := flag.String() + + if output != test.expected { + t.Errorf("%q does not match %q", output, test.expected) + } + } +} + +func TestStringSliceFlagWithEnvVarHelpOutput(t *testing.T) { + os.Clearenv() + os.Setenv("APP_QWWX", "11,4") + for _, test := range stringSliceFlagTests { + flag := StringSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_QWWX"} + output := flag.String() + + if !strings.HasSuffix(output, " [$APP_QWWX]") { + t.Errorf("%q does not end with [$APP_QWWX]", output) + } + } +} + +var intFlagTests = []struct { + name string + expected string +}{ + {"help", "--help \"0\"\t"}, + {"h", "-h \"0\"\t"}, +} + +func TestIntFlagHelpOutput(t *testing.T) { + + for _, test := range intFlagTests { + flag := IntFlag{Name: test.name} + output := flag.String() + + if output != test.expected { + t.Errorf("%s does not match %s", output, test.expected) + } + } +} + +func TestIntFlagWithEnvVarHelpOutput(t *testing.T) { + os.Clearenv() + os.Setenv("APP_BAR", "2") + for _, test := range intFlagTests { + flag := IntFlag{Name: test.name, EnvVar: "APP_BAR"} + output := flag.String() + + if !strings.HasSuffix(output, " [$APP_BAR]") { + t.Errorf("%s does not end with [$APP_BAR]", output) + } + } +} + +var durationFlagTests = []struct { + name string + expected string +}{ + {"help", "--help \"0\"\t"}, + {"h", "-h \"0\"\t"}, +} + +func TestDurationFlagHelpOutput(t *testing.T) { + + for _, test := range durationFlagTests { + flag := DurationFlag{Name: test.name} + output := flag.String() + + if output != test.expected { + t.Errorf("%s does not match %s", output, test.expected) + } + } +} + +func TestDurationFlagWithEnvVarHelpOutput(t *testing.T) { + os.Clearenv() + os.Setenv("APP_BAR", "2h3m6s") + for _, test := range durationFlagTests { + flag := DurationFlag{Name: test.name, EnvVar: "APP_BAR"} + output := flag.String() + + if !strings.HasSuffix(output, " [$APP_BAR]") { + t.Errorf("%s does not end with [$APP_BAR]", output) + } + } +} + +var intSliceFlagTests = []struct { + name string + value *IntSlice + expected string +}{ + {"help", &IntSlice{}, "--help [--help option --help option]\t"}, + {"h", &IntSlice{}, "-h [-h option -h option]\t"}, + {"h", &IntSlice{}, "-h [-h option -h option]\t"}, + {"test", func() *IntSlice { + i := &IntSlice{} + i.Set("9") + return i + }(), "--test [--test option --test option]\t"}, +} + +func TestIntSliceFlagHelpOutput(t *testing.T) { + + for _, test := range intSliceFlagTests { + flag := IntSliceFlag{Name: test.name, Value: test.value} + output := flag.String() + + if output != test.expected { + t.Errorf("%q does not match %q", output, test.expected) + } + } +} + +func TestIntSliceFlagWithEnvVarHelpOutput(t *testing.T) { + os.Clearenv() + os.Setenv("APP_SMURF", "42,3") + for _, test := range intSliceFlagTests { + flag := IntSliceFlag{Name: test.name, Value: test.value, EnvVar: "APP_SMURF"} + output := flag.String() + + if !strings.HasSuffix(output, " [$APP_SMURF]") { + t.Errorf("%q does not end with [$APP_SMURF]", output) + } + } +} + +var float64FlagTests = []struct { + name string + expected string +}{ + {"help", "--help \"0\"\t"}, + {"h", "-h \"0\"\t"}, +} + +func TestFloat64FlagHelpOutput(t *testing.T) { + + for _, test := range float64FlagTests { + flag := Float64Flag{Name: test.name} + output := flag.String() + + if output != test.expected { + t.Errorf("%s does not match %s", output, test.expected) + } + } +} + +func TestFloat64FlagWithEnvVarHelpOutput(t *testing.T) { + os.Clearenv() + os.Setenv("APP_BAZ", "99.4") + for _, test := range float64FlagTests { + flag := Float64Flag{Name: test.name, EnvVar: "APP_BAZ"} + output := flag.String() + + if !strings.HasSuffix(output, " [$APP_BAZ]") { + t.Errorf("%s does not end with [$APP_BAZ]", output) + } + } +} + +var genericFlagTests = []struct { + name string + value Generic + expected string +}{ + {"test", &Parser{"abc", "def"}, "--test \"abc,def\"\ttest flag"}, + {"t", &Parser{"abc", "def"}, "-t \"abc,def\"\ttest flag"}, +} + +func TestGenericFlagHelpOutput(t *testing.T) { + + for _, test := range genericFlagTests { + flag := GenericFlag{Name: test.name, Value: test.value, Usage: "test flag"} + output := flag.String() + + if output != test.expected { + t.Errorf("%q does not match %q", output, test.expected) + } + } +} + +func TestGenericFlagWithEnvVarHelpOutput(t *testing.T) { + os.Clearenv() + os.Setenv("APP_ZAP", "3") + for _, test := range genericFlagTests { + flag := GenericFlag{Name: test.name, EnvVar: "APP_ZAP"} + output := flag.String() + + if !strings.HasSuffix(output, " [$APP_ZAP]") { + t.Errorf("%s does not end with [$APP_ZAP]", output) + } + } +} + +func TestParseMultiString(t *testing.T) { + (&App{ + Flags: []Flag{ + StringFlag{Name: "serve, s"}, + }, + Action: func(ctx *Context) { + if ctx.String("serve") != "10" { + t.Errorf("main name not set") + } + if ctx.String("s") != "10" { + t.Errorf("short name not set") + } + }, + }).Run([]string{"run", "-s", "10"}) +} + +func TestParseMultiStringFromEnv(t *testing.T) { + os.Clearenv() + os.Setenv("APP_COUNT", "20") + (&App{ + Flags: []Flag{ + StringFlag{Name: "count, c", EnvVar: "APP_COUNT"}, + }, + Action: func(ctx *Context) { + if ctx.String("count") != "20" { + t.Errorf("main name not set") + } + if ctx.String("c") != "20" { + t.Errorf("short name not set") + } + }, + }).Run([]string{"run"}) +} + +func TestParseMultiStringFromEnvCascade(t *testing.T) { + os.Clearenv() + os.Setenv("APP_COUNT", "20") + (&App{ + Flags: []Flag{ + StringFlag{Name: "count, c", EnvVar: "COMPAT_COUNT,APP_COUNT"}, + }, + Action: func(ctx *Context) { + if ctx.String("count") != "20" { + t.Errorf("main name not set") + } + if ctx.String("c") != "20" { + t.Errorf("short name not set") + } + }, + }).Run([]string{"run"}) +} + +func TestParseMultiStringSlice(t *testing.T) { + (&App{ + Flags: []Flag{ + StringSliceFlag{Name: "serve, s", Value: &StringSlice{}}, + }, + Action: func(ctx *Context) { + if !reflect.DeepEqual(ctx.StringSlice("serve"), []string{"10", "20"}) { + t.Errorf("main name not set") + } + if !reflect.DeepEqual(ctx.StringSlice("s"), []string{"10", "20"}) { + t.Errorf("short name not set") + } + }, + }).Run([]string{"run", "-s", "10", "-s", "20"}) +} + +func TestParseMultiStringSliceFromEnv(t *testing.T) { + os.Clearenv() + os.Setenv("APP_INTERVALS", "20,30,40") + + (&App{ + Flags: []Flag{ + StringSliceFlag{Name: "intervals, i", Value: &StringSlice{}, EnvVar: "APP_INTERVALS"}, + }, + Action: func(ctx *Context) { + if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) { + t.Errorf("main name not set from env") + } + if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) { + t.Errorf("short name not set from env") + } + }, + }).Run([]string{"run"}) +} + +func TestParseMultiStringSliceFromEnvCascade(t *testing.T) { + os.Clearenv() + os.Setenv("APP_INTERVALS", "20,30,40") + + (&App{ + Flags: []Flag{ + StringSliceFlag{Name: "intervals, i", Value: &StringSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"}, + }, + Action: func(ctx *Context) { + if !reflect.DeepEqual(ctx.StringSlice("intervals"), []string{"20", "30", "40"}) { + t.Errorf("main name not set from env") + } + if !reflect.DeepEqual(ctx.StringSlice("i"), []string{"20", "30", "40"}) { + t.Errorf("short name not set from env") + } + }, + }).Run([]string{"run"}) +} + +func TestParseMultiInt(t *testing.T) { + a := App{ + Flags: []Flag{ + IntFlag{Name: "serve, s"}, + }, + Action: func(ctx *Context) { + if ctx.Int("serve") != 10 { + t.Errorf("main name not set") + } + if ctx.Int("s") != 10 { + t.Errorf("short name not set") + } + }, + } + a.Run([]string{"run", "-s", "10"}) +} + +func TestParseMultiIntFromEnv(t *testing.T) { + os.Clearenv() + os.Setenv("APP_TIMEOUT_SECONDS", "10") + a := App{ + Flags: []Flag{ + IntFlag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"}, + }, + Action: func(ctx *Context) { + if ctx.Int("timeout") != 10 { + t.Errorf("main name not set") + } + if ctx.Int("t") != 10 { + t.Errorf("short name not set") + } + }, + } + a.Run([]string{"run"}) +} + +func TestParseMultiIntFromEnvCascade(t *testing.T) { + os.Clearenv() + os.Setenv("APP_TIMEOUT_SECONDS", "10") + a := App{ + Flags: []Flag{ + IntFlag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"}, + }, + Action: func(ctx *Context) { + if ctx.Int("timeout") != 10 { + t.Errorf("main name not set") + } + if ctx.Int("t") != 10 { + t.Errorf("short name not set") + } + }, + } + a.Run([]string{"run"}) +} + +func TestParseMultiIntSlice(t *testing.T) { + (&App{ + Flags: []Flag{ + IntSliceFlag{Name: "serve, s", Value: &IntSlice{}}, + }, + Action: func(ctx *Context) { + if !reflect.DeepEqual(ctx.IntSlice("serve"), []int{10, 20}) { + t.Errorf("main name not set") + } + if !reflect.DeepEqual(ctx.IntSlice("s"), []int{10, 20}) { + t.Errorf("short name not set") + } + }, + }).Run([]string{"run", "-s", "10", "-s", "20"}) +} + +func TestParseMultiIntSliceFromEnv(t *testing.T) { + os.Clearenv() + os.Setenv("APP_INTERVALS", "20,30,40") + + (&App{ + Flags: []Flag{ + IntSliceFlag{Name: "intervals, i", Value: &IntSlice{}, EnvVar: "APP_INTERVALS"}, + }, + Action: func(ctx *Context) { + if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) { + t.Errorf("main name not set from env") + } + if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) { + t.Errorf("short name not set from env") + } + }, + }).Run([]string{"run"}) +} + +func TestParseMultiIntSliceFromEnvCascade(t *testing.T) { + os.Clearenv() + os.Setenv("APP_INTERVALS", "20,30,40") + + (&App{ + Flags: []Flag{ + IntSliceFlag{Name: "intervals, i", Value: &IntSlice{}, EnvVar: "COMPAT_INTERVALS,APP_INTERVALS"}, + }, + Action: func(ctx *Context) { + if !reflect.DeepEqual(ctx.IntSlice("intervals"), []int{20, 30, 40}) { + t.Errorf("main name not set from env") + } + if !reflect.DeepEqual(ctx.IntSlice("i"), []int{20, 30, 40}) { + t.Errorf("short name not set from env") + } + }, + }).Run([]string{"run"}) +} + +func TestParseMultiFloat64(t *testing.T) { + a := App{ + Flags: []Flag{ + Float64Flag{Name: "serve, s"}, + }, + Action: func(ctx *Context) { + if ctx.Float64("serve") != 10.2 { + t.Errorf("main name not set") + } + if ctx.Float64("s") != 10.2 { + t.Errorf("short name not set") + } + }, + } + a.Run([]string{"run", "-s", "10.2"}) +} + +func TestParseMultiFloat64FromEnv(t *testing.T) { + os.Clearenv() + os.Setenv("APP_TIMEOUT_SECONDS", "15.5") + a := App{ + Flags: []Flag{ + Float64Flag{Name: "timeout, t", EnvVar: "APP_TIMEOUT_SECONDS"}, + }, + Action: func(ctx *Context) { + if ctx.Float64("timeout") != 15.5 { + t.Errorf("main name not set") + } + if ctx.Float64("t") != 15.5 { + t.Errorf("short name not set") + } + }, + } + a.Run([]string{"run"}) +} + +func TestParseMultiFloat64FromEnvCascade(t *testing.T) { + os.Clearenv() + os.Setenv("APP_TIMEOUT_SECONDS", "15.5") + a := App{ + Flags: []Flag{ + Float64Flag{Name: "timeout, t", EnvVar: "COMPAT_TIMEOUT_SECONDS,APP_TIMEOUT_SECONDS"}, + }, + Action: func(ctx *Context) { + if ctx.Float64("timeout") != 15.5 { + t.Errorf("main name not set") + } + if ctx.Float64("t") != 15.5 { + t.Errorf("short name not set") + } + }, + } + a.Run([]string{"run"}) +} + +func TestParseMultiBool(t *testing.T) { + a := App{ + Flags: []Flag{ + BoolFlag{Name: "serve, s"}, + }, + Action: func(ctx *Context) { + if ctx.Bool("serve") != true { + t.Errorf("main name not set") + } + if ctx.Bool("s") != true { + t.Errorf("short name not set") + } + }, + } + a.Run([]string{"run", "--serve"}) +} + +func TestParseMultiBoolFromEnv(t *testing.T) { + os.Clearenv() + os.Setenv("APP_DEBUG", "1") + a := App{ + Flags: []Flag{ + BoolFlag{Name: "debug, d", EnvVar: "APP_DEBUG"}, + }, + Action: func(ctx *Context) { + if ctx.Bool("debug") != true { + t.Errorf("main name not set from env") + } + if ctx.Bool("d") != true { + t.Errorf("short name not set from env") + } + }, + } + a.Run([]string{"run"}) +} + +func TestParseMultiBoolFromEnvCascade(t *testing.T) { + os.Clearenv() + os.Setenv("APP_DEBUG", "1") + a := App{ + Flags: []Flag{ + BoolFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"}, + }, + Action: func(ctx *Context) { + if ctx.Bool("debug") != true { + t.Errorf("main name not set from env") + } + if ctx.Bool("d") != true { + t.Errorf("short name not set from env") + } + }, + } + a.Run([]string{"run"}) +} + +func TestParseMultiBoolT(t *testing.T) { + a := App{ + Flags: []Flag{ + BoolTFlag{Name: "serve, s"}, + }, + Action: func(ctx *Context) { + if ctx.BoolT("serve") != true { + t.Errorf("main name not set") + } + if ctx.BoolT("s") != true { + t.Errorf("short name not set") + } + }, + } + a.Run([]string{"run", "--serve"}) +} + +func TestParseMultiBoolTFromEnv(t *testing.T) { + os.Clearenv() + os.Setenv("APP_DEBUG", "0") + a := App{ + Flags: []Flag{ + BoolTFlag{Name: "debug, d", EnvVar: "APP_DEBUG"}, + }, + Action: func(ctx *Context) { + if ctx.BoolT("debug") != false { + t.Errorf("main name not set from env") + } + if ctx.BoolT("d") != false { + t.Errorf("short name not set from env") + } + }, + } + a.Run([]string{"run"}) +} + +func TestParseMultiBoolTFromEnvCascade(t *testing.T) { + os.Clearenv() + os.Setenv("APP_DEBUG", "0") + a := App{ + Flags: []Flag{ + BoolTFlag{Name: "debug, d", EnvVar: "COMPAT_DEBUG,APP_DEBUG"}, + }, + Action: func(ctx *Context) { + if ctx.BoolT("debug") != false { + t.Errorf("main name not set from env") + } + if ctx.BoolT("d") != false { + t.Errorf("short name not set from env") + } + }, + } + a.Run([]string{"run"}) +} + +type Parser [2]string + +func (p *Parser) Set(value string) error { + parts := strings.Split(value, ",") + if len(parts) != 2 { + return fmt.Errorf("invalid format") + } + + (*p)[0] = parts[0] + (*p)[1] = parts[1] + + return nil +} + +func (p *Parser) String() string { + return fmt.Sprintf("%s,%s", p[0], p[1]) +} + +func TestParseGeneric(t *testing.T) { + a := App{ + Flags: []Flag{ + GenericFlag{Name: "serve, s", Value: &Parser{}}, + }, + Action: func(ctx *Context) { + if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"10", "20"}) { + t.Errorf("main name not set") + } + if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"10", "20"}) { + t.Errorf("short name not set") + } + }, + } + a.Run([]string{"run", "-s", "10,20"}) +} + +func TestParseGenericFromEnv(t *testing.T) { + os.Clearenv() + os.Setenv("APP_SERVE", "20,30") + a := App{ + Flags: []Flag{ + GenericFlag{Name: "serve, s", Value: &Parser{}, EnvVar: "APP_SERVE"}, + }, + Action: func(ctx *Context) { + if !reflect.DeepEqual(ctx.Generic("serve"), &Parser{"20", "30"}) { + t.Errorf("main name not set from env") + } + if !reflect.DeepEqual(ctx.Generic("s"), &Parser{"20", "30"}) { + t.Errorf("short name not set from env") + } + }, + } + a.Run([]string{"run"}) +} + +func TestParseGenericFromEnvCascade(t *testing.T) { + os.Clearenv() + os.Setenv("APP_FOO", "99,2000") + a := App{ + Flags: []Flag{ + GenericFlag{Name: "foos", Value: &Parser{}, EnvVar: "COMPAT_FOO,APP_FOO"}, + }, + Action: func(ctx *Context) { + if !reflect.DeepEqual(ctx.Generic("foos"), &Parser{"99", "2000"}) { + t.Errorf("value not set from env") + } + }, + } + a.Run([]string{"run"}) +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/help.go b/Godeps/_workspace/src/github.com/codegangsta/cli/help.go new file mode 100644 index 0000000..66ef2fb --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/help.go @@ -0,0 +1,238 @@ +package cli + +import ( + "fmt" + "io" + "strings" + "text/tabwriter" + "text/template" +) + +// The text template for the Default help topic. +// cli.go uses text/template to render templates. You can +// render custom help text by setting this variable. +var AppHelpTemplate = `NAME: + {{.Name}} - {{.Usage}} + +USAGE: + {{.Name}} {{if .Flags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} [arguments...] + {{if .Version}} +VERSION: + {{.Version}} + {{end}}{{if len .Authors}} +AUTHOR(S): + {{range .Authors}}{{ . }}{{end}} + {{end}}{{if .Commands}} +COMMANDS: + {{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}} + {{end}}{{end}}{{if .Flags}} +GLOBAL OPTIONS: + {{range .Flags}}{{.}} + {{end}}{{end}}{{if .Copyright }} +COPYRIGHT: + {{.Copyright}} + {{end}} +` + +// The text template for the command help topic. +// cli.go uses text/template to render templates. You can +// render custom help text by setting this variable. +var CommandHelpTemplate = `NAME: + {{.FullName}} - {{.Usage}} + +USAGE: + command {{.FullName}}{{if .Flags}} [command options]{{end}} [arguments...]{{if .Description}} + +DESCRIPTION: + {{.Description}}{{end}}{{if .Flags}} + +OPTIONS: + {{range .Flags}}{{.}} + {{end}}{{ end }} +` + +// The text template for the subcommand help topic. +// cli.go uses text/template to render templates. You can +// render custom help text by setting this variable. +var SubcommandHelpTemplate = `NAME: + {{.Name}} - {{.Usage}} + +USAGE: + {{.Name}} command{{if .Flags}} [command options]{{end}} [arguments...] + +COMMANDS: + {{range .Commands}}{{join .Names ", "}}{{ "\t" }}{{.Usage}} + {{end}}{{if .Flags}} +OPTIONS: + {{range .Flags}}{{.}} + {{end}}{{end}} +` + +var helpCommand = Command{ + Name: "help", + Aliases: []string{"h"}, + Usage: "Shows a list of commands or help for one command", + Action: func(c *Context) { + args := c.Args() + if args.Present() { + ShowCommandHelp(c, args.First()) + } else { + ShowAppHelp(c) + } + }, +} + +var helpSubcommand = Command{ + Name: "help", + Aliases: []string{"h"}, + Usage: "Shows a list of commands or help for one command", + Action: func(c *Context) { + args := c.Args() + if args.Present() { + ShowCommandHelp(c, args.First()) + } else { + ShowSubcommandHelp(c) + } + }, +} + +// Prints help for the App or Command +type helpPrinter func(w io.Writer, templ string, data interface{}) + +var HelpPrinter helpPrinter = printHelp + +// Prints version for the App +var VersionPrinter = printVersion + +func ShowAppHelp(c *Context) { + HelpPrinter(c.App.Writer, AppHelpTemplate, c.App) +} + +// Prints the list of subcommands as the default app completion method +func DefaultAppComplete(c *Context) { + for _, command := range c.App.Commands { + for _, name := range command.Names() { + fmt.Fprintln(c.App.Writer, name) + } + } +} + +// Prints help for the given command +func ShowCommandHelp(ctx *Context, command string) { + // show the subcommand help for a command with subcommands + if command == "" { + HelpPrinter(ctx.App.Writer, SubcommandHelpTemplate, ctx.App) + return + } + + for _, c := range ctx.App.Commands { + if c.HasName(command) { + HelpPrinter(ctx.App.Writer, CommandHelpTemplate, c) + return + } + } + + if ctx.App.CommandNotFound != nil { + ctx.App.CommandNotFound(ctx, command) + } else { + fmt.Fprintf(ctx.App.Writer, "No help topic for '%v'\n", command) + } +} + +// Prints help for the given subcommand +func ShowSubcommandHelp(c *Context) { + ShowCommandHelp(c, c.Command.Name) +} + +// Prints the version number of the App +func ShowVersion(c *Context) { + VersionPrinter(c) +} + +func printVersion(c *Context) { + fmt.Fprintf(c.App.Writer, "%v version %v\n", c.App.Name, c.App.Version) +} + +// Prints the lists of commands within a given context +func ShowCompletions(c *Context) { + a := c.App + if a != nil && a.BashComplete != nil { + a.BashComplete(c) + } +} + +// Prints the custom completions for a given command +func ShowCommandCompletions(ctx *Context, command string) { + c := ctx.App.Command(command) + if c != nil && c.BashComplete != nil { + c.BashComplete(ctx) + } +} + +func printHelp(out io.Writer, templ string, data interface{}) { + funcMap := template.FuncMap{ + "join": strings.Join, + } + + w := tabwriter.NewWriter(out, 0, 8, 1, '\t', 0) + t := template.Must(template.New("help").Funcs(funcMap).Parse(templ)) + err := t.Execute(w, data) + if err != nil { + panic(err) + } + w.Flush() +} + +func checkVersion(c *Context) bool { + if c.GlobalBool("version") || c.GlobalBool("v") || c.Bool("version") || c.Bool("v") { + ShowVersion(c) + return true + } + + return false +} + +func checkHelp(c *Context) bool { + if c.GlobalBool("h") || c.GlobalBool("help") || c.Bool("h") || c.Bool("help") { + ShowAppHelp(c) + return true + } + + return false +} + +func checkCommandHelp(c *Context, name string) bool { + if c.Bool("h") || c.Bool("help") { + ShowCommandHelp(c, name) + return true + } + + return false +} + +func checkSubcommandHelp(c *Context) bool { + if c.GlobalBool("h") || c.GlobalBool("help") { + ShowSubcommandHelp(c) + return true + } + + return false +} + +func checkCompletions(c *Context) bool { + if (c.GlobalBool(BashCompletionFlag.Name) || c.Bool(BashCompletionFlag.Name)) && c.App.EnableBashCompletion { + ShowCompletions(c) + return true + } + + return false +} + +func checkCommandCompletions(c *Context, name string) bool { + if c.Bool(BashCompletionFlag.Name) && c.App.EnableBashCompletion { + ShowCommandCompletions(c, name) + return true + } + + return false +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/help_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/help_test.go new file mode 100644 index 0000000..42d0284 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/help_test.go @@ -0,0 +1,36 @@ +package cli + +import ( + "bytes" + "testing" +) + +func Test_ShowAppHelp_NoAuthor(t *testing.T) { + output := new(bytes.Buffer) + app := NewApp() + app.Writer = output + + c := NewContext(app, nil, nil) + + ShowAppHelp(c) + + if bytes.Index(output.Bytes(), []byte("AUTHOR(S):")) != -1 { + t.Errorf("expected\n%snot to include %s", output.String(), "AUTHOR(S):") + } +} + +func Test_ShowAppHelp_NoVersion(t *testing.T) { + output := new(bytes.Buffer) + app := NewApp() + app.Writer = output + + app.Version = "" + + c := NewContext(app, nil, nil) + + ShowAppHelp(c) + + if bytes.Index(output.Bytes(), []byte("VERSION:")) != -1 { + t.Errorf("expected\n%snot to include %s", output.String(), "VERSION:") + } +} diff --git a/Godeps/_workspace/src/github.com/codegangsta/cli/helpers_test.go b/Godeps/_workspace/src/github.com/codegangsta/cli/helpers_test.go new file mode 100644 index 0000000..3ce8e93 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codegangsta/cli/helpers_test.go @@ -0,0 +1,19 @@ +package cli + +import ( + "reflect" + "testing" +) + +/* Test Helpers */ +func expect(t *testing.T, a interface{}, b interface{}) { + if a != b { + t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a)) + } +} + +func refute(t *testing.T, a interface{}, b interface{}) { + if a == b { + t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a)) + } +} diff --git a/Godeps/_workspace/src/github.com/codeskyblue/kproc/.gitignore b/Godeps/_workspace/src/github.com/codeskyblue/kproc/.gitignore new file mode 100644 index 0000000..0026861 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codeskyblue/kproc/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/Godeps/_workspace/src/github.com/codeskyblue/kproc/README.md b/Godeps/_workspace/src/github.com/codeskyblue/kproc/README.md new file mode 100644 index 0000000..7b193dc --- /dev/null +++ b/Godeps/_workspace/src/github.com/codeskyblue/kproc/README.md @@ -0,0 +1,24 @@ +# kproc +[![GoDoc](https://godoc.org/github.com/codeskyblue/kproc?status.svg)](https://godoc.org/github.com/codeskyblue/kproc) + +This is a golang lib, offer a better way to kill all child process. + +Tested on _windows, linux, darwin._ + +This lib has been used in [fswatch](https://github.com/codeskyblue/fswatch). + +## Usage + + go get -v github.com/codeskyblue/kproc + +example: + + func main() { + p := kproc.ProcString("python flask_main.py") + p.Start() + time.Sleep(3 * time.Second) + err := p.Terminate(syscall.SIGKILL) + if err != nil { + log.Println(err) + } + } diff --git a/Godeps/_workspace/src/github.com/codeskyblue/kproc/proc.go b/Godeps/_workspace/src/github.com/codeskyblue/kproc/proc.go new file mode 100644 index 0000000..660d3d8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codeskyblue/kproc/proc.go @@ -0,0 +1,7 @@ +package kproc + +import "os/exec" + +type Process struct { + *exec.Cmd +} diff --git a/Godeps/_workspace/src/github.com/codeskyblue/kproc/proc_posix.go b/Godeps/_workspace/src/github.com/codeskyblue/kproc/proc_posix.go new file mode 100644 index 0000000..c3c815c --- /dev/null +++ b/Godeps/_workspace/src/github.com/codeskyblue/kproc/proc_posix.go @@ -0,0 +1,45 @@ +// +build !windows + +package kproc + +import ( + "log" + "os" + "os/exec" + "syscall" +) + +func setupCmd(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{} + cmd.SysProcAttr.Setsid = true +} + +func ProcCommand(cmd *exec.Cmd) *Process { + setupCmd(cmd) + return &Process{ + Cmd: cmd, + } +} + +func ProcString(command string) *Process { + cmd := exec.Command("/bin/bash", "-c", command) + setupCmd(cmd) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return &Process{ + Cmd: cmd, + } +} + +func (p *Process) Terminate(sig os.Signal) (err error) { + if p.Process == nil { + return + } + // find pgid, ref: http://unix.stackexchange.com/questions/14815/process-descendants + group, err := os.FindProcess(-1 * p.Process.Pid) + log.Println(group) + if err == nil { + err = group.Signal(sig) + } + return err +} diff --git a/Godeps/_workspace/src/github.com/codeskyblue/kproc/proc_windows.go b/Godeps/_workspace/src/github.com/codeskyblue/kproc/proc_windows.go new file mode 100644 index 0000000..97e4148 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codeskyblue/kproc/proc_windows.go @@ -0,0 +1,33 @@ +package kproc + +import ( + "os" + "os/exec" + "strconv" +) + +func ProcCommand(cmd *exec.Cmd) *Process { + return &Process{ + Cmd: cmd, + } +} + +func ProcString(command string) *Process { + cmd := exec.Command("cmd", "/c", command) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return &Process{ + Cmd: cmd, + } +} + +func (p *Process) Terminate(sig os.Signal) (err error) { + if p.Process == nil { + return nil + } + pid := p.Process.Pid + c := exec.Command("taskkill", "/t", "/f", "/pid", strconv.Itoa(pid)) + c.Stdout = os.Stdout + c.Stderr = os.Stderr + return c.Run() +} diff --git a/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/Procfile b/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/Procfile new file mode 100644 index 0000000..c075a57 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/Procfile @@ -0,0 +1 @@ +web: python flask_main.py diff --git a/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/flask_main.py b/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/flask_main.py new file mode 100644 index 0000000..490d39c --- /dev/null +++ b/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/flask_main.py @@ -0,0 +1,11 @@ +import flask + + +app = flask.Flask(__name__) + +@app.route('/') +def homepage(): + return 'Home' + +if __name__ == '__main__': + app.run(debug=True, host='0.0.0.0') \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/main.go b/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/main.go new file mode 100644 index 0000000..19326f4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/main.go @@ -0,0 +1,22 @@ +package main + +import ( + "fmt" + "kproc" + "log" + "os/exec" + "syscall" + "time" +) + +func main() { + p := kproc.ProcString("python flask_main.py") + p.Start() + time.Sleep(10 * time.Second) + err := p.Terminate(syscall.SIGKILL) + if err != nil { + log.Println(err) + } + out, _ := exec.Command("lsof", "-i:5000").CombinedOutput() + fmt.Println(string(out)) +} diff --git a/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/tests b/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/tests new file mode 100644 index 0000000..bad716a Binary files /dev/null and b/Godeps/_workspace/src/github.com/codeskyblue/kproc/tests/tests differ diff --git a/Godeps/_workspace/src/github.com/franela/goreq/.gitignore b/Godeps/_workspace/src/github.com/franela/goreq/.gitignore new file mode 100644 index 0000000..131fc16 --- /dev/null +++ b/Godeps/_workspace/src/github.com/franela/goreq/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe + +src diff --git a/Godeps/_workspace/src/github.com/franela/goreq/.travis.yml b/Godeps/_workspace/src/github.com/franela/goreq/.travis.yml new file mode 100644 index 0000000..629753d --- /dev/null +++ b/Godeps/_workspace/src/github.com/franela/goreq/.travis.yml @@ -0,0 +1,8 @@ +language: go +go: + - 1.2 + - tip +notifications: + email: + - ionathan@gmail.com + - marcosnils@gmail.com diff --git a/Godeps/_workspace/src/github.com/franela/goreq/LICENSE b/Godeps/_workspace/src/github.com/franela/goreq/LICENSE new file mode 100644 index 0000000..068dee1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/franela/goreq/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2013 Jonathan Leibiusky and Marcos Lilljedahl + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/Godeps/_workspace/src/github.com/franela/goreq/Makefile b/Godeps/_workspace/src/github.com/franela/goreq/Makefile new file mode 100644 index 0000000..0f04d65 --- /dev/null +++ b/Godeps/_workspace/src/github.com/franela/goreq/Makefile @@ -0,0 +1,3 @@ +test: + go get -v -d -t ./... + go test -v diff --git a/Godeps/_workspace/src/github.com/franela/goreq/README.md b/Godeps/_workspace/src/github.com/franela/goreq/README.md new file mode 100644 index 0000000..b6a0bfd --- /dev/null +++ b/Godeps/_workspace/src/github.com/franela/goreq/README.md @@ -0,0 +1,444 @@ +[![Build Status](https://img.shields.io/travis/franela/goreq/master.svg)](https://travis-ci.org/franela/goreq) +[![GoDoc](https://godoc.org/github.com/franela/goreq?status.svg)](https://godoc.org/github.com/franela/goreq) + +GoReq +======= + +Simple and sane HTTP request library for Go language. + + + +**Table of Contents** + +- [Why GoReq?](#user-content-why-goreq) +- [How do I install it?](#user-content-how-do-i-install-it) +- [What can I do with it?](#user-content-what-can-i-do-with-it) + - [Making requests with different methods](#user-content-making-requests-with-different-methods) + - [GET](#user-content-get) + - [Tags](#user-content-tags) + - [POST](#user-content-post) + - [Sending payloads in the Body](#user-content-sending-payloads-in-the-body) + - [Specifiying request headers](#user-content-specifiying-request-headers) + - [Sending Cookies](#cookie-support) + - [Setting timeouts](#user-content-setting-timeouts) + - [Using the Response and Error](#user-content-using-the-response-and-error) + - [Receiving JSON](#user-content-receiving-json) + - [Sending/Receiving Compressed Payloads](#user-content-sendingreceiving-compressed-payloads) + - [Using gzip compression:](#user-content-using-gzip-compression) + - [Using deflate compression:](#user-content-using-deflate-compression) + - [Using compressed responses:](#user-content-using-compressed-responses) + - [Proxy](#proxy) + - [Debugging requests](#debug) + - [Getting raw Request & Response](#getting-raw-request--response) + - [TODO:](#user-content-todo) + + + +Why GoReq? +========== + +Go has very nice native libraries that allows you to do lots of cool things. But sometimes those libraries are too low level, which means that to do a simple thing, like an HTTP Request, it takes some time. And if you want to do something as simple as adding a timeout to a request, you will end up writing several lines of code. + +This is why we think GoReq is useful. Because you can do all your HTTP requests in a very simple and comprehensive way, while enabling you to do more advanced stuff by giving you access to the native API. + +How do I install it? +==================== + +```bash +go get github.com/franela/goreq +``` + +What can I do with it? +====================== + +## Making requests with different methods + +#### GET +```go +res, err := goreq.Request{ Uri: "http://www.google.com" }.Do() +``` + +GoReq default method is GET. + +You can also set value to GET method easily + +```go +type Item struct { + Limit int + Skip int + Fields string +} + +item := Item { + Limit: 3, + Skip: 5, + Fields: "Value", +} + +res, err := goreq.Request{ + Uri: "http://localhost:3000/", + QueryString: item, +}.Do() +``` +The sample above will send `http://localhost:3000/?limit=3&skip=5&fields=Value` + +Alternatively the `url` tag can be used in struct fields to customize encoding properties + +```go +type Item struct { + TheLimit int `url:"the_limit"` + TheSkip string `url:"the_skip,omitempty"` + TheFields string `url:"-"` +} + +item := Item { + TheLimit: 3, + TheSkip: "", + TheFields: "Value", +} + +res, err := goreq.Request{ + Uri: "http://localhost:3000/", + QueryString: item, +}.Do() +``` +The sample above will send `http://localhost:3000/?the_limit=3` + + +QueryString also support url.Values + +```go +item := url.Values{} +item.Set("Limit", 3) +item.Add("Field", "somefield") +item.Add("Field", "someotherfield") + +res, err := goreq.Request{ + Uri: "http://localhost:3000/", + QueryString: item, +}.Do() +``` + +The sample above will send `http://localhost:3000/?limit=3&field=somefield&field=someotherfield` + +### Tags + +Struct field `url` tag is mainly used as the request parameter name. +Tags can be comma separated multiple values, 1st value is for naming and rest has special meanings. + +- special tag for 1st value + - `-`: value is ignored if set this + +- special tag for rest 2nd value + - `omitempty`: zero-value is ignored if set this + - `squash`: the fields of embedded struct is used for parameter + +#### Tag Examples + +```go +type Place struct { + Country string `url:"country"` + City string `url:"city"` + ZipCode string `url:"zipcode,omitempty"` +} + +type Person struct { + Place `url:",squash"` + + FirstName string `url:"first_name"` + LastName string `url:"last_name"` + Age string `url:"age,omitempty"` + Password string `url:"-"` +} + +johnbull := Person{ + Place: Place{ // squash the embedded struct value + Country: "UK", + City: "London", + ZipCode: "SW1", + }, + FirstName: "John", + LastName: "Doe", + Age: "35", + Password: "my-secret", // ignored for parameter +} + +goreq.Request{ + Uri: "http://localhost/", + QueryString: johnbull, +}.Do() +// => `http://localhost/?first_name=John&last_name=Doe&age=35&country=UK&city=London&zip_code=SW1` + + +// age and zipcode will be ignored because of `omitempty` +// but firstname isn't. +samurai := Person{ + Place: Place{ // squash the embedded struct value + Country: "Japan", + City: "Tokyo", + }, + LastName: "Yagyu", +} + +goreq.Request{ + Uri: "http://localhost/", + QueryString: samurai, +}.Do() +// => `http://localhost/?first_name=&last_name=yagyu&country=Japan&city=Tokyo` +``` + + +#### POST + +```go +res, err := goreq.Request{ Method: "POST", Uri: "http://www.google.com" }.Do() +``` + +## Sending payloads in the Body + +You can send ```string```, ```Reader``` or ```interface{}``` in the body. The first two will be sent as text. The last one will be marshalled to JSON, if possible. + +```go +type Item struct { + Id int + Name string +} + +item := Item{ Id: 1111, Name: "foobar" } + +res, err := goreq.Request{ + Method: "POST", + Uri: "http://www.google.com", + Body: item, +}.Do() +``` + +## Specifiying request headers + +We think that most of the times the request headers that you use are: ```Host```, ```Content-Type```, ```Accept``` and ```User-Agent```. This is why we decided to make it very easy to set these headers. + +```go +res, err := goreq.Request{ + Uri: "http://www.google.com", + Host: "foobar.com", + Accept: "application/json", + ContentType: "application/json", + UserAgent: "goreq", +}.Do() +``` + +But sometimes you need to set other headers. You can still do it. + +```go +req := goreq.Request{ Uri: "http://www.google.com" } + +req.AddHeader("X-Custom", "somevalue") + +req.Do() +``` + +Alternatively you can use the `WithHeader` function to keep the syntax short + +```go +res, err = goreq.Request{ Uri: "http://www.google.com" }.WithHeader("X-Custom", "somevalue").Do() +``` + +## Cookie support + +Cookies can be either set at the request level by sending a [CookieJar](http://golang.org/pkg/net/http/cookiejar/) in the `CookieJar` request field +or you can use goreq's one-liner WithCookie method as shown below + +```go +res, err := goreq.Request{ + Uri: "http://www.google.com", +}. +WithCookie(&http.Cookie{Name: "c1", Value: "v1"}). +Do() +``` + +## Setting timeouts + +GoReq supports 2 kind of timeouts. A general connection timeout and a request specific one. By default the connection timeout is of 1 second. There is no default for request timeout, which means it will wait forever. + +You can change the connection timeout doing: + +```go +goreq.SetConnectTimeout(100 * time.Millisecond) +``` + +And specify the request timeout doing: + +```go +res, err := goreq.Request{ + Uri: "http://www.google.com", + Timeout: 500 * time.Millisecond, +}.Do() +``` + +## Using the Response and Error + +GoReq will always return 2 values: a ```Response``` and an ```Error```. +If ```Error``` is not ```nil``` it means that an error happened while doing the request and you shouldn't use the ```Response``` in any way. +You can check what happened by getting the error message: + +```go +fmt.Println(err.Error()) +``` +And to make it easy to know if it was a timeout error, you can ask the error or return it: + +```go +if serr, ok := err.(*goreq.Error); ok { + if serr.Timeout() { + ... + } +} +return err +``` + +If you don't get an error, you can safely use the ```Response```. + +```go +res.Uri // return final URL location of the response (fulfilled after redirect was made) +res.StatusCode // return the status code of the response +res.Body // gives you access to the body +res.Body.ToString() // will return the body as a string +res.Header.Get("Content-Type") // gives you access to all the response headers +``` +Remember that you should **always** close `res.Body` if it's not `nil` + +## Receiving JSON + +GoReq will help you to receive and unmarshal JSON. + +```go +type Item struct { + Id int + Name string +} + +var item Item + +res.Body.FromJsonTo(&item) +``` + +## Sending/Receiving Compressed Payloads +GoReq supports gzip, deflate and zlib compression of requests' body and transparent decompression of responses provided they have a correct `Content-Encoding` header. + +#####Using gzip compression: +```go +res, err := goreq.Request{ + Method: "POST", + Uri: "http://www.google.com", + Body: item, + Compression: goreq.Gzip(), +}.Do() +``` +#####Using deflate/zlib compression: +```go +res, err := goreq.Request{ + Method: "POST", + Uri: "http://www.google.com", + Body: item, + Compression: goreq.Deflate(), +}.Do() +``` +#####Using compressed responses: +If servers replies a correct and matching `Content-Encoding` header (gzip requires `Content-Encoding: gzip` and deflate `Content-Encoding: deflate`) goreq transparently decompresses the response so the previous example should always work: +```go +type Item struct { + Id int + Name string +} +res, err := goreq.Request{ + Method: "POST", + Uri: "http://www.google.com", + Body: item, + Compression: goreq.Gzip(), +}.Do() +var item Item +res.Body.FromJsonTo(&item) +``` +If no `Content-Encoding` header is replied by the server GoReq will return the crude response. + +## Proxy +If you need to use a proxy for your requests GoReq supports the standard `http_proxy` env variable as well as manually setting the proxy for each request + +```go +res, err := goreq.Request{ + Method: "GET", + Proxy: "http://myproxy:myproxyport", + Uri: "http://www.google.com", +}.Do() +``` + +### Proxy basic auth is also supported + +```go +res, err := goreq.Request{ + Method: "GET", + Proxy: "http://user:pass@myproxy:myproxyport", + Uri: "http://www.google.com", +}.Do() +``` + +## Debug +If you need to debug your http requests, it can print the http request detail. + +```go +res, err := goreq.Request{ + Method: "GET", + Uri: "http://www.google.com", + Compression: goreq.Gzip(), + ShowDebug: true, +}.Do() +fmt.Println(res, err) +``` + +and it will print the log: +``` +GET / HTTP/1.1 +Host: www.google.com +Accept: +Accept-Encoding: gzip +Content-Encoding: gzip +Content-Type: +``` + + +### Getting raw Request & Response + +To get the Request: + +```go +req := goreq.Request{ + Host: "foobar.com", +} + +//req.Request will return a new instance of an http.Request so you can safely use it for something else +request, _ := req.NewRequest() + +``` + + +To get the Response: + +```go +res, err := goreq.Request{ + Method: "GET", + Uri: "http://www.google.com", + Compression: goreq.Gzip(), + ShowDebug: true, +}.Do() + +// res.Response will contain the original http.Response structure +fmt.Println(res.Response, err) +``` + + + + +TODO: +----- + +We do have a couple of [issues](https://github.com/franela/goreq/issues) pending we'll be addressing soon. But feel free to +contribute and send us PRs (with tests please :smile:). diff --git a/Godeps/_workspace/src/github.com/franela/goreq/goreq.go b/Godeps/_workspace/src/github.com/franela/goreq/goreq.go new file mode 100644 index 0000000..27a1561 --- /dev/null +++ b/Godeps/_workspace/src/github.com/franela/goreq/goreq.go @@ -0,0 +1,494 @@ +package goreq + +import ( + "bufio" + "bytes" + "compress/gzip" + "compress/zlib" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httputil" + "net/url" + "reflect" + "strings" + "time" +) + +type Request struct { + headers []headerTuple + cookies []*http.Cookie + Method string + Uri string + Body interface{} + QueryString interface{} + Timeout time.Duration + ContentType string + Accept string + Host string + UserAgent string + Insecure bool + MaxRedirects int + RedirectHeaders bool + Proxy string + Compression *compression + BasicAuthUsername string + BasicAuthPassword string + CookieJar http.CookieJar + ShowDebug bool + OnBeforeRequest func(goreq *Request, httpreq *http.Request) +} + +type compression struct { + writer func(buffer io.Writer) (io.WriteCloser, error) + reader func(buffer io.Reader) (io.ReadCloser, error) + ContentEncoding string +} + +type Response struct { + *http.Response + Uri string + Body *Body + req *http.Request +} + +func (r Response) CancelRequest() { + cancelRequest(r.req) + +} + +func cancelRequest(r *http.Request) { + if transport, ok := DefaultTransport.(transportRequestCanceler); ok { + transport.CancelRequest(r) + } +} + +type headerTuple struct { + name string + value string +} + +type Body struct { + reader io.ReadCloser + compressedReader io.ReadCloser +} + +type Error struct { + timeout bool + Err error +} + +type transportRequestCanceler interface { + CancelRequest(*http.Request) +} + +func (e *Error) Timeout() bool { + return e.timeout +} + +func (e *Error) Error() string { + return e.Err.Error() +} + +func (b *Body) Read(p []byte) (int, error) { + if b.compressedReader != nil { + return b.compressedReader.Read(p) + } + return b.reader.Read(p) +} + +func (b *Body) Close() error { + err := b.reader.Close() + if b.compressedReader != nil { + return b.compressedReader.Close() + } + return err +} + +func (b *Body) FromJsonTo(o interface{}) error { + return json.NewDecoder(b).Decode(o) +} + +func (b *Body) ToString() (string, error) { + body, err := ioutil.ReadAll(b) + if err != nil { + return "", err + } + return string(body), nil +} + +func Gzip() *compression { + reader := func(buffer io.Reader) (io.ReadCloser, error) { + return gzip.NewReader(buffer) + } + writer := func(buffer io.Writer) (io.WriteCloser, error) { + return gzip.NewWriter(buffer), nil + } + return &compression{writer: writer, reader: reader, ContentEncoding: "gzip"} +} + +func Deflate() *compression { + reader := func(buffer io.Reader) (io.ReadCloser, error) { + return zlib.NewReader(buffer) + } + writer := func(buffer io.Writer) (io.WriteCloser, error) { + return zlib.NewWriter(buffer), nil + } + return &compression{writer: writer, reader: reader, ContentEncoding: "deflate"} +} + +func Zlib() *compression { + return Deflate() +} + +func paramParse(query interface{}) (string, error) { + switch query.(type) { + case url.Values: + return query.(url.Values).Encode(), nil + case *url.Values: + return query.(*url.Values).Encode(), nil + default: + var v = &url.Values{} + err := paramParseStruct(v, query) + return v.Encode(), err + } +} + +func paramParseStruct(v *url.Values, query interface{}) error { + var ( + s = reflect.ValueOf(query) + t = reflect.TypeOf(query) + ) + for t.Kind() == reflect.Ptr || t.Kind() == reflect.Interface { + s = s.Elem() + t = s.Type() + } + if t.Kind() != reflect.Struct { + return errors.New("Can not parse QueryString.") + } + + for i := 0; i < t.NumField(); i++ { + var name string + + field := s.Field(i) + typeField := t.Field(i) + + if !field.CanInterface() { + continue + } + + urlTag := typeField.Tag.Get("url") + if urlTag == "-" { + continue + } + + name, opts := parseTag(urlTag) + + var omitEmpty, squash bool + omitEmpty = opts.Contains("omitempty") + squash = opts.Contains("squash") + + if squash { + err := paramParseStruct(v, field.Interface()) + if err != nil { + return err + } + continue + } + + if urlTag == "" { + name = strings.ToLower(typeField.Name) + } + + if val := fmt.Sprintf("%v", field.Interface()); !(omitEmpty && len(val) == 0) { + v.Add(name, val) + } + } + return nil +} + +func prepareRequestBody(b interface{}) (io.Reader, error) { + switch b.(type) { + case string: + // treat is as text + return strings.NewReader(b.(string)), nil + case io.Reader: + // treat is as text + return b.(io.Reader), nil + case []byte: + //treat as byte array + return bytes.NewReader(b.([]byte)), nil + case nil: + return nil, nil + default: + // try to jsonify it + j, err := json.Marshal(b) + if err == nil { + return bytes.NewReader(j), nil + } + return nil, err + } +} + +var DefaultDialer = &net.Dialer{Timeout: 1000 * time.Millisecond} +var DefaultTransport http.RoundTripper = &http.Transport{Dial: DefaultDialer.Dial, Proxy: http.ProxyFromEnvironment} +var DefaultClient = &http.Client{Transport: DefaultTransport} + +var proxyTransport http.RoundTripper +var proxyClient *http.Client + +func SetConnectTimeout(duration time.Duration) { + DefaultDialer.Timeout = duration +} + +func (r *Request) AddHeader(name string, value string) { + if r.headers == nil { + r.headers = []headerTuple{} + } + r.headers = append(r.headers, headerTuple{name: name, value: value}) +} + +func (r Request) WithHeader(name string, value string) Request { + r.AddHeader(name, value) + return r +} + +func (r *Request) AddCookie(c *http.Cookie) { + r.cookies = append(r.cookies, c) +} + +func (r Request) WithCookie(c *http.Cookie) Request { + r.AddCookie(c) + return r + +} + +func (r Request) Do() (*Response, error) { + var client = DefaultClient + var transport = DefaultTransport + var resUri string + var redirectFailed bool + + r.Method = valueOrDefault(r.Method, "GET") + + // use a client with a cookie jar if necessary. We create a new client not + // to modify the default one. + if r.CookieJar != nil { + client = &http.Client{ + Transport: transport, + Jar: r.CookieJar, + } + } + + if r.Proxy != "" { + proxyUrl, err := url.Parse(r.Proxy) + if err != nil { + // proxy address is in a wrong format + return nil, &Error{Err: err} + } + + //If jar is specified new client needs to be built + if proxyTransport == nil || client.Jar != nil { + proxyTransport = &http.Transport{Dial: DefaultDialer.Dial, Proxy: http.ProxyURL(proxyUrl)} + proxyClient = &http.Client{Transport: proxyTransport, Jar: client.Jar} + } else if proxyTransport, ok := proxyTransport.(*http.Transport); ok { + proxyTransport.Proxy = http.ProxyURL(proxyUrl) + } + transport = proxyTransport + client = proxyClient + } + + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + + if len(via) > r.MaxRedirects { + redirectFailed = true + return errors.New("Error redirecting. MaxRedirects reached") + } + + resUri = req.URL.String() + + //By default Golang will not redirect request headers + // https://code.google.com/p/go/issues/detail?id=4800&q=request%20header + if r.RedirectHeaders { + for key, val := range via[0].Header { + req.Header[key] = val + } + } + return nil + } + + if transport, ok := transport.(*http.Transport); ok { + if r.Insecure { + if transport.TLSClientConfig != nil { + transport.TLSClientConfig.InsecureSkipVerify = true + } else { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + } else if transport.TLSClientConfig != nil { + // the default TLS client (when transport.TLSClientConfig==nil) is + // already set to verify, so do nothing in that case + transport.TLSClientConfig.InsecureSkipVerify = false + } + } + + req, err := r.NewRequest() + + if err != nil { + // we couldn't parse the URL. + return nil, &Error{Err: err} + } + + timeout := false + var timer *time.Timer + if r.Timeout > 0 { + timer = time.AfterFunc(r.Timeout, func() { + cancelRequest(req) + timeout = true + }) + } + + if r.ShowDebug { + dump, err := httputil.DumpRequest(req, true) + if err != nil { + log.Println(err) + } + log.Println(string(dump)) + } + + if r.OnBeforeRequest != nil { + r.OnBeforeRequest(&r, req) + } + res, err := client.Do(req) + if timer != nil { + timer.Stop() + } + + if err != nil { + if !timeout { + switch err := err.(type) { + case *net.OpError: + timeout = err.Timeout() + case *url.Error: + if op, ok := err.Err.(*net.OpError); ok { + timeout = op.Timeout() + } + } + } + + var response *Response + //If redirect fails we still want to return response data + if redirectFailed { + if res != nil { + response = &Response{res, resUri, &Body{reader: res.Body}, req} + } else { + response = &Response{res, resUri, nil, req} + } + } + + //If redirect fails and we haven't set a redirect count we shouldn't return an error + if redirectFailed && r.MaxRedirects == 0 { + return response, nil + } + + return response, &Error{timeout: timeout, Err: err} + } + + if r.Compression != nil && strings.Contains(res.Header.Get("Content-Encoding"), r.Compression.ContentEncoding) { + compressedReader, err := r.Compression.reader(res.Body) + if err != nil { + return nil, &Error{Err: err} + } + return &Response{res, resUri, &Body{reader: res.Body, compressedReader: compressedReader}, req}, nil + } + + return &Response{res, resUri, &Body{reader: res.Body}, req}, nil +} + +func (r Request) addHeaders(headersMap http.Header) { + if len(r.UserAgent) > 0 { + headersMap.Add("User-Agent", r.UserAgent) + } + if r.Accept != "" { + headersMap.Add("Accept", r.Accept) + } + if r.ContentType != "" { + headersMap.Add("Content-Type", r.ContentType) + } +} + +func (r Request) NewRequest() (*http.Request, error) { + + b, e := prepareRequestBody(r.Body) + if e != nil { + // there was a problem marshaling the body + return nil, &Error{Err: e} + } + + if r.QueryString != nil { + param, e := paramParse(r.QueryString) + if e != nil { + return nil, &Error{Err: e} + } + r.Uri = r.Uri + "?" + param + } + + var bodyReader io.Reader + if b != nil && r.Compression != nil { + buffer := bytes.NewBuffer([]byte{}) + readBuffer := bufio.NewReader(b) + writer, err := r.Compression.writer(buffer) + if err != nil { + return nil, &Error{Err: err} + } + _, e = readBuffer.WriteTo(writer) + writer.Close() + if e != nil { + return nil, &Error{Err: e} + } + bodyReader = buffer + } else { + bodyReader = b + } + + req, err := http.NewRequest(r.Method, r.Uri, bodyReader) + if err != nil { + return nil, err + } + // add headers to the request + req.Host = r.Host + + r.addHeaders(req.Header) + if r.Compression != nil { + req.Header.Add("Content-Encoding", r.Compression.ContentEncoding) + req.Header.Add("Accept-Encoding", r.Compression.ContentEncoding) + } + if r.headers != nil { + for _, header := range r.headers { + req.Header.Add(header.name, header.value) + } + } + + //use basic auth if required + if r.BasicAuthUsername != "" { + req.SetBasicAuth(r.BasicAuthUsername, r.BasicAuthPassword) + } + + for _, c := range r.cookies { + req.AddCookie(c) + } + return req, nil +} + +// Return value if nonempty, def otherwise. +func valueOrDefault(value, def string) string { + if value != "" { + return value + } + return def +} diff --git a/Godeps/_workspace/src/github.com/franela/goreq/goreq_test.go b/Godeps/_workspace/src/github.com/franela/goreq/goreq_test.go new file mode 100644 index 0000000..3b88c22 --- /dev/null +++ b/Godeps/_workspace/src/github.com/franela/goreq/goreq_test.go @@ -0,0 +1,1184 @@ +package goreq + +import ( + "compress/gzip" + "compress/zlib" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "math" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + . "github.com/franela/goblin" + . "github.com/onsi/gomega" +) + +type Query struct { + Limit int + Skip int +} + +func TestRequest(t *testing.T) { + + query := Query{ + Limit: 3, + Skip: 5, + } + + valuesQuery := url.Values{} + valuesQuery.Set("name", "marcos") + valuesQuery.Add("friend", "jonas") + valuesQuery.Add("friend", "peter") + + g := Goblin(t) + + RegisterFailHandler(func(m string, _ ...int) { g.Fail(m) }) + + g.Describe("Request", func() { + + g.Describe("General request methods", func() { + var ts *httptest.Server + var requestHeaders http.Header + + g.Before(func() { + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestHeaders = r.Header + if (r.Method == "GET" || r.Method == "OPTIONS" || r.Method == "TRACE" || r.Method == "PATCH" || r.Method == "FOOBAR") && r.URL.Path == "/foo" { + w.WriteHeader(200) + fmt.Fprint(w, "bar") + } + if r.Method == "GET" && r.URL.Path == "/getquery" { + w.WriteHeader(200) + fmt.Fprint(w, fmt.Sprintf("%v", r.URL)) + } + if r.Method == "GET" && r.URL.Path == "/getbody" { + w.WriteHeader(200) + io.Copy(w, r.Body) + } + if r.Method == "POST" && r.URL.Path == "/" { + w.Header().Add("Location", ts.URL+"/123") + w.WriteHeader(201) + io.Copy(w, r.Body) + } + if r.Method == "POST" && r.URL.Path == "/getquery" { + w.WriteHeader(200) + fmt.Fprint(w, fmt.Sprintf("%v", r.URL)) + } + if r.Method == "PUT" && r.URL.Path == "/foo/123" { + w.WriteHeader(200) + io.Copy(w, r.Body) + } + if r.Method == "DELETE" && r.URL.Path == "/foo/123" { + w.WriteHeader(204) + } + if r.Method == "GET" && r.URL.Path == "/redirect_test/301" { + http.Redirect(w, r, "/redirect_test/302", 301) + } + if r.Method == "GET" && r.URL.Path == "/redirect_test/302" { + http.Redirect(w, r, "/redirect_test/303", 302) + } + if r.Method == "GET" && r.URL.Path == "/redirect_test/303" { + http.Redirect(w, r, "/redirect_test/307", 303) + } + if r.Method == "GET" && r.URL.Path == "/redirect_test/307" { + http.Redirect(w, r, "/getquery", 307) + } + if r.Method == "GET" && r.URL.Path == "/redirect_test/destination" { + http.Redirect(w, r, ts.URL+"/destination", 301) + } + if r.Method == "GET" && r.URL.Path == "/getcookies" { + defer r.Body.Close() + w.WriteHeader(200) + fmt.Fprint(w, requestHeaders.Get("Cookie")) + } + if r.Method == "GET" && r.URL.Path == "/setcookies" { + defer r.Body.Close() + w.Header().Add("Set-Cookie", "foobar=42 ; Path=/") + w.WriteHeader(200) + } + if r.Method == "GET" && r.URL.Path == "/compressed" { + defer r.Body.Close() + b := "{\"foo\":\"bar\",\"fuu\":\"baz\"}" + gw := gzip.NewWriter(w) + defer gw.Close() + if strings.Contains(r.Header.Get("Content-Encoding"), "gzip") { + w.Header().Add("Content-Encoding", "gzip") + } + w.WriteHeader(200) + gw.Write([]byte(b)) + } + if r.Method == "GET" && r.URL.Path == "/compressed_deflate" { + defer r.Body.Close() + b := "{\"foo\":\"bar\",\"fuu\":\"baz\"}" + gw := zlib.NewWriter(w) + defer gw.Close() + if strings.Contains(r.Header.Get("Content-Encoding"), "deflate") { + w.Header().Add("Content-Encoding", "deflate") + } + w.WriteHeader(200) + gw.Write([]byte(b)) + } + if r.Method == "GET" && r.URL.Path == "/compressed_and_return_compressed_without_header" { + defer r.Body.Close() + b := "{\"foo\":\"bar\",\"fuu\":\"baz\"}" + gw := gzip.NewWriter(w) + defer gw.Close() + w.WriteHeader(200) + gw.Write([]byte(b)) + } + if r.Method == "GET" && r.URL.Path == "/compressed_deflate_and_return_compressed_without_header" { + defer r.Body.Close() + b := "{\"foo\":\"bar\",\"fuu\":\"baz\"}" + gw := zlib.NewWriter(w) + defer gw.Close() + w.WriteHeader(200) + gw.Write([]byte(b)) + } + if r.Method == "POST" && r.URL.Path == "/compressed" && r.Header.Get("Content-Encoding") == "gzip" { + defer r.Body.Close() + gr, _ := gzip.NewReader(r.Body) + defer gr.Close() + b, _ := ioutil.ReadAll(gr) + w.WriteHeader(201) + w.Write(b) + } + if r.Method == "POST" && r.URL.Path == "/compressed_deflate" && r.Header.Get("Content-Encoding") == "deflate" { + defer r.Body.Close() + gr, _ := zlib.NewReader(r.Body) + defer gr.Close() + b, _ := ioutil.ReadAll(gr) + w.WriteHeader(201) + w.Write(b) + } + if r.Method == "POST" && r.URL.Path == "/compressed_and_return_compressed" { + defer r.Body.Close() + w.Header().Add("Content-Encoding", "gzip") + w.WriteHeader(201) + io.Copy(w, r.Body) + } + if r.Method == "POST" && r.URL.Path == "/compressed_deflate_and_return_compressed" { + defer r.Body.Close() + w.Header().Add("Content-Encoding", "deflate") + w.WriteHeader(201) + io.Copy(w, r.Body) + } + if r.Method == "POST" && r.URL.Path == "/compressed_deflate_and_return_compressed_without_header" { + defer r.Body.Close() + w.WriteHeader(201) + io.Copy(w, r.Body) + } + if r.Method == "POST" && r.URL.Path == "/compressed_and_return_compressed_without_header" { + defer r.Body.Close() + w.WriteHeader(201) + io.Copy(w, r.Body) + } + })) + }) + + g.After(func() { + ts.Close() + }) + + g.Describe("GET", func() { + + g.It("Should do a GET", func() { + res, err := Request{Uri: ts.URL + "/foo"}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("bar")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should return ContentLength", func() { + res, err := Request{Uri: ts.URL + "/foo"}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("bar")) + Expect(res.StatusCode).Should(Equal(200)) + Expect(res.ContentLength).Should(Equal(int64(3))) + }) + + g.It("Should do a GET with querystring", func() { + res, err := Request{ + Uri: ts.URL + "/getquery", + QueryString: query, + }.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("/getquery?limit=3&skip=5")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should support url.Values in querystring", func() { + res, err := Request{ + Uri: ts.URL + "/getquery", + QueryString: valuesQuery, + }.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("/getquery?friend=jonas&friend=peter&name=marcos")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should support sending string body", func() { + res, err := Request{Uri: ts.URL + "/getbody", Body: "foo"}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("foo")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Shoulds support sending a Reader body", func() { + res, err := Request{Uri: ts.URL + "/getbody", Body: strings.NewReader("foo")}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("foo")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Support sending any object that is json encodable", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Uri: ts.URL + "/getbody", Body: obj}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Support sending an array of bytes body", func() { + bdy := []byte{'f', 'o', 'o'} + res, err := Request{Uri: ts.URL + "/getbody", Body: bdy}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("foo")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should return an error when body is not JSON encodable", func() { + res, err := Request{Uri: ts.URL + "/getbody", Body: math.NaN()}.Do() + + Expect(res).Should(BeNil()) + Expect(err).ShouldNot(BeNil()) + }) + + g.It("Should return a gzip reader if Content-Encoding is 'gzip'", func() { + res, err := Request{Uri: ts.URL + "/compressed", Compression: Gzip()}.Do() + b, _ := ioutil.ReadAll(res.Body) + Expect(err).Should(BeNil()) + Expect(res.Body.compressedReader).ShouldNot(BeNil()) + Expect(res.Body.reader).ShouldNot(BeNil()) + Expect(string(b)).Should(Equal("{\"foo\":\"bar\",\"fuu\":\"baz\"}")) + Expect(res.Body.compressedReader).ShouldNot(BeNil()) + Expect(res.Body.reader).ShouldNot(BeNil()) + }) + + g.It("Should close reader and compresserReader on Body close", func() { + res, err := Request{Uri: ts.URL + "/compressed", Compression: Gzip()}.Do() + Expect(err).Should(BeNil()) + + _, e := ioutil.ReadAll(res.Body.reader) + Expect(e).Should(BeNil()) + _, e = ioutil.ReadAll(res.Body.compressedReader) + Expect(e).Should(BeNil()) + + _, e = ioutil.ReadAll(res.Body.reader) + //when reading body again it doesnt error + Expect(e).Should(BeNil()) + + res.Body.Close() + _, e = ioutil.ReadAll(res.Body.reader) + //error because body is already closed + Expect(e).ShouldNot(BeNil()) + + _, e = ioutil.ReadAll(res.Body.compressedReader) + //compressedReaders dont error on reading when closed + Expect(e).Should(BeNil()) + }) + + g.It("Should not return a gzip reader if Content-Encoding is not 'gzip'", func() { + res, err := Request{Uri: ts.URL + "/compressed_and_return_compressed_without_header", Compression: Gzip()}.Do() + b, _ := ioutil.ReadAll(res.Body) + Expect(err).Should(BeNil()) + Expect(string(b)).ShouldNot(Equal("{\"foo\":\"bar\",\"fuu\":\"baz\"}")) + }) + + g.It("Should return a deflate reader if Content-Encoding is 'deflate'", func() { + res, err := Request{Uri: ts.URL + "/compressed_deflate", Compression: Deflate()}.Do() + b, _ := ioutil.ReadAll(res.Body) + Expect(err).Should(BeNil()) + Expect(string(b)).Should(Equal("{\"foo\":\"bar\",\"fuu\":\"baz\"}")) + }) + + g.It("Should not return a delfate reader if Content-Encoding is not 'deflate'", func() { + res, err := Request{Uri: ts.URL + "/compressed_deflate_and_return_compressed_without_header", Compression: Deflate()}.Do() + b, _ := ioutil.ReadAll(res.Body) + Expect(err).Should(BeNil()) + Expect(string(b)).ShouldNot(Equal("{\"foo\":\"bar\",\"fuu\":\"baz\"}")) + }) + + g.It("Should return a deflate reader when using zlib if Content-Encoding is 'deflate'", func() { + res, err := Request{Uri: ts.URL + "/compressed_deflate", Compression: Zlib()}.Do() + b, _ := ioutil.ReadAll(res.Body) + Expect(err).Should(BeNil()) + Expect(string(b)).Should(Equal("{\"foo\":\"bar\",\"fuu\":\"baz\"}")) + }) + + g.It("Should not return a delfate reader when using zlib if Content-Encoding is not 'deflate'", func() { + res, err := Request{Uri: ts.URL + "/compressed_deflate_and_return_compressed_without_header", Compression: Zlib()}.Do() + b, _ := ioutil.ReadAll(res.Body) + Expect(err).Should(BeNil()) + Expect(string(b)).ShouldNot(Equal("{\"foo\":\"bar\",\"fuu\":\"baz\"}")) + }) + + g.It("Should send cookies from the cookiejar", func() { + uri, err := url.Parse(ts.URL + "/getcookies") + Expect(err).Should(BeNil()) + + jar, err := cookiejar.New(nil) + Expect(err).Should(BeNil()) + + jar.SetCookies(uri, []*http.Cookie{ + { + Name: "bar", + Value: "foo", + Path: "/", + }, + }) + + res, err := Request{ + Uri: ts.URL + "/getcookies", + CookieJar: jar, + }.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("bar=foo")) + Expect(res.StatusCode).Should(Equal(200)) + Expect(res.ContentLength).Should(Equal(int64(7))) + }) + + g.It("Should send cookies added with .AddCookie", func() { + c1 := &http.Cookie{Name: "c1", Value: "v1"} + c2 := &http.Cookie{Name: "c2", Value: "v2"} + + req := Request{Uri: ts.URL + "/getcookies"} + req.AddCookie(c1) + req.AddCookie(c2) + + res, err := req.Do() + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("c1=v1; c2=v2")) + Expect(res.StatusCode).Should(Equal(200)) + Expect(res.ContentLength).Should(Equal(int64(12))) + }) + + g.It("Should send cookies added with .WithCookie", func() { + c1 := &http.Cookie{Name: "c1", Value: "v2"} + c2 := &http.Cookie{Name: "c2", Value: "v3"} + + res, err := Request{Uri: ts.URL + "/getcookies"}. + WithCookie(c1). + WithCookie(c2). + Do() + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("c1=v2; c2=v3")) + Expect(res.StatusCode).Should(Equal(200)) + Expect(res.ContentLength).Should(Equal(int64(12))) + }) + + g.It("Should populate the cookiejar", func() { + uri, err := url.Parse(ts.URL + "/setcookies") + Expect(err).Should(BeNil()) + + jar, _ := cookiejar.New(nil) + Expect(err).Should(BeNil()) + + res, err := Request{ + Uri: ts.URL + "/setcookies", + CookieJar: jar, + }.Do() + + Expect(err).Should(BeNil()) + + Expect(res.Header.Get("Set-Cookie")).Should(Equal("foobar=42 ; Path=/")) + + cookies := jar.Cookies(uri) + Expect(len(cookies)).Should(Equal(1)) + + cookie := cookies[0] + Expect(*cookie).Should(Equal(http.Cookie{ + Name: "foobar", + Value: "42", + })) + }) + }) + + g.Describe("POST", func() { + g.It("Should send a string", func() { + res, err := Request{Method: "POST", Uri: ts.URL, Body: "foo"}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("foo")) + Expect(res.StatusCode).Should(Equal(201)) + Expect(res.Header.Get("Location")).Should(Equal(ts.URL + "/123")) + }) + + g.It("Should send a Reader", func() { + res, err := Request{Method: "POST", Uri: ts.URL, Body: strings.NewReader("foo")}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("foo")) + Expect(res.StatusCode).Should(Equal(201)) + Expect(res.Header.Get("Location")).Should(Equal(ts.URL + "/123")) + }) + + g.It("Send any object that is json encodable", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL, Body: obj}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + Expect(res.Header.Get("Location")).Should(Equal(ts.URL + "/123")) + }) + + g.It("Send an array of bytes", func() { + bdy := []byte{'f', 'o', 'o'} + res, err := Request{Method: "POST", Uri: ts.URL, Body: bdy}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("foo")) + Expect(res.StatusCode).Should(Equal(201)) + Expect(res.Header.Get("Location")).Should(Equal(ts.URL + "/123")) + }) + + g.It("Should return an error when body is not JSON encodable", func() { + res, err := Request{Method: "POST", Uri: ts.URL, Body: math.NaN()}.Do() + + Expect(res).Should(BeNil()) + Expect(err).ShouldNot(BeNil()) + }) + + g.It("Should do a POST with querystring", func() { + bdy := []byte{'f', 'o', 'o'} + res, err := Request{ + Method: "POST", + Uri: ts.URL + "/getquery", + Body: bdy, + QueryString: query, + }.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("/getquery?limit=3&skip=5")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should send body as gzip if compressed", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL + "/compressed", Body: obj, Compression: Gzip()}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + }) + + g.It("Should send body as deflate if compressed", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL + "/compressed_deflate", Body: obj, Compression: Deflate()}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + }) + + g.It("Should send body as deflate using zlib if compressed", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL + "/compressed_deflate", Body: obj, Compression: Zlib()}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + }) + + g.It("Should send body as gzip if compressed and parse return body", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL + "/compressed_and_return_compressed", Body: obj, Compression: Gzip()}.Do() + + Expect(err).Should(BeNil()) + b, _ := ioutil.ReadAll(res.Body) + Expect(string(b)).Should(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + }) + + g.It("Should send body as deflate if compressed and parse return body", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL + "/compressed_deflate_and_return_compressed", Body: obj, Compression: Deflate()}.Do() + + Expect(err).Should(BeNil()) + b, _ := ioutil.ReadAll(res.Body) + Expect(string(b)).Should(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + }) + + g.It("Should send body as deflate using zlib if compressed and parse return body", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL + "/compressed_deflate_and_return_compressed", Body: obj, Compression: Zlib()}.Do() + + Expect(err).Should(BeNil()) + b, _ := ioutil.ReadAll(res.Body) + Expect(string(b)).Should(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + }) + + g.It("Should send body as gzip if compressed and not parse return body if header not set ", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL + "/compressed_and_return_compressed_without_header", Body: obj, Compression: Gzip()}.Do() + + Expect(err).Should(BeNil()) + b, _ := ioutil.ReadAll(res.Body) + Expect(string(b)).ShouldNot(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + }) + + g.It("Should send body as deflate if compressed and not parse return body if header not set ", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL + "/compressed_deflate_and_return_compressed_without_header", Body: obj, Compression: Deflate()}.Do() + + Expect(err).Should(BeNil()) + b, _ := ioutil.ReadAll(res.Body) + Expect(string(b)).ShouldNot(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + }) + + g.It("Should send body as deflate using zlib if compressed and not parse return body if header not set ", func() { + obj := map[string]string{"foo": "bar"} + res, err := Request{Method: "POST", Uri: ts.URL + "/compressed_deflate_and_return_compressed_without_header", Body: obj, Compression: Zlib()}.Do() + + Expect(err).Should(BeNil()) + b, _ := ioutil.ReadAll(res.Body) + Expect(string(b)).ShouldNot(Equal(`{"foo":"bar"}`)) + Expect(res.StatusCode).Should(Equal(201)) + }) + }) + + g.It("Should do a PUT", func() { + res, err := Request{Method: "PUT", Uri: ts.URL + "/foo/123", Body: "foo"}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("foo")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should do a DELETE", func() { + res, err := Request{Method: "DELETE", Uri: ts.URL + "/foo/123"}.Do() + + Expect(err).Should(BeNil()) + Expect(res.StatusCode).Should(Equal(204)) + }) + + g.It("Should do a OPTIONS", func() { + res, err := Request{Method: "OPTIONS", Uri: ts.URL + "/foo"}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("bar")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should do a PATCH", func() { + res, err := Request{Method: "PATCH", Uri: ts.URL + "/foo"}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("bar")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should do a TRACE", func() { + res, err := Request{Method: "TRACE", Uri: ts.URL + "/foo"}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("bar")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should do a custom method", func() { + res, err := Request{Method: "FOOBAR", Uri: ts.URL + "/foo"}.Do() + + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(str).Should(Equal("bar")) + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.Describe("Responses", func() { + g.It("Should handle strings", func() { + res, _ := Request{Method: "POST", Uri: ts.URL, Body: "foo bar"}.Do() + + str, _ := res.Body.ToString() + Expect(str).Should(Equal("foo bar")) + }) + + g.It("Should handle io.ReaderCloser", func() { + res, _ := Request{Method: "POST", Uri: ts.URL, Body: "foo bar"}.Do() + + body, _ := ioutil.ReadAll(res.Body) + Expect(string(body)).Should(Equal("foo bar")) + }) + + g.It("Should handle parsing JSON", func() { + res, _ := Request{Method: "POST", Uri: ts.URL, Body: `{"foo": "bar"}`}.Do() + + var foobar map[string]string + + res.Body.FromJsonTo(&foobar) + + Expect(foobar).Should(Equal(map[string]string{"foo": "bar"})) + }) + + g.It("Should return the original request response", func() { + res, _ := Request{Method: "POST", Uri: ts.URL, Body: `{"foo": "bar"}`}.Do() + + Expect(res.Response).ShouldNot(BeNil()) + }) + }) + g.Describe("Redirects", func() { + g.It("Should not follow by default", func() { + res, _ := Request{ + Uri: ts.URL + "/redirect_test/301", + }.Do() + Expect(res.StatusCode).Should(Equal(301)) + }) + + g.It("Should not follow if method is explicitly specified", func() { + res, err := Request{ + Method: "GET", + Uri: ts.URL + "/redirect_test/301", + }.Do() + Expect(res.StatusCode).Should(Equal(301)) + Expect(err).ShouldNot(HaveOccurred()) + }) + + g.It("Should throw an error if MaxRedirect limit is exceeded", func() { + res, err := Request{ + Method: "GET", + MaxRedirects: 1, + Uri: ts.URL + "/redirect_test/301", + }.Do() + Expect(res.StatusCode).Should(Equal(302)) + Expect(err).Should(HaveOccurred()) + }) + + g.It("Should copy request headers headers when redirecting if specified", func() { + req := Request{ + Method: "GET", + Uri: ts.URL + "/redirect_test/301", + MaxRedirects: 4, + RedirectHeaders: true, + } + req.AddHeader("Testheader", "TestValue") + res, _ := req.Do() + Expect(res.StatusCode).Should(Equal(200)) + Expect(requestHeaders.Get("Testheader")).Should(Equal("TestValue")) + }) + + g.It("Should follow only specified number of MaxRedirects", func() { + res, _ := Request{ + Uri: ts.URL + "/redirect_test/301", + MaxRedirects: 1, + }.Do() + Expect(res.StatusCode).Should(Equal(302)) + res, _ = Request{ + Uri: ts.URL + "/redirect_test/301", + MaxRedirects: 2, + }.Do() + Expect(res.StatusCode).Should(Equal(303)) + res, _ = Request{ + Uri: ts.URL + "/redirect_test/301", + MaxRedirects: 3, + }.Do() + Expect(res.StatusCode).Should(Equal(307)) + res, _ = Request{ + Uri: ts.URL + "/redirect_test/301", + MaxRedirects: 4, + }.Do() + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should return final URL of the response when redirecting", func() { + res, _ := Request{ + Uri: ts.URL + "/redirect_test/destination", + MaxRedirects: 2, + }.Do() + Expect(res.Uri).Should(Equal(ts.URL + "/destination")) + }) + }) + }) + + g.Describe("Timeouts", func() { + + g.Describe("Connection timeouts", func() { + g.It("Should connect timeout after a default of 1000 ms", func() { + start := time.Now() + res, err := Request{Uri: "http://10.255.255.1"}.Do() + elapsed := time.Since(start) + + Expect(elapsed).Should(BeNumerically("<", 1100*time.Millisecond)) + Expect(elapsed).Should(BeNumerically(">=", 1000*time.Millisecond)) + Expect(res).Should(BeNil()) + Expect(err.(*Error).Timeout()).Should(BeTrue()) + }) + g.It("Should connect timeout after a custom amount of time", func() { + SetConnectTimeout(100 * time.Millisecond) + start := time.Now() + res, err := Request{Uri: "http://10.255.255.1"}.Do() + elapsed := time.Since(start) + + Expect(elapsed).Should(BeNumerically("<", 150*time.Millisecond)) + Expect(elapsed).Should(BeNumerically(">=", 100*time.Millisecond)) + Expect(res).Should(BeNil()) + Expect(err.(*Error).Timeout()).Should(BeTrue()) + }) + g.It("Should connect timeout after a custom amount of time even with method set", func() { + SetConnectTimeout(100 * time.Millisecond) + start := time.Now() + request := Request{ + Uri: "http://10.255.255.1", + Method: "GET", + } + res, err := request.Do() + elapsed := time.Since(start) + + Expect(elapsed).Should(BeNumerically("<", 150*time.Millisecond)) + Expect(elapsed).Should(BeNumerically(">=", 100*time.Millisecond)) + Expect(res).Should(BeNil()) + Expect(err.(*Error).Timeout()).Should(BeTrue()) + }) + }) + + g.Describe("Request timeout", func() { + var ts *httptest.Server + stop := make(chan bool) + + g.Before(func() { + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-stop + // just wait for someone to tell you when to end the request. this is used to simulate a slow server + })) + }) + g.After(func() { + stop <- true + ts.Close() + }) + g.It("Should request timeout after a custom amount of time", func() { + SetConnectTimeout(1000 * time.Millisecond) + + start := time.Now() + res, err := Request{Uri: ts.URL, Timeout: 500 * time.Millisecond}.Do() + elapsed := time.Since(start) + + Expect(elapsed).Should(BeNumerically("<", 550*time.Millisecond)) + Expect(elapsed).Should(BeNumerically(">=", 500*time.Millisecond)) + Expect(res).Should(BeNil()) + Expect(err.(*Error).Timeout()).Should(BeTrue()) + }) + }) + }) + + g.Describe("Misc", func() { + g.It("Should set default golang user agent when not explicitly passed", func() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Expect(r.Header.Get("User-Agent")).ShouldNot(BeZero()) + Expect(r.Host).Should(Equal("foobar.com")) + + w.WriteHeader(200) + })) + defer ts.Close() + + req := Request{Uri: ts.URL, Host: "foobar.com"} + res, err := req.Do() + Expect(err).ShouldNot(HaveOccurred()) + + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should offer to set request headers", func() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Expect(r.Header.Get("User-Agent")).Should(Equal("foobaragent")) + Expect(r.Host).Should(Equal("foobar.com")) + Expect(r.Header.Get("Accept")).Should(Equal("application/json")) + Expect(r.Header.Get("Content-Type")).Should(Equal("application/json")) + Expect(r.Header.Get("X-Custom")).Should(Equal("foobar")) + Expect(r.Header.Get("X-Custom2")).Should(Equal("barfoo")) + + w.WriteHeader(200) + })) + defer ts.Close() + + req := Request{Uri: ts.URL, Accept: "application/json", ContentType: "application/json", UserAgent: "foobaragent", Host: "foobar.com"} + req.AddHeader("X-Custom", "foobar") + res, _ := req.WithHeader("X-Custom2", "barfoo").Do() + + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should call hook before request", func() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + Expect(r.Header.Get("X-Custom")).Should(Equal("foobar")) + + w.WriteHeader(200) + })) + defer ts.Close() + + hook := func(goreq *Request, httpreq *http.Request) { + httpreq.Header.Add("X-Custom", "foobar") + } + req := Request{Uri: ts.URL, OnBeforeRequest: hook} + res, _ := req.Do() + + Expect(res.StatusCode).Should(Equal(200)) + }) + + g.It("Should not create a body by defualt", func() { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := ioutil.ReadAll(r.Body) + Expect(b).Should(HaveLen(0)) + w.WriteHeader(200) + })) + defer ts.Close() + + req := Request{Uri: ts.URL, Host: "foobar.com"} + req.Do() + }) + g.It("Should change transport TLS config if Request.Insecure is set", func() { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + req := Request{ + Insecure: true, + Uri: ts.URL, + Host: "foobar.com", + } + res, _ := req.Do() + + Expect(DefaultClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true)) + Expect(res.StatusCode).Should(Equal(200)) + }) + g.It("Should work if a different transport is specified", func() { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + var currentTransport = DefaultTransport + DefaultTransport = &http.Transport{Dial: DefaultDialer.Dial} + + req := Request{ + Insecure: true, + Uri: ts.URL, + Host: "foobar.com", + } + res, _ := req.Do() + + Expect(DefaultClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true)) + Expect(res.StatusCode).Should(Equal(200)) + + DefaultTransport = currentTransport + + }) + g.It("GetRequest should return the underlying httpRequest ", func() { + req := Request{ + Host: "foobar.com", + } + + request, _ := req.NewRequest() + Expect(request).ShouldNot(BeNil()) + Expect(request.Host).Should(Equal(req.Host)) + }) + + g.It("Response should allow to cancel in-flight request", func() { + unblockc := make(chan bool) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello") + w.(http.Flusher).Flush() + <-unblockc + })) + defer ts.Close() + defer close(unblockc) + + req := Request{ + Insecure: true, + Uri: ts.URL, + Host: "foobar.com", + } + res, _ := req.Do() + res.CancelRequest() + _, err := ioutil.ReadAll(res.Body) + g.Assert(err != nil).IsTrue() + }) + }) + + g.Describe("Errors", func() { + var ts *httptest.Server + + g.Before(func() { + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "POST" && r.URL.Path == "/" { + w.Header().Add("Location", ts.URL+"/123") + w.WriteHeader(201) + io.Copy(w, r.Body) + } + })) + }) + + g.After(func() { + ts.Close() + }) + g.It("Should throw an error when FromJsonTo fails", func() { + res, _ := Request{Method: "POST", Uri: ts.URL, Body: `{"foo" "bar"}`}.Do() + var foobar map[string]string + + err := res.Body.FromJsonTo(&foobar) + Expect(err).Should(HaveOccurred()) + }) + g.It("Should handle Url parsing errors", func() { + _, err := Request{Uri: ":"}.Do() + + Expect(err).ShouldNot(BeNil()) + }) + g.It("Should handle DNS errors", func() { + _, err := Request{Uri: "http://.localhost"}.Do() + Expect(err).ShouldNot(BeNil()) + }) + }) + + g.Describe("Proxy", func() { + var ts *httptest.Server + var lastReq *http.Request + g.Before(func() { + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && r.URL.Path == "/" { + lastReq = r + w.Header().Add("x-forwarded-for", "test") + w.Header().Add("Set-Cookie", "foo=bar") + w.WriteHeader(200) + w.Write([]byte("")) + } else if r.Method == "GET" && r.URL.Path == "/redirect_test/301" { + http.Redirect(w, r, "/", 301) + } + })) + + }) + + g.BeforeEach(func() { + lastReq = nil + }) + + g.After(func() { + ts.Close() + }) + + g.It("Should use Proxy", func() { + proxiedHost := "www.google.com" + res, err := Request{Uri: "http://" + proxiedHost, Proxy: ts.URL}.Do() + Expect(err).Should(BeNil()) + Expect(res.Header.Get("x-forwarded-for")).Should(Equal("test")) + Expect(lastReq).ShouldNot(BeNil()) + Expect(lastReq.Host).Should(Equal(proxiedHost)) + }) + + g.It("Should not redirect if MaxRedirects is not set", func() { + res, err := Request{Uri: ts.URL + "/redirect_test/301", Proxy: ts.URL}.Do() + Expect(err).ShouldNot(HaveOccurred()) + Expect(res.StatusCode).Should(Equal(301)) + }) + + g.It("Should use Proxy authentication", func() { + proxiedHost := "www.google.com" + uri := strings.Replace(ts.URL, "http://", "http://user:pass@", -1) + res, err := Request{Uri: "http://" + proxiedHost, Proxy: uri}.Do() + Expect(err).Should(BeNil()) + Expect(res.Header.Get("x-forwarded-for")).Should(Equal("test")) + Expect(lastReq).ShouldNot(BeNil()) + Expect(lastReq.Header.Get("Proxy-Authorization")).Should(Equal("Basic dXNlcjpwYXNz")) + }) + + g.It("Should propagate cookies", func() { + proxiedHost, _ := url.Parse("http://www.google.com") + jar, _ := cookiejar.New(nil) + res, err := Request{Uri: proxiedHost.String(), Proxy: ts.URL, CookieJar: jar}.Do() + Expect(err).Should(BeNil()) + Expect(res.Header.Get("x-forwarded-for")).Should(Equal("test")) + + Expect(jar.Cookies(proxiedHost)).Should(HaveLen(1)) + Expect(jar.Cookies(proxiedHost)[0].Name).Should(Equal("foo")) + Expect(jar.Cookies(proxiedHost)[0].Value).Should(Equal("bar")) + }) + + }) + + g.Describe("BasicAuth", func() { + var ts *httptest.Server + + g.Before(func() { + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/basic_auth" { + auth_array := r.Header["Authorization"] + if len(auth_array) > 0 { + auth := strings.TrimSpace(auth_array[0]) + w.WriteHeader(200) + fmt.Fprint(w, auth) + } else { + w.WriteHeader(401) + fmt.Fprint(w, "private") + } + } + })) + + }) + + g.After(func() { + ts.Close() + }) + + g.It("Should support basic http authorization", func() { + res, err := Request{ + Uri: ts.URL + "/basic_auth", + BasicAuthUsername: "username", + BasicAuthPassword: "password", + }.Do() + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(res.StatusCode).Should(Equal(200)) + expectedStr := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) + Expect(str).Should(Equal(expectedStr)) + }) + + g.It("Should fail when basic http authorization is required and not provided", func() { + res, err := Request{ + Uri: ts.URL + "/basic_auth", + }.Do() + Expect(err).Should(BeNil()) + str, _ := res.Body.ToString() + Expect(res.StatusCode).Should(Equal(401)) + Expect(str).Should(Equal("private")) + }) + }) + }) +} + +func Test_paramParse(t *testing.T) { + type Form struct { + A string + B string + c string + } + + type AnnotedForm struct { + Foo string `url:"foo_bar"` + Baz string `url:"bad,omitempty"` + Norf string `url:"norf,omitempty"` + Qux string `url:"-"` + } + + type EmbedForm struct { + AnnotedForm `url:",squash"` + Form `url:",squash"` + Corge string `url:"corge"` + } + + g := Goblin(t) + RegisterFailHandler(func(m string, _ ...int) { g.Fail(m) }) + var form = Form{} + var aform = AnnotedForm{} + var eform = EmbedForm{} + var values = url.Values{} + const result = "a=1&b=2" + g.Describe("QueryString ParamParse", func() { + g.Before(func() { + form.A = "1" + form.B = "2" + form.c = "3" + aform.Foo = "xyz" + aform.Norf = "abc" + aform.Qux = "def" + eform.Form = form + eform.AnnotedForm = aform + eform.Corge = "xxx" + values.Add("a", "1") + values.Add("b", "2") + }) + g.It("Should accept struct and ignores unexported field", func() { + str, err := paramParse(form) + Expect(err).Should(BeNil()) + Expect(str).Should(Equal(result)) + }) + g.It("Should accept struct and use the field annotations", func() { + str, err := paramParse(aform) + Expect(err).Should(BeNil()) + Expect(str).Should(Equal("foo_bar=xyz&norf=abc")) + }) + g.It("Should accept pointer of struct", func() { + str, err := paramParse(&form) + Expect(err).Should(BeNil()) + Expect(str).Should(Equal(result)) + }) + g.It("Should accept recursive pointer of struct", func() { + f := &form + ff := &f + str, err := paramParse(ff) + Expect(err).Should(BeNil()) + Expect(str).Should(Equal(result)) + }) + g.It("Should accept embedded struct", func() { + str, err := paramParse(eform) + Expect(err).Should(BeNil()) + Expect(str).Should(Equal("a=1&b=2&corge=xxx&foo_bar=xyz&norf=abc")) + }) + g.It("Should accept interface{} which forcely converted by struct", func() { + str, err := paramParse(interface{}(&form)) + Expect(err).Should(BeNil()) + Expect(str).Should(Equal(result)) + }) + + g.It("Should accept url.Values", func() { + str, err := paramParse(values) + Expect(err).Should(BeNil()) + Expect(str).Should(Equal(result)) + }) + g.It("Should accept &url.Values", func() { + str, err := paramParse(&values) + Expect(err).Should(BeNil()) + Expect(str).Should(Equal(result)) + }) + }) + +} diff --git a/Godeps/_workspace/src/github.com/franela/goreq/tags.go b/Godeps/_workspace/src/github.com/franela/goreq/tags.go new file mode 100644 index 0000000..ffe8789 --- /dev/null +++ b/Godeps/_workspace/src/github.com/franela/goreq/tags.go @@ -0,0 +1,64 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package goreq + +import ( + "strings" + "unicode" +) + +// tagOptions is the string following a comma in a struct field's "json" +// tag, or the empty string. It does not include the leading comma. +type tagOptions string + +// parseTag splits a struct field's json tag into its name and +// comma-separated options. +func parseTag(tag string) (string, tagOptions) { + if idx := strings.Index(tag, ","); idx != -1 { + return tag[:idx], tagOptions(tag[idx+1:]) + } + return tag, tagOptions("") +} + +// Contains reports whether a comma-separated list of options +// contains a particular substr flag. substr must be surrounded by a +// string boundary or commas. +func (o tagOptions) Contains(optionName string) bool { + if len(o) == 0 { + return false + } + s := string(o) + for s != "" { + var next string + i := strings.Index(s, ",") + if i >= 0 { + s, next = s[:i], s[i+1:] + } + if s == optionName { + return true + } + s = next + } + return false +} + +func isValidTag(s string) bool { + if s == "" { + return false + } + for _, c := range s { + switch { + case strings.ContainsRune("!#$%&()*+-./:<=>?@[]^_{|}~ ", c): + // Backslash and quote chars are reserved, but + // otherwise any punctuation chars are allowed + // in a tag name. + default: + if !unicode.IsLetter(c) && !unicode.IsDigit(c) { + return false + } + } + } + return true +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/Makefile b/Godeps/_workspace/src/github.com/golang/protobuf/proto/Makefile new file mode 100644 index 0000000..f1f0656 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/Makefile @@ -0,0 +1,43 @@ +# Go support for Protocol Buffers - Google's data interchange format +# +# Copyright 2010 The Go Authors. All rights reserved. +# https://github.com/golang/protobuf +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +install: + go install + +test: install generate-test-pbs + go test + + +generate-test-pbs: + make install + make -C testdata + protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata:. proto3_proto/proto3.proto + make diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/all_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/all_test.go new file mode 100644 index 0000000..49ce01f --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/all_test.go @@ -0,0 +1,2156 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "math" + "math/rand" + "reflect" + "runtime/debug" + "strings" + "testing" + "time" + + . "github.com/golang/protobuf/proto" + . "github.com/golang/protobuf/proto/testdata" +) + +var globalO *Buffer + +func old() *Buffer { + if globalO == nil { + globalO = NewBuffer(nil) + } + globalO.Reset() + return globalO +} + +func equalbytes(b1, b2 []byte, t *testing.T) { + if len(b1) != len(b2) { + t.Errorf("wrong lengths: 2*%d != %d", len(b1), len(b2)) + return + } + for i := 0; i < len(b1); i++ { + if b1[i] != b2[i] { + t.Errorf("bad byte[%d]:%x %x: %s %s", i, b1[i], b2[i], b1, b2) + } + } +} + +func initGoTestField() *GoTestField { + f := new(GoTestField) + f.Label = String("label") + f.Type = String("type") + return f +} + +// These are all structurally equivalent but the tag numbers differ. +// (It's remarkable that required, optional, and repeated all have +// 8 letters.) +func initGoTest_RequiredGroup() *GoTest_RequiredGroup { + return &GoTest_RequiredGroup{ + RequiredField: String("required"), + } +} + +func initGoTest_OptionalGroup() *GoTest_OptionalGroup { + return &GoTest_OptionalGroup{ + RequiredField: String("optional"), + } +} + +func initGoTest_RepeatedGroup() *GoTest_RepeatedGroup { + return &GoTest_RepeatedGroup{ + RequiredField: String("repeated"), + } +} + +func initGoTest(setdefaults bool) *GoTest { + pb := new(GoTest) + if setdefaults { + pb.F_BoolDefaulted = Bool(Default_GoTest_F_BoolDefaulted) + pb.F_Int32Defaulted = Int32(Default_GoTest_F_Int32Defaulted) + pb.F_Int64Defaulted = Int64(Default_GoTest_F_Int64Defaulted) + pb.F_Fixed32Defaulted = Uint32(Default_GoTest_F_Fixed32Defaulted) + pb.F_Fixed64Defaulted = Uint64(Default_GoTest_F_Fixed64Defaulted) + pb.F_Uint32Defaulted = Uint32(Default_GoTest_F_Uint32Defaulted) + pb.F_Uint64Defaulted = Uint64(Default_GoTest_F_Uint64Defaulted) + pb.F_FloatDefaulted = Float32(Default_GoTest_F_FloatDefaulted) + pb.F_DoubleDefaulted = Float64(Default_GoTest_F_DoubleDefaulted) + pb.F_StringDefaulted = String(Default_GoTest_F_StringDefaulted) + pb.F_BytesDefaulted = Default_GoTest_F_BytesDefaulted + pb.F_Sint32Defaulted = Int32(Default_GoTest_F_Sint32Defaulted) + pb.F_Sint64Defaulted = Int64(Default_GoTest_F_Sint64Defaulted) + } + + pb.Kind = GoTest_TIME.Enum() + pb.RequiredField = initGoTestField() + pb.F_BoolRequired = Bool(true) + pb.F_Int32Required = Int32(3) + pb.F_Int64Required = Int64(6) + pb.F_Fixed32Required = Uint32(32) + pb.F_Fixed64Required = Uint64(64) + pb.F_Uint32Required = Uint32(3232) + pb.F_Uint64Required = Uint64(6464) + pb.F_FloatRequired = Float32(3232) + pb.F_DoubleRequired = Float64(6464) + pb.F_StringRequired = String("string") + pb.F_BytesRequired = []byte("bytes") + pb.F_Sint32Required = Int32(-32) + pb.F_Sint64Required = Int64(-64) + pb.Requiredgroup = initGoTest_RequiredGroup() + + return pb +} + +func fail(msg string, b *bytes.Buffer, s string, t *testing.T) { + data := b.Bytes() + ld := len(data) + ls := len(s) / 2 + + fmt.Printf("fail %s ld=%d ls=%d\n", msg, ld, ls) + + // find the interesting spot - n + n := ls + if ld < ls { + n = ld + } + j := 0 + for i := 0; i < n; i++ { + bs := hex(s[j])*16 + hex(s[j+1]) + j += 2 + if data[i] == bs { + continue + } + n = i + break + } + l := n - 10 + if l < 0 { + l = 0 + } + h := n + 10 + + // find the interesting spot - n + fmt.Printf("is[%d]:", l) + for i := l; i < h; i++ { + if i >= ld { + fmt.Printf(" --") + continue + } + fmt.Printf(" %.2x", data[i]) + } + fmt.Printf("\n") + + fmt.Printf("sb[%d]:", l) + for i := l; i < h; i++ { + if i >= ls { + fmt.Printf(" --") + continue + } + bs := hex(s[j])*16 + hex(s[j+1]) + j += 2 + fmt.Printf(" %.2x", bs) + } + fmt.Printf("\n") + + t.Fail() + + // t.Errorf("%s: \ngood: %s\nbad: %x", msg, s, b.Bytes()) + // Print the output in a partially-decoded format; can + // be helpful when updating the test. It produces the output + // that is pasted, with minor edits, into the argument to verify(). + // data := b.Bytes() + // nesting := 0 + // for b.Len() > 0 { + // start := len(data) - b.Len() + // var u uint64 + // u, err := DecodeVarint(b) + // if err != nil { + // fmt.Printf("decode error on varint:", err) + // return + // } + // wire := u & 0x7 + // tag := u >> 3 + // switch wire { + // case WireVarint: + // v, err := DecodeVarint(b) + // if err != nil { + // fmt.Printf("decode error on varint:", err) + // return + // } + // fmt.Printf("\t\t\"%x\" // field %d, encoding %d, value %d\n", + // data[start:len(data)-b.Len()], tag, wire, v) + // case WireFixed32: + // v, err := DecodeFixed32(b) + // if err != nil { + // fmt.Printf("decode error on fixed32:", err) + // return + // } + // fmt.Printf("\t\t\"%x\" // field %d, encoding %d, value %d\n", + // data[start:len(data)-b.Len()], tag, wire, v) + // case WireFixed64: + // v, err := DecodeFixed64(b) + // if err != nil { + // fmt.Printf("decode error on fixed64:", err) + // return + // } + // fmt.Printf("\t\t\"%x\" // field %d, encoding %d, value %d\n", + // data[start:len(data)-b.Len()], tag, wire, v) + // case WireBytes: + // nb, err := DecodeVarint(b) + // if err != nil { + // fmt.Printf("decode error on bytes:", err) + // return + // } + // after_tag := len(data) - b.Len() + // str := make([]byte, nb) + // _, err = b.Read(str) + // if err != nil { + // fmt.Printf("decode error on bytes:", err) + // return + // } + // fmt.Printf("\t\t\"%x\" \"%x\" // field %d, encoding %d (FIELD)\n", + // data[start:after_tag], str, tag, wire) + // case WireStartGroup: + // nesting++ + // fmt.Printf("\t\t\"%x\"\t\t// start group field %d level %d\n", + // data[start:len(data)-b.Len()], tag, nesting) + // case WireEndGroup: + // fmt.Printf("\t\t\"%x\"\t\t// end group field %d level %d\n", + // data[start:len(data)-b.Len()], tag, nesting) + // nesting-- + // default: + // fmt.Printf("unrecognized wire type %d\n", wire) + // return + // } + // } +} + +func hex(c uint8) uint8 { + if '0' <= c && c <= '9' { + return c - '0' + } + if 'a' <= c && c <= 'f' { + return 10 + c - 'a' + } + if 'A' <= c && c <= 'F' { + return 10 + c - 'A' + } + return 0 +} + +func equal(b []byte, s string, t *testing.T) bool { + if 2*len(b) != len(s) { + // fail(fmt.Sprintf("wrong lengths: 2*%d != %d", len(b), len(s)), b, s, t) + fmt.Printf("wrong lengths: 2*%d != %d\n", len(b), len(s)) + return false + } + for i, j := 0, 0; i < len(b); i, j = i+1, j+2 { + x := hex(s[j])*16 + hex(s[j+1]) + if b[i] != x { + // fail(fmt.Sprintf("bad byte[%d]:%x %x", i, b[i], x), b, s, t) + fmt.Printf("bad byte[%d]:%x %x", i, b[i], x) + return false + } + } + return true +} + +func overify(t *testing.T, pb *GoTest, expected string) { + o := old() + err := o.Marshal(pb) + if err != nil { + fmt.Printf("overify marshal-1 err = %v", err) + o.DebugPrint("", o.Bytes()) + t.Fatalf("expected = %s", expected) + } + if !equal(o.Bytes(), expected, t) { + o.DebugPrint("overify neq 1", o.Bytes()) + t.Fatalf("expected = %s", expected) + } + + // Now test Unmarshal by recreating the original buffer. + pbd := new(GoTest) + err = o.Unmarshal(pbd) + if err != nil { + t.Fatalf("overify unmarshal err = %v", err) + o.DebugPrint("", o.Bytes()) + t.Fatalf("string = %s", expected) + } + o.Reset() + err = o.Marshal(pbd) + if err != nil { + t.Errorf("overify marshal-2 err = %v", err) + o.DebugPrint("", o.Bytes()) + t.Fatalf("string = %s", expected) + } + if !equal(o.Bytes(), expected, t) { + o.DebugPrint("overify neq 2", o.Bytes()) + t.Fatalf("string = %s", expected) + } +} + +// Simple tests for numeric encode/decode primitives (varint, etc.) +func TestNumericPrimitives(t *testing.T) { + for i := uint64(0); i < 1e6; i += 111 { + o := old() + if o.EncodeVarint(i) != nil { + t.Error("EncodeVarint") + break + } + x, e := o.DecodeVarint() + if e != nil { + t.Fatal("DecodeVarint") + } + if x != i { + t.Fatal("varint decode fail:", i, x) + } + + o = old() + if o.EncodeFixed32(i) != nil { + t.Fatal("encFixed32") + } + x, e = o.DecodeFixed32() + if e != nil { + t.Fatal("decFixed32") + } + if x != i { + t.Fatal("fixed32 decode fail:", i, x) + } + + o = old() + if o.EncodeFixed64(i*1234567) != nil { + t.Error("encFixed64") + break + } + x, e = o.DecodeFixed64() + if e != nil { + t.Error("decFixed64") + break + } + if x != i*1234567 { + t.Error("fixed64 decode fail:", i*1234567, x) + break + } + + o = old() + i32 := int32(i - 12345) + if o.EncodeZigzag32(uint64(i32)) != nil { + t.Fatal("EncodeZigzag32") + } + x, e = o.DecodeZigzag32() + if e != nil { + t.Fatal("DecodeZigzag32") + } + if x != uint64(uint32(i32)) { + t.Fatal("zigzag32 decode fail:", i32, x) + } + + o = old() + i64 := int64(i - 12345) + if o.EncodeZigzag64(uint64(i64)) != nil { + t.Fatal("EncodeZigzag64") + } + x, e = o.DecodeZigzag64() + if e != nil { + t.Fatal("DecodeZigzag64") + } + if x != uint64(i64) { + t.Fatal("zigzag64 decode fail:", i64, x) + } + } +} + +// fakeMarshaler is a simple struct implementing Marshaler and Message interfaces. +type fakeMarshaler struct { + b []byte + err error +} + +func (f *fakeMarshaler) Marshal() ([]byte, error) { return f.b, f.err } +func (f *fakeMarshaler) String() string { return fmt.Sprintf("Bytes: %v Error: %v", f.b, f.err) } +func (f *fakeMarshaler) ProtoMessage() {} +func (f *fakeMarshaler) Reset() {} + +type msgWithFakeMarshaler struct { + M *fakeMarshaler `protobuf:"bytes,1,opt,name=fake"` +} + +func (m *msgWithFakeMarshaler) String() string { return CompactTextString(m) } +func (m *msgWithFakeMarshaler) ProtoMessage() {} +func (m *msgWithFakeMarshaler) Reset() {} + +// Simple tests for proto messages that implement the Marshaler interface. +func TestMarshalerEncoding(t *testing.T) { + tests := []struct { + name string + m Message + want []byte + wantErr error + }{ + { + name: "Marshaler that fails", + m: &fakeMarshaler{ + err: errors.New("some marshal err"), + b: []byte{5, 6, 7}, + }, + // Since there's an error, nothing should be written to buffer. + want: nil, + wantErr: errors.New("some marshal err"), + }, + { + name: "Marshaler that fails with RequiredNotSetError", + m: &msgWithFakeMarshaler{ + M: &fakeMarshaler{ + err: &RequiredNotSetError{}, + b: []byte{5, 6, 7}, + }, + }, + // Since there's an error that can be continued after, + // the buffer should be written. + want: []byte{ + 10, 3, // for &msgWithFakeMarshaler + 5, 6, 7, // for &fakeMarshaler + }, + wantErr: &RequiredNotSetError{}, + }, + { + name: "Marshaler that succeeds", + m: &fakeMarshaler{ + b: []byte{0, 1, 2, 3, 4, 127, 255}, + }, + want: []byte{0, 1, 2, 3, 4, 127, 255}, + wantErr: nil, + }, + } + for _, test := range tests { + b := NewBuffer(nil) + err := b.Marshal(test.m) + if _, ok := err.(*RequiredNotSetError); ok { + // We're not in package proto, so we can only assert the type in this case. + err = &RequiredNotSetError{} + } + if !reflect.DeepEqual(test.wantErr, err) { + t.Errorf("%s: got err %v wanted %v", test.name, err, test.wantErr) + } + if !reflect.DeepEqual(test.want, b.Bytes()) { + t.Errorf("%s: got bytes %v wanted %v", test.name, b.Bytes(), test.want) + } + } +} + +// Simple tests for bytes +func TestBytesPrimitives(t *testing.T) { + o := old() + bytes := []byte{'n', 'o', 'w', ' ', 'i', 's', ' ', 't', 'h', 'e', ' ', 't', 'i', 'm', 'e'} + if o.EncodeRawBytes(bytes) != nil { + t.Error("EncodeRawBytes") + } + decb, e := o.DecodeRawBytes(false) + if e != nil { + t.Error("DecodeRawBytes") + } + equalbytes(bytes, decb, t) +} + +// Simple tests for strings +func TestStringPrimitives(t *testing.T) { + o := old() + s := "now is the time" + if o.EncodeStringBytes(s) != nil { + t.Error("enc_string") + } + decs, e := o.DecodeStringBytes() + if e != nil { + t.Error("dec_string") + } + if s != decs { + t.Error("string encode/decode fail:", s, decs) + } +} + +// Do we catch the "required bit not set" case? +func TestRequiredBit(t *testing.T) { + o := old() + pb := new(GoTest) + err := o.Marshal(pb) + if err == nil { + t.Error("did not catch missing required fields") + } else if strings.Index(err.Error(), "Kind") < 0 { + t.Error("wrong error type:", err) + } +} + +// Check that all fields are nil. +// Clearly silly, and a residue from a more interesting test with an earlier, +// different initialization property, but it once caught a compiler bug so +// it lives. +func checkInitialized(pb *GoTest, t *testing.T) { + if pb.F_BoolDefaulted != nil { + t.Error("New or Reset did not set boolean:", *pb.F_BoolDefaulted) + } + if pb.F_Int32Defaulted != nil { + t.Error("New or Reset did not set int32:", *pb.F_Int32Defaulted) + } + if pb.F_Int64Defaulted != nil { + t.Error("New or Reset did not set int64:", *pb.F_Int64Defaulted) + } + if pb.F_Fixed32Defaulted != nil { + t.Error("New or Reset did not set fixed32:", *pb.F_Fixed32Defaulted) + } + if pb.F_Fixed64Defaulted != nil { + t.Error("New or Reset did not set fixed64:", *pb.F_Fixed64Defaulted) + } + if pb.F_Uint32Defaulted != nil { + t.Error("New or Reset did not set uint32:", *pb.F_Uint32Defaulted) + } + if pb.F_Uint64Defaulted != nil { + t.Error("New or Reset did not set uint64:", *pb.F_Uint64Defaulted) + } + if pb.F_FloatDefaulted != nil { + t.Error("New or Reset did not set float:", *pb.F_FloatDefaulted) + } + if pb.F_DoubleDefaulted != nil { + t.Error("New or Reset did not set double:", *pb.F_DoubleDefaulted) + } + if pb.F_StringDefaulted != nil { + t.Error("New or Reset did not set string:", *pb.F_StringDefaulted) + } + if pb.F_BytesDefaulted != nil { + t.Error("New or Reset did not set bytes:", string(pb.F_BytesDefaulted)) + } + if pb.F_Sint32Defaulted != nil { + t.Error("New or Reset did not set int32:", *pb.F_Sint32Defaulted) + } + if pb.F_Sint64Defaulted != nil { + t.Error("New or Reset did not set int64:", *pb.F_Sint64Defaulted) + } +} + +// Does Reset() reset? +func TestReset(t *testing.T) { + pb := initGoTest(true) + // muck with some values + pb.F_BoolDefaulted = Bool(false) + pb.F_Int32Defaulted = Int32(237) + pb.F_Int64Defaulted = Int64(12346) + pb.F_Fixed32Defaulted = Uint32(32000) + pb.F_Fixed64Defaulted = Uint64(666) + pb.F_Uint32Defaulted = Uint32(323232) + pb.F_Uint64Defaulted = nil + pb.F_FloatDefaulted = nil + pb.F_DoubleDefaulted = Float64(0) + pb.F_StringDefaulted = String("gotcha") + pb.F_BytesDefaulted = []byte("asdfasdf") + pb.F_Sint32Defaulted = Int32(123) + pb.F_Sint64Defaulted = Int64(789) + pb.Reset() + checkInitialized(pb, t) +} + +// All required fields set, no defaults provided. +func TestEncodeDecode1(t *testing.T) { + pb := initGoTest(false) + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 0x20 + "714000000000000000"+ // field 14, encoding 1, value 0x40 + "78a019"+ // field 15, encoding 0, value 0xca0 = 3232 + "8001c032"+ // field 16, encoding 0, value 0x1940 = 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2, string "string" + "b304"+ // field 70, encoding 3, start group + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // field 70, encoding 4, end group + "aa0605"+"6279746573"+ // field 101, encoding 2, string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f") // field 103, encoding 0, 0x7f zigzag64 +} + +// All required fields set, defaults provided. +func TestEncodeDecode2(t *testing.T) { + pb := initGoTest(true) + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "c00201"+ // field 40, encoding 0, value 1 + "c80220"+ // field 41, encoding 0, value 32 + "d00240"+ // field 42, encoding 0, value 64 + "dd0240010000"+ // field 43, encoding 5, value 320 + "e1028002000000000000"+ // field 44, encoding 1, value 640 + "e8028019"+ // field 45, encoding 0, value 3200 + "f0028032"+ // field 46, encoding 0, value 6400 + "fd02e0659948"+ // field 47, encoding 5, value 314159.0 + "81030000000050971041"+ // field 48, encoding 1, value 271828.0 + "8a0310"+"68656c6c6f2c2022776f726c6421220a"+ // field 49, encoding 2 string "hello, \"world!\"\n" + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "8a1907"+"4269676e6f7365"+ // field 401, encoding 2, string "Bignose" + "90193f"+ // field 402, encoding 0, value 63 + "98197f") // field 403, encoding 0, value 127 + +} + +// All default fields set to their default value by hand +func TestEncodeDecode3(t *testing.T) { + pb := initGoTest(false) + pb.F_BoolDefaulted = Bool(true) + pb.F_Int32Defaulted = Int32(32) + pb.F_Int64Defaulted = Int64(64) + pb.F_Fixed32Defaulted = Uint32(320) + pb.F_Fixed64Defaulted = Uint64(640) + pb.F_Uint32Defaulted = Uint32(3200) + pb.F_Uint64Defaulted = Uint64(6400) + pb.F_FloatDefaulted = Float32(314159) + pb.F_DoubleDefaulted = Float64(271828) + pb.F_StringDefaulted = String("hello, \"world!\"\n") + pb.F_BytesDefaulted = []byte("Bignose") + pb.F_Sint32Defaulted = Int32(-32) + pb.F_Sint64Defaulted = Int64(-64) + + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "c00201"+ // field 40, encoding 0, value 1 + "c80220"+ // field 41, encoding 0, value 32 + "d00240"+ // field 42, encoding 0, value 64 + "dd0240010000"+ // field 43, encoding 5, value 320 + "e1028002000000000000"+ // field 44, encoding 1, value 640 + "e8028019"+ // field 45, encoding 0, value 3200 + "f0028032"+ // field 46, encoding 0, value 6400 + "fd02e0659948"+ // field 47, encoding 5, value 314159.0 + "81030000000050971041"+ // field 48, encoding 1, value 271828.0 + "8a0310"+"68656c6c6f2c2022776f726c6421220a"+ // field 49, encoding 2 string "hello, \"world!\"\n" + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "8a1907"+"4269676e6f7365"+ // field 401, encoding 2, string "Bignose" + "90193f"+ // field 402, encoding 0, value 63 + "98197f") // field 403, encoding 0, value 127 + +} + +// All required fields set, defaults provided, all non-defaulted optional fields have values. +func TestEncodeDecode4(t *testing.T) { + pb := initGoTest(true) + pb.Table = String("hello") + pb.Param = Int32(7) + pb.OptionalField = initGoTestField() + pb.F_BoolOptional = Bool(true) + pb.F_Int32Optional = Int32(32) + pb.F_Int64Optional = Int64(64) + pb.F_Fixed32Optional = Uint32(3232) + pb.F_Fixed64Optional = Uint64(6464) + pb.F_Uint32Optional = Uint32(323232) + pb.F_Uint64Optional = Uint64(646464) + pb.F_FloatOptional = Float32(32.) + pb.F_DoubleOptional = Float64(64.) + pb.F_StringOptional = String("hello") + pb.F_BytesOptional = []byte("Bignose") + pb.F_Sint32Optional = Int32(-32) + pb.F_Sint64Optional = Int64(-64) + pb.Optionalgroup = initGoTest_OptionalGroup() + + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "1205"+"68656c6c6f"+ // field 2, encoding 2, string "hello" + "1807"+ // field 3, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "320d"+"0a056c6162656c120474797065"+ // field 6, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "f00101"+ // field 30, encoding 0, value 1 + "f80120"+ // field 31, encoding 0, value 32 + "800240"+ // field 32, encoding 0, value 64 + "8d02a00c0000"+ // field 33, encoding 5, value 3232 + "91024019000000000000"+ // field 34, encoding 1, value 6464 + "9802a0dd13"+ // field 35, encoding 0, value 323232 + "a002c0ba27"+ // field 36, encoding 0, value 646464 + "ad0200000042"+ // field 37, encoding 5, value 32.0 + "b1020000000000005040"+ // field 38, encoding 1, value 64.0 + "ba0205"+"68656c6c6f"+ // field 39, encoding 2, string "hello" + "c00201"+ // field 40, encoding 0, value 1 + "c80220"+ // field 41, encoding 0, value 32 + "d00240"+ // field 42, encoding 0, value 64 + "dd0240010000"+ // field 43, encoding 5, value 320 + "e1028002000000000000"+ // field 44, encoding 1, value 640 + "e8028019"+ // field 45, encoding 0, value 3200 + "f0028032"+ // field 46, encoding 0, value 6400 + "fd02e0659948"+ // field 47, encoding 5, value 314159.0 + "81030000000050971041"+ // field 48, encoding 1, value 271828.0 + "8a0310"+"68656c6c6f2c2022776f726c6421220a"+ // field 49, encoding 2 string "hello, \"world!\"\n" + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "d305"+ // start group field 90 level 1 + "da0508"+"6f7074696f6e616c"+ // field 91, encoding 2, string "optional" + "d405"+ // end group field 90 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "ea1207"+"4269676e6f7365"+ // field 301, encoding 2, string "Bignose" + "f0123f"+ // field 302, encoding 0, value 63 + "f8127f"+ // field 303, encoding 0, value 127 + "8a1907"+"4269676e6f7365"+ // field 401, encoding 2, string "Bignose" + "90193f"+ // field 402, encoding 0, value 63 + "98197f") // field 403, encoding 0, value 127 + +} + +// All required fields set, defaults provided, all repeated fields given two values. +func TestEncodeDecode5(t *testing.T) { + pb := initGoTest(true) + pb.RepeatedField = []*GoTestField{initGoTestField(), initGoTestField()} + pb.F_BoolRepeated = []bool{false, true} + pb.F_Int32Repeated = []int32{32, 33} + pb.F_Int64Repeated = []int64{64, 65} + pb.F_Fixed32Repeated = []uint32{3232, 3333} + pb.F_Fixed64Repeated = []uint64{6464, 6565} + pb.F_Uint32Repeated = []uint32{323232, 333333} + pb.F_Uint64Repeated = []uint64{646464, 656565} + pb.F_FloatRepeated = []float32{32., 33.} + pb.F_DoubleRepeated = []float64{64., 65.} + pb.F_StringRepeated = []string{"hello", "sailor"} + pb.F_BytesRepeated = [][]byte{[]byte("big"), []byte("nose")} + pb.F_Sint32Repeated = []int32{32, -32} + pb.F_Sint64Repeated = []int64{64, -64} + pb.Repeatedgroup = []*GoTest_RepeatedGroup{initGoTest_RepeatedGroup(), initGoTest_RepeatedGroup()} + + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "2a0d"+"0a056c6162656c120474797065"+ // field 5, encoding 2 (GoTestField) + "2a0d"+"0a056c6162656c120474797065"+ // field 5, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "a00100"+ // field 20, encoding 0, value 0 + "a00101"+ // field 20, encoding 0, value 1 + "a80120"+ // field 21, encoding 0, value 32 + "a80121"+ // field 21, encoding 0, value 33 + "b00140"+ // field 22, encoding 0, value 64 + "b00141"+ // field 22, encoding 0, value 65 + "bd01a00c0000"+ // field 23, encoding 5, value 3232 + "bd01050d0000"+ // field 23, encoding 5, value 3333 + "c1014019000000000000"+ // field 24, encoding 1, value 6464 + "c101a519000000000000"+ // field 24, encoding 1, value 6565 + "c801a0dd13"+ // field 25, encoding 0, value 323232 + "c80195ac14"+ // field 25, encoding 0, value 333333 + "d001c0ba27"+ // field 26, encoding 0, value 646464 + "d001b58928"+ // field 26, encoding 0, value 656565 + "dd0100000042"+ // field 27, encoding 5, value 32.0 + "dd0100000442"+ // field 27, encoding 5, value 33.0 + "e1010000000000005040"+ // field 28, encoding 1, value 64.0 + "e1010000000000405040"+ // field 28, encoding 1, value 65.0 + "ea0105"+"68656c6c6f"+ // field 29, encoding 2, string "hello" + "ea0106"+"7361696c6f72"+ // field 29, encoding 2, string "sailor" + "c00201"+ // field 40, encoding 0, value 1 + "c80220"+ // field 41, encoding 0, value 32 + "d00240"+ // field 42, encoding 0, value 64 + "dd0240010000"+ // field 43, encoding 5, value 320 + "e1028002000000000000"+ // field 44, encoding 1, value 640 + "e8028019"+ // field 45, encoding 0, value 3200 + "f0028032"+ // field 46, encoding 0, value 6400 + "fd02e0659948"+ // field 47, encoding 5, value 314159.0 + "81030000000050971041"+ // field 48, encoding 1, value 271828.0 + "8a0310"+"68656c6c6f2c2022776f726c6421220a"+ // field 49, encoding 2 string "hello, \"world!\"\n" + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "8305"+ // start group field 80 level 1 + "8a0508"+"7265706561746564"+ // field 81, encoding 2, string "repeated" + "8405"+ // end group field 80 level 1 + "8305"+ // start group field 80 level 1 + "8a0508"+"7265706561746564"+ // field 81, encoding 2, string "repeated" + "8405"+ // end group field 80 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "ca0c03"+"626967"+ // field 201, encoding 2, string "big" + "ca0c04"+"6e6f7365"+ // field 201, encoding 2, string "nose" + "d00c40"+ // field 202, encoding 0, value 32 + "d00c3f"+ // field 202, encoding 0, value -32 + "d80c8001"+ // field 203, encoding 0, value 64 + "d80c7f"+ // field 203, encoding 0, value -64 + "8a1907"+"4269676e6f7365"+ // field 401, encoding 2, string "Bignose" + "90193f"+ // field 402, encoding 0, value 63 + "98197f") // field 403, encoding 0, value 127 + +} + +// All required fields set, all packed repeated fields given two values. +func TestEncodeDecode6(t *testing.T) { + pb := initGoTest(false) + pb.F_BoolRepeatedPacked = []bool{false, true} + pb.F_Int32RepeatedPacked = []int32{32, 33} + pb.F_Int64RepeatedPacked = []int64{64, 65} + pb.F_Fixed32RepeatedPacked = []uint32{3232, 3333} + pb.F_Fixed64RepeatedPacked = []uint64{6464, 6565} + pb.F_Uint32RepeatedPacked = []uint32{323232, 333333} + pb.F_Uint64RepeatedPacked = []uint64{646464, 656565} + pb.F_FloatRepeatedPacked = []float32{32., 33.} + pb.F_DoubleRepeatedPacked = []float64{64., 65.} + pb.F_Sint32RepeatedPacked = []int32{32, -32} + pb.F_Sint64RepeatedPacked = []int64{64, -64} + + overify(t, pb, + "0807"+ // field 1, encoding 0, value 7 + "220d"+"0a056c6162656c120474797065"+ // field 4, encoding 2 (GoTestField) + "5001"+ // field 10, encoding 0, value 1 + "5803"+ // field 11, encoding 0, value 3 + "6006"+ // field 12, encoding 0, value 6 + "6d20000000"+ // field 13, encoding 5, value 32 + "714000000000000000"+ // field 14, encoding 1, value 64 + "78a019"+ // field 15, encoding 0, value 3232 + "8001c032"+ // field 16, encoding 0, value 6464 + "8d0100004a45"+ // field 17, encoding 5, value 3232.0 + "9101000000000040b940"+ // field 18, encoding 1, value 6464.0 + "9a0106"+"737472696e67"+ // field 19, encoding 2 string "string" + "9203020001"+ // field 50, encoding 2, 2 bytes, value 0, value 1 + "9a03022021"+ // field 51, encoding 2, 2 bytes, value 32, value 33 + "a203024041"+ // field 52, encoding 2, 2 bytes, value 64, value 65 + "aa0308"+ // field 53, encoding 2, 8 bytes + "a00c0000050d0000"+ // value 3232, value 3333 + "b20310"+ // field 54, encoding 2, 16 bytes + "4019000000000000a519000000000000"+ // value 6464, value 6565 + "ba0306"+ // field 55, encoding 2, 6 bytes + "a0dd1395ac14"+ // value 323232, value 333333 + "c20306"+ // field 56, encoding 2, 6 bytes + "c0ba27b58928"+ // value 646464, value 656565 + "ca0308"+ // field 57, encoding 2, 8 bytes + "0000004200000442"+ // value 32.0, value 33.0 + "d20310"+ // field 58, encoding 2, 16 bytes + "00000000000050400000000000405040"+ // value 64.0, value 65.0 + "b304"+ // start group field 70 level 1 + "ba0408"+"7265717569726564"+ // field 71, encoding 2, string "required" + "b404"+ // end group field 70 level 1 + "aa0605"+"6279746573"+ // field 101, encoding 2 string "bytes" + "b0063f"+ // field 102, encoding 0, 0x3f zigzag32 + "b8067f"+ // field 103, encoding 0, 0x7f zigzag64 + "b21f02"+ // field 502, encoding 2, 2 bytes + "403f"+ // value 32, value -32 + "ba1f03"+ // field 503, encoding 2, 3 bytes + "80017f") // value 64, value -64 +} + +// Test that we can encode empty bytes fields. +func TestEncodeDecodeBytes1(t *testing.T) { + pb := initGoTest(false) + + // Create our bytes + pb.F_BytesRequired = []byte{} + pb.F_BytesRepeated = [][]byte{{}} + pb.F_BytesOptional = []byte{} + + d, err := Marshal(pb) + if err != nil { + t.Error(err) + } + + pbd := new(GoTest) + if err := Unmarshal(d, pbd); err != nil { + t.Error(err) + } + + if pbd.F_BytesRequired == nil || len(pbd.F_BytesRequired) != 0 { + t.Error("required empty bytes field is incorrect") + } + if pbd.F_BytesRepeated == nil || len(pbd.F_BytesRepeated) == 1 && pbd.F_BytesRepeated[0] == nil { + t.Error("repeated empty bytes field is incorrect") + } + if pbd.F_BytesOptional == nil || len(pbd.F_BytesOptional) != 0 { + t.Error("optional empty bytes field is incorrect") + } +} + +// Test that we encode nil-valued fields of a repeated bytes field correctly. +// Since entries in a repeated field cannot be nil, nil must mean empty value. +func TestEncodeDecodeBytes2(t *testing.T) { + pb := initGoTest(false) + + // Create our bytes + pb.F_BytesRepeated = [][]byte{nil} + + d, err := Marshal(pb) + if err != nil { + t.Error(err) + } + + pbd := new(GoTest) + if err := Unmarshal(d, pbd); err != nil { + t.Error(err) + } + + if len(pbd.F_BytesRepeated) != 1 || pbd.F_BytesRepeated[0] == nil { + t.Error("Unexpected value for repeated bytes field") + } +} + +// All required fields set, defaults provided, all repeated fields given two values. +func TestSkippingUnrecognizedFields(t *testing.T) { + o := old() + pb := initGoTestField() + + // Marshal it normally. + o.Marshal(pb) + + // Now new a GoSkipTest record. + skip := &GoSkipTest{ + SkipInt32: Int32(32), + SkipFixed32: Uint32(3232), + SkipFixed64: Uint64(6464), + SkipString: String("skipper"), + Skipgroup: &GoSkipTest_SkipGroup{ + GroupInt32: Int32(75), + GroupString: String("wxyz"), + }, + } + + // Marshal it into same buffer. + o.Marshal(skip) + + pbd := new(GoTestField) + o.Unmarshal(pbd) + + // The __unrecognized field should be a marshaling of GoSkipTest + skipd := new(GoSkipTest) + + o.SetBuf(pbd.XXX_unrecognized) + o.Unmarshal(skipd) + + if *skipd.SkipInt32 != *skip.SkipInt32 { + t.Error("skip int32", skipd.SkipInt32) + } + if *skipd.SkipFixed32 != *skip.SkipFixed32 { + t.Error("skip fixed32", skipd.SkipFixed32) + } + if *skipd.SkipFixed64 != *skip.SkipFixed64 { + t.Error("skip fixed64", skipd.SkipFixed64) + } + if *skipd.SkipString != *skip.SkipString { + t.Error("skip string", *skipd.SkipString) + } + if *skipd.Skipgroup.GroupInt32 != *skip.Skipgroup.GroupInt32 { + t.Error("skip group int32", skipd.Skipgroup.GroupInt32) + } + if *skipd.Skipgroup.GroupString != *skip.Skipgroup.GroupString { + t.Error("skip group string", *skipd.Skipgroup.GroupString) + } +} + +// Check that unrecognized fields of a submessage are preserved. +func TestSubmessageUnrecognizedFields(t *testing.T) { + nm := &NewMessage{ + Nested: &NewMessage_Nested{ + Name: String("Nigel"), + FoodGroup: String("carbs"), + }, + } + b, err := Marshal(nm) + if err != nil { + t.Fatalf("Marshal of NewMessage: %v", err) + } + + // Unmarshal into an OldMessage. + om := new(OldMessage) + if err := Unmarshal(b, om); err != nil { + t.Fatalf("Unmarshal to OldMessage: %v", err) + } + exp := &OldMessage{ + Nested: &OldMessage_Nested{ + Name: String("Nigel"), + // normal protocol buffer users should not do this + XXX_unrecognized: []byte("\x12\x05carbs"), + }, + } + if !Equal(om, exp) { + t.Errorf("om = %v, want %v", om, exp) + } + + // Clone the OldMessage. + om = Clone(om).(*OldMessage) + if !Equal(om, exp) { + t.Errorf("Clone(om) = %v, want %v", om, exp) + } + + // Marshal the OldMessage, then unmarshal it into an empty NewMessage. + if b, err = Marshal(om); err != nil { + t.Fatalf("Marshal of OldMessage: %v", err) + } + t.Logf("Marshal(%v) -> %q", om, b) + nm2 := new(NewMessage) + if err := Unmarshal(b, nm2); err != nil { + t.Fatalf("Unmarshal to NewMessage: %v", err) + } + if !Equal(nm, nm2) { + t.Errorf("NewMessage round-trip: %v => %v", nm, nm2) + } +} + +// Check that an int32 field can be upgraded to an int64 field. +func TestNegativeInt32(t *testing.T) { + om := &OldMessage{ + Num: Int32(-1), + } + b, err := Marshal(om) + if err != nil { + t.Fatalf("Marshal of OldMessage: %v", err) + } + + // Check the size. It should be 11 bytes; + // 1 for the field/wire type, and 10 for the negative number. + if len(b) != 11 { + t.Errorf("%v marshaled as %q, wanted 11 bytes", om, b) + } + + // Unmarshal into a NewMessage. + nm := new(NewMessage) + if err := Unmarshal(b, nm); err != nil { + t.Fatalf("Unmarshal to NewMessage: %v", err) + } + want := &NewMessage{ + Num: Int64(-1), + } + if !Equal(nm, want) { + t.Errorf("nm = %v, want %v", nm, want) + } +} + +// Check that we can grow an array (repeated field) to have many elements. +// This test doesn't depend only on our encoding; for variety, it makes sure +// we create, encode, and decode the correct contents explicitly. It's therefore +// a bit messier. +// This test also uses (and hence tests) the Marshal/Unmarshal functions +// instead of the methods. +func TestBigRepeated(t *testing.T) { + pb := initGoTest(true) + + // Create the arrays + const N = 50 // Internally the library starts much smaller. + pb.Repeatedgroup = make([]*GoTest_RepeatedGroup, N) + pb.F_Sint64Repeated = make([]int64, N) + pb.F_Sint32Repeated = make([]int32, N) + pb.F_BytesRepeated = make([][]byte, N) + pb.F_StringRepeated = make([]string, N) + pb.F_DoubleRepeated = make([]float64, N) + pb.F_FloatRepeated = make([]float32, N) + pb.F_Uint64Repeated = make([]uint64, N) + pb.F_Uint32Repeated = make([]uint32, N) + pb.F_Fixed64Repeated = make([]uint64, N) + pb.F_Fixed32Repeated = make([]uint32, N) + pb.F_Int64Repeated = make([]int64, N) + pb.F_Int32Repeated = make([]int32, N) + pb.F_BoolRepeated = make([]bool, N) + pb.RepeatedField = make([]*GoTestField, N) + + // Fill in the arrays with checkable values. + igtf := initGoTestField() + igtrg := initGoTest_RepeatedGroup() + for i := 0; i < N; i++ { + pb.Repeatedgroup[i] = igtrg + pb.F_Sint64Repeated[i] = int64(i) + pb.F_Sint32Repeated[i] = int32(i) + s := fmt.Sprint(i) + pb.F_BytesRepeated[i] = []byte(s) + pb.F_StringRepeated[i] = s + pb.F_DoubleRepeated[i] = float64(i) + pb.F_FloatRepeated[i] = float32(i) + pb.F_Uint64Repeated[i] = uint64(i) + pb.F_Uint32Repeated[i] = uint32(i) + pb.F_Fixed64Repeated[i] = uint64(i) + pb.F_Fixed32Repeated[i] = uint32(i) + pb.F_Int64Repeated[i] = int64(i) + pb.F_Int32Repeated[i] = int32(i) + pb.F_BoolRepeated[i] = i%2 == 0 + pb.RepeatedField[i] = igtf + } + + // Marshal. + buf, _ := Marshal(pb) + + // Now test Unmarshal by recreating the original buffer. + pbd := new(GoTest) + Unmarshal(buf, pbd) + + // Check the checkable values + for i := uint64(0); i < N; i++ { + if pbd.Repeatedgroup[i] == nil { // TODO: more checking? + t.Error("pbd.Repeatedgroup bad") + } + var x uint64 + x = uint64(pbd.F_Sint64Repeated[i]) + if x != i { + t.Error("pbd.F_Sint64Repeated bad", x, i) + } + x = uint64(pbd.F_Sint32Repeated[i]) + if x != i { + t.Error("pbd.F_Sint32Repeated bad", x, i) + } + s := fmt.Sprint(i) + equalbytes(pbd.F_BytesRepeated[i], []byte(s), t) + if pbd.F_StringRepeated[i] != s { + t.Error("pbd.F_Sint32Repeated bad", pbd.F_StringRepeated[i], i) + } + x = uint64(pbd.F_DoubleRepeated[i]) + if x != i { + t.Error("pbd.F_DoubleRepeated bad", x, i) + } + x = uint64(pbd.F_FloatRepeated[i]) + if x != i { + t.Error("pbd.F_FloatRepeated bad", x, i) + } + x = pbd.F_Uint64Repeated[i] + if x != i { + t.Error("pbd.F_Uint64Repeated bad", x, i) + } + x = uint64(pbd.F_Uint32Repeated[i]) + if x != i { + t.Error("pbd.F_Uint32Repeated bad", x, i) + } + x = pbd.F_Fixed64Repeated[i] + if x != i { + t.Error("pbd.F_Fixed64Repeated bad", x, i) + } + x = uint64(pbd.F_Fixed32Repeated[i]) + if x != i { + t.Error("pbd.F_Fixed32Repeated bad", x, i) + } + x = uint64(pbd.F_Int64Repeated[i]) + if x != i { + t.Error("pbd.F_Int64Repeated bad", x, i) + } + x = uint64(pbd.F_Int32Repeated[i]) + if x != i { + t.Error("pbd.F_Int32Repeated bad", x, i) + } + if pbd.F_BoolRepeated[i] != (i%2 == 0) { + t.Error("pbd.F_BoolRepeated bad", x, i) + } + if pbd.RepeatedField[i] == nil { // TODO: more checking? + t.Error("pbd.RepeatedField bad") + } + } +} + +// Verify we give a useful message when decoding to the wrong structure type. +func TestTypeMismatch(t *testing.T) { + pb1 := initGoTest(true) + + // Marshal + o := old() + o.Marshal(pb1) + + // Now Unmarshal it to the wrong type. + pb2 := initGoTestField() + err := o.Unmarshal(pb2) + if err == nil { + t.Error("expected error, got no error") + } else if !strings.Contains(err.Error(), "bad wiretype") { + t.Error("expected bad wiretype error, got", err) + } +} + +func encodeDecode(t *testing.T, in, out Message, msg string) { + buf, err := Marshal(in) + if err != nil { + t.Fatalf("failed marshaling %v: %v", msg, err) + } + if err := Unmarshal(buf, out); err != nil { + t.Fatalf("failed unmarshaling %v: %v", msg, err) + } +} + +func TestPackedNonPackedDecoderSwitching(t *testing.T) { + np, p := new(NonPackedTest), new(PackedTest) + + // non-packed -> packed + np.A = []int32{0, 1, 1, 2, 3, 5} + encodeDecode(t, np, p, "non-packed -> packed") + if !reflect.DeepEqual(np.A, p.B) { + t.Errorf("failed non-packed -> packed; np.A=%+v, p.B=%+v", np.A, p.B) + } + + // packed -> non-packed + np.Reset() + p.B = []int32{3, 1, 4, 1, 5, 9} + encodeDecode(t, p, np, "packed -> non-packed") + if !reflect.DeepEqual(p.B, np.A) { + t.Errorf("failed packed -> non-packed; p.B=%+v, np.A=%+v", p.B, np.A) + } +} + +func TestProto1RepeatedGroup(t *testing.T) { + pb := &MessageList{ + Message: []*MessageList_Message{ + { + Name: String("blah"), + Count: Int32(7), + }, + // NOTE: pb.Message[1] is a nil + nil, + }, + } + + o := old() + err := o.Marshal(pb) + if err == nil || !strings.Contains(err.Error(), "repeated field Message has nil") { + t.Fatalf("unexpected or no error when marshaling: %v", err) + } +} + +// Test that enums work. Checks for a bug introduced by making enums +// named types instead of int32: newInt32FromUint64 would crash with +// a type mismatch in reflect.PointTo. +func TestEnum(t *testing.T) { + pb := new(GoEnum) + pb.Foo = FOO_FOO1.Enum() + o := old() + if err := o.Marshal(pb); err != nil { + t.Fatal("error encoding enum:", err) + } + pb1 := new(GoEnum) + if err := o.Unmarshal(pb1); err != nil { + t.Fatal("error decoding enum:", err) + } + if *pb1.Foo != FOO_FOO1 { + t.Error("expected 7 but got ", *pb1.Foo) + } +} + +// Enum types have String methods. Check that enum fields can be printed. +// We don't care what the value actually is, just as long as it doesn't crash. +func TestPrintingNilEnumFields(t *testing.T) { + pb := new(GoEnum) + fmt.Sprintf("%+v", pb) +} + +// Verify that absent required fields cause Marshal/Unmarshal to return errors. +func TestRequiredFieldEnforcement(t *testing.T) { + pb := new(GoTestField) + _, err := Marshal(pb) + if err == nil { + t.Error("marshal: expected error, got nil") + } else if strings.Index(err.Error(), "Label") < 0 { + t.Errorf("marshal: bad error type: %v", err) + } + + // A slightly sneaky, yet valid, proto. It encodes the same required field twice, + // so simply counting the required fields is insufficient. + // field 1, encoding 2, value "hi" + buf := []byte("\x0A\x02hi\x0A\x02hi") + err = Unmarshal(buf, pb) + if err == nil { + t.Error("unmarshal: expected error, got nil") + } else if strings.Index(err.Error(), "{Unknown}") < 0 { + t.Errorf("unmarshal: bad error type: %v", err) + } +} + +func TestTypedNilMarshal(t *testing.T) { + // A typed nil should return ErrNil and not crash. + _, err := Marshal((*GoEnum)(nil)) + if err != ErrNil { + t.Errorf("Marshal: got err %v, want ErrNil", err) + } +} + +// A type that implements the Marshaler interface, but is not nillable. +type nonNillableInt uint64 + +func (nni nonNillableInt) Marshal() ([]byte, error) { + return EncodeVarint(uint64(nni)), nil +} + +type NNIMessage struct { + nni nonNillableInt +} + +func (*NNIMessage) Reset() {} +func (*NNIMessage) String() string { return "" } +func (*NNIMessage) ProtoMessage() {} + +// A type that implements the Marshaler interface and is nillable. +type nillableMessage struct { + x uint64 +} + +func (nm *nillableMessage) Marshal() ([]byte, error) { + return EncodeVarint(nm.x), nil +} + +type NMMessage struct { + nm *nillableMessage +} + +func (*NMMessage) Reset() {} +func (*NMMessage) String() string { return "" } +func (*NMMessage) ProtoMessage() {} + +// Verify a type that uses the Marshaler interface, but has a nil pointer. +func TestNilMarshaler(t *testing.T) { + // Try a struct with a Marshaler field that is nil. + // It should be directly marshable. + nmm := new(NMMessage) + if _, err := Marshal(nmm); err != nil { + t.Error("unexpected error marshaling nmm: ", err) + } + + // Try a struct with a Marshaler field that is not nillable. + nnim := new(NNIMessage) + nnim.nni = 7 + var _ Marshaler = nnim.nni // verify it is truly a Marshaler + if _, err := Marshal(nnim); err != nil { + t.Error("unexpected error marshaling nnim: ", err) + } +} + +func TestAllSetDefaults(t *testing.T) { + // Exercise SetDefaults with all scalar field types. + m := &Defaults{ + // NaN != NaN, so override that here. + F_Nan: Float32(1.7), + } + expected := &Defaults{ + F_Bool: Bool(true), + F_Int32: Int32(32), + F_Int64: Int64(64), + F_Fixed32: Uint32(320), + F_Fixed64: Uint64(640), + F_Uint32: Uint32(3200), + F_Uint64: Uint64(6400), + F_Float: Float32(314159), + F_Double: Float64(271828), + F_String: String(`hello, "world!"` + "\n"), + F_Bytes: []byte("Bignose"), + F_Sint32: Int32(-32), + F_Sint64: Int64(-64), + F_Enum: Defaults_GREEN.Enum(), + F_Pinf: Float32(float32(math.Inf(1))), + F_Ninf: Float32(float32(math.Inf(-1))), + F_Nan: Float32(1.7), + StrZero: String(""), + } + SetDefaults(m) + if !Equal(m, expected) { + t.Errorf("SetDefaults failed\n got %v\nwant %v", m, expected) + } +} + +func TestSetDefaultsWithSetField(t *testing.T) { + // Check that a set value is not overridden. + m := &Defaults{ + F_Int32: Int32(12), + } + SetDefaults(m) + if v := m.GetF_Int32(); v != 12 { + t.Errorf("m.FInt32 = %v, want 12", v) + } +} + +func TestSetDefaultsWithSubMessage(t *testing.T) { + m := &OtherMessage{ + Key: Int64(123), + Inner: &InnerMessage{ + Host: String("gopher"), + }, + } + expected := &OtherMessage{ + Key: Int64(123), + Inner: &InnerMessage{ + Host: String("gopher"), + Port: Int32(4000), + }, + } + SetDefaults(m) + if !Equal(m, expected) { + t.Errorf("\n got %v\nwant %v", m, expected) + } +} + +func TestSetDefaultsWithRepeatedSubMessage(t *testing.T) { + m := &MyMessage{ + RepInner: []*InnerMessage{{}}, + } + expected := &MyMessage{ + RepInner: []*InnerMessage{{ + Port: Int32(4000), + }}, + } + SetDefaults(m) + if !Equal(m, expected) { + t.Errorf("\n got %v\nwant %v", m, expected) + } +} + +func TestSetDefaultWithRepeatedNonMessage(t *testing.T) { + m := &MyMessage{ + Pet: []string{"turtle", "wombat"}, + } + expected := Clone(m) + SetDefaults(m) + if !Equal(m, expected) { + t.Errorf("\n got %v\nwant %v", m, expected) + } +} + +func TestMaximumTagNumber(t *testing.T) { + m := &MaxTag{ + LastField: String("natural goat essence"), + } + buf, err := Marshal(m) + if err != nil { + t.Fatalf("proto.Marshal failed: %v", err) + } + m2 := new(MaxTag) + if err := Unmarshal(buf, m2); err != nil { + t.Fatalf("proto.Unmarshal failed: %v", err) + } + if got, want := m2.GetLastField(), *m.LastField; got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestJSON(t *testing.T) { + m := &MyMessage{ + Count: Int32(4), + Pet: []string{"bunny", "kitty"}, + Inner: &InnerMessage{ + Host: String("cauchy"), + }, + Bikeshed: MyMessage_GREEN.Enum(), + } + const expected = `{"count":4,"pet":["bunny","kitty"],"inner":{"host":"cauchy"},"bikeshed":1}` + + b, err := json.Marshal(m) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + s := string(b) + if s != expected { + t.Errorf("got %s\nwant %s", s, expected) + } + + received := new(MyMessage) + if err := json.Unmarshal(b, received); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + if !Equal(received, m) { + t.Fatalf("got %s, want %s", received, m) + } + + // Test unmarshalling of JSON with symbolic enum name. + const old = `{"count":4,"pet":["bunny","kitty"],"inner":{"host":"cauchy"},"bikeshed":"GREEN"}` + received.Reset() + if err := json.Unmarshal([]byte(old), received); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + if !Equal(received, m) { + t.Fatalf("got %s, want %s", received, m) + } +} + +func TestBadWireType(t *testing.T) { + b := []byte{7<<3 | 6} // field 7, wire type 6 + pb := new(OtherMessage) + if err := Unmarshal(b, pb); err == nil { + t.Errorf("Unmarshal did not fail") + } else if !strings.Contains(err.Error(), "unknown wire type") { + t.Errorf("wrong error: %v", err) + } +} + +func TestBytesWithInvalidLength(t *testing.T) { + // If a byte sequence has an invalid (negative) length, Unmarshal should not panic. + b := []byte{2<<3 | WireBytes, 0xff, 0xff, 0xff, 0xff, 0xff, 0} + Unmarshal(b, new(MyMessage)) +} + +func TestLengthOverflow(t *testing.T) { + // Overflowing a length should not panic. + b := []byte{2<<3 | WireBytes, 1, 1, 3<<3 | WireBytes, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, 0x01} + Unmarshal(b, new(MyMessage)) +} + +func TestVarintOverflow(t *testing.T) { + // Overflowing a 64-bit length should not be allowed. + b := []byte{1<<3 | WireVarint, 0x01, 3<<3 | WireBytes, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01} + if err := Unmarshal(b, new(MyMessage)); err == nil { + t.Fatalf("Overflowed uint64 length without error") + } +} + +func TestUnmarshalFuzz(t *testing.T) { + const N = 1000 + seed := time.Now().UnixNano() + t.Logf("RNG seed is %d", seed) + rng := rand.New(rand.NewSource(seed)) + buf := make([]byte, 20) + for i := 0; i < N; i++ { + for j := range buf { + buf[j] = byte(rng.Intn(256)) + } + fuzzUnmarshal(t, buf) + } +} + +func TestMergeMessages(t *testing.T) { + pb := &MessageList{Message: []*MessageList_Message{{Name: String("x"), Count: Int32(1)}}} + data, err := Marshal(pb) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + pb1 := new(MessageList) + if err := Unmarshal(data, pb1); err != nil { + t.Fatalf("first Unmarshal: %v", err) + } + if err := Unmarshal(data, pb1); err != nil { + t.Fatalf("second Unmarshal: %v", err) + } + if len(pb1.Message) != 1 { + t.Errorf("two Unmarshals produced %d Messages, want 1", len(pb1.Message)) + } + + pb2 := new(MessageList) + if err := UnmarshalMerge(data, pb2); err != nil { + t.Fatalf("first UnmarshalMerge: %v", err) + } + if err := UnmarshalMerge(data, pb2); err != nil { + t.Fatalf("second UnmarshalMerge: %v", err) + } + if len(pb2.Message) != 2 { + t.Errorf("two UnmarshalMerges produced %d Messages, want 2", len(pb2.Message)) + } +} + +func TestExtensionMarshalOrder(t *testing.T) { + m := &MyMessage{Count: Int(123)} + if err := SetExtension(m, E_Ext_More, &Ext{Data: String("alpha")}); err != nil { + t.Fatalf("SetExtension: %v", err) + } + if err := SetExtension(m, E_Ext_Text, String("aleph")); err != nil { + t.Fatalf("SetExtension: %v", err) + } + if err := SetExtension(m, E_Ext_Number, Int32(1)); err != nil { + t.Fatalf("SetExtension: %v", err) + } + + // Serialize m several times, and check we get the same bytes each time. + var orig []byte + for i := 0; i < 100; i++ { + b, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if i == 0 { + orig = b + continue + } + if !bytes.Equal(b, orig) { + t.Errorf("Bytes differ on attempt #%d", i) + } + } +} + +// Many extensions, because small maps might not iterate differently on each iteration. +var exts = []*ExtensionDesc{ + E_X201, + E_X202, + E_X203, + E_X204, + E_X205, + E_X206, + E_X207, + E_X208, + E_X209, + E_X210, + E_X211, + E_X212, + E_X213, + E_X214, + E_X215, + E_X216, + E_X217, + E_X218, + E_X219, + E_X220, + E_X221, + E_X222, + E_X223, + E_X224, + E_X225, + E_X226, + E_X227, + E_X228, + E_X229, + E_X230, + E_X231, + E_X232, + E_X233, + E_X234, + E_X235, + E_X236, + E_X237, + E_X238, + E_X239, + E_X240, + E_X241, + E_X242, + E_X243, + E_X244, + E_X245, + E_X246, + E_X247, + E_X248, + E_X249, + E_X250, +} + +func TestMessageSetMarshalOrder(t *testing.T) { + m := &MyMessageSet{} + for _, x := range exts { + if err := SetExtension(m, x, &Empty{}); err != nil { + t.Fatalf("SetExtension: %v", err) + } + } + + buf, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + // Serialize m several times, and check we get the same bytes each time. + for i := 0; i < 10; i++ { + b1, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + if !bytes.Equal(b1, buf) { + t.Errorf("Bytes differ on re-Marshal #%d", i) + } + + m2 := &MyMessageSet{} + if err := Unmarshal(buf, m2); err != nil { + t.Errorf("Unmarshal: %v", err) + } + b2, err := Marshal(m2) + if err != nil { + t.Errorf("re-Marshal: %v", err) + } + if !bytes.Equal(b2, buf) { + t.Errorf("Bytes differ on round-trip #%d", i) + } + } +} + +func TestUnmarshalMergesMessages(t *testing.T) { + // If a nested message occurs twice in the input, + // the fields should be merged when decoding. + a := &OtherMessage{ + Key: Int64(123), + Inner: &InnerMessage{ + Host: String("polhode"), + Port: Int32(1234), + }, + } + aData, err := Marshal(a) + if err != nil { + t.Fatalf("Marshal(a): %v", err) + } + b := &OtherMessage{ + Weight: Float32(1.2), + Inner: &InnerMessage{ + Host: String("herpolhode"), + Connected: Bool(true), + }, + } + bData, err := Marshal(b) + if err != nil { + t.Fatalf("Marshal(b): %v", err) + } + want := &OtherMessage{ + Key: Int64(123), + Weight: Float32(1.2), + Inner: &InnerMessage{ + Host: String("herpolhode"), + Port: Int32(1234), + Connected: Bool(true), + }, + } + got := new(OtherMessage) + if err := Unmarshal(append(aData, bData...), got); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + if !Equal(got, want) { + t.Errorf("\n got %v\nwant %v", got, want) + } +} + +func TestEncodingSizes(t *testing.T) { + tests := []struct { + m Message + n int + }{ + {&Defaults{F_Int32: Int32(math.MaxInt32)}, 6}, + {&Defaults{F_Int32: Int32(math.MinInt32)}, 11}, + {&Defaults{F_Uint32: Uint32(uint32(math.MaxInt32) + 1)}, 6}, + {&Defaults{F_Uint32: Uint32(math.MaxUint32)}, 6}, + } + for _, test := range tests { + b, err := Marshal(test.m) + if err != nil { + t.Errorf("Marshal(%v): %v", test.m, err) + continue + } + if len(b) != test.n { + t.Errorf("Marshal(%v) yielded %d bytes, want %d bytes", test.m, len(b), test.n) + } + } +} + +func TestRequiredNotSetError(t *testing.T) { + pb := initGoTest(false) + pb.RequiredField.Label = nil + pb.F_Int32Required = nil + pb.F_Int64Required = nil + + expected := "0807" + // field 1, encoding 0, value 7 + "2206" + "120474797065" + // field 4, encoding 2 (GoTestField) + "5001" + // field 10, encoding 0, value 1 + "6d20000000" + // field 13, encoding 5, value 0x20 + "714000000000000000" + // field 14, encoding 1, value 0x40 + "78a019" + // field 15, encoding 0, value 0xca0 = 3232 + "8001c032" + // field 16, encoding 0, value 0x1940 = 6464 + "8d0100004a45" + // field 17, encoding 5, value 3232.0 + "9101000000000040b940" + // field 18, encoding 1, value 6464.0 + "9a0106" + "737472696e67" + // field 19, encoding 2, string "string" + "b304" + // field 70, encoding 3, start group + "ba0408" + "7265717569726564" + // field 71, encoding 2, string "required" + "b404" + // field 70, encoding 4, end group + "aa0605" + "6279746573" + // field 101, encoding 2, string "bytes" + "b0063f" + // field 102, encoding 0, 0x3f zigzag32 + "b8067f" // field 103, encoding 0, 0x7f zigzag64 + + o := old() + bytes, err := Marshal(pb) + if _, ok := err.(*RequiredNotSetError); !ok { + fmt.Printf("marshal-1 err = %v, want *RequiredNotSetError", err) + o.DebugPrint("", bytes) + t.Fatalf("expected = %s", expected) + } + if strings.Index(err.Error(), "RequiredField.Label") < 0 { + t.Errorf("marshal-1 wrong err msg: %v", err) + } + if !equal(bytes, expected, t) { + o.DebugPrint("neq 1", bytes) + t.Fatalf("expected = %s", expected) + } + + // Now test Unmarshal by recreating the original buffer. + pbd := new(GoTest) + err = Unmarshal(bytes, pbd) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Fatalf("unmarshal err = %v, want *RequiredNotSetError", err) + o.DebugPrint("", bytes) + t.Fatalf("string = %s", expected) + } + if strings.Index(err.Error(), "RequiredField.{Unknown}") < 0 { + t.Errorf("unmarshal wrong err msg: %v", err) + } + bytes, err = Marshal(pbd) + if _, ok := err.(*RequiredNotSetError); !ok { + t.Errorf("marshal-2 err = %v, want *RequiredNotSetError", err) + o.DebugPrint("", bytes) + t.Fatalf("string = %s", expected) + } + if strings.Index(err.Error(), "RequiredField.Label") < 0 { + t.Errorf("marshal-2 wrong err msg: %v", err) + } + if !equal(bytes, expected, t) { + o.DebugPrint("neq 2", bytes) + t.Fatalf("string = %s", expected) + } +} + +func fuzzUnmarshal(t *testing.T, data []byte) { + defer func() { + if e := recover(); e != nil { + t.Errorf("These bytes caused a panic: %+v", data) + t.Logf("Stack:\n%s", debug.Stack()) + t.FailNow() + } + }() + + pb := new(MyMessage) + Unmarshal(data, pb) +} + +func TestMapFieldMarshal(t *testing.T) { + m := &MessageWithMap{ + NameMapping: map[int32]string{ + 1: "Rob", + 4: "Ian", + 8: "Dave", + }, + } + b, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + + // b should be the concatenation of these three byte sequences in some order. + parts := []string{ + "\n\a\b\x01\x12\x03Rob", + "\n\a\b\x04\x12\x03Ian", + "\n\b\b\x08\x12\x04Dave", + } + ok := false + for i := range parts { + for j := range parts { + if j == i { + continue + } + for k := range parts { + if k == i || k == j { + continue + } + try := parts[i] + parts[j] + parts[k] + if bytes.Equal(b, []byte(try)) { + ok = true + break + } + } + } + } + if !ok { + t.Fatalf("Incorrect Marshal output.\n got %q\nwant %q (or a permutation of that)", b, parts[0]+parts[1]+parts[2]) + } + t.Logf("FYI b: %q", b) + + (new(Buffer)).DebugPrint("Dump of b", b) +} + +func TestMapFieldRoundTrips(t *testing.T) { + m := &MessageWithMap{ + NameMapping: map[int32]string{ + 1: "Rob", + 4: "Ian", + 8: "Dave", + }, + MsgMapping: map[int64]*FloatingPoint{ + 0x7001: &FloatingPoint{F: Float64(2.0)}, + }, + ByteMapping: map[bool][]byte{ + false: []byte("that's not right!"), + true: []byte("aye, 'tis true!"), + }, + } + b, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + t.Logf("FYI b: %q", b) + m2 := new(MessageWithMap) + if err := Unmarshal(b, m2); err != nil { + t.Fatalf("Unmarshal: %v", err) + } + for _, pair := range [][2]interface{}{ + {m.NameMapping, m2.NameMapping}, + {m.MsgMapping, m2.MsgMapping}, + {m.ByteMapping, m2.ByteMapping}, + } { + if !reflect.DeepEqual(pair[0], pair[1]) { + t.Errorf("Map did not survive a round trip.\ninitial: %v\n final: %v", pair[0], pair[1]) + } + } +} + +func TestMapFieldWithNil(t *testing.T) { + m := &MessageWithMap{ + MsgMapping: map[int64]*FloatingPoint{ + 1: nil, + }, + } + b, err := Marshal(m) + if err == nil { + t.Fatalf("Marshal of bad map should have failed, got these bytes: %v", b) + } +} + +func TestOneof(t *testing.T) { + m := &Communique{} + b, err := Marshal(m) + if err != nil { + t.Fatalf("Marshal of empty message with oneof: %v", err) + } + if len(b) != 0 { + t.Errorf("Marshal of empty message yielded too many bytes: %v", b) + } + + m = &Communique{ + Union: &Communique_Name{"Barry"}, + } + + // Round-trip. + b, err = Marshal(m) + if err != nil { + t.Fatalf("Marshal of message with oneof: %v", err) + } + if len(b) != 7 { // name tag/wire (1) + name len (1) + name (5) + t.Errorf("Incorrect marshal of message with oneof: %v", b) + } + m.Reset() + if err := Unmarshal(b, m); err != nil { + t.Fatalf("Unmarshal of message with oneof: %v", err) + } + if x, ok := m.Union.(*Communique_Name); !ok || x.Name != "Barry" { + t.Errorf("After round trip, Union = %+v", m.Union) + } + if name := m.GetName(); name != "Barry" { + t.Errorf("After round trip, GetName = %q, want %q", name, "Barry") + } + + // Let's try with a message in the oneof. + m.Union = &Communique_Msg{&Strings{StringField: String("deep deep string")}} + b, err = Marshal(m) + if err != nil { + t.Fatalf("Marshal of message with oneof set to message: %v", err) + } + if len(b) != 20 { // msg tag/wire (1) + msg len (1) + msg (1 + 1 + 16) + t.Errorf("Incorrect marshal of message with oneof set to message: %v", b) + } + m.Reset() + if err := Unmarshal(b, m); err != nil { + t.Fatalf("Unmarshal of message with oneof set to message: %v", err) + } + ss, ok := m.Union.(*Communique_Msg) + if !ok || ss.Msg.GetStringField() != "deep deep string" { + t.Errorf("After round trip with oneof set to message, Union = %+v", m.Union) + } +} + +// Benchmarks + +func testMsg() *GoTest { + pb := initGoTest(true) + const N = 1000 // Internally the library starts much smaller. + pb.F_Int32Repeated = make([]int32, N) + pb.F_DoubleRepeated = make([]float64, N) + for i := 0; i < N; i++ { + pb.F_Int32Repeated[i] = int32(i) + pb.F_DoubleRepeated[i] = float64(i) + } + return pb +} + +func bytesMsg() *GoTest { + pb := initGoTest(true) + buf := make([]byte, 4000) + for i := range buf { + buf[i] = byte(i) + } + pb.F_BytesDefaulted = buf + return pb +} + +func benchmarkMarshal(b *testing.B, pb Message, marshal func(Message) ([]byte, error)) { + d, _ := marshal(pb) + b.SetBytes(int64(len(d))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + marshal(pb) + } +} + +func benchmarkBufferMarshal(b *testing.B, pb Message) { + p := NewBuffer(nil) + benchmarkMarshal(b, pb, func(pb0 Message) ([]byte, error) { + p.Reset() + err := p.Marshal(pb0) + return p.Bytes(), err + }) +} + +func benchmarkSize(b *testing.B, pb Message) { + benchmarkMarshal(b, pb, func(pb0 Message) ([]byte, error) { + Size(pb) + return nil, nil + }) +} + +func newOf(pb Message) Message { + in := reflect.ValueOf(pb) + if in.IsNil() { + return pb + } + return reflect.New(in.Type().Elem()).Interface().(Message) +} + +func benchmarkUnmarshal(b *testing.B, pb Message, unmarshal func([]byte, Message) error) { + d, _ := Marshal(pb) + b.SetBytes(int64(len(d))) + pbd := newOf(pb) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + unmarshal(d, pbd) + } +} + +func benchmarkBufferUnmarshal(b *testing.B, pb Message) { + p := NewBuffer(nil) + benchmarkUnmarshal(b, pb, func(d []byte, pb0 Message) error { + p.SetBuf(d) + return p.Unmarshal(pb0) + }) +} + +// Benchmark{Marshal,BufferMarshal,Size,Unmarshal,BufferUnmarshal}{,Bytes} + +func BenchmarkMarshal(b *testing.B) { + benchmarkMarshal(b, testMsg(), Marshal) +} + +func BenchmarkBufferMarshal(b *testing.B) { + benchmarkBufferMarshal(b, testMsg()) +} + +func BenchmarkSize(b *testing.B) { + benchmarkSize(b, testMsg()) +} + +func BenchmarkUnmarshal(b *testing.B) { + benchmarkUnmarshal(b, testMsg(), Unmarshal) +} + +func BenchmarkBufferUnmarshal(b *testing.B) { + benchmarkBufferUnmarshal(b, testMsg()) +} + +func BenchmarkMarshalBytes(b *testing.B) { + benchmarkMarshal(b, bytesMsg(), Marshal) +} + +func BenchmarkBufferMarshalBytes(b *testing.B) { + benchmarkBufferMarshal(b, bytesMsg()) +} + +func BenchmarkSizeBytes(b *testing.B) { + benchmarkSize(b, bytesMsg()) +} + +func BenchmarkUnmarshalBytes(b *testing.B) { + benchmarkUnmarshal(b, bytesMsg(), Unmarshal) +} + +func BenchmarkBufferUnmarshalBytes(b *testing.B) { + benchmarkBufferUnmarshal(b, bytesMsg()) +} + +func BenchmarkUnmarshalUnrecognizedFields(b *testing.B) { + b.StopTimer() + pb := initGoTestField() + skip := &GoSkipTest{ + SkipInt32: Int32(32), + SkipFixed32: Uint32(3232), + SkipFixed64: Uint64(6464), + SkipString: String("skipper"), + Skipgroup: &GoSkipTest_SkipGroup{ + GroupInt32: Int32(75), + GroupString: String("wxyz"), + }, + } + + pbd := new(GoTestField) + p := NewBuffer(nil) + p.Marshal(pb) + p.Marshal(skip) + p2 := NewBuffer(nil) + + b.StartTimer() + for i := 0; i < b.N; i++ { + p2.SetBuf(p.Bytes()) + p2.Unmarshal(pbd) + } +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/clone.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/clone.go new file mode 100644 index 0000000..84bbaf4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/clone.go @@ -0,0 +1,223 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Protocol buffer deep copy and merge. +// TODO: MessageSet and RawMessage. + +package proto + +import ( + "log" + "reflect" + "strings" +) + +// Clone returns a deep copy of a protocol buffer. +func Clone(pb Message) Message { + in := reflect.ValueOf(pb) + if in.IsNil() { + return pb + } + + out := reflect.New(in.Type().Elem()) + // out is empty so a merge is a deep copy. + mergeStruct(out.Elem(), in.Elem()) + return out.Interface().(Message) +} + +// Merge merges src into dst. +// Required and optional fields that are set in src will be set to that value in dst. +// Elements of repeated fields will be appended. +// Merge panics if src and dst are not the same type, or if dst is nil. +func Merge(dst, src Message) { + in := reflect.ValueOf(src) + out := reflect.ValueOf(dst) + if out.IsNil() { + panic("proto: nil destination") + } + if in.Type() != out.Type() { + // Explicit test prior to mergeStruct so that mistyped nils will fail + panic("proto: type mismatch") + } + if in.IsNil() { + // Merging nil into non-nil is a quiet no-op + return + } + mergeStruct(out.Elem(), in.Elem()) +} + +func mergeStruct(out, in reflect.Value) { + sprop := GetProperties(in.Type()) + for i := 0; i < in.NumField(); i++ { + f := in.Type().Field(i) + if strings.HasPrefix(f.Name, "XXX_") { + continue + } + mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i]) + } + + if emIn, ok := in.Addr().Interface().(extendableProto); ok { + emOut := out.Addr().Interface().(extendableProto) + mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap()) + } + + uf := in.FieldByName("XXX_unrecognized") + if !uf.IsValid() { + return + } + uin := uf.Bytes() + if len(uin) > 0 { + out.FieldByName("XXX_unrecognized").SetBytes(append([]byte(nil), uin...)) + } +} + +// mergeAny performs a merge between two values of the same type. +// viaPtr indicates whether the values were indirected through a pointer (implying proto2). +// prop is set if this is a struct field (it may be nil). +func mergeAny(out, in reflect.Value, viaPtr bool, prop *Properties) { + if in.Type() == protoMessageType { + if !in.IsNil() { + if out.IsNil() { + out.Set(reflect.ValueOf(Clone(in.Interface().(Message)))) + } else { + Merge(out.Interface().(Message), in.Interface().(Message)) + } + } + return + } + switch in.Kind() { + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, + reflect.String, reflect.Uint32, reflect.Uint64: + if !viaPtr && isProto3Zero(in) { + return + } + out.Set(in) + case reflect.Interface: + // Probably a oneof field; copy non-nil values. + if in.IsNil() { + return + } + // Allocate destination if it is not set, or set to a different type. + // Otherwise we will merge as normal. + if out.IsNil() || out.Elem().Type() != in.Elem().Type() { + out.Set(reflect.New(in.Elem().Elem().Type())) // interface -> *T -> T -> new(T) + } + mergeAny(out.Elem(), in.Elem(), false, nil) + case reflect.Map: + if in.Len() == 0 { + return + } + if out.IsNil() { + out.Set(reflect.MakeMap(in.Type())) + } + // For maps with value types of *T or []byte we need to deep copy each value. + elemKind := in.Type().Elem().Kind() + for _, key := range in.MapKeys() { + var val reflect.Value + switch elemKind { + case reflect.Ptr: + val = reflect.New(in.Type().Elem().Elem()) + mergeAny(val, in.MapIndex(key), false, nil) + case reflect.Slice: + val = in.MapIndex(key) + val = reflect.ValueOf(append([]byte{}, val.Bytes()...)) + default: + val = in.MapIndex(key) + } + out.SetMapIndex(key, val) + } + case reflect.Ptr: + if in.IsNil() { + return + } + if out.IsNil() { + out.Set(reflect.New(in.Elem().Type())) + } + mergeAny(out.Elem(), in.Elem(), true, nil) + case reflect.Slice: + if in.IsNil() { + return + } + if in.Type().Elem().Kind() == reflect.Uint8 { + // []byte is a scalar bytes field, not a repeated field. + + // Edge case: if this is in a proto3 message, a zero length + // bytes field is considered the zero value, and should not + // be merged. + if prop != nil && prop.proto3 && in.Len() == 0 { + return + } + + // Make a deep copy. + // Append to []byte{} instead of []byte(nil) so that we never end up + // with a nil result. + out.SetBytes(append([]byte{}, in.Bytes()...)) + return + } + n := in.Len() + if out.IsNil() { + out.Set(reflect.MakeSlice(in.Type(), 0, n)) + } + switch in.Type().Elem().Kind() { + case reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int32, reflect.Int64, + reflect.String, reflect.Uint32, reflect.Uint64: + out.Set(reflect.AppendSlice(out, in)) + default: + for i := 0; i < n; i++ { + x := reflect.Indirect(reflect.New(in.Type().Elem())) + mergeAny(x, in.Index(i), false, nil) + out.Set(reflect.Append(out, x)) + } + } + case reflect.Struct: + mergeStruct(out, in) + default: + // unknown type, so not a protocol buffer + log.Printf("proto: don't know how to copy %v", in) + } +} + +func mergeExtension(out, in map[int32]Extension) { + for extNum, eIn := range in { + eOut := Extension{desc: eIn.desc} + if eIn.value != nil { + v := reflect.New(reflect.TypeOf(eIn.value)).Elem() + mergeAny(v, reflect.ValueOf(eIn.value), false, nil) + eOut.value = v.Interface() + } + if eIn.enc != nil { + eOut.enc = make([]byte, len(eIn.enc)) + copy(eOut.enc, eIn.enc) + } + + out[extNum] = eOut + } +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/clone_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/clone_test.go new file mode 100644 index 0000000..76720f1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/clone_test.go @@ -0,0 +1,267 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "testing" + + "github.com/golang/protobuf/proto" + + proto3pb "github.com/golang/protobuf/proto/proto3_proto" + pb "github.com/golang/protobuf/proto/testdata" +) + +var cloneTestMessage = &pb.MyMessage{ + Count: proto.Int32(42), + Name: proto.String("Dave"), + Pet: []string{"bunny", "kitty", "horsey"}, + Inner: &pb.InnerMessage{ + Host: proto.String("niles"), + Port: proto.Int32(9099), + Connected: proto.Bool(true), + }, + Others: []*pb.OtherMessage{ + { + Value: []byte("some bytes"), + }, + }, + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: proto.Int32(6), + }, + RepBytes: [][]byte{[]byte("sham"), []byte("wow")}, +} + +func init() { + ext := &pb.Ext{ + Data: proto.String("extension"), + } + if err := proto.SetExtension(cloneTestMessage, pb.E_Ext_More, ext); err != nil { + panic("SetExtension: " + err.Error()) + } +} + +func TestClone(t *testing.T) { + m := proto.Clone(cloneTestMessage).(*pb.MyMessage) + if !proto.Equal(m, cloneTestMessage) { + t.Errorf("Clone(%v) = %v", cloneTestMessage, m) + } + + // Verify it was a deep copy. + *m.Inner.Port++ + if proto.Equal(m, cloneTestMessage) { + t.Error("Mutating clone changed the original") + } + // Byte fields and repeated fields should be copied. + if &m.Pet[0] == &cloneTestMessage.Pet[0] { + t.Error("Pet: repeated field not copied") + } + if &m.Others[0] == &cloneTestMessage.Others[0] { + t.Error("Others: repeated field not copied") + } + if &m.Others[0].Value[0] == &cloneTestMessage.Others[0].Value[0] { + t.Error("Others[0].Value: bytes field not copied") + } + if &m.RepBytes[0] == &cloneTestMessage.RepBytes[0] { + t.Error("RepBytes: repeated field not copied") + } + if &m.RepBytes[0][0] == &cloneTestMessage.RepBytes[0][0] { + t.Error("RepBytes[0]: bytes field not copied") + } +} + +func TestCloneNil(t *testing.T) { + var m *pb.MyMessage + if c := proto.Clone(m); !proto.Equal(m, c) { + t.Errorf("Clone(%v) = %v", m, c) + } +} + +var mergeTests = []struct { + src, dst, want proto.Message +}{ + { + src: &pb.MyMessage{ + Count: proto.Int32(42), + }, + dst: &pb.MyMessage{ + Name: proto.String("Dave"), + }, + want: &pb.MyMessage{ + Count: proto.Int32(42), + Name: proto.String("Dave"), + }, + }, + { + src: &pb.MyMessage{ + Inner: &pb.InnerMessage{ + Host: proto.String("hey"), + Connected: proto.Bool(true), + }, + Pet: []string{"horsey"}, + Others: []*pb.OtherMessage{ + { + Value: []byte("some bytes"), + }, + }, + }, + dst: &pb.MyMessage{ + Inner: &pb.InnerMessage{ + Host: proto.String("niles"), + Port: proto.Int32(9099), + }, + Pet: []string{"bunny", "kitty"}, + Others: []*pb.OtherMessage{ + { + Key: proto.Int64(31415926535), + }, + { + // Explicitly test a src=nil field + Inner: nil, + }, + }, + }, + want: &pb.MyMessage{ + Inner: &pb.InnerMessage{ + Host: proto.String("hey"), + Connected: proto.Bool(true), + Port: proto.Int32(9099), + }, + Pet: []string{"bunny", "kitty", "horsey"}, + Others: []*pb.OtherMessage{ + { + Key: proto.Int64(31415926535), + }, + {}, + { + Value: []byte("some bytes"), + }, + }, + }, + }, + { + src: &pb.MyMessage{ + RepBytes: [][]byte{[]byte("wow")}, + }, + dst: &pb.MyMessage{ + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: proto.Int32(6), + }, + RepBytes: [][]byte{[]byte("sham")}, + }, + want: &pb.MyMessage{ + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: proto.Int32(6), + }, + RepBytes: [][]byte{[]byte("sham"), []byte("wow")}, + }, + }, + // Check that a scalar bytes field replaces rather than appends. + { + src: &pb.OtherMessage{Value: []byte("foo")}, + dst: &pb.OtherMessage{Value: []byte("bar")}, + want: &pb.OtherMessage{Value: []byte("foo")}, + }, + { + src: &pb.MessageWithMap{ + NameMapping: map[int32]string{6: "Nigel"}, + MsgMapping: map[int64]*pb.FloatingPoint{ + 0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)}, + }, + ByteMapping: map[bool][]byte{true: []byte("wowsa")}, + }, + dst: &pb.MessageWithMap{ + NameMapping: map[int32]string{ + 6: "Bruce", // should be overwritten + 7: "Andrew", + }, + }, + want: &pb.MessageWithMap{ + NameMapping: map[int32]string{ + 6: "Nigel", + 7: "Andrew", + }, + MsgMapping: map[int64]*pb.FloatingPoint{ + 0x4001: &pb.FloatingPoint{F: proto.Float64(2.0)}, + }, + ByteMapping: map[bool][]byte{true: []byte("wowsa")}, + }, + }, + // proto3 shouldn't merge zero values, + // in the same way that proto2 shouldn't merge nils. + { + src: &proto3pb.Message{ + Name: "Aaron", + Data: []byte(""), // zero value, but not nil + }, + dst: &proto3pb.Message{ + HeightInCm: 176, + Data: []byte("texas!"), + }, + want: &proto3pb.Message{ + Name: "Aaron", + HeightInCm: 176, + Data: []byte("texas!"), + }, + }, + // Oneof fields should merge by assignment. + { + src: &pb.Communique{ + Union: &pb.Communique_Number{41}, + }, + dst: &pb.Communique{ + Union: &pb.Communique_Name{"Bobby Tables"}, + }, + want: &pb.Communique{ + Union: &pb.Communique_Number{41}, + }, + }, + // Oneof nil is the same as not set. + { + src: &pb.Communique{}, + dst: &pb.Communique{ + Union: &pb.Communique_Name{"Bobby Tables"}, + }, + want: &pb.Communique{ + Union: &pb.Communique_Name{"Bobby Tables"}, + }, + }, +} + +func TestMerge(t *testing.T) { + for _, m := range mergeTests { + got := proto.Clone(m.dst) + proto.Merge(got, m.src) + if !proto.Equal(got, m.want) { + t.Errorf("Merge(%v, %v)\n got %v\nwant %v\n", m.dst, m.src, got, m.want) + } + } +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/decode.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/decode.go new file mode 100644 index 0000000..8486635 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/decode.go @@ -0,0 +1,863 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Routines for decoding protocol buffer data to construct in-memory representations. + */ + +import ( + "errors" + "fmt" + "io" + "os" + "reflect" +) + +// errOverflow is returned when an integer is too large to be represented. +var errOverflow = errors.New("proto: integer overflow") + +// ErrInternalBadWireType is returned by generated code when an incorrect +// wire type is encountered. It does not get returned to user code. +var ErrInternalBadWireType = errors.New("proto: internal error: bad wiretype for oneof") + +// The fundamental decoders that interpret bytes on the wire. +// Those that take integer types all return uint64 and are +// therefore of type valueDecoder. + +// DecodeVarint reads a varint-encoded integer from the slice. +// It returns the integer and the number of bytes consumed, or +// zero if there is not enough. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +func DecodeVarint(buf []byte) (x uint64, n int) { + // x, n already 0 + for shift := uint(0); shift < 64; shift += 7 { + if n >= len(buf) { + return 0, 0 + } + b := uint64(buf[n]) + n++ + x |= (b & 0x7F) << shift + if (b & 0x80) == 0 { + return x, n + } + } + + // The number is too large to represent in a 64-bit value. + return 0, 0 +} + +// DecodeVarint reads a varint-encoded integer from the Buffer. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +func (p *Buffer) DecodeVarint() (x uint64, err error) { + // x, err already 0 + + i := p.index + l := len(p.buf) + + for shift := uint(0); shift < 64; shift += 7 { + if i >= l { + err = io.ErrUnexpectedEOF + return + } + b := p.buf[i] + i++ + x |= (uint64(b) & 0x7F) << shift + if b < 0x80 { + p.index = i + return + } + } + + // The number is too large to represent in a 64-bit value. + err = errOverflow + return +} + +// DecodeFixed64 reads a 64-bit integer from the Buffer. +// This is the format for the +// fixed64, sfixed64, and double protocol buffer types. +func (p *Buffer) DecodeFixed64() (x uint64, err error) { + // x, err already 0 + i := p.index + 8 + if i < 0 || i > len(p.buf) { + err = io.ErrUnexpectedEOF + return + } + p.index = i + + x = uint64(p.buf[i-8]) + x |= uint64(p.buf[i-7]) << 8 + x |= uint64(p.buf[i-6]) << 16 + x |= uint64(p.buf[i-5]) << 24 + x |= uint64(p.buf[i-4]) << 32 + x |= uint64(p.buf[i-3]) << 40 + x |= uint64(p.buf[i-2]) << 48 + x |= uint64(p.buf[i-1]) << 56 + return +} + +// DecodeFixed32 reads a 32-bit integer from the Buffer. +// This is the format for the +// fixed32, sfixed32, and float protocol buffer types. +func (p *Buffer) DecodeFixed32() (x uint64, err error) { + // x, err already 0 + i := p.index + 4 + if i < 0 || i > len(p.buf) { + err = io.ErrUnexpectedEOF + return + } + p.index = i + + x = uint64(p.buf[i-4]) + x |= uint64(p.buf[i-3]) << 8 + x |= uint64(p.buf[i-2]) << 16 + x |= uint64(p.buf[i-1]) << 24 + return +} + +// DecodeZigzag64 reads a zigzag-encoded 64-bit integer +// from the Buffer. +// This is the format used for the sint64 protocol buffer type. +func (p *Buffer) DecodeZigzag64() (x uint64, err error) { + x, err = p.DecodeVarint() + if err != nil { + return + } + x = (x >> 1) ^ uint64((int64(x&1)<<63)>>63) + return +} + +// DecodeZigzag32 reads a zigzag-encoded 32-bit integer +// from the Buffer. +// This is the format used for the sint32 protocol buffer type. +func (p *Buffer) DecodeZigzag32() (x uint64, err error) { + x, err = p.DecodeVarint() + if err != nil { + return + } + x = uint64((uint32(x) >> 1) ^ uint32((int32(x&1)<<31)>>31)) + return +} + +// These are not ValueDecoders: they produce an array of bytes or a string. +// bytes, embedded messages + +// DecodeRawBytes reads a count-delimited byte buffer from the Buffer. +// This is the format used for the bytes protocol buffer +// type and for embedded messages. +func (p *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) { + n, err := p.DecodeVarint() + if err != nil { + return nil, err + } + + nb := int(n) + if nb < 0 { + return nil, fmt.Errorf("proto: bad byte length %d", nb) + } + end := p.index + nb + if end < p.index || end > len(p.buf) { + return nil, io.ErrUnexpectedEOF + } + + if !alloc { + // todo: check if can get more uses of alloc=false + buf = p.buf[p.index:end] + p.index += nb + return + } + + buf = make([]byte, nb) + copy(buf, p.buf[p.index:]) + p.index += nb + return +} + +// DecodeStringBytes reads an encoded string from the Buffer. +// This is the format used for the proto2 string type. +func (p *Buffer) DecodeStringBytes() (s string, err error) { + buf, err := p.DecodeRawBytes(false) + if err != nil { + return + } + return string(buf), nil +} + +// Skip the next item in the buffer. Its wire type is decoded and presented as an argument. +// If the protocol buffer has extensions, and the field matches, add it as an extension. +// Otherwise, if the XXX_unrecognized field exists, append the skipped data there. +func (o *Buffer) skipAndSave(t reflect.Type, tag, wire int, base structPointer, unrecField field) error { + oi := o.index + + err := o.skip(t, tag, wire) + if err != nil { + return err + } + + if !unrecField.IsValid() { + return nil + } + + ptr := structPointer_Bytes(base, unrecField) + + // Add the skipped field to struct field + obuf := o.buf + + o.buf = *ptr + o.EncodeVarint(uint64(tag<<3 | wire)) + *ptr = append(o.buf, obuf[oi:o.index]...) + + o.buf = obuf + + return nil +} + +// Skip the next item in the buffer. Its wire type is decoded and presented as an argument. +func (o *Buffer) skip(t reflect.Type, tag, wire int) error { + + var u uint64 + var err error + + switch wire { + case WireVarint: + _, err = o.DecodeVarint() + case WireFixed64: + _, err = o.DecodeFixed64() + case WireBytes: + _, err = o.DecodeRawBytes(false) + case WireFixed32: + _, err = o.DecodeFixed32() + case WireStartGroup: + for { + u, err = o.DecodeVarint() + if err != nil { + break + } + fwire := int(u & 0x7) + if fwire == WireEndGroup { + break + } + ftag := int(u >> 3) + err = o.skip(t, ftag, fwire) + if err != nil { + break + } + } + default: + err = fmt.Errorf("proto: can't skip unknown wire type %d for %s", wire, t) + } + return err +} + +// Unmarshaler is the interface representing objects that can +// unmarshal themselves. The method should reset the receiver before +// decoding starts. The argument points to data that may be +// overwritten, so implementations should not keep references to the +// buffer. +type Unmarshaler interface { + Unmarshal([]byte) error +} + +// Unmarshal parses the protocol buffer representation in buf and places the +// decoded result in pb. If the struct underlying pb does not match +// the data in buf, the results can be unpredictable. +// +// Unmarshal resets pb before starting to unmarshal, so any +// existing data in pb is always removed. Use UnmarshalMerge +// to preserve and append to existing data. +func Unmarshal(buf []byte, pb Message) error { + pb.Reset() + return UnmarshalMerge(buf, pb) +} + +// UnmarshalMerge parses the protocol buffer representation in buf and +// writes the decoded result to pb. If the struct underlying pb does not match +// the data in buf, the results can be unpredictable. +// +// UnmarshalMerge merges into existing data in pb. +// Most code should use Unmarshal instead. +func UnmarshalMerge(buf []byte, pb Message) error { + // If the object can unmarshal itself, let it. + if u, ok := pb.(Unmarshaler); ok { + return u.Unmarshal(buf) + } + return NewBuffer(buf).Unmarshal(pb) +} + +// DecodeMessage reads a count-delimited message from the Buffer. +func (p *Buffer) DecodeMessage(pb Message) error { + enc, err := p.DecodeRawBytes(false) + if err != nil { + return err + } + return NewBuffer(enc).Unmarshal(pb) +} + +// DecodeGroup reads a tag-delimited group from the Buffer. +func (p *Buffer) DecodeGroup(pb Message) error { + typ, base, err := getbase(pb) + if err != nil { + return err + } + return p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), true, base) +} + +// Unmarshal parses the protocol buffer representation in the +// Buffer and places the decoded result in pb. If the struct +// underlying pb does not match the data in the buffer, the results can be +// unpredictable. +func (p *Buffer) Unmarshal(pb Message) error { + // If the object can unmarshal itself, let it. + if u, ok := pb.(Unmarshaler); ok { + err := u.Unmarshal(p.buf[p.index:]) + p.index = len(p.buf) + return err + } + + typ, base, err := getbase(pb) + if err != nil { + return err + } + + err = p.unmarshalType(typ.Elem(), GetProperties(typ.Elem()), false, base) + + if collectStats { + stats.Decode++ + } + + return err +} + +// unmarshalType does the work of unmarshaling a structure. +func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error { + var state errorState + required, reqFields := prop.reqCount, uint64(0) + + var err error + for err == nil && o.index < len(o.buf) { + oi := o.index + var u uint64 + u, err = o.DecodeVarint() + if err != nil { + break + } + wire := int(u & 0x7) + if wire == WireEndGroup { + if is_group { + return nil // input is satisfied + } + return fmt.Errorf("proto: %s: wiretype end group for non-group", st) + } + tag := int(u >> 3) + if tag <= 0 { + return fmt.Errorf("proto: %s: illegal tag %d (wire type %d)", st, tag, wire) + } + fieldnum, ok := prop.decoderTags.get(tag) + if !ok { + // Maybe it's an extension? + if prop.extendable { + if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) { + if err = o.skip(st, tag, wire); err == nil { + ext := e.ExtensionMap()[int32(tag)] // may be missing + ext.enc = append(ext.enc, o.buf[oi:o.index]...) + e.ExtensionMap()[int32(tag)] = ext + } + continue + } + } + // Maybe it's a oneof? + if prop.oneofUnmarshaler != nil { + m := structPointer_Interface(base, st).(Message) + // First return value indicates whether tag is a oneof field. + ok, err = prop.oneofUnmarshaler(m, tag, wire, o) + if err == ErrInternalBadWireType { + // Map the error to something more descriptive. + // Do the formatting here to save generated code space. + err = fmt.Errorf("bad wiretype for oneof field in %T", m) + } + if ok { + continue + } + } + err = o.skipAndSave(st, tag, wire, base, prop.unrecField) + continue + } + p := prop.Prop[fieldnum] + + if p.dec == nil { + fmt.Fprintf(os.Stderr, "proto: no protobuf decoder for %s.%s\n", st, st.Field(fieldnum).Name) + continue + } + dec := p.dec + if wire != WireStartGroup && wire != p.WireType { + if wire == WireBytes && p.packedDec != nil { + // a packable field + dec = p.packedDec + } else { + err = fmt.Errorf("proto: bad wiretype for field %s.%s: got wiretype %d, want %d", st, st.Field(fieldnum).Name, wire, p.WireType) + continue + } + } + decErr := dec(o, p, base) + if decErr != nil && !state.shouldContinue(decErr, p) { + err = decErr + } + if err == nil && p.Required { + // Successfully decoded a required field. + if tag <= 64 { + // use bitmap for fields 1-64 to catch field reuse. + var mask uint64 = 1 << uint64(tag-1) + if reqFields&mask == 0 { + // new required field + reqFields |= mask + required-- + } + } else { + // This is imprecise. It can be fooled by a required field + // with a tag > 64 that is encoded twice; that's very rare. + // A fully correct implementation would require allocating + // a data structure, which we would like to avoid. + required-- + } + } + } + if err == nil { + if is_group { + return io.ErrUnexpectedEOF + } + if state.err != nil { + return state.err + } + if required > 0 { + // Not enough information to determine the exact field. If we use extra + // CPU, we could determine the field only if the missing required field + // has a tag <= 64 and we check reqFields. + return &RequiredNotSetError{"{Unknown}"} + } + } + return err +} + +// Individual type decoders +// For each, +// u is the decoded value, +// v is a pointer to the field (pointer) in the struct + +// Sizes of the pools to allocate inside the Buffer. +// The goal is modest amortization and allocation +// on at least 16-byte boundaries. +const ( + boolPoolSize = 16 + uint32PoolSize = 8 + uint64PoolSize = 4 +) + +// Decode a bool. +func (o *Buffer) dec_bool(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + if len(o.bools) == 0 { + o.bools = make([]bool, boolPoolSize) + } + o.bools[0] = u != 0 + *structPointer_Bool(base, p.field) = &o.bools[0] + o.bools = o.bools[1:] + return nil +} + +func (o *Buffer) dec_proto3_bool(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + *structPointer_BoolVal(base, p.field) = u != 0 + return nil +} + +// Decode an int32. +func (o *Buffer) dec_int32(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word32_Set(structPointer_Word32(base, p.field), o, uint32(u)) + return nil +} + +func (o *Buffer) dec_proto3_int32(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word32Val_Set(structPointer_Word32Val(base, p.field), uint32(u)) + return nil +} + +// Decode an int64. +func (o *Buffer) dec_int64(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word64_Set(structPointer_Word64(base, p.field), o, u) + return nil +} + +func (o *Buffer) dec_proto3_int64(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + word64Val_Set(structPointer_Word64Val(base, p.field), o, u) + return nil +} + +// Decode a string. +func (o *Buffer) dec_string(p *Properties, base structPointer) error { + s, err := o.DecodeStringBytes() + if err != nil { + return err + } + *structPointer_String(base, p.field) = &s + return nil +} + +func (o *Buffer) dec_proto3_string(p *Properties, base structPointer) error { + s, err := o.DecodeStringBytes() + if err != nil { + return err + } + *structPointer_StringVal(base, p.field) = s + return nil +} + +// Decode a slice of bytes ([]byte). +func (o *Buffer) dec_slice_byte(p *Properties, base structPointer) error { + b, err := o.DecodeRawBytes(true) + if err != nil { + return err + } + *structPointer_Bytes(base, p.field) = b + return nil +} + +// Decode a slice of bools ([]bool). +func (o *Buffer) dec_slice_bool(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + v := structPointer_BoolSlice(base, p.field) + *v = append(*v, u != 0) + return nil +} + +// Decode a slice of bools ([]bool) in packed format. +func (o *Buffer) dec_slice_packed_bool(p *Properties, base structPointer) error { + v := structPointer_BoolSlice(base, p.field) + + nn, err := o.DecodeVarint() + if err != nil { + return err + } + nb := int(nn) // number of bytes of encoded bools + + y := *v + for i := 0; i < nb; i++ { + u, err := p.valDec(o) + if err != nil { + return err + } + y = append(y, u != 0) + } + + *v = y + return nil +} + +// Decode a slice of int32s ([]int32). +func (o *Buffer) dec_slice_int32(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + structPointer_Word32Slice(base, p.field).Append(uint32(u)) + return nil +} + +// Decode a slice of int32s ([]int32) in packed format. +func (o *Buffer) dec_slice_packed_int32(p *Properties, base structPointer) error { + v := structPointer_Word32Slice(base, p.field) + + nn, err := o.DecodeVarint() + if err != nil { + return err + } + nb := int(nn) // number of bytes of encoded int32s + + fin := o.index + nb + if fin < o.index { + return errOverflow + } + for o.index < fin { + u, err := p.valDec(o) + if err != nil { + return err + } + v.Append(uint32(u)) + } + return nil +} + +// Decode a slice of int64s ([]int64). +func (o *Buffer) dec_slice_int64(p *Properties, base structPointer) error { + u, err := p.valDec(o) + if err != nil { + return err + } + + structPointer_Word64Slice(base, p.field).Append(u) + return nil +} + +// Decode a slice of int64s ([]int64) in packed format. +func (o *Buffer) dec_slice_packed_int64(p *Properties, base structPointer) error { + v := structPointer_Word64Slice(base, p.field) + + nn, err := o.DecodeVarint() + if err != nil { + return err + } + nb := int(nn) // number of bytes of encoded int64s + + fin := o.index + nb + if fin < o.index { + return errOverflow + } + for o.index < fin { + u, err := p.valDec(o) + if err != nil { + return err + } + v.Append(u) + } + return nil +} + +// Decode a slice of strings ([]string). +func (o *Buffer) dec_slice_string(p *Properties, base structPointer) error { + s, err := o.DecodeStringBytes() + if err != nil { + return err + } + v := structPointer_StringSlice(base, p.field) + *v = append(*v, s) + return nil +} + +// Decode a slice of slice of bytes ([][]byte). +func (o *Buffer) dec_slice_slice_byte(p *Properties, base structPointer) error { + b, err := o.DecodeRawBytes(true) + if err != nil { + return err + } + v := structPointer_BytesSlice(base, p.field) + *v = append(*v, b) + return nil +} + +// Decode a map field. +func (o *Buffer) dec_new_map(p *Properties, base structPointer) error { + raw, err := o.DecodeRawBytes(false) + if err != nil { + return err + } + oi := o.index // index at the end of this map entry + o.index -= len(raw) // move buffer back to start of map entry + + mptr := structPointer_NewAt(base, p.field, p.mtype) // *map[K]V + if mptr.Elem().IsNil() { + mptr.Elem().Set(reflect.MakeMap(mptr.Type().Elem())) + } + v := mptr.Elem() // map[K]V + + // Prepare addressable doubly-indirect placeholders for the key and value types. + // See enc_new_map for why. + keyptr := reflect.New(reflect.PtrTo(p.mtype.Key())).Elem() // addressable *K + keybase := toStructPointer(keyptr.Addr()) // **K + + var valbase structPointer + var valptr reflect.Value + switch p.mtype.Elem().Kind() { + case reflect.Slice: + // []byte + var dummy []byte + valptr = reflect.ValueOf(&dummy) // *[]byte + valbase = toStructPointer(valptr) // *[]byte + case reflect.Ptr: + // message; valptr is **Msg; need to allocate the intermediate pointer + valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V + valptr.Set(reflect.New(valptr.Type().Elem())) + valbase = toStructPointer(valptr) + default: + // everything else + valptr = reflect.New(reflect.PtrTo(p.mtype.Elem())).Elem() // addressable *V + valbase = toStructPointer(valptr.Addr()) // **V + } + + // Decode. + // This parses a restricted wire format, namely the encoding of a message + // with two fields. See enc_new_map for the format. + for o.index < oi { + // tagcode for key and value properties are always a single byte + // because they have tags 1 and 2. + tagcode := o.buf[o.index] + o.index++ + switch tagcode { + case p.mkeyprop.tagcode[0]: + if err := p.mkeyprop.dec(o, p.mkeyprop, keybase); err != nil { + return err + } + case p.mvalprop.tagcode[0]: + if err := p.mvalprop.dec(o, p.mvalprop, valbase); err != nil { + return err + } + default: + // TODO: Should we silently skip this instead? + return fmt.Errorf("proto: bad map data tag %d", raw[0]) + } + } + keyelem, valelem := keyptr.Elem(), valptr.Elem() + if !keyelem.IsValid() || !valelem.IsValid() { + // We did not decode the key or the value in the map entry. + // Either way, it's an invalid map entry. + return fmt.Errorf("proto: bad map data: missing key/val") + } + + v.SetMapIndex(keyelem, valelem) + return nil +} + +// Decode a group. +func (o *Buffer) dec_struct_group(p *Properties, base structPointer) error { + bas := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(bas) { + // allocate new nested message + bas = toStructPointer(reflect.New(p.stype)) + structPointer_SetStructPointer(base, p.field, bas) + } + return o.unmarshalType(p.stype, p.sprop, true, bas) +} + +// Decode an embedded message. +func (o *Buffer) dec_struct_message(p *Properties, base structPointer) (err error) { + raw, e := o.DecodeRawBytes(false) + if e != nil { + return e + } + + bas := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(bas) { + // allocate new nested message + bas = toStructPointer(reflect.New(p.stype)) + structPointer_SetStructPointer(base, p.field, bas) + } + + // If the object can unmarshal itself, let it. + if p.isUnmarshaler { + iv := structPointer_Interface(bas, p.stype) + return iv.(Unmarshaler).Unmarshal(raw) + } + + obuf := o.buf + oi := o.index + o.buf = raw + o.index = 0 + + err = o.unmarshalType(p.stype, p.sprop, false, bas) + o.buf = obuf + o.index = oi + + return err +} + +// Decode a slice of embedded messages. +func (o *Buffer) dec_slice_struct_message(p *Properties, base structPointer) error { + return o.dec_slice_struct(p, false, base) +} + +// Decode a slice of embedded groups. +func (o *Buffer) dec_slice_struct_group(p *Properties, base structPointer) error { + return o.dec_slice_struct(p, true, base) +} + +// Decode a slice of structs ([]*struct). +func (o *Buffer) dec_slice_struct(p *Properties, is_group bool, base structPointer) error { + v := reflect.New(p.stype) + bas := toStructPointer(v) + structPointer_StructPointerSlice(base, p.field).Append(bas) + + if is_group { + err := o.unmarshalType(p.stype, p.sprop, is_group, bas) + return err + } + + raw, err := o.DecodeRawBytes(false) + if err != nil { + return err + } + + // If the object can unmarshal itself, let it. + if p.isUnmarshaler { + iv := v.Interface() + return iv.(Unmarshaler).Unmarshal(raw) + } + + obuf := o.buf + oi := o.index + o.buf = raw + o.index = 0 + + err = o.unmarshalType(p.stype, p.sprop, is_group, bas) + + o.buf = obuf + o.index = oi + + return err +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/encode.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/encode.go new file mode 100644 index 0000000..fe48cc7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/encode.go @@ -0,0 +1,1336 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Routines for encoding data into the wire format for protocol buffers. + */ + +import ( + "errors" + "fmt" + "reflect" + "sort" +) + +// RequiredNotSetError is the error returned if Marshal is called with +// a protocol buffer struct whose required fields have not +// all been initialized. It is also the error returned if Unmarshal is +// called with an encoded protocol buffer that does not include all the +// required fields. +// +// When printed, RequiredNotSetError reports the first unset required field in a +// message. If the field cannot be precisely determined, it is reported as +// "{Unknown}". +type RequiredNotSetError struct { + field string +} + +func (e *RequiredNotSetError) Error() string { + return fmt.Sprintf("proto: required field %q not set", e.field) +} + +var ( + // errRepeatedHasNil is the error returned if Marshal is called with + // a struct with a repeated field containing a nil element. + errRepeatedHasNil = errors.New("proto: repeated field has nil element") + + // ErrNil is the error returned if Marshal is called with nil. + ErrNil = errors.New("proto: Marshal called with nil") +) + +// The fundamental encoders that put bytes on the wire. +// Those that take integer types all accept uint64 and are +// therefore of type valueEncoder. + +const maxVarintBytes = 10 // maximum length of a varint + +// EncodeVarint returns the varint encoding of x. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +// Not used by the package itself, but helpful to clients +// wishing to use the same encoding. +func EncodeVarint(x uint64) []byte { + var buf [maxVarintBytes]byte + var n int + for n = 0; x > 127; n++ { + buf[n] = 0x80 | uint8(x&0x7F) + x >>= 7 + } + buf[n] = uint8(x) + n++ + return buf[0:n] +} + +// EncodeVarint writes a varint-encoded integer to the Buffer. +// This is the format for the +// int32, int64, uint32, uint64, bool, and enum +// protocol buffer types. +func (p *Buffer) EncodeVarint(x uint64) error { + for x >= 1<<7 { + p.buf = append(p.buf, uint8(x&0x7f|0x80)) + x >>= 7 + } + p.buf = append(p.buf, uint8(x)) + return nil +} + +func sizeVarint(x uint64) (n int) { + for { + n++ + x >>= 7 + if x == 0 { + break + } + } + return n +} + +// EncodeFixed64 writes a 64-bit integer to the Buffer. +// This is the format for the +// fixed64, sfixed64, and double protocol buffer types. +func (p *Buffer) EncodeFixed64(x uint64) error { + p.buf = append(p.buf, + uint8(x), + uint8(x>>8), + uint8(x>>16), + uint8(x>>24), + uint8(x>>32), + uint8(x>>40), + uint8(x>>48), + uint8(x>>56)) + return nil +} + +func sizeFixed64(x uint64) int { + return 8 +} + +// EncodeFixed32 writes a 32-bit integer to the Buffer. +// This is the format for the +// fixed32, sfixed32, and float protocol buffer types. +func (p *Buffer) EncodeFixed32(x uint64) error { + p.buf = append(p.buf, + uint8(x), + uint8(x>>8), + uint8(x>>16), + uint8(x>>24)) + return nil +} + +func sizeFixed32(x uint64) int { + return 4 +} + +// EncodeZigzag64 writes a zigzag-encoded 64-bit integer +// to the Buffer. +// This is the format used for the sint64 protocol buffer type. +func (p *Buffer) EncodeZigzag64(x uint64) error { + // use signed number to get arithmetic right shift. + return p.EncodeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} + +func sizeZigzag64(x uint64) int { + return sizeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63)))) +} + +// EncodeZigzag32 writes a zigzag-encoded 32-bit integer +// to the Buffer. +// This is the format used for the sint32 protocol buffer type. +func (p *Buffer) EncodeZigzag32(x uint64) error { + // use signed number to get arithmetic right shift. + return p.EncodeVarint(uint64((uint32(x) << 1) ^ uint32((int32(x) >> 31)))) +} + +func sizeZigzag32(x uint64) int { + return sizeVarint(uint64((uint32(x) << 1) ^ uint32((int32(x) >> 31)))) +} + +// EncodeRawBytes writes a count-delimited byte buffer to the Buffer. +// This is the format used for the bytes protocol buffer +// type and for embedded messages. +func (p *Buffer) EncodeRawBytes(b []byte) error { + p.EncodeVarint(uint64(len(b))) + p.buf = append(p.buf, b...) + return nil +} + +func sizeRawBytes(b []byte) int { + return sizeVarint(uint64(len(b))) + + len(b) +} + +// EncodeStringBytes writes an encoded string to the Buffer. +// This is the format used for the proto2 string type. +func (p *Buffer) EncodeStringBytes(s string) error { + p.EncodeVarint(uint64(len(s))) + p.buf = append(p.buf, s...) + return nil +} + +func sizeStringBytes(s string) int { + return sizeVarint(uint64(len(s))) + + len(s) +} + +// Marshaler is the interface representing objects that can marshal themselves. +type Marshaler interface { + Marshal() ([]byte, error) +} + +// Marshal takes the protocol buffer +// and encodes it into the wire format, returning the data. +func Marshal(pb Message) ([]byte, error) { + // Can the object marshal itself? + if m, ok := pb.(Marshaler); ok { + return m.Marshal() + } + p := NewBuffer(nil) + err := p.Marshal(pb) + var state errorState + if err != nil && !state.shouldContinue(err, nil) { + return nil, err + } + if p.buf == nil && err == nil { + // Return a non-nil slice on success. + return []byte{}, nil + } + return p.buf, err +} + +// EncodeMessage writes the protocol buffer to the Buffer, +// prefixed by a varint-encoded length. +func (p *Buffer) EncodeMessage(pb Message) error { + t, base, err := getbase(pb) + if structPointer_IsNil(base) { + return ErrNil + } + if err == nil { + var state errorState + err = p.enc_len_struct(GetProperties(t.Elem()), base, &state) + } + return err +} + +// Marshal takes the protocol buffer +// and encodes it into the wire format, writing the result to the +// Buffer. +func (p *Buffer) Marshal(pb Message) error { + // Can the object marshal itself? + if m, ok := pb.(Marshaler); ok { + data, err := m.Marshal() + if err != nil { + return err + } + p.buf = append(p.buf, data...) + return nil + } + + t, base, err := getbase(pb) + if structPointer_IsNil(base) { + return ErrNil + } + if err == nil { + err = p.enc_struct(GetProperties(t.Elem()), base) + } + + if collectStats { + stats.Encode++ + } + + return err +} + +// Size returns the encoded size of a protocol buffer. +func Size(pb Message) (n int) { + // Can the object marshal itself? If so, Size is slow. + // TODO: add Size to Marshaler, or add a Sizer interface. + if m, ok := pb.(Marshaler); ok { + b, _ := m.Marshal() + return len(b) + } + + t, base, err := getbase(pb) + if structPointer_IsNil(base) { + return 0 + } + if err == nil { + n = size_struct(GetProperties(t.Elem()), base) + } + + if collectStats { + stats.Size++ + } + + return +} + +// Individual type encoders. + +// Encode a bool. +func (o *Buffer) enc_bool(p *Properties, base structPointer) error { + v := *structPointer_Bool(base, p.field) + if v == nil { + return ErrNil + } + x := 0 + if *v { + x = 1 + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func (o *Buffer) enc_proto3_bool(p *Properties, base structPointer) error { + v := *structPointer_BoolVal(base, p.field) + if !v { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, 1) + return nil +} + +func size_bool(p *Properties, base structPointer) int { + v := *structPointer_Bool(base, p.field) + if v == nil { + return 0 + } + return len(p.tagcode) + 1 // each bool takes exactly one byte +} + +func size_proto3_bool(p *Properties, base structPointer) int { + v := *structPointer_BoolVal(base, p.field) + if !v { + return 0 + } + return len(p.tagcode) + 1 // each bool takes exactly one byte +} + +// Encode an int32. +func (o *Buffer) enc_int32(p *Properties, base structPointer) error { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return ErrNil + } + x := int32(word32_Get(v)) // permit sign extension to use full 64-bit range + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func (o *Buffer) enc_proto3_int32(p *Properties, base structPointer) error { + v := structPointer_Word32Val(base, p.field) + x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range + if x == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func size_int32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return 0 + } + x := int32(word32_Get(v)) // permit sign extension to use full 64-bit range + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +func size_proto3_int32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32Val(base, p.field) + x := int32(word32Val_Get(v)) // permit sign extension to use full 64-bit range + if x == 0 { + return 0 + } + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +// Encode a uint32. +// Exactly the same as int32, except for no sign extension. +func (o *Buffer) enc_uint32(p *Properties, base structPointer) error { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return ErrNil + } + x := word32_Get(v) + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func (o *Buffer) enc_proto3_uint32(p *Properties, base structPointer) error { + v := structPointer_Word32Val(base, p.field) + x := word32Val_Get(v) + if x == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, uint64(x)) + return nil +} + +func size_uint32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32(base, p.field) + if word32_IsNil(v) { + return 0 + } + x := word32_Get(v) + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +func size_proto3_uint32(p *Properties, base structPointer) (n int) { + v := structPointer_Word32Val(base, p.field) + x := word32Val_Get(v) + if x == 0 { + return 0 + } + n += len(p.tagcode) + n += p.valSize(uint64(x)) + return +} + +// Encode an int64. +func (o *Buffer) enc_int64(p *Properties, base structPointer) error { + v := structPointer_Word64(base, p.field) + if word64_IsNil(v) { + return ErrNil + } + x := word64_Get(v) + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, x) + return nil +} + +func (o *Buffer) enc_proto3_int64(p *Properties, base structPointer) error { + v := structPointer_Word64Val(base, p.field) + x := word64Val_Get(v) + if x == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, x) + return nil +} + +func size_int64(p *Properties, base structPointer) (n int) { + v := structPointer_Word64(base, p.field) + if word64_IsNil(v) { + return 0 + } + x := word64_Get(v) + n += len(p.tagcode) + n += p.valSize(x) + return +} + +func size_proto3_int64(p *Properties, base structPointer) (n int) { + v := structPointer_Word64Val(base, p.field) + x := word64Val_Get(v) + if x == 0 { + return 0 + } + n += len(p.tagcode) + n += p.valSize(x) + return +} + +// Encode a string. +func (o *Buffer) enc_string(p *Properties, base structPointer) error { + v := *structPointer_String(base, p.field) + if v == nil { + return ErrNil + } + x := *v + o.buf = append(o.buf, p.tagcode...) + o.EncodeStringBytes(x) + return nil +} + +func (o *Buffer) enc_proto3_string(p *Properties, base structPointer) error { + v := *structPointer_StringVal(base, p.field) + if v == "" { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeStringBytes(v) + return nil +} + +func size_string(p *Properties, base structPointer) (n int) { + v := *structPointer_String(base, p.field) + if v == nil { + return 0 + } + x := *v + n += len(p.tagcode) + n += sizeStringBytes(x) + return +} + +func size_proto3_string(p *Properties, base structPointer) (n int) { + v := *structPointer_StringVal(base, p.field) + if v == "" { + return 0 + } + n += len(p.tagcode) + n += sizeStringBytes(v) + return +} + +// All protocol buffer fields are nillable, but be careful. +func isNil(v reflect.Value) bool { + switch v.Kind() { + case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return v.IsNil() + } + return false +} + +// Encode a message struct. +func (o *Buffer) enc_struct_message(p *Properties, base structPointer) error { + var state errorState + structp := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(structp) { + return ErrNil + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, err := m.Marshal() + if err != nil && !state.shouldContinue(err, nil) { + return err + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(data) + return state.err + } + + o.buf = append(o.buf, p.tagcode...) + return o.enc_len_struct(p.sprop, structp, &state) +} + +func size_struct_message(p *Properties, base structPointer) int { + structp := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(structp) { + return 0 + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, _ := m.Marshal() + n0 := len(p.tagcode) + n1 := sizeRawBytes(data) + return n0 + n1 + } + + n0 := len(p.tagcode) + n1 := size_struct(p.sprop, structp) + n2 := sizeVarint(uint64(n1)) // size of encoded length + return n0 + n1 + n2 +} + +// Encode a group struct. +func (o *Buffer) enc_struct_group(p *Properties, base structPointer) error { + var state errorState + b := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(b) { + return ErrNil + } + + o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup)) + err := o.enc_struct(p.sprop, b) + if err != nil && !state.shouldContinue(err, nil) { + return err + } + o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup)) + return state.err +} + +func size_struct_group(p *Properties, base structPointer) (n int) { + b := structPointer_GetStructPointer(base, p.field) + if structPointer_IsNil(b) { + return 0 + } + + n += sizeVarint(uint64((p.Tag << 3) | WireStartGroup)) + n += size_struct(p.sprop, b) + n += sizeVarint(uint64((p.Tag << 3) | WireEndGroup)) + return +} + +// Encode a slice of bools ([]bool). +func (o *Buffer) enc_slice_bool(p *Properties, base structPointer) error { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return ErrNil + } + for _, x := range s { + o.buf = append(o.buf, p.tagcode...) + v := uint64(0) + if x { + v = 1 + } + p.valEnc(o, v) + } + return nil +} + +func size_slice_bool(p *Properties, base structPointer) int { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return 0 + } + return l * (len(p.tagcode) + 1) // each bool takes exactly one byte +} + +// Encode a slice of bools ([]bool) in packed format. +func (o *Buffer) enc_slice_packed_bool(p *Properties, base structPointer) error { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(l)) // each bool takes exactly one byte + for _, x := range s { + v := uint64(0) + if x { + v = 1 + } + p.valEnc(o, v) + } + return nil +} + +func size_slice_packed_bool(p *Properties, base structPointer) (n int) { + s := *structPointer_BoolSlice(base, p.field) + l := len(s) + if l == 0 { + return 0 + } + n += len(p.tagcode) + n += sizeVarint(uint64(l)) + n += l // each bool takes exactly one byte + return +} + +// Encode a slice of bytes ([]byte). +func (o *Buffer) enc_slice_byte(p *Properties, base structPointer) error { + s := *structPointer_Bytes(base, p.field) + if s == nil { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(s) + return nil +} + +func (o *Buffer) enc_proto3_slice_byte(p *Properties, base structPointer) error { + s := *structPointer_Bytes(base, p.field) + if len(s) == 0 { + return ErrNil + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(s) + return nil +} + +func size_slice_byte(p *Properties, base structPointer) (n int) { + s := *structPointer_Bytes(base, p.field) + if s == nil { + return 0 + } + n += len(p.tagcode) + n += sizeRawBytes(s) + return +} + +func size_proto3_slice_byte(p *Properties, base structPointer) (n int) { + s := *structPointer_Bytes(base, p.field) + if len(s) == 0 { + return 0 + } + n += len(p.tagcode) + n += sizeRawBytes(s) + return +} + +// Encode a slice of int32s ([]int32). +func (o *Buffer) enc_slice_int32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + p.valEnc(o, uint64(x)) + } + return nil +} + +func size_slice_int32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + for i := 0; i < l; i++ { + n += len(p.tagcode) + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + n += p.valSize(uint64(x)) + } + return +} + +// Encode a slice of int32s ([]int32) in packed format. +func (o *Buffer) enc_slice_packed_int32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + // TODO: Reuse a Buffer. + buf := NewBuffer(nil) + for i := 0; i < l; i++ { + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + p.valEnc(buf, uint64(x)) + } + + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(len(buf.buf))) + o.buf = append(o.buf, buf.buf...) + return nil +} + +func size_slice_packed_int32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + var bufSize int + for i := 0; i < l; i++ { + x := int32(s.Index(i)) // permit sign extension to use full 64-bit range + bufSize += p.valSize(uint64(x)) + } + + n += len(p.tagcode) + n += sizeVarint(uint64(bufSize)) + n += bufSize + return +} + +// Encode a slice of uint32s ([]uint32). +// Exactly the same as int32, except for no sign extension. +func (o *Buffer) enc_slice_uint32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + x := s.Index(i) + p.valEnc(o, uint64(x)) + } + return nil +} + +func size_slice_uint32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + for i := 0; i < l; i++ { + n += len(p.tagcode) + x := s.Index(i) + n += p.valSize(uint64(x)) + } + return +} + +// Encode a slice of uint32s ([]uint32) in packed format. +// Exactly the same as int32, except for no sign extension. +func (o *Buffer) enc_slice_packed_uint32(p *Properties, base structPointer) error { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + // TODO: Reuse a Buffer. + buf := NewBuffer(nil) + for i := 0; i < l; i++ { + p.valEnc(buf, uint64(s.Index(i))) + } + + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(len(buf.buf))) + o.buf = append(o.buf, buf.buf...) + return nil +} + +func size_slice_packed_uint32(p *Properties, base structPointer) (n int) { + s := structPointer_Word32Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + var bufSize int + for i := 0; i < l; i++ { + bufSize += p.valSize(uint64(s.Index(i))) + } + + n += len(p.tagcode) + n += sizeVarint(uint64(bufSize)) + n += bufSize + return +} + +// Encode a slice of int64s ([]int64). +func (o *Buffer) enc_slice_int64(p *Properties, base structPointer) error { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + p.valEnc(o, s.Index(i)) + } + return nil +} + +func size_slice_int64(p *Properties, base structPointer) (n int) { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + for i := 0; i < l; i++ { + n += len(p.tagcode) + n += p.valSize(s.Index(i)) + } + return +} + +// Encode a slice of int64s ([]int64) in packed format. +func (o *Buffer) enc_slice_packed_int64(p *Properties, base structPointer) error { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return ErrNil + } + // TODO: Reuse a Buffer. + buf := NewBuffer(nil) + for i := 0; i < l; i++ { + p.valEnc(buf, s.Index(i)) + } + + o.buf = append(o.buf, p.tagcode...) + o.EncodeVarint(uint64(len(buf.buf))) + o.buf = append(o.buf, buf.buf...) + return nil +} + +func size_slice_packed_int64(p *Properties, base structPointer) (n int) { + s := structPointer_Word64Slice(base, p.field) + l := s.Len() + if l == 0 { + return 0 + } + var bufSize int + for i := 0; i < l; i++ { + bufSize += p.valSize(s.Index(i)) + } + + n += len(p.tagcode) + n += sizeVarint(uint64(bufSize)) + n += bufSize + return +} + +// Encode a slice of slice of bytes ([][]byte). +func (o *Buffer) enc_slice_slice_byte(p *Properties, base structPointer) error { + ss := *structPointer_BytesSlice(base, p.field) + l := len(ss) + if l == 0 { + return ErrNil + } + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(ss[i]) + } + return nil +} + +func size_slice_slice_byte(p *Properties, base structPointer) (n int) { + ss := *structPointer_BytesSlice(base, p.field) + l := len(ss) + if l == 0 { + return 0 + } + n += l * len(p.tagcode) + for i := 0; i < l; i++ { + n += sizeRawBytes(ss[i]) + } + return +} + +// Encode a slice of strings ([]string). +func (o *Buffer) enc_slice_string(p *Properties, base structPointer) error { + ss := *structPointer_StringSlice(base, p.field) + l := len(ss) + for i := 0; i < l; i++ { + o.buf = append(o.buf, p.tagcode...) + o.EncodeStringBytes(ss[i]) + } + return nil +} + +func size_slice_string(p *Properties, base structPointer) (n int) { + ss := *structPointer_StringSlice(base, p.field) + l := len(ss) + n += l * len(p.tagcode) + for i := 0; i < l; i++ { + n += sizeStringBytes(ss[i]) + } + return +} + +// Encode a slice of message structs ([]*struct). +func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) error { + var state errorState + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + + for i := 0; i < l; i++ { + structp := s.Index(i) + if structPointer_IsNil(structp) { + return errRepeatedHasNil + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, err := m.Marshal() + if err != nil && !state.shouldContinue(err, nil) { + return err + } + o.buf = append(o.buf, p.tagcode...) + o.EncodeRawBytes(data) + continue + } + + o.buf = append(o.buf, p.tagcode...) + err := o.enc_len_struct(p.sprop, structp, &state) + if err != nil && !state.shouldContinue(err, nil) { + if err == ErrNil { + return errRepeatedHasNil + } + return err + } + } + return state.err +} + +func size_slice_struct_message(p *Properties, base structPointer) (n int) { + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + n += l * len(p.tagcode) + for i := 0; i < l; i++ { + structp := s.Index(i) + if structPointer_IsNil(structp) { + return // return the size up to this point + } + + // Can the object marshal itself? + if p.isMarshaler { + m := structPointer_Interface(structp, p.stype).(Marshaler) + data, _ := m.Marshal() + n += len(p.tagcode) + n += sizeRawBytes(data) + continue + } + + n0 := size_struct(p.sprop, structp) + n1 := sizeVarint(uint64(n0)) // size of encoded length + n += n0 + n1 + } + return +} + +// Encode a slice of group structs ([]*struct). +func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error { + var state errorState + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + + for i := 0; i < l; i++ { + b := s.Index(i) + if structPointer_IsNil(b) { + return errRepeatedHasNil + } + + o.EncodeVarint(uint64((p.Tag << 3) | WireStartGroup)) + + err := o.enc_struct(p.sprop, b) + + if err != nil && !state.shouldContinue(err, nil) { + if err == ErrNil { + return errRepeatedHasNil + } + return err + } + + o.EncodeVarint(uint64((p.Tag << 3) | WireEndGroup)) + } + return state.err +} + +func size_slice_struct_group(p *Properties, base structPointer) (n int) { + s := structPointer_StructPointerSlice(base, p.field) + l := s.Len() + + n += l * sizeVarint(uint64((p.Tag<<3)|WireStartGroup)) + n += l * sizeVarint(uint64((p.Tag<<3)|WireEndGroup)) + for i := 0; i < l; i++ { + b := s.Index(i) + if structPointer_IsNil(b) { + return // return size up to this point + } + + n += size_struct(p.sprop, b) + } + return +} + +// Encode an extension map. +func (o *Buffer) enc_map(p *Properties, base structPointer) error { + v := *structPointer_ExtMap(base, p.field) + if err := encodeExtensionMap(v); err != nil { + return err + } + // Fast-path for common cases: zero or one extensions. + if len(v) <= 1 { + for _, e := range v { + o.buf = append(o.buf, e.enc...) + } + return nil + } + + // Sort keys to provide a deterministic encoding. + keys := make([]int, 0, len(v)) + for k := range v { + keys = append(keys, int(k)) + } + sort.Ints(keys) + + for _, k := range keys { + o.buf = append(o.buf, v[int32(k)].enc...) + } + return nil +} + +func size_map(p *Properties, base structPointer) int { + v := *structPointer_ExtMap(base, p.field) + return sizeExtensionMap(v) +} + +// Encode a map field. +func (o *Buffer) enc_new_map(p *Properties, base structPointer) error { + var state errorState // XXX: or do we need to plumb this through? + + /* + A map defined as + map map_field = N; + is encoded in the same way as + message MapFieldEntry { + key_type key = 1; + value_type value = 2; + } + repeated MapFieldEntry map_field = N; + */ + + v := structPointer_NewAt(base, p.field, p.mtype).Elem() // map[K]V + if v.Len() == 0 { + return nil + } + + keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype) + + enc := func() error { + if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil { + return err + } + if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil { + return err + } + return nil + } + + keys := v.MapKeys() + sort.Sort(mapKeys(keys)) + for _, key := range keys { + val := v.MapIndex(key) + + // The only illegal map entry values are nil message pointers. + if val.Kind() == reflect.Ptr && val.IsNil() { + return errors.New("proto: map has nil element") + } + + keycopy.Set(key) + valcopy.Set(val) + + o.buf = append(o.buf, p.tagcode...) + if err := o.enc_len_thing(enc, &state); err != nil { + return err + } + } + return nil +} + +func size_new_map(p *Properties, base structPointer) int { + v := structPointer_NewAt(base, p.field, p.mtype).Elem() // map[K]V + + keycopy, valcopy, keybase, valbase := mapEncodeScratch(p.mtype) + + n := 0 + for _, key := range v.MapKeys() { + val := v.MapIndex(key) + keycopy.Set(key) + valcopy.Set(val) + + // Tag codes for key and val are the responsibility of the sub-sizer. + keysize := p.mkeyprop.size(p.mkeyprop, keybase) + valsize := p.mvalprop.size(p.mvalprop, valbase) + entry := keysize + valsize + // Add on tag code and length of map entry itself. + n += len(p.tagcode) + sizeVarint(uint64(entry)) + entry + } + return n +} + +// mapEncodeScratch returns a new reflect.Value matching the map's value type, +// and a structPointer suitable for passing to an encoder or sizer. +func mapEncodeScratch(mapType reflect.Type) (keycopy, valcopy reflect.Value, keybase, valbase structPointer) { + // Prepare addressable doubly-indirect placeholders for the key and value types. + // This is needed because the element-type encoders expect **T, but the map iteration produces T. + + keycopy = reflect.New(mapType.Key()).Elem() // addressable K + keyptr := reflect.New(reflect.PtrTo(keycopy.Type())).Elem() // addressable *K + keyptr.Set(keycopy.Addr()) // + keybase = toStructPointer(keyptr.Addr()) // **K + + // Value types are more varied and require special handling. + switch mapType.Elem().Kind() { + case reflect.Slice: + // []byte + var dummy []byte + valcopy = reflect.ValueOf(&dummy).Elem() // addressable []byte + valbase = toStructPointer(valcopy.Addr()) + case reflect.Ptr: + // message; the generated field type is map[K]*Msg (so V is *Msg), + // so we only need one level of indirection. + valcopy = reflect.New(mapType.Elem()).Elem() // addressable V + valbase = toStructPointer(valcopy.Addr()) + default: + // everything else + valcopy = reflect.New(mapType.Elem()).Elem() // addressable V + valptr := reflect.New(reflect.PtrTo(valcopy.Type())).Elem() // addressable *V + valptr.Set(valcopy.Addr()) // + valbase = toStructPointer(valptr.Addr()) // **V + } + return +} + +// Encode a struct. +func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error { + var state errorState + // Encode fields in tag order so that decoders may use optimizations + // that depend on the ordering. + // https://developers.google.com/protocol-buffers/docs/encoding#order + for _, i := range prop.order { + p := prop.Prop[i] + if p.enc != nil { + err := p.enc(o, p, base) + if err != nil { + if err == ErrNil { + if p.Required && state.err == nil { + state.err = &RequiredNotSetError{p.Name} + } + } else if err == errRepeatedHasNil { + // Give more context to nil values in repeated fields. + return errors.New("repeated field " + p.OrigName + " has nil element") + } else if !state.shouldContinue(err, p) { + return err + } + } + } + } + + // Do oneof fields. + if prop.oneofMarshaler != nil { + m := structPointer_Interface(base, prop.stype).(Message) + if err := prop.oneofMarshaler(m, o); err != nil { + return err + } + } + + // Add unrecognized fields at the end. + if prop.unrecField.IsValid() { + v := *structPointer_Bytes(base, prop.unrecField) + if len(v) > 0 { + o.buf = append(o.buf, v...) + } + } + + return state.err +} + +func size_struct(prop *StructProperties, base structPointer) (n int) { + for _, i := range prop.order { + p := prop.Prop[i] + if p.size != nil { + n += p.size(p, base) + } + } + + // Add unrecognized fields at the end. + if prop.unrecField.IsValid() { + v := *structPointer_Bytes(base, prop.unrecField) + n += len(v) + } + + // Factor in any oneof fields. + // TODO: This could be faster and use less reflection. + if prop.oneofMarshaler != nil { + sv := reflect.ValueOf(structPointer_Interface(base, prop.stype)).Elem() + for i := 0; i < prop.stype.NumField(); i++ { + fv := sv.Field(i) + if fv.Kind() != reflect.Interface || fv.IsNil() { + continue + } + if prop.stype.Field(i).Tag.Get("protobuf_oneof") == "" { + continue + } + spv := fv.Elem() // interface -> *T + sv := spv.Elem() // *T -> T + sf := sv.Type().Field(0) // StructField inside T + var prop Properties + prop.Init(sf.Type, "whatever", sf.Tag.Get("protobuf"), &sf) + n += prop.size(&prop, toStructPointer(spv)) + } + } + + return +} + +var zeroes [20]byte // longer than any conceivable sizeVarint + +// Encode a struct, preceded by its encoded length (as a varint). +func (o *Buffer) enc_len_struct(prop *StructProperties, base structPointer, state *errorState) error { + return o.enc_len_thing(func() error { return o.enc_struct(prop, base) }, state) +} + +// Encode something, preceded by its encoded length (as a varint). +func (o *Buffer) enc_len_thing(enc func() error, state *errorState) error { + iLen := len(o.buf) + o.buf = append(o.buf, 0, 0, 0, 0) // reserve four bytes for length + iMsg := len(o.buf) + err := enc() + if err != nil && !state.shouldContinue(err, nil) { + return err + } + lMsg := len(o.buf) - iMsg + lLen := sizeVarint(uint64(lMsg)) + switch x := lLen - (iMsg - iLen); { + case x > 0: // actual length is x bytes larger than the space we reserved + // Move msg x bytes right. + o.buf = append(o.buf, zeroes[:x]...) + copy(o.buf[iMsg+x:], o.buf[iMsg:iMsg+lMsg]) + case x < 0: // actual length is x bytes smaller than the space we reserved + // Move msg x bytes left. + copy(o.buf[iMsg+x:], o.buf[iMsg:iMsg+lMsg]) + o.buf = o.buf[:len(o.buf)+x] // x is negative + } + // Encode the length in the reserved space. + o.buf = o.buf[:iLen] + o.EncodeVarint(uint64(lMsg)) + o.buf = o.buf[:len(o.buf)+lMsg] + return state.err +} + +// errorState maintains the first error that occurs and updates that error +// with additional context. +type errorState struct { + err error +} + +// shouldContinue reports whether encoding should continue upon encountering the +// given error. If the error is RequiredNotSetError, shouldContinue returns true +// and, if this is the first appearance of that error, remembers it for future +// reporting. +// +// If prop is not nil, it may update any error with additional context about the +// field with the error. +func (s *errorState) shouldContinue(err error, prop *Properties) bool { + // Ignore unset required fields. + reqNotSet, ok := err.(*RequiredNotSetError) + if !ok { + return false + } + if s.err == nil { + if prop != nil { + err = &RequiredNotSetError{prop.Name + "." + reqNotSet.field} + } + s.err = err + } + return true +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/equal.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/equal.go new file mode 100644 index 0000000..5475c3d --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/equal.go @@ -0,0 +1,267 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Protocol buffer comparison. +// TODO: MessageSet. + +package proto + +import ( + "bytes" + "log" + "reflect" + "strings" +) + +/* +Equal returns true iff protocol buffers a and b are equal. +The arguments must both be pointers to protocol buffer structs. + +Equality is defined in this way: + - Two messages are equal iff they are the same type, + corresponding fields are equal, unknown field sets + are equal, and extensions sets are equal. + - Two set scalar fields are equal iff their values are equal. + If the fields are of a floating-point type, remember that + NaN != x for all x, including NaN. + - Two repeated fields are equal iff their lengths are the same, + and their corresponding elements are equal (a "bytes" field, + although represented by []byte, is not a repeated field) + - Two unset fields are equal. + - Two unknown field sets are equal if their current + encoded state is equal. + - Two extension sets are equal iff they have corresponding + elements that are pairwise equal. + - Every other combination of things are not equal. + +The return value is undefined if a and b are not protocol buffers. +*/ +func Equal(a, b Message) bool { + if a == nil || b == nil { + return a == b + } + v1, v2 := reflect.ValueOf(a), reflect.ValueOf(b) + if v1.Type() != v2.Type() { + return false + } + if v1.Kind() == reflect.Ptr { + if v1.IsNil() { + return v2.IsNil() + } + if v2.IsNil() { + return false + } + v1, v2 = v1.Elem(), v2.Elem() + } + if v1.Kind() != reflect.Struct { + return false + } + return equalStruct(v1, v2) +} + +// v1 and v2 are known to have the same type. +func equalStruct(v1, v2 reflect.Value) bool { + for i := 0; i < v1.NumField(); i++ { + f := v1.Type().Field(i) + if strings.HasPrefix(f.Name, "XXX_") { + continue + } + f1, f2 := v1.Field(i), v2.Field(i) + if f.Type.Kind() == reflect.Ptr { + if n1, n2 := f1.IsNil(), f2.IsNil(); n1 && n2 { + // both unset + continue + } else if n1 != n2 { + // set/unset mismatch + return false + } + b1, ok := f1.Interface().(raw) + if ok { + b2 := f2.Interface().(raw) + // RawMessage + if !bytes.Equal(b1.Bytes(), b2.Bytes()) { + return false + } + continue + } + f1, f2 = f1.Elem(), f2.Elem() + } + if !equalAny(f1, f2) { + return false + } + } + + if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() { + em2 := v2.FieldByName("XXX_extensions") + if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) { + return false + } + } + + uf := v1.FieldByName("XXX_unrecognized") + if !uf.IsValid() { + return true + } + + u1 := uf.Bytes() + u2 := v2.FieldByName("XXX_unrecognized").Bytes() + if !bytes.Equal(u1, u2) { + return false + } + + return true +} + +// v1 and v2 are known to have the same type. +func equalAny(v1, v2 reflect.Value) bool { + if v1.Type() == protoMessageType { + m1, _ := v1.Interface().(Message) + m2, _ := v2.Interface().(Message) + return Equal(m1, m2) + } + switch v1.Kind() { + case reflect.Bool: + return v1.Bool() == v2.Bool() + case reflect.Float32, reflect.Float64: + return v1.Float() == v2.Float() + case reflect.Int32, reflect.Int64: + return v1.Int() == v2.Int() + case reflect.Interface: + // Probably a oneof field; compare the inner values. + n1, n2 := v1.IsNil(), v2.IsNil() + if n1 || n2 { + return n1 == n2 + } + e1, e2 := v1.Elem(), v2.Elem() + if e1.Type() != e2.Type() { + return false + } + return equalAny(e1, e2) + case reflect.Map: + if v1.Len() != v2.Len() { + return false + } + for _, key := range v1.MapKeys() { + val2 := v2.MapIndex(key) + if !val2.IsValid() { + // This key was not found in the second map. + return false + } + if !equalAny(v1.MapIndex(key), val2) { + return false + } + } + return true + case reflect.Ptr: + return equalAny(v1.Elem(), v2.Elem()) + case reflect.Slice: + if v1.Type().Elem().Kind() == reflect.Uint8 { + // short circuit: []byte + if v1.IsNil() != v2.IsNil() { + return false + } + return bytes.Equal(v1.Interface().([]byte), v2.Interface().([]byte)) + } + + if v1.Len() != v2.Len() { + return false + } + for i := 0; i < v1.Len(); i++ { + if !equalAny(v1.Index(i), v2.Index(i)) { + return false + } + } + return true + case reflect.String: + return v1.Interface().(string) == v2.Interface().(string) + case reflect.Struct: + return equalStruct(v1, v2) + case reflect.Uint32, reflect.Uint64: + return v1.Uint() == v2.Uint() + } + + // unknown type, so not a protocol buffer + log.Printf("proto: don't know how to compare %v", v1) + return false +} + +// base is the struct type that the extensions are based on. +// em1 and em2 are extension maps. +func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool { + if len(em1) != len(em2) { + return false + } + + for extNum, e1 := range em1 { + e2, ok := em2[extNum] + if !ok { + return false + } + + m1, m2 := e1.value, e2.value + + if m1 != nil && m2 != nil { + // Both are unencoded. + if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) { + return false + } + continue + } + + // At least one is encoded. To do a semantically correct comparison + // we need to unmarshal them first. + var desc *ExtensionDesc + if m := extensionMaps[base]; m != nil { + desc = m[extNum] + } + if desc == nil { + log.Printf("proto: don't know how to compare extension %d of %v", extNum, base) + continue + } + var err error + if m1 == nil { + m1, err = decodeExtension(e1.enc, desc) + } + if m2 == nil && err == nil { + m2, err = decodeExtension(e2.enc, desc) + } + if err != nil { + // The encoded form is invalid. + log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err) + return false + } + if !equalAny(reflect.ValueOf(m1), reflect.ValueOf(m2)) { + return false + } + } + + return true +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/equal_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/equal_test.go new file mode 100644 index 0000000..0e0db8a --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/equal_test.go @@ -0,0 +1,209 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2011 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "testing" + + . "github.com/golang/protobuf/proto" + pb "github.com/golang/protobuf/proto/testdata" +) + +// Four identical base messages. +// The init function adds extensions to some of them. +var messageWithoutExtension = &pb.MyMessage{Count: Int32(7)} +var messageWithExtension1a = &pb.MyMessage{Count: Int32(7)} +var messageWithExtension1b = &pb.MyMessage{Count: Int32(7)} +var messageWithExtension2 = &pb.MyMessage{Count: Int32(7)} + +// Two messages with non-message extensions. +var messageWithInt32Extension1 = &pb.MyMessage{Count: Int32(8)} +var messageWithInt32Extension2 = &pb.MyMessage{Count: Int32(8)} + +func init() { + ext1 := &pb.Ext{Data: String("Kirk")} + ext2 := &pb.Ext{Data: String("Picard")} + + // messageWithExtension1a has ext1, but never marshals it. + if err := SetExtension(messageWithExtension1a, pb.E_Ext_More, ext1); err != nil { + panic("SetExtension on 1a failed: " + err.Error()) + } + + // messageWithExtension1b is the unmarshaled form of messageWithExtension1a. + if err := SetExtension(messageWithExtension1b, pb.E_Ext_More, ext1); err != nil { + panic("SetExtension on 1b failed: " + err.Error()) + } + buf, err := Marshal(messageWithExtension1b) + if err != nil { + panic("Marshal of 1b failed: " + err.Error()) + } + messageWithExtension1b.Reset() + if err := Unmarshal(buf, messageWithExtension1b); err != nil { + panic("Unmarshal of 1b failed: " + err.Error()) + } + + // messageWithExtension2 has ext2. + if err := SetExtension(messageWithExtension2, pb.E_Ext_More, ext2); err != nil { + panic("SetExtension on 2 failed: " + err.Error()) + } + + if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(23)); err != nil { + panic("SetExtension on Int32-1 failed: " + err.Error()) + } + if err := SetExtension(messageWithInt32Extension1, pb.E_Ext_Number, Int32(24)); err != nil { + panic("SetExtension on Int32-2 failed: " + err.Error()) + } +} + +var EqualTests = []struct { + desc string + a, b Message + exp bool +}{ + {"different types", &pb.GoEnum{}, &pb.GoTestField{}, false}, + {"equal empty", &pb.GoEnum{}, &pb.GoEnum{}, true}, + {"nil vs nil", nil, nil, true}, + {"typed nil vs typed nil", (*pb.GoEnum)(nil), (*pb.GoEnum)(nil), true}, + {"typed nil vs empty", (*pb.GoEnum)(nil), &pb.GoEnum{}, false}, + {"different typed nil", (*pb.GoEnum)(nil), (*pb.GoTestField)(nil), false}, + + {"one set field, one unset field", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{}, false}, + {"one set field zero, one unset field", &pb.GoTest{Param: Int32(0)}, &pb.GoTest{}, false}, + {"different set fields", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("bar")}, false}, + {"equal set", &pb.GoTestField{Label: String("foo")}, &pb.GoTestField{Label: String("foo")}, true}, + + {"repeated, one set", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{}, false}, + {"repeated, different length", &pb.GoTest{F_Int32Repeated: []int32{2, 3}}, &pb.GoTest{F_Int32Repeated: []int32{2}}, false}, + {"repeated, different value", &pb.GoTest{F_Int32Repeated: []int32{2}}, &pb.GoTest{F_Int32Repeated: []int32{3}}, false}, + {"repeated, equal", &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, &pb.GoTest{F_Int32Repeated: []int32{2, 4}}, true}, + {"repeated, nil equal nil", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: nil}, true}, + {"repeated, nil equal empty", &pb.GoTest{F_Int32Repeated: nil}, &pb.GoTest{F_Int32Repeated: []int32{}}, true}, + {"repeated, empty equal nil", &pb.GoTest{F_Int32Repeated: []int32{}}, &pb.GoTest{F_Int32Repeated: nil}, true}, + + { + "nested, different", + &pb.GoTest{RequiredField: &pb.GoTestField{Label: String("foo")}}, + &pb.GoTest{RequiredField: &pb.GoTestField{Label: String("bar")}}, + false, + }, + { + "nested, equal", + &pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}}, + &pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}}, + true, + }, + + {"bytes", &pb.OtherMessage{Value: []byte("foo")}, &pb.OtherMessage{Value: []byte("foo")}, true}, + {"bytes, empty", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: []byte{}}, true}, + {"bytes, empty vs nil", &pb.OtherMessage{Value: []byte{}}, &pb.OtherMessage{Value: nil}, false}, + { + "repeated bytes", + &pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}}, + &pb.MyMessage{RepBytes: [][]byte{[]byte("sham"), []byte("wow")}}, + true, + }, + + {"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false}, + {"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true}, + {"extension vs. different extension", messageWithExtension1a, messageWithExtension2, false}, + + {"int32 extension vs. itself", messageWithInt32Extension1, messageWithInt32Extension1, true}, + {"int32 extension vs. a different int32", messageWithInt32Extension1, messageWithInt32Extension2, false}, + + { + "message with group", + &pb.MyMessage{ + Count: Int32(1), + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: Int32(5), + }, + }, + &pb.MyMessage{ + Count: Int32(1), + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: Int32(5), + }, + }, + true, + }, + + { + "map same", + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + true, + }, + { + "map different entry", + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + &pb.MessageWithMap{NameMapping: map[int32]string{2: "Rob"}}, + false, + }, + { + "map different key only", + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + &pb.MessageWithMap{NameMapping: map[int32]string{2: "Ken"}}, + false, + }, + { + "map different value only", + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Ken"}}, + &pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob"}}, + false, + }, + { + "oneof same", + &pb.Communique{Union: &pb.Communique_Number{41}}, + &pb.Communique{Union: &pb.Communique_Number{41}}, + true, + }, + { + "oneof one nil", + &pb.Communique{Union: &pb.Communique_Number{41}}, + &pb.Communique{}, + false, + }, + { + "oneof different", + &pb.Communique{Union: &pb.Communique_Number{41}}, + &pb.Communique{Union: &pb.Communique_Name{"Bobby Tables"}}, + false, + }, +} + +func TestEqual(t *testing.T) { + for _, tc := range EqualTests { + if res := Equal(tc.a, tc.b); res != tc.exp { + t.Errorf("%v: Equal(%v, %v) = %v, want %v", tc.desc, tc.a, tc.b, res, tc.exp) + } + } +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/extensions.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/extensions.go new file mode 100644 index 0000000..e591cce --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/extensions.go @@ -0,0 +1,400 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Types and routines for supporting protocol buffer extensions. + */ + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "sync" +) + +// ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. +var ErrMissingExtension = errors.New("proto: missing extension") + +// ExtensionRange represents a range of message extensions for a protocol buffer. +// Used in code generated by the protocol compiler. +type ExtensionRange struct { + Start, End int32 // both inclusive +} + +// extendableProto is an interface implemented by any protocol buffer that may be extended. +type extendableProto interface { + Message + ExtensionRangeArray() []ExtensionRange + ExtensionMap() map[int32]Extension +} + +var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() + +// ExtensionDesc represents an extension specification. +// Used in generated code from the protocol compiler. +type ExtensionDesc struct { + ExtendedType Message // nil pointer to the type that is being extended + ExtensionType interface{} // nil pointer to the extension type + Field int32 // field number + Name string // fully-qualified name of extension, for text formatting + Tag string // protobuf tag style +} + +func (ed *ExtensionDesc) repeated() bool { + t := reflect.TypeOf(ed.ExtensionType) + return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 +} + +// Extension represents an extension in a message. +type Extension struct { + // When an extension is stored in a message using SetExtension + // only desc and value are set. When the message is marshaled + // enc will be set to the encoded form of the message. + // + // When a message is unmarshaled and contains extensions, each + // extension will have only enc set. When such an extension is + // accessed using GetExtension (or GetExtensions) desc and value + // will be set. + desc *ExtensionDesc + value interface{} + enc []byte +} + +// SetRawExtension is for testing only. +func SetRawExtension(base extendableProto, id int32, b []byte) { + base.ExtensionMap()[id] = Extension{enc: b} +} + +// isExtensionField returns true iff the given field number is in an extension range. +func isExtensionField(pb extendableProto, field int32) bool { + for _, er := range pb.ExtensionRangeArray() { + if er.Start <= field && field <= er.End { + return true + } + } + return false +} + +// checkExtensionTypes checks that the given extension is valid for pb. +func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { + // Check the extended type. + if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b { + return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) + } + // Check the range. + if !isExtensionField(pb, extension.Field) { + return errors.New("proto: bad extension number; not in declared ranges") + } + return nil +} + +// extPropKey is sufficient to uniquely identify an extension. +type extPropKey struct { + base reflect.Type + field int32 +} + +var extProp = struct { + sync.RWMutex + m map[extPropKey]*Properties +}{ + m: make(map[extPropKey]*Properties), +} + +func extensionProperties(ed *ExtensionDesc) *Properties { + key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} + + extProp.RLock() + if prop, ok := extProp.m[key]; ok { + extProp.RUnlock() + return prop + } + extProp.RUnlock() + + extProp.Lock() + defer extProp.Unlock() + // Check again. + if prop, ok := extProp.m[key]; ok { + return prop + } + + prop := new(Properties) + prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) + extProp.m[key] = prop + return prop +} + +// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m. +func encodeExtensionMap(m map[int32]Extension) error { + for k, e := range m { + if e.value == nil || e.desc == nil { + // Extension is only in its encoded form. + continue + } + + // We don't skip extensions that have an encoded form set, + // because the extension value may have been mutated after + // the last time this function was called. + + et := reflect.TypeOf(e.desc.ExtensionType) + props := extensionProperties(e.desc) + + p := NewBuffer(nil) + // If e.value has type T, the encoder expects a *struct{ X T }. + // Pass a *T with a zero field and hope it all works out. + x := reflect.New(et) + x.Elem().Set(reflect.ValueOf(e.value)) + if err := props.enc(p, props, toStructPointer(x)); err != nil { + return err + } + e.enc = p.buf + m[k] = e + } + return nil +} + +func sizeExtensionMap(m map[int32]Extension) (n int) { + for _, e := range m { + if e.value == nil || e.desc == nil { + // Extension is only in its encoded form. + n += len(e.enc) + continue + } + + // We don't skip extensions that have an encoded form set, + // because the extension value may have been mutated after + // the last time this function was called. + + et := reflect.TypeOf(e.desc.ExtensionType) + props := extensionProperties(e.desc) + + // If e.value has type T, the encoder expects a *struct{ X T }. + // Pass a *T with a zero field and hope it all works out. + x := reflect.New(et) + x.Elem().Set(reflect.ValueOf(e.value)) + n += props.size(props, toStructPointer(x)) + } + return +} + +// HasExtension returns whether the given extension is present in pb. +func HasExtension(pb extendableProto, extension *ExtensionDesc) bool { + // TODO: Check types, field numbers, etc.? + _, ok := pb.ExtensionMap()[extension.Field] + return ok +} + +// ClearExtension removes the given extension from pb. +func ClearExtension(pb extendableProto, extension *ExtensionDesc) { + // TODO: Check types, field numbers, etc.? + delete(pb.ExtensionMap(), extension.Field) +} + +// GetExtension parses and returns the given extension of pb. +// If the extension is not present and has no default value it returns ErrMissingExtension. +func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) { + if err := checkExtensionTypes(pb, extension); err != nil { + return nil, err + } + + emap := pb.ExtensionMap() + e, ok := emap[extension.Field] + if !ok { + // defaultExtensionValue returns the default value or + // ErrMissingExtension if there is no default. + return defaultExtensionValue(extension) + } + + if e.value != nil { + // Already decoded. Check the descriptor, though. + if e.desc != extension { + // This shouldn't happen. If it does, it means that + // GetExtension was called twice with two different + // descriptors with the same field number. + return nil, errors.New("proto: descriptor conflict") + } + return e.value, nil + } + + v, err := decodeExtension(e.enc, extension) + if err != nil { + return nil, err + } + + // Remember the decoded version and drop the encoded version. + // That way it is safe to mutate what we return. + e.value = v + e.desc = extension + e.enc = nil + emap[extension.Field] = e + return e.value, nil +} + +// defaultExtensionValue returns the default value for extension. +// If no default for an extension is defined ErrMissingExtension is returned. +func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) { + t := reflect.TypeOf(extension.ExtensionType) + props := extensionProperties(extension) + + sf, _, err := fieldDefault(t, props) + if err != nil { + return nil, err + } + + if sf == nil || sf.value == nil { + // There is no default value. + return nil, ErrMissingExtension + } + + if t.Kind() != reflect.Ptr { + // We do not need to return a Ptr, we can directly return sf.value. + return sf.value, nil + } + + // We need to return an interface{} that is a pointer to sf.value. + value := reflect.New(t).Elem() + value.Set(reflect.New(value.Type().Elem())) + if sf.kind == reflect.Int32 { + // We may have an int32 or an enum, but the underlying data is int32. + // Since we can't set an int32 into a non int32 reflect.value directly + // set it as a int32. + value.Elem().SetInt(int64(sf.value.(int32))) + } else { + value.Elem().Set(reflect.ValueOf(sf.value)) + } + return value.Interface(), nil +} + +// decodeExtension decodes an extension encoded in b. +func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { + o := NewBuffer(b) + + t := reflect.TypeOf(extension.ExtensionType) + rep := extension.repeated() + + props := extensionProperties(extension) + + // t is a pointer to a struct, pointer to basic type or a slice. + // Allocate a "field" to store the pointer/slice itself; the + // pointer/slice will be stored here. We pass + // the address of this field to props.dec. + // This passes a zero field and a *t and lets props.dec + // interpret it as a *struct{ x t }. + value := reflect.New(t).Elem() + + for { + // Discard wire type and field number varint. It isn't needed. + if _, err := o.DecodeVarint(); err != nil { + return nil, err + } + + if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil { + return nil, err + } + + if !rep || o.index >= len(o.buf) { + break + } + } + return value.Interface(), nil +} + +// GetExtensions returns a slice of the extensions present in pb that are also listed in es. +// The returned slice has the same length as es; missing extensions will appear as nil elements. +func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { + epb, ok := pb.(extendableProto) + if !ok { + err = errors.New("proto: not an extendable proto") + return + } + extensions = make([]interface{}, len(es)) + for i, e := range es { + extensions[i], err = GetExtension(epb, e) + if err == ErrMissingExtension { + err = nil + } + if err != nil { + return + } + } + return +} + +// SetExtension sets the specified extension of pb to the specified value. +func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error { + if err := checkExtensionTypes(pb, extension); err != nil { + return err + } + typ := reflect.TypeOf(extension.ExtensionType) + if typ != reflect.TypeOf(value) { + return errors.New("proto: bad extension value type") + } + // nil extension values need to be caught early, because the + // encoder can't distinguish an ErrNil due to a nil extension + // from an ErrNil due to a missing field. Extensions are + // always optional, so the encoder would just swallow the error + // and drop all the extensions from the encoded message. + if reflect.ValueOf(value).IsNil() { + return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) + } + + pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value} + return nil +} + +// A global registry of extensions. +// The generated code will register the generated descriptors by calling RegisterExtension. + +var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) + +// RegisterExtension is called from the generated code. +func RegisterExtension(desc *ExtensionDesc) { + st := reflect.TypeOf(desc.ExtendedType).Elem() + m := extensionMaps[st] + if m == nil { + m = make(map[int32]*ExtensionDesc) + extensionMaps[st] = m + } + if _, ok := m[desc.Field]; ok { + panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) + } + m[desc.Field] = desc +} + +// RegisteredExtensions returns a map of the registered extensions of a +// protocol buffer struct, indexed by the extension number. +// The argument pb should be a nil pointer to the struct type. +func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { + return extensionMaps[reflect.TypeOf(pb).Elem()] +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/extensions_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/extensions_test.go new file mode 100644 index 0000000..7255276 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/extensions_test.go @@ -0,0 +1,292 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2014 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "fmt" + "reflect" + "testing" + + "github.com/golang/protobuf/proto" + pb "github.com/golang/protobuf/proto/testdata" +) + +func TestGetExtensionsWithMissingExtensions(t *testing.T) { + msg := &pb.MyMessage{} + ext1 := &pb.Ext{} + if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { + t.Fatalf("Could not set ext1: %s", ext1) + } + exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{ + pb.E_Ext_More, + pb.E_Ext_Text, + }) + if err != nil { + t.Fatalf("GetExtensions() failed: %s", err) + } + if exts[0] != ext1 { + t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0]) + } + if exts[1] != nil { + t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1]) + } +} + +func TestGetExtensionStability(t *testing.T) { + check := func(m *pb.MyMessage) bool { + ext1, err := proto.GetExtension(m, pb.E_Ext_More) + if err != nil { + t.Fatalf("GetExtension() failed: %s", err) + } + ext2, err := proto.GetExtension(m, pb.E_Ext_More) + if err != nil { + t.Fatalf("GetExtension() failed: %s", err) + } + return ext1 == ext2 + } + msg := &pb.MyMessage{Count: proto.Int32(4)} + ext0 := &pb.Ext{} + if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil { + t.Fatalf("Could not set ext1: %s", ext0) + } + if !check(msg) { + t.Errorf("GetExtension() not stable before marshaling") + } + bb, err := proto.Marshal(msg) + if err != nil { + t.Fatalf("Marshal() failed: %s", err) + } + msg1 := &pb.MyMessage{} + err = proto.Unmarshal(bb, msg1) + if err != nil { + t.Fatalf("Unmarshal() failed: %s", err) + } + if !check(msg1) { + t.Errorf("GetExtension() not stable after unmarshaling") + } +} + +func TestGetExtensionDefaults(t *testing.T) { + var setFloat64 float64 = 1 + var setFloat32 float32 = 2 + var setInt32 int32 = 3 + var setInt64 int64 = 4 + var setUint32 uint32 = 5 + var setUint64 uint64 = 6 + var setBool = true + var setBool2 = false + var setString = "Goodnight string" + var setBytes = []byte("Goodnight bytes") + var setEnum = pb.DefaultsMessage_TWO + + type testcase struct { + ext *proto.ExtensionDesc // Extension we are testing. + want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail). + def interface{} // Expected value of extension after ClearExtension(). + } + tests := []testcase{ + {pb.E_NoDefaultDouble, setFloat64, nil}, + {pb.E_NoDefaultFloat, setFloat32, nil}, + {pb.E_NoDefaultInt32, setInt32, nil}, + {pb.E_NoDefaultInt64, setInt64, nil}, + {pb.E_NoDefaultUint32, setUint32, nil}, + {pb.E_NoDefaultUint64, setUint64, nil}, + {pb.E_NoDefaultSint32, setInt32, nil}, + {pb.E_NoDefaultSint64, setInt64, nil}, + {pb.E_NoDefaultFixed32, setUint32, nil}, + {pb.E_NoDefaultFixed64, setUint64, nil}, + {pb.E_NoDefaultSfixed32, setInt32, nil}, + {pb.E_NoDefaultSfixed64, setInt64, nil}, + {pb.E_NoDefaultBool, setBool, nil}, + {pb.E_NoDefaultBool, setBool2, nil}, + {pb.E_NoDefaultString, setString, nil}, + {pb.E_NoDefaultBytes, setBytes, nil}, + {pb.E_NoDefaultEnum, setEnum, nil}, + {pb.E_DefaultDouble, setFloat64, float64(3.1415)}, + {pb.E_DefaultFloat, setFloat32, float32(3.14)}, + {pb.E_DefaultInt32, setInt32, int32(42)}, + {pb.E_DefaultInt64, setInt64, int64(43)}, + {pb.E_DefaultUint32, setUint32, uint32(44)}, + {pb.E_DefaultUint64, setUint64, uint64(45)}, + {pb.E_DefaultSint32, setInt32, int32(46)}, + {pb.E_DefaultSint64, setInt64, int64(47)}, + {pb.E_DefaultFixed32, setUint32, uint32(48)}, + {pb.E_DefaultFixed64, setUint64, uint64(49)}, + {pb.E_DefaultSfixed32, setInt32, int32(50)}, + {pb.E_DefaultSfixed64, setInt64, int64(51)}, + {pb.E_DefaultBool, setBool, true}, + {pb.E_DefaultBool, setBool2, true}, + {pb.E_DefaultString, setString, "Hello, string"}, + {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")}, + {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE}, + } + + checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error { + val, err := proto.GetExtension(msg, test.ext) + if err != nil { + if valWant != nil { + return fmt.Errorf("GetExtension(): %s", err) + } + if want := proto.ErrMissingExtension; err != want { + return fmt.Errorf("Unexpected error: got %v, want %v", err, want) + } + return nil + } + + // All proto2 extension values are either a pointer to a value or a slice of values. + ty := reflect.TypeOf(val) + tyWant := reflect.TypeOf(test.ext.ExtensionType) + if got, want := ty, tyWant; got != want { + return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want) + } + tye := ty.Elem() + tyeWant := tyWant.Elem() + if got, want := tye, tyeWant; got != want { + return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want) + } + + // Check the name of the type of the value. + // If it is an enum it will be type int32 with the name of the enum. + if got, want := tye.Name(), tye.Name(); got != want { + return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want) + } + + // Check that value is what we expect. + // If we have a pointer in val, get the value it points to. + valExp := val + if ty.Kind() == reflect.Ptr { + valExp = reflect.ValueOf(val).Elem().Interface() + } + if got, want := valExp, valWant; !reflect.DeepEqual(got, want) { + return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want) + } + + return nil + } + + setTo := func(test testcase) interface{} { + setTo := reflect.ValueOf(test.want) + if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr { + setTo = reflect.New(typ).Elem() + setTo.Set(reflect.New(setTo.Type().Elem())) + setTo.Elem().Set(reflect.ValueOf(test.want)) + } + return setTo.Interface() + } + + for _, test := range tests { + msg := &pb.DefaultsMessage{} + name := test.ext.Name + + // Check the initial value. + if err := checkVal(test, msg, test.def); err != nil { + t.Errorf("%s: %v", name, err) + } + + // Set the per-type value and check value. + name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want) + if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil { + t.Errorf("%s: SetExtension(): %v", name, err) + continue + } + if err := checkVal(test, msg, test.want); err != nil { + t.Errorf("%s: %v", name, err) + continue + } + + // Set and check the value. + name += " (cleared)" + proto.ClearExtension(msg, test.ext) + if err := checkVal(test, msg, test.def); err != nil { + t.Errorf("%s: %v", name, err) + } + } +} + +func TestExtensionsRoundTrip(t *testing.T) { + msg := &pb.MyMessage{} + ext1 := &pb.Ext{ + Data: proto.String("hi"), + } + ext2 := &pb.Ext{ + Data: proto.String("there"), + } + exists := proto.HasExtension(msg, pb.E_Ext_More) + if exists { + t.Error("Extension More present unexpectedly") + } + if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { + t.Error(err) + } + if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil { + t.Error(err) + } + e, err := proto.GetExtension(msg, pb.E_Ext_More) + if err != nil { + t.Error(err) + } + x, ok := e.(*pb.Ext) + if !ok { + t.Errorf("e has type %T, expected testdata.Ext", e) + } else if *x.Data != "there" { + t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x) + } + proto.ClearExtension(msg, pb.E_Ext_More) + if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension { + t.Errorf("got %v, expected ErrMissingExtension", e) + } + if _, err := proto.GetExtension(msg, pb.E_X215); err == nil { + t.Error("expected bad extension error, got nil") + } + if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil { + t.Error("expected extension err") + } + if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil { + t.Error("expected some sort of type mismatch error, got nil") + } +} + +func TestNilExtension(t *testing.T) { + msg := &pb.MyMessage{ + Count: proto.Int32(1), + } + if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil { + t.Fatal(err) + } + if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil { + t.Error("expected SetExtension to fail due to a nil extension") + } else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want { + t.Errorf("expected error %v, got %v", want, err) + } + // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update + // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal. +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/lib.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/lib.go new file mode 100644 index 0000000..abbf758 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/lib.go @@ -0,0 +1,883 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* +Package proto converts data structures to and from the wire format of +protocol buffers. It works in concert with the Go source code generated +for .proto files by the protocol compiler. + +A summary of the properties of the protocol buffer interface +for a protocol buffer variable v: + + - Names are turned from camel_case to CamelCase for export. + - There are no methods on v to set fields; just treat + them as structure fields. + - There are getters that return a field's value if set, + and return the field's default value if unset. + The getters work even if the receiver is a nil message. + - The zero value for a struct is its correct initialization state. + All desired fields must be set before marshaling. + - A Reset() method will restore a protobuf struct to its zero state. + - Non-repeated fields are pointers to the values; nil means unset. + That is, optional or required field int32 f becomes F *int32. + - Repeated fields are slices. + - Helper functions are available to aid the setting of fields. + msg.Foo = proto.String("hello") // set field + - Constants are defined to hold the default values of all fields that + have them. They have the form Default_StructName_FieldName. + Because the getter methods handle defaulted values, + direct use of these constants should be rare. + - Enums are given type names and maps from names to values. + Enum values are prefixed by the enclosing message's name, or by the + enum's type name if it is a top-level enum. Enum types have a String + method, and a Enum method to assist in message construction. + - Nested messages, groups and enums have type names prefixed with the name of + the surrounding message type. + - Extensions are given descriptor names that start with E_, + followed by an underscore-delimited list of the nested messages + that contain it (if any) followed by the CamelCased name of the + extension field itself. HasExtension, ClearExtension, GetExtension + and SetExtension are functions for manipulating extensions. + - Oneof field sets are given a single field in their message, + with distinguished wrapper types for each possible field value. + - Marshal and Unmarshal are functions to encode and decode the wire format. + +The simplest way to describe this is to see an example. +Given file test.proto, containing + + package example; + + enum FOO { X = 17; } + + message Test { + required string label = 1; + optional int32 type = 2 [default=77]; + repeated int64 reps = 3; + optional group OptionalGroup = 4 { + required string RequiredField = 5; + } + oneof union { + int32 number = 6; + string name = 7; + } + } + +The resulting file, test.pb.go, is: + + package example + + import proto "github.com/golang/protobuf/proto" + import math "math" + + type FOO int32 + const ( + FOO_X FOO = 17 + ) + var FOO_name = map[int32]string{ + 17: "X", + } + var FOO_value = map[string]int32{ + "X": 17, + } + + func (x FOO) Enum() *FOO { + p := new(FOO) + *p = x + return p + } + func (x FOO) String() string { + return proto.EnumName(FOO_name, int32(x)) + } + func (x *FOO) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(FOO_value, data) + if err != nil { + return err + } + *x = FOO(value) + return nil + } + + type Test struct { + Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` + Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"` + Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"` + Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup" json:"optionalgroup,omitempty"` + // Types that are valid to be assigned to Union: + // *Test_Number + // *Test_Name + Union isTest_Union `protobuf_oneof:"union"` + XXX_unrecognized []byte `json:"-"` + } + func (m *Test) Reset() { *m = Test{} } + func (m *Test) String() string { return proto.CompactTextString(m) } + func (*Test) ProtoMessage() {} + + type isTest_Union interface { + isTest_Union() + } + + type Test_Number struct { + Number int32 `protobuf:"varint,6,opt,name=number"` + } + type Test_Name struct { + Name string `protobuf:"bytes,7,opt,name=name"` + } + + func (*Test_Number) isTest_Union() {} + func (*Test_Name) isTest_Union() {} + + func (m *Test) GetUnion() isTest_Union { + if m != nil { + return m.Union + } + return nil + } + const Default_Test_Type int32 = 77 + + func (m *Test) GetLabel() string { + if m != nil && m.Label != nil { + return *m.Label + } + return "" + } + + func (m *Test) GetType() int32 { + if m != nil && m.Type != nil { + return *m.Type + } + return Default_Test_Type + } + + func (m *Test) GetOptionalgroup() *Test_OptionalGroup { + if m != nil { + return m.Optionalgroup + } + return nil + } + + type Test_OptionalGroup struct { + RequiredField *string `protobuf:"bytes,5,req" json:"RequiredField,omitempty"` + } + func (m *Test_OptionalGroup) Reset() { *m = Test_OptionalGroup{} } + func (m *Test_OptionalGroup) String() string { return proto.CompactTextString(m) } + + func (m *Test_OptionalGroup) GetRequiredField() string { + if m != nil && m.RequiredField != nil { + return *m.RequiredField + } + return "" + } + + func (m *Test) GetNumber() int32 { + if x, ok := m.GetUnion().(*Test_Number); ok { + return x.Number + } + return 0 + } + + func (m *Test) GetName() string { + if x, ok := m.GetUnion().(*Test_Name); ok { + return x.Name + } + return "" + } + + func init() { + proto.RegisterEnum("example.FOO", FOO_name, FOO_value) + } + +To create and play with a Test object: + +package main + + import ( + "log" + + "github.com/golang/protobuf/proto" + pb "./example.pb" + ) + + func main() { + test := &pb.Test{ + Label: proto.String("hello"), + Type: proto.Int32(17), + Optionalgroup: &pb.Test_OptionalGroup{ + RequiredField: proto.String("good bye"), + }, + Union: &pb.Test_Name{"fred"}, + } + data, err := proto.Marshal(test) + if err != nil { + log.Fatal("marshaling error: ", err) + } + newTest := &pb.Test{} + err = proto.Unmarshal(data, newTest) + if err != nil { + log.Fatal("unmarshaling error: ", err) + } + // Now test and newTest contain the same data. + if test.GetLabel() != newTest.GetLabel() { + log.Fatalf("data mismatch %q != %q", test.GetLabel(), newTest.GetLabel()) + } + // Use a type switch to determine which oneof was set. + switch u := test.Union.(type) { + case *pb.Test_Number: // u.Number contains the number. + case *pb.Test_Name: // u.Name contains the string. + } + // etc. + } +*/ +package proto + +import ( + "encoding/json" + "fmt" + "log" + "reflect" + "sort" + "strconv" + "sync" +) + +// Message is implemented by generated protocol buffer messages. +type Message interface { + Reset() + String() string + ProtoMessage() +} + +// Stats records allocation details about the protocol buffer encoders +// and decoders. Useful for tuning the library itself. +type Stats struct { + Emalloc uint64 // mallocs in encode + Dmalloc uint64 // mallocs in decode + Encode uint64 // number of encodes + Decode uint64 // number of decodes + Chit uint64 // number of cache hits + Cmiss uint64 // number of cache misses + Size uint64 // number of sizes +} + +// Set to true to enable stats collection. +const collectStats = false + +var stats Stats + +// GetStats returns a copy of the global Stats structure. +func GetStats() Stats { return stats } + +// A Buffer is a buffer manager for marshaling and unmarshaling +// protocol buffers. It may be reused between invocations to +// reduce memory usage. It is not necessary to use a Buffer; +// the global functions Marshal and Unmarshal create a +// temporary Buffer and are fine for most applications. +type Buffer struct { + buf []byte // encode/decode byte stream + index int // write point + + // pools of basic types to amortize allocation. + bools []bool + uint32s []uint32 + uint64s []uint64 + + // extra pools, only used with pointer_reflect.go + int32s []int32 + int64s []int64 + float32s []float32 + float64s []float64 +} + +// NewBuffer allocates a new Buffer and initializes its internal data to +// the contents of the argument slice. +func NewBuffer(e []byte) *Buffer { + return &Buffer{buf: e} +} + +// Reset resets the Buffer, ready for marshaling a new protocol buffer. +func (p *Buffer) Reset() { + p.buf = p.buf[0:0] // for reading/writing + p.index = 0 // for reading +} + +// SetBuf replaces the internal buffer with the slice, +// ready for unmarshaling the contents of the slice. +func (p *Buffer) SetBuf(s []byte) { + p.buf = s + p.index = 0 +} + +// Bytes returns the contents of the Buffer. +func (p *Buffer) Bytes() []byte { return p.buf } + +/* + * Helper routines for simplifying the creation of optional fields of basic type. + */ + +// Bool is a helper routine that allocates a new bool value +// to store v and returns a pointer to it. +func Bool(v bool) *bool { + return &v +} + +// Int32 is a helper routine that allocates a new int32 value +// to store v and returns a pointer to it. +func Int32(v int32) *int32 { + return &v +} + +// Int is a helper routine that allocates a new int32 value +// to store v and returns a pointer to it, but unlike Int32 +// its argument value is an int. +func Int(v int) *int32 { + p := new(int32) + *p = int32(v) + return p +} + +// Int64 is a helper routine that allocates a new int64 value +// to store v and returns a pointer to it. +func Int64(v int64) *int64 { + return &v +} + +// Float32 is a helper routine that allocates a new float32 value +// to store v and returns a pointer to it. +func Float32(v float32) *float32 { + return &v +} + +// Float64 is a helper routine that allocates a new float64 value +// to store v and returns a pointer to it. +func Float64(v float64) *float64 { + return &v +} + +// Uint32 is a helper routine that allocates a new uint32 value +// to store v and returns a pointer to it. +func Uint32(v uint32) *uint32 { + return &v +} + +// Uint64 is a helper routine that allocates a new uint64 value +// to store v and returns a pointer to it. +func Uint64(v uint64) *uint64 { + return &v +} + +// String is a helper routine that allocates a new string value +// to store v and returns a pointer to it. +func String(v string) *string { + return &v +} + +// EnumName is a helper function to simplify printing protocol buffer enums +// by name. Given an enum map and a value, it returns a useful string. +func EnumName(m map[int32]string, v int32) string { + s, ok := m[v] + if ok { + return s + } + return strconv.Itoa(int(v)) +} + +// UnmarshalJSONEnum is a helper function to simplify recovering enum int values +// from their JSON-encoded representation. Given a map from the enum's symbolic +// names to its int values, and a byte buffer containing the JSON-encoded +// value, it returns an int32 that can be cast to the enum type by the caller. +// +// The function can deal with both JSON representations, numeric and symbolic. +func UnmarshalJSONEnum(m map[string]int32, data []byte, enumName string) (int32, error) { + if data[0] == '"' { + // New style: enums are strings. + var repr string + if err := json.Unmarshal(data, &repr); err != nil { + return -1, err + } + val, ok := m[repr] + if !ok { + return 0, fmt.Errorf("unrecognized enum %s value %q", enumName, repr) + } + return val, nil + } + // Old style: enums are ints. + var val int32 + if err := json.Unmarshal(data, &val); err != nil { + return 0, fmt.Errorf("cannot unmarshal %#q into enum %s", data, enumName) + } + return val, nil +} + +// DebugPrint dumps the encoded data in b in a debugging format with a header +// including the string s. Used in testing but made available for general debugging. +func (p *Buffer) DebugPrint(s string, b []byte) { + var u uint64 + + obuf := p.buf + index := p.index + p.buf = b + p.index = 0 + depth := 0 + + fmt.Printf("\n--- %s ---\n", s) + +out: + for { + for i := 0; i < depth; i++ { + fmt.Print(" ") + } + + index := p.index + if index == len(p.buf) { + break + } + + op, err := p.DecodeVarint() + if err != nil { + fmt.Printf("%3d: fetching op err %v\n", index, err) + break out + } + tag := op >> 3 + wire := op & 7 + + switch wire { + default: + fmt.Printf("%3d: t=%3d unknown wire=%d\n", + index, tag, wire) + break out + + case WireBytes: + var r []byte + + r, err = p.DecodeRawBytes(false) + if err != nil { + break out + } + fmt.Printf("%3d: t=%3d bytes [%d]", index, tag, len(r)) + if len(r) <= 6 { + for i := 0; i < len(r); i++ { + fmt.Printf(" %.2x", r[i]) + } + } else { + for i := 0; i < 3; i++ { + fmt.Printf(" %.2x", r[i]) + } + fmt.Printf(" ..") + for i := len(r) - 3; i < len(r); i++ { + fmt.Printf(" %.2x", r[i]) + } + } + fmt.Printf("\n") + + case WireFixed32: + u, err = p.DecodeFixed32() + if err != nil { + fmt.Printf("%3d: t=%3d fix32 err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d fix32 %d\n", index, tag, u) + + case WireFixed64: + u, err = p.DecodeFixed64() + if err != nil { + fmt.Printf("%3d: t=%3d fix64 err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d fix64 %d\n", index, tag, u) + + case WireVarint: + u, err = p.DecodeVarint() + if err != nil { + fmt.Printf("%3d: t=%3d varint err %v\n", index, tag, err) + break out + } + fmt.Printf("%3d: t=%3d varint %d\n", index, tag, u) + + case WireStartGroup: + fmt.Printf("%3d: t=%3d start\n", index, tag) + depth++ + + case WireEndGroup: + depth-- + fmt.Printf("%3d: t=%3d end\n", index, tag) + } + } + + if depth != 0 { + fmt.Printf("%3d: start-end not balanced %d\n", p.index, depth) + } + fmt.Printf("\n") + + p.buf = obuf + p.index = index +} + +// SetDefaults sets unset protocol buffer fields to their default values. +// It only modifies fields that are both unset and have defined defaults. +// It recursively sets default values in any non-nil sub-messages. +func SetDefaults(pb Message) { + setDefaults(reflect.ValueOf(pb), true, false) +} + +// v is a pointer to a struct. +func setDefaults(v reflect.Value, recur, zeros bool) { + v = v.Elem() + + defaultMu.RLock() + dm, ok := defaults[v.Type()] + defaultMu.RUnlock() + if !ok { + dm = buildDefaultMessage(v.Type()) + defaultMu.Lock() + defaults[v.Type()] = dm + defaultMu.Unlock() + } + + for _, sf := range dm.scalars { + f := v.Field(sf.index) + if !f.IsNil() { + // field already set + continue + } + dv := sf.value + if dv == nil && !zeros { + // no explicit default, and don't want to set zeros + continue + } + fptr := f.Addr().Interface() // **T + // TODO: Consider batching the allocations we do here. + switch sf.kind { + case reflect.Bool: + b := new(bool) + if dv != nil { + *b = dv.(bool) + } + *(fptr.(**bool)) = b + case reflect.Float32: + f := new(float32) + if dv != nil { + *f = dv.(float32) + } + *(fptr.(**float32)) = f + case reflect.Float64: + f := new(float64) + if dv != nil { + *f = dv.(float64) + } + *(fptr.(**float64)) = f + case reflect.Int32: + // might be an enum + if ft := f.Type(); ft != int32PtrType { + // enum + f.Set(reflect.New(ft.Elem())) + if dv != nil { + f.Elem().SetInt(int64(dv.(int32))) + } + } else { + // int32 field + i := new(int32) + if dv != nil { + *i = dv.(int32) + } + *(fptr.(**int32)) = i + } + case reflect.Int64: + i := new(int64) + if dv != nil { + *i = dv.(int64) + } + *(fptr.(**int64)) = i + case reflect.String: + s := new(string) + if dv != nil { + *s = dv.(string) + } + *(fptr.(**string)) = s + case reflect.Uint8: + // exceptional case: []byte + var b []byte + if dv != nil { + db := dv.([]byte) + b = make([]byte, len(db)) + copy(b, db) + } else { + b = []byte{} + } + *(fptr.(*[]byte)) = b + case reflect.Uint32: + u := new(uint32) + if dv != nil { + *u = dv.(uint32) + } + *(fptr.(**uint32)) = u + case reflect.Uint64: + u := new(uint64) + if dv != nil { + *u = dv.(uint64) + } + *(fptr.(**uint64)) = u + default: + log.Printf("proto: can't set default for field %v (sf.kind=%v)", f, sf.kind) + } + } + + for _, ni := range dm.nested { + f := v.Field(ni) + // f is *T or []*T or map[T]*T + switch f.Kind() { + case reflect.Ptr: + if f.IsNil() { + continue + } + setDefaults(f, recur, zeros) + + case reflect.Slice: + for i := 0; i < f.Len(); i++ { + e := f.Index(i) + if e.IsNil() { + continue + } + setDefaults(e, recur, zeros) + } + + case reflect.Map: + for _, k := range f.MapKeys() { + e := f.MapIndex(k) + if e.IsNil() { + continue + } + setDefaults(e, recur, zeros) + } + } + } +} + +var ( + // defaults maps a protocol buffer struct type to a slice of the fields, + // with its scalar fields set to their proto-declared non-zero default values. + defaultMu sync.RWMutex + defaults = make(map[reflect.Type]defaultMessage) + + int32PtrType = reflect.TypeOf((*int32)(nil)) +) + +// defaultMessage represents information about the default values of a message. +type defaultMessage struct { + scalars []scalarField + nested []int // struct field index of nested messages +} + +type scalarField struct { + index int // struct field index + kind reflect.Kind // element type (the T in *T or []T) + value interface{} // the proto-declared default value, or nil +} + +// t is a struct type. +func buildDefaultMessage(t reflect.Type) (dm defaultMessage) { + sprop := GetProperties(t) + for _, prop := range sprop.Prop { + fi, ok := sprop.decoderTags.get(prop.Tag) + if !ok { + // XXX_unrecognized + continue + } + ft := t.Field(fi).Type + + sf, nested, err := fieldDefault(ft, prop) + switch { + case err != nil: + log.Print(err) + case nested: + dm.nested = append(dm.nested, fi) + case sf != nil: + sf.index = fi + dm.scalars = append(dm.scalars, *sf) + } + } + + return dm +} + +// fieldDefault returns the scalarField for field type ft. +// sf will be nil if the field can not have a default. +// nestedMessage will be true if this is a nested message. +// Note that sf.index is not set on return. +func fieldDefault(ft reflect.Type, prop *Properties) (sf *scalarField, nestedMessage bool, err error) { + var canHaveDefault bool + switch ft.Kind() { + case reflect.Ptr: + if ft.Elem().Kind() == reflect.Struct { + nestedMessage = true + } else { + canHaveDefault = true // proto2 scalar field + } + + case reflect.Slice: + switch ft.Elem().Kind() { + case reflect.Ptr: + nestedMessage = true // repeated message + case reflect.Uint8: + canHaveDefault = true // bytes field + } + + case reflect.Map: + if ft.Elem().Kind() == reflect.Ptr { + nestedMessage = true // map with message values + } + } + + if !canHaveDefault { + if nestedMessage { + return nil, true, nil + } + return nil, false, nil + } + + // We now know that ft is a pointer or slice. + sf = &scalarField{kind: ft.Elem().Kind()} + + // scalar fields without defaults + if !prop.HasDefault { + return sf, false, nil + } + + // a scalar field: either *T or []byte + switch ft.Elem().Kind() { + case reflect.Bool: + x, err := strconv.ParseBool(prop.Default) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default bool %q: %v", prop.Default, err) + } + sf.value = x + case reflect.Float32: + x, err := strconv.ParseFloat(prop.Default, 32) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default float32 %q: %v", prop.Default, err) + } + sf.value = float32(x) + case reflect.Float64: + x, err := strconv.ParseFloat(prop.Default, 64) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default float64 %q: %v", prop.Default, err) + } + sf.value = x + case reflect.Int32: + x, err := strconv.ParseInt(prop.Default, 10, 32) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default int32 %q: %v", prop.Default, err) + } + sf.value = int32(x) + case reflect.Int64: + x, err := strconv.ParseInt(prop.Default, 10, 64) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default int64 %q: %v", prop.Default, err) + } + sf.value = x + case reflect.String: + sf.value = prop.Default + case reflect.Uint8: + // []byte (not *uint8) + sf.value = []byte(prop.Default) + case reflect.Uint32: + x, err := strconv.ParseUint(prop.Default, 10, 32) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default uint32 %q: %v", prop.Default, err) + } + sf.value = uint32(x) + case reflect.Uint64: + x, err := strconv.ParseUint(prop.Default, 10, 64) + if err != nil { + return nil, false, fmt.Errorf("proto: bad default uint64 %q: %v", prop.Default, err) + } + sf.value = x + default: + return nil, false, fmt.Errorf("proto: unhandled def kind %v", ft.Elem().Kind()) + } + + return sf, false, nil +} + +// Map fields may have key types of non-float scalars, strings and enums. +// The easiest way to sort them in some deterministic order is to use fmt. +// If this turns out to be inefficient we can always consider other options, +// such as doing a Schwartzian transform. + +func mapKeys(vs []reflect.Value) sort.Interface { + s := mapKeySorter{ + vs: vs, + // default Less function: textual comparison + less: func(a, b reflect.Value) bool { + return fmt.Sprint(a.Interface()) < fmt.Sprint(b.Interface()) + }, + } + + // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps; + // numeric keys are sorted numerically. + if len(vs) == 0 { + return s + } + switch vs[0].Kind() { + case reflect.Int32, reflect.Int64: + s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() } + case reflect.Uint32, reflect.Uint64: + s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() } + } + + return s +} + +type mapKeySorter struct { + vs []reflect.Value + less func(a, b reflect.Value) bool +} + +func (s mapKeySorter) Len() int { return len(s.vs) } +func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] } +func (s mapKeySorter) Less(i, j int) bool { + return s.less(s.vs[i], s.vs[j]) +} + +// isProto3Zero reports whether v is a zero proto3 value. +func isProto3Zero(v reflect.Value) bool { + switch v.Kind() { + case reflect.Bool: + return !v.Bool() + case reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint32, reflect.Uint64: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.String: + return v.String() == "" + } + return false +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/message_set.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/message_set.go new file mode 100644 index 0000000..9d912bc --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/message_set.go @@ -0,0 +1,287 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Support for message sets. + */ + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "reflect" + "sort" +) + +// ErrNoMessageTypeId occurs when a protocol buffer does not have a message type ID. +// A message type ID is required for storing a protocol buffer in a message set. +var ErrNoMessageTypeId = errors.New("proto does not have a message type ID") + +// The first two types (_MessageSet_Item and MessageSet) +// model what the protocol compiler produces for the following protocol message: +// message MessageSet { +// repeated group Item = 1 { +// required int32 type_id = 2; +// required string message = 3; +// }; +// } +// That is the MessageSet wire format. We can't use a proto to generate these +// because that would introduce a circular dependency between it and this package. +// +// When a proto1 proto has a field that looks like: +// optional message info = 3; +// the protocol compiler produces a field in the generated struct that looks like: +// Info *_proto_.MessageSet `protobuf:"bytes,3,opt,name=info"` +// The package is automatically inserted so there is no need for that proto file to +// import this package. + +type _MessageSet_Item struct { + TypeId *int32 `protobuf:"varint,2,req,name=type_id"` + Message []byte `protobuf:"bytes,3,req,name=message"` +} + +type MessageSet struct { + Item []*_MessageSet_Item `protobuf:"group,1,rep"` + XXX_unrecognized []byte + // TODO: caching? +} + +// Make sure MessageSet is a Message. +var _ Message = (*MessageSet)(nil) + +// messageTypeIder is an interface satisfied by a protocol buffer type +// that may be stored in a MessageSet. +type messageTypeIder interface { + MessageTypeId() int32 +} + +func (ms *MessageSet) find(pb Message) *_MessageSet_Item { + mti, ok := pb.(messageTypeIder) + if !ok { + return nil + } + id := mti.MessageTypeId() + for _, item := range ms.Item { + if *item.TypeId == id { + return item + } + } + return nil +} + +func (ms *MessageSet) Has(pb Message) bool { + if ms.find(pb) != nil { + return true + } + return false +} + +func (ms *MessageSet) Unmarshal(pb Message) error { + if item := ms.find(pb); item != nil { + return Unmarshal(item.Message, pb) + } + if _, ok := pb.(messageTypeIder); !ok { + return ErrNoMessageTypeId + } + return nil // TODO: return error instead? +} + +func (ms *MessageSet) Marshal(pb Message) error { + msg, err := Marshal(pb) + if err != nil { + return err + } + if item := ms.find(pb); item != nil { + // reuse existing item + item.Message = msg + return nil + } + + mti, ok := pb.(messageTypeIder) + if !ok { + return ErrNoMessageTypeId + } + + mtid := mti.MessageTypeId() + ms.Item = append(ms.Item, &_MessageSet_Item{ + TypeId: &mtid, + Message: msg, + }) + return nil +} + +func (ms *MessageSet) Reset() { *ms = MessageSet{} } +func (ms *MessageSet) String() string { return CompactTextString(ms) } +func (*MessageSet) ProtoMessage() {} + +// Support for the message_set_wire_format message option. + +func skipVarint(buf []byte) []byte { + i := 0 + for ; buf[i]&0x80 != 0; i++ { + } + return buf[i+1:] +} + +// MarshalMessageSet encodes the extension map represented by m in the message set wire format. +// It is called by generated Marshal methods on protocol buffer messages with the message_set_wire_format option. +func MarshalMessageSet(m map[int32]Extension) ([]byte, error) { + if err := encodeExtensionMap(m); err != nil { + return nil, err + } + + // Sort extension IDs to provide a deterministic encoding. + // See also enc_map in encode.go. + ids := make([]int, 0, len(m)) + for id := range m { + ids = append(ids, int(id)) + } + sort.Ints(ids) + + ms := &MessageSet{Item: make([]*_MessageSet_Item, 0, len(m))} + for _, id := range ids { + e := m[int32(id)] + // Remove the wire type and field number varint, as well as the length varint. + msg := skipVarint(skipVarint(e.enc)) + + ms.Item = append(ms.Item, &_MessageSet_Item{ + TypeId: Int32(int32(id)), + Message: msg, + }) + } + return Marshal(ms) +} + +// UnmarshalMessageSet decodes the extension map encoded in buf in the message set wire format. +// It is called by generated Unmarshal methods on protocol buffer messages with the message_set_wire_format option. +func UnmarshalMessageSet(buf []byte, m map[int32]Extension) error { + ms := new(MessageSet) + if err := Unmarshal(buf, ms); err != nil { + return err + } + for _, item := range ms.Item { + id := *item.TypeId + msg := item.Message + + // Restore wire type and field number varint, plus length varint. + // Be careful to preserve duplicate items. + b := EncodeVarint(uint64(id)<<3 | WireBytes) + if ext, ok := m[id]; ok { + // Existing data; rip off the tag and length varint + // so we join the new data correctly. + // We can assume that ext.enc is set because we are unmarshaling. + o := ext.enc[len(b):] // skip wire type and field number + _, n := DecodeVarint(o) // calculate length of length varint + o = o[n:] // skip length varint + msg = append(o, msg...) // join old data and new data + } + b = append(b, EncodeVarint(uint64(len(msg)))...) + b = append(b, msg...) + + m[id] = Extension{enc: b} + } + return nil +} + +// MarshalMessageSetJSON encodes the extension map represented by m in JSON format. +// It is called by generated MarshalJSON methods on protocol buffer messages with the message_set_wire_format option. +func MarshalMessageSetJSON(m map[int32]Extension) ([]byte, error) { + var b bytes.Buffer + b.WriteByte('{') + + // Process the map in key order for deterministic output. + ids := make([]int32, 0, len(m)) + for id := range m { + ids = append(ids, id) + } + sort.Sort(int32Slice(ids)) // int32Slice defined in text.go + + for i, id := range ids { + ext := m[id] + if i > 0 { + b.WriteByte(',') + } + + msd, ok := messageSetMap[id] + if !ok { + // Unknown type; we can't render it, so skip it. + continue + } + fmt.Fprintf(&b, `"[%s]":`, msd.name) + + x := ext.value + if x == nil { + x = reflect.New(msd.t.Elem()).Interface() + if err := Unmarshal(ext.enc, x.(Message)); err != nil { + return nil, err + } + } + d, err := json.Marshal(x) + if err != nil { + return nil, err + } + b.Write(d) + } + b.WriteByte('}') + return b.Bytes(), nil +} + +// UnmarshalMessageSetJSON decodes the extension map encoded in buf in JSON format. +// It is called by generated UnmarshalJSON methods on protocol buffer messages with the message_set_wire_format option. +func UnmarshalMessageSetJSON(buf []byte, m map[int32]Extension) error { + // Common-case fast path. + if len(buf) == 0 || bytes.Equal(buf, []byte("{}")) { + return nil + } + + // This is fairly tricky, and it's not clear that it is needed. + return errors.New("TODO: UnmarshalMessageSetJSON not yet implemented") +} + +// A global registry of types that can be used in a MessageSet. + +var messageSetMap = make(map[int32]messageSetDesc) + +type messageSetDesc struct { + t reflect.Type // pointer to struct + name string +} + +// RegisterMessageSetType is called from the generated code. +func RegisterMessageSetType(m Message, fieldNum int32, name string) { + messageSetMap[fieldNum] = messageSetDesc{ + t: reflect.TypeOf(m), + name: name, + } +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/message_set_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/message_set_test.go new file mode 100644 index 0000000..7c29bcc --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/message_set_test.go @@ -0,0 +1,66 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2014 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +import ( + "bytes" + "testing" +) + +func TestUnmarshalMessageSetWithDuplicate(t *testing.T) { + // Check that a repeated message set entry will be concatenated. + in := &MessageSet{ + Item: []*_MessageSet_Item{ + {TypeId: Int32(12345), Message: []byte("hoo")}, + {TypeId: Int32(12345), Message: []byte("hah")}, + }, + } + b, err := Marshal(in) + if err != nil { + t.Fatalf("Marshal: %v", err) + } + t.Logf("Marshaled bytes: %q", b) + + m := make(map[int32]Extension) + if err := UnmarshalMessageSet(b, m); err != nil { + t.Fatalf("UnmarshalMessageSet: %v", err) + } + ext, ok := m[12345] + if !ok { + t.Fatalf("Didn't retrieve extension 12345; map is %v", m) + } + // Skip wire type/field number and length varints. + got := skipVarint(skipVarint(ext.enc)) + if want := []byte("hoohah"); !bytes.Equal(got, want) { + t.Errorf("Combined extension is %q, want %q", got, want) + } +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/pointer_reflect.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/pointer_reflect.go new file mode 100644 index 0000000..749919d --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/pointer_reflect.go @@ -0,0 +1,479 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +build appengine + +// This file contains an implementation of proto field accesses using package reflect. +// It is slower than the code in pointer_unsafe.go but it avoids package unsafe and can +// be used on App Engine. + +package proto + +import ( + "math" + "reflect" +) + +// A structPointer is a pointer to a struct. +type structPointer struct { + v reflect.Value +} + +// toStructPointer returns a structPointer equivalent to the given reflect value. +// The reflect value must itself be a pointer to a struct. +func toStructPointer(v reflect.Value) structPointer { + return structPointer{v} +} + +// IsNil reports whether p is nil. +func structPointer_IsNil(p structPointer) bool { + return p.v.IsNil() +} + +// Interface returns the struct pointer as an interface value. +func structPointer_Interface(p structPointer, _ reflect.Type) interface{} { + return p.v.Interface() +} + +// A field identifies a field in a struct, accessible from a structPointer. +// In this implementation, a field is identified by the sequence of field indices +// passed to reflect's FieldByIndex. +type field []int + +// toField returns a field equivalent to the given reflect field. +func toField(f *reflect.StructField) field { + return f.Index +} + +// invalidField is an invalid field identifier. +var invalidField = field(nil) + +// IsValid reports whether the field identifier is valid. +func (f field) IsValid() bool { return f != nil } + +// field returns the given field in the struct as a reflect value. +func structPointer_field(p structPointer, f field) reflect.Value { + // Special case: an extension map entry with a value of type T + // passes a *T to the struct-handling code with a zero field, + // expecting that it will be treated as equivalent to *struct{ X T }, + // which has the same memory layout. We have to handle that case + // specially, because reflect will panic if we call FieldByIndex on a + // non-struct. + if f == nil { + return p.v.Elem() + } + + return p.v.Elem().FieldByIndex(f) +} + +// ifield returns the given field in the struct as an interface value. +func structPointer_ifield(p structPointer, f field) interface{} { + return structPointer_field(p, f).Addr().Interface() +} + +// Bytes returns the address of a []byte field in the struct. +func structPointer_Bytes(p structPointer, f field) *[]byte { + return structPointer_ifield(p, f).(*[]byte) +} + +// BytesSlice returns the address of a [][]byte field in the struct. +func structPointer_BytesSlice(p structPointer, f field) *[][]byte { + return structPointer_ifield(p, f).(*[][]byte) +} + +// Bool returns the address of a *bool field in the struct. +func structPointer_Bool(p structPointer, f field) **bool { + return structPointer_ifield(p, f).(**bool) +} + +// BoolVal returns the address of a bool field in the struct. +func structPointer_BoolVal(p structPointer, f field) *bool { + return structPointer_ifield(p, f).(*bool) +} + +// BoolSlice returns the address of a []bool field in the struct. +func structPointer_BoolSlice(p structPointer, f field) *[]bool { + return structPointer_ifield(p, f).(*[]bool) +} + +// String returns the address of a *string field in the struct. +func structPointer_String(p structPointer, f field) **string { + return structPointer_ifield(p, f).(**string) +} + +// StringVal returns the address of a string field in the struct. +func structPointer_StringVal(p structPointer, f field) *string { + return structPointer_ifield(p, f).(*string) +} + +// StringSlice returns the address of a []string field in the struct. +func structPointer_StringSlice(p structPointer, f field) *[]string { + return structPointer_ifield(p, f).(*[]string) +} + +// ExtMap returns the address of an extension map field in the struct. +func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { + return structPointer_ifield(p, f).(*map[int32]Extension) +} + +// NewAt returns the reflect.Value for a pointer to a field in the struct. +func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value { + return structPointer_field(p, f).Addr() +} + +// SetStructPointer writes a *struct field in the struct. +func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { + structPointer_field(p, f).Set(q.v) +} + +// GetStructPointer reads a *struct field in the struct. +func structPointer_GetStructPointer(p structPointer, f field) structPointer { + return structPointer{structPointer_field(p, f)} +} + +// StructPointerSlice the address of a []*struct field in the struct. +func structPointer_StructPointerSlice(p structPointer, f field) structPointerSlice { + return structPointerSlice{structPointer_field(p, f)} +} + +// A structPointerSlice represents the address of a slice of pointers to structs +// (themselves messages or groups). That is, v.Type() is *[]*struct{...}. +type structPointerSlice struct { + v reflect.Value +} + +func (p structPointerSlice) Len() int { return p.v.Len() } +func (p structPointerSlice) Index(i int) structPointer { return structPointer{p.v.Index(i)} } +func (p structPointerSlice) Append(q structPointer) { + p.v.Set(reflect.Append(p.v, q.v)) +} + +var ( + int32Type = reflect.TypeOf(int32(0)) + uint32Type = reflect.TypeOf(uint32(0)) + float32Type = reflect.TypeOf(float32(0)) + int64Type = reflect.TypeOf(int64(0)) + uint64Type = reflect.TypeOf(uint64(0)) + float64Type = reflect.TypeOf(float64(0)) +) + +// A word32 represents a field of type *int32, *uint32, *float32, or *enum. +// That is, v.Type() is *int32, *uint32, *float32, or *enum and v is assignable. +type word32 struct { + v reflect.Value +} + +// IsNil reports whether p is nil. +func word32_IsNil(p word32) bool { + return p.v.IsNil() +} + +// Set sets p to point at a newly allocated word with bits set to x. +func word32_Set(p word32, o *Buffer, x uint32) { + t := p.v.Type().Elem() + switch t { + case int32Type: + if len(o.int32s) == 0 { + o.int32s = make([]int32, uint32PoolSize) + } + o.int32s[0] = int32(x) + p.v.Set(reflect.ValueOf(&o.int32s[0])) + o.int32s = o.int32s[1:] + return + case uint32Type: + if len(o.uint32s) == 0 { + o.uint32s = make([]uint32, uint32PoolSize) + } + o.uint32s[0] = x + p.v.Set(reflect.ValueOf(&o.uint32s[0])) + o.uint32s = o.uint32s[1:] + return + case float32Type: + if len(o.float32s) == 0 { + o.float32s = make([]float32, uint32PoolSize) + } + o.float32s[0] = math.Float32frombits(x) + p.v.Set(reflect.ValueOf(&o.float32s[0])) + o.float32s = o.float32s[1:] + return + } + + // must be enum + p.v.Set(reflect.New(t)) + p.v.Elem().SetInt(int64(int32(x))) +} + +// Get gets the bits pointed at by p, as a uint32. +func word32_Get(p word32) uint32 { + elem := p.v.Elem() + switch elem.Kind() { + case reflect.Int32: + return uint32(elem.Int()) + case reflect.Uint32: + return uint32(elem.Uint()) + case reflect.Float32: + return math.Float32bits(float32(elem.Float())) + } + panic("unreachable") +} + +// Word32 returns a reference to a *int32, *uint32, *float32, or *enum field in the struct. +func structPointer_Word32(p structPointer, f field) word32 { + return word32{structPointer_field(p, f)} +} + +// A word32Val represents a field of type int32, uint32, float32, or enum. +// That is, v.Type() is int32, uint32, float32, or enum and v is assignable. +type word32Val struct { + v reflect.Value +} + +// Set sets *p to x. +func word32Val_Set(p word32Val, x uint32) { + switch p.v.Type() { + case int32Type: + p.v.SetInt(int64(x)) + return + case uint32Type: + p.v.SetUint(uint64(x)) + return + case float32Type: + p.v.SetFloat(float64(math.Float32frombits(x))) + return + } + + // must be enum + p.v.SetInt(int64(int32(x))) +} + +// Get gets the bits pointed at by p, as a uint32. +func word32Val_Get(p word32Val) uint32 { + elem := p.v + switch elem.Kind() { + case reflect.Int32: + return uint32(elem.Int()) + case reflect.Uint32: + return uint32(elem.Uint()) + case reflect.Float32: + return math.Float32bits(float32(elem.Float())) + } + panic("unreachable") +} + +// Word32Val returns a reference to a int32, uint32, float32, or enum field in the struct. +func structPointer_Word32Val(p structPointer, f field) word32Val { + return word32Val{structPointer_field(p, f)} +} + +// A word32Slice is a slice of 32-bit values. +// That is, v.Type() is []int32, []uint32, []float32, or []enum. +type word32Slice struct { + v reflect.Value +} + +func (p word32Slice) Append(x uint32) { + n, m := p.v.Len(), p.v.Cap() + if n < m { + p.v.SetLen(n + 1) + } else { + t := p.v.Type().Elem() + p.v.Set(reflect.Append(p.v, reflect.Zero(t))) + } + elem := p.v.Index(n) + switch elem.Kind() { + case reflect.Int32: + elem.SetInt(int64(int32(x))) + case reflect.Uint32: + elem.SetUint(uint64(x)) + case reflect.Float32: + elem.SetFloat(float64(math.Float32frombits(x))) + } +} + +func (p word32Slice) Len() int { + return p.v.Len() +} + +func (p word32Slice) Index(i int) uint32 { + elem := p.v.Index(i) + switch elem.Kind() { + case reflect.Int32: + return uint32(elem.Int()) + case reflect.Uint32: + return uint32(elem.Uint()) + case reflect.Float32: + return math.Float32bits(float32(elem.Float())) + } + panic("unreachable") +} + +// Word32Slice returns a reference to a []int32, []uint32, []float32, or []enum field in the struct. +func structPointer_Word32Slice(p structPointer, f field) word32Slice { + return word32Slice{structPointer_field(p, f)} +} + +// word64 is like word32 but for 64-bit values. +type word64 struct { + v reflect.Value +} + +func word64_Set(p word64, o *Buffer, x uint64) { + t := p.v.Type().Elem() + switch t { + case int64Type: + if len(o.int64s) == 0 { + o.int64s = make([]int64, uint64PoolSize) + } + o.int64s[0] = int64(x) + p.v.Set(reflect.ValueOf(&o.int64s[0])) + o.int64s = o.int64s[1:] + return + case uint64Type: + if len(o.uint64s) == 0 { + o.uint64s = make([]uint64, uint64PoolSize) + } + o.uint64s[0] = x + p.v.Set(reflect.ValueOf(&o.uint64s[0])) + o.uint64s = o.uint64s[1:] + return + case float64Type: + if len(o.float64s) == 0 { + o.float64s = make([]float64, uint64PoolSize) + } + o.float64s[0] = math.Float64frombits(x) + p.v.Set(reflect.ValueOf(&o.float64s[0])) + o.float64s = o.float64s[1:] + return + } + panic("unreachable") +} + +func word64_IsNil(p word64) bool { + return p.v.IsNil() +} + +func word64_Get(p word64) uint64 { + elem := p.v.Elem() + switch elem.Kind() { + case reflect.Int64: + return uint64(elem.Int()) + case reflect.Uint64: + return elem.Uint() + case reflect.Float64: + return math.Float64bits(elem.Float()) + } + panic("unreachable") +} + +func structPointer_Word64(p structPointer, f field) word64 { + return word64{structPointer_field(p, f)} +} + +// word64Val is like word32Val but for 64-bit values. +type word64Val struct { + v reflect.Value +} + +func word64Val_Set(p word64Val, o *Buffer, x uint64) { + switch p.v.Type() { + case int64Type: + p.v.SetInt(int64(x)) + return + case uint64Type: + p.v.SetUint(x) + return + case float64Type: + p.v.SetFloat(math.Float64frombits(x)) + return + } + panic("unreachable") +} + +func word64Val_Get(p word64Val) uint64 { + elem := p.v + switch elem.Kind() { + case reflect.Int64: + return uint64(elem.Int()) + case reflect.Uint64: + return elem.Uint() + case reflect.Float64: + return math.Float64bits(elem.Float()) + } + panic("unreachable") +} + +func structPointer_Word64Val(p structPointer, f field) word64Val { + return word64Val{structPointer_field(p, f)} +} + +type word64Slice struct { + v reflect.Value +} + +func (p word64Slice) Append(x uint64) { + n, m := p.v.Len(), p.v.Cap() + if n < m { + p.v.SetLen(n + 1) + } else { + t := p.v.Type().Elem() + p.v.Set(reflect.Append(p.v, reflect.Zero(t))) + } + elem := p.v.Index(n) + switch elem.Kind() { + case reflect.Int64: + elem.SetInt(int64(int64(x))) + case reflect.Uint64: + elem.SetUint(uint64(x)) + case reflect.Float64: + elem.SetFloat(float64(math.Float64frombits(x))) + } +} + +func (p word64Slice) Len() int { + return p.v.Len() +} + +func (p word64Slice) Index(i int) uint64 { + elem := p.v.Index(i) + switch elem.Kind() { + case reflect.Int64: + return uint64(elem.Int()) + case reflect.Uint64: + return uint64(elem.Uint()) + case reflect.Float64: + return math.Float64bits(float64(elem.Float())) + } + panic("unreachable") +} + +func structPointer_Word64Slice(p structPointer, f field) word64Slice { + return word64Slice{structPointer_field(p, f)} +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/pointer_unsafe.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/pointer_unsafe.go new file mode 100644 index 0000000..e9be0fe --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/pointer_unsafe.go @@ -0,0 +1,266 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +build !appengine + +// This file contains the implementation of the proto field accesses using package unsafe. + +package proto + +import ( + "reflect" + "unsafe" +) + +// NOTE: These type_Foo functions would more idiomatically be methods, +// but Go does not allow methods on pointer types, and we must preserve +// some pointer type for the garbage collector. We use these +// funcs with clunky names as our poor approximation to methods. +// +// An alternative would be +// type structPointer struct { p unsafe.Pointer } +// but that does not registerize as well. + +// A structPointer is a pointer to a struct. +type structPointer unsafe.Pointer + +// toStructPointer returns a structPointer equivalent to the given reflect value. +func toStructPointer(v reflect.Value) structPointer { + return structPointer(unsafe.Pointer(v.Pointer())) +} + +// IsNil reports whether p is nil. +func structPointer_IsNil(p structPointer) bool { + return p == nil +} + +// Interface returns the struct pointer, assumed to have element type t, +// as an interface value. +func structPointer_Interface(p structPointer, t reflect.Type) interface{} { + return reflect.NewAt(t, unsafe.Pointer(p)).Interface() +} + +// A field identifies a field in a struct, accessible from a structPointer. +// In this implementation, a field is identified by its byte offset from the start of the struct. +type field uintptr + +// toField returns a field equivalent to the given reflect field. +func toField(f *reflect.StructField) field { + return field(f.Offset) +} + +// invalidField is an invalid field identifier. +const invalidField = ^field(0) + +// IsValid reports whether the field identifier is valid. +func (f field) IsValid() bool { + return f != ^field(0) +} + +// Bytes returns the address of a []byte field in the struct. +func structPointer_Bytes(p structPointer, f field) *[]byte { + return (*[]byte)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// BytesSlice returns the address of a [][]byte field in the struct. +func structPointer_BytesSlice(p structPointer, f field) *[][]byte { + return (*[][]byte)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// Bool returns the address of a *bool field in the struct. +func structPointer_Bool(p structPointer, f field) **bool { + return (**bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// BoolVal returns the address of a bool field in the struct. +func structPointer_BoolVal(p structPointer, f field) *bool { + return (*bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// BoolSlice returns the address of a []bool field in the struct. +func structPointer_BoolSlice(p structPointer, f field) *[]bool { + return (*[]bool)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// String returns the address of a *string field in the struct. +func structPointer_String(p structPointer, f field) **string { + return (**string)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// StringVal returns the address of a string field in the struct. +func structPointer_StringVal(p structPointer, f field) *string { + return (*string)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// StringSlice returns the address of a []string field in the struct. +func structPointer_StringSlice(p structPointer, f field) *[]string { + return (*[]string)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// ExtMap returns the address of an extension map field in the struct. +func structPointer_ExtMap(p structPointer, f field) *map[int32]Extension { + return (*map[int32]Extension)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// NewAt returns the reflect.Value for a pointer to a field in the struct. +func structPointer_NewAt(p structPointer, f field, typ reflect.Type) reflect.Value { + return reflect.NewAt(typ, unsafe.Pointer(uintptr(p)+uintptr(f))) +} + +// SetStructPointer writes a *struct field in the struct. +func structPointer_SetStructPointer(p structPointer, f field, q structPointer) { + *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) = q +} + +// GetStructPointer reads a *struct field in the struct. +func structPointer_GetStructPointer(p structPointer, f field) structPointer { + return *(*structPointer)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// StructPointerSlice the address of a []*struct field in the struct. +func structPointer_StructPointerSlice(p structPointer, f field) *structPointerSlice { + return (*structPointerSlice)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// A structPointerSlice represents a slice of pointers to structs (themselves submessages or groups). +type structPointerSlice []structPointer + +func (v *structPointerSlice) Len() int { return len(*v) } +func (v *structPointerSlice) Index(i int) structPointer { return (*v)[i] } +func (v *structPointerSlice) Append(p structPointer) { *v = append(*v, p) } + +// A word32 is the address of a "pointer to 32-bit value" field. +type word32 **uint32 + +// IsNil reports whether *v is nil. +func word32_IsNil(p word32) bool { + return *p == nil +} + +// Set sets *v to point at a newly allocated word set to x. +func word32_Set(p word32, o *Buffer, x uint32) { + if len(o.uint32s) == 0 { + o.uint32s = make([]uint32, uint32PoolSize) + } + o.uint32s[0] = x + *p = &o.uint32s[0] + o.uint32s = o.uint32s[1:] +} + +// Get gets the value pointed at by *v. +func word32_Get(p word32) uint32 { + return **p +} + +// Word32 returns the address of a *int32, *uint32, *float32, or *enum field in the struct. +func structPointer_Word32(p structPointer, f field) word32 { + return word32((**uint32)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// A word32Val is the address of a 32-bit value field. +type word32Val *uint32 + +// Set sets *p to x. +func word32Val_Set(p word32Val, x uint32) { + *p = x +} + +// Get gets the value pointed at by p. +func word32Val_Get(p word32Val) uint32 { + return *p +} + +// Word32Val returns the address of a *int32, *uint32, *float32, or *enum field in the struct. +func structPointer_Word32Val(p structPointer, f field) word32Val { + return word32Val((*uint32)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// A word32Slice is a slice of 32-bit values. +type word32Slice []uint32 + +func (v *word32Slice) Append(x uint32) { *v = append(*v, x) } +func (v *word32Slice) Len() int { return len(*v) } +func (v *word32Slice) Index(i int) uint32 { return (*v)[i] } + +// Word32Slice returns the address of a []int32, []uint32, []float32, or []enum field in the struct. +func structPointer_Word32Slice(p structPointer, f field) *word32Slice { + return (*word32Slice)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} + +// word64 is like word32 but for 64-bit values. +type word64 **uint64 + +func word64_Set(p word64, o *Buffer, x uint64) { + if len(o.uint64s) == 0 { + o.uint64s = make([]uint64, uint64PoolSize) + } + o.uint64s[0] = x + *p = &o.uint64s[0] + o.uint64s = o.uint64s[1:] +} + +func word64_IsNil(p word64) bool { + return *p == nil +} + +func word64_Get(p word64) uint64 { + return **p +} + +func structPointer_Word64(p structPointer, f field) word64 { + return word64((**uint64)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// word64Val is like word32Val but for 64-bit values. +type word64Val *uint64 + +func word64Val_Set(p word64Val, o *Buffer, x uint64) { + *p = x +} + +func word64Val_Get(p word64Val) uint64 { + return *p +} + +func structPointer_Word64Val(p structPointer, f field) word64Val { + return word64Val((*uint64)(unsafe.Pointer(uintptr(p) + uintptr(f)))) +} + +// word64Slice is like word32Slice but for 64-bit values. +type word64Slice []uint64 + +func (v *word64Slice) Append(x uint64) { *v = append(*v, x) } +func (v *word64Slice) Len() int { return len(*v) } +func (v *word64Slice) Index(i int) uint64 { return (*v)[i] } + +func structPointer_Word64Slice(p structPointer, f field) *word64Slice { + return (*word64Slice)(unsafe.Pointer(uintptr(p) + uintptr(f))) +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/properties.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/properties.go new file mode 100644 index 0000000..692fe43 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/properties.go @@ -0,0 +1,799 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +/* + * Routines for encoding data into the wire format for protocol buffers. + */ + +import ( + "fmt" + "os" + "reflect" + "sort" + "strconv" + "strings" + "sync" +) + +const debug bool = false + +// Constants that identify the encoding of a value on the wire. +const ( + WireVarint = 0 + WireFixed64 = 1 + WireBytes = 2 + WireStartGroup = 3 + WireEndGroup = 4 + WireFixed32 = 5 +) + +const startSize = 10 // initial slice/string sizes + +// Encoders are defined in encode.go +// An encoder outputs the full representation of a field, including its +// tag and encoder type. +type encoder func(p *Buffer, prop *Properties, base structPointer) error + +// A valueEncoder encodes a single integer in a particular encoding. +type valueEncoder func(o *Buffer, x uint64) error + +// Sizers are defined in encode.go +// A sizer returns the encoded size of a field, including its tag and encoder +// type. +type sizer func(prop *Properties, base structPointer) int + +// A valueSizer returns the encoded size of a single integer in a particular +// encoding. +type valueSizer func(x uint64) int + +// Decoders are defined in decode.go +// A decoder creates a value from its wire representation. +// Unrecognized subelements are saved in unrec. +type decoder func(p *Buffer, prop *Properties, base structPointer) error + +// A valueDecoder decodes a single integer in a particular encoding. +type valueDecoder func(o *Buffer) (x uint64, err error) + +// A oneofMarshaler does the marshaling for all oneof fields in a message. +type oneofMarshaler func(Message, *Buffer) error + +// A oneofUnmarshaler does the unmarshaling for a oneof field in a message. +type oneofUnmarshaler func(Message, int, int, *Buffer) (bool, error) + +// tagMap is an optimization over map[int]int for typical protocol buffer +// use-cases. Encoded protocol buffers are often in tag order with small tag +// numbers. +type tagMap struct { + fastTags []int + slowTags map[int]int +} + +// tagMapFastLimit is the upper bound on the tag number that will be stored in +// the tagMap slice rather than its map. +const tagMapFastLimit = 1024 + +func (p *tagMap) get(t int) (int, bool) { + if t > 0 && t < tagMapFastLimit { + if t >= len(p.fastTags) { + return 0, false + } + fi := p.fastTags[t] + return fi, fi >= 0 + } + fi, ok := p.slowTags[t] + return fi, ok +} + +func (p *tagMap) put(t int, fi int) { + if t > 0 && t < tagMapFastLimit { + for len(p.fastTags) < t+1 { + p.fastTags = append(p.fastTags, -1) + } + p.fastTags[t] = fi + return + } + if p.slowTags == nil { + p.slowTags = make(map[int]int) + } + p.slowTags[t] = fi +} + +// StructProperties represents properties for all the fields of a struct. +// decoderTags and decoderOrigNames should only be used by the decoder. +type StructProperties struct { + Prop []*Properties // properties for each field + reqCount int // required count + decoderTags tagMap // map from proto tag to struct field number + decoderOrigNames map[string]int // map from original name to struct field number + order []int // list of struct field numbers in tag order + unrecField field // field id of the XXX_unrecognized []byte field + extendable bool // is this an extendable proto + + oneofMarshaler oneofMarshaler + oneofUnmarshaler oneofUnmarshaler + stype reflect.Type + + // OneofTypes contains information about the oneof fields in this message. + // It is keyed by the original name of a field. + OneofTypes map[string]*OneofProperties +} + +// OneofProperties represents information about a specific field in a oneof. +type OneofProperties struct { + Type reflect.Type // pointer to generated struct type for this oneof field + Field int // struct field number of the containing oneof in the message + Prop *Properties +} + +// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec. +// See encode.go, (*Buffer).enc_struct. + +func (sp *StructProperties) Len() int { return len(sp.order) } +func (sp *StructProperties) Less(i, j int) bool { + return sp.Prop[sp.order[i]].Tag < sp.Prop[sp.order[j]].Tag +} +func (sp *StructProperties) Swap(i, j int) { sp.order[i], sp.order[j] = sp.order[j], sp.order[i] } + +// Properties represents the protocol-specific behavior of a single struct field. +type Properties struct { + Name string // name of the field, for error messages + OrigName string // original name before protocol compiler (always set) + Wire string + WireType int + Tag int + Required bool + Optional bool + Repeated bool + Packed bool // relevant for repeated primitives only + Enum string // set for enum types only + proto3 bool // whether this is known to be a proto3 field; set for []byte only + + Default string // default value + HasDefault bool // whether an explicit default was provided + def_uint64 uint64 + + enc encoder + valEnc valueEncoder // set for bool and numeric types only + field field + tagcode []byte // encoding of EncodeVarint((Tag<<3)|WireType) + tagbuf [8]byte + stype reflect.Type // set for struct types only + sprop *StructProperties // set for struct types only + isMarshaler bool + isUnmarshaler bool + + mtype reflect.Type // set for map types only + mkeyprop *Properties // set for map types only + mvalprop *Properties // set for map types only + + size sizer + valSize valueSizer // set for bool and numeric types only + + dec decoder + valDec valueDecoder // set for bool and numeric types only + + // If this is a packable field, this will be the decoder for the packed version of the field. + packedDec decoder +} + +// String formats the properties in the protobuf struct field tag style. +func (p *Properties) String() string { + s := p.Wire + s = "," + s += strconv.Itoa(p.Tag) + if p.Required { + s += ",req" + } + if p.Optional { + s += ",opt" + } + if p.Repeated { + s += ",rep" + } + if p.Packed { + s += ",packed" + } + if p.OrigName != p.Name { + s += ",name=" + p.OrigName + } + if p.proto3 { + s += ",proto3" + } + if len(p.Enum) > 0 { + s += ",enum=" + p.Enum + } + if p.HasDefault { + s += ",def=" + p.Default + } + return s +} + +// Parse populates p by parsing a string in the protobuf struct field tag style. +func (p *Properties) Parse(s string) { + // "bytes,49,opt,name=foo,def=hello!" + fields := strings.Split(s, ",") // breaks def=, but handled below. + if len(fields) < 2 { + fmt.Fprintf(os.Stderr, "proto: tag has too few fields: %q\n", s) + return + } + + p.Wire = fields[0] + switch p.Wire { + case "varint": + p.WireType = WireVarint + p.valEnc = (*Buffer).EncodeVarint + p.valDec = (*Buffer).DecodeVarint + p.valSize = sizeVarint + case "fixed32": + p.WireType = WireFixed32 + p.valEnc = (*Buffer).EncodeFixed32 + p.valDec = (*Buffer).DecodeFixed32 + p.valSize = sizeFixed32 + case "fixed64": + p.WireType = WireFixed64 + p.valEnc = (*Buffer).EncodeFixed64 + p.valDec = (*Buffer).DecodeFixed64 + p.valSize = sizeFixed64 + case "zigzag32": + p.WireType = WireVarint + p.valEnc = (*Buffer).EncodeZigzag32 + p.valDec = (*Buffer).DecodeZigzag32 + p.valSize = sizeZigzag32 + case "zigzag64": + p.WireType = WireVarint + p.valEnc = (*Buffer).EncodeZigzag64 + p.valDec = (*Buffer).DecodeZigzag64 + p.valSize = sizeZigzag64 + case "bytes", "group": + p.WireType = WireBytes + // no numeric converter for non-numeric types + default: + fmt.Fprintf(os.Stderr, "proto: tag has unknown wire type: %q\n", s) + return + } + + var err error + p.Tag, err = strconv.Atoi(fields[1]) + if err != nil { + return + } + + for i := 2; i < len(fields); i++ { + f := fields[i] + switch { + case f == "req": + p.Required = true + case f == "opt": + p.Optional = true + case f == "rep": + p.Repeated = true + case f == "packed": + p.Packed = true + case strings.HasPrefix(f, "name="): + p.OrigName = f[5:] + case strings.HasPrefix(f, "enum="): + p.Enum = f[5:] + case f == "proto3": + p.proto3 = true + case strings.HasPrefix(f, "def="): + p.HasDefault = true + p.Default = f[4:] // rest of string + if i+1 < len(fields) { + // Commas aren't escaped, and def is always last. + p.Default += "," + strings.Join(fields[i+1:], ",") + break + } + } + } +} + +func logNoSliceEnc(t1, t2 reflect.Type) { + fmt.Fprintf(os.Stderr, "proto: no slice oenc for %T = []%T\n", t1, t2) +} + +var protoMessageType = reflect.TypeOf((*Message)(nil)).Elem() + +// Initialize the fields for encoding and decoding. +func (p *Properties) setEncAndDec(typ reflect.Type, f *reflect.StructField, lockGetProp bool) { + p.enc = nil + p.dec = nil + p.size = nil + + switch t1 := typ; t1.Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no coders for %v\n", t1) + + // proto3 scalar types + + case reflect.Bool: + p.enc = (*Buffer).enc_proto3_bool + p.dec = (*Buffer).dec_proto3_bool + p.size = size_proto3_bool + case reflect.Int32: + p.enc = (*Buffer).enc_proto3_int32 + p.dec = (*Buffer).dec_proto3_int32 + p.size = size_proto3_int32 + case reflect.Uint32: + p.enc = (*Buffer).enc_proto3_uint32 + p.dec = (*Buffer).dec_proto3_int32 // can reuse + p.size = size_proto3_uint32 + case reflect.Int64, reflect.Uint64: + p.enc = (*Buffer).enc_proto3_int64 + p.dec = (*Buffer).dec_proto3_int64 + p.size = size_proto3_int64 + case reflect.Float32: + p.enc = (*Buffer).enc_proto3_uint32 // can just treat them as bits + p.dec = (*Buffer).dec_proto3_int32 + p.size = size_proto3_uint32 + case reflect.Float64: + p.enc = (*Buffer).enc_proto3_int64 // can just treat them as bits + p.dec = (*Buffer).dec_proto3_int64 + p.size = size_proto3_int64 + case reflect.String: + p.enc = (*Buffer).enc_proto3_string + p.dec = (*Buffer).dec_proto3_string + p.size = size_proto3_string + + case reflect.Ptr: + switch t2 := t1.Elem(); t2.Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no encoder function for %v -> %v\n", t1, t2) + break + case reflect.Bool: + p.enc = (*Buffer).enc_bool + p.dec = (*Buffer).dec_bool + p.size = size_bool + case reflect.Int32: + p.enc = (*Buffer).enc_int32 + p.dec = (*Buffer).dec_int32 + p.size = size_int32 + case reflect.Uint32: + p.enc = (*Buffer).enc_uint32 + p.dec = (*Buffer).dec_int32 // can reuse + p.size = size_uint32 + case reflect.Int64, reflect.Uint64: + p.enc = (*Buffer).enc_int64 + p.dec = (*Buffer).dec_int64 + p.size = size_int64 + case reflect.Float32: + p.enc = (*Buffer).enc_uint32 // can just treat them as bits + p.dec = (*Buffer).dec_int32 + p.size = size_uint32 + case reflect.Float64: + p.enc = (*Buffer).enc_int64 // can just treat them as bits + p.dec = (*Buffer).dec_int64 + p.size = size_int64 + case reflect.String: + p.enc = (*Buffer).enc_string + p.dec = (*Buffer).dec_string + p.size = size_string + case reflect.Struct: + p.stype = t1.Elem() + p.isMarshaler = isMarshaler(t1) + p.isUnmarshaler = isUnmarshaler(t1) + if p.Wire == "bytes" { + p.enc = (*Buffer).enc_struct_message + p.dec = (*Buffer).dec_struct_message + p.size = size_struct_message + } else { + p.enc = (*Buffer).enc_struct_group + p.dec = (*Buffer).dec_struct_group + p.size = size_struct_group + } + } + + case reflect.Slice: + switch t2 := t1.Elem(); t2.Kind() { + default: + logNoSliceEnc(t1, t2) + break + case reflect.Bool: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_bool + p.size = size_slice_packed_bool + } else { + p.enc = (*Buffer).enc_slice_bool + p.size = size_slice_bool + } + p.dec = (*Buffer).dec_slice_bool + p.packedDec = (*Buffer).dec_slice_packed_bool + case reflect.Int32: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_int32 + p.size = size_slice_packed_int32 + } else { + p.enc = (*Buffer).enc_slice_int32 + p.size = size_slice_int32 + } + p.dec = (*Buffer).dec_slice_int32 + p.packedDec = (*Buffer).dec_slice_packed_int32 + case reflect.Uint32: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_uint32 + p.size = size_slice_packed_uint32 + } else { + p.enc = (*Buffer).enc_slice_uint32 + p.size = size_slice_uint32 + } + p.dec = (*Buffer).dec_slice_int32 + p.packedDec = (*Buffer).dec_slice_packed_int32 + case reflect.Int64, reflect.Uint64: + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_int64 + p.size = size_slice_packed_int64 + } else { + p.enc = (*Buffer).enc_slice_int64 + p.size = size_slice_int64 + } + p.dec = (*Buffer).dec_slice_int64 + p.packedDec = (*Buffer).dec_slice_packed_int64 + case reflect.Uint8: + p.enc = (*Buffer).enc_slice_byte + p.dec = (*Buffer).dec_slice_byte + p.size = size_slice_byte + // This is a []byte, which is either a bytes field, + // or the value of a map field. In the latter case, + // we always encode an empty []byte, so we should not + // use the proto3 enc/size funcs. + // f == nil iff this is the key/value of a map field. + if p.proto3 && f != nil { + p.enc = (*Buffer).enc_proto3_slice_byte + p.size = size_proto3_slice_byte + } + case reflect.Float32, reflect.Float64: + switch t2.Bits() { + case 32: + // can just treat them as bits + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_uint32 + p.size = size_slice_packed_uint32 + } else { + p.enc = (*Buffer).enc_slice_uint32 + p.size = size_slice_uint32 + } + p.dec = (*Buffer).dec_slice_int32 + p.packedDec = (*Buffer).dec_slice_packed_int32 + case 64: + // can just treat them as bits + if p.Packed { + p.enc = (*Buffer).enc_slice_packed_int64 + p.size = size_slice_packed_int64 + } else { + p.enc = (*Buffer).enc_slice_int64 + p.size = size_slice_int64 + } + p.dec = (*Buffer).dec_slice_int64 + p.packedDec = (*Buffer).dec_slice_packed_int64 + default: + logNoSliceEnc(t1, t2) + break + } + case reflect.String: + p.enc = (*Buffer).enc_slice_string + p.dec = (*Buffer).dec_slice_string + p.size = size_slice_string + case reflect.Ptr: + switch t3 := t2.Elem(); t3.Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no ptr oenc for %T -> %T -> %T\n", t1, t2, t3) + break + case reflect.Struct: + p.stype = t2.Elem() + p.isMarshaler = isMarshaler(t2) + p.isUnmarshaler = isUnmarshaler(t2) + if p.Wire == "bytes" { + p.enc = (*Buffer).enc_slice_struct_message + p.dec = (*Buffer).dec_slice_struct_message + p.size = size_slice_struct_message + } else { + p.enc = (*Buffer).enc_slice_struct_group + p.dec = (*Buffer).dec_slice_struct_group + p.size = size_slice_struct_group + } + } + case reflect.Slice: + switch t2.Elem().Kind() { + default: + fmt.Fprintf(os.Stderr, "proto: no slice elem oenc for %T -> %T -> %T\n", t1, t2, t2.Elem()) + break + case reflect.Uint8: + p.enc = (*Buffer).enc_slice_slice_byte + p.dec = (*Buffer).dec_slice_slice_byte + p.size = size_slice_slice_byte + } + } + + case reflect.Map: + p.enc = (*Buffer).enc_new_map + p.dec = (*Buffer).dec_new_map + p.size = size_new_map + + p.mtype = t1 + p.mkeyprop = &Properties{} + p.mkeyprop.init(reflect.PtrTo(p.mtype.Key()), "Key", f.Tag.Get("protobuf_key"), nil, lockGetProp) + p.mvalprop = &Properties{} + vtype := p.mtype.Elem() + if vtype.Kind() != reflect.Ptr && vtype.Kind() != reflect.Slice { + // The value type is not a message (*T) or bytes ([]byte), + // so we need encoders for the pointer to this type. + vtype = reflect.PtrTo(vtype) + } + p.mvalprop.init(vtype, "Value", f.Tag.Get("protobuf_val"), nil, lockGetProp) + } + + // precalculate tag code + wire := p.WireType + if p.Packed { + wire = WireBytes + } + x := uint32(p.Tag)<<3 | uint32(wire) + i := 0 + for i = 0; x > 127; i++ { + p.tagbuf[i] = 0x80 | uint8(x&0x7F) + x >>= 7 + } + p.tagbuf[i] = uint8(x) + p.tagcode = p.tagbuf[0 : i+1] + + if p.stype != nil { + if lockGetProp { + p.sprop = GetProperties(p.stype) + } else { + p.sprop = getPropertiesLocked(p.stype) + } + } +} + +var ( + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() +) + +// isMarshaler reports whether type t implements Marshaler. +func isMarshaler(t reflect.Type) bool { + // We're checking for (likely) pointer-receiver methods + // so if t is not a pointer, something is very wrong. + // The calls above only invoke isMarshaler on pointer types. + if t.Kind() != reflect.Ptr { + panic("proto: misuse of isMarshaler") + } + return t.Implements(marshalerType) +} + +// isUnmarshaler reports whether type t implements Unmarshaler. +func isUnmarshaler(t reflect.Type) bool { + // We're checking for (likely) pointer-receiver methods + // so if t is not a pointer, something is very wrong. + // The calls above only invoke isUnmarshaler on pointer types. + if t.Kind() != reflect.Ptr { + panic("proto: misuse of isUnmarshaler") + } + return t.Implements(unmarshalerType) +} + +// Init populates the properties from a protocol buffer struct tag. +func (p *Properties) Init(typ reflect.Type, name, tag string, f *reflect.StructField) { + p.init(typ, name, tag, f, true) +} + +func (p *Properties) init(typ reflect.Type, name, tag string, f *reflect.StructField, lockGetProp bool) { + // "bytes,49,opt,def=hello!" + p.Name = name + p.OrigName = name + if f != nil { + p.field = toField(f) + } + if tag == "" { + return + } + p.Parse(tag) + p.setEncAndDec(typ, f, lockGetProp) +} + +var ( + propertiesMu sync.RWMutex + propertiesMap = make(map[reflect.Type]*StructProperties) +) + +// GetProperties returns the list of properties for the type represented by t. +// t must represent a generated struct type of a protocol message. +func GetProperties(t reflect.Type) *StructProperties { + if t.Kind() != reflect.Struct { + panic("proto: type must have kind struct") + } + + // Most calls to GetProperties in a long-running program will be + // retrieving details for types we have seen before. + propertiesMu.RLock() + sprop, ok := propertiesMap[t] + propertiesMu.RUnlock() + if ok { + if collectStats { + stats.Chit++ + } + return sprop + } + + propertiesMu.Lock() + sprop = getPropertiesLocked(t) + propertiesMu.Unlock() + return sprop +} + +// getPropertiesLocked requires that propertiesMu is held. +func getPropertiesLocked(t reflect.Type) *StructProperties { + if prop, ok := propertiesMap[t]; ok { + if collectStats { + stats.Chit++ + } + return prop + } + if collectStats { + stats.Cmiss++ + } + + prop := new(StructProperties) + // in case of recursive protos, fill this in now. + propertiesMap[t] = prop + + // build properties + prop.extendable = reflect.PtrTo(t).Implements(extendableProtoType) + prop.unrecField = invalidField + prop.Prop = make([]*Properties, t.NumField()) + prop.order = make([]int, t.NumField()) + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + p := new(Properties) + name := f.Name + p.init(f.Type, name, f.Tag.Get("protobuf"), &f, false) + + if f.Name == "XXX_extensions" { // special case + p.enc = (*Buffer).enc_map + p.dec = nil // not needed + p.size = size_map + } + if f.Name == "XXX_unrecognized" { // special case + prop.unrecField = toField(&f) + } + oneof := f.Tag.Get("protobuf_oneof") != "" // special case + prop.Prop[i] = p + prop.order[i] = i + if debug { + print(i, " ", f.Name, " ", t.String(), " ") + if p.Tag > 0 { + print(p.String()) + } + print("\n") + } + if p.enc == nil && !strings.HasPrefix(f.Name, "XXX_") && !oneof { + fmt.Fprintln(os.Stderr, "proto: no encoder for", f.Name, f.Type.String(), "[GetProperties]") + } + } + + // Re-order prop.order. + sort.Sort(prop) + + type oneofMessage interface { + XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), []interface{}) + } + if om, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); ok { + var oots []interface{} + prop.oneofMarshaler, prop.oneofUnmarshaler, oots = om.XXX_OneofFuncs() + prop.stype = t + + // Interpret oneof metadata. + prop.OneofTypes = make(map[string]*OneofProperties) + for _, oot := range oots { + oop := &OneofProperties{ + Type: reflect.ValueOf(oot).Type(), // *T + Prop: new(Properties), + } + sft := oop.Type.Elem().Field(0) + oop.Prop.Name = sft.Name + oop.Prop.Parse(sft.Tag.Get("protobuf")) + // There will be exactly one interface field that + // this new value is assignable to. + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.Type.Kind() != reflect.Interface { + continue + } + if !oop.Type.AssignableTo(f.Type) { + continue + } + oop.Field = i + break + } + prop.OneofTypes[oop.Prop.OrigName] = oop + } + } + + // build required counts + // build tags + reqCount := 0 + prop.decoderOrigNames = make(map[string]int) + for i, p := range prop.Prop { + if strings.HasPrefix(p.Name, "XXX_") { + // Internal fields should not appear in tags/origNames maps. + // They are handled specially when encoding and decoding. + continue + } + if p.Required { + reqCount++ + } + prop.decoderTags.put(p.Tag, i) + prop.decoderOrigNames[p.OrigName] = i + } + prop.reqCount = reqCount + + return prop +} + +// Return the Properties object for the x[0]'th field of the structure. +func propByIndex(t reflect.Type, x []int) *Properties { + if len(x) != 1 { + fmt.Fprintf(os.Stderr, "proto: field index dimension %d (not 1) for type %s\n", len(x), t) + return nil + } + prop := GetProperties(t) + return prop.Prop[x[0]] +} + +// Get the address and type of a pointer to a struct from an interface. +func getbase(pb Message) (t reflect.Type, b structPointer, err error) { + if pb == nil { + err = ErrNil + return + } + // get the reflect type of the pointer to the struct. + t = reflect.TypeOf(pb) + // get the address of the struct. + value := reflect.ValueOf(pb) + b = toStructPointer(value) + return +} + +// A global registry of enum types. +// The generated code will register the generated maps by calling RegisterEnum. + +var enumValueMaps = make(map[string]map[string]int32) + +// RegisterEnum is called from the generated code to install the enum descriptor +// maps into the global table to aid parsing text format protocol buffers. +func RegisterEnum(typeName string, unusedNameMap map[int32]string, valueMap map[string]int32) { + if _, ok := enumValueMaps[typeName]; ok { + panic("proto: duplicate enum registered: " + typeName) + } + enumValueMaps[typeName] = valueMap +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/proto3_proto/proto3.pb.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/proto3_proto/proto3.pb.go new file mode 100644 index 0000000..37c7782 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/proto3_proto/proto3.pb.go @@ -0,0 +1,122 @@ +// Code generated by protoc-gen-go. +// source: proto3_proto/proto3.proto +// DO NOT EDIT! + +/* +Package proto3_proto is a generated protocol buffer package. + +It is generated from these files: + proto3_proto/proto3.proto + +It has these top-level messages: + Message + Nested + MessageWithMap +*/ +package proto3_proto + +import proto "github.com/golang/protobuf/proto" +import testdata "github.com/golang/protobuf/proto/testdata" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal + +type Message_Humour int32 + +const ( + Message_UNKNOWN Message_Humour = 0 + Message_PUNS Message_Humour = 1 + Message_SLAPSTICK Message_Humour = 2 + Message_BILL_BAILEY Message_Humour = 3 +) + +var Message_Humour_name = map[int32]string{ + 0: "UNKNOWN", + 1: "PUNS", + 2: "SLAPSTICK", + 3: "BILL_BAILEY", +} +var Message_Humour_value = map[string]int32{ + "UNKNOWN": 0, + "PUNS": 1, + "SLAPSTICK": 2, + "BILL_BAILEY": 3, +} + +func (x Message_Humour) String() string { + return proto.EnumName(Message_Humour_name, int32(x)) +} + +type Message struct { + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + Hilarity Message_Humour `protobuf:"varint,2,opt,name=hilarity,enum=proto3_proto.Message_Humour" json:"hilarity,omitempty"` + HeightInCm uint32 `protobuf:"varint,3,opt,name=height_in_cm" json:"height_in_cm,omitempty"` + Data []byte `protobuf:"bytes,4,opt,name=data,proto3" json:"data,omitempty"` + ResultCount int64 `protobuf:"varint,7,opt,name=result_count" json:"result_count,omitempty"` + TrueScotsman bool `protobuf:"varint,8,opt,name=true_scotsman" json:"true_scotsman,omitempty"` + Score float32 `protobuf:"fixed32,9,opt,name=score" json:"score,omitempty"` + Key []uint64 `protobuf:"varint,5,rep,name=key" json:"key,omitempty"` + Nested *Nested `protobuf:"bytes,6,opt,name=nested" json:"nested,omitempty"` + Terrain map[string]*Nested `protobuf:"bytes,10,rep,name=terrain" json:"terrain,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + Proto2Field *testdata.SubDefaults `protobuf:"bytes,11,opt,name=proto2_field" json:"proto2_field,omitempty"` + Proto2Value map[string]*testdata.SubDefaults `protobuf:"bytes,13,rep,name=proto2_value" json:"proto2_value,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` +} + +func (m *Message) Reset() { *m = Message{} } +func (m *Message) String() string { return proto.CompactTextString(m) } +func (*Message) ProtoMessage() {} + +func (m *Message) GetNested() *Nested { + if m != nil { + return m.Nested + } + return nil +} + +func (m *Message) GetTerrain() map[string]*Nested { + if m != nil { + return m.Terrain + } + return nil +} + +func (m *Message) GetProto2Field() *testdata.SubDefaults { + if m != nil { + return m.Proto2Field + } + return nil +} + +func (m *Message) GetProto2Value() map[string]*testdata.SubDefaults { + if m != nil { + return m.Proto2Value + } + return nil +} + +type Nested struct { + Bunny string `protobuf:"bytes,1,opt,name=bunny" json:"bunny,omitempty"` +} + +func (m *Nested) Reset() { *m = Nested{} } +func (m *Nested) String() string { return proto.CompactTextString(m) } +func (*Nested) ProtoMessage() {} + +type MessageWithMap struct { + ByteMapping map[bool][]byte `protobuf:"bytes,1,rep,name=byte_mapping" json:"byte_mapping,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value,proto3"` +} + +func (m *MessageWithMap) Reset() { *m = MessageWithMap{} } +func (m *MessageWithMap) String() string { return proto.CompactTextString(m) } +func (*MessageWithMap) ProtoMessage() {} + +func (m *MessageWithMap) GetByteMapping() map[bool][]byte { + if m != nil { + return m.ByteMapping + } + return nil +} + +func init() { + proto.RegisterEnum("proto3_proto.Message_Humour", Message_Humour_name, Message_Humour_value) +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/proto3_proto/proto3.proto b/Godeps/_workspace/src/github.com/golang/protobuf/proto/proto3_proto/proto3.proto new file mode 100644 index 0000000..e2311d9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/proto3_proto/proto3.proto @@ -0,0 +1,68 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2014 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto3"; + +import "testdata/test.proto"; + +package proto3_proto; + +message Message { + enum Humour { + UNKNOWN = 0; + PUNS = 1; + SLAPSTICK = 2; + BILL_BAILEY = 3; + } + + string name = 1; + Humour hilarity = 2; + uint32 height_in_cm = 3; + bytes data = 4; + int64 result_count = 7; + bool true_scotsman = 8; + float score = 9; + + repeated uint64 key = 5; + Nested nested = 6; + + map terrain = 10; + testdata.SubDefaults proto2_field = 11; + map proto2_value = 13; +} + +message Nested { + string bunny = 1; +} + +message MessageWithMap { + map byte_mapping = 1; +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/proto3_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/proto3_test.go new file mode 100644 index 0000000..462f805 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/proto3_test.go @@ -0,0 +1,125 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2014 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "testing" + + "github.com/golang/protobuf/proto" + pb "github.com/golang/protobuf/proto/proto3_proto" + tpb "github.com/golang/protobuf/proto/testdata" +) + +func TestProto3ZeroValues(t *testing.T) { + tests := []struct { + desc string + m proto.Message + }{ + {"zero message", &pb.Message{}}, + {"empty bytes field", &pb.Message{Data: []byte{}}}, + } + for _, test := range tests { + b, err := proto.Marshal(test.m) + if err != nil { + t.Errorf("%s: proto.Marshal: %v", test.desc, err) + continue + } + if len(b) > 0 { + t.Errorf("%s: Encoding is non-empty: %q", test.desc, b) + } + } +} + +func TestRoundTripProto3(t *testing.T) { + m := &pb.Message{ + Name: "David", // (2 | 1<<3): 0x0a 0x05 "David" + Hilarity: pb.Message_PUNS, // (0 | 2<<3): 0x10 0x01 + HeightInCm: 178, // (0 | 3<<3): 0x18 0xb2 0x01 + Data: []byte("roboto"), // (2 | 4<<3): 0x20 0x06 "roboto" + ResultCount: 47, // (0 | 7<<3): 0x38 0x2f + TrueScotsman: true, // (0 | 8<<3): 0x40 0x01 + Score: 8.1, // (5 | 9<<3): 0x4d <8.1> + + Key: []uint64{1, 0xdeadbeef}, + Nested: &pb.Nested{ + Bunny: "Monty", + }, + } + t.Logf(" m: %v", m) + + b, err := proto.Marshal(m) + if err != nil { + t.Fatalf("proto.Marshal: %v", err) + } + t.Logf(" b: %q", b) + + m2 := new(pb.Message) + if err := proto.Unmarshal(b, m2); err != nil { + t.Fatalf("proto.Unmarshal: %v", err) + } + t.Logf("m2: %v", m2) + + if !proto.Equal(m, m2) { + t.Errorf("proto.Equal returned false:\n m: %v\nm2: %v", m, m2) + } +} + +func TestProto3SetDefaults(t *testing.T) { + in := &pb.Message{ + Terrain: map[string]*pb.Nested{ + "meadow": new(pb.Nested), + }, + Proto2Field: new(tpb.SubDefaults), + Proto2Value: map[string]*tpb.SubDefaults{ + "badlands": new(tpb.SubDefaults), + }, + } + + got := proto.Clone(in).(*pb.Message) + proto.SetDefaults(got) + + // There are no defaults in proto3. Everything should be the zero value, but + // we need to remember to set defaults for nested proto2 messages. + want := &pb.Message{ + Terrain: map[string]*pb.Nested{ + "meadow": new(pb.Nested), + }, + Proto2Field: &tpb.SubDefaults{N: proto.Int64(7)}, + Proto2Value: map[string]*tpb.SubDefaults{ + "badlands": &tpb.SubDefaults{N: proto.Int64(7)}, + }, + } + + if !proto.Equal(got, want) { + t.Errorf("with in = %v\nproto.SetDefaults(in) =>\ngot %v\nwant %v", in, got, want) + } +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/size2_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/size2_test.go new file mode 100644 index 0000000..a2729c3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/size2_test.go @@ -0,0 +1,63 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +import ( + "testing" +) + +// This is a separate file and package from size_test.go because that one uses +// generated messages and thus may not be in package proto without having a circular +// dependency, whereas this file tests unexported details of size.go. + +func TestVarintSize(t *testing.T) { + // Check the edge cases carefully. + testCases := []struct { + n uint64 + size int + }{ + {0, 1}, + {1, 1}, + {127, 1}, + {128, 2}, + {16383, 2}, + {16384, 3}, + {1<<63 - 1, 9}, + {1 << 63, 10}, + } + for _, tc := range testCases { + size := sizeVarint(tc.n) + if size != tc.size { + t.Errorf("sizeVarint(%d) = %d, want %d", tc.n, size, tc.size) + } + } +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/size_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/size_test.go new file mode 100644 index 0000000..53806a3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/size_test.go @@ -0,0 +1,146 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2012 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "log" + "strings" + "testing" + + . "github.com/golang/protobuf/proto" + proto3pb "github.com/golang/protobuf/proto/proto3_proto" + pb "github.com/golang/protobuf/proto/testdata" +) + +var messageWithExtension1 = &pb.MyMessage{Count: Int32(7)} + +// messageWithExtension2 is in equal_test.go. +var messageWithExtension3 = &pb.MyMessage{Count: Int32(8)} + +func init() { + if err := SetExtension(messageWithExtension1, pb.E_Ext_More, &pb.Ext{Data: String("Abbott")}); err != nil { + log.Panicf("SetExtension: %v", err) + } + if err := SetExtension(messageWithExtension3, pb.E_Ext_More, &pb.Ext{Data: String("Costello")}); err != nil { + log.Panicf("SetExtension: %v", err) + } + + // Force messageWithExtension3 to have the extension encoded. + Marshal(messageWithExtension3) + +} + +var SizeTests = []struct { + desc string + pb Message +}{ + {"empty", &pb.OtherMessage{}}, + // Basic types. + {"bool", &pb.Defaults{F_Bool: Bool(true)}}, + {"int32", &pb.Defaults{F_Int32: Int32(12)}}, + {"negative int32", &pb.Defaults{F_Int32: Int32(-1)}}, + {"small int64", &pb.Defaults{F_Int64: Int64(1)}}, + {"big int64", &pb.Defaults{F_Int64: Int64(1 << 20)}}, + {"negative int64", &pb.Defaults{F_Int64: Int64(-1)}}, + {"fixed32", &pb.Defaults{F_Fixed32: Uint32(71)}}, + {"fixed64", &pb.Defaults{F_Fixed64: Uint64(72)}}, + {"uint32", &pb.Defaults{F_Uint32: Uint32(123)}}, + {"uint64", &pb.Defaults{F_Uint64: Uint64(124)}}, + {"float", &pb.Defaults{F_Float: Float32(12.6)}}, + {"double", &pb.Defaults{F_Double: Float64(13.9)}}, + {"string", &pb.Defaults{F_String: String("niles")}}, + {"bytes", &pb.Defaults{F_Bytes: []byte("wowsa")}}, + {"bytes, empty", &pb.Defaults{F_Bytes: []byte{}}}, + {"sint32", &pb.Defaults{F_Sint32: Int32(65)}}, + {"sint64", &pb.Defaults{F_Sint64: Int64(67)}}, + {"enum", &pb.Defaults{F_Enum: pb.Defaults_BLUE.Enum()}}, + // Repeated. + {"empty repeated bool", &pb.MoreRepeated{Bools: []bool{}}}, + {"repeated bool", &pb.MoreRepeated{Bools: []bool{false, true, true, false}}}, + {"packed repeated bool", &pb.MoreRepeated{BoolsPacked: []bool{false, true, true, false, true, true, true}}}, + {"repeated int32", &pb.MoreRepeated{Ints: []int32{1, 12203, 1729, -1}}}, + {"repeated int32 packed", &pb.MoreRepeated{IntsPacked: []int32{1, 12203, 1729}}}, + {"repeated int64 packed", &pb.MoreRepeated{Int64SPacked: []int64{ + // Need enough large numbers to verify that the header is counting the number of bytes + // for the field, not the number of elements. + 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, + 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, 1 << 62, + }}}, + {"repeated string", &pb.MoreRepeated{Strings: []string{"r", "ken", "gri"}}}, + {"repeated fixed", &pb.MoreRepeated{Fixeds: []uint32{1, 2, 3, 4}}}, + // Nested. + {"nested", &pb.OldMessage{Nested: &pb.OldMessage_Nested{Name: String("whatever")}}}, + {"group", &pb.GroupOld{G: &pb.GroupOld_G{X: Int32(12345)}}}, + // Other things. + {"unrecognized", &pb.MoreRepeated{XXX_unrecognized: []byte{13<<3 | 0, 4}}}, + {"extension (unencoded)", messageWithExtension1}, + {"extension (encoded)", messageWithExtension3}, + // proto3 message + {"proto3 empty", &proto3pb.Message{}}, + {"proto3 bool", &proto3pb.Message{TrueScotsman: true}}, + {"proto3 int64", &proto3pb.Message{ResultCount: 1}}, + {"proto3 uint32", &proto3pb.Message{HeightInCm: 123}}, + {"proto3 float", &proto3pb.Message{Score: 12.6}}, + {"proto3 string", &proto3pb.Message{Name: "Snezana"}}, + {"proto3 bytes", &proto3pb.Message{Data: []byte("wowsa")}}, + {"proto3 bytes, empty", &proto3pb.Message{Data: []byte{}}}, + {"proto3 enum", &proto3pb.Message{Hilarity: proto3pb.Message_PUNS}}, + {"proto3 map field with empty bytes", &proto3pb.MessageWithMap{ByteMapping: map[bool][]byte{false: []byte{}}}}, + + {"map field", &pb.MessageWithMap{NameMapping: map[int32]string{1: "Rob", 7: "Andrew"}}}, + {"map field with message", &pb.MessageWithMap{MsgMapping: map[int64]*pb.FloatingPoint{0x7001: &pb.FloatingPoint{F: Float64(2.0)}}}}, + {"map field with bytes", &pb.MessageWithMap{ByteMapping: map[bool][]byte{true: []byte("this time for sure")}}}, + {"map field with empty bytes", &pb.MessageWithMap{ByteMapping: map[bool][]byte{true: []byte{}}}}, + + {"map field with big entry", &pb.MessageWithMap{NameMapping: map[int32]string{8: strings.Repeat("x", 125)}}}, + {"map field with big key and val", &pb.MessageWithMap{StrToStr: map[string]string{strings.Repeat("x", 70): strings.Repeat("y", 70)}}}, + {"map field with big numeric key", &pb.MessageWithMap{NameMapping: map[int32]string{0xf00d: "om nom nom"}}}, + + {"oneof not set", &pb.Communique{}}, + {"oneof int32", &pb.Communique{Union: &pb.Communique_Number{3}}}, + {"oneof string", &pb.Communique{Union: &pb.Communique_Name{"Rhythmic Fman"}}}, +} + +func TestSize(t *testing.T) { + for _, tc := range SizeTests { + size := Size(tc.pb) + b, err := Marshal(tc.pb) + if err != nil { + t.Errorf("%v: Marshal failed: %v", tc.desc, err) + continue + } + if size != len(b) { + t.Errorf("%v: Size(%v) = %d, want %d", tc.desc, tc.pb, size, len(b)) + t.Logf("%v: bytes: %#v", tc.desc, b) + } + } +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/text.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/text.go new file mode 100644 index 0000000..dd69110 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/text.go @@ -0,0 +1,796 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +// Functions for writing the text protocol buffer format. + +import ( + "bufio" + "bytes" + "encoding" + "errors" + "fmt" + "io" + "log" + "math" + "reflect" + "sort" + "strings" +) + +var ( + newline = []byte("\n") + spaces = []byte(" ") + gtNewline = []byte(">\n") + endBraceNewline = []byte("}\n") + backslashN = []byte{'\\', 'n'} + backslashR = []byte{'\\', 'r'} + backslashT = []byte{'\\', 't'} + backslashDQ = []byte{'\\', '"'} + backslashBS = []byte{'\\', '\\'} + posInf = []byte("inf") + negInf = []byte("-inf") + nan = []byte("nan") +) + +type writer interface { + io.Writer + WriteByte(byte) error +} + +// textWriter is an io.Writer that tracks its indentation level. +type textWriter struct { + ind int + complete bool // if the current position is a complete line + compact bool // whether to write out as a one-liner + w writer +} + +func (w *textWriter) WriteString(s string) (n int, err error) { + if !strings.Contains(s, "\n") { + if !w.compact && w.complete { + w.writeIndent() + } + w.complete = false + return io.WriteString(w.w, s) + } + // WriteString is typically called without newlines, so this + // codepath and its copy are rare. We copy to avoid + // duplicating all of Write's logic here. + return w.Write([]byte(s)) +} + +func (w *textWriter) Write(p []byte) (n int, err error) { + newlines := bytes.Count(p, newline) + if newlines == 0 { + if !w.compact && w.complete { + w.writeIndent() + } + n, err = w.w.Write(p) + w.complete = false + return n, err + } + + frags := bytes.SplitN(p, newline, newlines+1) + if w.compact { + for i, frag := range frags { + if i > 0 { + if err := w.w.WriteByte(' '); err != nil { + return n, err + } + n++ + } + nn, err := w.w.Write(frag) + n += nn + if err != nil { + return n, err + } + } + return n, nil + } + + for i, frag := range frags { + if w.complete { + w.writeIndent() + } + nn, err := w.w.Write(frag) + n += nn + if err != nil { + return n, err + } + if i+1 < len(frags) { + if err := w.w.WriteByte('\n'); err != nil { + return n, err + } + n++ + } + } + w.complete = len(frags[len(frags)-1]) == 0 + return n, nil +} + +func (w *textWriter) WriteByte(c byte) error { + if w.compact && c == '\n' { + c = ' ' + } + if !w.compact && w.complete { + w.writeIndent() + } + err := w.w.WriteByte(c) + w.complete = c == '\n' + return err +} + +func (w *textWriter) indent() { w.ind++ } + +func (w *textWriter) unindent() { + if w.ind == 0 { + log.Printf("proto: textWriter unindented too far") + return + } + w.ind-- +} + +func writeName(w *textWriter, props *Properties) error { + if _, err := w.WriteString(props.OrigName); err != nil { + return err + } + if props.Wire != "group" { + return w.WriteByte(':') + } + return nil +} + +var ( + messageSetType = reflect.TypeOf((*MessageSet)(nil)).Elem() +) + +// raw is the interface satisfied by RawMessage. +type raw interface { + Bytes() []byte +} + +func writeStruct(w *textWriter, sv reflect.Value) error { + if sv.Type() == messageSetType { + return writeMessageSet(w, sv.Addr().Interface().(*MessageSet)) + } + + st := sv.Type() + sprops := GetProperties(st) + for i := 0; i < sv.NumField(); i++ { + fv := sv.Field(i) + props := sprops.Prop[i] + name := st.Field(i).Name + + if strings.HasPrefix(name, "XXX_") { + // There are two XXX_ fields: + // XXX_unrecognized []byte + // XXX_extensions map[int32]proto.Extension + // The first is handled here; + // the second is handled at the bottom of this function. + if name == "XXX_unrecognized" && !fv.IsNil() { + if err := writeUnknownStruct(w, fv.Interface().([]byte)); err != nil { + return err + } + } + continue + } + if fv.Kind() == reflect.Ptr && fv.IsNil() { + // Field not filled in. This could be an optional field or + // a required field that wasn't filled in. Either way, there + // isn't anything we can show for it. + continue + } + if fv.Kind() == reflect.Slice && fv.IsNil() { + // Repeated field that is empty, or a bytes field that is unused. + continue + } + + if props.Repeated && fv.Kind() == reflect.Slice { + // Repeated field. + for j := 0; j < fv.Len(); j++ { + if err := writeName(w, props); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + v := fv.Index(j) + if v.Kind() == reflect.Ptr && v.IsNil() { + // A nil message in a repeated field is not valid, + // but we can handle that more gracefully than panicking. + if _, err := w.Write([]byte("\n")); err != nil { + return err + } + continue + } + if err := writeAny(w, v, props); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + } + continue + } + if fv.Kind() == reflect.Map { + // Map fields are rendered as a repeated struct with key/value fields. + keys := fv.MapKeys() + sort.Sort(mapKeys(keys)) + for _, key := range keys { + val := fv.MapIndex(key) + if err := writeName(w, props); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + // open struct + if err := w.WriteByte('<'); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte('\n'); err != nil { + return err + } + } + w.indent() + // key + if _, err := w.WriteString("key:"); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if err := writeAny(w, key, props.mkeyprop); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + // nil values aren't legal, but we can avoid panicking because of them. + if val.Kind() != reflect.Ptr || !val.IsNil() { + // value + if _, err := w.WriteString("value:"); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if err := writeAny(w, val, props.mvalprop); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + } + // close struct + w.unindent() + if err := w.WriteByte('>'); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + } + continue + } + if props.proto3 && fv.Kind() == reflect.Slice && fv.Len() == 0 { + // empty bytes field + continue + } + if fv.Kind() != reflect.Ptr && fv.Kind() != reflect.Slice { + // proto3 non-repeated scalar field; skip if zero value + if isProto3Zero(fv) { + continue + } + } + + if fv.Kind() == reflect.Interface { + // Check if it is a oneof. + if st.Field(i).Tag.Get("protobuf_oneof") != "" { + // fv is nil, or holds a pointer to generated struct. + // That generated struct has exactly one field, + // which has a protobuf struct tag. + if fv.IsNil() { + continue + } + inner := fv.Elem().Elem() // interface -> *T -> T + tag := inner.Type().Field(0).Tag.Get("protobuf") + props.Parse(tag) // Overwrite the outer props. + // Write the value in the oneof, not the oneof itself. + fv = inner.Field(0) + + // Special case to cope with malformed messages gracefully: + // If the value in the oneof is a nil pointer, don't panic + // in writeAny. + if fv.Kind() == reflect.Ptr && fv.IsNil() { + // Use errors.New so writeAny won't render quotes. + msg := errors.New("/* nil */") + fv = reflect.ValueOf(&msg).Elem() + } + } + } + + if err := writeName(w, props); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if b, ok := fv.Interface().(raw); ok { + if err := writeRaw(w, b.Bytes()); err != nil { + return err + } + continue + } + + // Enums have a String method, so writeAny will work fine. + if err := writeAny(w, fv, props); err != nil { + return err + } + + if err := w.WriteByte('\n'); err != nil { + return err + } + } + + // Extensions (the XXX_extensions field). + pv := sv.Addr() + if pv.Type().Implements(extendableProtoType) { + if err := writeExtensions(w, pv); err != nil { + return err + } + } + + return nil +} + +// writeRaw writes an uninterpreted raw message. +func writeRaw(w *textWriter, b []byte) error { + if err := w.WriteByte('<'); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte('\n'); err != nil { + return err + } + } + w.indent() + if err := writeUnknownStruct(w, b); err != nil { + return err + } + w.unindent() + if err := w.WriteByte('>'); err != nil { + return err + } + return nil +} + +// writeAny writes an arbitrary field. +func writeAny(w *textWriter, v reflect.Value, props *Properties) error { + v = reflect.Indirect(v) + + // Floats have special cases. + if v.Kind() == reflect.Float32 || v.Kind() == reflect.Float64 { + x := v.Float() + var b []byte + switch { + case math.IsInf(x, 1): + b = posInf + case math.IsInf(x, -1): + b = negInf + case math.IsNaN(x): + b = nan + } + if b != nil { + _, err := w.Write(b) + return err + } + // Other values are handled below. + } + + // We don't attempt to serialise every possible value type; only those + // that can occur in protocol buffers. + switch v.Kind() { + case reflect.Slice: + // Should only be a []byte; repeated fields are handled in writeStruct. + if err := writeString(w, string(v.Interface().([]byte))); err != nil { + return err + } + case reflect.String: + if err := writeString(w, v.String()); err != nil { + return err + } + case reflect.Struct: + // Required/optional group/message. + var bra, ket byte = '<', '>' + if props != nil && props.Wire == "group" { + bra, ket = '{', '}' + } + if err := w.WriteByte(bra); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte('\n'); err != nil { + return err + } + } + w.indent() + if tm, ok := v.Interface().(encoding.TextMarshaler); ok { + text, err := tm.MarshalText() + if err != nil { + return err + } + if _, err = w.Write(text); err != nil { + return err + } + } else if err := writeStruct(w, v); err != nil { + return err + } + w.unindent() + if err := w.WriteByte(ket); err != nil { + return err + } + default: + _, err := fmt.Fprint(w, v.Interface()) + return err + } + return nil +} + +// equivalent to C's isprint. +func isprint(c byte) bool { + return c >= 0x20 && c < 0x7f +} + +// writeString writes a string in the protocol buffer text format. +// It is similar to strconv.Quote except we don't use Go escape sequences, +// we treat the string as a byte sequence, and we use octal escapes. +// These differences are to maintain interoperability with the other +// languages' implementations of the text format. +func writeString(w *textWriter, s string) error { + // use WriteByte here to get any needed indent + if err := w.WriteByte('"'); err != nil { + return err + } + // Loop over the bytes, not the runes. + for i := 0; i < len(s); i++ { + var err error + // Divergence from C++: we don't escape apostrophes. + // There's no need to escape them, and the C++ parser + // copes with a naked apostrophe. + switch c := s[i]; c { + case '\n': + _, err = w.w.Write(backslashN) + case '\r': + _, err = w.w.Write(backslashR) + case '\t': + _, err = w.w.Write(backslashT) + case '"': + _, err = w.w.Write(backslashDQ) + case '\\': + _, err = w.w.Write(backslashBS) + default: + if isprint(c) { + err = w.w.WriteByte(c) + } else { + _, err = fmt.Fprintf(w.w, "\\%03o", c) + } + } + if err != nil { + return err + } + } + return w.WriteByte('"') +} + +func writeMessageSet(w *textWriter, ms *MessageSet) error { + for _, item := range ms.Item { + id := *item.TypeId + if msd, ok := messageSetMap[id]; ok { + // Known message set type. + if _, err := fmt.Fprintf(w, "[%s]: <\n", msd.name); err != nil { + return err + } + w.indent() + + pb := reflect.New(msd.t.Elem()) + if err := Unmarshal(item.Message, pb.Interface().(Message)); err != nil { + if _, err := fmt.Fprintf(w, "/* bad message: %v */\n", err); err != nil { + return err + } + } else { + if err := writeStruct(w, pb.Elem()); err != nil { + return err + } + } + } else { + // Unknown type. + if _, err := fmt.Fprintf(w, "[%d]: <\n", id); err != nil { + return err + } + w.indent() + if err := writeUnknownStruct(w, item.Message); err != nil { + return err + } + } + w.unindent() + if _, err := w.Write(gtNewline); err != nil { + return err + } + } + return nil +} + +func writeUnknownStruct(w *textWriter, data []byte) (err error) { + if !w.compact { + if _, err := fmt.Fprintf(w, "/* %d unknown bytes */\n", len(data)); err != nil { + return err + } + } + b := NewBuffer(data) + for b.index < len(b.buf) { + x, err := b.DecodeVarint() + if err != nil { + _, err := fmt.Fprintf(w, "/* %v */\n", err) + return err + } + wire, tag := x&7, x>>3 + if wire == WireEndGroup { + w.unindent() + if _, err := w.Write(endBraceNewline); err != nil { + return err + } + continue + } + if _, err := fmt.Fprint(w, tag); err != nil { + return err + } + if wire != WireStartGroup { + if err := w.WriteByte(':'); err != nil { + return err + } + } + if !w.compact || wire == WireStartGroup { + if err := w.WriteByte(' '); err != nil { + return err + } + } + switch wire { + case WireBytes: + buf, e := b.DecodeRawBytes(false) + if e == nil { + _, err = fmt.Fprintf(w, "%q", buf) + } else { + _, err = fmt.Fprintf(w, "/* %v */", e) + } + case WireFixed32: + x, err = b.DecodeFixed32() + err = writeUnknownInt(w, x, err) + case WireFixed64: + x, err = b.DecodeFixed64() + err = writeUnknownInt(w, x, err) + case WireStartGroup: + err = w.WriteByte('{') + w.indent() + case WireVarint: + x, err = b.DecodeVarint() + err = writeUnknownInt(w, x, err) + default: + _, err = fmt.Fprintf(w, "/* unknown wire type %d */", wire) + } + if err != nil { + return err + } + if err = w.WriteByte('\n'); err != nil { + return err + } + } + return nil +} + +func writeUnknownInt(w *textWriter, x uint64, err error) error { + if err == nil { + _, err = fmt.Fprint(w, x) + } else { + _, err = fmt.Fprintf(w, "/* %v */", err) + } + return err +} + +type int32Slice []int32 + +func (s int32Slice) Len() int { return len(s) } +func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] } +func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +// writeExtensions writes all the extensions in pv. +// pv is assumed to be a pointer to a protocol message struct that is extendable. +func writeExtensions(w *textWriter, pv reflect.Value) error { + emap := extensionMaps[pv.Type().Elem()] + ep := pv.Interface().(extendableProto) + + // Order the extensions by ID. + // This isn't strictly necessary, but it will give us + // canonical output, which will also make testing easier. + m := ep.ExtensionMap() + ids := make([]int32, 0, len(m)) + for id := range m { + ids = append(ids, id) + } + sort.Sort(int32Slice(ids)) + + for _, extNum := range ids { + ext := m[extNum] + var desc *ExtensionDesc + if emap != nil { + desc = emap[extNum] + } + if desc == nil { + // Unknown extension. + if err := writeUnknownStruct(w, ext.enc); err != nil { + return err + } + continue + } + + pb, err := GetExtension(ep, desc) + if err != nil { + return fmt.Errorf("failed getting extension: %v", err) + } + + // Repeated extensions will appear as a slice. + if !desc.repeated() { + if err := writeExtension(w, desc.Name, pb); err != nil { + return err + } + } else { + v := reflect.ValueOf(pb) + for i := 0; i < v.Len(); i++ { + if err := writeExtension(w, desc.Name, v.Index(i).Interface()); err != nil { + return err + } + } + } + } + return nil +} + +func writeExtension(w *textWriter, name string, pb interface{}) error { + if _, err := fmt.Fprintf(w, "[%s]:", name); err != nil { + return err + } + if !w.compact { + if err := w.WriteByte(' '); err != nil { + return err + } + } + if err := writeAny(w, reflect.ValueOf(pb), nil); err != nil { + return err + } + if err := w.WriteByte('\n'); err != nil { + return err + } + return nil +} + +func (w *textWriter) writeIndent() { + if !w.complete { + return + } + remain := w.ind * 2 + for remain > 0 { + n := remain + if n > len(spaces) { + n = len(spaces) + } + w.w.Write(spaces[:n]) + remain -= n + } + w.complete = false +} + +func marshalText(w io.Writer, pb Message, compact bool) error { + val := reflect.ValueOf(pb) + if pb == nil || val.IsNil() { + w.Write([]byte("")) + return nil + } + var bw *bufio.Writer + ww, ok := w.(writer) + if !ok { + bw = bufio.NewWriter(w) + ww = bw + } + aw := &textWriter{ + w: ww, + complete: true, + compact: compact, + } + + if tm, ok := pb.(encoding.TextMarshaler); ok { + text, err := tm.MarshalText() + if err != nil { + return err + } + if _, err = aw.Write(text); err != nil { + return err + } + if bw != nil { + return bw.Flush() + } + return nil + } + // Dereference the received pointer so we don't have outer < and >. + v := reflect.Indirect(val) + if err := writeStruct(aw, v); err != nil { + return err + } + if bw != nil { + return bw.Flush() + } + return nil +} + +// MarshalText writes a given protocol buffer in text format. +// The only errors returned are from w. +func MarshalText(w io.Writer, pb Message) error { + return marshalText(w, pb, false) +} + +// MarshalTextString is the same as MarshalText, but returns the string directly. +func MarshalTextString(pb Message) string { + var buf bytes.Buffer + marshalText(&buf, pb, false) + return buf.String() +} + +// CompactText writes a given protocol buffer in compact text format (one line). +func CompactText(w io.Writer, pb Message) error { return marshalText(w, pb, true) } + +// CompactTextString is the same as CompactText, but returns the string directly. +func CompactTextString(pb Message) string { + var buf bytes.Buffer + marshalText(&buf, pb, true) + return buf.String() +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/text_parser.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/text_parser.go new file mode 100644 index 0000000..8bce369 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/text_parser.go @@ -0,0 +1,784 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto + +// Functions for parsing the Text protocol buffer format. +// TODO: message sets. + +import ( + "encoding" + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "unicode/utf8" +) + +type ParseError struct { + Message string + Line int // 1-based line number + Offset int // 0-based byte offset from start of input +} + +func (p *ParseError) Error() string { + if p.Line == 1 { + // show offset only for first line + return fmt.Sprintf("line 1.%d: %v", p.Offset, p.Message) + } + return fmt.Sprintf("line %d: %v", p.Line, p.Message) +} + +type token struct { + value string + err *ParseError + line int // line number + offset int // byte number from start of input, not start of line + unquoted string // the unquoted version of value, if it was a quoted string +} + +func (t *token) String() string { + if t.err == nil { + return fmt.Sprintf("%q (line=%d, offset=%d)", t.value, t.line, t.offset) + } + return fmt.Sprintf("parse error: %v", t.err) +} + +type textParser struct { + s string // remaining input + done bool // whether the parsing is finished (success or error) + backed bool // whether back() was called + offset, line int + cur token +} + +func newTextParser(s string) *textParser { + p := new(textParser) + p.s = s + p.line = 1 + p.cur.line = 1 + return p +} + +func (p *textParser) errorf(format string, a ...interface{}) *ParseError { + pe := &ParseError{fmt.Sprintf(format, a...), p.cur.line, p.cur.offset} + p.cur.err = pe + p.done = true + return pe +} + +// Numbers and identifiers are matched by [-+._A-Za-z0-9] +func isIdentOrNumberChar(c byte) bool { + switch { + case 'A' <= c && c <= 'Z', 'a' <= c && c <= 'z': + return true + case '0' <= c && c <= '9': + return true + } + switch c { + case '-', '+', '.', '_': + return true + } + return false +} + +func isWhitespace(c byte) bool { + switch c { + case ' ', '\t', '\n', '\r': + return true + } + return false +} + +func (p *textParser) skipWhitespace() { + i := 0 + for i < len(p.s) && (isWhitespace(p.s[i]) || p.s[i] == '#') { + if p.s[i] == '#' { + // comment; skip to end of line or input + for i < len(p.s) && p.s[i] != '\n' { + i++ + } + if i == len(p.s) { + break + } + } + if p.s[i] == '\n' { + p.line++ + } + i++ + } + p.offset += i + p.s = p.s[i:len(p.s)] + if len(p.s) == 0 { + p.done = true + } +} + +func (p *textParser) advance() { + // Skip whitespace + p.skipWhitespace() + if p.done { + return + } + + // Start of non-whitespace + p.cur.err = nil + p.cur.offset, p.cur.line = p.offset, p.line + p.cur.unquoted = "" + switch p.s[0] { + case '<', '>', '{', '}', ':', '[', ']', ';', ',': + // Single symbol + p.cur.value, p.s = p.s[0:1], p.s[1:len(p.s)] + case '"', '\'': + // Quoted string + i := 1 + for i < len(p.s) && p.s[i] != p.s[0] && p.s[i] != '\n' { + if p.s[i] == '\\' && i+1 < len(p.s) { + // skip escaped char + i++ + } + i++ + } + if i >= len(p.s) || p.s[i] != p.s[0] { + p.errorf("unmatched quote") + return + } + unq, err := unquoteC(p.s[1:i], rune(p.s[0])) + if err != nil { + p.errorf("invalid quoted string %s: %v", p.s[0:i+1], err) + return + } + p.cur.value, p.s = p.s[0:i+1], p.s[i+1:len(p.s)] + p.cur.unquoted = unq + default: + i := 0 + for i < len(p.s) && isIdentOrNumberChar(p.s[i]) { + i++ + } + if i == 0 { + p.errorf("unexpected byte %#x", p.s[0]) + return + } + p.cur.value, p.s = p.s[0:i], p.s[i:len(p.s)] + } + p.offset += len(p.cur.value) +} + +var ( + errBadUTF8 = errors.New("proto: bad UTF-8") + errBadHex = errors.New("proto: bad hexadecimal") +) + +func unquoteC(s string, quote rune) (string, error) { + // This is based on C++'s tokenizer.cc. + // Despite its name, this is *not* parsing C syntax. + // For instance, "\0" is an invalid quoted string. + + // Avoid allocation in trivial cases. + simple := true + for _, r := range s { + if r == '\\' || r == quote { + simple = false + break + } + } + if simple { + return s, nil + } + + buf := make([]byte, 0, 3*len(s)/2) + for len(s) > 0 { + r, n := utf8.DecodeRuneInString(s) + if r == utf8.RuneError && n == 1 { + return "", errBadUTF8 + } + s = s[n:] + if r != '\\' { + if r < utf8.RuneSelf { + buf = append(buf, byte(r)) + } else { + buf = append(buf, string(r)...) + } + continue + } + + ch, tail, err := unescape(s) + if err != nil { + return "", err + } + buf = append(buf, ch...) + s = tail + } + return string(buf), nil +} + +func unescape(s string) (ch string, tail string, err error) { + r, n := utf8.DecodeRuneInString(s) + if r == utf8.RuneError && n == 1 { + return "", "", errBadUTF8 + } + s = s[n:] + switch r { + case 'a': + return "\a", s, nil + case 'b': + return "\b", s, nil + case 'f': + return "\f", s, nil + case 'n': + return "\n", s, nil + case 'r': + return "\r", s, nil + case 't': + return "\t", s, nil + case 'v': + return "\v", s, nil + case '?': + return "?", s, nil // trigraph workaround + case '\'', '"', '\\': + return string(r), s, nil + case '0', '1', '2', '3', '4', '5', '6', '7', 'x', 'X': + if len(s) < 2 { + return "", "", fmt.Errorf(`\%c requires 2 following digits`, r) + } + base := 8 + ss := s[:2] + s = s[2:] + if r == 'x' || r == 'X' { + base = 16 + } else { + ss = string(r) + ss + } + i, err := strconv.ParseUint(ss, base, 8) + if err != nil { + return "", "", err + } + return string([]byte{byte(i)}), s, nil + case 'u', 'U': + n := 4 + if r == 'U' { + n = 8 + } + if len(s) < n { + return "", "", fmt.Errorf(`\%c requires %d digits`, r, n) + } + + bs := make([]byte, n/2) + for i := 0; i < n; i += 2 { + a, ok1 := unhex(s[i]) + b, ok2 := unhex(s[i+1]) + if !ok1 || !ok2 { + return "", "", errBadHex + } + bs[i/2] = a<<4 | b + } + s = s[n:] + return string(bs), s, nil + } + return "", "", fmt.Errorf(`unknown escape \%c`, r) +} + +// Adapted from src/pkg/strconv/quote.go. +func unhex(b byte) (v byte, ok bool) { + switch { + case '0' <= b && b <= '9': + return b - '0', true + case 'a' <= b && b <= 'f': + return b - 'a' + 10, true + case 'A' <= b && b <= 'F': + return b - 'A' + 10, true + } + return 0, false +} + +// Back off the parser by one token. Can only be done between calls to next(). +// It makes the next advance() a no-op. +func (p *textParser) back() { p.backed = true } + +// Advances the parser and returns the new current token. +func (p *textParser) next() *token { + if p.backed || p.done { + p.backed = false + return &p.cur + } + p.advance() + if p.done { + p.cur.value = "" + } else if len(p.cur.value) > 0 && p.cur.value[0] == '"' { + // Look for multiple quoted strings separated by whitespace, + // and concatenate them. + cat := p.cur + for { + p.skipWhitespace() + if p.done || p.s[0] != '"' { + break + } + p.advance() + if p.cur.err != nil { + return &p.cur + } + cat.value += " " + p.cur.value + cat.unquoted += p.cur.unquoted + } + p.done = false // parser may have seen EOF, but we want to return cat + p.cur = cat + } + return &p.cur +} + +func (p *textParser) consumeToken(s string) error { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value != s { + p.back() + return p.errorf("expected %q, found %q", s, tok.value) + } + return nil +} + +// Return a RequiredNotSetError indicating which required field was not set. +func (p *textParser) missingRequiredFieldError(sv reflect.Value) *RequiredNotSetError { + st := sv.Type() + sprops := GetProperties(st) + for i := 0; i < st.NumField(); i++ { + if !isNil(sv.Field(i)) { + continue + } + + props := sprops.Prop[i] + if props.Required { + return &RequiredNotSetError{fmt.Sprintf("%v.%v", st, props.OrigName)} + } + } + return &RequiredNotSetError{fmt.Sprintf("%v.", st)} // should not happen +} + +// Returns the index in the struct for the named field, as well as the parsed tag properties. +func structFieldByName(sprops *StructProperties, name string) (int, *Properties, bool) { + i, ok := sprops.decoderOrigNames[name] + if ok { + return i, sprops.Prop[i], true + } + return -1, nil, false +} + +// Consume a ':' from the input stream (if the next token is a colon), +// returning an error if a colon is needed but not present. +func (p *textParser) checkForColon(props *Properties, typ reflect.Type) *ParseError { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value != ":" { + // Colon is optional when the field is a group or message. + needColon := true + switch props.Wire { + case "group": + needColon = false + case "bytes": + // A "bytes" field is either a message, a string, or a repeated field; + // those three become *T, *string and []T respectively, so we can check for + // this field being a pointer to a non-string. + if typ.Kind() == reflect.Ptr { + // *T or *string + if typ.Elem().Kind() == reflect.String { + break + } + } else if typ.Kind() == reflect.Slice { + // []T or []*T + if typ.Elem().Kind() != reflect.Ptr { + break + } + } else if typ.Kind() == reflect.String { + // The proto3 exception is for a string field, + // which requires a colon. + break + } + needColon = false + } + if needColon { + return p.errorf("expected ':', found %q", tok.value) + } + p.back() + } + return nil +} + +func (p *textParser) readStruct(sv reflect.Value, terminator string) error { + st := sv.Type() + sprops := GetProperties(st) + reqCount := sprops.reqCount + var reqFieldErr error + fieldSet := make(map[string]bool) + // A struct is a sequence of "name: value", terminated by one of + // '>' or '}', or the end of the input. A name may also be + // "[extension]". + for { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == terminator { + break + } + if tok.value == "[" { + // Looks like an extension. + // + // TODO: Check whether we need to handle + // namespace rooted names (e.g. ".something.Foo"). + tok = p.next() + if tok.err != nil { + return tok.err + } + var desc *ExtensionDesc + // This could be faster, but it's functional. + // TODO: Do something smarter than a linear scan. + for _, d := range RegisteredExtensions(reflect.New(st).Interface().(Message)) { + if d.Name == tok.value { + desc = d + break + } + } + if desc == nil { + return p.errorf("unrecognized extension %q", tok.value) + } + // Check the extension terminator. + tok = p.next() + if tok.err != nil { + return tok.err + } + if tok.value != "]" { + return p.errorf("unrecognized extension terminator %q", tok.value) + } + + props := &Properties{} + props.Parse(desc.Tag) + + typ := reflect.TypeOf(desc.ExtensionType) + if err := p.checkForColon(props, typ); err != nil { + return err + } + + rep := desc.repeated() + + // Read the extension structure, and set it in + // the value we're constructing. + var ext reflect.Value + if !rep { + ext = reflect.New(typ).Elem() + } else { + ext = reflect.New(typ.Elem()).Elem() + } + if err := p.readAny(ext, props); err != nil { + if _, ok := err.(*RequiredNotSetError); !ok { + return err + } + reqFieldErr = err + } + ep := sv.Addr().Interface().(extendableProto) + if !rep { + SetExtension(ep, desc, ext.Interface()) + } else { + old, err := GetExtension(ep, desc) + var sl reflect.Value + if err == nil { + sl = reflect.ValueOf(old) // existing slice + } else { + sl = reflect.MakeSlice(typ, 0, 1) + } + sl = reflect.Append(sl, ext) + SetExtension(ep, desc, sl.Interface()) + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + continue + } + + // This is a normal, non-extension field. + name := tok.value + var dst reflect.Value + fi, props, ok := structFieldByName(sprops, name) + if ok { + dst = sv.Field(fi) + } else if oop, ok := sprops.OneofTypes[name]; ok { + // It is a oneof. + props = oop.Prop + nv := reflect.New(oop.Type.Elem()) + dst = nv.Elem().Field(0) + sv.Field(oop.Field).Set(nv) + } + if !dst.IsValid() { + return p.errorf("unknown field name %q in %v", name, st) + } + + if dst.Kind() == reflect.Map { + // Consume any colon. + if err := p.checkForColon(props, dst.Type()); err != nil { + return err + } + + // Construct the map if it doesn't already exist. + if dst.IsNil() { + dst.Set(reflect.MakeMap(dst.Type())) + } + key := reflect.New(dst.Type().Key()).Elem() + val := reflect.New(dst.Type().Elem()).Elem() + + // The map entry should be this sequence of tokens: + // < key : KEY value : VALUE > + // Technically the "key" and "value" could come in any order, + // but in practice they won't. + + tok := p.next() + var terminator string + switch tok.value { + case "<": + terminator = ">" + case "{": + terminator = "}" + default: + return p.errorf("expected '{' or '<', found %q", tok.value) + } + if err := p.consumeToken("key"); err != nil { + return err + } + if err := p.consumeToken(":"); err != nil { + return err + } + if err := p.readAny(key, props.mkeyprop); err != nil { + return err + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + if err := p.consumeToken("value"); err != nil { + return err + } + if err := p.checkForColon(props.mvalprop, dst.Type().Elem()); err != nil { + return err + } + if err := p.readAny(val, props.mvalprop); err != nil { + return err + } + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + if err := p.consumeToken(terminator); err != nil { + return err + } + + dst.SetMapIndex(key, val) + continue + } + + // Check that it's not already set if it's not a repeated field. + if !props.Repeated && fieldSet[name] { + return p.errorf("non-repeated field %q was repeated", name) + } + + if err := p.checkForColon(props, dst.Type()); err != nil { + return err + } + + // Parse into the field. + fieldSet[name] = true + if err := p.readAny(dst, props); err != nil { + if _, ok := err.(*RequiredNotSetError); !ok { + return err + } + reqFieldErr = err + } else if props.Required { + reqCount-- + } + + if err := p.consumeOptionalSeparator(); err != nil { + return err + } + + } + + if reqCount > 0 { + return p.missingRequiredFieldError(sv) + } + return reqFieldErr +} + +// consumeOptionalSeparator consumes an optional semicolon or comma. +// It is used in readStruct to provide backward compatibility. +func (p *textParser) consumeOptionalSeparator() error { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value != ";" && tok.value != "," { + p.back() + } + return nil +} + +func (p *textParser) readAny(v reflect.Value, props *Properties) error { + tok := p.next() + if tok.err != nil { + return tok.err + } + if tok.value == "" { + return p.errorf("unexpected EOF") + } + + switch fv := v; fv.Kind() { + case reflect.Slice: + at := v.Type() + if at.Elem().Kind() == reflect.Uint8 { + // Special case for []byte + if tok.value[0] != '"' && tok.value[0] != '\'' { + // Deliberately written out here, as the error after + // this switch statement would write "invalid []byte: ...", + // which is not as user-friendly. + return p.errorf("invalid string: %v", tok.value) + } + bytes := []byte(tok.unquoted) + fv.Set(reflect.ValueOf(bytes)) + return nil + } + // Repeated field. May already exist. + flen := fv.Len() + if flen == fv.Cap() { + nav := reflect.MakeSlice(at, flen, 2*flen+1) + reflect.Copy(nav, fv) + fv.Set(nav) + } + fv.SetLen(flen + 1) + + // Read one. + p.back() + return p.readAny(fv.Index(flen), props) + case reflect.Bool: + // Either "true", "false", 1 or 0. + switch tok.value { + case "true", "1": + fv.SetBool(true) + return nil + case "false", "0": + fv.SetBool(false) + return nil + } + case reflect.Float32, reflect.Float64: + v := tok.value + // Ignore 'f' for compatibility with output generated by C++, but don't + // remove 'f' when the value is "-inf" or "inf". + if strings.HasSuffix(v, "f") && tok.value != "-inf" && tok.value != "inf" { + v = v[:len(v)-1] + } + if f, err := strconv.ParseFloat(v, fv.Type().Bits()); err == nil { + fv.SetFloat(f) + return nil + } + case reflect.Int32: + if x, err := strconv.ParseInt(tok.value, 0, 32); err == nil { + fv.SetInt(x) + return nil + } + + if len(props.Enum) == 0 { + break + } + m, ok := enumValueMaps[props.Enum] + if !ok { + break + } + x, ok := m[tok.value] + if !ok { + break + } + fv.SetInt(int64(x)) + return nil + case reflect.Int64: + if x, err := strconv.ParseInt(tok.value, 0, 64); err == nil { + fv.SetInt(x) + return nil + } + + case reflect.Ptr: + // A basic field (indirected through pointer), or a repeated message/group + p.back() + fv.Set(reflect.New(fv.Type().Elem())) + return p.readAny(fv.Elem(), props) + case reflect.String: + if tok.value[0] == '"' || tok.value[0] == '\'' { + fv.SetString(tok.unquoted) + return nil + } + case reflect.Struct: + var terminator string + switch tok.value { + case "{": + terminator = "}" + case "<": + terminator = ">" + default: + return p.errorf("expected '{' or '<', found %q", tok.value) + } + // TODO: Handle nested messages which implement encoding.TextUnmarshaler. + return p.readStruct(fv, terminator) + case reflect.Uint32: + if x, err := strconv.ParseUint(tok.value, 0, 32); err == nil { + fv.SetUint(uint64(x)) + return nil + } + case reflect.Uint64: + if x, err := strconv.ParseUint(tok.value, 0, 64); err == nil { + fv.SetUint(x) + return nil + } + } + return p.errorf("invalid %v: %v", v.Type(), tok.value) +} + +// UnmarshalText reads a protocol buffer in Text format. UnmarshalText resets pb +// before starting to unmarshal, so any existing data in pb is always removed. +// If a required field is not set and no other error occurs, +// UnmarshalText returns *RequiredNotSetError. +func UnmarshalText(s string, pb Message) error { + if um, ok := pb.(encoding.TextUnmarshaler); ok { + err := um.UnmarshalText([]byte(s)) + return err + } + pb.Reset() + v := reflect.ValueOf(pb) + if pe := newTextParser(s).readStruct(v.Elem(), ""); pe != nil { + return pe + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/text_parser_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/text_parser_test.go new file mode 100644 index 0000000..a2a9604 --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/text_parser_test.go @@ -0,0 +1,523 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "math" + "reflect" + "testing" + + . "github.com/golang/protobuf/proto" + proto3pb "github.com/golang/protobuf/proto/proto3_proto" + . "github.com/golang/protobuf/proto/testdata" +) + +type UnmarshalTextTest struct { + in string + err string // if "", no error expected + out *MyMessage +} + +func buildExtStructTest(text string) UnmarshalTextTest { + msg := &MyMessage{ + Count: Int32(42), + } + SetExtension(msg, E_Ext_More, &Ext{ + Data: String("Hello, world!"), + }) + return UnmarshalTextTest{in: text, out: msg} +} + +func buildExtDataTest(text string) UnmarshalTextTest { + msg := &MyMessage{ + Count: Int32(42), + } + SetExtension(msg, E_Ext_Text, String("Hello, world!")) + SetExtension(msg, E_Ext_Number, Int32(1729)) + return UnmarshalTextTest{in: text, out: msg} +} + +func buildExtRepStringTest(text string) UnmarshalTextTest { + msg := &MyMessage{ + Count: Int32(42), + } + if err := SetExtension(msg, E_Greeting, []string{"bula", "hola"}); err != nil { + panic(err) + } + return UnmarshalTextTest{in: text, out: msg} +} + +var unMarshalTextTests = []UnmarshalTextTest{ + // Basic + { + in: " count:42\n name:\"Dave\" ", + out: &MyMessage{ + Count: Int32(42), + Name: String("Dave"), + }, + }, + + // Empty quoted string + { + in: `count:42 name:""`, + out: &MyMessage{ + Count: Int32(42), + Name: String(""), + }, + }, + + // Quoted string concatenation + { + in: `count:42 name: "My name is "` + "\n" + `"elsewhere"`, + out: &MyMessage{ + Count: Int32(42), + Name: String("My name is elsewhere"), + }, + }, + + // Quoted string with escaped apostrophe + { + in: `count:42 name: "HOLIDAY - New Year\'s Day"`, + out: &MyMessage{ + Count: Int32(42), + Name: String("HOLIDAY - New Year's Day"), + }, + }, + + // Quoted string with single quote + { + in: `count:42 name: 'Roger "The Ramster" Ramjet'`, + out: &MyMessage{ + Count: Int32(42), + Name: String(`Roger "The Ramster" Ramjet`), + }, + }, + + // Quoted string with all the accepted special characters from the C++ test + { + in: `count:42 name: ` + "\"\\\"A string with \\' characters \\n and \\r newlines and \\t tabs and \\001 slashes \\\\ and multiple spaces\"", + out: &MyMessage{ + Count: Int32(42), + Name: String("\"A string with ' characters \n and \r newlines and \t tabs and \001 slashes \\ and multiple spaces"), + }, + }, + + // Quoted string with quoted backslash + { + in: `count:42 name: "\\'xyz"`, + out: &MyMessage{ + Count: Int32(42), + Name: String(`\'xyz`), + }, + }, + + // Quoted string with UTF-8 bytes. + { + in: "count:42 name: '\303\277\302\201\xAB'", + out: &MyMessage{ + Count: Int32(42), + Name: String("\303\277\302\201\xAB"), + }, + }, + + // Bad quoted string + { + in: `inner: < host: "\0" >` + "\n", + err: `line 1.15: invalid quoted string "\0": \0 requires 2 following digits`, + }, + + // Number too large for int64 + { + in: "count: 1 others { key: 123456789012345678901 }", + err: "line 1.23: invalid int64: 123456789012345678901", + }, + + // Number too large for int32 + { + in: "count: 1234567890123", + err: "line 1.7: invalid int32: 1234567890123", + }, + + // Number in hexadecimal + { + in: "count: 0x2beef", + out: &MyMessage{ + Count: Int32(0x2beef), + }, + }, + + // Number in octal + { + in: "count: 024601", + out: &MyMessage{ + Count: Int32(024601), + }, + }, + + // Floating point number with "f" suffix + { + in: "count: 4 others:< weight: 17.0f >", + out: &MyMessage{ + Count: Int32(4), + Others: []*OtherMessage{ + { + Weight: Float32(17), + }, + }, + }, + }, + + // Floating point positive infinity + { + in: "count: 4 bigfloat: inf", + out: &MyMessage{ + Count: Int32(4), + Bigfloat: Float64(math.Inf(1)), + }, + }, + + // Floating point negative infinity + { + in: "count: 4 bigfloat: -inf", + out: &MyMessage{ + Count: Int32(4), + Bigfloat: Float64(math.Inf(-1)), + }, + }, + + // Number too large for float32 + { + in: "others:< weight: 12345678901234567890123456789012345678901234567890 >", + err: "line 1.17: invalid float32: 12345678901234567890123456789012345678901234567890", + }, + + // Number posing as a quoted string + { + in: `inner: < host: 12 >` + "\n", + err: `line 1.15: invalid string: 12`, + }, + + // Quoted string posing as int32 + { + in: `count: "12"`, + err: `line 1.7: invalid int32: "12"`, + }, + + // Quoted string posing a float32 + { + in: `others:< weight: "17.4" >`, + err: `line 1.17: invalid float32: "17.4"`, + }, + + // Enum + { + in: `count:42 bikeshed: BLUE`, + out: &MyMessage{ + Count: Int32(42), + Bikeshed: MyMessage_BLUE.Enum(), + }, + }, + + // Repeated field + { + in: `count:42 pet: "horsey" pet:"bunny"`, + out: &MyMessage{ + Count: Int32(42), + Pet: []string{"horsey", "bunny"}, + }, + }, + + // Repeated message with/without colon and <>/{} + { + in: `count:42 others:{} others{} others:<> others:{}`, + out: &MyMessage{ + Count: Int32(42), + Others: []*OtherMessage{ + {}, + {}, + {}, + {}, + }, + }, + }, + + // Missing colon for inner message + { + in: `count:42 inner < host: "cauchy.syd" >`, + out: &MyMessage{ + Count: Int32(42), + Inner: &InnerMessage{ + Host: String("cauchy.syd"), + }, + }, + }, + + // Missing colon for string field + { + in: `name "Dave"`, + err: `line 1.5: expected ':', found "\"Dave\""`, + }, + + // Missing colon for int32 field + { + in: `count 42`, + err: `line 1.6: expected ':', found "42"`, + }, + + // Missing required field + { + in: `name: "Pawel"`, + err: `proto: required field "testdata.MyMessage.count" not set`, + out: &MyMessage{ + Name: String("Pawel"), + }, + }, + + // Repeated non-repeated field + { + in: `name: "Rob" name: "Russ"`, + err: `line 1.12: non-repeated field "name" was repeated`, + }, + + // Group + { + in: `count: 17 SomeGroup { group_field: 12 }`, + out: &MyMessage{ + Count: Int32(17), + Somegroup: &MyMessage_SomeGroup{ + GroupField: Int32(12), + }, + }, + }, + + // Semicolon between fields + { + in: `count:3;name:"Calvin"`, + out: &MyMessage{ + Count: Int32(3), + Name: String("Calvin"), + }, + }, + // Comma between fields + { + in: `count:4,name:"Ezekiel"`, + out: &MyMessage{ + Count: Int32(4), + Name: String("Ezekiel"), + }, + }, + + // Extension + buildExtStructTest(`count: 42 [testdata.Ext.more]:`), + buildExtStructTest(`count: 42 [testdata.Ext.more] {data:"Hello, world!"}`), + buildExtDataTest(`count: 42 [testdata.Ext.text]:"Hello, world!" [testdata.Ext.number]:1729`), + buildExtRepStringTest(`count: 42 [testdata.greeting]:"bula" [testdata.greeting]:"hola"`), + + // Big all-in-one + { + in: "count:42 # Meaning\n" + + `name:"Dave" ` + + `quote:"\"I didn't want to go.\"" ` + + `pet:"bunny" ` + + `pet:"kitty" ` + + `pet:"horsey" ` + + `inner:<` + + ` host:"footrest.syd" ` + + ` port:7001 ` + + ` connected:true ` + + `> ` + + `others:<` + + ` key:3735928559 ` + + ` value:"\x01A\a\f" ` + + `> ` + + `others:<` + + " weight:58.9 # Atomic weight of Co\n" + + ` inner:<` + + ` host:"lesha.mtv" ` + + ` port:8002 ` + + ` >` + + `>`, + out: &MyMessage{ + Count: Int32(42), + Name: String("Dave"), + Quote: String(`"I didn't want to go."`), + Pet: []string{"bunny", "kitty", "horsey"}, + Inner: &InnerMessage{ + Host: String("footrest.syd"), + Port: Int32(7001), + Connected: Bool(true), + }, + Others: []*OtherMessage{ + { + Key: Int64(3735928559), + Value: []byte{0x1, 'A', '\a', '\f'}, + }, + { + Weight: Float32(58.9), + Inner: &InnerMessage{ + Host: String("lesha.mtv"), + Port: Int32(8002), + }, + }, + }, + }, + }, +} + +func TestUnmarshalText(t *testing.T) { + for i, test := range unMarshalTextTests { + pb := new(MyMessage) + err := UnmarshalText(test.in, pb) + if test.err == "" { + // We don't expect failure. + if err != nil { + t.Errorf("Test %d: Unexpected error: %v", i, err) + } else if !reflect.DeepEqual(pb, test.out) { + t.Errorf("Test %d: Incorrect populated \nHave: %v\nWant: %v", + i, pb, test.out) + } + } else { + // We do expect failure. + if err == nil { + t.Errorf("Test %d: Didn't get expected error: %v", i, test.err) + } else if err.Error() != test.err { + t.Errorf("Test %d: Incorrect error.\nHave: %v\nWant: %v", + i, err.Error(), test.err) + } else if _, ok := err.(*RequiredNotSetError); ok && test.out != nil && !reflect.DeepEqual(pb, test.out) { + t.Errorf("Test %d: Incorrect populated \nHave: %v\nWant: %v", + i, pb, test.out) + } + } + } +} + +func TestUnmarshalTextCustomMessage(t *testing.T) { + msg := &textMessage{} + if err := UnmarshalText("custom", msg); err != nil { + t.Errorf("Unexpected error from custom unmarshal: %v", err) + } + if UnmarshalText("not custom", msg) == nil { + t.Errorf("Didn't get expected error from custom unmarshal") + } +} + +// Regression test; this caused a panic. +func TestRepeatedEnum(t *testing.T) { + pb := new(RepeatedEnum) + if err := UnmarshalText("color: RED", pb); err != nil { + t.Fatal(err) + } + exp := &RepeatedEnum{ + Color: []RepeatedEnum_Color{RepeatedEnum_RED}, + } + if !Equal(pb, exp) { + t.Errorf("Incorrect populated \nHave: %v\nWant: %v", pb, exp) + } +} + +func TestProto3TextParsing(t *testing.T) { + m := new(proto3pb.Message) + const in = `name: "Wallace" true_scotsman: true` + want := &proto3pb.Message{ + Name: "Wallace", + TrueScotsman: true, + } + if err := UnmarshalText(in, m); err != nil { + t.Fatal(err) + } + if !Equal(m, want) { + t.Errorf("\n got %v\nwant %v", m, want) + } +} + +func TestMapParsing(t *testing.T) { + m := new(MessageWithMap) + const in = `name_mapping: name_mapping:` + + `msg_mapping:,>` + // separating commas are okay + `msg_mapping>` + // no colon after "value" + `byte_mapping:` + want := &MessageWithMap{ + NameMapping: map[int32]string{ + 1: "Beatles", + 1234: "Feist", + }, + MsgMapping: map[int64]*FloatingPoint{ + -4: {F: Float64(2.0)}, + -2: {F: Float64(4.0)}, + }, + ByteMapping: map[bool][]byte{ + true: []byte("so be it"), + }, + } + if err := UnmarshalText(in, m); err != nil { + t.Fatal(err) + } + if !Equal(m, want) { + t.Errorf("\n got %v\nwant %v", m, want) + } +} + +func TestOneofParsing(t *testing.T) { + const in = `name:"Shrek"` + m := new(Communique) + want := &Communique{Union: &Communique_Name{"Shrek"}} + if err := UnmarshalText(in, m); err != nil { + t.Fatal(err) + } + if !Equal(m, want) { + t.Errorf("\n got %v\nwant %v", m, want) + } +} + +var benchInput string + +func init() { + benchInput = "count: 4\n" + for i := 0; i < 1000; i++ { + benchInput += "pet: \"fido\"\n" + } + + // Check it is valid input. + pb := new(MyMessage) + err := UnmarshalText(benchInput, pb) + if err != nil { + panic("Bad benchmark input: " + err.Error()) + } +} + +func BenchmarkUnmarshalText(b *testing.B) { + pb := new(MyMessage) + for i := 0; i < b.N; i++ { + UnmarshalText(benchInput, pb) + } + b.SetBytes(int64(len(benchInput))) +} diff --git a/Godeps/_workspace/src/github.com/golang/protobuf/proto/text_test.go b/Godeps/_workspace/src/github.com/golang/protobuf/proto/text_test.go new file mode 100644 index 0000000..3eabaca --- /dev/null +++ b/Godeps/_workspace/src/github.com/golang/protobuf/proto/text_test.go @@ -0,0 +1,474 @@ +// Go support for Protocol Buffers - Google's data interchange format +// +// Copyright 2010 The Go Authors. All rights reserved. +// https://github.com/golang/protobuf +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package proto_test + +import ( + "bytes" + "errors" + "io/ioutil" + "math" + "strings" + "testing" + + "github.com/golang/protobuf/proto" + + proto3pb "github.com/golang/protobuf/proto/proto3_proto" + pb "github.com/golang/protobuf/proto/testdata" +) + +// textMessage implements the methods that allow it to marshal and unmarshal +// itself as text. +type textMessage struct { +} + +func (*textMessage) MarshalText() ([]byte, error) { + return []byte("custom"), nil +} + +func (*textMessage) UnmarshalText(bytes []byte) error { + if string(bytes) != "custom" { + return errors.New("expected 'custom'") + } + return nil +} + +func (*textMessage) Reset() {} +func (*textMessage) String() string { return "" } +func (*textMessage) ProtoMessage() {} + +func newTestMessage() *pb.MyMessage { + msg := &pb.MyMessage{ + Count: proto.Int32(42), + Name: proto.String("Dave"), + Quote: proto.String(`"I didn't want to go."`), + Pet: []string{"bunny", "kitty", "horsey"}, + Inner: &pb.InnerMessage{ + Host: proto.String("footrest.syd"), + Port: proto.Int32(7001), + Connected: proto.Bool(true), + }, + Others: []*pb.OtherMessage{ + { + Key: proto.Int64(0xdeadbeef), + Value: []byte{1, 65, 7, 12}, + }, + { + Weight: proto.Float32(6.022), + Inner: &pb.InnerMessage{ + Host: proto.String("lesha.mtv"), + Port: proto.Int32(8002), + }, + }, + }, + Bikeshed: pb.MyMessage_BLUE.Enum(), + Somegroup: &pb.MyMessage_SomeGroup{ + GroupField: proto.Int32(8), + }, + // One normally wouldn't do this. + // This is an undeclared tag 13, as a varint (wire type 0) with value 4. + XXX_unrecognized: []byte{13<<3 | 0, 4}, + } + ext := &pb.Ext{ + Data: proto.String("Big gobs for big rats"), + } + if err := proto.SetExtension(msg, pb.E_Ext_More, ext); err != nil { + panic(err) + } + greetings := []string{"adg", "easy", "cow"} + if err := proto.SetExtension(msg, pb.E_Greeting, greetings); err != nil { + panic(err) + } + + // Add an unknown extension. We marshal a pb.Ext, and fake the ID. + b, err := proto.Marshal(&pb.Ext{Data: proto.String("3G skiing")}) + if err != nil { + panic(err) + } + b = append(proto.EncodeVarint(201<<3|proto.WireBytes), b...) + proto.SetRawExtension(msg, 201, b) + + // Extensions can be plain fields, too, so let's test that. + b = append(proto.EncodeVarint(202<<3|proto.WireVarint), 19) + proto.SetRawExtension(msg, 202, b) + + return msg +} + +const text = `count: 42 +name: "Dave" +quote: "\"I didn't want to go.\"" +pet: "bunny" +pet: "kitty" +pet: "horsey" +inner: < + host: "footrest.syd" + port: 7001 + connected: true +> +others: < + key: 3735928559 + value: "\001A\007\014" +> +others: < + weight: 6.022 + inner: < + host: "lesha.mtv" + port: 8002 + > +> +bikeshed: BLUE +SomeGroup { + group_field: 8 +} +/* 2 unknown bytes */ +13: 4 +[testdata.Ext.more]: < + data: "Big gobs for big rats" +> +[testdata.greeting]: "adg" +[testdata.greeting]: "easy" +[testdata.greeting]: "cow" +/* 13 unknown bytes */ +201: "\t3G skiing" +/* 3 unknown bytes */ +202: 19 +` + +func TestMarshalText(t *testing.T) { + buf := new(bytes.Buffer) + if err := proto.MarshalText(buf, newTestMessage()); err != nil { + t.Fatalf("proto.MarshalText: %v", err) + } + s := buf.String() + if s != text { + t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", s, text) + } +} + +func TestMarshalTextCustomMessage(t *testing.T) { + buf := new(bytes.Buffer) + if err := proto.MarshalText(buf, &textMessage{}); err != nil { + t.Fatalf("proto.MarshalText: %v", err) + } + s := buf.String() + if s != "custom" { + t.Errorf("Got %q, expected %q", s, "custom") + } +} +func TestMarshalTextNil(t *testing.T) { + want := "" + tests := []proto.Message{nil, (*pb.MyMessage)(nil)} + for i, test := range tests { + buf := new(bytes.Buffer) + if err := proto.MarshalText(buf, test); err != nil { + t.Fatal(err) + } + if got := buf.String(); got != want { + t.Errorf("%d: got %q want %q", i, got, want) + } + } +} + +func TestMarshalTextUnknownEnum(t *testing.T) { + // The Color enum only specifies values 0-2. + m := &pb.MyMessage{Bikeshed: pb.MyMessage_Color(3).Enum()} + got := m.String() + const want = `bikeshed:3 ` + if got != want { + t.Errorf("\n got %q\nwant %q", got, want) + } +} + +func TestTextOneof(t *testing.T) { + tests := []struct { + m proto.Message + want string + }{ + // zero message + {&pb.Communique{}, ``}, + // scalar field + {&pb.Communique{Union: &pb.Communique_Number{4}}, `number:4`}, + // message field + {&pb.Communique{Union: &pb.Communique_Msg{ + &pb.Strings{StringField: proto.String("why hello!")}, + }}, `msg:`}, + // bad oneof (should not panic) + {&pb.Communique{Union: &pb.Communique_Msg{nil}}, `msg:/* nil */`}, + } + for _, test := range tests { + got := strings.TrimSpace(test.m.String()) + if got != test.want { + t.Errorf("\n got %s\nwant %s", got, test.want) + } + } +} + +func BenchmarkMarshalTextBuffered(b *testing.B) { + buf := new(bytes.Buffer) + m := newTestMessage() + for i := 0; i < b.N; i++ { + buf.Reset() + proto.MarshalText(buf, m) + } +} + +func BenchmarkMarshalTextUnbuffered(b *testing.B) { + w := ioutil.Discard + m := newTestMessage() + for i := 0; i < b.N; i++ { + proto.MarshalText(w, m) + } +} + +func compact(src string) string { + // s/[ \n]+/ /g; s/ $//; + dst := make([]byte, len(src)) + space, comment := false, false + j := 0 + for i := 0; i < len(src); i++ { + if strings.HasPrefix(src[i:], "/*") { + comment = true + i++ + continue + } + if comment && strings.HasPrefix(src[i:], "*/") { + comment = false + i++ + continue + } + if comment { + continue + } + c := src[i] + if c == ' ' || c == '\n' { + space = true + continue + } + if j > 0 && (dst[j-1] == ':' || dst[j-1] == '<' || dst[j-1] == '{') { + space = false + } + if c == '{' { + space = false + } + if space { + dst[j] = ' ' + j++ + space = false + } + dst[j] = c + j++ + } + if space { + dst[j] = ' ' + j++ + } + return string(dst[0:j]) +} + +var compactText = compact(text) + +func TestCompactText(t *testing.T) { + s := proto.CompactTextString(newTestMessage()) + if s != compactText { + t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v\n===\n", s, compactText) + } +} + +func TestStringEscaping(t *testing.T) { + testCases := []struct { + in *pb.Strings + out string + }{ + { + // Test data from C++ test (TextFormatTest.StringEscape). + // Single divergence: we don't escape apostrophes. + &pb.Strings{StringField: proto.String("\"A string with ' characters \n and \r newlines and \t tabs and \001 slashes \\ and multiple spaces")}, + "string_field: \"\\\"A string with ' characters \\n and \\r newlines and \\t tabs and \\001 slashes \\\\ and multiple spaces\"\n", + }, + { + // Test data from the same C++ test. + &pb.Strings{StringField: proto.String("\350\260\267\346\255\214")}, + "string_field: \"\\350\\260\\267\\346\\255\\214\"\n", + }, + { + // Some UTF-8. + &pb.Strings{StringField: proto.String("\x00\x01\xff\x81")}, + `string_field: "\000\001\377\201"` + "\n", + }, + } + + for i, tc := range testCases { + var buf bytes.Buffer + if err := proto.MarshalText(&buf, tc.in); err != nil { + t.Errorf("proto.MarsalText: %v", err) + continue + } + s := buf.String() + if s != tc.out { + t.Errorf("#%d: Got:\n%s\nExpected:\n%s\n", i, s, tc.out) + continue + } + + // Check round-trip. + pb := new(pb.Strings) + if err := proto.UnmarshalText(s, pb); err != nil { + t.Errorf("#%d: UnmarshalText: %v", i, err) + continue + } + if !proto.Equal(pb, tc.in) { + t.Errorf("#%d: Round-trip failed:\nstart: %v\n end: %v", i, tc.in, pb) + } + } +} + +// A limitedWriter accepts some output before it fails. +// This is a proxy for something like a nearly-full or imminently-failing disk, +// or a network connection that is about to die. +type limitedWriter struct { + b bytes.Buffer + limit int +} + +var outOfSpace = errors.New("proto: insufficient space") + +func (w *limitedWriter) Write(p []byte) (n int, err error) { + var avail = w.limit - w.b.Len() + if avail <= 0 { + return 0, outOfSpace + } + if len(p) <= avail { + return w.b.Write(p) + } + n, _ = w.b.Write(p[:avail]) + return n, outOfSpace +} + +func TestMarshalTextFailing(t *testing.T) { + // Try lots of different sizes to exercise more error code-paths. + for lim := 0; lim < len(text); lim++ { + buf := new(limitedWriter) + buf.limit = lim + err := proto.MarshalText(buf, newTestMessage()) + // We expect a certain error, but also some partial results in the buffer. + if err != outOfSpace { + t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", err, outOfSpace) + } + s := buf.b.String() + x := text[:buf.limit] + if s != x { + t.Errorf("Got:\n===\n%v===\nExpected:\n===\n%v===\n", s, x) + } + } +} + +func TestFloats(t *testing.T) { + tests := []struct { + f float64 + want string + }{ + {0, "0"}, + {4.7, "4.7"}, + {math.Inf(1), "inf"}, + {math.Inf(-1), "-inf"}, + {math.NaN(), "nan"}, + } + for _, test := range tests { + msg := &pb.FloatingPoint{F: &test.f} + got := strings.TrimSpace(msg.String()) + want := `f:` + test.want + if got != want { + t.Errorf("f=%f: got %q, want %q", test.f, got, want) + } + } +} + +func TestRepeatedNilText(t *testing.T) { + m := &pb.MessageList{ + Message: []*pb.MessageList_Message{ + nil, + &pb.MessageList_Message{ + Name: proto.String("Horse"), + }, + nil, + }, + } + want := `Message +Message { + name: "Horse" +} +Message +` + if s := proto.MarshalTextString(m); s != want { + t.Errorf(" got: %s\nwant: %s", s, want) + } +} + +func TestProto3Text(t *testing.T) { + tests := []struct { + m proto.Message + want string + }{ + // zero message + {&proto3pb.Message{}, ``}, + // zero message except for an empty byte slice + {&proto3pb.Message{Data: []byte{}}, ``}, + // trivial case + {&proto3pb.Message{Name: "Rob", HeightInCm: 175}, `name:"Rob" height_in_cm:175`}, + // empty map + {&pb.MessageWithMap{}, ``}, + // non-empty map; map format is the same as a repeated struct, + // and they are sorted by key (numerically for numeric keys). + { + &pb.MessageWithMap{NameMapping: map[int32]string{ + -1: "Negatory", + 7: "Lucky", + 1234: "Feist", + 6345789: "Otis", + }}, + `name_mapping: ` + + `name_mapping: ` + + `name_mapping: ` + + `name_mapping:`, + }, + // map with nil value; not well-defined, but we shouldn't crash + { + &pb.MessageWithMap{MsgMapping: map[int64]*pb.FloatingPoint{7: nil}}, + `msg_mapping:`, + }, + } + for _, test := range tests { + got := strings.TrimSpace(test.m.String()) + if got != test.want { + t.Errorf("\n got %s\nwant %s", got, test.want) + } + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/log/.gitignore b/Godeps/_workspace/src/github.com/lunny/log/.gitignore new file mode 100644 index 0000000..b9dcb21 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/log/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +log.db +*.log \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/lunny/log/README.md b/Godeps/_workspace/src/github.com/lunny/log/README.md new file mode 100644 index 0000000..8ed4b8a --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/log/README.md @@ -0,0 +1,43 @@ +## log +[![GoDoc](https://godoc.org/github.com/lunny/log?status.png)](https://godoc.org/github.com/lunny/log) + +[简体中文](https://github.com/lunny/log/blob/master/README_CN.md) + +# Installation + +``` +go get github.com/lunny/log +``` + +# Features + +* Add color support for unix console +* Implemented dbwriter to save log to database +* Implemented FileWriter to save log to file by date or time. + +# Example + +For Single File: +```Go +f, _ := os.Create("my.log") +log.Std.SetOutput(f) +``` + +For Multiple Writer: +```Go +f, _ := os.Create("my.log") +log.Std.SetOutput(io.MultiWriter(f, os.Stdout)) +``` + +For log files by date or time: +```Go +w := log.NewFileWriter(log.FileOptions{ + ByType:log.ByDay, + Dir:"./logs", +}) +log.Std.SetOutput(w) +``` + +# About + +This repo is an extension of Golang log. \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/lunny/log/README_CN.md b/Godeps/_workspace/src/github.com/lunny/log/README_CN.md new file mode 100644 index 0000000..b31e087 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/log/README_CN.md @@ -0,0 +1,46 @@ +## log +[![GoDoc](https://godoc.org/github.com/lunny/log?status.png)](https://godoc.org/github.com/lunny/log) + +[English](https://github.com/lunny/log/blob/master/README.md) + +# 安装 + +``` +go get github.com/lunny/log +``` + +# 特性 + +* 对unix增加控制台颜色支持 +* 实现了保存log到数据库支持 +* 实现了保存log到按日期的文件支持 + +# 例子 + +保存到单个文件: + +```Go +f, _ := os.Create("my.log") +log.Std.SetOutput(f) +``` + +保存到数据库: + +```Go +f, _ := os.Create("my.log") +log.Std.SetOutput(io.MultiWriter(f, os.Stdout)) +``` + +保存到按时间分隔的文件: + +```Go +w := log.NewFileWriter(log.FileOptions{ + ByType:log.ByDay, + Dir:"./logs", +}) +log.Std.SetOutput(w) +``` + +# 关于 + +本 Log 是在 golang 的 log 之上的扩展 \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/lunny/log/dbwriter.go b/Godeps/_workspace/src/github.com/lunny/log/dbwriter.go new file mode 100644 index 0000000..e8ff00b --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/log/dbwriter.go @@ -0,0 +1,36 @@ +package log + +import ( + "database/sql" + "time" +) + +type DBWriter struct { + db *sql.DB + stmt *sql.Stmt + content chan []byte +} + +func NewDBWriter(db *sql.DB) (*DBWriter, error) { + _, err := db.Exec("CREATE TABLE IF NOT EXISTS log (id int, content text, created datetime)") + if err != nil { + return nil, err + } + stmt, err := db.Prepare("INSERT INTO log (content, created) values (?, ?)") + if err != nil { + return nil, err + } + return &DBWriter{db, stmt, make(chan []byte, 1000)}, nil +} + +func (w *DBWriter) Write(p []byte) (n int, err error) { + _, err = w.stmt.Exec(string(p), time.Now()) + if err == nil { + n = len(p) + } + return +} + +func (w *DBWriter) Close() { + w.stmt.Close() +} diff --git a/Godeps/_workspace/src/github.com/lunny/log/filewriter.go b/Godeps/_workspace/src/github.com/lunny/log/filewriter.go new file mode 100644 index 0000000..37ff05c --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/log/filewriter.go @@ -0,0 +1,103 @@ +package log + +import ( + "io" + "os" + "path/filepath" + "sync" + "time" +) + +var _ io.Writer = &Files{} + +type ByType int + +const ( + ByDay ByType = iota + ByHour + ByMonth +) + +var ( + formats = map[ByType]string{ + ByDay: "2006-01-02", + ByHour: "2006-01-02-15", + ByMonth: "2006-01", + } +) + +func SetFileFormat(t ByType, format string) { + formats[t] = format +} + +func (b ByType) Format() string { + return formats[b] +} + +type Files struct { + FileOptions + f *os.File + lastFormat string + lock sync.Mutex +} + +type FileOptions struct { + Dir string + ByType ByType +} + +func prePareFileOption(opts []FileOptions) FileOptions { + var opt FileOptions + if len(opts) > 0 { + opt = opts[0] + } + if opt.Dir == "" { + opt.Dir = "./" + } + return opt +} + +func NewFileWriter(opts ...FileOptions) *Files { + opt := prePareFileOption(opts) + return &Files{ + FileOptions: opt, + } +} + +func (f *Files) getFile() (*os.File, error) { + if f.f == nil { + f.lastFormat = time.Now().Format(f.ByType.Format()) + var err error + f.f, err = os.OpenFile(filepath.Join(f.Dir, f.lastFormat+".log"), + os.O_CREATE | os.O_APPEND | os.O_WRONLY, os.ModePerm) + return f.f, err + } + if f.lastFormat != time.Now().Format(f.ByType.Format()) { + f.f.Close() + f.lastFormat = time.Now().Format(f.ByType.Format()) + var err error + f.f, err = os.OpenFile(filepath.Join(f.Dir, f.lastFormat+".log"), + os.O_CREATE | os.O_APPEND | os.O_WRONLY, os.ModePerm) + return f.f, err + } + return f.f, nil +} + +func (f *Files) Write(bs []byte) (int, error) { + f.lock.Lock() + defer f.lock.Unlock() + + w, err := f.getFile() + if err != nil { + return 0, err + } + return w.Write(bs) +} + +func (f *Files) Close() { + if f.f != nil { + f.f.Close() + f.f = nil + } + f.lastFormat = "" +} diff --git a/Godeps/_workspace/src/github.com/lunny/log/filewriter_test.go b/Godeps/_workspace/src/github.com/lunny/log/filewriter_test.go new file mode 100644 index 0000000..20e2682 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/log/filewriter_test.go @@ -0,0 +1,15 @@ +package log + +import ( + "testing" +) + +func TestFiles(t *testing.T) { + w := NewFileWriter(FileOptions{ByType:ByDay}) + Std.SetOutput(w) + Std.SetFlags(Std.Flags() | ^Llongcolor | ^Lshortcolor) + Info("test") + Debug("ssss") + Warn("ssss") + w.Close() +} diff --git a/Godeps/_workspace/src/github.com/lunny/log/logext.go b/Godeps/_workspace/src/github.com/lunny/log/logext.go new file mode 100644 index 0000000..dccd55a --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/log/logext.go @@ -0,0 +1,567 @@ +package log + +import ( + "bytes" + "fmt" + "io" + "os" + "runtime" + "strings" + "sync" + "time" +) + +// These flags define which text to prefix to each log entry generated by the Logger. +const ( + // Bits or'ed together to control what's printed. There is no control over the + // order they appear (the order listed here) or the format they present (as + // described in the comments). A colon appears after these items: + // 2009/0123 01:23:23.123123 /a/b/c/d.go:23: message + Ldate = 1 << iota // the date: 2009/0123 + Ltime // the time: 01:23:23 + Lmicroseconds // microsecond resolution: 01:23:23.123123. assumes Ltime. + Llongfile // full file name and line number: /a/b/c/d.go:23 + Lshortfile // final file name element and line number: d.go:23. overrides Llongfile + Lmodule // module name + Llevel // level: 0(Debug), 1(Info), 2(Warn), 3(Error), 4(Panic), 5(Fatal) + Llongcolor // color will start [info] end of line + Lshortcolor // color only include [info] + LstdFlags = Ldate | Ltime // initial values for the standard logger + //Ldefault = Llevel | LstdFlags | Lshortfile | Llongcolor +) // [prefix][time][level][module][shortfile|longfile] + +func Ldefault() int { + if runtime.GOOS == "windows" { + return Llevel | LstdFlags | Lshortfile + } + return Llevel | LstdFlags | Lshortfile | Llongcolor +} + +const ( + Lall = iota +) +const ( + Ldebug = iota + Linfo + Lwarn + Lerror + Lpanic + Lfatal + Lnone +) + +const ( + ForeBlack = iota + 30 //30 + ForeRed //31 + ForeGreen //32 + ForeYellow //33 + ForeBlue //34 + ForePurple //35 + ForeCyan //36 + ForeWhite //37 +) + +const ( + BackBlack = iota + 40 //40 + BackRed //41 + BackGreen //42 + BackYellow //43 + BackBlue //44 + BackPurple //45 + BackCyan //46 + BackWhite //47 +) + +var levels = []string{ + "[Debug]", + "[Info]", + "[Warn]", + "[Error]", + "[Panic]", + "[Fatal]", +} + +// MUST called before all logs +func SetLevels(lvs []string) { + levels = lvs +} + +var colors = []int{ + ForeCyan, + ForeGreen, + ForeYellow, + ForeRed, + ForePurple, + ForeBlue, +} + +// MUST called before all logs +func SetColors(cls []int) { + colors = cls +} + +// A Logger represents an active logging object that generates lines of +// output to an io.Writer. Each logging operation makes a single call to +// the Writer's Write method. A Logger can be used simultaneously from +// multiple goroutines; it guarantees to serialize access to the Writer. +type Logger struct { + mu sync.Mutex // ensures atomic writes; protects the following fields + prefix string // prefix to write at beginning of each line + flag int // properties + Level int + out io.Writer // destination for output + buf bytes.Buffer // for accumulating text to write + levelStats [6]int64 +} + +// New creates a new Logger. The out variable sets the +// destination to which log data will be written. +// The prefix appears at the beginning of each generated log line. +// The flag argument defines the logging properties. +func New(out io.Writer, prefix string, flag int) *Logger { + return &Logger{out: out, prefix: prefix, Level: 1, flag: flag} +} + +var Std = New(os.Stderr, "", Ldefault()) + +// Cheap integer to fixed-width decimal ASCII. Give a negative width to avoid zero-padding. +// Knows the buffer has capacity. +func itoa(buf *bytes.Buffer, i int, wid int) { + var u uint = uint(i) + if u == 0 && wid <= 1 { + buf.WriteByte('0') + return + } + + // Assemble decimal in reverse order. + var b [32]byte + bp := len(b) + for ; u > 0 || wid > 0; u /= 10 { + bp-- + wid-- + b[bp] = byte(u%10) + '0' + } + + // avoid slicing b to avoid an allocation. + for bp < len(b) { + buf.WriteByte(b[bp]) + bp++ + } +} + +func moduleOf(file string) string { + pos := strings.LastIndex(file, "/") + if pos != -1 { + pos1 := strings.LastIndex(file[:pos], "/src/") + if pos1 != -1 { + return file[pos1+5 : pos] + } + } + return "UNKNOWN" +} + +func (l *Logger) formatHeader(buf *bytes.Buffer, t time.Time, + file string, line int, lvl int, reqId string) { + if l.prefix != "" { + buf.WriteString(l.prefix) + } + if l.flag&(Ldate|Ltime|Lmicroseconds) != 0 { + if l.flag&Ldate != 0 { + year, month, day := t.Date() + itoa(buf, year, 4) + buf.WriteByte('/') + itoa(buf, int(month), 2) + buf.WriteByte('/') + itoa(buf, day, 2) + buf.WriteByte(' ') + } + if l.flag&(Ltime|Lmicroseconds) != 0 { + hour, min, sec := t.Clock() + itoa(buf, hour, 2) + buf.WriteByte(':') + itoa(buf, min, 2) + buf.WriteByte(':') + itoa(buf, sec, 2) + if l.flag&Lmicroseconds != 0 { + buf.WriteByte('.') + itoa(buf, t.Nanosecond()/1e3, 6) + } + buf.WriteByte(' ') + } + } + if reqId != "" { + buf.WriteByte('[') + buf.WriteString(reqId) + buf.WriteByte(']') + buf.WriteByte(' ') + } + + if l.flag&(Lshortcolor|Llongcolor) != 0 { + buf.WriteString(fmt.Sprintf("\033[1;%dm", colors[lvl])) + } + if l.flag&Llevel != 0 { + buf.WriteString(levels[lvl]) + buf.WriteByte(' ') + } + if l.flag&Lshortcolor != 0 { + buf.WriteString("\033[0m") + } + + if l.flag&Lmodule != 0 { + buf.WriteByte('[') + buf.WriteString(moduleOf(file)) + buf.WriteByte(']') + buf.WriteByte(' ') + } + if l.flag&(Lshortfile|Llongfile) != 0 { + if l.flag&Lshortfile != 0 { + short := file + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + short = file[i+1:] + break + } + } + file = short + } + buf.WriteString(file) + buf.WriteByte(':') + itoa(buf, line, -1) + buf.WriteByte(' ') + } +} + +// Output writes the output for a logging event. The string s contains +// the text to print after the prefix specified by the flags of the +// Logger. A newline is appended if the last character of s is not +// already a newline. Calldepth is used to recover the PC and is +// provided for generality, although at the moment on all pre-defined +// paths it will be 2. +func (l *Logger) Output(reqId string, lvl int, calldepth int, s string) error { + if lvl < l.Level { + return nil + } + now := time.Now() // get this early. + var file string + var line int + l.mu.Lock() + defer l.mu.Unlock() + if l.flag&(Lshortfile|Llongfile|Lmodule) != 0 { + // release lock while getting caller info - it's expensive. + l.mu.Unlock() + var ok bool + _, file, line, ok = runtime.Caller(calldepth) + if !ok { + file = "???" + line = 0 + } + l.mu.Lock() + } + l.levelStats[lvl]++ + l.buf.Reset() + l.formatHeader(&l.buf, now, file, line, lvl, reqId) + l.buf.WriteString(s) + if l.flag&Llongcolor != 0 { + l.buf.WriteString("\033[0m") + } + if len(s) > 0 && s[len(s)-1] != '\n' { + l.buf.WriteByte('\n') + } + _, err := l.out.Write(l.buf.Bytes()) + return err +} + +// ----------------------------------------- + +// Printf calls l.Output to print to the logger. +// Arguments are handled in the manner of fmt.Printf. +func (l *Logger) Printf(format string, v ...interface{}) { + l.Output("", Linfo, 2, fmt.Sprintf(format, v...)) +} + +// Print calls l.Output to print to the logger. +// Arguments are handled in the manner of fmt.Print. +func (l *Logger) Print(v ...interface{}) { + l.Output("", Linfo, 2, fmt.Sprint(v...)) +} + +// Println calls l.Output to print to the logger. +// Arguments are handled in the manner of fmt.Println. +func (l *Logger) Println(v ...interface{}) { + l.Output("", Linfo, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func (l *Logger) Debugf(format string, v ...interface{}) { + l.Output("", Ldebug, 2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Debug(v ...interface{}) { + l.Output("", Ldebug, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- +func (l *Logger) Infof(format string, v ...interface{}) { + l.Output("", Linfo, 2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Info(v ...interface{}) { + l.Output("", Linfo, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- +func (l *Logger) Warnf(format string, v ...interface{}) { + l.Output("", Lwarn, 2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Warn(v ...interface{}) { + l.Output("", Lwarn, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func (l *Logger) Errorf(format string, v ...interface{}) { + l.Output("", Lerror, 2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Error(v ...interface{}) { + l.Output("", Lerror, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func (l *Logger) Fatal(v ...interface{}) { + l.Output("", Lfatal, 2, fmt.Sprintln(v...)) + os.Exit(1) +} + +// Fatalf is equivalent to l.Printf() followed by a call to os.Exit(1). +func (l *Logger) Fatalf(format string, v ...interface{}) { + l.Output("", Lfatal, 2, fmt.Sprintf(format, v...)) + os.Exit(1) +} + +// ----------------------------------------- +// Panic is equivalent to l.Print() followed by a call to panic(). +func (l *Logger) Panic(v ...interface{}) { + s := fmt.Sprintln(v...) + l.Output("", Lpanic, 2, s) + panic(s) +} + +// Panicf is equivalent to l.Printf() followed by a call to panic(). +func (l *Logger) Panicf(format string, v ...interface{}) { + s := fmt.Sprintf(format, v...) + l.Output("", Lpanic, 2, s) + panic(s) +} + +// ----------------------------------------- +func (l *Logger) Stack(v ...interface{}) { + s := fmt.Sprint(v...) + s += "\n" + buf := make([]byte, 1024*1024) + n := runtime.Stack(buf, true) + s += string(buf[:n]) + s += "\n" + l.Output("", Lerror, 2, s) +} + +// ----------------------------------------- +func (l *Logger) Stat() (stats []int64) { + l.mu.Lock() + v := l.levelStats + l.mu.Unlock() + return v[:] +} + +// Flags returns the output flags for the logger. +func (l *Logger) Flags() int { + l.mu.Lock() + defer l.mu.Unlock() + return l.flag +} + +func RmColorFlags(flag int) int { + // for un std out, it should not show color since almost them don't support + if flag&Llongcolor != 0 { + flag = flag ^ Llongcolor + } + if flag&Lshortcolor != 0 { + flag = flag ^ Lshortcolor + } + return flag +} + +// SetFlags sets the output flags for the logger. +func (l *Logger) SetFlags(flag int) { + l.mu.Lock() + defer l.mu.Unlock() + l.flag = flag +} + +// Prefix returns the output prefix for the logger. +func (l *Logger) Prefix() string { + l.mu.Lock() + defer l.mu.Unlock() + return l.prefix +} + +// SetPrefix sets the output prefix for the logger. +func (l *Logger) SetPrefix(prefix string) { + l.mu.Lock() + defer l.mu.Unlock() + l.prefix = prefix +} + +// SetOutputLevel sets the output level for the logger. +func (l *Logger) SetOutputLevel(lvl int) { + l.mu.Lock() + defer l.mu.Unlock() + l.Level = lvl +} + +func (l *Logger) OutputLevel() int { + return l.Level +} + +func (l *Logger) SetOutput(w io.Writer) { + l.mu.Lock() + defer l.mu.Unlock() + l.out = w + if w != os.Stdout { + l.flag = RmColorFlags(l.flag) + } +} + +// SetOutput sets the output destination for the standard logger. +func SetOutput(w io.Writer) { + Std.SetOutput(w) +} + +// Flags returns the output flags for the standard logger. +func Flags() int { + return Std.Flags() +} + +// SetFlags sets the output flags for the standard logger. +func SetFlags(flag int) { + Std.SetFlags(flag) +} + +// Prefix returns the output prefix for the standard logger. +func Prefix() string { + return Std.Prefix() +} + +// SetPrefix sets the output prefix for the standard logger. +func SetPrefix(prefix string) { + Std.SetPrefix(prefix) +} + +func SetOutputLevel(lvl int) { + Std.SetOutputLevel(lvl) +} + +func OutputLevel() int { + return Std.OutputLevel() +} + +// ----------------------------------------- + +// Print calls Output to print to the standard logger. +// Arguments are handled in the manner of fmt.Print. +func Print(v ...interface{}) { + Std.Output("", Linfo, 2, fmt.Sprintln(v...)) +} + +// Printf calls Output to print to the standard logger. +// Arguments are handled in the manner of fmt.Printf. +func Printf(format string, v ...interface{}) { + Std.Output("", Linfo, 2, fmt.Sprintf(format, v...)) +} + +// Println calls Output to print to the standard logger. +// Arguments are handled in the manner of fmt.Println. +func Println(v ...interface{}) { + Std.Output("", Linfo, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func Debugf(format string, v ...interface{}) { + Std.Output("", Ldebug, 2, fmt.Sprintf(format, v...)) +} + +func Debug(v ...interface{}) { + Std.Output("", Ldebug, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func Infof(format string, v ...interface{}) { + Std.Output("", Linfo, 2, fmt.Sprintf(format, v...)) +} + +func Info(v ...interface{}) { + Std.Output("", Linfo, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func Warnf(format string, v ...interface{}) { + Std.Output("", Lwarn, 2, fmt.Sprintf(format, v...)) +} + +func Warn(v ...interface{}) { + Std.Output("", Lwarn, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func Errorf(format string, v ...interface{}) { + Std.Output("", Lerror, 2, fmt.Sprintf(format, v...)) +} + +func Error(v ...interface{}) { + Std.Output("", Lerror, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +// Fatal is equivalent to Print() followed by a call to os.Exit(1). +func Fatal(v ...interface{}) { + Std.Output("", Lfatal, 2, fmt.Sprintln(v...)) +} + +// Fatalf is equivalent to Printf() followed by a call to os.Exit(1). +func Fatalf(format string, v ...interface{}) { + Std.Output("", Lfatal, 2, fmt.Sprintf(format, v...)) +} + +// ----------------------------------------- + +// Panic is equivalent to Print() followed by a call to panic(). +func Panic(v ...interface{}) { + Std.Output("", Lpanic, 2, fmt.Sprintln(v...)) +} + +// Panicf is equivalent to Printf() followed by a call to panic(). +func Panicf(format string, v ...interface{}) { + Std.Output("", Lpanic, 2, fmt.Sprintf(format, v...)) +} + +// ----------------------------------------- + +func Stack(v ...interface{}) { + s := fmt.Sprint(v...) + s += "\n" + buf := make([]byte, 1024*1024) + n := runtime.Stack(buf, true) + s += string(buf[:n]) + s += "\n" + Std.Output("", Lerror, 2, s) +} + +// ----------------------------------------- diff --git a/Godeps/_workspace/src/github.com/lunny/log/logext_test.go b/Godeps/_workspace/src/github.com/lunny/log/logext_test.go new file mode 100644 index 0000000..c6a5bb0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/log/logext_test.go @@ -0,0 +1,75 @@ +package log + +import ( + "database/sql" + "fmt" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func TestLog(t *testing.T) { + + SetOutputLevel(Ldebug) + + Debugf("Debug: foo\n") + Debug("Debug: foo") + + Infof("Info: foo\n") + Info("Info: foo") + + Warnf("Warn: foo\n") + Warn("Warn: foo") + + Errorf("Error: foo\n") + Error("Error: foo") + + SetOutputLevel(Linfo) + + Debugf("Debug: foo\n") + Debug("Debug: foo") + + Infof("Info: foo\n") + Info("Info: foo") + + Warnf("Warn: foo\n") + Warn("Warn: foo") + + Errorf("Error: foo\n") + Error("Error: foo") + + db, err := sql.Open("sqlite3", "./log.db") + if err != nil { + t.Fatal(err) + } + + w, err := NewDBWriter(db) + if err != nil { + fmt.Println(err) + t.Fatal(err) + } + SetOutput(w) + SetFlags(RmColorFlags(Flags())) + // output all + SetOutputLevel(Lall) + Debugf("Debug: foo\n") + Debug("Debug: foo") + + Infof("Info: foo\n") + Info("Info: foo") + + Warnf("Warn: foo\n") + Warn("Warn: foo") + + Errorf("Error: foo\n") + Error("Error: foo") + + // output none + SetOutputLevel(Lnone) + Warnf("Warn: foo\n") + Warn("Warn: foo") + + Errorf("Error: foo\n") + Error("Error: foo") + +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/LICENSE b/Godeps/_workspace/src/github.com/lunny/tango/LICENSE new file mode 100644 index 0000000..b75ef9a --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2014 - 2015 The Tango Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/Godeps/_workspace/src/github.com/lunny/tango/README.md b/Godeps/_workspace/src/github.com/lunny/tango/README.md new file mode 100644 index 0000000..0b8712c --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/README.md @@ -0,0 +1,108 @@ +Tango [![Build Status](https://drone.io/github.com/lunny/tango/status.png)](https://drone.io/github.com/lunny/tango/latest) [![](http://gocover.io/_badge/github.com/lunny/tango)](http://gocover.io/github.com/lunny/tango) [简体中文](README_CN.md) +======================= + +[![Join the chat at https://gitter.im/lunny/tango](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/lunny/tango?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) + +![Tango Logo](logo.png) + +Package tango is a micro & pluggable web framework for Go. + +##### Current version: v0.4.6 [Version History](https://github.com/lunny/tango/releases) + +## Getting Started + +To install Tango: + + go get github.com/lunny/tango + +A classic usage of Tango below: + +```go +package main + +import ( + "errors" + "github.com/lunny/tango" +) + +type Action struct { + tango.Json +} + +func (Action) Get() interface{} { + if true { + return map[string]string{ + "say": "Hello tango!", + } + } + return errors.New("something error") +} + +func main() { + t := tango.Classic() + t.Get("/", new(Action)) + t.Run() +} +``` + +Then visit `http://localhost:8000` on your browser. You will get +``` +{"say":"Hello tango!"} +``` + +If you change `true` after `if` to `false`, then you will get +``` +{"err":"something error"} +``` + +This code will automatically convert returned map or error to a json because we has an embedded struct `tango.Json`. + +## Features + +- Powerful routing & Flexible routes combinations. +- Directly integrate with existing services. +- Easy to plugin features with modular design. +- High performance dependency injection embedded. + +## Middlewares + +Middlewares allow you easily plugin features for your Tango applications. + +There are already many [middlewares](https://github.com/tango-contrib) to simplify your work: + +- [recovery](https://github.com/lunny/tango/wiki/Recovery) - recover after panic +- [compress](https://github.com/lunny/tango/wiki/Compress) - Gzip & Deflate compression +- [static](https://github.com/lunny/tango/wiki/Static) - Serves static files +- [logger](https://github.com/lunny/tango/wiki/Logger) - Log the request & inject Logger to action struct +- [param](https://github.com/lunny/tango/wiki/Params) - get the router parameters +- [return](https://github.com/lunny/tango/wiki/Return) - Handle the returned value smartlly +- [context](https://github.com/lunny/tango/wiki/Context) - Inject context to action struct +- [session](https://github.com/tango-contrib/session) - [![Build Status](https://drone.io/github.com/tango-contrib/session/status.png)](https://drone.io/github.com/tango-contrib/session/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/session)](http://gocover.io/github.com/tango-contrib/session) Session manager, [session-redis](http://github.com/tango-contrib/session-redis), [session-nodb](http://github.com/tango-contrib/session-nodb), [session-ledis](http://github.com/tango-contrib/session-ledis) +- [xsrf](https://github.com/tango-contrib/xsrf) - [![Build Status](https://drone.io/github.com/tango-contrib/xsrf/status.png)](https://drone.io/github.com/tango-contrib/xsrf/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/xsrf)](http://gocover.io/github.com/tango-contrib/xsrf) Generates and validates csrf tokens +- [binding](https://github.com/tango-contrib/binding) - [![Build Status](https://drone.io/github.com/tango-contrib/binding/status.png)](https://drone.io/github.com/tango-contrib/binding/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/binding)](http://gocover.io/github.com/tango-contrib/binding) Bind and validates forms +- [renders](https://github.com/tango-contrib/renders) - [![Build Status](https://drone.io/github.com/tango-contrib/renders/status.png)](https://drone.io/github.com/tango-contrib/renders/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/renders)](http://gocover.io/github.com/tango-contrib/renders) Go template engine +- [dispatch](https://github.com/tango-contrib/dispatch) - [![Build Status](https://drone.io/github.com/tango-contrib/dispatch/status.png)](https://drone.io/github.com/tango-contrib/dispatch/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/dispatch)](http://gocover.io/github.com/tango-contrib/dispatch) Multiple Application support on one server +- [tpongo2](https://github.com/tango-contrib/tpongo2) - [![Build Status](https://drone.io/github.com/tango-contrib/tpongo2/status.png)](https://drone.io/github.com/tango-contrib/tpongo2/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/tpongo2)](http://gocover.io/github.com/tango-contrib/tpongo2) [Pongo2](https://github.com/flosch/pongo2) teamplte engine support +- [captcha](https://github.com/tango-contrib/captcha) - [![Build Status](https://drone.io/github.com/tango-contrib/captcha/status.png)](https://drone.io/github.com/tango-contrib/captcha/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/captcha)](http://gocover.io/github.com/tango-contrib/captcha) Captcha +- [events](https://github.com/tango-contrib/events) - [![Build Status](https://drone.io/github.com/tango-contrib/events/status.png)](https://drone.io/github.com/tango-contrib/events/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/events)](http://gocover.io/github.com/tango-contrib/events) Before and After +- [flash](https://github.com/tango-contrib/flash) - [![Build Status](https://drone.io/github.com/tango-contrib/flash/status.png)](https://drone.io/github.com/tango-contrib/flash/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/flash)](http://gocover.io/github.com/tango-contrib/flash) Share data between requests +- [debug](https://github.com/tango-contrib/debug) - [![Build Status](https://drone.io/github.com/tango-contrib/debug/status.png)](https://drone.io/github.com/tango-contrib/debug/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/debug)](http://gocover.io/github.com/tango-contrib/debug) show detail debug infomaton on log +- [basicauth](https://github.com/tango-contrib/basicauth) - [![Build Status](https://drone.io/github.com/tango-contrib/basicauth/status.png)](https://drone.io/github.com/tango-contrib/basicauth/latest) [![](http://gocover.io/_badge/github.com/tango-contrib/basicauth)](http://gocover.io/github.com/tango-contrib/basicauth) basicauth middleware + +## Getting Help + +- [Wiki](https://github.com/lunny/tango/wiki) +- [API Reference](https://gowalker.org/github.com/lunny/tango) +- [Discuss](https://groups.google.com/forum/#!forum/go-tango) + +## Cases + +- [Wego](https://github.com/go-tango/wego) +- [dbweb](https://github.com/go-xorm/dbweb) +- [Godaily](http://godaily.org) - [github](https://github.com/godaily/news) +- [FXH Blog](https://github.com/gofxh/blog) +- [Gos](https://github.com/go-tango/gos) + +## License + +This project is under BSD License. See the [LICENSE](LICENSE) file for the full license text. diff --git a/Godeps/_workspace/src/github.com/lunny/tango/README_CN.md b/Godeps/_workspace/src/github.com/lunny/tango/README_CN.md new file mode 100644 index 0000000..95b0c32 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/README_CN.md @@ -0,0 +1,37 @@ +Tango [![Build Status](https://drone.io/github.com/lunny/tango/status.png)](https://drone.io/github.com/lunny/tango/latest) [![](http://gocover.io/_badge/github.com/lunny/tango)](http://gocover.io/github.com/lunny/tango) [English](README.md) +======================= + +![Tango Logo](logo.png) + +Tango 是一个微内核的Go语言Web框架,采用模块化和注入式的设计理念。开发者可根据自身业务逻辑来选择性的装卸框架的功能,甚至利用丰富的中间件来搭建一个全栈式Web开发框架。 + +##### 当前版本: v0.4.6 [版本更新记录](https://github.com/lunny/tango/releases) + +## 特性 +- 强大而灵活的路由设计 +- 兼容已有的`http.Handler` +- 基于中间件的模块化设计,灵活定制框架功能 +- 高性能的依赖注入方式 + +## 安装Tango: + go get github.com/lunny/tango + +## 帮助文档 +- [快速入门](https://github.com/lunny/tango/wiki/QuickStart) +- [开发文档](https://github.com/lunny/tango/wiki/ZH_Home) +- [API文档](https://gowalker.org/github.com/lunny/tango) + +## 使用案例 +- [Wego](https://github.com/go-tango/wego) tango结合[xorm](http://www.xorm.io/)开发的论坛 +- [FXH Blog](https://github.com/gofxh/blog) 博客 +- [DBWeb](https://github.com/go-xorm/dbweb) 基于Web的数据库管理工具 +- [Godaily](http://godaily.org) - [github](https://github.com/godaily/news) RSS聚合工具 +- [Gos](https://github.com/go-tango/gos) 简易的Web静态文件服务端 + +## 交流讨论 +- QQ群:369240307 +- [中文论坛](https://groups.google.com/forum/#!forum/go-tango) +- [英文论坛](https://groups.google.com/forum/#!forum/go-tango) + +## License +This project is under BSD License. See the [LICENSE](LICENSE) file for the full license text. diff --git a/Godeps/_workspace/src/github.com/lunny/tango/RELEASE.md b/Godeps/_workspace/src/github.com/lunny/tango/RELEASE.md new file mode 100644 index 0000000..d59adfe --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/RELEASE.md @@ -0,0 +1,28 @@ +# Release Log + +## v0.2.9 + +* Add `Prefix` handler to wrap a handler filtered by request url prefix +* Add `Use` method for `Group` for supporting apply middleware with group and resolved #2 + +## v0.2.8 + +New Features: +* cookie support +* custom error handler +* custom routes' method + +Improvements on internal middlewares: +* return handler +* static handler added more options + +Added external middlewares +* events +* debug +* renders + +Added a new example [dbweb](github.com/go-xorm/dbweb) + +## v0.1 + +First version \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/lunny/tango/compress.go b/Godeps/_workspace/src/github.com/lunny/tango/compress.go new file mode 100644 index 0000000..1eead7e --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/compress.go @@ -0,0 +1,147 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bufio" + "compress/flate" + "compress/gzip" + "fmt" + "io" + "net" + "net/http" + "path" + "strings" +) + +const ( + HeaderAcceptEncoding = "Accept-Encoding" + HeaderContentEncoding = "Content-Encoding" + HeaderContentLength = "Content-Length" + HeaderContentType = "Content-Type" + HeaderVary = "Vary" +) + +type Compresser interface { + CompressType() string +} + +type GZip struct{} + +func (GZip) CompressType() string { + return "gzip" +} + +type Deflate struct{} + +func (Deflate) CompressType() string { + return "deflate" +} + +type Compress struct{} + +func (Compress) CompressType() string { + return "auto" +} + +func Compresses(exts []string) HandlerFunc { + extsmap := make(map[string]bool) + for _, ext := range exts { + extsmap[strings.ToLower(ext)] = true + } + + return func(ctx *Context) { + ae := ctx.Req().Header.Get("Accept-Encoding") + if ae == "" { + ctx.Next() + return + } + + if len(extsmap) > 0 { + ext := strings.ToLower(path.Ext(ctx.Req().URL.Path)) + if _, ok := extsmap[ext]; ok { + compress(ctx, "auto") + return + } + } + + if action := ctx.Action(); action != nil { + if c, ok := action.(Compresser); ok { + compress(ctx, c.CompressType()) + return + } + } + + // if blank, then no compress + ctx.Next() + } +} + +func compress(ctx *Context, compressType string) { + ae := ctx.Req().Header.Get("Accept-Encoding") + acceptCompress := strings.SplitN(ae, ",", -1) + var writer io.Writer + var val string + + for _, val = range acceptCompress { + val = strings.TrimSpace(val) + if compressType == "auto" || val == compressType { + if val == "gzip" { + ctx.Header().Set("Content-Encoding", "gzip") + writer = gzip.NewWriter(ctx.ResponseWriter) + break + } else if val == "deflate" { + ctx.Header().Set("Content-Encoding", "deflate") + writer, _ = flate.NewWriter(ctx.ResponseWriter, flate.BestSpeed) + break + } + } + } + + // not supported compress method, then ignore + if writer == nil { + ctx.Next() + return + } + + // for cache server + ctx.Header().Add(HeaderVary, "Accept-Encoding") + + gzw := &compressWriter{writer, ctx.ResponseWriter} + ctx.ResponseWriter = gzw + + ctx.Next() + + // delete content length after we know we have been written to + gzw.Header().Del(HeaderContentLength) + ctx.ResponseWriter = gzw.ResponseWriter + + switch writer.(type) { + case *gzip.Writer: + writer.(*gzip.Writer).Close() + case *flate.Writer: + writer.(*flate.Writer).Close() + } +} + +type compressWriter struct { + w io.Writer + ResponseWriter +} + +func (grw *compressWriter) Write(p []byte) (int, error) { + if len(grw.Header().Get(HeaderContentType)) == 0 { + grw.Header().Set(HeaderContentType, http.DetectContentType(p)) + } + return grw.w.Write(p) +} + +func (grw *compressWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := grw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface") + } + return hijacker.Hijack() +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/compress_test.go b/Godeps/_workspace/src/github.com/lunny/tango/compress_test.go new file mode 100644 index 0000000..3dad5bf --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/compress_test.go @@ -0,0 +1,126 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +type CompressExample struct { + Compress // add this for ask compress according request accept-encoding +} + +func (CompressExample) Get() string { + return fmt.Sprintf("This is a auto compress text") +} + +type GZipExample struct { + GZip // add this for ask compress to GZip +} + +func (GZipExample) Get() string { + return fmt.Sprintf("This is a gzip compress text") +} + +type DeflateExample struct { + Deflate // add this for ask compress to Deflate, if not support then not compress +} + +func (DeflateExample) Get() string { + return fmt.Sprintf("This is a deflate compress text") +} + +type NoCompress struct { +} + +func (NoCompress) Get() string { + return fmt.Sprintf("This is a non-compress text") +} + +func TestCompressAuto(t *testing.T) { + o := Classic() + o.Get("/", new(CompressExample)) + testCompress(t, o, "http://localhost:8000/", + "This is a auto compress text", "gzip") +} + +func TestCompressGzip(t *testing.T) { + o := Classic() + o.Get("/", new(GZipExample)) + testCompress(t, o, "http://localhost:8000/", + "This is a gzip compress text", "gzip") +} + +func TestCompressDeflate(t *testing.T) { + o := Classic() + o.Get("/", new(DeflateExample)) + testCompress(t, o, "http://localhost:8000/", + "This is a deflate compress text", "deflate") +} + +func TestCompressNon(t *testing.T) { + o := Classic() + o.Get("/", new(NoCompress)) + testCompress(t, o, "http://localhost:8000/", + "This is a non-compress text", "") +} + +func TestCompressStatic(t *testing.T) { + o := New() + o.Use(Compresses([]string{".html"})) + o.Use(ClassicHandlers...) + testCompress(t, o, "http://localhost:8000/public/test.html", + "hello tango", "gzip") +} + +func testCompress(t *testing.T, o *Tango, url, content, enc string) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Error(err) + } + req.Header.Add("Accept-Encoding", "gzip, deflate") + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + + ce := recorder.Header().Get("Content-Encoding") + if ce == "gzip" { + r, err := gzip.NewReader(buff) + if err != nil { + t.Error(err) + } + defer r.Close() + + bs, err := ioutil.ReadAll(r) + if err != nil { + t.Error(err) + } + expect(t, string(bs), content) + } else if ce == "deflate" { + r := flate.NewReader(buff) + defer r.Close() + + bs, err := ioutil.ReadAll(r) + if err != nil { + t.Error(err) + } + expect(t, string(bs), content) + } else { + expect(t, buff.String(), content) + } + expect(t, enc, ce) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/context.go b/Godeps/_workspace/src/github.com/lunny/tango/context.go new file mode 100644 index 0000000..525ffae --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/context.go @@ -0,0 +1,275 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "reflect" +) + +type Handler interface { + Handle(*Context) +} + +type Context struct { + tan *Tango + Logger + + idx int + req *http.Request + ResponseWriter + route *Route + params Params + callArgs []reflect.Value + matched bool + + action interface{} + Result interface{} +} + +func (ctx *Context) reset(req *http.Request, resp ResponseWriter) { + ctx.req = req + ctx.ResponseWriter = resp + ctx.idx = 0 + ctx.route = nil + ctx.params = nil + ctx.callArgs = nil + ctx.matched = false + ctx.action = nil + ctx.Result = nil +} + +func (ctx *Context) HandleError() { + ctx.tan.ErrHandler.Handle(ctx) +} + +func (ctx *Context) Req() *http.Request { + return ctx.req +} + +func (ctx *Context) SecureCookies(secret string) Cookies { + return &secureCookies{ + (*cookies)(ctx), + secret, + } +} + +func (ctx *Context) Cookies() Cookies { + return (*cookies)(ctx) +} + +func (ctx *Context) Forms() *Forms { + return (*Forms)(ctx.req) +} + +func (ctx *Context) Route() *Route { + ctx.newAction() + return ctx.route +} + +func (ctx *Context) Params() *Params { + ctx.newAction() + return &ctx.params +} + +func (ctx *Context) Action() interface{} { + ctx.newAction() + return ctx.action +} + +func (ctx *Context) newAction() { + if !ctx.matched { + reqPath := removeStick(ctx.Req().URL.Path) + ctx.route, ctx.params = ctx.tan.Match(reqPath, ctx.Req().Method) + if ctx.route != nil { + vc := ctx.route.newAction() + ctx.action = vc.Interface() + switch ctx.route.routeType { + case StructPtrRoute: + ctx.callArgs = []reflect.Value{vc.Elem()} + case StructRoute: + ctx.callArgs = []reflect.Value{vc} + case FuncRoute: + ctx.callArgs = []reflect.Value{} + case FuncHttpRoute: + ctx.callArgs = []reflect.Value{reflect.ValueOf(ctx.ResponseWriter), + reflect.ValueOf(ctx.Req())} + case FuncReqRoute: + ctx.callArgs = []reflect.Value{reflect.ValueOf(ctx.Req())} + case FuncResponseRoute: + ctx.callArgs = []reflect.Value{reflect.ValueOf(ctx.ResponseWriter)} + case FuncCtxRoute: + ctx.callArgs = []reflect.Value{reflect.ValueOf(ctx)} + default: + panic("routeType error") + } + } + ctx.matched = true + } +} + +// WARNING: don't invoke this method on action +func (ctx *Context) Next() { + ctx.idx += 1 + ctx.Invoke() +} + +// WARNING: don't invoke this method on action +func (ctx *Context) Invoke() { + if ctx.idx < len(ctx.tan.handlers) { + ctx.tan.handlers[ctx.idx].Handle(ctx) + } else { + ctx.newAction() + // route is matched + if ctx.action != nil { + var ret []reflect.Value + switch fn := ctx.route.raw.(type) { + case func(*Context): + fn(ctx) + case func(*http.Request, http.ResponseWriter): + fn(ctx.req, ctx.ResponseWriter) + case func(): + fn() + case func(*http.Request): + fn(ctx.req) + case func(http.ResponseWriter): + fn(ctx.ResponseWriter) + default: + ret = ctx.route.method.Call(ctx.callArgs) + } + + if len(ret) == 1 { + ctx.Result = ret[0].Interface() + } else if len(ret) == 2 { + if code, ok := ret[0].Interface().(int); ok { + ctx.Result = &StatusResult{code, ret[1].Interface()} + } + } + // not route matched + } else { + if !ctx.Written() { + ctx.NotFound() + } + } + } +} + +func (ctx *Context) ServeFile(path string) error { + http.ServeFile(ctx, ctx.Req(), path) + return nil +} + +func (ctx *Context) ServeXml(obj interface{}) error { + encoder := xml.NewEncoder(ctx) + ctx.Header().Set("Content-Type", "application/xml; charset=UTF-8") + err := encoder.Encode(obj) + if err != nil { + ctx.Header().Del("Content-Type") + } + return err +} + +func (ctx *Context) ServeJson(obj interface{}) error { + encoder := json.NewEncoder(ctx) + ctx.Header().Set("Content-Type", "application/json; charset=UTF-8") + err := encoder.Encode(obj) + if err != nil { + ctx.Header().Del("Content-Type") + } + return err +} + +func (ctx *Context) Download(fpath string) error { + f, err := os.Open(fpath) + if err != nil { + return err + } + defer f.Close() + + fName := filepath.Base(fpath) + ctx.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%v\"", fName)) + _, err = io.Copy(ctx, f) + return err +} + +func (ctx *Context) SaveToFile(formName, savePath string) error { + file, _, err := ctx.Req().FormFile(formName) + if err != nil { + return err + } + defer file.Close() + + f, err := os.OpenFile(savePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0666) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(f, file) + return err +} + +func (ctx *Context) Redirect(url string, status ...int) { + s := http.StatusFound + if len(status) > 0 { + s = status[0] + } + http.Redirect(ctx.ResponseWriter, ctx.Req(), url, s) +} + +// Notmodified writes a 304 HTTP response +func (ctx *Context) NotModified() { + ctx.WriteHeader(http.StatusNotModified) +} + +func (ctx *Context) Unauthorized() { + ctx.Abort(http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)) +} + +// NotFound writes a 404 HTTP response +func (ctx *Context) NotFound(message ...string) { + if len(message) == 0 { + ctx.Abort(http.StatusNotFound, http.StatusText(http.StatusNotFound)) + return + } + ctx.Abort(http.StatusNotFound, message[0]) +} + +// Abort is a helper method that sends an HTTP header and an optional +// body. It is useful for returning 4xx or 5xx errors. +// Once it has been called, any return value from the handler will +// not be written to the response. +func (ctx *Context) Abort(status int, body string) { + ctx.Result = Abort(status, body) + ctx.HandleError() +} + +type Contexter interface { + SetContext(*Context) +} + +type Ctx struct { + *Context +} + +func (c *Ctx) SetContext(ctx *Context) { + c.Context = ctx +} + +func Contexts() HandlerFunc { + return func(ctx *Context) { + if action := ctx.Action(); action != nil { + if a, ok := action.(Contexter); ok { + a.SetContext(ctx) + } + } + ctx.Next() + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/context_test.go b/Godeps/_workspace/src/github.com/lunny/tango/context_test.go new file mode 100644 index 0000000..d1eb7d8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/context_test.go @@ -0,0 +1,282 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type CtxAction struct { + Ctx +} + +func (p *CtxAction) Get() { + p.Ctx.Write([]byte("context")) +} + +func TestContext1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(CtxAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "context") +} + +type CtxJsonAction struct { + Ctx +} + +func (p *CtxJsonAction) Get() error { + return p.Ctx.ServeJson(map[string]string{ + "get": "ctx", + }) +} + +func TestContext2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(CtxJsonAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, recorder.Header().Get("Content-Type"), "application/json; charset=UTF-8") + expect(t, strings.TrimSpace(buff.String()), `{"get":"ctx"}`) +} + +type CtxXmlAction struct { + Ctx +} + +type XmlStruct struct { + Content string +} + +func (p *CtxXmlAction) Get() error { + return p.Ctx.ServeXml(XmlStruct{"content"}) +} + +func TestContext3(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(CtxXmlAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + expect(t, recorder.Header().Get("Content-Type"), "application/xml; charset=UTF-8") + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `content`) +} + +type CtxFileAction struct { + Ctx +} + +func (p *CtxFileAction) Get() error { + return p.Ctx.ServeFile("./public/index.html") +} + +func TestContext4(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Any("/", new(CtxFileAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `this is index.html`) +} + +func TestContext5(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Any("/2", func() string { + return "2" + }) + o.Any("/", func(ctx *Context) { + ctx.Redirect("/2") + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusFound) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `Found.`) +} + +type NotFoundAction struct { + Ctx +} + +func (n *NotFoundAction) Get() { + n.NotFound("not found") +} + +func TestContext6(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Any("/", new(NotFoundAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotFound) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), "not found") +} + +type NotModifidAction struct { + Ctx +} + +func (n *NotModifidAction) Get() { + n.NotModified() +} + +func TestContext7(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Any("/", new(NotModifidAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotModified) + expect(t, len(buff.String()), 0) +} + +type UnauthorizedAction struct { + Ctx +} + +func (n *UnauthorizedAction) Get() { + n.Unauthorized() +} + +func TestContext8(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Any("/", new(UnauthorizedAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusUnauthorized) + expect(t, buff.String(), http.StatusText(http.StatusUnauthorized)) +} + +type DownloadAction struct { + Ctx +} + +func (n *DownloadAction) Get() { + n.Download("./public/index.html") +} + +func TestContext9(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Any("/", new(DownloadAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + + expect(t, recorder.Header().Get("Content-Disposition"), `attachment; filename="index.html"`) + expect(t, recorder.Code, http.StatusOK) + expect(t, buff.String(), "this is index.html") +} + +// check unsupported function will panic +func TestContext10(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + var ifPanic bool + defer func() { + if err := recover(); err != nil { + ifPanic = true + } + + expect(t, ifPanic, true) + }() + + o.Any("/", func(i int) { + fmt.Println(i) + }) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/cookie.go b/Godeps/_workspace/src/github.com/lunny/tango/cookie.go new file mode 100644 index 0000000..3cead40 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/cookie.go @@ -0,0 +1,748 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "io/ioutil" + "net/http" + "strconv" + "strings" + "time" + "html/template" +) + +func isValidCookieValue(p []byte) bool { + for _, b := range p { + if b <= ' ' || + b >= 127 || + b == '"' || + b == ',' || + b == ';' || + b == '\\' { + return false + } + } + return true +} + +func isValidCookieName(s string) bool { + for _, r := range s { + if r <= ' ' || + r >= 127 || + strings.ContainsRune(" \t\"(),/:;<=>?@[]\\{}", r) { + return false + } + } + return true +} + +type Set interface { + String(key string) (string, error) + Int(key string) (int, error) + Int32(key string) (int32, error) + Int64(key string) (int64, error) + Uint(key string) (uint, error) + Uint32(key string) (uint32, error) + Uint64(key string) (uint64, error) + Float32(key string) (float32, error) + Float64(key string) (float64, error) + Bool(key string) (bool, error) + + MustString(key string, defaults ...string) string + MustEscape(key string, defaults ...string) string + MustInt(key string, defaults ...int) int + MustInt32(key string, defaults ...int32) int32 + MustInt64(key string, defaults ...int64) int64 + MustUint(key string, defaults ...uint) uint + MustUint32(key string, defaults ...uint32) uint32 + MustUint64(key string, defaults ...uint64) uint64 + MustFloat32(key string, defaults ...float32) float32 + MustFloat64(key string, defaults ...float64) float64 + MustBool(key string, defaults ...bool) bool +} + +type Cookies interface { + Set + Get(string) *http.Cookie + Set(*http.Cookie) + Expire(string, time.Time) + Del(string) +} + +type cookies Context + +var _ Cookies = &cookies{} + +func NewCookie(name string, value string, age ...int64) *http.Cookie { + if !isValidCookieName(name) || !isValidCookieValue([]byte(value)) { + return nil + } + + var utctime time.Time + if len(age) == 0 { + // 2^31 - 1 seconds (roughly 2038) + utctime = time.Unix(2147483647, 0) + } else { + utctime = time.Unix(time.Now().Unix()+age[0], 0) + } + return &http.Cookie{Name: name, Value: value, Expires: utctime} +} + +func (c *cookies) Get(key string) *http.Cookie { + ck, err := c.req.Cookie(key) + if err != nil { + return nil + } + return ck +} + +func (c *cookies) Set(ck *http.Cookie) { + http.SetCookie(c.ResponseWriter, ck) +} + +func (c *cookies) Expire(key string, expire time.Time) { + ck := c.Get(key) + if ck != nil { + ck.Expires = expire + ck.MaxAge = int(expire.Sub(time.Now()).Seconds()) + c.Set(ck) + } +} + +func (c *cookies) Del(key string) { + c.Expire(key, time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local)) +} + +func (c *cookies) String(key string) (string, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return "", err + } + return ck.Value, nil +} + +func (c *cookies) Int(key string) (int, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + return strconv.Atoi(ck.Value) +} + +func (c *cookies) Int32(key string) (int32, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + v, err := strconv.ParseInt(ck.Value, 10, 32) + return int32(v), err +} + +func (c *cookies) Int64(key string) (int64, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + return strconv.ParseInt(ck.Value, 10, 64) +} + +func (c *cookies) Uint(key string) (uint, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + v, err := strconv.ParseUint(ck.Value, 10, 64) + return uint(v), err +} + +func (c *cookies) Uint32(key string) (uint32, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + v, err := strconv.ParseUint(ck.Value, 10, 32) + return uint32(v), err +} + +func (c *cookies) Uint64(key string) (uint64, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + return strconv.ParseUint(ck.Value, 10, 64) +} + +func (c *cookies) Float32(key string) (float32, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + v, err := strconv.ParseFloat(ck.Value, 32) + return float32(v), err +} + +func (c *cookies) Float64(key string) (float64, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + return strconv.ParseFloat(ck.Value, 32) +} + +func (c *cookies) Bool(key string) (bool, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return false, err + } + return strconv.ParseBool(ck.Value) +} + +func (c *cookies) MustString(key string, defaults ...string) string { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return "" + } + return ck.Value +} + +func (c *cookies) MustEscape(key string, defaults ...string) string { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return "" + } + return template.HTMLEscapeString(ck.Value) +} + +func (c *cookies) MustInt(key string, defaults ...int) int { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + v, err := strconv.Atoi(ck.Value) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (c *cookies) MustInt32(key string, defaults ...int32) int32 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + v, err := strconv.ParseInt(ck.Value, 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return int32(v) +} + +func (c *cookies) MustInt64(key string, defaults ...int64) int64 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + v, err := strconv.ParseInt(ck.Value, 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (c *cookies) MustUint(key string, defaults ...uint) uint { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + v, err := strconv.ParseUint(ck.Value, 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return uint(v) +} + +func (c *cookies) MustUint32(key string, defaults ...uint32) uint32 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + v, err := strconv.ParseUint(ck.Value, 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return uint32(v) +} + +func (c *cookies) MustUint64(key string, defaults ...uint64) uint64 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + v, err := strconv.ParseUint(ck.Value, 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (c *cookies) MustFloat32(key string, defaults ...float32) float32 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + v, err := strconv.ParseFloat(ck.Value, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return float32(v) +} + +func (c *cookies) MustFloat64(key string, defaults ...float64) float64 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + v, err := strconv.ParseFloat(ck.Value, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (c *cookies) MustBool(key string, defaults ...bool) bool { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return false + } + v, err := strconv.ParseBool(ck.Value) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (ctx *Context) Cookie(key string, defaults ...string) string { + return ctx.Cookies().MustString(key, defaults...) +} + +func (ctx *Context) CookieEscape(key string, defaults ...string) string { + return ctx.Cookies().MustEscape(key, defaults...) +} + +func (ctx *Context) CookieInt(key string, defaults ...int) int { + return ctx.Cookies().MustInt(key, defaults...) +} + +func (ctx *Context) CookieInt32(key string, defaults ...int32) int32 { + return ctx.Cookies().MustInt32(key, defaults...) +} + +func (ctx *Context) CookieInt64(key string, defaults ...int64) int64 { + return ctx.Cookies().MustInt64(key, defaults...) +} + +func (ctx *Context) CookieUint(key string, defaults ...uint) uint { + return ctx.Cookies().MustUint(key, defaults...) +} + +func (ctx *Context) CookieUint32(key string, defaults ...uint32) uint32 { + return ctx.Cookies().MustUint32(key, defaults...) +} + +func (ctx *Context) CookieUint64(key string, defaults ...uint64) uint64 { + return ctx.Cookies().MustUint64(key, defaults...) +} + +func (ctx *Context) CookieFloat32(key string, defaults ...float32) float32 { + return ctx.Cookies().MustFloat32(key, defaults...) +} + +func (ctx *Context) CookieFloat64(key string, defaults ...float64) float64 { + return ctx.Cookies().MustFloat64(key, defaults...) +} + +func (ctx *Context) CookieBool(key string, defaults ...bool) bool { + return ctx.Cookies().MustBool(key, defaults...) +} + +func getCookieSig(key string, val []byte, timestamp string) string { + hm := hmac.New(sha1.New, []byte(key)) + + hm.Write(val) + hm.Write([]byte(timestamp)) + + hex := fmt.Sprintf("%02x", hm.Sum(nil)) + return hex +} + +// secure cookies +type secureCookies struct { + *cookies + secret string +} + +var _ Cookies = &secureCookies{} + +func parseSecureCookie(secret string, value string) string { + parts := strings.SplitN(value, "|", 3) + val, timestamp, sig := parts[0], parts[1], parts[2] + if getCookieSig(secret, []byte(val), timestamp) != sig { + return "" + } + + ts, _ := strconv.ParseInt(timestamp, 0, 64) + if time.Now().Unix()-31*86400 > ts { + return "" + } + + buf := bytes.NewBufferString(val) + encoder := base64.NewDecoder(base64.StdEncoding, buf) + + res, _ := ioutil.ReadAll(encoder) + return string(res) +} + +func (c *secureCookies) Get(key string) *http.Cookie { + ck := c.cookies.Get(key) + if ck == nil { + return nil + } + + v := parseSecureCookie(c.secret, ck.Value) + if v == "" { + return nil + } + ck.Value = v + return ck +} + +func (c *secureCookies) String(key string) (string, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return "", err + } + v := parseSecureCookie(c.secret, ck.Value) + return v, nil +} + +func (c *secureCookies) Int(key string) (int, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + s := parseSecureCookie(c.secret, ck.Value) + return strconv.Atoi(s) +} + +func (c *secureCookies) Int32(key string) (int32, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseInt(s, 10, 32) + return int32(v), err +} + +func (c *secureCookies) Int64(key string) (int64, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + s := parseSecureCookie(c.secret, ck.Value) + return strconv.ParseInt(s, 10, 64) +} + +func (c *secureCookies) Uint(key string) (uint, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseUint(s, 10, 64) + return uint(v), err +} + +func (c *secureCookies) Uint32(key string) (uint32, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseUint(s, 10, 32) + return uint32(v), err +} + +func (c *secureCookies) Uint64(key string) (uint64, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + s := parseSecureCookie(c.secret, ck.Value) + return strconv.ParseUint(s, 10, 64) +} + +func (c *secureCookies) Float32(key string) (float32, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseFloat(s, 32) + return float32(v), err +} + +func (c *secureCookies) Float64(key string) (float64, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return 0, err + } + s := parseSecureCookie(c.secret, ck.Value) + return strconv.ParseFloat(s, 32) +} + +func (c *secureCookies) Bool(key string) (bool, error) { + ck, err := c.req.Cookie(key) + if err != nil { + return false, err + } + s := parseSecureCookie(c.secret, ck.Value) + return strconv.ParseBool(s) +} + +func (c *secureCookies) MustString(key string, defaults ...string) string { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return "" + } + s := parseSecureCookie(c.secret, ck.Value) + return s +} + +func (c *secureCookies) MustEscape(key string, defaults ...string) string { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return "" + } + s := parseSecureCookie(c.secret, ck.Value) + return template.HTMLEscapeString(s) +} + +func (c *secureCookies) MustInt(key string, defaults ...int) int { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.Atoi(s) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (c *secureCookies) MustInt32(key string, defaults ...int32) int32 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseInt(s, 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return int32(v) +} + +func (c *secureCookies) MustInt64(key string, defaults ...int64) int64 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseInt(s, 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (c *secureCookies) MustUint(key string, defaults ...uint) uint { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseUint(s, 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return uint(v) +} + +func (c *secureCookies) MustUint32(key string, defaults ...uint32) uint32 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseUint(s, 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return uint32(v) +} + +func (c *secureCookies) MustUint64(key string, defaults ...uint64) uint64 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseUint(s, 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (c *secureCookies) MustFloat32(key string, defaults ...float32) float32 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseFloat(s, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return float32(v) +} + +func (c *secureCookies) MustFloat64(key string, defaults ...float64) float64 { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return 0 + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseFloat(s, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (c *secureCookies) MustBool(key string, defaults ...bool) bool { + ck, err := c.req.Cookie(key) + if err != nil { + if len(defaults) > 0 { + return defaults[0] + } + return false + } + s := parseSecureCookie(c.secret, ck.Value) + v, err := strconv.ParseBool(s) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func secCookieValue(secret string, vb []byte) string { + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + sig := getCookieSig(secret, vb, timestamp) + return strings.Join([]string{string(vb), timestamp, sig}, "|") +} + +func NewSecureCookie(secret, name string, val string, age ...int64) *http.Cookie { + var buf bytes.Buffer + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + encoder.Write([]byte(val)) + encoder.Close() + + cookie := secCookieValue(secret, buf.Bytes()) + return NewCookie(name, cookie, age...) +} + +func (c *secureCookies) Expire(key string, expire time.Time) { + ck := c.Get(key) + if ck != nil { + ck.Expires = expire + ck.Value = secCookieValue(c.secret, []byte(ck.Value)) + c.Set(ck) + } +} + +func (c *secureCookies) Del(key string) { + c.Expire(key, time.Date(1900, 1, 1, 0, 0, 0, 0, time.Local)) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/cookie_test.go b/Godeps/_workspace/src/github.com/lunny/tango/cookie_test.go new file mode 100644 index 0000000..1e1c87d --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/cookie_test.go @@ -0,0 +1,1673 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestCookie1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) string { + ck := ctx.Cookies().Get("name") + if ck != nil { + return ck.Value + } + return "" + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("name", "test")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "test") +} + +func TestCookie2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) string { + ctx.Cookies().Set(NewCookie("name", "test")) + return "test" + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "test") + expect(t, strings.Split(recorder.Header().Get("Set-Cookie"), ";")[0], "name=test") +} + +func TestCookie3(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) string { + ctx.Cookies().Expire("expire", time.Now()) + return "test" + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + req.AddCookie(NewCookie("expire", "test")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "test") + fmt.Println(recorder.Header().Get("Set-Cookie")) + expect(t, strings.Split(recorder.Header().Get("Set-Cookie"), ";")[0], "expire=test") +} + +func TestCookie4(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) string { + ctx.Cookies().Del("ttttt") + return "test" + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + req.AddCookie(NewCookie("ttttt", "test")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "test") + expect(t, recorder.Header().Get("Set-Cookie"), "ttttt=test; Max-Age=0") +} + +func TestCookie5(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) string { + ck := ctx.SecureCookies("sssss").Get("name") + if ck != nil { + return ck.Value + } + return "" + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie("sssss", "name", "test")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "test") +} + +func TestCookie6(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) string { + ctx.SecureCookies("ttttt").Set(NewSecureCookie("ttttt", "name", "test")) + return "test" + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "test") + + r := strings.Split(recorder.Header().Get("Set-Cookie"), ";")[0] + s := strings.SplitN(r, "=", 2) + name, value := s[0], s[1] + expect(t, name, "name") + expect(t, parseSecureCookie("ttttt", value), "test") +} + +func TestCookie7(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) string { + ctx.SecureCookies("ttttt").Expire("expire", time.Now()) + return "test" + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + req.AddCookie(NewSecureCookie("ttttt", "expire", "test")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "test") + expect(t, strings.Split(recorder.Header().Get("Set-Cookie"), "|")[0], "expire=test") +} + +func TestCookie8(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) string { + ctx.SecureCookies("ttttt").Del("ttttt") + return "test" + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + req.AddCookie(NewSecureCookie("ttttt", "ttttt", "test")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "test") + expect(t, strings.Split(recorder.Header().Get("Set-Cookie"), "|")[0], "ttttt=test") +} + +type Cookie11Action struct { + Ctx +} + +func (a *Cookie11Action) Get() string { + v, _ := a.Cookies().Int("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie11(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie11Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie12Action struct { + Ctx +} + +func (a *Cookie12Action) Get() string { + v, _ := a.Cookies().Int32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie12(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie12Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie13Action struct { + Ctx +} + +func (a *Cookie13Action) Get() string { + v, _ := a.Cookies().Int64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie13(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie13Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie14Action struct { + Ctx +} + +func (a *Cookie14Action) Get() string { + v, _ := a.Cookies().Uint("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie14(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie14Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie15Action struct { + Ctx +} + +func (a *Cookie15Action) Get() string { + v, _ := a.Cookies().Uint32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie15(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie15Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie16Action struct { + Ctx +} + +func (a *Cookie16Action) Get() string { + v, _ := a.Cookies().Uint64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie16(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie16Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie17Action struct { + Ctx +} + +func (a *Cookie17Action) Get() string { + v, _ := a.Cookies().Float32("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie17(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie17Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie18Action struct { + Ctx +} + +func (a *Cookie18Action) Get() string { + v, _ := a.Cookies().Float64("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie18(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie18Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie19Action struct { + Ctx +} + +func (a *Cookie19Action) Get() string { + v, _ := a.Cookies().Bool("test") + return fmt.Sprintf("%v", v) +} + +func TestCookie19(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie19Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Cookie20Action struct { + Ctx +} + +func (a *Cookie20Action) Get() string { + v, _ := a.Cookies().String("test") + return fmt.Sprintf("%v", v) +} + +func TestCookie20(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie20Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie21Action struct { + Ctx +} + +func (a *Cookie21Action) Get() string { + v := a.Cookies().MustInt("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie21(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie21Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie22Action struct { + Ctx +} + +func (a *Cookie22Action) Get() string { + v := a.Cookies().MustInt32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie22(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie22Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie23Action struct { + Ctx +} + +func (a *Cookie23Action) Get() string { + v := a.Cookies().MustInt64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie23(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie23Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie24Action struct { + Ctx +} + +func (a *Cookie24Action) Get() string { + v := a.Cookies().MustUint("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie24(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie24Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie25Action struct { + Ctx +} + +func (a *Cookie25Action) Get() string { + v := a.Cookies().MustUint32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie25(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie25Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie26Action struct { + Ctx +} + +func (a *Cookie26Action) Get() string { + v := a.Cookies().MustUint64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie26(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie26Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie27Action struct { + Ctx +} + +func (a *Cookie27Action) Get() string { + v := a.Cookies().MustFloat32("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie27(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie27Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie28Action struct { + Ctx +} + +func (a *Cookie28Action) Get() string { + v := a.Cookies().MustFloat64("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie28(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie28Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie29Action struct { + Ctx +} + +func (a *Cookie29Action) Get() string { + v := a.Cookies().MustBool("test") + return fmt.Sprintf("%v", v) +} + +func TestCookie29(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie29Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Cookie30Action struct { + Ctx +} + +func (a *Cookie30Action) Get() string { + v := a.Cookies().MustString("test") + return fmt.Sprintf("%s", v) +} + +func TestCookie30(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie30Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +const ( + secrectCookie = "ssss" +) + +type Cookie31Action struct { + Ctx +} + +func (a *Cookie31Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).Int("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie31(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie31Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie32Action struct { + Ctx +} + +func (a *Cookie32Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).Int32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie32(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie32Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie33Action struct { + Ctx +} + +func (a *Cookie33Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).Int64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie33(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie33Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie34Action struct { + Ctx +} + +func (a *Cookie34Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).Uint("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie34(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie34Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie35Action struct { + Ctx +} + +func (a *Cookie35Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).Uint32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie35(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie35Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie36Action struct { + Ctx +} + +func (a *Cookie36Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).Uint64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie36(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie36Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie37Action struct { + Ctx +} + +func (a *Cookie37Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).Float32("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie37(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie37Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie38Action struct { + Ctx +} + +func (a *Cookie38Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).Float64("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie38(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie38Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie39Action struct { + Ctx +} + +func (a *Cookie39Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).Bool("test") + return fmt.Sprintf("%v", v) +} + +func TestCookie39(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie39Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Cookie40Action struct { + Ctx +} + +func (a *Cookie40Action) Get() string { + v, _ := a.SecureCookies(secrectCookie).String("test") + return fmt.Sprintf("%v", v) +} + +func TestCookie40(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie40Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie41Action struct { + Ctx +} + +func (a *Cookie41Action) Get() string { + v := a.SecureCookies(secrectCookie).MustInt("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie41(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie41Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie42Action struct { + Ctx +} + +func (a *Cookie42Action) Get() string { + v := a.SecureCookies(secrectCookie).MustInt32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie42(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie42Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie43Action struct { + Ctx +} + +func (a *Cookie43Action) Get() string { + v := a.SecureCookies(secrectCookie).MustInt64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie43(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie43Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie44Action struct { + Ctx +} + +func (a *Cookie44Action) Get() string { + v := a.SecureCookies(secrectCookie).MustUint("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie44(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie44Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie45Action struct { + Ctx +} + +func (a *Cookie45Action) Get() string { + v := a.SecureCookies(secrectCookie).MustUint32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie45(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie45Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie46Action struct { + Ctx +} + +func (a *Cookie46Action) Get() string { + v := a.SecureCookies(secrectCookie).MustUint64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie46(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie46Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie47Action struct { + Ctx +} + +func (a *Cookie47Action) Get() string { + v := a.SecureCookies(secrectCookie).MustFloat32("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie47(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie47Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie48Action struct { + Ctx +} + +func (a *Cookie48Action) Get() string { + v := a.SecureCookies(secrectCookie).MustFloat64("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie48(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie48Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie49Action struct { + Ctx +} + +func (a *Cookie49Action) Get() string { + v := a.SecureCookies(secrectCookie).MustBool("test") + return fmt.Sprintf("%v", v) +} + +func TestCookie49(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie49Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Cookie50Action struct { + Ctx +} + +func (a *Cookie50Action) Get() string { + v := a.SecureCookies(secrectCookie).MustString("test") + return fmt.Sprintf("%s", v) +} + +func TestCookie50(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie50Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewSecureCookie(secrectCookie, "test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie51Action struct { + Ctx +} + +func (a *Cookie51Action) Get() string { + v := a.CookieInt("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie51(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie51Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie52Action struct { + Ctx +} + +func (a *Cookie52Action) Get() string { + v := a.CookieInt32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie52(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie52Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie53Action struct { + Ctx +} + +func (a *Cookie53Action) Get() string { + v := a.CookieInt64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie53(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie53Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie54Action struct { + Ctx +} + +func (a *Cookie54Action) Get() string { + v := a.CookieUint("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie54(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie54Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie55Action struct { + Ctx +} + +func (a *Cookie55Action) Get() string { + v := a.CookieUint32("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie55(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie55Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie56Action struct { + Ctx +} + +func (a *Cookie56Action) Get() string { + v := a.CookieUint64("test") + return fmt.Sprintf("%d", v) +} + +func TestCookie56(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie56Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Cookie57Action struct { + Ctx +} + +func (a *Cookie57Action) Get() string { + v := a.CookieFloat32("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie57(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie57Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie58Action struct { + Ctx +} + +func (a *Cookie58Action) Get() string { + v := a.CookieFloat64("test") + return fmt.Sprintf("%.2f", v) +} + +func TestCookie58(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie58Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Cookie59Action struct { + Ctx +} + +func (a *Cookie59Action) Get() string { + v := a.CookieBool("test") + return fmt.Sprintf("%v", v) +} + +func TestCookie59(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie59Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Cookie60Action struct { + Ctx +} + +func (a *Cookie60Action) Get() string { + v := a.Cookie("test") + return fmt.Sprintf("%s", v) +} + +func TestCookie60(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Cookie60Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + req.AddCookie(NewCookie("test", "1")) + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/doc.go b/Godeps/_workspace/src/github.com/lunny/tango/doc.go new file mode 100644 index 0000000..4965325 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/doc.go @@ -0,0 +1,51 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Tango is a micro & pluggable web framework for Go language. + +// package main + +// import "github.com/lunny/tango" + +// type Action struct { +// } + +// func (Action) Get() string { +// return "Hello tango!" +// } + +// func main() { +// t := tango.Classic() +// t.Get("/", new(Action)) +// t.Run() +// } + +// Middlewares allow you easily plugin/unplugin features for your Tango applications. + +// There are already many [middlewares](https://github.com/tango-contrib) to simplify your work: + +// - recovery - recover after panic +// - compress - Gzip & Deflate compression +// - static - Serves static files +// - logger - Log the request & inject Logger to action struct +// - param - get the router parameters +// - return - Handle the returned value smartlly +// - ctx - Inject context to action struct + +// - [session](https://github.com/tango-contrib/session) - Session manager, with stores support: +// * Memory - memory as a session store +// * [Redis](https://github.com/tango-contrib/session-redis) - redis server as a session store +// * [nodb](https://github.com/tango-contrib/session-nodb) - nodb as a session store +// * [ledis](https://github.com/tango-contrib/session-ledis) - ledis server as a session store) +// - [xsrf](https://github.com/tango-contrib/xsrf) - Generates and validates csrf tokens +// - [binding](https://github.com/tango-contrib/binding) - Bind and validates forms +// - [renders](https://github.com/tango-contrib/renders) - Go template engine +// - [dispatch](https://github.com/tango-contrib/dispatch) - Multiple Application support on one server +// - [tpongo2](https://github.com/tango-contrib/tpongo2) - Pongo2 teamplte engine support +// - [captcha](https://github.com/tango-contrib/captcha) - Captcha +// - [events](https://github.com/tango-contrib/events) - Before and After +// - [flash](https://github.com/tango-contrib/flash) - Share data between requests +// - [debug](https://github.com/tango-contrib/debug) - Show detail debug infomaton on log + +package tango diff --git a/Godeps/_workspace/src/github.com/lunny/tango/error.go b/Godeps/_workspace/src/github.com/lunny/tango/error.go new file mode 100644 index 0000000..db05388 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/error.go @@ -0,0 +1,72 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "fmt" + "net/http" +) + +type AbortError interface { + error + Code() int +} + +type abortError struct { + code int + content string +} + +func (a *abortError) Code() int { + return a.code +} + +func (a *abortError) Error() string { + return fmt.Sprintf("%v", a.content) +} + +func Abort(code int, content ...string) AbortError { + if len(content) >= 1 { + return &abortError{code, content[0]} + } + return &abortError{code, http.StatusText(code)} +} + +func NotFound(content ...string) AbortError { + return Abort(http.StatusNotFound, content...) +} + +func NotSupported(content ...string) AbortError { + return Abort(http.StatusMethodNotAllowed, content...) +} + +func InternalServerError(content ...string) AbortError { + return Abort(http.StatusInternalServerError, content...) +} + +func Forbidden(content ...string) AbortError { + return Abort(http.StatusForbidden, content...) +} + +func Unauthorized(content ...string) AbortError { + return Abort(http.StatusUnauthorized, content...) +} + +// default errorhandler, you can use your self handler +func Errors() HandlerFunc { + return func(ctx *Context) { + switch res := ctx.Result.(type) { + case AbortError: + ctx.WriteHeader(res.Code()) + ctx.Write([]byte(res.Error())) + case error: + ctx.WriteHeader(http.StatusInternalServerError) + ctx.Write([]byte(res.Error())) + default: + ctx.WriteHeader(http.StatusInternalServerError) + ctx.Write([]byte(http.StatusText(http.StatusInternalServerError))) + } + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/error_test.go b/Godeps/_workspace/src/github.com/lunny/tango/error_test.go new file mode 100644 index 0000000..94a216a --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/error_test.go @@ -0,0 +1,202 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestError1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + l := NewLogger(os.Stdout) + o := Classic(l) + o.Get("/", func(ctx *Context) error { + return NotFound() + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotFound) + refute(t, len(buff.String()), 0) +} + +func TestError2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) error { + return NotSupported() + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusMethodNotAllowed) + refute(t, len(buff.String()), 0) +} + +func TestError3(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) error { + return InternalServerError() + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusInternalServerError) + refute(t, len(buff.String()), 0) +} + +func TestError4(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Patch("/", func(ctx *Context) error { + return Forbidden() + }) + + req, err := http.NewRequest("PATCH", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusForbidden) + refute(t, len(buff.String()), 0) +} + +func TestError5(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Trace("/", func(ctx *Context) error { + return Unauthorized() + }) + + req, err := http.NewRequest("TRACE", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusUnauthorized) + refute(t, len(buff.String()), 0) +} + +var err500 = Abort(500, "error") + +func TestError6(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Head("/", func(ctx *Context) error { + return err500 + }) + + req, err := http.NewRequest("HEAD", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, err500.Code()) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), err500.Error()) +} + +func TestError7(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Head("/", func(ctx *Context) { + return + }) + + req, err := http.NewRequest("HEAD", "http://localhost:8000/11?==", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotFound) + refute(t, len(buff.String()), 0) +} + +var ( + prefix = "tango
" + suffix = fmt.Sprintf("
version: %s
", Version()) +) + +func TestError8(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.ErrHandler = HandlerFunc(func(ctx *Context) { + switch res := ctx.Result.(type) { + case AbortError: + ctx.WriteHeader(res.Code()) + ctx.Write([]byte(prefix)) + ctx.Write([]byte(res.Error())) + case error: + ctx.WriteHeader(http.StatusInternalServerError) + ctx.Write([]byte(prefix)) + ctx.Write([]byte(res.Error())) + default: + ctx.WriteHeader(http.StatusInternalServerError) + ctx.Write([]byte(prefix)) + ctx.Write([]byte(http.StatusText(http.StatusInternalServerError))) + } + ctx.Write([]byte(suffix)) + }) + + o.Get("/", func() error { + return NotFound() + }) + + req, err := http.NewRequest("HEAD", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotFound) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), prefix+"Not Found"+suffix) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/file.go b/Godeps/_workspace/src/github.com/lunny/tango/file.go new file mode 100644 index 0000000..9881578 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/file.go @@ -0,0 +1,25 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import "path/filepath" + +func File(path string) func(ctx *Context) { + return func(ctx *Context) { + ctx.ServeFile(path) + } +} + +func Dir(dir string) func(ctx *Context) { + return func(ctx *Context) { + params := ctx.Params() + if len(*params) <= 0 { + ctx.Result = NotFound() + ctx.HandleError() + return + } + ctx.ServeFile(filepath.Join(dir, (*params)[0].Value)) + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/file_test.go b/Godeps/_workspace/src/github.com/lunny/tango/file_test.go new file mode 100644 index 0000000..d5039b8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/file_test.go @@ -0,0 +1,69 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" +) + +func TestDir1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + tg := New() + tg.Get("/:name", Dir("./public")) + + req, err := http.NewRequest("GET", "http://localhost:8000/test.html", nil) + if err != nil { + t.Error(err) + } + + tg.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "hello tango") +} + +func TestDir2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + tg := New() + tg.Get("/", Dir("./public")) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + tg.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotFound) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), http.StatusText(http.StatusNotFound)) +} + +func TestFile1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + tg := New() + tg.Get("/test.html", File("./public/test.html")) + + req, err := http.NewRequest("GET", "http://localhost:8000/test.html", nil) + if err != nil { + t.Error(err) + } + + tg.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "hello tango") +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/form.go b/Godeps/_workspace/src/github.com/lunny/tango/form.go new file mode 100644 index 0000000..045b469 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/form.go @@ -0,0 +1,229 @@ +package tango + +import ( + "errors" + "net/http" + "strconv" + "html/template" +) + +type Forms http.Request + +var _ Set = &Forms{} + +func (f *Forms) String(key string) (string, error) { + (*http.Request)(f).ParseForm() + if v, ok := (*http.Request)(f).Form[key]; ok { + return v[0], nil + } + return "", errors.New("not exist") +} + +func (f *Forms) Strings(key string) ([]string, error) { + (*http.Request)(f).ParseForm() + if v, ok := (*http.Request)(f).Form[key]; ok { + return v, nil + } + return nil, errors.New("not exist") +} + +func (f *Forms) Escape(key string) (string, error) { + (*http.Request)(f).ParseForm() + if v, ok := (*http.Request)(f).Form[key]; ok { + return template.HTMLEscapeString(v[0]), nil + } + return "", errors.New("not exist") +} + +func (f *Forms) Int(key string) (int, error) { + return strconv.Atoi((*http.Request)(f).FormValue(key)) +} + +func (f *Forms) Int32(key string) (int32, error) { + v, err := strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 32) + return int32(v), err +} + +func (f *Forms) Int64(key string) (int64, error) { + return strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 64) +} + +func (f *Forms) Uint(key string) (uint, error) { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) + return uint(v), err +} + +func (f *Forms) Uint32(key string) (uint32, error) { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 32) + return uint32(v), err +} + +func (f *Forms) Uint64(key string) (uint64, error) { + return strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) +} + +func (f *Forms) Bool(key string) (bool, error) { + return strconv.ParseBool((*http.Request)(f).FormValue(key)) +} + +func (f *Forms) Float32(key string) (float32, error) { + v, err := strconv.ParseFloat((*http.Request)(f).FormValue(key), 64) + return float32(v), err +} + +func (f *Forms) Float64(key string) (float64, error) { + return strconv.ParseFloat((*http.Request)(f).FormValue(key), 64) +} + +func (f *Forms) MustString(key string, defaults ...string) string { + (*http.Request)(f).ParseForm() + if v, ok := (*http.Request)(f).Form[key]; ok { + return v[0] + } + if len(defaults) > 0 { + return defaults[0] + } + return "" +} + +func (f *Forms) MustStrings(key string, defaults ...[]string) []string { + (*http.Request)(f).ParseForm() + if v, ok := (*http.Request)(f).Form[key]; ok { + return v + } + if len(defaults) > 0 { + return defaults[0] + } + return []string{} +} + +func (f *Forms) MustEscape(key string, defaults ...string) string { + (*http.Request)(f).ParseForm() + if v, ok := (*http.Request)(f).Form[key]; ok { + return template.HTMLEscapeString(v[0]) + } + if len(defaults) > 0 { + return defaults[0] + } + return "" +} + +func (f *Forms) MustInt(key string, defaults ...int) int { + v, err := strconv.Atoi((*http.Request)(f).FormValue(key)) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (f *Forms) MustInt32(key string, defaults ...int32) int32 { + v, err := strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return int32(v) +} + +func (f *Forms) MustInt64(key string, defaults ...int64) int64 { + v, err := strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (f *Forms) MustUint(key string, defaults ...uint) uint { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return uint(v) +} + +func (f *Forms) MustUint32(key string, defaults ...uint32) uint32 { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return uint32(v) +} + +func (f *Forms) MustUint64(key string, defaults ...uint64) uint64 { + v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (f *Forms) MustFloat32(key string, defaults ...float32) float32 { + v, err := strconv.ParseFloat((*http.Request)(f).FormValue(key), 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return float32(v) +} + +func (f *Forms) MustFloat64(key string, defaults ...float64) float64 { + v, err := strconv.ParseFloat((*http.Request)(f).FormValue(key), 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (f *Forms) MustBool(key string, defaults ...bool) bool { + v, err := strconv.ParseBool((*http.Request)(f).FormValue(key)) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (ctx *Context) Form(key string, defaults ...string) string { + return (*Forms)(ctx.req).MustString(key, defaults...) +} + +func (ctx *Context) FormStrings(key string, defaults ...[]string) []string { + return (*Forms)(ctx.req).MustStrings(key, defaults...) +} + +func (ctx *Context) FormEscape(key string, defaults ...string) string { + return (*Forms)(ctx.req).MustEscape(key, defaults...) +} + +func (ctx *Context) FormInt(key string, defaults ...int) int { + return (*Forms)(ctx.req).MustInt(key, defaults...) +} + +func (ctx *Context) FormInt32(key string, defaults ...int32) int32 { + return (*Forms)(ctx.req).MustInt32(key, defaults...) +} + +func (ctx *Context) FormInt64(key string, defaults ...int64) int64 { + return (*Forms)(ctx.req).MustInt64(key, defaults...) +} + +func (ctx *Context) FormUint(key string, defaults ...uint) uint { + return (*Forms)(ctx.req).MustUint(key, defaults...) +} + +func (ctx *Context) FormUint32(key string, defaults ...uint32) uint32 { + return (*Forms)(ctx.req).MustUint32(key, defaults...) +} + +func (ctx *Context) FormUint64(key string, defaults ...uint64) uint64 { + return (*Forms)(ctx.req).MustUint64(key, defaults...) +} + +func (ctx *Context) FormFloat32(key string, defaults ...float32) float32 { + return (*Forms)(ctx.req).MustFloat32(key, defaults...) +} + +func (ctx *Context) FormFloat64(key string, defaults ...float64) float64 { + return (*Forms)(ctx.req).MustFloat64(key, defaults...) +} + +func (ctx *Context) FormBool(key string, defaults ...bool) bool { + return (*Forms)(ctx.req).MustBool(key, defaults...) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/form_test.go b/Godeps/_workspace/src/github.com/lunny/tango/form_test.go new file mode 100644 index 0000000..ac93377 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/form_test.go @@ -0,0 +1,849 @@ +package tango + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +type Form1Action struct { + Ctx +} + +func (a *Form1Action) Get() string { + v, _ := a.Forms().Int("test") + return fmt.Sprintf("%d", v) +} + +func TestForm1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form1Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form2Action struct { + Ctx +} + +func (a *Form2Action) Get() string { + v, _ := a.Forms().Int32("test") + return fmt.Sprintf("%d", v) +} + +func TestForm2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form2Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form3Action struct { + Ctx +} + +func (a *Form3Action) Get() string { + v, _ := a.Forms().Int64("test") + return fmt.Sprintf("%d", v) +} + +func TestForm3(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form3Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form4Action struct { + Ctx +} + +func (a *Form4Action) Get() string { + v, _ := a.Forms().Uint("test") + return fmt.Sprintf("%d", v) +} + +func TestForm4(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form4Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form5Action struct { + Ctx +} + +func (a *Form5Action) Get() string { + v, _ := a.Forms().Uint32("test") + return fmt.Sprintf("%d", v) +} + +func TestForm5(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form5Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form6Action struct { + Ctx +} + +func (a *Form6Action) Get() string { + v, _ := a.Forms().Uint64("test") + return fmt.Sprintf("%d", v) +} + +func TestForm6(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form6Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form7Action struct { + Ctx +} + +func (a *Form7Action) Get() string { + v, _ := a.Forms().Float32("test") + return fmt.Sprintf("%.2f", v) +} + +func TestForm7(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form7Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Form8Action struct { + Ctx +} + +func (a *Form8Action) Get() string { + v, _ := a.Forms().Float64("test") + return fmt.Sprintf("%.2f", v) +} + +func TestForm8(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form8Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Form9Action struct { + Ctx +} + +func (a *Form9Action) Get() string { + v, _ := a.Forms().Bool("test") + return fmt.Sprintf("%v", v) +} + +func TestForm9(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form9Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Form10Action struct { + Ctx +} + +func (a *Form10Action) Get() string { + v, _ := a.Forms().String("test") + return fmt.Sprintf("%v", v) +} + +func TestForm10(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form10Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form11Action struct { + Ctx +} + +func (a *Form11Action) Get() string { + v := a.Forms().MustInt("test") + return fmt.Sprintf("%d", v) +} + +func TestForm11(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form11Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form12Action struct { + Ctx +} + +func (a *Form12Action) Get() string { + v := a.Forms().MustInt32("test") + return fmt.Sprintf("%d", v) +} + +func TestForm12(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form12Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form13Action struct { + Ctx +} + +func (a *Form13Action) Get() string { + v := a.Forms().MustInt64("test") + return fmt.Sprintf("%d", v) +} + +func TestForm13(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form13Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form14Action struct { + Ctx +} + +func (a *Form14Action) Get() string { + v := a.Forms().MustUint("test") + return fmt.Sprintf("%d", v) +} + +func TestForm14(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form14Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form15Action struct { + Ctx +} + +func (a *Form15Action) Get() string { + v := a.Forms().MustUint32("test") + return fmt.Sprintf("%d", v) +} + +func TestForm15(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form15Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form16Action struct { + Ctx +} + +func (a *Form16Action) Get() string { + v := a.Forms().MustUint64("test") + return fmt.Sprintf("%d", v) +} + +func TestForm16(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form16Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form17Action struct { + Ctx +} + +func (a *Form17Action) Get() string { + v := a.Forms().MustFloat32("test") + return fmt.Sprintf("%.2f", v) +} + +func TestForm17(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form17Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Form18Action struct { + Ctx +} + +func (a *Form18Action) Get() string { + v := a.Forms().MustFloat64("test") + return fmt.Sprintf("%.2f", v) +} + +func TestForm18(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form18Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Form19Action struct { + Ctx +} + +func (a *Form19Action) Get() string { + v := a.Forms().MustBool("test") + return fmt.Sprintf("%v", v) +} + +func TestForm19(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form19Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Form20Action struct { + Ctx +} + +func (a *Form20Action) Get() string { + v := a.Forms().MustString("test") + return fmt.Sprintf("%s", v) +} + +func TestForm20(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form20Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form21Action struct { + Ctx +} + +func (a *Form21Action) Get() string { + v := a.FormInt("test") + return fmt.Sprintf("%d", v) +} + +func TestForm21(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form21Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form22Action struct { + Ctx +} + +func (a *Form22Action) Get() string { + v := a.FormInt32("test") + return fmt.Sprintf("%d", v) +} + +func TestForm22(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form22Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form23Action struct { + Ctx +} + +func (a *Form23Action) Get() string { + v := a.FormInt64("test") + return fmt.Sprintf("%d", v) +} + +func TestForm23(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form23Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form24Action struct { + Ctx +} + +func (a *Form24Action) Get() string { + v := a.FormUint("test") + return fmt.Sprintf("%d", v) +} + +func TestForm24(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form24Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form25Action struct { + Ctx +} + +func (a *Form25Action) Get() string { + v := a.FormUint32("test") + return fmt.Sprintf("%d", v) +} + +func TestForm25(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form25Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form26Action struct { + Ctx +} + +func (a *Form26Action) Get() string { + v := a.FormUint64("test") + return fmt.Sprintf("%d", v) +} + +func TestForm26(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form26Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Form27Action struct { + Ctx +} + +func (a *Form27Action) Get() string { + v := a.FormFloat32("test") + return fmt.Sprintf("%.2f", v) +} + +func TestForm27(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form27Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Form28Action struct { + Ctx +} + +func (a *Form28Action) Get() string { + v := a.FormFloat64("test") + return fmt.Sprintf("%.2f", v) +} + +func TestForm28(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form28Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Form29Action struct { + Ctx +} + +func (a *Form29Action) Get() string { + v := a.FormBool("test") + return fmt.Sprintf("%v", v) +} + +func TestForm29(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form29Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Form30Action struct { + Ctx +} + +func (a *Form30Action) Get() string { + v := a.Form("test") + return fmt.Sprintf("%s", v) +} + +func TestForm30(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(Form30Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/?test=1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/group.go b/Godeps/_workspace/src/github.com/lunny/tango/group.go new file mode 100644 index 0000000..62195c3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/group.go @@ -0,0 +1,108 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +type groupRouter struct { + methods []string + url string + c interface{} +} + +type Group struct { + routers []groupRouter + handlers []Handler +} + +func NewGroup() *Group { + return &Group{ + routers: make([]groupRouter, 0), + handlers: make([]Handler, 0), + } +} + +func (g *Group) Use(handlers ...Handler) { + g.handlers = append(g.handlers, handlers...) +} + +func (g *Group) Get(url string, c interface{}) { + g.Route([]string{"GET", "HEAD"}, url, c) +} + +func (g *Group) Post(url string, c interface{}) { + g.Route([]string{"POST"}, url, c) +} + +func (g *Group) Head(url string, c interface{}) { + g.Route([]string{"HEAD"}, url, c) +} + +func (g *Group) Options(url string, c interface{}) { + g.Route([]string{"OPTIONS"}, url, c) +} + +func (g *Group) Trace(url string, c interface{}) { + g.Route([]string{"TRACE"}, url, c) +} + +func (g *Group) Patch(url string, c interface{}) { + g.Route([]string{"PATCH"}, url, c) +} + +func (g *Group) Delete(url string, c interface{}) { + g.Route([]string{"DELETE"}, url, c) +} + +func (g *Group) Put(url string, c interface{}) { + g.Route([]string{"PUT"}, url, c) +} + +func (g *Group) Any(url string, c interface{}) { + g.Route(SupportMethods, url, c) +} + +func (g *Group) Route(methods []string, url string, c interface{}) { + g.routers = append(g.routers, groupRouter{methods, url, c}) +} + +func (g *Group) Group(p string, o interface{}) { + gr := getGroup(o) + for _, gchild := range gr.routers { + g.Route(gchild.methods, joinRoute(p, gchild.url), gchild.c) + } +} + +func getGroup(o interface{}) *Group { + var g *Group + var gf func(*Group) + var ok bool + if g, ok = o.(*Group); ok { + } else if gf, ok = o.(func(*Group)); ok { + g = NewGroup() + gf(g) + } else { + panic("not allowed group parameter") + } + return g +} + +func joinRoute(p, url string) string { + if len(p) == 0 || p == "/" { + return url + } + return p + url +} + +func (t *Tango) addGroup(p string, g *Group) { + for _, r := range g.routers { + t.Route(r.methods, joinRoute(p, r.url), r.c) + } + for _, h := range g.handlers { + t.Use(Prefix(p, h)) + } +} + +func (t *Tango) Group(p string, o interface{}) { + t.addGroup(p, getGroup(o)) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/group_test.go b/Godeps/_workspace/src/github.com/lunny/tango/group_test.go new file mode 100644 index 0000000..96762a4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/group_test.go @@ -0,0 +1,279 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGroup1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Group("/api", func(g *Group) { + g.Get("/1", func() string { + return "/1" + }) + g.Post("/2", func() string { + return "/2" + }) + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/api/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/1") + + buff.Reset() + req, err = http.NewRequest("POST", "http://localhost:8000/api/2", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/2") +} + +func TestGroup2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + g := NewGroup() + g.Any("/1", func() string { + return "/1" + }) + g.Options("/2", func() string { + return "/2" + }) + + o := Classic() + o.Group("/api", g) + + req, err := http.NewRequest("GET", "http://localhost:8000/api/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/1") + + buff.Reset() + req, err = http.NewRequest("OPTIONS", "http://localhost:8000/api/2", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/2") +} + +func TestGroup3(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Group("/api", func(g *Group) { + g.Group("/v1", func(cg *Group) { + cg.Trace("/1", func() string { + return "/1" + }) + cg.Patch("/2", func() string { + return "/2" + }) + }) + }) + + req, err := http.NewRequest("TRACE", "http://localhost:8000/api/v1/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/1") + + buff.Reset() + req, err = http.NewRequest("PATCH", "http://localhost:8000/api/v1/2", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/2") +} + +func TestGroup4(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Group("", func(g *Group) { + g.Delete("/api/1", func() string { + return "/1" + }) + g.Head("/api/2", func() string { + return "/2" + }) + }) + + req, err := http.NewRequest("DELETE", "http://localhost:8000/api/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/1") + + buff.Reset() + req, err = http.NewRequest("HEAD", "http://localhost:8000/api/2", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/2") +} + +func TestGroup5(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + var handlerGroup bool + o.Group("/api", func(g *Group) { + g.Use(HandlerFunc(func(ctx *Context) { + handlerGroup = true + ctx.Next() + })) + g.Put("/1", func() string { + return "/1" + }) + }) + o.Post("/2", func() string { + return "/2" + }) + + req, err := http.NewRequest("PUT", "http://localhost:8000/api/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/1") + expect(t, handlerGroup, true) + + handlerGroup = false + buff.Reset() + req, err = http.NewRequest("POST", "http://localhost:8000/2", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/2") + expect(t, handlerGroup, false) +} + +func TestGroup6(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + var handlerGroup bool + g := NewGroup() + g.Use(HandlerFunc(func(ctx *Context) { + handlerGroup = true + ctx.Next() + })) + g.Get("/1", func() string { + return "/1" + }) + + o := Classic() + o.Group("/api", g) + o.Post("/2", func() string { + return "/2" + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/api/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/1") + expect(t, handlerGroup, true) + + handlerGroup = false + buff.Reset() + req, err = http.NewRequest("POST", "http://localhost:8000/2", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "/2") + expect(t, handlerGroup, false) +} + +func TestGroup7(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + var isPanic bool + defer func() { + if err := recover(); err != nil { + isPanic = true + } + expect(t, isPanic, true) + }() + + o := Classic() + o.Group("/api", func() { + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/api/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/logger.go b/Godeps/_workspace/src/github.com/lunny/tango/logger.go new file mode 100644 index 0000000..d4a5a09 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/logger.go @@ -0,0 +1,76 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "io" + "time" + + "github.com/lunny/log" +) + +type Logger interface { + Debugf(format string, v ...interface{}) + Debug(v ...interface{}) + Infof(format string, v ...interface{}) + Info(v ...interface{}) + Warnf(format string, v ...interface{}) + Warn(v ...interface{}) + Errorf(format string, v ...interface{}) + Error(v ...interface{}) +} + +func NewLogger(out io.Writer) Logger { + l := log.New(out, "[tango] ", log.Ldefault()) + l.SetOutputLevel(log.Ldebug) + return l +} + +type LogInterface interface { + SetLogger(Logger) +} + +type Log struct { + Logger +} + +func (l *Log) SetLogger(log Logger) { + l.Logger = log +} + +func Logging() HandlerFunc { + return func(ctx *Context) { + start := time.Now() + p := ctx.Req().URL.Path + if len(ctx.Req().URL.RawQuery) > 0 { + p = p + "?" + ctx.Req().URL.RawQuery + } + + ctx.Debug("Started", ctx.Req().Method, p, "for", ctx.Req().RemoteAddr) + + if action := ctx.Action(); action != nil { + if l, ok := action.(LogInterface); ok { + l.SetLogger(ctx.Logger) + } + } + + ctx.Next() + + if !ctx.Written() { + if ctx.Result == nil { + ctx.Result = NotFound() + } + ctx.HandleError() + } + + statusCode := ctx.Status() + + if statusCode >= 200 && statusCode < 400 { + ctx.Info(ctx.Req().Method, statusCode, time.Since(start), p) + } else { + ctx.Error(ctx.Req().Method, statusCode, time.Since(start), p, ctx.Result) + } + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/logger_test.go b/Godeps/_workspace/src/github.com/lunny/tango/logger_test.go new file mode 100644 index 0000000..2214013 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/logger_test.go @@ -0,0 +1,61 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/lunny/log" +) + +func TestLogger1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + + n := NewWithLog(log.New(buff, "[tango] ", 0)) + n.Use(Logging()) + n.UseHandler(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.WriteHeader(http.StatusNotFound) + })) + + req, err := http.NewRequest("GET", "http://localhost:3000/foobar", nil) + if err != nil { + t.Error(err) + } + + n.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotFound) + refute(t, len(buff.String()), 0) +} + +type LoggerAction struct { + Log +} + +func (l *LoggerAction) Get() string { + return "log" +} + +func TestLogger2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + n := Classic() + n.Get("/", new(LoggerAction)) + + req, err := http.NewRequest("GET", "http://localhost:3000/", nil) + if err != nil { + t.Error(err) + } + + n.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "log") +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/logo.png b/Godeps/_workspace/src/github.com/lunny/tango/logo.png new file mode 100644 index 0000000..eefc76f Binary files /dev/null and b/Godeps/_workspace/src/github.com/lunny/tango/logo.png differ diff --git a/Godeps/_workspace/src/github.com/lunny/tango/param.go b/Godeps/_workspace/src/github.com/lunny/tango/param.go new file mode 100644 index 0000000..7ded00e --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/param.go @@ -0,0 +1,349 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "errors" + "strconv" + "html/template" +) + +type ( + param struct { + Name string + Value string + } + Params []param +) + +var _ Set = &Params{} + +func (p *Params) Get(key string) string { + if len(key) == 0 { + return "" + } + if key[0] != ':' && key[0] != '*' { + key = ":"+ key + } + + for _, v := range *p { + if v.Name == key { + return v.Value + } + } + return "" +} + +func (p *Params) String(key string) (string, error) { + if len(key) == 0 { + return "", errors.New("not exist") + } + if key[0] != ':' && key[0] != '*' { + key = ":"+ key + } + + for _, v := range *p { + if v.Name == key { + return v.Value, nil + } + } + return "", errors.New("not exist") +} + +func (p *Params) Strings(key string) ([]string, error) { + if len(key) == 0 { + return nil, errors.New("not exist") + } + if key[0] != ':' && key[0] != '*' { + key = ":"+ key + } + + var s = make([]string, 0) + for _, v := range *p { + if v.Name == key { + s = append(s, v.Value) + } + } + if len(s) > 0 { + return s, nil + } + return nil, errors.New("not exist") +} + +func (p *Params) Escape(key string) (string, error) { + if len(key) == 0 { + return "", errors.New("not exist") + } + if key[0] != ':' && key[0] != '*' { + key = ":"+ key + } + + for _, v := range *p { + if v.Name == key { + return template.HTMLEscapeString(v.Value), nil + } + } + return "", errors.New("not exist") +} + +func (p *Params) Int(key string) (int, error) { + return strconv.Atoi(p.Get(key)) +} + +func (p *Params) Int32(key string) (int32, error) { + v, err := strconv.ParseInt(p.Get(key), 10, 32) + return int32(v), err +} + +func (p *Params) Int64(key string) (int64, error) { + return strconv.ParseInt(p.Get(key), 10, 64) +} + +func (p *Params) Uint(key string) (uint, error) { + v, err := strconv.ParseUint(p.Get(key), 10, 64) + return uint(v), err +} + +func (p *Params) Uint32(key string) (uint32, error) { + v, err := strconv.ParseUint(p.Get(key), 10, 32) + return uint32(v), err +} + +func (p *Params) Uint64(key string) (uint64, error) { + return strconv.ParseUint(p.Get(key), 10, 64) +} + +func (p *Params) Bool(key string) (bool, error) { + return strconv.ParseBool(p.Get(key)) +} + +func (p *Params) Float32(key string) (float32, error) { + v, err := strconv.ParseFloat(p.Get(key), 32) + return float32(v), err +} + +func (p *Params) Float64(key string) (float64, error) { + return strconv.ParseFloat(p.Get(key), 64) +} + +func (p *Params) MustString(key string, defaults ...string) string { + if len(key) == 0 { + return "" + } + if key[0] != ':' && key[0] != '*' { + key = ":"+ key + } + + for _, v := range *p { + if v.Name == key { + return v.Value + } + } + if len(defaults) > 0 { + return defaults[0] + } + return "" +} + +func (p *Params) MustStrings(key string, defaults ...[]string) []string { + if len(key) == 0 { + return []string{} + } + if key[0] != ':' && key[0] != '*' { + key = ":"+ key + } + + var s = make([]string, 0) + for _, v := range *p { + if v.Name == key { + s = append(s, v.Value) + } + } + if len(s) > 0 { + return s + } + if len(defaults) > 0 { + return defaults[0] + } + return []string{} +} + +func (p *Params) MustEscape(key string, defaults ...string) string { + if len(key) == 0 { + return "" + } + if key[0] != ':' && key[0] != '*' { + key = ":"+ key + } + + for _, v := range *p { + if v.Name == key { + return template.HTMLEscapeString(v.Value) + } + } + if len(defaults) > 0 { + return defaults[0] + } + return "" +} + +func (p *Params) MustInt(key string, defaults ...int) int { + v, err := strconv.Atoi(p.Get(key)) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return v +} + +func (p *Params) MustInt32(key string, defaults ...int32) int32 { + r, err := strconv.ParseInt(p.Get(key), 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + + return int32(r) +} + +func (p *Params) MustInt64(key string, defaults ...int64) int64 { + r, err := strconv.ParseInt(p.Get(key), 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return r +} + +func (p *Params) MustUint(key string, defaults ...uint) uint { + v, err := strconv.ParseUint(p.Get(key), 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return uint(v) +} + +func (p *Params) MustUint32(key string, defaults ...uint32) uint32 { + r, err := strconv.ParseUint(p.Get(key), 10, 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + + return uint32(r) +} + +func (p *Params) MustUint64(key string, defaults ...uint64) uint64 { + r, err := strconv.ParseUint(p.Get(key), 10, 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return r +} + +func (p *Params) MustFloat32(key string, defaults ...float32) float32 { + r, err := strconv.ParseFloat(p.Get(key), 32) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return float32(r) +} + +func (p *Params) MustFloat64(key string, defaults ...float64) float64 { + r, err := strconv.ParseFloat(p.Get(key), 64) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return r +} + +func (p *Params) MustBool(key string, defaults ...bool) bool { + r, err := strconv.ParseBool(p.Get(key)) + if len(defaults) > 0 && err != nil { + return defaults[0] + } + return r +} + +func (ctx *Context) Param(key string, defaults ...string) string { + return ctx.Params().MustString(key, defaults...) +} + +func (ctx *Context) ParamStrings(key string, defaults ...[]string) []string { + return ctx.Params().MustStrings(key, defaults...) +} + +func (ctx *Context) ParamEscape(key string, defaults ...string) string { + return ctx.Params().MustEscape(key, defaults...) +} + +func (ctx *Context) ParamInt(key string, defaults ...int) int { + return ctx.Params().MustInt(key, defaults...) +} + +func (ctx *Context) ParamInt32(key string, defaults ...int32) int32 { + return ctx.Params().MustInt32(key, defaults...) +} + +func (ctx *Context) ParamInt64(key string, defaults ...int64) int64 { + return ctx.Params().MustInt64(key, defaults...) +} + +func (ctx *Context) ParamUint(key string, defaults ...uint) uint { + return ctx.Params().MustUint(key, defaults...) +} + +func (ctx *Context) ParamUint32(key string, defaults ...uint32) uint32 { + return ctx.Params().MustUint32(key, defaults...) +} + +func (ctx *Context) ParamUint64(key string, defaults ...uint64) uint64 { + return ctx.Params().MustUint64(key, defaults...) +} + +func (ctx *Context) ParamFloat32(key string, defaults ...float32) float32 { + return ctx.Params().MustFloat32(key, defaults...) +} + +func (ctx *Context) ParamFloat64(key string, defaults ...float64) float64 { + return ctx.Params().MustFloat64(key, defaults...) +} + +func (ctx *Context) ParamBool(key string, defaults ...bool) bool { + return ctx.Params().MustBool(key, defaults...) +} + +func (p *Params) Set(key, value string) { + if len(key) == 0 { + return + } + if key[0] != ':' && key[0] != '*' { + key = ":"+ key + } + + for i, v := range *p { + if v.Name == key { + (*p)[i].Value = value + return + } + } + + *p = append(*p, param{key, value}) +} + +type Paramer interface { + SetParams([]param) +} + +func (p *Params) SetParams(params []param) { + *p = params +} + +func Param() HandlerFunc { + return func(ctx *Context) { + if action := ctx.Action(); action != nil { + if p, ok := action.(Paramer); ok { + p.SetParams(*ctx.Params()) + } + } + ctx.Next() + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/param_test.go b/Godeps/_workspace/src/github.com/lunny/tango/param_test.go new file mode 100644 index 0000000..67c0f78 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/param_test.go @@ -0,0 +1,1008 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +type ParamAction struct { + Params +} + +func (p *ParamAction) Get() string { + return p.Params.Get(":name") +} + +func TestParams1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/:name", new(ParamAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/foobar", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "foobar") +} + +type Param2Action struct { + Params +} + +func (p *Param2Action) Get() string { + return p.Params.Get(":name") +} + +func TestParams2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name.*)", new(Param2Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/foobar", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "foobar") +} + +type Param3Action struct { + Ctx +} + +func (p *Param3Action) Get() string { + fmt.Println(p.params) + p.Params().Set(":name", "name") + fmt.Println(p.params) + return p.Params().Get(":name") +} + +func TestParams3(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name.*)", new(Param3Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/foobar", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "name") +} + +type Param4Action struct { + Params +} + +func (p *Param4Action) Get() string { + fmt.Println(p.Params) + p.Params.Set(":name", "name") + fmt.Println(p.Params) + return p.Params.Get(":name") +} + +func TestParams4(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name.*)", new(Param4Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/foobar", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "name") +} + +func TestParams5(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) { + ctx.Params().Set(":name", "test") + ctx.Write([]byte(ctx.Params().Get(":name"))) + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + expect(t, buff.String(), "test") +} + +func TestParams6(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func(ctx *Context) { + ctx.Write([]byte(ctx.Params().Get(":name"))) + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + expect(t, buff.String(), "") +} + +type Param5Action struct { + Params +} + +func (p *Param5Action) Get() string { + i, _ := p.Params.Int(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams7(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param5Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param6Action struct { + Params +} + +func (p *Param6Action) Get() string { + i, _ := p.Params.Int32(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams8(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param6Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param7Action struct { + Params +} + +func (p *Param7Action) Get() string { + i, _ := p.Params.Int64(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams9(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param7Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param8Action struct { + Params +} + +func (p *Param8Action) Get() string { + i, _ := p.Params.Float32(":name") + return fmt.Sprintf("%.2f", i) +} + +func TestParams10(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param8Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Param9Action struct { + Params +} + +func (p *Param9Action) Get() string { + i, _ := p.Params.Float64(":name") + return fmt.Sprintf("%.2f", i) +} + +func TestParams11(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param9Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Param10Action struct { + Params +} + +func (p *Param10Action) Get() string { + i, _ := p.Params.Uint(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams12(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param10Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param11Action struct { + Params +} + +func (p *Param11Action) Get() string { + i, _ := p.Params.Uint32(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams13(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param11Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param12Action struct { + Params +} + +func (p *Param12Action) Get() string { + i, _ := p.Params.Uint64(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams14(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param12Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param13Action struct { + Params +} + +func (p *Param13Action) Get() string { + i, _ := p.Params.Bool(":name") + return fmt.Sprintf("%v", i) +} + +func TestParams15(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param13Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Param14Action struct { + Params +} + +func (p *Param14Action) Get() string { + i := p.Params.MustInt(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams16(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param14Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param15Action struct { + Params +} + +func (p *Param15Action) Get() string { + i := p.Params.MustInt32(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams17(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param15Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param16Action struct { + Params +} + +func (p *Param16Action) Get() string { + i := p.Params.MustInt64(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams18(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param16Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param17Action struct { + Params +} + +func (p *Param17Action) Get() string { + i := p.Params.MustFloat32(":name") + return fmt.Sprintf("%.2f", i) +} + +func TestParams19(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param17Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Param18Action struct { + Params +} + +func (p *Param18Action) Get() string { + i := p.Params.MustFloat64(":name") + return fmt.Sprintf("%.2f", i) +} + +func TestParams20(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param18Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Param19Action struct { + Params +} + +func (p *Param19Action) Get() string { + i := p.Params.MustUint(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams21(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param19Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param20Action struct { + Params +} + +func (p *Param20Action) Get() string { + i := p.Params.MustUint32(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams22(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param20Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param21Action struct { + Params +} + +func (p *Param21Action) Get() string { + i := p.Params.MustUint64(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams23(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param21Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param22Action struct { + Params +} + +func (p *Param22Action) Get() string { + i := p.Params.MustBool(":name") + return fmt.Sprintf("%v", i) +} + +func TestParams24(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param22Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Param23Action struct { + Params +} + +func (p *Param23Action) Get() string { + i, _ := p.Params.String(":name") + return fmt.Sprintf("%v", i) +} + +func TestParams25(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param23Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param24Action struct { + Params +} + +func (p *Param24Action) Get() string { + i := p.Params.MustString(":name") + return fmt.Sprintf("%v", i) +} + +func TestParams26(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param24Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param25Action struct { + Ctx +} + +func (p *Param25Action) Get() string { + i := p.ParamInt(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams27(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param25Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param26Action struct { + Ctx +} + +func (p *Param26Action) Get() string { + i := p.ParamInt32(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams28(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param26Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param27Action struct { + Ctx +} + +func (p *Param27Action) Get() string { + i := p.ParamInt64(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams29(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param27Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param28Action struct { + Ctx +} + +func (p *Param28Action) Get() string { + i := p.ParamFloat32(":name") + return fmt.Sprintf("%.2f", i) +} + +func TestParams30(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param28Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Param29Action struct { + Ctx +} + +func (p *Param29Action) Get() string { + i := p.ParamFloat64(":name") + return fmt.Sprintf("%.2f", i) +} + +func TestParams31(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param29Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1.00") +} + +type Param30Action struct { + Ctx +} + +func (p *Param30Action) Get() string { + i := p.ParamUint(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams32(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param30Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param31Action struct { + Ctx +} + +func (p *Param31Action) Get() string { + i := p.ParamUint32(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams33(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param31Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param32Action struct { + Ctx +} + +func (p *Param32Action) Get() string { + i := p.ParamUint64(":name") + return fmt.Sprintf("%d", i) +} + +func TestParams34(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param32Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} + +type Param33Action struct { + Ctx +} + +func (p *Param33Action) Get() string { + i := p.ParamBool(":name") + return fmt.Sprintf("%v", i) +} + +func TestParams35(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param33Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "true") +} + +type Param34Action struct { + Ctx +} + +func (p *Param34Action) Get() string { + i := p.Param(":name") + return fmt.Sprintf("%v", i) +} + +func TestParams36(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[0-9]+)", new(Param34Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/1", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "1") +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/pool.go b/Godeps/_workspace/src/github.com/lunny/tango/pool.go new file mode 100644 index 0000000..5ac7b28 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/pool.go @@ -0,0 +1,41 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "reflect" + "sync" +) + +type pool struct { + size int + tp reflect.Type + pool reflect.Value + cur int + lock sync.Mutex +} + +func newPool(size int, tp reflect.Type) *pool { + return &pool{ + size: size, + cur: 0, + pool: reflect.MakeSlice(reflect.SliceOf(tp), size, size), + tp: reflect.SliceOf(tp), + } +} + +func (p *pool) New() reflect.Value { + p.lock.Lock() + defer func() { + p.cur++ + p.lock.Unlock() + }() + + if p.cur == p.pool.Len() { + p.pool = reflect.MakeSlice(p.tp, p.size, p.size) + p.cur = 0 + } + return p.pool.Index(p.cur).Addr() +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/pool_test.go b/Godeps/_workspace/src/github.com/lunny/tango/pool_test.go new file mode 100644 index 0000000..ef8468b --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/pool_test.go @@ -0,0 +1,47 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +type PoolAction struct { +} + +func (PoolAction) Get() string { + return "pool" +} + +func TestPool(t *testing.T) { + o := Classic() + o.Get("/", new(PoolAction)) + + var wg sync.WaitGroup + // default pool size is 800 + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "pool") + wg.Done() + }() + } + wg.Wait() +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/prefix.go b/Godeps/_workspace/src/github.com/lunny/tango/prefix.go new file mode 100644 index 0000000..a81807e --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/prefix.go @@ -0,0 +1,21 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "strings" +) + +// TODO: regex prefix +func Prefix(prefix string, handler Handler) HandlerFunc { + return func(ctx *Context) { + if strings.HasPrefix(ctx.Req().URL.Path, prefix) { + handler.Handle(ctx) + return + } + + ctx.Next() + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/prefix_test.go b/Godeps/_workspace/src/github.com/lunny/tango/prefix_test.go new file mode 100644 index 0000000..40d7c78 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/prefix_test.go @@ -0,0 +1,67 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" +) + +type PrefixAction struct { +} + +func (PrefixAction) Get() string { + return "Prefix" +} + +type NoPrefixAction struct { +} + +func (NoPrefixAction) Get() string { + return "NoPrefix" +} + +func TestPrefix(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + var isPrefix bool + + o := Classic() + o.Use(Prefix("/prefix", HandlerFunc(func(ctx *Context) { + isPrefix = true + ctx.Next() + }))) + o.Get("/prefix/t", new(PrefixAction)) + o.Get("/t", new(NoPrefixAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/prefix/t", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "Prefix") + expect(t, isPrefix, true) + + isPrefix = false + buff.Reset() + + req, err = http.NewRequest("GET", "http://localhost:8000/t", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "NoPrefix") + expect(t, isPrefix, false) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/public/cert.pem b/Godeps/_workspace/src/github.com/lunny/tango/public/cert.pem new file mode 100644 index 0000000..fa19d32 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/public/cert.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC8TCCAdugAwIBAgIRAIQ6ko7wQnT7LktA0f0iotYwCwYJKoZIhvcNAQELMBIx +EDAOBgNVBAoTB0FjbWUgQ28wHhcNMTUwNDE2MTUxNzQyWhcNMTYwNDE1MTUxNzQy +WjASMRAwDgYDVQQKEwdBY21lIENvMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEAn+nHPLXpiibd8vdizD3DmaRZdicUJ3Ftrjq8UvLWgGUsmYk05bESWT/I +vGRcOpQqhc9ugctZraFK7ihn6c+A7fGOYdXlY/wCNt12evAY8P9uxJQ4lQPY9qUs +hegJyCGD/HtUr/AdbGBzBLtmU5Z7ZOe73TzO6zmiVnLwasO2AoVvUCRCFXQkZkIk +HO/evHq03lo/z0PyBWbT6YKSM9mLFTYKUdAkC/gr5FKcGCqyMsMmewiupPtRSnln +ZrDpzp9+oGl92lkj885714BkROfP2KGTbTSlpbVfmtvcb6TQK8mfithnQzs/9jDw +LGSWBWQEdomTzPmKw07zNHYkXhcpqwIDAQABo0YwRDAOBgNVHQ8BAf8EBAMCAKAw +EwYDVR0lBAwwCgYIKwYBBQUHAwEwDAYDVR0TAQH/BAIwADAPBgNVHREECDAGhwR/ +AAABMAsGCSqGSIb3DQEBCwOCAQEAQSVmRqDi/UXXhF+qD5JxrUNG4ssBvFl91b0n +Po8W4InE4WaNqTfAxnsDwgglNMGijGp30cpp5ZdbpdHnpVEjPG9wTrSJBCoMwcnH +rnDau97FuVLff+UROghwtXHhMt1bdLMdle2rdHpyd69awNBuI4wY+Dg/BmVZq+JM +DTfGd4X5FND72JhOxx62+p2S23jmcb+0vnNvctT88ekaWVOicF8bs1XY0Pm1KzA/ +l5mIUiSl2zR6Met0b+LK3F7d4H+uhuZql0syPZ8HSAJweUAMX3L5KASB5QJFtXwO +DMWBDzqT/nEsT77Cb16tKNhVVeb6U9h2LezxB0M8+NfLL8alxw== +-----END CERTIFICATE----- diff --git a/Godeps/_workspace/src/github.com/lunny/tango/public/index.html b/Godeps/_workspace/src/github.com/lunny/tango/public/index.html new file mode 100644 index 0000000..0b32541 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/public/index.html @@ -0,0 +1 @@ +this is index.html \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/lunny/tango/public/key.pem b/Godeps/_workspace/src/github.com/lunny/tango/public/key.pem new file mode 100644 index 0000000..691f983 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/public/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAn+nHPLXpiibd8vdizD3DmaRZdicUJ3Ftrjq8UvLWgGUsmYk0 +5bESWT/IvGRcOpQqhc9ugctZraFK7ihn6c+A7fGOYdXlY/wCNt12evAY8P9uxJQ4 +lQPY9qUshegJyCGD/HtUr/AdbGBzBLtmU5Z7ZOe73TzO6zmiVnLwasO2AoVvUCRC +FXQkZkIkHO/evHq03lo/z0PyBWbT6YKSM9mLFTYKUdAkC/gr5FKcGCqyMsMmewiu +pPtRSnlnZrDpzp9+oGl92lkj885714BkROfP2KGTbTSlpbVfmtvcb6TQK8mfithn +Qzs/9jDwLGSWBWQEdomTzPmKw07zNHYkXhcpqwIDAQABAoIBAFWkA7m1yr7cFd17 +M4QiR9DOvcKTJy4AhzbZ6eWae9oDVSFc4+FnNWZqzHxoWyRcGXHUJ2CHoR1l1hU5 +unzzTh8gUJqAzPsBCcaMUFmCoDjg81d/8dWMW/Orfe6w2BxAJsle23nl5DwYY0DT +g/ecDbV6jZfsavx6v0ABClSDP8SVDMwpgqLzLgqaKdcLzx1TKNcJXfMAGO8knXcG +YHTmQuh/SVVbNd1yjsFUMfL/wrq2Wd18vWlUV8BmicARw2BxkariUoTOs1Ze1tPj +CvBPJ51dmdz8TEbP+VQpetfw8Sj70GHdUhvysVRf5gaI5BtDgbjTpfBfbaPsVJu4 +biGkqWECgYEAzWKLtY/rVcnIB0M2A5bTjD/97Sfq5Hf7qAbYnBpnw/D2utSdYow3 +2Qo6Ng6TXCGEDhpyWfUy7jjXmBIhurO5GCNdK650JNoRWELE+TWEKg3qfQigAxBI +IRQRiARKWJ0e2we9Bxqe09EMZaA70GClQYB7hWaXZhmf4RjNBhwc07kCgYEAx1J7 +e/2IzMKbmlcCCPG7tilwQdEY+++cpgfzh/Arp/ZLZ6u6qSlye2gZkiRcm6W6H7dS +SX9S+iJKBRTm6ysk0gagAu/x1ScsPrV3yJJ5T5qRG03/ZLlAmtJzc9h04Jqi8ErQ +qFCazvgZ+g2Ls4n+fPeA8Y9SNpxHKJKhar2aYoMCgYABwUXQV1p7cS30Ye6kOTW1 +jRZuYFjxetT7qpNPQiqA0h5Jmmd94BTaFexJafZ4YxDtzewMOLwmrPWqpv0Cy2ZZ +fnPdW7BCYFqllmx4dKycb2IBj4FOhWUYY0ODFgZMm4sX9Aj5dpDE3pRsieH49dpz +pNVpXmcMyEtFcSDPXI4igQKBgDaffAe2q06x5kKdpYkd9fstz/25d8dTGvLFKxAN +2WjmLjPy8+x310/Kb3eFT3u4JxGaA4rwwaSa0P4jhETeRfDor+EeMH/hhFaLFJB6 +05PlH+8DqQHJYtMK6WjN4PnMZurDFfuKW2Jsy3GjVK2XG47TpRqN1FHy8e1Egcfm +vfBRAoGAaMKW2EYGI98m/0UhnaBf5mCqcewpU3Cusht3+9orC2PFtiQwF2JkBfF/ +zcSgoO9VsTqpVWVmUI5LaVWG/1CscSIdl4tAVxkXiUbSniufYZko9xUXPDdB3a6/ +H4vpV8K07EtKToJnA+qZC/szyTwFZx3yuqtXAEtyW5hhZ873dbI= +-----END RSA PRIVATE KEY----- diff --git a/Godeps/_workspace/src/github.com/lunny/tango/public/test.html b/Godeps/_workspace/src/github.com/lunny/tango/public/test.html new file mode 100644 index 0000000..14920d0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/public/test.html @@ -0,0 +1 @@ +hello tango \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/lunny/tango/recovery.go b/Godeps/_workspace/src/github.com/lunny/tango/recovery.go new file mode 100644 index 0000000..b9c8d32 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/recovery.go @@ -0,0 +1,41 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "fmt" + "net/http" + "runtime" +) + +func Recovery(debug bool) HandlerFunc { + return func(ctx *Context) { + defer func() { + if e := recover(); e != nil { + content := fmt.Sprintf("Handler crashed with error: %v", e) + for i := 1; ; i += 1 { + _, file, line, ok := runtime.Caller(i) + if !ok { + break + } else { + content += "\n" + } + content += fmt.Sprintf("%v %v", file, line) + } + + ctx.Logger.Error(content) + + if !ctx.Written() { + if !debug { + content = http.StatusText(http.StatusInternalServerError) + } + ctx.Result = InternalServerError(content) + } + } + }() + + ctx.Next() + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/recovery_test.go b/Godeps/_workspace/src/github.com/lunny/tango/recovery_test.go new file mode 100644 index 0000000..fa7a1cf --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/recovery_test.go @@ -0,0 +1,35 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestRecovery(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + n := NewWithLog(NewLogger(os.Stdout)) + n.Use(Recovery(true)) + n.UseHandler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + panic("here is a panic!") + })) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + n.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusInternalServerError) + refute(t, recorder.Body.Len(), 0) + refute(t, len(buff.String()), 0) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/response.go b/Godeps/_workspace/src/github.com/lunny/tango/response.go new file mode 100644 index 0000000..2c1b952 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/response.go @@ -0,0 +1,85 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bufio" + "fmt" + "net" + "net/http" +) + +// ResponseWriter is a wrapper around http.ResponseWriter that provides extra information about +// the response. It is recommended that middleware handlers use this construct to wrap a responsewriter +// if the functionality calls for it. +type ResponseWriter interface { + http.ResponseWriter + http.Flusher + http.Hijacker + // Status returns the status code of the response or 0 if the response has not been written. + Status() int + // Written returns whether or not the ResponseWriter has been written. + Written() bool + // Size returns the size of the response body. + Size() int +} + +type responseWriter struct { + http.ResponseWriter + status int + size int +} + +func (rw *responseWriter) reset(w http.ResponseWriter) { + rw.ResponseWriter = w + rw.status = 0 + rw.size = 0 +} + +func (rw *responseWriter) WriteHeader(s int) { + rw.status = s + rw.ResponseWriter.WriteHeader(s) +} + +func (rw *responseWriter) Write(b []byte) (int, error) { + if !rw.Written() { + // The status will be StatusOK if WriteHeader has not been called yet + rw.WriteHeader(http.StatusOK) + } + size, err := rw.ResponseWriter.Write(b) + rw.size += size + return size, err +} + +func (rw *responseWriter) Status() int { + return rw.status +} + +func (rw *responseWriter) Size() int { + return rw.size +} + +func (rw *responseWriter) Written() bool { + return rw.status != 0 +} + +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := rw.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("the ResponseWriter doesn't support the Hijacker interface") + } + return hijacker.Hijack() +} + +func (rw *responseWriter) CloseNotify() <-chan bool { + return rw.ResponseWriter.(http.CloseNotifier).CloseNotify() +} + +func (rw *responseWriter) Flush() { + flusher, ok := rw.ResponseWriter.(http.Flusher) + if ok { + flusher.Flush() + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/return.go b/Godeps/_workspace/src/github.com/lunny/tango/return.go new file mode 100644 index 0000000..0f9f58d --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/return.go @@ -0,0 +1,165 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "encoding/json" + "encoding/xml" + "net/http" + "reflect" +) + +type StatusResult struct { + Code int + Result interface{} +} + +const ( + AutoResponse = iota + JsonResponse + XmlResponse +) + +type ResponseTyper interface { + ResponseType() int +} + +type Json struct{} + +func (Json) ResponseType() int { + return JsonResponse +} + +type Xml struct{} + +func (Xml) ResponseType() int { + return XmlResponse +} + +func isNil(a interface{}) bool { + if a == nil { + return true + } + aa := reflect.ValueOf(a) + return !aa.IsValid() || (aa.Type().Kind() == reflect.Ptr && aa.IsNil()) +} + +type XmlError struct { + XMLName xml.Name `xml:"err"` + Content string `xml:"content"` +} + +type XmlString struct { + XMLName xml.Name `xml:"string"` + Content string `xml:"content"` +} + +func Return() HandlerFunc { + return func(ctx *Context) { + var rt int + action := ctx.Action() + if action != nil { + if i, ok := action.(ResponseTyper); ok { + rt = i.ResponseType() + } + } + + ctx.Next() + + // if no route match or has been write, then return + if action == nil || ctx.Written() { + return + } + + // if there is no return value or return nil + if isNil(ctx.Result) { + // then we return blank page + ctx.Result = "" + } + + var result = ctx.Result + if res, ok := ctx.Result.(*StatusResult); ok { + ctx.WriteHeader(res.Code) + result = res.Result + } + + if rt == JsonResponse { + encoder := json.NewEncoder(ctx) + ctx.Header().Set("Content-Type", "application/json; charset=UTF-8") + + switch res := result.(type) { + case AbortError: + ctx.WriteHeader(res.Code()) + encoder.Encode(map[string]string{ + "err": res.Error(), + }) + case error: + encoder.Encode(map[string]string{ + "err": res.Error(), + }) + case string: + encoder.Encode(map[string]string{ + "content": res, + }) + case []byte: + encoder.Encode(map[string]string{ + "content": string(res), + }) + default: + err := encoder.Encode(result) + if err != nil { + ctx.Result = err + encoder.Encode(map[string]string{ + "err": err.Error(), + }) + } + } + + return + } else if rt == XmlResponse { + encoder := xml.NewEncoder(ctx) + ctx.Header().Set("Content-Type", "application/xml; charset=UTF-8") + switch res := result.(type) { + case AbortError: + ctx.WriteHeader(res.Code()) + encoder.Encode(XmlError{ + Content: res.Error(), + }) + case error: + encoder.Encode(XmlError{ + Content: res.Error(), + }) + case string: + encoder.Encode(XmlString{ + Content: res, + }) + case []byte: + encoder.Encode(XmlString{ + Content: string(res), + }) + default: + err := encoder.Encode(result) + if err != nil { + ctx.Result = err + encoder.Encode(XmlError{ + Content: err.Error(), + }) + } + } + return + } + + switch res := result.(type) { + case AbortError, error: + ctx.HandleError() + case []byte: + ctx.WriteHeader(http.StatusOK) + ctx.Write(res) + case string: + ctx.WriteHeader(http.StatusOK) + ctx.Write([]byte(res)) + } + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/return_test.go b/Godeps/_workspace/src/github.com/lunny/tango/return_test.go new file mode 100644 index 0000000..ffd65e5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/return_test.go @@ -0,0 +1,322 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "encoding/xml" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type MyReturn struct { +} + +func (m MyReturn) Get() string { + return "string return" +} + +func (m MyReturn) Post() []byte { + return []byte("bytes return") +} + +func (m MyReturn) Put() error { + return errors.New("error return") +} + +func TestReturn(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Any("/", new(MyReturn)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "string return") + + buff.Reset() + req, err = http.NewRequest("POST", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "bytes return") +} + +func TestReturnPut(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Any("/", new(MyReturn)) + + req, err := http.NewRequest("PUT", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusInternalServerError) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "error return") +} + +type JsonReturn struct { + Json +} + +func (JsonReturn) Get() interface{} { + return map[string]interface{}{ + "test1": 1, + "test2": "2", + "test3": true, + } +} + +func TestReturnJson1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(JsonReturn)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `{"test1":1,"test2":"2","test3":true}`) +} + +type JsonErrReturn struct { + Json +} + +func (JsonErrReturn) Get() error { + return errors.New("error") +} + +func TestReturnJsonError(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(JsonErrReturn)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `{"err":"error"}`) +} + +type JsonErrReturn2 struct { + Json +} + +func (JsonErrReturn2) Get() error { + return Abort(http.StatusInternalServerError, "error") +} + +func TestReturnJsonError2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(JsonErrReturn2)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusInternalServerError) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `{"err":"error"}`) +} + +type JsonReturn1 struct { + Json +} + +func (JsonReturn1) Get() string { + return "return" +} + +func TestReturnJson2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(JsonReturn1)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `{"content":"return"}`) +} + +type JsonReturn2 struct { + Json +} + +func (JsonReturn2) Get() []byte { + return []byte("return") +} + +func TestReturnJson3(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(JsonReturn2)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `{"content":"return"}`) +} + +type JsonReturn3 struct { + Json +} + +func (JsonReturn3) Get() (int, interface{}) { + if true { + return 201, map[string]string{ + "say": "Hello tango!", + } + } + return 500, errors.New("something error") +} + +func TestReturnJson4(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(JsonReturn3)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, 201) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `{"say":"Hello tango!"}`) +} + +type XmlReturn struct { + Xml +} + +type Address struct { + City, State string +} +type Person struct { + XMLName xml.Name `xml:"person"` + Id int `xml:"id,attr"` + FirstName string `xml:"name>first"` + LastName string `xml:"name>last"` + Age int `xml:"age"` + Height float32 `xml:"height,omitempty"` + Married bool + Address + Comment string `xml:",comment"` +} + +func (XmlReturn) Get() interface{} { + v := &Person{Id: 13, FirstName: "John", LastName: "Doe", Age: 42} + v.Comment = " Need more details. " + v.Address = Address{"Hanga Roa", "Easter Island"} + return v +} + +func TestReturnXml(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(XmlReturn)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), `JohnDoe42falseHanga RoaEaster Island`) +} + +type XmlErrReturn struct { + Xml +} + +func (XmlErrReturn) Get() error { + return errors.New("error") +} + +func TestReturnXmlError(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", new(XmlErrReturn)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, strings.TrimSpace(buff.String()), `error`) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/router.go b/Godeps/_workspace/src/github.com/lunny/tango/router.go new file mode 100644 index 0000000..4f1516b --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/router.go @@ -0,0 +1,533 @@ +package tango + +import ( + "bytes" + "fmt" + "net/http" + "reflect" + "regexp" + "sort" + "strings" +) + +type RouteType byte + +const ( + FuncRoute RouteType = iota + 1 // func () + FuncHttpRoute // func (http.ResponseWriter, *http.Request) + FuncReqRoute // func (*http.Request) + FuncResponseRoute // func (http.ResponseWriter) + FuncCtxRoute // func (*tango.Context) + StructRoute // func (st) () + StructPtrRoute // func (*struct) () +) + +var ( + SupportMethods = []string{ + "GET", + "POST", + "HEAD", + "DELETE", + "PUT", + "OPTIONS", + "TRACE", + "PATCH", + } + + PoolSize = 800 +) + +// Route +type Route struct { + raw interface{} + method reflect.Value + routeType RouteType + pool *pool +} + +func NewRoute(v interface{}, t reflect.Type, + method reflect.Value, tp RouteType) *Route { + var pool *pool + if tp == StructRoute || tp == StructPtrRoute { + pool = newPool(PoolSize, t) + } + return &Route{ + raw: v, + routeType: tp, + method: method, + pool: pool, + } +} + +func (r *Route) Raw() interface{} { + return r.raw +} + +func (r *Route) Method() reflect.Value { + return r.method +} + +func (r *Route) RouteType() RouteType { + return r.routeType +} + +func (r *Route) IsStruct() bool { + return r.routeType == StructRoute || r.routeType == StructPtrRoute +} + +func (r *Route) newAction() reflect.Value { + if !r.IsStruct() { + return r.method + } + + return r.pool.New() +} + +type Router interface { + Route(methods interface{}, path string, handler interface{}) + Match(requestPath, method string) (*Route, Params) +} + +var specialBytes = []byte(`.\+*?|[]{}^$`) + +func isSpecial(ch byte) bool { + return bytes.IndexByte(specialBytes, ch) > -1 +} + +func isAlpha(ch byte) bool { + return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_' +} + +func isDigit(ch byte) bool { + return '0' <= ch && ch <= '9' +} + +func isAlnum(ch byte) bool { + return isAlpha(ch) || isDigit(ch) +} + +type ( + router struct { + trees map[string]*node + } + ntype byte + node struct { + tp ntype // Type of node it contains + handle *Route // executor + regexp *regexp.Regexp // regexp if tp is rnode + content string // static content or named + edges edges // children + path string // executor path + } + edges []*node +) + +func (e edges) Len() int { return len(e) } + +func (e edges) Swap(i, j int) { e[i], e[j] = e[j], e[i] } + +// static route will be put the first, so it will be match first. +// two static route, content longer is first. +func (e edges) Less(i, j int) bool { + if e[i].tp == snode { + if e[j].tp == snode { + return len(e[i].content) > len(e[j].content) + } + return true + } + if e[j].tp == snode { + return false + } + return i < j +} + +const ( + snode ntype = iota // static, should equal + nnode // named node, match a non-/ is ok + anode // catch-all node, match any + rnode // regex node, should match +) + +func (n *node) equal(o *node) bool { + if n.tp != o.tp || n.content != o.content { + return false + } + return true +} + +func NewRouter() (r *router) { + r = &router{ + trees: make(map[string]*node), + } + for _, m := range SupportMethods { + r.trees[m] = &node{ + edges: edges{}, + } + } + return +} + +// /:name1/:name2 /:name1-:name2 /(:name1)sss(:name2) +// /(*name) /(:name[0-9]+) /(:name[a-z]+) +func parseNodes(path string) []*node { + var i, j int + l := len(path) + var nodes = make([]*node, 0) + var bracket int + for ; i < l; i++ { + if path[i] == ':' { + nodes = append(nodes, &node{tp: snode, content: path[j : i-bracket]}) + j = i + var regex string + if bracket == 1 { + var start = -1 + for ; i < l && ')' != path[i]; i++ { + if start == -1 && isSpecial(path[i]) { + start = i + } + } + if path[i] != ')' { + panic("lack of )") + } + if start > -1 { + regex = path[start:i] + } + } else { + i = i + 1 + for ; i < l && isAlnum(path[i]); i++ { + } + } + + if len(regex) > 0 { + nodes = append(nodes, &node{tp: rnode, + regexp: regexp.MustCompile("(" + regex + ")"), + content: path[j : i-len(regex)]}) + } else { + nodes = append(nodes, &node{tp: nnode, content: path[j:i]}) + } + i = i + bracket + j = i + bracket = 0 + if i == l { + return nodes + } + } else if path[i] == '*' { + nodes = append(nodes, &node{tp: snode, content: path[j : i-bracket]}) + j = i + if bracket == 1 { + for ; i < l && ')' != path[i]; i++ { + } + } else { + i = i + 1 + for ; i < l && isAlnum(path[i]); i++ { + } + } + nodes = append(nodes, &node{tp: anode, content: path[j:i]}) + i = i + bracket + bracket = 0 + j = i + if i == l { + return nodes + } + } else if path[i] == '(' { + bracket = 1 + } else if path[i] == '/' { + if bracket == 0 && i > j { + nodes = append(nodes, &node{tp: snode, content: path[j:i]}) + j = i + } + } else { + bracket = 0 + } + } + + nodes = append(nodes, &node{ + tp: snode, + content: path[j:i], + }) + + return nodes +} + +func printNode(i int, node *node) { + for _, c := range node.edges { + for j := 0; j < i; j++ { + fmt.Print(" ") + } + if i > 1 { + fmt.Print("┗", " ") + } + + fmt.Print(c.content) + if c.handle != nil { + fmt.Print(" ", c.handle.method.Type()) + fmt.Printf(" %p", c.handle.method.Interface()) + } + fmt.Println() + printNode(i+1, c) + } +} + +func (r *router) printTrees() { + for _, method := range SupportMethods { + if len(r.trees[method].edges) > 0 { + fmt.Println(method) + printNode(1, r.trees[method]) + fmt.Println() + } + } +} + +func (r *router) addRoute(method, path string, h *Route) { + nodes := parseNodes(path) + nodes[len(nodes)-1].handle = h + nodes[len(nodes)-1].path = path + if !validNodes(nodes) { + panic(fmt.Sprintln("express", path, "is not supported")) + } + r.addnodes(method, nodes) + //r.printTrees() +} + +func (r *router) matchNode(n *node, url string, params Params) (*node, Params) { + if n.tp == snode { + if strings.HasPrefix(url, n.content) { + if len(url) == len(n.content) { + return n, params + } + for _, c := range n.edges { + e, newParams := r.matchNode(c, url[len(n.content):], params) + if e != nil { + return e, newParams + } + } + } + } else if n.tp == anode { + for _, c := range n.edges { + idx := strings.LastIndex(url, c.content) + if idx > -1 { + params = append(params, param{n.content, url[:idx]}) + return r.matchNode(c, url[idx:], params) + } + } + return n, append(params, param{n.content, url}) + } else if n.tp == nnode { + idx := strings.IndexByte(url, '/') + if idx > -1 { + for _, c := range n.edges { + h, newParams := r.matchNode(c, url[idx:], params) + if h != nil { + return h, append([]param{param{n.content, url[:idx]}}, newParams...) + } + } + return nil, params + } + + for _, c := range n.edges { + idx := strings.Index(url, c.content) + if idx > -1 { + params = append(params, param{n.content, url[:idx]}) + return r.matchNode(c, url[idx:], params) + } + } + params = append(params, param{n.content, url}) + return n, params + } else if n.tp == rnode { + idx := strings.IndexByte(url, '/') + if idx > -1 { + if n.regexp.MatchString(url[:idx]) { + for _, c := range n.edges { + h, newParams := r.matchNode(c, url[idx:], params) + if h != nil { + return h, append([]param{param{n.content, url[:idx]}}, newParams...) + } + } + } + return nil, params + } + + for _, c := range n.edges { + idx := strings.Index(url, c.content) + if idx > -1 && n.regexp.MatchString(url[:idx]) { + params = append(params, param{n.content, url[:idx]}) + return r.matchNode(c, url[idx:], params) + } + } + + if n.regexp.MatchString(url) { + params = append(params, param{n.content, url}) + return n, params + } + } + return nil, params +} + +func (r *router) Match(url, method string) (*Route, Params) { + cn := r.trees[method] + var params = make(Params, 0, strings.Count(url, "/")) + for _, n := range cn.edges { + e, newParams := r.matchNode(n, url, params) + if e != nil { + return e.handle, newParams + } + } + return nil, nil +} + +// add node nodes[i] to parent node p +func (r *router) addnode(p *node, nodes []*node, i int) *node { + if len(p.edges) == 0 { + p.edges = make([]*node, 0) + } + + for _, pc := range p.edges { + if pc.equal(nodes[i]) { + if i == len(nodes)-1 { + pc.handle = nodes[i].handle + } + return pc + } + } + + p.edges = append(p.edges, nodes[i]) + sort.Sort(p.edges) + return nodes[i] +} + +// validate parsed nodes, all non-static route should have static route children. +func validNodes(nodes []*node) bool { + if len(nodes) == 0 { + return false + } + var lastTp = nodes[0] + for _, node := range nodes[1:] { + if lastTp.tp != snode && node.tp != snode { + return false + } + lastTp = node + } + return true +} + +// add nodes to trees +func (r *router) addnodes(method string, nodes []*node) { + cn := r.trees[method] + var p *node = cn + + for i, _ := range nodes { + p = r.addnode(p, nodes, i) + } +} + +func removeStick(uri string) string { + uri = strings.TrimRight(uri, "/") + if uri == "" { + uri = "/" + } + return uri +} + +func (router *router) Route(ms interface{}, url string, c interface{}) { + vc := reflect.ValueOf(c) + if vc.Kind() == reflect.Func { + switch ms.(type) { + case string: + router.addFunc([]string{ms.(string)}, url, c) + case []string: + router.addFunc(ms.([]string), url, c) + default: + panic("unknow methods format") + } + } else if vc.Kind() == reflect.Ptr && vc.Elem().Kind() == reflect.Struct { + var methods = make(map[string]string) + switch ms.(type) { + case string: + s := strings.Split(ms.(string), ":") + if len(s) == 1 { + methods[s[0]] = strings.Title(strings.ToLower(s[0])) + } else if len(s) == 2 { + methods[s[0]] = strings.TrimSpace(s[1]) + } else { + panic("unknow methods format") + } + case []string: + for _, m := range ms.([]string) { + s := strings.Split(m, ":") + if len(s) == 1 { + methods[s[0]] = strings.Title(strings.ToLower(s[0])) + } else if len(s) == 2 { + methods[s[0]] = strings.TrimSpace(s[1]) + } else { + panic("unknow format") + } + } + case map[string]string: + methods = ms.(map[string]string) + default: + panic("unsupported methods") + } + + router.addStruct(methods, url, c) + } else { + panic("not support route type") + } +} + +/* + Tango supports 5 form funcs + + func() + func(*Context) + func(http.ResponseWriter, *http.Request) + func(http.ResponseWriter) + func(*http.Request) + + it can has or has not return value +*/ +func (router *router) addFunc(methods []string, url string, c interface{}) { + vc := reflect.ValueOf(c) + t := vc.Type() + var rt RouteType + + if t.NumIn() == 0 { + rt = FuncRoute + } else if t.NumIn() == 1 { + if t.In(0) == reflect.TypeOf(new(Context)) { + rt = FuncCtxRoute + } else if t.In(0) == reflect.TypeOf(new(http.Request)) { + rt = FuncReqRoute + } else if t.In(0).Kind() == reflect.Interface && t.In(0).Name() == "ResponseWriter" && + t.In(0).PkgPath() == "net/http" { + rt = FuncResponseRoute + } else { + panic(fmt.Sprintln("no support function type", methods, url, c)) + } + } else if t.NumIn() == 2 && + (t.In(0).Kind() == reflect.Interface && t.In(0).Name() == "ResponseWriter" && + t.In(0).PkgPath() == "net/http") && + t.In(1) == reflect.TypeOf(new(http.Request)) { + rt = FuncHttpRoute + } else { + panic(fmt.Sprintln("no support function type", methods, url, c)) + } + + var r = NewRoute(c, t, vc, rt) + url = removeStick(url) + for _, m := range methods { + router.addRoute(m, url, r) + } +} + +func (router *router) addStruct(methods map[string]string, url string, c interface{}) { + vc := reflect.ValueOf(c) + t := vc.Type().Elem() + + // added a default method Get, Post + for name, method := range methods { + if m, ok := t.MethodByName(method); ok { + router.addRoute(name, removeStick(url), NewRoute(c, t, m.Func, StructPtrRoute)) + } else if m, ok := vc.Type().MethodByName(method); ok { + router.addRoute(name, removeStick(url), NewRoute(c, t, m.Func, StructRoute)) + } + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/router_test.go b/Godeps/_workspace/src/github.com/lunny/tango/router_test.go new file mode 100644 index 0000000..c7026e9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/router_test.go @@ -0,0 +1,737 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "net/http" + "net/http/httptest" + "regexp" + "testing" +) + +type RouterNoMethodAction struct { +} + +func TestRouter1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/:name", new(RouterNoMethodAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/foobar", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotFound) + refute(t, len(buff.String()), 0) +} + +type RouterGetAction struct { +} + +func (a *RouterGetAction) Get() string { + return "get" +} + +func (a *RouterGetAction) Post() string { + return "post" +} + +func TestRouter2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/:name", new(RouterGetAction)) + o.Post("/:name", new(RouterGetAction)) + + req, err := http.NewRequest("GET", "http://localhost:8000/foobar", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "get") + + buff.Reset() + + req, err = http.NewRequest("POST", "http://localhost:8000/foobar", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "post") +} + +type RouterSpecAction struct { + a string +} + +func (RouterSpecAction) Method1() string { + return "1" +} + +func (r *RouterSpecAction) Method2() string { + return r.a +} + +func TestRouterFunc(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/", func() string { + return "func" + }) + o.Post("/", func(ctx *Context) { + ctx.Write([]byte("func(*Context)")) + }) + o.Put("/", func(resp http.ResponseWriter, req *http.Request) { + resp.Write([]byte("func(http.ResponseWriter, *http.Request)")) + }) + o.Options("/", func(resp http.ResponseWriter) { + resp.Write([]byte("func(http.ResponseWriter)")) + }) + o.Delete("/", func(req *http.Request) string { + return "func(*http.Request)" + }) + + // plain + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "func") + + // context + buff.Reset() + + req, err = http.NewRequest("POST", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "func(*Context)") + + // http + buff.Reset() + + req, err = http.NewRequest("PUT", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "func(http.ResponseWriter, *http.Request)") + + // response + buff.Reset() + + req, err = http.NewRequest("OPTIONS", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "func(http.ResponseWriter)") + + // req + buff.Reset() + + req, err = http.NewRequest("DELETE", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "func(*http.Request)") +} + +type Router4Action struct { + Params +} + +func (r *Router4Action) Get() string { + return r.Params.Get(":name1") + "-" + r.Params.Get(":name2") +} + +func TestRouter4(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/:name1-:name2", new(Router4Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/foobar-foobar2", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "foobar-foobar2") +} + +type Router5Action struct { +} + +func (r *Router5Action) Get() string { + return "router5" +} + +func TestRouter5(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Route("GET", "/", new(Router5Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "router5") +} + +type Router6Action struct { +} + +func (r *Router6Action) MyMethod() string { + return "router6" +} + +func TestRouter6(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Route("GET:MyMethod", "/", new(Router6Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "router6") +} + +type Router7Action struct { +} + +func (r *Router7Action) MyGet() string { + return "router7-get" +} + +func (r *Router7Action) Post() string { + return "router7-post" +} + +func TestRouter7(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Route([]string{"GET:MyGet", "POST"}, "/", new(Router7Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "router7-get") + + buff.Reset() + + req, err = http.NewRequest("POST", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "router7-post") +} + +type Router8Action struct { +} + +func (r *Router8Action) MyGet() string { + return "router8-get" +} + +func (r *Router8Action) Post() string { + return "router8-post" +} + +func TestRouter8(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Route(map[string]string{"GET": "MyGet", "POST": "Post"}, "/", new(Router8Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "router8-get") + + buff.Reset() + + req, err = http.NewRequest("POST", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "router8-post") +} + +type Regex1Action struct { + Params +} + +func (r *Regex1Action) Get() string { + return r.Params.Get(":name") +} + +func TestRouter9(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + o := Classic() + o.Get("/(:name[a-zA-Z]+)", new(Regex1Action)) + + req, err := http.NewRequest("GET", "http://localhost:8000/foobar", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + expect(t, buff.String(), "foobar") + refute(t, len(buff.String()), 0) +} + +var ( + parsedResult = map[string][]*node{ + "/": []*node{ + {content: "/", tp: snode}, + }, + "/static/css/bootstrap.css": []*node{ + {content: "/static", tp: snode}, + {content: "/css", tp: snode}, + {content: "/bootstrap.css", tp: snode}, + }, + "/:name": []*node{ + {content: "/", tp: snode}, + {content: ":name", tp: nnode}, + }, + "/sss:name": []*node{ + {content: "/sss", tp: snode}, + {content: ":name", tp: nnode}, + }, + "/(:name)": []*node{ + {content: "/", tp: snode}, + {content: ":name", tp: nnode}, + }, + "/(:name)/sss": []*node{ + {content: "/", tp: snode}, + {content: ":name", tp: nnode}, + {content: "/sss", tp: snode}, + }, + "/:name-:value": []*node{ + {content: "/", tp: snode}, + {content: ":name", tp: nnode}, + {content: "-", tp: snode}, + {content: ":value", tp: nnode}, + }, + "/(:name)ssss(:value)": []*node{ + {content: "/", tp: snode}, + {content: ":name", tp: nnode}, + {content: "ssss", tp: snode}, + {content: ":value", tp: nnode}, + }, + "/(:name[0-9]+)": []*node{ + {content: "/", tp: snode}, + {content: ":name", tp: rnode, regexp: regexp.MustCompile("([0-9]+)")}, + }, + "/*name": []*node{ + {content: "/", tp: snode}, + {content: "*name", tp: anode}, + }, + "/*name/ssss": []*node{ + {content: "/", tp: snode}, + {content: "*name", tp: anode}, + {content: "/ssss", tp: snode}, + }, + "/(*name)ssss": []*node{ + {content: "/", tp: snode}, + {content: "*name", tp: anode}, + {content: "ssss", tp: snode}, + }, + "/:name-(:name2[a-z]+)": []*node{ + {content: "/", tp: snode}, + {content: ":name", tp: nnode}, + {content: "-", tp: snode}, + {content: ":name2", tp: rnode, regexp: regexp.MustCompile("([a-z]+)")}, + }, + } +) + +func TestParseNode(t *testing.T) { + for p, r := range parsedResult { + res := parseNodes(p) + if len(r) != len(res) { + t.Fatalf("%v 's result %v is not equal %v", p, r, res) + } + for i := 0; i < len(r); i++ { + if r[i].content != res[i].content || + r[i].tp != res[i].tp { + t.Fatalf("%v 's %d result %v is not equal %v, %v", p, i, r[i], res[i]) + } + + if r[i].tp != rnode { + if r[i].regexp != nil { + t.Fatalf("%v 's %d result %v is not equal %v, %v", p, i, r[i], res[i]) + } + } else { + if r[i].regexp == nil { + t.Fatalf("%v 's %d result %v is not equal %v, %v", p, i, r[i], res[i]) + } + } + } + } +} + +type result struct { + url string + match bool + params Params +} + +var ( + matchResult = map[string][]result{ + "/": []result{ + {"/", true, Params{}}, + {"/s", false, Params{}}, + {"/123", false, Params{}}, + }, + "/ss/tt": []result{ + {"/ss/tt", true, Params{}}, + {"/s", false, Params{}}, + {"/ss", false, Params{}}, + }, + "/:name": []result{ + {"/s", true, Params{{":name", "s"}}}, + {"/", false, Params{}}, + {"/123/s", false, Params{}}, + }, + "/:name1/:name2/:name3": []result{ + {"/1/2/3", true, Params{{":name1", "1"}, {":name2", "2"}, {":name3", "3"}}}, + {"/1/2", false, Params{}}, + {"/1/2/3/", false, Params{}}, + }, + "/*name": []result{ + {"/s", true, Params{{"*name", "s"}}}, + {"/123/s", true, Params{{"*name", "123/s"}}}, + {"/", false, Params{}}, + }, + "/(*name)ssss": []result{ + {"/sssss", true, Params{{"*name", "s"}}}, + {"/123/ssss", true, Params{{"*name", "123/"}}}, + {"/", false, Params{}}, + {"/ss", false, Params{}}, + }, + "/111(*name)ssss": []result{ + {"/111sssss", true, Params{{"*name", "s"}}}, + {"/111/123/ssss", true, Params{{"*name", "/123/"}}}, + {"/", false, Params{}}, + {"/ss", false, Params{}}, + }, + "/(:name[0-9]+)": []result{ + {"/123", true, Params{{":name", "123"}}}, + {"/sss", false, Params{}}, + }, + "/ss(:name[0-9]+)": []result{ + {"/ss123", true, Params{{":name", "123"}}}, + {"/sss", false, Params{}}, + }, + "/ss(:name[0-9]+)tt": []result{ + {"/ss123tt", true, Params{{":name", "123"}}}, + {"/sss", false, Params{}}, + }, + "/:name1-(:name2[0-9]+)": []result{ + {"/ss-123", true, Params{{":name1", "ss"}, {":name2", "123"}}}, + {"/sss", false, Params{}}, + }, + "/(:name1)00(:name2[0-9]+)": []result{ + {"/ss00123", true, Params{{":name1", "ss"}, {":name2", "123"}}}, + {"/sss", false, Params{}}, + }, + "/(:name1)!(:name2[0-9]+)!(:name3.*)": []result{ + {"/ss!123!456", true, Params{{":name1", "ss"}, {":name2", "123"}, {":name3", "456"}}}, + {"/sss", false, Params{}}, + }, + } +) + +type Action struct { +} + +func (Action) Get() string { + return "get" +} + +func TestRouterSingle(t *testing.T) { + for k, m := range matchResult { + r := New() + r.Route("GET", k, new(Action)) + + for _, res := range m { + handler, params := r.Match(res.url, "GET") + if res.match { + if handler == nil { + t.Fatal(k, res, "handler", handler, "should not be nil") + } + for i, v := range params { + if res.params[i].Name != v.Name { + t.Fatal(k, res, "params name", v, "not equal", res.params[i]) + } + if res.params[i].Value != v.Value { + t.Fatal(k, res, "params value", v, "not equal", res.params[i]) + } + } + } else { + if handler != nil { + t.Fatal(k, res, "handler", handler, "should be nil") + } + } + } + } +} + +type testCase struct { + routers []string + results []result +} + +var ( + matchResult2 = []testCase{ + { + []string{"/"}, + []result{ + {"/", true, Params{}}, + {"/s", false, Params{}}, + {"/123", false, Params{}}, + }, + }, + + { + []string{"/admin", "/:name"}, + []result{ + {"/", false, Params{}}, + {"/admin", true, Params{}}, + {"/s", true, Params{param{":name", "s"}}}, + {"/123", true, Params{param{":name", "123"}}}, + }, + }, + + { + []string{"/:name", "/admin"}, + []result{ + {"/", false, Params{}}, + {"/admin", true, Params{}}, + {"/s", true, Params{param{":name", "s"}}}, + {"/123", true, Params{param{":name", "123"}}}, + }, + }, + + { + []string{"/admin", "/*name"}, + []result{ + {"/", false, Params{}}, + {"/admin", true, Params{}}, + {"/s", true, Params{param{"*name", "s"}}}, + {"/123", true, Params{param{"*name", "123"}}}, + }, + }, + + { + []string{"/*name", "/admin"}, + []result{ + {"/", false, Params{}}, + {"/admin", true, Params{}}, + {"/s", true, Params{param{"*name", "s"}}}, + {"/123", true, Params{param{"*name", "123"}}}, + }, + }, + + { + []string{"/*name", "/:name"}, + []result{ + {"/", false, Params{}}, + {"/s", true, Params{param{"*name", "s"}}}, + {"/123", true, Params{param{"*name", "123"}}}, + }, + }, + + { + []string{"/:name", "/*name"}, + []result{ + {"/", false, Params{}}, + {"/s", true, Params{param{":name", "s"}}}, + {"/123", true, Params{param{":name", "123"}}}, + {"/123/1", true, Params{param{"*name", "123/1"}}}, + }, + }, + + { + []string{"/*name", "/*name/123"}, + []result{ + {"/", false, Params{}}, + {"/123", true, Params{param{"*name", "123"}}}, + {"/s", true, Params{param{"*name", "s"}}}, + {"/abc/123", true, Params{param{"*name", "abc"}}}, + {"/name1/name2/123", true, Params{param{"*name", "name1/name2"}}}, + }, + }, + + { + []string{"/admin/ui", "/*name", "/:name"}, + []result{ + {"/", false, Params{}}, + {"/admin/ui", true, Params{}}, + {"/s", true, Params{param{"*name", "s"}}}, + {"/123", true, Params{param{"*name", "123"}}}, + }, + }, + + { + []string{"/(:id[0-9]+)", "/(:id[0-9]+)/edit", "/(:id[0-9]+)/del"}, + []result{ + {"/1", true, Params{param{":id", "1"}}}, + {"/admin/ui", false, Params{}}, + {"/2/edit", true, Params{param{":id", "2"}}}, + {"/3/del", true, Params{param{":id", "3"}}}, + }, + }, + + { + []string{"/admin/ui", "/:name1/:name2"}, + []result{ + {"/", false, Params{}}, + {"/admin/ui", true, Params{}}, + {"/s", false, Params{}}, + {"/admin/ui2", true, Params{param{":name1", "admin"}, param{":name2", "ui2"}}}, + {"/123/s", true, Params{param{":name1", "123"}, param{":name2", "s"}}}, + }, + }, + } +) + +func TestRouterMultiple(t *testing.T) { + for _, kase := range matchResult2 { + r := New() + for _, k := range kase.routers { + r.Route("GET", k, new(Action)) + } + + for _, res := range kase.results { + handler, params := r.Match(res.url, "GET") + if res.match { + if handler == nil { + t.Fatal(kase.routers, res, "handler", handler, "should not be nil") + } + + if len(res.params) != len(params) { + t.Fatal(kase.routers, res, "params", params, "not equal", res.params) + } + for i, v := range params { + if res.params[i].Name != v.Name { + t.Fatal(kase.routers, res, "params name", v, "not equal", res.params[i]) + } + if res.params[i].Value != v.Value { + t.Fatal(kase.routers, res, "params value", v, "not equal", res.params[i]) + } + } + } else { + if handler != nil { + t.Fatal(kase.routers, res, "handler", handler, "should be nil") + } + } + } + } +} + +func TestRouter10(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + r := New() + r.Get("/", func(ctx *Context) { + ctx.Write([]byte("test")) + }) + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + r.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + expect(t, buff.String(), "test") + refute(t, len(buff.String()), 0) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/static.go b/Godeps/_workspace/src/github.com/lunny/tango/static.go new file mode 100644 index 0000000..a8d7577 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/static.go @@ -0,0 +1,165 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "os" + "path" + "path/filepath" + "strings" +) + +type StaticOptions struct { + RootPath string + Prefix string + IndexFiles []string + ListDir bool + FilterExts []string +} + +func prepareStaticOptions(options []StaticOptions) StaticOptions { + var opt StaticOptions + if len(options) > 0 { + opt = options[0] + } + + // Defaults + if len(opt.RootPath) == 0 { + opt.RootPath = "./public" + } + + if len(opt.Prefix) > 0 { + if opt.Prefix[0] != '/' { + opt.Prefix = "/" + opt.Prefix + } + } + + if len(opt.IndexFiles) == 0 { + opt.IndexFiles = []string{"index.html", "index.htm"} + } + + return opt +} + +func Static(opts ...StaticOptions) HandlerFunc { + return func(ctx *Context) { + if ctx.Req().Method != "GET" && ctx.Req().Method != "HEAD" { + ctx.Next() + return + } + + opt := prepareStaticOptions(opts) + + var rPath = ctx.Req().URL.Path + // if defined prefix, then only check prefix + if opt.Prefix != "" { + if !strings.HasPrefix(ctx.Req().URL.Path, opt.Prefix) { + ctx.Next() + return + } else { + if len(opt.Prefix) == len(ctx.Req().URL.Path) { + rPath = "" + } else { + rPath = ctx.Req().URL.Path[len(opt.Prefix):] + } + } + } + + fPath, _ := filepath.Abs(filepath.Join(opt.RootPath, rPath)) + finfo, err := os.Stat(fPath) + if err != nil { + if !os.IsNotExist(err) { + ctx.Result = InternalServerError(err.Error()) + ctx.HandleError() + return + } + } else if !finfo.IsDir() { + if len(opt.FilterExts) > 0 { + var matched bool + for _, ext := range opt.FilterExts { + if filepath.Ext(fPath) == ext { + matched = true + break + } + } + if !matched { + ctx.Next() + return + } + } + + err := ctx.ServeFile(fPath) + if err != nil { + ctx.Result = InternalServerError(err.Error()) + ctx.HandleError() + } + return + } else { + // try serving index.html or index.htm + if len(opt.IndexFiles) > 0 { + for _, index := range opt.IndexFiles { + nPath := filepath.Join(fPath, index) + finfo, err = os.Stat(nPath) + if err != nil { + if !os.IsNotExist(err) { + ctx.Result = InternalServerError(err.Error()) + ctx.HandleError() + return + } + } else if !finfo.IsDir() { + err = ctx.ServeFile(nPath) + if err != nil { + ctx.Result = InternalServerError(err.Error()) + ctx.HandleError() + } + return + } + } + } + + // list dir files + if opt.ListDir { + ctx.Header().Set("Content-Type", "text/html; charset=UTF-8") + ctx.Write([]byte(`
    `)) + rootPath, _ := filepath.Abs(opt.RootPath) + rPath, _ := filepath.Rel(rootPath, fPath) + if fPath != rootPath { + ctx.Write([]byte(`
  •     ..
  • `)) + } + err = filepath.Walk(fPath, func(p string, fi os.FileInfo, err error) error { + rPath, _ := filepath.Rel(fPath, p) + if rPath == "." || len(strings.Split(rPath, string(filepath.Separator))) > 1 { + return nil + } + rPath, _ = filepath.Rel(rootPath, p) + ps, _ := os.Stat(p) + if ps.IsDir() { + ctx.Write([]byte(`
  • ` + filepath.Base(p) + `
  • `)) + } else { + if len(opt.FilterExts) > 0 { + var matched bool + for _, ext := range opt.FilterExts { + if filepath.Ext(p) == ext { + matched = true + break + } + } + if !matched { + return nil + } + } + + ctx.Write([]byte(`
  •     ` + filepath.Base(p) + `
  • `)) + } + return nil + }) + ctx.Write([]byte("
")) + return + } + } + + ctx.Next() + } +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/static_test.go b/Godeps/_workspace/src/github.com/lunny/tango/static_test.go new file mode 100644 index 0000000..b567fe8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/static_test.go @@ -0,0 +1,125 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" +) + +func TestStatic(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + tg := New() + tg.Use(Static()) + + req, err := http.NewRequest("GET", "http://localhost:8000/test.html", nil) + if err != nil { + t.Error(err) + } + + tg.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "hello tango") + + buff.Reset() + + req, err = http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + tg.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), "this is index.html") +} + +func TestStatic2(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + tg := New() + tg.Use(Static()) + + req, err := http.NewRequest("GET", "http://localhost:8000/test.png", nil) + if err != nil { + t.Error(err) + } + + tg.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotFound) +} + +func TestStatic3(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + tg := New() + tg.Use(Static(StaticOptions{ + Prefix: "/public", + RootPath: "./public", + })) + + req, err := http.NewRequest("GET", "http://localhost:8000/public/test.html", nil) + if err != nil { + t.Error(err) + } + + tg.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + expect(t, buff.String(), "hello tango") +} + +func TestStatic4(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + tg := New() + tg.Use(Static(StaticOptions{ + Prefix: "/public", + RootPath: "./public", + })) + + req, err := http.NewRequest("GET", "http://localhost:8000/public/t.html", nil) + if err != nil { + t.Error(err) + } + + tg.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusNotFound) + expect(t, buff.String(), NotFound().Error()) +} + +func TestStatic5(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + tg := New() + tg.Use(Static(StaticOptions{ + Prefix: "/public", + RootPath: "./public", + ListDir: true, + IndexFiles: []string{"a.html"}, + })) + + req, err := http.NewRequest("GET", "http://localhost:8000/public/", nil) + if err != nil { + t.Error(err) + } + + tg.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + //expect(t, buff.String(), NotFound().Error()) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/tan.go b/Godeps/_workspace/src/github.com/lunny/tango/tan.go new file mode 100644 index 0000000..0406d36 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/tan.go @@ -0,0 +1,257 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "net/http" + "os" + "strconv" + "strings" + "sync" +) + +func Version() string { + return "0.4.6.0807" +} + +type Tango struct { + Router + handlers []Handler + logger Logger + ErrHandler Handler + ctxPool sync.Pool + respPool sync.Pool +} + +var ( + ClassicHandlers = []Handler{ + Logging(), + Recovery(false), + Compresses([]string{}), + Static(StaticOptions{Prefix: "public"}), + Return(), + Param(), + Contexts(), + } +) + +func (t *Tango) Logger() Logger { + return t.logger +} + +func (t *Tango) Get(url string, c interface{}) { + t.Route([]string{"GET", "HEAD"}, url, c) +} + +func (t *Tango) Post(url string, c interface{}) { + t.Route([]string{"POST"}, url, c) +} + +func (t *Tango) Head(url string, c interface{}) { + t.Route([]string{"HEAD"}, url, c) +} + +func (t *Tango) Options(url string, c interface{}) { + t.Route([]string{"OPTIONS"}, url, c) +} + +func (t *Tango) Trace(url string, c interface{}) { + t.Route([]string{"TRACE"}, url, c) +} + +func (t *Tango) Patch(url string, c interface{}) { + t.Route([]string{"PATCH"}, url, c) +} + +func (t *Tango) Delete(url string, c interface{}) { + t.Route([]string{"DELETE"}, url, c) +} + +func (t *Tango) Put(url string, c interface{}) { + t.Route([]string{"PUT"}, url, c) +} + +func (t *Tango) Any(url string, c interface{}) { + t.Route(SupportMethods, url, c) +} + +func (t *Tango) Use(handlers ...Handler) { + t.handlers = append(t.handlers, handlers...) +} + +func GetDefaultListenInfo() (string, int) { + host := os.Getenv("HOST") + if len(host) == 0 { + host = "0.0.0.0" + } + _port, _ := strconv.ParseInt(os.Getenv("PORT"), 10, 32) + port := int(_port) + if port == 0 { + port = 8000 + } + return host, port +} + +func GetAddress(args ...interface{}) string { + host, port := GetDefaultListenInfo() + + if len(args) == 1 { + switch arg := args[0].(type) { + case string: + addrs := strings.Split(args[0].(string), ":") + if len(addrs) == 1 { + host = addrs[0] + } else if len(addrs) >= 2 { + host = addrs[0] + _port, _ := strconv.ParseInt(addrs[1], 10, 0) + port = int(_port) + } + case int: + port = arg + } + } else if len(args) >= 2 { + if arg, ok := args[0].(string); ok { + host = arg + } + if arg, ok := args[1].(int); ok { + port = arg + } + } + + addr := host + ":" + strconv.FormatInt(int64(port), 10) + + return addr +} + +// Run the http server. Listening on os.GetEnv("PORT") or 8000 by default. +func (t *Tango) Run(args ...interface{}) { + addr := GetAddress(args...) + t.logger.Info("Listening on http", addr) + + err := http.ListenAndServe(addr, t) + if err != nil { + t.logger.Error(err) + } +} + +func (t *Tango) RunTLS(certFile, keyFile string, args ...interface{}) { + addr := GetAddress(args...) + + t.logger.Info("Listening on https", addr) + + err := http.ListenAndServeTLS(addr, certFile, keyFile, t) + if err != nil { + t.logger.Error(err) + } +} + +type HandlerFunc func(ctx *Context) + +func (h HandlerFunc) Handle(ctx *Context) { + h(ctx) +} + +func WrapBefore(handler http.Handler) HandlerFunc { + return func(ctx *Context) { + handler.ServeHTTP(ctx.ResponseWriter, ctx.Req()) + + ctx.Next() + } +} + +func WrapAfter(handler http.Handler) HandlerFunc { + return func(ctx *Context) { + ctx.Next() + + handler.ServeHTTP(ctx.ResponseWriter, ctx.Req()) + } +} + +func (t *Tango) UseHandler(handler http.Handler) { + t.Use(WrapBefore(handler)) +} + +func (t *Tango) ServeHTTP(w http.ResponseWriter, req *http.Request) { + resp := t.respPool.Get().(*responseWriter) + resp.reset(w) + + ctx := t.ctxPool.Get().(*Context) + ctx.tan = t + ctx.reset(req, resp) + + ctx.Invoke() + + // if there is no logging or error handle, so the last written check. + if !ctx.Written() { + p := req.URL.Path + if len(req.URL.RawQuery) > 0 { + p = p + "?" + req.URL.RawQuery + } + + if ctx.Route() != nil { + if ctx.Result == nil { + ctx.Write([]byte("")) + t.logger.Info(req.Method, ctx.Status(), p) + t.ctxPool.Put(ctx) + t.respPool.Put(resp) + return + } + panic("result should be handler before") + } + + if ctx.Result == nil { + ctx.Result = NotFound() + } + + ctx.HandleError() + + t.logger.Error(req.Method, ctx.Status(), p) + } + + t.ctxPool.Put(ctx) + t.respPool.Put(resp) +} + +func NewWithLog(logger Logger, handlers ...Handler) *Tango { + tan := &Tango{ + Router: NewRouter(), + logger: logger, + handlers: make([]Handler, 0), + ErrHandler: Errors(), + } + + tan.ctxPool.New = func() interface{} { + return &Context{ + tan: tan, + Logger: tan.logger, + } + } + + tan.respPool.New = func() interface{} { + return &responseWriter{} + } + + tan.Use(handlers...) + + return tan +} + +func New(handlers ...Handler) *Tango { + return NewWithLog(NewLogger(os.Stdout), handlers...) +} + +func Classic(l ...Logger) *Tango { + var logger Logger + if len(l) == 0 { + logger = NewLogger(os.Stdout) + } else { + logger = l[0] + } + + return NewWithLog( + logger, + ClassicHandlers..., + ) +} diff --git a/Godeps/_workspace/src/github.com/lunny/tango/tan_test.go b/Godeps/_workspace/src/github.com/lunny/tango/tan_test.go new file mode 100644 index 0000000..bac2ea9 --- /dev/null +++ b/Godeps/_workspace/src/github.com/lunny/tango/tan_test.go @@ -0,0 +1,119 @@ +// Copyright 2015 The Tango Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tango + +import ( + "bytes" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "reflect" + "testing" + "time" +) + +func TestTan1(t *testing.T) { + buff := bytes.NewBufferString("") + recorder := httptest.NewRecorder() + recorder.Body = buff + + l := NewLogger(os.Stdout) + o := Classic(l) + o.Get("/", func() string { + return Version() + }) + o.Logger().Debug("it's ok") + + req, err := http.NewRequest("GET", "http://localhost:8000/", nil) + if err != nil { + t.Error(err) + } + + o.ServeHTTP(recorder, req) + expect(t, recorder.Code, http.StatusOK) + refute(t, len(buff.String()), 0) + expect(t, buff.String(), Version()) +} + +func TestTan2(t *testing.T) { + o := Classic() + o.Get("/", func() string { + return Version() + }) + go o.Run() + + time.Sleep(100 * time.Millisecond) + + resp, err := http.Get("http://localhost:8000/") + if err != nil { + t.Error(err) + } + bs, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + + expect(t, resp.StatusCode, http.StatusOK) + expect(t, string(bs), Version()) +} + +func TestTan3(t *testing.T) { + o := Classic() + o.Get("/", func() string { + return Version() + }) + go o.Run(":4040") + + time.Sleep(100 * time.Millisecond) + + resp, err := http.Get("http://localhost:4040/") + if err != nil { + t.Error(err) + } + bs, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + + expect(t, resp.StatusCode, http.StatusOK) + expect(t, string(bs), Version()) +} + +/* +func TestTan4(t *testing.T) { + o := Classic() + o.Get("/", func() string { + return Version() + }) + go o.RunTLS("./public/cert.pem", "./public/key.pem", ":5050") + + time.Sleep(100 * time.Millisecond) + + resp, err := http.Get("https://localhost:5050/") + if err != nil { + t.Error(err) + } + bs, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Error(err) + } + + expect(t, resp.StatusCode, http.StatusOK) + expect(t, string(bs), Version()) +}*/ + +/* Test Helpers */ +func expect(t *testing.T, a interface{}, b interface{}) { + if a != b { + t.Errorf("Expected %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a)) + } +} + +func refute(t *testing.T, a interface{}, b interface{}) { + if a == b { + t.Errorf("Did not expect %v (type %v) - Got %v (type %v)", b, reflect.TypeOf(b), a, reflect.TypeOf(a)) + } +} diff --git a/Godeps/_workspace/src/github.com/qiniu/log/.gitignore b/Godeps/_workspace/src/github.com/qiniu/log/.gitignore new file mode 100644 index 0000000..0026861 --- /dev/null +++ b/Godeps/_workspace/src/github.com/qiniu/log/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/Godeps/_workspace/src/github.com/qiniu/log/README.md b/Godeps/_workspace/src/github.com/qiniu/log/README.md new file mode 100644 index 0000000..a5e831b --- /dev/null +++ b/Godeps/_workspace/src/github.com/qiniu/log/README.md @@ -0,0 +1,4 @@ +log +=== + +Extension module of golang logging \ No newline at end of file diff --git a/Godeps/_workspace/src/github.com/qiniu/log/logext.go b/Godeps/_workspace/src/github.com/qiniu/log/logext.go new file mode 100644 index 0000000..0b6c54a --- /dev/null +++ b/Godeps/_workspace/src/github.com/qiniu/log/logext.go @@ -0,0 +1,526 @@ +package log + +import ( + "bytes" + "fmt" + "io" + "os" + "runtime" + "strings" + "sync" + "time" +) + +// These flags define which text to prefix to each log entry generated by the Logger. +const ( + // Bits or'ed together to control what's printed. There is no control over the + // order they appear (the order listed here) or the format they present (as + // described in the comments). A colon appears after these items: + // 2009/0123 01:23:23.123123 /a/b/c/d.go:23: message + Ldate = 1 << iota // the date: 2009/0123 + Ltime // the time: 01:23:23 + Lmicroseconds // microsecond resolution: 01:23:23.123123. assumes Ltime. + Llongfile // full file name and line number: /a/b/c/d.go:23 + Lshortfile // final file name element and line number: d.go:23. overrides Llongfile + Lmodule // module name + Llevel // level: 0(Debug), 1(Info), 2(Warn), 3(Error), 4(Panic), 5(Fatal) + LstdFlags = Ldate | Ltime // initial values for the standard logger + Ldefault = Lmodule | Llevel | Lshortfile | LstdFlags +) // [prefix][time][level][module][shortfile|longfile] + +const ( + Ldebug = iota + Linfo + Lwarn + Lerror + Lpanic + Lfatal +) + +var levels = []string{ + "[DEBUG]", + "[INFO]", + "[WARN]", + "[ERROR]", + "[PANIC]", + "[FATAL]", +} + +// A Logger represents an active logging object that generates lines of +// output to an io.Writer. Each logging operation makes a single call to +// the Writer's Write method. A Logger can be used simultaneously from +// multiple goroutines; it guarantees to serialize access to the Writer. +type Logger struct { + mu sync.Mutex // ensures atomic writes; protects the following fields + prefix string // prefix to write at beginning of each line + flag int // properties + Level int + out io.Writer // destination for output + buf bytes.Buffer // for accumulating text to write + levelStats [6]int64 +} + +// New creates a new Logger. The out variable sets the +// destination to which log data will be written. +// The prefix appears at the beginning of each generated log line. +// The flag argument defines the logging properties. +func New(out io.Writer, prefix string, flag int) *Logger { + return &Logger{out: out, prefix: prefix, Level: 1, flag: flag} +} + +var Std = New(os.Stderr, "", Ldefault) + +// Cheap integer to fixed-width decimal ASCII. Give a negative width to avoid zero-padding. +// Knows the buffer has capacity. +func itoa(buf *bytes.Buffer, i int, wid int) { + var u uint = uint(i) + if u == 0 && wid <= 1 { + buf.WriteByte('0') + return + } + + // Assemble decimal in reverse order. + var b [32]byte + bp := len(b) + for ; u > 0 || wid > 0; u /= 10 { + bp-- + wid-- + b[bp] = byte(u%10) + '0' + } + + // avoid slicing b to avoid an allocation. + for bp < len(b) { + buf.WriteByte(b[bp]) + bp++ + } +} + +func moduleOf(file string) string { + pos := strings.LastIndex(file, "/") + if pos != -1 { + pos1 := strings.LastIndex(file[:pos], "/src/") + if pos1 != -1 { + return file[pos1+5 : pos] + } + } + return "UNKNOWN" +} + +func (l *Logger) formatHeader(buf *bytes.Buffer, t time.Time, file string, line int, lvl int, reqId string) { + if l.prefix != "" { + buf.WriteString(l.prefix) + } + if l.flag&(Ldate|Ltime|Lmicroseconds) != 0 { + if l.flag&Ldate != 0 { + year, month, day := t.Date() + itoa(buf, year, 4) + buf.WriteByte('/') + itoa(buf, int(month), 2) + buf.WriteByte('/') + itoa(buf, day, 2) + buf.WriteByte(' ') + } + if l.flag&(Ltime|Lmicroseconds) != 0 { + hour, min, sec := t.Clock() + itoa(buf, hour, 2) + buf.WriteByte(':') + itoa(buf, min, 2) + buf.WriteByte(':') + itoa(buf, sec, 2) + if l.flag&Lmicroseconds != 0 { + buf.WriteByte('.') + itoa(buf, t.Nanosecond()/1e3, 6) + } + buf.WriteByte(' ') + } + } + if reqId != "" { + buf.WriteByte('[') + buf.WriteString(reqId) + buf.WriteByte(']') + } + if l.flag&Llevel != 0 { + buf.WriteString(levels[lvl]) + } + if l.flag&Lmodule != 0 { + buf.WriteByte('[') + buf.WriteString(moduleOf(file)) + buf.WriteByte(']') + buf.WriteByte(' ') + } + if l.flag&(Lshortfile|Llongfile) != 0 { + if l.flag&Lshortfile != 0 { + short := file + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + short = file[i+1:] + break + } + } + file = short + } + buf.WriteString(file) + buf.WriteByte(':') + itoa(buf, line, -1) + buf.WriteString(": ") + } +} + +// Output writes the output for a logging event. The string s contains +// the text to print after the prefix specified by the flags of the +// Logger. A newline is appended if the last character of s is not +// already a newline. Calldepth is used to recover the PC and is +// provided for generality, although at the moment on all pre-defined +// paths it will be 2. +func (l *Logger) Output(reqId string, lvl int, calldepth int, s string) error { + if lvl < l.Level { + return nil + } + now := time.Now() // get this early. + var file string + var line int + l.mu.Lock() + defer l.mu.Unlock() + if l.flag&(Lshortfile|Llongfile|Lmodule) != 0 { + // release lock while getting caller info - it's expensive. + l.mu.Unlock() + var ok bool + _, file, line, ok = runtime.Caller(calldepth) + if !ok { + file = "???" + line = 0 + } + l.mu.Lock() + } + l.levelStats[lvl]++ + l.buf.Reset() + l.formatHeader(&l.buf, now, file, line, lvl, reqId) + l.buf.WriteString(s) + if len(s) > 0 && s[len(s)-1] != '\n' { + l.buf.WriteByte('\n') + } + _, err := l.out.Write(l.buf.Bytes()) + return err +} + +// ----------------------------------------- + +// Printf calls l.Output to print to the logger. +// Arguments are handled in the manner of fmt.Printf. +func (l *Logger) Printf(format string, v ...interface{}) { + l.Output("", Linfo, 2, fmt.Sprintf(format, v...)) +} + +// Print calls l.Output to print to the logger. +// Arguments are handled in the manner of fmt.Print. +func (l *Logger) Print(v ...interface{}) { l.Output("", Linfo, 2, fmt.Sprint(v...)) } + +// Println calls l.Output to print to the logger. +// Arguments are handled in the manner of fmt.Println. +func (l *Logger) Println(v ...interface{}) { l.Output("", Linfo, 2, fmt.Sprintln(v...)) } + +// ----------------------------------------- + +func (l *Logger) Debugf(format string, v ...interface{}) { + if Ldebug < l.Level { + return + } + l.Output("", Ldebug, 2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Debug(v ...interface{}) { + if Ldebug < l.Level { + return + } + l.Output("", Ldebug, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func (l *Logger) Infof(format string, v ...interface{}) { + if Linfo < l.Level { + return + } + l.Output("", Linfo, 2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Info(v ...interface{}) { + if Linfo < l.Level { + return + } + l.Output("", Linfo, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func (l *Logger) Warnf(format string, v ...interface{}) { + l.Output("", Lwarn, 2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Warn(v ...interface{}) { l.Output("", Lwarn, 2, fmt.Sprintln(v...)) } + +// ----------------------------------------- + +func (l *Logger) Errorf(format string, v ...interface{}) { + l.Output("", Lerror, 2, fmt.Sprintf(format, v...)) +} + +func (l *Logger) Error(v ...interface{}) { l.Output("", Lerror, 2, fmt.Sprintln(v...)) } + +// ----------------------------------------- + +func (l *Logger) Fatal(v ...interface{}) { + l.Output("", Lfatal, 2, fmt.Sprint(v...)) + os.Exit(1) +} + +// Fatalf is equivalent to l.Printf() followed by a call to os.Exit(1). +func (l *Logger) Fatalf(format string, v ...interface{}) { + l.Output("", Lfatal, 2, fmt.Sprintf(format, v...)) + os.Exit(1) +} + +// Fatalln is equivalent to l.Println() followed by a call to os.Exit(1). +func (l *Logger) Fatalln(v ...interface{}) { + l.Output("", Lfatal, 2, fmt.Sprintln(v...)) + os.Exit(1) +} + +// ----------------------------------------- + +// Panic is equivalent to l.Print() followed by a call to panic(). +func (l *Logger) Panic(v ...interface{}) { + s := fmt.Sprint(v...) + l.Output("", Lpanic, 2, s) + panic(s) +} + +// Panicf is equivalent to l.Printf() followed by a call to panic(). +func (l *Logger) Panicf(format string, v ...interface{}) { + s := fmt.Sprintf(format, v...) + l.Output("", Lpanic, 2, s) + panic(s) +} + +// Panicln is equivalent to l.Println() followed by a call to panic(). +func (l *Logger) Panicln(v ...interface{}) { + s := fmt.Sprintln(v...) + l.Output("", Lpanic, 2, s) + panic(s) +} + +// ----------------------------------------- + +func (l *Logger) Stack(v ...interface{}) { + s := fmt.Sprint(v...) + s += "\n" + buf := make([]byte, 1024*1024) + n := runtime.Stack(buf, true) + s += string(buf[:n]) + s += "\n" + l.Output("", Lerror, 2, s) +} + +// ----------------------------------------- + +func (l *Logger) Stat() (stats []int64) { + l.mu.Lock() + v := l.levelStats + l.mu.Unlock() + return v[:] +} + +// Flags returns the output flags for the logger. +func (l *Logger) Flags() int { + l.mu.Lock() + defer l.mu.Unlock() + return l.flag +} + +// SetFlags sets the output flags for the logger. +func (l *Logger) SetFlags(flag int) { + l.mu.Lock() + defer l.mu.Unlock() + l.flag = flag +} + +// Prefix returns the output prefix for the logger. +func (l *Logger) Prefix() string { + l.mu.Lock() + defer l.mu.Unlock() + return l.prefix +} + +// SetPrefix sets the output prefix for the logger. +func (l *Logger) SetPrefix(prefix string) { + l.mu.Lock() + defer l.mu.Unlock() + l.prefix = prefix +} + +// SetOutputLevel sets the output level for the logger. +func (l *Logger) SetOutputLevel(lvl int) { + l.mu.Lock() + defer l.mu.Unlock() + l.Level = lvl +} + +// SetOutput sets the output destination for the standard logger. +func SetOutput(w io.Writer) { + Std.mu.Lock() + defer Std.mu.Unlock() + Std.out = w +} + +// Flags returns the output flags for the standard logger. +func Flags() int { + return Std.Flags() +} + +// SetFlags sets the output flags for the standard logger. +func SetFlags(flag int) { + Std.SetFlags(flag) +} + +// Prefix returns the output prefix for the standard logger. +func Prefix() string { + return Std.Prefix() +} + +// SetPrefix sets the output prefix for the standard logger. +func SetPrefix(prefix string) { + Std.SetPrefix(prefix) +} + +func SetOutputLevel(lvl int) { + Std.SetOutputLevel(lvl) +} + +func GetOutputLevel() int { + return Std.Level +} + +// ----------------------------------------- + +// Print calls Output to print to the standard logger. +// Arguments are handled in the manner of fmt.Print. +func Print(v ...interface{}) { + Std.Output("", Linfo, 2, fmt.Sprint(v...)) +} + +// Printf calls Output to print to the standard logger. +// Arguments are handled in the manner of fmt.Printf. +func Printf(format string, v ...interface{}) { + Std.Output("", Linfo, 2, fmt.Sprintf(format, v...)) +} + +// Println calls Output to print to the standard logger. +// Arguments are handled in the manner of fmt.Println. +func Println(v ...interface{}) { + Std.Output("", Linfo, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func Debugf(format string, v ...interface{}) { + if Ldebug < Std.Level { + return + } + Std.Output("", Ldebug, 2, fmt.Sprintf(format, v...)) +} + +func Debug(v ...interface{}) { + if Ldebug < Std.Level { + return + } + Std.Output("", Ldebug, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func Infof(format string, v ...interface{}) { + if Linfo < Std.Level { + return + } + Std.Output("", Linfo, 2, fmt.Sprintf(format, v...)) +} + +func Info(v ...interface{}) { + if Linfo < Std.Level { + return + } + Std.Output("", Linfo, 2, fmt.Sprintln(v...)) +} + +// ----------------------------------------- + +func Warnf(format string, v ...interface{}) { + Std.Output("", Lwarn, 2, fmt.Sprintf(format, v...)) +} + +func Warn(v ...interface{}) { Std.Output("", Lwarn, 2, fmt.Sprintln(v...)) } + +// ----------------------------------------- + +func Errorf(format string, v ...interface{}) { + Std.Output("", Lerror, 2, fmt.Sprintf(format, v...)) +} + +func Error(v ...interface{}) { Std.Output("", Lerror, 2, fmt.Sprintln(v...)) } + +// ----------------------------------------- + +// Fatal is equivalent to Print() followed by a call to os.Exit(1). +func Fatal(v ...interface{}) { + Std.Output("", Lfatal, 2, fmt.Sprint(v...)) + os.Exit(1) +} + +// Fatalf is equivalent to Printf() followed by a call to os.Exit(1). +func Fatalf(format string, v ...interface{}) { + Std.Output("", Lfatal, 2, fmt.Sprintf(format, v...)) + os.Exit(1) +} + +// Fatalln is equivalent to Println() followed by a call to os.Exit(1). +func Fatalln(v ...interface{}) { + Std.Output("", Lfatal, 2, fmt.Sprintln(v...)) + os.Exit(1) +} + +// ----------------------------------------- + +// Panic is equivalent to Print() followed by a call to panic(). +func Panic(v ...interface{}) { + s := fmt.Sprint(v...) + Std.Output("", Lpanic, 2, s) + panic(s) +} + +// Panicf is equivalent to Printf() followed by a call to panic(). +func Panicf(format string, v ...interface{}) { + s := fmt.Sprintf(format, v...) + Std.Output("", Lpanic, 2, s) + panic(s) +} + +// Panicln is equivalent to Println() followed by a call to panic(). +func Panicln(v ...interface{}) { + s := fmt.Sprintln(v...) + Std.Output("", Lpanic, 2, s) + panic(s) +} + +// ----------------------------------------- + +func Stack(v ...interface{}) { + s := fmt.Sprint(v...) + s += "\n" + buf := make([]byte, 1024*1024) + n := runtime.Stack(buf, true) + s += string(buf[:n]) + s += "\n" + Std.Output("", Lerror, 2, s) +} + +// ----------------------------------------- diff --git a/Godeps/_workspace/src/github.com/qiniu/log/logext_test.go b/Godeps/_workspace/src/github.com/qiniu/log/logext_test.go new file mode 100644 index 0000000..6e6b6e8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/qiniu/log/logext_test.go @@ -0,0 +1,36 @@ +package log + +import ( + "testing" +) + +func TestLog(t *testing.T) { + + SetOutputLevel(Ldebug) + + Debugf("Debug: foo\n") + Debug("Debug: foo") + + Infof("Info: foo\n") + Info("Info: foo") + + Warnf("Warn: foo\n") + Warn("Warn: foo") + + Errorf("Error: foo\n") + Error("Error: foo") + + SetOutputLevel(Linfo) + + Debugf("Debug: foo\n") + Debug("Debug: foo") + + Infof("Info: foo\n") + Info("Info: foo") + + Warnf("Warn: foo\n") + Warn("Warn: foo") + + Errorf("Error: foo\n") + Error("Error: foo") +} diff --git a/Godeps/_workspace/src/golang.org/x/net/context/context.go b/Godeps/_workspace/src/golang.org/x/net/context/context.go new file mode 100644 index 0000000..ef2f3e8 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/context/context.go @@ -0,0 +1,447 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package context defines the Context type, which carries deadlines, +// cancelation signals, and other request-scoped values across API boundaries +// and between processes. +// +// Incoming requests to a server should create a Context, and outgoing calls to +// servers should accept a Context. The chain of function calls between must +// propagate the Context, optionally replacing it with a modified copy created +// using WithDeadline, WithTimeout, WithCancel, or WithValue. +// +// Programs that use Contexts should follow these rules to keep interfaces +// consistent across packages and enable static analysis tools to check context +// propagation: +// +// Do not store Contexts inside a struct type; instead, pass a Context +// explicitly to each function that needs it. The Context should be the first +// parameter, typically named ctx: +// +// func DoSomething(ctx context.Context, arg Arg) error { +// // ... use ctx ... +// } +// +// Do not pass a nil Context, even if a function permits it. Pass context.TODO +// if you are unsure about which Context to use. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The same Context may be passed to functions running in different goroutines; +// Contexts are safe for simultaneous use by multiple goroutines. +// +// See http://blog.golang.org/context for example code for a server that uses +// Contexts. +package context + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// A Context carries a deadline, a cancelation signal, and other values across +// API boundaries. +// +// Context's methods may be called by multiple goroutines simultaneously. +type Context interface { + // Deadline returns the time when work done on behalf of this context + // should be canceled. Deadline returns ok==false when no deadline is + // set. Successive calls to Deadline return the same results. + Deadline() (deadline time.Time, ok bool) + + // Done returns a channel that's closed when work done on behalf of this + // context should be canceled. Done may return nil if this context can + // never be canceled. Successive calls to Done return the same value. + // + // WithCancel arranges for Done to be closed when cancel is called; + // WithDeadline arranges for Done to be closed when the deadline + // expires; WithTimeout arranges for Done to be closed when the timeout + // elapses. + // + // Done is provided for use in select statements: + // + // // Stream generates values with DoSomething and sends them to out + // // until DoSomething returns an error or ctx.Done is closed. + // func Stream(ctx context.Context, out <-chan Value) error { + // for { + // v, err := DoSomething(ctx) + // if err != nil { + // return err + // } + // select { + // case <-ctx.Done(): + // return ctx.Err() + // case out <- v: + // } + // } + // } + // + // See http://blog.golang.org/pipelines for more examples of how to use + // a Done channel for cancelation. + Done() <-chan struct{} + + // Err returns a non-nil error value after Done is closed. Err returns + // Canceled if the context was canceled or DeadlineExceeded if the + // context's deadline passed. No other values for Err are defined. + // After Done is closed, successive calls to Err return the same value. + Err() error + + // Value returns the value associated with this context for key, or nil + // if no value is associated with key. Successive calls to Value with + // the same key returns the same result. + // + // Use context values only for request-scoped data that transits + // processes and API boundaries, not for passing optional parameters to + // functions. + // + // A key identifies a specific value in a Context. Functions that wish + // to store values in Context typically allocate a key in a global + // variable then use that key as the argument to context.WithValue and + // Context.Value. A key can be any type that supports equality; + // packages should define keys as an unexported type to avoid + // collisions. + // + // Packages that define a Context key should provide type-safe accessors + // for the values stores using that key: + // + // // Package user defines a User type that's stored in Contexts. + // package user + // + // import "golang.org/x/net/context" + // + // // User is the type of value stored in the Contexts. + // type User struct {...} + // + // // key is an unexported type for keys defined in this package. + // // This prevents collisions with keys defined in other packages. + // type key int + // + // // userKey is the key for user.User values in Contexts. It is + // // unexported; clients use user.NewContext and user.FromContext + // // instead of using this key directly. + // var userKey key = 0 + // + // // NewContext returns a new Context that carries value u. + // func NewContext(ctx context.Context, u *User) context.Context { + // return context.WithValue(ctx, userKey, u) + // } + // + // // FromContext returns the User value stored in ctx, if any. + // func FromContext(ctx context.Context) (*User, bool) { + // u, ok := ctx.Value(userKey).(*User) + // return u, ok + // } + Value(key interface{}) interface{} +} + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = errors.New("context canceled") + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = errors.New("context deadline exceeded") + +// An emptyCtx is never canceled, has no values, and has no deadline. It is not +// struct{}, since vars of this type must have distinct addresses. +type emptyCtx int + +func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*emptyCtx) Done() <-chan struct{} { + return nil +} + +func (*emptyCtx) Err() error { + return nil +} + +func (*emptyCtx) Value(key interface{}) interface{} { + return nil +} + +func (e *emptyCtx) String() string { + switch e { + case background: + return "context.Background" + case todo: + return "context.TODO" + } + return "unknown empty Context" +} + +var ( + background = new(emptyCtx) + todo = new(emptyCtx) +) + +// Background returns a non-nil, empty Context. It is never canceled, has no +// values, and has no deadline. It is typically used by the main function, +// initialization, and tests, and as the top-level Context for incoming +// requests. +func Background() Context { + return background +} + +// TODO returns a non-nil, empty Context. Code should use context.TODO when +// it's unclear which Context to use or it's is not yet available (because the +// surrounding function has not yet been extended to accept a Context +// parameter). TODO is recognized by static analysis tools that determine +// whether Contexts are propagated correctly in a program. +func TODO() Context { + return todo +} + +// A CancelFunc tells an operation to abandon its work. +// A CancelFunc does not wait for the work to stop. +// After the first call, subsequent calls to a CancelFunc do nothing. +type CancelFunc func() + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + c := newCancelCtx(parent) + propagateCancel(parent, &c) + return &c, func() { c.cancel(true, Canceled) } +} + +// newCancelCtx returns an initialized cancelCtx. +func newCancelCtx(parent Context) cancelCtx { + return cancelCtx{ + Context: parent, + done: make(chan struct{}), + } +} + +// propagateCancel arranges for child to be canceled when parent is. +func propagateCancel(parent Context, child canceler) { + if parent.Done() == nil { + return // parent is never canceled + } + if p, ok := parentCancelCtx(parent); ok { + p.mu.Lock() + if p.err != nil { + // parent has already been canceled + child.cancel(false, p.err) + } else { + if p.children == nil { + p.children = make(map[canceler]bool) + } + p.children[child] = true + } + p.mu.Unlock() + } else { + go func() { + select { + case <-parent.Done(): + child.cancel(false, parent.Err()) + case <-child.Done(): + } + }() + } +} + +// parentCancelCtx follows a chain of parent references until it finds a +// *cancelCtx. This function understands how each of the concrete types in this +// package represents its parent. +func parentCancelCtx(parent Context) (*cancelCtx, bool) { + for { + switch c := parent.(type) { + case *cancelCtx: + return c, true + case *timerCtx: + return &c.cancelCtx, true + case *valueCtx: + parent = c.Context + default: + return nil, false + } + } +} + +// removeChild removes a context from its parent. +func removeChild(parent Context, child canceler) { + p, ok := parentCancelCtx(parent) + if !ok { + return + } + p.mu.Lock() + if p.children != nil { + delete(p.children, child) + } + p.mu.Unlock() +} + +// A canceler is a context type that can be canceled directly. The +// implementations are *cancelCtx and *timerCtx. +type canceler interface { + cancel(removeFromParent bool, err error) + Done() <-chan struct{} +} + +// A cancelCtx can be canceled. When canceled, it also cancels any children +// that implement canceler. +type cancelCtx struct { + Context + + done chan struct{} // closed by the first cancel call. + + mu sync.Mutex + children map[canceler]bool // set to nil by the first cancel call + err error // set to non-nil by the first cancel call +} + +func (c *cancelCtx) Done() <-chan struct{} { + return c.done +} + +func (c *cancelCtx) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +func (c *cancelCtx) String() string { + return fmt.Sprintf("%v.WithCancel", c.Context) +} + +// cancel closes c.done, cancels each of c's children, and, if +// removeFromParent is true, removes c from its parent's children. +func (c *cancelCtx) cancel(removeFromParent bool, err error) { + if err == nil { + panic("context: internal error: missing cancel error") + } + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return // already canceled + } + c.err = err + close(c.done) + for child := range c.children { + // NOTE: acquiring the child's lock while holding parent's lock. + child.cancel(false, err) + } + c.children = nil + c.mu.Unlock() + + if removeFromParent { + removeChild(c.Context, c) + } +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { + // The current deadline is already sooner than the new one. + return WithCancel(parent) + } + c := &timerCtx{ + cancelCtx: newCancelCtx(parent), + deadline: deadline, + } + propagateCancel(parent, c) + d := deadline.Sub(time.Now()) + if d <= 0 { + c.cancel(true, DeadlineExceeded) // deadline has already passed + return c, func() { c.cancel(true, Canceled) } + } + c.mu.Lock() + defer c.mu.Unlock() + if c.err == nil { + c.timer = time.AfterFunc(d, func() { + c.cancel(true, DeadlineExceeded) + }) + } + return c, func() { c.cancel(true, Canceled) } +} + +// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to +// implement Done and Err. It implements cancel by stopping its timer then +// delegating to cancelCtx.cancel. +type timerCtx struct { + cancelCtx + timer *time.Timer // Under cancelCtx.mu. + + deadline time.Time +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { + return c.deadline, true +} + +func (c *timerCtx) String() string { + return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) +} + +func (c *timerCtx) cancel(removeFromParent bool, err error) { + c.cancelCtx.cancel(false, err) + if removeFromParent { + // Remove this timerCtx from its parent cancelCtx's children. + removeChild(c.cancelCtx.Context, c) + } + c.mu.Lock() + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + c.mu.Unlock() +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return &valueCtx{parent, key, val} +} + +// A valueCtx carries a key-value pair. It implements Value for that key and +// delegates all other calls to the embedded Context. +type valueCtx struct { + Context + key, val interface{} +} + +func (c *valueCtx) String() string { + return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) +} + +func (c *valueCtx) Value(key interface{}) interface{} { + if c.key == key { + return c.val + } + return c.Context.Value(key) +} diff --git a/Godeps/_workspace/src/golang.org/x/net/context/context_test.go b/Godeps/_workspace/src/golang.org/x/net/context/context_test.go new file mode 100644 index 0000000..e64afa6 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/context/context_test.go @@ -0,0 +1,575 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context + +import ( + "fmt" + "math/rand" + "runtime" + "strings" + "sync" + "testing" + "time" +) + +// otherContext is a Context that's not one of the types defined in context.go. +// This lets us test code paths that differ based on the underlying type of the +// Context. +type otherContext struct { + Context +} + +func TestBackground(t *testing.T) { + c := Background() + if c == nil { + t.Fatalf("Background returned nil") + } + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + if got, want := fmt.Sprint(c), "context.Background"; got != want { + t.Errorf("Background().String() = %q want %q", got, want) + } +} + +func TestTODO(t *testing.T) { + c := TODO() + if c == nil { + t.Fatalf("TODO returned nil") + } + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + if got, want := fmt.Sprint(c), "context.TODO"; got != want { + t.Errorf("TODO().String() = %q want %q", got, want) + } +} + +func TestWithCancel(t *testing.T) { + c1, cancel := WithCancel(Background()) + + if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want { + t.Errorf("c1.String() = %q want %q", got, want) + } + + o := otherContext{c1} + c2, _ := WithCancel(o) + contexts := []Context{c1, o, c2} + + for i, c := range contexts { + if d := c.Done(); d == nil { + t.Errorf("c[%d].Done() == %v want non-nil", i, d) + } + if e := c.Err(); e != nil { + t.Errorf("c[%d].Err() == %v want nil", i, e) + } + + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + } + + cancel() + time.Sleep(100 * time.Millisecond) // let cancelation propagate + + for i, c := range contexts { + select { + case <-c.Done(): + default: + t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i) + } + if e := c.Err(); e != Canceled { + t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled) + } + } +} + +func TestParentFinishesChild(t *testing.T) { + // Context tree: + // parent -> cancelChild + // parent -> valueChild -> timerChild + parent, cancel := WithCancel(Background()) + cancelChild, stop := WithCancel(parent) + defer stop() + valueChild := WithValue(parent, "key", "value") + timerChild, stop := WithTimeout(valueChild, 10000*time.Hour) + defer stop() + + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + case x := <-cancelChild.Done(): + t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x) + case x := <-timerChild.Done(): + t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x) + case x := <-valueChild.Done(): + t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x) + default: + } + + // The parent's children should contain the two cancelable children. + pc := parent.(*cancelCtx) + cc := cancelChild.(*cancelCtx) + tc := timerChild.(*timerCtx) + pc.mu.Lock() + if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] { + t.Errorf("bad linkage: pc.children = %v, want %v and %v", + pc.children, cc, tc) + } + pc.mu.Unlock() + + if p, ok := parentCancelCtx(cc.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc) + } + if p, ok := parentCancelCtx(tc.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc) + } + + cancel() + + pc.mu.Lock() + if len(pc.children) != 0 { + t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children) + } + pc.mu.Unlock() + + // parent and children should all be finished. + check := func(ctx Context, name string) { + select { + case <-ctx.Done(): + default: + t.Errorf("<-%s.Done() blocked, but shouldn't have", name) + } + if e := ctx.Err(); e != Canceled { + t.Errorf("%s.Err() == %v want %v", name, e, Canceled) + } + } + check(parent, "parent") + check(cancelChild, "cancelChild") + check(valueChild, "valueChild") + check(timerChild, "timerChild") + + // WithCancel should return a canceled context on a canceled parent. + precanceledChild := WithValue(parent, "key", "value") + select { + case <-precanceledChild.Done(): + default: + t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have") + } + if e := precanceledChild.Err(); e != Canceled { + t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled) + } +} + +func TestChildFinishesFirst(t *testing.T) { + cancelable, stop := WithCancel(Background()) + defer stop() + for _, parent := range []Context{Background(), cancelable} { + child, cancel := WithCancel(parent) + + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + case x := <-child.Done(): + t.Errorf("<-child.Done() == %v want nothing (it should block)", x) + default: + } + + cc := child.(*cancelCtx) + pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background() + if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) { + t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok) + } + + if pcok { + pc.mu.Lock() + if len(pc.children) != 1 || !pc.children[cc] { + t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc) + } + pc.mu.Unlock() + } + + cancel() + + if pcok { + pc.mu.Lock() + if len(pc.children) != 0 { + t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children) + } + pc.mu.Unlock() + } + + // child should be finished. + select { + case <-child.Done(): + default: + t.Errorf("<-child.Done() blocked, but shouldn't have") + } + if e := child.Err(); e != Canceled { + t.Errorf("child.Err() == %v want %v", e, Canceled) + } + + // parent should not be finished. + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + default: + } + if e := parent.Err(); e != nil { + t.Errorf("parent.Err() == %v want nil", e) + } + } +} + +func testDeadline(c Context, wait time.Duration, t *testing.T) { + select { + case <-time.After(wait): + t.Fatalf("context should have timed out") + case <-c.Done(): + } + if e := c.Err(); e != DeadlineExceeded { + t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded) + } +} + +func TestDeadline(t *testing.T) { + c, _ := WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { + t.Errorf("c.String() = %q want prefix %q", got, prefix) + } + testDeadline(c, 200*time.Millisecond, t) + + c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + o := otherContext{c} + testDeadline(o, 200*time.Millisecond, t) + + c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + o = otherContext{c} + c, _ = WithDeadline(o, time.Now().Add(300*time.Millisecond)) + testDeadline(c, 200*time.Millisecond, t) +} + +func TestTimeout(t *testing.T) { + c, _ := WithTimeout(Background(), 100*time.Millisecond) + if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { + t.Errorf("c.String() = %q want prefix %q", got, prefix) + } + testDeadline(c, 200*time.Millisecond, t) + + c, _ = WithTimeout(Background(), 100*time.Millisecond) + o := otherContext{c} + testDeadline(o, 200*time.Millisecond, t) + + c, _ = WithTimeout(Background(), 100*time.Millisecond) + o = otherContext{c} + c, _ = WithTimeout(o, 300*time.Millisecond) + testDeadline(c, 200*time.Millisecond, t) +} + +func TestCanceledTimeout(t *testing.T) { + c, _ := WithTimeout(Background(), 200*time.Millisecond) + o := otherContext{c} + c, cancel := WithTimeout(o, 400*time.Millisecond) + cancel() + time.Sleep(100 * time.Millisecond) // let cancelation propagate + select { + case <-c.Done(): + default: + t.Errorf("<-c.Done() blocked, but shouldn't have") + } + if e := c.Err(); e != Canceled { + t.Errorf("c.Err() == %v want %v", e, Canceled) + } +} + +type key1 int +type key2 int + +var k1 = key1(1) +var k2 = key2(1) // same int as k1, different type +var k3 = key2(3) // same type as k2, different int + +func TestValues(t *testing.T) { + check := func(c Context, nm, v1, v2, v3 string) { + if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 { + t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0) + } + if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 { + t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0) + } + if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 { + t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0) + } + } + + c0 := Background() + check(c0, "c0", "", "", "") + + c1 := WithValue(Background(), k1, "c1k1") + check(c1, "c1", "c1k1", "", "") + + if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want { + t.Errorf("c.String() = %q want %q", got, want) + } + + c2 := WithValue(c1, k2, "c2k2") + check(c2, "c2", "c1k1", "c2k2", "") + + c3 := WithValue(c2, k3, "c3k3") + check(c3, "c2", "c1k1", "c2k2", "c3k3") + + c4 := WithValue(c3, k1, nil) + check(c4, "c4", "", "c2k2", "c3k3") + + o0 := otherContext{Background()} + check(o0, "o0", "", "", "") + + o1 := otherContext{WithValue(Background(), k1, "c1k1")} + check(o1, "o1", "c1k1", "", "") + + o2 := WithValue(o1, k2, "o2k2") + check(o2, "o2", "c1k1", "o2k2", "") + + o3 := otherContext{c4} + check(o3, "o3", "", "c2k2", "c3k3") + + o4 := WithValue(o3, k3, nil) + check(o4, "o4", "", "c2k2", "") +} + +func TestAllocs(t *testing.T) { + bg := Background() + for _, test := range []struct { + desc string + f func() + limit float64 + gccgoLimit float64 + }{ + { + desc: "Background()", + f: func() { Background() }, + limit: 0, + gccgoLimit: 0, + }, + { + desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1), + f: func() { + c := WithValue(bg, k1, nil) + c.Value(k1) + }, + limit: 3, + gccgoLimit: 3, + }, + { + desc: "WithTimeout(bg, 15*time.Millisecond)", + f: func() { + c, _ := WithTimeout(bg, 15*time.Millisecond) + <-c.Done() + }, + limit: 8, + gccgoLimit: 15, + }, + { + desc: "WithCancel(bg)", + f: func() { + c, cancel := WithCancel(bg) + cancel() + <-c.Done() + }, + limit: 5, + gccgoLimit: 8, + }, + { + desc: "WithTimeout(bg, 100*time.Millisecond)", + f: func() { + c, cancel := WithTimeout(bg, 100*time.Millisecond) + cancel() + <-c.Done() + }, + limit: 8, + gccgoLimit: 25, + }, + } { + limit := test.limit + if runtime.Compiler == "gccgo" { + // gccgo does not yet do escape analysis. + // TOOD(iant): Remove this when gccgo does do escape analysis. + limit = test.gccgoLimit + } + if n := testing.AllocsPerRun(100, test.f); n > limit { + t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit)) + } + } +} + +func TestSimultaneousCancels(t *testing.T) { + root, cancel := WithCancel(Background()) + m := map[Context]CancelFunc{root: cancel} + q := []Context{root} + // Create a tree of contexts. + for len(q) != 0 && len(m) < 100 { + parent := q[0] + q = q[1:] + for i := 0; i < 4; i++ { + ctx, cancel := WithCancel(parent) + m[ctx] = cancel + q = append(q, ctx) + } + } + // Start all the cancels in a random order. + var wg sync.WaitGroup + wg.Add(len(m)) + for _, cancel := range m { + go func(cancel CancelFunc) { + cancel() + wg.Done() + }(cancel) + } + // Wait on all the contexts in a random order. + for ctx := range m { + select { + case <-ctx.Done(): + case <-time.After(1 * time.Second): + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n]) + } + } + // Wait for all the cancel functions to return. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(1 * time.Second): + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n]) + } +} + +func TestInterlockedCancels(t *testing.T) { + parent, cancelParent := WithCancel(Background()) + child, cancelChild := WithCancel(parent) + go func() { + parent.Done() + cancelChild() + }() + cancelParent() + select { + case <-child.Done(): + case <-time.After(1 * time.Second): + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n]) + } +} + +func TestLayersCancel(t *testing.T) { + testLayers(t, time.Now().UnixNano(), false) +} + +func TestLayersTimeout(t *testing.T) { + testLayers(t, time.Now().UnixNano(), true) +} + +func testLayers(t *testing.T, seed int64, testTimeout bool) { + rand.Seed(seed) + errorf := func(format string, a ...interface{}) { + t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...) + } + const ( + timeout = 200 * time.Millisecond + minLayers = 30 + ) + type value int + var ( + vals []*value + cancels []CancelFunc + numTimers int + ctx = Background() + ) + for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ { + switch rand.Intn(3) { + case 0: + v := new(value) + ctx = WithValue(ctx, v, v) + vals = append(vals, v) + case 1: + var cancel CancelFunc + ctx, cancel = WithCancel(ctx) + cancels = append(cancels, cancel) + case 2: + var cancel CancelFunc + ctx, cancel = WithTimeout(ctx, timeout) + cancels = append(cancels, cancel) + numTimers++ + } + } + checkValues := func(when string) { + for _, key := range vals { + if val := ctx.Value(key).(*value); key != val { + errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key) + } + } + } + select { + case <-ctx.Done(): + errorf("ctx should not be canceled yet") + default: + } + if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) { + t.Errorf("ctx.String() = %q want prefix %q", s, prefix) + } + t.Log(ctx) + checkValues("before cancel") + if testTimeout { + select { + case <-ctx.Done(): + case <-time.After(timeout + timeout/10): + errorf("ctx should have timed out") + } + checkValues("after timeout") + } else { + cancel := cancels[rand.Intn(len(cancels))] + cancel() + select { + case <-ctx.Done(): + default: + errorf("ctx should be canceled") + } + checkValues("after cancel") + } +} + +func TestCancelRemoves(t *testing.T) { + checkChildren := func(when string, ctx Context, want int) { + if got := len(ctx.(*cancelCtx).children); got != want { + t.Errorf("%s: context has %d children, want %d", when, got, want) + } + } + + ctx, _ := WithCancel(Background()) + checkChildren("after creation", ctx, 0) + _, cancel := WithCancel(ctx) + checkChildren("with WithCancel child ", ctx, 1) + cancel() + checkChildren("after cancelling WithCancel child", ctx, 0) + + ctx, _ = WithCancel(Background()) + checkChildren("after creation", ctx, 0) + _, cancel = WithTimeout(ctx, 60*time.Minute) + checkChildren("with WithTimeout child ", ctx, 1) + cancel() + checkChildren("after cancelling WithTimeout child", ctx, 0) +} diff --git a/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/cancelreq.go b/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/cancelreq.go new file mode 100644 index 0000000..48610e3 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/cancelreq.go @@ -0,0 +1,18 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.5 + +package ctxhttp + +import "net/http" + +func canceler(client *http.Client, req *http.Request) func() { + ch := make(chan struct{}) + req.Cancel = ch + + return func() { + close(ch) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/cancelreq_go14.go b/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/cancelreq_go14.go new file mode 100644 index 0000000..56bcbad --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/cancelreq_go14.go @@ -0,0 +1,23 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.5 + +package ctxhttp + +import "net/http" + +type requestCanceler interface { + CancelRequest(*http.Request) +} + +func canceler(client *http.Client, req *http.Request) func() { + rc, ok := client.Transport.(requestCanceler) + if !ok { + return func() {} + } + return func() { + rc.CancelRequest(req) + } +} diff --git a/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/ctxhttp.go b/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/ctxhttp.go new file mode 100644 index 0000000..9f34888 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/ctxhttp.go @@ -0,0 +1,79 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ctxhttp provides helper functions for performing context-aware HTTP requests. +package ctxhttp + +import ( + "io" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/context" +) + +// Do sends an HTTP request with the provided http.Client and returns an HTTP response. +// If the client is nil, http.DefaultClient is used. +// If the context is canceled or times out, ctx.Err() will be returned. +func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { + if client == nil { + client = http.DefaultClient + } + + // Request cancelation changed in Go 1.5, see cancelreq.go and cancelreq_go14.go. + cancel := canceler(client, req) + + type responseAndError struct { + resp *http.Response + err error + } + result := make(chan responseAndError, 1) + + go func() { + resp, err := client.Do(req) + result <- responseAndError{resp, err} + }() + + select { + case <-ctx.Done(): + cancel() + return nil, ctx.Err() + case r := <-result: + return r.resp, r.err + } +} + +// Get issues a GET request via the Do function. +func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Head issues a HEAD request via the Do function. +func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Post issues a POST request via the Do function. +func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + return Do(ctx, client, req) +} + +// PostForm issues a POST request via the Do function. +func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { + return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} diff --git a/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/ctxhttp_test.go b/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/ctxhttp_test.go new file mode 100644 index 0000000..47b53d7 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/context/ctxhttp/ctxhttp_test.go @@ -0,0 +1,72 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ctxhttp + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/net/context" +) + +const ( + requestDuration = 100 * time.Millisecond + requestBody = "ok" +) + +func TestNoTimeout(t *testing.T) { + ctx := context.Background() + resp, err := doRequest(ctx) + + if resp == nil || err != nil { + t.Fatalf("error received from client: %v %v", err, resp) + } +} +func TestCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(requestDuration / 2) + cancel() + }() + + resp, err := doRequest(ctx) + + if resp != nil || err == nil { + t.Fatalf("expected error, didn't get one. resp: %v", resp) + } + if err != ctx.Err() { + t.Fatalf("expected error from context but got: %v", err) + } +} + +func TestCancelAfterRequest(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + resp, err := doRequest(ctx) + + // Cancel before reading the body. + // Request.Body should still be readable after the context is canceled. + cancel() + + b, err := ioutil.ReadAll(resp.Body) + if err != nil || string(b) != requestBody { + t.Fatalf("could not read body: %q %v", b, err) + } +} + +func doRequest(ctx context.Context) (*http.Response, error) { + var okHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(requestDuration) + w.Write([]byte(requestBody)) + }) + + serv := httptest.NewServer(okHandler) + defer serv.Close() + + return Get(ctx, nil, serv.URL) +} diff --git a/Godeps/_workspace/src/golang.org/x/net/context/withtimeout_test.go b/Godeps/_workspace/src/golang.org/x/net/context/withtimeout_test.go new file mode 100644 index 0000000..a6754dc --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/context/withtimeout_test.go @@ -0,0 +1,26 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context_test + +import ( + "fmt" + "time" + + "golang.org/x/net/context" +) + +func ExampleWithTimeout() { + // Pass a context with a timeout to tell a blocking function that it + // should abandon its work after the timeout elapses. + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + select { + case <-time.After(200 * time.Millisecond): + fmt.Println("overslept") + case <-ctx.Done(): + fmt.Println(ctx.Err()) // prints "context deadline exceeded" + } + // Output: + // context deadline exceeded +} diff --git a/Godeps/_workspace/src/golang.org/x/net/internal/timeseries/timeseries.go b/Godeps/_workspace/src/golang.org/x/net/internal/timeseries/timeseries.go new file mode 100644 index 0000000..3f90b73 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/internal/timeseries/timeseries.go @@ -0,0 +1,525 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package timeseries implements a time series structure for stats collection. +package timeseries + +import ( + "fmt" + "log" + "time" +) + +const ( + timeSeriesNumBuckets = 64 + minuteHourSeriesNumBuckets = 60 +) + +var timeSeriesResolutions = []time.Duration{ + 1 * time.Second, + 10 * time.Second, + 1 * time.Minute, + 10 * time.Minute, + 1 * time.Hour, + 6 * time.Hour, + 24 * time.Hour, // 1 day + 7 * 24 * time.Hour, // 1 week + 4 * 7 * 24 * time.Hour, // 4 weeks + 16 * 7 * 24 * time.Hour, // 16 weeks +} + +var minuteHourSeriesResolutions = []time.Duration{ + 1 * time.Second, + 1 * time.Minute, +} + +// An Observable is a kind of data that can be aggregated in a time series. +type Observable interface { + Multiply(ratio float64) // Multiplies the data in self by a given ratio + Add(other Observable) // Adds the data from a different observation to self + Clear() // Clears the observation so it can be reused. + CopyFrom(other Observable) // Copies the contents of a given observation to self +} + +// Float attaches the methods of Observable to a float64. +type Float float64 + +// NewFloat returns a Float. +func NewFloat() Observable { + f := Float(0) + return &f +} + +// String returns the float as a string. +func (f *Float) String() string { return fmt.Sprintf("%g", f.Value()) } + +// Value returns the float's value. +func (f *Float) Value() float64 { return float64(*f) } + +func (f *Float) Multiply(ratio float64) { *f *= Float(ratio) } + +func (f *Float) Add(other Observable) { + o := other.(*Float) + *f += *o +} + +func (f *Float) Clear() { *f = 0 } + +func (f *Float) CopyFrom(other Observable) { + o := other.(*Float) + *f = *o +} + +// A Clock tells the current time. +type Clock interface { + Time() time.Time +} + +type defaultClock int + +var defaultClockInstance defaultClock + +func (defaultClock) Time() time.Time { return time.Now() } + +// Information kept per level. Each level consists of a circular list of +// observations. The start of the level may be derived from end and the +// len(buckets) * sizeInMillis. +type tsLevel struct { + oldest int // index to oldest bucketed Observable + newest int // index to newest bucketed Observable + end time.Time // end timestamp for this level + size time.Duration // duration of the bucketed Observable + buckets []Observable // collections of observations + provider func() Observable // used for creating new Observable +} + +func (l *tsLevel) Clear() { + l.oldest = 0 + l.newest = len(l.buckets) - 1 + l.end = time.Time{} + for i := range l.buckets { + if l.buckets[i] != nil { + l.buckets[i].Clear() + l.buckets[i] = nil + } + } +} + +func (l *tsLevel) InitLevel(size time.Duration, numBuckets int, f func() Observable) { + l.size = size + l.provider = f + l.buckets = make([]Observable, numBuckets) +} + +// Keeps a sequence of levels. Each level is responsible for storing data at +// a given resolution. For example, the first level stores data at a one +// minute resolution while the second level stores data at a one hour +// resolution. + +// Each level is represented by a sequence of buckets. Each bucket spans an +// interval equal to the resolution of the level. New observations are added +// to the last bucket. +type timeSeries struct { + provider func() Observable // make more Observable + numBuckets int // number of buckets in each level + levels []*tsLevel // levels of bucketed Observable + lastAdd time.Time // time of last Observable tracked + total Observable // convenient aggregation of all Observable + clock Clock // Clock for getting current time + pending Observable // observations not yet bucketed + pendingTime time.Time // what time are we keeping in pending + dirty bool // if there are pending observations +} + +// init initializes a level according to the supplied criteria. +func (ts *timeSeries) init(resolutions []time.Duration, f func() Observable, numBuckets int, clock Clock) { + ts.provider = f + ts.numBuckets = numBuckets + ts.clock = clock + ts.levels = make([]*tsLevel, len(resolutions)) + + for i := range resolutions { + if i > 0 && resolutions[i-1] >= resolutions[i] { + log.Print("timeseries: resolutions must be monotonically increasing") + break + } + newLevel := new(tsLevel) + newLevel.InitLevel(resolutions[i], ts.numBuckets, ts.provider) + ts.levels[i] = newLevel + } + + ts.Clear() +} + +// Clear removes all observations from the time series. +func (ts *timeSeries) Clear() { + ts.lastAdd = time.Time{} + ts.total = ts.resetObservation(ts.total) + ts.pending = ts.resetObservation(ts.pending) + ts.pendingTime = time.Time{} + ts.dirty = false + + for i := range ts.levels { + ts.levels[i].Clear() + } +} + +// Add records an observation at the current time. +func (ts *timeSeries) Add(observation Observable) { + ts.AddWithTime(observation, ts.clock.Time()) +} + +// AddWithTime records an observation at the specified time. +func (ts *timeSeries) AddWithTime(observation Observable, t time.Time) { + + smallBucketDuration := ts.levels[0].size + + if t.After(ts.lastAdd) { + ts.lastAdd = t + } + + if t.After(ts.pendingTime) { + ts.advance(t) + ts.mergePendingUpdates() + ts.pendingTime = ts.levels[0].end + ts.pending.CopyFrom(observation) + ts.dirty = true + } else if t.After(ts.pendingTime.Add(-1 * smallBucketDuration)) { + // The observation is close enough to go into the pending bucket. + // This compensates for clock skewing and small scheduling delays + // by letting the update stay in the fast path. + ts.pending.Add(observation) + ts.dirty = true + } else { + ts.mergeValue(observation, t) + } +} + +// mergeValue inserts the observation at the specified time in the past into all levels. +func (ts *timeSeries) mergeValue(observation Observable, t time.Time) { + for _, level := range ts.levels { + index := (ts.numBuckets - 1) - int(level.end.Sub(t)/level.size) + if 0 <= index && index < ts.numBuckets { + bucketNumber := (level.oldest + index) % ts.numBuckets + if level.buckets[bucketNumber] == nil { + level.buckets[bucketNumber] = level.provider() + } + level.buckets[bucketNumber].Add(observation) + } + } + ts.total.Add(observation) +} + +// mergePendingUpdates applies the pending updates into all levels. +func (ts *timeSeries) mergePendingUpdates() { + if ts.dirty { + ts.mergeValue(ts.pending, ts.pendingTime) + ts.pending = ts.resetObservation(ts.pending) + ts.dirty = false + } +} + +// advance cycles the buckets at each level until the latest bucket in +// each level can hold the time specified. +func (ts *timeSeries) advance(t time.Time) { + if !t.After(ts.levels[0].end) { + return + } + for i := 0; i < len(ts.levels); i++ { + level := ts.levels[i] + if !level.end.Before(t) { + break + } + + // If the time is sufficiently far, just clear the level and advance + // directly. + if !t.Before(level.end.Add(level.size * time.Duration(ts.numBuckets))) { + for _, b := range level.buckets { + ts.resetObservation(b) + } + level.end = time.Unix(0, (t.UnixNano()/level.size.Nanoseconds())*level.size.Nanoseconds()) + } + + for t.After(level.end) { + level.end = level.end.Add(level.size) + level.newest = level.oldest + level.oldest = (level.oldest + 1) % ts.numBuckets + ts.resetObservation(level.buckets[level.newest]) + } + + t = level.end + } +} + +// Latest returns the sum of the num latest buckets from the level. +func (ts *timeSeries) Latest(level, num int) Observable { + now := ts.clock.Time() + if ts.levels[0].end.Before(now) { + ts.advance(now) + } + + ts.mergePendingUpdates() + + result := ts.provider() + l := ts.levels[level] + index := l.newest + + for i := 0; i < num; i++ { + if l.buckets[index] != nil { + result.Add(l.buckets[index]) + } + if index == 0 { + index = ts.numBuckets + } + index-- + } + + return result +} + +// LatestBuckets returns a copy of the num latest buckets from level. +func (ts *timeSeries) LatestBuckets(level, num int) []Observable { + if level < 0 || level > len(ts.levels) { + log.Print("timeseries: bad level argument: ", level) + return nil + } + if num < 0 || num >= ts.numBuckets { + log.Print("timeseries: bad num argument: ", num) + return nil + } + + results := make([]Observable, num) + now := ts.clock.Time() + if ts.levels[0].end.Before(now) { + ts.advance(now) + } + + ts.mergePendingUpdates() + + l := ts.levels[level] + index := l.newest + + for i := 0; i < num; i++ { + result := ts.provider() + results[i] = result + if l.buckets[index] != nil { + result.CopyFrom(l.buckets[index]) + } + + if index == 0 { + index = ts.numBuckets + } + index -= 1 + } + return results +} + +// ScaleBy updates observations by scaling by factor. +func (ts *timeSeries) ScaleBy(factor float64) { + for _, l := range ts.levels { + for i := 0; i < ts.numBuckets; i++ { + l.buckets[i].Multiply(factor) + } + } + + ts.total.Multiply(factor) + ts.pending.Multiply(factor) +} + +// Range returns the sum of observations added over the specified time range. +// If start or finish times don't fall on bucket boundaries of the same +// level, then return values are approximate answers. +func (ts *timeSeries) Range(start, finish time.Time) Observable { + return ts.ComputeRange(start, finish, 1)[0] +} + +// Recent returns the sum of observations from the last delta. +func (ts *timeSeries) Recent(delta time.Duration) Observable { + now := ts.clock.Time() + return ts.Range(now.Add(-delta), now) +} + +// Total returns the total of all observations. +func (ts *timeSeries) Total() Observable { + ts.mergePendingUpdates() + return ts.total +} + +// ComputeRange computes a specified number of values into a slice using +// the observations recorded over the specified time period. The return +// values are approximate if the start or finish times don't fall on the +// bucket boundaries at the same level or if the number of buckets spanning +// the range is not an integral multiple of num. +func (ts *timeSeries) ComputeRange(start, finish time.Time, num int) []Observable { + if start.After(finish) { + log.Printf("timeseries: start > finish, %v>%v", start, finish) + return nil + } + + if num < 0 { + log.Printf("timeseries: num < 0, %v", num) + return nil + } + + results := make([]Observable, num) + + for _, l := range ts.levels { + if !start.Before(l.end.Add(-l.size * time.Duration(ts.numBuckets))) { + ts.extract(l, start, finish, num, results) + return results + } + } + + // Failed to find a level that covers the desired range. So just + // extract from the last level, even if it doesn't cover the entire + // desired range. + ts.extract(ts.levels[len(ts.levels)-1], start, finish, num, results) + + return results +} + +// RecentList returns the specified number of values in slice over the most +// recent time period of the specified range. +func (ts *timeSeries) RecentList(delta time.Duration, num int) []Observable { + if delta < 0 { + return nil + } + now := ts.clock.Time() + return ts.ComputeRange(now.Add(-delta), now, num) +} + +// extract returns a slice of specified number of observations from a given +// level over a given range. +func (ts *timeSeries) extract(l *tsLevel, start, finish time.Time, num int, results []Observable) { + ts.mergePendingUpdates() + + srcInterval := l.size + dstInterval := finish.Sub(start) / time.Duration(num) + dstStart := start + srcStart := l.end.Add(-srcInterval * time.Duration(ts.numBuckets)) + + srcIndex := 0 + + // Where should scanning start? + if dstStart.After(srcStart) { + advance := dstStart.Sub(srcStart) / srcInterval + srcIndex += int(advance) + srcStart = srcStart.Add(advance * srcInterval) + } + + // The i'th value is computed as show below. + // interval = (finish/start)/num + // i'th value = sum of observation in range + // [ start + i * interval, + // start + (i + 1) * interval ) + for i := 0; i < num; i++ { + results[i] = ts.resetObservation(results[i]) + dstEnd := dstStart.Add(dstInterval) + for srcIndex < ts.numBuckets && srcStart.Before(dstEnd) { + srcEnd := srcStart.Add(srcInterval) + if srcEnd.After(ts.lastAdd) { + srcEnd = ts.lastAdd + } + + if !srcEnd.Before(dstStart) { + srcValue := l.buckets[(srcIndex+l.oldest)%ts.numBuckets] + if !srcStart.Before(dstStart) && !srcEnd.After(dstEnd) { + // dst completely contains src. + if srcValue != nil { + results[i].Add(srcValue) + } + } else { + // dst partially overlaps src. + overlapStart := maxTime(srcStart, dstStart) + overlapEnd := minTime(srcEnd, dstEnd) + base := srcEnd.Sub(srcStart) + fraction := overlapEnd.Sub(overlapStart).Seconds() / base.Seconds() + + used := ts.provider() + if srcValue != nil { + used.CopyFrom(srcValue) + } + used.Multiply(fraction) + results[i].Add(used) + } + + if srcEnd.After(dstEnd) { + break + } + } + srcIndex++ + srcStart = srcStart.Add(srcInterval) + } + dstStart = dstStart.Add(dstInterval) + } +} + +// resetObservation clears the content so the struct may be reused. +func (ts *timeSeries) resetObservation(observation Observable) Observable { + if observation == nil { + observation = ts.provider() + } else { + observation.Clear() + } + return observation +} + +// TimeSeries tracks data at granularities from 1 second to 16 weeks. +type TimeSeries struct { + timeSeries +} + +// NewTimeSeries creates a new TimeSeries using the function provided for creating new Observable. +func NewTimeSeries(f func() Observable) *TimeSeries { + return NewTimeSeriesWithClock(f, defaultClockInstance) +} + +// NewTimeSeriesWithClock creates a new TimeSeries using the function provided for creating new Observable and the clock for +// assigning timestamps. +func NewTimeSeriesWithClock(f func() Observable, clock Clock) *TimeSeries { + ts := new(TimeSeries) + ts.timeSeries.init(timeSeriesResolutions, f, timeSeriesNumBuckets, clock) + return ts +} + +// MinuteHourSeries tracks data at granularities of 1 minute and 1 hour. +type MinuteHourSeries struct { + timeSeries +} + +// NewMinuteHourSeries creates a new MinuteHourSeries using the function provided for creating new Observable. +func NewMinuteHourSeries(f func() Observable) *MinuteHourSeries { + return NewMinuteHourSeriesWithClock(f, defaultClockInstance) +} + +// NewMinuteHourSeriesWithClock creates a new MinuteHourSeries using the function provided for creating new Observable and the clock for +// assigning timestamps. +func NewMinuteHourSeriesWithClock(f func() Observable, clock Clock) *MinuteHourSeries { + ts := new(MinuteHourSeries) + ts.timeSeries.init(minuteHourSeriesResolutions, f, + minuteHourSeriesNumBuckets, clock) + return ts +} + +func (ts *MinuteHourSeries) Minute() Observable { + return ts.timeSeries.Latest(0, 60) +} + +func (ts *MinuteHourSeries) Hour() Observable { + return ts.timeSeries.Latest(1, 60) +} + +func minTime(a, b time.Time) time.Time { + if a.Before(b) { + return a + } + return b +} + +func maxTime(a, b time.Time) time.Time { + if a.After(b) { + return a + } + return b +} diff --git a/Godeps/_workspace/src/golang.org/x/net/internal/timeseries/timeseries_test.go b/Godeps/_workspace/src/golang.org/x/net/internal/timeseries/timeseries_test.go new file mode 100644 index 0000000..66325a9 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/internal/timeseries/timeseries_test.go @@ -0,0 +1,170 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package timeseries + +import ( + "math" + "testing" + "time" +) + +func isNear(x *Float, y float64, tolerance float64) bool { + return math.Abs(x.Value()-y) < tolerance +} + +func isApproximate(x *Float, y float64) bool { + return isNear(x, y, 1e-2) +} + +func checkApproximate(t *testing.T, o Observable, y float64) { + x := o.(*Float) + if !isApproximate(x, y) { + t.Errorf("Wanted %g, got %g", y, x.Value()) + } +} + +func checkNear(t *testing.T, o Observable, y, tolerance float64) { + x := o.(*Float) + if !isNear(x, y, tolerance) { + t.Errorf("Wanted %g +- %g, got %g", y, tolerance, x.Value()) + } +} + +var baseTime = time.Date(2013, 1, 1, 0, 0, 0, 0, time.UTC) + +func tu(s int64) time.Time { + return baseTime.Add(time.Duration(s) * time.Second) +} + +func tu2(s int64, ns int64) time.Time { + return baseTime.Add(time.Duration(s)*time.Second + time.Duration(ns)*time.Nanosecond) +} + +func TestBasicTimeSeries(t *testing.T) { + ts := NewTimeSeries(NewFloat) + fo := new(Float) + *fo = Float(10) + ts.AddWithTime(fo, tu(1)) + ts.AddWithTime(fo, tu(1)) + ts.AddWithTime(fo, tu(1)) + ts.AddWithTime(fo, tu(1)) + checkApproximate(t, ts.Range(tu(0), tu(1)), 40) + checkApproximate(t, ts.Total(), 40) + ts.AddWithTime(fo, tu(3)) + ts.AddWithTime(fo, tu(3)) + ts.AddWithTime(fo, tu(3)) + checkApproximate(t, ts.Range(tu(0), tu(2)), 40) + checkApproximate(t, ts.Range(tu(2), tu(4)), 30) + checkApproximate(t, ts.Total(), 70) + ts.AddWithTime(fo, tu(1)) + ts.AddWithTime(fo, tu(1)) + checkApproximate(t, ts.Range(tu(0), tu(2)), 60) + checkApproximate(t, ts.Range(tu(2), tu(4)), 30) + checkApproximate(t, ts.Total(), 90) + *fo = Float(100) + ts.AddWithTime(fo, tu(100)) + checkApproximate(t, ts.Range(tu(99), tu(100)), 100) + checkApproximate(t, ts.Range(tu(0), tu(4)), 36) + checkApproximate(t, ts.Total(), 190) + *fo = Float(10) + ts.AddWithTime(fo, tu(1)) + ts.AddWithTime(fo, tu(1)) + checkApproximate(t, ts.Range(tu(0), tu(4)), 44) + checkApproximate(t, ts.Range(tu(37), tu2(100, 100e6)), 100) + checkApproximate(t, ts.Range(tu(50), tu2(100, 100e6)), 100) + checkApproximate(t, ts.Range(tu(99), tu2(100, 100e6)), 100) + checkApproximate(t, ts.Total(), 210) + + for i, l := range ts.ComputeRange(tu(36), tu(100), 64) { + if i == 63 { + checkApproximate(t, l, 100) + } else { + checkApproximate(t, l, 0) + } + } + + checkApproximate(t, ts.Range(tu(0), tu(100)), 210) + checkApproximate(t, ts.Range(tu(10), tu(100)), 100) + + for i, l := range ts.ComputeRange(tu(0), tu(100), 100) { + if i < 10 { + checkApproximate(t, l, 11) + } else if i >= 90 { + checkApproximate(t, l, 10) + } else { + checkApproximate(t, l, 0) + } + } +} + +func TestFloat(t *testing.T) { + f := Float(1) + if g, w := f.String(), "1"; g != w { + t.Errorf("Float(1).String = %q; want %q", g, w) + } + f2 := Float(2) + var o Observable = &f2 + f.Add(o) + if g, w := f.Value(), 3.0; g != w { + t.Errorf("Float post-add = %v; want %v", g, w) + } + f.Multiply(2) + if g, w := f.Value(), 6.0; g != w { + t.Errorf("Float post-multiply = %v; want %v", g, w) + } + f.Clear() + if g, w := f.Value(), 0.0; g != w { + t.Errorf("Float post-clear = %v; want %v", g, w) + } + f.CopyFrom(&f2) + if g, w := f.Value(), 2.0; g != w { + t.Errorf("Float post-CopyFrom = %v; want %v", g, w) + } +} + +type mockClock struct { + time time.Time +} + +func (m *mockClock) Time() time.Time { return m.time } +func (m *mockClock) Set(t time.Time) { m.time = t } + +const buckets = 6 + +var testResolutions = []time.Duration{ + 10 * time.Second, // level holds one minute of observations + 100 * time.Second, // level holds ten minutes of observations + 10 * time.Minute, // level holds one hour of observations +} + +// TestTimeSeries uses a small number of buckets to force a higher +// error rate on approximations from the timeseries. +type TestTimeSeries struct { + timeSeries +} + +func TestExpectedErrorRate(t *testing.T) { + ts := new(TestTimeSeries) + fake := new(mockClock) + fake.Set(time.Now()) + ts.timeSeries.init(testResolutions, NewFloat, buckets, fake) + for i := 1; i <= 61*61; i++ { + fake.Set(fake.Time().Add(1 * time.Second)) + ob := Float(1) + ts.AddWithTime(&ob, fake.Time()) + + // The results should be accurate within one missing bucket (1/6) of the observations recorded. + checkNear(t, ts.Latest(0, buckets), min(float64(i), 60), 10) + checkNear(t, ts.Latest(1, buckets), min(float64(i), 600), 100) + checkNear(t, ts.Latest(2, buckets), min(float64(i), 3600), 600) + } +} + +func min(a, b float64) float64 { + if a < b { + return a + } + return b +} diff --git a/Godeps/_workspace/src/golang.org/x/net/trace/events.go b/Godeps/_workspace/src/golang.org/x/net/trace/events.go new file mode 100644 index 0000000..e66c7e3 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/trace/events.go @@ -0,0 +1,524 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package trace + +import ( + "bytes" + "fmt" + "html/template" + "io" + "log" + "net/http" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "text/tabwriter" + "time" +) + +var eventsTmpl = template.Must(template.New("events").Funcs(template.FuncMap{ + "elapsed": elapsed, + "trimSpace": strings.TrimSpace, +}).Parse(eventsHTML)) + +const maxEventsPerLog = 100 + +type bucket struct { + MaxErrAge time.Duration + String string +} + +var buckets = []bucket{ + {0, "total"}, + {10 * time.Second, "errs<10s"}, + {1 * time.Minute, "errs<1m"}, + {10 * time.Minute, "errs<10m"}, + {1 * time.Hour, "errs<1h"}, + {10 * time.Hour, "errs<10h"}, + {24000 * time.Hour, "errors"}, +} + +// RenderEvents renders the HTML page typically served at /debug/events. +// It does not do any auth checking; see AuthRequest for the default auth check +// used by the handler registered on http.DefaultServeMux. +// req may be nil. +func RenderEvents(w http.ResponseWriter, req *http.Request, sensitive bool) { + now := time.Now() + data := &struct { + Families []string // family names + Buckets []bucket + Counts [][]int // eventLog count per family/bucket + + // Set when a bucket has been selected. + Family string + Bucket int + EventLogs eventLogs + Expanded bool + }{ + Buckets: buckets, + } + + data.Families = make([]string, 0, len(families)) + famMu.RLock() + for name := range families { + data.Families = append(data.Families, name) + } + famMu.RUnlock() + sort.Strings(data.Families) + + // Count the number of eventLogs in each family for each error age. + data.Counts = make([][]int, len(data.Families)) + for i, name := range data.Families { + // TODO(sameer): move this loop under the family lock. + f := getEventFamily(name) + data.Counts[i] = make([]int, len(data.Buckets)) + for j, b := range data.Buckets { + data.Counts[i][j] = f.Count(now, b.MaxErrAge) + } + } + + if req != nil { + var ok bool + data.Family, data.Bucket, ok = parseEventsArgs(req) + if !ok { + // No-op + } else { + data.EventLogs = getEventFamily(data.Family).Copy(now, buckets[data.Bucket].MaxErrAge) + } + if data.EventLogs != nil { + defer data.EventLogs.Free() + sort.Sort(data.EventLogs) + } + if exp, err := strconv.ParseBool(req.FormValue("exp")); err == nil { + data.Expanded = exp + } + } + + famMu.RLock() + defer famMu.RUnlock() + if err := eventsTmpl.Execute(w, data); err != nil { + log.Printf("net/trace: Failed executing template: %v", err) + } +} + +func parseEventsArgs(req *http.Request) (fam string, b int, ok bool) { + fam, bStr := req.FormValue("fam"), req.FormValue("b") + if fam == "" || bStr == "" { + return "", 0, false + } + b, err := strconv.Atoi(bStr) + if err != nil || b < 0 || b >= len(buckets) { + return "", 0, false + } + return fam, b, true +} + +// An EventLog provides a log of events associated with a specific object. +type EventLog interface { + // Printf formats its arguments with fmt.Sprintf and adds the + // result to the event log. + Printf(format string, a ...interface{}) + + // Errorf is like Printf, but it marks this event as an error. + Errorf(format string, a ...interface{}) + + // Finish declares that this event log is complete. + // The event log should not be used after calling this method. + Finish() +} + +// NewEventLog returns a new EventLog with the specified family name +// and title. +func NewEventLog(family, title string) EventLog { + el := newEventLog() + el.ref() + el.Family, el.Title = family, title + el.Start = time.Now() + el.events = make([]logEntry, 0, maxEventsPerLog) + el.stack = make([]uintptr, 32) + n := runtime.Callers(2, el.stack) + el.stack = el.stack[:n] + + getEventFamily(family).add(el) + return el +} + +func (el *eventLog) Finish() { + getEventFamily(el.Family).remove(el) + el.unref() // matches ref in New +} + +var ( + famMu sync.RWMutex + families = make(map[string]*eventFamily) // family name => family +) + +func getEventFamily(fam string) *eventFamily { + famMu.Lock() + defer famMu.Unlock() + f := families[fam] + if f == nil { + f = &eventFamily{} + families[fam] = f + } + return f +} + +type eventFamily struct { + mu sync.RWMutex + eventLogs eventLogs +} + +func (f *eventFamily) add(el *eventLog) { + f.mu.Lock() + f.eventLogs = append(f.eventLogs, el) + f.mu.Unlock() +} + +func (f *eventFamily) remove(el *eventLog) { + f.mu.Lock() + defer f.mu.Unlock() + for i, el0 := range f.eventLogs { + if el == el0 { + copy(f.eventLogs[i:], f.eventLogs[i+1:]) + f.eventLogs = f.eventLogs[:len(f.eventLogs)-1] + return + } + } +} + +func (f *eventFamily) Count(now time.Time, maxErrAge time.Duration) (n int) { + f.mu.RLock() + defer f.mu.RUnlock() + for _, el := range f.eventLogs { + if el.hasRecentError(now, maxErrAge) { + n++ + } + } + return +} + +func (f *eventFamily) Copy(now time.Time, maxErrAge time.Duration) (els eventLogs) { + f.mu.RLock() + defer f.mu.RUnlock() + els = make(eventLogs, 0, len(f.eventLogs)) + for _, el := range f.eventLogs { + if el.hasRecentError(now, maxErrAge) { + el.ref() + els = append(els, el) + } + } + return +} + +type eventLogs []*eventLog + +// Free calls unref on each element of the list. +func (els eventLogs) Free() { + for _, el := range els { + el.unref() + } +} + +// eventLogs may be sorted in reverse chronological order. +func (els eventLogs) Len() int { return len(els) } +func (els eventLogs) Less(i, j int) bool { return els[i].Start.After(els[j].Start) } +func (els eventLogs) Swap(i, j int) { els[i], els[j] = els[j], els[i] } + +// A logEntry is a timestamped log entry in an event log. +type logEntry struct { + When time.Time + Elapsed time.Duration // since previous event in log + NewDay bool // whether this event is on a different day to the previous event + What string + IsErr bool +} + +// WhenString returns a string representation of the elapsed time of the event. +// It will include the date if midnight was crossed. +func (e logEntry) WhenString() string { + if e.NewDay { + return e.When.Format("2006/01/02 15:04:05.000000") + } + return e.When.Format("15:04:05.000000") +} + +// An eventLog represents an active event log. +type eventLog struct { + // Family is the top-level grouping of event logs to which this belongs. + Family string + + // Title is the title of this event log. + Title string + + // Timing information. + Start time.Time + + // Call stack where this event log was created. + stack []uintptr + + // Append-only sequence of events. + // + // TODO(sameer): change this to a ring buffer to avoid the array copy + // when we hit maxEventsPerLog. + mu sync.RWMutex + events []logEntry + LastErrorTime time.Time + discarded int + + refs int32 // how many buckets this is in +} + +func (el *eventLog) reset() { + // Clear all but the mutex. Mutexes may not be copied, even when unlocked. + el.Family = "" + el.Title = "" + el.Start = time.Time{} + el.stack = nil + el.events = nil + el.LastErrorTime = time.Time{} + el.discarded = 0 + el.refs = 0 +} + +func (el *eventLog) hasRecentError(now time.Time, maxErrAge time.Duration) bool { + if maxErrAge == 0 { + return true + } + el.mu.RLock() + defer el.mu.RUnlock() + return now.Sub(el.LastErrorTime) < maxErrAge +} + +// delta returns the elapsed time since the last event or the log start, +// and whether it spans midnight. +// L >= el.mu +func (el *eventLog) delta(t time.Time) (time.Duration, bool) { + if len(el.events) == 0 { + return t.Sub(el.Start), false + } + prev := el.events[len(el.events)-1].When + return t.Sub(prev), prev.Day() != t.Day() + +} + +func (el *eventLog) Printf(format string, a ...interface{}) { + el.printf(false, format, a...) +} + +func (el *eventLog) Errorf(format string, a ...interface{}) { + el.printf(true, format, a...) +} + +func (el *eventLog) printf(isErr bool, format string, a ...interface{}) { + e := logEntry{When: time.Now(), IsErr: isErr, What: fmt.Sprintf(format, a...)} + el.mu.Lock() + e.Elapsed, e.NewDay = el.delta(e.When) + if len(el.events) < maxEventsPerLog { + el.events = append(el.events, e) + } else { + // Discard the oldest event. + if el.discarded == 0 { + // el.discarded starts at two to count for the event it + // is replacing, plus the next one that we are about to + // drop. + el.discarded = 2 + } else { + el.discarded++ + } + // TODO(sameer): if this causes allocations on a critical path, + // change eventLog.What to be a fmt.Stringer, as in trace.go. + el.events[0].What = fmt.Sprintf("(%d events discarded)", el.discarded) + // The timestamp of the discarded meta-event should be + // the time of the last event it is representing. + el.events[0].When = el.events[1].When + copy(el.events[1:], el.events[2:]) + el.events[maxEventsPerLog-1] = e + } + if e.IsErr { + el.LastErrorTime = e.When + } + el.mu.Unlock() +} + +func (el *eventLog) ref() { + atomic.AddInt32(&el.refs, 1) +} + +func (el *eventLog) unref() { + if atomic.AddInt32(&el.refs, -1) == 0 { + freeEventLog(el) + } +} + +func (el *eventLog) When() string { + return el.Start.Format("2006/01/02 15:04:05.000000") +} + +func (el *eventLog) ElapsedTime() string { + elapsed := time.Since(el.Start) + return fmt.Sprintf("%.6f", elapsed.Seconds()) +} + +func (el *eventLog) Stack() string { + buf := new(bytes.Buffer) + tw := tabwriter.NewWriter(buf, 1, 8, 1, '\t', 0) + printStackRecord(tw, el.stack) + tw.Flush() + return buf.String() +} + +// printStackRecord prints the function + source line information +// for a single stack trace. +// Adapted from runtime/pprof/pprof.go. +func printStackRecord(w io.Writer, stk []uintptr) { + for _, pc := range stk { + f := runtime.FuncForPC(pc) + if f == nil { + continue + } + file, line := f.FileLine(pc) + name := f.Name() + // Hide runtime.goexit and any runtime functions at the beginning. + if strings.HasPrefix(name, "runtime.") { + continue + } + fmt.Fprintf(w, "# %s\t%s:%d\n", name, file, line) + } +} + +func (el *eventLog) Events() []logEntry { + el.mu.RLock() + defer el.mu.RUnlock() + return el.events +} + +// freeEventLogs is a freelist of *eventLog +var freeEventLogs = make(chan *eventLog, 1000) + +// newEventLog returns a event log ready to use. +func newEventLog() *eventLog { + select { + case el := <-freeEventLogs: + return el + default: + return new(eventLog) + } +} + +// freeEventLog adds el to freeEventLogs if there's room. +// This is non-blocking. +func freeEventLog(el *eventLog) { + el.reset() + select { + case freeEventLogs <- el: + default: + } +} + +const eventsHTML = ` + + + events + + + + +

/debug/events

+ + + {{range $i, $fam := .Families}} + + + + {{range $j, $bucket := $.Buckets}} + {{$n := index $.Counts $i $j}} + + {{end}} + + {{end}} +
{{$fam}} + {{if $n}}{{end}} + [{{$n}} {{$bucket.String}}] + {{if $n}}{{end}} +
+ +{{if $.EventLogs}} +
+

Family: {{$.Family}}

+ +{{if $.Expanded}}{{end}} +[Summary]{{if $.Expanded}}{{end}} + +{{if not $.Expanded}}{{end}} +[Expanded]{{if not $.Expanded}}{{end}} + + + + {{range $el := $.EventLogs}} + + + + + {{if $.Expanded}} + + + + + + {{range $el.Events}} + + + + + + {{end}} + {{end}} + {{end}} +
WhenElapsed
{{$el.When}}{{$el.ElapsedTime}}{{$el.Title}} +
{{$el.Stack|trimSpace}}
{{.WhenString}}{{elapsed .Elapsed}}.{{if .IsErr}}E{{else}}.{{end}}. {{.What}}
+{{end}} + + +` diff --git a/Godeps/_workspace/src/golang.org/x/net/trace/histogram.go b/Godeps/_workspace/src/golang.org/x/net/trace/histogram.go new file mode 100644 index 0000000..bb42aa5 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/trace/histogram.go @@ -0,0 +1,356 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package trace + +// This file implements histogramming for RPC statistics collection. + +import ( + "bytes" + "fmt" + "html/template" + "log" + "math" + + "golang.org/x/net/internal/timeseries" +) + +const ( + bucketCount = 38 +) + +// histogram keeps counts of values in buckets that are spaced +// out in powers of 2: 0-1, 2-3, 4-7... +// histogram implements timeseries.Observable +type histogram struct { + sum int64 // running total of measurements + sumOfSquares float64 // square of running total + buckets []int64 // bucketed values for histogram + value int // holds a single value as an optimization + valueCount int64 // number of values recorded for single value +} + +// AddMeasurement records a value measurement observation to the histogram. +func (h *histogram) addMeasurement(value int64) { + // TODO: assert invariant + h.sum += value + h.sumOfSquares += float64(value) * float64(value) + + bucketIndex := getBucket(value) + + if h.valueCount == 0 || (h.valueCount > 0 && h.value == bucketIndex) { + h.value = bucketIndex + h.valueCount++ + } else { + h.allocateBuckets() + h.buckets[bucketIndex]++ + } +} + +func (h *histogram) allocateBuckets() { + if h.buckets == nil { + h.buckets = make([]int64, bucketCount) + h.buckets[h.value] = h.valueCount + h.value = 0 + h.valueCount = -1 + } +} + +func log2(i int64) int { + n := 0 + for ; i >= 0x100; i >>= 8 { + n += 8 + } + for ; i > 0; i >>= 1 { + n += 1 + } + return n +} + +func getBucket(i int64) (index int) { + index = log2(i) - 1 + if index < 0 { + index = 0 + } + if index >= bucketCount { + index = bucketCount - 1 + } + return +} + +// Total returns the number of recorded observations. +func (h *histogram) total() (total int64) { + if h.valueCount >= 0 { + total = h.valueCount + } + for _, val := range h.buckets { + total += int64(val) + } + return +} + +// Average returns the average value of recorded observations. +func (h *histogram) average() float64 { + t := h.total() + if t == 0 { + return 0 + } + return float64(h.sum) / float64(t) +} + +// Variance returns the variance of recorded observations. +func (h *histogram) variance() float64 { + t := float64(h.total()) + if t == 0 { + return 0 + } + s := float64(h.sum) / t + return h.sumOfSquares/t - s*s +} + +// StandardDeviation returns the standard deviation of recorded observations. +func (h *histogram) standardDeviation() float64 { + return math.Sqrt(h.variance()) +} + +// PercentileBoundary estimates the value that the given fraction of recorded +// observations are less than. +func (h *histogram) percentileBoundary(percentile float64) int64 { + total := h.total() + + // Corner cases (make sure result is strictly less than Total()) + if total == 0 { + return 0 + } else if total == 1 { + return int64(h.average()) + } + + percentOfTotal := round(float64(total) * percentile) + var runningTotal int64 + + for i := range h.buckets { + value := h.buckets[i] + runningTotal += value + if runningTotal == percentOfTotal { + // We hit an exact bucket boundary. If the next bucket has data, it is a + // good estimate of the value. If the bucket is empty, we interpolate the + // midpoint between the next bucket's boundary and the next non-zero + // bucket. If the remaining buckets are all empty, then we use the + // boundary for the next bucket as the estimate. + j := uint8(i + 1) + min := bucketBoundary(j) + if runningTotal < total { + for h.buckets[j] == 0 { + j++ + } + } + max := bucketBoundary(j) + return min + round(float64(max-min)/2) + } else if runningTotal > percentOfTotal { + // The value is in this bucket. Interpolate the value. + delta := runningTotal - percentOfTotal + percentBucket := float64(value-delta) / float64(value) + bucketMin := bucketBoundary(uint8(i)) + nextBucketMin := bucketBoundary(uint8(i + 1)) + bucketSize := nextBucketMin - bucketMin + return bucketMin + round(percentBucket*float64(bucketSize)) + } + } + return bucketBoundary(bucketCount - 1) +} + +// Median returns the estimated median of the observed values. +func (h *histogram) median() int64 { + return h.percentileBoundary(0.5) +} + +// Add adds other to h. +func (h *histogram) Add(other timeseries.Observable) { + o := other.(*histogram) + if o.valueCount == 0 { + // Other histogram is empty + } else if h.valueCount >= 0 && o.valueCount > 0 && h.value == o.value { + // Both have a single bucketed value, aggregate them + h.valueCount += o.valueCount + } else { + // Two different values necessitate buckets in this histogram + h.allocateBuckets() + if o.valueCount >= 0 { + h.buckets[o.value] += o.valueCount + } else { + for i := range h.buckets { + h.buckets[i] += o.buckets[i] + } + } + } + h.sumOfSquares += o.sumOfSquares + h.sum += o.sum +} + +// Clear resets the histogram to an empty state, removing all observed values. +func (h *histogram) Clear() { + h.buckets = nil + h.value = 0 + h.valueCount = 0 + h.sum = 0 + h.sumOfSquares = 0 +} + +// CopyFrom copies from other, which must be a *histogram, into h. +func (h *histogram) CopyFrom(other timeseries.Observable) { + o := other.(*histogram) + if o.valueCount == -1 { + h.allocateBuckets() + copy(h.buckets, o.buckets) + } + h.sum = o.sum + h.sumOfSquares = o.sumOfSquares + h.value = o.value + h.valueCount = o.valueCount +} + +// Multiply scales the histogram by the specified ratio. +func (h *histogram) Multiply(ratio float64) { + if h.valueCount == -1 { + for i := range h.buckets { + h.buckets[i] = int64(float64(h.buckets[i]) * ratio) + } + } else { + h.valueCount = int64(float64(h.valueCount) * ratio) + } + h.sum = int64(float64(h.sum) * ratio) + h.sumOfSquares = h.sumOfSquares * ratio +} + +// New creates a new histogram. +func (h *histogram) New() timeseries.Observable { + r := new(histogram) + r.Clear() + return r +} + +func (h *histogram) String() string { + return fmt.Sprintf("%d, %f, %d, %d, %v", + h.sum, h.sumOfSquares, h.value, h.valueCount, h.buckets) +} + +// round returns the closest int64 to the argument +func round(in float64) int64 { + return int64(math.Floor(in + 0.5)) +} + +// bucketBoundary returns the first value in the bucket. +func bucketBoundary(bucket uint8) int64 { + if bucket == 0 { + return 0 + } + return 1 << bucket +} + +// bucketData holds data about a specific bucket for use in distTmpl. +type bucketData struct { + Lower, Upper int64 + N int64 + Pct, CumulativePct float64 + GraphWidth int +} + +// data holds data about a Distribution for use in distTmpl. +type data struct { + Buckets []*bucketData + Count, Median int64 + Mean, StandardDeviation float64 +} + +// maxHTMLBarWidth is the maximum width of the HTML bar for visualizing buckets. +const maxHTMLBarWidth = 350.0 + +// newData returns data representing h for use in distTmpl. +func (h *histogram) newData() *data { + // Force the allocation of buckets to simplify the rendering implementation + h.allocateBuckets() + // We scale the bars on the right so that the largest bar is + // maxHTMLBarWidth pixels in width. + maxBucket := int64(0) + for _, n := range h.buckets { + if n > maxBucket { + maxBucket = n + } + } + total := h.total() + barsizeMult := maxHTMLBarWidth / float64(maxBucket) + var pctMult float64 + if total == 0 { + pctMult = 1.0 + } else { + pctMult = 100.0 / float64(total) + } + + buckets := make([]*bucketData, len(h.buckets)) + runningTotal := int64(0) + for i, n := range h.buckets { + if n == 0 { + continue + } + runningTotal += n + var upperBound int64 + if i < bucketCount-1 { + upperBound = bucketBoundary(uint8(i + 1)) + } else { + upperBound = math.MaxInt64 + } + buckets[i] = &bucketData{ + Lower: bucketBoundary(uint8(i)), + Upper: upperBound, + N: n, + Pct: float64(n) * pctMult, + CumulativePct: float64(runningTotal) * pctMult, + GraphWidth: int(float64(n) * barsizeMult), + } + } + return &data{ + Buckets: buckets, + Count: total, + Median: h.median(), + Mean: h.average(), + StandardDeviation: h.standardDeviation(), + } +} + +func (h *histogram) html() template.HTML { + buf := new(bytes.Buffer) + if err := distTmpl.Execute(buf, h.newData()); err != nil { + buf.Reset() + log.Printf("net/trace: couldn't execute template: %v", err) + } + return template.HTML(buf.String()) +} + +// Input: data +var distTmpl = template.Must(template.New("distTmpl").Parse(` + + + + + + + +
Count: {{.Count}}Mean: {{printf "%.0f" .Mean}}StdDev: {{printf "%.0f" .StandardDeviation}}Median: {{.Median}}
+
+ +{{range $b := .Buckets}} +{{if $b}} + + + + + + + + + +{{end}} +{{end}} +
[{{.Lower}},{{.Upper}}){{.N}}{{printf "%#.3f" .Pct}}%{{printf "%#.3f" .CumulativePct}}%
+`)) diff --git a/Godeps/_workspace/src/golang.org/x/net/trace/histogram_test.go b/Godeps/_workspace/src/golang.org/x/net/trace/histogram_test.go new file mode 100644 index 0000000..d384b93 --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/trace/histogram_test.go @@ -0,0 +1,325 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package trace + +import ( + "math" + "testing" +) + +type sumTest struct { + value int64 + sum int64 + sumOfSquares float64 + total int64 +} + +var sumTests = []sumTest{ + {100, 100, 10000, 1}, + {50, 150, 12500, 2}, + {50, 200, 15000, 3}, + {50, 250, 17500, 4}, +} + +type bucketingTest struct { + in int64 + log int + bucket int +} + +var bucketingTests = []bucketingTest{ + {0, 0, 0}, + {1, 1, 0}, + {2, 2, 1}, + {3, 2, 1}, + {4, 3, 2}, + {1000, 10, 9}, + {1023, 10, 9}, + {1024, 11, 10}, + {1000000, 20, 19}, +} + +type multiplyTest struct { + in int64 + ratio float64 + expectedSum int64 + expectedTotal int64 + expectedSumOfSquares float64 +} + +var multiplyTests = []multiplyTest{ + {15, 2.5, 37, 2, 562.5}, + {128, 4.6, 758, 13, 77953.9}, +} + +type percentileTest struct { + fraction float64 + expected int64 +} + +var percentileTests = []percentileTest{ + {0.25, 48}, + {0.5, 96}, + {0.6, 109}, + {0.75, 128}, + {0.90, 205}, + {0.95, 230}, + {0.99, 256}, +} + +func TestSum(t *testing.T) { + var h histogram + + for _, test := range sumTests { + h.addMeasurement(test.value) + sum := h.sum + if sum != test.sum { + t.Errorf("h.Sum = %v WANT: %v", sum, test.sum) + } + + sumOfSquares := h.sumOfSquares + if sumOfSquares != test.sumOfSquares { + t.Errorf("h.SumOfSquares = %v WANT: %v", sumOfSquares, test.sumOfSquares) + } + + total := h.total() + if total != test.total { + t.Errorf("h.Total = %v WANT: %v", total, test.total) + } + } +} + +func TestMultiply(t *testing.T) { + var h histogram + for i, test := range multiplyTests { + h.addMeasurement(test.in) + h.Multiply(test.ratio) + if h.sum != test.expectedSum { + t.Errorf("#%v: h.sum = %v WANT: %v", i, h.sum, test.expectedSum) + } + if h.total() != test.expectedTotal { + t.Errorf("#%v: h.total = %v WANT: %v", i, h.total(), test.expectedTotal) + } + if h.sumOfSquares != test.expectedSumOfSquares { + t.Errorf("#%v: h.SumOfSquares = %v WANT: %v", i, test.expectedSumOfSquares, h.sumOfSquares) + } + } +} + +func TestBucketingFunctions(t *testing.T) { + for _, test := range bucketingTests { + log := log2(test.in) + if log != test.log { + t.Errorf("log2 = %v WANT: %v", log, test.log) + } + + bucket := getBucket(test.in) + if bucket != test.bucket { + t.Errorf("getBucket = %v WANT: %v", bucket, test.bucket) + } + } +} + +func TestAverage(t *testing.T) { + a := new(histogram) + average := a.average() + if average != 0 { + t.Errorf("Average of empty histogram was %v WANT: 0", average) + } + + a.addMeasurement(1) + a.addMeasurement(1) + a.addMeasurement(3) + const expected = float64(5) / float64(3) + average = a.average() + + if !isApproximate(average, expected) { + t.Errorf("Average = %g WANT: %v", average, expected) + } +} + +func TestStandardDeviation(t *testing.T) { + a := new(histogram) + add(a, 10, 1<<4) + add(a, 10, 1<<5) + add(a, 10, 1<<6) + stdDev := a.standardDeviation() + const expected = 19.95 + + if !isApproximate(stdDev, expected) { + t.Errorf("StandardDeviation = %v WANT: %v", stdDev, expected) + } + + // No values + a = new(histogram) + stdDev = a.standardDeviation() + + if !isApproximate(stdDev, 0) { + t.Errorf("StandardDeviation = %v WANT: 0", stdDev) + } + + add(a, 1, 1<<4) + if !isApproximate(stdDev, 0) { + t.Errorf("StandardDeviation = %v WANT: 0", stdDev) + } + + add(a, 10, 1<<4) + if !isApproximate(stdDev, 0) { + t.Errorf("StandardDeviation = %v WANT: 0", stdDev) + } +} + +func TestPercentileBoundary(t *testing.T) { + a := new(histogram) + add(a, 5, 1<<4) + add(a, 10, 1<<6) + add(a, 5, 1<<7) + + for _, test := range percentileTests { + percentile := a.percentileBoundary(test.fraction) + if percentile != test.expected { + t.Errorf("h.PercentileBoundary (fraction=%v) = %v WANT: %v", test.fraction, percentile, test.expected) + } + } +} + +func TestCopyFrom(t *testing.T) { + a := histogram{5, 25, []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38}, 4, -1} + b := histogram{6, 36, []int64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}, 5, -1} + + a.CopyFrom(&b) + + if a.String() != b.String() { + t.Errorf("a.String = %s WANT: %s", a.String(), b.String()) + } +} + +func TestClear(t *testing.T) { + a := histogram{5, 25, []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38}, 4, -1} + + a.Clear() + + expected := "0, 0.000000, 0, 0, []" + if a.String() != expected { + t.Errorf("a.String = %s WANT %s", a.String(), expected) + } +} + +func TestNew(t *testing.T) { + a := histogram{5, 25, []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38}, 4, -1} + b := a.New() + + expected := "0, 0.000000, 0, 0, []" + if b.(*histogram).String() != expected { + t.Errorf("b.(*histogram).String = %s WANT: %s", b.(*histogram).String(), expected) + } +} + +func TestAdd(t *testing.T) { + // The tests here depend on the associativity of addMeasurement and Add. + // Add empty observation + a := histogram{5, 25, []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38}, 4, -1} + b := a.New() + + expected := a.String() + a.Add(b) + if a.String() != expected { + t.Errorf("a.String = %s WANT: %s", a.String(), expected) + } + + // Add same bucketed value, no new buckets + c := new(histogram) + d := new(histogram) + e := new(histogram) + c.addMeasurement(12) + d.addMeasurement(11) + e.addMeasurement(12) + e.addMeasurement(11) + c.Add(d) + if c.String() != e.String() { + t.Errorf("c.String = %s WANT: %s", c.String(), e.String()) + } + + // Add bucketed values + f := new(histogram) + g := new(histogram) + h := new(histogram) + f.addMeasurement(4) + f.addMeasurement(12) + f.addMeasurement(100) + g.addMeasurement(18) + g.addMeasurement(36) + g.addMeasurement(255) + h.addMeasurement(4) + h.addMeasurement(12) + h.addMeasurement(100) + h.addMeasurement(18) + h.addMeasurement(36) + h.addMeasurement(255) + f.Add(g) + if f.String() != h.String() { + t.Errorf("f.String = %q WANT: %q", f.String(), h.String()) + } + + // add buckets to no buckets + i := new(histogram) + j := new(histogram) + k := new(histogram) + j.addMeasurement(18) + j.addMeasurement(36) + j.addMeasurement(255) + k.addMeasurement(18) + k.addMeasurement(36) + k.addMeasurement(255) + i.Add(j) + if i.String() != k.String() { + t.Errorf("i.String = %q WANT: %q", i.String(), k.String()) + } + + // add buckets to single value (no overlap) + l := new(histogram) + m := new(histogram) + n := new(histogram) + l.addMeasurement(0) + m.addMeasurement(18) + m.addMeasurement(36) + m.addMeasurement(255) + n.addMeasurement(0) + n.addMeasurement(18) + n.addMeasurement(36) + n.addMeasurement(255) + l.Add(m) + if l.String() != n.String() { + t.Errorf("l.String = %q WANT: %q", l.String(), n.String()) + } + + // mixed order + o := new(histogram) + p := new(histogram) + o.addMeasurement(0) + o.addMeasurement(2) + o.addMeasurement(0) + p.addMeasurement(0) + p.addMeasurement(0) + p.addMeasurement(2) + if o.String() != p.String() { + t.Errorf("o.String = %q WANT: %q", o.String(), p.String()) + } +} + +func add(h *histogram, times int, val int64) { + for i := 0; i < times; i++ { + h.addMeasurement(val) + } +} + +func isApproximate(x, y float64) bool { + return math.Abs(x-y) < 1e-2 +} diff --git a/Godeps/_workspace/src/golang.org/x/net/trace/trace.go b/Godeps/_workspace/src/golang.org/x/net/trace/trace.go new file mode 100644 index 0000000..c87290b --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/trace/trace.go @@ -0,0 +1,1057 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package trace implements tracing of requests and long-lived objects. +It exports HTTP interfaces on /debug/requests and /debug/events. + +A trace.Trace provides tracing for short-lived objects, usually requests. +A request handler might be implemented like this: + + func fooHandler(w http.ResponseWriter, req *http.Request) { + tr := trace.New("mypkg.Foo", req.URL.Path) + defer tr.Finish() + ... + tr.LazyPrintf("some event %q happened", str) + ... + if err := somethingImportant(); err != nil { + tr.LazyPrintf("somethingImportant failed: %v", err) + tr.SetError() + } + } + +The /debug/requests HTTP endpoint organizes the traces by family, +errors, and duration. It also provides histogram of request duration +for each family. + +A trace.EventLog provides tracing for long-lived objects, such as RPC +connections. + + // A Fetcher fetches URL paths for a single domain. + type Fetcher struct { + domain string + events trace.EventLog + } + + func NewFetcher(domain string) *Fetcher { + return &Fetcher{ + domain, + trace.NewEventLog("mypkg.Fetcher", domain), + } + } + + func (f *Fetcher) Fetch(path string) (string, error) { + resp, err := http.Get("http://" + f.domain + "/" + path) + if err != nil { + f.events.Errorf("Get(%q) = %v", path, err) + return "", err + } + f.events.Printf("Get(%q) = %s", path, resp.Status) + ... + } + + func (f *Fetcher) Close() error { + f.events.Finish() + return nil + } + +The /debug/events HTTP endpoint organizes the event logs by family and +by time since the last error. The expanded view displays recent log +entries and the log's call stack. +*/ +package trace + +import ( + "bytes" + "fmt" + "html/template" + "io" + "log" + "net" + "net/http" + "runtime" + "sort" + "strconv" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/context" + "golang.org/x/net/internal/timeseries" +) + +// DebugUseAfterFinish controls whether to debug uses of Trace values after finishing. +// FOR DEBUGGING ONLY. This will slow down the program. +var DebugUseAfterFinish = false + +// AuthRequest determines whether a specific request is permitted to load the +// /debug/requests or /debug/events pages. +// +// It returns two bools; the first indicates whether the page may be viewed at all, +// and the second indicates whether sensitive events will be shown. +// +// AuthRequest may be replaced by a program to customise its authorisation requirements. +// +// The default AuthRequest function returns (true, true) iff the request comes from localhost/127.0.0.1/[::1]. +var AuthRequest = func(req *http.Request) (any, sensitive bool) { + host, _, err := net.SplitHostPort(req.RemoteAddr) + switch { + case err != nil: // Badly formed address; fail closed. + return false, false + case host == "localhost" || host == "127.0.0.1" || host == "::1": + return true, true + default: + return false, false + } +} + +func init() { + http.HandleFunc("/debug/requests", func(w http.ResponseWriter, req *http.Request) { + any, sensitive := AuthRequest(req) + if !any { + http.Error(w, "not allowed", http.StatusUnauthorized) + return + } + Render(w, req, sensitive) + }) + http.HandleFunc("/debug/events", func(w http.ResponseWriter, req *http.Request) { + any, sensitive := AuthRequest(req) + if !any { + http.Error(w, "not allowed", http.StatusUnauthorized) + return + } + RenderEvents(w, req, sensitive) + }) +} + +// Render renders the HTML page typically served at /debug/requests. +// It does not do any auth checking; see AuthRequest for the default auth check +// used by the handler registered on http.DefaultServeMux. +// req may be nil. +func Render(w io.Writer, req *http.Request, sensitive bool) { + data := &struct { + Families []string + ActiveTraceCount map[string]int + CompletedTraces map[string]*family + + // Set when a bucket has been selected. + Traces traceList + Family string + Bucket int + Expanded bool + Traced bool + Active bool + ShowSensitive bool // whether to show sensitive events + + Histogram template.HTML + HistogramWindow string // e.g. "last minute", "last hour", "all time" + + // If non-zero, the set of traces is a partial set, + // and this is the total number. + Total int + }{ + CompletedTraces: completedTraces, + } + + data.ShowSensitive = sensitive + if req != nil { + // Allow show_sensitive=0 to force hiding of sensitive data for testing. + // This only goes one way; you can't use show_sensitive=1 to see things. + if req.FormValue("show_sensitive") == "0" { + data.ShowSensitive = false + } + + if exp, err := strconv.ParseBool(req.FormValue("exp")); err == nil { + data.Expanded = exp + } + if exp, err := strconv.ParseBool(req.FormValue("rtraced")); err == nil { + data.Traced = exp + } + } + + completedMu.RLock() + data.Families = make([]string, 0, len(completedTraces)) + for fam, _ := range completedTraces { + data.Families = append(data.Families, fam) + } + completedMu.RUnlock() + sort.Strings(data.Families) + + // We are careful here to minimize the time spent locking activeMu, + // since that lock is required every time an RPC starts and finishes. + data.ActiveTraceCount = make(map[string]int, len(data.Families)) + activeMu.RLock() + for fam, s := range activeTraces { + data.ActiveTraceCount[fam] = s.Len() + } + activeMu.RUnlock() + + var ok bool + data.Family, data.Bucket, ok = parseArgs(req) + switch { + case !ok: + // No-op + case data.Bucket == -1: + data.Active = true + n := data.ActiveTraceCount[data.Family] + data.Traces = getActiveTraces(data.Family) + if len(data.Traces) < n { + data.Total = n + } + case data.Bucket < bucketsPerFamily: + if b := lookupBucket(data.Family, data.Bucket); b != nil { + data.Traces = b.Copy(data.Traced) + } + default: + if f := getFamily(data.Family, false); f != nil { + var obs timeseries.Observable + f.LatencyMu.RLock() + switch o := data.Bucket - bucketsPerFamily; o { + case 0: + obs = f.Latency.Minute() + data.HistogramWindow = "last minute" + case 1: + obs = f.Latency.Hour() + data.HistogramWindow = "last hour" + case 2: + obs = f.Latency.Total() + data.HistogramWindow = "all time" + } + f.LatencyMu.RUnlock() + if obs != nil { + data.Histogram = obs.(*histogram).html() + } + } + } + + if data.Traces != nil { + defer data.Traces.Free() + sort.Sort(data.Traces) + } + + completedMu.RLock() + defer completedMu.RUnlock() + if err := pageTmpl.ExecuteTemplate(w, "Page", data); err != nil { + log.Printf("net/trace: Failed executing template: %v", err) + } +} + +func parseArgs(req *http.Request) (fam string, b int, ok bool) { + if req == nil { + return "", 0, false + } + fam, bStr := req.FormValue("fam"), req.FormValue("b") + if fam == "" || bStr == "" { + return "", 0, false + } + b, err := strconv.Atoi(bStr) + if err != nil || b < -1 { + return "", 0, false + } + + return fam, b, true +} + +func lookupBucket(fam string, b int) *traceBucket { + f := getFamily(fam, false) + if f == nil || b < 0 || b >= len(f.Buckets) { + return nil + } + return f.Buckets[b] +} + +type contextKeyT string + +var contextKey = contextKeyT("golang.org/x/net/trace.Trace") + +// NewContext returns a copy of the parent context +// and associates it with a Trace. +func NewContext(ctx context.Context, tr Trace) context.Context { + return context.WithValue(ctx, contextKey, tr) +} + +// FromContext returns the Trace bound to the context, if any. +func FromContext(ctx context.Context) (tr Trace, ok bool) { + tr, ok = ctx.Value(contextKey).(Trace) + return +} + +// Trace represents an active request. +type Trace interface { + // LazyLog adds x to the event log. It will be evaluated each time the + // /debug/requests page is rendered. Any memory referenced by x will be + // pinned until the trace is finished and later discarded. + LazyLog(x fmt.Stringer, sensitive bool) + + // LazyPrintf evaluates its arguments with fmt.Sprintf each time the + // /debug/requests page is rendered. Any memory referenced by a will be + // pinned until the trace is finished and later discarded. + LazyPrintf(format string, a ...interface{}) + + // SetError declares that this trace resulted in an error. + SetError() + + // SetRecycler sets a recycler for the trace. + // f will be called for each event passed to LazyLog at a time when + // it is no longer required, whether while the trace is still active + // and the event is discarded, or when a completed trace is discarded. + SetRecycler(f func(interface{})) + + // SetTraceInfo sets the trace info for the trace. + // This is currently unused. + SetTraceInfo(traceID, spanID uint64) + + // SetMaxEvents sets the maximum number of events that will be stored + // in the trace. This has no effect if any events have already been + // added to the trace. + SetMaxEvents(m int) + + // Finish declares that this trace is complete. + // The trace should not be used after calling this method. + Finish() +} + +type lazySprintf struct { + format string + a []interface{} +} + +func (l *lazySprintf) String() string { + return fmt.Sprintf(l.format, l.a...) +} + +// New returns a new Trace with the specified family and title. +func New(family, title string) Trace { + tr := newTrace() + tr.ref() + tr.Family, tr.Title = family, title + tr.Start = time.Now() + tr.events = make([]event, 0, maxEventsPerTrace) + + activeMu.RLock() + s := activeTraces[tr.Family] + activeMu.RUnlock() + if s == nil { + activeMu.Lock() + s = activeTraces[tr.Family] // check again + if s == nil { + s = new(traceSet) + activeTraces[tr.Family] = s + } + activeMu.Unlock() + } + s.Add(tr) + + // Trigger allocation of the completed trace structure for this family. + // This will cause the family to be present in the request page during + // the first trace of this family. We don't care about the return value, + // nor is there any need for this to run inline, so we execute it in its + // own goroutine, but only if the family isn't allocated yet. + completedMu.RLock() + if _, ok := completedTraces[tr.Family]; !ok { + go allocFamily(tr.Family) + } + completedMu.RUnlock() + + return tr +} + +func (tr *trace) Finish() { + tr.Elapsed = time.Now().Sub(tr.Start) + if DebugUseAfterFinish { + buf := make([]byte, 4<<10) // 4 KB should be enough + n := runtime.Stack(buf, false) + tr.finishStack = buf[:n] + } + + activeMu.RLock() + m := activeTraces[tr.Family] + activeMu.RUnlock() + m.Remove(tr) + + f := getFamily(tr.Family, true) + for _, b := range f.Buckets { + if b.Cond.match(tr) { + b.Add(tr) + } + } + // Add a sample of elapsed time as microseconds to the family's timeseries + h := new(histogram) + h.addMeasurement(tr.Elapsed.Nanoseconds() / 1e3) + f.LatencyMu.Lock() + f.Latency.Add(h) + f.LatencyMu.Unlock() + + tr.unref() // matches ref in New +} + +const ( + bucketsPerFamily = 9 + tracesPerBucket = 10 + maxActiveTraces = 20 // Maximum number of active traces to show. + maxEventsPerTrace = 10 + numHistogramBuckets = 38 +) + +var ( + // The active traces. + activeMu sync.RWMutex + activeTraces = make(map[string]*traceSet) // family -> traces + + // Families of completed traces. + completedMu sync.RWMutex + completedTraces = make(map[string]*family) // family -> traces +) + +type traceSet struct { + mu sync.RWMutex + m map[*trace]bool + + // We could avoid the entire map scan in FirstN by having a slice of all the traces + // ordered by start time, and an index into that from the trace struct, with a periodic + // repack of the slice after enough traces finish; we could also use a skip list or similar. + // However, that would shift some of the expense from /debug/requests time to RPC time, + // which is probably the wrong trade-off. +} + +func (ts *traceSet) Len() int { + ts.mu.RLock() + defer ts.mu.RUnlock() + return len(ts.m) +} + +func (ts *traceSet) Add(tr *trace) { + ts.mu.Lock() + if ts.m == nil { + ts.m = make(map[*trace]bool) + } + ts.m[tr] = true + ts.mu.Unlock() +} + +func (ts *traceSet) Remove(tr *trace) { + ts.mu.Lock() + delete(ts.m, tr) + ts.mu.Unlock() +} + +// FirstN returns the first n traces ordered by time. +func (ts *traceSet) FirstN(n int) traceList { + ts.mu.RLock() + defer ts.mu.RUnlock() + + if n > len(ts.m) { + n = len(ts.m) + } + trl := make(traceList, 0, n) + + // Fast path for when no selectivity is needed. + if n == len(ts.m) { + for tr := range ts.m { + tr.ref() + trl = append(trl, tr) + } + sort.Sort(trl) + return trl + } + + // Pick the oldest n traces. + // This is inefficient. See the comment in the traceSet struct. + for tr := range ts.m { + // Put the first n traces into trl in the order they occur. + // When we have n, sort trl, and thereafter maintain its order. + if len(trl) < n { + tr.ref() + trl = append(trl, tr) + if len(trl) == n { + // This is guaranteed to happen exactly once during this loop. + sort.Sort(trl) + } + continue + } + if tr.Start.After(trl[n-1].Start) { + continue + } + + // Find where to insert this one. + tr.ref() + i := sort.Search(n, func(i int) bool { return trl[i].Start.After(tr.Start) }) + trl[n-1].unref() + copy(trl[i+1:], trl[i:]) + trl[i] = tr + } + + return trl +} + +func getActiveTraces(fam string) traceList { + activeMu.RLock() + s := activeTraces[fam] + activeMu.RUnlock() + if s == nil { + return nil + } + return s.FirstN(maxActiveTraces) +} + +func getFamily(fam string, allocNew bool) *family { + completedMu.RLock() + f := completedTraces[fam] + completedMu.RUnlock() + if f == nil && allocNew { + f = allocFamily(fam) + } + return f +} + +func allocFamily(fam string) *family { + completedMu.Lock() + defer completedMu.Unlock() + f := completedTraces[fam] + if f == nil { + f = newFamily() + completedTraces[fam] = f + } + return f +} + +// family represents a set of trace buckets and associated latency information. +type family struct { + // traces may occur in multiple buckets. + Buckets [bucketsPerFamily]*traceBucket + + // latency time series + LatencyMu sync.RWMutex + Latency *timeseries.MinuteHourSeries +} + +func newFamily() *family { + return &family{ + Buckets: [bucketsPerFamily]*traceBucket{ + {Cond: minCond(0)}, + {Cond: minCond(50 * time.Millisecond)}, + {Cond: minCond(100 * time.Millisecond)}, + {Cond: minCond(200 * time.Millisecond)}, + {Cond: minCond(500 * time.Millisecond)}, + {Cond: minCond(1 * time.Second)}, + {Cond: minCond(10 * time.Second)}, + {Cond: minCond(100 * time.Second)}, + {Cond: errorCond{}}, + }, + Latency: timeseries.NewMinuteHourSeries(func() timeseries.Observable { return new(histogram) }), + } +} + +// traceBucket represents a size-capped bucket of historic traces, +// along with a condition for a trace to belong to the bucket. +type traceBucket struct { + Cond cond + + // Ring buffer implementation of a fixed-size FIFO queue. + mu sync.RWMutex + buf [tracesPerBucket]*trace + start int // < tracesPerBucket + length int // <= tracesPerBucket +} + +func (b *traceBucket) Add(tr *trace) { + b.mu.Lock() + defer b.mu.Unlock() + + i := b.start + b.length + if i >= tracesPerBucket { + i -= tracesPerBucket + } + if b.length == tracesPerBucket { + // "Remove" an element from the bucket. + b.buf[i].unref() + b.start++ + if b.start == tracesPerBucket { + b.start = 0 + } + } + b.buf[i] = tr + if b.length < tracesPerBucket { + b.length++ + } + tr.ref() +} + +// Copy returns a copy of the traces in the bucket. +// If tracedOnly is true, only the traces with trace information will be returned. +// The logs will be ref'd before returning; the caller should call +// the Free method when it is done with them. +// TODO(dsymonds): keep track of traced requests in separate buckets. +func (b *traceBucket) Copy(tracedOnly bool) traceList { + b.mu.RLock() + defer b.mu.RUnlock() + + trl := make(traceList, 0, b.length) + for i, x := 0, b.start; i < b.length; i++ { + tr := b.buf[x] + if !tracedOnly || tr.spanID != 0 { + tr.ref() + trl = append(trl, tr) + } + x++ + if x == b.length { + x = 0 + } + } + return trl +} + +func (b *traceBucket) Empty() bool { + b.mu.RLock() + defer b.mu.RUnlock() + return b.length == 0 +} + +// cond represents a condition on a trace. +type cond interface { + match(t *trace) bool + String() string +} + +type minCond time.Duration + +func (m minCond) match(t *trace) bool { return t.Elapsed >= time.Duration(m) } +func (m minCond) String() string { return fmt.Sprintf("≥%gs", time.Duration(m).Seconds()) } + +type errorCond struct{} + +func (e errorCond) match(t *trace) bool { return t.IsError } +func (e errorCond) String() string { return "errors" } + +type traceList []*trace + +// Free calls unref on each element of the list. +func (trl traceList) Free() { + for _, t := range trl { + t.unref() + } +} + +// traceList may be sorted in reverse chronological order. +func (trl traceList) Len() int { return len(trl) } +func (trl traceList) Less(i, j int) bool { return trl[i].Start.After(trl[j].Start) } +func (trl traceList) Swap(i, j int) { trl[i], trl[j] = trl[j], trl[i] } + +// An event is a timestamped log entry in a trace. +type event struct { + When time.Time + Elapsed time.Duration // since previous event in trace + NewDay bool // whether this event is on a different day to the previous event + Recyclable bool // whether this event was passed via LazyLog + What interface{} // string or fmt.Stringer + Sensitive bool // whether this event contains sensitive information +} + +// WhenString returns a string representation of the elapsed time of the event. +// It will include the date if midnight was crossed. +func (e event) WhenString() string { + if e.NewDay { + return e.When.Format("2006/01/02 15:04:05.000000") + } + return e.When.Format("15:04:05.000000") +} + +// discarded represents a number of discarded events. +// It is stored as *discarded to make it easier to update in-place. +type discarded int + +func (d *discarded) String() string { + return fmt.Sprintf("(%d events discarded)", int(*d)) +} + +// trace represents an active or complete request, +// either sent or received by this program. +type trace struct { + // Family is the top-level grouping of traces to which this belongs. + Family string + + // Title is the title of this trace. + Title string + + // Timing information. + Start time.Time + Elapsed time.Duration // zero while active + + // Trace information if non-zero. + traceID uint64 + spanID uint64 + + // Whether this trace resulted in an error. + IsError bool + + // Append-only sequence of events (modulo discards). + mu sync.RWMutex + events []event + + refs int32 // how many buckets this is in + recycler func(interface{}) + disc discarded // scratch space to avoid allocation + + finishStack []byte // where finish was called, if DebugUseAfterFinish is set +} + +func (tr *trace) reset() { + // Clear all but the mutex. Mutexes may not be copied, even when unlocked. + tr.Family = "" + tr.Title = "" + tr.Start = time.Time{} + tr.Elapsed = 0 + tr.traceID = 0 + tr.spanID = 0 + tr.IsError = false + tr.events = nil + tr.refs = 0 + tr.recycler = nil + tr.disc = 0 + tr.finishStack = nil +} + +// delta returns the elapsed time since the last event or the trace start, +// and whether it spans midnight. +// L >= tr.mu +func (tr *trace) delta(t time.Time) (time.Duration, bool) { + if len(tr.events) == 0 { + return t.Sub(tr.Start), false + } + prev := tr.events[len(tr.events)-1].When + return t.Sub(prev), prev.Day() != t.Day() +} + +func (tr *trace) addEvent(x interface{}, recyclable, sensitive bool) { + if DebugUseAfterFinish && tr.finishStack != nil { + buf := make([]byte, 4<<10) // 4 KB should be enough + n := runtime.Stack(buf, false) + log.Printf("net/trace: trace used after finish:\nFinished at:\n%s\nUsed at:\n%s", tr.finishStack, buf[:n]) + } + + /* + NOTE TO DEBUGGERS + + If you are here because your program panicked in this code, + it is almost definitely the fault of code using this package, + and very unlikely to be the fault of this code. + + The most likely scenario is that some code elsewhere is using + a requestz.Trace after its Finish method is called. + You can temporarily set the DebugUseAfterFinish var + to help discover where that is; do not leave that var set, + since it makes this package much less efficient. + */ + + e := event{When: time.Now(), What: x, Recyclable: recyclable, Sensitive: sensitive} + tr.mu.Lock() + e.Elapsed, e.NewDay = tr.delta(e.When) + if len(tr.events) < cap(tr.events) { + tr.events = append(tr.events, e) + } else { + // Discard the middle events. + di := int((cap(tr.events) - 1) / 2) + if d, ok := tr.events[di].What.(*discarded); ok { + (*d)++ + } else { + // disc starts at two to count for the event it is replacing, + // plus the next one that we are about to drop. + tr.disc = 2 + if tr.recycler != nil && tr.events[di].Recyclable { + go tr.recycler(tr.events[di].What) + } + tr.events[di].What = &tr.disc + } + // The timestamp of the discarded meta-event should be + // the time of the last event it is representing. + tr.events[di].When = tr.events[di+1].When + + if tr.recycler != nil && tr.events[di+1].Recyclable { + go tr.recycler(tr.events[di+1].What) + } + copy(tr.events[di+1:], tr.events[di+2:]) + tr.events[cap(tr.events)-1] = e + } + tr.mu.Unlock() +} + +func (tr *trace) LazyLog(x fmt.Stringer, sensitive bool) { + tr.addEvent(x, true, sensitive) +} + +func (tr *trace) LazyPrintf(format string, a ...interface{}) { + tr.addEvent(&lazySprintf{format, a}, false, false) +} + +func (tr *trace) SetError() { tr.IsError = true } + +func (tr *trace) SetRecycler(f func(interface{})) { + tr.recycler = f +} + +func (tr *trace) SetTraceInfo(traceID, spanID uint64) { + tr.traceID, tr.spanID = traceID, spanID +} + +func (tr *trace) SetMaxEvents(m int) { + // Always keep at least three events: first, discarded count, last. + if len(tr.events) == 0 && m > 3 { + tr.events = make([]event, 0, m) + } +} + +func (tr *trace) ref() { + atomic.AddInt32(&tr.refs, 1) +} + +func (tr *trace) unref() { + if atomic.AddInt32(&tr.refs, -1) == 0 { + if tr.recycler != nil { + // freeTrace clears tr, so we hold tr.recycler and tr.events here. + go func(f func(interface{}), es []event) { + for _, e := range es { + if e.Recyclable { + f(e.What) + } + } + }(tr.recycler, tr.events) + } + + freeTrace(tr) + } +} + +func (tr *trace) When() string { + return tr.Start.Format("2006/01/02 15:04:05.000000") +} + +func (tr *trace) ElapsedTime() string { + t := tr.Elapsed + if t == 0 { + // Active trace. + t = time.Since(tr.Start) + } + return fmt.Sprintf("%.6f", t.Seconds()) +} + +func (tr *trace) Events() []event { + tr.mu.RLock() + defer tr.mu.RUnlock() + return tr.events +} + +var traceFreeList = make(chan *trace, 1000) // TODO(dsymonds): Use sync.Pool? + +// newTrace returns a trace ready to use. +func newTrace() *trace { + select { + case tr := <-traceFreeList: + return tr + default: + return new(trace) + } +} + +// freeTrace adds tr to traceFreeList if there's room. +// This is non-blocking. +func freeTrace(tr *trace) { + if DebugUseAfterFinish { + return // never reuse + } + tr.reset() + select { + case traceFreeList <- tr: + default: + } +} + +func elapsed(d time.Duration) string { + b := []byte(fmt.Sprintf("%.6f", d.Seconds())) + + // For subsecond durations, blank all zeros before decimal point, + // and all zeros between the decimal point and the first non-zero digit. + if d < time.Second { + dot := bytes.IndexByte(b, '.') + for i := 0; i < dot; i++ { + b[i] = ' ' + } + for i := dot + 1; i < len(b); i++ { + if b[i] == '0' { + b[i] = ' ' + } else { + break + } + } + } + + return string(b) +} + +var pageTmpl = template.Must(template.New("Page").Funcs(template.FuncMap{ + "elapsed": elapsed, + "add": func(a, b int) int { return a + b }, +}).Parse(pageHTML)) + +const pageHTML = ` +{{template "Prolog" .}} +{{template "StatusTable" .}} +{{template "Epilog" .}} + +{{define "Prolog"}} + + + /debug/requests + + + + +

/debug/requests

+{{end}} {{/* end of Prolog */}} + +{{define "StatusTable"}} + + {{range $fam := .Families}} + + + + {{$n := index $.ActiveTraceCount $fam}} + + + {{$f := index $.CompletedTraces $fam}} + {{range $i, $b := $f.Buckets}} + {{$empty := $b.Empty}} + + {{end}} + + {{$nb := len $f.Buckets}} + + + + + + {{end}} +
{{$fam}} + {{if $n}}{{end}} + [{{$n}} active] + {{if $n}}{{end}} + + {{if not $empty}}{{end}} + [{{.Cond}}] + {{if not $empty}}{{end}} + + [minute] + + [hour] + + [total] +
+{{end}} {{/* end of StatusTable */}} + +{{define "Epilog"}} +{{if $.Traces}} +
+

Family: {{$.Family}}

+ +{{if or $.Expanded $.Traced}} + [Normal/Summary] +{{else}} + [Normal/Summary] +{{end}} + +{{if or (not $.Expanded) $.Traced}} + [Normal/Expanded] +{{else}} + [Normal/Expanded] +{{end}} + +{{if not $.Active}} + {{if or $.Expanded (not $.Traced)}} + [Traced/Summary] + {{else}} + [Traced/Summary] + {{end}} + {{if or (not $.Expanded) (not $.Traced)}} + [Traced/Expanded] + {{else}} + [Traced/Expanded] + {{end}} +{{end}} + +{{if $.Total}} +

Showing {{len $.Traces}} of {{$.Total}} traces.

+{{end}} + + + + + {{range $tr := $.Traces}} + + + + + {{/* TODO: include traceID/spanID */}} + + {{if $.Expanded}} + {{range $tr.Events}} + + + + + + {{end}} + {{end}} + {{end}} +
+ {{if $.Active}}Active{{else}}Completed{{end}} Requests +
WhenElapsed (s)
{{$tr.When}}{{$tr.ElapsedTime}}{{$tr.Title}}
{{.WhenString}}{{elapsed .Elapsed}}{{if or $.ShowSensitive (not .Sensitive)}}... {{.What}}{{else}}[redacted]{{end}}
+{{end}} {{/* if $.Traces */}} + +{{if $.Histogram}} +

Latency (µs) of {{$.Family}} over {{$.HistogramWindow}}

+{{$.Histogram}} +{{end}} {{/* if $.Histogram */}} + + + +{{end}} {{/* end of Epilog */}} +` diff --git a/Godeps/_workspace/src/golang.org/x/net/trace/trace_test.go b/Godeps/_workspace/src/golang.org/x/net/trace/trace_test.go new file mode 100644 index 0000000..c2f5fcb --- /dev/null +++ b/Godeps/_workspace/src/golang.org/x/net/trace/trace_test.go @@ -0,0 +1,46 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package trace + +import ( + "reflect" + "testing" +) + +type s struct{} + +func (s) String() string { return "lazy string" } + +// TestReset checks whether all the fields are zeroed after reset. +func TestReset(t *testing.T) { + tr := New("foo", "bar") + tr.LazyLog(s{}, false) + tr.LazyPrintf("%d", 1) + tr.SetRecycler(func(_ interface{}) {}) + tr.SetTraceInfo(3, 4) + tr.SetMaxEvents(100) + tr.SetError() + tr.Finish() + + tr.(*trace).reset() + + if !reflect.DeepEqual(tr, new(trace)) { + t.Errorf("reset didn't clear all fields: %+v", tr) + } +} + +// TestResetLog checks whether all the fields are zeroed after reset. +func TestResetLog(t *testing.T) { + el := NewEventLog("foo", "bar") + el.Printf("message") + el.Errorf("error") + el.Finish() + + el.(*eventLog).reset() + + if !reflect.DeepEqual(el, new(eventLog)) { + t.Errorf("reset didn't clear all fields: %+v", el) + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/.travis.yml b/Godeps/_workspace/src/google.golang.org/grpc/.travis.yml new file mode 100644 index 0000000..2503560 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/.travis.yml @@ -0,0 +1,4 @@ +language: go + +script: + - make test testrace diff --git a/Godeps/_workspace/src/google.golang.org/grpc/CONTRIBUTING.md b/Godeps/_workspace/src/google.golang.org/grpc/CONTRIBUTING.md new file mode 100644 index 0000000..407d384 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/CONTRIBUTING.md @@ -0,0 +1,23 @@ +# How to contribute + +We definitely welcome patches and contribution to grpc! Here is some guideline +and information about how to do so. + +## Getting started + +### Legal requirements + +In order to protect both you and ourselves, you will need to sign the +[Contributor License Agreement](https://cla.developers.google.com/clas). + +### Filing Issues +When filing an issue, make sure to answer these five questions: + +1. What version of Go are you using (`go version`)? +2. What operating system and processor architecture are you using? +3. What did you do? +4. What did you expect to see? +5. What did you see instead? + +### Contributing code +Unless otherwise noted, the Go source files are distributed under the BSD-style license found in the LICENSE file. diff --git a/Godeps/_workspace/src/google.golang.org/grpc/LICENSE b/Godeps/_workspace/src/google.golang.org/grpc/LICENSE new file mode 100644 index 0000000..f4988b4 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/LICENSE @@ -0,0 +1,28 @@ +Copyright 2014, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/google.golang.org/grpc/Makefile b/Godeps/_workspace/src/google.golang.org/grpc/Makefile new file mode 100644 index 0000000..0dc225f --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/Makefile @@ -0,0 +1,47 @@ +.PHONY: \ + all \ + deps \ + updatedeps \ + testdeps \ + updatetestdeps \ + build \ + proto \ + test \ + testrace \ + clean \ + +all: test testrace + +deps: + go get -d -v google.golang.org/grpc/... + +updatedeps: + go get -d -v -u -f google.golang.org/grpc/... + +testdeps: + go get -d -v -t google.golang.org/grpc/... + +updatetestdeps: + go get -d -v -t -u -f google.golang.org/grpc/... + +build: deps + go build google.golang.org/grpc/... + +proto: + @ if ! which protoc > /dev/null; then \ + echo "error: protoc not installed" >&2; \ + exit 1; \ + fi + go get -v github.com/golang/protobuf/protoc-gen-go + for file in $$(git ls-files '*.proto'); do \ + protoc -I $$(dirname $$file) --go_out=plugins=grpc:$$(dirname $$file) $$file; \ + done + +test: testdeps + go test -v -cpu 1,4 google.golang.org/grpc/... + +testrace: testdeps + go test -v -race -cpu 1,4 google.golang.org/grpc/... + +clean: + go clean google.golang.org/grpc/... diff --git a/Godeps/_workspace/src/google.golang.org/grpc/PATENTS b/Godeps/_workspace/src/google.golang.org/grpc/PATENTS new file mode 100644 index 0000000..619f9db --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/PATENTS @@ -0,0 +1,22 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Google as part of the GRPC project. + +Google hereby grants to You a perpetual, worldwide, non-exclusive, +no-charge, royalty-free, irrevocable (except as stated in this section) +patent license to make, have made, use, offer to sell, sell, import, +transfer and otherwise run, modify and propagate the contents of this +implementation of GRPC, where such license applies only to those patent +claims, both currently owned or controlled by Google and acquired in +the future, licensable by Google that are necessarily infringed by this +implementation of GRPC. This grant does not include claims that would be +infringed only as a consequence of further modification of this +implementation. If you or your agent or exclusive licensee institute or +order or agree to the institution of patent litigation against any +entity (including a cross-claim or counterclaim in a lawsuit) alleging +that this implementation of GRPC or any code incorporated within this +implementation of GRPC constitutes direct or contributory patent +infringement, or inducement of patent infringement, then any patent +rights granted to you under this License for this implementation of GRPC +shall terminate as of the date such litigation is filed. diff --git a/Godeps/_workspace/src/google.golang.org/grpc/README.md b/Godeps/_workspace/src/google.golang.org/grpc/README.md new file mode 100644 index 0000000..f16d406 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/README.md @@ -0,0 +1,28 @@ +#gRPC-Go + +[![Build Status](https://travis-ci.org/grpc/grpc-go.svg)](https://travis-ci.org/grpc/grpc-go) [![GoDoc](https://godoc.org/google.golang.org/grpc?status.svg)](https://godoc.org/google.golang.org/grpc) + +The Go implementation of [gRPC](https://github.com/grpc/grpc) + +Installation +------------ + +To install this package, you need to install Go 1.4 and setup your Go workspace on your computer. The simplest way to install the library is to run: + +``` +$ go get google.golang.org/grpc +``` + +Prerequisites +------------- + +This requires Go 1.4 or above. + +Documentation +------------- +You can find more detailed documentation and examples in the [examples directory](examples/). + +Status +------ +Alpha - ready for early adopters. + diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/benchmark.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/benchmark.go new file mode 100644 index 0000000..7215d35 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/benchmark.go @@ -0,0 +1,147 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +/* +Package benchmark implements the building blocks to setup end-to-end gRPC benchmarks. +*/ +package benchmark + +import ( + "io" + "math" + "net" + + "golang.org/x/net/context" + "google.golang.org/grpc" + testpb "google.golang.org/grpc/benchmark/grpc_testing" + "google.golang.org/grpc/grpclog" +) + +func newPayload(t testpb.PayloadType, size int) *testpb.Payload { + if size < 0 { + grpclog.Fatalf("Requested a response with invalid length %d", size) + } + body := make([]byte, size) + switch t { + case testpb.PayloadType_COMPRESSABLE: + case testpb.PayloadType_UNCOMPRESSABLE: + grpclog.Fatalf("PayloadType UNCOMPRESSABLE is not supported") + default: + grpclog.Fatalf("Unsupported payload type: %d", t) + } + return &testpb.Payload{ + Type: t, + Body: body, + } +} + +type testServer struct { +} + +func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return &testpb.SimpleResponse{ + Payload: newPayload(in.ResponseType, int(in.ResponseSize)), + }, nil +} + +func (s *testServer) StreamingCall(stream testpb.TestService_StreamingCallServer) error { + for { + in, err := stream.Recv() + if err == io.EOF { + // read done. + return nil + } + if err != nil { + return err + } + if err := stream.Send(&testpb.SimpleResponse{ + Payload: newPayload(in.ResponseType, int(in.ResponseSize)), + }); err != nil { + return err + } + } +} + +// StartServer starts a gRPC server serving a benchmark service on the given +// address, which may be something like "localhost:0". It returns its listen +// address and a function to stop the server. +func StartServer(addr string) (string, func()) { + lis, err := net.Listen("tcp", addr) + if err != nil { + grpclog.Fatalf("Failed to listen: %v", err) + } + s := grpc.NewServer(grpc.MaxConcurrentStreams(math.MaxUint32)) + testpb.RegisterTestServiceServer(s, &testServer{}) + go s.Serve(lis) + return lis.Addr().String(), func() { + s.Stop() + } +} + +// DoUnaryCall performs an unary RPC with given stub and request and response sizes. +func DoUnaryCall(tc testpb.TestServiceClient, reqSize, respSize int) { + pl := newPayload(testpb.PayloadType_COMPRESSABLE, reqSize) + req := &testpb.SimpleRequest{ + ResponseType: pl.Type, + ResponseSize: int32(respSize), + Payload: pl, + } + if _, err := tc.UnaryCall(context.Background(), req); err != nil { + grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err) + } +} + +// DoStreamingRoundTrip performs a round trip for a single streaming rpc. +func DoStreamingRoundTrip(tc testpb.TestServiceClient, stream testpb.TestService_StreamingCallClient, reqSize, respSize int) { + pl := newPayload(testpb.PayloadType_COMPRESSABLE, reqSize) + req := &testpb.SimpleRequest{ + ResponseType: pl.Type, + ResponseSize: int32(respSize), + Payload: pl, + } + if err := stream.Send(req); err != nil { + grpclog.Fatalf("StreamingCall(_).Send: %v", err) + } + if _, err := stream.Recv(); err != nil { + grpclog.Fatalf("StreamingCall(_).Recv: %v", err) + } +} + +// NewClientConn creates a gRPC client connection to addr. +func NewClientConn(addr string) *grpc.ClientConn { + conn, err := grpc.Dial(addr, grpc.WithInsecure()) + if err != nil { + grpclog.Fatalf("NewClientConn(%q) failed to create a ClientConn %v", addr, err) + } + return conn +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/benchmark_test.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/benchmark_test.go new file mode 100644 index 0000000..c3b4281 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/benchmark_test.go @@ -0,0 +1,198 @@ +package benchmark + +import ( + "os" + "sync" + "testing" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc" + testpb "google.golang.org/grpc/benchmark/grpc_testing" + "google.golang.org/grpc/benchmark/stats" + "google.golang.org/grpc/grpclog" +) + +func runUnary(b *testing.B, maxConcurrentCalls int) { + s := stats.AddStats(b, 38) + b.StopTimer() + target, stopper := StartServer("localhost:0") + defer stopper() + conn := NewClientConn(target) + tc := testpb.NewTestServiceClient(conn) + + // Warm up connection. + for i := 0; i < 10; i++ { + unaryCaller(tc) + } + ch := make(chan int, maxConcurrentCalls*4) + var ( + mu sync.Mutex + wg sync.WaitGroup + ) + wg.Add(maxConcurrentCalls) + + // Distribute the b.N calls over maxConcurrentCalls workers. + for i := 0; i < maxConcurrentCalls; i++ { + go func() { + for _ = range ch { + start := time.Now() + unaryCaller(tc) + elapse := time.Since(start) + mu.Lock() + s.Add(elapse) + mu.Unlock() + } + wg.Done() + }() + } + b.StartTimer() + for i := 0; i < b.N; i++ { + ch <- i + } + b.StopTimer() + close(ch) + wg.Wait() + conn.Close() +} + +func runStream(b *testing.B, maxConcurrentCalls int) { + s := stats.AddStats(b, 38) + b.StopTimer() + target, stopper := StartServer("localhost:0") + defer stopper() + conn := NewClientConn(target) + tc := testpb.NewTestServiceClient(conn) + + // Warm up connection. + stream, err := tc.StreamingCall(context.Background()) + if err != nil { + grpclog.Fatalf("%v.StreamingCall(_) = _, %v", tc, err) + } + for i := 0; i < 10; i++ { + streamCaller(tc, stream) + } + + ch := make(chan int, maxConcurrentCalls*4) + var ( + mu sync.Mutex + wg sync.WaitGroup + ) + wg.Add(maxConcurrentCalls) + + // Distribute the b.N calls over maxConcurrentCalls workers. + for i := 0; i < maxConcurrentCalls; i++ { + go func() { + stream, err := tc.StreamingCall(context.Background()) + if err != nil { + grpclog.Fatalf("%v.StreamingCall(_) = _, %v", tc, err) + } + for _ = range ch { + start := time.Now() + streamCaller(tc, stream) + elapse := time.Since(start) + mu.Lock() + s.Add(elapse) + mu.Unlock() + } + wg.Done() + }() + } + b.StartTimer() + for i := 0; i < b.N; i++ { + ch <- i + } + b.StopTimer() + close(ch) + wg.Wait() + conn.Close() +} +func unaryCaller(client testpb.TestServiceClient) { + DoUnaryCall(client, 1, 1) +} + +func streamCaller(client testpb.TestServiceClient, stream testpb.TestService_StreamingCallClient) { + DoStreamingRoundTrip(client, stream, 1, 1) +} + +func BenchmarkClientStreamc1(b *testing.B) { + grpc.EnableTracing = true + runStream(b, 1) +} + +func BenchmarkClientStreamc8(b *testing.B) { + grpc.EnableTracing = true + runStream(b, 8) +} + +func BenchmarkClientStreamc64(b *testing.B) { + grpc.EnableTracing = true + runStream(b, 64) +} + +func BenchmarkClientStreamc512(b *testing.B) { + grpc.EnableTracing = true + runStream(b, 512) +} +func BenchmarkClientUnaryc1(b *testing.B) { + grpc.EnableTracing = true + runUnary(b, 1) +} + +func BenchmarkClientUnaryc8(b *testing.B) { + grpc.EnableTracing = true + runUnary(b, 8) +} + +func BenchmarkClientUnaryc64(b *testing.B) { + grpc.EnableTracing = true + runUnary(b, 64) +} + +func BenchmarkClientUnaryc512(b *testing.B) { + grpc.EnableTracing = true + runUnary(b, 512) +} + +func BenchmarkClientStreamNoTracec1(b *testing.B) { + grpc.EnableTracing = false + runStream(b, 1) +} + +func BenchmarkClientStreamNoTracec8(b *testing.B) { + grpc.EnableTracing = false + runStream(b, 8) +} + +func BenchmarkClientStreamNoTracec64(b *testing.B) { + grpc.EnableTracing = false + runStream(b, 64) +} + +func BenchmarkClientStreamNoTracec512(b *testing.B) { + grpc.EnableTracing = false + runStream(b, 512) +} +func BenchmarkClientUnaryNoTracec1(b *testing.B) { + grpc.EnableTracing = false + runUnary(b, 1) +} + +func BenchmarkClientUnaryNoTracec8(b *testing.B) { + grpc.EnableTracing = false + runUnary(b, 8) +} + +func BenchmarkClientUnaryNoTracec64(b *testing.B) { + grpc.EnableTracing = false + runUnary(b, 64) +} + +func BenchmarkClientUnaryNoTracec512(b *testing.B) { + grpc.EnableTracing = false + runUnary(b, 512) +} + +func TestMain(m *testing.M) { + os.Exit(stats.RunTestMain(m)) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/client/main.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/client/main.go new file mode 100644 index 0000000..f9e2a83 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/client/main.go @@ -0,0 +1,161 @@ +package main + +import ( + "flag" + "math" + "net" + "net/http" + _ "net/http/pprof" + "sync" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/benchmark" + testpb "google.golang.org/grpc/benchmark/grpc_testing" + "google.golang.org/grpc/benchmark/stats" + "google.golang.org/grpc/grpclog" +) + +var ( + server = flag.String("server", "", "The server address") + maxConcurrentRPCs = flag.Int("max_concurrent_rpcs", 1, "The max number of concurrent RPCs") + duration = flag.Int("duration", math.MaxInt32, "The duration in seconds to run the benchmark client") + trace = flag.Bool("trace", true, "Whether tracing is on") + rpcType = flag.Int("rpc_type", 0, + `Configure different client rpc type. Valid options are: + 0 : unary call; + 1 : streaming call.`) +) + +func unaryCaller(client testpb.TestServiceClient) { + benchmark.DoUnaryCall(client, 1, 1) +} + +func streamCaller(client testpb.TestServiceClient, stream testpb.TestService_StreamingCallClient) { + benchmark.DoStreamingRoundTrip(client, stream, 1, 1) +} + +func buildConnection() (s *stats.Stats, conn *grpc.ClientConn, tc testpb.TestServiceClient) { + s = stats.NewStats(256) + conn = benchmark.NewClientConn(*server) + tc = testpb.NewTestServiceClient(conn) + return s, conn, tc +} + +func closeLoopUnary() { + s, conn, tc := buildConnection() + + for i := 0; i < 100; i++ { + unaryCaller(tc) + } + ch := make(chan int, *maxConcurrentRPCs*4) + var ( + mu sync.Mutex + wg sync.WaitGroup + ) + wg.Add(*maxConcurrentRPCs) + + for i := 0; i < *maxConcurrentRPCs; i++ { + go func() { + for _ = range ch { + start := time.Now() + unaryCaller(tc) + elapse := time.Since(start) + mu.Lock() + s.Add(elapse) + mu.Unlock() + } + wg.Done() + }() + } + // Stop the client when time is up. + done := make(chan struct{}) + go func() { + <-time.After(time.Duration(*duration) * time.Second) + close(done) + }() + ok := true + for ok { + select { + case ch <- 0: + case <-done: + ok = false + } + } + close(ch) + wg.Wait() + conn.Close() + grpclog.Println(s.String()) + +} + +func closeLoopStream() { + s, conn, tc := buildConnection() + stream, err := tc.StreamingCall(context.Background()) + if err != nil { + grpclog.Fatalf("%v.StreamingCall(_) = _, %v", tc, err) + } + for i := 0; i < 100; i++ { + streamCaller(tc, stream) + } + ch := make(chan int, *maxConcurrentRPCs*4) + var ( + mu sync.Mutex + wg sync.WaitGroup + ) + wg.Add(*maxConcurrentRPCs) + // Distribute RPCs over maxConcurrentCalls workers. + for i := 0; i < *maxConcurrentRPCs; i++ { + go func() { + for _ = range ch { + start := time.Now() + streamCaller(tc, stream) + elapse := time.Since(start) + mu.Lock() + s.Add(elapse) + mu.Unlock() + } + wg.Done() + }() + } + // Stop the client when time is up. + done := make(chan struct{}) + go func() { + <-time.After(time.Duration(*duration) * time.Second) + close(done) + }() + ok := true + for ok { + select { + case ch <- 0: + case <-done: + ok = false + } + } + close(ch) + wg.Wait() + conn.Close() + grpclog.Println(s.String()) +} + +func main() { + flag.Parse() + grpc.EnableTracing = *trace + go func() { + lis, err := net.Listen("tcp", ":0") + if err != nil { + grpclog.Fatalf("Failed to listen: %v", err) + } + grpclog.Println("Client profiling address: ", lis.Addr().String()) + if err := http.Serve(lis, nil); err != nil { + grpclog.Fatalf("Failed to serve: %v", err) + } + }() + switch *rpcType { + case 0: + closeLoopUnary() + case 1: + closeLoopStream() + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/grpc_testing/test.pb.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/grpc_testing/test.pb.go new file mode 100644 index 0000000..619c450 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/grpc_testing/test.pb.go @@ -0,0 +1,641 @@ +// Code generated by protoc-gen-go. +// source: test.proto +// DO NOT EDIT! + +/* +Package grpc_testing is a generated protocol buffer package. + +It is generated from these files: + test.proto + +It has these top-level messages: + StatsRequest + ServerStats + Payload + HistogramData + ClientConfig + Mark + ClientArgs + ClientStats + ClientStatus + ServerConfig + ServerArgs + ServerStatus + SimpleRequest + SimpleResponse +*/ +package grpc_testing + +import proto "github.com/golang/protobuf/proto" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal + +type PayloadType int32 + +const ( + // Compressable text format. + PayloadType_COMPRESSABLE PayloadType = 0 + // Uncompressable binary format. + PayloadType_UNCOMPRESSABLE PayloadType = 1 + // Randomly chosen from all other formats defined in this enum. + PayloadType_RANDOM PayloadType = 2 +) + +var PayloadType_name = map[int32]string{ + 0: "COMPRESSABLE", + 1: "UNCOMPRESSABLE", + 2: "RANDOM", +} +var PayloadType_value = map[string]int32{ + "COMPRESSABLE": 0, + "UNCOMPRESSABLE": 1, + "RANDOM": 2, +} + +func (x PayloadType) String() string { + return proto.EnumName(PayloadType_name, int32(x)) +} + +type ClientType int32 + +const ( + ClientType_SYNCHRONOUS_CLIENT ClientType = 0 + ClientType_ASYNC_CLIENT ClientType = 1 +) + +var ClientType_name = map[int32]string{ + 0: "SYNCHRONOUS_CLIENT", + 1: "ASYNC_CLIENT", +} +var ClientType_value = map[string]int32{ + "SYNCHRONOUS_CLIENT": 0, + "ASYNC_CLIENT": 1, +} + +func (x ClientType) String() string { + return proto.EnumName(ClientType_name, int32(x)) +} + +type ServerType int32 + +const ( + ServerType_SYNCHRONOUS_SERVER ServerType = 0 + ServerType_ASYNC_SERVER ServerType = 1 +) + +var ServerType_name = map[int32]string{ + 0: "SYNCHRONOUS_SERVER", + 1: "ASYNC_SERVER", +} +var ServerType_value = map[string]int32{ + "SYNCHRONOUS_SERVER": 0, + "ASYNC_SERVER": 1, +} + +func (x ServerType) String() string { + return proto.EnumName(ServerType_name, int32(x)) +} + +type RpcType int32 + +const ( + RpcType_UNARY RpcType = 0 + RpcType_STREAMING RpcType = 1 +) + +var RpcType_name = map[int32]string{ + 0: "UNARY", + 1: "STREAMING", +} +var RpcType_value = map[string]int32{ + "UNARY": 0, + "STREAMING": 1, +} + +func (x RpcType) String() string { + return proto.EnumName(RpcType_name, int32(x)) +} + +type StatsRequest struct { + // run number + TestNum int32 `protobuf:"varint,1,opt,name=test_num" json:"test_num,omitempty"` +} + +func (m *StatsRequest) Reset() { *m = StatsRequest{} } +func (m *StatsRequest) String() string { return proto.CompactTextString(m) } +func (*StatsRequest) ProtoMessage() {} + +type ServerStats struct { + // wall clock time + TimeElapsed float64 `protobuf:"fixed64,1,opt,name=time_elapsed" json:"time_elapsed,omitempty"` + // user time used by the server process and threads + TimeUser float64 `protobuf:"fixed64,2,opt,name=time_user" json:"time_user,omitempty"` + // server time used by the server process and all threads + TimeSystem float64 `protobuf:"fixed64,3,opt,name=time_system" json:"time_system,omitempty"` +} + +func (m *ServerStats) Reset() { *m = ServerStats{} } +func (m *ServerStats) String() string { return proto.CompactTextString(m) } +func (*ServerStats) ProtoMessage() {} + +type Payload struct { + // The type of data in body. + Type PayloadType `protobuf:"varint,1,opt,name=type,enum=grpc.testing.PayloadType" json:"type,omitempty"` + // Primary contents of payload. + Body []byte `protobuf:"bytes,2,opt,name=body,proto3" json:"body,omitempty"` +} + +func (m *Payload) Reset() { *m = Payload{} } +func (m *Payload) String() string { return proto.CompactTextString(m) } +func (*Payload) ProtoMessage() {} + +type HistogramData struct { + Bucket []uint32 `protobuf:"varint,1,rep,name=bucket" json:"bucket,omitempty"` + MinSeen float64 `protobuf:"fixed64,2,opt,name=min_seen" json:"min_seen,omitempty"` + MaxSeen float64 `protobuf:"fixed64,3,opt,name=max_seen" json:"max_seen,omitempty"` + Sum float64 `protobuf:"fixed64,4,opt,name=sum" json:"sum,omitempty"` + SumOfSquares float64 `protobuf:"fixed64,5,opt,name=sum_of_squares" json:"sum_of_squares,omitempty"` + Count float64 `protobuf:"fixed64,6,opt,name=count" json:"count,omitempty"` +} + +func (m *HistogramData) Reset() { *m = HistogramData{} } +func (m *HistogramData) String() string { return proto.CompactTextString(m) } +func (*HistogramData) ProtoMessage() {} + +type ClientConfig struct { + ServerTargets []string `protobuf:"bytes,1,rep,name=server_targets" json:"server_targets,omitempty"` + ClientType ClientType `protobuf:"varint,2,opt,name=client_type,enum=grpc.testing.ClientType" json:"client_type,omitempty"` + EnableSsl bool `protobuf:"varint,3,opt,name=enable_ssl" json:"enable_ssl,omitempty"` + OutstandingRpcsPerChannel int32 `protobuf:"varint,4,opt,name=outstanding_rpcs_per_channel" json:"outstanding_rpcs_per_channel,omitempty"` + ClientChannels int32 `protobuf:"varint,5,opt,name=client_channels" json:"client_channels,omitempty"` + PayloadSize int32 `protobuf:"varint,6,opt,name=payload_size" json:"payload_size,omitempty"` + // only for async client: + AsyncClientThreads int32 `protobuf:"varint,7,opt,name=async_client_threads" json:"async_client_threads,omitempty"` + RpcType RpcType `protobuf:"varint,8,opt,name=rpc_type,enum=grpc.testing.RpcType" json:"rpc_type,omitempty"` +} + +func (m *ClientConfig) Reset() { *m = ClientConfig{} } +func (m *ClientConfig) String() string { return proto.CompactTextString(m) } +func (*ClientConfig) ProtoMessage() {} + +// Request current stats +type Mark struct { +} + +func (m *Mark) Reset() { *m = Mark{} } +func (m *Mark) String() string { return proto.CompactTextString(m) } +func (*Mark) ProtoMessage() {} + +type ClientArgs struct { + Setup *ClientConfig `protobuf:"bytes,1,opt,name=setup" json:"setup,omitempty"` + Mark *Mark `protobuf:"bytes,2,opt,name=mark" json:"mark,omitempty"` +} + +func (m *ClientArgs) Reset() { *m = ClientArgs{} } +func (m *ClientArgs) String() string { return proto.CompactTextString(m) } +func (*ClientArgs) ProtoMessage() {} + +func (m *ClientArgs) GetSetup() *ClientConfig { + if m != nil { + return m.Setup + } + return nil +} + +func (m *ClientArgs) GetMark() *Mark { + if m != nil { + return m.Mark + } + return nil +} + +type ClientStats struct { + Latencies *HistogramData `protobuf:"bytes,1,opt,name=latencies" json:"latencies,omitempty"` + TimeElapsed float64 `protobuf:"fixed64,3,opt,name=time_elapsed" json:"time_elapsed,omitempty"` + TimeUser float64 `protobuf:"fixed64,4,opt,name=time_user" json:"time_user,omitempty"` + TimeSystem float64 `protobuf:"fixed64,5,opt,name=time_system" json:"time_system,omitempty"` +} + +func (m *ClientStats) Reset() { *m = ClientStats{} } +func (m *ClientStats) String() string { return proto.CompactTextString(m) } +func (*ClientStats) ProtoMessage() {} + +func (m *ClientStats) GetLatencies() *HistogramData { + if m != nil { + return m.Latencies + } + return nil +} + +type ClientStatus struct { + Stats *ClientStats `protobuf:"bytes,1,opt,name=stats" json:"stats,omitempty"` +} + +func (m *ClientStatus) Reset() { *m = ClientStatus{} } +func (m *ClientStatus) String() string { return proto.CompactTextString(m) } +func (*ClientStatus) ProtoMessage() {} + +func (m *ClientStatus) GetStats() *ClientStats { + if m != nil { + return m.Stats + } + return nil +} + +type ServerConfig struct { + ServerType ServerType `protobuf:"varint,1,opt,name=server_type,enum=grpc.testing.ServerType" json:"server_type,omitempty"` + Threads int32 `protobuf:"varint,2,opt,name=threads" json:"threads,omitempty"` + EnableSsl bool `protobuf:"varint,3,opt,name=enable_ssl" json:"enable_ssl,omitempty"` +} + +func (m *ServerConfig) Reset() { *m = ServerConfig{} } +func (m *ServerConfig) String() string { return proto.CompactTextString(m) } +func (*ServerConfig) ProtoMessage() {} + +type ServerArgs struct { + Setup *ServerConfig `protobuf:"bytes,1,opt,name=setup" json:"setup,omitempty"` + Mark *Mark `protobuf:"bytes,2,opt,name=mark" json:"mark,omitempty"` +} + +func (m *ServerArgs) Reset() { *m = ServerArgs{} } +func (m *ServerArgs) String() string { return proto.CompactTextString(m) } +func (*ServerArgs) ProtoMessage() {} + +func (m *ServerArgs) GetSetup() *ServerConfig { + if m != nil { + return m.Setup + } + return nil +} + +func (m *ServerArgs) GetMark() *Mark { + if m != nil { + return m.Mark + } + return nil +} + +type ServerStatus struct { + Stats *ServerStats `protobuf:"bytes,1,opt,name=stats" json:"stats,omitempty"` + Port int32 `protobuf:"varint,2,opt,name=port" json:"port,omitempty"` +} + +func (m *ServerStatus) Reset() { *m = ServerStatus{} } +func (m *ServerStatus) String() string { return proto.CompactTextString(m) } +func (*ServerStatus) ProtoMessage() {} + +func (m *ServerStatus) GetStats() *ServerStats { + if m != nil { + return m.Stats + } + return nil +} + +type SimpleRequest struct { + // Desired payload type in the response from the server. + // If response_type is RANDOM, server randomly chooses one from other formats. + ResponseType PayloadType `protobuf:"varint,1,opt,name=response_type,enum=grpc.testing.PayloadType" json:"response_type,omitempty"` + // Desired payload size in the response from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + ResponseSize int32 `protobuf:"varint,2,opt,name=response_size" json:"response_size,omitempty"` + // Optional input payload sent along with the request. + Payload *Payload `protobuf:"bytes,3,opt,name=payload" json:"payload,omitempty"` +} + +func (m *SimpleRequest) Reset() { *m = SimpleRequest{} } +func (m *SimpleRequest) String() string { return proto.CompactTextString(m) } +func (*SimpleRequest) ProtoMessage() {} + +func (m *SimpleRequest) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +type SimpleResponse struct { + Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"` +} + +func (m *SimpleResponse) Reset() { *m = SimpleResponse{} } +func (m *SimpleResponse) String() string { return proto.CompactTextString(m) } +func (*SimpleResponse) ProtoMessage() {} + +func (m *SimpleResponse) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +func init() { + proto.RegisterEnum("grpc.testing.PayloadType", PayloadType_name, PayloadType_value) + proto.RegisterEnum("grpc.testing.ClientType", ClientType_name, ClientType_value) + proto.RegisterEnum("grpc.testing.ServerType", ServerType_name, ServerType_value) + proto.RegisterEnum("grpc.testing.RpcType", RpcType_name, RpcType_value) +} + +// Client API for TestService service + +type TestServiceClient interface { + // One request followed by one response. + // The server returns the client payload as-is. + UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) + // One request followed by one response. + // The server returns the client payload as-is. + StreamingCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingCallClient, error) +} + +type testServiceClient struct { + cc *grpc.ClientConn +} + +func NewTestServiceClient(cc *grpc.ClientConn) TestServiceClient { + return &testServiceClient{cc} +} + +func (c *testServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) { + out := new(SimpleResponse) + err := grpc.Invoke(ctx, "/grpc.testing.TestService/UnaryCall", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *testServiceClient) StreamingCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[0], c.cc, "/grpc.testing.TestService/StreamingCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceStreamingCallClient{stream} + return x, nil +} + +type TestService_StreamingCallClient interface { + Send(*SimpleRequest) error + Recv() (*SimpleResponse, error) + grpc.ClientStream +} + +type testServiceStreamingCallClient struct { + grpc.ClientStream +} + +func (x *testServiceStreamingCallClient) Send(m *SimpleRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceStreamingCallClient) Recv() (*SimpleResponse, error) { + m := new(SimpleResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for TestService service + +type TestServiceServer interface { + // One request followed by one response. + // The server returns the client payload as-is. + UnaryCall(context.Context, *SimpleRequest) (*SimpleResponse, error) + // One request followed by one response. + // The server returns the client payload as-is. + StreamingCall(TestService_StreamingCallServer) error +} + +func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { + s.RegisterService(&_TestService_serviceDesc, srv) +} + +func _TestService_UnaryCall_Handler(srv interface{}, ctx context.Context, codec grpc.Codec, buf []byte) (interface{}, error) { + in := new(SimpleRequest) + if err := codec.Unmarshal(buf, in); err != nil { + return nil, err + } + out, err := srv.(TestServiceServer).UnaryCall(ctx, in) + if err != nil { + return nil, err + } + return out, nil +} + +func _TestService_StreamingCall_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).StreamingCall(&testServiceStreamingCallServer{stream}) +} + +type TestService_StreamingCallServer interface { + Send(*SimpleResponse) error + Recv() (*SimpleRequest, error) + grpc.ServerStream +} + +type testServiceStreamingCallServer struct { + grpc.ServerStream +} + +func (x *testServiceStreamingCallServer) Send(m *SimpleResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceStreamingCallServer) Recv() (*SimpleRequest, error) { + m := new(SimpleRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _TestService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "grpc.testing.TestService", + HandlerType: (*TestServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "UnaryCall", + Handler: _TestService_UnaryCall_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "StreamingCall", + Handler: _TestService_StreamingCall_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, +} + +// Client API for Worker service + +type WorkerClient interface { + // Start test with specified workload + RunTest(ctx context.Context, opts ...grpc.CallOption) (Worker_RunTestClient, error) + // Start test with specified workload + RunServer(ctx context.Context, opts ...grpc.CallOption) (Worker_RunServerClient, error) +} + +type workerClient struct { + cc *grpc.ClientConn +} + +func NewWorkerClient(cc *grpc.ClientConn) WorkerClient { + return &workerClient{cc} +} + +func (c *workerClient) RunTest(ctx context.Context, opts ...grpc.CallOption) (Worker_RunTestClient, error) { + stream, err := grpc.NewClientStream(ctx, &_Worker_serviceDesc.Streams[0], c.cc, "/grpc.testing.Worker/RunTest", opts...) + if err != nil { + return nil, err + } + x := &workerRunTestClient{stream} + return x, nil +} + +type Worker_RunTestClient interface { + Send(*ClientArgs) error + Recv() (*ClientStatus, error) + grpc.ClientStream +} + +type workerRunTestClient struct { + grpc.ClientStream +} + +func (x *workerRunTestClient) Send(m *ClientArgs) error { + return x.ClientStream.SendMsg(m) +} + +func (x *workerRunTestClient) Recv() (*ClientStatus, error) { + m := new(ClientStatus) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *workerClient) RunServer(ctx context.Context, opts ...grpc.CallOption) (Worker_RunServerClient, error) { + stream, err := grpc.NewClientStream(ctx, &_Worker_serviceDesc.Streams[1], c.cc, "/grpc.testing.Worker/RunServer", opts...) + if err != nil { + return nil, err + } + x := &workerRunServerClient{stream} + return x, nil +} + +type Worker_RunServerClient interface { + Send(*ServerArgs) error + Recv() (*ServerStatus, error) + grpc.ClientStream +} + +type workerRunServerClient struct { + grpc.ClientStream +} + +func (x *workerRunServerClient) Send(m *ServerArgs) error { + return x.ClientStream.SendMsg(m) +} + +func (x *workerRunServerClient) Recv() (*ServerStatus, error) { + m := new(ServerStatus) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for Worker service + +type WorkerServer interface { + // Start test with specified workload + RunTest(Worker_RunTestServer) error + // Start test with specified workload + RunServer(Worker_RunServerServer) error +} + +func RegisterWorkerServer(s *grpc.Server, srv WorkerServer) { + s.RegisterService(&_Worker_serviceDesc, srv) +} + +func _Worker_RunTest_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(WorkerServer).RunTest(&workerRunTestServer{stream}) +} + +type Worker_RunTestServer interface { + Send(*ClientStatus) error + Recv() (*ClientArgs, error) + grpc.ServerStream +} + +type workerRunTestServer struct { + grpc.ServerStream +} + +func (x *workerRunTestServer) Send(m *ClientStatus) error { + return x.ServerStream.SendMsg(m) +} + +func (x *workerRunTestServer) Recv() (*ClientArgs, error) { + m := new(ClientArgs) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func _Worker_RunServer_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(WorkerServer).RunServer(&workerRunServerServer{stream}) +} + +type Worker_RunServerServer interface { + Send(*ServerStatus) error + Recv() (*ServerArgs, error) + grpc.ServerStream +} + +type workerRunServerServer struct { + grpc.ServerStream +} + +func (x *workerRunServerServer) Send(m *ServerStatus) error { + return x.ServerStream.SendMsg(m) +} + +func (x *workerRunServerServer) Recv() (*ServerArgs, error) { + m := new(ServerArgs) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _Worker_serviceDesc = grpc.ServiceDesc{ + ServiceName: "grpc.testing.Worker", + HandlerType: (*WorkerServer)(nil), + Methods: []grpc.MethodDesc{}, + Streams: []grpc.StreamDesc{ + { + StreamName: "RunTest", + Handler: _Worker_RunTest_Handler, + ServerStreams: true, + ClientStreams: true, + }, + { + StreamName: "RunServer", + Handler: _Worker_RunServer_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/grpc_testing/test.proto b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/grpc_testing/test.proto new file mode 100644 index 0000000..e3a27f8 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/grpc_testing/test.proto @@ -0,0 +1,148 @@ +// An integration test service that covers all the method signature permutations +// of unary/streaming requests/responses. +syntax = "proto3"; + +package grpc.testing; + +enum PayloadType { + // Compressable text format. + COMPRESSABLE = 0; + + // Uncompressable binary format. + UNCOMPRESSABLE = 1; + + // Randomly chosen from all other formats defined in this enum. + RANDOM = 2; +} + +message StatsRequest { + // run number + optional int32 test_num = 1; +} + +message ServerStats { + // wall clock time + double time_elapsed = 1; + + // user time used by the server process and threads + double time_user = 2; + + // server time used by the server process and all threads + double time_system = 3; +} + +message Payload { + // The type of data in body. + PayloadType type = 1; + // Primary contents of payload. + bytes body = 2; +} + +message HistogramData { + repeated uint32 bucket = 1; + double min_seen = 2; + double max_seen = 3; + double sum = 4; + double sum_of_squares = 5; + double count = 6; +} + +enum ClientType { + SYNCHRONOUS_CLIENT = 0; + ASYNC_CLIENT = 1; +} + +enum ServerType { + SYNCHRONOUS_SERVER = 0; + ASYNC_SERVER = 1; +} + +enum RpcType { + UNARY = 0; + STREAMING = 1; +} + +message ClientConfig { + repeated string server_targets = 1; + ClientType client_type = 2; + bool enable_ssl = 3; + int32 outstanding_rpcs_per_channel = 4; + int32 client_channels = 5; + int32 payload_size = 6; + // only for async client: + int32 async_client_threads = 7; + RpcType rpc_type = 8; +} + +// Request current stats +message Mark {} + +message ClientArgs { + oneof argtype { + ClientConfig setup = 1; + Mark mark = 2; + } +} + +message ClientStats { + HistogramData latencies = 1; + double time_elapsed = 3; + double time_user = 4; + double time_system = 5; +} + +message ClientStatus { + ClientStats stats = 1; +} + +message ServerConfig { + ServerType server_type = 1; + int32 threads = 2; + bool enable_ssl = 3; +} + +message ServerArgs { + oneof argtype { + ServerConfig setup = 1; + Mark mark = 2; + } +} + +message ServerStatus { + ServerStats stats = 1; + int32 port = 2; +} + +message SimpleRequest { + // Desired payload type in the response from the server. + // If response_type is RANDOM, server randomly chooses one from other formats. + PayloadType response_type = 1; + + // Desired payload size in the response from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + int32 response_size = 2; + + // Optional input payload sent along with the request. + Payload payload = 3; +} + +message SimpleResponse { + Payload payload = 1; +} + +service TestService { + // One request followed by one response. + // The server returns the client payload as-is. + rpc UnaryCall(SimpleRequest) returns (SimpleResponse); + + // One request followed by one response. + // The server returns the client payload as-is. + rpc StreamingCall(stream SimpleRequest) returns (stream SimpleResponse); +} + +service Worker { + // Start test with specified workload + rpc RunTest(stream ClientArgs) returns (stream ClientStatus); + // Start test with specified workload + rpc RunServer(stream ServerArgs) returns (stream ServerStatus); +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/server/main.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/server/main.go new file mode 100644 index 0000000..090f002 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/server/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "flag" + "math" + "net" + "net/http" + _ "net/http/pprof" + "time" + + "google.golang.org/grpc/benchmark" + "google.golang.org/grpc/grpclog" +) + +var ( + duration = flag.Int("duration", math.MaxInt32, "The duration in seconds to run the benchmark server") +) + +func main() { + flag.Parse() + go func() { + lis, err := net.Listen("tcp", ":0") + if err != nil { + grpclog.Fatalf("Failed to listen: %v", err) + } + grpclog.Println("Server profiling address: ", lis.Addr().String()) + if err := http.Serve(lis, nil); err != nil { + grpclog.Fatalf("Failed to serve: %v", err) + } + }() + addr, stopper := benchmark.StartServer(":0") // listen on all interfaces + grpclog.Println("Server Address: ", addr) + <-time.After(time.Duration(*duration) * time.Second) + stopper() +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/counter.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/counter.go new file mode 100644 index 0000000..4389bae --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/counter.go @@ -0,0 +1,135 @@ +package stats + +import ( + "sync" + "time" +) + +var ( + // TimeNow is used for testing. + TimeNow = time.Now +) + +const ( + hour = 0 + tenminutes = 1 + minute = 2 +) + +// Counter is a counter that keeps track of its recent values over a given +// period of time, and with a given resolution. Use newCounter() to instantiate. +type Counter struct { + mu sync.RWMutex + ts [3]*timeseries + lastUpdate time.Time +} + +// newCounter returns a new Counter. +func newCounter() *Counter { + now := TimeNow() + c := &Counter{} + c.ts[hour] = newTimeSeries(now, time.Hour, time.Minute) + c.ts[tenminutes] = newTimeSeries(now, 10*time.Minute, 10*time.Second) + c.ts[minute] = newTimeSeries(now, time.Minute, time.Second) + return c +} + +func (c *Counter) advance() time.Time { + now := TimeNow() + for _, ts := range c.ts { + ts.advanceTime(now) + } + return now +} + +// Value returns the current value of the counter. +func (c *Counter) Value() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.ts[minute].headValue() +} + +// LastUpdate returns the last update time of the counter. +func (c *Counter) LastUpdate() time.Time { + c.mu.RLock() + defer c.mu.RUnlock() + return c.lastUpdate +} + +// Set updates the current value of the counter. +func (c *Counter) Set(value int64) { + c.mu.Lock() + defer c.mu.Unlock() + c.lastUpdate = c.advance() + for _, ts := range c.ts { + ts.set(value) + } +} + +// Incr increments the current value of the counter by 'delta'. +func (c *Counter) Incr(delta int64) { + c.mu.Lock() + defer c.mu.Unlock() + c.lastUpdate = c.advance() + for _, ts := range c.ts { + ts.incr(delta) + } +} + +// Delta1h returns the delta for the last hour. +func (c *Counter) Delta1h() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + c.advance() + return c.ts[hour].delta() +} + +// Delta10m returns the delta for the last 10 minutes. +func (c *Counter) Delta10m() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + c.advance() + return c.ts[tenminutes].delta() +} + +// Delta1m returns the delta for the last minute. +func (c *Counter) Delta1m() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + c.advance() + return c.ts[minute].delta() +} + +// Rate1h returns the rate of change of the counter in the last hour. +func (c *Counter) Rate1h() float64 { + c.mu.RLock() + defer c.mu.RUnlock() + c.advance() + return c.ts[hour].rate() +} + +// Rate10m returns the rate of change of the counter in the last 10 minutes. +func (c *Counter) Rate10m() float64 { + c.mu.RLock() + defer c.mu.RUnlock() + c.advance() + return c.ts[tenminutes].rate() +} + +// Rate1m returns the rate of change of the counter in the last minute. +func (c *Counter) Rate1m() float64 { + c.mu.RLock() + defer c.mu.RUnlock() + c.advance() + return c.ts[minute].rate() +} + +// Reset resets the counter to an empty state. +func (c *Counter) Reset() { + c.mu.Lock() + defer c.mu.Unlock() + now := TimeNow() + for _, ts := range c.ts { + ts.reset(now) + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/histogram.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/histogram.go new file mode 100644 index 0000000..727808c --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/histogram.go @@ -0,0 +1,255 @@ +package stats + +import ( + "bytes" + "fmt" + "io" + "strconv" + "strings" + "time" +) + +// HistogramValue is the value of Histogram objects. +type HistogramValue struct { + // Count is the total number of values added to the histogram. + Count int64 + // Sum is the sum of all the values added to the histogram. + Sum int64 + // Min is the minimum of all the values added to the histogram. + Min int64 + // Max is the maximum of all the values added to the histogram. + Max int64 + // Buckets contains all the buckets of the histogram. + Buckets []HistogramBucket +} + +// HistogramBucket is one histogram bucket. +type HistogramBucket struct { + // LowBound is the lower bound of the bucket. + LowBound int64 + // Count is the number of values in the bucket. + Count int64 +} + +// Print writes textual output of the histogram values. +func (v HistogramValue) Print(w io.Writer) { + avg := float64(v.Sum) / float64(v.Count) + fmt.Fprintf(w, "Count: %d Min: %d Max: %d Avg: %.2f\n", v.Count, v.Min, v.Max, avg) + fmt.Fprintf(w, "%s\n", strings.Repeat("-", 60)) + if v.Count <= 0 { + return + } + + maxBucketDigitLen := len(strconv.FormatInt(v.Buckets[len(v.Buckets)-1].LowBound, 10)) + if maxBucketDigitLen < 3 { + // For "inf". + maxBucketDigitLen = 3 + } + maxCountDigitLen := len(strconv.FormatInt(v.Count, 10)) + percentMulti := 100 / float64(v.Count) + + accCount := int64(0) + for i, b := range v.Buckets { + fmt.Fprintf(w, "[%*d, ", maxBucketDigitLen, b.LowBound) + if i+1 < len(v.Buckets) { + fmt.Fprintf(w, "%*d)", maxBucketDigitLen, v.Buckets[i+1].LowBound) + } else { + fmt.Fprintf(w, "%*s)", maxBucketDigitLen, "inf") + } + + accCount += b.Count + fmt.Fprintf(w, " %*d %5.1f%% %5.1f%%", maxCountDigitLen, b.Count, float64(b.Count)*percentMulti, float64(accCount)*percentMulti) + + const barScale = 0.1 + barLength := int(float64(b.Count)*percentMulti*barScale + 0.5) + fmt.Fprintf(w, " %s\n", strings.Repeat("#", barLength)) + } +} + +// String returns the textual output of the histogram values as string. +func (v HistogramValue) String() string { + var b bytes.Buffer + v.Print(&b) + return b.String() +} + +// A Histogram accumulates values in the form of a histogram. The type of the +// values is int64, which is suitable for keeping track of things like RPC +// latency in milliseconds. New histogram objects should be obtained via the +// New() function. +type Histogram struct { + opts HistogramOptions + buckets []bucketInternal + count *Counter + sum *Counter + tracker *Tracker +} + +// HistogramOptions contains the parameters that define the histogram's buckets. +type HistogramOptions struct { + // NumBuckets is the number of buckets. + NumBuckets int + // GrowthFactor is the growth factor of the buckets. A value of 0.1 + // indicates that bucket N+1 will be 10% larger than bucket N. + GrowthFactor float64 + // SmallestBucketSize is the size of the first bucket. Bucket sizes are + // rounded down to the nearest integer. + SmallestBucketSize float64 + // MinValue is the lower bound of the first bucket. + MinValue int64 +} + +// bucketInternal is the internal representation of a bucket, which includes a +// rate counter. +type bucketInternal struct { + lowBound int64 + count *Counter +} + +// NewHistogram returns a pointer to a new Histogram object that was created +// with the provided options. +func NewHistogram(opts HistogramOptions) *Histogram { + if opts.NumBuckets == 0 { + opts.NumBuckets = 32 + } + if opts.SmallestBucketSize == 0.0 { + opts.SmallestBucketSize = 1.0 + } + h := Histogram{ + opts: opts, + buckets: make([]bucketInternal, opts.NumBuckets), + count: newCounter(), + sum: newCounter(), + tracker: newTracker(), + } + low := opts.MinValue + delta := opts.SmallestBucketSize + for i := 0; i < opts.NumBuckets; i++ { + h.buckets[i].lowBound = low + h.buckets[i].count = newCounter() + low = low + int64(delta) + delta = delta * (1.0 + opts.GrowthFactor) + } + return &h +} + +// Opts returns a copy of the options used to create the Histogram. +func (h *Histogram) Opts() HistogramOptions { + return h.opts +} + +// Add adds a value to the histogram. +func (h *Histogram) Add(value int64) error { + bucket, err := h.findBucket(value) + if err != nil { + return err + } + h.buckets[bucket].count.Incr(1) + h.count.Incr(1) + h.sum.Incr(value) + h.tracker.Push(value) + return nil +} + +// LastUpdate returns the time at which the object was last updated. +func (h *Histogram) LastUpdate() time.Time { + return h.count.LastUpdate() +} + +// Value returns the accumulated state of the histogram since it was created. +func (h *Histogram) Value() HistogramValue { + b := make([]HistogramBucket, len(h.buckets)) + for i, v := range h.buckets { + b[i] = HistogramBucket{ + LowBound: v.lowBound, + Count: v.count.Value(), + } + } + + v := HistogramValue{ + Count: h.count.Value(), + Sum: h.sum.Value(), + Min: h.tracker.Min(), + Max: h.tracker.Max(), + Buckets: b, + } + return v +} + +// Delta1h returns the change in the last hour. +func (h *Histogram) Delta1h() HistogramValue { + b := make([]HistogramBucket, len(h.buckets)) + for i, v := range h.buckets { + b[i] = HistogramBucket{ + LowBound: v.lowBound, + Count: v.count.Delta1h(), + } + } + + v := HistogramValue{ + Count: h.count.Delta1h(), + Sum: h.sum.Delta1h(), + Min: h.tracker.Min1h(), + Max: h.tracker.Max1h(), + Buckets: b, + } + return v +} + +// Delta10m returns the change in the last 10 minutes. +func (h *Histogram) Delta10m() HistogramValue { + b := make([]HistogramBucket, len(h.buckets)) + for i, v := range h.buckets { + b[i] = HistogramBucket{ + LowBound: v.lowBound, + Count: v.count.Delta10m(), + } + } + + v := HistogramValue{ + Count: h.count.Delta10m(), + Sum: h.sum.Delta10m(), + Min: h.tracker.Min10m(), + Max: h.tracker.Max10m(), + Buckets: b, + } + return v +} + +// Delta1m returns the change in the last 10 minutes. +func (h *Histogram) Delta1m() HistogramValue { + b := make([]HistogramBucket, len(h.buckets)) + for i, v := range h.buckets { + b[i] = HistogramBucket{ + LowBound: v.lowBound, + Count: v.count.Delta1m(), + } + } + + v := HistogramValue{ + Count: h.count.Delta1m(), + Sum: h.sum.Delta1m(), + Min: h.tracker.Min1m(), + Max: h.tracker.Max1m(), + Buckets: b, + } + return v +} + +// findBucket does a binary search to find in which bucket the value goes. +func (h *Histogram) findBucket(value int64) (int, error) { + lastBucket := len(h.buckets) - 1 + min, max := 0, lastBucket + for max >= min { + b := (min + max) / 2 + if value >= h.buckets[b].lowBound && (b == lastBucket || value < h.buckets[b+1].lowBound) { + return b, nil + } + if value < h.buckets[b].lowBound { + max = b - 1 + continue + } + min = b + 1 + } + return 0, fmt.Errorf("no bucket for value: %d", value) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/stats.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/stats.go new file mode 100644 index 0000000..4290ad7 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/stats.go @@ -0,0 +1,116 @@ +package stats + +import ( + "bytes" + "fmt" + "io" + "math" + "time" +) + +// Stats is a simple helper for gathering additional statistics like histogram +// during benchmarks. This is not thread safe. +type Stats struct { + numBuckets int + unit time.Duration + min, max int64 + histogram *Histogram + + durations durationSlice + dirty bool +} + +type durationSlice []time.Duration + +// NewStats creates a new Stats instance. If numBuckets is not positive, +// the default value (16) will be used. +func NewStats(numBuckets int) *Stats { + if numBuckets <= 0 { + numBuckets = 16 + } + return &Stats{ + // Use one more bucket for the last unbounded bucket. + numBuckets: numBuckets + 1, + durations: make(durationSlice, 0, 100000), + } +} + +// Add adds an elapsed time per operation to the stats. +func (stats *Stats) Add(d time.Duration) { + stats.durations = append(stats.durations, d) + stats.dirty = true +} + +// Clear resets the stats, removing all values. +func (stats *Stats) Clear() { + stats.durations = stats.durations[:0] + stats.histogram = nil + stats.dirty = false +} + +// maybeUpdate updates internal stat data if there was any newly added +// stats since this was updated. +func (stats *Stats) maybeUpdate() { + if !stats.dirty { + return + } + + stats.min = math.MaxInt64 + stats.max = 0 + for _, d := range stats.durations { + if stats.min > int64(d) { + stats.min = int64(d) + } + if stats.max < int64(d) { + stats.max = int64(d) + } + } + + // Use the largest unit that can represent the minimum time duration. + stats.unit = time.Nanosecond + for _, u := range []time.Duration{time.Microsecond, time.Millisecond, time.Second} { + if stats.min <= int64(u) { + break + } + stats.unit = u + } + + // Adjust the min/max according to the new unit. + stats.min /= int64(stats.unit) + stats.max /= int64(stats.unit) + numBuckets := stats.numBuckets + if n := int(stats.max - stats.min + 1); n < numBuckets { + numBuckets = n + } + stats.histogram = NewHistogram(HistogramOptions{ + NumBuckets: numBuckets, + // max(i.e., Nth lower bound) = min + (1 + growthFactor)^(numBuckets-2). + GrowthFactor: math.Pow(float64(stats.max-stats.min), 1/float64(stats.numBuckets-2)) - 1, + SmallestBucketSize: 1.0, + MinValue: stats.min}) + + for _, d := range stats.durations { + stats.histogram.Add(int64(d / stats.unit)) + } + + stats.dirty = false +} + +// Print writes textual output of the Stats. +func (stats *Stats) Print(w io.Writer) { + stats.maybeUpdate() + + if stats.histogram == nil { + fmt.Fprint(w, "Histogram (empty)\n") + } else { + fmt.Fprintf(w, "Histogram (unit: %s)\n", fmt.Sprintf("%v", stats.unit)[1:]) + stats.histogram.Value().Print(w) + } +} + +// String returns the textual output of the Stats as string. +func (stats *Stats) String() string { + var b bytes.Buffer + stats.Print(&b) + return b.String() +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/timeseries.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/timeseries.go new file mode 100644 index 0000000..2ba18a4 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/timeseries.go @@ -0,0 +1,154 @@ +package stats + +import ( + "math" + "time" +) + +// timeseries holds the history of a changing value over a predefined period of +// time. +type timeseries struct { + size int // The number of time slots. Equivalent to len(slots). + resolution time.Duration // The time resolution of each slot. + stepCount int64 // The number of intervals seen since creation. + head int // The position of the current time in slots. + time time.Time // The time at the beginning of the current time slot. + slots []int64 // A circular buffer of time slots. +} + +// newTimeSeries returns a newly allocated timeseries that covers the requested +// period with the given resolution. +func newTimeSeries(initialTime time.Time, period, resolution time.Duration) *timeseries { + size := int(period.Nanoseconds()/resolution.Nanoseconds()) + 1 + return ×eries{ + size: size, + resolution: resolution, + stepCount: 1, + time: initialTime, + slots: make([]int64, size), + } +} + +// advanceTimeWithFill moves the timeseries forward to time t and fills in any +// slots that get skipped in the process with the given value. Values older than +// the timeseries period are lost. +func (ts *timeseries) advanceTimeWithFill(t time.Time, value int64) { + advanceTo := t.Truncate(ts.resolution) + if !advanceTo.After(ts.time) { + // This is shortcut for the most common case of a busy counter + // where updates come in many times per ts.resolution. + ts.time = advanceTo + return + } + steps := int(advanceTo.Sub(ts.time).Nanoseconds() / ts.resolution.Nanoseconds()) + ts.stepCount += int64(steps) + if steps > ts.size { + steps = ts.size + } + for steps > 0 { + ts.head = (ts.head + 1) % ts.size + ts.slots[ts.head] = value + steps-- + } + ts.time = advanceTo +} + +// advanceTime moves the timeseries forward to time t and fills in any slots +// that get skipped in the process with the head value. Values older than the +// timeseries period are lost. +func (ts *timeseries) advanceTime(t time.Time) { + ts.advanceTimeWithFill(t, ts.slots[ts.head]) +} + +// set sets the current value of the timeseries. +func (ts *timeseries) set(value int64) { + ts.slots[ts.head] = value +} + +// incr sets the current value of the timeseries. +func (ts *timeseries) incr(delta int64) { + ts.slots[ts.head] += delta +} + +// headValue returns the latest value from the timeseries. +func (ts *timeseries) headValue() int64 { + return ts.slots[ts.head] +} + +// headTime returns the time of the latest value from the timeseries. +func (ts *timeseries) headTime() time.Time { + return ts.time +} + +// tailValue returns the oldest value from the timeseries. +func (ts *timeseries) tailValue() int64 { + if ts.stepCount < int64(ts.size) { + return 0 + } + return ts.slots[(ts.head+1)%ts.size] +} + +// tailTime returns the time of the oldest value from the timeseries. +func (ts *timeseries) tailTime() time.Time { + size := int64(ts.size) + if ts.stepCount < size { + size = ts.stepCount + } + return ts.time.Add(-time.Duration(size-1) * ts.resolution) +} + +// delta returns the difference between the newest and oldest values from the +// timeseries. +func (ts *timeseries) delta() int64 { + return ts.headValue() - ts.tailValue() +} + +// rate returns the rate of change between the oldest and newest values from +// the timeseries in units per second. +func (ts *timeseries) rate() float64 { + deltaTime := ts.headTime().Sub(ts.tailTime()).Seconds() + if deltaTime == 0 { + return 0 + } + return float64(ts.delta()) / deltaTime +} + +// min returns the smallest value from the timeseries. +func (ts *timeseries) min() int64 { + to := ts.size + if ts.stepCount < int64(ts.size) { + to = ts.head + 1 + } + tail := (ts.head + 1) % ts.size + min := int64(math.MaxInt64) + for b := 0; b < to; b++ { + if b != tail && ts.slots[b] < min { + min = ts.slots[b] + } + } + return min +} + +// max returns the largest value from the timeseries. +func (ts *timeseries) max() int64 { + to := ts.size + if ts.stepCount < int64(ts.size) { + to = ts.head + 1 + } + tail := (ts.head + 1) % ts.size + max := int64(math.MinInt64) + for b := 0; b < to; b++ { + if b != tail && ts.slots[b] > max { + max = ts.slots[b] + } + } + return max +} + +// reset resets the timeseries to an empty state. +func (ts *timeseries) reset(t time.Time) { + ts.head = 0 + ts.time = t + ts.stepCount = 1 + ts.slots = make([]int64, ts.size) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/tracker.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/tracker.go new file mode 100644 index 0000000..802f729 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/tracker.go @@ -0,0 +1,159 @@ +package stats + +import ( + "math" + "sync" + "time" +) + +// Tracker is a min/max value tracker that keeps track of its min/max values +// over a given period of time, and with a given resolution. The initial min +// and max values are math.MaxInt64 and math.MinInt64 respectively. +type Tracker struct { + mu sync.RWMutex + min, max int64 // All time min/max. + minTS, maxTS [3]*timeseries + lastUpdate time.Time +} + +// newTracker returns a new Tracker. +func newTracker() *Tracker { + now := TimeNow() + t := &Tracker{} + t.minTS[hour] = newTimeSeries(now, time.Hour, time.Minute) + t.minTS[tenminutes] = newTimeSeries(now, 10*time.Minute, 10*time.Second) + t.minTS[minute] = newTimeSeries(now, time.Minute, time.Second) + t.maxTS[hour] = newTimeSeries(now, time.Hour, time.Minute) + t.maxTS[tenminutes] = newTimeSeries(now, 10*time.Minute, 10*time.Second) + t.maxTS[minute] = newTimeSeries(now, time.Minute, time.Second) + t.init() + return t +} + +func (t *Tracker) init() { + t.min = math.MaxInt64 + t.max = math.MinInt64 + for _, ts := range t.minTS { + ts.set(math.MaxInt64) + } + for _, ts := range t.maxTS { + ts.set(math.MinInt64) + } +} + +func (t *Tracker) advance() time.Time { + now := TimeNow() + for _, ts := range t.minTS { + ts.advanceTimeWithFill(now, math.MaxInt64) + } + for _, ts := range t.maxTS { + ts.advanceTimeWithFill(now, math.MinInt64) + } + return now +} + +// LastUpdate returns the last update time of the range. +func (t *Tracker) LastUpdate() time.Time { + t.mu.RLock() + defer t.mu.RUnlock() + return t.lastUpdate +} + +// Push adds a new value if it is a new minimum or maximum. +func (t *Tracker) Push(value int64) { + t.mu.Lock() + defer t.mu.Unlock() + t.lastUpdate = t.advance() + if t.min > value { + t.min = value + } + if t.max < value { + t.max = value + } + for _, ts := range t.minTS { + if ts.headValue() > value { + ts.set(value) + } + } + for _, ts := range t.maxTS { + if ts.headValue() < value { + ts.set(value) + } + } +} + +// Min returns the minimum value of the tracker +func (t *Tracker) Min() int64 { + t.mu.RLock() + defer t.mu.RUnlock() + return t.min +} + +// Max returns the maximum value of the tracker. +func (t *Tracker) Max() int64 { + t.mu.RLock() + defer t.mu.RUnlock() + return t.max +} + +// Min1h returns the minimum value for the last hour. +func (t *Tracker) Min1h() int64 { + t.mu.Lock() + defer t.mu.Unlock() + t.advance() + return t.minTS[hour].min() +} + +// Max1h returns the maximum value for the last hour. +func (t *Tracker) Max1h() int64 { + t.mu.Lock() + defer t.mu.Unlock() + t.advance() + return t.maxTS[hour].max() +} + +// Min10m returns the minimum value for the last 10 minutes. +func (t *Tracker) Min10m() int64 { + t.mu.Lock() + defer t.mu.Unlock() + t.advance() + return t.minTS[tenminutes].min() +} + +// Max10m returns the maximum value for the last 10 minutes. +func (t *Tracker) Max10m() int64 { + t.mu.Lock() + defer t.mu.Unlock() + t.advance() + return t.maxTS[tenminutes].max() +} + +// Min1m returns the minimum value for the last 1 minute. +func (t *Tracker) Min1m() int64 { + t.mu.Lock() + defer t.mu.Unlock() + t.advance() + return t.minTS[minute].min() +} + +// Max1m returns the maximum value for the last 1 minute. +func (t *Tracker) Max1m() int64 { + t.mu.Lock() + defer t.mu.Unlock() + t.advance() + return t.maxTS[minute].max() +} + +// Reset resets the range to an empty state. +func (t *Tracker) Reset() { + t.mu.Lock() + defer t.mu.Unlock() + now := TimeNow() + for _, ts := range t.minTS { + ts.reset(now) + } + for _, ts := range t.maxTS { + ts.reset(now) + } + t.init() +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/util.go b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/util.go new file mode 100644 index 0000000..a9922f9 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/benchmark/stats/util.go @@ -0,0 +1,191 @@ +package stats + +import ( + "bufio" + "bytes" + "fmt" + "os" + "runtime" + "sort" + "strings" + "sync" + "testing" +) + +var ( + curB *testing.B + curBenchName string + curStats map[string]*Stats + + orgStdout *os.File + nextOutPos int + + injectCond *sync.Cond + injectDone chan struct{} +) + +// AddStats adds a new unnamed Stats instance to the current benchmark. You need +// to run benchmarks by calling RunTestMain() to inject the stats to the +// benchmark results. If numBuckets is not positive, the default value (16) will +// be used. Please note that this calls b.ResetTimer() since it may be blocked +// until the previous benchmark stats is printed out. So AddStats() should +// typically be called at the very beginning of each benchmark function. +func AddStats(b *testing.B, numBuckets int) *Stats { + return AddStatsWithName(b, "", numBuckets) +} + +// AddStatsWithName adds a new named Stats instance to the current benchmark. +// With this, you can add multiple stats in a single benchmark. You need +// to run benchmarks by calling RunTestMain() to inject the stats to the +// benchmark results. If numBuckets is not positive, the default value (16) will +// be used. Please note that this calls b.ResetTimer() since it may be blocked +// until the previous benchmark stats is printed out. So AddStatsWithName() +// should typically be called at the very beginning of each benchmark function. +func AddStatsWithName(b *testing.B, name string, numBuckets int) *Stats { + var benchName string + for i := 1; ; i++ { + pc, _, _, ok := runtime.Caller(i) + if !ok { + panic("benchmark function not found") + } + p := strings.Split(runtime.FuncForPC(pc).Name(), ".") + benchName = p[len(p)-1] + if strings.HasPrefix(benchName, "Benchmark") { + break + } + } + procs := runtime.GOMAXPROCS(-1) + if procs != 1 { + benchName = fmt.Sprintf("%s-%d", benchName, procs) + } + + stats := NewStats(numBuckets) + + if injectCond != nil { + // We need to wait until the previous benchmark stats is printed out. + injectCond.L.Lock() + for curB != nil && curBenchName != benchName { + injectCond.Wait() + } + + curB = b + curBenchName = benchName + curStats[name] = stats + + injectCond.L.Unlock() + } + + b.ResetTimer() + return stats +} + +// RunTestMain runs the tests with enabling injection of benchmark stats. It +// returns an exit code to pass to os.Exit. +func RunTestMain(m *testing.M) int { + startStatsInjector() + defer stopStatsInjector() + return m.Run() +} + +// startStatsInjector starts stats injection to benchmark results. +func startStatsInjector() { + orgStdout = os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + nextOutPos = 0 + + resetCurBenchStats() + + injectCond = sync.NewCond(&sync.Mutex{}) + injectDone = make(chan struct{}) + go func() { + defer close(injectDone) + + scanner := bufio.NewScanner(r) + scanner.Split(splitLines) + for scanner.Scan() { + injectStatsIfFinished(scanner.Text()) + } + if err := scanner.Err(); err != nil { + panic(err) + } + }() +} + +// stopStatsInjector stops stats injection and restores os.Stdout. +func stopStatsInjector() { + os.Stdout.Close() + <-injectDone + injectCond = nil + os.Stdout = orgStdout +} + +// splitLines is a split function for a bufio.Scanner that returns each line +// of text, teeing texts to the original stdout even before each line ends. +func splitLines(data []byte, eof bool) (advance int, token []byte, err error) { + if eof && len(data) == 0 { + return 0, nil, nil + } + + if i := bytes.IndexByte(data, '\n'); i >= 0 { + orgStdout.Write(data[nextOutPos : i+1]) + nextOutPos = 0 + return i + 1, data[0:i], nil + } + + orgStdout.Write(data[nextOutPos:]) + nextOutPos = len(data) + + if eof { + // This is a final, non-terminated line. Return it. + return len(data), data, nil + } + + return 0, nil, nil +} + +// injectStatsIfFinished prints out the stats if the current benchmark finishes. +func injectStatsIfFinished(line string) { + injectCond.L.Lock() + defer injectCond.L.Unlock() + + // We assume that the benchmark results start with the benchmark name. + if curB == nil || !strings.HasPrefix(line, curBenchName) { + return + } + + if !curB.Failed() { + // Output all stats in alphabetical order. + names := make([]string, 0, len(curStats)) + for name := range curStats { + names = append(names, name) + } + sort.Strings(names) + for _, name := range names { + stats := curStats[name] + // The output of stats starts with a header like "Histogram (unit: ms)" + // followed by statistical properties and the buckets. Add the stats name + // if it is a named stats and indent them as Go testing outputs. + lines := strings.Split(stats.String(), "\n") + if n := len(lines); n > 0 { + if name != "" { + name = ": " + name + } + fmt.Fprintf(orgStdout, "--- %s%s\n", lines[0], name) + for _, line := range lines[1 : n-1] { + fmt.Fprintf(orgStdout, "\t%s\n", line) + } + } + } + } + + resetCurBenchStats() + injectCond.Signal() +} + +// resetCurBenchStats resets the current benchmark stats. +func resetCurBenchStats() { + curB = nil + curBenchName = "" + curStats = make(map[string]*Stats) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/call.go b/Godeps/_workspace/src/google.golang.org/grpc/call.go new file mode 100644 index 0000000..63b7966 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/call.go @@ -0,0 +1,194 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc + +import ( + "io" + "time" + + "golang.org/x/net/context" + "golang.org/x/net/trace" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/transport" +) + +// recvResponse receives and parses an RPC response. +// On error, it returns the error and indicates whether the call should be retried. +// +// TODO(zhaoq): Check whether the received message sequence is valid. +func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error { + // Try to acquire header metadata from the server if there is any. + var err error + c.headerMD, err = stream.Header() + if err != nil { + return err + } + p := &parser{s: stream} + for { + if err = recv(p, codec, reply); err != nil { + if err == io.EOF { + break + } + return err + } + } + c.trailerMD = stream.Trailer() + return nil +} + +// sendRequest writes out various information of an RPC such as Context and Message. +func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) { + stream, err := t.NewStream(ctx, callHdr) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + if _, ok := err.(transport.ConnectionError); !ok { + t.CloseStream(stream, err) + } + } + }() + // TODO(zhaoq): Support compression. + outBuf, err := encode(codec, args, compressionNone) + if err != nil { + return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err) + } + err = t.Write(stream, outBuf, opts) + if err != nil { + return nil, err + } + // Sent successfully. + return stream, nil +} + +// callInfo contains all related configuration and information about an RPC. +type callInfo struct { + failFast bool + headerMD metadata.MD + trailerMD metadata.MD + traceInfo traceInfo // in trace.go +} + +// Invoke is called by the generated code. It sends the RPC request on the +// wire and returns after response is received. +func Invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (err error) { + var c callInfo + for _, o := range opts { + if err := o.before(&c); err != nil { + return toRPCErr(err) + } + } + defer func() { + for _, o := range opts { + o.after(&c) + } + }() + + if EnableTracing { + c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) + defer c.traceInfo.tr.Finish() + c.traceInfo.firstLine.client = true + if deadline, ok := ctx.Deadline(); ok { + c.traceInfo.firstLine.deadline = deadline.Sub(time.Now()) + } + c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false) + // TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set. + defer func() { + if err != nil { + c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + c.traceInfo.tr.SetError() + } + }() + } + callHdr := &transport.CallHdr{ + Host: cc.authority, + Method: method, + } + topts := &transport.Options{ + Last: true, + Delay: false, + } + var ( + ts int // track the transport sequence number + lastErr error // record the error that happened + ) + for { + var ( + err error + t transport.ClientTransport + stream *transport.Stream + ) + // TODO(zhaoq): Need a formal spec of retry strategy for non-failfast rpcs. + if lastErr != nil && c.failFast { + return toRPCErr(lastErr) + } + t, ts, err = cc.wait(ctx, ts) + if err != nil { + if lastErr != nil { + // This was a retry; return the error from the last attempt. + return toRPCErr(lastErr) + } + return toRPCErr(err) + } + if c.traceInfo.tr != nil { + c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) + } + stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts) + if err != nil { + if _, ok := err.(transport.ConnectionError); ok { + lastErr = err + continue + } + if lastErr != nil { + return toRPCErr(lastErr) + } + return toRPCErr(err) + } + // Receive the response + lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply) + if _, ok := lastErr.(transport.ConnectionError); ok { + continue + } + if c.traceInfo.tr != nil { + c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) + } + t.CloseStream(stream, lastErr) + if lastErr != nil { + return toRPCErr(lastErr) + } + return Errorf(stream.StatusCode(), stream.StatusDesc()) + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/clientconn.go b/Godeps/_workspace/src/google.golang.org/grpc/clientconn.go new file mode 100644 index 0000000..aacba5b --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/clientconn.go @@ -0,0 +1,458 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc + +import ( + "errors" + "fmt" + "net" + "strings" + "sync" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/transport" +) + +var ( + // ErrUnspecTarget indicates that the target address is unspecified. + ErrUnspecTarget = errors.New("grpc: target is unspecified") + // ErrNoTransportSecurity indicates that there is no transport security + // being set for ClientConn. Users should either set one or explicityly + // call WithInsecure DialOption to disable security. + ErrNoTransportSecurity = errors.New("grpc: no transport security set (use grpc.WithInsecure() explicitly or set credentials)") + // ErrCredentialsMisuse indicates that users want to transmit security infomation + // (e.g., oauth2 token) which requires secure connection on an insecure + // connection. + ErrCredentialsMisuse = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportAuthenticator() to set)") + // ErrClientConnClosing indicates that the operation is illegal because + // the session is closing. + ErrClientConnClosing = errors.New("grpc: the client connection is closing") + // ErrClientConnTimeout indicates that the connection could not be + // established or re-established within the specified timeout. + ErrClientConnTimeout = errors.New("grpc: timed out trying to connect") + // minimum time to give a connection to complete + minConnectTimeout = 20 * time.Second +) + +// dialOptions configure a Dial call. dialOptions are set by the DialOption +// values passed to Dial. +type dialOptions struct { + codec Codec + block bool + insecure bool + copts transport.ConnectOptions +} + +// DialOption configures how we set up the connection. +type DialOption func(*dialOptions) + +// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling. +func WithCodec(c Codec) DialOption { + return func(o *dialOptions) { + o.codec = c + } +} + +// WithBlock returns a DialOption which makes caller of Dial blocks until the underlying +// connection is up. Without this, Dial returns immediately and connecting the server +// happens in background. +func WithBlock() DialOption { + return func(o *dialOptions) { + o.block = true + } +} + +func WithInsecure() DialOption { + return func(o *dialOptions) { + o.insecure = true + } +} + +// WithTransportCredentials returns a DialOption which configures a +// connection level security credentials (e.g., TLS/SSL). +func WithTransportCredentials(creds credentials.TransportAuthenticator) DialOption { + return func(o *dialOptions) { + o.copts.AuthOptions = append(o.copts.AuthOptions, creds) + } +} + +// WithPerRPCCredentials returns a DialOption which sets +// credentials which will place auth state on each outbound RPC. +func WithPerRPCCredentials(creds credentials.Credentials) DialOption { + return func(o *dialOptions) { + o.copts.AuthOptions = append(o.copts.AuthOptions, creds) + } +} + +// WithTimeout returns a DialOption that configures a timeout for dialing a client connection. +func WithTimeout(d time.Duration) DialOption { + return func(o *dialOptions) { + o.copts.Timeout = d + } +} + +// WithDialer returns a DialOption that specifies a function to use for dialing network addresses. +func WithDialer(f func(addr string, timeout time.Duration) (net.Conn, error)) DialOption { + return func(o *dialOptions) { + o.copts.Dialer = f + } +} + +// WithUserAgent returns a DialOption that specifies a user agent string for all the RPCs. +func WithUserAgent(s string) DialOption { + return func(o *dialOptions) { + o.copts.UserAgent = s + } +} + +// Dial creates a client connection the given target. +func Dial(target string, opts ...DialOption) (*ClientConn, error) { + if target == "" { + return nil, ErrUnspecTarget + } + cc := &ClientConn{ + target: target, + shutdownChan: make(chan struct{}), + } + for _, opt := range opts { + opt(&cc.dopts) + } + if !cc.dopts.insecure { + var ok bool + for _, c := range cc.dopts.copts.AuthOptions { + if _, ok := c.(credentials.TransportAuthenticator); !ok { + continue + } + ok = true + } + if !ok { + return nil, ErrNoTransportSecurity + } + } else { + for _, c := range cc.dopts.copts.AuthOptions { + if c.RequireTransportSecurity() { + return nil, ErrCredentialsMisuse + } + } + } + colonPos := strings.LastIndex(target, ":") + if colonPos == -1 { + colonPos = len(target) + } + cc.authority = target[:colonPos] + if cc.dopts.codec == nil { + // Set the default codec. + cc.dopts.codec = protoCodec{} + } + cc.stateCV = sync.NewCond(&cc.mu) + if cc.dopts.block { + if err := cc.resetTransport(false); err != nil { + cc.Close() + return nil, err + } + // Start to monitor the error status of transport. + go cc.transportMonitor() + } else { + // Start a goroutine connecting to the server asynchronously. + go func() { + if err := cc.resetTransport(false); err != nil { + grpclog.Printf("Failed to dial %s: %v; please retry.", target, err) + cc.Close() + return + } + go cc.transportMonitor() + }() + } + return cc, nil +} + +// ConnectivityState indicates the state of a client connection. +type ConnectivityState int + +const ( + // Idle indicates the ClientConn is idle. + Idle ConnectivityState = iota + // Connecting indicates the ClienConn is connecting. + Connecting + // Ready indicates the ClientConn is ready for work. + Ready + // TransientFailure indicates the ClientConn has seen a failure but expects to recover. + TransientFailure + // Shutdown indicates the ClientConn has stated shutting down. + Shutdown +) + +func (s ConnectivityState) String() string { + switch s { + case Idle: + return "IDLE" + case Connecting: + return "CONNECTING" + case Ready: + return "READY" + case TransientFailure: + return "TRANSIENT_FAILURE" + case Shutdown: + return "SHUTDOWN" + default: + panic(fmt.Sprintf("unknown connectivity state: %d", s)) + } +} + +// ClientConn represents a client connection to an RPC service. +type ClientConn struct { + target string + authority string + dopts dialOptions + shutdownChan chan struct{} + + mu sync.Mutex + state ConnectivityState + stateCV *sync.Cond + // ready is closed and becomes nil when a new transport is up or failed + // due to timeout. + ready chan struct{} + // Every time a new transport is created, this is incremented by 1. Used + // to avoid trying to recreate a transport while the new one is already + // under construction. + transportSeq int + transport transport.ClientTransport +} + +// State returns the connectivity state of the ClientConn +func (cc *ClientConn) State() ConnectivityState { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.state +} + +// WaitForStateChange blocks until the state changes to something other than the sourceState +// or timeout fires. It returns false if timeout fires and true otherwise. +func (cc *ClientConn) WaitForStateChange(timeout time.Duration, sourceState ConnectivityState) bool { + start := time.Now() + cc.mu.Lock() + defer cc.mu.Unlock() + if sourceState != cc.state { + return true + } + expired := timeout <= time.Since(start) + if expired { + return false + } + done := make(chan struct{}) + go func() { + select { + case <-time.After(timeout - time.Since(start)): + cc.mu.Lock() + expired = true + cc.stateCV.Broadcast() + cc.mu.Unlock() + case <-done: + } + }() + defer close(done) + for sourceState == cc.state { + cc.stateCV.Wait() + if expired { + return false + } + } + return true +} + +func (cc *ClientConn) resetTransport(closeTransport bool) error { + var retries int + start := time.Now() + for { + cc.mu.Lock() + cc.state = Connecting + cc.stateCV.Broadcast() + t := cc.transport + ts := cc.transportSeq + // Avoid wait() picking up a dying transport unnecessarily. + cc.transportSeq = 0 + if cc.state == Shutdown { + cc.mu.Unlock() + return ErrClientConnClosing + } + cc.mu.Unlock() + if closeTransport { + t.Close() + } + // Adjust timeout for the current try. + copts := cc.dopts.copts + if copts.Timeout < 0 { + cc.Close() + return ErrClientConnTimeout + } + if copts.Timeout > 0 { + copts.Timeout -= time.Since(start) + if copts.Timeout <= 0 { + cc.Close() + return ErrClientConnTimeout + } + } + sleepTime := backoff(retries) + timeout := sleepTime + if timeout < minConnectTimeout { + timeout = minConnectTimeout + } + if copts.Timeout == 0 || copts.Timeout > timeout { + copts.Timeout = timeout + } + connectTime := time.Now() + newTransport, err := transport.NewClientTransport(cc.target, &copts) + if err != nil { + cc.mu.Lock() + cc.state = TransientFailure + cc.stateCV.Broadcast() + cc.mu.Unlock() + sleepTime -= time.Since(connectTime) + if sleepTime < 0 { + sleepTime = 0 + } + // Fail early before falling into sleep. + if cc.dopts.copts.Timeout > 0 && cc.dopts.copts.Timeout < sleepTime+time.Since(start) { + cc.Close() + return ErrClientConnTimeout + } + closeTransport = false + time.Sleep(sleepTime) + retries++ + grpclog.Printf("grpc: ClientConn.resetTransport failed to create client transport: %v; Reconnecting to %q", err, cc.target) + continue + } + cc.mu.Lock() + if cc.state == Shutdown { + // cc.Close() has been invoked. + cc.mu.Unlock() + newTransport.Close() + return ErrClientConnClosing + } + cc.state = Ready + cc.stateCV.Broadcast() + cc.transport = newTransport + cc.transportSeq = ts + 1 + if cc.ready != nil { + close(cc.ready) + cc.ready = nil + } + cc.mu.Unlock() + return nil + } +} + +// Run in a goroutine to track the error in transport and create the +// new transport if an error happens. It returns when the channel is closing. +func (cc *ClientConn) transportMonitor() { + for { + select { + // shutdownChan is needed to detect the teardown when + // the ClientConn is idle (i.e., no RPC in flight). + case <-cc.shutdownChan: + return + case <-cc.transport.Error(): + cc.mu.Lock() + cc.state = TransientFailure + cc.stateCV.Broadcast() + cc.mu.Unlock() + if err := cc.resetTransport(true); err != nil { + // The ClientConn is closing. + grpclog.Printf("grpc: ClientConn.transportMonitor exits due to: %v", err) + return + } + continue + } + } +} + +// When wait returns, either the new transport is up or ClientConn is +// closing. Used to avoid working on a dying transport. It updates and +// returns the transport and its version when there is no error. +func (cc *ClientConn) wait(ctx context.Context, ts int) (transport.ClientTransport, int, error) { + for { + cc.mu.Lock() + switch { + case cc.state == Shutdown: + cc.mu.Unlock() + return nil, 0, ErrClientConnClosing + case ts < cc.transportSeq: + // Worked on a dying transport. Try the new one immediately. + defer cc.mu.Unlock() + return cc.transport, cc.transportSeq, nil + default: + ready := cc.ready + if ready == nil { + ready = make(chan struct{}) + cc.ready = ready + } + cc.mu.Unlock() + select { + case <-ctx.Done(): + return nil, 0, transport.ContextErr(ctx.Err()) + // Wait until the new transport is ready or failed. + case <-ready: + } + } + } +} + +// Close starts to tear down the ClientConn. Returns ErrClientConnClosing if +// it has been closed (mostly due to dial time-out). +// TODO(zhaoq): Make this synchronous to avoid unbounded memory consumption in +// some edge cases (e.g., the caller opens and closes many ClientConn's in a +// tight loop. +func (cc *ClientConn) Close() error { + cc.mu.Lock() + defer cc.mu.Unlock() + if cc.state == Shutdown { + return ErrClientConnClosing + } + cc.state = Shutdown + cc.stateCV.Broadcast() + if cc.ready != nil { + close(cc.ready) + cc.ready = nil + } + if cc.transport != nil { + cc.transport.Close() + } + if cc.shutdownChan != nil { + close(cc.shutdownChan) + } + return nil +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/codegen.sh b/Godeps/_workspace/src/google.golang.org/grpc/codegen.sh new file mode 100644 index 0000000..b009488 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/codegen.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# This script serves as an example to demonstrate how to generate the gRPC-Go +# interface and the related messages from .proto file. +# +# It assumes the installation of i) Google proto buffer compiler at +# https://github.com/google/protobuf (after v2.6.1) and ii) the Go codegen +# plugin at https://github.com/golang/protobuf (after 2015-02-20). If you have +# not, please install them first. +# +# We recommend running this script at $GOPATH/src. +# +# If this is not what you need, feel free to make your own scripts. Again, this +# script is for demonstration purpose. +# +proto=$1 +protoc --go_out=plugins=grpc:. $proto diff --git a/Godeps/_workspace/src/google.golang.org/grpc/codes/code_string.go b/Godeps/_workspace/src/google.golang.org/grpc/codes/code_string.go new file mode 100644 index 0000000..e6762d0 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/codes/code_string.go @@ -0,0 +1,16 @@ +// generated by stringer -type=Code; DO NOT EDIT + +package codes + +import "fmt" + +const _Code_name = "OKCanceledUnknownInvalidArgumentDeadlineExceededNotFoundAlreadyExistsPermissionDeniedResourceExhaustedFailedPreconditionAbortedOutOfRangeUnimplementedInternalUnavailableDataLossUnauthenticated" + +var _Code_index = [...]uint8{0, 2, 10, 17, 32, 48, 56, 69, 85, 102, 120, 127, 137, 150, 158, 169, 177, 192} + +func (i Code) String() string { + if i+1 >= Code(len(_Code_index)) { + return fmt.Sprintf("Code(%d)", i) + } + return _Code_name[_Code_index[i]:_Code_index[i+1]] +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/codes/codes.go b/Godeps/_workspace/src/google.golang.org/grpc/codes/codes.go new file mode 100644 index 0000000..37c5b86 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/codes/codes.go @@ -0,0 +1,159 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// Package codes defines the canonical error codes used by gRPC. It is +// consistent across various languages. +package codes + +// A Code is an unsigned 32-bit error code as defined in the gRPC spec. +type Code uint32 + +//go:generate stringer -type=Code + +const ( + // OK is returned on success. + OK Code = 0 + + // Canceled indicates the operation was cancelled (typically by the caller). + Canceled Code = 1 + + // Unknown error. An example of where this error may be returned is + // if a Status value received from another address space belongs to + // an error-space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + Unknown Code = 2 + + // InvalidArgument indicates client specified an invalid argument. + // Note that this differs from FailedPrecondition. It indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + InvalidArgument Code = 3 + + // DeadlineExceeded means operation expired before completion. + // For operations that change the state of the system, this error may be + // returned even if the operation has completed successfully. For + // example, a successful response from a server could have been delayed + // long enough for the deadline to expire. + DeadlineExceeded Code = 4 + + // NotFound means some requested entity (e.g., file or directory) was + // not found. + NotFound Code = 5 + + // AlreadyExists means an attempt to create an entity failed because one + // already exists. + AlreadyExists Code = 6 + + // PermissionDenied indicates the caller does not have permission to + // execute the specified operation. It must not be used for rejections + // caused by exhausting some resource (use ResourceExhausted + // instead for those errors). It must not be + // used if the caller cannot be identified (use Unauthenticated + // instead for those errors). + PermissionDenied Code = 7 + + // Unauthenticated indicates the request does not have valid + // authentication credentials for the operation. + Unauthenticated Code = 16 + + // ResourceExhausted indicates some resource has been exhausted, perhaps + // a per-user quota, or perhaps the entire file system is out of space. + ResourceExhausted Code = 8 + + // FailedPrecondition indicates operation was rejected because the + // system is not in a state required for the operation's execution. + // For example, directory to be deleted may be non-empty, an rmdir + // operation is applied to a non-directory, etc. + // + // A litmus test that may help a service implementor in deciding + // between FailedPrecondition, Aborted, and Unavailable: + // (a) Use Unavailable if the client can retry just the failing call. + // (b) Use Aborted if the client should retry at a higher-level + // (e.g., restarting a read-modify-write sequence). + // (c) Use FailedPrecondition if the client should not retry until + // the system state has been explicitly fixed. E.g., if an "rmdir" + // fails because the directory is non-empty, FailedPrecondition + // should be returned since the client should not retry unless + // they have first fixed up the directory by deleting files from it. + // (d) Use FailedPrecondition if the client performs conditional + // REST Get/Update/Delete on a resource and the resource on the + // server does not match the condition. E.g., conflicting + // read-modify-write on the same resource. + FailedPrecondition Code = 9 + + // Aborted indicates the operation was aborted, typically due to a + // concurrency issue like sequencer check failures, transaction aborts, + // etc. + // + // See litmus test above for deciding between FailedPrecondition, + // Aborted, and Unavailable. + Aborted Code = 10 + + // OutOfRange means operation was attempted past the valid range. + // E.g., seeking or reading past end of file. + // + // Unlike InvalidArgument, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate InvalidArgument if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // OutOfRange if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between FailedPrecondition and + // OutOfRange. We recommend using OutOfRange (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an OutOfRange error to detect when + // they are done. + OutOfRange Code = 11 + + // Unimplemented indicates operation is not implemented or not + // supported/enabled in this service. + Unimplemented Code = 12 + + // Internal errors. Means some invariants expected by underlying + // system has been broken. If you see one of these errors, + // something is very broken. + Internal Code = 13 + + // Unavailable indicates the service is currently unavailable. + // This is a most likely a transient condition and may be corrected + // by retrying with a backoff. + // + // See litmus test above for deciding between FailedPrecondition, + // Aborted, and Unavailable. + Unavailable Code = 14 + + // DataLoss indicates unrecoverable data loss or corruption. + DataLoss Code = 15 +) diff --git a/Godeps/_workspace/src/google.golang.org/grpc/credentials/credentials.go b/Godeps/_workspace/src/google.golang.org/grpc/credentials/credentials.go new file mode 100644 index 0000000..3cec7e4 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/credentials/credentials.go @@ -0,0 +1,239 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// Package credentials implements various credentials supported by gRPC library, +// which encapsulate all the state needed by a client to authenticate with a +// server and make various assertions, e.g., about the client's identity, role, +// or whether it is authorized to make a particular call. +package credentials + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "strings" + "time" + + "golang.org/x/net/context" +) + +var ( + // alpnProtoStr are the specified application level protocols for gRPC. + alpnProtoStr = []string{"h2"} +) + +// Credentials defines the common interface all supported credentials must +// implement. +type Credentials interface { + // GetRequestMetadata gets the current request metadata, refreshing + // tokens if required. This should be called by the transport layer on + // each request, and the data should be populated in headers or other + // context. uri is the URI of the entry point for the request. When + // supported by the underlying implementation, ctx can be used for + // timeout and cancellation. + // TODO(zhaoq): Define the set of the qualified keys instead of leaving + // it as an arbitrary string. + GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) + // RequireTransportSecurity indicates whether the credentails requires + // transport security. + RequireTransportSecurity() bool +} + +// ProtocolInfo provides information regarding the gRPC wire protocol version, +// security protocol, security protocol version in use, etc. +type ProtocolInfo struct { + // ProtocolVersion is the gRPC wire protocol version. + ProtocolVersion string + // SecurityProtocol is the security protocol in use. + SecurityProtocol string + // SecurityVersion is the security protocol version. + SecurityVersion string +} + +// AuthInfo defines the common interface for the auth information the users are interested in. +type AuthInfo interface { + AuthType() string +} + +type authInfoKey struct{} + +// NewContext creates a new context with authInfo attached. +func NewContext(ctx context.Context, authInfo AuthInfo) context.Context { + return context.WithValue(ctx, authInfoKey{}, authInfo) +} + +// FromContext returns the authInfo in ctx if it exists. +func FromContext(ctx context.Context) (authInfo AuthInfo, ok bool) { + authInfo, ok = ctx.Value(authInfoKey{}).(AuthInfo) + return +} + +// TransportAuthenticator defines the common interface for all the live gRPC wire +// protocols and supported transport security protocols (e.g., TLS, SSL). +type TransportAuthenticator interface { + // ClientHandshake does the authentication handshake specified by the corresponding + // authentication protocol on rawConn for clients. It returns the authenticated + // connection and the corresponding auth information about the connection. + ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, AuthInfo, error) + // ServerHandshake does the authentication handshake for servers. It returns + // the authenticated connection and the corresponding auth information about + // the connection. + ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) + // Info provides the ProtocolInfo of this TransportAuthenticator. + Info() ProtocolInfo + Credentials +} + +// TLSInfo contains the auth information for a TLS authenticated connection. +// It implements the AuthInfo interface. +type TLSInfo struct { + state tls.ConnectionState +} + +func (t TLSInfo) AuthType() string { + return "tls" +} + +// tlsCreds is the credentials required for authenticating a connection using TLS. +type tlsCreds struct { + // TLS configuration + config tls.Config +} + +func (c tlsCreds) Info() ProtocolInfo { + return ProtocolInfo{ + SecurityProtocol: "tls", + SecurityVersion: "1.2", + } +} + +// GetRequestMetadata returns nil, nil since TLS credentials does not have +// metadata. +func (c *tlsCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return nil, nil +} + +func (c *tlsCreds) RequireTransportSecurity() bool { + return true +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "credentials: Dial timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ AuthInfo, err error) { + // borrow some code from tls.DialWithDialer + var errChannel chan error + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- timeoutError{} + }) + } + if c.config.ServerName == "" { + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + c.config.ServerName = addr[:colonPos] + } + conn := tls.Client(rawConn, &c.config) + if timeout == 0 { + err = conn.Handshake() + } else { + go func() { + errChannel <- conn.Handshake() + }() + err = <-errChannel + } + if err != nil { + rawConn.Close() + return nil, nil, err + } + // TODO(zhaoq): Omit the auth info for client now. It is more for + // information than anything else. + return conn, nil, nil +} + +func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, AuthInfo, error) { + conn := tls.Server(rawConn, &c.config) + if err := conn.Handshake(); err != nil { + rawConn.Close() + return nil, nil, err + } + return conn, TLSInfo{conn.ConnectionState()}, nil +} + +// NewTLS uses c to construct a TransportAuthenticator based on TLS. +func NewTLS(c *tls.Config) TransportAuthenticator { + tc := &tlsCreds{*c} + tc.config.NextProtos = alpnProtoStr + return tc +} + +// NewClientTLSFromCert constructs a TLS from the input certificate for client. +func NewClientTLSFromCert(cp *x509.CertPool, serverName string) TransportAuthenticator { + return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}) +} + +// NewClientTLSFromFile constructs a TLS from the input certificate file for client. +func NewClientTLSFromFile(certFile, serverName string) (TransportAuthenticator, error) { + b, err := ioutil.ReadFile(certFile) + if err != nil { + return nil, err + } + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(b) { + return nil, fmt.Errorf("credentials: failed to append certificates") + } + return NewTLS(&tls.Config{ServerName: serverName, RootCAs: cp}), nil +} + +// NewServerTLSFromCert constructs a TLS from the input certificate for server. +func NewServerTLSFromCert(cert *tls.Certificate) TransportAuthenticator { + return NewTLS(&tls.Config{Certificates: []tls.Certificate{*cert}}) +} + +// NewServerTLSFromFile constructs a TLS from the input certificate file and key +// file for server. +func NewServerTLSFromFile(certFile, keyFile string) (TransportAuthenticator, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + return NewTLS(&tls.Config{Certificates: []tls.Certificate{cert}}), nil +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/credentials/oauth/oauth.go b/Godeps/_workspace/src/google.golang.org/grpc/credentials/oauth/oauth.go new file mode 100644 index 0000000..04943fd --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/credentials/oauth/oauth.go @@ -0,0 +1,177 @@ +/* + * + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// Package oauth implements gRPC credentials using OAuth. +package oauth + +import ( + "fmt" + "io/ioutil" + + "golang.org/x/net/context" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "golang.org/x/oauth2/jwt" + "google.golang.org/grpc/credentials" +) + +// TokenSource supplies credentials from an oauth2.TokenSource. +type TokenSource struct { + oauth2.TokenSource +} + +// GetRequestMetadata gets the request metadata as a map from a TokenSource. +func (ts TokenSource) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + token, err := ts.Token() + if err != nil { + return nil, err + } + return map[string]string{ + "authorization": token.TokenType + " " + token.AccessToken, + }, nil +} + +func (ts TokenSource) RequireTransportSecurity() bool { + return true +} + +type jwtAccess struct { + jsonKey []byte +} + +func NewJWTAccessFromFile(keyFile string) (credentials.Credentials, error) { + jsonKey, err := ioutil.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err) + } + return NewJWTAccessFromKey(jsonKey) +} + +func NewJWTAccessFromKey(jsonKey []byte) (credentials.Credentials, error) { + return jwtAccess{jsonKey}, nil +} + +func (j jwtAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + ts, err := google.JWTAccessTokenSourceFromJSON(j.jsonKey, uri[0]) + if err != nil { + return nil, err + } + token, err := ts.Token() + if err != nil { + return nil, err + } + return map[string]string{ + "authorization": token.TokenType + " " + token.AccessToken, + }, nil +} + +func (j jwtAccess) RequireTransportSecurity() bool { + return true +} + +// oauthAccess supplies credentials from a given token. +type oauthAccess struct { + token oauth2.Token +} + +// NewOauthAccess constructs the credentials using a given token. +func NewOauthAccess(token *oauth2.Token) credentials.Credentials { + return oauthAccess{token: *token} +} + +func (oa oauthAccess) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + return map[string]string{ + "authorization": oa.token.TokenType + " " + oa.token.AccessToken, + }, nil +} + +func (oa oauthAccess) RequireTransportSecurity() bool { + return true +} + +// NewComputeEngine constructs the credentials that fetches access tokens from +// Google Compute Engine (GCE)'s metadata server. It is only valid to use this +// if your program is running on a GCE instance. +// TODO(dsymonds): Deprecate and remove this. +func NewComputeEngine() credentials.Credentials { + return TokenSource{google.ComputeTokenSource("")} +} + +// serviceAccount represents credentials via JWT signing key. +type serviceAccount struct { + config *jwt.Config +} + +func (s serviceAccount) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { + token, err := s.config.TokenSource(ctx).Token() + if err != nil { + return nil, err + } + return map[string]string{ + "authorization": token.TokenType + " " + token.AccessToken, + }, nil +} + +func (s serviceAccount) RequireTransportSecurity() bool { + return true +} + +// NewServiceAccountFromKey constructs the credentials using the JSON key slice +// from a Google Developers service account. +func NewServiceAccountFromKey(jsonKey []byte, scope ...string) (credentials.Credentials, error) { + config, err := google.JWTConfigFromJSON(jsonKey, scope...) + if err != nil { + return nil, err + } + return serviceAccount{config: config}, nil +} + +// NewServiceAccountFromFile constructs the credentials using the JSON key file +// of a Google Developers service account. +func NewServiceAccountFromFile(keyFile string, scope ...string) (credentials.Credentials, error) { + jsonKey, err := ioutil.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("credentials: failed to read the service account key file: %v", err) + } + return NewServiceAccountFromKey(jsonKey, scope...) +} + +// NewApplicationDefault returns "Application Default Credentials". For more +// detail, see https://developers.google.com/accounts/docs/application-default-credentials. +func NewApplicationDefault(ctx context.Context, scope ...string) (credentials.Credentials, error) { + t, err := google.DefaultTokenSource(ctx, scope...) + if err != nil { + return nil, err + } + return TokenSource{t}, nil +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/doc.go b/Godeps/_workspace/src/google.golang.org/grpc/doc.go new file mode 100644 index 0000000..c0f54f7 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/doc.go @@ -0,0 +1,6 @@ +/* +Package grpc implements an RPC system called gRPC. + +See https://github.com/grpc/grpc for more information about gRPC. +*/ +package grpc diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/README.md b/Godeps/_workspace/src/google.golang.org/grpc/examples/README.md new file mode 100644 index 0000000..e5c03c3 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/README.md @@ -0,0 +1,53 @@ +gRPC in 3 minutes (Go) +====================== + +BACKGROUND +------------- +For this sample, we've already generated the server and client stubs from [helloworld.proto](examples/helloworld/proto/helloworld.proto). + +PREREQUISITES +------------- + +- This requires Go 1.4 +- Requires that [GOPATH is set](https://golang.org/doc/code.html#GOPATH) +```sh +$ go help gopath +$ # ensure the PATH contains $GOPATH/bin +$ export PATH=$PATH:$GOPATH/bin +``` + +INSTALL +------- + +```sh +$ go get -u github.com/grpc/grpc-go/examples/greeter_client +$ go get -u github.com/grpc/grpc-go/examples/greeter_server +``` + +TRY IT! +------- + +- Run the server +```sh +$ greeter_server & +``` + +- Run the client +```sh +$ greeter_client +``` + +OPTIONAL - Rebuilding the generated code +---------------------------------------- + +1 First [install protoc](https://github.com/google/protobuf/blob/master/INSTALL.txt) + - For now, this needs to be installed from source + - This is will change once proto3 is officially released + +2 Install the protoc Go plugin. +```sh +$ go get -a github.com/golang/protobuf/protoc-gen-go +$ +$ # from this dir; invoke protoc +$ protoc -I ./helloworld/proto/ ./helloworld/proto/helloworld.proto --go_out=plugins=grpc:helloworld +``` diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/gotutorial.md b/Godeps/_workspace/src/google.golang.org/grpc/examples/gotutorial.md new file mode 100644 index 0000000..8df15a4 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/gotutorial.md @@ -0,0 +1,431 @@ +#gRPC Basics: Go + +This tutorial provides a basic Go programmer's introduction to working with gRPC. By walking through this example you'll learn how to: + +- Define a service in a .proto file. +- Generate server and client code using the protocol buffer compiler. +- Use the Go gRPC API to write a simple client and server for your service. + +It assumes that you have read the [Getting started](https://github.com/grpc/grpc/tree/master/examples) guide and are familiar with [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). Note that the example in this tutorial uses the proto3 version of the protocol buffers language, which is currently in alpha release:you can find out more in the [proto3 language guide](https://developers.google.com/protocol-buffers/docs/proto3) and see the [release notes](https://github.com/google/protobuf/releases) for the new version in the protocol buffers Github repository. + +This isn't a comprehensive guide to using gRPC in Go: more reference documentation is coming soon. + +## Why use gRPC? + +Our example is a simple route mapping application that lets clients get information about features on their route, create a summary of their route, and exchange route information such as traffic updates with the server and other clients. + +With gRPC we can define our service once in a .proto file and implement clients and servers in any of gRPC's supported languages, which in turn can be run in environments ranging from servers inside Google to your own tablet - all the complexity of communication between different languages and environments is handled for you by gRPC. We also get all the advantages of working with protocol buffers, including efficient serialization, a simple IDL, and easy interface updating. + +## Example code and setup + +The example code for our tutorial is in [grpc/grpc-go/examples/route_guide](https://github.com/grpc/grpc-go/tree/master/examples/route_guide). To download the example, clone the `grpc-go` repository by running the following command: +```shell +$ go get google.golang.org/grpc +``` + +Then change your current directory to `grpc-go/examples/route_guide`: +```shell +$ cd $GOPATH/src/google.golang.org/grpc/examples/route_guide +``` + +You also should have the relevant tools installed to generate the server and client interface code - if you don't already, follow the setup instructions in [the Go quick start guide](examples/). + + +## Defining the service + +Our first step (as you'll know from [Getting started](https://github.com/grpc/grpc/tree/master/examples)) is to define the gRPC *service* and the method *request* and *response* types using [protocol buffers] (https://developers.google.com/protocol-buffers/docs/overview). You can see the complete .proto file in [`examples/route_guide/proto/route_guide.proto`](examples/route_guide/proto/route_guide.proto). + +To define a service, you specify a named `service` in your .proto file: + +```proto +service RouteGuide { + ... +} +``` + +Then you define `rpc` methods inside your service definition, specifying their request and response types. gRPC lets you define four kinds of service method, all of which are used in the `RouteGuide` service: + +- A *simple RPC* where the client sends a request to the server using the stub and waits for a response to come back, just like a normal function call. +```proto + // Obtains the feature at a given position. + rpc GetFeature(Point) returns (Feature) {} +``` + +- A *server-side streaming RPC* where the client sends a request to the server and gets a stream to read a sequence of messages back. The client reads from the returned stream until there are no more messages. As you can see in our example, you specify a server-side streaming method by placing the `stream` keyword before the *response* type. +```proto + // Obtains the Features available within the given Rectangle. Results are + // streamed rather than returned at once (e.g. in a response message with a + // repeated field), as the rectangle may cover a large area and contain a + // huge number of features. + rpc ListFeatures(Rectangle) returns (stream Feature) {} +``` + +- A *client-side streaming RPC* where the client writes a sequence of messages and sends them to the server, again using a provided stream. Once the client has finished writing the messages, it waits for the server to read them all and return its response. You specify a client-side streaming method by placing the `stream` keyword before the *request* type. +```proto + // Accepts a stream of Points on a route being traversed, returning a + // RouteSummary when traversal is completed. + rpc RecordRoute(stream Point) returns (RouteSummary) {} +``` + +- A *bidirectional streaming RPC* where both sides send a sequence of messages using a read-write stream. The two streams operate independently, so clients and servers can read and write in whatever order they like: for example, the server could wait to receive all the client messages before writing its responses, or it could alternately read a message then write a message, or some other combination of reads and writes. The order of messages in each stream is preserved. You specify this type of method by placing the `stream` keyword before both the request and the response. +```proto + // Accepts a stream of RouteNotes sent while a route is being traversed, + // while receiving other RouteNotes (e.g. from other users). + rpc RouteChat(stream RouteNote) returns (stream RouteNote) {} +``` + +Our .proto file also contains protocol buffer message type definitions for all the request and response types used in our service methods - for example, here's the `Point` message type: +```proto +// Points are represented as latitude-longitude pairs in the E7 representation +// (degrees multiplied by 10**7 and rounded to the nearest integer). +// Latitudes should be in the range +/- 90 degrees and longitude should be in +// the range +/- 180 degrees (inclusive). +message Point { + int32 latitude = 1; + int32 longitude = 2; +} +``` + + +## Generating client and server code + +Next we need to generate the gRPC client and server interfaces from our .proto service definition. We do this using the protocol buffer compiler `protoc` with a special gRPC Go plugin. + +For simplicity, we've provided a [bash script](https://github.com/grpc/grpc-go/blob/master/codegen.sh) that runs `protoc` for you with the appropriate plugin, input, and output (if you want to run this by yourself, make sure you've installed protoc and followed the gRPC-Go [installation instructions](https://github.com/grpc/grpc-go/blob/master/README.md) first): + +```shell +$ codegen.sh route_guide.proto +``` + +which actually runs: + +```shell +$ protoc --go_out=plugins=grpc:. route_guide.proto +``` + +Running this command generates the following file in your current directory: +- `route_guide.pb.go` + +This contains: +- All the protocol buffer code to populate, serialize, and retrieve our request and response message types +- An interface type (or *stub*) for clients to call with the methods defined in the `RouteGuide` service. +- An interface type for servers to implement, also with the methods defined in the `RouteGuide` service. + + + +## Creating the server + +First let's look at how we create a `RouteGuide` server. If you're only interested in creating gRPC clients, you can skip this section and go straight to [Creating the client](#client) (though you might find it interesting anyway!). + +There are two parts to making our `RouteGuide` service do its job: +- Implementing the service interface generated from our service definition: doing the actual "work" of our service. +- Running a gRPC server to listen for requests from clients and dispatch them to the right service implementation. + +You can find our example `RouteGuide` server in [grpc-go/examples/route_guide/server/server.go](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/server/server.go). Let's take a closer look at how it works. + +### Implementing RouteGuide + +As you can see, our server has a `routeGuideServer` struct type that implements the generated `RouteGuideServer` interface: + +```go +type routeGuideServer struct { + ... +} +... + +func (s *routeGuideServer) GetFeature(ctx context.Context, point *pb.Point) (*pb.Feature, error) { + ... +} +... + +func (s *routeGuideServer) ListFeatures(rect *pb.Rectangle, stream pb.RouteGuide_ListFeaturesServer) error { + ... +} +... + +func (s *routeGuideServer) RecordRoute(stream pb.RouteGuide_RecordRouteServer) error { + ... +} +... + +func (s *routeGuideServer) RouteChat(stream pb.RouteGuide_RouteChatServer) error { + ... +} +... +``` + +#### Simple RPC +`routeGuideServer` implements all our service methods. Let's look at the simplest type first, `GetFeature`, which just gets a `Point` from the client and returns the corresponding feature information from its database in a `Feature`. + +```go +func (s *routeGuideServer) GetFeature(ctx context.Context, point *pb.Point) (*pb.Feature, error) { + for _, feature := range s.savedFeatures { + if proto.Equal(feature.Location, point) { + return feature, nil + } + } + // No feature was found, return an unnamed feature + return &pb.Feature{"", point}, nil +} +``` + +The method is passed a context object for the RPC and the client's `Point` protocol buffer request. It returns a `Feature` protocol buffer object with the response information and an `error`. In the method we populate the `Feature` with the appropriate information, and then `return` it along with an `nil` error to tell gRPC that we've finished dealing with the RPC and that the `Feature` can be returned to the client. + +#### Server-side streaming RPC +Now let's look at one of our streaming RPCs. `ListFeatures` is a server-side streaming RPC, so we need to send back multiple `Feature`s to our client. + +```go +func (s *routeGuideServer) ListFeatures(rect *pb.Rectangle, stream pb.RouteGuide_ListFeaturesServer) error { + for _, feature := range s.savedFeatures { + if inRange(feature.Location, rect) { + if err := stream.Send(feature); err != nil { + return err + } + } + } + return nil +} +``` + +As you can see, instead of getting simple request and response objects in our method parameters, this time we get a request object (the `Rectangle` in which our client wants to find `Feature`s) and a special `RouteGuide_ListFeaturesServer` object to write our responses. + +In the method, we populate as many `Feature` objects as we need to return, writing them to the `RouteGuide_ListFeaturesServer` using its `Send()` method. Finally, as in our simple RPC, we return a `nil` error to tell gRPC that we've finished writing responses. Should any error happen in this call, we return a non-`nil` error; the gRPC layer will translate it into an appropriate RPC status to be sent on the wire. + +#### Client-side streaming RPC +Now let's look at something a little more complicated: the client-side streaming method `RecordRoute`, where we get a stream of `Point`s from the client and return a single `RouteSummary` with information about their trip. As you can see, this time the method doesn't have a request parameter at all. Instead, it gets a `RouteGuide_RecordRouteServer` stream, which the server can use to both read *and* write messages - it can receive client messages using its `Recv()` method and return its single response using its `SendAndClose()` method. + +```go +func (s *routeGuideServer) RecordRoute(stream pb.RouteGuide_RecordRouteServer) error { + var pointCount, featureCount, distance int32 + var lastPoint *pb.Point + startTime := time.Now() + for { + point, err := stream.Recv() + if err == io.EOF { + endTime := time.Now() + return stream.SendAndClose(&pb.RouteSummary{ + PointCount: pointCount, + FeatureCount: featureCount, + Distance: distance, + ElapsedTime: int32(endTime.Sub(startTime).Seconds()), + }) + } + if err != nil { + return err + } + pointCount++ + for _, feature := range s.savedFeatures { + if proto.Equal(feature.Location, point) { + featureCount++ + } + } + if lastPoint != nil { + distance += calcDistance(lastPoint, point) + } + lastPoint = point + } +} +``` + +In the method body we use the `RouteGuide_RecordRouteServer`s `Recv()` method to repeatedly read in our client's requests to a request object (in this case a `Point`) until there are no more messages: the server needs to check the the error returned from `Read()` after each call. If this is `nil`, the stream is still good and it can continue reading; if it's `io.EOF` the message stream has ended and the server can return its `RouteSummary`. If it has any other value, we return the error "as is" so that it'll be translated to an RPC status by the gRPC layer. + +#### Bidirectional streaming RPC +Finally, let's look at our bidirectional streaming RPC `RouteChat()`. + +```go +func (s *routeGuideServer) RouteChat(stream pb.RouteGuide_RouteChatServer) error { + for { + in, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + key := serialize(in.Location) + ... // look for notes to be sent to client + for _, note := range s.routeNotes[key] { + if err := stream.Send(note); err != nil { + return err + } + } + } +} +``` + +This time we get a `RouteGuide_RouteChatServer` stream that, as in our client-side streaming example, can be used to read and write messages. However, this time we return values via our method's stream while the client is still writing messages to *their* message stream. + +The syntax for reading and writing here is very similar to our client-streaming method, except the server uses the stream's `Send()` method rather than `SendAndClose()` because it's writing multiple responses. Although each side will always get the other's messages in the order they were written, both the client and server can read and write in any order — the streams operate completely independently. + +### Starting the server + +Once we've implemented all our methods, we also need to start up a gRPC server so that clients can actually use our service. The following snippet shows how we do this for our `RouteGuide` service: + +```go +flag.Parse() +lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) +if err != nil { + log.Fatalf("failed to listen: %v", err) +} +grpcServer := grpc.NewServer() +pb.RegisterRouteGuideServer(grpcServer, &routeGuideServer{}) +... // determine whether to use TLS +grpcServer.Serve(lis) +``` +To build and start a server, we: + +1. Specify the port we want to use to listen for client requests using `lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))`. +2. Create an instance of the gRPC server using `grpc.NewServer()`. +3. Register our service implementation with the gRPC server. +4. Call `Serve()` on the server with our port details to do a blocking wait until the process is killed or `Stop()` is called. + + +## Creating the client + +In this section, we'll look at creating a Go client for our `RouteGuide` service. You can see our complete example client code in [grpc-go/examples/route_guide/client/client.go](https://github.com/grpc/grpc-go/tree/master/examples/route_guide/client/client.go). + +### Creating a stub + +To call service methods, we first need to create a gRPC *channel* to communicate with the server. We create this by passing the server address and port number to `grpc.Dial()` as follows: + +```go +conn, err := grpc.Dial(*serverAddr) +if err != nil { + ... +} +defer conn.Close() +``` + +You can use `DialOptions` to set the auth credentials (e.g., TLS, GCE credentials, JWT credentials) in `grpc.Dial` if the service you request requires that - however, we don't need to do this for our `RouteGuide` service. + +Once the gRPC *channel* is setup, we need a client *stub* to perform RPCs. We get this using the `NewRouteGuideClient` method provided in the `pb` package we generated from our .proto. + +```go +client := pb.NewRouteGuideClient(conn) +``` + +### Calling service methods + +Now let's look at how we call our service methods. Note that in gRPC-Go, RPCs operate in a blocking/synchronous mode, which means that the RPC call waits for the server to respond, and will either return a response or an error. + +#### Simple RPC + +Calling the simple RPC `GetFeature` is nearly as straightforward as calling a local method. + +```go +feature, err := client.GetFeature(context.Background(), &pb.Point{409146138, -746188906}) +if err != nil { + ... +} +``` + +As you can see, we call the method on the stub we got earlier. In our method parameters we create and populate a request protocol buffer object (in our case `Point`). We also pass a `context.Context` object which lets us change our RPC's behaviour if necessary, such as time-out/cancel an RPC in flight. If the call doesn't return an error, then we can read the response information from the server from the first return value. + +```go +log.Println(feature) +``` + +#### Server-side streaming RPC + +Here's where we call the server-side streaming method `ListFeatures`, which returns a stream of geographical `Feature`s. If you've already read [Creating the server](#server) some of this may look very familiar - streaming RPCs are implemented in a similar way on both sides. + +```go +rect := &pb.Rectangle{ ... } // initialize a pb.Rectangle +stream, err := client.ListFeatures(context.Background(), rect) +if err != nil { + ... +} +for { + feature, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + log.Fatalf("%v.ListFeatures(_) = _, %v", client, err) + } + log.Println(feature) +} +``` + +As in the simple RPC, we pass the method a context and a request. However, instead of getting a response object back, we get back an instance of `RouteGuide_ListFeaturesClient`. The client can use the `RouteGuide_ListFeaturesClient` stream to read the server's responses. + +We use the `RouteGuide_ListFeaturesClient`'s `Recv()` method to repeatedly read in the server's responses to a response protocol buffer object (in this case a `Feature`) until there are no more messages: the client needs to check the error `err` returned from `Recv()` after each call. If `nil`, the stream is still good and it can continue reading; if it's `io.EOF` then the message stream has ended; otherwise there must be an RPC error, which is passed over through `err`. + +#### Client-side streaming RPC + +The client-side streaming method `RecordRoute` is similar to the server-side method, except that we only pass the method a context and get a `RouteGuide_RecordRouteClient` stream back, which we can use to both write *and* read messages. + +```go +// Create a random number of random points +r := rand.New(rand.NewSource(time.Now().UnixNano())) +pointCount := int(r.Int31n(100)) + 2 // Traverse at least two points +var points []*pb.Point +for i := 0; i < pointCount; i++ { + points = append(points, randomPoint(r)) +} +log.Printf("Traversing %d points.", len(points)) +stream, err := client.RecordRoute(context.Background()) +if err != nil { + log.Fatalf("%v.RecordRoute(_) = _, %v", client, err) +} +for _, point := range points { + if err := stream.Send(point); err != nil { + log.Fatalf("%v.Send(%v) = %v", stream, point, err) + } +} +reply, err := stream.CloseAndRecv() +if err != nil { + log.Fatalf("%v.CloseAndRecv() got error %v, want %v", stream, err, nil) +} +log.Printf("Route summary: %v", reply) +``` + +The `RouteGuide_RecordRouteClient` has a `Send()` method that we can use to send requests to the server. Once we've finished writing our client's requests to the stream using `Send()`, we need to call `CloseAndRecv()` on the stream to let gRPC know that we've finished writing and are expecting to receive a response. We get our RPC status from the `err` returned from `CloseAndRecv()`. If the status is `nil`, then the first return value from `CloseAndRecv()` will be a valid server response. + +#### Bidirectional streaming RPC + +Finally, let's look at our bidirectional streaming RPC `RouteChat()`. As in the case of `RecordRoute`, we only pass the method a context object and get back a stream that we can use to both write and read messages. However, this time we return values via our method's stream while the server is still writing messages to *their* message stream. + +```go +stream, err := client.RouteChat(context.Background()) +waitc := make(chan struct{}) +go func() { + for { + in, err := stream.Recv() + if err == io.EOF { + // read done. + close(waitc) + return + } + if err != nil { + log.Fatalf("Failed to receive a note : %v", err) + } + log.Printf("Got message %s at point(%d, %d)", in.Message, in.Location.Latitude, in.Location.Longitude) + } +}() +for _, note := range notes { + if err := stream.Send(note); err != nil { + log.Fatalf("Failed to send a note: %v", err) + } +} +stream.CloseSend() +<-waitc +``` + +The syntax for reading and writing here is very similar to our client-side streaming method, except we use the stream's `CloseSend()` method once we've finished our call. Although each side will always get the other's messages in the order they were written, both the client and server can read and write in any order — the streams operate completely independently. + +## Try it out! + +To compile and run the server, assuming you are in the folder +`$GOPATH/src/google.golang.org/grpc/examples/route_guide`, simply: + +```sh +$ go run server/server.go +``` + +Likewise, to run the client: + +```sh +$ go run client/client.go +``` + diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/greeter_client/main.go b/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/greeter_client/main.go new file mode 100644 index 0000000..1e02ab1 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/greeter_client/main.go @@ -0,0 +1,69 @@ +/* + * + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package main + +import ( + "log" + "os" + + pb "google.golang.org/grpc/examples/helloworld/helloworld" + "golang.org/x/net/context" + "google.golang.org/grpc" +) + +const ( + address = "localhost:50051" + defaultName = "world" +) + +func main() { + // Set up a connection to the server. + conn, err := grpc.Dial(address, grpc.WithInsecure()) + if err != nil { + log.Fatalf("did not connect: %v", err) + } + defer conn.Close() + c := pb.NewGreeterClient(conn) + + // Contact the server and print out its response. + name := defaultName + if len(os.Args) > 1 { + name = os.Args[1] + } + r, err := c.SayHello(context.Background(), &pb.HelloRequest{Name: name}) + if err != nil { + log.Fatalf("could not greet: %v", err) + } + log.Printf("Greeting: %s", r.Message) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/greeter_server/main.go b/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/greeter_server/main.go new file mode 100644 index 0000000..ba985df --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/greeter_server/main.go @@ -0,0 +1,65 @@ +/* + * + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package main + +import ( + "log" + "net" + + pb "google.golang.org/grpc/examples/helloworld/helloworld" + "golang.org/x/net/context" + "google.golang.org/grpc" +) + +const ( + port = ":50051" +) + +// server is used to implement hellowrld.GreeterServer. +type server struct{} + +// SayHello implements helloworld.GreeterServer +func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { + return &pb.HelloReply{Message: "Hello " + in.Name}, nil +} + +func main() { + lis, err := net.Listen("tcp", port) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + s := grpc.NewServer() + pb.RegisterGreeterServer(s, &server{}) + s.Serve(lis) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/helloworld/helloworld.pb.go b/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/helloworld/helloworld.pb.go new file mode 100644 index 0000000..1ff931a --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/helloworld/helloworld.pb.go @@ -0,0 +1,109 @@ +// Code generated by protoc-gen-go. +// source: helloworld.proto +// DO NOT EDIT! + +/* +Package helloworld is a generated protocol buffer package. + +It is generated from these files: + helloworld.proto + +It has these top-level messages: + HelloRequest + HelloReply +*/ +package helloworld + +import proto "github.com/golang/protobuf/proto" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal + +// The request message containing the user's name. +type HelloRequest struct { + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` +} + +func (m *HelloRequest) Reset() { *m = HelloRequest{} } +func (m *HelloRequest) String() string { return proto.CompactTextString(m) } +func (*HelloRequest) ProtoMessage() {} + +// The response message containing the greetings +type HelloReply struct { + Message string `protobuf:"bytes,1,opt,name=message" json:"message,omitempty"` +} + +func (m *HelloReply) Reset() { *m = HelloReply{} } +func (m *HelloReply) String() string { return proto.CompactTextString(m) } +func (*HelloReply) ProtoMessage() {} + +func init() { +} + +// Client API for Greeter service + +type GreeterClient interface { + // Sends a greeting + SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) +} + +type greeterClient struct { + cc *grpc.ClientConn +} + +func NewGreeterClient(cc *grpc.ClientConn) GreeterClient { + return &greeterClient{cc} +} + +func (c *greeterClient) SayHello(ctx context.Context, in *HelloRequest, opts ...grpc.CallOption) (*HelloReply, error) { + out := new(HelloReply) + err := grpc.Invoke(ctx, "/helloworld.Greeter/SayHello", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for Greeter service + +type GreeterServer interface { + // Sends a greeting + SayHello(context.Context, *HelloRequest) (*HelloReply, error) +} + +func RegisterGreeterServer(s *grpc.Server, srv GreeterServer) { + s.RegisterService(&_Greeter_serviceDesc, srv) +} + +func _Greeter_SayHello_Handler(srv interface{}, ctx context.Context, codec grpc.Codec, buf []byte) (interface{}, error) { + in := new(HelloRequest) + if err := codec.Unmarshal(buf, in); err != nil { + return nil, err + } + out, err := srv.(GreeterServer).SayHello(ctx, in) + if err != nil { + return nil, err + } + return out, nil +} + +var _Greeter_serviceDesc = grpc.ServiceDesc{ + ServiceName: "helloworld.Greeter", + HandlerType: (*GreeterServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "SayHello", + Handler: _Greeter_SayHello_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/helloworld/helloworld.proto b/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/helloworld/helloworld.proto new file mode 100644 index 0000000..7d58870 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/helloworld/helloworld/helloworld.proto @@ -0,0 +1,51 @@ +// Copyright 2015, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto3"; + +option java_package = "io.grpc.examples"; +option objc_class_prefix = "HLW"; + +package helloworld; + +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/README.md b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/README.md new file mode 100644 index 0000000..7571621 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/README.md @@ -0,0 +1,35 @@ +# Description +The route guide server and client demonstrate how to use grpc go libraries to +perform unary, client streaming, server streaming and full duplex RPCs. + +Please refer to [Getting Started Guide for Go] (examples/gotutorial.md) for more information. + +See the definition of the route guide service in proto/route_guide.proto. + +# Run the sample code +To compile and run the server, assuming you are in the root of the route_guide +folder, i.e., .../examples/route_guide/, simply: + +```sh +$ go run server/server.go +``` + +Likewise, to run the client: + +```sh +$ go run client/client.go +``` + +# Optional command line flags +The server and client both take optional command line flags. For example, the +client and server run without TLS by default. To enable TLS: + +```sh +$ go run server/server.go -tls=true +``` + +and + +```sh +$ go run client/client.go -tls=true +``` diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/client/client.go b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/client/client.go new file mode 100644 index 0000000..a96c030 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/client/client.go @@ -0,0 +1,202 @@ +/* + * + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// Package main implements a simple gRPC client that demonstrates how to use gRPC-Go libraries +// to perform unary, client streaming, server streaming and full duplex RPCs. +// +// It interacts with the route guide service whose definition can be found in proto/route_guide.proto. +package main + +import ( + "flag" + "io" + "math/rand" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + pb "google.golang.org/grpc/examples/route_guide/routeguide" + "google.golang.org/grpc/grpclog" +) + +var ( + tls = flag.Bool("tls", false, "Connection uses TLS if true, else plain TCP") + caFile = flag.String("ca_file", "testdata/ca.pem", "The file containning the CA root cert file") + serverAddr = flag.String("server_addr", "127.0.0.1:10000", "The server address in the format of host:port") + serverHostOverride = flag.String("server_host_override", "x.test.youtube.com", "The server name use to verify the hostname returned by TLS handshake") +) + +// printFeature gets the feature for the given point. +func printFeature(client pb.RouteGuideClient, point *pb.Point) { + grpclog.Printf("Getting feature for point (%d, %d)", point.Latitude, point.Longitude) + feature, err := client.GetFeature(context.Background(), point) + if err != nil { + grpclog.Fatalf("%v.GetFeatures(_) = _, %v: ", client, err) + } + grpclog.Println(feature) +} + +// printFeatures lists all the features within the given bounding Rectangle. +func printFeatures(client pb.RouteGuideClient, rect *pb.Rectangle) { + grpclog.Printf("Looking for features within %v", rect) + stream, err := client.ListFeatures(context.Background(), rect) + if err != nil { + grpclog.Fatalf("%v.ListFeatures(_) = _, %v", client, err) + } + for { + feature, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + grpclog.Fatalf("%v.ListFeatures(_) = _, %v", client, err) + } + grpclog.Println(feature) + } +} + +// runRecordRoute sends a sequence of points to server and expects to get a RouteSummary from server. +func runRecordRoute(client pb.RouteGuideClient) { + // Create a random number of random points + r := rand.New(rand.NewSource(time.Now().UnixNano())) + pointCount := int(r.Int31n(100)) + 2 // Traverse at least two points + var points []*pb.Point + for i := 0; i < pointCount; i++ { + points = append(points, randomPoint(r)) + } + grpclog.Printf("Traversing %d points.", len(points)) + stream, err := client.RecordRoute(context.Background()) + if err != nil { + grpclog.Fatalf("%v.RecordRoute(_) = _, %v", client, err) + } + for _, point := range points { + if err := stream.Send(point); err != nil { + grpclog.Fatalf("%v.Send(%v) = %v", stream, point, err) + } + } + reply, err := stream.CloseAndRecv() + if err != nil { + grpclog.Fatalf("%v.CloseAndRecv() got error %v, want %v", stream, err, nil) + } + grpclog.Printf("Route summary: %v", reply) +} + +// runRouteChat receives a sequence of route notes, while sending notes for various locations. +func runRouteChat(client pb.RouteGuideClient) { + notes := []*pb.RouteNote{ + {&pb.Point{0, 1}, "First message"}, + {&pb.Point{0, 2}, "Second message"}, + {&pb.Point{0, 3}, "Third message"}, + {&pb.Point{0, 1}, "Fourth message"}, + {&pb.Point{0, 2}, "Fifth message"}, + {&pb.Point{0, 3}, "Sixth message"}, + } + stream, err := client.RouteChat(context.Background()) + if err != nil { + grpclog.Fatalf("%v.RouteChat(_) = _, %v", client, err) + } + waitc := make(chan struct{}) + go func() { + for { + in, err := stream.Recv() + if err == io.EOF { + // read done. + close(waitc) + return + } + if err != nil { + grpclog.Fatalf("Failed to receive a note : %v", err) + } + grpclog.Printf("Got message %s at point(%d, %d)", in.Message, in.Location.Latitude, in.Location.Longitude) + } + }() + for _, note := range notes { + if err := stream.Send(note); err != nil { + grpclog.Fatalf("Failed to send a note: %v", err) + } + } + stream.CloseSend() + <-waitc +} + +func randomPoint(r *rand.Rand) *pb.Point { + lat := (r.Int31n(180) - 90) * 1e7 + long := (r.Int31n(360) - 180) * 1e7 + return &pb.Point{lat, long} +} + +func main() { + flag.Parse() + var opts []grpc.DialOption + if *tls { + var sn string + if *serverHostOverride != "" { + sn = *serverHostOverride + } + var creds credentials.TransportAuthenticator + if *caFile != "" { + var err error + creds, err = credentials.NewClientTLSFromFile(*caFile, sn) + if err != nil { + grpclog.Fatalf("Failed to create TLS credentials %v", err) + } + } else { + creds = credentials.NewClientTLSFromCert(nil, sn) + } + opts = append(opts, grpc.WithTransportCredentials(creds)) + } else { + opts = append(opts, grpc.WithInsecure()) + } + conn, err := grpc.Dial(*serverAddr, opts...) + if err != nil { + grpclog.Fatalf("fail to dial: %v", err) + } + defer conn.Close() + client := pb.NewRouteGuideClient(conn) + + // Looking for a valid feature + printFeature(client, &pb.Point{409146138, -746188906}) + + // Feature missing. + printFeature(client, &pb.Point{0, 0}) + + // Looking for features between 40, -75 and 42, -73. + printFeatures(client, &pb.Rectangle{&pb.Point{400000000, -750000000}, &pb.Point{420000000, -730000000}}) + + // RecordRoute + runRecordRoute(client) + + // RouteChat + runRouteChat(client) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/routeguide/route_guide.pb.go b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/routeguide/route_guide.pb.go new file mode 100644 index 0000000..fcf5c74 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/routeguide/route_guide.pb.go @@ -0,0 +1,425 @@ +// Code generated by protoc-gen-go. +// source: route_guide.proto +// DO NOT EDIT! + +/* +Package proto is a generated protocol buffer package. + +It is generated from these files: + route_guide.proto + +It has these top-level messages: + Point + Rectangle + Feature + RouteNote + RouteSummary +*/ +package routeguide + +import proto1 "github.com/golang/protobuf/proto" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto1.Marshal + +// Points are represented as latitude-longitude pairs in the E7 representation +// (degrees multiplied by 10**7 and rounded to the nearest integer). +// Latitudes should be in the range +/- 90 degrees and longitude should be in +// the range +/- 180 degrees (inclusive). +type Point struct { + Latitude int32 `protobuf:"varint,1,opt,name=latitude" json:"latitude,omitempty"` + Longitude int32 `protobuf:"varint,2,opt,name=longitude" json:"longitude,omitempty"` +} + +func (m *Point) Reset() { *m = Point{} } +func (m *Point) String() string { return proto1.CompactTextString(m) } +func (*Point) ProtoMessage() {} + +// A latitude-longitude rectangle, represented as two diagonally opposite +// points "lo" and "hi". +type Rectangle struct { + // One corner of the rectangle. + Lo *Point `protobuf:"bytes,1,opt,name=lo" json:"lo,omitempty"` + // The other corner of the rectangle. + Hi *Point `protobuf:"bytes,2,opt,name=hi" json:"hi,omitempty"` +} + +func (m *Rectangle) Reset() { *m = Rectangle{} } +func (m *Rectangle) String() string { return proto1.CompactTextString(m) } +func (*Rectangle) ProtoMessage() {} + +func (m *Rectangle) GetLo() *Point { + if m != nil { + return m.Lo + } + return nil +} + +func (m *Rectangle) GetHi() *Point { + if m != nil { + return m.Hi + } + return nil +} + +// A feature names something at a given point. +// +// If a feature could not be named, the name is empty. +type Feature struct { + // The name of the feature. + Name string `protobuf:"bytes,1,opt,name=name" json:"name,omitempty"` + // The point where the feature is detected. + Location *Point `protobuf:"bytes,2,opt,name=location" json:"location,omitempty"` +} + +func (m *Feature) Reset() { *m = Feature{} } +func (m *Feature) String() string { return proto1.CompactTextString(m) } +func (*Feature) ProtoMessage() {} + +func (m *Feature) GetLocation() *Point { + if m != nil { + return m.Location + } + return nil +} + +// A RouteNote is a message sent while at a given point. +type RouteNote struct { + // The location from which the message is sent. + Location *Point `protobuf:"bytes,1,opt,name=location" json:"location,omitempty"` + // The message to be sent. + Message string `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` +} + +func (m *RouteNote) Reset() { *m = RouteNote{} } +func (m *RouteNote) String() string { return proto1.CompactTextString(m) } +func (*RouteNote) ProtoMessage() {} + +func (m *RouteNote) GetLocation() *Point { + if m != nil { + return m.Location + } + return nil +} + +// A RouteSummary is received in response to a RecordRoute rpc. +// +// It contains the number of individual points received, the number of +// detected features, and the total distance covered as the cumulative sum of +// the distance between each point. +type RouteSummary struct { + // The number of points received. + PointCount int32 `protobuf:"varint,1,opt,name=point_count" json:"point_count,omitempty"` + // The number of known features passed while traversing the route. + FeatureCount int32 `protobuf:"varint,2,opt,name=feature_count" json:"feature_count,omitempty"` + // The distance covered in metres. + Distance int32 `protobuf:"varint,3,opt,name=distance" json:"distance,omitempty"` + // The duration of the traversal in seconds. + ElapsedTime int32 `protobuf:"varint,4,opt,name=elapsed_time" json:"elapsed_time,omitempty"` +} + +func (m *RouteSummary) Reset() { *m = RouteSummary{} } +func (m *RouteSummary) String() string { return proto1.CompactTextString(m) } +func (*RouteSummary) ProtoMessage() {} + +func init() { +} + +// Client API for RouteGuide service + +type RouteGuideClient interface { + // A simple RPC. + // + // Obtains the feature at a given position. + // + // If no feature is found for the given point, a feature with an empty name + // should be returned. + GetFeature(ctx context.Context, in *Point, opts ...grpc.CallOption) (*Feature, error) + // A server-to-client streaming RPC. + // + // Obtains the Features available within the given Rectangle. Results are + // streamed rather than returned at once (e.g. in a response message with a + // repeated field), as the rectangle may cover a large area and contain a + // huge number of features. + ListFeatures(ctx context.Context, in *Rectangle, opts ...grpc.CallOption) (RouteGuide_ListFeaturesClient, error) + // A client-to-server streaming RPC. + // + // Accepts a stream of Points on a route being traversed, returning a + // RouteSummary when traversal is completed. + RecordRoute(ctx context.Context, opts ...grpc.CallOption) (RouteGuide_RecordRouteClient, error) + // A Bidirectional streaming RPC. + // + // Accepts a stream of RouteNotes sent while a route is being traversed, + // while receiving other RouteNotes (e.g. from other users). + RouteChat(ctx context.Context, opts ...grpc.CallOption) (RouteGuide_RouteChatClient, error) +} + +type routeGuideClient struct { + cc *grpc.ClientConn +} + +func NewRouteGuideClient(cc *grpc.ClientConn) RouteGuideClient { + return &routeGuideClient{cc} +} + +func (c *routeGuideClient) GetFeature(ctx context.Context, in *Point, opts ...grpc.CallOption) (*Feature, error) { + out := new(Feature) + err := grpc.Invoke(ctx, "/proto.RouteGuide/GetFeature", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *routeGuideClient) ListFeatures(ctx context.Context, in *Rectangle, opts ...grpc.CallOption) (RouteGuide_ListFeaturesClient, error) { + stream, err := grpc.NewClientStream(ctx, &_RouteGuide_serviceDesc.Streams[0], c.cc, "/proto.RouteGuide/ListFeatures", opts...) + if err != nil { + return nil, err + } + x := &routeGuideListFeaturesClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type RouteGuide_ListFeaturesClient interface { + Recv() (*Feature, error) + grpc.ClientStream +} + +type routeGuideListFeaturesClient struct { + grpc.ClientStream +} + +func (x *routeGuideListFeaturesClient) Recv() (*Feature, error) { + m := new(Feature) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *routeGuideClient) RecordRoute(ctx context.Context, opts ...grpc.CallOption) (RouteGuide_RecordRouteClient, error) { + stream, err := grpc.NewClientStream(ctx, &_RouteGuide_serviceDesc.Streams[1], c.cc, "/proto.RouteGuide/RecordRoute", opts...) + if err != nil { + return nil, err + } + x := &routeGuideRecordRouteClient{stream} + return x, nil +} + +type RouteGuide_RecordRouteClient interface { + Send(*Point) error + CloseAndRecv() (*RouteSummary, error) + grpc.ClientStream +} + +type routeGuideRecordRouteClient struct { + grpc.ClientStream +} + +func (x *routeGuideRecordRouteClient) Send(m *Point) error { + return x.ClientStream.SendMsg(m) +} + +func (x *routeGuideRecordRouteClient) CloseAndRecv() (*RouteSummary, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(RouteSummary) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *routeGuideClient) RouteChat(ctx context.Context, opts ...grpc.CallOption) (RouteGuide_RouteChatClient, error) { + stream, err := grpc.NewClientStream(ctx, &_RouteGuide_serviceDesc.Streams[2], c.cc, "/proto.RouteGuide/RouteChat", opts...) + if err != nil { + return nil, err + } + x := &routeGuideRouteChatClient{stream} + return x, nil +} + +type RouteGuide_RouteChatClient interface { + Send(*RouteNote) error + Recv() (*RouteNote, error) + grpc.ClientStream +} + +type routeGuideRouteChatClient struct { + grpc.ClientStream +} + +func (x *routeGuideRouteChatClient) Send(m *RouteNote) error { + return x.ClientStream.SendMsg(m) +} + +func (x *routeGuideRouteChatClient) Recv() (*RouteNote, error) { + m := new(RouteNote) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for RouteGuide service + +type RouteGuideServer interface { + // A simple RPC. + // + // Obtains the feature at a given position. + // + // If no feature is found for the given point, a feature with an empty name + // should be returned. + GetFeature(context.Context, *Point) (*Feature, error) + // A server-to-client streaming RPC. + // + // Obtains the Features available within the given Rectangle. Results are + // streamed rather than returned at once (e.g. in a response message with a + // repeated field), as the rectangle may cover a large area and contain a + // huge number of features. + ListFeatures(*Rectangle, RouteGuide_ListFeaturesServer) error + // A client-to-server streaming RPC. + // + // Accepts a stream of Points on a route being traversed, returning a + // RouteSummary when traversal is completed. + RecordRoute(RouteGuide_RecordRouteServer) error + // A Bidirectional streaming RPC. + // + // Accepts a stream of RouteNotes sent while a route is being traversed, + // while receiving other RouteNotes (e.g. from other users). + RouteChat(RouteGuide_RouteChatServer) error +} + +func RegisterRouteGuideServer(s *grpc.Server, srv RouteGuideServer) { + s.RegisterService(&_RouteGuide_serviceDesc, srv) +} + +func _RouteGuide_GetFeature_Handler(srv interface{}, ctx context.Context, codec grpc.Codec, buf []byte) (interface{}, error) { + in := new(Point) + if err := codec.Unmarshal(buf, in); err != nil { + return nil, err + } + out, err := srv.(RouteGuideServer).GetFeature(ctx, in) + if err != nil { + return nil, err + } + return out, nil +} + +func _RouteGuide_ListFeatures_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(Rectangle) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(RouteGuideServer).ListFeatures(m, &routeGuideListFeaturesServer{stream}) +} + +type RouteGuide_ListFeaturesServer interface { + Send(*Feature) error + grpc.ServerStream +} + +type routeGuideListFeaturesServer struct { + grpc.ServerStream +} + +func (x *routeGuideListFeaturesServer) Send(m *Feature) error { + return x.ServerStream.SendMsg(m) +} + +func _RouteGuide_RecordRoute_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(RouteGuideServer).RecordRoute(&routeGuideRecordRouteServer{stream}) +} + +type RouteGuide_RecordRouteServer interface { + SendAndClose(*RouteSummary) error + Recv() (*Point, error) + grpc.ServerStream +} + +type routeGuideRecordRouteServer struct { + grpc.ServerStream +} + +func (x *routeGuideRecordRouteServer) SendAndClose(m *RouteSummary) error { + return x.ServerStream.SendMsg(m) +} + +func (x *routeGuideRecordRouteServer) Recv() (*Point, error) { + m := new(Point) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func _RouteGuide_RouteChat_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(RouteGuideServer).RouteChat(&routeGuideRouteChatServer{stream}) +} + +type RouteGuide_RouteChatServer interface { + Send(*RouteNote) error + Recv() (*RouteNote, error) + grpc.ServerStream +} + +type routeGuideRouteChatServer struct { + grpc.ServerStream +} + +func (x *routeGuideRouteChatServer) Send(m *RouteNote) error { + return x.ServerStream.SendMsg(m) +} + +func (x *routeGuideRouteChatServer) Recv() (*RouteNote, error) { + m := new(RouteNote) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _RouteGuide_serviceDesc = grpc.ServiceDesc{ + ServiceName: "proto.RouteGuide", + HandlerType: (*RouteGuideServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetFeature", + Handler: _RouteGuide_GetFeature_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "ListFeatures", + Handler: _RouteGuide_ListFeatures_Handler, + ServerStreams: true, + }, + { + StreamName: "RecordRoute", + Handler: _RouteGuide_RecordRoute_Handler, + ClientStreams: true, + }, + { + StreamName: "RouteChat", + Handler: _RouteGuide_RouteChat_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/routeguide/route_guide.proto b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/routeguide/route_guide.proto new file mode 100644 index 0000000..bee7ac5 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/routeguide/route_guide.proto @@ -0,0 +1,121 @@ +// Copyright 2015, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto3"; + +package routeguide; + +// Interface exported by the server. +service RouteGuide { + // A simple RPC. + // + // Obtains the feature at a given position. + // + // If no feature is found for the given point, a feature with an empty name + // should be returned. + rpc GetFeature(Point) returns (Feature) {} + + // A server-to-client streaming RPC. + // + // Obtains the Features available within the given Rectangle. Results are + // streamed rather than returned at once (e.g. in a response message with a + // repeated field), as the rectangle may cover a large area and contain a + // huge number of features. + rpc ListFeatures(Rectangle) returns (stream Feature) {} + + // A client-to-server streaming RPC. + // + // Accepts a stream of Points on a route being traversed, returning a + // RouteSummary when traversal is completed. + rpc RecordRoute(stream Point) returns (RouteSummary) {} + + // A Bidirectional streaming RPC. + // + // Accepts a stream of RouteNotes sent while a route is being traversed, + // while receiving other RouteNotes (e.g. from other users). + rpc RouteChat(stream RouteNote) returns (stream RouteNote) {} +} + +// Points are represented as latitude-longitude pairs in the E7 representation +// (degrees multiplied by 10**7 and rounded to the nearest integer). +// Latitudes should be in the range +/- 90 degrees and longitude should be in +// the range +/- 180 degrees (inclusive). +message Point { + int32 latitude = 1; + int32 longitude = 2; +} + +// A latitude-longitude rectangle, represented as two diagonally opposite +// points "lo" and "hi". +message Rectangle { + // One corner of the rectangle. + Point lo = 1; + + // The other corner of the rectangle. + Point hi = 2; +} + +// A feature names something at a given point. +// +// If a feature could not be named, the name is empty. +message Feature { + // The name of the feature. + string name = 1; + + // The point where the feature is detected. + Point location = 2; +} + +// A RouteNote is a message sent while at a given point. +message RouteNote { + // The location from which the message is sent. + Point location = 1; + + // The message to be sent. + string message = 2; +} + +// A RouteSummary is received in response to a RecordRoute rpc. +// +// It contains the number of individual points received, the number of +// detected features, and the total distance covered as the cumulative sum of +// the distance between each point. +message RouteSummary { + // The number of points received. + int32 point_count = 1; + + // The number of known features passed while traversing the route. + int32 feature_count = 2; + + // The distance covered in metres. + int32 distance = 3; + + // The duration of the traversal in seconds. + int32 elapsed_time = 4; +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/server/server.go b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/server/server.go new file mode 100644 index 0000000..09b3942 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/examples/route_guide/server/server.go @@ -0,0 +1,239 @@ +/* + * + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// Package main implements a simple gRPC server that demonstrates how to use gRPC-Go libraries +// to perform unary, client streaming, server streaming and full duplex RPCs. +// +// It implements the route guide service whose definition can be found in proto/route_guide.proto. +package main + +import ( + "encoding/json" + "flag" + "fmt" + "io" + "io/ioutil" + "math" + "net" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + + "github.com/golang/protobuf/proto" + + pb "google.golang.org/grpc/examples/route_guide/routeguide" +) + +var ( + tls = flag.Bool("tls", false, "Connection uses TLS if true, else plain TCP") + certFile = flag.String("cert_file", "testdata/server1.pem", "The TLS cert file") + keyFile = flag.String("key_file", "testdata/server1.key", "The TLS key file") + jsonDBFile = flag.String("json_db_file", "testdata/route_guide_db.json", "A json file containing a list of features") + port = flag.Int("port", 10000, "The server port") +) + +type routeGuideServer struct { + savedFeatures []*pb.Feature + routeNotes map[string][]*pb.RouteNote +} + +// GetFeature returns the feature at the given point. +func (s *routeGuideServer) GetFeature(ctx context.Context, point *pb.Point) (*pb.Feature, error) { + for _, feature := range s.savedFeatures { + if proto.Equal(feature.Location, point) { + return feature, nil + } + } + // No feature was found, return an unnamed feature + return &pb.Feature{"", point}, nil +} + +// ListFeatures lists all features comtained within the given bounding Rectangle. +func (s *routeGuideServer) ListFeatures(rect *pb.Rectangle, stream pb.RouteGuide_ListFeaturesServer) error { + for _, feature := range s.savedFeatures { + if inRange(feature.Location, rect) { + if err := stream.Send(feature); err != nil { + return err + } + } + } + return nil +} + +// RecordRoute records a route composited of a sequence of points. +// +// It gets a stream of points, and responds with statistics about the "trip": +// number of points, number of known features visited, total distance traveled, and +// total time spent. +func (s *routeGuideServer) RecordRoute(stream pb.RouteGuide_RecordRouteServer) error { + var pointCount, featureCount, distance int32 + var lastPoint *pb.Point + startTime := time.Now() + for { + point, err := stream.Recv() + if err == io.EOF { + endTime := time.Now() + return stream.SendAndClose(&pb.RouteSummary{ + PointCount: pointCount, + FeatureCount: featureCount, + Distance: distance, + ElapsedTime: int32(endTime.Sub(startTime).Seconds()), + }) + } + if err != nil { + return err + } + pointCount++ + for _, feature := range s.savedFeatures { + if proto.Equal(feature.Location, point) { + featureCount++ + } + } + if lastPoint != nil { + distance += calcDistance(lastPoint, point) + } + lastPoint = point + } +} + +// RouteChat receives a stream of message/location pairs, and responds with a stream of all +// previous messages at each of those locations. +func (s *routeGuideServer) RouteChat(stream pb.RouteGuide_RouteChatServer) error { + for { + in, err := stream.Recv() + if err == io.EOF { + return nil + } + if err != nil { + return err + } + key := serialize(in.Location) + if _, present := s.routeNotes[key]; !present { + s.routeNotes[key] = []*pb.RouteNote{in} + } else { + s.routeNotes[key] = append(s.routeNotes[key], in) + } + for _, note := range s.routeNotes[key] { + if err := stream.Send(note); err != nil { + return err + } + } + } +} + +// loadFeatures loads features from a JSON file. +func (s *routeGuideServer) loadFeatures(filePath string) { + file, err := ioutil.ReadFile(filePath) + if err != nil { + grpclog.Fatalf("Failed to load default features: %v", err) + } + if err := json.Unmarshal(file, &s.savedFeatures); err != nil { + grpclog.Fatalf("Failed to load default features: %v", err) + } +} + +func toRadians(num float64) float64 { + return num * math.Pi / float64(180) +} + +// calcDistance calculates the distance between two points using the "haversine" formula. +// This code was taken from http://www.movable-type.co.uk/scripts/latlong.html. +func calcDistance(p1 *pb.Point, p2 *pb.Point) int32 { + const CordFactor float64 = 1e7 + const R float64 = float64(6371000) // metres + lat1 := float64(p1.Latitude) / CordFactor + lat2 := float64(p2.Latitude) / CordFactor + lng1 := float64(p1.Longitude) / CordFactor + lng2 := float64(p2.Longitude) / CordFactor + φ1 := toRadians(lat1) + φ2 := toRadians(lat2) + Δφ := toRadians(lat2 - lat1) + Δλ := toRadians(lng2 - lng1) + + a := math.Sin(Δφ/2)*math.Sin(Δφ/2) + + math.Cos(φ1)*math.Cos(φ2)* + math.Sin(Δλ/2)*math.Sin(Δλ/2) + c := 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a)) + + distance := R * c + return int32(distance) +} + +func inRange(point *pb.Point, rect *pb.Rectangle) bool { + left := math.Min(float64(rect.Lo.Longitude), float64(rect.Hi.Longitude)) + right := math.Max(float64(rect.Lo.Longitude), float64(rect.Hi.Longitude)) + top := math.Max(float64(rect.Lo.Latitude), float64(rect.Hi.Latitude)) + bottom := math.Min(float64(rect.Lo.Latitude), float64(rect.Hi.Latitude)) + + if float64(point.Longitude) >= left && + float64(point.Longitude) <= right && + float64(point.Latitude) >= bottom && + float64(point.Latitude) <= top { + return true + } + return false +} + +func serialize(point *pb.Point) string { + return fmt.Sprintf("%d %d", point.Latitude, point.Longitude) +} + +func newServer() *routeGuideServer { + s := new(routeGuideServer) + s.loadFeatures(*jsonDBFile) + s.routeNotes = make(map[string][]*pb.RouteNote) + return s +} + +func main() { + flag.Parse() + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) + if err != nil { + grpclog.Fatalf("failed to listen: %v", err) + } + var opts []grpc.ServerOption + if *tls { + creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile) + if err != nil { + grpclog.Fatalf("Failed to generate credentials %v", err) + } + opts = []grpc.ServerOption{grpc.Creds(creds)} + } + grpcServer := grpc.NewServer(opts...) + pb.RegisterRouteGuideServer(grpcServer, newServer()) + grpcServer.Serve(lis) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/grpc-auth-support.md b/Godeps/_workspace/src/google.golang.org/grpc/grpc-auth-support.md new file mode 100644 index 0000000..f80cbbd --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/grpc-auth-support.md @@ -0,0 +1,41 @@ +# Authentication + +As outlined here gRPC supports a number of different mechanisms for asserting identity between an client and server. We'll present some code-samples here demonstrating how to provide TLS support encryption and identity assertions as well as passing OAuth2 tokens to services that support it. + +# Enabling TLS on a gRPC client + +```Go +conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, "")) +``` + +# Enabling TLS on a gRPC server + +```Go +creds, err := credentials.NewServerTLSFromFile(certFile, keyFile) +if err != nil { + log.Fatalf("Failed to generate credentials %v", err) +} +lis, err := net.Listen("tcp", ":0") +server := grpc.NewServer(grpc.Creds(creds)) +... +server.Serve(lis) +``` + +# Authenticating with Google + +## Google Compute Engine (GCE) + +```Go +conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, ""), grpc.WithPerRPCCredentials(oauth.NewComputeEngine()))) +``` + +## JWT + +```Go +jwtCreds, err := oauth.NewServiceAccountFromFile(*serviceAccountKeyFile, *oauthScope) +if err != nil { + log.Fatalf("Failed to create JWT credentials: %v", err) +} +conn, err := grpc.Dial(serverAddr, grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(nil, ""), grpc.WithPerRPCCredentials(jwtCreds))) +``` + diff --git a/Godeps/_workspace/src/google.golang.org/grpc/grpclog/glogger/glogger.go b/Godeps/_workspace/src/google.golang.org/grpc/grpclog/glogger/glogger.go new file mode 100644 index 0000000..53e3c53 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/grpclog/glogger/glogger.go @@ -0,0 +1,72 @@ +/* + * + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +/* +Package glogger defines glog-based logging for grpc. +*/ +package glogger + +import ( + "github.com/golang/glog" + "google.golang.org/grpc/grpclog" +) + +func init() { + grpclog.SetLogger(&glogger{}) +} + +type glogger struct{} + +func (g *glogger) Fatal(args ...interface{}) { + glog.Fatal(args...) +} + +func (g *glogger) Fatalf(format string, args ...interface{}) { + glog.Fatalf(format, args...) +} + +func (g *glogger) Fatalln(args ...interface{}) { + glog.Fatalln(args...) +} + +func (g *glogger) Print(args ...interface{}) { + glog.Info(args...) +} + +func (g *glogger) Printf(format string, args ...interface{}) { + glog.Infof(format, args...) +} + +func (g *glogger) Println(args ...interface{}) { + glog.Infoln(args...) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/grpclog/logger.go b/Godeps/_workspace/src/google.golang.org/grpc/grpclog/logger.go new file mode 100644 index 0000000..cc6e27c --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/grpclog/logger.go @@ -0,0 +1,90 @@ +/* + * + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +/* +Package grpclog defines logging for grpc. +*/ +package grpclog + +import ( + "log" + "os" +) + +// Use golang's standard logger by default. +var logger Logger = log.New(os.Stderr, "", log.LstdFlags) + +// Logger mimics golang's standard Logger as an interface. +type Logger interface { + Fatal(args ...interface{}) + Fatalf(format string, args ...interface{}) + Fatalln(args ...interface{}) + Print(args ...interface{}) + Printf(format string, args ...interface{}) + Println(args ...interface{}) +} + +// SetLogger sets the logger that is used in grpc. +func SetLogger(l Logger) { + logger = l +} + +// Fatal is equivalent to Print() followed by a call to os.Exit() with a non-zero exit code. +func Fatal(args ...interface{}) { + logger.Fatal(args...) +} + +// Fatalf is equivalent to Printf() followed by a call to os.Exit() with a non-zero exit code. +func Fatalf(format string, args ...interface{}) { + logger.Fatalf(format, args...) +} + +// Fatalln is equivalent to Println() followed by a call to os.Exit()) with a non-zero exit code. +func Fatalln(args ...interface{}) { + logger.Fatalln(args...) +} + +// Print prints to the logger. Arguments are handled in the manner of fmt.Print. +func Print(args ...interface{}) { + logger.Print(args...) +} + +// Printf prints to the logger. Arguments are handled in the manner of fmt.Printf. +func Printf(format string, args ...interface{}) { + logger.Printf(format, args...) +} + +// Println prints to the logger. Arguments are handled in the manner of fmt.Println. +func Println(args ...interface{}) { + logger.Println(args...) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/health/grpc_health_v1alpha/health.pb.go b/Godeps/_workspace/src/google.golang.org/grpc/health/grpc_health_v1alpha/health.pb.go new file mode 100644 index 0000000..6bfbe49 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/health/grpc_health_v1alpha/health.pb.go @@ -0,0 +1,129 @@ +// Code generated by protoc-gen-go. +// source: health.proto +// DO NOT EDIT! + +/* +Package grpc_health_v1alpha is a generated protocol buffer package. + +It is generated from these files: + health.proto + +It has these top-level messages: + HealthCheckRequest + HealthCheckResponse +*/ +package grpc_health_v1alpha + +import proto "github.com/golang/protobuf/proto" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal + +type HealthCheckResponse_ServingStatus int32 + +const ( + HealthCheckResponse_UNKNOWN HealthCheckResponse_ServingStatus = 0 + HealthCheckResponse_SERVING HealthCheckResponse_ServingStatus = 1 + HealthCheckResponse_NOT_SERVING HealthCheckResponse_ServingStatus = 2 +) + +var HealthCheckResponse_ServingStatus_name = map[int32]string{ + 0: "UNKNOWN", + 1: "SERVING", + 2: "NOT_SERVING", +} +var HealthCheckResponse_ServingStatus_value = map[string]int32{ + "UNKNOWN": 0, + "SERVING": 1, + "NOT_SERVING": 2, +} + +func (x HealthCheckResponse_ServingStatus) String() string { + return proto.EnumName(HealthCheckResponse_ServingStatus_name, int32(x)) +} + +type HealthCheckRequest struct { + Service string `protobuf:"bytes,2,opt,name=service" json:"service,omitempty"` +} + +func (m *HealthCheckRequest) Reset() { *m = HealthCheckRequest{} } +func (m *HealthCheckRequest) String() string { return proto.CompactTextString(m) } +func (*HealthCheckRequest) ProtoMessage() {} + +type HealthCheckResponse struct { + Status HealthCheckResponse_ServingStatus `protobuf:"varint,1,opt,name=status,enum=grpc.health.v1alpha.HealthCheckResponse_ServingStatus" json:"status,omitempty"` +} + +func (m *HealthCheckResponse) Reset() { *m = HealthCheckResponse{} } +func (m *HealthCheckResponse) String() string { return proto.CompactTextString(m) } +func (*HealthCheckResponse) ProtoMessage() {} + +func init() { + proto.RegisterEnum("grpc.health.v1alpha.HealthCheckResponse_ServingStatus", HealthCheckResponse_ServingStatus_name, HealthCheckResponse_ServingStatus_value) +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// Client API for HealthCheck service + +type HealthCheckClient interface { + Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) +} + +type healthCheckClient struct { + cc *grpc.ClientConn +} + +func NewHealthCheckClient(cc *grpc.ClientConn) HealthCheckClient { + return &healthCheckClient{cc} +} + +func (c *healthCheckClient) Check(ctx context.Context, in *HealthCheckRequest, opts ...grpc.CallOption) (*HealthCheckResponse, error) { + out := new(HealthCheckResponse) + err := grpc.Invoke(ctx, "/grpc.health.v1alpha.HealthCheck/Check", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for HealthCheck service + +type HealthCheckServer interface { + Check(context.Context, *HealthCheckRequest) (*HealthCheckResponse, error) +} + +func RegisterHealthCheckServer(s *grpc.Server, srv HealthCheckServer) { + s.RegisterService(&_HealthCheck_serviceDesc, srv) +} + +func _HealthCheck_Check_Handler(srv interface{}, ctx context.Context, codec grpc.Codec, buf []byte) (interface{}, error) { + in := new(HealthCheckRequest) + if err := codec.Unmarshal(buf, in); err != nil { + return nil, err + } + out, err := srv.(HealthCheckServer).Check(ctx, in) + if err != nil { + return nil, err + } + return out, nil +} + +var _HealthCheck_serviceDesc = grpc.ServiceDesc{ + ServiceName: "grpc.health.v1alpha.HealthCheck", + HandlerType: (*HealthCheckServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Check", + Handler: _HealthCheck_Check_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/health/grpc_health_v1alpha/health.proto b/Godeps/_workspace/src/google.golang.org/grpc/health/grpc_health_v1alpha/health.proto new file mode 100644 index 0000000..1ca5bbc --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/health/grpc_health_v1alpha/health.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package grpc.health.v1alpha; + +message HealthCheckRequest { + string service = 2; +} + +message HealthCheckResponse { + enum ServingStatus { + UNKNOWN = 0; + SERVING = 1; + NOT_SERVING = 2; + } + ServingStatus status = 1; +} + +service HealthCheck{ + rpc Check(HealthCheckRequest) returns (HealthCheckResponse); +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/health/health.go b/Godeps/_workspace/src/google.golang.org/grpc/health/health.go new file mode 100644 index 0000000..7930bde --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/health/health.go @@ -0,0 +1,49 @@ +// Package health provides some utility functions to health-check a server. The implementation +// is based on protobuf. Users need to write their own implementations if other IDLs are used. +package health + +import ( + "sync" + + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + healthpb "google.golang.org/grpc/health/grpc_health_v1alpha" +) + +type HealthServer struct { + mu sync.Mutex + // statusMap stores the serving status of the services this HealthServer monitors. + statusMap map[string]healthpb.HealthCheckResponse_ServingStatus +} + +func NewHealthServer() *HealthServer { + return &HealthServer{ + statusMap: make(map[string]healthpb.HealthCheckResponse_ServingStatus), + } +} + +func (s *HealthServer) Check(ctx context.Context, in *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + s.mu.Lock() + defer s.mu.Unlock() + if in.Service == "" { + // check the server overall health status. + return &healthpb.HealthCheckResponse{ + Status: healthpb.HealthCheckResponse_SERVING, + }, nil + } + if status, ok := s.statusMap[in.Service]; ok { + return &healthpb.HealthCheckResponse{ + Status: status, + }, nil + } + return nil, grpc.Errorf(codes.NotFound, "unknown service") +} + +// SetServingStatus is called when need to reset the serving status of a service +// or insert a new service entry into the statusMap. +func (s *HealthServer) SetServingStatus(service string, status healthpb.HealthCheckResponse_ServingStatus) { + s.mu.Lock() + s.statusMap[service] = status + s.mu.Unlock() +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/interop/client/client.go b/Godeps/_workspace/src/google.golang.org/grpc/interop/client/client.go new file mode 100644 index 0000000..4f715d3 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/interop/client/client.go @@ -0,0 +1,581 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package main + +import ( + "flag" + "io" + "io/ioutil" + "net" + "strconv" + "strings" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/oauth" + "google.golang.org/grpc/grpclog" + testpb "google.golang.org/grpc/interop/grpc_testing" + "google.golang.org/grpc/metadata" +) + +var ( + useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP") + caFile = flag.String("tls_ca_file", "testdata/ca.pem", "The file containning the CA root cert file") + serviceAccountKeyFile = flag.String("service_account_key_file", "", "Path to service account json key file") + oauthScope = flag.String("oauth_scope", "", "The scope for OAuth2 tokens") + defaultServiceAccount = flag.String("default_service_account", "", "Email of GCE default service account") + serverHost = flag.String("server_host", "127.0.0.1", "The server host name") + serverPort = flag.Int("server_port", 10000, "The server port number") + tlsServerName = flag.String("server_host_override", "x.test.youtube.com", "The server name use to verify the hostname returned by TLS handshake if it is not empty. Otherwise, --server_host is used.") + testCase = flag.String("test_case", "large_unary", + `Configure different test cases. Valid options are: + empty_unary : empty (zero bytes) request and response; + large_unary : single request and (large) response; + client_streaming : request streaming with single response; + server_streaming : single request with response streaming; + ping_pong : full-duplex streaming; + empty_stream : full-duplex streaming with zero message; + timeout_on_sleeping_server: fullduplex streaming; + compute_engine_creds: large_unary with compute engine auth; + service_account_creds: large_unary with service account auth; + jwt_token_creds: large_unary with jwt token auth; + per_rpc_creds: large_unary with per rpc token; + oauth2_auth_token: large_unary with oauth2 token auth; + cancel_after_begin: cancellation after metadata has been sent but before payloads are sent; + cancel_after_first_response: cancellation after receiving 1st message from the server.`) +) + +var ( + reqSizes = []int{27182, 8, 1828, 45904} + respSizes = []int{31415, 9, 2653, 58979} + largeReqSize = 271828 + largeRespSize = 314159 +) + +func newPayload(t testpb.PayloadType, size int) *testpb.Payload { + if size < 0 { + grpclog.Fatalf("Requested a response with invalid length %d", size) + } + body := make([]byte, size) + switch t { + case testpb.PayloadType_COMPRESSABLE: + case testpb.PayloadType_UNCOMPRESSABLE: + grpclog.Fatalf("PayloadType UNCOMPRESSABLE is not supported") + default: + grpclog.Fatalf("Unsupported payload type: %d", t) + } + return &testpb.Payload{ + Type: t.Enum(), + Body: body, + } +} + +func doEmptyUnaryCall(tc testpb.TestServiceClient) { + reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{}) + if err != nil { + grpclog.Fatal("/TestService/EmptyCall RPC failed: ", err) + } + if !proto.Equal(&testpb.Empty{}, reply) { + grpclog.Fatalf("/TestService/EmptyCall receives %v, want %v", reply, testpb.Empty{}) + } + grpclog.Println("EmptyUnaryCall done") +} + +func doLargeUnaryCall(tc testpb.TestServiceClient) { + pl := newPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeRespSize)), + Payload: pl, + } + reply, err := tc.UnaryCall(context.Background(), req) + if err != nil { + grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err) + } + t := reply.GetPayload().GetType() + s := len(reply.GetPayload().GetBody()) + if t != testpb.PayloadType_COMPRESSABLE || s != largeRespSize { + grpclog.Fatalf("Got the reply with type %d len %d; want %d, %d", t, s, testpb.PayloadType_COMPRESSABLE, largeRespSize) + } + grpclog.Println("LargeUnaryCall done") +} + +func doClientStreaming(tc testpb.TestServiceClient) { + stream, err := tc.StreamingInputCall(context.Background()) + if err != nil { + grpclog.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err) + } + var sum int + for _, s := range reqSizes { + pl := newPayload(testpb.PayloadType_COMPRESSABLE, s) + req := &testpb.StreamingInputCallRequest{ + Payload: pl, + } + if err := stream.Send(req); err != nil { + grpclog.Fatalf("%v.Send(%v) = %v", stream, req, err) + } + sum += s + grpclog.Printf("Sent a request of size %d, aggregated size %d", s, sum) + + } + reply, err := stream.CloseAndRecv() + if err != nil { + grpclog.Fatalf("%v.CloseAndRecv() got error %v, want %v", stream, err, nil) + } + if reply.GetAggregatedPayloadSize() != int32(sum) { + grpclog.Fatalf("%v.CloseAndRecv().GetAggregatePayloadSize() = %v; want %v", stream, reply.GetAggregatedPayloadSize(), sum) + } + grpclog.Println("ClientStreaming done") +} + +func doServerStreaming(tc testpb.TestServiceClient) { + respParam := make([]*testpb.ResponseParameters, len(respSizes)) + for i, s := range respSizes { + respParam[i] = &testpb.ResponseParameters{ + Size: proto.Int32(int32(s)), + } + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + } + stream, err := tc.StreamingOutputCall(context.Background(), req) + if err != nil { + grpclog.Fatalf("%v.StreamingOutputCall(_) = _, %v", tc, err) + } + var rpcStatus error + var respCnt int + var index int + for { + reply, err := stream.Recv() + if err != nil { + rpcStatus = err + break + } + t := reply.GetPayload().GetType() + if t != testpb.PayloadType_COMPRESSABLE { + grpclog.Fatalf("Got the reply of type %d, want %d", t, testpb.PayloadType_COMPRESSABLE) + } + size := len(reply.GetPayload().GetBody()) + if size != int(respSizes[index]) { + grpclog.Fatalf("Got reply body of length %d, want %d", size, respSizes[index]) + } + index++ + respCnt++ + } + if rpcStatus != io.EOF { + grpclog.Fatalf("Failed to finish the server streaming rpc: %v", err) + } + if respCnt != len(respSizes) { + grpclog.Fatalf("Got %d reply, want %d", len(respSizes), respCnt) + } + grpclog.Println("ServerStreaming done") +} + +func doPingPong(tc testpb.TestServiceClient) { + stream, err := tc.FullDuplexCall(context.Background()) + if err != nil { + grpclog.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err) + } + var index int + for index < len(reqSizes) { + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(respSizes[index])), + }, + } + pl := newPayload(testpb.PayloadType_COMPRESSABLE, reqSizes[index]) + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: pl, + } + if err := stream.Send(req); err != nil { + grpclog.Fatalf("%v.Send(%v) = %v", stream, req, err) + } + reply, err := stream.Recv() + if err != nil { + grpclog.Fatalf("%v.Recv() = %v", stream, err) + } + t := reply.GetPayload().GetType() + if t != testpb.PayloadType_COMPRESSABLE { + grpclog.Fatalf("Got the reply of type %d, want %d", t, testpb.PayloadType_COMPRESSABLE) + } + size := len(reply.GetPayload().GetBody()) + if size != int(respSizes[index]) { + grpclog.Fatalf("Got reply body of length %d, want %d", size, respSizes[index]) + } + index++ + } + if err := stream.CloseSend(); err != nil { + grpclog.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil) + } + if _, err := stream.Recv(); err != io.EOF { + grpclog.Fatalf("%v failed to complele the ping pong test: %v", stream, err) + } + grpclog.Println("Pingpong done") +} + +func doEmptyStream(tc testpb.TestServiceClient) { + stream, err := tc.FullDuplexCall(context.Background()) + if err != nil { + grpclog.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err) + } + if err := stream.CloseSend(); err != nil { + grpclog.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil) + } + if _, err := stream.Recv(); err != io.EOF { + grpclog.Fatalf("%v failed to complete the empty stream test: %v", stream, err) + } + grpclog.Println("Emptystream done") +} + +func doTimeoutOnSleepingServer(tc testpb.TestServiceClient) { + ctx, _ := context.WithTimeout(context.Background(), 1*time.Millisecond) + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + if grpc.Code(err) == codes.DeadlineExceeded { + grpclog.Println("TimeoutOnSleepingServer done") + return + } + grpclog.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err) + } + pl := newPayload(testpb.PayloadType_COMPRESSABLE, 27182) + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + Payload: pl, + } + if err := stream.Send(req); err != nil { + grpclog.Fatalf("%v.Send(%v) = %v", stream, req, err) + } + if _, err := stream.Recv(); grpc.Code(err) != codes.DeadlineExceeded { + grpclog.Fatalf("%v.Recv() = _, %v, want error code %d", stream, err, codes.DeadlineExceeded) + } + grpclog.Println("TimeoutOnSleepingServer done") +} + +func doComputeEngineCreds(tc testpb.TestServiceClient) { + pl := newPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeRespSize)), + Payload: pl, + FillUsername: proto.Bool(true), + FillOauthScope: proto.Bool(true), + } + reply, err := tc.UnaryCall(context.Background(), req) + if err != nil { + grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err) + } + user := reply.GetUsername() + scope := reply.GetOauthScope() + if user != *defaultServiceAccount { + grpclog.Fatalf("Got user name %q, want %q.", user, *defaultServiceAccount) + } + if !strings.Contains(*oauthScope, scope) { + grpclog.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, *oauthScope) + } + grpclog.Println("ComputeEngineCreds done") +} + +func getServiceAccountJSONKey() []byte { + jsonKey, err := ioutil.ReadFile(*serviceAccountKeyFile) + if err != nil { + grpclog.Fatalf("Failed to read the service account key file: %v", err) + } + return jsonKey +} + +func doServiceAccountCreds(tc testpb.TestServiceClient) { + pl := newPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeRespSize)), + Payload: pl, + FillUsername: proto.Bool(true), + FillOauthScope: proto.Bool(true), + } + reply, err := tc.UnaryCall(context.Background(), req) + if err != nil { + grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err) + } + jsonKey := getServiceAccountJSONKey() + user := reply.GetUsername() + scope := reply.GetOauthScope() + if !strings.Contains(string(jsonKey), user) { + grpclog.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey) + } + if !strings.Contains(*oauthScope, scope) { + grpclog.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, *oauthScope) + } + grpclog.Println("ServiceAccountCreds done") +} + +func doJWTTokenCreds(tc testpb.TestServiceClient) { + pl := newPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeRespSize)), + Payload: pl, + FillUsername: proto.Bool(true), + } + reply, err := tc.UnaryCall(context.Background(), req) + if err != nil { + grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err) + } + jsonKey := getServiceAccountJSONKey() + user := reply.GetUsername() + if !strings.Contains(string(jsonKey), user) { + grpclog.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey) + } + grpclog.Println("JWTtokenCreds done") +} + +func getToken() *oauth2.Token { + jsonKey := getServiceAccountJSONKey() + config, err := google.JWTConfigFromJSON(jsonKey, *oauthScope) + if err != nil { + grpclog.Fatalf("Failed to get the config: %v", err) + } + token, err := config.TokenSource(context.Background()).Token() + if err != nil { + grpclog.Fatalf("Failed to get the token: %v", err) + } + return token +} + +func doOauth2TokenCreds(tc testpb.TestServiceClient) { + pl := newPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeRespSize)), + Payload: pl, + FillUsername: proto.Bool(true), + FillOauthScope: proto.Bool(true), + } + reply, err := tc.UnaryCall(context.Background(), req) + if err != nil { + grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err) + } + jsonKey := getServiceAccountJSONKey() + user := reply.GetUsername() + scope := reply.GetOauthScope() + if !strings.Contains(string(jsonKey), user) { + grpclog.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey) + } + if !strings.Contains(*oauthScope, scope) { + grpclog.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, *oauthScope) + } + grpclog.Println("Oauth2TokenCreds done") +} + +func doPerRPCCreds(tc testpb.TestServiceClient) { + jsonKey := getServiceAccountJSONKey() + pl := newPayload(testpb.PayloadType_COMPRESSABLE, largeReqSize) + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(largeRespSize)), + Payload: pl, + FillUsername: proto.Bool(true), + FillOauthScope: proto.Bool(true), + } + token := getToken() + kv := map[string]string{"authorization": token.TokenType + " " + token.AccessToken} + ctx := metadata.NewContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}}) + reply, err := tc.UnaryCall(ctx, req) + if err != nil { + grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err) + } + user := reply.GetUsername() + scope := reply.GetOauthScope() + if !strings.Contains(string(jsonKey), user) { + grpclog.Fatalf("Got user name %q which is NOT a substring of %q.", user, jsonKey) + } + if !strings.Contains(*oauthScope, scope) { + grpclog.Fatalf("Got OAuth scope %q which is NOT a substring of %q.", scope, *oauthScope) + } + grpclog.Println("PerRPCCreds done") +} + +var ( + testMetadata = metadata.MD{ + "key1": []string{"value1"}, + "key2": []string{"value2"}, + } +) + +func doCancelAfterBegin(tc testpb.TestServiceClient) { + ctx, cancel := context.WithCancel(metadata.NewContext(context.Background(), testMetadata)) + stream, err := tc.StreamingInputCall(ctx) + if err != nil { + grpclog.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err) + } + cancel() + _, err = stream.CloseAndRecv() + if grpc.Code(err) != codes.Canceled { + grpclog.Fatalf("%v.CloseAndRecv() got error code %d, want %d", stream, grpc.Code(err), codes.Canceled) + } + grpclog.Println("CancelAfterBegin done") +} + +func doCancelAfterFirstResponse(tc testpb.TestServiceClient) { + ctx, cancel := context.WithCancel(context.Background()) + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + grpclog.Fatalf("%v.FullDuplexCall(_) = _, %v", tc, err) + } + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(31415), + }, + } + pl := newPayload(testpb.PayloadType_COMPRESSABLE, 27182) + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: pl, + } + if err := stream.Send(req); err != nil { + grpclog.Fatalf("%v.Send(%v) = %v", stream, req, err) + } + if _, err := stream.Recv(); err != nil { + grpclog.Fatalf("%v.Recv() = %v", stream, err) + } + cancel() + if _, err := stream.Recv(); grpc.Code(err) != codes.Canceled { + grpclog.Fatalf("%v compleled with error code %d, want %d", stream, grpc.Code(err), codes.Canceled) + } + grpclog.Println("CancelAfterFirstResponse done") +} + +func main() { + flag.Parse() + serverAddr := net.JoinHostPort(*serverHost, strconv.Itoa(*serverPort)) + var opts []grpc.DialOption + if *useTLS { + var sn string + if *tlsServerName != "" { + sn = *tlsServerName + } + var creds credentials.TransportAuthenticator + if *caFile != "" { + var err error + creds, err = credentials.NewClientTLSFromFile(*caFile, sn) + if err != nil { + grpclog.Fatalf("Failed to create TLS credentials %v", err) + } + } else { + creds = credentials.NewClientTLSFromCert(nil, sn) + } + opts = append(opts, grpc.WithTransportCredentials(creds)) + if *testCase == "compute_engine_creds" { + opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewComputeEngine())) + } else if *testCase == "service_account_creds" { + jwtCreds, err := oauth.NewServiceAccountFromFile(*serviceAccountKeyFile, *oauthScope) + if err != nil { + grpclog.Fatalf("Failed to create JWT credentials: %v", err) + } + opts = append(opts, grpc.WithPerRPCCredentials(jwtCreds)) + } else if *testCase == "jwt_token_creds" { + jwtCreds, err := oauth.NewJWTAccessFromFile(*serviceAccountKeyFile) + if err != nil { + grpclog.Fatalf("Failed to create JWT credentials: %v", err) + } + opts = append(opts, grpc.WithPerRPCCredentials(jwtCreds)) + } else if *testCase == "oauth2_auth_token" { + opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(getToken()))) + } + } else { + opts = append(opts, grpc.WithInsecure()) + } + conn, err := grpc.Dial(serverAddr, opts...) + if err != nil { + grpclog.Fatalf("Fail to dial: %v", err) + } + defer conn.Close() + tc := testpb.NewTestServiceClient(conn) + switch *testCase { + case "empty_unary": + doEmptyUnaryCall(tc) + case "large_unary": + doLargeUnaryCall(tc) + case "client_streaming": + doClientStreaming(tc) + case "server_streaming": + doServerStreaming(tc) + case "ping_pong": + doPingPong(tc) + case "empty_stream": + doEmptyStream(tc) + case "timeout_on_sleeping_server": + doTimeoutOnSleepingServer(tc) + case "compute_engine_creds": + if !*useTLS { + grpclog.Fatalf("TLS is not enabled. TLS is required to execute compute_engine_creds test case.") + } + doComputeEngineCreds(tc) + case "service_account_creds": + if !*useTLS { + grpclog.Fatalf("TLS is not enabled. TLS is required to execute service_account_creds test case.") + } + doServiceAccountCreds(tc) + case "jwt_token_creds": + if !*useTLS { + grpclog.Fatalf("TLS is not enabled. TLS is required to execute jwt_token_creds test case.") + } + doJWTTokenCreds(tc) + case "per_rpc_creds": + if !*useTLS { + grpclog.Fatalf("TLS is not enabled. TLS is required to execute per_rpc_creds test case.") + } + doPerRPCCreds(tc) + case "oauth2_auth_token": + if !*useTLS { + grpclog.Fatalf("TLS is not enabled. TLS is required to execute oauth2_auth_token test case.") + } + doOauth2TokenCreds(tc) + case "cancel_after_begin": + doCancelAfterBegin(tc) + case "cancel_after_first_response": + doCancelAfterFirstResponse(tc) + default: + grpclog.Fatal("Unsupported test case: ", *testCase) + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/interop/grpc_testing/test.pb.go b/Godeps/_workspace/src/google.golang.org/grpc/interop/grpc_testing/test.pb.go new file mode 100644 index 0000000..b25e98b --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/interop/grpc_testing/test.pb.go @@ -0,0 +1,702 @@ +// Code generated by protoc-gen-go. +// source: test.proto +// DO NOT EDIT! + +/* +Package grpc_testing is a generated protocol buffer package. + +It is generated from these files: + test.proto + +It has these top-level messages: + Empty + Payload + SimpleRequest + SimpleResponse + StreamingInputCallRequest + StreamingInputCallResponse + ResponseParameters + StreamingOutputCallRequest + StreamingOutputCallResponse +*/ +package grpc_testing + +import proto "github.com/golang/protobuf/proto" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = math.Inf + +// The type of payload that should be returned. +type PayloadType int32 + +const ( + // Compressable text format. + PayloadType_COMPRESSABLE PayloadType = 0 + // Uncompressable binary format. + PayloadType_UNCOMPRESSABLE PayloadType = 1 + // Randomly chosen from all other formats defined in this enum. + PayloadType_RANDOM PayloadType = 2 +) + +var PayloadType_name = map[int32]string{ + 0: "COMPRESSABLE", + 1: "UNCOMPRESSABLE", + 2: "RANDOM", +} +var PayloadType_value = map[string]int32{ + "COMPRESSABLE": 0, + "UNCOMPRESSABLE": 1, + "RANDOM": 2, +} + +func (x PayloadType) Enum() *PayloadType { + p := new(PayloadType) + *p = x + return p +} +func (x PayloadType) String() string { + return proto.EnumName(PayloadType_name, int32(x)) +} +func (x *PayloadType) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(PayloadType_value, data, "PayloadType") + if err != nil { + return err + } + *x = PayloadType(value) + return nil +} + +type Empty struct { + XXX_unrecognized []byte `json:"-"` +} + +func (m *Empty) Reset() { *m = Empty{} } +func (m *Empty) String() string { return proto.CompactTextString(m) } +func (*Empty) ProtoMessage() {} + +// A block of data, to simply increase gRPC message size. +type Payload struct { + // The type of data in body. + Type *PayloadType `protobuf:"varint,1,opt,name=type,enum=grpc.testing.PayloadType" json:"type,omitempty"` + // Primary contents of payload. + Body []byte `protobuf:"bytes,2,opt,name=body" json:"body,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Payload) Reset() { *m = Payload{} } +func (m *Payload) String() string { return proto.CompactTextString(m) } +func (*Payload) ProtoMessage() {} + +func (m *Payload) GetType() PayloadType { + if m != nil && m.Type != nil { + return *m.Type + } + return PayloadType_COMPRESSABLE +} + +func (m *Payload) GetBody() []byte { + if m != nil { + return m.Body + } + return nil +} + +// Unary request. +type SimpleRequest struct { + // Desired payload type in the response from the server. + // If response_type is RANDOM, server randomly chooses one from other formats. + ResponseType *PayloadType `protobuf:"varint,1,opt,name=response_type,enum=grpc.testing.PayloadType" json:"response_type,omitempty"` + // Desired payload size in the response from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + ResponseSize *int32 `protobuf:"varint,2,opt,name=response_size" json:"response_size,omitempty"` + // Optional input payload sent along with the request. + Payload *Payload `protobuf:"bytes,3,opt,name=payload" json:"payload,omitempty"` + // Whether SimpleResponse should include username. + FillUsername *bool `protobuf:"varint,4,opt,name=fill_username" json:"fill_username,omitempty"` + // Whether SimpleResponse should include OAuth scope. + FillOauthScope *bool `protobuf:"varint,5,opt,name=fill_oauth_scope" json:"fill_oauth_scope,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SimpleRequest) Reset() { *m = SimpleRequest{} } +func (m *SimpleRequest) String() string { return proto.CompactTextString(m) } +func (*SimpleRequest) ProtoMessage() {} + +func (m *SimpleRequest) GetResponseType() PayloadType { + if m != nil && m.ResponseType != nil { + return *m.ResponseType + } + return PayloadType_COMPRESSABLE +} + +func (m *SimpleRequest) GetResponseSize() int32 { + if m != nil && m.ResponseSize != nil { + return *m.ResponseSize + } + return 0 +} + +func (m *SimpleRequest) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +func (m *SimpleRequest) GetFillUsername() bool { + if m != nil && m.FillUsername != nil { + return *m.FillUsername + } + return false +} + +func (m *SimpleRequest) GetFillOauthScope() bool { + if m != nil && m.FillOauthScope != nil { + return *m.FillOauthScope + } + return false +} + +// Unary response, as configured by the request. +type SimpleResponse struct { + // Payload to increase message size. + Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"` + // The user the request came from, for verifying authentication was + // successful when the client expected it. + Username *string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` + // OAuth scope. + OauthScope *string `protobuf:"bytes,3,opt,name=oauth_scope" json:"oauth_scope,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SimpleResponse) Reset() { *m = SimpleResponse{} } +func (m *SimpleResponse) String() string { return proto.CompactTextString(m) } +func (*SimpleResponse) ProtoMessage() {} + +func (m *SimpleResponse) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +func (m *SimpleResponse) GetUsername() string { + if m != nil && m.Username != nil { + return *m.Username + } + return "" +} + +func (m *SimpleResponse) GetOauthScope() string { + if m != nil && m.OauthScope != nil { + return *m.OauthScope + } + return "" +} + +// Client-streaming request. +type StreamingInputCallRequest struct { + // Optional input payload sent along with the request. + Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *StreamingInputCallRequest) Reset() { *m = StreamingInputCallRequest{} } +func (m *StreamingInputCallRequest) String() string { return proto.CompactTextString(m) } +func (*StreamingInputCallRequest) ProtoMessage() {} + +func (m *StreamingInputCallRequest) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +// Client-streaming response. +type StreamingInputCallResponse struct { + // Aggregated size of payloads received from the client. + AggregatedPayloadSize *int32 `protobuf:"varint,1,opt,name=aggregated_payload_size" json:"aggregated_payload_size,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *StreamingInputCallResponse) Reset() { *m = StreamingInputCallResponse{} } +func (m *StreamingInputCallResponse) String() string { return proto.CompactTextString(m) } +func (*StreamingInputCallResponse) ProtoMessage() {} + +func (m *StreamingInputCallResponse) GetAggregatedPayloadSize() int32 { + if m != nil && m.AggregatedPayloadSize != nil { + return *m.AggregatedPayloadSize + } + return 0 +} + +// Configuration for a particular response. +type ResponseParameters struct { + // Desired payload sizes in responses from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + Size *int32 `protobuf:"varint,1,opt,name=size" json:"size,omitempty"` + // Desired interval between consecutive responses in the response stream in + // microseconds. + IntervalUs *int32 `protobuf:"varint,2,opt,name=interval_us" json:"interval_us,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *ResponseParameters) Reset() { *m = ResponseParameters{} } +func (m *ResponseParameters) String() string { return proto.CompactTextString(m) } +func (*ResponseParameters) ProtoMessage() {} + +func (m *ResponseParameters) GetSize() int32 { + if m != nil && m.Size != nil { + return *m.Size + } + return 0 +} + +func (m *ResponseParameters) GetIntervalUs() int32 { + if m != nil && m.IntervalUs != nil { + return *m.IntervalUs + } + return 0 +} + +// Server-streaming request. +type StreamingOutputCallRequest struct { + // Desired payload type in the response from the server. + // If response_type is RANDOM, the payload from each response in the stream + // might be of different types. This is to simulate a mixed type of payload + // stream. + ResponseType *PayloadType `protobuf:"varint,1,opt,name=response_type,enum=grpc.testing.PayloadType" json:"response_type,omitempty"` + // Configuration for each expected response message. + ResponseParameters []*ResponseParameters `protobuf:"bytes,2,rep,name=response_parameters" json:"response_parameters,omitempty"` + // Optional input payload sent along with the request. + Payload *Payload `protobuf:"bytes,3,opt,name=payload" json:"payload,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *StreamingOutputCallRequest) Reset() { *m = StreamingOutputCallRequest{} } +func (m *StreamingOutputCallRequest) String() string { return proto.CompactTextString(m) } +func (*StreamingOutputCallRequest) ProtoMessage() {} + +func (m *StreamingOutputCallRequest) GetResponseType() PayloadType { + if m != nil && m.ResponseType != nil { + return *m.ResponseType + } + return PayloadType_COMPRESSABLE +} + +func (m *StreamingOutputCallRequest) GetResponseParameters() []*ResponseParameters { + if m != nil { + return m.ResponseParameters + } + return nil +} + +func (m *StreamingOutputCallRequest) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +// Server-streaming response, as configured by the request and parameters. +type StreamingOutputCallResponse struct { + // Payload to increase response size. + Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *StreamingOutputCallResponse) Reset() { *m = StreamingOutputCallResponse{} } +func (m *StreamingOutputCallResponse) String() string { return proto.CompactTextString(m) } +func (*StreamingOutputCallResponse) ProtoMessage() {} + +func (m *StreamingOutputCallResponse) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +func init() { + proto.RegisterEnum("grpc.testing.PayloadType", PayloadType_name, PayloadType_value) +} + +// Client API for TestService service + +type TestServiceClient interface { + // One empty request followed by one empty response. + EmptyCall(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) + // One request followed by one response. + // The server returns the client payload as-is. + UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) + // One request followed by a sequence of responses (streamed download). + // The server returns the payload with client desired type and sizes. + StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) + // A sequence of requests followed by one response (streamed upload). + // The server returns the aggregated size of client payload as the result. + StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) + // A sequence of requests with each request served by the server immediately. + // As one request could lead to multiple responses, this interface + // demonstrates the idea of full duplexing. + FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) + // A sequence of requests followed by a sequence of responses. + // The server buffers all the client requests and then serves them in order. A + // stream of responses are returned to the client when the server starts with + // first request. + HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) +} + +type testServiceClient struct { + cc *grpc.ClientConn +} + +func NewTestServiceClient(cc *grpc.ClientConn) TestServiceClient { + return &testServiceClient{cc} +} + +func (c *testServiceClient) EmptyCall(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/grpc.testing.TestService/EmptyCall", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *testServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) { + out := new(SimpleResponse) + err := grpc.Invoke(ctx, "/grpc.testing.TestService/UnaryCall", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *testServiceClient) StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[0], c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceStreamingOutputCallClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type TestService_StreamingOutputCallClient interface { + Recv() (*StreamingOutputCallResponse, error) + grpc.ClientStream +} + +type testServiceStreamingOutputCallClient struct { + grpc.ClientStream +} + +func (x *testServiceStreamingOutputCallClient) Recv() (*StreamingOutputCallResponse, error) { + m := new(StreamingOutputCallResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *testServiceClient) StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[1], c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceStreamingInputCallClient{stream} + return x, nil +} + +type TestService_StreamingInputCallClient interface { + Send(*StreamingInputCallRequest) error + CloseAndRecv() (*StreamingInputCallResponse, error) + grpc.ClientStream +} + +type testServiceStreamingInputCallClient struct { + grpc.ClientStream +} + +func (x *testServiceStreamingInputCallClient) Send(m *StreamingInputCallRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceStreamingInputCallClient) CloseAndRecv() (*StreamingInputCallResponse, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(StreamingInputCallResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[2], c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceFullDuplexCallClient{stream} + return x, nil +} + +type TestService_FullDuplexCallClient interface { + Send(*StreamingOutputCallRequest) error + Recv() (*StreamingOutputCallResponse, error) + grpc.ClientStream +} + +type testServiceFullDuplexCallClient struct { + grpc.ClientStream +} + +func (x *testServiceFullDuplexCallClient) Send(m *StreamingOutputCallRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceFullDuplexCallClient) Recv() (*StreamingOutputCallResponse, error) { + m := new(StreamingOutputCallResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *testServiceClient) HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[3], c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceHalfDuplexCallClient{stream} + return x, nil +} + +type TestService_HalfDuplexCallClient interface { + Send(*StreamingOutputCallRequest) error + Recv() (*StreamingOutputCallResponse, error) + grpc.ClientStream +} + +type testServiceHalfDuplexCallClient struct { + grpc.ClientStream +} + +func (x *testServiceHalfDuplexCallClient) Send(m *StreamingOutputCallRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceHalfDuplexCallClient) Recv() (*StreamingOutputCallResponse, error) { + m := new(StreamingOutputCallResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for TestService service + +type TestServiceServer interface { + // One empty request followed by one empty response. + EmptyCall(context.Context, *Empty) (*Empty, error) + // One request followed by one response. + // The server returns the client payload as-is. + UnaryCall(context.Context, *SimpleRequest) (*SimpleResponse, error) + // One request followed by a sequence of responses (streamed download). + // The server returns the payload with client desired type and sizes. + StreamingOutputCall(*StreamingOutputCallRequest, TestService_StreamingOutputCallServer) error + // A sequence of requests followed by one response (streamed upload). + // The server returns the aggregated size of client payload as the result. + StreamingInputCall(TestService_StreamingInputCallServer) error + // A sequence of requests with each request served by the server immediately. + // As one request could lead to multiple responses, this interface + // demonstrates the idea of full duplexing. + FullDuplexCall(TestService_FullDuplexCallServer) error + // A sequence of requests followed by a sequence of responses. + // The server buffers all the client requests and then serves them in order. A + // stream of responses are returned to the client when the server starts with + // first request. + HalfDuplexCall(TestService_HalfDuplexCallServer) error +} + +func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { + s.RegisterService(&_TestService_serviceDesc, srv) +} + +func _TestService_EmptyCall_Handler(srv interface{}, ctx context.Context, codec grpc.Codec, buf []byte) (interface{}, error) { + in := new(Empty) + if err := codec.Unmarshal(buf, in); err != nil { + return nil, err + } + out, err := srv.(TestServiceServer).EmptyCall(ctx, in) + if err != nil { + return nil, err + } + return out, nil +} + +func _TestService_UnaryCall_Handler(srv interface{}, ctx context.Context, codec grpc.Codec, buf []byte) (interface{}, error) { + in := new(SimpleRequest) + if err := codec.Unmarshal(buf, in); err != nil { + return nil, err + } + out, err := srv.(TestServiceServer).UnaryCall(ctx, in) + if err != nil { + return nil, err + } + return out, nil +} + +func _TestService_StreamingOutputCall_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(StreamingOutputCallRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(TestServiceServer).StreamingOutputCall(m, &testServiceStreamingOutputCallServer{stream}) +} + +type TestService_StreamingOutputCallServer interface { + Send(*StreamingOutputCallResponse) error + grpc.ServerStream +} + +type testServiceStreamingOutputCallServer struct { + grpc.ServerStream +} + +func (x *testServiceStreamingOutputCallServer) Send(m *StreamingOutputCallResponse) error { + return x.ServerStream.SendMsg(m) +} + +func _TestService_StreamingInputCall_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).StreamingInputCall(&testServiceStreamingInputCallServer{stream}) +} + +type TestService_StreamingInputCallServer interface { + SendAndClose(*StreamingInputCallResponse) error + Recv() (*StreamingInputCallRequest, error) + grpc.ServerStream +} + +type testServiceStreamingInputCallServer struct { + grpc.ServerStream +} + +func (x *testServiceStreamingInputCallServer) SendAndClose(m *StreamingInputCallResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceStreamingInputCallServer) Recv() (*StreamingInputCallRequest, error) { + m := new(StreamingInputCallRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func _TestService_FullDuplexCall_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).FullDuplexCall(&testServiceFullDuplexCallServer{stream}) +} + +type TestService_FullDuplexCallServer interface { + Send(*StreamingOutputCallResponse) error + Recv() (*StreamingOutputCallRequest, error) + grpc.ServerStream +} + +type testServiceFullDuplexCallServer struct { + grpc.ServerStream +} + +func (x *testServiceFullDuplexCallServer) Send(m *StreamingOutputCallResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceFullDuplexCallServer) Recv() (*StreamingOutputCallRequest, error) { + m := new(StreamingOutputCallRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func _TestService_HalfDuplexCall_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).HalfDuplexCall(&testServiceHalfDuplexCallServer{stream}) +} + +type TestService_HalfDuplexCallServer interface { + Send(*StreamingOutputCallResponse) error + Recv() (*StreamingOutputCallRequest, error) + grpc.ServerStream +} + +type testServiceHalfDuplexCallServer struct { + grpc.ServerStream +} + +func (x *testServiceHalfDuplexCallServer) Send(m *StreamingOutputCallResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceHalfDuplexCallServer) Recv() (*StreamingOutputCallRequest, error) { + m := new(StreamingOutputCallRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _TestService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "grpc.testing.TestService", + HandlerType: (*TestServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "EmptyCall", + Handler: _TestService_EmptyCall_Handler, + }, + { + MethodName: "UnaryCall", + Handler: _TestService_UnaryCall_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "StreamingOutputCall", + Handler: _TestService_StreamingOutputCall_Handler, + ServerStreams: true, + }, + { + StreamName: "StreamingInputCall", + Handler: _TestService_StreamingInputCall_Handler, + ClientStreams: true, + }, + { + StreamName: "FullDuplexCall", + Handler: _TestService_FullDuplexCall_Handler, + ServerStreams: true, + ClientStreams: true, + }, + { + StreamName: "HalfDuplexCall", + Handler: _TestService_HalfDuplexCall_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/interop/grpc_testing/test.proto b/Godeps/_workspace/src/google.golang.org/grpc/interop/grpc_testing/test.proto new file mode 100644 index 0000000..b5bfe05 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/interop/grpc_testing/test.proto @@ -0,0 +1,140 @@ +// An integration test service that covers all the method signature permutations +// of unary/streaming requests/responses. +syntax = "proto2"; + +package grpc.testing; + +message Empty {} + +// The type of payload that should be returned. +enum PayloadType { + // Compressable text format. + COMPRESSABLE = 0; + + // Uncompressable binary format. + UNCOMPRESSABLE = 1; + + // Randomly chosen from all other formats defined in this enum. + RANDOM = 2; +} + +// A block of data, to simply increase gRPC message size. +message Payload { + // The type of data in body. + optional PayloadType type = 1; + // Primary contents of payload. + optional bytes body = 2; +} + +// Unary request. +message SimpleRequest { + // Desired payload type in the response from the server. + // If response_type is RANDOM, server randomly chooses one from other formats. + optional PayloadType response_type = 1; + + // Desired payload size in the response from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + optional int32 response_size = 2; + + // Optional input payload sent along with the request. + optional Payload payload = 3; + + // Whether SimpleResponse should include username. + optional bool fill_username = 4; + + // Whether SimpleResponse should include OAuth scope. + optional bool fill_oauth_scope = 5; +} + +// Unary response, as configured by the request. +message SimpleResponse { + // Payload to increase message size. + optional Payload payload = 1; + + // The user the request came from, for verifying authentication was + // successful when the client expected it. + optional string username = 2; + + // OAuth scope. + optional string oauth_scope = 3; +} + +// Client-streaming request. +message StreamingInputCallRequest { + // Optional input payload sent along with the request. + optional Payload payload = 1; + + // Not expecting any payload from the response. +} + +// Client-streaming response. +message StreamingInputCallResponse { + // Aggregated size of payloads received from the client. + optional int32 aggregated_payload_size = 1; +} + +// Configuration for a particular response. +message ResponseParameters { + // Desired payload sizes in responses from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + optional int32 size = 1; + + // Desired interval between consecutive responses in the response stream in + // microseconds. + optional int32 interval_us = 2; +} + +// Server-streaming request. +message StreamingOutputCallRequest { + // Desired payload type in the response from the server. + // If response_type is RANDOM, the payload from each response in the stream + // might be of different types. This is to simulate a mixed type of payload + // stream. + optional PayloadType response_type = 1; + + // Configuration for each expected response message. + repeated ResponseParameters response_parameters = 2; + + // Optional input payload sent along with the request. + optional Payload payload = 3; +} + +// Server-streaming response, as configured by the request and parameters. +message StreamingOutputCallResponse { + // Payload to increase response size. + optional Payload payload = 1; +} + +// A simple service to test the various types of RPCs and experiment with +// performance with various types of payload. +service TestService { + // One empty request followed by one empty response. + rpc EmptyCall(Empty) returns (Empty); + + // One request followed by one response. + // The server returns the client payload as-is. + rpc UnaryCall(SimpleRequest) returns (SimpleResponse); + + // One request followed by a sequence of responses (streamed download). + // The server returns the payload with client desired type and sizes. + rpc StreamingOutputCall(StreamingOutputCallRequest) + returns (stream StreamingOutputCallResponse); + + // A sequence of requests followed by one response (streamed upload). + // The server returns the aggregated size of client payload as the result. + rpc StreamingInputCall(stream StreamingInputCallRequest) + returns (StreamingInputCallResponse); + + // A sequence of requests with each request served by the server immediately. + // As one request could lead to multiple responses, this interface + // demonstrates the idea of full duplexing. + rpc FullDuplexCall(stream StreamingOutputCallRequest) + returns (stream StreamingOutputCallResponse); + + // A sequence of requests followed by a sequence of responses. + // The server buffers all the client requests and then serves them in order. A + // stream of responses are returned to the client when the server starts with + // first request. + rpc HalfDuplexCall(stream StreamingOutputCallRequest) + returns (stream StreamingOutputCallResponse); +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/interop/server/server.go b/Godeps/_workspace/src/google.golang.org/grpc/interop/server/server.go new file mode 100644 index 0000000..f781c9d --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/interop/server/server.go @@ -0,0 +1,209 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package main + +import ( + "flag" + "fmt" + "io" + "net" + "strconv" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + testpb "google.golang.org/grpc/interop/grpc_testing" +) + +var ( + useTLS = flag.Bool("use_tls", false, "Connection uses TLS if true, else plain TCP") + certFile = flag.String("tls_cert_file", "testdata/server1.pem", "The TLS cert file") + keyFile = flag.String("tls_key_file", "testdata/server1.key", "The TLS key file") + port = flag.Int("port", 10000, "The server port") +) + +type testServer struct { +} + +func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return new(testpb.Empty), nil +} + +func newPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) { + if size < 0 { + return nil, fmt.Errorf("requested a response with invalid length %d", size) + } + body := make([]byte, size) + switch t { + case testpb.PayloadType_COMPRESSABLE: + case testpb.PayloadType_UNCOMPRESSABLE: + return nil, fmt.Errorf("payloadType UNCOMPRESSABLE is not supported") + default: + return nil, fmt.Errorf("unsupported payload type: %d", t) + } + return &testpb.Payload{ + Type: t.Enum(), + Body: body, + }, nil +} + +func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + pl, err := newPayload(in.GetResponseType(), in.GetResponseSize()) + if err != nil { + return nil, err + } + return &testpb.SimpleResponse{ + Payload: pl, + }, nil +} + +func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error { + cs := args.GetResponseParameters() + for _, c := range cs { + if us := c.GetIntervalUs(); us > 0 { + time.Sleep(time.Duration(us) * time.Microsecond) + } + pl, err := newPayload(args.GetResponseType(), c.GetSize()) + if err != nil { + return err + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: pl, + }); err != nil { + return err + } + } + return nil +} + +func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInputCallServer) error { + var sum int + for { + in, err := stream.Recv() + if err == io.EOF { + return stream.SendAndClose(&testpb.StreamingInputCallResponse{ + AggregatedPayloadSize: proto.Int32(int32(sum)), + }) + } + if err != nil { + return err + } + p := in.GetPayload().GetBody() + sum += len(p) + } +} + +func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { + for { + in, err := stream.Recv() + if err == io.EOF { + // read done. + return nil + } + if err != nil { + return err + } + cs := in.GetResponseParameters() + for _, c := range cs { + if us := c.GetIntervalUs(); us > 0 { + time.Sleep(time.Duration(us) * time.Microsecond) + } + pl, err := newPayload(in.GetResponseType(), c.GetSize()) + if err != nil { + return err + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: pl, + }); err != nil { + return err + } + } + } +} + +func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServer) error { + var msgBuf []*testpb.StreamingOutputCallRequest + for { + in, err := stream.Recv() + if err == io.EOF { + // read done. + break + } + if err != nil { + return err + } + msgBuf = append(msgBuf, in) + } + for _, m := range msgBuf { + cs := m.GetResponseParameters() + for _, c := range cs { + if us := c.GetIntervalUs(); us > 0 { + time.Sleep(time.Duration(us) * time.Microsecond) + } + pl, err := newPayload(m.GetResponseType(), c.GetSize()) + if err != nil { + return err + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: pl, + }); err != nil { + return err + } + } + } + return nil +} + +func main() { + flag.Parse() + p := strconv.Itoa(*port) + lis, err := net.Listen("tcp", ":"+p) + if err != nil { + grpclog.Fatalf("failed to listen: %v", err) + } + var opts []grpc.ServerOption + if *useTLS { + creds, err := credentials.NewServerTLSFromFile(*certFile, *keyFile) + if err != nil { + grpclog.Fatalf("Failed to generate credentials %v", err) + } + opts = []grpc.ServerOption{grpc.Creds(creds)} + } + server := grpc.NewServer(opts...) + testpb.RegisterTestServiceServer(server, &testServer{}) + server.Serve(lis) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/metadata/metadata.go b/Godeps/_workspace/src/google.golang.org/grpc/metadata/metadata.go new file mode 100644 index 0000000..5f26aba --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/metadata/metadata.go @@ -0,0 +1,146 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +// Package metadata define the structure of the metadata supported by gRPC library. +package metadata + +import ( + "encoding/base64" + "fmt" + "strings" + + "golang.org/x/net/context" +) + +const ( + binHdrSuffix = "-bin" +) + +// grpc-http2 requires ASCII header key and value (more detail can be found in +// "Requests" subsection in go/grpc-http2). +func isASCII(s string) bool { + for _, c := range s { + if c > 127 { + return false + } + } + return true +} + +// encodeKeyValue encodes key and value qualified for transmission via gRPC. +// Transmitting binary headers violates HTTP/2 spec. +// TODO(zhaoq): Maybe check if k is ASCII also. +func encodeKeyValue(k, v string) (string, string) { + if isASCII(v) { + return k, v + } + key := strings.ToLower(k + binHdrSuffix) + val := base64.StdEncoding.EncodeToString([]byte(v)) + return key, string(val) +} + +// DecodeKeyValue returns the original key and value corresponding to the +// encoded data in k, v. +func DecodeKeyValue(k, v string) (string, string, error) { + if !strings.HasSuffix(k, binHdrSuffix) { + return k, v, nil + } + key := k[:len(k)-len(binHdrSuffix)] + val, err := base64.StdEncoding.DecodeString(v) + if err != nil { + return "", "", err + } + return key, string(val), nil +} + +// MD is a mapping from metadata keys to values. Users should use the following +// two convenience functions New and Pairs to generate MD. +type MD map[string][]string + +// New creates a MD from given key-value map. +func New(m map[string]string) MD { + md := MD{} + for k, v := range m { + key, val := encodeKeyValue(k, v) + md[key] = append(md[key], val) + } + return md +} + +// Pairs returns an MD formed by the mapping of key, value ... +// Pairs panics if len(kv) is odd. +func Pairs(kv ...string) MD { + if len(kv)%2 == 1 { + panic(fmt.Sprintf("metadata: Pairs got the odd number of input pairs for metadata: %d", len(kv))) + } + md := MD{} + var k string + for i, s := range kv { + if i%2 == 0 { + k = s + continue + } + key, val := encodeKeyValue(k, s) + md[key] = append(md[key], val) + } + return md +} + +// Len returns the number of items in md. +func (md MD) Len() int { + return len(md) +} + +// Copy returns a copy of md. +func (md MD) Copy() MD { + out := MD{} + for k, v := range md { + for _, i := range v { + out[k] = append(out[k], i) + } + } + return out +} + +type mdKey struct{} + +// NewContext creates a new context with md attached. +func NewContext(ctx context.Context, md MD) context.Context { + return context.WithValue(ctx, mdKey{}, md) +} + +// FromContext returns the MD in ctx if it exists. +func FromContext(ctx context.Context) (md MD, ok bool) { + md, ok = ctx.Value(mdKey{}).(MD) + return +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/metadata/metadata_test.go b/Godeps/_workspace/src/google.golang.org/grpc/metadata/metadata_test.go new file mode 100644 index 0000000..ef3a1b8 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/metadata/metadata_test.go @@ -0,0 +1,82 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package metadata + +import ( + "reflect" + "testing" +) + +const binaryValue = string(128) + +func TestDecodeKeyValue(t *testing.T) { + for _, test := range []struct { + // input + kin string + vin string + // output + kout string + vout string + err error + }{ + {"a", "abc", "a", "abc", nil}, + {"key-bin", "Zm9vAGJhcg==", "key", "foo\x00bar", nil}, + {"key-bin", "woA=", "key", binaryValue, nil}, + } { + k, v, err := DecodeKeyValue(test.kin, test.vin) + if k != test.kout || !reflect.DeepEqual(v, test.vout) || !reflect.DeepEqual(err, test.err) { + t.Fatalf("DecodeKeyValue(%q, %q) = %q, %q, %v, want %q, %q, %v", test.kin, test.vin, k, v, err, test.kout, test.vout, test.err) + } + } +} + +func TestPairsMD(t *testing.T) { + for _, test := range []struct { + // input + kv []string + // output + md MD + }{ + {[]string{}, MD{}}, + {[]string{"k1", "v1", "k2", binaryValue}, New(map[string]string{ + "k1": "v1", + "k2-bin": "woA=", + })}, + } { + md := Pairs(test.kv...) + if !reflect.DeepEqual(md, test.md) { + t.Fatalf("Pairs(%v) = %v, want %v", test.kv, md, test.md) + } + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/naming/etcd/etcd.go b/Godeps/_workspace/src/google.golang.org/grpc/naming/etcd/etcd.go new file mode 100644 index 0000000..915e227 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/naming/etcd/etcd.go @@ -0,0 +1,145 @@ +package etcd + +import ( + "log" + "sync" + + etcdcl "github.com/coreos/etcd/client" + "golang.org/x/net/context" + "google.golang.org/grpc/naming" +) + +type kv struct { + key, value string +} + +// recvBuffer is an unbounded channel of *kv to record all the pending changes from etcd server. +type recvBuffer struct { + c chan *kv + mu sync.Mutex + stopping bool + backlog []*kv +} + +func newRecvBuffer() *recvBuffer { + b := &recvBuffer{ + c: make(chan *kv, 1), + } + return b +} + +func (b *recvBuffer) put(r *kv) { + b.mu.Lock() + defer b.mu.Unlock() + if b.stopping { + return + } + b.backlog = append(b.backlog, r) + select { + case b.c <- b.backlog[0]: + b.backlog = b.backlog[1:] + default: + } +} + +func (b *recvBuffer) load() { + b.mu.Lock() + defer b.mu.Unlock() + if b.stopping || len(b.backlog) == 0 { + return + } + select { + case b.c <- b.backlog[0]: + b.backlog = b.backlog[1:] + default: + } +} + +func (b *recvBuffer) get() <-chan *kv { + return b.c +} + +// stop terminates the recvBuffer. After it is called, the recvBuffer is not usable any more. +func (b *recvBuffer) stop() { + b.mu.Lock() + b.stopping = true + close(b.c) + b.mu.Unlock() +} + +type etcdNR struct { + kAPI etcdcl.KeysAPI + recv *recvBuffer + ctx context.Context + cancel context.CancelFunc +} + +// NewETCDNR creates an etcd NameResolver. +func NewETCDNR(cfg etcdcl.Config) (naming.Resolver, error) { + c, err := etcdcl.New(cfg) + if err != nil { + return nil, err + } + kAPI := etcdcl.NewKeysAPI(c) + ctx, cancel := context.WithCancel(context.Background()) + return &etcdNR{ + kAPI: kAPI, + recv: newRecvBuffer(), + ctx: ctx, + cancel: cancel, + }, nil +} + +// getNode builds the resulting key-value map starting from node recursively. +func getNode(node *etcdcl.Node, res map[string]string) { + if !node.Dir { + res[node.Key] = node.Value + return + } + for _, val := range node.Nodes { + getNode(val, res) + } +} + +func (nr *etcdNR) Get(target string) map[string]string { + resp, err := nr.kAPI.Get(nr.ctx, target, &etcdcl.GetOptions{Recursive: true, Sort: true}) + if err != nil { + log.Printf("etcdNR.Get(_) stopped: %v", err) + return nil + } + res := make(map[string]string) + getNode(resp.Node, res) + return res +} + +func (nr *etcdNR) Watch(target string) { + watcher := nr.kAPI.Watcher(target, &etcdcl.WatcherOptions{Recursive: true}) + for { + resp, err := watcher.Next(nr.ctx) + if err != nil { + log.Printf("etcdNR.Watch(_) stopped: %v", err) + break + } + if resp.Node.Dir { + continue + } + entry := &kv{key: resp.Node.Key, value: resp.Node.Value} + nr.recv.put(entry) + } +} + +func (nr *etcdNR) GetUpdate() (string, string) { + i := <-nr.recv.get() + nr.recv.load() + if i == nil { + return "", "" + } + // returns key and the corresponding value of the updated kv pair + return i.key, i.value + +} + +func (nr *etcdNR) Stop() { + nr.recv.stop() + nr.cancel() +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/naming/naming.go b/Godeps/_workspace/src/google.golang.org/grpc/naming/naming.go new file mode 100644 index 0000000..a6a319f --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/naming/naming.go @@ -0,0 +1,13 @@ +package naming + +// Resolver dose name resolution and watches for the resolution changes. +type Resolver interface { + // Get gets a snapshot of the current name resolution results for target. + Get(target string) map[string]string + // Watch watches for the name resolution changes on target. It blocks until Stop() is invoked. The watch results are obtained via GetUpdate(). + Watch(target string) + // GetUpdate returns a name resolution change when watch is triggered. It blocks until it observes a change. The caller needs to call it again to get the next change. + GetUpdate() (string, string) + // Stop shuts down the NameResolver. + Stop() +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/rpc_util.go b/Godeps/_workspace/src/google.golang.org/grpc/rpc_util.go new file mode 100644 index 0000000..a0f0b48 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/rpc_util.go @@ -0,0 +1,307 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math/rand" + "os" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/transport" +) + +// Codec defines the interface gRPC uses to encode and decode messages. +type Codec interface { + // Marshal returns the wire format of v. + Marshal(v interface{}) ([]byte, error) + // Unmarshal parses the wire format into v. + Unmarshal(data []byte, v interface{}) error + // String returns the name of the Codec implementation. The returned + // string will be used as part of content type in transmission. + String() string +} + +// protoCodec is a Codec implemetation with protobuf. It is the default codec for gRPC. +type protoCodec struct{} + +func (protoCodec) Marshal(v interface{}) ([]byte, error) { + return proto.Marshal(v.(proto.Message)) +} + +func (protoCodec) Unmarshal(data []byte, v interface{}) error { + return proto.Unmarshal(data, v.(proto.Message)) +} + +func (protoCodec) String() string { + return "proto" +} + +// CallOption configures a Call before it starts or extracts information from +// a Call after it completes. +type CallOption interface { + // before is called before the call is sent to any server. If before + // returns a non-nil error, the RPC fails with that error. + before(*callInfo) error + + // after is called after the call has completed. after cannot return an + // error, so any failures should be reported via output parameters. + after(*callInfo) +} + +type beforeCall func(c *callInfo) error + +func (o beforeCall) before(c *callInfo) error { return o(c) } +func (o beforeCall) after(c *callInfo) {} + +type afterCall func(c *callInfo) + +func (o afterCall) before(c *callInfo) error { return nil } +func (o afterCall) after(c *callInfo) { o(c) } + +// Header returns a CallOptions that retrieves the header metadata +// for a unary RPC. +func Header(md *metadata.MD) CallOption { + return afterCall(func(c *callInfo) { + *md = c.headerMD + }) +} + +// Trailer returns a CallOptions that retrieves the trailer metadata +// for a unary RPC. +func Trailer(md *metadata.MD) CallOption { + return afterCall(func(c *callInfo) { + *md = c.trailerMD + }) +} + +// The format of the payload: compressed or not? +type payloadFormat uint8 + +const ( + compressionNone payloadFormat = iota // no compression + compressionFlate + // More formats +) + +// parser reads complelete gRPC messages from the underlying reader. +type parser struct { + s io.Reader +} + +// msgFixedHeader defines the header of a gRPC message (go/grpc-wirefmt). +type msgFixedHeader struct { + T payloadFormat + Length uint32 +} + +// recvMsg is to read a complete gRPC message from the stream. It is blocking if +// the message has not been complete yet. It returns the message and its type, +// EOF is returned with nil msg and 0 pf if the entire stream is done. Other +// non-nil error is returned if something is wrong on reading. +func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) { + var hdr msgFixedHeader + if err := binary.Read(p.s, binary.BigEndian, &hdr); err != nil { + return 0, nil, err + } + if hdr.Length == 0 { + return hdr.T, nil, nil + } + msg = make([]byte, int(hdr.Length)) + if _, err := io.ReadFull(p.s, msg); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return 0, nil, err + } + return hdr.T, msg, nil +} + +// encode serializes msg and prepends the message header. If msg is nil, it +// generates the message header of 0 message length. +func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) { + var buf bytes.Buffer + // Write message fixed header. + buf.WriteByte(uint8(pf)) + var b []byte + var length uint32 + if msg != nil { + var err error + // TODO(zhaoq): optimize to reduce memory alloc and copying. + b, err = c.Marshal(msg) + if err != nil { + return nil, err + } + length = uint32(len(b)) + } + var szHdr [4]byte + binary.BigEndian.PutUint32(szHdr[:], length) + buf.Write(szHdr[:]) + buf.Write(b) + return buf.Bytes(), nil +} + +func recv(p *parser, c Codec, m interface{}) error { + pf, d, err := p.recvMsg() + if err != nil { + return err + } + switch pf { + case compressionNone: + if err := c.Unmarshal(d, m); err != nil { + return Errorf(codes.Internal, "grpc: %v", err) + } + default: + return Errorf(codes.Internal, "gprc: compression is not supported yet.") + } + return nil +} + +// rpcError defines the status from an RPC. +type rpcError struct { + code codes.Code + desc string +} + +func (e rpcError) Error() string { + return fmt.Sprintf("rpc error: code = %d desc = %q", e.code, e.desc) +} + +// Code returns the error code for err if it was produced by the rpc system. +// Otherwise, it returns codes.Unknown. +func Code(err error) codes.Code { + if err == nil { + return codes.OK + } + if e, ok := err.(rpcError); ok { + return e.code + } + return codes.Unknown +} + +// Errorf returns an error containing an error code and a description; +// Errorf returns nil if c is OK. +func Errorf(c codes.Code, format string, a ...interface{}) error { + if c == codes.OK { + return nil + } + return rpcError{ + code: c, + desc: fmt.Sprintf(format, a...), + } +} + +// toRPCErr converts an error into a rpcError. +func toRPCErr(err error) error { + switch e := err.(type) { + case transport.StreamError: + return rpcError{ + code: e.Code, + desc: e.Desc, + } + case transport.ConnectionError: + return rpcError{ + code: codes.Internal, + desc: e.Desc, + } + } + return Errorf(codes.Unknown, "%v", err) +} + +// convertCode converts a standard Go error into its canonical code. Note that +// this is only used to translate the error returned by the server applications. +func convertCode(err error) codes.Code { + switch err { + case nil: + return codes.OK + case io.EOF: + return codes.OutOfRange + case io.ErrClosedPipe, io.ErrNoProgress, io.ErrShortBuffer, io.ErrShortWrite, io.ErrUnexpectedEOF: + return codes.FailedPrecondition + case os.ErrInvalid: + return codes.InvalidArgument + case context.Canceled: + return codes.Canceled + case context.DeadlineExceeded: + return codes.DeadlineExceeded + } + switch { + case os.IsExist(err): + return codes.AlreadyExists + case os.IsNotExist(err): + return codes.NotFound + case os.IsPermission(err): + return codes.PermissionDenied + } + return codes.Unknown +} + +const ( + // how long to wait after the first failure before retrying + baseDelay = 1.0 * time.Second + // upper bound of backoff delay + maxDelay = 120 * time.Second + // backoff increases by this factor on each retry + backoffFactor = 1.6 + // backoff is randomized downwards by this factor + backoffJitter = 0.2 +) + +func backoff(retries int) (t time.Duration) { + if retries == 0 { + return baseDelay + } + backoff, max := float64(baseDelay), float64(maxDelay) + for backoff < max && retries > 0 { + backoff *= backoffFactor + retries-- + } + if backoff > max { + backoff = max + } + // Randomize backoff delays so that if a cluster of requests start at + // the same time, they won't operate in lockstep. + backoff *= 1 + backoffJitter*(rand.Float64()*2-1) + if backoff < 0 { + return 0 + } + return time.Duration(backoff) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/rpc_util_test.go b/Godeps/_workspace/src/google.golang.org/grpc/rpc_util_test.go new file mode 100644 index 0000000..1a29684 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/rpc_util_test.go @@ -0,0 +1,190 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc + +import ( + "bytes" + "io" + "reflect" + "testing" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + perfpb "google.golang.org/grpc/test/codec_perf" + "google.golang.org/grpc/transport" +) + +func TestSimpleParsing(t *testing.T) { + for _, test := range []struct { + // input + p []byte + // outputs + err error + b []byte + pt payloadFormat + }{ + {nil, io.EOF, nil, compressionNone}, + {[]byte{0, 0, 0, 0, 0}, nil, nil, compressionNone}, + {[]byte{0, 0, 0, 0, 1, 'a'}, nil, []byte{'a'}, compressionNone}, + {[]byte{1, 0}, io.ErrUnexpectedEOF, nil, compressionNone}, + {[]byte{0, 0, 0, 0, 10, 'a'}, io.ErrUnexpectedEOF, nil, compressionNone}, + } { + buf := bytes.NewReader(test.p) + parser := &parser{buf} + pt, b, err := parser.recvMsg() + if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt { + t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err) + } + } +} + +func TestMultipleParsing(t *testing.T) { + // Set a byte stream consists of 3 messages with their headers. + p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'} + b := bytes.NewReader(p) + parser := &parser{b} + + wantRecvs := []struct { + pt payloadFormat + data []byte + }{ + {compressionNone, []byte("a")}, + {compressionNone, []byte("bc")}, + {compressionNone, []byte("d")}, + } + for i, want := range wantRecvs { + pt, data, err := parser.recvMsg() + if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) { + t.Fatalf("after %d calls, parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, ", + i, p, pt, data, err, want.pt, want.data) + } + } + + pt, data, err := parser.recvMsg() + if err != io.EOF { + t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg() = %v, %v, %v\nwant _, _, %v", + len(wantRecvs), p, pt, data, err, io.EOF) + } +} + +func TestEncode(t *testing.T) { + for _, test := range []struct { + // input + msg proto.Message + pt payloadFormat + // outputs + b []byte + err error + }{ + {nil, compressionNone, []byte{0, 0, 0, 0, 0}, nil}, + } { + b, err := encode(protoCodec{}, test.msg, test.pt) + if err != test.err || !bytes.Equal(b, test.b) { + t.Fatalf("encode(_, _, %d) = %v, %v\nwant %v, %v", test.pt, b, err, test.b, test.err) + } + } +} + +func TestToRPCErr(t *testing.T) { + for _, test := range []struct { + // input + errIn error + // outputs + errOut error + }{ + {transport.StreamErrorf(codes.Unknown, ""), Errorf(codes.Unknown, "")}, + {transport.ErrConnClosing, Errorf(codes.Internal, transport.ErrConnClosing.Desc)}, + } { + err := toRPCErr(test.errIn) + if err != test.errOut { + t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) + } + } +} + +func TestContextErr(t *testing.T) { + for _, test := range []struct { + // input + errIn error + // outputs + errOut transport.StreamError + }{ + {context.DeadlineExceeded, transport.StreamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)}, + {context.Canceled, transport.StreamErrorf(codes.Canceled, "%v", context.Canceled)}, + } { + err := transport.ContextErr(test.errIn) + if err != test.errOut { + t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) + } + } +} + +// bmEncode benchmarks encoding a Protocol Buffer message containing mSize +// bytes. +func bmEncode(b *testing.B, mSize int) { + msg := &perfpb.Buffer{Body: make([]byte, mSize)} + encoded, _ := encode(protoCodec{}, msg, compressionNone) + encodedSz := int64(len(encoded)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + encode(protoCodec{}, msg, compressionNone) + } + b.SetBytes(encodedSz) +} + +func BenchmarkEncode1B(b *testing.B) { + bmEncode(b, 1) +} + +func BenchmarkEncode1KiB(b *testing.B) { + bmEncode(b, 1024) +} + +func BenchmarkEncode8KiB(b *testing.B) { + bmEncode(b, 8*1024) +} + +func BenchmarkEncode64KiB(b *testing.B) { + bmEncode(b, 64*1024) +} + +func BenchmarkEncode512KiB(b *testing.B) { + bmEncode(b, 512*1024) +} + +func BenchmarkEncode1MiB(b *testing.B) { + bmEncode(b, 1024*1024) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/server.go b/Godeps/_workspace/src/google.golang.org/grpc/server.go new file mode 100644 index 0000000..feb2988 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/server.go @@ -0,0 +1,463 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc + +import ( + "errors" + "fmt" + "io" + "net" + "reflect" + "strings" + "sync" + + "golang.org/x/net/context" + "golang.org/x/net/trace" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/transport" +) + +type methodHandler func(srv interface{}, ctx context.Context, codec Codec, buf []byte) (interface{}, error) + +// MethodDesc represents an RPC service's method specification. +type MethodDesc struct { + MethodName string + Handler methodHandler +} + +// ServiceDesc represents an RPC service's specification. +type ServiceDesc struct { + ServiceName string + // The pointer to the service interface. Used to check whether the user + // provided implementation satisfies the interface requirements. + HandlerType interface{} + Methods []MethodDesc + Streams []StreamDesc +} + +// service consists of the information of the server serving this service and +// the methods in this service. +type service struct { + server interface{} // the server for service methods + md map[string]*MethodDesc + sd map[string]*StreamDesc +} + +// Server is a gRPC server to serve RPC requests. +type Server struct { + opts options + mu sync.Mutex + lis map[net.Listener]bool + conns map[transport.ServerTransport]bool + m map[string]*service // service name -> service info +} + +type options struct { + creds credentials.Credentials + codec Codec + maxConcurrentStreams uint32 +} + +// A ServerOption sets options. +type ServerOption func(*options) + +// CustomCodec returns a ServerOption that sets a codec for message marshaling and unmarshaling. +func CustomCodec(codec Codec) ServerOption { + return func(o *options) { + o.codec = codec + } +} + +// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number +// of concurrent streams to each ServerTransport. +func MaxConcurrentStreams(n uint32) ServerOption { + return func(o *options) { + o.maxConcurrentStreams = n + } +} + +// Creds returns a ServerOption that sets credentials for server connections. +func Creds(c credentials.Credentials) ServerOption { + return func(o *options) { + o.creds = c + } +} + +// NewServer creates a gRPC server which has no service registered and has not +// started to accept requests yet. +func NewServer(opt ...ServerOption) *Server { + var opts options + for _, o := range opt { + o(&opts) + } + if opts.codec == nil { + // Set the default codec. + opts.codec = protoCodec{} + } + return &Server{ + lis: make(map[net.Listener]bool), + opts: opts, + conns: make(map[transport.ServerTransport]bool), + m: make(map[string]*service), + } +} + +// RegisterService register a service and its implementation to the gRPC +// server. Called from the IDL generated code. This must be called before +// invoking Serve. +func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) { + ht := reflect.TypeOf(sd.HandlerType).Elem() + st := reflect.TypeOf(ss) + if !st.Implements(ht) { + grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht) + } + s.register(sd, ss) +} + +func (s *Server) register(sd *ServiceDesc, ss interface{}) { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.m[sd.ServiceName]; ok { + grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName) + } + srv := &service{ + server: ss, + md: make(map[string]*MethodDesc), + sd: make(map[string]*StreamDesc), + } + for i := range sd.Methods { + d := &sd.Methods[i] + srv.md[d.MethodName] = d + } + for i := range sd.Streams { + d := &sd.Streams[i] + srv.sd[d.StreamName] = d + } + s.m[sd.ServiceName] = srv +} + +var ( + // ErrServerStopped indicates that the operation is now illegal because of + // the server being stopped. + ErrServerStopped = errors.New("grpc: the server has been stopped") +) + +// Serve accepts incoming connections on the listener lis, creating a new +// ServerTransport and service goroutine for each. The service goroutines +// read gRPC request and then call the registered handlers to reply to them. +// Service returns when lis.Accept fails. +func (s *Server) Serve(lis net.Listener) error { + s.mu.Lock() + if s.lis == nil { + s.mu.Unlock() + return ErrServerStopped + } + s.lis[lis] = true + s.mu.Unlock() + defer func() { + lis.Close() + s.mu.Lock() + delete(s.lis, lis) + s.mu.Unlock() + }() + for { + c, err := lis.Accept() + if err != nil { + return err + } + var authInfo credentials.AuthInfo + if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok { + c, authInfo, err = creds.ServerHandshake(c) + if err != nil { + grpclog.Println("grpc: Server.Serve failed to complete security handshake.") + continue + } + } + s.mu.Lock() + if s.conns == nil { + s.mu.Unlock() + c.Close() + return nil + } + st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo) + if err != nil { + s.mu.Unlock() + c.Close() + grpclog.Println("grpc: Server.Serve failed to create ServerTransport: ", err) + continue + } + s.conns[st] = true + s.mu.Unlock() + + go func() { + st.HandleStreams(func(stream *transport.Stream) { + s.handleStream(st, stream) + }) + s.mu.Lock() + delete(s.conns, st) + s.mu.Unlock() + }() + } +} + +func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, pf payloadFormat, opts *transport.Options) error { + p, err := encode(s.opts.codec, msg, pf) + if err != nil { + // This typically indicates a fatal issue (e.g., memory + // corruption or hardware faults) the application program + // cannot handle. + // + // TODO(zhaoq): There exist other options also such as only closing the + // faulty stream locally and remotely (Other streams can keep going). Find + // the optimal option. + grpclog.Fatalf("grpc: Server failed to encode response %v", err) + } + return t.Write(stream, p, opts) +} + +func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc) (err error) { + var traceInfo traceInfo + if EnableTracing { + traceInfo.tr = trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) + defer traceInfo.tr.Finish() + traceInfo.firstLine.client = false + traceInfo.tr.LazyLog(&traceInfo.firstLine, false) + defer func() { + if err != nil && err != io.EOF { + traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + traceInfo.tr.SetError() + } + }() + } + p := &parser{s: stream} + for { + pf, req, err := p.recvMsg() + if err == io.EOF { + // The entire stream is done (for unary RPC only). + return err + } + if err != nil { + switch err := err.(type) { + case transport.ConnectionError: + // Nothing to do here. + case transport.StreamError: + if err := t.WriteStatus(stream, err.Code, err.Desc); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) + } + default: + panic(fmt.Sprintf("grpc: Unexpected error (%T) from recvMsg: %v", err, err)) + } + return err + } + if traceInfo.tr != nil { + traceInfo.tr.LazyLog(&payload{sent: false, msg: req}, true) + } + switch pf { + case compressionNone: + statusCode := codes.OK + statusDesc := "" + reply, appErr := md.Handler(srv.server, stream.Context(), s.opts.codec, req) + if appErr != nil { + if err, ok := appErr.(rpcError); ok { + statusCode = err.code + statusDesc = err.desc + } else { + statusCode = convertCode(appErr) + statusDesc = appErr.Error() + } + if err := t.WriteStatus(stream, statusCode, statusDesc); err != nil { + grpclog.Printf("grpc: Server.processUnaryRPC failed to write status: %v", err) + return err + } + return nil + } + opts := &transport.Options{ + Last: true, + Delay: false, + } + if err := s.sendResponse(t, stream, reply, compressionNone, opts); err != nil { + switch err := err.(type) { + case transport.ConnectionError: + // Nothing to do here. + case transport.StreamError: + statusCode = err.Code + statusDesc = err.Desc + default: + statusCode = codes.Unknown + statusDesc = err.Error() + } + return err + } + if traceInfo.tr != nil { + traceInfo.tr.LazyLog(&payload{sent: true, msg: reply}, true) + } + return t.WriteStatus(stream, statusCode, statusDesc) + default: + panic(fmt.Sprintf("payload format to be supported: %d", pf)) + } + } +} + +func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, sd *StreamDesc) (err error) { + ss := &serverStream{ + t: t, + s: stream, + p: &parser{s: stream}, + codec: s.opts.codec, + tracing: EnableTracing, + } + if ss.tracing { + ss.traceInfo.tr = trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) + ss.traceInfo.firstLine.client = false + ss.traceInfo.tr.LazyLog(&ss.traceInfo.firstLine, false) + defer func() { + ss.mu.Lock() + if err != nil && err != io.EOF { + ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + ss.traceInfo.tr.SetError() + } + ss.traceInfo.tr.Finish() + ss.traceInfo.tr = nil + ss.mu.Unlock() + }() + } + if appErr := sd.Handler(srv.server, ss); appErr != nil { + if err, ok := appErr.(rpcError); ok { + ss.statusCode = err.code + ss.statusDesc = err.desc + } else { + ss.statusCode = convertCode(appErr) + ss.statusDesc = appErr.Error() + } + } + return t.WriteStatus(ss.s, ss.statusCode, ss.statusDesc) + +} + +func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { + sm := stream.Method() + if sm != "" && sm[0] == '/' { + sm = sm[1:] + } + pos := strings.LastIndex(sm, "/") + if pos == -1 { + if err := t.WriteStatus(stream, codes.InvalidArgument, fmt.Sprintf("malformed method name: %q", stream.Method())); err != nil { + grpclog.Printf("grpc: Server.handleStream failed to write status: %v", err) + } + return + } + service := sm[:pos] + method := sm[pos+1:] + srv, ok := s.m[service] + if !ok { + if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown service %v", service)); err != nil { + grpclog.Printf("grpc: Server.handleStream failed to write status: %v", err) + } + return + } + // Unary RPC or Streaming RPC? + if md, ok := srv.md[method]; ok { + s.processUnaryRPC(t, stream, srv, md) + return + } + if sd, ok := srv.sd[method]; ok { + s.processStreamingRPC(t, stream, srv, sd) + return + } + if err := t.WriteStatus(stream, codes.Unimplemented, fmt.Sprintf("unknown method %v", method)); err != nil { + grpclog.Printf("grpc: Server.handleStream failed to write status: %v", err) + } +} + +// Stop stops the gRPC server. Once Stop returns, the server stops accepting +// connection requests and closes all the connected connections. +func (s *Server) Stop() { + s.mu.Lock() + listeners := s.lis + s.lis = nil + cs := s.conns + s.conns = nil + s.mu.Unlock() + for lis := range listeners { + lis.Close() + } + for c := range cs { + c.Close() + } +} + +// TestingCloseConns closes all exiting transports but keeps s.lis accepting new +// connections. This is for test only now. +func (s *Server) TestingCloseConns() { + s.mu.Lock() + for c := range s.conns { + c.Close() + } + s.conns = make(map[transport.ServerTransport]bool) + s.mu.Unlock() +} + +// SendHeader sends header metadata. It may be called at most once from a unary +// RPC handler. The ctx is the RPC handler's Context or one derived from it. +func SendHeader(ctx context.Context, md metadata.MD) error { + if md.Len() == 0 { + return nil + } + stream, ok := transport.StreamFromContext(ctx) + if !ok { + return fmt.Errorf("grpc: failed to fetch the stream from the context %v", ctx) + } + t := stream.ServerTransport() + if t == nil { + grpclog.Fatalf("grpc: SendHeader: %v has no ServerTransport to send header metadata.", stream) + } + return t.WriteHeader(stream, md) +} + +// SetTrailer sets the trailer metadata that will be sent when an RPC returns. +// It may be called at most once from a unary RPC handler. The ctx is the RPC +// handler's Context or one derived from it. +func SetTrailer(ctx context.Context, md metadata.MD) error { + if md.Len() == 0 { + return nil + } + stream, ok := transport.StreamFromContext(ctx) + if !ok { + return fmt.Errorf("grpc: failed to fetch the stream from the context %v", ctx) + } + return stream.SetTrailer(md) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/stream.go b/Godeps/_workspace/src/google.golang.org/grpc/stream.go new file mode 100644 index 0000000..5c99bff --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/stream.go @@ -0,0 +1,346 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc + +import ( + "errors" + "io" + "sync" + "time" + + "golang.org/x/net/context" + "golang.org/x/net/trace" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/transport" +) + +type streamHandler func(srv interface{}, stream ServerStream) error + +// StreamDesc represents a streaming RPC service's method specification. +type StreamDesc struct { + StreamName string + Handler streamHandler + + // At least one of these is true. + ServerStreams bool + ClientStreams bool +} + +// Stream defines the common interface a client or server stream has to satisfy. +type Stream interface { + // Context returns the context for this stream. + Context() context.Context + // SendMsg blocks until it sends m, the stream is done or the stream + // breaks. + // On error, it aborts the stream and returns an RPC status on client + // side. On server side, it simply returns the error to the caller. + // SendMsg is called by generated code. + SendMsg(m interface{}) error + // RecvMsg blocks until it receives a message or the stream is + // done. On client side, it returns io.EOF when the stream is done. On + // any other error, it aborts the streama nd returns an RPC status. On + // server side, it simply returns the error to the caller. + RecvMsg(m interface{}) error +} + +// ClientStream defines the interface a client stream has to satify. +type ClientStream interface { + // Header returns the header metedata received from the server if there + // is any. It blocks if the metadata is not ready to read. + Header() (metadata.MD, error) + // Trailer returns the trailer metadata from the server. It must be called + // after stream.Recv() returns non-nil error (including io.EOF) for + // bi-directional streaming and server streaming or stream.CloseAndRecv() + // returns for client streaming in order to receive trailer metadata if + // present. Otherwise, it could returns an empty MD even though trailer + // is present. + Trailer() metadata.MD + // CloseSend closes the send direction of the stream. It closes the stream + // when non-nil error is met. + CloseSend() error + Stream +} + +// NewClientStream creates a new Stream for the client side. This is called +// by generated code. +func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) { + // TODO(zhaoq): CallOption is omitted. Add support when it is needed. + callHdr := &transport.CallHdr{ + Host: cc.authority, + Method: method, + } + cs := &clientStream{ + desc: desc, + codec: cc.dopts.codec, + tracing: EnableTracing, + } + if cs.tracing { + cs.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) + cs.traceInfo.firstLine.client = true + if deadline, ok := ctx.Deadline(); ok { + cs.traceInfo.firstLine.deadline = deadline.Sub(time.Now()) + } + cs.traceInfo.tr.LazyLog(&cs.traceInfo.firstLine, false) + } + t, _, err := cc.wait(ctx, 0) + if err != nil { + return nil, toRPCErr(err) + } + s, err := t.NewStream(ctx, callHdr) + if err != nil { + return nil, toRPCErr(err) + } + cs.t = t + cs.s = s + cs.p = &parser{s: s} + return cs, nil +} + +// clientStream implements a client side Stream. +type clientStream struct { + t transport.ClientTransport + s *transport.Stream + p *parser + desc *StreamDesc + codec Codec + + tracing bool // set to EnableTracing when the clientStream is created. + + mu sync.Mutex // protects traceInfo + // traceInfo.tr is set when the clientStream is created (if EnableTracing is true), + // and is set to nil when the clientStream's finish method is called. + traceInfo traceInfo +} + +func (cs *clientStream) Context() context.Context { + return cs.s.Context() +} + +func (cs *clientStream) Header() (metadata.MD, error) { + m, err := cs.s.Header() + if err != nil { + if _, ok := err.(transport.ConnectionError); !ok { + cs.t.CloseStream(cs.s, err) + } + } + return m, err +} + +func (cs *clientStream) Trailer() metadata.MD { + return cs.s.Trailer() +} + +func (cs *clientStream) SendMsg(m interface{}) (err error) { + if cs.tracing { + cs.mu.Lock() + if cs.traceInfo.tr != nil { + cs.traceInfo.tr.LazyLog(&payload{sent: true, msg: m}, true) + } + cs.mu.Unlock() + } + defer func() { + if err == nil || err == io.EOF { + return + } + if _, ok := err.(transport.ConnectionError); !ok { + cs.t.CloseStream(cs.s, err) + } + err = toRPCErr(err) + }() + out, err := encode(cs.codec, m, compressionNone) + if err != nil { + return transport.StreamErrorf(codes.Internal, "grpc: %v", err) + } + return cs.t.Write(cs.s, out, &transport.Options{Last: false}) +} + +func (cs *clientStream) RecvMsg(m interface{}) (err error) { + err = recv(cs.p, cs.codec, m) + defer func() { + // err != nil indicates the termination of the stream. + if err != nil { + cs.finish(err) + } + }() + if err == nil { + if cs.tracing { + cs.mu.Lock() + if cs.traceInfo.tr != nil { + cs.traceInfo.tr.LazyLog(&payload{sent: false, msg: m}, true) + } + cs.mu.Unlock() + } + if !cs.desc.ClientStreams || cs.desc.ServerStreams { + return + } + // Special handling for client streaming rpc. + err = recv(cs.p, cs.codec, m) + cs.t.CloseStream(cs.s, err) + if err == nil { + return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) + } + if err == io.EOF { + if cs.s.StatusCode() == codes.OK { + return nil + } + return Errorf(cs.s.StatusCode(), cs.s.StatusDesc()) + } + return toRPCErr(err) + } + if _, ok := err.(transport.ConnectionError); !ok { + cs.t.CloseStream(cs.s, err) + } + if err == io.EOF { + if cs.s.StatusCode() == codes.OK { + // Returns io.EOF to indicate the end of the stream. + return + } + return Errorf(cs.s.StatusCode(), cs.s.StatusDesc()) + } + return toRPCErr(err) +} + +func (cs *clientStream) CloseSend() (err error) { + err = cs.t.Write(cs.s, nil, &transport.Options{Last: true}) + if err == nil || err == io.EOF { + return + } + if _, ok := err.(transport.ConnectionError); !ok { + cs.t.CloseStream(cs.s, err) + } + err = toRPCErr(err) + return +} + +func (cs *clientStream) finish(err error) { + if !cs.tracing { + return + } + cs.mu.Lock() + defer cs.mu.Unlock() + if cs.traceInfo.tr != nil { + if err == nil || err == io.EOF { + cs.traceInfo.tr.LazyPrintf("RPC: [OK]") + } else { + cs.traceInfo.tr.LazyPrintf("RPC: [%v]", err) + cs.traceInfo.tr.SetError() + } + cs.traceInfo.tr.Finish() + cs.traceInfo.tr = nil + } +} + +// ServerStream defines the interface a server stream has to satisfy. +type ServerStream interface { + // SendHeader sends the header metadata. It should not be called + // after SendProto. It fails if called multiple times or if + // called after SendProto. + SendHeader(metadata.MD) error + // SetTrailer sets the trailer metadata which will be sent with the + // RPC status. + SetTrailer(metadata.MD) + Stream +} + +// serverStream implements a server side Stream. +type serverStream struct { + t transport.ServerTransport + s *transport.Stream + p *parser + codec Codec + statusCode codes.Code + statusDesc string + + tracing bool // set to EnableTracing when the serverStream is created. + + mu sync.Mutex // protects traceInfo + // traceInfo.tr is set when the serverStream is created (if EnableTracing is true), + // and is set to nil when the serverStream's finish method is called. + traceInfo traceInfo +} + +func (ss *serverStream) Context() context.Context { + return ss.s.Context() +} + +func (ss *serverStream) SendHeader(md metadata.MD) error { + return ss.t.WriteHeader(ss.s, md) +} + +func (ss *serverStream) SetTrailer(md metadata.MD) { + if md.Len() == 0 { + return + } + ss.s.SetTrailer(md) + return +} + +func (ss *serverStream) SendMsg(m interface{}) (err error) { + defer func() { + if ss.tracing { + ss.mu.Lock() + if err == nil { + ss.traceInfo.tr.LazyLog(&payload{sent: true, msg: m}, true) + } else { + ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + ss.traceInfo.tr.SetError() + } + + ss.mu.Unlock() + } + }() + out, err := encode(ss.codec, m, compressionNone) + if err != nil { + err = transport.StreamErrorf(codes.Internal, "grpc: %v", err) + return err + } + return ss.t.Write(ss.s, out, &transport.Options{Last: false}) +} + +func (ss *serverStream) RecvMsg(m interface{}) (err error) { + defer func() { + if ss.tracing { + ss.mu.Lock() + if err == nil { + ss.traceInfo.tr.LazyLog(&payload{sent: false, msg: m}, true) + } else if err != io.EOF { + ss.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true) + ss.traceInfo.tr.SetError() + } + ss.mu.Unlock() + } + }() + return recv(ss.p, ss.codec, m) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/test/codec_perf/perf.pb.go b/Godeps/_workspace/src/google.golang.org/grpc/test/codec_perf/perf.pb.go new file mode 100644 index 0000000..5f72c55 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/test/codec_perf/perf.pb.go @@ -0,0 +1,42 @@ +// Code generated by protoc-gen-go. +// source: perf.proto +// DO NOT EDIT! + +/* +Package codec_perf is a generated protocol buffer package. + +It is generated from these files: + perf.proto + +It has these top-level messages: + Buffer +*/ +package codec_perf + +import proto "github.com/golang/protobuf/proto" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = math.Inf + +// Buffer is a message that contains a body of bytes that is used to exercise +// encoding and decoding overheads. +type Buffer struct { + Body []byte `protobuf:"bytes,1,opt,name=body" json:"body,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Buffer) Reset() { *m = Buffer{} } +func (m *Buffer) String() string { return proto.CompactTextString(m) } +func (*Buffer) ProtoMessage() {} + +func (m *Buffer) GetBody() []byte { + if m != nil { + return m.Body + } + return nil +} + +func init() { +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/test/codec_perf/perf.proto b/Godeps/_workspace/src/google.golang.org/grpc/test/codec_perf/perf.proto new file mode 100644 index 0000000..4cf561f --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/test/codec_perf/perf.proto @@ -0,0 +1,11 @@ +// Messages used for performance tests that may not reference grpc directly for +// reasons of import cycles. +syntax = "proto2"; + +package codec.perf; + +// Buffer is a message that contains a body of bytes that is used to exercise +// encoding and decoding overheads. +message Buffer { + optional bytes body = 1; +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/test/end2end_test.go b/Godeps/_workspace/src/google.golang.org/grpc/test/end2end_test.go new file mode 100644 index 0000000..b1862b0 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/test/end2end_test.go @@ -0,0 +1,1012 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc_test + +import ( + "fmt" + "io" + "math" + "net" + "reflect" + "runtime" + "sync" + "syscall" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1alpha" + "google.golang.org/grpc/metadata" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +var ( + testMetadata = metadata.MD{ + "key1": []string{"value1"}, + "key2": []string{"value2"}, + } + testAppUA = "myApp1/1.0 myApp2/0.9" +) + +type testServer struct { + security string // indicate the authentication protocol used by this server. +} + +func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if md, ok := metadata.FromContext(ctx); ok { + // For testing purpose, returns an error if there is attached metadata other than + // the user agent set by the client application. + if _, ok := md["user-agent"]; !ok { + return nil, grpc.Errorf(codes.DataLoss, "got extra metadata") + } + var str []string + for _, entry := range md["user-agent"] { + str = append(str, "ua", entry) + } + grpc.SendHeader(ctx, metadata.Pairs(str...)) + } + return new(testpb.Empty), nil +} + +func newPayload(t testpb.PayloadType, size int32) *testpb.Payload { + if size < 0 { + grpclog.Fatalf("Requested a response with invalid length %d", size) + } + body := make([]byte, size) + switch t { + case testpb.PayloadType_COMPRESSABLE: + case testpb.PayloadType_UNCOMPRESSABLE: + grpclog.Fatalf("PayloadType UNCOMPRESSABLE is not supported") + default: + grpclog.Fatalf("Unsupported payload type: %d", t) + } + return &testpb.Payload{ + Type: t.Enum(), + Body: body, + } +} + +func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + md, ok := metadata.FromContext(ctx) + if ok { + if err := grpc.SendHeader(ctx, md); err != nil { + grpclog.Fatalf("grpc.SendHeader(%v, %v) = %v, want %v", ctx, md, err, nil) + } + grpc.SetTrailer(ctx, md) + } + if s.security != "" { + // Check Auth info + authInfo, ok := credentials.FromContext(ctx) + if !ok { + grpclog.Fatalf("Failed to get AuthInfo from ctx.") + } + var authType string + switch info := authInfo.(type) { + case credentials.TLSInfo: + authType = info.AuthType() + default: + grpclog.Fatalf("Unknown AuthInfo type") + } + if authType != s.security { + grpclog.Fatalf("Wrong auth type: got %q, want %q", authType, s.security) + } + } + + // Simulate some service delay. + time.Sleep(time.Second) + return &testpb.SimpleResponse{ + Payload: newPayload(in.GetResponseType(), in.GetResponseSize()), + }, nil +} + +func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error { + if md, ok := metadata.FromContext(stream.Context()); ok { + // For testing purpose, returns an error if there is attached metadata. + if len(md) > 0 { + return grpc.Errorf(codes.DataLoss, "got extra metadata") + } + } + cs := args.GetResponseParameters() + for _, c := range cs { + if us := c.GetIntervalUs(); us > 0 { + time.Sleep(time.Duration(us) * time.Microsecond) + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: newPayload(args.GetResponseType(), c.GetSize()), + }); err != nil { + return err + } + } + return nil +} + +func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInputCallServer) error { + var sum int + for { + in, err := stream.Recv() + if err == io.EOF { + return stream.SendAndClose(&testpb.StreamingInputCallResponse{ + AggregatedPayloadSize: proto.Int32(int32(sum)), + }) + } + if err != nil { + return err + } + p := in.GetPayload().GetBody() + sum += len(p) + } +} + +func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { + md, ok := metadata.FromContext(stream.Context()) + if ok { + if err := stream.SendHeader(md); err != nil { + grpclog.Fatalf("%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) + } + stream.SetTrailer(md) + } + for { + in, err := stream.Recv() + if err == io.EOF { + // read done. + return nil + } + if err != nil { + return err + } + cs := in.GetResponseParameters() + for _, c := range cs { + if us := c.GetIntervalUs(); us > 0 { + time.Sleep(time.Duration(us) * time.Microsecond) + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: newPayload(in.GetResponseType(), c.GetSize()), + }); err != nil { + return err + } + } + } +} + +func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServer) error { + var msgBuf []*testpb.StreamingOutputCallRequest + for { + in, err := stream.Recv() + if err == io.EOF { + // read done. + break + } + if err != nil { + return err + } + msgBuf = append(msgBuf, in) + } + for _, m := range msgBuf { + cs := m.GetResponseParameters() + for _, c := range cs { + if us := c.GetIntervalUs(); us > 0 { + time.Sleep(time.Duration(us) * time.Microsecond) + } + if err := stream.Send(&testpb.StreamingOutputCallResponse{ + Payload: newPayload(m.GetResponseType(), c.GetSize()), + }); err != nil { + return err + } + } + } + return nil +} + +const tlsDir = "testdata/" + +func TestDialTimeout(t *testing.T) { + conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTimeout(time.Millisecond), grpc.WithBlock(), grpc.WithInsecure()) + if err == nil { + conn.Close() + } + if err != grpc.ErrClientConnTimeout { + t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, grpc.ErrClientConnTimeout) + } +} + +func TestTLSDialTimeout(t *testing.T) { + creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") + if err != nil { + t.Fatalf("Failed to create credentials %v", err) + } + conn, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock()) + if err == nil { + conn.Close() + } + if err != grpc.ErrClientConnTimeout { + t.Fatalf("grpc.Dial(_, _) = %v, %v, want %v", conn, err, grpc.ErrClientConnTimeout) + } +} + +func TestCredentialsMisuse(t *testing.T) { + creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") + if err != nil { + t.Fatalf("Failed to create credentials %v", err) + } + // Two conflicting credential configurations + if _, err := grpc.Dial("Non-Existent.Server:80", grpc.WithTransportCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock(), grpc.WithInsecure()); err != grpc.ErrCredentialsMisuse { + t.Fatalf("grpc.Dial(_, _) = _, %v, want _, %v", err, grpc.ErrCredentialsMisuse) + } + // security info on insecure connection + if _, err := grpc.Dial("Non-Existent.Server:80", grpc.WithPerRPCCredentials(creds), grpc.WithTimeout(time.Millisecond), grpc.WithBlock(), grpc.WithInsecure()); err != grpc.ErrCredentialsMisuse { + t.Fatalf("grpc.Dial(_, _) = _, %v, want _, %v", err, grpc.ErrCredentialsMisuse) + } +} + +func TestReconnectTimeout(t *testing.T) { + lis, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + _, port, err := net.SplitHostPort(lis.Addr().String()) + if err != nil { + t.Fatalf("Failed to parse listener address: %v", err) + } + addr := "localhost:" + port + conn, err := grpc.Dial(addr, grpc.WithTimeout(5*time.Second), grpc.WithBlock(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("Failed to dial to the server %q: %v", addr, err) + } + // Close unaccepted connection (i.e., conn). + lis.Close() + tc := testpb.NewTestServiceClient(conn) + waitC := make(chan struct{}) + go func() { + defer close(waitC) + argSize := 271828 + respSize := 314159 + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(respSize)), + Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)), + } + if _, err := tc.UnaryCall(context.Background(), req); err == nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, , want _, non-nil") + } + }() + // Block untill reconnect times out. + <-waitC + if err := conn.Close(); err != grpc.ErrClientConnClosing { + t.Fatalf("%v.Close() = %v, want %v", conn, err, grpc.ErrClientConnClosing) + } +} + +func unixDialer(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", addr, timeout) +} + +type env struct { + network string // The type of network such as tcp, unix, etc. + dialer func(addr string, timeout time.Duration) (net.Conn, error) + security string // The security protocol such as TLS, SSH, etc. +} + +func listTestEnv() []env { + if runtime.GOOS == "windows" { + return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}} + } + return []env{env{"tcp", nil, ""}, env{"tcp", nil, "tls"}, env{"unix", unixDialer, ""}, env{"unix", unixDialer, "tls"}} +} + +func setUp(hs *health.HealthServer, maxStream uint32, ua string, e env) (s *grpc.Server, cc *grpc.ClientConn) { + sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(maxStream)} + la := ":0" + switch e.network { + case "unix": + la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now()) + syscall.Unlink(la) + } + lis, err := net.Listen(e.network, la) + if err != nil { + grpclog.Fatalf("Failed to listen: %v", err) + } + if e.security == "tls" { + creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") + if err != nil { + grpclog.Fatalf("Failed to generate credentials %v", err) + } + sopts = append(sopts, grpc.Creds(creds)) + } + s = grpc.NewServer(sopts...) + if hs != nil { + healthpb.RegisterHealthCheckServer(s, hs) + } + testpb.RegisterTestServiceServer(s, &testServer{security: e.security}) + go s.Serve(lis) + addr := la + switch e.network { + case "unix": + default: + _, port, err := net.SplitHostPort(lis.Addr().String()) + if err != nil { + grpclog.Fatalf("Failed to parse listener address: %v", err) + } + addr = "localhost:" + port + } + if e.security == "tls" { + creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com") + if err != nil { + grpclog.Fatalf("Failed to create credentials %v", err) + } + cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithDialer(e.dialer), grpc.WithUserAgent(ua)) + } else { + cc, err = grpc.Dial(addr, grpc.WithDialer(e.dialer), grpc.WithInsecure(), grpc.WithUserAgent(ua)) + } + if err != nil { + grpclog.Fatalf("Dial(%q) = %v", addr, err) + } + return +} + +func tearDown(s *grpc.Server, cc *grpc.ClientConn) { + cc.Close() + s.Stop() +} + +func TestTimeoutOnDeadServer(t *testing.T) { + for _, e := range listTestEnv() { + testTimeoutOnDeadServer(t, e) + } +} + +func testTimeoutOnDeadServer(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + if ok := cc.WaitForStateChange(time.Second, grpc.Idle); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Idle, ok) + } + if ok := cc.WaitForStateChange(time.Second, grpc.Connecting); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Connecting, ok) + } + if cc.State() != grpc.Ready { + t.Fatalf("cc.State() = %s, want %s", cc.State(), grpc.Ready) + } + if ok := cc.WaitForStateChange(time.Millisecond, grpc.Ready); ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want false", grpc.Ready, ok) + } + s.Stop() + // Set -1 as the timeout to make sure if transportMonitor gets error + // notification in time the failure path of the 1st invoke of + // ClientConn.wait hits the deadline exceeded error. + ctx, _ := context.WithTimeout(context.Background(), -1) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf("TestService/EmptyCall(%v, _) = _, error %v, want _, error code: %d", ctx, err, codes.DeadlineExceeded) + } + if ok := cc.WaitForStateChange(time.Second, grpc.Ready); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Ready, ok) + } + state := cc.State() + if state != grpc.Connecting && state != grpc.TransientFailure { + t.Fatalf("cc.State() = %s, want %s or %s", state, grpc.Connecting, grpc.TransientFailure) + } + cc.Close() +} + +func healthCheck(t time.Duration, cc *grpc.ClientConn, serviceName string) (*healthpb.HealthCheckResponse, error) { + ctx, _ := context.WithTimeout(context.Background(), t) + hc := healthpb.NewHealthCheckClient(cc) + req := &healthpb.HealthCheckRequest{ + Service: serviceName, + } + return hc.Check(ctx, req) +} + +func TestHealthCheckOnSuccess(t *testing.T) { + for _, e := range listTestEnv() { + testHealthCheckOnSuccess(t, e) + } +} + +func testHealthCheckOnSuccess(t *testing.T, e env) { + hs := health.NewHealthServer() + hs.SetServingStatus("grpc.health.v1alpha.HealthCheck", 1) + s, cc := setUp(hs, math.MaxUint32, "", e) + defer tearDown(s, cc) + if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1alpha.HealthCheck"); err != nil { + t.Fatalf("HealthCheck/Check(_, _) = _, %v, want _, ", err) + } +} + +func TestHealthCheckOnFailure(t *testing.T) { + for _, e := range listTestEnv() { + testHealthCheckOnFailure(t, e) + } +} + +func testHealthCheckOnFailure(t *testing.T, e env) { + hs := health.NewHealthServer() + hs.SetServingStatus("grpc.health.v1alpha.HealthCheck", 1) + s, cc := setUp(hs, math.MaxUint32, "", e) + defer tearDown(s, cc) + if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1alpha.HealthCheck"); err != grpc.Errorf(codes.DeadlineExceeded, "context deadline exceeded") { + t.Fatalf("HealthCheck/Check(_, _) = _, %v, want _, error code %d", err, codes.DeadlineExceeded) + } +} + +func TestHealthCheckOff(t *testing.T) { + for _, e := range listTestEnv() { + testHealthCheckOff(t, e) + } +} + +func testHealthCheckOff(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + defer tearDown(s, cc) + if _, err := healthCheck(1*time.Second, cc, ""); err != grpc.Errorf(codes.Unimplemented, "unknown service grpc.health.v1alpha.HealthCheck") { + t.Fatalf("HealthCheck/Check(_, _) = _, %v, want _, error code %d", err, codes.Unimplemented) + } +} + +func TestHealthCheckServingStatus(t *testing.T) { + for _, e := range listTestEnv() { + testHealthCheckServingStatus(t, e) + } +} + +func testHealthCheckServingStatus(t *testing.T, e env) { + hs := health.NewHealthServer() + s, cc := setUp(hs, math.MaxUint32, "", e) + defer tearDown(s, cc) + out, err := healthCheck(1*time.Second, cc, "") + if err != nil { + t.Fatalf("HealthCheck/Check(_, _) = _, %v, want _, ", err) + } + if out.Status != healthpb.HealthCheckResponse_SERVING { + t.Fatalf("Got the serving status %v, want SERVING", out.Status) + } + if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1alpha.HealthCheck"); err != grpc.Errorf(codes.NotFound, "unknown service") { + t.Fatalf("HealthCheck/Check(_, _) = _, %v, want _, error code %d", err, codes.NotFound) + } + hs.SetServingStatus("grpc.health.v1alpha.HealthCheck", healthpb.HealthCheckResponse_SERVING) + out, err = healthCheck(1*time.Second, cc, "grpc.health.v1alpha.HealthCheck") + if err != nil { + t.Fatalf("HealthCheck/Check(_, _) = _, %v, want _, ", err) + } + if out.Status != healthpb.HealthCheckResponse_SERVING { + t.Fatalf("Got the serving status %v, want SERVING", out.Status) + } + hs.SetServingStatus("grpc.health.v1alpha.HealthCheck", healthpb.HealthCheckResponse_NOT_SERVING) + out, err = healthCheck(1*time.Second, cc, "grpc.health.v1alpha.HealthCheck") + if err != nil { + t.Fatalf("HealthCheck/Check(_, _) = _, %v, want _, ", err) + } + if out.Status != healthpb.HealthCheckResponse_NOT_SERVING { + t.Fatalf("Got the serving status %v, want NOT_SERVING", out.Status) + } + +} + +func TestEmptyUnaryWithUserAgent(t *testing.T) { + for _, e := range listTestEnv() { + testEmptyUnaryWithUserAgent(t, e) + } +} + +func testEmptyUnaryWithUserAgent(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, testAppUA, e) + // Wait until cc is connected. + if ok := cc.WaitForStateChange(time.Second, grpc.Idle); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Idle, ok) + } + if ok := cc.WaitForStateChange(10*time.Second, grpc.Connecting); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Connecting, ok) + } + if cc.State() != grpc.Ready { + t.Fatalf("cc.State() = %s, want %s", cc.State(), grpc.Ready) + } + if ok := cc.WaitForStateChange(time.Second, grpc.Ready); ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want false", grpc.Ready, ok) + } + tc := testpb.NewTestServiceClient(cc) + var header metadata.MD + reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Header(&header)) + if err != nil || !proto.Equal(&testpb.Empty{}, reply) { + t.Fatalf("TestService/EmptyCall(_, _) = %v, %v, want %v, ", reply, err, &testpb.Empty{}) + } + if v, ok := header["ua"]; !ok || v[0] != testAppUA { + t.Fatalf("header[\"ua\"] = %q, %t, want %q, true", v, ok, testAppUA) + } + tearDown(s, cc) + if ok := cc.WaitForStateChange(5*time.Second, grpc.Ready); !ok { + t.Fatalf("cc.WaitForStateChange(_, %s) = %t, want true", grpc.Ready, ok) + } + if cc.State() != grpc.Shutdown { + t.Fatalf("cc.State() = %s, want %s", cc.State(), grpc.Shutdown) + } +} + +func TestFailedEmptyUnary(t *testing.T) { + for _, e := range listTestEnv() { + testFailedEmptyUnary(t, e) + } +} + +func testFailedEmptyUnary(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + ctx := metadata.NewContext(context.Background(), testMetadata) + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != grpc.Errorf(codes.DataLoss, "got extra metadata") { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, grpc.Errorf(codes.DataLoss, "got extra metadata")) + } +} + +func TestLargeUnary(t *testing.T) { + for _, e := range listTestEnv() { + testLargeUnary(t, e) + } +} + +func testLargeUnary(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + argSize := 271828 + respSize := 314159 + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(respSize)), + Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)), + } + reply, err := tc.UnaryCall(context.Background(), req) + if err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) + } + pt := reply.GetPayload().GetType() + ps := len(reply.GetPayload().GetBody()) + if pt != testpb.PayloadType_COMPRESSABLE || ps != respSize { + t.Fatalf("Got the reply with type %d len %d; want %d, %d", pt, ps, testpb.PayloadType_COMPRESSABLE, respSize) + } +} + +func TestMetadataUnaryRPC(t *testing.T) { + for _, e := range listTestEnv() { + testMetadataUnaryRPC(t, e) + } +} + +func testMetadataUnaryRPC(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + argSize := 2718 + respSize := 314 + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(respSize)), + Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)), + } + var header, trailer metadata.MD + ctx := metadata.NewContext(context.Background(), testMetadata) + _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.Trailer(&trailer)) + if err != nil { + t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) + } + if !reflect.DeepEqual(testMetadata, header) { + t.Fatalf("Received header metadata %v, want %v", header, testMetadata) + } + if !reflect.DeepEqual(testMetadata, trailer) { + t.Fatalf("Received trailer metadata %v, want %v", trailer, testMetadata) + } +} + +func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup) { + argSize := 2718 + respSize := 314 + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(respSize)), + Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)), + } + reply, err := tc.UnaryCall(context.Background(), req) + if err != nil { + t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) + } + pt := reply.GetPayload().GetType() + ps := len(reply.GetPayload().GetBody()) + if pt != testpb.PayloadType_COMPRESSABLE || ps != respSize { + t.Fatalf("Got the reply with type %d len %d; want %d, %d", pt, ps, testpb.PayloadType_COMPRESSABLE, respSize) + } + wg.Done() +} + +func TestRetry(t *testing.T) { + for _, e := range listTestEnv() { + testRetry(t, e) + } +} + +// This test mimics a user who sends 1000 RPCs concurrently on a faulty transport. +// TODO(zhaoq): Refactor to make this clearer and add more cases to test racy +// and error-prone paths. +func testRetry(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + var wg sync.WaitGroup + wg.Add(1) + go func() { + time.Sleep(1 * time.Second) + // The server shuts down the network connection to make a + // transport error which will be detected by the client side + // code. + s.TestingCloseConns() + wg.Done() + }() + // All these RPCs should succeed eventually. + for i := 0; i < 1000; i++ { + time.Sleep(2 * time.Millisecond) + wg.Add(1) + go performOneRPC(t, tc, &wg) + } + wg.Wait() +} + +func TestRPCTimeout(t *testing.T) { + for _, e := range listTestEnv() { + testRPCTimeout(t, e) + } +} + +// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism. +func testRPCTimeout(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + argSize := 2718 + respSize := 314 + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(respSize)), + Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)), + } + for i := -1; i <= 10; i++ { + ctx, _ := context.WithTimeout(context.Background(), time.Duration(i)*time.Millisecond) + reply, err := tc.UnaryCall(ctx, req) + if grpc.Code(err) != codes.DeadlineExceeded { + t.Fatalf(`TestService/UnaryCallv(_, _) = %v, %v; want , error code: %d`, reply, err, codes.DeadlineExceeded) + } + } +} + +func TestCancel(t *testing.T) { + for _, e := range listTestEnv() { + testCancel(t, e) + } +} + +func testCancel(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + argSize := 2718 + respSize := 314 + req := &testpb.SimpleRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseSize: proto.Int32(int32(respSize)), + Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(argSize)), + } + ctx, cancel := context.WithCancel(context.Background()) + time.AfterFunc(1*time.Millisecond, cancel) + reply, err := tc.UnaryCall(ctx, req) + if grpc.Code(err) != codes.Canceled { + t.Fatalf(`TestService/UnaryCall(_, _) = %v, %v; want , error code: %d`, reply, err, codes.Canceled) + } +} + +// The following tests the gRPC streaming RPC implementations. +// TODO(zhaoq): Have better coverage on error cases. +var ( + reqSizes = []int{27182, 8, 1828, 45904} + respSizes = []int{31415, 9, 2653, 58979} +) + +func TestPingPong(t *testing.T) { + for _, e := range listTestEnv() { + testPingPong(t, e) + } +} + +func testPingPong(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + stream, err := tc.FullDuplexCall(context.Background()) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + var index int + for index < len(reqSizes) { + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(respSizes[index])), + }, + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(reqSizes[index])), + } + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + } + reply, err := stream.Recv() + if err != nil { + t.Fatalf("%v.Recv() = %v, want ", stream, err) + } + pt := reply.GetPayload().GetType() + if pt != testpb.PayloadType_COMPRESSABLE { + t.Fatalf("Got the reply of type %d, want %d", pt, testpb.PayloadType_COMPRESSABLE) + } + size := len(reply.GetPayload().GetBody()) + if size != int(respSizes[index]) { + t.Fatalf("Got reply body of length %d, want %d", size, respSizes[index]) + } + index++ + } + if err := stream.CloseSend(); err != nil { + t.Fatalf("%v.CloseSend() got %v, want %v", stream, err, nil) + } + if _, err := stream.Recv(); err != io.EOF { + t.Fatalf("%v failed to complele the ping pong test: %v", stream, err) + } +} + +func TestMetadataStreamingRPC(t *testing.T) { + for _, e := range listTestEnv() { + testMetadataStreamingRPC(t, e) + } +} + +func testMetadataStreamingRPC(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + ctx := metadata.NewContext(context.Background(), testMetadata) + stream, err := tc.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) + } + go func() { + headerMD, err := stream.Header() + if e.security == "tls" { + delete(headerMD, "transport_security_type") + } + if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { + t.Errorf("#1 %v.Header() = %v, %v, want %v, ", stream, headerMD, err, testMetadata) + } + // test the cached value. + headerMD, err = stream.Header() + if err != nil || !reflect.DeepEqual(testMetadata, headerMD) { + t.Errorf("#2 %v.Header() = %v, %v, want %v, ", stream, headerMD, err, testMetadata) + } + var index int + for index < len(reqSizes) { + respParam := []*testpb.ResponseParameters{ + { + Size: proto.Int32(int32(respSizes[index])), + }, + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + Payload: newPayload(testpb.PayloadType_COMPRESSABLE, int32(reqSizes[index])), + } + if err := stream.Send(req); err != nil { + t.Errorf("%v.Send(%v) = %v, want ", stream, req, err) + return + } + index++ + } + // Tell the server we're done sending args. + stream.CloseSend() + }() + for { + if _, err := stream.Recv(); err != nil { + break + } + } + trailerMD := stream.Trailer() + if !reflect.DeepEqual(testMetadata, trailerMD) { + t.Fatalf("%v.Trailer() = %v, want %v", stream, trailerMD, testMetadata) + } +} + +func TestServerStreaming(t *testing.T) { + for _, e := range listTestEnv() { + testServerStreaming(t, e) + } +} + +func testServerStreaming(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + respParam := make([]*testpb.ResponseParameters, len(respSizes)) + for i, s := range respSizes { + respParam[i] = &testpb.ResponseParameters{ + Size: proto.Int32(int32(s)), + } + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + } + stream, err := tc.StreamingOutputCall(context.Background(), req) + if err != nil { + t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want ", tc, err) + } + var rpcStatus error + var respCnt int + var index int + for { + reply, err := stream.Recv() + if err != nil { + rpcStatus = err + break + } + pt := reply.GetPayload().GetType() + if pt != testpb.PayloadType_COMPRESSABLE { + t.Fatalf("Got the reply of type %d, want %d", pt, testpb.PayloadType_COMPRESSABLE) + } + size := len(reply.GetPayload().GetBody()) + if size != int(respSizes[index]) { + t.Fatalf("Got reply body of length %d, want %d", size, respSizes[index]) + } + index++ + respCnt++ + } + if rpcStatus != io.EOF { + t.Fatalf("Failed to finish the server streaming rpc: %v, want ", rpcStatus) + } + if respCnt != len(respSizes) { + t.Fatalf("Got %d reply, want %d", len(respSizes), respCnt) + } +} + +func TestFailedServerStreaming(t *testing.T) { + for _, e := range listTestEnv() { + testFailedServerStreaming(t, e) + } +} + +func testFailedServerStreaming(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + respParam := make([]*testpb.ResponseParameters, len(respSizes)) + for i, s := range respSizes { + respParam[i] = &testpb.ResponseParameters{ + Size: proto.Int32(int32(s)), + } + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), + ResponseParameters: respParam, + } + ctx := metadata.NewContext(context.Background(), testMetadata) + stream, err := tc.StreamingOutputCall(ctx, req) + if err != nil { + t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want ", tc, err) + } + if _, err := stream.Recv(); err != grpc.Errorf(codes.DataLoss, "got extra metadata") { + t.Fatalf("%v.Recv() = _, %v, want _, %v", stream, err, grpc.Errorf(codes.DataLoss, "got extra metadata")) + } +} + +func TestClientStreaming(t *testing.T) { + for _, e := range listTestEnv() { + testClientStreaming(t, e) + } +} + +func testClientStreaming(t *testing.T, e env) { + s, cc := setUp(nil, math.MaxUint32, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + stream, err := tc.StreamingInputCall(context.Background()) + if err != nil { + t.Fatalf("%v.StreamingInputCall(_) = _, %v, want ", tc, err) + } + var sum int + for _, s := range reqSizes { + pl := newPayload(testpb.PayloadType_COMPRESSABLE, int32(s)) + req := &testpb.StreamingInputCallRequest{ + Payload: pl, + } + if err := stream.Send(req); err != nil { + t.Fatalf("%v.Send(%v) = %v, want ", stream, req, err) + } + sum += s + } + reply, err := stream.CloseAndRecv() + if err != nil { + t.Fatalf("%v.CloseAndRecv() got error %v, want %v", stream, err, nil) + } + if reply.GetAggregatedPayloadSize() != int32(sum) { + t.Fatalf("%v.CloseAndRecv().GetAggregatePayloadSize() = %v; want %v", stream, reply.GetAggregatedPayloadSize(), sum) + } +} + +func TestExceedMaxStreamsLimit(t *testing.T) { + for _, e := range listTestEnv() { + testExceedMaxStreamsLimit(t, e) + } +} + +func testExceedMaxStreamsLimit(t *testing.T, e env) { + // Only allows 1 live stream per server transport. + s, cc := setUp(nil, 1, "", e) + tc := testpb.NewTestServiceClient(cc) + defer tearDown(s, cc) + done := make(chan struct{}) + ch := make(chan int) + go func() { + for { + select { + case <-time.After(5 * time.Millisecond): + ch <- 0 + case <-time.After(5 * time.Second): + close(done) + return + } + } + }() + // Loop until a stream creation hangs due to the new max stream setting. + for { + select { + case <-ch: + ctx, _ := context.WithTimeout(context.Background(), time.Second) + if _, err := tc.StreamingInputCall(ctx); err != nil { + if grpc.Code(err) == codes.DeadlineExceeded { + return + } + t.Fatalf("%v.StreamingInputCall(_) = %v, want ", tc, err) + } + case <-done: + t.Fatalf("Client has not received the max stream setting in 5 seconds.") + } + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/test/grpc_testing/test.pb.go b/Godeps/_workspace/src/google.golang.org/grpc/test/grpc_testing/test.pb.go new file mode 100644 index 0000000..b25e98b --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/test/grpc_testing/test.pb.go @@ -0,0 +1,702 @@ +// Code generated by protoc-gen-go. +// source: test.proto +// DO NOT EDIT! + +/* +Package grpc_testing is a generated protocol buffer package. + +It is generated from these files: + test.proto + +It has these top-level messages: + Empty + Payload + SimpleRequest + SimpleResponse + StreamingInputCallRequest + StreamingInputCallResponse + ResponseParameters + StreamingOutputCallRequest + StreamingOutputCallResponse +*/ +package grpc_testing + +import proto "github.com/golang/protobuf/proto" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = math.Inf + +// The type of payload that should be returned. +type PayloadType int32 + +const ( + // Compressable text format. + PayloadType_COMPRESSABLE PayloadType = 0 + // Uncompressable binary format. + PayloadType_UNCOMPRESSABLE PayloadType = 1 + // Randomly chosen from all other formats defined in this enum. + PayloadType_RANDOM PayloadType = 2 +) + +var PayloadType_name = map[int32]string{ + 0: "COMPRESSABLE", + 1: "UNCOMPRESSABLE", + 2: "RANDOM", +} +var PayloadType_value = map[string]int32{ + "COMPRESSABLE": 0, + "UNCOMPRESSABLE": 1, + "RANDOM": 2, +} + +func (x PayloadType) Enum() *PayloadType { + p := new(PayloadType) + *p = x + return p +} +func (x PayloadType) String() string { + return proto.EnumName(PayloadType_name, int32(x)) +} +func (x *PayloadType) UnmarshalJSON(data []byte) error { + value, err := proto.UnmarshalJSONEnum(PayloadType_value, data, "PayloadType") + if err != nil { + return err + } + *x = PayloadType(value) + return nil +} + +type Empty struct { + XXX_unrecognized []byte `json:"-"` +} + +func (m *Empty) Reset() { *m = Empty{} } +func (m *Empty) String() string { return proto.CompactTextString(m) } +func (*Empty) ProtoMessage() {} + +// A block of data, to simply increase gRPC message size. +type Payload struct { + // The type of data in body. + Type *PayloadType `protobuf:"varint,1,opt,name=type,enum=grpc.testing.PayloadType" json:"type,omitempty"` + // Primary contents of payload. + Body []byte `protobuf:"bytes,2,opt,name=body" json:"body,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *Payload) Reset() { *m = Payload{} } +func (m *Payload) String() string { return proto.CompactTextString(m) } +func (*Payload) ProtoMessage() {} + +func (m *Payload) GetType() PayloadType { + if m != nil && m.Type != nil { + return *m.Type + } + return PayloadType_COMPRESSABLE +} + +func (m *Payload) GetBody() []byte { + if m != nil { + return m.Body + } + return nil +} + +// Unary request. +type SimpleRequest struct { + // Desired payload type in the response from the server. + // If response_type is RANDOM, server randomly chooses one from other formats. + ResponseType *PayloadType `protobuf:"varint,1,opt,name=response_type,enum=grpc.testing.PayloadType" json:"response_type,omitempty"` + // Desired payload size in the response from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + ResponseSize *int32 `protobuf:"varint,2,opt,name=response_size" json:"response_size,omitempty"` + // Optional input payload sent along with the request. + Payload *Payload `protobuf:"bytes,3,opt,name=payload" json:"payload,omitempty"` + // Whether SimpleResponse should include username. + FillUsername *bool `protobuf:"varint,4,opt,name=fill_username" json:"fill_username,omitempty"` + // Whether SimpleResponse should include OAuth scope. + FillOauthScope *bool `protobuf:"varint,5,opt,name=fill_oauth_scope" json:"fill_oauth_scope,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SimpleRequest) Reset() { *m = SimpleRequest{} } +func (m *SimpleRequest) String() string { return proto.CompactTextString(m) } +func (*SimpleRequest) ProtoMessage() {} + +func (m *SimpleRequest) GetResponseType() PayloadType { + if m != nil && m.ResponseType != nil { + return *m.ResponseType + } + return PayloadType_COMPRESSABLE +} + +func (m *SimpleRequest) GetResponseSize() int32 { + if m != nil && m.ResponseSize != nil { + return *m.ResponseSize + } + return 0 +} + +func (m *SimpleRequest) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +func (m *SimpleRequest) GetFillUsername() bool { + if m != nil && m.FillUsername != nil { + return *m.FillUsername + } + return false +} + +func (m *SimpleRequest) GetFillOauthScope() bool { + if m != nil && m.FillOauthScope != nil { + return *m.FillOauthScope + } + return false +} + +// Unary response, as configured by the request. +type SimpleResponse struct { + // Payload to increase message size. + Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"` + // The user the request came from, for verifying authentication was + // successful when the client expected it. + Username *string `protobuf:"bytes,2,opt,name=username" json:"username,omitempty"` + // OAuth scope. + OauthScope *string `protobuf:"bytes,3,opt,name=oauth_scope" json:"oauth_scope,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *SimpleResponse) Reset() { *m = SimpleResponse{} } +func (m *SimpleResponse) String() string { return proto.CompactTextString(m) } +func (*SimpleResponse) ProtoMessage() {} + +func (m *SimpleResponse) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +func (m *SimpleResponse) GetUsername() string { + if m != nil && m.Username != nil { + return *m.Username + } + return "" +} + +func (m *SimpleResponse) GetOauthScope() string { + if m != nil && m.OauthScope != nil { + return *m.OauthScope + } + return "" +} + +// Client-streaming request. +type StreamingInputCallRequest struct { + // Optional input payload sent along with the request. + Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *StreamingInputCallRequest) Reset() { *m = StreamingInputCallRequest{} } +func (m *StreamingInputCallRequest) String() string { return proto.CompactTextString(m) } +func (*StreamingInputCallRequest) ProtoMessage() {} + +func (m *StreamingInputCallRequest) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +// Client-streaming response. +type StreamingInputCallResponse struct { + // Aggregated size of payloads received from the client. + AggregatedPayloadSize *int32 `protobuf:"varint,1,opt,name=aggregated_payload_size" json:"aggregated_payload_size,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *StreamingInputCallResponse) Reset() { *m = StreamingInputCallResponse{} } +func (m *StreamingInputCallResponse) String() string { return proto.CompactTextString(m) } +func (*StreamingInputCallResponse) ProtoMessage() {} + +func (m *StreamingInputCallResponse) GetAggregatedPayloadSize() int32 { + if m != nil && m.AggregatedPayloadSize != nil { + return *m.AggregatedPayloadSize + } + return 0 +} + +// Configuration for a particular response. +type ResponseParameters struct { + // Desired payload sizes in responses from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + Size *int32 `protobuf:"varint,1,opt,name=size" json:"size,omitempty"` + // Desired interval between consecutive responses in the response stream in + // microseconds. + IntervalUs *int32 `protobuf:"varint,2,opt,name=interval_us" json:"interval_us,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *ResponseParameters) Reset() { *m = ResponseParameters{} } +func (m *ResponseParameters) String() string { return proto.CompactTextString(m) } +func (*ResponseParameters) ProtoMessage() {} + +func (m *ResponseParameters) GetSize() int32 { + if m != nil && m.Size != nil { + return *m.Size + } + return 0 +} + +func (m *ResponseParameters) GetIntervalUs() int32 { + if m != nil && m.IntervalUs != nil { + return *m.IntervalUs + } + return 0 +} + +// Server-streaming request. +type StreamingOutputCallRequest struct { + // Desired payload type in the response from the server. + // If response_type is RANDOM, the payload from each response in the stream + // might be of different types. This is to simulate a mixed type of payload + // stream. + ResponseType *PayloadType `protobuf:"varint,1,opt,name=response_type,enum=grpc.testing.PayloadType" json:"response_type,omitempty"` + // Configuration for each expected response message. + ResponseParameters []*ResponseParameters `protobuf:"bytes,2,rep,name=response_parameters" json:"response_parameters,omitempty"` + // Optional input payload sent along with the request. + Payload *Payload `protobuf:"bytes,3,opt,name=payload" json:"payload,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *StreamingOutputCallRequest) Reset() { *m = StreamingOutputCallRequest{} } +func (m *StreamingOutputCallRequest) String() string { return proto.CompactTextString(m) } +func (*StreamingOutputCallRequest) ProtoMessage() {} + +func (m *StreamingOutputCallRequest) GetResponseType() PayloadType { + if m != nil && m.ResponseType != nil { + return *m.ResponseType + } + return PayloadType_COMPRESSABLE +} + +func (m *StreamingOutputCallRequest) GetResponseParameters() []*ResponseParameters { + if m != nil { + return m.ResponseParameters + } + return nil +} + +func (m *StreamingOutputCallRequest) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +// Server-streaming response, as configured by the request and parameters. +type StreamingOutputCallResponse struct { + // Payload to increase response size. + Payload *Payload `protobuf:"bytes,1,opt,name=payload" json:"payload,omitempty"` + XXX_unrecognized []byte `json:"-"` +} + +func (m *StreamingOutputCallResponse) Reset() { *m = StreamingOutputCallResponse{} } +func (m *StreamingOutputCallResponse) String() string { return proto.CompactTextString(m) } +func (*StreamingOutputCallResponse) ProtoMessage() {} + +func (m *StreamingOutputCallResponse) GetPayload() *Payload { + if m != nil { + return m.Payload + } + return nil +} + +func init() { + proto.RegisterEnum("grpc.testing.PayloadType", PayloadType_name, PayloadType_value) +} + +// Client API for TestService service + +type TestServiceClient interface { + // One empty request followed by one empty response. + EmptyCall(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) + // One request followed by one response. + // The server returns the client payload as-is. + UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) + // One request followed by a sequence of responses (streamed download). + // The server returns the payload with client desired type and sizes. + StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) + // A sequence of requests followed by one response (streamed upload). + // The server returns the aggregated size of client payload as the result. + StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) + // A sequence of requests with each request served by the server immediately. + // As one request could lead to multiple responses, this interface + // demonstrates the idea of full duplexing. + FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) + // A sequence of requests followed by a sequence of responses. + // The server buffers all the client requests and then serves them in order. A + // stream of responses are returned to the client when the server starts with + // first request. + HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) +} + +type testServiceClient struct { + cc *grpc.ClientConn +} + +func NewTestServiceClient(cc *grpc.ClientConn) TestServiceClient { + return &testServiceClient{cc} +} + +func (c *testServiceClient) EmptyCall(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) { + out := new(Empty) + err := grpc.Invoke(ctx, "/grpc.testing.TestService/EmptyCall", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *testServiceClient) UnaryCall(ctx context.Context, in *SimpleRequest, opts ...grpc.CallOption) (*SimpleResponse, error) { + out := new(SimpleResponse) + err := grpc.Invoke(ctx, "/grpc.testing.TestService/UnaryCall", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *testServiceClient) StreamingOutputCall(ctx context.Context, in *StreamingOutputCallRequest, opts ...grpc.CallOption) (TestService_StreamingOutputCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[0], c.cc, "/grpc.testing.TestService/StreamingOutputCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceStreamingOutputCallClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type TestService_StreamingOutputCallClient interface { + Recv() (*StreamingOutputCallResponse, error) + grpc.ClientStream +} + +type testServiceStreamingOutputCallClient struct { + grpc.ClientStream +} + +func (x *testServiceStreamingOutputCallClient) Recv() (*StreamingOutputCallResponse, error) { + m := new(StreamingOutputCallResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *testServiceClient) StreamingInputCall(ctx context.Context, opts ...grpc.CallOption) (TestService_StreamingInputCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[1], c.cc, "/grpc.testing.TestService/StreamingInputCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceStreamingInputCallClient{stream} + return x, nil +} + +type TestService_StreamingInputCallClient interface { + Send(*StreamingInputCallRequest) error + CloseAndRecv() (*StreamingInputCallResponse, error) + grpc.ClientStream +} + +type testServiceStreamingInputCallClient struct { + grpc.ClientStream +} + +func (x *testServiceStreamingInputCallClient) Send(m *StreamingInputCallRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceStreamingInputCallClient) CloseAndRecv() (*StreamingInputCallResponse, error) { + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + m := new(StreamingInputCallResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *testServiceClient) FullDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_FullDuplexCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[2], c.cc, "/grpc.testing.TestService/FullDuplexCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceFullDuplexCallClient{stream} + return x, nil +} + +type TestService_FullDuplexCallClient interface { + Send(*StreamingOutputCallRequest) error + Recv() (*StreamingOutputCallResponse, error) + grpc.ClientStream +} + +type testServiceFullDuplexCallClient struct { + grpc.ClientStream +} + +func (x *testServiceFullDuplexCallClient) Send(m *StreamingOutputCallRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceFullDuplexCallClient) Recv() (*StreamingOutputCallResponse, error) { + m := new(StreamingOutputCallResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *testServiceClient) HalfDuplexCall(ctx context.Context, opts ...grpc.CallOption) (TestService_HalfDuplexCallClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[3], c.cc, "/grpc.testing.TestService/HalfDuplexCall", opts...) + if err != nil { + return nil, err + } + x := &testServiceHalfDuplexCallClient{stream} + return x, nil +} + +type TestService_HalfDuplexCallClient interface { + Send(*StreamingOutputCallRequest) error + Recv() (*StreamingOutputCallResponse, error) + grpc.ClientStream +} + +type testServiceHalfDuplexCallClient struct { + grpc.ClientStream +} + +func (x *testServiceHalfDuplexCallClient) Send(m *StreamingOutputCallRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServiceHalfDuplexCallClient) Recv() (*StreamingOutputCallResponse, error) { + m := new(StreamingOutputCallResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// Server API for TestService service + +type TestServiceServer interface { + // One empty request followed by one empty response. + EmptyCall(context.Context, *Empty) (*Empty, error) + // One request followed by one response. + // The server returns the client payload as-is. + UnaryCall(context.Context, *SimpleRequest) (*SimpleResponse, error) + // One request followed by a sequence of responses (streamed download). + // The server returns the payload with client desired type and sizes. + StreamingOutputCall(*StreamingOutputCallRequest, TestService_StreamingOutputCallServer) error + // A sequence of requests followed by one response (streamed upload). + // The server returns the aggregated size of client payload as the result. + StreamingInputCall(TestService_StreamingInputCallServer) error + // A sequence of requests with each request served by the server immediately. + // As one request could lead to multiple responses, this interface + // demonstrates the idea of full duplexing. + FullDuplexCall(TestService_FullDuplexCallServer) error + // A sequence of requests followed by a sequence of responses. + // The server buffers all the client requests and then serves them in order. A + // stream of responses are returned to the client when the server starts with + // first request. + HalfDuplexCall(TestService_HalfDuplexCallServer) error +} + +func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { + s.RegisterService(&_TestService_serviceDesc, srv) +} + +func _TestService_EmptyCall_Handler(srv interface{}, ctx context.Context, codec grpc.Codec, buf []byte) (interface{}, error) { + in := new(Empty) + if err := codec.Unmarshal(buf, in); err != nil { + return nil, err + } + out, err := srv.(TestServiceServer).EmptyCall(ctx, in) + if err != nil { + return nil, err + } + return out, nil +} + +func _TestService_UnaryCall_Handler(srv interface{}, ctx context.Context, codec grpc.Codec, buf []byte) (interface{}, error) { + in := new(SimpleRequest) + if err := codec.Unmarshal(buf, in); err != nil { + return nil, err + } + out, err := srv.(TestServiceServer).UnaryCall(ctx, in) + if err != nil { + return nil, err + } + return out, nil +} + +func _TestService_StreamingOutputCall_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(StreamingOutputCallRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(TestServiceServer).StreamingOutputCall(m, &testServiceStreamingOutputCallServer{stream}) +} + +type TestService_StreamingOutputCallServer interface { + Send(*StreamingOutputCallResponse) error + grpc.ServerStream +} + +type testServiceStreamingOutputCallServer struct { + grpc.ServerStream +} + +func (x *testServiceStreamingOutputCallServer) Send(m *StreamingOutputCallResponse) error { + return x.ServerStream.SendMsg(m) +} + +func _TestService_StreamingInputCall_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).StreamingInputCall(&testServiceStreamingInputCallServer{stream}) +} + +type TestService_StreamingInputCallServer interface { + SendAndClose(*StreamingInputCallResponse) error + Recv() (*StreamingInputCallRequest, error) + grpc.ServerStream +} + +type testServiceStreamingInputCallServer struct { + grpc.ServerStream +} + +func (x *testServiceStreamingInputCallServer) SendAndClose(m *StreamingInputCallResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceStreamingInputCallServer) Recv() (*StreamingInputCallRequest, error) { + m := new(StreamingInputCallRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func _TestService_FullDuplexCall_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).FullDuplexCall(&testServiceFullDuplexCallServer{stream}) +} + +type TestService_FullDuplexCallServer interface { + Send(*StreamingOutputCallResponse) error + Recv() (*StreamingOutputCallRequest, error) + grpc.ServerStream +} + +type testServiceFullDuplexCallServer struct { + grpc.ServerStream +} + +func (x *testServiceFullDuplexCallServer) Send(m *StreamingOutputCallResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceFullDuplexCallServer) Recv() (*StreamingOutputCallRequest, error) { + m := new(StreamingOutputCallRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func _TestService_HalfDuplexCall_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).HalfDuplexCall(&testServiceHalfDuplexCallServer{stream}) +} + +type TestService_HalfDuplexCallServer interface { + Send(*StreamingOutputCallResponse) error + Recv() (*StreamingOutputCallRequest, error) + grpc.ServerStream +} + +type testServiceHalfDuplexCallServer struct { + grpc.ServerStream +} + +func (x *testServiceHalfDuplexCallServer) Send(m *StreamingOutputCallResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServiceHalfDuplexCallServer) Recv() (*StreamingOutputCallRequest, error) { + m := new(StreamingOutputCallRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +var _TestService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "grpc.testing.TestService", + HandlerType: (*TestServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "EmptyCall", + Handler: _TestService_EmptyCall_Handler, + }, + { + MethodName: "UnaryCall", + Handler: _TestService_UnaryCall_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "StreamingOutputCall", + Handler: _TestService_StreamingOutputCall_Handler, + ServerStreams: true, + }, + { + StreamName: "StreamingInputCall", + Handler: _TestService_StreamingInputCall_Handler, + ClientStreams: true, + }, + { + StreamName: "FullDuplexCall", + Handler: _TestService_FullDuplexCall_Handler, + ServerStreams: true, + ClientStreams: true, + }, + { + StreamName: "HalfDuplexCall", + Handler: _TestService_HalfDuplexCall_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/test/grpc_testing/test.proto b/Godeps/_workspace/src/google.golang.org/grpc/test/grpc_testing/test.proto new file mode 100644 index 0000000..b5bfe05 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/test/grpc_testing/test.proto @@ -0,0 +1,140 @@ +// An integration test service that covers all the method signature permutations +// of unary/streaming requests/responses. +syntax = "proto2"; + +package grpc.testing; + +message Empty {} + +// The type of payload that should be returned. +enum PayloadType { + // Compressable text format. + COMPRESSABLE = 0; + + // Uncompressable binary format. + UNCOMPRESSABLE = 1; + + // Randomly chosen from all other formats defined in this enum. + RANDOM = 2; +} + +// A block of data, to simply increase gRPC message size. +message Payload { + // The type of data in body. + optional PayloadType type = 1; + // Primary contents of payload. + optional bytes body = 2; +} + +// Unary request. +message SimpleRequest { + // Desired payload type in the response from the server. + // If response_type is RANDOM, server randomly chooses one from other formats. + optional PayloadType response_type = 1; + + // Desired payload size in the response from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + optional int32 response_size = 2; + + // Optional input payload sent along with the request. + optional Payload payload = 3; + + // Whether SimpleResponse should include username. + optional bool fill_username = 4; + + // Whether SimpleResponse should include OAuth scope. + optional bool fill_oauth_scope = 5; +} + +// Unary response, as configured by the request. +message SimpleResponse { + // Payload to increase message size. + optional Payload payload = 1; + + // The user the request came from, for verifying authentication was + // successful when the client expected it. + optional string username = 2; + + // OAuth scope. + optional string oauth_scope = 3; +} + +// Client-streaming request. +message StreamingInputCallRequest { + // Optional input payload sent along with the request. + optional Payload payload = 1; + + // Not expecting any payload from the response. +} + +// Client-streaming response. +message StreamingInputCallResponse { + // Aggregated size of payloads received from the client. + optional int32 aggregated_payload_size = 1; +} + +// Configuration for a particular response. +message ResponseParameters { + // Desired payload sizes in responses from the server. + // If response_type is COMPRESSABLE, this denotes the size before compression. + optional int32 size = 1; + + // Desired interval between consecutive responses in the response stream in + // microseconds. + optional int32 interval_us = 2; +} + +// Server-streaming request. +message StreamingOutputCallRequest { + // Desired payload type in the response from the server. + // If response_type is RANDOM, the payload from each response in the stream + // might be of different types. This is to simulate a mixed type of payload + // stream. + optional PayloadType response_type = 1; + + // Configuration for each expected response message. + repeated ResponseParameters response_parameters = 2; + + // Optional input payload sent along with the request. + optional Payload payload = 3; +} + +// Server-streaming response, as configured by the request and parameters. +message StreamingOutputCallResponse { + // Payload to increase response size. + optional Payload payload = 1; +} + +// A simple service to test the various types of RPCs and experiment with +// performance with various types of payload. +service TestService { + // One empty request followed by one empty response. + rpc EmptyCall(Empty) returns (Empty); + + // One request followed by one response. + // The server returns the client payload as-is. + rpc UnaryCall(SimpleRequest) returns (SimpleResponse); + + // One request followed by a sequence of responses (streamed download). + // The server returns the payload with client desired type and sizes. + rpc StreamingOutputCall(StreamingOutputCallRequest) + returns (stream StreamingOutputCallResponse); + + // A sequence of requests followed by one response (streamed upload). + // The server returns the aggregated size of client payload as the result. + rpc StreamingInputCall(stream StreamingInputCallRequest) + returns (StreamingInputCallResponse); + + // A sequence of requests with each request served by the server immediately. + // As one request could lead to multiple responses, this interface + // demonstrates the idea of full duplexing. + rpc FullDuplexCall(stream StreamingOutputCallRequest) + returns (stream StreamingOutputCallResponse); + + // A sequence of requests followed by a sequence of responses. + // The server buffers all the client requests and then serves them in order. A + // stream of responses are returned to the client when the server starts with + // first request. + rpc HalfDuplexCall(stream StreamingOutputCallRequest) + returns (stream StreamingOutputCallResponse); +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/trace.go b/Godeps/_workspace/src/google.golang.org/grpc/trace.go new file mode 100644 index 0000000..2463574 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/trace.go @@ -0,0 +1,116 @@ +/* + * + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package grpc + +import ( + "bytes" + "fmt" + "io" + "net" + "strings" + "time" + + "golang.org/x/net/trace" +) + +// EnableTracing controls whether to trace RPCs using the golang.org/x/net/trace package. +// This should only be set before any RPCs are sent or received by this program. +var EnableTracing = true + +// methodFamily returns the trace family for the given method. +// It turns "/pkg.Service/GetFoo" into "pkg.Service". +func methodFamily(m string) string { + m = strings.TrimPrefix(m, "/") // remove leading slash + if i := strings.Index(m, "/"); i >= 0 { + m = m[:i] // remove everything from second slash + } + if i := strings.LastIndex(m, "."); i >= 0 { + m = m[i+1:] // cut down to last dotted component + } + return m +} + +// traceInfo contains tracing information for an RPC. +type traceInfo struct { + tr trace.Trace + firstLine firstLine +} + +// firstLine is the first line of an RPC trace. +type firstLine struct { + client bool // whether this is a client (outgoing) RPC + remoteAddr net.Addr + deadline time.Duration // may be zero +} + +func (f *firstLine) String() string { + var line bytes.Buffer + io.WriteString(&line, "RPC: ") + if f.client { + io.WriteString(&line, "to") + } else { + io.WriteString(&line, "from") + } + fmt.Fprintf(&line, " %v deadline:", f.remoteAddr) + if f.deadline != 0 { + fmt.Fprint(&line, f.deadline) + } else { + io.WriteString(&line, "none") + } + return line.String() +} + +// payload represents an RPC request or response payload. +type payload struct { + sent bool // whether this is an outgoing payload + msg interface{} // e.g. a proto.Message + // TODO(dsymonds): add stringifying info to codec, and limit how much we hold here? +} + +func (p payload) String() string { + if p.sent { + return fmt.Sprintf("sent: %v", p.msg) + } else { + return fmt.Sprintf("recv: %v", p.msg) + } +} + +type fmtStringer struct { + format string + a []interface{} +} + +func (f *fmtStringer) String() string { + return fmt.Sprintf(f.format, f.a...) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/transport/control.go b/Godeps/_workspace/src/google.golang.org/grpc/transport/control.go new file mode 100644 index 0000000..464bf10 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/transport/control.go @@ -0,0 +1,259 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "fmt" + "sync" + + "github.com/bradfitz/http2" +) + +const ( + // The default value of flow control window size in HTTP2 spec. + defaultWindowSize = 65535 + // The initial window size for flow control. + initialWindowSize = defaultWindowSize // for an RPC + initialConnWindowSize = defaultWindowSize * 16 // for a connection +) + +// The following defines various control items which could flow through +// the control buffer of transport. They represent different aspects of +// control tasks, e.g., flow control, settings, streaming resetting, etc. +type windowUpdate struct { + streamID uint32 + increment uint32 +} + +func (windowUpdate) isItem() bool { + return true +} + +type settings struct { + ack bool + ss []http2.Setting +} + +func (settings) isItem() bool { + return true +} + +type resetStream struct { + streamID uint32 + code http2.ErrCode +} + +func (resetStream) isItem() bool { + return true +} + +type flushIO struct { +} + +func (flushIO) isItem() bool { + return true +} + +type ping struct { + ack bool +} + +func (ping) isItem() bool { + return true +} + +// quotaPool is a pool which accumulates the quota and sends it to acquire() +// when it is available. +type quotaPool struct { + c chan int + + mu sync.Mutex + quota int +} + +// newQuotaPool creates a quotaPool which has quota q available to consume. +func newQuotaPool(q int) *quotaPool { + qb := "aPool{ + c: make(chan int, 1), + } + if q > 0 { + qb.c <- q + } else { + qb.quota = q + } + return qb +} + +// add adds n to the available quota and tries to send it on acquire. +func (qb *quotaPool) add(n int) { + qb.mu.Lock() + defer qb.mu.Unlock() + qb.quota += n + if qb.quota <= 0 { + return + } + select { + case qb.c <- qb.quota: + qb.quota = 0 + default: + } +} + +// cancel cancels the pending quota sent on acquire, if any. +func (qb *quotaPool) cancel() { + qb.mu.Lock() + defer qb.mu.Unlock() + select { + case n := <-qb.c: + qb.quota += n + default: + } +} + +// reset cancels the pending quota sent on acquired, incremented by v and sends +// it back on acquire. +func (qb *quotaPool) reset(v int) { + qb.mu.Lock() + defer qb.mu.Unlock() + select { + case n := <-qb.c: + qb.quota += n + default: + } + qb.quota += v + if qb.quota <= 0 { + return + } + select { + case qb.c <- qb.quota: + qb.quota = 0 + default: + } +} + +// acquire returns the channel on which available quota amounts are sent. +func (qb *quotaPool) acquire() <-chan int { + return qb.c +} + +// inFlow deals with inbound flow control +type inFlow struct { + // The inbound flow control limit for pending data. + limit uint32 + // conn points to the shared connection-level inFlow that is shared + // by all streams on that conn. It is nil for the inFlow on the conn + // directly. + conn *inFlow + + mu sync.Mutex + // pendingData is the overall data which have been received but not been + // consumed by applications. + pendingData uint32 + // The amount of data the application has consumed but grpc has not sent + // window update for them. Used to reduce window update frequency. + pendingUpdate uint32 +} + +// onData is invoked when some data frame is received. It increments not only its +// own pendingData but also that of the associated connection-level flow. +func (f *inFlow) onData(n uint32) error { + if n == 0 { + return nil + } + f.mu.Lock() + defer f.mu.Unlock() + if f.pendingData+f.pendingUpdate+n > f.limit { + return fmt.Errorf("recieved %d-bytes data exceeding the limit %d bytes", f.pendingData+f.pendingUpdate+n, f.limit) + } + if f.conn != nil { + if err := f.conn.onData(n); err != nil { + return ConnectionErrorf("%v", err) + } + } + f.pendingData += n + return nil +} + +// connOnRead updates the connection level states when the application consumes data. +func (f *inFlow) connOnRead(n uint32) uint32 { + if n == 0 || f.conn != nil { + return 0 + } + f.mu.Lock() + defer f.mu.Unlock() + f.pendingData -= n + f.pendingUpdate += n + if f.pendingUpdate >= f.limit/4 { + ret := f.pendingUpdate + f.pendingUpdate = 0 + return ret + } + return 0 +} + +// onRead is invoked when the application reads the data. It returns the window updates +// for both stream and connection level. +func (f *inFlow) onRead(n uint32) (swu, cwu uint32) { + if n == 0 { + return + } + f.mu.Lock() + defer f.mu.Unlock() + if f.pendingData == 0 { + // pendingData has been adjusted by restoreConn. + return + } + f.pendingData -= n + f.pendingUpdate += n + if f.pendingUpdate >= f.limit/4 { + swu = f.pendingUpdate + f.pendingUpdate = 0 + } + cwu = f.conn.connOnRead(n) + return +} + +// restoreConn is invoked when a stream is terminated. It removes its stake in +// the connection-level flow and resets its own state. +func (f *inFlow) restoreConn() uint32 { + if f.conn == nil { + return 0 + } + f.mu.Lock() + defer f.mu.Unlock() + n := f.pendingData + f.pendingData = 0 + f.pendingUpdate = 0 + return f.conn.connOnRead(n) +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/transport/http2_client.go b/Godeps/_workspace/src/google.golang.org/grpc/transport/http2_client.go new file mode 100644 index 0000000..40b7640 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/transport/http2_client.go @@ -0,0 +1,849 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "bytes" + "errors" + "io" + "math" + "net" + "strings" + "sync" + "time" + + "github.com/bradfitz/http2" + "github.com/bradfitz/http2/hpack" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" +) + +// http2Client implements the ClientTransport interface with HTTP2. +type http2Client struct { + target string // server name/addr + userAgent string + conn net.Conn // underlying communication channel + authInfo credentials.AuthInfo // auth info about the connection + nextID uint32 // the next stream ID to be used + + // writableChan synchronizes write access to the transport. + // A writer acquires the write lock by sending a value on writableChan + // and releases it by receiving from writableChan. + writableChan chan int + // shutdownChan is closed when Close is called. + // Blocking operations should select on shutdownChan to avoid + // blocking forever after Close. + // TODO(zhaoq): Maybe have a channel context? + shutdownChan chan struct{} + // errorChan is closed to notify the I/O error to the caller. + errorChan chan struct{} + + framer *framer + hBuf *bytes.Buffer // the buffer for HPACK encoding + hEnc *hpack.Encoder // HPACK encoder + + // controlBuf delivers all the control related tasks (e.g., window + // updates, reset streams, and various settings) to the controller. + controlBuf *recvBuffer + fc *inFlow + // sendQuotaPool provides flow control to outbound message. + sendQuotaPool *quotaPool + // streamsQuota limits the max number of concurrent streams. + streamsQuota *quotaPool + + // The scheme used: https if TLS is on, http otherwise. + scheme string + + authCreds []credentials.Credentials + + mu sync.Mutex // guard the following variables + state transportState // the state of underlying connection + activeStreams map[uint32]*Stream + // The max number of concurrent streams + maxStreams int + // the per-stream outbound flow control window size set by the peer. + streamSendQuota uint32 +} + +// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2 +// and starts to receive messages on it. Non-nil error returns if construction +// fails. +func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err error) { + if opts.Dialer == nil { + // Set the default Dialer. + opts.Dialer = func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("tcp", addr, timeout) + } + } + scheme := "http" + startT := time.Now() + timeout := opts.Timeout + conn, connErr := opts.Dialer(addr, timeout) + if connErr != nil { + return nil, ConnectionErrorf("transport: %v", connErr) + } + var authInfo credentials.AuthInfo + for _, c := range opts.AuthOptions { + if ccreds, ok := c.(credentials.TransportAuthenticator); ok { + scheme = "https" + // TODO(zhaoq): Now the first TransportAuthenticator is used if there are + // multiple ones provided. Revisit this if it is not appropriate. Probably + // place the ClientTransport construction into a separate function to make + // things clear. + if timeout > 0 { + timeout -= time.Since(startT) + } + conn, authInfo, connErr = ccreds.ClientHandshake(addr, conn, timeout) + break + } + } + if connErr != nil { + return nil, ConnectionErrorf("transport: %v", connErr) + } + defer func() { + if err != nil { + conn.Close() + } + }() + // Send connection preface to server. + n, err := conn.Write(clientPreface) + if err != nil { + return nil, ConnectionErrorf("transport: %v", err) + } + if n != len(clientPreface) { + return nil, ConnectionErrorf("transport: preface mismatch, wrote %d bytes; want %d", n, len(clientPreface)) + } + framer := newFramer(conn) + if initialWindowSize != defaultWindowSize { + err = framer.writeSettings(true, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)}) + } else { + err = framer.writeSettings(true) + } + if err != nil { + return nil, ConnectionErrorf("transport: %v", err) + } + // Adjust the connection flow control window if needed. + if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { + if err := framer.writeWindowUpdate(true, 0, delta); err != nil { + return nil, ConnectionErrorf("transport: %v", err) + } + } + ua := primaryUA + if opts.UserAgent != "" { + ua = opts.UserAgent + " " + ua + } + var buf bytes.Buffer + t := &http2Client{ + target: addr, + userAgent: ua, + conn: conn, + authInfo: authInfo, + // The client initiated stream id is odd starting from 1. + nextID: 1, + writableChan: make(chan int, 1), + shutdownChan: make(chan struct{}), + errorChan: make(chan struct{}), + framer: framer, + hBuf: &buf, + hEnc: hpack.NewEncoder(&buf), + controlBuf: newRecvBuffer(), + fc: &inFlow{limit: initialConnWindowSize}, + sendQuotaPool: newQuotaPool(defaultWindowSize), + scheme: scheme, + state: reachable, + activeStreams: make(map[uint32]*Stream), + authCreds: opts.AuthOptions, + maxStreams: math.MaxInt32, + streamSendQuota: defaultWindowSize, + } + go t.controller() + t.writableChan <- 0 + // Start the reader goroutine for incoming message. The threading model + // on receiving is that each transport has a dedicated goroutine which + // reads HTTP2 frame from network. Then it dispatches the frame to the + // corresponding stream entity. + go t.reader() + return t, nil +} + +func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *Stream { + fc := &inFlow{ + limit: initialWindowSize, + conn: t.fc, + } + // TODO(zhaoq): Handle uint32 overflow of Stream.id. + s := &Stream{ + id: t.nextID, + method: callHdr.Method, + buf: newRecvBuffer(), + fc: fc, + sendQuotaPool: newQuotaPool(int(t.streamSendQuota)), + headerChan: make(chan struct{}), + } + t.nextID += 2 + s.windowHandler = func(n int) { + t.updateWindow(s, uint32(n)) + } + // Make a stream be able to cancel the pending operations by itself. + s.ctx, s.cancel = context.WithCancel(ctx) + s.dec = &recvBufferReader{ + ctx: s.ctx, + recv: s.buf, + } + return s +} + +// NewStream creates a stream and register it into the transport as "active" +// streams. +func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { + // Record the timeout value on the context. + var timeout time.Duration + if dl, ok := ctx.Deadline(); ok { + timeout = dl.Sub(time.Now()) + if timeout <= 0 { + return nil, ContextErr(context.DeadlineExceeded) + } + } + // Attach Auth info if there is any. + if t.authInfo != nil { + ctx = credentials.NewContext(ctx, t.authInfo) + } + authData := make(map[string]string) + for _, c := range t.authCreds { + // Construct URI required to get auth request metadata. + var port string + if pos := strings.LastIndex(t.target, ":"); pos != -1 { + // Omit port if it is the default one. + if t.target[pos+1:] != "443" { + port = ":" + t.target[pos+1:] + } + } + pos := strings.LastIndex(callHdr.Method, "/") + if pos == -1 { + return nil, StreamErrorf(codes.InvalidArgument, "transport: malformed method name: %q", callHdr.Method) + } + audience := "https://" + callHdr.Host + port + callHdr.Method[:pos] + data, err := c.GetRequestMetadata(ctx, audience) + if err != nil { + return nil, StreamErrorf(codes.InvalidArgument, "transport: %v", err) + } + for k, v := range data { + authData[k] = v + } + } + t.mu.Lock() + if t.state != reachable { + t.mu.Unlock() + return nil, ErrConnClosing + } + checkStreamsQuota := t.streamsQuota != nil + t.mu.Unlock() + if checkStreamsQuota { + sq, err := wait(ctx, t.shutdownChan, t.streamsQuota.acquire()) + if err != nil { + return nil, err + } + // Returns the quota balance back. + if sq > 1 { + t.streamsQuota.add(sq - 1) + } + } + if _, err := wait(ctx, t.shutdownChan, t.writableChan); err != nil { + // t.streamsQuota will be updated when t.CloseStream is invoked. + return nil, err + } + t.mu.Lock() + if t.state != reachable { + t.mu.Unlock() + return nil, ErrConnClosing + } + s := t.newStream(ctx, callHdr) + t.activeStreams[s.id] = s + t.mu.Unlock() + // HPACK encodes various headers. Note that once WriteField(...) is + // called, the corresponding headers/continuation frame has to be sent + // because hpack.Encoder is stateful. + t.hBuf.Reset() + t.hEnc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}) + t.hEnc.WriteField(hpack.HeaderField{Name: ":scheme", Value: t.scheme}) + t.hEnc.WriteField(hpack.HeaderField{Name: ":path", Value: callHdr.Method}) + t.hEnc.WriteField(hpack.HeaderField{Name: ":authority", Value: callHdr.Host}) + t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + t.hEnc.WriteField(hpack.HeaderField{Name: "user-agent", Value: t.userAgent}) + t.hEnc.WriteField(hpack.HeaderField{Name: "te", Value: "trailers"}) + + if timeout > 0 { + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)}) + } + for k, v := range authData { + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: v}) + } + var ( + hasMD bool + endHeaders bool + ) + if md, ok := metadata.FromContext(ctx); ok { + hasMD = true + for k, v := range md { + for _, entry := range v { + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) + } + } + } + first := true + // Sends the headers in a single batch even when they span multiple frames. + for !endHeaders { + size := t.hBuf.Len() + if size > http2MaxFrameLen { + size = http2MaxFrameLen + } else { + endHeaders = true + } + if first { + // Sends a HeadersFrame to server to start a new stream. + p := http2.HeadersFrameParam{ + StreamID: s.id, + BlockFragment: t.hBuf.Next(size), + EndStream: false, + EndHeaders: endHeaders, + } + // Do a force flush for the buffered frames iff it is the last headers frame + // and there is header metadata to be sent. Otherwise, there is flushing until + // the corresponding data frame is written. + err = t.framer.writeHeaders(hasMD && endHeaders, p) + first = false + } else { + // Sends Continuation frames for the leftover headers. + err = t.framer.writeContinuation(hasMD && endHeaders, s.id, endHeaders, t.hBuf.Next(size)) + } + if err != nil { + t.notifyError(err) + return nil, ConnectionErrorf("transport: %v", err) + } + } + t.writableChan <- 0 + return s, nil +} + +// CloseStream clears the footprint of a stream when the stream is not needed any more. +// This must not be executed in reader's goroutine. +func (t *http2Client) CloseStream(s *Stream, err error) { + var updateStreams bool + t.mu.Lock() + if t.streamsQuota != nil { + updateStreams = true + } + delete(t.activeStreams, s.id) + t.mu.Unlock() + if updateStreams { + t.streamsQuota.add(1) + } + s.mu.Lock() + if q := s.fc.restoreConn(); q > 0 { + t.controlBuf.put(&windowUpdate{0, q}) + } + if s.state == streamDone { + s.mu.Unlock() + return + } + if !s.headerDone { + close(s.headerChan) + s.headerDone = true + } + s.state = streamDone + s.mu.Unlock() + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), the caller needs + // to call cancel on the stream to interrupt the blocking on + // other goroutines. + s.cancel() + if _, ok := err.(StreamError); ok { + t.controlBuf.put(&resetStream{s.id, http2.ErrCodeCancel}) + } +} + +// Close kicks off the shutdown process of the transport. This should be called +// only once on a transport. Once it is called, the transport should not be +// accessed any more. +func (t *http2Client) Close() (err error) { + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return errors.New("transport: Close() was already called") + } + t.state = closing + t.mu.Unlock() + close(t.shutdownChan) + err = t.conn.Close() + t.mu.Lock() + streams := t.activeStreams + t.activeStreams = nil + t.mu.Unlock() + // Notify all active streams. + for _, s := range streams { + s.mu.Lock() + if !s.headerDone { + close(s.headerChan) + s.headerDone = true + } + s.mu.Unlock() + s.write(recvMsg{err: ErrConnClosing}) + } + return +} + +// Write formats the data into HTTP2 data frame(s) and sends it out. The caller +// should proceed only if Write returns nil. +// TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later +// if it improves the performance. +func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error { + r := bytes.NewBuffer(data) + for { + var p []byte + if r.Len() > 0 { + size := http2MaxFrameLen + s.sendQuotaPool.add(0) + // Wait until the stream has some quota to send the data. + sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) + if err != nil { + return err + } + t.sendQuotaPool.add(0) + // Wait until the transport has some quota to send the data. + tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) + if err != nil { + if _, ok := err.(StreamError); ok { + t.sendQuotaPool.cancel() + } + return err + } + if sq < size { + size = sq + } + if tq < size { + size = tq + } + p = r.Next(size) + ps := len(p) + if ps < sq { + // Overbooked stream quota. Return it back. + s.sendQuotaPool.add(sq - ps) + } + if ps < tq { + // Overbooked transport quota. Return it back. + t.sendQuotaPool.add(tq - ps) + } + } + var ( + endStream bool + forceFlush bool + ) + if opts.Last && r.Len() == 0 { + endStream = true + } + // Indicate there is a writer who is about to write a data frame. + t.framer.adjustNumWriters(1) + // Got some quota. Try to acquire writing privilege on the transport. + if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if t.framer.adjustNumWriters(-1) == 0 { + // This writer is the last one in this batch and has the + // responsibility to flush the buffered frames. It queues + // a flush request to controlBuf instead of flushing directly + // in order to avoid the race with other writing or flushing. + t.controlBuf.put(&flushIO{}) + } + return err + } + if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 { + // Do a force flush iff this is last frame for the entire gRPC message + // and the caller is the only writer at this moment. + forceFlush = true + } + // If WriteData fails, all the pending streams will be handled + // by http2Client.Close(). No explicit CloseStream() needs to be + // invoked. + if err := t.framer.writeData(forceFlush, s.id, endStream, p); err != nil { + t.notifyError(err) + return ConnectionErrorf("transport: %v", err) + } + if t.framer.adjustNumWriters(-1) == 0 { + t.framer.flushWrite() + } + t.writableChan <- 0 + if r.Len() == 0 { + break + } + } + if !opts.Last { + return nil + } + s.mu.Lock() + if s.state != streamDone { + if s.state == streamReadDone { + s.state = streamDone + } else { + s.state = streamWriteDone + } + } + s.mu.Unlock() + return nil +} + +func (t *http2Client) getStream(f http2.Frame) (*Stream, bool) { + t.mu.Lock() + defer t.mu.Unlock() + if t.activeStreams == nil { + // The transport is closing. + return nil, false + } + if s, ok := t.activeStreams[f.Header().StreamID]; ok { + return s, true + } + return nil, false +} + +// updateWindow adjusts the inbound quota for the stream and the transport. +// Window updates will deliver to the controller for sending when +// the cumulative quota exceeds the corresponding threshold. +func (t *http2Client) updateWindow(s *Stream, n uint32) { + swu, cwu := s.fc.onRead(n) + if swu > 0 { + t.controlBuf.put(&windowUpdate{s.id, swu}) + } + if cwu > 0 { + t.controlBuf.put(&windowUpdate{0, cwu}) + } +} + +func (t *http2Client) handleData(f *http2.DataFrame) { + // Select the right stream to dispatch. + s, ok := t.getStream(f) + if !ok { + return + } + size := len(f.Data()) + if size > 0 { + if err := s.fc.onData(uint32(size)); err != nil { + if _, ok := err.(ConnectionError); ok { + t.notifyError(err) + return + } + s.mu.Lock() + if s.state == streamDone { + s.mu.Unlock() + return + } + s.state = streamDone + s.statusCode = codes.Internal + s.statusDesc = err.Error() + s.mu.Unlock() + s.write(recvMsg{err: io.EOF}) + t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) + return + } + // TODO(bradfitz, zhaoq): A copy is required here because there is no + // guarantee f.Data() is consumed before the arrival of next frame. + // Can this copy be eliminated? + data := make([]byte, size) + copy(data, f.Data()) + s.write(recvMsg{data: data}) + } + // The server has closed the stream without sending trailers. Record that + // the read direction is closed, and set the status appropriately. + if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) { + s.mu.Lock() + if s.state == streamWriteDone { + s.state = streamDone + } else { + s.state = streamReadDone + } + s.statusCode = codes.Internal + s.statusDesc = "server closed the stream without sending trailers" + s.mu.Unlock() + s.write(recvMsg{err: io.EOF}) + } +} + +func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) { + s, ok := t.getStream(f) + if !ok { + return + } + s.mu.Lock() + if s.state == streamDone { + s.mu.Unlock() + return + } + s.state = streamDone + if !s.headerDone { + close(s.headerChan) + s.headerDone = true + } + s.statusCode, ok = http2RSTErrConvTab[http2.ErrCode(f.ErrCode)] + if !ok { + grpclog.Println("transport: http2Client.handleRSTStream found no mapped gRPC status for the received http2 error ", f.ErrCode) + } + s.mu.Unlock() + s.write(recvMsg{err: io.EOF}) +} + +func (t *http2Client) handleSettings(f *http2.SettingsFrame) { + if f.IsAck() { + return + } + var ss []http2.Setting + f.ForeachSetting(func(s http2.Setting) error { + ss = append(ss, s) + return nil + }) + // The settings will be applied once the ack is sent. + t.controlBuf.put(&settings{ack: true, ss: ss}) +} + +func (t *http2Client) handlePing(f *http2.PingFrame) { + t.controlBuf.put(&ping{true}) +} + +func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) { + // TODO(zhaoq): GoAwayFrame handler to be implemented +} + +func (t *http2Client) handleWindowUpdate(f *http2.WindowUpdateFrame) { + id := f.Header().StreamID + incr := f.Increment + if id == 0 { + t.sendQuotaPool.add(int(incr)) + return + } + if s, ok := t.getStream(f); ok { + s.sendQuotaPool.add(int(incr)) + } +} + +// operateHeader takes action on the decoded headers. It returns the current +// stream if there are remaining headers on the wire (in the following +// Continuation frame). +func (t *http2Client) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool) (pendingStream *Stream) { + defer func() { + if pendingStream == nil { + hDec.state = decodeState{} + } + }() + endHeaders, err := hDec.decodeClientHTTP2Headers(frame) + if s == nil { + // s has been closed. + return nil + } + if err != nil { + s.write(recvMsg{err: err}) + // Something wrong. Stops reading even when there is remaining. + return nil + } + if !endHeaders { + return s + } + + s.mu.Lock() + if !s.headerDone { + if !endStream && len(hDec.state.mdata) > 0 { + s.header = hDec.state.mdata + } + close(s.headerChan) + s.headerDone = true + } + if !endStream || s.state == streamDone { + s.mu.Unlock() + return nil + } + + if len(hDec.state.mdata) > 0 { + s.trailer = hDec.state.mdata + } + s.state = streamDone + s.statusCode = hDec.state.statusCode + s.statusDesc = hDec.state.statusDesc + s.mu.Unlock() + + s.write(recvMsg{err: io.EOF}) + return nil +} + +// reader runs as a separate goroutine in charge of reading data from network +// connection. +// +// TODO(zhaoq): currently one reader per transport. Investigate whether this is +// optimal. +// TODO(zhaoq): Check the validity of the incoming frame sequence. +func (t *http2Client) reader() { + // Check the validity of server preface. + frame, err := t.framer.readFrame() + if err != nil { + t.notifyError(err) + return + } + sf, ok := frame.(*http2.SettingsFrame) + if !ok { + t.notifyError(err) + return + } + t.handleSettings(sf) + + hDec := newHPACKDecoder() + var curStream *Stream + // loop to keep reading incoming messages on this transport. + for { + frame, err := t.framer.readFrame() + if err != nil { + t.notifyError(err) + return + } + switch frame := frame.(type) { + case *http2.HeadersFrame: + // operateHeaders has to be invoked regardless the value of curStream + // because the HPACK decoder needs to be updated using the received + // headers. + curStream, _ = t.getStream(frame) + endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream) + curStream = t.operateHeaders(hDec, curStream, frame, endStream) + case *http2.ContinuationFrame: + curStream = t.operateHeaders(hDec, curStream, frame, false) + case *http2.DataFrame: + t.handleData(frame) + case *http2.RSTStreamFrame: + t.handleRSTStream(frame) + case *http2.SettingsFrame: + t.handleSettings(frame) + case *http2.PingFrame: + t.handlePing(frame) + case *http2.GoAwayFrame: + t.handleGoAway(frame) + case *http2.WindowUpdateFrame: + t.handleWindowUpdate(frame) + default: + grpclog.Printf("transport: http2Client.reader got unhandled frame type %v.", frame) + } + } +} + +func (t *http2Client) applySettings(ss []http2.Setting) { + for _, s := range ss { + switch s.ID { + case http2.SettingMaxConcurrentStreams: + // TODO(zhaoq): This is a hack to avoid significant refactoring of the + // code to deal with the unrealistic int32 overflow. Probably will try + // to find a better way to handle this later. + if s.Val > math.MaxInt32 { + s.Val = math.MaxInt32 + } + t.mu.Lock() + reset := t.streamsQuota != nil + if !reset { + t.streamsQuota = newQuotaPool(int(s.Val) - len(t.activeStreams)) + } + ms := t.maxStreams + t.maxStreams = int(s.Val) + t.mu.Unlock() + if reset { + t.streamsQuota.reset(int(s.Val) - ms) + } + case http2.SettingInitialWindowSize: + t.mu.Lock() + for _, stream := range t.activeStreams { + // Adjust the sending quota for each stream. + stream.sendQuotaPool.reset(int(s.Val - t.streamSendQuota)) + } + t.streamSendQuota = s.Val + t.mu.Unlock() + } + } +} + +// controller running in a separate goroutine takes charge of sending control +// frames (e.g., window update, reset stream, setting, etc.) to the server. +func (t *http2Client) controller() { + for { + select { + case i := <-t.controlBuf.get(): + t.controlBuf.load() + select { + case <-t.writableChan: + switch i := i.(type) { + case *windowUpdate: + t.framer.writeWindowUpdate(true, i.streamID, i.increment) + case *settings: + if i.ack { + t.framer.writeSettingsAck(true) + t.applySettings(i.ss) + } else { + t.framer.writeSettings(true, i.ss...) + } + case *resetStream: + t.framer.writeRSTStream(true, i.streamID, i.code) + case *flushIO: + t.framer.flushWrite() + case *ping: + // TODO(zhaoq): Ack with all-0 data now. will change to some + // meaningful content when this is actually in use. + t.framer.writePing(true, i.ack, [8]byte{}) + default: + grpclog.Printf("transport: http2Client.controller got unexpected item type %v\n", i) + } + t.writableChan <- 0 + continue + case <-t.shutdownChan: + return + } + case <-t.shutdownChan: + return + } + } +} + +func (t *http2Client) Error() <-chan struct{} { + return t.errorChan +} + +func (t *http2Client) notifyError(err error) { + t.mu.Lock() + defer t.mu.Unlock() + // make sure t.errorChan is closed only once. + if t.state == reachable { + t.state = unreachable + close(t.errorChan) + grpclog.Printf("transport: http2Client.notifyError got notified that the client transport was broken %v.", err) + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/transport/http2_server.go b/Godeps/_workspace/src/google.golang.org/grpc/transport/http2_server.go new file mode 100644 index 0000000..8856d7f --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/transport/http2_server.go @@ -0,0 +1,691 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "bytes" + "errors" + "io" + "math" + "net" + "strconv" + "sync" + + "github.com/bradfitz/http2" + "github.com/bradfitz/http2/hpack" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" +) + +// ErrIllegalHeaderWrite indicates that setting header is illegal because of +// the stream's state. +var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHeader was already called") + +// http2Server implements the ServerTransport interface with HTTP2. +type http2Server struct { + conn net.Conn + maxStreamID uint32 // max stream ID ever seen + authInfo credentials.AuthInfo // auth info about the connection + // writableChan synchronizes write access to the transport. + // A writer acquires the write lock by sending a value on writableChan + // and releases it by receiving from writableChan. + writableChan chan int + // shutdownChan is closed when Close is called. + // Blocking operations should select on shutdownChan to avoid + // blocking forever after Close. + shutdownChan chan struct{} + framer *framer + hBuf *bytes.Buffer // the buffer for HPACK encoding + hEnc *hpack.Encoder // HPACK encoder + + // The max number of concurrent streams. + maxStreams uint32 + // controlBuf delivers all the control related tasks (e.g., window + // updates, reset streams, and various settings) to the controller. + controlBuf *recvBuffer + fc *inFlow + // sendQuotaPool provides flow control to outbound message. + sendQuotaPool *quotaPool + + mu sync.Mutex // guard the following + state transportState + activeStreams map[uint32]*Stream + // the per-stream outbound flow control window size set by the peer. + streamSendQuota uint32 +} + +// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is +// returned if something goes wrong. +func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (_ ServerTransport, err error) { + framer := newFramer(conn) + // Send initial settings as connection preface to client. + var settings []http2.Setting + // TODO(zhaoq): Have a better way to signal "no limit" because 0 is + // permitted in the HTTP2 spec. + if maxStreams == 0 { + maxStreams = math.MaxUint32 + } else { + settings = append(settings, http2.Setting{http2.SettingMaxConcurrentStreams, maxStreams}) + } + if initialWindowSize != defaultWindowSize { + settings = append(settings, http2.Setting{http2.SettingInitialWindowSize, uint32(initialWindowSize)}) + } + if err := framer.writeSettings(true, settings...); err != nil { + return nil, ConnectionErrorf("transport: %v", err) + } + // Adjust the connection flow control window if needed. + if delta := uint32(initialConnWindowSize - defaultWindowSize); delta > 0 { + if err := framer.writeWindowUpdate(true, 0, delta); err != nil { + return nil, ConnectionErrorf("transport: %v", err) + } + } + var buf bytes.Buffer + t := &http2Server{ + conn: conn, + authInfo: authInfo, + framer: framer, + hBuf: &buf, + hEnc: hpack.NewEncoder(&buf), + maxStreams: maxStreams, + controlBuf: newRecvBuffer(), + fc: &inFlow{limit: initialConnWindowSize}, + sendQuotaPool: newQuotaPool(defaultWindowSize), + state: reachable, + writableChan: make(chan int, 1), + shutdownChan: make(chan struct{}), + activeStreams: make(map[uint32]*Stream), + streamSendQuota: defaultWindowSize, + } + go t.controller() + t.writableChan <- 0 + return t, nil +} + +// operateHeader takes action on the decoded headers. It returns the current +// stream if there are remaining headers on the wire (in the following +// Continuation frame). +func (t *http2Server) operateHeaders(hDec *hpackDecoder, s *Stream, frame headerFrame, endStream bool, handle func(*Stream), wg *sync.WaitGroup) (pendingStream *Stream) { + defer func() { + if pendingStream == nil { + hDec.state = decodeState{} + } + }() + endHeaders, err := hDec.decodeServerHTTP2Headers(frame) + if s == nil { + // s has been closed. + return nil + } + if err != nil { + grpclog.Printf("transport: http2Server.operateHeader found %v", err) + if se, ok := err.(StreamError); ok { + t.controlBuf.put(&resetStream{s.id, statusCodeConvTab[se.Code]}) + } + return nil + } + if endStream { + // s is just created by the caller. No lock needed. + s.state = streamReadDone + } + if !endHeaders { + return s + } + t.mu.Lock() + if t.state != reachable { + t.mu.Unlock() + return nil + } + if uint32(len(t.activeStreams)) >= t.maxStreams { + t.mu.Unlock() + t.controlBuf.put(&resetStream{s.id, http2.ErrCodeRefusedStream}) + return nil + } + s.sendQuotaPool = newQuotaPool(int(t.streamSendQuota)) + t.activeStreams[s.id] = s + t.mu.Unlock() + s.windowHandler = func(n int) { + t.updateWindow(s, uint32(n)) + } + if hDec.state.timeoutSet { + s.ctx, s.cancel = context.WithTimeout(context.TODO(), hDec.state.timeout) + } else { + s.ctx, s.cancel = context.WithCancel(context.TODO()) + } + // Attach Auth info if there is any. + if t.authInfo != nil { + s.ctx = credentials.NewContext(s.ctx, t.authInfo) + } + // Cache the current stream to the context so that the server application + // can find out. Required when the server wants to send some metadata + // back to the client (unary call only). + s.ctx = newContextWithStream(s.ctx, s) + // Attach the received metadata to the context. + if len(hDec.state.mdata) > 0 { + s.ctx = metadata.NewContext(s.ctx, hDec.state.mdata) + } + + s.dec = &recvBufferReader{ + ctx: s.ctx, + recv: s.buf, + } + s.method = hDec.state.method + + wg.Add(1) + go func() { + handle(s) + wg.Done() + }() + return nil +} + +// HandleStreams receives incoming streams using the given handler. This is +// typically run in a separate goroutine. +func (t *http2Server) HandleStreams(handle func(*Stream)) { + // Check the validity of client preface. + preface := make([]byte, len(clientPreface)) + if _, err := io.ReadFull(t.conn, preface); err != nil { + grpclog.Printf("transport: http2Server.HandleStreams failed to receive the preface from client: %v", err) + t.Close() + return + } + if !bytes.Equal(preface, clientPreface) { + grpclog.Printf("transport: http2Server.HandleStreams received bogus greeting from client: %q", preface) + t.Close() + return + } + + frame, err := t.framer.readFrame() + if err != nil { + grpclog.Printf("transport: http2Server.HandleStreams failed to read frame: %v", err) + t.Close() + return + } + sf, ok := frame.(*http2.SettingsFrame) + if !ok { + grpclog.Printf("transport: http2Server.HandleStreams saw invalid preface type %T from client", frame) + t.Close() + return + } + t.handleSettings(sf) + + hDec := newHPACKDecoder() + var curStream *Stream + var wg sync.WaitGroup + defer wg.Wait() + for { + frame, err := t.framer.readFrame() + if err != nil { + t.Close() + return + } + switch frame := frame.(type) { + case *http2.HeadersFrame: + id := frame.Header().StreamID + if id%2 != 1 || id <= t.maxStreamID { + // illegal gRPC stream id. + grpclog.Println("transport: http2Server.HandleStreams received an illegal stream id: ", id) + t.Close() + break + } + t.maxStreamID = id + buf := newRecvBuffer() + fc := &inFlow{ + limit: initialWindowSize, + conn: t.fc, + } + curStream = &Stream{ + id: frame.Header().StreamID, + st: t, + buf: buf, + fc: fc, + } + endStream := frame.Header().Flags.Has(http2.FlagHeadersEndStream) + curStream = t.operateHeaders(hDec, curStream, frame, endStream, handle, &wg) + case *http2.ContinuationFrame: + curStream = t.operateHeaders(hDec, curStream, frame, false, handle, &wg) + case *http2.DataFrame: + t.handleData(frame) + case *http2.RSTStreamFrame: + t.handleRSTStream(frame) + case *http2.SettingsFrame: + t.handleSettings(frame) + case *http2.PingFrame: + t.handlePing(frame) + case *http2.WindowUpdateFrame: + t.handleWindowUpdate(frame) + case *http2.GoAwayFrame: + break + default: + grpclog.Printf("transport: http2Server.HandleStreams found unhandled frame type %v.", frame) + } + } +} + +func (t *http2Server) getStream(f http2.Frame) (*Stream, bool) { + t.mu.Lock() + defer t.mu.Unlock() + if t.activeStreams == nil { + // The transport is closing. + return nil, false + } + s, ok := t.activeStreams[f.Header().StreamID] + if !ok { + // The stream is already done. + return nil, false + } + return s, true +} + +// updateWindow adjusts the inbound quota for the stream and the transport. +// Window updates will deliver to the controller for sending when +// the cumulative quota exceeds the corresponding threshold. +func (t *http2Server) updateWindow(s *Stream, n uint32) { + swu, cwu := s.fc.onRead(n) + if swu > 0 { + t.controlBuf.put(&windowUpdate{s.id, swu}) + } + if cwu > 0 { + t.controlBuf.put(&windowUpdate{0, cwu}) + } +} + +func (t *http2Server) handleData(f *http2.DataFrame) { + // Select the right stream to dispatch. + s, ok := t.getStream(f) + if !ok { + return + } + size := len(f.Data()) + if size > 0 { + if err := s.fc.onData(uint32(size)); err != nil { + if _, ok := err.(ConnectionError); ok { + grpclog.Printf("transport: http2Server %v", err) + t.Close() + return + } + t.closeStream(s) + t.controlBuf.put(&resetStream{s.id, http2.ErrCodeFlowControl}) + return + } + // TODO(bradfitz, zhaoq): A copy is required here because there is no + // guarantee f.Data() is consumed before the arrival of next frame. + // Can this copy be eliminated? + data := make([]byte, size) + copy(data, f.Data()) + s.write(recvMsg{data: data}) + } + if f.Header().Flags.Has(http2.FlagDataEndStream) { + // Received the end of stream from the client. + s.mu.Lock() + if s.state != streamDone { + if s.state == streamWriteDone { + s.state = streamDone + } else { + s.state = streamReadDone + } + } + s.mu.Unlock() + s.write(recvMsg{err: io.EOF}) + } +} + +func (t *http2Server) handleRSTStream(f *http2.RSTStreamFrame) { + s, ok := t.getStream(f) + if !ok { + return + } + t.closeStream(s) +} + +func (t *http2Server) handleSettings(f *http2.SettingsFrame) { + if f.IsAck() { + return + } + var ss []http2.Setting + f.ForeachSetting(func(s http2.Setting) error { + ss = append(ss, s) + return nil + }) + // The settings will be applied once the ack is sent. + t.controlBuf.put(&settings{ack: true, ss: ss}) +} + +func (t *http2Server) handlePing(f *http2.PingFrame) { + t.controlBuf.put(&ping{true}) +} + +func (t *http2Server) handleWindowUpdate(f *http2.WindowUpdateFrame) { + id := f.Header().StreamID + incr := f.Increment + if id == 0 { + t.sendQuotaPool.add(int(incr)) + return + } + if s, ok := t.getStream(f); ok { + s.sendQuotaPool.add(int(incr)) + } +} + +func (t *http2Server) writeHeaders(s *Stream, b *bytes.Buffer, endStream bool) error { + first := true + endHeaders := false + var err error + // Sends the headers in a single batch. + for !endHeaders { + size := t.hBuf.Len() + if size > http2MaxFrameLen { + size = http2MaxFrameLen + } else { + endHeaders = true + } + if first { + p := http2.HeadersFrameParam{ + StreamID: s.id, + BlockFragment: b.Next(size), + EndStream: endStream, + EndHeaders: endHeaders, + } + err = t.framer.writeHeaders(endHeaders, p) + first = false + } else { + err = t.framer.writeContinuation(endHeaders, s.id, endHeaders, b.Next(size)) + } + if err != nil { + t.Close() + return ConnectionErrorf("transport: %v", err) + } + } + return nil +} + +// WriteHeader sends the header metedata md back to the client. +func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error { + s.mu.Lock() + if s.headerOk || s.state == streamDone { + s.mu.Unlock() + return ErrIllegalHeaderWrite + } + s.headerOk = true + s.mu.Unlock() + if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + return err + } + t.hBuf.Reset() + t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + for k, v := range md { + for _, entry := range v { + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) + } + } + if err := t.writeHeaders(s, t.hBuf, false); err != nil { + return err + } + t.writableChan <- 0 + return nil +} + +// WriteStatus sends stream status to the client and terminates the stream. +// There is no further I/O operations being able to perform on this stream. +// TODO(zhaoq): Now it indicates the end of entire stream. Revisit if early +// OK is adopted. +func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error { + s.mu.RLock() + if s.state == streamDone { + s.mu.RUnlock() + return nil + } + s.mu.RUnlock() + if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + return err + } + t.hBuf.Reset() + t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + t.hEnc.WriteField( + hpack.HeaderField{ + Name: "grpc-status", + Value: strconv.Itoa(int(statusCode)), + }) + t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc}) + // Attach the trailer metadata. + for k, v := range s.trailer { + for _, entry := range v { + t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry}) + } + } + if err := t.writeHeaders(s, t.hBuf, true); err != nil { + t.Close() + return err + } + t.closeStream(s) + t.writableChan <- 0 + return nil +} + +// Write converts the data into HTTP2 data frame and sends it out. Non-nil error +// is returns if it fails (e.g., framing error, transport error). +func (t *http2Server) Write(s *Stream, data []byte, opts *Options) error { + // TODO(zhaoq): Support multi-writers for a single stream. + var writeHeaderFrame bool + s.mu.Lock() + if !s.headerOk { + writeHeaderFrame = true + s.headerOk = true + } + s.mu.Unlock() + if writeHeaderFrame { + if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + return err + } + t.hBuf.Reset() + t.hEnc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"}) + t.hEnc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}) + p := http2.HeadersFrameParam{ + StreamID: s.id, + BlockFragment: t.hBuf.Bytes(), + EndHeaders: true, + } + if err := t.framer.writeHeaders(false, p); err != nil { + t.Close() + return ConnectionErrorf("transport: %v", err) + } + t.writableChan <- 0 + } + r := bytes.NewBuffer(data) + for { + if r.Len() == 0 { + return nil + } + size := http2MaxFrameLen + s.sendQuotaPool.add(0) + // Wait until the stream has some quota to send the data. + sq, err := wait(s.ctx, t.shutdownChan, s.sendQuotaPool.acquire()) + if err != nil { + return err + } + t.sendQuotaPool.add(0) + // Wait until the transport has some quota to send the data. + tq, err := wait(s.ctx, t.shutdownChan, t.sendQuotaPool.acquire()) + if err != nil { + if _, ok := err.(StreamError); ok { + t.sendQuotaPool.cancel() + } + return err + } + if sq < size { + size = sq + } + if tq < size { + size = tq + } + p := r.Next(size) + ps := len(p) + if ps < sq { + // Overbooked stream quota. Return it back. + s.sendQuotaPool.add(sq - ps) + } + if ps < tq { + // Overbooked transport quota. Return it back. + t.sendQuotaPool.add(tq - ps) + } + t.framer.adjustNumWriters(1) + // Got some quota. Try to acquire writing privilege on the + // transport. + if _, err := wait(s.ctx, t.shutdownChan, t.writableChan); err != nil { + if t.framer.adjustNumWriters(-1) == 0 { + // This writer is the last one in this batch and has the + // responsibility to flush the buffered frames. It queues + // a flush request to controlBuf instead of flushing directly + // in order to avoid the race with other writing or flushing. + t.controlBuf.put(&flushIO{}) + } + return err + } + var forceFlush bool + if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last { + forceFlush = true + } + if err := t.framer.writeData(forceFlush, s.id, false, p); err != nil { + t.Close() + return ConnectionErrorf("transport: %v", err) + } + if t.framer.adjustNumWriters(-1) == 0 { + t.framer.flushWrite() + } + t.writableChan <- 0 + } + +} + +func (t *http2Server) applySettings(ss []http2.Setting) { + for _, s := range ss { + if s.ID == http2.SettingInitialWindowSize { + t.mu.Lock() + defer t.mu.Unlock() + for _, stream := range t.activeStreams { + stream.sendQuotaPool.reset(int(s.Val - t.streamSendQuota)) + } + t.streamSendQuota = s.Val + } + + } +} + +// controller running in a separate goroutine takes charge of sending control +// frames (e.g., window update, reset stream, setting, etc.) to the server. +func (t *http2Server) controller() { + for { + select { + case i := <-t.controlBuf.get(): + t.controlBuf.load() + select { + case <-t.writableChan: + switch i := i.(type) { + case *windowUpdate: + t.framer.writeWindowUpdate(true, i.streamID, i.increment) + case *settings: + if i.ack { + t.framer.writeSettingsAck(true) + t.applySettings(i.ss) + } else { + t.framer.writeSettings(true, i.ss...) + } + case *resetStream: + t.framer.writeRSTStream(true, i.streamID, i.code) + case *flushIO: + t.framer.flushWrite() + case *ping: + // TODO(zhaoq): Ack with all-0 data now. will change to some + // meaningful content when this is actually in use. + t.framer.writePing(true, i.ack, [8]byte{}) + default: + grpclog.Printf("transport: http2Server.controller got unexpected item type %v\n", i) + } + t.writableChan <- 0 + continue + case <-t.shutdownChan: + return + } + case <-t.shutdownChan: + return + } + } +} + +// Close starts shutting down the http2Server transport. +// TODO(zhaoq): Now the destruction is not blocked on any pending streams. This +// could cause some resource issue. Revisit this later. +func (t *http2Server) Close() (err error) { + t.mu.Lock() + if t.state == closing { + t.mu.Unlock() + return errors.New("transport: Close() was already called") + } + t.state = closing + streams := t.activeStreams + t.activeStreams = nil + t.mu.Unlock() + close(t.shutdownChan) + err = t.conn.Close() + // Notify all active streams. + for _, s := range streams { + s.write(recvMsg{err: ErrConnClosing}) + } + return +} + +// closeStream clears the footprint of a stream when the stream is not needed +// any more. +func (t *http2Server) closeStream(s *Stream) { + t.mu.Lock() + delete(t.activeStreams, s.id) + t.mu.Unlock() + if q := s.fc.restoreConn(); q > 0 { + t.controlBuf.put(&windowUpdate{0, q}) + } + s.mu.Lock() + if s.state == streamDone { + s.mu.Unlock() + return + } + s.state = streamDone + s.mu.Unlock() + // In case stream sending and receiving are invoked in separate + // goroutines (e.g., bi-directional streaming), the caller needs + // to call cancel on the stream to interrupt the blocking on + // other goroutines. + s.cancel() +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/transport/http_util.go b/Godeps/_workspace/src/google.golang.org/grpc/transport/http_util.go new file mode 100644 index 0000000..3370cc4 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/transport/http_util.go @@ -0,0 +1,454 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "bufio" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/bradfitz/http2" + "github.com/bradfitz/http2/hpack" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/metadata" +) + +const ( + // The primary user agent + primaryUA = "grpc-go/0.7" + // http2MaxFrameLen specifies the max length of a HTTP2 frame. + http2MaxFrameLen = 16384 // 16KB frame + // http://http2.github.io/http2-spec/#SettingValues + http2InitHeaderTableSize = 4096 + // http2IOBufSize specifies the buffer size for sending frames. + http2IOBufSize = 32 * 1024 +) + +var ( + clientPreface = []byte(http2.ClientPreface) + http2RSTErrConvTab = map[http2.ErrCode]codes.Code{ + http2.ErrCodeNo: codes.Internal, + http2.ErrCodeProtocol: codes.Internal, + http2.ErrCodeInternal: codes.Internal, + http2.ErrCodeFlowControl: codes.ResourceExhausted, + http2.ErrCodeSettingsTimeout: codes.Internal, + http2.ErrCodeFrameSize: codes.Internal, + http2.ErrCodeRefusedStream: codes.Unavailable, + http2.ErrCodeCancel: codes.Canceled, + http2.ErrCodeCompression: codes.Internal, + http2.ErrCodeConnect: codes.Internal, + http2.ErrCodeEnhanceYourCalm: codes.ResourceExhausted, + http2.ErrCodeInadequateSecurity: codes.PermissionDenied, + } + statusCodeConvTab = map[codes.Code]http2.ErrCode{ + codes.Internal: http2.ErrCodeInternal, + codes.Canceled: http2.ErrCodeCancel, + codes.Unavailable: http2.ErrCodeRefusedStream, + codes.ResourceExhausted: http2.ErrCodeEnhanceYourCalm, + codes.PermissionDenied: http2.ErrCodeInadequateSecurity, + } +) + +// Records the states during HPACK decoding. Must be reset once the +// decoding of the entire headers are finished. +type decodeState struct { + // statusCode caches the stream status received from the trailer + // the server sent. Client side only. + statusCode codes.Code + statusDesc string + // Server side only fields. + timeoutSet bool + timeout time.Duration + method string + // key-value metadata map from the peer. + mdata map[string][]string +} + +// An hpackDecoder decodes HTTP2 headers which may span multiple frames. +type hpackDecoder struct { + h *hpack.Decoder + state decodeState + err error // The err when decoding +} + +// A headerFrame is either a http2.HeaderFrame or http2.ContinuationFrame. +type headerFrame interface { + Header() http2.FrameHeader + HeaderBlockFragment() []byte + HeadersEnded() bool +} + +// isReservedHeader checks whether hdr belongs to HTTP2 headers +// reserved by gRPC protocol. Any other headers are classified as the +// user-specified metadata. +func isReservedHeader(hdr string) bool { + if hdr[0] == ':' { + return true + } + switch hdr { + case "content-type", + "grpc-message-type", + "grpc-encoding", + "grpc-message", + "grpc-status", + "grpc-timeout", + "te": + return true + default: + return false + } +} + +func newHPACKDecoder() *hpackDecoder { + d := &hpackDecoder{} + d.h = hpack.NewDecoder(http2InitHeaderTableSize, func(f hpack.HeaderField) { + switch f.Name { + case "content-type": + // TODO(zhaoq): Tentatively disable the check until a bug is fixed. + /* + if !strings.Contains(f.Value, "application/grpc") { + d.err = StreamErrorf(codes.FailedPrecondition, "transport: received the unexpected header") + return + } + */ + case "grpc-status": + code, err := strconv.Atoi(f.Value) + if err != nil { + d.err = StreamErrorf(codes.Internal, "transport: malformed grpc-status: %v", err) + return + } + d.state.statusCode = codes.Code(code) + case "grpc-message": + d.state.statusDesc = f.Value + case "grpc-timeout": + d.state.timeoutSet = true + var err error + d.state.timeout, err = timeoutDecode(f.Value) + if err != nil { + d.err = StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err) + return + } + case ":path": + d.state.method = f.Value + default: + if !isReservedHeader(f.Name) { + if f.Name == "user-agent" { + i := strings.LastIndex(f.Value, " ") + if i == -1 { + // There is no application user agent string being set. + return + } + // Extract the application user agent string. + f.Value = f.Value[:i] + } + if d.state.mdata == nil { + d.state.mdata = make(map[string][]string) + } + k, v, err := metadata.DecodeKeyValue(f.Name, f.Value) + if err != nil { + grpclog.Printf("Failed to decode (%q, %q): %v", f.Name, f.Value, err) + return + } + d.state.mdata[k] = append(d.state.mdata[k], v) + } + } + }) + return d +} + +func (d *hpackDecoder) decodeClientHTTP2Headers(frame headerFrame) (endHeaders bool, err error) { + d.err = nil + _, err = d.h.Write(frame.HeaderBlockFragment()) + if err != nil { + err = StreamErrorf(codes.Internal, "transport: HPACK header decode error: %v", err) + } + + if frame.HeadersEnded() { + if closeErr := d.h.Close(); closeErr != nil && err == nil { + err = StreamErrorf(codes.Internal, "transport: HPACK decoder close error: %v", closeErr) + } + endHeaders = true + } + + if err == nil && d.err != nil { + err = d.err + } + return +} + +func (d *hpackDecoder) decodeServerHTTP2Headers(frame headerFrame) (endHeaders bool, err error) { + d.err = nil + _, err = d.h.Write(frame.HeaderBlockFragment()) + if err != nil { + err = StreamErrorf(codes.Internal, "transport: HPACK header decode error: %v", err) + } + + if frame.HeadersEnded() { + if closeErr := d.h.Close(); closeErr != nil && err == nil { + err = StreamErrorf(codes.Internal, "transport: HPACK decoder close error: %v", closeErr) + } + endHeaders = true + } + + if err == nil && d.err != nil { + err = d.err + } + return +} + +type timeoutUnit uint8 + +const ( + hour timeoutUnit = 'H' + minute timeoutUnit = 'M' + second timeoutUnit = 'S' + millisecond timeoutUnit = 'm' + microsecond timeoutUnit = 'u' + nanosecond timeoutUnit = 'n' +) + +func timeoutUnitToDuration(u timeoutUnit) (d time.Duration, ok bool) { + switch u { + case hour: + return time.Hour, true + case minute: + return time.Minute, true + case second: + return time.Second, true + case millisecond: + return time.Millisecond, true + case microsecond: + return time.Microsecond, true + case nanosecond: + return time.Nanosecond, true + default: + } + return +} + +const maxTimeoutValue int64 = 100000000 - 1 + +// div does integer division and round-up the result. Note that this is +// equivalent to (d+r-1)/r but has less chance to overflow. +func div(d, r time.Duration) int64 { + if m := d % r; m > 0 { + return int64(d/r + 1) + } + return int64(d / r) +} + +// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it. +func timeoutEncode(t time.Duration) string { + if d := div(t, time.Nanosecond); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "n" + } + if d := div(t, time.Microsecond); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "u" + } + if d := div(t, time.Millisecond); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "m" + } + if d := div(t, time.Second); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "S" + } + if d := div(t, time.Minute); d <= maxTimeoutValue { + return strconv.FormatInt(d, 10) + "M" + } + // Note that maxTimeoutValue * time.Hour > MaxInt64. + return strconv.FormatInt(div(t, time.Hour), 10) + "H" +} + +func timeoutDecode(s string) (time.Duration, error) { + size := len(s) + if size < 2 { + return 0, fmt.Errorf("transport: timeout string is too short: %q", s) + } + unit := timeoutUnit(s[size-1]) + d, ok := timeoutUnitToDuration(unit) + if !ok { + return 0, fmt.Errorf("transport: timeout unit is not recognized: %q", s) + } + t, err := strconv.ParseInt(s[:size-1], 10, 64) + if err != nil { + return 0, err + } + return d * time.Duration(t), nil +} + +type framer struct { + numWriters int32 + reader io.Reader + writer *bufio.Writer + fr *http2.Framer +} + +func newFramer(conn net.Conn) *framer { + f := &framer{ + reader: conn, + writer: bufio.NewWriterSize(conn, http2IOBufSize), + } + f.fr = http2.NewFramer(f.writer, f.reader) + return f +} + +func (f *framer) adjustNumWriters(i int32) int32 { + return atomic.AddInt32(&f.numWriters, i) +} + +// The following writeXXX functions can only be called when the caller gets +// unblocked from writableChan channel (i.e., owns the privilege to write). + +func (f *framer) writeContinuation(forceFlush bool, streamID uint32, endHeaders bool, headerBlockFragment []byte) error { + if err := f.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeData(forceFlush bool, streamID uint32, endStream bool, data []byte) error { + if err := f.fr.WriteData(streamID, endStream, data); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeGoAway(forceFlush bool, maxStreamID uint32, code http2.ErrCode, debugData []byte) error { + if err := f.fr.WriteGoAway(maxStreamID, code, debugData); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeHeaders(forceFlush bool, p http2.HeadersFrameParam) error { + if err := f.fr.WriteHeaders(p); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writePing(forceFlush, ack bool, data [8]byte) error { + if err := f.fr.WritePing(ack, data); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writePriority(forceFlush bool, streamID uint32, p http2.PriorityParam) error { + if err := f.fr.WritePriority(streamID, p); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writePushPromise(forceFlush bool, p http2.PushPromiseParam) error { + if err := f.fr.WritePushPromise(p); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeRSTStream(forceFlush bool, streamID uint32, code http2.ErrCode) error { + if err := f.fr.WriteRSTStream(streamID, code); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeSettings(forceFlush bool, settings ...http2.Setting) error { + if err := f.fr.WriteSettings(settings...); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeSettingsAck(forceFlush bool) error { + if err := f.fr.WriteSettingsAck(); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) writeWindowUpdate(forceFlush bool, streamID, incr uint32) error { + if err := f.fr.WriteWindowUpdate(streamID, incr); err != nil { + return err + } + if forceFlush { + return f.writer.Flush() + } + return nil +} + +func (f *framer) flushWrite() error { + return f.writer.Flush() +} + +func (f *framer) readFrame() (http2.Frame, error) { + return f.fr.ReadFrame() +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/transport/http_util_test.go b/Godeps/_workspace/src/google.golang.org/grpc/transport/http_util_test.go new file mode 100644 index 0000000..b5b18bf --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/transport/http_util_test.go @@ -0,0 +1,87 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "fmt" + "testing" + "time" +) + +func TestTimeoutEncode(t *testing.T) { + for _, test := range []struct { + in string + out string + }{ + {"12345678ns", "12345678n"}, + {"123456789ns", "123457u"}, + {"12345678us", "12345678u"}, + {"123456789us", "123457m"}, + {"12345678ms", "12345678m"}, + {"123456789ms", "123457S"}, + {"12345678s", "12345678S"}, + {"123456789s", "2057614M"}, + {"12345678m", "12345678M"}, + {"123456789m", "2057614H"}, + } { + d, err := time.ParseDuration(test.in) + if err != nil { + t.Fatalf("failed to parse duration string %s: %v", test.in, err) + } + out := timeoutEncode(d) + if out != test.out { + t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out) + } + } +} + +func TestTimeoutDecode(t *testing.T) { + for _, test := range []struct { + // input + s string + // output + d time.Duration + err error + }{ + {"1234S", time.Second * 1234, nil}, + {"1234x", 0, fmt.Errorf("transport: timeout unit is not recognized: %q", "1234x")}, + {"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")}, + {"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")}, + } { + d, err := timeoutDecode(test.s) + if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) { + t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err) + } + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/transport/transport.go b/Godeps/_workspace/src/google.golang.org/grpc/transport/transport.go new file mode 100644 index 0000000..e20a6d9 --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/transport/transport.go @@ -0,0 +1,457 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +/* +Package transport defines and implements message oriented communication channel +to complete various transactions (e.g., an RPC). +*/ +package transport + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" +) + +// recvMsg represents the received msg from the transport. All transport +// protocol specific info has been removed. +type recvMsg struct { + data []byte + // nil: received some data + // io.EOF: stream is completed. data is nil. + // other non-nil error: transport failure. data is nil. + err error +} + +func (recvMsg) isItem() bool { + return true +} + +// All items in an out of a recvBuffer should be the same type. +type item interface { + isItem() bool +} + +// recvBuffer is an unbounded channel of item. +type recvBuffer struct { + c chan item + mu sync.Mutex + backlog []item +} + +func newRecvBuffer() *recvBuffer { + b := &recvBuffer{ + c: make(chan item, 1), + } + return b +} + +func (b *recvBuffer) put(r item) { + b.mu.Lock() + defer b.mu.Unlock() + b.backlog = append(b.backlog, r) + select { + case b.c <- b.backlog[0]: + b.backlog = b.backlog[1:] + default: + } +} + +func (b *recvBuffer) load() { + b.mu.Lock() + defer b.mu.Unlock() + if len(b.backlog) > 0 { + select { + case b.c <- b.backlog[0]: + b.backlog = b.backlog[1:] + default: + } + } +} + +// get returns the channel that receives an item in the buffer. +// +// Upon receipt of an item, the caller should call load to send another +// item onto the channel if there is any. +func (b *recvBuffer) get() <-chan item { + return b.c +} + +// recvBufferReader implements io.Reader interface to read the data from +// recvBuffer. +type recvBufferReader struct { + ctx context.Context + recv *recvBuffer + last *bytes.Reader // Stores the remaining data in the previous calls. + err error +} + +// Read reads the next len(p) bytes from last. If last is drained, it tries to +// read additional data from recv. It blocks if there no additional data available +// in recv. If Read returns any non-nil error, it will continue to return that error. +func (r *recvBufferReader) Read(p []byte) (n int, err error) { + if r.err != nil { + return 0, r.err + } + defer func() { r.err = err }() + if r.last != nil && r.last.Len() > 0 { + // Read remaining data left in last call. + return r.last.Read(p) + } + select { + case <-r.ctx.Done(): + return 0, ContextErr(r.ctx.Err()) + case i := <-r.recv.get(): + r.recv.load() + m := i.(*recvMsg) + if m.err != nil { + return 0, m.err + } + r.last = bytes.NewReader(m.data) + return r.last.Read(p) + } +} + +type streamState uint8 + +const ( + streamActive streamState = iota + streamWriteDone // EndStream sent + streamReadDone // EndStream received + streamDone // sendDone and recvDone or RSTStreamFrame is sent or received. +) + +// Stream represents an RPC in the transport layer. +type Stream struct { + id uint32 + // nil for client side Stream. + st ServerTransport + // ctx is the associated context of the stream. + ctx context.Context + cancel context.CancelFunc + // method records the associated RPC method of the stream. + method string + buf *recvBuffer + dec io.Reader + fc *inFlow + recvQuota uint32 + // The accumulated inbound quota pending for window update. + updateQuota uint32 + // The handler to control the window update procedure for both this + // particular stream and the associated transport. + windowHandler func(int) + + sendQuotaPool *quotaPool + // Close headerChan to indicate the end of reception of header metadata. + headerChan chan struct{} + // header caches the received header metadata. + header metadata.MD + // The key-value map of trailer metadata. + trailer metadata.MD + + mu sync.RWMutex // guard the following + // headerOK becomes true from the first header is about to send. + headerOk bool + state streamState + // true iff headerChan is closed. Used to avoid closing headerChan + // multiple times. + headerDone bool + // the status received from the server. + statusCode codes.Code + statusDesc string +} + +// Header acquires the key-value pairs of header metadata once it +// is available. It blocks until i) the metadata is ready or ii) there is no +// header metadata or iii) the stream is cancelled/expired. +func (s *Stream) Header() (metadata.MD, error) { + select { + case <-s.ctx.Done(): + return nil, ContextErr(s.ctx.Err()) + case <-s.headerChan: + return s.header.Copy(), nil + } +} + +// Trailer returns the cached trailer metedata. Note that if it is not called +// after the entire stream is done, it could return an empty MD. Client +// side only. +func (s *Stream) Trailer() metadata.MD { + s.mu.RLock() + defer s.mu.RUnlock() + return s.trailer.Copy() +} + +// ServerTransport returns the underlying ServerTransport for the stream. +// The client side stream always returns nil. +func (s *Stream) ServerTransport() ServerTransport { + return s.st +} + +// Context returns the context of the stream. +func (s *Stream) Context() context.Context { + return s.ctx +} + +// Method returns the method for the stream. +func (s *Stream) Method() string { + return s.method +} + +// StatusCode returns statusCode received from the server. +func (s *Stream) StatusCode() codes.Code { + return s.statusCode +} + +// StatusDesc returns statusDesc received from the server. +func (s *Stream) StatusDesc() string { + return s.statusDesc +} + +// ErrIllegalTrailerSet indicates that the trailer has already been set or it +// is too late to do so. +var ErrIllegalTrailerSet = errors.New("transport: trailer has been set") + +// SetTrailer sets the trailer metadata which will be sent with the RPC status +// by the server. This can only be called at most once. Server side only. +func (s *Stream) SetTrailer(md metadata.MD) error { + s.mu.Lock() + defer s.mu.Unlock() + if s.trailer != nil { + return ErrIllegalTrailerSet + } + s.trailer = md.Copy() + return nil +} + +func (s *Stream) write(m recvMsg) { + s.buf.put(&m) +} + +// Read reads all the data available for this Stream from the transport and +// passes them into the decoder, which converts them into a gRPC message stream. +// The error is io.EOF when the stream is done or another non-nil error if +// the stream broke. +func (s *Stream) Read(p []byte) (n int, err error) { + n, err = s.dec.Read(p) + if err != nil { + return + } + s.windowHandler(n) + return +} + +type key int + +// The key to save transport.Stream in the context. +const streamKey = key(0) + +// newContextWithStream creates a new context from ctx and attaches stream +// to it. +func newContextWithStream(ctx context.Context, stream *Stream) context.Context { + return context.WithValue(ctx, streamKey, stream) +} + +// StreamFromContext returns the stream saved in ctx. +func StreamFromContext(ctx context.Context) (s *Stream, ok bool) { + s, ok = ctx.Value(streamKey).(*Stream) + return +} + +// state of transport +type transportState int + +const ( + reachable transportState = iota + unreachable + closing +) + +// NewServerTransport creates a ServerTransport with conn or non-nil error +// if it fails. +func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo credentials.AuthInfo) (ServerTransport, error) { + return newHTTP2Server(conn, maxStreams, authInfo) +} + +// ConnectOptions covers all relevant options for dialing a server. +type ConnectOptions struct { + // UserAgent is the application user agent. + UserAgent string + // Dialer specifies how to dial a network address. + Dialer func(string, time.Duration) (net.Conn, error) + // AuthOptions stores the credentials required to setup a client connection and/or issue RPCs. + AuthOptions []credentials.Credentials + // Timeout specifies the timeout for dialing a client connection. + Timeout time.Duration +} + +// NewClientTransport establishes the transport with the required ConnectOptions +// and returns it to the caller. +func NewClientTransport(target string, opts *ConnectOptions) (ClientTransport, error) { + return newHTTP2Client(target, opts) +} + +// Options provides additional hints and information for message +// transmission. +type Options struct { + // Indicate whether it is the last piece for this stream. + Last bool + // The hint to transport impl whether the data could be buffered for + // batching write. Transport impl can feel free to ignore it. + Delay bool +} + +// CallHdr carries the information of a particular RPC. +type CallHdr struct { + Host string // peer host + Method string // the operation to perform on the specified host +} + +// ClientTransport is the common interface for all gRPC client side transport +// implementations. +type ClientTransport interface { + // Close tears down this transport. Once it returns, the transport + // should not be accessed any more. The caller must make sure this + // is called only once. + Close() error + + // Write sends the data for the given stream. A nil stream indicates + // the write is to be performed on the transport as a whole. + Write(s *Stream, data []byte, opts *Options) error + + // NewStream creates a Stream for an RPC. + NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error) + + // CloseStream clears the footprint of a stream when the stream is + // not needed any more. The err indicates the error incurred when + // CloseStream is called. Must be called when a stream is finished + // unless the associated transport is closing. + CloseStream(stream *Stream, err error) + + // Error returns a channel that is closed when some I/O error + // happens. Typically the caller should have a goroutine to monitor + // this in order to take action (e.g., close the current transport + // and create a new one) in error case. It should not return nil + // once the transport is initiated. + Error() <-chan struct{} +} + +// ServerTransport is the common interface for all gRPC server side transport +// implementations. +type ServerTransport interface { + // WriteStatus sends the status of a stream to the client. + WriteStatus(s *Stream, statusCode codes.Code, statusDesc string) error + // Write sends the data for the given stream. + Write(s *Stream, data []byte, opts *Options) error + // WriteHeader sends the header metedata for the given stream. + WriteHeader(s *Stream, md metadata.MD) error + // HandleStreams receives incoming streams using the given handler. + HandleStreams(func(*Stream)) + // Close tears down the transport. Once it is called, the transport + // should not be accessed any more. All the pending streams and their + // handlers will be terminated asynchronously. + Close() error +} + +// StreamErrorf creates an StreamError with the specified error code and description. +func StreamErrorf(c codes.Code, format string, a ...interface{}) StreamError { + return StreamError{ + Code: c, + Desc: fmt.Sprintf(format, a...), + } +} + +// ConnectionErrorf creates an ConnectionError with the specified error description. +func ConnectionErrorf(format string, a ...interface{}) ConnectionError { + return ConnectionError{ + Desc: fmt.Sprintf(format, a...), + } +} + +// ConnectionError is an error that results in the termination of the +// entire connection and the retry of all the active streams. +type ConnectionError struct { + Desc string +} + +func (e ConnectionError) Error() string { + return fmt.Sprintf("connection error: desc = %q", e.Desc) +} + +// Define some common ConnectionErrors. +var ErrConnClosing = ConnectionError{Desc: "transport is closing"} + +// StreamError is an error that only affects one stream within a connection. +type StreamError struct { + Code codes.Code + Desc string +} + +func (e StreamError) Error() string { + return fmt.Sprintf("stream error: code = %d desc = %q", e.Code, e.Desc) +} + +// ContextErr converts the error from context package into a StreamError. +func ContextErr(err error) StreamError { + switch err { + case context.DeadlineExceeded: + return StreamErrorf(codes.DeadlineExceeded, "%v", err) + case context.Canceled: + return StreamErrorf(codes.Canceled, "%v", err) + } + panic(fmt.Sprintf("Unexpected error from context packet: %v", err)) +} + +// wait blocks until it can receive from ctx.Done, closing, or proceed. +// If it receives from ctx.Done, it returns 0, the StreamError for ctx.Err. +// If it receives from closing, it returns 0, ErrConnClosing. +// If it receives from proceed, it returns the received integer, nil. +func wait(ctx context.Context, closing <-chan struct{}, proceed <-chan int) (int, error) { + select { + case <-ctx.Done(): + return 0, ContextErr(ctx.Err()) + case <-closing: + return 0, ErrConnClosing + case i := <-proceed: + return i, nil + } +} diff --git a/Godeps/_workspace/src/google.golang.org/grpc/transport/transport_test.go b/Godeps/_workspace/src/google.golang.org/grpc/transport/transport_test.go new file mode 100644 index 0000000..f0c01ec --- /dev/null +++ b/Godeps/_workspace/src/google.golang.org/grpc/transport/transport_test.go @@ -0,0 +1,591 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +package transport + +import ( + "bytes" + "io" + "math" + "net" + "reflect" + "strconv" + "sync" + "testing" + "time" + + "github.com/bradfitz/http2" + "golang.org/x/net/context" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/grpclog" +) + +type server struct { + lis net.Listener + port string + // channel to signal server is ready to serve. + readyChan chan bool + mu sync.Mutex + conns map[ServerTransport]bool +} + +var ( + expectedRequest = []byte("ping") + expectedResponse = []byte("pong") + expectedRequestLarge = make([]byte, initialWindowSize*2) + expectedResponseLarge = make([]byte, initialWindowSize*2) +) + +type testStreamHandler struct { + t ServerTransport +} + +type hType int + +const ( + normal hType = iota + suspended + misbehaved +) + +func (h *testStreamHandler) handleStream(s *Stream) { + req := expectedRequest + resp := expectedResponse + if s.Method() == "foo.Large" { + req = expectedRequestLarge + resp = expectedResponseLarge + } + p := make([]byte, len(req)) + _, err := io.ReadFull(s, p) + if err != nil || !bytes.Equal(p, req) { + if err == ErrConnClosing { + return + } + grpclog.Fatalf("handleStream got error: %v, want ; result: %v, want %v", err, p, req) + } + // send a response back to the client. + h.t.Write(s, resp, &Options{}) + // send the trailer to end the stream. + h.t.WriteStatus(s, codes.OK, "") +} + +// handleStreamSuspension blocks until s.ctx is canceled. +func (h *testStreamHandler) handleStreamSuspension(s *Stream) { + <-s.ctx.Done() +} + +func (h *testStreamHandler) handleStreamMisbehave(s *Stream) { + conn, ok := s.ServerTransport().(*http2Server) + if !ok { + grpclog.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport()) + } + size := 1 + if s.Method() == "foo.MaxFrame" { + size = http2MaxFrameLen + } + // Drain the client side stream flow control window. + var sent int + for sent <= initialWindowSize { + <-conn.writableChan + if err := conn.framer.writeData(true, s.id, false, make([]byte, size)); err != nil { + conn.writableChan <- 0 + break + } + conn.writableChan <- 0 + sent += size + } +} + +// start starts server. Other goroutines should block on s.readyChan for futher operations. +func (s *server) start(port int, maxStreams uint32, ht hType) { + var err error + if port == 0 { + s.lis, err = net.Listen("tcp", ":0") + } else { + s.lis, err = net.Listen("tcp", ":"+strconv.Itoa(port)) + } + if err != nil { + grpclog.Fatalf("failed to listen: %v", err) + } + _, p, err := net.SplitHostPort(s.lis.Addr().String()) + if err != nil { + grpclog.Fatalf("failed to parse listener address: %v", err) + } + s.port = p + s.conns = make(map[ServerTransport]bool) + if s.readyChan != nil { + close(s.readyChan) + } + for { + conn, err := s.lis.Accept() + if err != nil { + return + } + t, err := NewServerTransport("http2", conn, maxStreams, nil) + if err != nil { + return + } + s.mu.Lock() + if s.conns == nil { + s.mu.Unlock() + t.Close() + return + } + s.conns[t] = true + s.mu.Unlock() + h := &testStreamHandler{t} + switch ht { + case suspended: + go t.HandleStreams(h.handleStreamSuspension) + case misbehaved: + go t.HandleStreams(h.handleStreamMisbehave) + default: + go t.HandleStreams(h.handleStream) + } + } +} + +func (s *server) wait(t *testing.T, timeout time.Duration) { + select { + case <-s.readyChan: + case <-time.After(timeout): + t.Fatalf("Timed out after %v waiting for server to be ready", timeout) + } +} + +func (s *server) stop() { + s.lis.Close() + s.mu.Lock() + for c := range s.conns { + c.Close() + } + s.conns = nil + s.mu.Unlock() +} + +func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { + server := &server{readyChan: make(chan bool)} + go server.start(port, maxStreams, ht) + server.wait(t, 2*time.Second) + addr := "localhost:" + server.port + var ( + ct ClientTransport + connErr error + ) + ct, connErr = NewClientTransport(addr, &ConnectOptions{}) + if connErr != nil { + t.Fatalf("failed to create transport: %v", connErr) + } + return server, ct +} + +func TestClientSendAndReceive(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, normal) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo.Small", + } + s1, err1 := ct.NewStream(context.Background(), callHdr) + if err1 != nil { + t.Fatalf("failed to open stream: %v", err1) + } + if s1.id != 1 { + t.Fatalf("wrong stream id: %d", s1.id) + } + s2, err2 := ct.NewStream(context.Background(), callHdr) + if err2 != nil { + t.Fatalf("failed to open stream: %v", err2) + } + if s2.id != 3 { + t.Fatalf("wrong stream id: %d", s2.id) + } + opts := Options{ + Last: true, + Delay: false, + } + if err := ct.Write(s1, expectedRequest, &opts); err != nil { + t.Fatalf("failed to send data: %v", err) + } + p := make([]byte, len(expectedResponse)) + _, recvErr := io.ReadFull(s1, p) + if recvErr != nil || !bytes.Equal(p, expectedResponse) { + t.Fatalf("Error: %v, want ; Result: %v, want %v", recvErr, p, expectedResponse) + } + _, recvErr = io.ReadFull(s1, p) + if recvErr != io.EOF { + t.Fatalf("Error: %v; want ", recvErr) + } + ct.Close() + server.stop() +} + +func TestClientErrorNotify(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, normal) + go server.stop() + // ct.reader should detect the error and activate ct.Error(). + <-ct.Error() + ct.Close() +} + +func performOneRPC(ct ClientTransport) { + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo.Small", + } + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + return + } + opts := Options{ + Last: true, + Delay: false, + } + if err := ct.Write(s, expectedRequest, &opts); err == nil { + time.Sleep(5 * time.Millisecond) + // The following s.Recv()'s could error out because the + // underlying transport is gone. + // + // Read response + p := make([]byte, len(expectedResponse)) + io.ReadFull(s, p) + // Read io.EOF + io.ReadFull(s, p) + } +} + +func TestClientMix(t *testing.T) { + s, ct := setUp(t, 0, math.MaxUint32, normal) + go func(s *server) { + time.Sleep(5 * time.Second) + s.stop() + }(s) + go func(ct ClientTransport) { + <-ct.Error() + ct.Close() + }(ct) + for i := 0; i < 1000; i++ { + time.Sleep(10 * time.Millisecond) + go performOneRPC(ct) + } +} + +func TestLargeMessage(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, normal) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo.Large", + } + var wg sync.WaitGroup + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + t.Errorf("failed to open stream: %v", err) + } + if err := ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil { + t.Errorf("failed to send data: %v", err) + } + p := make([]byte, len(expectedResponseLarge)) + _, recvErr := io.ReadFull(s, p) + if recvErr != nil || !bytes.Equal(p, expectedResponseLarge) { + t.Errorf("Error: %v, want ; Result len: %d, want len %d", recvErr, len(p), len(expectedResponseLarge)) + } + _, recvErr = io.ReadFull(s, p) + if recvErr != io.EOF { + t.Errorf("Error: %v; want ", recvErr) + } + }() + } + wg.Wait() + ct.Close() + server.stop() +} + +func TestLargeMessageSuspension(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, suspended) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo.Large", + } + // Set a long enough timeout for writing a large message out. + ctx, _ := context.WithTimeout(context.Background(), time.Second) + s, err := ct.NewStream(ctx, callHdr) + if err != nil { + t.Fatalf("failed to open stream: %v", err) + } + // Write should not be done successfully due to flow control. + err = ct.Write(s, expectedRequestLarge, &Options{Last: true, Delay: false}) + expectedErr := StreamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) + if err == nil || err != expectedErr { + t.Fatalf("Write got %v, want %v", err, expectedErr) + } + ct.Close() + server.stop() +} + +func TestMaxStreams(t *testing.T) { + server, ct := setUp(t, 0, 1, suspended) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo.Large", + } + // Have a pending stream which takes all streams quota. + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + cc, ok := ct.(*http2Client) + if !ok { + t.Fatalf("Failed to convert %v to *http2Client", ct) + } + done := make(chan struct{}) + ch := make(chan int) + go func() { + for { + select { + case <-time.After(5 * time.Millisecond): + ch <- 0 + case <-time.After(5 * time.Second): + close(done) + return + } + } + }() + for { + select { + case <-ch: + case <-done: + t.Fatalf("Client has not received the max stream setting in 5 seconds.") + } + cc.mu.Lock() + // cc.streamsQuota should be initialized once receiving the 1st setting frame from + // the server. + if cc.streamsQuota != nil { + cc.mu.Unlock() + select { + case <-cc.streamsQuota.acquire(): + t.Fatalf("streamsQuota.acquire() becomes readable mistakenly.") + default: + if cc.streamsQuota.quota != 0 { + t.Fatalf("streamsQuota.quota got non-zero quota mistakenly.") + } + } + break + } + cc.mu.Unlock() + } + // Close the pending stream so that the streams quota becomes available for the next new stream. + ct.CloseStream(s, nil) + select { + case i := <-cc.streamsQuota.acquire(): + if i != 1 { + t.Fatalf("streamsQuota.acquire() got %d quota, want 1.", i) + } + cc.streamsQuota.add(i) + default: + t.Fatalf("streamsQuota.acquire() is not readable.") + } + if _, err := ct.NewStream(context.Background(), callHdr); err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + ct.Close() + server.stop() +} + +func TestServerWithMisbehavedClient(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, suspended) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo", + } + var sc *http2Server + // Wait until the server transport is setup. + for { + server.mu.Lock() + if len(server.conns) == 0 { + server.mu.Unlock() + time.Sleep(time.Millisecond) + continue + } + for k := range server.conns { + var ok bool + sc, ok = k.(*http2Server) + if !ok { + t.Fatalf("Failed to convert %v to *http2Server", k) + } + } + server.mu.Unlock() + break + } + cc, ok := ct.(*http2Client) + if !ok { + t.Fatalf("Failed to convert %v to *http2Client", ct) + } + // Test server behavior for violation of stream flow control window size restriction. + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + var sent int + // Drain the stream flow control window + <-cc.writableChan + if err = cc.framer.writeData(true, s.id, false, make([]byte, http2MaxFrameLen)); err != nil { + t.Fatalf("Failed to write data: %v", err) + } + cc.writableChan <- 0 + sent += http2MaxFrameLen + // Wait until the server creates the corresponding stream and receive some data. + var ss *Stream + for { + time.Sleep(time.Millisecond) + sc.mu.Lock() + if len(sc.activeStreams) == 0 { + sc.mu.Unlock() + continue + } + ss = sc.activeStreams[s.id] + sc.mu.Unlock() + ss.fc.mu.Lock() + if ss.fc.pendingData > 0 { + ss.fc.mu.Unlock() + break + } + ss.fc.mu.Unlock() + } + if ss.fc.pendingData != http2MaxFrameLen || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != http2MaxFrameLen || sc.fc.pendingUpdate != 0 { + t.Fatalf("Server mistakenly updates inbound flow control params: got %d, %d, %d, %d; want %d, %d, %d, %d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, http2MaxFrameLen, 0, http2MaxFrameLen, 0) + } + // Keep sending until the server inbound window is drained for that stream. + for sent <= initialWindowSize { + <-cc.writableChan + if err = cc.framer.writeData(true, s.id, false, make([]byte, 1)); err != nil { + t.Fatalf("Failed to write data: %v", err) + } + cc.writableChan <- 0 + sent++ + } + // Server sent a resetStream for s already. + code := http2RSTErrConvTab[http2.ErrCodeFlowControl] + if _, err := io.ReadFull(s, make([]byte, 1)); err != io.EOF || s.statusCode != code { + t.Fatalf("%v got err %v with statusCode %d, want err with statusCode %d", s, err, s.statusCode, code) + } + + if ss.fc.pendingData != 0 || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate != initialWindowSize { + t.Fatalf("Server mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, %d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, initialWindowSize) + } + ct.CloseStream(s, nil) + // Test server behavior for violation of connection flow control window size restriction. + // + // Keep creating new streams until the connection window is drained on the server and + // the server tears down the connection. + for { + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + // The server tears down the connection. + break + } + <-cc.writableChan + cc.framer.writeData(true, s.id, true, make([]byte, http2MaxFrameLen)) + cc.writableChan <- 0 + } + ct.Close() + server.stop() +} + +func TestClientWithMisbehavedServer(t *testing.T) { + server, ct := setUp(t, 0, math.MaxUint32, misbehaved) + callHdr := &CallHdr{ + Host: "localhost", + Method: "foo", + } + conn, ok := ct.(*http2Client) + if !ok { + t.Fatalf("Failed to convert %v to *http2Client", ct) + } + // Test the logic for the violation of stream flow control window size restriction. + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil { + t.Fatalf("Failed to write: %v", err) + } + // Read without window update. + for { + p := make([]byte, http2MaxFrameLen) + if _, err = s.dec.Read(p); err != nil { + break + } + } + if s.fc.pendingData != initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData != initialWindowSize || conn.fc.pendingUpdate != 0 { + t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want %d, %d, %d, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, initialWindowSize, 0) + } + if err != io.EOF || s.statusCode != codes.Internal { + t.Fatalf("Got err %v and the status code %d, want and the code %d", err, s.statusCode, codes.Internal) + } + conn.CloseStream(s, err) + if s.fc.pendingData != 0 || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate != initialWindowSize { + t.Fatalf("Client mistakenly resets inbound flow control params: got %d, %d, %d, %d; want 0, 0, 0, %d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize) + } + // Test the logic for the violation of the connection flow control window size restriction. + // + // Generate enough streams to drain the connection window. + callHdr = &CallHdr{ + Host: "localhost", + Method: "foo.MaxFrame", + } + for i := 0; i < int(initialConnWindowSize/initialWindowSize+10); i++ { + s, err := ct.NewStream(context.Background(), callHdr) + if err != nil { + break + } + if err := ct.Write(s, expectedRequest, &Options{Last: true, Delay: false}); err != nil { + break + } + } + // http2Client.errChan is closed due to connection flow control window size violation. + <-conn.Error() + ct.Close() + server.stop() +} + +func TestStreamContext(t *testing.T) { + expectedStream := Stream{} + ctx := newContextWithStream(context.Background(), &expectedStream) + s, ok := StreamFromContext(ctx) + if !ok || !reflect.DeepEqual(expectedStream, *s) { + t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, *s, ok, expectedStream) + } +}