You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
2.6 KiB
85 lines
2.6 KiB
package http
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/aws/smithy-go/middleware"
|
|
)
|
|
|
|
// ComputeContentLength provides a middleware to set the content-length
|
|
// header for the length of a serialize request body.
|
|
type ComputeContentLength struct {
|
|
}
|
|
|
|
// AddComputeContentLengthMiddleware adds ComputeContentLength to the middleware
|
|
// stack's Build step.
|
|
func AddComputeContentLengthMiddleware(stack *middleware.Stack) error {
|
|
return stack.Build.Add(&ComputeContentLength{}, middleware.After)
|
|
}
|
|
|
|
// ID returns the identifier for the ComputeContentLength.
|
|
func (m *ComputeContentLength) ID() string { return "ComputeContentLength" }
|
|
|
|
// HandleBuild adds the length of the serialized request to the HTTP header
|
|
// if the length can be determined.
|
|
func (m *ComputeContentLength) HandleBuild(
|
|
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
|
|
) (
|
|
out middleware.BuildOutput, metadata middleware.Metadata, err error,
|
|
) {
|
|
req, ok := in.Request.(*Request)
|
|
if !ok {
|
|
return out, metadata, fmt.Errorf("unknown request type %T", req)
|
|
}
|
|
|
|
// do nothing if request content-length was set to 0 or above.
|
|
if req.ContentLength >= 0 {
|
|
return next.HandleBuild(ctx, in)
|
|
}
|
|
|
|
// attempt to compute stream length
|
|
if n, ok, err := req.StreamLength(); err != nil {
|
|
return out, metadata, fmt.Errorf(
|
|
"failed getting length of request stream, %w", err)
|
|
} else if ok {
|
|
req.ContentLength = n
|
|
}
|
|
|
|
return next.HandleBuild(ctx, in)
|
|
}
|
|
|
|
// validateContentLength provides a middleware to validate the content-length
|
|
// is valid (greater than zero), for the serialized request payload.
|
|
type validateContentLength struct{}
|
|
|
|
// ValidateContentLengthHeader adds middleware that validates request content-length
|
|
// is set to value greater than zero.
|
|
func ValidateContentLengthHeader(stack *middleware.Stack) error {
|
|
return stack.Build.Add(&validateContentLength{}, middleware.After)
|
|
}
|
|
|
|
// ID returns the identifier for the ComputeContentLength.
|
|
func (m *validateContentLength) ID() string { return "ValidateContentLength" }
|
|
|
|
// HandleBuild adds the length of the serialized request to the HTTP header
|
|
// if the length can be determined.
|
|
func (m *validateContentLength) HandleBuild(
|
|
ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler,
|
|
) (
|
|
out middleware.BuildOutput, metadata middleware.Metadata, err error,
|
|
) {
|
|
req, ok := in.Request.(*Request)
|
|
if !ok {
|
|
return out, metadata, fmt.Errorf("unknown request type %T", req)
|
|
}
|
|
|
|
// if request content-length was set to less than 0, return an error
|
|
if req.ContentLength < 0 {
|
|
return out, metadata, fmt.Errorf(
|
|
"content length for payload is required and must be at least 0")
|
|
}
|
|
|
|
return next.HandleBuild(ctx, in)
|
|
}
|