You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
2.7 KiB
105 lines
2.7 KiB
package http
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
"github.com/aws/smithy-go"
|
|
"github.com/aws/smithy-go/middleware"
|
|
smithyhttp "github.com/aws/smithy-go/transport/http"
|
|
)
|
|
|
|
type readResult struct {
|
|
n int
|
|
err error
|
|
}
|
|
|
|
// ResponseTimeoutError is an error when the reads from the response are
|
|
// delayed longer than the timeout the read was configured for.
|
|
type ResponseTimeoutError struct {
|
|
TimeoutDur time.Duration
|
|
}
|
|
|
|
// Timeout returns that the error is was caused by a timeout, and can be
|
|
// retried.
|
|
func (*ResponseTimeoutError) Timeout() bool { return true }
|
|
|
|
func (e *ResponseTimeoutError) Error() string {
|
|
return fmt.Sprintf("read on body reach timeout limit, %v", e.TimeoutDur)
|
|
}
|
|
|
|
// timeoutReadCloser will handle body reads that take too long.
|
|
// We will return a ErrReadTimeout error if a timeout occurs.
|
|
type timeoutReadCloser struct {
|
|
reader io.ReadCloser
|
|
duration time.Duration
|
|
}
|
|
|
|
// Read will spin off a goroutine to call the reader's Read method. We will
|
|
// select on the timer's channel or the read's channel. Whoever completes first
|
|
// will be returned.
|
|
func (r *timeoutReadCloser) Read(b []byte) (int, error) {
|
|
timer := time.NewTimer(r.duration)
|
|
c := make(chan readResult, 1)
|
|
|
|
go func() {
|
|
n, err := r.reader.Read(b)
|
|
timer.Stop()
|
|
c <- readResult{n: n, err: err}
|
|
}()
|
|
|
|
select {
|
|
case data := <-c:
|
|
return data.n, data.err
|
|
case <-timer.C:
|
|
return 0, &ResponseTimeoutError{TimeoutDur: r.duration}
|
|
}
|
|
}
|
|
|
|
func (r *timeoutReadCloser) Close() error {
|
|
return r.reader.Close()
|
|
}
|
|
|
|
// AddResponseReadTimeoutMiddleware adds a middleware to the stack that wraps the
|
|
// response body so that a read that takes too long will return an error.
|
|
func AddResponseReadTimeoutMiddleware(stack *middleware.Stack, duration time.Duration) error {
|
|
return stack.Deserialize.Add(&readTimeout{duration: duration}, middleware.After)
|
|
}
|
|
|
|
// readTimeout wraps the response body with a timeoutReadCloser
|
|
type readTimeout struct {
|
|
duration time.Duration
|
|
}
|
|
|
|
// ID returns the id of the middleware
|
|
func (*readTimeout) ID() string {
|
|
return "ReadResponseTimeout"
|
|
}
|
|
|
|
// HandleDeserialize implements the DeserializeMiddleware interface
|
|
func (m *readTimeout) HandleDeserialize(
|
|
ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
|
|
) (
|
|
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
|
|
) {
|
|
out, metadata, err = next.HandleDeserialize(ctx, in)
|
|
if err != nil {
|
|
return out, metadata, err
|
|
}
|
|
|
|
response, ok := out.RawResponse.(*smithyhttp.Response)
|
|
if !ok {
|
|
return out, metadata, &smithy.DeserializationError{Err: fmt.Errorf("unknown transport type %T", out.RawResponse)}
|
|
}
|
|
|
|
response.Body = &timeoutReadCloser{
|
|
reader: response.Body,
|
|
duration: m.duration,
|
|
}
|
|
out.RawResponse = response
|
|
|
|
return out, metadata, err
|
|
}
|