You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
238 lines
6.2 KiB
238 lines
6.2 KiB
package imds
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
smithy "github.com/aws/smithy-go"
|
|
"github.com/aws/smithy-go/middleware"
|
|
smithyhttp "github.com/aws/smithy-go/transport/http"
|
|
)
|
|
|
|
const (
|
|
// Headers for Token and TTL
|
|
tokenHeader = "x-aws-ec2-metadata-token"
|
|
defaultTokenTTL = 5 * time.Minute
|
|
)
|
|
|
|
type tokenProvider struct {
|
|
client *Client
|
|
tokenTTL time.Duration
|
|
|
|
token *apiToken
|
|
tokenMux sync.RWMutex
|
|
|
|
disabled uint32 // Atomic updated
|
|
}
|
|
|
|
func newTokenProvider(client *Client, ttl time.Duration) *tokenProvider {
|
|
return &tokenProvider{
|
|
client: client,
|
|
tokenTTL: ttl,
|
|
}
|
|
}
|
|
|
|
// apiToken provides the API token used by all operation calls for th EC2
|
|
// Instance metadata service.
|
|
type apiToken struct {
|
|
token string
|
|
expires time.Time
|
|
}
|
|
|
|
var timeNow = time.Now
|
|
|
|
// Expired returns if the token is expired.
|
|
func (t *apiToken) Expired() bool {
|
|
// Calling Round(0) on the current time will truncate the monotonic reading only. Ensures credential expiry
|
|
// time is always based on reported wall-clock time.
|
|
return timeNow().Round(0).After(t.expires)
|
|
}
|
|
|
|
func (t *tokenProvider) ID() string { return "APITokenProvider" }
|
|
|
|
// HandleFinalize is the finalize stack middleware, that if the token provider is
|
|
// enabled, will attempt to add the cached API token to the request. If the API
|
|
// token is not cached, it will be retrieved in a separate API call, getToken.
|
|
//
|
|
// For retry attempts, handler must be added after attempt retryer.
|
|
//
|
|
// If request for getToken fails the token provider may be disabled from future
|
|
// requests, depending on the response status code.
|
|
func (t *tokenProvider) HandleFinalize(
|
|
ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
|
|
) (
|
|
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
|
) {
|
|
if !t.enabled() {
|
|
// short-circuits to insecure data flow if token provider is disabled.
|
|
return next.HandleFinalize(ctx, input)
|
|
}
|
|
|
|
req, ok := input.Request.(*smithyhttp.Request)
|
|
if !ok {
|
|
return out, metadata, fmt.Errorf("unexpected transport request type %T", input.Request)
|
|
}
|
|
|
|
tok, err := t.getToken(ctx)
|
|
if err != nil {
|
|
// If the error allows the token to downgrade to insecure flow allow that.
|
|
var bypassErr *bypassTokenRetrievalError
|
|
if errors.As(err, &bypassErr) {
|
|
return next.HandleFinalize(ctx, input)
|
|
}
|
|
|
|
return out, metadata, fmt.Errorf("failed to get API token, %w", err)
|
|
}
|
|
|
|
req.Header.Set(tokenHeader, tok.token)
|
|
|
|
return next.HandleFinalize(ctx, input)
|
|
}
|
|
|
|
// HandleDeserialize is the deserialize stack middleware for determining if the
|
|
// operation the token provider is decorating failed because of a 401
|
|
// unauthorized status code. If the operation failed for that reason the token
|
|
// provider needs to be re-enabled so that it can start adding the API token to
|
|
// operation calls.
|
|
func (t *tokenProvider) HandleDeserialize(
|
|
ctx context.Context, input middleware.DeserializeInput, next middleware.DeserializeHandler,
|
|
) (
|
|
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
|
|
) {
|
|
out, metadata, err = next.HandleDeserialize(ctx, input)
|
|
if err == nil {
|
|
return out, metadata, err
|
|
}
|
|
|
|
resp, ok := out.RawResponse.(*smithyhttp.Response)
|
|
if !ok {
|
|
return out, metadata, fmt.Errorf("expect HTTP transport, got %T", out.RawResponse)
|
|
}
|
|
|
|
if resp.StatusCode == http.StatusUnauthorized { // unauthorized
|
|
err = &retryableError{Err: err}
|
|
t.enable()
|
|
}
|
|
|
|
return out, metadata, err
|
|
}
|
|
|
|
type retryableError struct {
|
|
Err error
|
|
}
|
|
|
|
func (*retryableError) RetryableError() bool { return true }
|
|
|
|
func (e *retryableError) Error() string { return e.Err.Error() }
|
|
|
|
func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
|
|
if !t.enabled() {
|
|
return nil, &bypassTokenRetrievalError{
|
|
Err: fmt.Errorf("cannot get API token, provider disabled"),
|
|
}
|
|
}
|
|
|
|
t.tokenMux.RLock()
|
|
tok = t.token
|
|
t.tokenMux.RUnlock()
|
|
|
|
if tok != nil && !tok.Expired() {
|
|
return tok, nil
|
|
}
|
|
|
|
tok, err = t.updateToken(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot get API token, %w", err)
|
|
}
|
|
|
|
return tok, nil
|
|
}
|
|
|
|
func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
|
|
t.tokenMux.Lock()
|
|
defer t.tokenMux.Unlock()
|
|
|
|
// Prevent multiple requests to update retrieving the token.
|
|
if t.token != nil && !t.token.Expired() {
|
|
tok := t.token
|
|
return tok, nil
|
|
}
|
|
|
|
result, err := t.client.getToken(ctx, &getTokenInput{
|
|
TokenTTL: t.tokenTTL,
|
|
})
|
|
if err != nil {
|
|
// change the disabled flag on token provider to true, when error is request timeout error.
|
|
var statusErr interface{ HTTPStatusCode() int }
|
|
if errors.As(err, &statusErr) {
|
|
switch statusErr.HTTPStatusCode() {
|
|
|
|
// Disable get token if failed because of 403, 404, or 405
|
|
case http.StatusForbidden,
|
|
http.StatusNotFound,
|
|
http.StatusMethodNotAllowed:
|
|
|
|
t.disable()
|
|
|
|
// 400 errors are terminal, and need to be upstreamed
|
|
case http.StatusBadRequest:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Disable if request send failed or timed out getting response
|
|
var re *smithyhttp.RequestSendError
|
|
var ce *smithy.CanceledError
|
|
if errors.As(err, &re) || errors.As(err, &ce) {
|
|
atomic.StoreUint32(&t.disabled, 1)
|
|
}
|
|
|
|
// Token couldn't be retrieved, but bypass this, and allow the
|
|
// request to continue.
|
|
return nil, &bypassTokenRetrievalError{Err: err}
|
|
}
|
|
|
|
tok := &apiToken{
|
|
token: result.Token,
|
|
expires: timeNow().Add(result.TokenTTL),
|
|
}
|
|
t.token = tok
|
|
|
|
return tok, nil
|
|
}
|
|
|
|
type bypassTokenRetrievalError struct {
|
|
Err error
|
|
}
|
|
|
|
func (e *bypassTokenRetrievalError) Error() string {
|
|
return fmt.Sprintf("bypass token retrieval, %v", e.Err)
|
|
}
|
|
|
|
func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }
|
|
|
|
// enabled returns if the token provider is current enabled or not.
|
|
func (t *tokenProvider) enabled() bool {
|
|
return atomic.LoadUint32(&t.disabled) == 0
|
|
}
|
|
|
|
// disable disables the token provider and it will no longer attempt to inject
|
|
// the token, nor request updates.
|
|
func (t *tokenProvider) disable() {
|
|
atomic.StoreUint32(&t.disabled, 1)
|
|
}
|
|
|
|
// enable enables the token provide to start refreshing tokens, and adding them
|
|
// to the pending request.
|
|
func (t *tokenProvider) enable() {
|
|
t.tokenMux.Lock()
|
|
t.token = nil
|
|
t.tokenMux.Unlock()
|
|
atomic.StoreUint32(&t.disabled, 0)
|
|
}
|