parent
0dce07f1ee
commit
7371dad756
@ -1 +0,0 @@
|
||||
* text=auto eol=lf
|
@ -1,10 +0,0 @@
|
||||
.vscode/
|
||||
|
||||
*.exe
|
||||
|
||||
# testing
|
||||
testdata
|
||||
|
||||
# go workspaces
|
||||
go.work
|
||||
go.work.sum
|
@ -1,144 +0,0 @@
|
||||
run:
|
||||
skip-dirs:
|
||||
- pkg/etw/sample
|
||||
|
||||
linters:
|
||||
enable:
|
||||
# style
|
||||
- containedctx # struct contains a context
|
||||
- dupl # duplicate code
|
||||
- errname # erorrs are named correctly
|
||||
- goconst # strings that should be constants
|
||||
- godot # comments end in a period
|
||||
- misspell
|
||||
- nolintlint # "//nolint" directives are properly explained
|
||||
- revive # golint replacement
|
||||
- stylecheck # golint replacement, less configurable than revive
|
||||
- unconvert # unnecessary conversions
|
||||
- wastedassign
|
||||
|
||||
# bugs, performance, unused, etc ...
|
||||
- contextcheck # function uses a non-inherited context
|
||||
- errorlint # errors not wrapped for 1.13
|
||||
- exhaustive # check exhaustiveness of enum switch statements
|
||||
- gofmt # files are gofmt'ed
|
||||
- gosec # security
|
||||
- nestif # deeply nested ifs
|
||||
- nilerr # returns nil even with non-nil error
|
||||
- prealloc # slices that can be pre-allocated
|
||||
- structcheck # unused struct fields
|
||||
- unparam # unused function params
|
||||
|
||||
issues:
|
||||
exclude-rules:
|
||||
# err is very often shadowed in nested scopes
|
||||
- linters:
|
||||
- govet
|
||||
text: '^shadow: declaration of "err" shadows declaration'
|
||||
|
||||
# ignore long lines for skip autogen directives
|
||||
- linters:
|
||||
- revive
|
||||
text: "^line-length-limit: "
|
||||
source: "^//(go:generate|sys) "
|
||||
|
||||
# allow unjustified ignores of error checks in defer statements
|
||||
- linters:
|
||||
- nolintlint
|
||||
text: "^directive `//nolint:errcheck` should provide explanation"
|
||||
source: '^\s*defer '
|
||||
|
||||
# allow unjustified ignores of error lints for io.EOF
|
||||
- linters:
|
||||
- nolintlint
|
||||
text: "^directive `//nolint:errorlint` should provide explanation"
|
||||
source: '[=|!]= io.EOF'
|
||||
|
||||
|
||||
linters-settings:
|
||||
govet:
|
||||
enable-all: true
|
||||
disable:
|
||||
# struct order is often for Win32 compat
|
||||
# also, ignore pointer bytes/GC issues for now until performance becomes an issue
|
||||
- fieldalignment
|
||||
check-shadowing: true
|
||||
nolintlint:
|
||||
allow-leading-space: false
|
||||
require-explanation: true
|
||||
require-specific: true
|
||||
revive:
|
||||
# revive is more configurable than static check, so likely the preferred alternative to static-check
|
||||
# (once the perf issue is solved: https://github.com/golangci/golangci-lint/issues/2997)
|
||||
enable-all-rules:
|
||||
true
|
||||
# https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md
|
||||
rules:
|
||||
# rules with required arguments
|
||||
- name: argument-limit
|
||||
disabled: true
|
||||
- name: banned-characters
|
||||
disabled: true
|
||||
- name: cognitive-complexity
|
||||
disabled: true
|
||||
- name: cyclomatic
|
||||
disabled: true
|
||||
- name: file-header
|
||||
disabled: true
|
||||
- name: function-length
|
||||
disabled: true
|
||||
- name: function-result-limit
|
||||
disabled: true
|
||||
- name: max-public-structs
|
||||
disabled: true
|
||||
# geneally annoying rules
|
||||
- name: add-constant # complains about any and all strings and integers
|
||||
disabled: true
|
||||
- name: confusing-naming # we frequently use "Foo()" and "foo()" together
|
||||
disabled: true
|
||||
- name: flag-parameter # excessive, and a common idiom we use
|
||||
disabled: true
|
||||
# general config
|
||||
- name: line-length-limit
|
||||
arguments:
|
||||
- 140
|
||||
- name: var-naming
|
||||
arguments:
|
||||
- []
|
||||
- - CID
|
||||
- CRI
|
||||
- CTRD
|
||||
- DACL
|
||||
- DLL
|
||||
- DOS
|
||||
- ETW
|
||||
- FSCTL
|
||||
- GCS
|
||||
- GMSA
|
||||
- HCS
|
||||
- HV
|
||||
- IO
|
||||
- LCOW
|
||||
- LDAP
|
||||
- LPAC
|
||||
- LTSC
|
||||
- MMIO
|
||||
- NT
|
||||
- OCI
|
||||
- PMEM
|
||||
- PWSH
|
||||
- RX
|
||||
- SACl
|
||||
- SID
|
||||
- SMB
|
||||
- TX
|
||||
- VHD
|
||||
- VHDX
|
||||
- VMID
|
||||
- VPCI
|
||||
- WCOW
|
||||
- WIM
|
||||
stylecheck:
|
||||
checks:
|
||||
- "all"
|
||||
- "-ST1003" # use revive's var naming
|
@ -1 +0,0 @@
|
||||
* @microsoft/containerplat
|
@ -1,22 +0,0 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015 Microsoft
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
@ -1,89 +0,0 @@
|
||||
# go-winio [![Build Status](https://github.com/microsoft/go-winio/actions/workflows/ci.yml/badge.svg)](https://github.com/microsoft/go-winio/actions/workflows/ci.yml)
|
||||
|
||||
This repository contains utilities for efficiently performing Win32 IO operations in
|
||||
Go. Currently, this is focused on accessing named pipes and other file handles, and
|
||||
for using named pipes as a net transport.
|
||||
|
||||
This code relies on IO completion ports to avoid blocking IO on system threads, allowing Go
|
||||
to reuse the thread to schedule another goroutine. This limits support to Windows Vista and
|
||||
newer operating systems. This is similar to the implementation of network sockets in Go's net
|
||||
package.
|
||||
|
||||
Please see the LICENSE file for licensing information.
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions.
|
||||
Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that
|
||||
you have the right to, and actually do, grant us the rights to use your contribution.
|
||||
For details, visit [Microsoft CLA](https://cla.microsoft.com).
|
||||
|
||||
When you submit a pull request, a CLA-bot will automatically determine whether you need to
|
||||
provide a CLA and decorate the PR appropriately (e.g., label, comment).
|
||||
Simply follow the instructions provided by the bot.
|
||||
You will only need to do this once across all repos using our CLA.
|
||||
|
||||
Additionally, the pull request pipeline requires the following steps to be performed before
|
||||
mergining.
|
||||
|
||||
### Code Sign-Off
|
||||
|
||||
We require that contributors sign their commits using [`git commit --signoff`][git-commit-s]
|
||||
to certify they either authored the work themselves or otherwise have permission to use it in this project.
|
||||
|
||||
A range of commits can be signed off using [`git rebase --signoff`][git-rebase-s].
|
||||
|
||||
Please see [the developer certificate](https://developercertificate.org) for more info,
|
||||
as well as to make sure that you can attest to the rules listed.
|
||||
Our CI uses the DCO Github app to ensure that all commits in a given PR are signed-off.
|
||||
|
||||
### Linting
|
||||
|
||||
Code must pass a linting stage, which uses [`golangci-lint`][lint].
|
||||
The linting settings are stored in [`.golangci.yaml`](./.golangci.yaml), and can be run
|
||||
automatically with VSCode by adding the following to your workspace or folder settings:
|
||||
|
||||
```json
|
||||
"go.lintTool": "golangci-lint",
|
||||
"go.lintOnSave": "package",
|
||||
```
|
||||
|
||||
Additional editor [integrations options are also available][lint-ide].
|
||||
|
||||
Alternatively, `golangci-lint` can be [installed locally][lint-install] and run from the repo root:
|
||||
|
||||
```shell
|
||||
# use . or specify a path to only lint a package
|
||||
# to show all lint errors, use flags "--max-issues-per-linter=0 --max-same-issues=0"
|
||||
> golangci-lint run ./...
|
||||
```
|
||||
|
||||
### Go Generate
|
||||
|
||||
The pipeline checks that auto-generated code, via `go generate`, are up to date.
|
||||
|
||||
This can be done for the entire repo:
|
||||
|
||||
```shell
|
||||
> go generate ./...
|
||||
```
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
## Special Thanks
|
||||
|
||||
Thanks to [natefinch][natefinch] for the inspiration for this library.
|
||||
See [npipe](https://github.com/natefinch/npipe) for another named pipe implementation.
|
||||
|
||||
[lint]: https://golangci-lint.run/
|
||||
[lint-ide]: https://golangci-lint.run/usage/integrations/#editor-integration
|
||||
[lint-install]: https://golangci-lint.run/usage/install/#local-installation
|
||||
|
||||
[git-commit-s]: https://git-scm.com/docs/git-commit#Documentation/git-commit.txt--s
|
||||
[git-rebase-s]: https://git-scm.com/docs/git-rebase#Documentation/git-rebase.txt---signoff
|
||||
|
||||
[natefinch]: https://github.com/natefinch
|
@ -1,41 +0,0 @@
|
||||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
@ -1,290 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"unicode/utf16"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead
|
||||
//sys backupWrite(h syscall.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite
|
||||
|
||||
const (
|
||||
BackupData = uint32(iota + 1)
|
||||
BackupEaData
|
||||
BackupSecurity
|
||||
BackupAlternateData
|
||||
BackupLink
|
||||
BackupPropertyData
|
||||
BackupObjectId //revive:disable-line:var-naming ID, not Id
|
||||
BackupReparseData
|
||||
BackupSparseBlock
|
||||
BackupTxfsData
|
||||
)
|
||||
|
||||
const (
|
||||
StreamSparseAttributes = uint32(8)
|
||||
)
|
||||
|
||||
//nolint:revive // var-naming: ALL_CAPS
|
||||
const (
|
||||
WRITE_DAC = windows.WRITE_DAC
|
||||
WRITE_OWNER = windows.WRITE_OWNER
|
||||
ACCESS_SYSTEM_SECURITY = windows.ACCESS_SYSTEM_SECURITY
|
||||
)
|
||||
|
||||
// BackupHeader represents a backup stream of a file.
|
||||
type BackupHeader struct {
|
||||
//revive:disable-next-line:var-naming ID, not Id
|
||||
Id uint32 // The backup stream ID
|
||||
Attributes uint32 // Stream attributes
|
||||
Size int64 // The size of the stream in bytes
|
||||
Name string // The name of the stream (for BackupAlternateData only).
|
||||
Offset int64 // The offset of the stream in the file (for BackupSparseBlock only).
|
||||
}
|
||||
|
||||
type win32StreamID struct {
|
||||
StreamID uint32
|
||||
Attributes uint32
|
||||
Size uint64
|
||||
NameSize uint32
|
||||
}
|
||||
|
||||
// BackupStreamReader reads from a stream produced by the BackupRead Win32 API and produces a series
|
||||
// of BackupHeader values.
|
||||
type BackupStreamReader struct {
|
||||
r io.Reader
|
||||
bytesLeft int64
|
||||
}
|
||||
|
||||
// NewBackupStreamReader produces a BackupStreamReader from any io.Reader.
|
||||
func NewBackupStreamReader(r io.Reader) *BackupStreamReader {
|
||||
return &BackupStreamReader{r, 0}
|
||||
}
|
||||
|
||||
// Next returns the next backup stream and prepares for calls to Read(). It skips the remainder of the current stream if
|
||||
// it was not completely read.
|
||||
func (r *BackupStreamReader) Next() (*BackupHeader, error) {
|
||||
if r.bytesLeft > 0 { //nolint:nestif // todo: flatten this
|
||||
if s, ok := r.r.(io.Seeker); ok {
|
||||
// Make sure Seek on io.SeekCurrent sometimes succeeds
|
||||
// before trying the actual seek.
|
||||
if _, err := s.Seek(0, io.SeekCurrent); err == nil {
|
||||
if _, err = s.Seek(r.bytesLeft, io.SeekCurrent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.bytesLeft = 0
|
||||
}
|
||||
}
|
||||
if _, err := io.Copy(io.Discard, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
var wsi win32StreamID
|
||||
if err := binary.Read(r.r, binary.LittleEndian, &wsi); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr := &BackupHeader{
|
||||
Id: wsi.StreamID,
|
||||
Attributes: wsi.Attributes,
|
||||
Size: int64(wsi.Size),
|
||||
}
|
||||
if wsi.NameSize != 0 {
|
||||
name := make([]uint16, int(wsi.NameSize/2))
|
||||
if err := binary.Read(r.r, binary.LittleEndian, name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr.Name = syscall.UTF16ToString(name)
|
||||
}
|
||||
if wsi.StreamID == BackupSparseBlock {
|
||||
if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hdr.Size -= 8
|
||||
}
|
||||
r.bytesLeft = hdr.Size
|
||||
return hdr, nil
|
||||
}
|
||||
|
||||
// Read reads from the current backup stream.
|
||||
func (r *BackupStreamReader) Read(b []byte) (int, error) {
|
||||
if r.bytesLeft == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if int64(len(b)) > r.bytesLeft {
|
||||
b = b[:r.bytesLeft]
|
||||
}
|
||||
n, err := r.r.Read(b)
|
||||
r.bytesLeft -= int64(n)
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
} else if r.bytesLeft == 0 && err == nil {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// BackupStreamWriter writes a stream compatible with the BackupWrite Win32 API.
|
||||
type BackupStreamWriter struct {
|
||||
w io.Writer
|
||||
bytesLeft int64
|
||||
}
|
||||
|
||||
// NewBackupStreamWriter produces a BackupStreamWriter on top of an io.Writer.
|
||||
func NewBackupStreamWriter(w io.Writer) *BackupStreamWriter {
|
||||
return &BackupStreamWriter{w, 0}
|
||||
}
|
||||
|
||||
// WriteHeader writes the next backup stream header and prepares for calls to Write().
|
||||
func (w *BackupStreamWriter) WriteHeader(hdr *BackupHeader) error {
|
||||
if w.bytesLeft != 0 {
|
||||
return fmt.Errorf("missing %d bytes", w.bytesLeft)
|
||||
}
|
||||
name := utf16.Encode([]rune(hdr.Name))
|
||||
wsi := win32StreamID{
|
||||
StreamID: hdr.Id,
|
||||
Attributes: hdr.Attributes,
|
||||
Size: uint64(hdr.Size),
|
||||
NameSize: uint32(len(name) * 2),
|
||||
}
|
||||
if hdr.Id == BackupSparseBlock {
|
||||
// Include space for the int64 block offset
|
||||
wsi.Size += 8
|
||||
}
|
||||
if err := binary.Write(w.w, binary.LittleEndian, &wsi); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(name) != 0 {
|
||||
if err := binary.Write(w.w, binary.LittleEndian, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if hdr.Id == BackupSparseBlock {
|
||||
if err := binary.Write(w.w, binary.LittleEndian, hdr.Offset); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.bytesLeft = hdr.Size
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes to the current backup stream.
|
||||
func (w *BackupStreamWriter) Write(b []byte) (int, error) {
|
||||
if w.bytesLeft < int64(len(b)) {
|
||||
return 0, fmt.Errorf("too many bytes by %d", int64(len(b))-w.bytesLeft)
|
||||
}
|
||||
n, err := w.w.Write(b)
|
||||
w.bytesLeft -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// BackupFileReader provides an io.ReadCloser interface on top of the BackupRead Win32 API.
|
||||
type BackupFileReader struct {
|
||||
f *os.File
|
||||
includeSecurity bool
|
||||
ctx uintptr
|
||||
}
|
||||
|
||||
// NewBackupFileReader returns a new BackupFileReader from a file handle. If includeSecurity is true,
|
||||
// Read will attempt to read the security descriptor of the file.
|
||||
func NewBackupFileReader(f *os.File, includeSecurity bool) *BackupFileReader {
|
||||
r := &BackupFileReader{f, includeSecurity, 0}
|
||||
return r
|
||||
}
|
||||
|
||||
// Read reads a backup stream from the file by calling the Win32 API BackupRead().
|
||||
func (r *BackupFileReader) Read(b []byte) (int, error) {
|
||||
var bytesRead uint32
|
||||
err := backupRead(syscall.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(r.f)
|
||||
if bytesRead == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return int(bytesRead), nil
|
||||
}
|
||||
|
||||
// Close frees Win32 resources associated with the BackupFileReader. It does not close
|
||||
// the underlying file.
|
||||
func (r *BackupFileReader) Close() error {
|
||||
if r.ctx != 0 {
|
||||
_ = backupRead(syscall.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx)
|
||||
runtime.KeepAlive(r.f)
|
||||
r.ctx = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BackupFileWriter provides an io.WriteCloser interface on top of the BackupWrite Win32 API.
|
||||
type BackupFileWriter struct {
|
||||
f *os.File
|
||||
includeSecurity bool
|
||||
ctx uintptr
|
||||
}
|
||||
|
||||
// NewBackupFileWriter returns a new BackupFileWriter from a file handle. If includeSecurity is true,
|
||||
// Write() will attempt to restore the security descriptor from the stream.
|
||||
func NewBackupFileWriter(f *os.File, includeSecurity bool) *BackupFileWriter {
|
||||
w := &BackupFileWriter{f, includeSecurity, 0}
|
||||
return w
|
||||
}
|
||||
|
||||
// Write restores a portion of the file using the provided backup stream.
|
||||
func (w *BackupFileWriter) Write(b []byte) (int, error) {
|
||||
var bytesWritten uint32
|
||||
err := backupWrite(syscall.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(w.f)
|
||||
if int(bytesWritten) != len(b) {
|
||||
return int(bytesWritten), errors.New("not all bytes could be written")
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Close frees Win32 resources associated with the BackupFileWriter. It does not
|
||||
// close the underlying file.
|
||||
func (w *BackupFileWriter) Close() error {
|
||||
if w.ctx != 0 {
|
||||
_ = backupWrite(syscall.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx)
|
||||
runtime.KeepAlive(w.f)
|
||||
w.ctx = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenForBackup opens a file or directory, potentially skipping access checks if the backup
|
||||
// or restore privileges have been acquired.
|
||||
//
|
||||
// If the file opened was a directory, it cannot be used with Readdir().
|
||||
func OpenForBackup(path string, access uint32, share uint32, createmode uint32) (*os.File, error) {
|
||||
winPath, err := syscall.UTF16FromString(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h, err := syscall.CreateFile(&winPath[0],
|
||||
access,
|
||||
share,
|
||||
nil,
|
||||
createmode,
|
||||
syscall.FILE_FLAG_BACKUP_SEMANTICS|syscall.FILE_FLAG_OPEN_REPARSE_POINT,
|
||||
0)
|
||||
if err != nil {
|
||||
err = &os.PathError{Op: "open", Path: path, Err: err}
|
||||
return nil, err
|
||||
}
|
||||
return os.NewFile(uintptr(h), path), nil
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
// This package provides utilities for efficiently performing Win32 IO operations in Go.
|
||||
// Currently, this package is provides support for genreal IO and management of
|
||||
// - named pipes
|
||||
// - files
|
||||
// - [Hyper-V sockets]
|
||||
//
|
||||
// This code is similar to Go's [net] package, and uses IO completion ports to avoid
|
||||
// blocking IO on system threads, allowing Go to reuse the thread to schedule other goroutines.
|
||||
//
|
||||
// This limits support to Windows Vista and newer operating systems.
|
||||
//
|
||||
// Additionally, this package provides support for:
|
||||
// - creating and managing GUIDs
|
||||
// - writing to [ETW]
|
||||
// - opening and manageing VHDs
|
||||
// - parsing [Windows Image files]
|
||||
// - auto-generating Win32 API code
|
||||
//
|
||||
// [Hyper-V sockets]: https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service
|
||||
// [ETW]: https://docs.microsoft.com/en-us/windows-hardware/drivers/devtest/event-tracing-for-windows--etw-
|
||||
// [Windows Image files]: https://docs.microsoft.com/en-us/windows-hardware/manufacture/desktop/work-with-windows-images
|
||||
package winio
|
@ -1,137 +0,0 @@
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type fileFullEaInformation struct {
|
||||
NextEntryOffset uint32
|
||||
Flags uint8
|
||||
NameLength uint8
|
||||
ValueLength uint16
|
||||
}
|
||||
|
||||
var (
|
||||
fileFullEaInformationSize = binary.Size(&fileFullEaInformation{})
|
||||
|
||||
errInvalidEaBuffer = errors.New("invalid extended attribute buffer")
|
||||
errEaNameTooLarge = errors.New("extended attribute name too large")
|
||||
errEaValueTooLarge = errors.New("extended attribute value too large")
|
||||
)
|
||||
|
||||
// ExtendedAttribute represents a single Windows EA.
|
||||
type ExtendedAttribute struct {
|
||||
Name string
|
||||
Value []byte
|
||||
Flags uint8
|
||||
}
|
||||
|
||||
func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) {
|
||||
var info fileFullEaInformation
|
||||
err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
err = errInvalidEaBuffer
|
||||
return ea, nb, err
|
||||
}
|
||||
|
||||
nameOffset := fileFullEaInformationSize
|
||||
nameLen := int(info.NameLength)
|
||||
valueOffset := nameOffset + int(info.NameLength) + 1
|
||||
valueLen := int(info.ValueLength)
|
||||
nextOffset := int(info.NextEntryOffset)
|
||||
if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) {
|
||||
err = errInvalidEaBuffer
|
||||
return ea, nb, err
|
||||
}
|
||||
|
||||
ea.Name = string(b[nameOffset : nameOffset+nameLen])
|
||||
ea.Value = b[valueOffset : valueOffset+valueLen]
|
||||
ea.Flags = info.Flags
|
||||
if info.NextEntryOffset != 0 {
|
||||
nb = b[info.NextEntryOffset:]
|
||||
}
|
||||
return ea, nb, err
|
||||
}
|
||||
|
||||
// DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION
|
||||
// buffer retrieved from BackupRead, ZwQueryEaFile, etc.
|
||||
func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) {
|
||||
for len(b) != 0 {
|
||||
ea, nb, err := parseEa(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eas = append(eas, ea)
|
||||
b = nb
|
||||
}
|
||||
return eas, err
|
||||
}
|
||||
|
||||
func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error {
|
||||
if int(uint8(len(ea.Name))) != len(ea.Name) {
|
||||
return errEaNameTooLarge
|
||||
}
|
||||
if int(uint16(len(ea.Value))) != len(ea.Value) {
|
||||
return errEaValueTooLarge
|
||||
}
|
||||
entrySize := uint32(fileFullEaInformationSize + len(ea.Name) + 1 + len(ea.Value))
|
||||
withPadding := (entrySize + 3) &^ 3
|
||||
nextOffset := uint32(0)
|
||||
if !last {
|
||||
nextOffset = withPadding
|
||||
}
|
||||
info := fileFullEaInformation{
|
||||
NextEntryOffset: nextOffset,
|
||||
Flags: ea.Flags,
|
||||
NameLength: uint8(len(ea.Name)),
|
||||
ValueLength: uint16(len(ea.Value)),
|
||||
}
|
||||
|
||||
err := binary.Write(buf, binary.LittleEndian, &info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte(ea.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = buf.WriteByte(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write(ea.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = buf.Write([]byte{0, 0, 0}[0 : withPadding-entrySize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeExtendedAttributes encodes a list of EAs into a FILE_FULL_EA_INFORMATION
|
||||
// buffer for use with BackupWrite, ZwSetEaFile, etc.
|
||||
func EncodeExtendedAttributes(eas []ExtendedAttribute) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
for i := range eas {
|
||||
last := false
|
||||
if i == len(eas)-1 {
|
||||
last = true
|
||||
}
|
||||
|
||||
err := writeEa(&buf, &eas[i], last)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
@ -1,331 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
|
||||
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
|
||||
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
||||
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
||||
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
||||
|
||||
type atomicBool int32
|
||||
|
||||
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
||||
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
||||
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
||||
|
||||
//revive:disable-next-line:predeclared Keep "new" to maintain consistency with "atomic" pkg
|
||||
func (b *atomicBool) swap(new bool) bool {
|
||||
var newInt int32
|
||||
if new {
|
||||
newInt = 1
|
||||
}
|
||||
return atomic.SwapInt32((*int32)(b), newInt) == 1
|
||||
}
|
||||
|
||||
var (
|
||||
ErrFileClosed = errors.New("file has already been closed")
|
||||
ErrTimeout = &timeoutError{}
|
||||
)
|
||||
|
||||
type timeoutError struct{}
|
||||
|
||||
func (*timeoutError) Error() string { return "i/o timeout" }
|
||||
func (*timeoutError) Timeout() bool { return true }
|
||||
func (*timeoutError) Temporary() bool { return true }
|
||||
|
||||
type timeoutChan chan struct{}
|
||||
|
||||
var ioInitOnce sync.Once
|
||||
var ioCompletionPort syscall.Handle
|
||||
|
||||
// ioResult contains the result of an asynchronous IO operation.
|
||||
type ioResult struct {
|
||||
bytes uint32
|
||||
err error
|
||||
}
|
||||
|
||||
// ioOperation represents an outstanding asynchronous Win32 IO.
|
||||
type ioOperation struct {
|
||||
o syscall.Overlapped
|
||||
ch chan ioResult
|
||||
}
|
||||
|
||||
func initIO() {
|
||||
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ioCompletionPort = h
|
||||
go ioCompletionProcessor(h)
|
||||
}
|
||||
|
||||
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
||||
// It takes ownership of this handle and will close it if it is garbage collected.
|
||||
type win32File struct {
|
||||
handle syscall.Handle
|
||||
wg sync.WaitGroup
|
||||
wgLock sync.RWMutex
|
||||
closing atomicBool
|
||||
socket bool
|
||||
readDeadline deadlineHandler
|
||||
writeDeadline deadlineHandler
|
||||
}
|
||||
|
||||
type deadlineHandler struct {
|
||||
setLock sync.Mutex
|
||||
channel timeoutChan
|
||||
channelLock sync.RWMutex
|
||||
timer *time.Timer
|
||||
timedout atomicBool
|
||||
}
|
||||
|
||||
// makeWin32File makes a new win32File from an existing file handle.
|
||||
func makeWin32File(h syscall.Handle) (*win32File, error) {
|
||||
f := &win32File{handle: h}
|
||||
ioInitOnce.Do(initIO)
|
||||
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = setFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f.readDeadline.channel = make(timeoutChan)
|
||||
f.writeDeadline.channel = make(timeoutChan)
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
|
||||
// If we return the result of makeWin32File directly, it can result in an
|
||||
// interface-wrapped nil, rather than a nil interface value.
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// closeHandle closes the resources associated with a Win32 handle.
|
||||
func (f *win32File) closeHandle() {
|
||||
f.wgLock.Lock()
|
||||
// Atomically set that we are closing, releasing the resources only once.
|
||||
if !f.closing.swap(true) {
|
||||
f.wgLock.Unlock()
|
||||
// cancel all IO and wait for it to complete
|
||||
_ = cancelIoEx(f.handle, nil)
|
||||
f.wg.Wait()
|
||||
// at this point, no new IO can start
|
||||
syscall.Close(f.handle)
|
||||
f.handle = 0
|
||||
} else {
|
||||
f.wgLock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes a win32File.
|
||||
func (f *win32File) Close() error {
|
||||
f.closeHandle()
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsClosed checks if the file has been closed.
|
||||
func (f *win32File) IsClosed() bool {
|
||||
return f.closing.isSet()
|
||||
}
|
||||
|
||||
// prepareIO prepares for a new IO operation.
|
||||
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||
func (f *win32File) prepareIO() (*ioOperation, error) {
|
||||
f.wgLock.RLock()
|
||||
if f.closing.isSet() {
|
||||
f.wgLock.RUnlock()
|
||||
return nil, ErrFileClosed
|
||||
}
|
||||
f.wg.Add(1)
|
||||
f.wgLock.RUnlock()
|
||||
c := &ioOperation{}
|
||||
c.ch = make(chan ioResult)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// ioCompletionProcessor processes completed async IOs forever.
|
||||
func ioCompletionProcessor(h syscall.Handle) {
|
||||
for {
|
||||
var bytes uint32
|
||||
var key uintptr
|
||||
var op *ioOperation
|
||||
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE)
|
||||
if op == nil {
|
||||
panic(err)
|
||||
}
|
||||
op.ch <- ioResult{bytes, err}
|
||||
}
|
||||
}
|
||||
|
||||
// todo: helsaawy - create an asyncIO version that takes a context
|
||||
|
||||
// asyncIO processes the return value from ReadFile or WriteFile, blocking until
|
||||
// the operation has actually completed.
|
||||
func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
||||
if err != syscall.ERROR_IO_PENDING { //nolint:errorlint // err is Errno
|
||||
return int(bytes), err
|
||||
}
|
||||
|
||||
if f.closing.isSet() {
|
||||
_ = cancelIoEx(f.handle, &c.o)
|
||||
}
|
||||
|
||||
var timeout timeoutChan
|
||||
if d != nil {
|
||||
d.channelLock.Lock()
|
||||
timeout = d.channel
|
||||
d.channelLock.Unlock()
|
||||
}
|
||||
|
||||
var r ioResult
|
||||
select {
|
||||
case r = <-c.ch:
|
||||
err = r.err
|
||||
if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
|
||||
if f.closing.isSet() {
|
||||
err = ErrFileClosed
|
||||
}
|
||||
} else if err != nil && f.socket {
|
||||
// err is from Win32. Query the overlapped structure to get the winsock error.
|
||||
var bytes, flags uint32
|
||||
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
||||
}
|
||||
case <-timeout:
|
||||
_ = cancelIoEx(f.handle, &c.o)
|
||||
r = <-c.ch
|
||||
err = r.err
|
||||
if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
|
||||
err = ErrTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// runtime.KeepAlive is needed, as c is passed via native
|
||||
// code to ioCompletionProcessor, c must remain alive
|
||||
// until the channel read is complete.
|
||||
// todo: (de)allocate *ioOperation via win32 heap functions, instead of needing to KeepAlive?
|
||||
runtime.KeepAlive(c)
|
||||
return int(r.bytes), err
|
||||
}
|
||||
|
||||
// Read reads from a file handle.
|
||||
func (f *win32File) Read(b []byte) (int, error) {
|
||||
c, err := f.prepareIO()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.readDeadline.timedout.isSet() {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
|
||||
var bytes uint32
|
||||
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
|
||||
n, err := f.asyncIO(c, &f.readDeadline, bytes, err)
|
||||
runtime.KeepAlive(b)
|
||||
|
||||
// Handle EOF conditions.
|
||||
if err == nil && n == 0 && len(b) != 0 {
|
||||
return 0, io.EOF
|
||||
} else if err == syscall.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno
|
||||
return 0, io.EOF
|
||||
} else {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
// Write writes to a file handle.
|
||||
func (f *win32File) Write(b []byte) (int, error) {
|
||||
c, err := f.prepareIO()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.wg.Done()
|
||||
|
||||
if f.writeDeadline.timedout.isSet() {
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
|
||||
var bytes uint32
|
||||
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
|
||||
n, err := f.asyncIO(c, &f.writeDeadline, bytes, err)
|
||||
runtime.KeepAlive(b)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (f *win32File) SetReadDeadline(deadline time.Time) error {
|
||||
return f.readDeadline.set(deadline)
|
||||
}
|
||||
|
||||
func (f *win32File) SetWriteDeadline(deadline time.Time) error {
|
||||
return f.writeDeadline.set(deadline)
|
||||
}
|
||||
|
||||
func (f *win32File) Flush() error {
|
||||
return syscall.FlushFileBuffers(f.handle)
|
||||
}
|
||||
|
||||
func (f *win32File) Fd() uintptr {
|
||||
return uintptr(f.handle)
|
||||
}
|
||||
|
||||
func (d *deadlineHandler) set(deadline time.Time) error {
|
||||
d.setLock.Lock()
|
||||
defer d.setLock.Unlock()
|
||||
|
||||
if d.timer != nil {
|
||||
if !d.timer.Stop() {
|
||||
<-d.channel
|
||||
}
|
||||
d.timer = nil
|
||||
}
|
||||
d.timedout.setFalse()
|
||||
|
||||
select {
|
||||
case <-d.channel:
|
||||
d.channelLock.Lock()
|
||||
d.channel = make(chan struct{})
|
||||
d.channelLock.Unlock()
|
||||
default:
|
||||
}
|
||||
|
||||
if deadline.IsZero() {
|
||||
return nil
|
||||
}
|
||||
|
||||
timeoutIO := func() {
|
||||
d.timedout.setTrue()
|
||||
close(d.channel)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
duration := deadline.Sub(now)
|
||||
if deadline.After(now) {
|
||||
// Deadline is in the future, set a timer to wait
|
||||
d.timer = time.AfterFunc(duration, timeoutIO)
|
||||
} else {
|
||||
// Deadline is in the past. Cancel all pending IO now.
|
||||
timeoutIO()
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,92 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// FileBasicInfo contains file access time and file attributes information.
|
||||
type FileBasicInfo struct {
|
||||
CreationTime, LastAccessTime, LastWriteTime, ChangeTime windows.Filetime
|
||||
FileAttributes uint32
|
||||
_ uint32 // padding
|
||||
}
|
||||
|
||||
// GetFileBasicInfo retrieves times and attributes for a file.
|
||||
func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
|
||||
bi := &FileBasicInfo{}
|
||||
if err := windows.GetFileInformationByHandleEx(
|
||||
windows.Handle(f.Fd()),
|
||||
windows.FileBasicInfo,
|
||||
(*byte)(unsafe.Pointer(bi)),
|
||||
uint32(unsafe.Sizeof(*bi)),
|
||||
); err != nil {
|
||||
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return bi, nil
|
||||
}
|
||||
|
||||
// SetFileBasicInfo sets times and attributes for a file.
|
||||
func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error {
|
||||
if err := windows.SetFileInformationByHandle(
|
||||
windows.Handle(f.Fd()),
|
||||
windows.FileBasicInfo,
|
||||
(*byte)(unsafe.Pointer(bi)),
|
||||
uint32(unsafe.Sizeof(*bi)),
|
||||
); err != nil {
|
||||
return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileStandardInfo contains extended information for the file.
|
||||
// FILE_STANDARD_INFO in WinBase.h
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-file_standard_info
|
||||
type FileStandardInfo struct {
|
||||
AllocationSize, EndOfFile int64
|
||||
NumberOfLinks uint32
|
||||
DeletePending, Directory bool
|
||||
}
|
||||
|
||||
// GetFileStandardInfo retrieves ended information for the file.
|
||||
func GetFileStandardInfo(f *os.File) (*FileStandardInfo, error) {
|
||||
si := &FileStandardInfo{}
|
||||
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()),
|
||||
windows.FileStandardInfo,
|
||||
(*byte)(unsafe.Pointer(si)),
|
||||
uint32(unsafe.Sizeof(*si))); err != nil {
|
||||
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return si, nil
|
||||
}
|
||||
|
||||
// FileIDInfo contains the volume serial number and file ID for a file. This pair should be
|
||||
// unique on a system.
|
||||
type FileIDInfo struct {
|
||||
VolumeSerialNumber uint64
|
||||
FileID [16]byte
|
||||
}
|
||||
|
||||
// GetFileID retrieves the unique (volume, file ID) pair for a file.
|
||||
func GetFileID(f *os.File) (*FileIDInfo, error) {
|
||||
fileID := &FileIDInfo{}
|
||||
if err := windows.GetFileInformationByHandleEx(
|
||||
windows.Handle(f.Fd()),
|
||||
windows.FileIdInfo,
|
||||
(*byte)(unsafe.Pointer(fileID)),
|
||||
uint32(unsafe.Sizeof(*fileID)),
|
||||
); err != nil {
|
||||
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
|
||||
}
|
||||
runtime.KeepAlive(f)
|
||||
return fileID, nil
|
||||
}
|
@ -1,575 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/Microsoft/go-winio/internal/socket"
|
||||
"github.com/Microsoft/go-winio/pkg/guid"
|
||||
)
|
||||
|
||||
const afHVSock = 34 // AF_HYPERV
|
||||
|
||||
// Well known Service and VM IDs
|
||||
//https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
|
||||
|
||||
// HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
|
||||
func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
|
||||
return guid.GUID{}
|
||||
}
|
||||
|
||||
// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions.
|
||||
func HvsockGUIDBroadcast() guid.GUID { //ffffffff-ffff-ffff-ffff-ffffffffffff
|
||||
return guid.GUID{
|
||||
Data1: 0xffffffff,
|
||||
Data2: 0xffff,
|
||||
Data3: 0xffff,
|
||||
Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
|
||||
func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
|
||||
return guid.GUID{
|
||||
Data1: 0xe0e16197,
|
||||
Data2: 0xdd56,
|
||||
Data3: 0x4a10,
|
||||
Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDSiloHost is the address of a silo's host partition:
|
||||
// - The silo host of a hosted silo is the utility VM.
|
||||
// - The silo host of a silo on a physical host is the physical host.
|
||||
func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
|
||||
return guid.GUID{
|
||||
Data1: 0x36bd0c5c,
|
||||
Data2: 0x7276,
|
||||
Data3: 0x4223,
|
||||
Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
|
||||
func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
|
||||
return guid.GUID{
|
||||
Data1: 0x90db8b89,
|
||||
Data2: 0xd35,
|
||||
Data3: 0x4f79,
|
||||
Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
|
||||
}
|
||||
}
|
||||
|
||||
// HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
|
||||
// Listening on this VmId accepts connection from:
|
||||
// - Inside silos: silo host partition.
|
||||
// - Inside hosted silo: host of the VM.
|
||||
// - Inside VM: VM host.
|
||||
// - Physical host: Not supported.
|
||||
func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
|
||||
return guid.GUID{
|
||||
Data1: 0xa42e7cda,
|
||||
Data2: 0xd03f,
|
||||
Data3: 0x480c,
|
||||
Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
|
||||
}
|
||||
}
|
||||
|
||||
// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol.
|
||||
func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
|
||||
return guid.GUID{
|
||||
Data2: 0xfacb,
|
||||
Data3: 0x11e6,
|
||||
Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
|
||||
}
|
||||
}
|
||||
|
||||
// An HvsockAddr is an address for a AF_HYPERV socket.
|
||||
type HvsockAddr struct {
|
||||
VMID guid.GUID
|
||||
ServiceID guid.GUID
|
||||
}
|
||||
|
||||
type rawHvsockAddr struct {
|
||||
Family uint16
|
||||
_ uint16
|
||||
VMID guid.GUID
|
||||
ServiceID guid.GUID
|
||||
}
|
||||
|
||||
var _ socket.RawSockaddr = &rawHvsockAddr{}
|
||||
|
||||
// Network returns the address's network name, "hvsock".
|
||||
func (*HvsockAddr) Network() string {
|
||||
return "hvsock"
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) String() string {
|
||||
return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
|
||||
}
|
||||
|
||||
// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
|
||||
func VsockServiceID(port uint32) guid.GUID {
|
||||
g := hvsockVsockServiceTemplate() // make a copy
|
||||
g.Data1 = port
|
||||
return g
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) raw() rawHvsockAddr {
|
||||
return rawHvsockAddr{
|
||||
Family: afHVSock,
|
||||
VMID: addr.VMID,
|
||||
ServiceID: addr.ServiceID,
|
||||
}
|
||||
}
|
||||
|
||||
func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
|
||||
addr.VMID = raw.VMID
|
||||
addr.ServiceID = raw.ServiceID
|
||||
}
|
||||
|
||||
// Sockaddr returns a pointer to and the size of this struct.
|
||||
//
|
||||
// Implements the [socket.RawSockaddr] interface, and allows use in
|
||||
// [socket.Bind] and [socket.ConnectEx].
|
||||
func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
|
||||
return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
|
||||
}
|
||||
|
||||
// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`.
|
||||
func (r *rawHvsockAddr) FromBytes(b []byte) error {
|
||||
n := int(unsafe.Sizeof(rawHvsockAddr{}))
|
||||
|
||||
if len(b) < n {
|
||||
return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
|
||||
}
|
||||
|
||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
|
||||
if r.Family != afHVSock {
|
||||
return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HvsockListener is a socket listener for the AF_HYPERV address family.
|
||||
type HvsockListener struct {
|
||||
sock *win32File
|
||||
addr HvsockAddr
|
||||
}
|
||||
|
||||
var _ net.Listener = &HvsockListener{}
|
||||
|
||||
// HvsockConn is a connected socket of the AF_HYPERV address family.
|
||||
type HvsockConn struct {
|
||||
sock *win32File
|
||||
local, remote HvsockAddr
|
||||
}
|
||||
|
||||
var _ net.Conn = &HvsockConn{}
|
||||
|
||||
func newHVSocket() (*win32File, error) {
|
||||
fd, err := syscall.Socket(afHVSock, syscall.SOCK_STREAM, 1)
|
||||
if err != nil {
|
||||
return nil, os.NewSyscallError("socket", err)
|
||||
}
|
||||
f, err := makeWin32File(fd)
|
||||
if err != nil {
|
||||
syscall.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
f.socket = true
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// ListenHvsock listens for connections on the specified hvsock address.
|
||||
func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
|
||||
l := &HvsockListener{addr: *addr}
|
||||
sock, err := newHVSocket()
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", err)
|
||||
}
|
||||
sa := addr.raw()
|
||||
err = socket.Bind(windows.Handle(sock.handle), &sa)
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", os.NewSyscallError("socket", err))
|
||||
}
|
||||
err = syscall.Listen(sock.handle, 16)
|
||||
if err != nil {
|
||||
return nil, l.opErr("listen", os.NewSyscallError("listen", err))
|
||||
}
|
||||
return &HvsockListener{sock: sock, addr: *addr}, nil
|
||||
}
|
||||
|
||||
func (l *HvsockListener) opErr(op string, err error) error {
|
||||
return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *HvsockListener) Addr() net.Addr {
|
||||
return &l.addr
|
||||
}
|
||||
|
||||
// Accept waits for the next connection and returns it.
|
||||
func (l *HvsockListener) Accept() (_ net.Conn, err error) {
|
||||
sock, err := newHVSocket()
|
||||
if err != nil {
|
||||
return nil, l.opErr("accept", err)
|
||||
}
|
||||
defer func() {
|
||||
if sock != nil {
|
||||
sock.Close()
|
||||
}
|
||||
}()
|
||||
c, err := l.sock.prepareIO()
|
||||
if err != nil {
|
||||
return nil, l.opErr("accept", err)
|
||||
}
|
||||
defer l.sock.wg.Done()
|
||||
|
||||
// AcceptEx, per documentation, requires an extra 16 bytes per address.
|
||||
//
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
|
||||
const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
|
||||
var addrbuf [addrlen * 2]byte
|
||||
|
||||
var bytes uint32
|
||||
err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /*rxdatalen*/, addrlen, addrlen, &bytes, &c.o)
|
||||
if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
|
||||
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
|
||||
}
|
||||
|
||||
conn := &HvsockConn{
|
||||
sock: sock,
|
||||
}
|
||||
// The local address returned in the AcceptEx buffer is the same as the Listener socket's
|
||||
// address. However, the service GUID reported by GetSockName is different from the Listeners
|
||||
// socket, and is sometimes the same as the local address of the socket that dialed the
|
||||
// address, with the service GUID.Data1 incremented, but othertimes is different.
|
||||
// todo: does the local address matter? is the listener's address or the actual address appropriate?
|
||||
conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
|
||||
conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
|
||||
|
||||
// initialize the accepted socket and update its properties with those of the listening socket
|
||||
if err = windows.Setsockopt(windows.Handle(sock.handle),
|
||||
windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
|
||||
(*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
|
||||
return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
|
||||
}
|
||||
|
||||
sock = nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close closes the listener, causing any pending Accept calls to fail.
|
||||
func (l *HvsockListener) Close() error {
|
||||
return l.sock.Close()
|
||||
}
|
||||
|
||||
// HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]).
|
||||
type HvsockDialer struct {
|
||||
// Deadline is the time the Dial operation must connect before erroring.
|
||||
Deadline time.Time
|
||||
|
||||
// Retries is the number of additional connects to try if the connection times out, is refused,
|
||||
// or the host is unreachable
|
||||
Retries uint
|
||||
|
||||
// RetryWait is the time to wait after a connection error to retry
|
||||
RetryWait time.Duration
|
||||
|
||||
rt *time.Timer // redial wait timer
|
||||
}
|
||||
|
||||
// Dial the Hyper-V socket at addr.
|
||||
//
|
||||
// See [HvsockDialer.Dial] for more information.
|
||||
func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
|
||||
return (&HvsockDialer{}).Dial(ctx, addr)
|
||||
}
|
||||
|
||||
// Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful.
|
||||
// Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between
|
||||
// retries.
|
||||
//
|
||||
// Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx.
|
||||
func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
|
||||
op := "dial"
|
||||
// create the conn early to use opErr()
|
||||
conn = &HvsockConn{
|
||||
remote: *addr,
|
||||
}
|
||||
|
||||
if !d.Deadline.IsZero() {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithDeadline(ctx, d.Deadline)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// preemptive timeout/cancellation check
|
||||
if err = ctx.Err(); err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
|
||||
sock, err := newHVSocket()
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
defer func() {
|
||||
if sock != nil {
|
||||
sock.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
sa := addr.raw()
|
||||
err = socket.Bind(windows.Handle(sock.handle), &sa)
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("bind", err))
|
||||
}
|
||||
|
||||
c, err := sock.prepareIO()
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
defer sock.wg.Done()
|
||||
var bytes uint32
|
||||
for i := uint(0); i <= d.Retries; i++ {
|
||||
err = socket.ConnectEx(
|
||||
windows.Handle(sock.handle),
|
||||
&sa,
|
||||
nil, // sendBuf
|
||||
0, // sendDataLen
|
||||
&bytes,
|
||||
(*windows.Overlapped)(unsafe.Pointer(&c.o)))
|
||||
_, err = sock.asyncIO(c, nil, bytes, err)
|
||||
if i < d.Retries && canRedial(err) {
|
||||
if err = d.redialWait(ctx); err == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
|
||||
}
|
||||
|
||||
// update the connection properties, so shutdown can be used
|
||||
if err = windows.Setsockopt(
|
||||
windows.Handle(sock.handle),
|
||||
windows.SOL_SOCKET,
|
||||
windows.SO_UPDATE_CONNECT_CONTEXT,
|
||||
nil, // optvalue
|
||||
0, // optlen
|
||||
); err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
|
||||
}
|
||||
|
||||
// get the local name
|
||||
var sal rawHvsockAddr
|
||||
err = socket.GetSockName(windows.Handle(sock.handle), &sal)
|
||||
if err != nil {
|
||||
return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
|
||||
}
|
||||
conn.local.fromRaw(&sal)
|
||||
|
||||
// one last check for timeout, since asyncIO doesn't check the context
|
||||
if err = ctx.Err(); err != nil {
|
||||
return nil, conn.opErr(op, err)
|
||||
}
|
||||
|
||||
conn.sock = sock
|
||||
sock = nil
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// redialWait waits before attempting to redial, resetting the timer as appropriate.
|
||||
func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
|
||||
if d.RetryWait == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if d.rt == nil {
|
||||
d.rt = time.NewTimer(d.RetryWait)
|
||||
} else {
|
||||
// should already be stopped and drained
|
||||
d.rt.Reset(d.RetryWait)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-d.rt.C:
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop and drain the timer
|
||||
if !d.rt.Stop() {
|
||||
<-d.rt.C
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// assumes error is a plain, unwrapped syscall.Errno provided by direct syscall.
|
||||
func canRedial(err error) bool {
|
||||
//nolint:errorlint // guaranteed to be an Errno
|
||||
switch err {
|
||||
case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
|
||||
windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) opErr(op string, err error) error {
|
||||
// translate from "file closed" to "socket closed"
|
||||
if errors.Is(err, ErrFileClosed) {
|
||||
err = socket.ErrSocketClosed
|
||||
}
|
||||
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) Read(b []byte) (int, error) {
|
||||
c, err := conn.sock.prepareIO()
|
||||
if err != nil {
|
||||
return 0, conn.opErr("read", err)
|
||||
}
|
||||
defer conn.sock.wg.Done()
|
||||
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||
var flags, bytes uint32
|
||||
err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
|
||||
n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
|
||||
if err != nil {
|
||||
var eno windows.Errno
|
||||
if errors.As(err, &eno) {
|
||||
err = os.NewSyscallError("wsarecv", eno)
|
||||
}
|
||||
return 0, conn.opErr("read", err)
|
||||
} else if n == 0 {
|
||||
err = io.EOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) Write(b []byte) (int, error) {
|
||||
t := 0
|
||||
for len(b) != 0 {
|
||||
n, err := conn.write(b)
|
||||
if err != nil {
|
||||
return t + n, err
|
||||
}
|
||||
t += n
|
||||
b = b[n:]
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) write(b []byte) (int, error) {
|
||||
c, err := conn.sock.prepareIO()
|
||||
if err != nil {
|
||||
return 0, conn.opErr("write", err)
|
||||
}
|
||||
defer conn.sock.wg.Done()
|
||||
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
||||
var bytes uint32
|
||||
err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
|
||||
n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
|
||||
if err != nil {
|
||||
var eno windows.Errno
|
||||
if errors.As(err, &eno) {
|
||||
err = os.NewSyscallError("wsasend", eno)
|
||||
}
|
||||
return 0, conn.opErr("write", err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close closes the socket connection, failing any pending read or write calls.
|
||||
func (conn *HvsockConn) Close() error {
|
||||
return conn.sock.Close()
|
||||
}
|
||||
|
||||
func (conn *HvsockConn) IsClosed() bool {
|
||||
return conn.sock.IsClosed()
|
||||
}
|
||||
|
||||
// shutdown disables sending or receiving on a socket.
|
||||
func (conn *HvsockConn) shutdown(how int) error {
|
||||
if conn.IsClosed() {
|
||||
return socket.ErrSocketClosed
|
||||
}
|
||||
|
||||
err := syscall.Shutdown(conn.sock.handle, how)
|
||||
if err != nil {
|
||||
// If the connection was closed, shutdowns fail with "not connected"
|
||||
if errors.Is(err, windows.WSAENOTCONN) ||
|
||||
errors.Is(err, windows.WSAESHUTDOWN) {
|
||||
err = socket.ErrSocketClosed
|
||||
}
|
||||
return os.NewSyscallError("shutdown", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseRead shuts down the read end of the socket, preventing future read operations.
|
||||
func (conn *HvsockConn) CloseRead() error {
|
||||
err := conn.shutdown(syscall.SHUT_RD)
|
||||
if err != nil {
|
||||
return conn.opErr("closeread", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseWrite shuts down the write end of the socket, preventing future write operations and
|
||||
// notifying the other endpoint that no more data will be written.
|
||||
func (conn *HvsockConn) CloseWrite() error {
|
||||
err := conn.shutdown(syscall.SHUT_WR)
|
||||
if err != nil {
|
||||
return conn.opErr("closewrite", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address of the connection.
|
||||
func (conn *HvsockConn) LocalAddr() net.Addr {
|
||||
return &conn.local
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address of the connection.
|
||||
func (conn *HvsockConn) RemoteAddr() net.Addr {
|
||||
return &conn.remote
|
||||
}
|
||||
|
||||
// SetDeadline implements the net.Conn SetDeadline method.
|
||||
func (conn *HvsockConn) SetDeadline(t time.Time) error {
|
||||
// todo: implement `SetDeadline` for `win32File`
|
||||
if err := conn.SetReadDeadline(t); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
if err := conn.SetWriteDeadline(t); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline implements the net.Conn SetReadDeadline method.
|
||||
func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
|
||||
return conn.sock.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
|
||||
func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
|
||||
return conn.sock.SetWriteDeadline(t)
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
package socket
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// RawSockaddr allows structs to be used with [Bind] and [ConnectEx]. The
|
||||
// struct must meet the Win32 sockaddr requirements specified here:
|
||||
// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
|
||||
//
|
||||
// Specifically, the struct size must be least larger than an int16 (unsigned short)
|
||||
// for the address family.
|
||||
type RawSockaddr interface {
|
||||
// Sockaddr returns a pointer to the RawSockaddr and its struct size, allowing
|
||||
// for the RawSockaddr's data to be overwritten by syscalls (if necessary).
|
||||
//
|
||||
// It is the callers responsibility to validate that the values are valid; invalid
|
||||
// pointers or size can cause a panic.
|
||||
Sockaddr() (unsafe.Pointer, int32, error)
|
||||
}
|
@ -1,179 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/Microsoft/go-winio/pkg/guid"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
|
||||
|
||||
//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
|
||||
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
|
||||
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
|
||||
|
||||
const socketError = uintptr(^uint32(0))
|
||||
|
||||
var (
|
||||
// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
|
||||
|
||||
ErrBufferSize = errors.New("buffer size")
|
||||
ErrAddrFamily = errors.New("address family")
|
||||
ErrInvalidPointer = errors.New("invalid pointer")
|
||||
ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
|
||||
)
|
||||
|
||||
// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
|
||||
|
||||
// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
|
||||
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
|
||||
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
|
||||
ptr, l, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||
}
|
||||
|
||||
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
|
||||
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
|
||||
return getsockname(s, ptr, &l)
|
||||
}
|
||||
|
||||
// GetPeerName returns the remote address the socket is connected to.
|
||||
//
|
||||
// See [GetSockName] for more information.
|
||||
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
|
||||
ptr, l, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||
}
|
||||
|
||||
return getpeername(s, ptr, &l)
|
||||
}
|
||||
|
||||
func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
|
||||
ptr, l, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
||||
}
|
||||
|
||||
return bind(s, ptr, l)
|
||||
}
|
||||
|
||||
// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
|
||||
// their sockaddr interface, so they cannot be used with HvsockAddr
|
||||
// Replicate functionality here from
|
||||
// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
|
||||
|
||||
// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
|
||||
// runtime via a WSAIoctl call:
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
|
||||
|
||||
type runtimeFunc struct {
|
||||
id guid.GUID
|
||||
once sync.Once
|
||||
addr uintptr
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *runtimeFunc) Load() error {
|
||||
f.once.Do(func() {
|
||||
var s windows.Handle
|
||||
s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
|
||||
if f.err != nil {
|
||||
return
|
||||
}
|
||||
defer windows.CloseHandle(s) //nolint:errcheck
|
||||
|
||||
var n uint32
|
||||
f.err = windows.WSAIoctl(s,
|
||||
windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
|
||||
(*byte)(unsafe.Pointer(&f.id)),
|
||||
uint32(unsafe.Sizeof(f.id)),
|
||||
(*byte)(unsafe.Pointer(&f.addr)),
|
||||
uint32(unsafe.Sizeof(f.addr)),
|
||||
&n,
|
||||
nil, //overlapped
|
||||
0, //completionRoutine
|
||||
)
|
||||
})
|
||||
return f.err
|
||||
}
|
||||
|
||||
var (
|
||||
// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
|
||||
WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
|
||||
Data1: 0x25a207b9,
|
||||
Data2: 0xddf3,
|
||||
Data3: 0x4660,
|
||||
Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
|
||||
}
|
||||
|
||||
connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
|
||||
)
|
||||
|
||||
func ConnectEx(
|
||||
fd windows.Handle,
|
||||
rsa RawSockaddr,
|
||||
sendBuf *byte,
|
||||
sendDataLen uint32,
|
||||
bytesSent *uint32,
|
||||
overlapped *windows.Overlapped,
|
||||
) error {
|
||||
if err := connectExFunc.Load(); err != nil {
|
||||
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
|
||||
}
|
||||
ptr, n, err := rsa.Sockaddr()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
|
||||
}
|
||||
|
||||
// BOOL LpfnConnectex(
|
||||
// [in] SOCKET s,
|
||||
// [in] const sockaddr *name,
|
||||
// [in] int namelen,
|
||||
// [in, optional] PVOID lpSendBuffer,
|
||||
// [in] DWORD dwSendDataLength,
|
||||
// [out] LPDWORD lpdwBytesSent,
|
||||
// [in] LPOVERLAPPED lpOverlapped
|
||||
// )
|
||||
|
||||
func connectEx(
|
||||
s windows.Handle,
|
||||
name unsafe.Pointer,
|
||||
namelen int32,
|
||||
sendBuf *byte,
|
||||
sendDataLen uint32,
|
||||
bytesSent *uint32,
|
||||
overlapped *windows.Overlapped,
|
||||
) (err error) {
|
||||
// todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN
|
||||
r1, _, e1 := syscall.Syscall9(connectExFunc.addr,
|
||||
7,
|
||||
uintptr(s),
|
||||
uintptr(name),
|
||||
uintptr(namelen),
|
||||
uintptr(unsafe.Pointer(sendBuf)),
|
||||
uintptr(sendDataLen),
|
||||
uintptr(unsafe.Pointer(bytesSent)),
|
||||
uintptr(unsafe.Pointer(overlapped)),
|
||||
0,
|
||||
0)
|
||||
if r1 == 0 {
|
||||
if e1 != 0 {
|
||||
err = error(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
procbind = modws2_32.NewProc("bind")
|
||||
procgetpeername = modws2_32.NewProc("getpeername")
|
||||
procgetsockname = modws2_32.NewProc("getsockname")
|
||||
)
|
||||
|
||||
func bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen))
|
||||
if r1 == socketError {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procgetpeername.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
|
||||
if r1 == socketError {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procgetsockname.Addr(), 3, uintptr(s), uintptr(name), uintptr(unsafe.Pointer(namelen)))
|
||||
if r1 == socketError {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
@ -1,521 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe
|
||||
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW
|
||||
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW
|
||||
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
||||
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
||||
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
|
||||
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) = ntdll.NtCreateNamedPipeFile
|
||||
//sys rtlNtStatusToDosError(status ntStatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
||||
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U
|
||||
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl
|
||||
|
||||
type ioStatusBlock struct {
|
||||
Status, Information uintptr
|
||||
}
|
||||
|
||||
type objectAttributes struct {
|
||||
Length uintptr
|
||||
RootDirectory uintptr
|
||||
ObjectName *unicodeString
|
||||
Attributes uintptr
|
||||
SecurityDescriptor *securityDescriptor
|
||||
SecurityQoS uintptr
|
||||
}
|
||||
|
||||
type unicodeString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer uintptr
|
||||
}
|
||||
|
||||
type securityDescriptor struct {
|
||||
Revision byte
|
||||
Sbz1 byte
|
||||
Control uint16
|
||||
Owner uintptr
|
||||
Group uintptr
|
||||
Sacl uintptr //revive:disable-line:var-naming SACL, not Sacl
|
||||
Dacl uintptr //revive:disable-line:var-naming DACL, not Dacl
|
||||
}
|
||||
|
||||
type ntStatus int32
|
||||
|
||||
func (status ntStatus) Err() error {
|
||||
if status >= 0 {
|
||||
return nil
|
||||
}
|
||||
return rtlNtStatusToDosError(status)
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
|
||||
ErrPipeListenerClosed = net.ErrClosed
|
||||
|
||||
errPipeWriteClosed = errors.New("pipe has been closed for write")
|
||||
)
|
||||
|
||||
type win32Pipe struct {
|
||||
*win32File
|
||||
path string
|
||||
}
|
||||
|
||||
type win32MessageBytePipe struct {
|
||||
win32Pipe
|
||||
writeClosed bool
|
||||
readEOF bool
|
||||
}
|
||||
|
||||
type pipeAddress string
|
||||
|
||||
func (f *win32Pipe) LocalAddr() net.Addr {
|
||||
return pipeAddress(f.path)
|
||||
}
|
||||
|
||||
func (f *win32Pipe) RemoteAddr() net.Addr {
|
||||
return pipeAddress(f.path)
|
||||
}
|
||||
|
||||
func (f *win32Pipe) SetDeadline(t time.Time) error {
|
||||
if err := f.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||
func (f *win32MessageBytePipe) CloseWrite() error {
|
||||
if f.writeClosed {
|
||||
return errPipeWriteClosed
|
||||
}
|
||||
err := f.win32File.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = f.win32File.Write(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.writeClosed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||
// they are used to implement CloseWrite().
|
||||
func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
|
||||
if f.writeClosed {
|
||||
return 0, errPipeWriteClosed
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return f.win32File.Write(b)
|
||||
}
|
||||
|
||||
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
|
||||
// mode pipe will return io.EOF, as will all subsequent reads.
|
||||
func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
|
||||
if f.readEOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err := f.win32File.Read(b)
|
||||
if err == io.EOF { //nolint:errorlint
|
||||
// If this was the result of a zero-byte read, then
|
||||
// it is possible that the read was due to a zero-size
|
||||
// message. Since we are simulating CloseWrite with a
|
||||
// zero-byte message, ensure that all future Read() calls
|
||||
// also return EOF.
|
||||
f.readEOF = true
|
||||
} else if err == syscall.ERROR_MORE_DATA { //nolint:errorlint // err is Errno
|
||||
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||
// and the message still has more bytes. Treat this as a success, since
|
||||
// this package presents all named pipes as byte streams.
|
||||
err = nil
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (pipeAddress) Network() string {
|
||||
return "pipe"
|
||||
}
|
||||
|
||||
func (s pipeAddress) String() string {
|
||||
return string(s)
|
||||
}
|
||||
|
||||
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
||||
func tryDialPipe(ctx context.Context, path *string, access uint32) (syscall.Handle, error) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return syscall.Handle(0), ctx.Err()
|
||||
default:
|
||||
h, err := createFile(*path,
|
||||
access,
|
||||
0,
|
||||
nil,
|
||||
syscall.OPEN_EXISTING,
|
||||
windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS,
|
||||
0)
|
||||
if err == nil {
|
||||
return h, nil
|
||||
}
|
||||
if err != windows.ERROR_PIPE_BUSY { //nolint:errorlint // err is Errno
|
||||
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||
}
|
||||
// Wait 10 msec and try again. This is a rather simplistic
|
||||
// view, as we always try each 10 milliseconds.
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DialPipe connects to a named pipe by path, timing out if the connection
|
||||
// takes longer than the specified duration. If timeout is nil, then we use
|
||||
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
||||
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
|
||||
var absTimeout time.Time
|
||||
if timeout != nil {
|
||||
absTimeout = time.Now().Add(*timeout)
|
||||
} else {
|
||||
absTimeout = time.Now().Add(2 * time.Second)
|
||||
}
|
||||
ctx, cancel := context.WithDeadline(context.Background(), absTimeout)
|
||||
defer cancel()
|
||||
conn, err := DialPipeContext(ctx, path)
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil, ErrTimeout
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
||||
// cancellation or timeout.
|
||||
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
|
||||
return DialPipeAccess(ctx, path, syscall.GENERIC_READ|syscall.GENERIC_WRITE)
|
||||
}
|
||||
|
||||
// DialPipeAccess attempts to connect to a named pipe by `path` with `access` until `ctx`
|
||||
// cancellation or timeout.
|
||||
func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn, error) {
|
||||
var err error
|
||||
var h syscall.Handle
|
||||
h, err = tryDialPipe(ctx, &path, access)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var flags uint32
|
||||
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
syscall.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the pipe is in message mode, return a message byte pipe, which
|
||||
// supports CloseWrite().
|
||||
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
|
||||
return &win32MessageBytePipe{
|
||||
win32Pipe: win32Pipe{win32File: f, path: path},
|
||||
}, nil
|
||||
}
|
||||
return &win32Pipe{win32File: f, path: path}, nil
|
||||
}
|
||||
|
||||
type acceptResponse struct {
|
||||
f *win32File
|
||||
err error
|
||||
}
|
||||
|
||||
type win32PipeListener struct {
|
||||
firstHandle syscall.Handle
|
||||
path string
|
||||
config PipeConfig
|
||||
acceptCh chan (chan acceptResponse)
|
||||
closeCh chan int
|
||||
doneCh chan int
|
||||
}
|
||||
|
||||
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
|
||||
path16, err := syscall.UTF16FromString(path)
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
var oa objectAttributes
|
||||
oa.Length = unsafe.Sizeof(oa)
|
||||
|
||||
var ntPath unicodeString
|
||||
if err := rtlDosPathNameToNtPathName(&path16[0],
|
||||
&ntPath,
|
||||
0,
|
||||
0,
|
||||
).Err(); err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
defer localFree(ntPath.Buffer)
|
||||
oa.ObjectName = &ntPath
|
||||
|
||||
// The security descriptor is only needed for the first pipe.
|
||||
if first {
|
||||
if sd != nil {
|
||||
l := uint32(len(sd))
|
||||
sdb := localAlloc(0, l)
|
||||
defer localFree(sdb)
|
||||
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
|
||||
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
|
||||
} else {
|
||||
// Construct the default named pipe security descriptor.
|
||||
var dacl uintptr
|
||||
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
||||
return 0, fmt.Errorf("getting default named pipe ACL: %w", err)
|
||||
}
|
||||
defer localFree(dacl)
|
||||
|
||||
sdb := &securityDescriptor{
|
||||
Revision: 1,
|
||||
Control: windows.SE_DACL_PRESENT,
|
||||
Dacl: dacl,
|
||||
}
|
||||
oa.SecurityDescriptor = sdb
|
||||
}
|
||||
}
|
||||
|
||||
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
|
||||
if c.MessageMode {
|
||||
typ |= windows.FILE_PIPE_MESSAGE_TYPE
|
||||
}
|
||||
|
||||
disposition := uint32(windows.FILE_OPEN)
|
||||
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
|
||||
if first {
|
||||
disposition = windows.FILE_CREATE
|
||||
// By not asking for read or write access, the named pipe file system
|
||||
// will put this pipe into an initially disconnected state, blocking
|
||||
// client connections until the next call with first == false.
|
||||
access = syscall.SYNCHRONIZE
|
||||
}
|
||||
|
||||
timeout := int64(-50 * 10000) // 50ms
|
||||
|
||||
var (
|
||||
h syscall.Handle
|
||||
iosb ioStatusBlock
|
||||
)
|
||||
err = ntCreateNamedPipeFile(&h,
|
||||
access,
|
||||
&oa,
|
||||
&iosb,
|
||||
syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE,
|
||||
disposition,
|
||||
0,
|
||||
typ,
|
||||
0,
|
||||
0,
|
||||
0xffffffff,
|
||||
uint32(c.InputBufferSize),
|
||||
uint32(c.OutputBufferSize),
|
||||
&timeout).Err()
|
||||
if err != nil {
|
||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||
}
|
||||
|
||||
runtime.KeepAlive(ntPath)
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
|
||||
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, err := makeWin32File(h)
|
||||
if err != nil {
|
||||
syscall.Close(h)
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
|
||||
p, err := l.makeServerPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait for the client to connect.
|
||||
ch := make(chan error)
|
||||
go func(p *win32File) {
|
||||
ch <- connectPipe(p)
|
||||
}(p)
|
||||
|
||||
select {
|
||||
case err = <-ch:
|
||||
if err != nil {
|
||||
p.Close()
|
||||
p = nil
|
||||
}
|
||||
case <-l.closeCh:
|
||||
// Abort the connect request by closing the handle.
|
||||
p.Close()
|
||||
p = nil
|
||||
err = <-ch
|
||||
if err == nil || err == ErrFileClosed { //nolint:errorlint // err is Errno
|
||||
err = ErrPipeListenerClosed
|
||||
}
|
||||
}
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) listenerRoutine() {
|
||||
closed := false
|
||||
for !closed {
|
||||
select {
|
||||
case <-l.closeCh:
|
||||
closed = true
|
||||
case responseCh := <-l.acceptCh:
|
||||
var (
|
||||
p *win32File
|
||||
err error
|
||||
)
|
||||
for {
|
||||
p, err = l.makeConnectedServerPipe()
|
||||
// If the connection was immediately closed by the client, try
|
||||
// again.
|
||||
if err != windows.ERROR_NO_DATA { //nolint:errorlint // err is Errno
|
||||
break
|
||||
}
|
||||
}
|
||||
responseCh <- acceptResponse{p, err}
|
||||
closed = err == ErrPipeListenerClosed //nolint:errorlint // err is Errno
|
||||
}
|
||||
}
|
||||
syscall.Close(l.firstHandle)
|
||||
l.firstHandle = 0
|
||||
// Notify Close() and Accept() callers that the handle has been closed.
|
||||
close(l.doneCh)
|
||||
}
|
||||
|
||||
// PipeConfig contain configuration for the pipe listener.
|
||||
type PipeConfig struct {
|
||||
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
|
||||
SecurityDescriptor string
|
||||
|
||||
// MessageMode determines whether the pipe is in byte or message mode. In either
|
||||
// case the pipe is read in byte mode by default. The only practical difference in
|
||||
// this implementation is that CloseWrite() is only supported for message mode pipes;
|
||||
// CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
|
||||
// transferred to the reader (and returned as io.EOF in this implementation)
|
||||
// when the pipe is in message mode.
|
||||
MessageMode bool
|
||||
|
||||
// InputBufferSize specifies the size of the input buffer, in bytes.
|
||||
InputBufferSize int32
|
||||
|
||||
// OutputBufferSize specifies the size of the output buffer, in bytes.
|
||||
OutputBufferSize int32
|
||||
}
|
||||
|
||||
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
||||
// The pipe must not already exist.
|
||||
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
||||
var (
|
||||
sd []byte
|
||||
err error
|
||||
)
|
||||
if c == nil {
|
||||
c = &PipeConfig{}
|
||||
}
|
||||
if c.SecurityDescriptor != "" {
|
||||
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
h, err := makeServerPipeHandle(path, sd, c, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l := &win32PipeListener{
|
||||
firstHandle: h,
|
||||
path: path,
|
||||
config: *c,
|
||||
acceptCh: make(chan (chan acceptResponse)),
|
||||
closeCh: make(chan int),
|
||||
doneCh: make(chan int),
|
||||
}
|
||||
go l.listenerRoutine()
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func connectPipe(p *win32File) error {
|
||||
c, err := p.prepareIO()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer p.wg.Done()
|
||||
|
||||
err = connectNamedPipe(p.handle, &c.o)
|
||||
_, err = p.asyncIO(c, nil, 0, err)
|
||||
if err != nil && err != windows.ERROR_PIPE_CONNECTED { //nolint:errorlint // err is Errno
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) Accept() (net.Conn, error) {
|
||||
ch := make(chan acceptResponse)
|
||||
select {
|
||||
case l.acceptCh <- ch:
|
||||
response := <-ch
|
||||
err := response.err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if l.config.MessageMode {
|
||||
return &win32MessageBytePipe{
|
||||
win32Pipe: win32Pipe{win32File: response.f, path: l.path},
|
||||
}, nil
|
||||
}
|
||||
return &win32Pipe{win32File: response.f, path: l.path}, nil
|
||||
case <-l.doneCh:
|
||||
return nil, ErrPipeListenerClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) Close() error {
|
||||
select {
|
||||
case l.closeCh <- 1:
|
||||
<-l.doneCh
|
||||
case <-l.doneCh:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *win32PipeListener) Addr() net.Addr {
|
||||
return pipeAddress(l.path)
|
||||
}
|
@ -1,232 +0,0 @@
|
||||
// Package guid provides a GUID type. The backing structure for a GUID is
|
||||
// identical to that used by the golang.org/x/sys/windows GUID type.
|
||||
// There are two main binary encodings used for a GUID, the big-endian encoding,
|
||||
// and the Windows (mixed-endian) encoding. See here for details:
|
||||
// https://en.wikipedia.org/wiki/Universally_unique_identifier#Encoding
|
||||
package guid
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha1" //nolint:gosec // not used for secure application
|
||||
"encoding"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
//go:generate go run golang.org/x/tools/cmd/stringer -type=Variant -trimprefix=Variant -linecomment
|
||||
|
||||
// Variant specifies which GUID variant (or "type") of the GUID. It determines
|
||||
// how the entirety of the rest of the GUID is interpreted.
|
||||
type Variant uint8
|
||||
|
||||
// The variants specified by RFC 4122 section 4.1.1.
|
||||
const (
|
||||
// VariantUnknown specifies a GUID variant which does not conform to one of
|
||||
// the variant encodings specified in RFC 4122.
|
||||
VariantUnknown Variant = iota
|
||||
VariantNCS
|
||||
VariantRFC4122 // RFC 4122
|
||||
VariantMicrosoft
|
||||
VariantFuture
|
||||
)
|
||||
|
||||
// Version specifies how the bits in the GUID were generated. For instance, a
|
||||
// version 4 GUID is randomly generated, and a version 5 is generated from the
|
||||
// hash of an input string.
|
||||
type Version uint8
|
||||
|
||||
func (v Version) String() string {
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
}
|
||||
|
||||
var _ = (encoding.TextMarshaler)(GUID{})
|
||||
var _ = (encoding.TextUnmarshaler)(&GUID{})
|
||||
|
||||
// NewV4 returns a new version 4 (pseudorandom) GUID, as defined by RFC 4122.
|
||||
func NewV4() (GUID, error) {
|
||||
var b [16]byte
|
||||
if _, err := rand.Read(b[:]); err != nil {
|
||||
return GUID{}, err
|
||||
}
|
||||
|
||||
g := FromArray(b)
|
||||
g.setVersion(4) // Version 4 means randomly generated.
|
||||
g.setVariant(VariantRFC4122)
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// NewV5 returns a new version 5 (generated from a string via SHA-1 hashing)
|
||||
// GUID, as defined by RFC 4122. The RFC is unclear on the encoding of the name,
|
||||
// and the sample code treats it as a series of bytes, so we do the same here.
|
||||
//
|
||||
// Some implementations, such as those found on Windows, treat the name as a
|
||||
// big-endian UTF16 stream of bytes. If that is desired, the string can be
|
||||
// encoded as such before being passed to this function.
|
||||
func NewV5(namespace GUID, name []byte) (GUID, error) {
|
||||
b := sha1.New() //nolint:gosec // not used for secure application
|
||||
namespaceBytes := namespace.ToArray()
|
||||
b.Write(namespaceBytes[:])
|
||||
b.Write(name)
|
||||
|
||||
a := [16]byte{}
|
||||
copy(a[:], b.Sum(nil))
|
||||
|
||||
g := FromArray(a)
|
||||
g.setVersion(5) // Version 5 means generated from a string.
|
||||
g.setVariant(VariantRFC4122)
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func fromArray(b [16]byte, order binary.ByteOrder) GUID {
|
||||
var g GUID
|
||||
g.Data1 = order.Uint32(b[0:4])
|
||||
g.Data2 = order.Uint16(b[4:6])
|
||||
g.Data3 = order.Uint16(b[6:8])
|
||||
copy(g.Data4[:], b[8:16])
|
||||
return g
|
||||
}
|
||||
|
||||
func (g GUID) toArray(order binary.ByteOrder) [16]byte {
|
||||
b := [16]byte{}
|
||||
order.PutUint32(b[0:4], g.Data1)
|
||||
order.PutUint16(b[4:6], g.Data2)
|
||||
order.PutUint16(b[6:8], g.Data3)
|
||||
copy(b[8:16], g.Data4[:])
|
||||
return b
|
||||
}
|
||||
|
||||
// FromArray constructs a GUID from a big-endian encoding array of 16 bytes.
|
||||
func FromArray(b [16]byte) GUID {
|
||||
return fromArray(b, binary.BigEndian)
|
||||
}
|
||||
|
||||
// ToArray returns an array of 16 bytes representing the GUID in big-endian
|
||||
// encoding.
|
||||
func (g GUID) ToArray() [16]byte {
|
||||
return g.toArray(binary.BigEndian)
|
||||
}
|
||||
|
||||
// FromWindowsArray constructs a GUID from a Windows encoding array of bytes.
|
||||
func FromWindowsArray(b [16]byte) GUID {
|
||||
return fromArray(b, binary.LittleEndian)
|
||||
}
|
||||
|
||||
// ToWindowsArray returns an array of 16 bytes representing the GUID in Windows
|
||||
// encoding.
|
||||
func (g GUID) ToWindowsArray() [16]byte {
|
||||
return g.toArray(binary.LittleEndian)
|
||||
}
|
||||
|
||||
func (g GUID) String() string {
|
||||
return fmt.Sprintf(
|
||||
"%08x-%04x-%04x-%04x-%012x",
|
||||
g.Data1,
|
||||
g.Data2,
|
||||
g.Data3,
|
||||
g.Data4[:2],
|
||||
g.Data4[2:])
|
||||
}
|
||||
|
||||
// FromString parses a string containing a GUID and returns the GUID. The only
|
||||
// format currently supported is the `xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx`
|
||||
// format.
|
||||
func FromString(s string) (GUID, error) {
|
||||
if len(s) != 36 {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
|
||||
var g GUID
|
||||
|
||||
data1, err := strconv.ParseUint(s[0:8], 16, 32)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data1 = uint32(data1)
|
||||
|
||||
data2, err := strconv.ParseUint(s[9:13], 16, 16)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data2 = uint16(data2)
|
||||
|
||||
data3, err := strconv.ParseUint(s[14:18], 16, 16)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data3 = uint16(data3)
|
||||
|
||||
for i, x := range []int{19, 21, 24, 26, 28, 30, 32, 34} {
|
||||
v, err := strconv.ParseUint(s[x:x+2], 16, 8)
|
||||
if err != nil {
|
||||
return GUID{}, fmt.Errorf("invalid GUID %q", s)
|
||||
}
|
||||
g.Data4[i] = uint8(v)
|
||||
}
|
||||
|
||||
return g, nil
|
||||
}
|
||||
|
||||
func (g *GUID) setVariant(v Variant) {
|
||||
d := g.Data4[0]
|
||||
switch v {
|
||||
case VariantNCS:
|
||||
d = (d & 0x7f)
|
||||
case VariantRFC4122:
|
||||
d = (d & 0x3f) | 0x80
|
||||
case VariantMicrosoft:
|
||||
d = (d & 0x1f) | 0xc0
|
||||
case VariantFuture:
|
||||
d = (d & 0x0f) | 0xe0
|
||||
case VariantUnknown:
|
||||
fallthrough
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid variant: %d", v))
|
||||
}
|
||||
g.Data4[0] = d
|
||||
}
|
||||
|
||||
// Variant returns the GUID variant, as defined in RFC 4122.
|
||||
func (g GUID) Variant() Variant {
|
||||
b := g.Data4[0]
|
||||
if b&0x80 == 0 {
|
||||
return VariantNCS
|
||||
} else if b&0xc0 == 0x80 {
|
||||
return VariantRFC4122
|
||||
} else if b&0xe0 == 0xc0 {
|
||||
return VariantMicrosoft
|
||||
} else if b&0xe0 == 0xe0 {
|
||||
return VariantFuture
|
||||
}
|
||||
return VariantUnknown
|
||||
}
|
||||
|
||||
func (g *GUID) setVersion(v Version) {
|
||||
g.Data3 = (g.Data3 & 0x0fff) | (uint16(v) << 12)
|
||||
}
|
||||
|
||||
// Version returns the GUID version, as defined in RFC 4122.
|
||||
func (g GUID) Version() Version {
|
||||
return Version((g.Data3 & 0xF000) >> 12)
|
||||
}
|
||||
|
||||
// MarshalText returns the textual representation of the GUID.
|
||||
func (g GUID) MarshalText() ([]byte, error) {
|
||||
return []byte(g.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText takes the textual representation of a GUID, and unmarhals it
|
||||
// into this GUID.
|
||||
func (g *GUID) UnmarshalText(text []byte) error {
|
||||
g2, err := FromString(string(text))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*g = g2
|
||||
return nil
|
||||
}
|
@ -1,16 +0,0 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package guid
|
||||
|
||||
// GUID represents a GUID/UUID. It has the same structure as
|
||||
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
|
||||
// that type. It is defined as its own type as that is only available to builds
|
||||
// targeted at `windows`. The representation matches that used by native Windows
|
||||
// code.
|
||||
type GUID struct {
|
||||
Data1 uint32
|
||||
Data2 uint16
|
||||
Data3 uint16
|
||||
Data4 [8]byte
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package guid
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
// GUID represents a GUID/UUID. It has the same structure as
|
||||
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
|
||||
// that type. It is defined as its own type so that stringification and
|
||||
// marshaling can be supported. The representation matches that used by native
|
||||
// Windows code.
|
||||
type GUID windows.GUID
|
@ -1,27 +0,0 @@
|
||||
// Code generated by "stringer -type=Variant -trimprefix=Variant -linecomment"; DO NOT EDIT.
|
||||
|
||||
package guid
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[VariantUnknown-0]
|
||||
_ = x[VariantNCS-1]
|
||||
_ = x[VariantRFC4122-2]
|
||||
_ = x[VariantMicrosoft-3]
|
||||
_ = x[VariantFuture-4]
|
||||
}
|
||||
|
||||
const _Variant_name = "UnknownNCSRFC 4122MicrosoftFuture"
|
||||
|
||||
var _Variant_index = [...]uint8{0, 7, 10, 18, 27, 33}
|
||||
|
||||
func (i Variant) String() string {
|
||||
if i >= Variant(len(_Variant_index)-1) {
|
||||
return "Variant(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _Variant_name[_Variant_index[i]:_Variant_index[i+1]]
|
||||
}
|
@ -1,197 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unicode/utf16"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) [true] = advapi32.AdjustTokenPrivileges
|
||||
//sys impersonateSelf(level uint32) (err error) = advapi32.ImpersonateSelf
|
||||
//sys revertToSelf() (err error) = advapi32.RevertToSelf
|
||||
//sys openThreadToken(thread syscall.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) = advapi32.OpenThreadToken
|
||||
//sys getCurrentThread() (h syscall.Handle) = GetCurrentThread
|
||||
//sys lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) = advapi32.LookupPrivilegeValueW
|
||||
//sys lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) = advapi32.LookupPrivilegeNameW
|
||||
//sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW
|
||||
|
||||
const (
|
||||
//revive:disable-next-line:var-naming ALL_CAPS
|
||||
SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED
|
||||
|
||||
//revive:disable-next-line:var-naming ALL_CAPS
|
||||
ERROR_NOT_ALL_ASSIGNED syscall.Errno = windows.ERROR_NOT_ALL_ASSIGNED
|
||||
|
||||
SeBackupPrivilege = "SeBackupPrivilege"
|
||||
SeRestorePrivilege = "SeRestorePrivilege"
|
||||
SeSecurityPrivilege = "SeSecurityPrivilege"
|
||||
)
|
||||
|
||||
var (
|
||||
privNames = make(map[string]uint64)
|
||||
privNameMutex sync.Mutex
|
||||
)
|
||||
|
||||
// PrivilegeError represents an error enabling privileges.
|
||||
type PrivilegeError struct {
|
||||
privileges []uint64
|
||||
}
|
||||
|
||||
func (e *PrivilegeError) Error() string {
|
||||
s := "Could not enable privilege "
|
||||
if len(e.privileges) > 1 {
|
||||
s = "Could not enable privileges "
|
||||
}
|
||||
for i, p := range e.privileges {
|
||||
if i != 0 {
|
||||
s += ", "
|
||||
}
|
||||
s += `"`
|
||||
s += getPrivilegeName(p)
|
||||
s += `"`
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// RunWithPrivilege enables a single privilege for a function call.
|
||||
func RunWithPrivilege(name string, fn func() error) error {
|
||||
return RunWithPrivileges([]string{name}, fn)
|
||||
}
|
||||
|
||||
// RunWithPrivileges enables privileges for a function call.
|
||||
func RunWithPrivileges(names []string, fn func() error) error {
|
||||
privileges, err := mapPrivileges(names)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
token, err := newThreadToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer releaseThreadToken(token)
|
||||
err = adjustPrivileges(token, privileges, SE_PRIVILEGE_ENABLED)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fn()
|
||||
}
|
||||
|
||||
func mapPrivileges(names []string) ([]uint64, error) {
|
||||
privileges := make([]uint64, 0, len(names))
|
||||
privNameMutex.Lock()
|
||||
defer privNameMutex.Unlock()
|
||||
for _, name := range names {
|
||||
p, ok := privNames[name]
|
||||
if !ok {
|
||||
err := lookupPrivilegeValue("", name, &p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
privNames[name] = p
|
||||
}
|
||||
privileges = append(privileges, p)
|
||||
}
|
||||
return privileges, nil
|
||||
}
|
||||
|
||||
// EnableProcessPrivileges enables privileges globally for the process.
|
||||
func EnableProcessPrivileges(names []string) error {
|
||||
return enableDisableProcessPrivilege(names, SE_PRIVILEGE_ENABLED)
|
||||
}
|
||||
|
||||
// DisableProcessPrivileges disables privileges globally for the process.
|
||||
func DisableProcessPrivileges(names []string) error {
|
||||
return enableDisableProcessPrivilege(names, 0)
|
||||
}
|
||||
|
||||
func enableDisableProcessPrivilege(names []string, action uint32) error {
|
||||
privileges, err := mapPrivileges(names)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := windows.CurrentProcess()
|
||||
var token windows.Token
|
||||
err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer token.Close()
|
||||
return adjustPrivileges(token, privileges, action)
|
||||
}
|
||||
|
||||
func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error {
|
||||
var b bytes.Buffer
|
||||
_ = binary.Write(&b, binary.LittleEndian, uint32(len(privileges)))
|
||||
for _, p := range privileges {
|
||||
_ = binary.Write(&b, binary.LittleEndian, p)
|
||||
_ = binary.Write(&b, binary.LittleEndian, action)
|
||||
}
|
||||
prevState := make([]byte, b.Len())
|
||||
reqSize := uint32(0)
|
||||
success, err := adjustTokenPrivileges(token, false, &b.Bytes()[0], uint32(len(prevState)), &prevState[0], &reqSize)
|
||||
if !success {
|
||||
return err
|
||||
}
|
||||
if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno
|
||||
return &PrivilegeError{privileges}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getPrivilegeName(luid uint64) string {
|
||||
var nameBuffer [256]uint16
|
||||
bufSize := uint32(len(nameBuffer))
|
||||
err := lookupPrivilegeName("", &luid, &nameBuffer[0], &bufSize)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("<unknown privilege %d>", luid)
|
||||
}
|
||||
|
||||
var displayNameBuffer [256]uint16
|
||||
displayBufSize := uint32(len(displayNameBuffer))
|
||||
var langID uint32
|
||||
err = lookupPrivilegeDisplayName("", &nameBuffer[0], &displayNameBuffer[0], &displayBufSize, &langID)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("<unknown privilege %s>", string(utf16.Decode(nameBuffer[:bufSize])))
|
||||
}
|
||||
|
||||
return string(utf16.Decode(displayNameBuffer[:displayBufSize]))
|
||||
}
|
||||
|
||||
func newThreadToken() (windows.Token, error) {
|
||||
err := impersonateSelf(windows.SecurityImpersonation)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var token windows.Token
|
||||
err = openThreadToken(getCurrentThread(), syscall.TOKEN_ADJUST_PRIVILEGES|syscall.TOKEN_QUERY, false, &token)
|
||||
if err != nil {
|
||||
rerr := revertToSelf()
|
||||
if rerr != nil {
|
||||
panic(rerr)
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func releaseThreadToken(h windows.Token) {
|
||||
err := revertToSelf()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
h.Close()
|
||||
}
|
@ -1,131 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode/utf16"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
reparseTagMountPoint = 0xA0000003
|
||||
reparseTagSymlink = 0xA000000C
|
||||
)
|
||||
|
||||
type reparseDataBuffer struct {
|
||||
ReparseTag uint32
|
||||
ReparseDataLength uint16
|
||||
Reserved uint16
|
||||
SubstituteNameOffset uint16
|
||||
SubstituteNameLength uint16
|
||||
PrintNameOffset uint16
|
||||
PrintNameLength uint16
|
||||
}
|
||||
|
||||
// ReparsePoint describes a Win32 symlink or mount point.
|
||||
type ReparsePoint struct {
|
||||
Target string
|
||||
IsMountPoint bool
|
||||
}
|
||||
|
||||
// UnsupportedReparsePointError is returned when trying to decode a non-symlink or
|
||||
// mount point reparse point.
|
||||
type UnsupportedReparsePointError struct {
|
||||
Tag uint32
|
||||
}
|
||||
|
||||
func (e *UnsupportedReparsePointError) Error() string {
|
||||
return fmt.Sprintf("unsupported reparse point %x", e.Tag)
|
||||
}
|
||||
|
||||
// DecodeReparsePoint decodes a Win32 REPARSE_DATA_BUFFER structure containing either a symlink
|
||||
// or a mount point.
|
||||
func DecodeReparsePoint(b []byte) (*ReparsePoint, error) {
|
||||
tag := binary.LittleEndian.Uint32(b[0:4])
|
||||
return DecodeReparsePointData(tag, b[8:])
|
||||
}
|
||||
|
||||
func DecodeReparsePointData(tag uint32, b []byte) (*ReparsePoint, error) {
|
||||
isMountPoint := false
|
||||
switch tag {
|
||||
case reparseTagMountPoint:
|
||||
isMountPoint = true
|
||||
case reparseTagSymlink:
|
||||
default:
|
||||
return nil, &UnsupportedReparsePointError{tag}
|
||||
}
|
||||
nameOffset := 8 + binary.LittleEndian.Uint16(b[4:6])
|
||||
if !isMountPoint {
|
||||
nameOffset += 4
|
||||
}
|
||||
nameLength := binary.LittleEndian.Uint16(b[6:8])
|
||||
name := make([]uint16, nameLength/2)
|
||||
err := binary.Read(bytes.NewReader(b[nameOffset:nameOffset+nameLength]), binary.LittleEndian, &name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ReparsePoint{string(utf16.Decode(name)), isMountPoint}, nil
|
||||
}
|
||||
|
||||
func isDriveLetter(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')
|
||||
}
|
||||
|
||||
// EncodeReparsePoint encodes a Win32 REPARSE_DATA_BUFFER structure describing a symlink or
|
||||
// mount point.
|
||||
func EncodeReparsePoint(rp *ReparsePoint) []byte {
|
||||
// Generate an NT path and determine if this is a relative path.
|
||||
var ntTarget string
|
||||
relative := false
|
||||
if strings.HasPrefix(rp.Target, `\\?\`) {
|
||||
ntTarget = `\??\` + rp.Target[4:]
|
||||
} else if strings.HasPrefix(rp.Target, `\\`) {
|
||||
ntTarget = `\??\UNC\` + rp.Target[2:]
|
||||
} else if len(rp.Target) >= 2 && isDriveLetter(rp.Target[0]) && rp.Target[1] == ':' {
|
||||
ntTarget = `\??\` + rp.Target
|
||||
} else {
|
||||
ntTarget = rp.Target
|
||||
relative = true
|
||||
}
|
||||
|
||||
// The paths must be NUL-terminated even though they are counted strings.
|
||||
target16 := utf16.Encode([]rune(rp.Target + "\x00"))
|
||||
ntTarget16 := utf16.Encode([]rune(ntTarget + "\x00"))
|
||||
|
||||
size := int(unsafe.Sizeof(reparseDataBuffer{})) - 8
|
||||
size += len(ntTarget16)*2 + len(target16)*2
|
||||
|
||||
tag := uint32(reparseTagMountPoint)
|
||||
if !rp.IsMountPoint {
|
||||
tag = reparseTagSymlink
|
||||
size += 4 // Add room for symlink flags
|
||||
}
|
||||
|
||||
data := reparseDataBuffer{
|
||||
ReparseTag: tag,
|
||||
ReparseDataLength: uint16(size),
|
||||
SubstituteNameOffset: 0,
|
||||
SubstituteNameLength: uint16((len(ntTarget16) - 1) * 2),
|
||||
PrintNameOffset: uint16(len(ntTarget16) * 2),
|
||||
PrintNameLength: uint16((len(target16) - 1) * 2),
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
_ = binary.Write(&b, binary.LittleEndian, &data)
|
||||
if !rp.IsMountPoint {
|
||||
flags := uint32(0)
|
||||
if relative {
|
||||
flags |= 1
|
||||
}
|
||||
_ = binary.Write(&b, binary.LittleEndian, flags)
|
||||
}
|
||||
|
||||
_ = binary.Write(&b, binary.LittleEndian, ntTarget16)
|
||||
_ = binary.Write(&b, binary.LittleEndian, target16)
|
||||
return b.Bytes()
|
||||
}
|
@ -1,144 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
//sys lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountNameW
|
||||
//sys lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountSidW
|
||||
//sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW
|
||||
//sys convertStringSidToSid(str *uint16, sid **byte) (err error) = advapi32.ConvertStringSidToSidW
|
||||
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
|
||||
//sys convertSecurityDescriptorToStringSecurityDescriptor(sd *byte, revision uint32, secInfo uint32, sddl **uint16, sddlSize *uint32) (err error) = advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW
|
||||
//sys localFree(mem uintptr) = LocalFree
|
||||
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
|
||||
|
||||
type AccountLookupError struct {
|
||||
Name string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *AccountLookupError) Error() string {
|
||||
if e.Name == "" {
|
||||
return "lookup account: empty account name specified"
|
||||
}
|
||||
var s string
|
||||
switch {
|
||||
case errors.Is(e.Err, windows.ERROR_INVALID_SID):
|
||||
s = "the security ID structure is invalid"
|
||||
case errors.Is(e.Err, windows.ERROR_NONE_MAPPED):
|
||||
s = "not found"
|
||||
default:
|
||||
s = e.Err.Error()
|
||||
}
|
||||
return "lookup account " + e.Name + ": " + s
|
||||
}
|
||||
|
||||
func (e *AccountLookupError) Unwrap() error { return e.Err }
|
||||
|
||||
type SddlConversionError struct {
|
||||
Sddl string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *SddlConversionError) Error() string {
|
||||
return "convert " + e.Sddl + ": " + e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *SddlConversionError) Unwrap() error { return e.Err }
|
||||
|
||||
// LookupSidByName looks up the SID of an account by name
|
||||
//
|
||||
//revive:disable-next-line:var-naming SID, not Sid
|
||||
func LookupSidByName(name string) (sid string, err error) {
|
||||
if name == "" {
|
||||
return "", &AccountLookupError{name, windows.ERROR_NONE_MAPPED}
|
||||
}
|
||||
|
||||
var sidSize, sidNameUse, refDomainSize uint32
|
||||
err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse)
|
||||
if err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
|
||||
return "", &AccountLookupError{name, err}
|
||||
}
|
||||
sidBuffer := make([]byte, sidSize)
|
||||
refDomainBuffer := make([]uint16, refDomainSize)
|
||||
err = lookupAccountName(nil, name, &sidBuffer[0], &sidSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{name, err}
|
||||
}
|
||||
var strBuffer *uint16
|
||||
err = convertSidToStringSid(&sidBuffer[0], &strBuffer)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{name, err}
|
||||
}
|
||||
sid = syscall.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(strBuffer))[:])
|
||||
localFree(uintptr(unsafe.Pointer(strBuffer)))
|
||||
return sid, nil
|
||||
}
|
||||
|
||||
// LookupNameBySid looks up the name of an account by SID
|
||||
//
|
||||
//revive:disable-next-line:var-naming SID, not Sid
|
||||
func LookupNameBySid(sid string) (name string, err error) {
|
||||
if sid == "" {
|
||||
return "", &AccountLookupError{sid, windows.ERROR_NONE_MAPPED}
|
||||
}
|
||||
|
||||
sidBuffer, err := windows.UTF16PtrFromString(sid)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
|
||||
var sidPtr *byte
|
||||
if err = convertStringSidToSid(sidBuffer, &sidPtr); err != nil {
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
defer localFree(uintptr(unsafe.Pointer(sidPtr)))
|
||||
|
||||
var nameSize, refDomainSize, sidNameUse uint32
|
||||
err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse)
|
||||
if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
|
||||
nameBuffer := make([]uint16, nameSize)
|
||||
refDomainBuffer := make([]uint16, refDomainSize)
|
||||
err = lookupAccountSid(nil, sidPtr, &nameBuffer[0], &nameSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse)
|
||||
if err != nil {
|
||||
return "", &AccountLookupError{sid, err}
|
||||
}
|
||||
|
||||
name = windows.UTF16ToString(nameBuffer)
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
||||
var sdBuffer uintptr
|
||||
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
|
||||
if err != nil {
|
||||
return nil, &SddlConversionError{sddl, err}
|
||||
}
|
||||
defer localFree(sdBuffer)
|
||||
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
|
||||
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
|
||||
return sd, nil
|
||||
}
|
||||
|
||||
func SecurityDescriptorToSddl(sd []byte) (string, error) {
|
||||
var sddl *uint16
|
||||
// The returned string length seems to include an arbitrary number of terminating NULs.
|
||||
// Don't use it.
|
||||
err := convertSecurityDescriptorToStringSecurityDescriptor(&sd[0], 1, 0xff, &sddl, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer localFree(uintptr(unsafe.Pointer(sddl)))
|
||||
return syscall.UTF16ToString((*[0xffff]uint16)(unsafe.Pointer(sddl))[:]), nil
|
||||
}
|
@ -1,5 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package winio
|
||||
|
||||
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go
|
@ -1,5 +0,0 @@
|
||||
//go:build tools
|
||||
|
||||
package winio
|
||||
|
||||
import _ "golang.org/x/tools/cmd/stringer"
|
@ -1,438 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
// Code generated by 'go generate' using "github.com/Microsoft/go-winio/tools/mkwinsyscall"; DO NOT EDIT.
|
||||
|
||||
package winio
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||
|
||||
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
|
||||
procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW")
|
||||
procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW")
|
||||
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
||||
procConvertStringSidToSidW = modadvapi32.NewProc("ConvertStringSidToSidW")
|
||||
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
|
||||
procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf")
|
||||
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
|
||||
procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW")
|
||||
procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW")
|
||||
procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW")
|
||||
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
|
||||
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
|
||||
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
|
||||
procBackupRead = modkernel32.NewProc("BackupRead")
|
||||
procBackupWrite = modkernel32.NewProc("BackupWrite")
|
||||
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
||||
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
||||
procCreateFileW = modkernel32.NewProc("CreateFileW")
|
||||
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
||||
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
|
||||
procGetCurrentThread = modkernel32.NewProc("GetCurrentThread")
|
||||
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
|
||||
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
|
||||
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
||||
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
|
||||
procLocalFree = modkernel32.NewProc("LocalFree")
|
||||
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
|
||||
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
|
||||
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
||||
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
||||
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
||||
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
||||
)
|
||||
|
||||
func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) {
|
||||
var _p0 uint32
|
||||
if releaseAll {
|
||||
_p0 = 1
|
||||
}
|
||||
r0, _, e1 := syscall.Syscall6(procAdjustTokenPrivileges.Addr(), 6, uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize)))
|
||||
success = r0 != 0
|
||||
if true {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertSecurityDescriptorToStringSecurityDescriptor(sd *byte, revision uint32, secInfo uint32, sddl **uint16, sddlSize *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procConvertSecurityDescriptorToStringSecurityDescriptorW.Addr(), 5, uintptr(unsafe.Pointer(sd)), uintptr(revision), uintptr(secInfo), uintptr(unsafe.Pointer(sddl)), uintptr(unsafe.Pointer(sddlSize)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertSidToStringSid(sid *byte, str **uint16) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procConvertSidToStringSidW.Addr(), 2, uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(str)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(str)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
|
||||
}
|
||||
|
||||
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertStringSidToSid(str *uint16, sid **byte) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procConvertStringSidToSidW.Addr(), 2, uintptr(unsafe.Pointer(str)), uintptr(unsafe.Pointer(sid)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
|
||||
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
|
||||
len = uint32(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func impersonateSelf(level uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procImpersonateSelf.Addr(), 1, uintptr(level), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(accountName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupAccountName(systemName, _p0, sid, sidSize, refDomain, refDomainSize, sidNameUse)
|
||||
}
|
||||
|
||||
func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall9(procLookupAccountNameW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall9(procLookupAccountSidW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupPrivilegeDisplayName(_p0, name, buffer, size, languageId)
|
||||
}
|
||||
|
||||
func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeDisplayNameW.Addr(), 5, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageId)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupPrivilegeName(_p0, luid, buffer, size)
|
||||
}
|
||||
|
||||
func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeNameW.Addr(), 4, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func lookupPrivilegeValue(systemName string, name string, luid *uint64) (err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(systemName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var _p1 *uint16
|
||||
_p1, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _lookupPrivilegeValue(_p0, _p1, luid)
|
||||
}
|
||||
|
||||
func _lookupPrivilegeValue(systemName *uint16, name *uint16, luid *uint64) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procLookupPrivilegeValueW.Addr(), 3, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(luid)))
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func openThreadToken(thread syscall.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) {
|
||||
var _p0 uint32
|
||||
if openAsSelf {
|
||||
_p0 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.Syscall6(procOpenThreadToken.Addr(), 4, uintptr(thread), uintptr(accessMask), uintptr(_p0), uintptr(unsafe.Pointer(token)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func revertToSelf() (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procRevertToSelf.Addr(), 0, 0, 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
|
||||
var _p0 *byte
|
||||
if len(b) > 0 {
|
||||
_p0 = &b[0]
|
||||
}
|
||||
var _p1 uint32
|
||||
if abort {
|
||||
_p1 = 1
|
||||
}
|
||||
var _p2 uint32
|
||||
if processSecurity {
|
||||
_p2 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.Syscall9(procBackupRead.Addr(), 7, uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesRead)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func backupWrite(h syscall.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) {
|
||||
var _p0 *byte
|
||||
if len(b) > 0 {
|
||||
_p0 = &b[0]
|
||||
}
|
||||
var _p1 uint32
|
||||
if abort {
|
||||
_p1 = 1
|
||||
}
|
||||
var _p2 uint32
|
||||
if processSecurity {
|
||||
_p2 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.Syscall9(procBackupWrite.Addr(), 7, uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesWritten)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
||||
}
|
||||
|
||||
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
|
||||
handle = syscall.Handle(r0)
|
||||
if handle == syscall.InvalidHandle {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
|
||||
newport = syscall.Handle(r0)
|
||||
if newport == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
|
||||
var _p0 *uint16
|
||||
_p0, err = syscall.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
||||
}
|
||||
|
||||
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
|
||||
handle = syscall.Handle(r0)
|
||||
if handle == syscall.InvalidHandle {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getCurrentThread() (h syscall.Handle) {
|
||||
r0, _, _ := syscall.Syscall(procGetCurrentThread.Addr(), 0, 0, 0, 0)
|
||||
h = syscall.Handle(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
|
||||
r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0)
|
||||
ptr = uintptr(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func localFree(mem uintptr) {
|
||||
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
|
||||
return
|
||||
}
|
||||
|
||||
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) {
|
||||
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
|
||||
status = ntStatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) {
|
||||
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
|
||||
status = ntStatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) {
|
||||
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
|
||||
status = ntStatus(r0)
|
||||
return
|
||||
}
|
||||
|
||||
func rtlNtStatusToDosError(status ntStatus) (winerr error) {
|
||||
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
|
||||
if r0 != 0 {
|
||||
winerr = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
||||
var _p0 uint32
|
||||
if wait {
|
||||
_p0 = 1
|
||||
}
|
||||
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
|
||||
if r1 == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
dist
|
||||
/doc
|
||||
/doc-staging
|
||||
.yardoc
|
||||
Gemfile.lock
|
||||
/internal/awstesting/integration/smoke/**/importmarker__.go
|
||||
/internal/awstesting/integration/smoke/_test/
|
||||
/vendor
|
||||
/private/model/cli/gen-api/gen-api
|
||||
.gradle/
|
||||
build/
|
||||
.idea/
|
||||
bin/
|
||||
.vscode/
|
@ -0,0 +1,27 @@
|
||||
[run]
|
||||
concurrency = 4
|
||||
timeout = "1m"
|
||||
issues-exit-code = 0
|
||||
modules-download-mode = "readonly"
|
||||
allow-parallel-runners = true
|
||||
skip-dirs = ["internal/repotools"]
|
||||
skip-dirs-use-default = true
|
||||
skip-files = ["service/transcribestreaming/eventstream_test.go"]
|
||||
[output]
|
||||
format = "github-actions"
|
||||
|
||||
[linters-settings.cyclop]
|
||||
skip-tests = false
|
||||
|
||||
[linters-settings.errcheck]
|
||||
check-blank = true
|
||||
|
||||
[linters]
|
||||
disable-all = true
|
||||
enable = ["errcheck"]
|
||||
fast = false
|
||||
|
||||
[issues]
|
||||
exclude-use-default = false
|
||||
|
||||
# Refer config definitions at https://golangci-lint.run/usage/configuration/#config-file
|
@ -0,0 +1,31 @@
|
||||
language: go
|
||||
sudo: true
|
||||
dist: bionic
|
||||
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
|
||||
os:
|
||||
- linux
|
||||
- osx
|
||||
# Travis doesn't work with windows and Go tip
|
||||
#- windows
|
||||
|
||||
go:
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
||||
|
||||
before_install:
|
||||
- if [ "$TRAVIS_OS_NAME" = "windows" ]; then choco install make; fi
|
||||
- (cd /tmp/; go get golang.org/x/lint/golint)
|
||||
|
||||
env:
|
||||
- EACHMODULE_CONCURRENCY=4
|
||||
|
||||
script:
|
||||
- make ci-test-no-generate;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,4 @@
|
||||
## Code of Conduct
|
||||
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
|
||||
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
|
||||
opensource-codeofconduct@amazon.com with any additional questions or comments.
|
@ -0,0 +1,178 @@
|
||||
# Contributing to the AWS SDK for Go
|
||||
|
||||
Thank you for your interest in contributing to the AWS SDK for Go!
|
||||
We work hard to provide a high-quality and useful SDK, and we greatly value
|
||||
feedback and contributions from our community. Whether it's a bug report,
|
||||
new feature, correction, or additional documentation, we welcome your issues
|
||||
and pull requests. Please read through this document before submitting any
|
||||
[issues] or [pull requests][pr] to ensure we have all the necessary information to
|
||||
effectively respond to your bug report or contribution.
|
||||
|
||||
Jump To:
|
||||
|
||||
* [Bug Reports](#bug-reports)
|
||||
* [Feature Requests](#feature-requests)
|
||||
* [Code Contributions](#code-contributions)
|
||||
|
||||
|
||||
## How to contribute
|
||||
|
||||
*Before you send us a pull request, please be sure that:*
|
||||
|
||||
1. You're working from the latest source on the master branch.
|
||||
2. You check existing open, and recently closed, pull requests to be sure
|
||||
that someone else hasn't already addressed the problem.
|
||||
3. You create an issue before working on a contribution that will take a
|
||||
significant amount of your time.
|
||||
|
||||
*Creating a Pull Request*
|
||||
|
||||
1. Fork the repository.
|
||||
2. In your fork, make your change in a branch that's based on this repo's master branch.
|
||||
3. Commit the change to your fork, using a clear and descriptive commit message.
|
||||
4. Create a pull request, answering any questions in the pull request form.
|
||||
|
||||
For contributions that will take a significant amount of time, open a new
|
||||
issue to pitch your idea before you get started. Explain the problem and
|
||||
describe the content you want to see added to the documentation. Let us know
|
||||
if you'll write it yourself or if you'd like us to help. We'll discuss your
|
||||
proposal with you and let you know whether we're likely to accept it.
|
||||
|
||||
## Bug Reports
|
||||
|
||||
You can file bug reports against the SDK on the [GitHub issues][issues] page.
|
||||
|
||||
If you are filing a report for a bug or regression in the SDK, it's extremely
|
||||
helpful to provide as much information as possible when opening the original
|
||||
issue. This helps us reproduce and investigate the possible bug without having
|
||||
to wait for this extra information to be provided. Please read the following
|
||||
guidelines prior to filing a bug report.
|
||||
|
||||
1. Search through existing [issues][] to ensure that your specific issue has
|
||||
not yet been reported. If it is a common issue, it is likely there is
|
||||
already a bug report for your problem.
|
||||
|
||||
2. Ensure that you have tested the latest version of the SDK. Although you
|
||||
may have an issue against an older version of the SDK, we cannot provide
|
||||
bug fixes for old versions. It's also possible that the bug may have been
|
||||
fixed in the latest release.
|
||||
|
||||
3. Provide as much information about your environment, SDK version, and
|
||||
relevant dependencies as possible. For example, let us know what version
|
||||
of Go you are using, which and version of the operating system, and the
|
||||
the environment your code is running in. e.g Container.
|
||||
|
||||
4. Provide a minimal test case that reproduces your issue or any error
|
||||
information you related to your problem. We can provide feedback much
|
||||
more quickly if we know what operations you are calling in the SDK. If
|
||||
you cannot provide a full test case, provide as much code as you can
|
||||
to help us diagnose the problem. Any relevant information should be provided
|
||||
as well, like whether this is a persistent issue, or if it only occurs
|
||||
some of the time.
|
||||
|
||||
## Feature Requests
|
||||
|
||||
Open an [issue][issues] with the following:
|
||||
|
||||
* A short, descriptive title. Ideally, other community members should be able
|
||||
to get a good idea of the feature just from reading the title.
|
||||
* A detailed description of the the proposed feature.
|
||||
* Why it should be added to the SDK.
|
||||
* If possible, example code to illustrate how it should work.
|
||||
* Use Markdown to make the request easier to read;
|
||||
* If you intend to implement this feature, indicate that you'd like to the issue to be assigned to you.
|
||||
|
||||
## Code Contributions
|
||||
|
||||
We are always happy to receive code and documentation contributions to the SDK.
|
||||
Please be aware of the following notes prior to opening a pull request:
|
||||
|
||||
1. The SDK is released under the [Apache license][license]. Any code you submit
|
||||
will be released under that license. For substantial contributions, we may
|
||||
ask you to sign a [Contributor License Agreement (CLA)][cla].
|
||||
|
||||
2. If you would like to implement support for a significant feature that is not
|
||||
yet available in the SDK, please talk to us beforehand to avoid any
|
||||
duplication of effort.
|
||||
|
||||
3. Wherever possible, pull requests should contain tests as appropriate.
|
||||
Bugfixes should contain tests that exercise the corrected behavior (i.e., the
|
||||
test should fail without the bugfix and pass with it), and new features
|
||||
should be accompanied by tests exercising the feature.
|
||||
|
||||
4. Pull requests that contain failing tests will not be merged until the test
|
||||
failures are addressed. Pull requests that cause a significant drop in the
|
||||
SDK's test coverage percentage are unlikely to be merged until tests have
|
||||
been added.
|
||||
|
||||
5. The JSON files under the SDK's `models` folder are sourced from outside the SDK.
|
||||
Such as `models/apis/ec2/2016-11-15/api.json`. We will not accept pull requests
|
||||
directly on these models. If you discover an issue with the models please
|
||||
create a [GitHub issue][issues] describing the issue.
|
||||
|
||||
### Testing
|
||||
|
||||
To run the tests locally, running the `make unit` command will `go get` the
|
||||
SDK's testing dependencies, and run vet, link and unit tests for the SDK.
|
||||
|
||||
```
|
||||
make unit
|
||||
```
|
||||
|
||||
Standard go testing functionality is supported as well. To test SDK code that
|
||||
is tagged with `codegen` you'll need to set the build tag in the go test
|
||||
command. The `make unit` command will do this automatically.
|
||||
|
||||
```
|
||||
go test -tags codegen ./private/...
|
||||
```
|
||||
|
||||
See the `Makefile` for additional testing tags that can be used in testing.
|
||||
|
||||
To test on multiple platform the SDK includes several DockerFiles under the
|
||||
`awstesting/sandbox` folder, and associated make recipes to to execute
|
||||
unit testing within environments configured for specific Go versions.
|
||||
|
||||
```
|
||||
make sandbox-test-go18
|
||||
```
|
||||
|
||||
To run all sandbox environments use the following make recipe
|
||||
|
||||
```
|
||||
# Optionally update the Go tip that will be used during the batch testing
|
||||
make update-aws-golang-tip
|
||||
|
||||
# Run all SDK tests for supported Go versions in sandboxes
|
||||
make sandbox-test
|
||||
```
|
||||
|
||||
In addition the sandbox environment include make recipes for interactive modes
|
||||
so you can run command within the Docker container and context of the SDK.
|
||||
|
||||
```
|
||||
make sandbox-go18
|
||||
```
|
||||
|
||||
### Changelog Documents
|
||||
|
||||
You can see all release changes in the `CHANGELOG.md` file at the root of the
|
||||
repository. The release notes added to this file will contain service client
|
||||
updates, and major SDK changes. When submitting a pull request please include an entry in `CHANGELOG_PENDING.md` under the appropriate changelog type so your changelog entry is included on the following release.
|
||||
|
||||
#### Changelog Types
|
||||
|
||||
* `SDK Features` - For major additive features, internal changes that have
|
||||
outward impact, or updates to the SDK foundations. This will result in a minor
|
||||
version change.
|
||||
* `SDK Enhancements` - For minor additive features or incremental sized changes.
|
||||
This will result in a patch version change.
|
||||
* `SDK Bugs` - For minor changes that resolve an issue. This will result in a
|
||||
patch version change.
|
||||
|
||||
[issues]: https://github.com/aws/aws-sdk-go/issues
|
||||
[pr]: https://github.com/aws/aws-sdk-go/pulls
|
||||
[license]: http://aws.amazon.com/apache2.0/
|
||||
[cla]: http://en.wikipedia.org/wiki/Contributor_License_Agreement
|
||||
[releasenotes]: https://github.com/aws/aws-sdk-go/releases
|
||||
|
@ -0,0 +1,15 @@
|
||||
Open Discussions
|
||||
---
|
||||
The following issues are currently open for community feedback.
|
||||
All discourse must adhere to the [Code of Conduct] policy.
|
||||
|
||||
* [Refactoring API Client Paginators](https://github.com/aws/aws-sdk-go-v2/issues/439)
|
||||
* [Refactoring API Client Waiters](https://github.com/aws/aws-sdk-go-v2/issues/442)
|
||||
* [Refactoring API Client Enums and Types to Discrete Packages](https://github.com/aws/aws-sdk-go-v2/issues/445)
|
||||
* [SDK Modularization](https://github.com/aws/aws-sdk-go-v2/issues/444)
|
||||
|
||||
Past Discussions
|
||||
---
|
||||
The issues listed here are for documentation purposes, and is used to capture issues and their associated discussions.
|
||||
|
||||
[Code of Conduct]: https://github.com/aws/aws-sdk-go-v2/blob/master/CODE_OF_CONDUCT.md
|
@ -0,0 +1,520 @@
|
||||
# Lint rules to ignore
|
||||
LINTIGNORESINGLEFIGHT='internal/sync/singleflight/singleflight.go:.+error should be the last type'
|
||||
LINT_IGNORE_S3MANAGER_INPUT='feature/s3/manager/upload.go:.+struct field SSEKMSKeyId should be SSEKMSKeyID'
|
||||
|
||||
UNIT_TEST_TAGS=
|
||||
BUILD_TAGS=-tags "example,codegen,integration,ec2env,perftest"
|
||||
|
||||
SMITHY_GO_SRC ?= $(shell pwd)/../smithy-go
|
||||
|
||||
SDK_MIN_GO_VERSION ?= 1.15
|
||||
|
||||
EACHMODULE_FAILFAST ?= true
|
||||
EACHMODULE_FAILFAST_FLAG=-fail-fast=${EACHMODULE_FAILFAST}
|
||||
|
||||
EACHMODULE_CONCURRENCY ?= 1
|
||||
EACHMODULE_CONCURRENCY_FLAG=-c ${EACHMODULE_CONCURRENCY}
|
||||
|
||||
EACHMODULE_SKIP ?=
|
||||
EACHMODULE_SKIP_FLAG=-skip="${EACHMODULE_SKIP}"
|
||||
|
||||
EACHMODULE_FLAGS=${EACHMODULE_CONCURRENCY_FLAG} ${EACHMODULE_FAILFAST_FLAG} ${EACHMODULE_SKIP_FLAG}
|
||||
|
||||
# SDK's Core and client packages that are compatible with Go 1.9+.
|
||||
SDK_CORE_PKGS=./aws/... ./internal/...
|
||||
SDK_CLIENT_PKGS=./service/...
|
||||
SDK_COMPA_PKGS=${SDK_CORE_PKGS} ${SDK_CLIENT_PKGS}
|
||||
|
||||
# SDK additional packages that are used for development of the SDK.
|
||||
SDK_EXAMPLES_PKGS=
|
||||
SDK_ALL_PKGS=${SDK_COMPA_PKGS} ${SDK_EXAMPLES_PKGS}
|
||||
|
||||
RUN_NONE=-run NONE
|
||||
RUN_INTEG=-run '^TestInteg_'
|
||||
|
||||
CODEGEN_RESOURCES_PATH=$(shell pwd)/codegen/smithy-aws-go-codegen/src/main/resources/software/amazon/smithy/aws/go/codegen
|
||||
CODEGEN_API_MODELS_PATH=$(shell pwd)/codegen/sdk-codegen/aws-models
|
||||
ENDPOINTS_JSON=${CODEGEN_RESOURCES_PATH}/endpoints.json
|
||||
ENDPOINT_PREFIX_JSON=${CODEGEN_RESOURCES_PATH}/endpoint-prefix.json
|
||||
|
||||
LICENSE_FILE=$(shell pwd)/LICENSE.txt
|
||||
|
||||
SMITHY_GO_VERSION ?=
|
||||
PRE_RELEASE_VERSION ?=
|
||||
RELEASE_MANIFEST_FILE ?=
|
||||
RELEASE_CHGLOG_DESC_FILE ?=
|
||||
|
||||
REPOTOOLS_VERSION ?= latest
|
||||
REPOTOOLS_MODULE = github.com/awslabs/aws-go-multi-module-repository-tools
|
||||
REPOTOOLS_CMD_ANNOTATE_STABLE_GEN = ${REPOTOOLS_MODULE}/cmd/annotatestablegen@${REPOTOOLS_VERSION}
|
||||
REPOTOOLS_CMD_MAKE_RELATIVE = ${REPOTOOLS_MODULE}/cmd/makerelative@${REPOTOOLS_VERSION}
|
||||
REPOTOOLS_CMD_CALCULATE_RELEASE = ${REPOTOOLS_MODULE}/cmd/calculaterelease@${REPOTOOLS_VERSION}
|
||||
REPOTOOLS_CMD_UPDATE_REQUIRES = ${REPOTOOLS_MODULE}/cmd/updaterequires@${REPOTOOLS_VERSION}
|
||||
REPOTOOLS_CMD_UPDATE_MODULE_METADATA = ${REPOTOOLS_MODULE}/cmd/updatemodulemeta@${REPOTOOLS_VERSION}
|
||||
REPOTOOLS_CMD_GENERATE_CHANGELOG = ${REPOTOOLS_MODULE}/cmd/generatechangelog@${REPOTOOLS_VERSION}
|
||||
REPOTOOLS_CMD_CHANGELOG = ${REPOTOOLS_MODULE}/cmd/changelog@${REPOTOOLS_VERSION}
|
||||
REPOTOOLS_CMD_TAG_RELEASE = ${REPOTOOLS_MODULE}/cmd/tagrelease@${REPOTOOLS_VERSION}
|
||||
REPOTOOLS_CMD_EDIT_MODULE_DEPENDENCY = ${REPOTOOLS_MODULE}/cmd/editmoduledependency@${REPOTOOLS_VERSION}
|
||||
|
||||
REPOTOOLS_CALCULATE_RELEASE_VERBOSE ?= false
|
||||
REPOTOOLS_CALCULATE_RELEASE_VERBOSE_FLAG=-v=${REPOTOOLS_CALCULATE_RELEASE_VERBOSE}
|
||||
|
||||
REPOTOOLS_CALCULATE_RELEASE_ADDITIONAL_ARGS ?=
|
||||
|
||||
ifneq ($(PRE_RELEASE_VERSION),)
|
||||
REPOTOOLS_CALCULATE_RELEASE_ADDITIONAL_ARGS += -preview=${PRE_RELEASE_VERSION}
|
||||
endif
|
||||
|
||||
.PHONY: all
|
||||
all: generate unit
|
||||
|
||||
###################
|
||||
# Code Generation #
|
||||
###################
|
||||
.PHONY: generate smithy-generate smithy-build smithy-build-% smithy-clean smithy-go-publish-local format \
|
||||
gen-config-asserts gen-repo-mod-replace gen-mod-replace-smithy gen-mod-dropreplace-smithy-% gen-aws-ptrs tidy-modules-% \
|
||||
add-module-license-files sync-models sync-endpoints-model sync-endpoints.json clone-v1-models gen-internal-codegen \
|
||||
sync-api-models copy-attributevalue-feature min-go-version-% update-requires smithy-annotate-stable \
|
||||
update-module-metadata download-modules-%
|
||||
|
||||
generate: smithy-generate update-requires gen-repo-mod-replace update-module-metadata smithy-annotate-stable \
|
||||
gen-config-asserts gen-internal-codegen copy-attributevalue-feature gen-mod-dropreplace-smithy-. min-go-version-. \
|
||||
tidy-modules-. add-module-license-files gen-aws-ptrs format
|
||||
|
||||
smithy-generate:
|
||||
cd codegen && ./gradlew clean build -Plog-tests && ./gradlew clean
|
||||
|
||||
smithy-build:
|
||||
cd codegen && ./gradlew clean build -Plog-tests
|
||||
|
||||
smithy-build-%:
|
||||
@# smithy-build- command that uses the pattern to define build filter that
|
||||
@# the smithy API model service id starts with. Strips off the
|
||||
@# "smithy-build-".
|
||||
@#
|
||||
@# e.g. smithy-build-com.amazonaws.rds
|
||||
@# e.g. smithy-build-com.amazonaws.rds#AmazonRDSv19
|
||||
cd codegen && \
|
||||
SMITHY_GO_BUILD_API="$(subst smithy-build-,,$@)" ./gradlew clean build -Plog-tests
|
||||
|
||||
smithy-annotate-stable:
|
||||
go run ${REPOTOOLS_CMD_ANNOTATE_STABLE_GEN}
|
||||
|
||||
smithy-clean:
|
||||
cd codegen && ./gradlew clean
|
||||
|
||||
smithy-go-publish-local:
|
||||
rm -rf /tmp/smithy-go-local
|
||||
git clone https://github.com/aws/smithy-go /tmp/smithy-go-local
|
||||
make -C /tmp/smithy-go-local smithy-clean smithy-publish-local
|
||||
|
||||
format:
|
||||
gofmt -w -s .
|
||||
|
||||
gen-config-asserts:
|
||||
@echo "Generating SDK config package implementor assertions"
|
||||
cd config \
|
||||
&& go mod tidy \
|
||||
&& go generate
|
||||
|
||||
gen-internal-codegen:
|
||||
@echo "Generating internal/codegen"
|
||||
cd internal/codegen \
|
||||
&& go mod tidy \
|
||||
&& go generate
|
||||
|
||||
gen-repo-mod-replace:
|
||||
@echo "Generating go.mod replace for repo modules"
|
||||
go run ${REPOTOOLS_CMD_MAKE_RELATIVE}
|
||||
|
||||
gen-mod-replace-smithy-%:
|
||||
@# gen-mod-replace-smithy- command that uses the pattern to define build filter that
|
||||
@# for modules to add replace to. Strips off the "gen-mod-replace-smithy-".
|
||||
@#
|
||||
@# SMITHY_GO_SRC environment variable is the path to add replace to
|
||||
@#
|
||||
@# e.g. gen-mod-replace-smithy-service_ssooidc
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst gen-mod-replace-smithy-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go mod edit -replace github.com/aws/smithy-go=${SMITHY_GO_SRC}"
|
||||
|
||||
gen-mod-dropreplace-smithy-%:
|
||||
@# gen-mod-dropreplace-smithy- command that uses the pattern to define build filter that
|
||||
@# for modules to add replace to. Strips off the "gen-mod-dropreplace-smithy-".
|
||||
@#
|
||||
@# e.g. gen-mod-dropreplace-smithy-service_ssooidc
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst gen-mod-dropreplace-smithy-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go mod edit -dropreplace github.com/aws/smithy-go"
|
||||
|
||||
gen-aws-ptrs:
|
||||
cd aws && go generate
|
||||
|
||||
tidy-modules-%:
|
||||
@# tidy command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "tidy-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. tidy-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst tidy-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go mod tidy"
|
||||
|
||||
download-modules-%:
|
||||
@# download command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "download-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. download-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst download-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go mod download all"
|
||||
|
||||
add-module-license-files:
|
||||
cd internal/repotools/cmd/eachmodule && \
|
||||
go run . -skip-root \
|
||||
"cp $(LICENSE_FILE) ."
|
||||
|
||||
sync-models: sync-endpoints-model sync-api-models
|
||||
|
||||
sync-endpoints-model: sync-endpoints.json
|
||||
|
||||
sync-endpoints.json:
|
||||
[[ ! -z "${ENDPOINTS_MODEL}" ]] && cp ${ENDPOINTS_MODEL} ${ENDPOINTS_JSON} || echo "ENDPOINTS_MODEL not set, must not be empty"
|
||||
|
||||
clone-v1-models:
|
||||
rm -rf /tmp/aws-sdk-go-model-sync
|
||||
git clone https://github.com/aws/aws-sdk-go.git --depth 1 /tmp/aws-sdk-go-model-sync
|
||||
|
||||
sync-api-models:
|
||||
cd internal/repotools/cmd/syncAPIModels && \
|
||||
go run . \
|
||||
-m ${API_MODELS} \
|
||||
-o ${CODEGEN_API_MODELS_PATH}
|
||||
|
||||
copy-attributevalue-feature:
|
||||
cd ./feature/dynamodbstreams/attributevalue && \
|
||||
find . -name "*.go" | grep -v "doc.go" | xargs -I % rm % && \
|
||||
find ../../dynamodb/attributevalue -name "*.go" | grep -v "doc.go" | xargs -I % cp % . && \
|
||||
ls *.go | grep -v "convert.go" | grep -v "doc.go" | \
|
||||
xargs -I % sed -i.bk -E 's:github.com/aws/aws-sdk-go-v2/(service|feature)/dynamodb:github.com/aws/aws-sdk-go-v2/\1/dynamodbstreams:g' % && \
|
||||
ls *.go | grep -v "convert.go" | grep -v "doc.go" | \
|
||||
xargs -I % sed -i.bk 's:DynamoDB:DynamoDBStreams:g' % && \
|
||||
ls *.go | grep -v "doc.go" | \
|
||||
xargs -I % sed -i.bk 's:dynamodb\.:dynamodbstreams.:g' % && \
|
||||
sed -i.bk 's:streams\.:ddbtypes.:g' "convert.go" && \
|
||||
sed -i.bk 's:ddb\.:streams.:g' "convert.go" && \
|
||||
sed -i.bk 's:ddbtypes\.:ddb.:g' "convert.go" &&\
|
||||
sed -i.bk 's:Streams::g' "convert.go" && \
|
||||
rm -rf ./*.bk && \
|
||||
go mod tidy && \
|
||||
gofmt -w -s . && \
|
||||
go test .
|
||||
|
||||
min-go-version-%:
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst min-go-version-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go mod edit -go=${SDK_MIN_GO_VERSION}"
|
||||
|
||||
update-requires:
|
||||
go run ${REPOTOOLS_CMD_UPDATE_REQUIRES}
|
||||
|
||||
update-module-metadata:
|
||||
go run ${REPOTOOLS_CMD_UPDATE_MODULE_METADATA}
|
||||
|
||||
################
|
||||
# Unit Testing #
|
||||
################
|
||||
.PHONY: unit unit-race unit-test unit-race-test unit-race-modules-% unit-modules-% build build-modules-% \
|
||||
go-build-modules-% test test-race-modules-% test-modules-% cachedep cachedep-modules-% api-diff-modules-%
|
||||
|
||||
unit: lint unit-modules-.
|
||||
unit-race: lint unit-race-modules-.
|
||||
|
||||
unit-test: test-modules-.
|
||||
unit-race-test: test-race-modules-.
|
||||
|
||||
unit-race-modules-%:
|
||||
@# unit command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "unit-race-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. unit-race-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst unit-race-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go vet ${BUILD_TAGS} --all ./..." \
|
||||
"go test ${BUILD_TAGS} ${RUN_NONE} ./..." \
|
||||
"go test -timeout=1m ${UNIT_TEST_TAGS} -race -cpu=4 ./..."
|
||||
|
||||
|
||||
unit-modules-%:
|
||||
@# unit command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "unit-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. unit-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst unit-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go vet ${BUILD_TAGS} --all ./..." \
|
||||
"go test ${BUILD_TAGS} ${RUN_NONE} ./..." \
|
||||
"go test -timeout=1m ${UNIT_TEST_TAGS} ./..."
|
||||
|
||||
build: build-modules-.
|
||||
|
||||
build-modules-%:
|
||||
@# build command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "build-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. build-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst build-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go test ${BUILD_TAGS} ${RUN_NONE} ./..."
|
||||
|
||||
go-build-modules-%:
|
||||
@# build command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "build-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# Validates that all modules in the repo have buildable Go files.
|
||||
@#
|
||||
@# e.g. go-build-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst go-build-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go build ${BUILD_TAGS} ./..."
|
||||
|
||||
test: test-modules-.
|
||||
|
||||
test-race-modules-%:
|
||||
@# Test command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "test-race-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. test-race-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst test-race-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go test -timeout=1m ${UNIT_TEST_TAGS} -race -cpu=4 ./..."
|
||||
|
||||
test-modules-%:
|
||||
@# Test command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "test-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. test-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst test-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go test -timeout=1m ${UNIT_TEST_TAGS} ./..."
|
||||
|
||||
cachedep: cachedep-modules-.
|
||||
|
||||
cachedep-modules-%:
|
||||
@# build command that uses the pattern to define the root path that the
|
||||
@# module caching will start from. Strips off the "cachedep-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. cachedep-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst cachedep-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go mod download"
|
||||
|
||||
api-diff-modules-%:
|
||||
@# Command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "api-diff-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# Requires golang.org/x/exp/cmd/gorelease to be available in the GOPATH.
|
||||
@#
|
||||
@# e.g. api-diff-modules-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst api-diff-modules-,,$@)) \
|
||||
-fail-fast=true \
|
||||
-c 1 \
|
||||
-skip="internal/repotools" \
|
||||
"$$(go env GOPATH)/bin/gorelease"
|
||||
|
||||
##############
|
||||
# CI Testing #
|
||||
##############
|
||||
.PHONY: ci-test ci-test-no-generate ci-test-generate-validate
|
||||
|
||||
ci-test: generate unit-race ci-test-generate-validate
|
||||
ci-test-no-generate: unit-race
|
||||
|
||||
ci-test-generate-validate:
|
||||
@echo "CI test validate no generated code changes"
|
||||
git update-index --assume-unchanged go.mod go.sum
|
||||
git add . -A
|
||||
gitstatus=`git diff --cached --ignore-space-change`; \
|
||||
echo "$$gitstatus"; \
|
||||
if [ "$$gitstatus" != "" ] && [ "$$gitstatus" != "skipping validation" ]; then echo "$$gitstatus"; exit 1; fi
|
||||
git update-index --no-assume-unchanged go.mod go.sum
|
||||
|
||||
ci-lint: ci-lint-.
|
||||
|
||||
ci-lint-%:
|
||||
@# Run golangci-lint command that uses the pattern to define the root path that the
|
||||
@# module check will start from. Strips off the "ci-lint-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. ci-lint-internal_protocoltest
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst ci-lint-,,$@)) \
|
||||
-fail-fast=false \
|
||||
-c 1 \
|
||||
-skip="internal/repotools" \
|
||||
"golangci-lint run"
|
||||
|
||||
ci-lint-install:
|
||||
@# Installs golangci-lint at GoPATH.
|
||||
@# This should be used to run golangci-lint locally.
|
||||
@#
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
#######################
|
||||
# Integration Testing #
|
||||
#######################
|
||||
.PHONY: integration integ-modules-% cleanup-integ-buckets
|
||||
|
||||
integration: integ-modules-service
|
||||
|
||||
integ-modules-%:
|
||||
@# integration command that uses the pattern to define the root path that
|
||||
@# the module testing will start from. Strips off the "integ-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. test-modules-service_dynamodb
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst integ-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go test -timeout=10m -tags "integration" -v ${RUN_INTEG} -count 1 ./..."
|
||||
|
||||
cleanup-integ-buckets:
|
||||
@echo "Cleaning up SDK integration resources"
|
||||
go run -tags "integration" ./internal/awstesting/cmd/bucket_cleanup/main.go "aws-sdk-go-integration"
|
||||
|
||||
##############
|
||||
# Benchmarks #
|
||||
##############
|
||||
.PHONY: bench bench-modules-%
|
||||
|
||||
bench: bench-modules-.
|
||||
|
||||
bench-modules-%:
|
||||
@# benchmark command that uses the pattern to define the root path that
|
||||
@# the module testing will start from. Strips off the "bench-modules-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# e.g. bench-modules-service_dynamodb
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst bench-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go test -timeout=10m -bench . --benchmem ${BUILD_TAGS} ${RUN_NONE} ./..."
|
||||
|
||||
|
||||
#####################
|
||||
# Release Process #
|
||||
#####################
|
||||
.PHONY: preview-release pre-release-validation release
|
||||
|
||||
ls-changes:
|
||||
go run ${REPOTOOLS_CMD_CHANGELOG} ls
|
||||
|
||||
preview-release:
|
||||
go run ${REPOTOOLS_CMD_CALCULATE_RELEASE} ${REPOTOOLS_CALCULATE_RELEASE_VERBOSE_FLAG} ${REPOTOOLS_CALCULATE_RELEASE_ADDITIONAL_ARGS}
|
||||
|
||||
pre-release-validation:
|
||||
@if [[ -z "${RELEASE_MANIFEST_FILE}" ]]; then \
|
||||
echo "RELEASE_MANIFEST_FILE is required to specify the file to write the release manifest" && false; \
|
||||
fi
|
||||
@if [[ -z "${RELEASE_CHGLOG_DESC_FILE}" ]]; then \
|
||||
echo "RELEASE_CHGLOG_DESC_FILE is required to specify the file to write the release notes" && false; \
|
||||
fi
|
||||
|
||||
release: pre-release-validation
|
||||
go run ${REPOTOOLS_CMD_CALCULATE_RELEASE} -o ${RELEASE_MANIFEST_FILE} ${REPOTOOLS_CALCULATE_RELEASE_VERBOSE_FLAG} ${REPOTOOLS_CALCULATE_RELEASE_ADDITIONAL_ARGS}
|
||||
go run ${REPOTOOLS_CMD_UPDATE_REQUIRES} -release ${RELEASE_MANIFEST_FILE}
|
||||
go run ${REPOTOOLS_CMD_UPDATE_MODULE_METADATA} -release ${RELEASE_MANIFEST_FILE}
|
||||
go run ${REPOTOOLS_CMD_GENERATE_CHANGELOG} -release ${RELEASE_MANIFEST_FILE} -o ${RELEASE_CHGLOG_DESC_FILE}
|
||||
go run ${REPOTOOLS_CMD_CHANGELOG} rm -all
|
||||
go run ${REPOTOOLS_CMD_TAG_RELEASE} -release ${RELEASE_MANIFEST_FILE}
|
||||
|
||||
##############
|
||||
# Repo Tools #
|
||||
##############
|
||||
.PHONY: install-repotools
|
||||
|
||||
install-repotools:
|
||||
go install ${REPOTOOLS_MODULE}/cmd/changelog@${REPOTOOLS_VERSION}
|
||||
|
||||
set-smithy-go-version:
|
||||
@if [[ -z "${SMITHY_GO_VERSION}" ]]; then \
|
||||
echo "SMITHY_GO_VERSION is required to update SDK's smithy-go module dependency version" && false; \
|
||||
fi
|
||||
go run ${REPOTOOLS_CMD_EDIT_MODULE_DEPENDENCY} -s "github.com/aws/smithy-go" -v "${SMITHY_GO_VERSION}"
|
||||
|
||||
##################
|
||||
# Linting/Verify #
|
||||
##################
|
||||
.PHONY: verify lint vet vet-modules-% sdkv1check
|
||||
|
||||
verify: lint vet sdkv1check
|
||||
|
||||
lint:
|
||||
@echo "go lint SDK and vendor packages"
|
||||
@lint=`golint ./...`; \
|
||||
dolint=`echo "$$lint" | grep -E -v \
|
||||
-e ${LINT_IGNORE_S3MANAGER_INPUT} \
|
||||
-e ${LINTIGNORESINGLEFIGHT}`; \
|
||||
echo "$$dolint"; \
|
||||
if [ "$$dolint" != "" ]; then exit 1; fi
|
||||
|
||||
vet: vet-modules-.
|
||||
|
||||
vet-modules-%:
|
||||
cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst vet-modules-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go vet ${BUILD_TAGS} --all ./..."
|
||||
|
||||
sdkv1check:
|
||||
@echo "Checking for usage of AWS SDK for Go v1"
|
||||
@sdkv1usage=`go list -test -f '''{{ if not .Standard }}{{ range $$_, $$name := .Imports }} * {{ $$.ImportPath }} -> {{ $$name }}{{ print "\n" }}{{ end }}{{ range $$_, $$name := .TestImports }} *: {{ $$.ImportPath }} -> {{ $$name }}{{ print "\n" }}{{ end }}{{ end}}''' ./... | sort -u | grep '''/aws-sdk-go/'''`; \
|
||||
echo "$$sdkv1usage"; \
|
||||
if [ "$$sdkv1usage" != "" ]; then exit 1; fi
|
||||
|
||||
list-deps: list-deps-.
|
||||
|
||||
list-deps-%:
|
||||
@# command that uses the pattern to define the root path that the
|
||||
@# module testing will start from. Strips off the "list-deps-" and
|
||||
@# replaces all "_" with "/".
|
||||
@#
|
||||
@# Trim output to only include stdout for list of dependencies only.
|
||||
@# make list-deps 2>&-
|
||||
@#
|
||||
@# e.g. list-deps-internal_protocoltest
|
||||
@cd ./internal/repotools/cmd/eachmodule \
|
||||
&& go run . -p $(subst _,/,$(subst list-deps-,,$@)) ${EACHMODULE_FLAGS} \
|
||||
"go list -m all | grep -v 'github.com/aws/aws-sdk-go-v2'" | sort -u
|
||||
|
||||
###################
|
||||
# Sandbox Testing #
|
||||
###################
|
||||
.PHONY: sandbox-tests sandbox-build-% sandbox-run-% sandbox-test-% update-aws-golang-tip
|
||||
|
||||
sandbox-tests: sandbox-test-go1.15 sandbox-test-go1.16 sandbox-test-go1.17 sandbox-test-gotip
|
||||
|
||||
sandbox-build-%:
|
||||
@# sandbox-build-go1.17
|
||||
@# sandbox-build-gotip
|
||||
docker build \
|
||||
-f ./internal/awstesting/sandbox/Dockerfile.test.$(subst sandbox-build-,,$@) \
|
||||
-t "aws-sdk-go-$(subst sandbox-build-,,$@)" .
|
||||
sandbox-run-%: sandbox-build-%
|
||||
@# sandbox-run-go1.17
|
||||
@# sandbox-run-gotip
|
||||
docker run -i -t "aws-sdk-go-$(subst sandbox-run-,,$@)" bash
|
||||
sandbox-test-%: sandbox-build-%
|
||||
@# sandbox-test-go1.17
|
||||
@# sandbox-test-gotip
|
||||
docker run -t "aws-sdk-go-$(subst sandbox-test-,,$@)"
|
||||
|
||||
update-aws-golang-tip:
|
||||
docker build --no-cache=true -f ./internal/awstesting/sandbox/Dockerfile.golang-tip -t "aws-golang:tip" .
|
@ -0,0 +1,3 @@
|
||||
AWS SDK for Go
|
||||
Copyright 2015 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Copyright 2014-2015 Stripe, Inc.
|
@ -0,0 +1,92 @@
|
||||
// Package arn provides a parser for interacting with Amazon Resource Names.
|
||||
package arn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
arnDelimiter = ":"
|
||||
arnSections = 6
|
||||
arnPrefix = "arn:"
|
||||
|
||||
// zero-indexed
|
||||
sectionPartition = 1
|
||||
sectionService = 2
|
||||
sectionRegion = 3
|
||||
sectionAccountID = 4
|
||||
sectionResource = 5
|
||||
|
||||
// errors
|
||||
invalidPrefix = "arn: invalid prefix"
|
||||
invalidSections = "arn: not enough sections"
|
||||
)
|
||||
|
||||
// ARN captures the individual fields of an Amazon Resource Name.
|
||||
// See http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html for more information.
|
||||
type ARN struct {
|
||||
// The partition that the resource is in. For standard AWS regions, the partition is "aws". If you have resources in
|
||||
// other partitions, the partition is "aws-partitionname". For example, the partition for resources in the China
|
||||
// (Beijing) region is "aws-cn".
|
||||
Partition string
|
||||
|
||||
// The service namespace that identifies the AWS product (for example, Amazon S3, IAM, or Amazon RDS). For a list of
|
||||
// namespaces, see
|
||||
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#genref-aws-service-namespaces.
|
||||
Service string
|
||||
|
||||
// The region the resource resides in. Note that the ARNs for some resources do not require a region, so this
|
||||
// component might be omitted.
|
||||
Region string
|
||||
|
||||
// The ID of the AWS account that owns the resource, without the hyphens. For example, 123456789012. Note that the
|
||||
// ARNs for some resources don't require an account number, so this component might be omitted.
|
||||
AccountID string
|
||||
|
||||
// The content of this part of the ARN varies by service. It often includes an indicator of the type of resource —
|
||||
// for example, an IAM user or Amazon RDS database - followed by a slash (/) or a colon (:), followed by the
|
||||
// resource name itself. Some services allows paths for resource names, as described in
|
||||
// http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html#arns-paths.
|
||||
Resource string
|
||||
}
|
||||
|
||||
// Parse parses an ARN into its constituent parts.
|
||||
//
|
||||
// Some example ARNs:
|
||||
// arn:aws:elasticbeanstalk:us-east-1:123456789012:environment/My App/MyEnvironment
|
||||
// arn:aws:iam::123456789012:user/David
|
||||
// arn:aws:rds:eu-west-1:123456789012:db:mysql-db
|
||||
// arn:aws:s3:::my_corporate_bucket/exampleobject.png
|
||||
func Parse(arn string) (ARN, error) {
|
||||
if !strings.HasPrefix(arn, arnPrefix) {
|
||||
return ARN{}, errors.New(invalidPrefix)
|
||||
}
|
||||
sections := strings.SplitN(arn, arnDelimiter, arnSections)
|
||||
if len(sections) != arnSections {
|
||||
return ARN{}, errors.New(invalidSections)
|
||||
}
|
||||
return ARN{
|
||||
Partition: sections[sectionPartition],
|
||||
Service: sections[sectionService],
|
||||
Region: sections[sectionRegion],
|
||||
AccountID: sections[sectionAccountID],
|
||||
Resource: sections[sectionResource],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsARN returns whether the given string is an arn
|
||||
// by looking for whether the string starts with arn:
|
||||
func IsARN(arn string) bool {
|
||||
return strings.HasPrefix(arn, arnPrefix) && strings.Count(arn, ":") >= arnSections-1
|
||||
}
|
||||
|
||||
// String returns the canonical representation of the ARN
|
||||
func (arn ARN) String() string {
|
||||
return arnPrefix +
|
||||
arn.Partition + arnDelimiter +
|
||||
arn.Service + arnDelimiter +
|
||||
arn.Region + arnDelimiter +
|
||||
arn.AccountID + arnDelimiter +
|
||||
arn.Resource
|
||||
}
|
@ -0,0 +1,179 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
smithybearer "github.com/aws/smithy-go/auth/bearer"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
)
|
||||
|
||||
// HTTPClient provides the interface to provide custom HTTPClients. Generally
|
||||
// *http.Client is sufficient for most use cases. The HTTPClient should not
|
||||
// follow 301 or 302 redirects.
|
||||
type HTTPClient interface {
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// A Config provides service configuration for service clients.
|
||||
type Config struct {
|
||||
// The region to send requests to. This parameter is required and must
|
||||
// be configured globally or on a per-client basis unless otherwise
|
||||
// noted. A full list of regions is found in the "Regions and Endpoints"
|
||||
// document.
|
||||
//
|
||||
// See http://docs.aws.amazon.com/general/latest/gr/rande.html for
|
||||
// information on AWS regions.
|
||||
Region string
|
||||
|
||||
// The credentials object to use when signing requests.
|
||||
// Use the LoadDefaultConfig to load configuration from all the SDK's supported
|
||||
// sources, and resolve credentials using the SDK's default credential chain.
|
||||
Credentials CredentialsProvider
|
||||
|
||||
// The Bearer Authentication token provider to use for authenticating API
|
||||
// operation calls with a Bearer Authentication token. The API clients and
|
||||
// operation must support Bearer Authentication scheme in order for the
|
||||
// token provider to be used. API clients created with NewFromConfig will
|
||||
// automatically be configured with this option, if the API client support
|
||||
// Bearer Authentication.
|
||||
//
|
||||
// The SDK's config.LoadDefaultConfig can automatically populate this
|
||||
// option for external configuration options such as SSO session.
|
||||
// https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-sso.html
|
||||
BearerAuthTokenProvider smithybearer.TokenProvider
|
||||
|
||||
// The HTTP Client the SDK's API clients will use to invoke HTTP requests.
|
||||
// The SDK defaults to a BuildableClient allowing API clients to create
|
||||
// copies of the HTTP Client for service specific customizations.
|
||||
//
|
||||
// Use a (*http.Client) for custom behavior. Using a custom http.Client
|
||||
// will prevent the SDK from modifying the HTTP client.
|
||||
HTTPClient HTTPClient
|
||||
|
||||
// An endpoint resolver that can be used to provide or override an endpoint
|
||||
// for the given service and region.
|
||||
//
|
||||
// See the `aws.EndpointResolver` documentation for additional usage
|
||||
// information.
|
||||
//
|
||||
// Deprecated: See Config.EndpointResolverWithOptions
|
||||
EndpointResolver EndpointResolver
|
||||
|
||||
// An endpoint resolver that can be used to provide or override an endpoint
|
||||
// for the given service and region.
|
||||
//
|
||||
// When EndpointResolverWithOptions is specified, it will be used by a
|
||||
// service client rather than using EndpointResolver if also specified.
|
||||
//
|
||||
// See the `aws.EndpointResolverWithOptions` documentation for additional
|
||||
// usage information.
|
||||
EndpointResolverWithOptions EndpointResolverWithOptions
|
||||
|
||||
// RetryMaxAttempts specifies the maximum number attempts an API client
|
||||
// will call an operation that fails with a retryable error.
|
||||
//
|
||||
// API Clients will only use this value to construct a retryer if the
|
||||
// Config.Retryer member is not nil. This value will be ignored if
|
||||
// Retryer is not nil.
|
||||
RetryMaxAttempts int
|
||||
|
||||
// RetryMode specifies the retry model the API client will be created with.
|
||||
//
|
||||
// API Clients will only use this value to construct a retryer if the
|
||||
// Config.Retryer member is not nil. This value will be ignored if
|
||||
// Retryer is not nil.
|
||||
RetryMode RetryMode
|
||||
|
||||
// Retryer is a function that provides a Retryer implementation. A Retryer
|
||||
// guides how HTTP requests should be retried in case of recoverable
|
||||
// failures. When nil the API client will use a default retryer.
|
||||
//
|
||||
// In general, the provider function should return a new instance of a
|
||||
// Retryer if you are attempting to provide a consistent Retryer
|
||||
// configuration across all clients. This will ensure that each client will
|
||||
// be provided a new instance of the Retryer implementation, and will avoid
|
||||
// issues such as sharing the same retry token bucket across services.
|
||||
//
|
||||
// If not nil, RetryMaxAttempts, and RetryMode will be ignored by API
|
||||
// clients.
|
||||
Retryer func() Retryer
|
||||
|
||||
// ConfigSources are the sources that were used to construct the Config.
|
||||
// Allows for additional configuration to be loaded by clients.
|
||||
ConfigSources []interface{}
|
||||
|
||||
// APIOptions provides the set of middleware mutations modify how the API
|
||||
// client requests will be handled. This is useful for adding additional
|
||||
// tracing data to a request, or changing behavior of the SDK's client.
|
||||
APIOptions []func(*middleware.Stack) error
|
||||
|
||||
// The logger writer interface to write logging messages to. Defaults to
|
||||
// standard error.
|
||||
Logger logging.Logger
|
||||
|
||||
// Configures the events that will be sent to the configured logger. This
|
||||
// can be used to configure the logging of signing, retries, request, and
|
||||
// responses of the SDK clients.
|
||||
//
|
||||
// See the ClientLogMode type documentation for the complete set of logging
|
||||
// modes and available configuration.
|
||||
ClientLogMode ClientLogMode
|
||||
|
||||
// The configured DefaultsMode. If not specified, service clients will
|
||||
// default to legacy.
|
||||
//
|
||||
// Supported modes are: auto, cross-region, in-region, legacy, mobile,
|
||||
// standard
|
||||
DefaultsMode DefaultsMode
|
||||
|
||||
// The RuntimeEnvironment configuration, only populated if the DefaultsMode
|
||||
// is set to DefaultsModeAuto and is initialized by
|
||||
// `config.LoadDefaultConfig`. You should not populate this structure
|
||||
// programmatically, or rely on the values here within your applications.
|
||||
RuntimeEnvironment RuntimeEnvironment
|
||||
}
|
||||
|
||||
// NewConfig returns a new Config pointer that can be chained with builder
|
||||
// methods to set multiple configuration values inline without using pointers.
|
||||
func NewConfig() *Config {
|
||||
return &Config{}
|
||||
}
|
||||
|
||||
// Copy will return a shallow copy of the Config object. If any additional
|
||||
// configurations are provided they will be merged into the new config returned.
|
||||
func (c Config) Copy() Config {
|
||||
cp := c
|
||||
return cp
|
||||
}
|
||||
|
||||
// EndpointDiscoveryEnableState indicates if endpoint discovery is
|
||||
// enabled, disabled, auto or unset state.
|
||||
//
|
||||
// Default behavior (Auto or Unset) indicates operations that require endpoint
|
||||
// discovery will use Endpoint Discovery by default. Operations that
|
||||
// optionally use Endpoint Discovery will not use Endpoint Discovery
|
||||
// unless EndpointDiscovery is explicitly enabled.
|
||||
type EndpointDiscoveryEnableState uint
|
||||
|
||||
// Enumeration values for EndpointDiscoveryEnableState
|
||||
const (
|
||||
// EndpointDiscoveryUnset represents EndpointDiscoveryEnableState is unset.
|
||||
// Users do not need to use this value explicitly. The behavior for unset
|
||||
// is the same as for EndpointDiscoveryAuto.
|
||||
EndpointDiscoveryUnset EndpointDiscoveryEnableState = iota
|
||||
|
||||
// EndpointDiscoveryAuto represents an AUTO state that allows endpoint
|
||||
// discovery only when required by the api. This is the default
|
||||
// configuration resolved by the client if endpoint discovery is neither
|
||||
// enabled or disabled.
|
||||
EndpointDiscoveryAuto // default state
|
||||
|
||||
// EndpointDiscoveryDisabled indicates client MUST not perform endpoint
|
||||
// discovery even when required.
|
||||
EndpointDiscoveryDisabled
|
||||
|
||||
// EndpointDiscoveryEnabled indicates client MUST always perform endpoint
|
||||
// discovery if supported for the operation.
|
||||
EndpointDiscoveryEnabled
|
||||
)
|
@ -0,0 +1,22 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type suppressedContext struct {
|
||||
context.Context
|
||||
}
|
||||
|
||||
func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func (s *suppressedContext) Done() <-chan struct{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *suppressedContext) Err() error {
|
||||
return nil
|
||||
}
|
@ -0,0 +1,224 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
|
||||
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
|
||||
)
|
||||
|
||||
// CredentialsCacheOptions are the options
|
||||
type CredentialsCacheOptions struct {
|
||||
|
||||
// ExpiryWindow will allow the credentials to trigger refreshing prior to
|
||||
// the credentials actually expiring. This is beneficial so race conditions
|
||||
// with expiring credentials do not cause request to fail unexpectedly
|
||||
// due to ExpiredTokenException exceptions.
|
||||
//
|
||||
// An ExpiryWindow of 10s would cause calls to IsExpired() to return true
|
||||
// 10 seconds before the credentials are actually expired. This can cause an
|
||||
// increased number of requests to refresh the credentials to occur.
|
||||
//
|
||||
// If ExpiryWindow is 0 or less it will be ignored.
|
||||
ExpiryWindow time.Duration
|
||||
|
||||
// ExpiryWindowJitterFrac provides a mechanism for randomizing the
|
||||
// expiration of credentials within the configured ExpiryWindow by a random
|
||||
// percentage. Valid values are between 0.0 and 1.0.
|
||||
//
|
||||
// As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac
|
||||
// is 0.5 then credentials will be set to expire between 30 to 60 seconds
|
||||
// prior to their actual expiration time.
|
||||
//
|
||||
// If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored.
|
||||
// If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window.
|
||||
// If ExpiryWindowJitterFrac < 0 the value will be treated as 0.
|
||||
// If ExpiryWindowJitterFrac > 1 the value will be treated as 1.
|
||||
ExpiryWindowJitterFrac float64
|
||||
}
|
||||
|
||||
// CredentialsCache provides caching and concurrency safe credentials retrieval
|
||||
// via the provider's retrieve method.
|
||||
//
|
||||
// CredentialsCache will look for optional interfaces on the Provider to adjust
|
||||
// how the credential cache handles credentials caching.
|
||||
//
|
||||
// - HandleFailRefreshCredentialsCacheStrategy - Allows provider to handle
|
||||
// credential refresh failures. This could return an updated Credentials
|
||||
// value, or attempt another means of retrieving credentials.
|
||||
//
|
||||
// - AdjustExpiresByCredentialsCacheStrategy - Allows provider to adjust how
|
||||
// credentials Expires is modified. This could modify how the Credentials
|
||||
// Expires is adjusted based on the CredentialsCache ExpiryWindow option.
|
||||
// Such as providing a floor not to reduce the Expires below.
|
||||
type CredentialsCache struct {
|
||||
provider CredentialsProvider
|
||||
|
||||
options CredentialsCacheOptions
|
||||
creds atomic.Value
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider
|
||||
// is expected to not be nil. A variadic list of one or more functions can be
|
||||
// provided to modify the CredentialsCache configuration. This allows for
|
||||
// configuration of credential expiry window and jitter.
|
||||
func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache {
|
||||
options := CredentialsCacheOptions{}
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
|
||||
if options.ExpiryWindow < 0 {
|
||||
options.ExpiryWindow = 0
|
||||
}
|
||||
|
||||
if options.ExpiryWindowJitterFrac < 0 {
|
||||
options.ExpiryWindowJitterFrac = 0
|
||||
} else if options.ExpiryWindowJitterFrac > 1 {
|
||||
options.ExpiryWindowJitterFrac = 1
|
||||
}
|
||||
|
||||
return &CredentialsCache{
|
||||
provider: provider,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve returns the credentials. If the credentials have already been
|
||||
// retrieved, and not expired the cached credentials will be returned. If the
|
||||
// credentials have not been retrieved yet, or expired the provider's Retrieve
|
||||
// method will be called.
|
||||
//
|
||||
// Returns and error if the provider's retrieve method returns an error.
|
||||
func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
|
||||
if creds, ok := p.getCreds(); ok && !creds.Expired() {
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
resCh := p.sf.DoChan("", func() (interface{}, error) {
|
||||
return p.singleRetrieve(&suppressedContext{ctx})
|
||||
})
|
||||
select {
|
||||
case res := <-resCh:
|
||||
return res.Val.(Credentials), res.Err
|
||||
case <-ctx.Done():
|
||||
return Credentials{}, &RequestCanceledError{Err: ctx.Err()}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) {
|
||||
currCreds, ok := p.getCreds()
|
||||
if ok && !currCreds.Expired() {
|
||||
return currCreds, nil
|
||||
}
|
||||
|
||||
newCreds, err := p.provider.Retrieve(ctx)
|
||||
if err != nil {
|
||||
handleFailToRefresh := defaultHandleFailToRefresh
|
||||
if cs, ok := p.provider.(HandleFailRefreshCredentialsCacheStrategy); ok {
|
||||
handleFailToRefresh = cs.HandleFailToRefresh
|
||||
}
|
||||
newCreds, err = handleFailToRefresh(ctx, currCreds, err)
|
||||
if err != nil {
|
||||
return Credentials{}, fmt.Errorf("failed to refresh cached credentials, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if newCreds.CanExpire && p.options.ExpiryWindow > 0 {
|
||||
adjustExpiresBy := defaultAdjustExpiresBy
|
||||
if cs, ok := p.provider.(AdjustExpiresByCredentialsCacheStrategy); ok {
|
||||
adjustExpiresBy = cs.AdjustExpiresBy
|
||||
}
|
||||
|
||||
randFloat64, err := sdkrand.CryptoRandFloat64()
|
||||
if err != nil {
|
||||
return Credentials{}, fmt.Errorf("failed to get random provider, %w", err)
|
||||
}
|
||||
|
||||
var jitter time.Duration
|
||||
if p.options.ExpiryWindowJitterFrac > 0 {
|
||||
jitter = time.Duration(randFloat64 *
|
||||
p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow))
|
||||
}
|
||||
|
||||
newCreds, err = adjustExpiresBy(newCreds, -(p.options.ExpiryWindow - jitter))
|
||||
if err != nil {
|
||||
return Credentials{}, fmt.Errorf("failed to adjust credentials expires, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
p.creds.Store(&newCreds)
|
||||
return newCreds, nil
|
||||
}
|
||||
|
||||
// getCreds returns the currently stored credentials and true. Returning false
|
||||
// if no credentials were stored.
|
||||
func (p *CredentialsCache) getCreds() (Credentials, bool) {
|
||||
v := p.creds.Load()
|
||||
if v == nil {
|
||||
return Credentials{}, false
|
||||
}
|
||||
|
||||
c := v.(*Credentials)
|
||||
if c == nil || !c.HasKeys() {
|
||||
return Credentials{}, false
|
||||
}
|
||||
|
||||
return *c, true
|
||||
}
|
||||
|
||||
// Invalidate will invalidate the cached credentials. The next call to Retrieve
|
||||
// will cause the provider's Retrieve method to be called.
|
||||
func (p *CredentialsCache) Invalidate() {
|
||||
p.creds.Store((*Credentials)(nil))
|
||||
}
|
||||
|
||||
// IsCredentialsProvider returns whether credential provider wrapped by CredentialsCache
|
||||
// matches the target provider type.
|
||||
func (p *CredentialsCache) IsCredentialsProvider(target CredentialsProvider) bool {
|
||||
return IsCredentialsProvider(p.provider, target)
|
||||
}
|
||||
|
||||
// HandleFailRefreshCredentialsCacheStrategy is an interface for
|
||||
// CredentialsCache to allow CredentialsProvider how failed to refresh
|
||||
// credentials is handled.
|
||||
type HandleFailRefreshCredentialsCacheStrategy interface {
|
||||
// Given the previously cached Credentials, if any, and refresh error, may
|
||||
// returns new or modified set of Credentials, or error.
|
||||
//
|
||||
// Credential caches may use default implementation if nil.
|
||||
HandleFailToRefresh(context.Context, Credentials, error) (Credentials, error)
|
||||
}
|
||||
|
||||
// defaultHandleFailToRefresh returns the passed in error.
|
||||
func defaultHandleFailToRefresh(ctx context.Context, _ Credentials, err error) (Credentials, error) {
|
||||
return Credentials{}, err
|
||||
}
|
||||
|
||||
// AdjustExpiresByCredentialsCacheStrategy is an interface for CredentialCache
|
||||
// to allow CredentialsProvider to intercept adjustments to Credentials expiry
|
||||
// based on expectations and use cases of CredentialsProvider.
|
||||
//
|
||||
// Credential caches may use default implementation if nil.
|
||||
type AdjustExpiresByCredentialsCacheStrategy interface {
|
||||
// Given a Credentials as input, applying any mutations and
|
||||
// returning the potentially updated Credentials, or error.
|
||||
AdjustExpiresBy(Credentials, time.Duration) (Credentials, error)
|
||||
}
|
||||
|
||||
// defaultAdjustExpiresBy adds the duration to the passed in credentials Expires,
|
||||
// and returns the updated credentials value. If Credentials value's CanExpire
|
||||
// is false, the passed in credentials are returned unchanged.
|
||||
func defaultAdjustExpiresBy(creds Credentials, dur time.Duration) (Credentials, error) {
|
||||
if !creds.CanExpire {
|
||||
return creds, nil
|
||||
}
|
||||
|
||||
creds.Expires = creds.Expires.Add(dur)
|
||||
return creds, nil
|
||||
}
|
@ -0,0 +1,170 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/internal/sdk"
|
||||
)
|
||||
|
||||
// AnonymousCredentials provides a sentinel CredentialsProvider that should be
|
||||
// used to instruct the SDK's signing middleware to not sign the request.
|
||||
//
|
||||
// Using `nil` credentials when configuring an API client will achieve the same
|
||||
// result. The AnonymousCredentials type allows you to configure the SDK's
|
||||
// external config loading to not attempt to source credentials from the shared
|
||||
// config or environment.
|
||||
//
|
||||
// For example you can use this CredentialsProvider with an API client's
|
||||
// Options to instruct the client not to sign a request for accessing public
|
||||
// S3 bucket objects.
|
||||
//
|
||||
// The following example demonstrates using the AnonymousCredentials to prevent
|
||||
// SDK's external config loading attempt to resolve credentials.
|
||||
//
|
||||
// cfg, err := config.LoadDefaultConfig(context.TODO(),
|
||||
// config.WithCredentialsProvider(aws.AnonymousCredentials{}),
|
||||
// )
|
||||
// if err != nil {
|
||||
// log.Fatalf("failed to load config, %v", err)
|
||||
// }
|
||||
//
|
||||
// client := s3.NewFromConfig(cfg)
|
||||
//
|
||||
// Alternatively you can leave the API client Option's `Credential` member to
|
||||
// nil. If using the `NewFromConfig` constructor you'll need to explicitly set
|
||||
// the `Credentials` member to nil, if the external config resolved a
|
||||
// credential provider.
|
||||
//
|
||||
// client := s3.New(s3.Options{
|
||||
// // Credentials defaults to a nil value.
|
||||
// })
|
||||
//
|
||||
// This can also be configured for specific operations calls too.
|
||||
//
|
||||
// cfg, err := config.LoadDefaultConfig(context.TODO())
|
||||
// if err != nil {
|
||||
// log.Fatalf("failed to load config, %v", err)
|
||||
// }
|
||||
//
|
||||
// client := s3.NewFromConfig(config)
|
||||
//
|
||||
// result, err := client.GetObject(context.TODO(), s3.GetObject{
|
||||
// Bucket: aws.String("example-bucket"),
|
||||
// Key: aws.String("example-key"),
|
||||
// }, func(o *s3.Options) {
|
||||
// o.Credentials = nil
|
||||
// // Or
|
||||
// o.Credentials = aws.AnonymousCredentials{}
|
||||
// })
|
||||
type AnonymousCredentials struct{}
|
||||
|
||||
// Retrieve implements the CredentialsProvider interface, but will always
|
||||
// return error, and cannot be used to sign a request. The AnonymousCredentials
|
||||
// type is used as a sentinel type instructing the AWS request signing
|
||||
// middleware to not sign a request.
|
||||
func (AnonymousCredentials) Retrieve(context.Context) (Credentials, error) {
|
||||
return Credentials{Source: "AnonymousCredentials"},
|
||||
fmt.Errorf("the AnonymousCredentials is not a valid credential provider, and cannot be used to sign AWS requests with")
|
||||
}
|
||||
|
||||
// A Credentials is the AWS credentials value for individual credential fields.
|
||||
type Credentials struct {
|
||||
// AWS Access key ID
|
||||
AccessKeyID string
|
||||
|
||||
// AWS Secret Access Key
|
||||
SecretAccessKey string
|
||||
|
||||
// AWS Session Token
|
||||
SessionToken string
|
||||
|
||||
// Source of the credentials
|
||||
Source string
|
||||
|
||||
// States if the credentials can expire or not.
|
||||
CanExpire bool
|
||||
|
||||
// The time the credentials will expire at. Should be ignored if CanExpire
|
||||
// is false.
|
||||
Expires time.Time
|
||||
}
|
||||
|
||||
// Expired returns if the credentials have expired.
|
||||
func (v Credentials) Expired() bool {
|
||||
if v.CanExpire {
|
||||
// Calling Round(0) on the current time will truncate the monotonic
|
||||
// reading only. Ensures credential expiry time is always based on
|
||||
// reported wall-clock time.
|
||||
return !v.Expires.After(sdk.NowTime().Round(0))
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HasKeys returns if the credentials keys are set.
|
||||
func (v Credentials) HasKeys() bool {
|
||||
return len(v.AccessKeyID) > 0 && len(v.SecretAccessKey) > 0
|
||||
}
|
||||
|
||||
// A CredentialsProvider is the interface for any component which will provide
|
||||
// credentials Credentials. A CredentialsProvider is required to manage its own
|
||||
// Expired state, and what to be expired means.
|
||||
//
|
||||
// A credentials provider implementation can be wrapped with a CredentialCache
|
||||
// to cache the credential value retrieved. Without the cache the SDK will
|
||||
// attempt to retrieve the credentials for every request.
|
||||
type CredentialsProvider interface {
|
||||
// Retrieve returns nil if it successfully retrieved the value.
|
||||
// Error is returned if the value were not obtainable, or empty.
|
||||
Retrieve(ctx context.Context) (Credentials, error)
|
||||
}
|
||||
|
||||
// CredentialsProviderFunc provides a helper wrapping a function value to
|
||||
// satisfy the CredentialsProvider interface.
|
||||
type CredentialsProviderFunc func(context.Context) (Credentials, error)
|
||||
|
||||
// Retrieve delegates to the function value the CredentialsProviderFunc wraps.
|
||||
func (fn CredentialsProviderFunc) Retrieve(ctx context.Context) (Credentials, error) {
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
type isCredentialsProvider interface {
|
||||
IsCredentialsProvider(CredentialsProvider) bool
|
||||
}
|
||||
|
||||
// IsCredentialsProvider returns whether the target CredentialProvider is the same type as provider when comparing the
|
||||
// implementation type.
|
||||
//
|
||||
// If provider has a method IsCredentialsProvider(CredentialsProvider) bool it will be responsible for validating
|
||||
// whether target matches the credential provider type.
|
||||
//
|
||||
// When comparing the CredentialProvider implementations provider and target for equality, the following rules are used:
|
||||
//
|
||||
// If provider is of type T and target is of type V, true if type *T is the same as type *V, otherwise false
|
||||
// If provider is of type *T and target is of type V, true if type *T is the same as type *V, otherwise false
|
||||
// If provider is of type T and target is of type *V, true if type *T is the same as type *V, otherwise false
|
||||
// If provider is of type *T and target is of type *V,true if type *T is the same as type *V, otherwise false
|
||||
func IsCredentialsProvider(provider, target CredentialsProvider) bool {
|
||||
if target == nil || provider == nil {
|
||||
return provider == target
|
||||
}
|
||||
|
||||
if x, ok := provider.(isCredentialsProvider); ok {
|
||||
return x.IsCredentialsProvider(target)
|
||||
}
|
||||
|
||||
targetType := reflect.TypeOf(target)
|
||||
if targetType.Kind() != reflect.Ptr {
|
||||
targetType = reflect.PtrTo(targetType)
|
||||
}
|
||||
|
||||
providerType := reflect.TypeOf(provider)
|
||||
if providerType.Kind() != reflect.Ptr {
|
||||
providerType = reflect.PtrTo(providerType)
|
||||
}
|
||||
|
||||
return targetType.AssignableTo(providerType)
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
package defaults
|
||||
|
||||
import (
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var getGOOS = func() string {
|
||||
return runtime.GOOS
|
||||
}
|
||||
|
||||
// ResolveDefaultsModeAuto is used to determine the effective aws.DefaultsMode when the mode
|
||||
// is set to aws.DefaultsModeAuto.
|
||||
func ResolveDefaultsModeAuto(region string, environment aws.RuntimeEnvironment) aws.DefaultsMode {
|
||||
goos := getGOOS()
|
||||
if goos == "android" || goos == "ios" {
|
||||
return aws.DefaultsModeMobile
|
||||
}
|
||||
|
||||
var currentRegion string
|
||||
if len(environment.EnvironmentIdentifier) > 0 {
|
||||
currentRegion = environment.Region
|
||||
}
|
||||
|
||||
if len(currentRegion) == 0 && len(environment.EC2InstanceMetadataRegion) > 0 {
|
||||
currentRegion = environment.EC2InstanceMetadataRegion
|
||||
}
|
||||
|
||||
if len(region) > 0 && len(currentRegion) > 0 {
|
||||
if strings.EqualFold(region, currentRegion) {
|
||||
return aws.DefaultsModeInRegion
|
||||
}
|
||||
return aws.DefaultsModeCrossRegion
|
||||
}
|
||||
|
||||
return aws.DefaultsModeStandard
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
package defaults
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
)
|
||||
|
||||
// Configuration is the set of SDK configuration options that are determined based
|
||||
// on the configured DefaultsMode.
|
||||
type Configuration struct {
|
||||
// RetryMode is the configuration's default retry mode API clients should
|
||||
// use for constructing a Retryer.
|
||||
RetryMode aws.RetryMode
|
||||
|
||||
// ConnectTimeout is the maximum amount of time a dial will wait for
|
||||
// a connect to complete.
|
||||
//
|
||||
// See https://pkg.go.dev/net#Dialer.Timeout
|
||||
ConnectTimeout *time.Duration
|
||||
|
||||
// TLSNegotiationTimeout specifies the maximum amount of time waiting to
|
||||
// wait for a TLS handshake.
|
||||
//
|
||||
// See https://pkg.go.dev/net/http#Transport.TLSHandshakeTimeout
|
||||
TLSNegotiationTimeout *time.Duration
|
||||
}
|
||||
|
||||
// GetConnectTimeout returns the ConnectTimeout value, returns false if the value is not set.
|
||||
func (c *Configuration) GetConnectTimeout() (time.Duration, bool) {
|
||||
if c.ConnectTimeout == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *c.ConnectTimeout, true
|
||||
}
|
||||
|
||||
// GetTLSNegotiationTimeout returns the TLSNegotiationTimeout value, returns false if the value is not set.
|
||||
func (c *Configuration) GetTLSNegotiationTimeout() (time.Duration, bool) {
|
||||
if c.TLSNegotiationTimeout == nil {
|
||||
return 0, false
|
||||
}
|
||||
return *c.TLSNegotiationTimeout, true
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
// Code generated by github.com/aws/aws-sdk-go-v2/internal/codegen/cmd/defaultsconfig. DO NOT EDIT.
|
||||
|
||||
package defaults
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GetModeConfiguration returns the default Configuration descriptor for the given mode.
|
||||
//
|
||||
// Supports the following modes: cross-region, in-region, mobile, standard
|
||||
func GetModeConfiguration(mode aws.DefaultsMode) (Configuration, error) {
|
||||
var mv aws.DefaultsMode
|
||||
mv.SetFromString(string(mode))
|
||||
|
||||
switch mv {
|
||||
case aws.DefaultsModeCrossRegion:
|
||||
settings := Configuration{
|
||||
ConnectTimeout: aws.Duration(3100 * time.Millisecond),
|
||||
RetryMode: aws.RetryMode("standard"),
|
||||
TLSNegotiationTimeout: aws.Duration(3100 * time.Millisecond),
|
||||
}
|
||||
return settings, nil
|
||||
case aws.DefaultsModeInRegion:
|
||||
settings := Configuration{
|
||||
ConnectTimeout: aws.Duration(1100 * time.Millisecond),
|
||||
RetryMode: aws.RetryMode("standard"),
|
||||
TLSNegotiationTimeout: aws.Duration(1100 * time.Millisecond),
|
||||
}
|
||||
return settings, nil
|
||||
case aws.DefaultsModeMobile:
|
||||
settings := Configuration{
|
||||
ConnectTimeout: aws.Duration(30000 * time.Millisecond),
|
||||
RetryMode: aws.RetryMode("standard"),
|
||||
TLSNegotiationTimeout: aws.Duration(30000 * time.Millisecond),
|
||||
}
|
||||
return settings, nil
|
||||
case aws.DefaultsModeStandard:
|
||||
settings := Configuration{
|
||||
ConnectTimeout: aws.Duration(3100 * time.Millisecond),
|
||||
RetryMode: aws.RetryMode("standard"),
|
||||
TLSNegotiationTimeout: aws.Duration(3100 * time.Millisecond),
|
||||
}
|
||||
return settings, nil
|
||||
default:
|
||||
return Configuration{}, fmt.Errorf("unsupported defaults mode: %v", mode)
|
||||
}
|
||||
}
|
@ -0,0 +1,2 @@
|
||||
// Package defaults provides recommended configuration values for AWS SDKs and CLIs.
|
||||
package defaults
|
@ -0,0 +1,95 @@
|
||||
// Code generated by github.com/aws/aws-sdk-go-v2/internal/codegen/cmd/defaultsmode. DO NOT EDIT.
|
||||
|
||||
package aws
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DefaultsMode is the SDK defaults mode setting.
|
||||
type DefaultsMode string
|
||||
|
||||
// The DefaultsMode constants.
|
||||
const (
|
||||
// DefaultsModeAuto is an experimental mode that builds on the standard mode.
|
||||
// The SDK will attempt to discover the execution environment to determine the
|
||||
// appropriate settings automatically.
|
||||
//
|
||||
// Note that the auto detection is heuristics-based and does not guarantee 100%
|
||||
// accuracy. STANDARD mode will be used if the execution environment cannot
|
||||
// be determined. The auto detection might query EC2 Instance Metadata service
|
||||
// (https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html),
|
||||
// which might introduce latency. Therefore we recommend choosing an explicit
|
||||
// defaults_mode instead if startup latency is critical to your application
|
||||
DefaultsModeAuto DefaultsMode = "auto"
|
||||
|
||||
// DefaultsModeCrossRegion builds on the standard mode and includes optimization
|
||||
// tailored for applications which call AWS services in a different region
|
||||
//
|
||||
// Note that the default values vended from this mode might change as best practices
|
||||
// may evolve. As a result, it is encouraged to perform tests when upgrading
|
||||
// the SDK
|
||||
DefaultsModeCrossRegion DefaultsMode = "cross-region"
|
||||
|
||||
// DefaultsModeInRegion builds on the standard mode and includes optimization
|
||||
// tailored for applications which call AWS services from within the same AWS
|
||||
// region
|
||||
//
|
||||
// Note that the default values vended from this mode might change as best practices
|
||||
// may evolve. As a result, it is encouraged to perform tests when upgrading
|
||||
// the SDK
|
||||
DefaultsModeInRegion DefaultsMode = "in-region"
|
||||
|
||||
// DefaultsModeLegacy provides default settings that vary per SDK and were used
|
||||
// prior to establishment of defaults_mode
|
||||
DefaultsModeLegacy DefaultsMode = "legacy"
|
||||
|
||||
// DefaultsModeMobile builds on the standard mode and includes optimization
|
||||
// tailored for mobile applications
|
||||
//
|
||||
// Note that the default values vended from this mode might change as best practices
|
||||
// may evolve. As a result, it is encouraged to perform tests when upgrading
|
||||
// the SDK
|
||||
DefaultsModeMobile DefaultsMode = "mobile"
|
||||
|
||||
// DefaultsModeStandard provides the latest recommended default values that
|
||||
// should be safe to run in most scenarios
|
||||
//
|
||||
// Note that the default values vended from this mode might change as best practices
|
||||
// may evolve. As a result, it is encouraged to perform tests when upgrading
|
||||
// the SDK
|
||||
DefaultsModeStandard DefaultsMode = "standard"
|
||||
)
|
||||
|
||||
// SetFromString sets the DefaultsMode value to one of the pre-defined constants that matches
|
||||
// the provided string when compared using EqualFold. If the value does not match a known
|
||||
// constant it will be set to as-is and the function will return false. As a special case, if the
|
||||
// provided value is a zero-length string, the mode will be set to LegacyDefaultsMode.
|
||||
func (d *DefaultsMode) SetFromString(v string) (ok bool) {
|
||||
switch {
|
||||
case strings.EqualFold(v, string(DefaultsModeAuto)):
|
||||
*d = DefaultsModeAuto
|
||||
ok = true
|
||||
case strings.EqualFold(v, string(DefaultsModeCrossRegion)):
|
||||
*d = DefaultsModeCrossRegion
|
||||
ok = true
|
||||
case strings.EqualFold(v, string(DefaultsModeInRegion)):
|
||||
*d = DefaultsModeInRegion
|
||||
ok = true
|
||||
case strings.EqualFold(v, string(DefaultsModeLegacy)):
|
||||
*d = DefaultsModeLegacy
|
||||
ok = true
|
||||
case strings.EqualFold(v, string(DefaultsModeMobile)):
|
||||
*d = DefaultsModeMobile
|
||||
ok = true
|
||||
case strings.EqualFold(v, string(DefaultsModeStandard)):
|
||||
*d = DefaultsModeStandard
|
||||
ok = true
|
||||
case len(v) == 0:
|
||||
*d = DefaultsModeLegacy
|
||||
ok = true
|
||||
default:
|
||||
*d = DefaultsMode(v)
|
||||
}
|
||||
return ok
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
// Package aws provides the core SDK's utilities and shared types. Use this package's
|
||||
// utilities to simplify setting and reading API operations parameters.
|
||||
//
|
||||
// # Value and Pointer Conversion Utilities
|
||||
//
|
||||
// This package includes a helper conversion utility for each scalar type the SDK's
|
||||
// API use. These utilities make getting a pointer of the scalar, and dereferencing
|
||||
// a pointer easier.
|
||||
//
|
||||
// Each conversion utility comes in two forms. Value to Pointer and Pointer to Value.
|
||||
// The Pointer to value will safely dereference the pointer and return its value.
|
||||
// If the pointer was nil, the scalar's zero value will be returned.
|
||||
//
|
||||
// The value to pointer functions will be named after the scalar type. So get a
|
||||
// *string from a string value use the "String" function. This makes it easy to
|
||||
// to get pointer of a literal string value, because getting the address of a
|
||||
// literal requires assigning the value to a variable first.
|
||||
//
|
||||
// var strPtr *string
|
||||
//
|
||||
// // Without the SDK's conversion functions
|
||||
// str := "my string"
|
||||
// strPtr = &str
|
||||
//
|
||||
// // With the SDK's conversion functions
|
||||
// strPtr = aws.String("my string")
|
||||
//
|
||||
// // Convert *string to string value
|
||||
// str = aws.ToString(strPtr)
|
||||
//
|
||||
// In addition to scalars the aws package also includes conversion utilities for
|
||||
// map and slice for commonly types used in API parameters. The map and slice
|
||||
// conversion functions use similar naming pattern as the scalar conversion
|
||||
// functions.
|
||||
//
|
||||
// var strPtrs []*string
|
||||
// var strs []string = []string{"Go", "Gophers", "Go"}
|
||||
//
|
||||
// // Convert []string to []*string
|
||||
// strPtrs = aws.StringSlice(strs)
|
||||
//
|
||||
// // Convert []*string to []string
|
||||
// strs = aws.ToStringSlice(strPtrs)
|
||||
//
|
||||
// # SDK Default HTTP Client
|
||||
//
|
||||
// The SDK will use the http.DefaultClient if a HTTP client is not provided to
|
||||
// the SDK's Session, or service client constructor. This means that if the
|
||||
// http.DefaultClient is modified by other components of your application the
|
||||
// modifications will be picked up by the SDK as well.
|
||||
//
|
||||
// In some cases this might be intended, but it is a better practice to create
|
||||
// a custom HTTP Client to share explicitly through your application. You can
|
||||
// configure the SDK to use the custom HTTP Client by setting the HTTPClient
|
||||
// value of the SDK's Config type when creating a Session or service client.
|
||||
package aws
|
||||
|
||||
// generate.go uses a build tag of "ignore", go run doesn't need to specify
|
||||
// this because go run ignores all build flags when running a go file directly.
|
||||
//go:generate go run -tags codegen generate.go
|
||||
//go:generate go run -tags codegen logging_generate.go
|
||||
//go:generate gofmt -w -s .
|
@ -0,0 +1,229 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// DualStackEndpointState is a constant to describe the dual-stack endpoint resolution behavior.
|
||||
type DualStackEndpointState uint
|
||||
|
||||
const (
|
||||
// DualStackEndpointStateUnset is the default value behavior for dual-stack endpoint resolution.
|
||||
DualStackEndpointStateUnset DualStackEndpointState = iota
|
||||
|
||||
// DualStackEndpointStateEnabled enables dual-stack endpoint resolution for service endpoints.
|
||||
DualStackEndpointStateEnabled
|
||||
|
||||
// DualStackEndpointStateDisabled disables dual-stack endpoint resolution for endpoints.
|
||||
DualStackEndpointStateDisabled
|
||||
)
|
||||
|
||||
// GetUseDualStackEndpoint takes a service's EndpointResolverOptions and returns the UseDualStackEndpoint value.
|
||||
// Returns boolean false if the provided options does not have a method to retrieve the DualStackEndpointState.
|
||||
func GetUseDualStackEndpoint(options ...interface{}) (value DualStackEndpointState, found bool) {
|
||||
type iface interface {
|
||||
GetUseDualStackEndpoint() DualStackEndpointState
|
||||
}
|
||||
for _, option := range options {
|
||||
if i, ok := option.(iface); ok {
|
||||
value = i.GetUseDualStackEndpoint()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return value, found
|
||||
}
|
||||
|
||||
// FIPSEndpointState is a constant to describe the FIPS endpoint resolution behavior.
|
||||
type FIPSEndpointState uint
|
||||
|
||||
const (
|
||||
// FIPSEndpointStateUnset is the default value behavior for FIPS endpoint resolution.
|
||||
FIPSEndpointStateUnset FIPSEndpointState = iota
|
||||
|
||||
// FIPSEndpointStateEnabled enables FIPS endpoint resolution for service endpoints.
|
||||
FIPSEndpointStateEnabled
|
||||
|
||||
// FIPSEndpointStateDisabled disables FIPS endpoint resolution for endpoints.
|
||||
FIPSEndpointStateDisabled
|
||||
)
|
||||
|
||||
// GetUseFIPSEndpoint takes a service's EndpointResolverOptions and returns the UseDualStackEndpoint value.
|
||||
// Returns boolean false if the provided options does not have a method to retrieve the DualStackEndpointState.
|
||||
func GetUseFIPSEndpoint(options ...interface{}) (value FIPSEndpointState, found bool) {
|
||||
type iface interface {
|
||||
GetUseFIPSEndpoint() FIPSEndpointState
|
||||
}
|
||||
for _, option := range options {
|
||||
if i, ok := option.(iface); ok {
|
||||
value = i.GetUseFIPSEndpoint()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return value, found
|
||||
}
|
||||
|
||||
// Endpoint represents the endpoint a service client should make API operation
|
||||
// calls to.
|
||||
//
|
||||
// The SDK will automatically resolve these endpoints per API client using an
|
||||
// internal endpoint resolvers. If you'd like to provide custom endpoint
|
||||
// resolving behavior you can implement the EndpointResolver interface.
|
||||
type Endpoint struct {
|
||||
// The base URL endpoint the SDK API clients will use to make API calls to.
|
||||
// The SDK will suffix URI path and query elements to this endpoint.
|
||||
URL string
|
||||
|
||||
// Specifies if the endpoint's hostname can be modified by the SDK's API
|
||||
// client.
|
||||
//
|
||||
// If the hostname is mutable the SDK API clients may modify any part of
|
||||
// the hostname based on the requirements of the API, (e.g. adding, or
|
||||
// removing content in the hostname). Such as, Amazon S3 API client
|
||||
// prefixing "bucketname" to the hostname, or changing the
|
||||
// hostname service name component from "s3." to "s3-accesspoint.dualstack."
|
||||
// for the dualstack endpoint of an S3 Accesspoint resource.
|
||||
//
|
||||
// Care should be taken when providing a custom endpoint for an API. If the
|
||||
// endpoint hostname is mutable, and the client cannot modify the endpoint
|
||||
// correctly, the operation call will most likely fail, or have undefined
|
||||
// behavior.
|
||||
//
|
||||
// If hostname is immutable, the SDK API clients will not modify the
|
||||
// hostname of the URL. This may cause the API client not to function
|
||||
// correctly if the API requires the operation specific hostname values
|
||||
// to be used by the client.
|
||||
//
|
||||
// This flag does not modify the API client's behavior if this endpoint
|
||||
// will be used instead of Endpoint Discovery, or if the endpoint will be
|
||||
// used to perform Endpoint Discovery. That behavior is configured via the
|
||||
// API Client's Options.
|
||||
HostnameImmutable bool
|
||||
|
||||
// The AWS partition the endpoint belongs to.
|
||||
PartitionID string
|
||||
|
||||
// The service name that should be used for signing the requests to the
|
||||
// endpoint.
|
||||
SigningName string
|
||||
|
||||
// The region that should be used for signing the request to the endpoint.
|
||||
SigningRegion string
|
||||
|
||||
// The signing method that should be used for signing the requests to the
|
||||
// endpoint.
|
||||
SigningMethod string
|
||||
|
||||
// The source of the Endpoint. By default, this will be EndpointSourceServiceMetadata.
|
||||
// When providing a custom endpoint, you should set the source as EndpointSourceCustom.
|
||||
// If source is not provided when providing a custom endpoint, the SDK may not
|
||||
// perform required host mutations correctly. Source should be used along with
|
||||
// HostnameImmutable property as per the usage requirement.
|
||||
Source EndpointSource
|
||||
}
|
||||
|
||||
// EndpointSource is the endpoint source type.
|
||||
type EndpointSource int
|
||||
|
||||
const (
|
||||
// EndpointSourceServiceMetadata denotes service modeled endpoint metadata is used as Endpoint Source.
|
||||
EndpointSourceServiceMetadata EndpointSource = iota
|
||||
|
||||
// EndpointSourceCustom denotes endpoint is a custom endpoint. This source should be used when
|
||||
// user provides a custom endpoint to be used by the SDK.
|
||||
EndpointSourceCustom
|
||||
)
|
||||
|
||||
// EndpointNotFoundError is a sentinel error to indicate that the
|
||||
// EndpointResolver implementation was unable to resolve an endpoint for the
|
||||
// given service and region. Resolvers should use this to indicate that an API
|
||||
// client should fallback and attempt to use it's internal default resolver to
|
||||
// resolve the endpoint.
|
||||
type EndpointNotFoundError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error is the error message.
|
||||
func (e *EndpointNotFoundError) Error() string {
|
||||
return fmt.Sprintf("endpoint not found, %v", e.Err)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (e *EndpointNotFoundError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// EndpointResolver is an endpoint resolver that can be used to provide or
|
||||
// override an endpoint for the given service and region. API clients will
|
||||
// attempt to use the EndpointResolver first to resolve an endpoint if
|
||||
// available. If the EndpointResolver returns an EndpointNotFoundError error,
|
||||
// API clients will fallback to attempting to resolve the endpoint using its
|
||||
// internal default endpoint resolver.
|
||||
//
|
||||
// Deprecated: See EndpointResolverWithOptions
|
||||
type EndpointResolver interface {
|
||||
ResolveEndpoint(service, region string) (Endpoint, error)
|
||||
}
|
||||
|
||||
// EndpointResolverFunc wraps a function to satisfy the EndpointResolver interface.
|
||||
//
|
||||
// Deprecated: See EndpointResolverWithOptionsFunc
|
||||
type EndpointResolverFunc func(service, region string) (Endpoint, error)
|
||||
|
||||
// ResolveEndpoint calls the wrapped function and returns the results.
|
||||
//
|
||||
// Deprecated: See EndpointResolverWithOptions.ResolveEndpoint
|
||||
func (e EndpointResolverFunc) ResolveEndpoint(service, region string) (Endpoint, error) {
|
||||
return e(service, region)
|
||||
}
|
||||
|
||||
// EndpointResolverWithOptions is an endpoint resolver that can be used to provide or
|
||||
// override an endpoint for the given service, region, and the service client's EndpointOptions. API clients will
|
||||
// attempt to use the EndpointResolverWithOptions first to resolve an endpoint if
|
||||
// available. If the EndpointResolverWithOptions returns an EndpointNotFoundError error,
|
||||
// API clients will fallback to attempting to resolve the endpoint using its
|
||||
// internal default endpoint resolver.
|
||||
type EndpointResolverWithOptions interface {
|
||||
ResolveEndpoint(service, region string, options ...interface{}) (Endpoint, error)
|
||||
}
|
||||
|
||||
// EndpointResolverWithOptionsFunc wraps a function to satisfy the EndpointResolverWithOptions interface.
|
||||
type EndpointResolverWithOptionsFunc func(service, region string, options ...interface{}) (Endpoint, error)
|
||||
|
||||
// ResolveEndpoint calls the wrapped function and returns the results.
|
||||
func (e EndpointResolverWithOptionsFunc) ResolveEndpoint(service, region string, options ...interface{}) (Endpoint, error) {
|
||||
return e(service, region, options...)
|
||||
}
|
||||
|
||||
// GetDisableHTTPS takes a service's EndpointResolverOptions and returns the DisableHTTPS value.
|
||||
// Returns boolean false if the provided options does not have a method to retrieve the DisableHTTPS.
|
||||
func GetDisableHTTPS(options ...interface{}) (value bool, found bool) {
|
||||
type iface interface {
|
||||
GetDisableHTTPS() bool
|
||||
}
|
||||
for _, option := range options {
|
||||
if i, ok := option.(iface); ok {
|
||||
value = i.GetDisableHTTPS()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return value, found
|
||||
}
|
||||
|
||||
// GetResolvedRegion takes a service's EndpointResolverOptions and returns the ResolvedRegion value.
|
||||
// Returns boolean false if the provided options does not have a method to retrieve the ResolvedRegion.
|
||||
func GetResolvedRegion(options ...interface{}) (value string, found bool) {
|
||||
type iface interface {
|
||||
GetResolvedRegion() string
|
||||
}
|
||||
for _, option := range options {
|
||||
if i, ok := option.(iface); ok {
|
||||
value = i.GetResolvedRegion()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return value, found
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
package aws
|
||||
|
||||
// MissingRegionError is an error that is returned if region configuration
|
||||
// value was not found.
|
||||
type MissingRegionError struct{}
|
||||
|
||||
func (*MissingRegionError) Error() string {
|
||||
return "an AWS region is required, but was not found"
|
||||
}
|
@ -0,0 +1,365 @@
|
||||
// Code generated by aws/generate.go DO NOT EDIT.
|
||||
|
||||
package aws
|
||||
|
||||
import (
|
||||
"github.com/aws/smithy-go/ptr"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ToBool returns bool value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a bool zero value if the
|
||||
// pointer was nil.
|
||||
func ToBool(p *bool) (v bool) {
|
||||
return ptr.ToBool(p)
|
||||
}
|
||||
|
||||
// ToBoolSlice returns a slice of bool values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a bool
|
||||
// zero value if the pointer was nil.
|
||||
func ToBoolSlice(vs []*bool) []bool {
|
||||
return ptr.ToBoolSlice(vs)
|
||||
}
|
||||
|
||||
// ToBoolMap returns a map of bool values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The bool
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToBoolMap(vs map[string]*bool) map[string]bool {
|
||||
return ptr.ToBoolMap(vs)
|
||||
}
|
||||
|
||||
// ToByte returns byte value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a byte zero value if the
|
||||
// pointer was nil.
|
||||
func ToByte(p *byte) (v byte) {
|
||||
return ptr.ToByte(p)
|
||||
}
|
||||
|
||||
// ToByteSlice returns a slice of byte values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a byte
|
||||
// zero value if the pointer was nil.
|
||||
func ToByteSlice(vs []*byte) []byte {
|
||||
return ptr.ToByteSlice(vs)
|
||||
}
|
||||
|
||||
// ToByteMap returns a map of byte values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The byte
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToByteMap(vs map[string]*byte) map[string]byte {
|
||||
return ptr.ToByteMap(vs)
|
||||
}
|
||||
|
||||
// ToString returns string value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a string zero value if the
|
||||
// pointer was nil.
|
||||
func ToString(p *string) (v string) {
|
||||
return ptr.ToString(p)
|
||||
}
|
||||
|
||||
// ToStringSlice returns a slice of string values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a string
|
||||
// zero value if the pointer was nil.
|
||||
func ToStringSlice(vs []*string) []string {
|
||||
return ptr.ToStringSlice(vs)
|
||||
}
|
||||
|
||||
// ToStringMap returns a map of string values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The string
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToStringMap(vs map[string]*string) map[string]string {
|
||||
return ptr.ToStringMap(vs)
|
||||
}
|
||||
|
||||
// ToInt returns int value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a int zero value if the
|
||||
// pointer was nil.
|
||||
func ToInt(p *int) (v int) {
|
||||
return ptr.ToInt(p)
|
||||
}
|
||||
|
||||
// ToIntSlice returns a slice of int values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a int
|
||||
// zero value if the pointer was nil.
|
||||
func ToIntSlice(vs []*int) []int {
|
||||
return ptr.ToIntSlice(vs)
|
||||
}
|
||||
|
||||
// ToIntMap returns a map of int values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The int
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToIntMap(vs map[string]*int) map[string]int {
|
||||
return ptr.ToIntMap(vs)
|
||||
}
|
||||
|
||||
// ToInt8 returns int8 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a int8 zero value if the
|
||||
// pointer was nil.
|
||||
func ToInt8(p *int8) (v int8) {
|
||||
return ptr.ToInt8(p)
|
||||
}
|
||||
|
||||
// ToInt8Slice returns a slice of int8 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a int8
|
||||
// zero value if the pointer was nil.
|
||||
func ToInt8Slice(vs []*int8) []int8 {
|
||||
return ptr.ToInt8Slice(vs)
|
||||
}
|
||||
|
||||
// ToInt8Map returns a map of int8 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The int8
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToInt8Map(vs map[string]*int8) map[string]int8 {
|
||||
return ptr.ToInt8Map(vs)
|
||||
}
|
||||
|
||||
// ToInt16 returns int16 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a int16 zero value if the
|
||||
// pointer was nil.
|
||||
func ToInt16(p *int16) (v int16) {
|
||||
return ptr.ToInt16(p)
|
||||
}
|
||||
|
||||
// ToInt16Slice returns a slice of int16 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a int16
|
||||
// zero value if the pointer was nil.
|
||||
func ToInt16Slice(vs []*int16) []int16 {
|
||||
return ptr.ToInt16Slice(vs)
|
||||
}
|
||||
|
||||
// ToInt16Map returns a map of int16 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The int16
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToInt16Map(vs map[string]*int16) map[string]int16 {
|
||||
return ptr.ToInt16Map(vs)
|
||||
}
|
||||
|
||||
// ToInt32 returns int32 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a int32 zero value if the
|
||||
// pointer was nil.
|
||||
func ToInt32(p *int32) (v int32) {
|
||||
return ptr.ToInt32(p)
|
||||
}
|
||||
|
||||
// ToInt32Slice returns a slice of int32 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a int32
|
||||
// zero value if the pointer was nil.
|
||||
func ToInt32Slice(vs []*int32) []int32 {
|
||||
return ptr.ToInt32Slice(vs)
|
||||
}
|
||||
|
||||
// ToInt32Map returns a map of int32 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The int32
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToInt32Map(vs map[string]*int32) map[string]int32 {
|
||||
return ptr.ToInt32Map(vs)
|
||||
}
|
||||
|
||||
// ToInt64 returns int64 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a int64 zero value if the
|
||||
// pointer was nil.
|
||||
func ToInt64(p *int64) (v int64) {
|
||||
return ptr.ToInt64(p)
|
||||
}
|
||||
|
||||
// ToInt64Slice returns a slice of int64 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a int64
|
||||
// zero value if the pointer was nil.
|
||||
func ToInt64Slice(vs []*int64) []int64 {
|
||||
return ptr.ToInt64Slice(vs)
|
||||
}
|
||||
|
||||
// ToInt64Map returns a map of int64 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The int64
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToInt64Map(vs map[string]*int64) map[string]int64 {
|
||||
return ptr.ToInt64Map(vs)
|
||||
}
|
||||
|
||||
// ToUint returns uint value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a uint zero value if the
|
||||
// pointer was nil.
|
||||
func ToUint(p *uint) (v uint) {
|
||||
return ptr.ToUint(p)
|
||||
}
|
||||
|
||||
// ToUintSlice returns a slice of uint values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a uint
|
||||
// zero value if the pointer was nil.
|
||||
func ToUintSlice(vs []*uint) []uint {
|
||||
return ptr.ToUintSlice(vs)
|
||||
}
|
||||
|
||||
// ToUintMap returns a map of uint values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The uint
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToUintMap(vs map[string]*uint) map[string]uint {
|
||||
return ptr.ToUintMap(vs)
|
||||
}
|
||||
|
||||
// ToUint8 returns uint8 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a uint8 zero value if the
|
||||
// pointer was nil.
|
||||
func ToUint8(p *uint8) (v uint8) {
|
||||
return ptr.ToUint8(p)
|
||||
}
|
||||
|
||||
// ToUint8Slice returns a slice of uint8 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a uint8
|
||||
// zero value if the pointer was nil.
|
||||
func ToUint8Slice(vs []*uint8) []uint8 {
|
||||
return ptr.ToUint8Slice(vs)
|
||||
}
|
||||
|
||||
// ToUint8Map returns a map of uint8 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The uint8
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToUint8Map(vs map[string]*uint8) map[string]uint8 {
|
||||
return ptr.ToUint8Map(vs)
|
||||
}
|
||||
|
||||
// ToUint16 returns uint16 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a uint16 zero value if the
|
||||
// pointer was nil.
|
||||
func ToUint16(p *uint16) (v uint16) {
|
||||
return ptr.ToUint16(p)
|
||||
}
|
||||
|
||||
// ToUint16Slice returns a slice of uint16 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a uint16
|
||||
// zero value if the pointer was nil.
|
||||
func ToUint16Slice(vs []*uint16) []uint16 {
|
||||
return ptr.ToUint16Slice(vs)
|
||||
}
|
||||
|
||||
// ToUint16Map returns a map of uint16 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The uint16
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToUint16Map(vs map[string]*uint16) map[string]uint16 {
|
||||
return ptr.ToUint16Map(vs)
|
||||
}
|
||||
|
||||
// ToUint32 returns uint32 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a uint32 zero value if the
|
||||
// pointer was nil.
|
||||
func ToUint32(p *uint32) (v uint32) {
|
||||
return ptr.ToUint32(p)
|
||||
}
|
||||
|
||||
// ToUint32Slice returns a slice of uint32 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a uint32
|
||||
// zero value if the pointer was nil.
|
||||
func ToUint32Slice(vs []*uint32) []uint32 {
|
||||
return ptr.ToUint32Slice(vs)
|
||||
}
|
||||
|
||||
// ToUint32Map returns a map of uint32 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The uint32
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToUint32Map(vs map[string]*uint32) map[string]uint32 {
|
||||
return ptr.ToUint32Map(vs)
|
||||
}
|
||||
|
||||
// ToUint64 returns uint64 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a uint64 zero value if the
|
||||
// pointer was nil.
|
||||
func ToUint64(p *uint64) (v uint64) {
|
||||
return ptr.ToUint64(p)
|
||||
}
|
||||
|
||||
// ToUint64Slice returns a slice of uint64 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a uint64
|
||||
// zero value if the pointer was nil.
|
||||
func ToUint64Slice(vs []*uint64) []uint64 {
|
||||
return ptr.ToUint64Slice(vs)
|
||||
}
|
||||
|
||||
// ToUint64Map returns a map of uint64 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The uint64
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToUint64Map(vs map[string]*uint64) map[string]uint64 {
|
||||
return ptr.ToUint64Map(vs)
|
||||
}
|
||||
|
||||
// ToFloat32 returns float32 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a float32 zero value if the
|
||||
// pointer was nil.
|
||||
func ToFloat32(p *float32) (v float32) {
|
||||
return ptr.ToFloat32(p)
|
||||
}
|
||||
|
||||
// ToFloat32Slice returns a slice of float32 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a float32
|
||||
// zero value if the pointer was nil.
|
||||
func ToFloat32Slice(vs []*float32) []float32 {
|
||||
return ptr.ToFloat32Slice(vs)
|
||||
}
|
||||
|
||||
// ToFloat32Map returns a map of float32 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The float32
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToFloat32Map(vs map[string]*float32) map[string]float32 {
|
||||
return ptr.ToFloat32Map(vs)
|
||||
}
|
||||
|
||||
// ToFloat64 returns float64 value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a float64 zero value if the
|
||||
// pointer was nil.
|
||||
func ToFloat64(p *float64) (v float64) {
|
||||
return ptr.ToFloat64(p)
|
||||
}
|
||||
|
||||
// ToFloat64Slice returns a slice of float64 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a float64
|
||||
// zero value if the pointer was nil.
|
||||
func ToFloat64Slice(vs []*float64) []float64 {
|
||||
return ptr.ToFloat64Slice(vs)
|
||||
}
|
||||
|
||||
// ToFloat64Map returns a map of float64 values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The float64
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToFloat64Map(vs map[string]*float64) map[string]float64 {
|
||||
return ptr.ToFloat64Map(vs)
|
||||
}
|
||||
|
||||
// ToTime returns time.Time value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a time.Time zero value if the
|
||||
// pointer was nil.
|
||||
func ToTime(p *time.Time) (v time.Time) {
|
||||
return ptr.ToTime(p)
|
||||
}
|
||||
|
||||
// ToTimeSlice returns a slice of time.Time values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a time.Time
|
||||
// zero value if the pointer was nil.
|
||||
func ToTimeSlice(vs []*time.Time) []time.Time {
|
||||
return ptr.ToTimeSlice(vs)
|
||||
}
|
||||
|
||||
// ToTimeMap returns a map of time.Time values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The time.Time
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToTimeMap(vs map[string]*time.Time) map[string]time.Time {
|
||||
return ptr.ToTimeMap(vs)
|
||||
}
|
||||
|
||||
// ToDuration returns time.Duration value dereferenced if the passed
|
||||
// in pointer was not nil. Returns a time.Duration zero value if the
|
||||
// pointer was nil.
|
||||
func ToDuration(p *time.Duration) (v time.Duration) {
|
||||
return ptr.ToDuration(p)
|
||||
}
|
||||
|
||||
// ToDurationSlice returns a slice of time.Duration values, that are
|
||||
// dereferenced if the passed in pointer was not nil. Returns a time.Duration
|
||||
// zero value if the pointer was nil.
|
||||
func ToDurationSlice(vs []*time.Duration) []time.Duration {
|
||||
return ptr.ToDurationSlice(vs)
|
||||
}
|
||||
|
||||
// ToDurationMap returns a map of time.Duration values, that are
|
||||
// dereferenced if the passed in pointer was not nil. The time.Duration
|
||||
// zero value is used if the pointer was nil.
|
||||
func ToDurationMap(vs map[string]*time.Duration) map[string]time.Duration {
|
||||
return ptr.ToDurationMap(vs)
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
// Code generated by internal/repotools/cmd/updatemodulemeta DO NOT EDIT.
|
||||
|
||||
package aws
|
||||
|
||||
// goModuleVersion is the tagged release for this module
|
||||
const goModuleVersion = "1.17.3"
|
@ -0,0 +1,119 @@
|
||||
// Code generated by aws/logging_generate.go DO NOT EDIT.
|
||||
|
||||
package aws
|
||||
|
||||
// ClientLogMode represents the logging mode of SDK clients. The client logging mode is a bit-field where
|
||||
// each bit is a flag that describes the logging behavior for one or more client components.
|
||||
// The entire 64-bit group is reserved for later expansion by the SDK.
|
||||
//
|
||||
// Example: Setting ClientLogMode to enable logging of retries and requests
|
||||
//
|
||||
// clientLogMode := aws.LogRetries | aws.LogRequest
|
||||
//
|
||||
// Example: Adding an additional log mode to an existing ClientLogMode value
|
||||
//
|
||||
// clientLogMode |= aws.LogResponse
|
||||
type ClientLogMode uint64
|
||||
|
||||
// Supported ClientLogMode bits that can be configured to toggle logging of specific SDK events.
|
||||
const (
|
||||
LogSigning ClientLogMode = 1 << (64 - 1 - iota)
|
||||
LogRetries
|
||||
LogRequest
|
||||
LogRequestWithBody
|
||||
LogResponse
|
||||
LogResponseWithBody
|
||||
LogDeprecatedUsage
|
||||
LogRequestEventMessage
|
||||
LogResponseEventMessage
|
||||
)
|
||||
|
||||
// IsSigning returns whether the Signing logging mode bit is set
|
||||
func (m ClientLogMode) IsSigning() bool {
|
||||
return m&LogSigning != 0
|
||||
}
|
||||
|
||||
// IsRetries returns whether the Retries logging mode bit is set
|
||||
func (m ClientLogMode) IsRetries() bool {
|
||||
return m&LogRetries != 0
|
||||
}
|
||||
|
||||
// IsRequest returns whether the Request logging mode bit is set
|
||||
func (m ClientLogMode) IsRequest() bool {
|
||||
return m&LogRequest != 0
|
||||
}
|
||||
|
||||
// IsRequestWithBody returns whether the RequestWithBody logging mode bit is set
|
||||
func (m ClientLogMode) IsRequestWithBody() bool {
|
||||
return m&LogRequestWithBody != 0
|
||||
}
|
||||
|
||||
// IsResponse returns whether the Response logging mode bit is set
|
||||
func (m ClientLogMode) IsResponse() bool {
|
||||
return m&LogResponse != 0
|
||||
}
|
||||
|
||||
// IsResponseWithBody returns whether the ResponseWithBody logging mode bit is set
|
||||
func (m ClientLogMode) IsResponseWithBody() bool {
|
||||
return m&LogResponseWithBody != 0
|
||||
}
|
||||
|
||||
// IsDeprecatedUsage returns whether the DeprecatedUsage logging mode bit is set
|
||||
func (m ClientLogMode) IsDeprecatedUsage() bool {
|
||||
return m&LogDeprecatedUsage != 0
|
||||
}
|
||||
|
||||
// IsRequestEventMessage returns whether the RequestEventMessage logging mode bit is set
|
||||
func (m ClientLogMode) IsRequestEventMessage() bool {
|
||||
return m&LogRequestEventMessage != 0
|
||||
}
|
||||
|
||||
// IsResponseEventMessage returns whether the ResponseEventMessage logging mode bit is set
|
||||
func (m ClientLogMode) IsResponseEventMessage() bool {
|
||||
return m&LogResponseEventMessage != 0
|
||||
}
|
||||
|
||||
// ClearSigning clears the Signing logging mode bit
|
||||
func (m *ClientLogMode) ClearSigning() {
|
||||
*m &^= LogSigning
|
||||
}
|
||||
|
||||
// ClearRetries clears the Retries logging mode bit
|
||||
func (m *ClientLogMode) ClearRetries() {
|
||||
*m &^= LogRetries
|
||||
}
|
||||
|
||||
// ClearRequest clears the Request logging mode bit
|
||||
func (m *ClientLogMode) ClearRequest() {
|
||||
*m &^= LogRequest
|
||||
}
|
||||
|
||||
// ClearRequestWithBody clears the RequestWithBody logging mode bit
|
||||
func (m *ClientLogMode) ClearRequestWithBody() {
|
||||
*m &^= LogRequestWithBody
|
||||
}
|
||||
|
||||
// ClearResponse clears the Response logging mode bit
|
||||
func (m *ClientLogMode) ClearResponse() {
|
||||
*m &^= LogResponse
|
||||
}
|
||||
|
||||
// ClearResponseWithBody clears the ResponseWithBody logging mode bit
|
||||
func (m *ClientLogMode) ClearResponseWithBody() {
|
||||
*m &^= LogResponseWithBody
|
||||
}
|
||||
|
||||
// ClearDeprecatedUsage clears the DeprecatedUsage logging mode bit
|
||||
func (m *ClientLogMode) ClearDeprecatedUsage() {
|
||||
*m &^= LogDeprecatedUsage
|
||||
}
|
||||
|
||||
// ClearRequestEventMessage clears the RequestEventMessage logging mode bit
|
||||
func (m *ClientLogMode) ClearRequestEventMessage() {
|
||||
*m &^= LogRequestEventMessage
|
||||
}
|
||||
|
||||
// ClearResponseEventMessage clears the ResponseEventMessage logging mode bit
|
||||
func (m *ClientLogMode) ClearResponseEventMessage() {
|
||||
*m &^= LogResponseEventMessage
|
||||
}
|
@ -0,0 +1,95 @@
|
||||
//go:build clientlogmode
|
||||
// +build clientlogmode
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
var config = struct {
|
||||
ModeBits []string
|
||||
}{
|
||||
// Items should be appended only to keep bit-flag positions stable
|
||||
ModeBits: []string{
|
||||
"Signing",
|
||||
"Retries",
|
||||
"Request",
|
||||
"RequestWithBody",
|
||||
"Response",
|
||||
"ResponseWithBody",
|
||||
"DeprecatedUsage",
|
||||
"RequestEventMessage",
|
||||
"ResponseEventMessage",
|
||||
},
|
||||
}
|
||||
|
||||
func bitName(name string) string {
|
||||
return strings.ToUpper(name[:1]) + name[1:]
|
||||
}
|
||||
|
||||
var tmpl = template.Must(template.New("ClientLogMode").Funcs(map[string]interface{}{
|
||||
"symbolName": func(name string) string {
|
||||
return "Log" + bitName(name)
|
||||
},
|
||||
"bitName": bitName,
|
||||
}).Parse(`// Code generated by aws/logging_generate.go DO NOT EDIT.
|
||||
|
||||
package aws
|
||||
|
||||
// ClientLogMode represents the logging mode of SDK clients. The client logging mode is a bit-field where
|
||||
// each bit is a flag that describes the logging behavior for one or more client components.
|
||||
// The entire 64-bit group is reserved for later expansion by the SDK.
|
||||
//
|
||||
// Example: Setting ClientLogMode to enable logging of retries and requests
|
||||
// clientLogMode := aws.LogRetries | aws.LogRequest
|
||||
//
|
||||
// Example: Adding an additional log mode to an existing ClientLogMode value
|
||||
// clientLogMode |= aws.LogResponse
|
||||
type ClientLogMode uint64
|
||||
|
||||
// Supported ClientLogMode bits that can be configured to toggle logging of specific SDK events.
|
||||
const (
|
||||
{{- range $index, $field := .ModeBits }}
|
||||
{{ (symbolName $field) }}{{- if (eq 0 $index) }} ClientLogMode = 1 << (64 - 1 - iota){{- end }}
|
||||
{{- end }}
|
||||
)
|
||||
{{ range $_, $field := .ModeBits }}
|
||||
// Is{{- bitName $field }} returns whether the {{ bitName $field }} logging mode bit is set
|
||||
func (m ClientLogMode) Is{{- bitName $field }}() bool {
|
||||
return m&{{- (symbolName $field) }} != 0
|
||||
}
|
||||
{{ end }}
|
||||
{{- range $_, $field := .ModeBits }}
|
||||
// Clear{{- bitName $field }} clears the {{ bitName $field }} logging mode bit
|
||||
func (m *ClientLogMode) Clear{{- bitName $field }}() {
|
||||
*m &^= {{ (symbolName $field) }}
|
||||
}
|
||||
{{ end -}}
|
||||
`))
|
||||
|
||||
func main() {
|
||||
uniqueBitFields := make(map[string]struct{})
|
||||
|
||||
for _, bitName := range config.ModeBits {
|
||||
if _, ok := uniqueBitFields[strings.ToLower(bitName)]; ok {
|
||||
panic(fmt.Sprintf("duplicate bit field: %s", bitName))
|
||||
}
|
||||
uniqueBitFields[bitName] = struct{}{}
|
||||
}
|
||||
|
||||
file, err := os.Create("logging.go")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
err = tmpl.Execute(file, config)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
@ -0,0 +1,180 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
)
|
||||
|
||||
// RegisterServiceMetadata registers metadata about the service and operation into the middleware context
|
||||
// so that it is available at runtime for other middleware to introspect.
|
||||
type RegisterServiceMetadata struct {
|
||||
ServiceID string
|
||||
SigningName string
|
||||
Region string
|
||||
OperationName string
|
||||
}
|
||||
|
||||
// ID returns the middleware identifier.
|
||||
func (s *RegisterServiceMetadata) ID() string {
|
||||
return "RegisterServiceMetadata"
|
||||
}
|
||||
|
||||
// HandleInitialize registers service metadata information into the middleware context, allowing for introspection.
|
||||
func (s RegisterServiceMetadata) HandleInitialize(
|
||||
ctx context.Context, in middleware.InitializeInput, next middleware.InitializeHandler,
|
||||
) (out middleware.InitializeOutput, metadata middleware.Metadata, err error) {
|
||||
if len(s.ServiceID) > 0 {
|
||||
ctx = SetServiceID(ctx, s.ServiceID)
|
||||
}
|
||||
if len(s.SigningName) > 0 {
|
||||
ctx = SetSigningName(ctx, s.SigningName)
|
||||
}
|
||||
if len(s.Region) > 0 {
|
||||
ctx = setRegion(ctx, s.Region)
|
||||
}
|
||||
if len(s.OperationName) > 0 {
|
||||
ctx = setOperationName(ctx, s.OperationName)
|
||||
}
|
||||
return next.HandleInitialize(ctx, in)
|
||||
}
|
||||
|
||||
// service metadata keys for storing and lookup of runtime stack information.
|
||||
type (
|
||||
serviceIDKey struct{}
|
||||
signingNameKey struct{}
|
||||
signingRegionKey struct{}
|
||||
regionKey struct{}
|
||||
operationNameKey struct{}
|
||||
partitionIDKey struct{}
|
||||
)
|
||||
|
||||
// GetServiceID retrieves the service id from the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetServiceID(ctx context.Context) (v string) {
|
||||
v, _ = middleware.GetStackValue(ctx, serviceIDKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// GetSigningName retrieves the service signing name from the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetSigningName(ctx context.Context) (v string) {
|
||||
v, _ = middleware.GetStackValue(ctx, signingNameKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// GetSigningRegion retrieves the region from the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetSigningRegion(ctx context.Context) (v string) {
|
||||
v, _ = middleware.GetStackValue(ctx, signingRegionKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// GetRegion retrieves the endpoint region from the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetRegion(ctx context.Context) (v string) {
|
||||
v, _ = middleware.GetStackValue(ctx, regionKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// GetOperationName retrieves the service operation metadata from the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetOperationName(ctx context.Context) (v string) {
|
||||
v, _ = middleware.GetStackValue(ctx, operationNameKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// GetPartitionID retrieves the endpoint partition id from the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func GetPartitionID(ctx context.Context) string {
|
||||
v, _ := middleware.GetStackValue(ctx, partitionIDKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
// SetSigningName set or modifies the signing name on the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func SetSigningName(ctx context.Context, value string) context.Context {
|
||||
return middleware.WithStackValue(ctx, signingNameKey{}, value)
|
||||
}
|
||||
|
||||
// SetSigningRegion sets or modifies the region on the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func SetSigningRegion(ctx context.Context, value string) context.Context {
|
||||
return middleware.WithStackValue(ctx, signingRegionKey{}, value)
|
||||
}
|
||||
|
||||
// SetServiceID sets the service id on the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func SetServiceID(ctx context.Context, value string) context.Context {
|
||||
return middleware.WithStackValue(ctx, serviceIDKey{}, value)
|
||||
}
|
||||
|
||||
// setRegion sets the endpoint region on the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func setRegion(ctx context.Context, value string) context.Context {
|
||||
return middleware.WithStackValue(ctx, regionKey{}, value)
|
||||
}
|
||||
|
||||
// setOperationName sets the service operation on the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func setOperationName(ctx context.Context, value string) context.Context {
|
||||
return middleware.WithStackValue(ctx, operationNameKey{}, value)
|
||||
}
|
||||
|
||||
// SetPartitionID sets the partition id of a resolved region on the context
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func SetPartitionID(ctx context.Context, value string) context.Context {
|
||||
return middleware.WithStackValue(ctx, partitionIDKey{}, value)
|
||||
}
|
||||
|
||||
// EndpointSource key
|
||||
type endpointSourceKey struct{}
|
||||
|
||||
// GetEndpointSource returns an endpoint source if set on context
|
||||
func GetEndpointSource(ctx context.Context) (v aws.EndpointSource) {
|
||||
v, _ = middleware.GetStackValue(ctx, endpointSourceKey{}).(aws.EndpointSource)
|
||||
return v
|
||||
}
|
||||
|
||||
// SetEndpointSource sets endpoint source on context
|
||||
func SetEndpointSource(ctx context.Context, value aws.EndpointSource) context.Context {
|
||||
return middleware.WithStackValue(ctx, endpointSourceKey{}, value)
|
||||
}
|
||||
|
||||
type signingCredentialsKey struct{}
|
||||
|
||||
// GetSigningCredentials returns the credentials that were used for signing if set on context.
|
||||
func GetSigningCredentials(ctx context.Context) (v aws.Credentials) {
|
||||
v, _ = middleware.GetStackValue(ctx, signingCredentialsKey{}).(aws.Credentials)
|
||||
return v
|
||||
}
|
||||
|
||||
// SetSigningCredentials sets the credentails used for signing on the context.
|
||||
func SetSigningCredentials(ctx context.Context, value aws.Credentials) context.Context {
|
||||
return middleware.WithStackValue(ctx, signingCredentialsKey{}, value)
|
||||
}
|
@ -0,0 +1,168 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/internal/rand"
|
||||
"github.com/aws/aws-sdk-go-v2/internal/sdk"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyrand "github.com/aws/smithy-go/rand"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
)
|
||||
|
||||
// ClientRequestID is a Smithy BuildMiddleware that will generate a unique ID for logical API operation
|
||||
// invocation.
|
||||
type ClientRequestID struct{}
|
||||
|
||||
// ID the identifier for the ClientRequestID
|
||||
func (r *ClientRequestID) ID() string {
|
||||
return "ClientRequestID"
|
||||
}
|
||||
|
||||
// HandleBuild attaches a unique operation invocation id for the operation to the request
|
||||
func (r ClientRequestID) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (
|
||||
out middleware.BuildOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
req, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, fmt.Errorf("unknown transport type %T", req)
|
||||
}
|
||||
|
||||
invocationID, err := smithyrand.NewUUID(rand.Reader).GetUUID()
|
||||
if err != nil {
|
||||
return out, metadata, err
|
||||
}
|
||||
|
||||
const invocationIDHeader = "Amz-Sdk-Invocation-Id"
|
||||
req.Header[invocationIDHeader] = append(req.Header[invocationIDHeader][:0], invocationID)
|
||||
|
||||
return next.HandleBuild(ctx, in)
|
||||
}
|
||||
|
||||
// RecordResponseTiming records the response timing for the SDK client requests.
|
||||
type RecordResponseTiming struct{}
|
||||
|
||||
// ID is the middleware identifier
|
||||
func (a *RecordResponseTiming) ID() string {
|
||||
return "RecordResponseTiming"
|
||||
}
|
||||
|
||||
// HandleDeserialize calculates response metadata and clock skew
|
||||
func (a RecordResponseTiming) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (
|
||||
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
out, metadata, err = next.HandleDeserialize(ctx, in)
|
||||
responseAt := sdk.NowTime()
|
||||
setResponseAt(&metadata, responseAt)
|
||||
|
||||
var serverTime time.Time
|
||||
|
||||
switch resp := out.RawResponse.(type) {
|
||||
case *smithyhttp.Response:
|
||||
respDateHeader := resp.Header.Get("Date")
|
||||
if len(respDateHeader) == 0 {
|
||||
break
|
||||
}
|
||||
var parseErr error
|
||||
serverTime, parseErr = smithyhttp.ParseTime(respDateHeader)
|
||||
if parseErr != nil {
|
||||
logger := middleware.GetLogger(ctx)
|
||||
logger.Logf(logging.Warn, "failed to parse response Date header value, got %v",
|
||||
parseErr.Error())
|
||||
break
|
||||
}
|
||||
setServerTime(&metadata, serverTime)
|
||||
}
|
||||
|
||||
if !serverTime.IsZero() {
|
||||
attemptSkew := serverTime.Sub(responseAt)
|
||||
setAttemptSkew(&metadata, attemptSkew)
|
||||
}
|
||||
|
||||
return out, metadata, err
|
||||
}
|
||||
|
||||
type responseAtKey struct{}
|
||||
|
||||
// GetResponseAt returns the time response was received at.
|
||||
func GetResponseAt(metadata middleware.Metadata) (v time.Time, ok bool) {
|
||||
v, ok = metadata.Get(responseAtKey{}).(time.Time)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// setResponseAt sets the response time on the metadata.
|
||||
func setResponseAt(metadata *middleware.Metadata, v time.Time) {
|
||||
metadata.Set(responseAtKey{}, v)
|
||||
}
|
||||
|
||||
type serverTimeKey struct{}
|
||||
|
||||
// GetServerTime returns the server time for response.
|
||||
func GetServerTime(metadata middleware.Metadata) (v time.Time, ok bool) {
|
||||
v, ok = metadata.Get(serverTimeKey{}).(time.Time)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// setServerTime sets the server time on the metadata.
|
||||
func setServerTime(metadata *middleware.Metadata, v time.Time) {
|
||||
metadata.Set(serverTimeKey{}, v)
|
||||
}
|
||||
|
||||
type attemptSkewKey struct{}
|
||||
|
||||
// GetAttemptSkew returns Attempt clock skew for response from metadata.
|
||||
func GetAttemptSkew(metadata middleware.Metadata) (v time.Duration, ok bool) {
|
||||
v, ok = metadata.Get(attemptSkewKey{}).(time.Duration)
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// setAttemptSkew sets the attempt clock skew on the metadata.
|
||||
func setAttemptSkew(metadata *middleware.Metadata, v time.Duration) {
|
||||
metadata.Set(attemptSkewKey{}, v)
|
||||
}
|
||||
|
||||
// AddClientRequestIDMiddleware adds ClientRequestID to the middleware stack
|
||||
func AddClientRequestIDMiddleware(stack *middleware.Stack) error {
|
||||
return stack.Build.Add(&ClientRequestID{}, middleware.After)
|
||||
}
|
||||
|
||||
// AddRecordResponseTiming adds RecordResponseTiming middleware to the
|
||||
// middleware stack.
|
||||
func AddRecordResponseTiming(stack *middleware.Stack) error {
|
||||
return stack.Deserialize.Add(&RecordResponseTiming{}, middleware.After)
|
||||
}
|
||||
|
||||
// rawResponseKey is the accessor key used to store and access the
|
||||
// raw response within the response metadata.
|
||||
type rawResponseKey struct{}
|
||||
|
||||
// addRawResponse middleware adds raw response on to the metadata
|
||||
type addRawResponse struct{}
|
||||
|
||||
// ID the identifier for the ClientRequestID
|
||||
func (m *addRawResponse) ID() string {
|
||||
return "AddRawResponseToMetadata"
|
||||
}
|
||||
|
||||
// HandleDeserialize adds raw response on the middleware metadata
|
||||
func (m addRawResponse) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (
|
||||
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
out, metadata, err = next.HandleDeserialize(ctx, in)
|
||||
metadata.Set(rawResponseKey{}, out.RawResponse)
|
||||
return out, metadata, err
|
||||
}
|
||||
|
||||
// AddRawResponseToMetadata adds middleware to the middleware stack that
|
||||
// store raw response on to the metadata.
|
||||
func AddRawResponseToMetadata(stack *middleware.Stack) error {
|
||||
return stack.Deserialize.Add(&addRawResponse{}, middleware.Before)
|
||||
}
|
||||
|
||||
// GetRawResponse returns raw response set on metadata
|
||||
func GetRawResponse(metadata middleware.Metadata) interface{} {
|
||||
return metadata.Get(rawResponseKey{})
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
//go:build go1.16
|
||||
// +build go1.16
|
||||
|
||||
package middleware
|
||||
|
||||
import "runtime"
|
||||
|
||||
func getNormalizedOSName() (os string) {
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
os = "android"
|
||||
case "linux":
|
||||
os = "linux"
|
||||
case "windows":
|
||||
os = "windows"
|
||||
case "darwin":
|
||||
os = "macos"
|
||||
case "ios":
|
||||
os = "ios"
|
||||
default:
|
||||
os = "other"
|
||||
}
|
||||
return os
|
||||
}
|
@ -0,0 +1,24 @@
|
||||
//go:build !go1.16
|
||||
// +build !go1.16
|
||||
|
||||
package middleware
|
||||
|
||||
import "runtime"
|
||||
|
||||
func getNormalizedOSName() (os string) {
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
os = "android"
|
||||
case "linux":
|
||||
os = "linux"
|
||||
case "windows":
|
||||
os = "windows"
|
||||
case "darwin":
|
||||
// Due to Apple M1 we can't distinguish between macOS and iOS when GOOS/GOARCH is darwin/amd64
|
||||
// For now declare this as "other" until we have a better detection mechanism.
|
||||
fallthrough
|
||||
default:
|
||||
os = "other"
|
||||
}
|
||||
return os
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
)
|
||||
|
||||
// requestIDKey is used to retrieve request id from response metadata
|
||||
type requestIDKey struct{}
|
||||
|
||||
// SetRequestIDMetadata sets the provided request id over middleware metadata
|
||||
func SetRequestIDMetadata(metadata *middleware.Metadata, id string) {
|
||||
metadata.Set(requestIDKey{}, id)
|
||||
}
|
||||
|
||||
// GetRequestIDMetadata retrieves the request id from middleware metadata
|
||||
// returns string and bool indicating value of request id, whether request id was set.
|
||||
func GetRequestIDMetadata(metadata middleware.Metadata) (string, bool) {
|
||||
if !metadata.Has(requestIDKey{}) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
v, ok := metadata.Get(requestIDKey{}).(string)
|
||||
if !ok {
|
||||
return "", true
|
||||
}
|
||||
return v, true
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
)
|
||||
|
||||
// AddRequestIDRetrieverMiddleware adds request id retriever middleware
|
||||
func AddRequestIDRetrieverMiddleware(stack *middleware.Stack) error {
|
||||
// add error wrapper middleware before operation deserializers so that it can wrap the error response
|
||||
// returned by operation deserializers
|
||||
return stack.Deserialize.Insert(&requestIDRetriever{}, "OperationDeserializer", middleware.Before)
|
||||
}
|
||||
|
||||
type requestIDRetriever struct {
|
||||
}
|
||||
|
||||
// ID returns the middleware identifier
|
||||
func (m *requestIDRetriever) ID() string {
|
||||
return "RequestIDRetriever"
|
||||
}
|
||||
|
||||
func (m *requestIDRetriever) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) (
|
||||
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
out, metadata, err = next.HandleDeserialize(ctx, in)
|
||||
|
||||
resp, ok := out.RawResponse.(*smithyhttp.Response)
|
||||
if !ok {
|
||||
// No raw response to wrap with.
|
||||
return out, metadata, err
|
||||
}
|
||||
|
||||
// Different header which can map to request id
|
||||
requestIDHeaderList := []string{"X-Amzn-Requestid", "X-Amz-RequestId"}
|
||||
|
||||
for _, h := range requestIDHeaderList {
|
||||
// check for headers known to contain Request id
|
||||
if v := resp.Header.Get(h); len(v) != 0 {
|
||||
// set reqID on metadata for successful responses.
|
||||
SetRequestIDMetadata(&metadata, v)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return out, metadata, err
|
||||
}
|
@ -0,0 +1,243 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
)
|
||||
|
||||
var languageVersion = strings.TrimPrefix(runtime.Version(), "go")
|
||||
|
||||
// SDKAgentKeyType is the metadata type to add to the SDK agent string
|
||||
type SDKAgentKeyType int
|
||||
|
||||
// The set of valid SDKAgentKeyType constants. If an unknown value is assigned for SDKAgentKeyType it will
|
||||
// be mapped to AdditionalMetadata.
|
||||
const (
|
||||
_ SDKAgentKeyType = iota
|
||||
APIMetadata
|
||||
OperatingSystemMetadata
|
||||
LanguageMetadata
|
||||
EnvironmentMetadata
|
||||
FeatureMetadata
|
||||
ConfigMetadata
|
||||
FrameworkMetadata
|
||||
AdditionalMetadata
|
||||
ApplicationIdentifier
|
||||
)
|
||||
|
||||
func (k SDKAgentKeyType) string() string {
|
||||
switch k {
|
||||
case APIMetadata:
|
||||
return "api"
|
||||
case OperatingSystemMetadata:
|
||||
return "os"
|
||||
case LanguageMetadata:
|
||||
return "lang"
|
||||
case EnvironmentMetadata:
|
||||
return "exec-env"
|
||||
case FeatureMetadata:
|
||||
return "ft"
|
||||
case ConfigMetadata:
|
||||
return "cfg"
|
||||
case FrameworkMetadata:
|
||||
return "lib"
|
||||
case ApplicationIdentifier:
|
||||
return "app"
|
||||
case AdditionalMetadata:
|
||||
fallthrough
|
||||
default:
|
||||
return "md"
|
||||
}
|
||||
}
|
||||
|
||||
const execEnvVar = `AWS_EXECUTION_ENV`
|
||||
|
||||
// requestUserAgent is a build middleware that set the User-Agent for the request.
|
||||
type requestUserAgent struct {
|
||||
sdkAgent, userAgent *smithyhttp.UserAgentBuilder
|
||||
}
|
||||
|
||||
// newRequestUserAgent returns a new requestUserAgent which will set the User-Agent and X-Amz-User-Agent for the
|
||||
// request.
|
||||
//
|
||||
// User-Agent example:
|
||||
//
|
||||
// aws-sdk-go-v2/1.2.3
|
||||
//
|
||||
// X-Amz-User-Agent example:
|
||||
//
|
||||
// aws-sdk-go-v2/1.2.3 md/GOOS/linux md/GOARCH/amd64 lang/go/1.15
|
||||
func newRequestUserAgent() *requestUserAgent {
|
||||
userAgent, sdkAgent := smithyhttp.NewUserAgentBuilder(), smithyhttp.NewUserAgentBuilder()
|
||||
addProductName(userAgent)
|
||||
addProductName(sdkAgent)
|
||||
|
||||
r := &requestUserAgent{
|
||||
sdkAgent: sdkAgent,
|
||||
userAgent: userAgent,
|
||||
}
|
||||
|
||||
addSDKMetadata(r)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func addSDKMetadata(r *requestUserAgent) {
|
||||
r.AddSDKAgentKey(OperatingSystemMetadata, getNormalizedOSName())
|
||||
r.AddSDKAgentKeyValue(LanguageMetadata, "go", languageVersion)
|
||||
r.AddSDKAgentKeyValue(AdditionalMetadata, "GOOS", runtime.GOOS)
|
||||
r.AddSDKAgentKeyValue(AdditionalMetadata, "GOARCH", runtime.GOARCH)
|
||||
if ev := os.Getenv(execEnvVar); len(ev) > 0 {
|
||||
r.AddSDKAgentKey(EnvironmentMetadata, ev)
|
||||
}
|
||||
}
|
||||
|
||||
func addProductName(builder *smithyhttp.UserAgentBuilder) {
|
||||
builder.AddKeyValue(aws.SDKName, aws.SDKVersion)
|
||||
}
|
||||
|
||||
// AddUserAgentKey retrieves a requestUserAgent from the provided stack, or initializes one.
|
||||
func AddUserAgentKey(key string) func(*middleware.Stack) error {
|
||||
return func(stack *middleware.Stack) error {
|
||||
requestUserAgent, err := getOrAddRequestUserAgent(stack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
requestUserAgent.AddUserAgentKey(key)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddUserAgentKeyValue retrieves a requestUserAgent from the provided stack, or initializes one.
|
||||
func AddUserAgentKeyValue(key, value string) func(*middleware.Stack) error {
|
||||
return func(stack *middleware.Stack) error {
|
||||
requestUserAgent, err := getOrAddRequestUserAgent(stack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
requestUserAgent.AddUserAgentKeyValue(key, value)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddSDKAgentKey retrieves a requestUserAgent from the provided stack, or initializes one.
|
||||
func AddSDKAgentKey(keyType SDKAgentKeyType, key string) func(*middleware.Stack) error {
|
||||
return func(stack *middleware.Stack) error {
|
||||
requestUserAgent, err := getOrAddRequestUserAgent(stack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
requestUserAgent.AddSDKAgentKey(keyType, key)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddSDKAgentKeyValue retrieves a requestUserAgent from the provided stack, or initializes one.
|
||||
func AddSDKAgentKeyValue(keyType SDKAgentKeyType, key, value string) func(*middleware.Stack) error {
|
||||
return func(stack *middleware.Stack) error {
|
||||
requestUserAgent, err := getOrAddRequestUserAgent(stack)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
requestUserAgent.AddSDKAgentKeyValue(keyType, key, value)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddRequestUserAgentMiddleware registers a requestUserAgent middleware on the stack if not present.
|
||||
func AddRequestUserAgentMiddleware(stack *middleware.Stack) error {
|
||||
_, err := getOrAddRequestUserAgent(stack)
|
||||
return err
|
||||
}
|
||||
|
||||
func getOrAddRequestUserAgent(stack *middleware.Stack) (*requestUserAgent, error) {
|
||||
id := (*requestUserAgent)(nil).ID()
|
||||
bm, ok := stack.Build.Get(id)
|
||||
if !ok {
|
||||
bm = newRequestUserAgent()
|
||||
err := stack.Build.Add(bm, middleware.After)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
requestUserAgent, ok := bm.(*requestUserAgent)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%T for %s middleware did not match expected type", bm, id)
|
||||
}
|
||||
|
||||
return requestUserAgent, nil
|
||||
}
|
||||
|
||||
// AddUserAgentKey adds the component identified by name to the User-Agent string.
|
||||
func (u *requestUserAgent) AddUserAgentKey(key string) {
|
||||
u.userAgent.AddKey(key)
|
||||
}
|
||||
|
||||
// AddUserAgentKeyValue adds the key identified by the given name and value to the User-Agent string.
|
||||
func (u *requestUserAgent) AddUserAgentKeyValue(key, value string) {
|
||||
u.userAgent.AddKeyValue(key, value)
|
||||
}
|
||||
|
||||
// AddUserAgentKey adds the component identified by name to the User-Agent string.
|
||||
func (u *requestUserAgent) AddSDKAgentKey(keyType SDKAgentKeyType, key string) {
|
||||
// TODO: should target sdkAgent
|
||||
u.userAgent.AddKey(keyType.string() + "/" + key)
|
||||
}
|
||||
|
||||
// AddUserAgentKeyValue adds the key identified by the given name and value to the User-Agent string.
|
||||
func (u *requestUserAgent) AddSDKAgentKeyValue(keyType SDKAgentKeyType, key, value string) {
|
||||
// TODO: should target sdkAgent
|
||||
u.userAgent.AddKeyValue(keyType.string()+"/"+key, value)
|
||||
}
|
||||
|
||||
// ID the name of the middleware.
|
||||
func (u *requestUserAgent) ID() string {
|
||||
return "UserAgent"
|
||||
}
|
||||
|
||||
// HandleBuild adds or appends the constructed user agent to the request.
|
||||
func (u *requestUserAgent) HandleBuild(ctx context.Context, in middleware.BuildInput, next middleware.BuildHandler) (
|
||||
out middleware.BuildOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
switch req := in.Request.(type) {
|
||||
case *smithyhttp.Request:
|
||||
u.addHTTPUserAgent(req)
|
||||
// TODO: To be re-enabled
|
||||
// u.addHTTPSDKAgent(req)
|
||||
default:
|
||||
return out, metadata, fmt.Errorf("unknown transport type %T", in)
|
||||
}
|
||||
|
||||
return next.HandleBuild(ctx, in)
|
||||
}
|
||||
|
||||
func (u *requestUserAgent) addHTTPUserAgent(request *smithyhttp.Request) {
|
||||
const userAgent = "User-Agent"
|
||||
updateHTTPHeader(request, userAgent, u.userAgent.Build())
|
||||
}
|
||||
|
||||
func (u *requestUserAgent) addHTTPSDKAgent(request *smithyhttp.Request) {
|
||||
const sdkAgent = "X-Amz-User-Agent"
|
||||
updateHTTPHeader(request, sdkAgent, u.sdkAgent.Build())
|
||||
}
|
||||
|
||||
func updateHTTPHeader(request *smithyhttp.Request, header string, value string) {
|
||||
var current string
|
||||
if v := request.Header[header]; len(v) > 0 {
|
||||
current = v[0]
|
||||
}
|
||||
if len(current) > 0 {
|
||||
current = value + " " + current
|
||||
} else {
|
||||
current = value
|
||||
}
|
||||
request.Header[header] = append(request.Header[header][:0], current)
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
# v1.4.10 (2022-12-02)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.9 (2022-10-24)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.8 (2022-09-14)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.7 (2022-09-02)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.6 (2022-08-31)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.5 (2022-08-29)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.4 (2022-08-09)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.3 (2022-06-29)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.2 (2022-06-07)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.1 (2022-03-24)
|
||||
|
||||
* No change notes available for this release.
|
||||
|
||||
# v1.4.0 (2022-03-08)
|
||||
|
||||
* **Feature**: Updated `github.com/aws/smithy-go` to latest version
|
||||
|
||||
# v1.3.0 (2022-02-24)
|
||||
|
||||
* **Feature**: Updated `github.com/aws/smithy-go` to latest version
|
||||
|
||||
# v1.2.0 (2022-01-14)
|
||||
|
||||
* **Feature**: Updated `github.com/aws/smithy-go` to latest version
|
||||
|
||||
# v1.1.0 (2022-01-07)
|
||||
|
||||
* **Feature**: Updated `github.com/aws/smithy-go` to latest version
|
||||
|
||||
# v1.0.0 (2021-11-06)
|
||||
|
||||
* **Announcement**: Support has been added for AWS EventStream APIs for Kinesis, S3, and Transcribe Streaming. Support for the Lex Runtime V2 EventStream API will be added in a future release.
|
||||
* **Release**: Protocol support has been added for AWS event stream.
|
||||
* **Feature**: Updated `github.com/aws/smithy-go` to latest version
|
||||
|
@ -0,0 +1,144 @@
|
||||
package eventstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type decodedMessage struct {
|
||||
rawMessage
|
||||
Headers decodedHeaders `json:"headers"`
|
||||
}
|
||||
type jsonMessage struct {
|
||||
Length json.Number `json:"total_length"`
|
||||
HeadersLen json.Number `json:"headers_length"`
|
||||
PreludeCRC json.Number `json:"prelude_crc"`
|
||||
Headers decodedHeaders `json:"headers"`
|
||||
Payload []byte `json:"payload"`
|
||||
CRC json.Number `json:"message_crc"`
|
||||
}
|
||||
|
||||
func (d *decodedMessage) UnmarshalJSON(b []byte) (err error) {
|
||||
var jsonMsg jsonMessage
|
||||
if err = json.Unmarshal(b, &jsonMsg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.Length, err = numAsUint32(jsonMsg.Length)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.HeadersLen, err = numAsUint32(jsonMsg.HeadersLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.PreludeCRC, err = numAsUint32(jsonMsg.PreludeCRC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.Headers = jsonMsg.Headers
|
||||
d.Payload = jsonMsg.Payload
|
||||
d.CRC, err = numAsUint32(jsonMsg.CRC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *decodedMessage) MarshalJSON() ([]byte, error) {
|
||||
jsonMsg := jsonMessage{
|
||||
Length: json.Number(strconv.Itoa(int(d.Length))),
|
||||
HeadersLen: json.Number(strconv.Itoa(int(d.HeadersLen))),
|
||||
PreludeCRC: json.Number(strconv.Itoa(int(d.PreludeCRC))),
|
||||
Headers: d.Headers,
|
||||
Payload: d.Payload,
|
||||
CRC: json.Number(strconv.Itoa(int(d.CRC))),
|
||||
}
|
||||
|
||||
return json.Marshal(jsonMsg)
|
||||
}
|
||||
|
||||
func numAsUint32(n json.Number) (uint32, error) {
|
||||
v, err := n.Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get int64 json number, %v", err)
|
||||
}
|
||||
|
||||
return uint32(v), nil
|
||||
}
|
||||
|
||||
func (d decodedMessage) Message() Message {
|
||||
return Message{
|
||||
Headers: Headers(d.Headers),
|
||||
Payload: d.Payload,
|
||||
}
|
||||
}
|
||||
|
||||
type decodedHeaders Headers
|
||||
|
||||
func (hs *decodedHeaders) UnmarshalJSON(b []byte) error {
|
||||
var jsonHeaders []struct {
|
||||
Name string `json:"name"`
|
||||
Type valueType `json:"type"`
|
||||
Value interface{} `json:"value"`
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(bytes.NewReader(b))
|
||||
decoder.UseNumber()
|
||||
if err := decoder.Decode(&jsonHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var headers Headers
|
||||
for _, h := range jsonHeaders {
|
||||
value, err := valueFromType(h.Type, h.Value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
headers.Set(h.Name, value)
|
||||
}
|
||||
*hs = decodedHeaders(headers)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func valueFromType(typ valueType, val interface{}) (Value, error) {
|
||||
switch typ {
|
||||
case trueValueType:
|
||||
return BoolValue(true), nil
|
||||
case falseValueType:
|
||||
return BoolValue(false), nil
|
||||
case int8ValueType:
|
||||
v, err := val.(json.Number).Int64()
|
||||
return Int8Value(int8(v)), err
|
||||
case int16ValueType:
|
||||
v, err := val.(json.Number).Int64()
|
||||
return Int16Value(int16(v)), err
|
||||
case int32ValueType:
|
||||
v, err := val.(json.Number).Int64()
|
||||
return Int32Value(int32(v)), err
|
||||
case int64ValueType:
|
||||
v, err := val.(json.Number).Int64()
|
||||
return Int64Value(v), err
|
||||
case bytesValueType:
|
||||
v, err := base64.StdEncoding.DecodeString(val.(string))
|
||||
return BytesValue(v), err
|
||||
case stringValueType:
|
||||
v, err := base64.StdEncoding.DecodeString(val.(string))
|
||||
return StringValue(string(v)), err
|
||||
case timestampValueType:
|
||||
v, err := val.(json.Number).Int64()
|
||||
return TimestampValue(timeFromEpochMilli(v)), err
|
||||
case uuidValueType:
|
||||
v, err := base64.StdEncoding.DecodeString(val.(string))
|
||||
var tv UUIDValue
|
||||
copy(tv[:], v)
|
||||
return tv, err
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown type, %s, %T", typ.String(), val))
|
||||
}
|
||||
}
|
@ -0,0 +1,218 @@
|
||||
package eventstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
"hash"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
)
|
||||
|
||||
// DecoderOptions is the Decoder configuration options.
|
||||
type DecoderOptions struct {
|
||||
Logger logging.Logger
|
||||
LogMessages bool
|
||||
}
|
||||
|
||||
// Decoder provides decoding of an Event Stream messages.
|
||||
type Decoder struct {
|
||||
options DecoderOptions
|
||||
}
|
||||
|
||||
// NewDecoder initializes and returns a Decoder for decoding event
|
||||
// stream messages from the reader provided.
|
||||
func NewDecoder(optFns ...func(*DecoderOptions)) *Decoder {
|
||||
options := DecoderOptions{}
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&options)
|
||||
}
|
||||
|
||||
return &Decoder{
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode attempts to decode a single message from the event stream reader.
|
||||
// Will return the event stream message, or error if decodeMessage fails to read
|
||||
// the message from the stream.
|
||||
//
|
||||
// payloadBuf is a byte slice that will be used in the returned Message.Payload. Callers
|
||||
// must ensure that the Message.Payload from a previous decode has been consumed before passing in the same underlying
|
||||
// payloadBuf byte slice.
|
||||
func (d *Decoder) Decode(reader io.Reader, payloadBuf []byte) (m Message, err error) {
|
||||
if d.options.Logger != nil && d.options.LogMessages {
|
||||
debugMsgBuf := bytes.NewBuffer(nil)
|
||||
reader = io.TeeReader(reader, debugMsgBuf)
|
||||
defer func() {
|
||||
logMessageDecode(d.options.Logger, debugMsgBuf, m, err)
|
||||
}()
|
||||
}
|
||||
|
||||
m, err = decodeMessage(reader, payloadBuf)
|
||||
|
||||
return m, err
|
||||
}
|
||||
|
||||
// decodeMessage attempts to decode a single message from the event stream reader.
|
||||
// Will return the event stream message, or error if decodeMessage fails to read
|
||||
// the message from the reader.
|
||||
func decodeMessage(reader io.Reader, payloadBuf []byte) (m Message, err error) {
|
||||
crc := crc32.New(crc32IEEETable)
|
||||
hashReader := io.TeeReader(reader, crc)
|
||||
|
||||
prelude, err := decodePrelude(hashReader, crc)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
|
||||
if prelude.HeadersLen > 0 {
|
||||
lr := io.LimitReader(hashReader, int64(prelude.HeadersLen))
|
||||
m.Headers, err = decodeHeaders(lr)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if payloadLen := prelude.PayloadLen(); payloadLen > 0 {
|
||||
buf, err := decodePayload(payloadBuf, io.LimitReader(hashReader, int64(payloadLen)))
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
m.Payload = buf
|
||||
}
|
||||
|
||||
msgCRC := crc.Sum32()
|
||||
if err := validateCRC(reader, msgCRC); err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func logMessageDecode(logger logging.Logger, msgBuf *bytes.Buffer, msg Message, decodeErr error) {
|
||||
w := bytes.NewBuffer(nil)
|
||||
defer func() { logger.Logf(logging.Debug, w.String()) }()
|
||||
|
||||
fmt.Fprintf(w, "Raw message:\n%s\n",
|
||||
hex.Dump(msgBuf.Bytes()))
|
||||
|
||||
if decodeErr != nil {
|
||||
fmt.Fprintf(w, "decodeMessage error: %v\n", decodeErr)
|
||||
return
|
||||
}
|
||||
|
||||
rawMsg, err := msg.rawMessage()
|
||||
if err != nil {
|
||||
fmt.Fprintf(w, "failed to create raw message, %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
decodedMsg := decodedMessage{
|
||||
rawMessage: rawMsg,
|
||||
Headers: decodedHeaders(msg.Headers),
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Decoded message:\n")
|
||||
encoder := json.NewEncoder(w)
|
||||
if err := encoder.Encode(decodedMsg); err != nil {
|
||||
fmt.Fprintf(w, "failed to generate decoded message, %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func decodePrelude(r io.Reader, crc hash.Hash32) (messagePrelude, error) {
|
||||
var p messagePrelude
|
||||
|
||||
var err error
|
||||
p.Length, err = decodeUint32(r)
|
||||
if err != nil {
|
||||
return messagePrelude{}, err
|
||||
}
|
||||
|
||||
p.HeadersLen, err = decodeUint32(r)
|
||||
if err != nil {
|
||||
return messagePrelude{}, err
|
||||
}
|
||||
|
||||
if err := p.ValidateLens(); err != nil {
|
||||
return messagePrelude{}, err
|
||||
}
|
||||
|
||||
preludeCRC := crc.Sum32()
|
||||
if err := validateCRC(r, preludeCRC); err != nil {
|
||||
return messagePrelude{}, err
|
||||
}
|
||||
|
||||
p.PreludeCRC = preludeCRC
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func decodePayload(buf []byte, r io.Reader) ([]byte, error) {
|
||||
w := bytes.NewBuffer(buf[0:0])
|
||||
|
||||
_, err := io.Copy(w, r)
|
||||
return w.Bytes(), err
|
||||
}
|
||||
|
||||
func decodeUint8(r io.Reader) (uint8, error) {
|
||||
type byteReader interface {
|
||||
ReadByte() (byte, error)
|
||||
}
|
||||
|
||||
if br, ok := r.(byteReader); ok {
|
||||
v, err := br.ReadByte()
|
||||
return v, err
|
||||
}
|
||||
|
||||
var b [1]byte
|
||||
_, err := io.ReadFull(r, b[:])
|
||||
return b[0], err
|
||||
}
|
||||
|
||||
func decodeUint16(r io.Reader) (uint16, error) {
|
||||
var b [2]byte
|
||||
bs := b[:]
|
||||
_, err := io.ReadFull(r, bs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return binary.BigEndian.Uint16(bs), nil
|
||||
}
|
||||
|
||||
func decodeUint32(r io.Reader) (uint32, error) {
|
||||
var b [4]byte
|
||||
bs := b[:]
|
||||
_, err := io.ReadFull(r, bs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return binary.BigEndian.Uint32(bs), nil
|
||||
}
|
||||
|
||||
func decodeUint64(r io.Reader) (uint64, error) {
|
||||
var b [8]byte
|
||||
bs := b[:]
|
||||
_, err := io.ReadFull(r, bs)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return binary.BigEndian.Uint64(bs), nil
|
||||
}
|
||||
|
||||
func validateCRC(r io.Reader, expect uint32) error {
|
||||
msgCRC, err := decodeUint32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msgCRC != expect {
|
||||
return ChecksumError{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -0,0 +1,167 @@
|
||||
package eventstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
"hash"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
)
|
||||
|
||||
// EncoderOptions is the configuration options for Encoder.
|
||||
type EncoderOptions struct {
|
||||
Logger logging.Logger
|
||||
LogMessages bool
|
||||
}
|
||||
|
||||
// Encoder provides EventStream message encoding.
|
||||
type Encoder struct {
|
||||
options EncoderOptions
|
||||
|
||||
headersBuf *bytes.Buffer
|
||||
messageBuf *bytes.Buffer
|
||||
}
|
||||
|
||||
// NewEncoder initializes and returns an Encoder to encode Event Stream
|
||||
// messages.
|
||||
func NewEncoder(optFns ...func(*EncoderOptions)) *Encoder {
|
||||
o := EncoderOptions{}
|
||||
|
||||
for _, fn := range optFns {
|
||||
fn(&o)
|
||||
}
|
||||
|
||||
return &Encoder{
|
||||
options: o,
|
||||
headersBuf: bytes.NewBuffer(nil),
|
||||
messageBuf: bytes.NewBuffer(nil),
|
||||
}
|
||||
}
|
||||
|
||||
// Encode encodes a single EventStream message to the io.Writer the Encoder
|
||||
// was created with. An error is returned if writing the message fails.
|
||||
func (e *Encoder) Encode(w io.Writer, msg Message) (err error) {
|
||||
e.headersBuf.Reset()
|
||||
e.messageBuf.Reset()
|
||||
|
||||
var writer io.Writer = e.messageBuf
|
||||
if e.options.Logger != nil && e.options.LogMessages {
|
||||
encodeMsgBuf := bytes.NewBuffer(nil)
|
||||
writer = io.MultiWriter(writer, encodeMsgBuf)
|
||||
defer func() {
|
||||
logMessageEncode(e.options.Logger, encodeMsgBuf, msg, err)
|
||||
}()
|
||||
}
|
||||
|
||||
if err = EncodeHeaders(e.headersBuf, msg.Headers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
crc := crc32.New(crc32IEEETable)
|
||||
hashWriter := io.MultiWriter(writer, crc)
|
||||
|
||||
headersLen := uint32(e.headersBuf.Len())
|
||||
payloadLen := uint32(len(msg.Payload))
|
||||
|
||||
if err = encodePrelude(hashWriter, crc, headersLen, payloadLen); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if headersLen > 0 {
|
||||
if _, err = io.Copy(hashWriter, e.headersBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if payloadLen > 0 {
|
||||
if _, err = hashWriter.Write(msg.Payload); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
msgCRC := crc.Sum32()
|
||||
if err := binary.Write(writer, binary.BigEndian, msgCRC); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(w, e.messageBuf)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func logMessageEncode(logger logging.Logger, msgBuf *bytes.Buffer, msg Message, encodeErr error) {
|
||||
w := bytes.NewBuffer(nil)
|
||||
defer func() { logger.Logf(logging.Debug, w.String()) }()
|
||||
|
||||
fmt.Fprintf(w, "Message to encode:\n")
|
||||
encoder := json.NewEncoder(w)
|
||||
if err := encoder.Encode(msg); err != nil {
|
||||
fmt.Fprintf(w, "Failed to get encoded message, %v\n", err)
|
||||
}
|
||||
|
||||
if encodeErr != nil {
|
||||
fmt.Fprintf(w, "Encode error: %v\n", encodeErr)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "Raw message:\n%s\n", hex.Dump(msgBuf.Bytes()))
|
||||
}
|
||||
|
||||
func encodePrelude(w io.Writer, crc hash.Hash32, headersLen, payloadLen uint32) error {
|
||||
p := messagePrelude{
|
||||
Length: minMsgLen + headersLen + payloadLen,
|
||||
HeadersLen: headersLen,
|
||||
}
|
||||
if err := p.ValidateLens(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := binaryWriteFields(w, binary.BigEndian,
|
||||
p.Length,
|
||||
p.HeadersLen,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.PreludeCRC = crc.Sum32()
|
||||
err = binary.Write(w, binary.BigEndian, p.PreludeCRC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeHeaders writes the header values to the writer encoded in the event
|
||||
// stream format. Returns an error if a header fails to encode.
|
||||
func EncodeHeaders(w io.Writer, headers Headers) error {
|
||||
for _, h := range headers {
|
||||
hn := headerName{
|
||||
Len: uint8(len(h.Name)),
|
||||
}
|
||||
copy(hn.Name[:hn.Len], h.Name)
|
||||
if err := hn.encode(w); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.Value.encode(w); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func binaryWriteFields(w io.Writer, order binary.ByteOrder, vs ...interface{}) error {
|
||||
for _, v := range vs {
|
||||
if err := binary.Write(w, order, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
package eventstream
|
||||
|
||||
import "fmt"
|
||||
|
||||
// LengthError provides the error for items being larger than a maximum length.
|
||||
type LengthError struct {
|
||||
Part string
|
||||
Want int
|
||||
Have int
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (e LengthError) Error() string {
|
||||
return fmt.Sprintf("%s length invalid, %d/%d, %v",
|
||||
e.Part, e.Want, e.Have, e.Value)
|
||||
}
|
||||
|
||||
// ChecksumError provides the error for message checksum invalidation errors.
|
||||
type ChecksumError struct{}
|
||||
|
||||
func (e ChecksumError) Error() string {
|
||||
return "message checksum mismatch"
|
||||
}
|
24
vendor/github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi/headers.go
generated
vendored
24
vendor/github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi/headers.go
generated
vendored
@ -0,0 +1,24 @@
|
||||
package eventstreamapi
|
||||
|
||||
// EventStream headers with specific meaning to async API functionality.
|
||||
const (
|
||||
ChunkSignatureHeader = `:chunk-signature` // chunk signature for message
|
||||
DateHeader = `:date` // Date header for signature
|
||||
ContentTypeHeader = ":content-type" // message payload content-type
|
||||
|
||||
// Message header and values
|
||||
MessageTypeHeader = `:message-type` // Identifies type of message.
|
||||
EventMessageType = `event`
|
||||
ErrorMessageType = `error`
|
||||
ExceptionMessageType = `exception`
|
||||
|
||||
// Message Events
|
||||
EventTypeHeader = `:event-type` // Identifies message event type e.g. "Stats".
|
||||
|
||||
// Message Error
|
||||
ErrorCodeHeader = `:error-code`
|
||||
ErrorMessageHeader = `:error-message`
|
||||
|
||||
// Message Exception
|
||||
ExceptionTypeHeader = `:exception-type`
|
||||
)
|
@ -0,0 +1,71 @@
|
||||
package eventstreamapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
"io"
|
||||
)
|
||||
|
||||
type eventStreamWriterKey struct{}
|
||||
|
||||
// GetInputStreamWriter returns EventTypeHeader io.PipeWriter used for the operation's input event stream.
|
||||
func GetInputStreamWriter(ctx context.Context) io.WriteCloser {
|
||||
writeCloser, _ := middleware.GetStackValue(ctx, eventStreamWriterKey{}).(io.WriteCloser)
|
||||
return writeCloser
|
||||
}
|
||||
|
||||
func setInputStreamWriter(ctx context.Context, writeCloser io.WriteCloser) context.Context {
|
||||
return middleware.WithStackValue(ctx, eventStreamWriterKey{}, writeCloser)
|
||||
}
|
||||
|
||||
// InitializeStreamWriter is a Finalize middleware initializes an in-memory pipe for sending event stream messages
|
||||
// via the HTTP request body.
|
||||
type InitializeStreamWriter struct{}
|
||||
|
||||
// AddInitializeStreamWriter adds the InitializeStreamWriter middleware to the provided stack.
|
||||
func AddInitializeStreamWriter(stack *middleware.Stack) error {
|
||||
return stack.Finalize.Add(&InitializeStreamWriter{}, middleware.After)
|
||||
}
|
||||
|
||||
// ID returns the identifier for the middleware.
|
||||
func (i *InitializeStreamWriter) ID() string {
|
||||
return "InitializeStreamWriter"
|
||||
}
|
||||
|
||||
// HandleFinalize is the middleware implementation.
|
||||
func (i *InitializeStreamWriter) HandleFinalize(
|
||||
ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler,
|
||||
) (
|
||||
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
request, ok := in.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, fmt.Errorf("unknown transport type: %T", in.Request)
|
||||
}
|
||||
|
||||
inputReader, inputWriter := io.Pipe()
|
||||
defer func() {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
_ = inputReader.Close()
|
||||
_ = inputWriter.Close()
|
||||
}()
|
||||
|
||||
request, err = request.SetStream(inputReader)
|
||||
if err != nil {
|
||||
return out, metadata, err
|
||||
}
|
||||
in.Request = request
|
||||
|
||||
ctx = setInputStreamWriter(ctx, inputWriter)
|
||||
|
||||
out, metadata, err = next.HandleFinalize(ctx, in)
|
||||
if err != nil {
|
||||
return out, metadata, err
|
||||
}
|
||||
|
||||
return out, metadata, err
|
||||
}
|
13
vendor/github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi/transport.go
generated
vendored
13
vendor/github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/eventstreamapi/transport.go
generated
vendored
@ -0,0 +1,13 @@
|
||||
//go:build go1.18
|
||||
// +build go1.18
|
||||
|
||||
package eventstreamapi
|
||||
|
||||
import smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
|
||||
// ApplyHTTPTransportFixes applies fixes to the HTTP request for proper event stream functionality.
|
||||
//
|
||||
// This operation is a no-op for Go 1.18 and above.
|
||||
func ApplyHTTPTransportFixes(r *smithyhttp.Request) error {
|
||||
return nil
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
//go:build !go1.18
|
||||
// +build !go1.18
|
||||
|
||||
package eventstreamapi
|
||||
|
||||
import smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
|
||||
// ApplyHTTPTransportFixes applies fixes to the HTTP request for proper event stream functionality.
|
||||
func ApplyHTTPTransportFixes(r *smithyhttp.Request) error {
|
||||
r.Header.Set("Expect", "100-continue")
|
||||
return nil
|
||||
}
|
6
vendor/github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/go_module_metadata.go
generated
vendored
6
vendor/github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream/go_module_metadata.go
generated
vendored
@ -0,0 +1,6 @@
|
||||
// Code generated by internal/repotools/cmd/updatemodulemeta DO NOT EDIT.
|
||||
|
||||
package eventstream
|
||||
|
||||
// goModuleVersion is the tagged release for this module
|
||||
const goModuleVersion = "1.4.10"
|
@ -0,0 +1,175 @@
|
||||
package eventstream
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Headers are a collection of EventStream header values.
|
||||
type Headers []Header
|
||||
|
||||
// Header is a single EventStream Key Value header pair.
|
||||
type Header struct {
|
||||
Name string
|
||||
Value Value
|
||||
}
|
||||
|
||||
// Set associates the name with a value. If the header name already exists in
|
||||
// the Headers the value will be replaced with the new one.
|
||||
func (hs *Headers) Set(name string, value Value) {
|
||||
var i int
|
||||
for ; i < len(*hs); i++ {
|
||||
if (*hs)[i].Name == name {
|
||||
(*hs)[i].Value = value
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
*hs = append(*hs, Header{
|
||||
Name: name, Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
// Get returns the Value associated with the header. Nil is returned if the
|
||||
// value does not exist.
|
||||
func (hs Headers) Get(name string) Value {
|
||||
for i := 0; i < len(hs); i++ {
|
||||
if h := hs[i]; h.Name == name {
|
||||
return h.Value
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Del deletes the value in the Headers if it exists.
|
||||
func (hs *Headers) Del(name string) {
|
||||
for i := 0; i < len(*hs); i++ {
|
||||
if (*hs)[i].Name == name {
|
||||
copy((*hs)[i:], (*hs)[i+1:])
|
||||
(*hs) = (*hs)[:len(*hs)-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of the headers
|
||||
func (hs Headers) Clone() Headers {
|
||||
o := make(Headers, 0, len(hs))
|
||||
for _, h := range hs {
|
||||
o.Set(h.Name, h.Value)
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
func decodeHeaders(r io.Reader) (Headers, error) {
|
||||
hs := Headers{}
|
||||
|
||||
for {
|
||||
name, err := decodeHeaderName(r)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
// EOF while getting header name means no more headers
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
value, err := decodeHeaderValue(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hs.Set(name, value)
|
||||
}
|
||||
|
||||
return hs, nil
|
||||
}
|
||||
|
||||
func decodeHeaderName(r io.Reader) (string, error) {
|
||||
var n headerName
|
||||
|
||||
var err error
|
||||
n.Len, err = decodeUint8(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
name := n.Name[:n.Len]
|
||||
if _, err := io.ReadFull(r, name); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(name), nil
|
||||
}
|
||||
|
||||
func decodeHeaderValue(r io.Reader) (Value, error) {
|
||||
var raw rawValue
|
||||
|
||||
typ, err := decodeUint8(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw.Type = valueType(typ)
|
||||
|
||||
var v Value
|
||||
|
||||
switch raw.Type {
|
||||
case trueValueType:
|
||||
v = BoolValue(true)
|
||||
case falseValueType:
|
||||
v = BoolValue(false)
|
||||
case int8ValueType:
|
||||
var tv Int8Value
|
||||
err = tv.decode(r)
|
||||
v = tv
|
||||
case int16ValueType:
|
||||
var tv Int16Value
|
||||
err = tv.decode(r)
|
||||
v = tv
|
||||
case int32ValueType:
|
||||
var tv Int32Value
|
||||
err = tv.decode(r)
|
||||
v = tv
|
||||
case int64ValueType:
|
||||
var tv Int64Value
|
||||
err = tv.decode(r)
|
||||
v = tv
|
||||
case bytesValueType:
|
||||
var tv BytesValue
|
||||
err = tv.decode(r)
|
||||
v = tv
|
||||
case stringValueType:
|
||||
var tv StringValue
|
||||
err = tv.decode(r)
|
||||
v = tv
|
||||
case timestampValueType:
|
||||
var tv TimestampValue
|
||||
err = tv.decode(r)
|
||||
v = tv
|
||||
case uuidValueType:
|
||||
var tv UUIDValue
|
||||
err = tv.decode(r)
|
||||
v = tv
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown value type %d", raw.Type))
|
||||
}
|
||||
|
||||
// Error could be EOF, let caller deal with it
|
||||
return v, err
|
||||
}
|
||||
|
||||
const maxHeaderNameLen = 255
|
||||
|
||||
type headerName struct {
|
||||
Len uint8
|
||||
Name [maxHeaderNameLen]byte
|
||||
}
|
||||
|
||||
func (v headerName) encode(w io.Writer) error {
|
||||
if err := binary.Write(w, binary.BigEndian, v.Len); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := w.Write(v.Name[:v.Len])
|
||||
return err
|
||||
}
|
@ -0,0 +1,521 @@
|
||||
package eventstream
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const maxHeaderValueLen = 1<<15 - 1 // 2^15-1 or 32KB - 1
|
||||
|
||||
// valueType is the EventStream header value type.
|
||||
type valueType uint8
|
||||
|
||||
// Header value types
|
||||
const (
|
||||
trueValueType valueType = iota
|
||||
falseValueType
|
||||
int8ValueType // Byte
|
||||
int16ValueType // Short
|
||||
int32ValueType // Integer
|
||||
int64ValueType // Long
|
||||
bytesValueType
|
||||
stringValueType
|
||||
timestampValueType
|
||||
uuidValueType
|
||||
)
|
||||
|
||||
func (t valueType) String() string {
|
||||
switch t {
|
||||
case trueValueType:
|
||||
return "bool"
|
||||
case falseValueType:
|
||||
return "bool"
|
||||
case int8ValueType:
|
||||
return "int8"
|
||||
case int16ValueType:
|
||||
return "int16"
|
||||
case int32ValueType:
|
||||
return "int32"
|
||||
case int64ValueType:
|
||||
return "int64"
|
||||
case bytesValueType:
|
||||
return "byte_array"
|
||||
case stringValueType:
|
||||
return "string"
|
||||
case timestampValueType:
|
||||
return "timestamp"
|
||||
case uuidValueType:
|
||||
return "uuid"
|
||||
default:
|
||||
return fmt.Sprintf("unknown value type %d", uint8(t))
|
||||
}
|
||||
}
|
||||
|
||||
type rawValue struct {
|
||||
Type valueType
|
||||
Len uint16 // Only set for variable length slices
|
||||
Value []byte // byte representation of value, BigEndian encoding.
|
||||
}
|
||||
|
||||
func (r rawValue) encodeScalar(w io.Writer, v interface{}) error {
|
||||
return binaryWriteFields(w, binary.BigEndian,
|
||||
r.Type,
|
||||
v,
|
||||
)
|
||||
}
|
||||
|
||||
func (r rawValue) encodeFixedSlice(w io.Writer, v []byte) error {
|
||||
binary.Write(w, binary.BigEndian, r.Type)
|
||||
|
||||
_, err := w.Write(v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r rawValue) encodeBytes(w io.Writer, v []byte) error {
|
||||
if len(v) > maxHeaderValueLen {
|
||||
return LengthError{
|
||||
Part: "header value",
|
||||
Want: maxHeaderValueLen, Have: len(v),
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
r.Len = uint16(len(v))
|
||||
|
||||
err := binaryWriteFields(w, binary.BigEndian,
|
||||
r.Type,
|
||||
r.Len,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = w.Write(v)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r rawValue) encodeString(w io.Writer, v string) error {
|
||||
if len(v) > maxHeaderValueLen {
|
||||
return LengthError{
|
||||
Part: "header value",
|
||||
Want: maxHeaderValueLen, Have: len(v),
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
r.Len = uint16(len(v))
|
||||
|
||||
type stringWriter interface {
|
||||
WriteString(string) (int, error)
|
||||
}
|
||||
|
||||
err := binaryWriteFields(w, binary.BigEndian,
|
||||
r.Type,
|
||||
r.Len,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sw, ok := w.(stringWriter); ok {
|
||||
_, err = sw.WriteString(v)
|
||||
} else {
|
||||
_, err = w.Write([]byte(v))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func decodeFixedBytesValue(r io.Reader, buf []byte) error {
|
||||
_, err := io.ReadFull(r, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func decodeBytesValue(r io.Reader) ([]byte, error) {
|
||||
var raw rawValue
|
||||
var err error
|
||||
raw.Len, err = decodeUint16(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := make([]byte, raw.Len)
|
||||
_, err = io.ReadFull(r, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func decodeStringValue(r io.Reader) (string, error) {
|
||||
v, err := decodeBytesValue(r)
|
||||
return string(v), err
|
||||
}
|
||||
|
||||
// Value represents the abstract header value.
|
||||
type Value interface {
|
||||
Get() interface{}
|
||||
String() string
|
||||
valueType() valueType
|
||||
encode(io.Writer) error
|
||||
}
|
||||
|
||||
// An BoolValue provides eventstream encoding, and representation
|
||||
// of a Go bool value.
|
||||
type BoolValue bool
|
||||
|
||||
// Get returns the underlying type
|
||||
func (v BoolValue) Get() interface{} {
|
||||
return bool(v)
|
||||
}
|
||||
|
||||
// valueType returns the EventStream header value type value.
|
||||
func (v BoolValue) valueType() valueType {
|
||||
if v {
|
||||
return trueValueType
|
||||
}
|
||||
return falseValueType
|
||||
}
|
||||
|
||||
func (v BoolValue) String() string {
|
||||
return strconv.FormatBool(bool(v))
|
||||
}
|
||||
|
||||
// encode encodes the BoolValue into an eventstream binary value
|
||||
// representation.
|
||||
func (v BoolValue) encode(w io.Writer) error {
|
||||
return binary.Write(w, binary.BigEndian, v.valueType())
|
||||
}
|
||||
|
||||
// An Int8Value provides eventstream encoding, and representation of a Go
|
||||
// int8 value.
|
||||
type Int8Value int8
|
||||
|
||||
// Get returns the underlying value.
|
||||
func (v Int8Value) Get() interface{} {
|
||||
return int8(v)
|
||||
}
|
||||
|
||||
// valueType returns the EventStream header value type value.
|
||||
func (Int8Value) valueType() valueType {
|
||||
return int8ValueType
|
||||
}
|
||||
|
||||
func (v Int8Value) String() string {
|
||||
return fmt.Sprintf("0x%02x", int8(v))
|
||||
}
|
||||
|
||||
// encode encodes the Int8Value into an eventstream binary value
|
||||
// representation.
|
||||
func (v Int8Value) encode(w io.Writer) error {
|
||||
raw := rawValue{
|
||||
Type: v.valueType(),
|
||||
}
|
||||
|
||||
return raw.encodeScalar(w, v)
|
||||
}
|
||||
|
||||
func (v *Int8Value) decode(r io.Reader) error {
|
||||
n, err := decodeUint8(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*v = Int8Value(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
// An Int16Value provides eventstream encoding, and representation of a Go
|
||||
// int16 value.
|
||||
type Int16Value int16
|
||||
|
||||
// Get returns the underlying value.
|
||||
func (v Int16Value) Get() interface{} {
|
||||
return int16(v)
|
||||
}
|
||||
|
||||
// valueType returns the EventStream header value type value.
|
||||
func (Int16Value) valueType() valueType {
|
||||
return int16ValueType
|
||||
}
|
||||
|
||||
func (v Int16Value) String() string {
|
||||
return fmt.Sprintf("0x%04x", int16(v))
|
||||
}
|
||||
|
||||
// encode encodes the Int16Value into an eventstream binary value
|
||||
// representation.
|
||||
func (v Int16Value) encode(w io.Writer) error {
|
||||
raw := rawValue{
|
||||
Type: v.valueType(),
|
||||
}
|
||||
return raw.encodeScalar(w, v)
|
||||
}
|
||||
|
||||
func (v *Int16Value) decode(r io.Reader) error {
|
||||
n, err := decodeUint16(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*v = Int16Value(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
// An Int32Value provides eventstream encoding, and representation of a Go
|
||||
// int32 value.
|
||||
type Int32Value int32
|
||||
|
||||
// Get returns the underlying value.
|
||||
func (v Int32Value) Get() interface{} {
|
||||
return int32(v)
|
||||
}
|
||||
|
||||
// valueType returns the EventStream header value type value.
|
||||
func (Int32Value) valueType() valueType {
|
||||
return int32ValueType
|
||||
}
|
||||
|
||||
func (v Int32Value) String() string {
|
||||
return fmt.Sprintf("0x%08x", int32(v))
|
||||
}
|
||||
|
||||
// encode encodes the Int32Value into an eventstream binary value
|
||||
// representation.
|
||||
func (v Int32Value) encode(w io.Writer) error {
|
||||
raw := rawValue{
|
||||
Type: v.valueType(),
|
||||
}
|
||||
return raw.encodeScalar(w, v)
|
||||
}
|
||||
|
||||
func (v *Int32Value) decode(r io.Reader) error {
|
||||
n, err := decodeUint32(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*v = Int32Value(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
// An Int64Value provides eventstream encoding, and representation of a Go
|
||||
// int64 value.
|
||||
type Int64Value int64
|
||||
|
||||
// Get returns the underlying value.
|
||||
func (v Int64Value) Get() interface{} {
|
||||
return int64(v)
|
||||
}
|
||||
|
||||
// valueType returns the EventStream header value type value.
|
||||
func (Int64Value) valueType() valueType {
|
||||
return int64ValueType
|
||||
}
|
||||
|
||||
func (v Int64Value) String() string {
|
||||
return fmt.Sprintf("0x%016x", int64(v))
|
||||
}
|
||||
|
||||
// encode encodes the Int64Value into an eventstream binary value
|
||||
// representation.
|
||||
func (v Int64Value) encode(w io.Writer) error {
|
||||
raw := rawValue{
|
||||
Type: v.valueType(),
|
||||
}
|
||||
return raw.encodeScalar(w, v)
|
||||
}
|
||||
|
||||
func (v *Int64Value) decode(r io.Reader) error {
|
||||
n, err := decodeUint64(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*v = Int64Value(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
// An BytesValue provides eventstream encoding, and representation of a Go
|
||||
// byte slice.
|
||||
type BytesValue []byte
|
||||
|
||||
// Get returns the underlying value.
|
||||
func (v BytesValue) Get() interface{} {
|
||||
return []byte(v)
|
||||
}
|
||||
|
||||
// valueType returns the EventStream header value type value.
|
||||
func (BytesValue) valueType() valueType {
|
||||
return bytesValueType
|
||||
}
|
||||
|
||||
func (v BytesValue) String() string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(v))
|
||||
}
|
||||
|
||||
// encode encodes the BytesValue into an eventstream binary value
|
||||
// representation.
|
||||
func (v BytesValue) encode(w io.Writer) error {
|
||||
raw := rawValue{
|
||||
Type: v.valueType(),
|
||||
}
|
||||
|
||||
return raw.encodeBytes(w, []byte(v))
|
||||
}
|
||||
|
||||
func (v *BytesValue) decode(r io.Reader) error {
|
||||
buf, err := decodeBytesValue(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*v = BytesValue(buf)
|
||||
return nil
|
||||
}
|
||||
|
||||
// An StringValue provides eventstream encoding, and representation of a Go
|
||||
// string.
|
||||
type StringValue string
|
||||
|
||||
// Get returns the underlying value.
|
||||
func (v StringValue) Get() interface{} {
|
||||
return string(v)
|
||||
}
|
||||
|
||||
// valueType returns the EventStream header value type value.
|
||||
func (StringValue) valueType() valueType {
|
||||
return stringValueType
|
||||
}
|
||||
|
||||
func (v StringValue) String() string {
|
||||
return string(v)
|
||||
}
|
||||
|
||||
// encode encodes the StringValue into an eventstream binary value
|
||||
// representation.
|
||||
func (v StringValue) encode(w io.Writer) error {
|
||||
raw := rawValue{
|
||||
Type: v.valueType(),
|
||||
}
|
||||
|
||||
return raw.encodeString(w, string(v))
|
||||
}
|
||||
|
||||
func (v *StringValue) decode(r io.Reader) error {
|
||||
s, err := decodeStringValue(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*v = StringValue(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// An TimestampValue provides eventstream encoding, and representation of a Go
|
||||
// timestamp.
|
||||
type TimestampValue time.Time
|
||||
|
||||
// Get returns the underlying value.
|
||||
func (v TimestampValue) Get() interface{} {
|
||||
return time.Time(v)
|
||||
}
|
||||
|
||||
// valueType returns the EventStream header value type value.
|
||||
func (TimestampValue) valueType() valueType {
|
||||
return timestampValueType
|
||||
}
|
||||
|
||||
func (v TimestampValue) epochMilli() int64 {
|
||||
nano := time.Time(v).UnixNano()
|
||||
msec := nano / int64(time.Millisecond)
|
||||
return msec
|
||||
}
|
||||
|
||||
func (v TimestampValue) String() string {
|
||||
msec := v.epochMilli()
|
||||
return strconv.FormatInt(msec, 10)
|
||||
}
|
||||
|
||||
// encode encodes the TimestampValue into an eventstream binary value
|
||||
// representation.
|
||||
func (v TimestampValue) encode(w io.Writer) error {
|
||||
raw := rawValue{
|
||||
Type: v.valueType(),
|
||||
}
|
||||
|
||||
msec := v.epochMilli()
|
||||
return raw.encodeScalar(w, msec)
|
||||
}
|
||||
|
||||
func (v *TimestampValue) decode(r io.Reader) error {
|
||||
n, err := decodeUint64(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*v = TimestampValue(timeFromEpochMilli(int64(n)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface
|
||||
func (v TimestampValue) MarshalJSON() ([]byte, error) {
|
||||
return []byte(v.String()), nil
|
||||
}
|
||||
|
||||
func timeFromEpochMilli(t int64) time.Time {
|
||||
secs := t / 1e3
|
||||
msec := t % 1e3
|
||||
return time.Unix(secs, msec*int64(time.Millisecond)).UTC()
|
||||
}
|
||||
|
||||
// An UUIDValue provides eventstream encoding, and representation of a UUID
|
||||
// value.
|
||||
type UUIDValue [16]byte
|
||||
|
||||
// Get returns the underlying value.
|
||||
func (v UUIDValue) Get() interface{} {
|
||||
return v[:]
|
||||
}
|
||||
|
||||
// valueType returns the EventStream header value type value.
|
||||
func (UUIDValue) valueType() valueType {
|
||||
return uuidValueType
|
||||
}
|
||||
|
||||
func (v UUIDValue) String() string {
|
||||
var scratch [36]byte
|
||||
|
||||
const dash = '-'
|
||||
|
||||
hex.Encode(scratch[:8], v[0:4])
|
||||
scratch[8] = dash
|
||||
hex.Encode(scratch[9:13], v[4:6])
|
||||
scratch[13] = dash
|
||||
hex.Encode(scratch[14:18], v[6:8])
|
||||
scratch[18] = dash
|
||||
hex.Encode(scratch[19:23], v[8:10])
|
||||
scratch[23] = dash
|
||||
hex.Encode(scratch[24:], v[10:])
|
||||
|
||||
return string(scratch[:])
|
||||
}
|
||||
|
||||
// encode encodes the UUIDValue into an eventstream binary value
|
||||
// representation.
|
||||
func (v UUIDValue) encode(w io.Writer) error {
|
||||
raw := rawValue{
|
||||
Type: v.valueType(),
|
||||
}
|
||||
|
||||
return raw.encodeFixedSlice(w, v[:])
|
||||
}
|
||||
|
||||
func (v *UUIDValue) decode(r io.Reader) error {
|
||||
tv := (*v)[:]
|
||||
return decodeFixedBytesValue(r, tv)
|
||||
}
|
@ -0,0 +1,117 @@
|
||||
package eventstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
)
|
||||
|
||||
const preludeLen = 8
|
||||
const preludeCRCLen = 4
|
||||
const msgCRCLen = 4
|
||||
const minMsgLen = preludeLen + preludeCRCLen + msgCRCLen
|
||||
const maxPayloadLen = 1024 * 1024 * 16 // 16MB
|
||||
const maxHeadersLen = 1024 * 128 // 128KB
|
||||
const maxMsgLen = minMsgLen + maxHeadersLen + maxPayloadLen
|
||||
|
||||
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
|
||||
|
||||
// A Message provides the eventstream message representation.
|
||||
type Message struct {
|
||||
Headers Headers
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func (m *Message) rawMessage() (rawMessage, error) {
|
||||
var raw rawMessage
|
||||
|
||||
if len(m.Headers) > 0 {
|
||||
var headers bytes.Buffer
|
||||
if err := EncodeHeaders(&headers, m.Headers); err != nil {
|
||||
return rawMessage{}, err
|
||||
}
|
||||
raw.Headers = headers.Bytes()
|
||||
raw.HeadersLen = uint32(len(raw.Headers))
|
||||
}
|
||||
|
||||
raw.Length = raw.HeadersLen + uint32(len(m.Payload)) + minMsgLen
|
||||
|
||||
hash := crc32.New(crc32IEEETable)
|
||||
binaryWriteFields(hash, binary.BigEndian, raw.Length, raw.HeadersLen)
|
||||
raw.PreludeCRC = hash.Sum32()
|
||||
|
||||
binaryWriteFields(hash, binary.BigEndian, raw.PreludeCRC)
|
||||
|
||||
if raw.HeadersLen > 0 {
|
||||
hash.Write(raw.Headers)
|
||||
}
|
||||
|
||||
// Read payload bytes and update hash for it as well.
|
||||
if len(m.Payload) > 0 {
|
||||
raw.Payload = m.Payload
|
||||
hash.Write(raw.Payload)
|
||||
}
|
||||
|
||||
raw.CRC = hash.Sum32()
|
||||
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of the message.
|
||||
func (m Message) Clone() Message {
|
||||
var payload []byte
|
||||
if m.Payload != nil {
|
||||
payload = make([]byte, len(m.Payload))
|
||||
copy(payload, m.Payload)
|
||||
}
|
||||
|
||||
return Message{
|
||||
Headers: m.Headers.Clone(),
|
||||
Payload: payload,
|
||||
}
|
||||
}
|
||||
|
||||
type messagePrelude struct {
|
||||
Length uint32
|
||||
HeadersLen uint32
|
||||
PreludeCRC uint32
|
||||
}
|
||||
|
||||
func (p messagePrelude) PayloadLen() uint32 {
|
||||
return p.Length - p.HeadersLen - minMsgLen
|
||||
}
|
||||
|
||||
func (p messagePrelude) ValidateLens() error {
|
||||
if p.Length == 0 || p.Length > maxMsgLen {
|
||||
return LengthError{
|
||||
Part: "message prelude",
|
||||
Want: maxMsgLen,
|
||||
Have: int(p.Length),
|
||||
}
|
||||
}
|
||||
if p.HeadersLen > maxHeadersLen {
|
||||
return LengthError{
|
||||
Part: "message headers",
|
||||
Want: maxHeadersLen,
|
||||
Have: int(p.HeadersLen),
|
||||
}
|
||||
}
|
||||
if payloadLen := p.PayloadLen(); payloadLen > maxPayloadLen {
|
||||
return LengthError{
|
||||
Part: "message payload",
|
||||
Want: maxPayloadLen,
|
||||
Have: int(payloadLen),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type rawMessage struct {
|
||||
messagePrelude
|
||||
|
||||
Headers []byte
|
||||
Payload []byte
|
||||
|
||||
CRC uint32
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Array represents the encoding of Query lists and sets. A Query array is a
|
||||
// representation of a list of values of a fixed type. A serialized array might
|
||||
// look like the following:
|
||||
//
|
||||
// ListName.member.1=foo
|
||||
// &ListName.member.2=bar
|
||||
// &Listname.member.3=baz
|
||||
type Array struct {
|
||||
// The query values to add the array to.
|
||||
values url.Values
|
||||
// The array's prefix, which includes the names of all parent structures
|
||||
// and ends with the name of the list. For example, the prefix might be
|
||||
// "ParentStructure.ListName". This prefix will be used to form the full
|
||||
// keys for each element in the list. For example, an entry might have the
|
||||
// key "ParentStructure.ListName.member.MemberName.1".
|
||||
//
|
||||
// While this is currently represented as a string that gets added to, it
|
||||
// could also be represented as a stack that only gets condensed into a
|
||||
// string when a finalized key is created. This could potentially reduce
|
||||
// allocations.
|
||||
prefix string
|
||||
// Whether the list is flat or not. A list that is not flat will produce the
|
||||
// following entry to the url.Values for a given entry:
|
||||
// ListName.MemberName.1=value
|
||||
// A list that is flat will produce the following:
|
||||
// ListName.1=value
|
||||
flat bool
|
||||
// The location name of the member. In most cases this should be "member".
|
||||
memberName string
|
||||
// Elements are stored in values, so we keep track of the list size here.
|
||||
size int32
|
||||
}
|
||||
|
||||
func newArray(values url.Values, prefix string, flat bool, memberName string) *Array {
|
||||
return &Array{
|
||||
values: values,
|
||||
prefix: prefix,
|
||||
flat: flat,
|
||||
memberName: memberName,
|
||||
}
|
||||
}
|
||||
|
||||
// Value adds a new element to the Query Array. Returns a Value type used to
|
||||
// encode the array element.
|
||||
func (a *Array) Value() Value {
|
||||
// Query lists start a 1, so adjust the size first
|
||||
a.size++
|
||||
prefix := a.prefix
|
||||
if !a.flat {
|
||||
prefix = fmt.Sprintf("%s.%s", prefix, a.memberName)
|
||||
}
|
||||
// Lists can't have flat members
|
||||
return newValue(a.values, fmt.Sprintf("%s.%d", prefix, a.size), false)
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/url"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Encoder is a Query encoder that supports construction of Query body
|
||||
// values using methods.
|
||||
type Encoder struct {
|
||||
// The query values that will be built up to manage encoding.
|
||||
values url.Values
|
||||
// The writer that the encoded body will be written to.
|
||||
writer io.Writer
|
||||
Value
|
||||
}
|
||||
|
||||
// NewEncoder returns a new Query body encoder
|
||||
func NewEncoder(writer io.Writer) *Encoder {
|
||||
values := url.Values{}
|
||||
return &Encoder{
|
||||
values: values,
|
||||
writer: writer,
|
||||
Value: newBaseValue(values),
|
||||
}
|
||||
}
|
||||
|
||||
// Encode returns the []byte slice representing the current
|
||||
// state of the Query encoder.
|
||||
func (e Encoder) Encode() error {
|
||||
ws, ok := e.writer.(interface{ WriteString(string) (int, error) })
|
||||
if !ok {
|
||||
// Fall back to less optimal byte slice casting if WriteString isn't available.
|
||||
ws = &wrapWriteString{writer: e.writer}
|
||||
}
|
||||
|
||||
// Get the keys and sort them to have a stable output
|
||||
keys := make([]string, 0, len(e.values))
|
||||
for k := range e.values {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
isFirstEntry := true
|
||||
for _, key := range keys {
|
||||
queryValues := e.values[key]
|
||||
escapedKey := url.QueryEscape(key)
|
||||
for _, value := range queryValues {
|
||||
if !isFirstEntry {
|
||||
if _, err := ws.WriteString(`&`); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
isFirstEntry = false
|
||||
}
|
||||
if _, err := ws.WriteString(escapedKey); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ws.WriteString(`=`); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := ws.WriteString(url.QueryEscape(value)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapWriteString wraps an io.Writer to provide a WriteString method
|
||||
// where one is not available.
|
||||
type wrapWriteString struct {
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
// WriteString writes a string to the wrapped writer by casting it to
|
||||
// a byte array first.
|
||||
func (w wrapWriteString) WriteString(v string) (int, error) {
|
||||
return w.writer.Write([]byte(v))
|
||||
}
|
@ -0,0 +1,78 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Map represents the encoding of Query maps. A Query map is a representation
|
||||
// of a mapping of arbitrary string keys to arbitrary values of a fixed type.
|
||||
// A Map differs from an Object in that the set of keys is not fixed, in that
|
||||
// the values must all be of the same type, and that map entries are ordered.
|
||||
// A serialized map might look like the following:
|
||||
//
|
||||
// MapName.entry.1.key=Foo
|
||||
// &MapName.entry.1.value=spam
|
||||
// &MapName.entry.2.key=Bar
|
||||
// &MapName.entry.2.value=eggs
|
||||
type Map struct {
|
||||
// The query values to add the map to.
|
||||
values url.Values
|
||||
// The map's prefix, which includes the names of all parent structures
|
||||
// and ends with the name of the object. For example, the prefix might be
|
||||
// "ParentStructure.MapName". This prefix will be used to form the full
|
||||
// keys for each key-value pair of the map. For example, a value might have
|
||||
// the key "ParentStructure.MapName.1.value".
|
||||
//
|
||||
// While this is currently represented as a string that gets added to, it
|
||||
// could also be represented as a stack that only gets condensed into a
|
||||
// string when a finalized key is created. This could potentially reduce
|
||||
// allocations.
|
||||
prefix string
|
||||
// Whether the map is flat or not. A map that is not flat will produce the
|
||||
// following entries to the url.Values for a given key-value pair:
|
||||
// MapName.entry.1.KeyLocationName=mykey
|
||||
// MapName.entry.1.ValueLocationName=myvalue
|
||||
// A map that is flat will produce the following:
|
||||
// MapName.1.KeyLocationName=mykey
|
||||
// MapName.1.ValueLocationName=myvalue
|
||||
flat bool
|
||||
// The location name of the key. In most cases this should be "key".
|
||||
keyLocationName string
|
||||
// The location name of the value. In most cases this should be "value".
|
||||
valueLocationName string
|
||||
// Elements are stored in values, so we keep track of the list size here.
|
||||
size int32
|
||||
}
|
||||
|
||||
func newMap(values url.Values, prefix string, flat bool, keyLocationName string, valueLocationName string) *Map {
|
||||
return &Map{
|
||||
values: values,
|
||||
prefix: prefix,
|
||||
flat: flat,
|
||||
keyLocationName: keyLocationName,
|
||||
valueLocationName: valueLocationName,
|
||||
}
|
||||
}
|
||||
|
||||
// Key adds the given named key to the Query map.
|
||||
// Returns a Value encoder that should be used to encode a Query value type.
|
||||
func (m *Map) Key(name string) Value {
|
||||
// Query lists start a 1, so adjust the size first
|
||||
m.size++
|
||||
var key string
|
||||
var value string
|
||||
if m.flat {
|
||||
key = fmt.Sprintf("%s.%d.%s", m.prefix, m.size, m.keyLocationName)
|
||||
value = fmt.Sprintf("%s.%d.%s", m.prefix, m.size, m.valueLocationName)
|
||||
} else {
|
||||
key = fmt.Sprintf("%s.entry.%d.%s", m.prefix, m.size, m.keyLocationName)
|
||||
value = fmt.Sprintf("%s.entry.%d.%s", m.prefix, m.size, m.valueLocationName)
|
||||
}
|
||||
|
||||
// The key can only be a string, so we just go ahead and set it here
|
||||
newValue(m.values, key, false).String(name)
|
||||
|
||||
// Maps can't have flat members
|
||||
return newValue(m.values, value, false)
|
||||
}
|
@ -0,0 +1,62 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithyhttp "github.com/aws/smithy-go/transport/http"
|
||||
)
|
||||
|
||||
// AddAsGetRequestMiddleware adds a middleware to the Serialize stack after the
|
||||
// operation serializer that will convert the query request body to a GET
|
||||
// operation with the query message in the HTTP request querystring.
|
||||
func AddAsGetRequestMiddleware(stack *middleware.Stack) error {
|
||||
return stack.Serialize.Insert(&asGetRequest{}, "OperationSerializer", middleware.After)
|
||||
}
|
||||
|
||||
type asGetRequest struct{}
|
||||
|
||||
func (*asGetRequest) ID() string { return "Query:AsGetRequest" }
|
||||
|
||||
func (m *asGetRequest) HandleSerialize(
|
||||
ctx context.Context, input middleware.SerializeInput, next middleware.SerializeHandler,
|
||||
) (
|
||||
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
|
||||
) {
|
||||
req, ok := input.Request.(*smithyhttp.Request)
|
||||
if !ok {
|
||||
return out, metadata, fmt.Errorf("expect smithy HTTP Request, got %T", input.Request)
|
||||
}
|
||||
|
||||
req.Method = "GET"
|
||||
|
||||
// If the stream is not set, nothing else to do.
|
||||
stream := req.GetStream()
|
||||
if stream == nil {
|
||||
return next.HandleSerialize(ctx, input)
|
||||
}
|
||||
|
||||
// Clear the stream since there will not be any body.
|
||||
req.Header.Del("Content-Type")
|
||||
req, err = req.SetStream(nil)
|
||||
if err != nil {
|
||||
return out, metadata, fmt.Errorf("unable update request body %w", err)
|
||||
}
|
||||
input.Request = req
|
||||
|
||||
// Update request query with the body's query string value.
|
||||
delim := ""
|
||||
if len(req.URL.RawQuery) != 0 {
|
||||
delim = "&"
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(stream)
|
||||
if err != nil {
|
||||
return out, metadata, fmt.Errorf("unable to get request body %w", err)
|
||||
}
|
||||
req.URL.RawQuery += delim + string(b)
|
||||
|
||||
return next.HandleSerialize(ctx, input)
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Object represents the encoding of Query structures and unions. A Query
|
||||
// object is a representation of a mapping of string keys to arbitrary
|
||||
// values where there is a fixed set of keys whose values each have their
|
||||
// own known type. A serialized object might look like the following:
|
||||
//
|
||||
// ObjectName.Foo=value
|
||||
// &ObjectName.Bar=5
|
||||
type Object struct {
|
||||
// The query values to add the object to.
|
||||
values url.Values
|
||||
// The object's prefix, which includes the names of all parent structures
|
||||
// and ends with the name of the object. For example, the prefix might be
|
||||
// "ParentStructure.ObjectName". This prefix will be used to form the full
|
||||
// keys for each member of the object. For example, a member might have the
|
||||
// key "ParentStructure.ObjectName.MemberName".
|
||||
//
|
||||
// While this is currently represented as a string that gets added to, it
|
||||
// could also be represented as a stack that only gets condensed into a
|
||||
// string when a finalized key is created. This could potentially reduce
|
||||
// allocations.
|
||||
prefix string
|
||||
}
|
||||
|
||||
func newObject(values url.Values, prefix string) *Object {
|
||||
return &Object{
|
||||
values: values,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Key adds the given named key to the Query object.
|
||||
// Returns a Value encoder that should be used to encode a Query value type.
|
||||
func (o *Object) Key(name string) Value {
|
||||
return o.key(name, false)
|
||||
}
|
||||
|
||||
// FlatKey adds the given named key to the Query object.
|
||||
// Returns a Value encoder that should be used to encode a Query value type. The
|
||||
// value will be flattened if it is a map or array.
|
||||
func (o *Object) FlatKey(name string) Value {
|
||||
return o.key(name, true)
|
||||
}
|
||||
|
||||
func (o *Object) key(name string, flatValue bool) Value {
|
||||
if o.prefix != "" {
|
||||
return newValue(o.values, fmt.Sprintf("%s.%s", o.prefix, name), flatValue)
|
||||
}
|
||||
return newValue(o.values, name, flatValue)
|
||||
}
|
@ -0,0 +1,106 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"net/url"
|
||||
|
||||
"github.com/aws/smithy-go/encoding/httpbinding"
|
||||
)
|
||||
|
||||
// Value represents a Query Value type.
|
||||
type Value struct {
|
||||
// The query values to add the value to.
|
||||
values url.Values
|
||||
// The value's key, which will form the prefix for complex types.
|
||||
key string
|
||||
// Whether the value should be flattened or not if it's a flattenable type.
|
||||
flat bool
|
||||
queryValue httpbinding.QueryValue
|
||||
}
|
||||
|
||||
func newValue(values url.Values, key string, flat bool) Value {
|
||||
return Value{
|
||||
values: values,
|
||||
key: key,
|
||||
flat: flat,
|
||||
queryValue: httpbinding.NewQueryValue(values, key, false),
|
||||
}
|
||||
}
|
||||
|
||||
func newBaseValue(values url.Values) Value {
|
||||
return Value{
|
||||
values: values,
|
||||
queryValue: httpbinding.NewQueryValue(nil, "", false),
|
||||
}
|
||||
}
|
||||
|
||||
// Array returns a new Array encoder.
|
||||
func (qv Value) Array(locationName string) *Array {
|
||||
return newArray(qv.values, qv.key, qv.flat, locationName)
|
||||
}
|
||||
|
||||
// Object returns a new Object encoder.
|
||||
func (qv Value) Object() *Object {
|
||||
return newObject(qv.values, qv.key)
|
||||
}
|
||||
|
||||
// Map returns a new Map encoder.
|
||||
func (qv Value) Map(keyLocationName string, valueLocationName string) *Map {
|
||||
return newMap(qv.values, qv.key, qv.flat, keyLocationName, valueLocationName)
|
||||
}
|
||||
|
||||
// Base64EncodeBytes encodes v as a base64 query string value.
|
||||
// This is intended to enable compatibility with the JSON encoder.
|
||||
func (qv Value) Base64EncodeBytes(v []byte) {
|
||||
qv.queryValue.Blob(v)
|
||||
}
|
||||
|
||||
// Boolean encodes v as a query string value
|
||||
func (qv Value) Boolean(v bool) {
|
||||
qv.queryValue.Boolean(v)
|
||||
}
|
||||
|
||||
// String encodes v as a query string value
|
||||
func (qv Value) String(v string) {
|
||||
qv.queryValue.String(v)
|
||||
}
|
||||
|
||||
// Byte encodes v as a query string value
|
||||
func (qv Value) Byte(v int8) {
|
||||
qv.queryValue.Byte(v)
|
||||
}
|
||||
|
||||
// Short encodes v as a query string value
|
||||
func (qv Value) Short(v int16) {
|
||||
qv.queryValue.Short(v)
|
||||
}
|
||||
|
||||
// Integer encodes v as a query string value
|
||||
func (qv Value) Integer(v int32) {
|
||||
qv.queryValue.Integer(v)
|
||||
}
|
||||
|
||||
// Long encodes v as a query string value
|
||||
func (qv Value) Long(v int64) {
|
||||
qv.queryValue.Long(v)
|
||||
}
|
||||
|
||||
// Float encodes v as a query string value
|
||||
func (qv Value) Float(v float32) {
|
||||
qv.queryValue.Float(v)
|
||||
}
|
||||
|
||||
// Double encodes v as a query string value
|
||||
func (qv Value) Double(v float64) {
|
||||
qv.queryValue.Double(v)
|
||||
}
|
||||
|
||||
// BigInteger encodes v as a query string value
|
||||
func (qv Value) BigInteger(v *big.Int) {
|
||||
qv.queryValue.BigInteger(v)
|
||||
}
|
||||
|
||||
// BigDecimal encodes v as a query string value
|
||||
func (qv Value) BigDecimal(v *big.Float) {
|
||||
qv.queryValue.BigDecimal(v)
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
package restjson
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/smithy-go"
|
||||
)
|
||||
|
||||
// GetErrorInfo util looks for code, __type, and message members in the
|
||||
// json body. These members are optionally available, and the function
|
||||
// returns the value of member if it is available. This function is useful to
|
||||
// identify the error code, msg in a REST JSON error response.
|
||||
func GetErrorInfo(decoder *json.Decoder) (errorType string, message string, err error) {
|
||||
var errInfo struct {
|
||||
Code string
|
||||
Type string `json:"__type"`
|
||||
Message string
|
||||
}
|
||||
|
||||
err = decoder.Decode(&errInfo)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return errorType, message, nil
|
||||
}
|
||||
return errorType, message, err
|
||||
}
|
||||
|
||||
// assign error type
|
||||
if len(errInfo.Code) != 0 {
|
||||
errorType = errInfo.Code
|
||||
} else if len(errInfo.Type) != 0 {
|
||||
errorType = errInfo.Type
|
||||
}
|
||||
|
||||
// assign error message
|
||||
if len(errInfo.Message) != 0 {
|
||||
message = errInfo.Message
|
||||
}
|
||||
|
||||
// sanitize error
|
||||
if len(errorType) != 0 {
|
||||
errorType = SanitizeErrorCode(errorType)
|
||||
}
|
||||
|
||||
return errorType, message, nil
|
||||
}
|
||||
|
||||
// SanitizeErrorCode sanitizes the errorCode string .
|
||||
// The rule for sanitizing is if a `:` character is present, then take only the
|
||||
// contents before the first : character in the value.
|
||||
// If a # character is present, then take only the contents after the
|
||||
// first # character in the value.
|
||||
func SanitizeErrorCode(errorCode string) string {
|
||||
if strings.ContainsAny(errorCode, ":") {
|
||||
errorCode = strings.SplitN(errorCode, ":", 2)[0]
|
||||
}
|
||||
|
||||
if strings.ContainsAny(errorCode, "#") {
|
||||
errorCode = strings.SplitN(errorCode, "#", 2)[1]
|
||||
}
|
||||
|
||||
return errorCode
|
||||
}
|
||||
|
||||
// GetSmithyGenericAPIError returns smithy generic api error and an error interface.
|
||||
// Takes in json decoder, and error Code string as args. The function retrieves error message
|
||||
// and error code from the decoder body. If errorCode of length greater than 0 is passed in as
|
||||
// an argument, it is used instead.
|
||||
func GetSmithyGenericAPIError(decoder *json.Decoder, errorCode string) (*smithy.GenericAPIError, error) {
|
||||
errorType, message, err := GetErrorInfo(decoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(errorCode) == 0 {
|
||||
errorCode = errorType
|
||||
}
|
||||
|
||||
return &smithy.GenericAPIError{
|
||||
Code: errorCode,
|
||||
Message: message,
|
||||
}, nil
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
package xml
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ErrorComponents represents the error response fields
|
||||
// that will be deserialized from an xml error response body
|
||||
type ErrorComponents struct {
|
||||
Code string
|
||||
Message string
|
||||
RequestID string
|
||||
}
|
||||
|
||||
// GetErrorResponseComponents returns the error fields from an xml error response body
|
||||
func GetErrorResponseComponents(r io.Reader, noErrorWrapping bool) (ErrorComponents, error) {
|
||||
if noErrorWrapping {
|
||||
var errResponse noWrappedErrorResponse
|
||||
if err := xml.NewDecoder(r).Decode(&errResponse); err != nil && err != io.EOF {
|
||||
return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response: %w", err)
|
||||
}
|
||||
return ErrorComponents{
|
||||
Code: errResponse.Code,
|
||||
Message: errResponse.Message,
|
||||
RequestID: errResponse.RequestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var errResponse wrappedErrorResponse
|
||||
if err := xml.NewDecoder(r).Decode(&errResponse); err != nil && err != io.EOF {
|
||||
return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response: %w", err)
|
||||
}
|
||||
return ErrorComponents{
|
||||
Code: errResponse.Code,
|
||||
Message: errResponse.Message,
|
||||
RequestID: errResponse.RequestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// noWrappedErrorResponse represents the error response body with
|
||||
// no internal <Error></Error wrapping
|
||||
type noWrappedErrorResponse struct {
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
RequestID string `xml:"RequestId"`
|
||||
}
|
||||
|
||||
// wrappedErrorResponse represents the error response body
|
||||
// wrapped within <Error>...</Error>
|
||||
type wrappedErrorResponse struct {
|
||||
Code string `xml:"Error>Code"`
|
||||
Message string `xml:"Error>Message"`
|
||||
RequestID string `xml:"RequestId"`
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TokenBucket provides a concurrency safe utility for adding and removing
|
||||
// tokens from the available token bucket.
|
||||
type TokenBucket struct {
|
||||
remainingTokens uint
|
||||
maxCapacity uint
|
||||
minCapacity uint
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTokenBucket returns an initialized TokenBucket with the capacity
|
||||
// specified.
|
||||
func NewTokenBucket(i uint) *TokenBucket {
|
||||
return &TokenBucket{
|
||||
remainingTokens: i,
|
||||
maxCapacity: i,
|
||||
minCapacity: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve attempts to reduce the available tokens by the amount requested. If
|
||||
// there are tokens available true will be returned along with the number of
|
||||
// available tokens remaining. If amount requested is larger than the available
|
||||
// capacity, false will be returned along with the available capacity. If the
|
||||
// amount is less than the available capacity, the capacity will be reduced by
|
||||
// that amount, and the remaining capacity and true will be returned.
|
||||
func (t *TokenBucket) Retrieve(amount uint) (available uint, retrieved bool) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if amount > t.remainingTokens {
|
||||
return t.remainingTokens, false
|
||||
}
|
||||
|
||||
t.remainingTokens -= amount
|
||||
return t.remainingTokens, true
|
||||
}
|
||||
|
||||
// Refund returns the amount of tokens back to the available token bucket, up
|
||||
// to the initial capacity.
|
||||
func (t *TokenBucket) Refund(amount uint) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Capacity cannot exceed max capacity.
|
||||
t.remainingTokens = uintMin(t.remainingTokens+amount, t.maxCapacity)
|
||||
}
|
||||
|
||||
// Capacity returns the maximum capacity of tokens that the bucket could
|
||||
// contain.
|
||||
func (t *TokenBucket) Capacity() uint {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
return t.maxCapacity
|
||||
}
|
||||
|
||||
// Remaining returns the number of tokens that remaining in the bucket.
|
||||
func (t *TokenBucket) Remaining() uint {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
return t.remainingTokens
|
||||
}
|
||||
|
||||
// Resize adjusts the size of the token bucket. Returns the capacity remaining.
|
||||
func (t *TokenBucket) Resize(size uint) uint {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.maxCapacity = uintMax(size, t.minCapacity)
|
||||
|
||||
// Capacity needs to be capped at max capacity, if max size reduced.
|
||||
t.remainingTokens = uintMin(t.remainingTokens, t.maxCapacity)
|
||||
|
||||
return t.remainingTokens
|
||||
}
|
||||
|
||||
func uintMin(a, b uint) uint {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func uintMax(a, b uint) uint {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
@ -0,0 +1,87 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type rateToken struct {
|
||||
tokenCost uint
|
||||
bucket *TokenBucket
|
||||
}
|
||||
|
||||
func (t rateToken) release() error {
|
||||
t.bucket.Refund(t.tokenCost)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TokenRateLimit provides a Token Bucket RateLimiter implementation
|
||||
// that limits the overall number of retry attempts that can be made across
|
||||
// operation invocations.
|
||||
type TokenRateLimit struct {
|
||||
bucket *TokenBucket
|
||||
}
|
||||
|
||||
// NewTokenRateLimit returns an TokenRateLimit with default values.
|
||||
// Functional options can configure the retry rate limiter.
|
||||
func NewTokenRateLimit(tokens uint) *TokenRateLimit {
|
||||
return &TokenRateLimit{
|
||||
bucket: NewTokenBucket(tokens),
|
||||
}
|
||||
}
|
||||
|
||||
func isTimeoutError(error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type canceledError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (c canceledError) CanceledError() bool { return true }
|
||||
func (c canceledError) Unwrap() error { return c.Err }
|
||||
func (c canceledError) Error() string {
|
||||
return fmt.Sprintf("canceled, %v", c.Err)
|
||||
}
|
||||
|
||||
// GetToken may cause a available pool of retry quota to be
|
||||
// decremented. Will return an error if the decremented value can not be
|
||||
// reduced from the retry quota.
|
||||
func (l *TokenRateLimit) GetToken(ctx context.Context, cost uint) (func() error, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, canceledError{Err: ctx.Err()}
|
||||
default:
|
||||
}
|
||||
if avail, ok := l.bucket.Retrieve(cost); !ok {
|
||||
return nil, QuotaExceededError{Available: avail, Requested: cost}
|
||||
}
|
||||
|
||||
return rateToken{
|
||||
tokenCost: cost,
|
||||
bucket: l.bucket,
|
||||
}.release, nil
|
||||
}
|
||||
|
||||
// AddTokens increments the token bucket by a fixed amount.
|
||||
func (l *TokenRateLimit) AddTokens(v uint) error {
|
||||
l.bucket.Refund(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remaining returns the number of remaining tokens in the bucket.
|
||||
func (l *TokenRateLimit) Remaining() uint {
|
||||
return l.bucket.Remaining()
|
||||
}
|
||||
|
||||
// QuotaExceededError provides the SDK error when the retries for a given
|
||||
// token bucket have been exhausted.
|
||||
type QuotaExceededError struct {
|
||||
Available uint
|
||||
Requested uint
|
||||
}
|
||||
|
||||
func (e QuotaExceededError) Error() string {
|
||||
return fmt.Sprintf("retry quota exceeded, %d available, %d requested",
|
||||
e.Available, e.Requested)
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// TODO remove replace with smithy.CanceledError
|
||||
|
||||
// RequestCanceledError is the error that will be returned by an API request
|
||||
// that was canceled. Requests given a Context may return this error when
|
||||
// canceled.
|
||||
type RequestCanceledError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// CanceledError returns true to satisfy interfaces checking for canceled errors.
|
||||
func (*RequestCanceledError) CanceledError() bool { return true }
|
||||
|
||||
// Unwrap returns the underlying error, if there was one.
|
||||
func (e *RequestCanceledError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
func (e *RequestCanceledError) Error() string {
|
||||
return fmt.Sprintf("request canceled, %v", e.Err)
|
||||
}
|
@ -0,0 +1,156 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/internal/sdk"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultRequestCost is the cost of a single request from the adaptive
|
||||
// rate limited token bucket.
|
||||
DefaultRequestCost uint = 1
|
||||
)
|
||||
|
||||
// DefaultThrottles provides the set of errors considered throttle errors that
|
||||
// are checked by default.
|
||||
var DefaultThrottles = []IsErrorThrottle{
|
||||
ThrottleErrorCode{
|
||||
Codes: DefaultThrottleErrorCodes,
|
||||
},
|
||||
}
|
||||
|
||||
// AdaptiveModeOptions provides the functional options for configuring the
|
||||
// adaptive retry mode, and delay behavior.
|
||||
type AdaptiveModeOptions struct {
|
||||
// If the adaptive token bucket is empty, when an attempt will be made
|
||||
// AdaptiveMode will sleep until a token is available. This can occur when
|
||||
// attempts fail with throttle errors. Use this option to disable the sleep
|
||||
// until token is available, and return error immediately.
|
||||
FailOnNoAttemptTokens bool
|
||||
|
||||
// The cost of an attempt from the AdaptiveMode's adaptive token bucket.
|
||||
RequestCost uint
|
||||
|
||||
// Set of strategies to determine if the attempt failed due to a throttle
|
||||
// error.
|
||||
//
|
||||
// It is safe to append to this list in NewAdaptiveMode's functional options.
|
||||
Throttles []IsErrorThrottle
|
||||
|
||||
// Set of options for standard retry mode that AdaptiveMode is built on top
|
||||
// of. AdaptiveMode may apply its own defaults to Standard retry mode that
|
||||
// are different than the defaults of NewStandard. Use these options to
|
||||
// override the default options.
|
||||
StandardOptions []func(*StandardOptions)
|
||||
}
|
||||
|
||||
// AdaptiveMode provides an experimental retry strategy that expands on the
|
||||
// Standard retry strategy, adding client attempt rate limits. The attempt rate
|
||||
// limit is initially unrestricted, but becomes restricted when the attempt
|
||||
// fails with for a throttle error. When restricted AdaptiveMode may need to
|
||||
// sleep before an attempt is made, if too many throttles have been received.
|
||||
// AdaptiveMode's sleep can be canceled with context cancel. Set
|
||||
// AdaptiveModeOptions FailOnNoAttemptTokens to change the behavior from sleep,
|
||||
// to fail fast.
|
||||
//
|
||||
// Eventually unrestricted attempt rate limit will be restored once attempts no
|
||||
// longer are failing due to throttle errors.
|
||||
type AdaptiveMode struct {
|
||||
options AdaptiveModeOptions
|
||||
throttles IsErrorThrottles
|
||||
|
||||
retryer aws.RetryerV2
|
||||
rateLimit *adaptiveRateLimit
|
||||
}
|
||||
|
||||
// NewAdaptiveMode returns an initialized AdaptiveMode retry strategy.
|
||||
func NewAdaptiveMode(optFns ...func(*AdaptiveModeOptions)) *AdaptiveMode {
|
||||
o := AdaptiveModeOptions{
|
||||
RequestCost: DefaultRequestCost,
|
||||
Throttles: append([]IsErrorThrottle{}, DefaultThrottles...),
|
||||
}
|
||||
for _, fn := range optFns {
|
||||
fn(&o)
|
||||
}
|
||||
|
||||
return &AdaptiveMode{
|
||||
options: o,
|
||||
throttles: IsErrorThrottles(o.Throttles),
|
||||
retryer: NewStandard(o.StandardOptions...),
|
||||
rateLimit: newAdaptiveRateLimit(),
|
||||
}
|
||||
}
|
||||
|
||||
// IsErrorRetryable returns if the failed attempt is retryable. This check
|
||||
// should determine if the error can be retried, or if the error is
|
||||
// terminal.
|
||||
func (a *AdaptiveMode) IsErrorRetryable(err error) bool {
|
||||
return a.retryer.IsErrorRetryable(err)
|
||||
}
|
||||
|
||||
// MaxAttempts returns the maximum number of attempts that can be made for
|
||||
// an attempt before failing. A value of 0 implies that the attempt should
|
||||
// be retried until it succeeds if the errors are retryable.
|
||||
func (a *AdaptiveMode) MaxAttempts() int {
|
||||
return a.retryer.MaxAttempts()
|
||||
}
|
||||
|
||||
// RetryDelay returns the delay that should be used before retrying the
|
||||
// attempt. Will return error if the if the delay could not be determined.
|
||||
func (a *AdaptiveMode) RetryDelay(attempt int, opErr error) (
|
||||
time.Duration, error,
|
||||
) {
|
||||
return a.retryer.RetryDelay(attempt, opErr)
|
||||
}
|
||||
|
||||
// GetRetryToken attempts to deduct the retry cost from the retry token pool.
|
||||
// Returning the token release function, or error.
|
||||
func (a *AdaptiveMode) GetRetryToken(ctx context.Context, opErr error) (
|
||||
releaseToken func(error) error, err error,
|
||||
) {
|
||||
return a.retryer.GetRetryToken(ctx, opErr)
|
||||
}
|
||||
|
||||
// GetInitialToken returns the initial attempt token that can increment the
|
||||
// retry token pool if the attempt is successful.
|
||||
//
|
||||
// Deprecated: This method does not provide a way to block using Context,
|
||||
// nor can it return an error. Use RetryerV2, and GetAttemptToken instead. Only
|
||||
// present to implement Retryer interface.
|
||||
func (a *AdaptiveMode) GetInitialToken() (releaseToken func(error) error) {
|
||||
return nopRelease
|
||||
}
|
||||
|
||||
// GetAttemptToken returns the attempt token that can be used to rate limit
|
||||
// attempt calls. Will be used by the SDK's retry package's Attempt
|
||||
// middleware to get an attempt token prior to calling the temp and releasing
|
||||
// the attempt token after the attempt has been made.
|
||||
func (a *AdaptiveMode) GetAttemptToken(ctx context.Context) (func(error) error, error) {
|
||||
for {
|
||||
acquiredToken, waitTryAgain := a.rateLimit.AcquireToken(a.options.RequestCost)
|
||||
if acquiredToken {
|
||||
break
|
||||
}
|
||||
if a.options.FailOnNoAttemptTokens {
|
||||
return nil, fmt.Errorf(
|
||||
"unable to get attempt token, and FailOnNoAttemptTokens enables")
|
||||
}
|
||||
|
||||
if err := sdk.SleepWithContext(ctx, waitTryAgain); err != nil {
|
||||
return nil, fmt.Errorf("failed to wait for token to be available, %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return a.handleResponse, nil
|
||||
}
|
||||
|
||||
func (a *AdaptiveMode) handleResponse(opErr error) error {
|
||||
throttled := a.throttles.IsErrorThrottle(opErr).Bool()
|
||||
|
||||
a.rateLimit.Update(throttled)
|
||||
return nil
|
||||
}
|
@ -0,0 +1,158 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/internal/sdk"
|
||||
)
|
||||
|
||||
type adaptiveRateLimit struct {
|
||||
tokenBucketEnabled bool
|
||||
|
||||
smooth float64
|
||||
beta float64
|
||||
scaleConstant float64
|
||||
minFillRate float64
|
||||
|
||||
fillRate float64
|
||||
calculatedRate float64
|
||||
lastRefilled time.Time
|
||||
measuredTxRate float64
|
||||
lastTxRateBucket float64
|
||||
requestCount int64
|
||||
lastMaxRate float64
|
||||
lastThrottleTime time.Time
|
||||
timeWindow float64
|
||||
|
||||
tokenBucket *adaptiveTokenBucket
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newAdaptiveRateLimit() *adaptiveRateLimit {
|
||||
now := sdk.NowTime()
|
||||
return &adaptiveRateLimit{
|
||||
smooth: 0.8,
|
||||
beta: 0.7,
|
||||
scaleConstant: 0.4,
|
||||
|
||||
minFillRate: 0.5,
|
||||
|
||||
lastTxRateBucket: math.Floor(timeFloat64Seconds(now)),
|
||||
lastThrottleTime: now,
|
||||
|
||||
tokenBucket: newAdaptiveTokenBucket(0),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *adaptiveRateLimit) Enable(v bool) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
a.tokenBucketEnabled = v
|
||||
}
|
||||
|
||||
func (a *adaptiveRateLimit) AcquireToken(amount uint) (
|
||||
tokenAcquired bool, waitTryAgain time.Duration,
|
||||
) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if !a.tokenBucketEnabled {
|
||||
return true, 0
|
||||
}
|
||||
|
||||
a.tokenBucketRefill()
|
||||
|
||||
available, ok := a.tokenBucket.Retrieve(float64(amount))
|
||||
if !ok {
|
||||
waitDur := float64Seconds((float64(amount) - available) / a.fillRate)
|
||||
return false, waitDur
|
||||
}
|
||||
|
||||
return true, 0
|
||||
}
|
||||
|
||||
func (a *adaptiveRateLimit) Update(throttled bool) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
a.updateMeasuredRate()
|
||||
|
||||
if throttled {
|
||||
rateToUse := a.measuredTxRate
|
||||
if a.tokenBucketEnabled {
|
||||
rateToUse = math.Min(a.measuredTxRate, a.fillRate)
|
||||
}
|
||||
|
||||
a.lastMaxRate = rateToUse
|
||||
a.calculateTimeWindow()
|
||||
a.lastThrottleTime = sdk.NowTime()
|
||||
a.calculatedRate = a.cubicThrottle(rateToUse)
|
||||
a.tokenBucketEnabled = true
|
||||
} else {
|
||||
a.calculateTimeWindow()
|
||||
a.calculatedRate = a.cubicSuccess(sdk.NowTime())
|
||||
}
|
||||
|
||||
newRate := math.Min(a.calculatedRate, 2*a.measuredTxRate)
|
||||
a.tokenBucketUpdateRate(newRate)
|
||||
}
|
||||
|
||||
func (a *adaptiveRateLimit) cubicSuccess(t time.Time) float64 {
|
||||
dt := secondsFloat64(t.Sub(a.lastThrottleTime))
|
||||
return (a.scaleConstant * math.Pow(dt-a.timeWindow, 3)) + a.lastMaxRate
|
||||
}
|
||||
|
||||
func (a *adaptiveRateLimit) cubicThrottle(rateToUse float64) float64 {
|
||||
return rateToUse * a.beta
|
||||
}
|
||||
|
||||
func (a *adaptiveRateLimit) calculateTimeWindow() {
|
||||
a.timeWindow = math.Pow((a.lastMaxRate*(1.-a.beta))/a.scaleConstant, 1./3.)
|
||||
}
|
||||
|
||||
func (a *adaptiveRateLimit) tokenBucketUpdateRate(newRPS float64) {
|
||||
a.tokenBucketRefill()
|
||||
a.fillRate = math.Max(newRPS, a.minFillRate)
|
||||
a.tokenBucket.Resize(newRPS)
|
||||
}
|
||||
|
||||
func (a *adaptiveRateLimit) updateMeasuredRate() {
|
||||
now := sdk.NowTime()
|
||||
timeBucket := math.Floor(timeFloat64Seconds(now)*2.) / 2.
|
||||
a.requestCount++
|
||||
|
||||
if timeBucket > a.lastTxRateBucket {
|
||||
currentRate := float64(a.requestCount) / (timeBucket - a.lastTxRateBucket)
|
||||
a.measuredTxRate = (currentRate * a.smooth) + (a.measuredTxRate * (1. - a.smooth))
|
||||
a.requestCount = 0
|
||||
a.lastTxRateBucket = timeBucket
|
||||
}
|
||||
}
|
||||
|
||||
func (a *adaptiveRateLimit) tokenBucketRefill() {
|
||||
now := sdk.NowTime()
|
||||
if a.lastRefilled.IsZero() {
|
||||
a.lastRefilled = now
|
||||
return
|
||||
}
|
||||
|
||||
fillAmount := secondsFloat64(now.Sub(a.lastRefilled)) * a.fillRate
|
||||
a.tokenBucket.Refund(fillAmount)
|
||||
a.lastRefilled = now
|
||||
}
|
||||
|
||||
func float64Seconds(v float64) time.Duration {
|
||||
return time.Duration(v * float64(time.Second))
|
||||
}
|
||||
|
||||
func secondsFloat64(v time.Duration) float64 {
|
||||
return float64(v) / float64(time.Second)
|
||||
}
|
||||
|
||||
func timeFloat64Seconds(v time.Time) float64 {
|
||||
return float64(v.UnixNano()) / float64(time.Second)
|
||||
}
|
@ -0,0 +1,83 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// adaptiveTokenBucket provides a concurrency safe utility for adding and
|
||||
// removing tokens from the available token bucket.
|
||||
type adaptiveTokenBucket struct {
|
||||
remainingTokens float64
|
||||
maxCapacity float64
|
||||
minCapacity float64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// newAdaptiveTokenBucket returns an initialized adaptiveTokenBucket with the
|
||||
// capacity specified.
|
||||
func newAdaptiveTokenBucket(i float64) *adaptiveTokenBucket {
|
||||
return &adaptiveTokenBucket{
|
||||
remainingTokens: i,
|
||||
maxCapacity: i,
|
||||
minCapacity: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve attempts to reduce the available tokens by the amount requested. If
|
||||
// there are tokens available true will be returned along with the number of
|
||||
// available tokens remaining. If amount requested is larger than the available
|
||||
// capacity, false will be returned along with the available capacity. If the
|
||||
// amount is less than the available capacity, the capacity will be reduced by
|
||||
// that amount, and the remaining capacity and true will be returned.
|
||||
func (t *adaptiveTokenBucket) Retrieve(amount float64) (available float64, retrieved bool) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if amount > t.remainingTokens {
|
||||
return t.remainingTokens, false
|
||||
}
|
||||
|
||||
t.remainingTokens -= amount
|
||||
return t.remainingTokens, true
|
||||
}
|
||||
|
||||
// Refund returns the amount of tokens back to the available token bucket, up
|
||||
// to the initial capacity.
|
||||
func (t *adaptiveTokenBucket) Refund(amount float64) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
// Capacity cannot exceed max capacity.
|
||||
t.remainingTokens = math.Min(t.remainingTokens+amount, t.maxCapacity)
|
||||
}
|
||||
|
||||
// Capacity returns the maximum capacity of tokens that the bucket could
|
||||
// contain.
|
||||
func (t *adaptiveTokenBucket) Capacity() float64 {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
return t.maxCapacity
|
||||
}
|
||||
|
||||
// Remaining returns the number of tokens that remaining in the bucket.
|
||||
func (t *adaptiveTokenBucket) Remaining() float64 {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
return t.remainingTokens
|
||||
}
|
||||
|
||||
// Resize adjusts the size of the token bucket. Returns the capacity remaining.
|
||||
func (t *adaptiveTokenBucket) Resize(size float64) float64 {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.maxCapacity = math.Max(size, t.minCapacity)
|
||||
|
||||
// Capacity needs to be capped at max capacity, if max size reduced.
|
||||
t.remainingTokens = math.Min(t.remainingTokens, t.maxCapacity)
|
||||
|
||||
return t.remainingTokens
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
// Package retry provides interfaces and implementations for SDK request retry behavior.
|
||||
//
|
||||
// # Retryer Interface and Implementations
|
||||
//
|
||||
// This package defines Retryer interface that is used to either implement custom retry behavior
|
||||
// or to extend the existing retry implementations provided by the SDK. This package provides a single
|
||||
// retry implementation: Standard.
|
||||
//
|
||||
// # Standard
|
||||
//
|
||||
// Standard is the default retryer implementation used by service clients. The standard retryer is a rate limited
|
||||
// retryer that has a configurable max attempts to limit the number of retry attempts when a retryable error occurs.
|
||||
// In addition, the retryer uses a configurable token bucket to rate limit the retry attempts across the client,
|
||||
// and uses an additional delay policy to limit the time between a requests subsequent attempts.
|
||||
//
|
||||
// By default the standard retryer uses the DefaultRetryables slice of IsErrorRetryable types to determine whether
|
||||
// a given error is retryable. By default this list of retryables includes the following:
|
||||
// - Retrying errors that implement the RetryableError method, and return true.
|
||||
// - Connection Errors
|
||||
// - Errors that implement a ConnectionError, Temporary, or Timeout method that return true.
|
||||
// - Connection Reset Errors.
|
||||
// - net.OpErr types that are dialing errors or are temporary.
|
||||
// - HTTP Status Codes: 500, 502, 503, and 504.
|
||||
// - API Error Codes
|
||||
// - RequestTimeout, RequestTimeoutException
|
||||
// - Throttling, ThrottlingException, ThrottledException, RequestThrottledException, TooManyRequestsException,
|
||||
// RequestThrottled, SlowDown, EC2ThrottledException
|
||||
// - ProvisionedThroughputExceededException, RequestLimitExceeded, BandwidthLimitExceeded, LimitExceededException
|
||||
// - TransactionInProgressException, PriorRequestNotComplete
|
||||
//
|
||||
// The standard retryer will not retry a request in the event if the context associated with the request
|
||||
// has been cancelled. Applications must handle this case explicitly if they wish to retry with a different context
|
||||
// value.
|
||||
//
|
||||
// You can configure the standard retryer implementation to fit your applications by constructing a standard retryer
|
||||
// using the NewStandard function, and providing one more functional argument that mutate the StandardOptions
|
||||
// structure. StandardOptions provides the ability to modify the token bucket rate limiter, retryable error conditions,
|
||||
// and the retry delay policy.
|
||||
//
|
||||
// For example to modify the default retry attempts for the standard retryer:
|
||||
//
|
||||
// // configure the custom retryer
|
||||
// customRetry := retry.NewStandard(func(o *retry.StandardOptions) {
|
||||
// o.MaxAttempts = 5
|
||||
// })
|
||||
//
|
||||
// // create a service client with the retryer
|
||||
// s3.NewFromConfig(cfg, func(o *s3.Options) {
|
||||
// o.Retryer = customRetry
|
||||
// })
|
||||
//
|
||||
// # Utilities
|
||||
//
|
||||
// A number of package functions have been provided to easily wrap retryer implementations in an implementation agnostic
|
||||
// way. These are:
|
||||
//
|
||||
// AddWithErrorCodes - Provides the ability to add additional API error codes that should be considered retryable
|
||||
// in addition to those considered retryable by the provided retryer.
|
||||
//
|
||||
// AddWithMaxAttempts - Provides the ability to set the max number of attempts for retrying a request by wrapping
|
||||
// a retryer implementation.
|
||||
//
|
||||
// AddWithMaxBackoffDelay - Provides the ability to set the max back off delay that can occur before retrying a
|
||||
// request by wrapping a retryer implementation.
|
||||
//
|
||||
// The following package functions have been provided to easily satisfy different retry interfaces to further customize
|
||||
// a given retryer's behavior:
|
||||
//
|
||||
// BackoffDelayerFunc - Can be used to wrap a function to satisfy the BackoffDelayer interface. For example,
|
||||
// you can use this method to easily create custom back off policies to be used with the
|
||||
// standard retryer.
|
||||
//
|
||||
// IsErrorRetryableFunc - Can be used to wrap a function to satisfy the IsErrorRetryable interface. For example,
|
||||
// this can be used to extend the standard retryer to add additional logic to determine if an
|
||||
// error should be retried.
|
||||
//
|
||||
// IsErrorTimeoutFunc - Can be used to wrap a function to satisfy IsErrorTimeout interface. For example,
|
||||
// this can be used to extend the standard retryer to add additional logic to determine if an
|
||||
// error should be considered a timeout.
|
||||
package retry
|
@ -0,0 +1,20 @@
|
||||
package retry
|
||||
|
||||
import "fmt"
|
||||
|
||||
// MaxAttemptsError provides the error when the maximum number of attempts have
|
||||
// been exceeded.
|
||||
type MaxAttemptsError struct {
|
||||
Attempt int
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *MaxAttemptsError) Error() string {
|
||||
return fmt.Sprintf("exceeded maximum number of attempts, %d, %v", e.Attempt, e.Err)
|
||||
}
|
||||
|
||||
// Unwrap returns the nested error causing the max attempts error. Provides the
|
||||
// implementation for errors.Is and errors.As to unwrap nested errors.
|
||||
func (e *MaxAttemptsError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/internal/rand"
|
||||
"github.com/aws/aws-sdk-go-v2/internal/timeconv"
|
||||
)
|
||||
|
||||
// ExponentialJitterBackoff provides backoff delays with jitter based on the
|
||||
// number of attempts.
|
||||
type ExponentialJitterBackoff struct {
|
||||
maxBackoff time.Duration
|
||||
// precomputed number of attempts needed to reach max backoff.
|
||||
maxBackoffAttempts float64
|
||||
|
||||
randFloat64 func() (float64, error)
|
||||
}
|
||||
|
||||
// NewExponentialJitterBackoff returns an ExponentialJitterBackoff configured
|
||||
// for the max backoff.
|
||||
func NewExponentialJitterBackoff(maxBackoff time.Duration) *ExponentialJitterBackoff {
|
||||
return &ExponentialJitterBackoff{
|
||||
maxBackoff: maxBackoff,
|
||||
maxBackoffAttempts: math.Log2(
|
||||
float64(maxBackoff) / float64(time.Second)),
|
||||
randFloat64: rand.CryptoRandFloat64,
|
||||
}
|
||||
}
|
||||
|
||||
// BackoffDelay returns the duration to wait before the next attempt should be
|
||||
// made. Returns an error if unable get a duration.
|
||||
func (j *ExponentialJitterBackoff) BackoffDelay(attempt int, err error) (time.Duration, error) {
|
||||
if attempt > int(j.maxBackoffAttempts) {
|
||||
return j.maxBackoff, nil
|
||||
}
|
||||
|
||||
b, err := j.randFloat64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// [0.0, 1.0) * 2 ^ attempts
|
||||
ri := int64(1 << uint64(attempt))
|
||||
delaySeconds := b * float64(ri)
|
||||
|
||||
return timeconv.FloatSecondsDur(delaySeconds), nil
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
awsmiddle "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
)
|
||||
|
||||
// attemptResultsKey is a metadata accessor key to retrieve metadata
|
||||
// for all request attempts.
|
||||
type attemptResultsKey struct {
|
||||
}
|
||||
|
||||
// GetAttemptResults retrieves attempts results from middleware metadata.
|
||||
func GetAttemptResults(metadata middleware.Metadata) (AttemptResults, bool) {
|
||||
m, ok := metadata.Get(attemptResultsKey{}).(AttemptResults)
|
||||
return m, ok
|
||||
}
|
||||
|
||||
// AttemptResults represents struct containing metadata returned by all request attempts.
|
||||
type AttemptResults struct {
|
||||
|
||||
// Results is a slice consisting attempt result from all request attempts.
|
||||
// Results are stored in order request attempt is made.
|
||||
Results []AttemptResult
|
||||
}
|
||||
|
||||
// AttemptResult represents attempt result returned by a single request attempt.
|
||||
type AttemptResult struct {
|
||||
|
||||
// Err is the error if received for the request attempt.
|
||||
Err error
|
||||
|
||||
// Retryable denotes if request may be retried. This states if an
|
||||
// error is considered retryable.
|
||||
Retryable bool
|
||||
|
||||
// Retried indicates if this request was retried.
|
||||
Retried bool
|
||||
|
||||
// ResponseMetadata is any existing metadata passed via the response middlewares.
|
||||
ResponseMetadata middleware.Metadata
|
||||
}
|
||||
|
||||
// addAttemptResults adds attempt results to middleware metadata
|
||||
func addAttemptResults(metadata *middleware.Metadata, v AttemptResults) {
|
||||
metadata.Set(attemptResultsKey{}, v)
|
||||
}
|
||||
|
||||
// GetRawResponse returns raw response recorded for the attempt result
|
||||
func (a AttemptResult) GetRawResponse() interface{} {
|
||||
return awsmiddle.GetRawResponse(a.ResponseMetadata)
|
||||
}
|
@ -0,0 +1,331 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
awsmiddle "github.com/aws/aws-sdk-go-v2/aws/middleware"
|
||||
"github.com/aws/aws-sdk-go-v2/internal/sdk"
|
||||
"github.com/aws/smithy-go/logging"
|
||||
"github.com/aws/smithy-go/middleware"
|
||||
smithymiddle "github.com/aws/smithy-go/middleware"
|
||||
"github.com/aws/smithy-go/transport/http"
|
||||
)
|
||||
|
||||
// RequestCloner is a function that can take an input request type and clone
|
||||
// the request for use in a subsequent retry attempt.
|
||||
type RequestCloner func(interface{}) interface{}
|
||||
|
||||
type retryMetadata struct {
|
||||
AttemptNum int
|
||||
AttemptTime time.Time
|
||||
MaxAttempts int
|
||||
AttemptClockSkew time.Duration
|
||||
}
|
||||
|
||||
// Attempt is a Smithy Finalize middleware that handles retry attempts using
|
||||
// the provided Retryer implementation.
|
||||
type Attempt struct {
|
||||
// Enable the logging of retry attempts performed by the SDK. This will
|
||||
// include logging retry attempts, unretryable errors, and when max
|
||||
// attempts are reached.
|
||||
LogAttempts bool
|
||||
|
||||
retryer aws.RetryerV2
|
||||
requestCloner RequestCloner
|
||||
}
|
||||
|
||||
// NewAttemptMiddleware returns a new Attempt retry middleware.
|
||||
func NewAttemptMiddleware(retryer aws.Retryer, requestCloner RequestCloner, optFns ...func(*Attempt)) *Attempt {
|
||||
m := &Attempt{
|
||||
retryer: wrapAsRetryerV2(retryer),
|
||||
requestCloner: requestCloner,
|
||||
}
|
||||
for _, fn := range optFns {
|
||||
fn(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ID returns the middleware identifier
|
||||
func (r *Attempt) ID() string { return "Retry" }
|
||||
|
||||
func (r Attempt) logf(logger logging.Logger, classification logging.Classification, format string, v ...interface{}) {
|
||||
if !r.LogAttempts {
|
||||
return
|
||||
}
|
||||
logger.Logf(classification, format, v...)
|
||||
}
|
||||
|
||||
// HandleFinalize utilizes the provider Retryer implementation to attempt
|
||||
// retries over the next handler
|
||||
func (r *Attempt) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) (
|
||||
out smithymiddle.FinalizeOutput, metadata smithymiddle.Metadata, err error,
|
||||
) {
|
||||
var attemptNum int
|
||||
var attemptClockSkew time.Duration
|
||||
var attemptResults AttemptResults
|
||||
|
||||
maxAttempts := r.retryer.MaxAttempts()
|
||||
releaseRetryToken := nopRelease
|
||||
|
||||
for {
|
||||
attemptNum++
|
||||
attemptInput := in
|
||||
attemptInput.Request = r.requestCloner(attemptInput.Request)
|
||||
|
||||
// Record the metadata for the for attempt being started.
|
||||
attemptCtx := setRetryMetadata(ctx, retryMetadata{
|
||||
AttemptNum: attemptNum,
|
||||
AttemptTime: sdk.NowTime().UTC(),
|
||||
MaxAttempts: maxAttempts,
|
||||
AttemptClockSkew: attemptClockSkew,
|
||||
})
|
||||
|
||||
var attemptResult AttemptResult
|
||||
out, attemptResult, releaseRetryToken, err = r.handleAttempt(attemptCtx, attemptInput, releaseRetryToken, next)
|
||||
attemptClockSkew, _ = awsmiddle.GetAttemptSkew(attemptResult.ResponseMetadata)
|
||||
|
||||
// AttemptResult Retried states that the attempt was not successful, and
|
||||
// should be retried.
|
||||
shouldRetry := attemptResult.Retried
|
||||
|
||||
// Add attempt metadata to list of all attempt metadata
|
||||
attemptResults.Results = append(attemptResults.Results, attemptResult)
|
||||
|
||||
if !shouldRetry {
|
||||
// Ensure the last response's metadata is used as the bases for result
|
||||
// metadata returned by the stack. The Slice of attempt results
|
||||
// will be added to this cloned metadata.
|
||||
metadata = attemptResult.ResponseMetadata.Clone()
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
addAttemptResults(&metadata, attemptResults)
|
||||
return out, metadata, err
|
||||
}
|
||||
|
||||
// handleAttempt handles an individual request attempt.
|
||||
func (r *Attempt) handleAttempt(
|
||||
ctx context.Context, in smithymiddle.FinalizeInput, releaseRetryToken func(error) error, next smithymiddle.FinalizeHandler,
|
||||
) (
|
||||
out smithymiddle.FinalizeOutput, attemptResult AttemptResult, _ func(error) error, err error,
|
||||
) {
|
||||
defer func() {
|
||||
attemptResult.Err = err
|
||||
}()
|
||||
|
||||
// Short circuit if this attempt never can succeed because the context is
|
||||
// canceled. This reduces the chance of token pools being modified for
|
||||
// attempts that will not be made
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return out, attemptResult, nopRelease, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
//------------------------------
|
||||
// Get Attempt Token
|
||||
//------------------------------
|
||||
releaseAttemptToken, err := r.retryer.GetAttemptToken(ctx)
|
||||
if err != nil {
|
||||
return out, attemptResult, nopRelease, fmt.Errorf(
|
||||
"failed to get retry Send token, %w", err)
|
||||
}
|
||||
|
||||
//------------------------------
|
||||
// Send Attempt
|
||||
//------------------------------
|
||||
logger := smithymiddle.GetLogger(ctx)
|
||||
service, operation := awsmiddle.GetServiceID(ctx), awsmiddle.GetOperationName(ctx)
|
||||
retryMetadata, _ := getRetryMetadata(ctx)
|
||||
attemptNum := retryMetadata.AttemptNum
|
||||
maxAttempts := retryMetadata.MaxAttempts
|
||||
|
||||
// Following attempts must ensure the request payload stream starts in a
|
||||
// rewound state.
|
||||
if attemptNum > 1 {
|
||||
if rewindable, ok := in.Request.(interface{ RewindStream() error }); ok {
|
||||
if rewindErr := rewindable.RewindStream(); rewindErr != nil {
|
||||
return out, attemptResult, nopRelease, fmt.Errorf(
|
||||
"failed to rewind transport stream for retry, %w", rewindErr)
|
||||
}
|
||||
}
|
||||
|
||||
r.logf(logger, logging.Debug, "retrying request %s/%s, attempt %d",
|
||||
service, operation, attemptNum)
|
||||
}
|
||||
|
||||
var metadata smithymiddle.Metadata
|
||||
out, metadata, err = next.HandleFinalize(ctx, in)
|
||||
attemptResult.ResponseMetadata = metadata
|
||||
|
||||
//------------------------------
|
||||
// Bookkeeping
|
||||
//------------------------------
|
||||
// Release the retry token based on the state of the attempt's error (if any).
|
||||
if releaseError := releaseRetryToken(err); releaseError != nil && err != nil {
|
||||
return out, attemptResult, nopRelease, fmt.Errorf(
|
||||
"failed to release retry token after request error, %w", err)
|
||||
}
|
||||
// Release the attempt token based on the state of the attempt's error (if any).
|
||||
if releaseError := releaseAttemptToken(err); releaseError != nil && err != nil {
|
||||
return out, attemptResult, nopRelease, fmt.Errorf(
|
||||
"failed to release initial token after request error, %w", err)
|
||||
}
|
||||
// If there was no error making the attempt, nothing further to do. There
|
||||
// will be nothing to retry.
|
||||
if err == nil {
|
||||
return out, attemptResult, nopRelease, err
|
||||
}
|
||||
|
||||
//------------------------------
|
||||
// Is Retryable and Should Retry
|
||||
//------------------------------
|
||||
// If the attempt failed with an unretryable error, nothing further to do
|
||||
// but return, and inform the caller about the terminal failure.
|
||||
retryable := r.retryer.IsErrorRetryable(err)
|
||||
if !retryable {
|
||||
r.logf(logger, logging.Debug, "request failed with unretryable error %v", err)
|
||||
return out, attemptResult, nopRelease, err
|
||||
}
|
||||
|
||||
// set retryable to true
|
||||
attemptResult.Retryable = true
|
||||
|
||||
// Once the maximum number of attempts have been exhausted there is nothing
|
||||
// further to do other than inform the caller about the terminal failure.
|
||||
if maxAttempts > 0 && attemptNum >= maxAttempts {
|
||||
r.logf(logger, logging.Debug, "max retry attempts exhausted, max %d", maxAttempts)
|
||||
err = &MaxAttemptsError{
|
||||
Attempt: attemptNum,
|
||||
Err: err,
|
||||
}
|
||||
return out, attemptResult, nopRelease, err
|
||||
}
|
||||
|
||||
//------------------------------
|
||||
// Get Retry (aka Retry Quota) Token
|
||||
//------------------------------
|
||||
// Get a retry token that will be released after the
|
||||
releaseRetryToken, retryTokenErr := r.retryer.GetRetryToken(ctx, err)
|
||||
if retryTokenErr != nil {
|
||||
return out, attemptResult, nopRelease, retryTokenErr
|
||||
}
|
||||
|
||||
//------------------------------
|
||||
// Retry Delay and Sleep
|
||||
//------------------------------
|
||||
// Get the retry delay before another attempt can be made, and sleep for
|
||||
// that time. Potentially early exist if the sleep is canceled via the
|
||||
// context.
|
||||
retryDelay, reqErr := r.retryer.RetryDelay(attemptNum, err)
|
||||
if reqErr != nil {
|
||||
return out, attemptResult, releaseRetryToken, reqErr
|
||||
}
|
||||
if reqErr = sdk.SleepWithContext(ctx, retryDelay); reqErr != nil {
|
||||
err = &aws.RequestCanceledError{Err: reqErr}
|
||||
return out, attemptResult, releaseRetryToken, err
|
||||
}
|
||||
|
||||
// The request should be re-attempted.
|
||||
attemptResult.Retried = true
|
||||
|
||||
return out, attemptResult, releaseRetryToken, err
|
||||
}
|
||||
|
||||
// MetricsHeader attaches SDK request metric header for retries to the transport
|
||||
type MetricsHeader struct{}
|
||||
|
||||
// ID returns the middleware identifier
|
||||
func (r *MetricsHeader) ID() string {
|
||||
return "RetryMetricsHeader"
|
||||
}
|
||||
|
||||
// HandleFinalize attaches the SDK request metric header to the transport layer
|
||||
func (r MetricsHeader) HandleFinalize(ctx context.Context, in smithymiddle.FinalizeInput, next smithymiddle.FinalizeHandler) (
|
||||
out smithymiddle.FinalizeOutput, metadata smithymiddle.Metadata, err error,
|
||||
) {
|
||||
retryMetadata, _ := getRetryMetadata(ctx)
|
||||
|
||||
const retryMetricHeader = "Amz-Sdk-Request"
|
||||
var parts []string
|
||||
|
||||
parts = append(parts, "attempt="+strconv.Itoa(retryMetadata.AttemptNum))
|
||||
if retryMetadata.MaxAttempts != 0 {
|
||||
parts = append(parts, "max="+strconv.Itoa(retryMetadata.MaxAttempts))
|
||||
}
|
||||
|
||||
var ttl time.Time
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
ttl = deadline
|
||||
}
|
||||
|
||||
// Only append the TTL if it can be determined.
|
||||
if !ttl.IsZero() && retryMetadata.AttemptClockSkew > 0 {
|
||||
const unixTimeFormat = "20060102T150405Z"
|
||||
ttl = ttl.Add(retryMetadata.AttemptClockSkew)
|
||||
parts = append(parts, "ttl="+ttl.Format(unixTimeFormat))
|
||||
}
|
||||
|
||||
switch req := in.Request.(type) {
|
||||
case *http.Request:
|
||||
req.Header[retryMetricHeader] = append(req.Header[retryMetricHeader][:0], strings.Join(parts, "; "))
|
||||
default:
|
||||
return out, metadata, fmt.Errorf("unknown transport type %T", req)
|
||||
}
|
||||
|
||||
return next.HandleFinalize(ctx, in)
|
||||
}
|
||||
|
||||
type retryMetadataKey struct{}
|
||||
|
||||
// getRetryMetadata retrieves retryMetadata from the context and a bool
|
||||
// indicating if it was set.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func getRetryMetadata(ctx context.Context) (metadata retryMetadata, ok bool) {
|
||||
metadata, ok = middleware.GetStackValue(ctx, retryMetadataKey{}).(retryMetadata)
|
||||
return metadata, ok
|
||||
}
|
||||
|
||||
// setRetryMetadata sets the retryMetadata on the context.
|
||||
//
|
||||
// Scoped to stack values. Use github.com/aws/smithy-go/middleware#ClearStackValues
|
||||
// to clear all stack values.
|
||||
func setRetryMetadata(ctx context.Context, metadata retryMetadata) context.Context {
|
||||
return middleware.WithStackValue(ctx, retryMetadataKey{}, metadata)
|
||||
}
|
||||
|
||||
// AddRetryMiddlewaresOptions is the set of options that can be passed to
|
||||
// AddRetryMiddlewares for configuring retry associated middleware.
|
||||
type AddRetryMiddlewaresOptions struct {
|
||||
Retryer aws.Retryer
|
||||
|
||||
// Enable the logging of retry attempts performed by the SDK. This will
|
||||
// include logging retry attempts, unretryable errors, and when max
|
||||
// attempts are reached.
|
||||
LogRetryAttempts bool
|
||||
}
|
||||
|
||||
// AddRetryMiddlewares adds retry middleware to operation middleware stack
|
||||
func AddRetryMiddlewares(stack *smithymiddle.Stack, options AddRetryMiddlewaresOptions) error {
|
||||
attempt := NewAttemptMiddleware(options.Retryer, http.RequestCloner, func(middleware *Attempt) {
|
||||
middleware.LogAttempts = options.LogRetryAttempts
|
||||
})
|
||||
|
||||
if err := stack.Finalize.Add(attempt, smithymiddle.After); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := stack.Finalize.Add(&MetricsHeader{}, smithymiddle.After); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,90 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
)
|
||||
|
||||
// AddWithErrorCodes returns a Retryer with additional error codes considered
|
||||
// for determining if the error should be retried.
|
||||
func AddWithErrorCodes(r aws.Retryer, codes ...string) aws.Retryer {
|
||||
retryable := &RetryableErrorCode{
|
||||
Codes: map[string]struct{}{},
|
||||
}
|
||||
for _, c := range codes {
|
||||
retryable.Codes[c] = struct{}{}
|
||||
}
|
||||
|
||||
return &withIsErrorRetryable{
|
||||
RetryerV2: wrapAsRetryerV2(r),
|
||||
Retryable: retryable,
|
||||
}
|
||||
}
|
||||
|
||||
type withIsErrorRetryable struct {
|
||||
aws.RetryerV2
|
||||
Retryable IsErrorRetryable
|
||||
}
|
||||
|
||||
func (r *withIsErrorRetryable) IsErrorRetryable(err error) bool {
|
||||
if v := r.Retryable.IsErrorRetryable(err); v != aws.UnknownTernary {
|
||||
return v.Bool()
|
||||
}
|
||||
return r.RetryerV2.IsErrorRetryable(err)
|
||||
}
|
||||
|
||||
// AddWithMaxAttempts returns a Retryer with MaxAttempts set to the value
|
||||
// specified.
|
||||
func AddWithMaxAttempts(r aws.Retryer, max int) aws.Retryer {
|
||||
return &withMaxAttempts{
|
||||
RetryerV2: wrapAsRetryerV2(r),
|
||||
Max: max,
|
||||
}
|
||||
}
|
||||
|
||||
type withMaxAttempts struct {
|
||||
aws.RetryerV2
|
||||
Max int
|
||||
}
|
||||
|
||||
func (w *withMaxAttempts) MaxAttempts() int {
|
||||
return w.Max
|
||||
}
|
||||
|
||||
// AddWithMaxBackoffDelay returns a retryer wrapping the passed in retryer
|
||||
// overriding the RetryDelay behavior for a alternate minimum initial backoff
|
||||
// delay.
|
||||
func AddWithMaxBackoffDelay(r aws.Retryer, delay time.Duration) aws.Retryer {
|
||||
return &withMaxBackoffDelay{
|
||||
RetryerV2: wrapAsRetryerV2(r),
|
||||
backoff: NewExponentialJitterBackoff(delay),
|
||||
}
|
||||
}
|
||||
|
||||
type withMaxBackoffDelay struct {
|
||||
aws.RetryerV2
|
||||
backoff *ExponentialJitterBackoff
|
||||
}
|
||||
|
||||
func (r *withMaxBackoffDelay) RetryDelay(attempt int, err error) (time.Duration, error) {
|
||||
return r.backoff.BackoffDelay(attempt, err)
|
||||
}
|
||||
|
||||
type wrappedAsRetryerV2 struct {
|
||||
aws.Retryer
|
||||
}
|
||||
|
||||
func wrapAsRetryerV2(r aws.Retryer) aws.RetryerV2 {
|
||||
v, ok := r.(aws.RetryerV2)
|
||||
if !ok {
|
||||
v = wrappedAsRetryerV2{Retryer: r}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (w wrappedAsRetryerV2) GetAttemptToken(context.Context) (func(error) error, error) {
|
||||
return w.Retryer.GetInitialToken(), nil
|
||||
}
|
@ -0,0 +1,186 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
)
|
||||
|
||||
// IsErrorRetryable provides the interface of an implementation to determine if
|
||||
// a error as the result of an operation is retryable.
|
||||
type IsErrorRetryable interface {
|
||||
IsErrorRetryable(error) aws.Ternary
|
||||
}
|
||||
|
||||
// IsErrorRetryables is a collection of checks to determine of the error is
|
||||
// retryable. Iterates through the checks and returns the state of retryable
|
||||
// if any check returns something other than unknown.
|
||||
type IsErrorRetryables []IsErrorRetryable
|
||||
|
||||
// IsErrorRetryable returns if the error is retryable if any of the checks in
|
||||
// the list return a value other than unknown.
|
||||
func (r IsErrorRetryables) IsErrorRetryable(err error) aws.Ternary {
|
||||
for _, re := range r {
|
||||
if v := re.IsErrorRetryable(err); v != aws.UnknownTernary {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
|
||||
// IsErrorRetryableFunc wraps a function with the IsErrorRetryable interface.
|
||||
type IsErrorRetryableFunc func(error) aws.Ternary
|
||||
|
||||
// IsErrorRetryable returns if the error is retryable.
|
||||
func (fn IsErrorRetryableFunc) IsErrorRetryable(err error) aws.Ternary {
|
||||
return fn(err)
|
||||
}
|
||||
|
||||
// RetryableError is an IsErrorRetryable implementation which uses the
|
||||
// optional interface Retryable on the error value to determine if the error is
|
||||
// retryable.
|
||||
type RetryableError struct{}
|
||||
|
||||
// IsErrorRetryable returns if the error is retryable if it satisfies the
|
||||
// Retryable interface, and returns if the attempt should be retried.
|
||||
func (RetryableError) IsErrorRetryable(err error) aws.Ternary {
|
||||
var v interface{ RetryableError() bool }
|
||||
|
||||
if !errors.As(err, &v) {
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
|
||||
return aws.BoolTernary(v.RetryableError())
|
||||
}
|
||||
|
||||
// NoRetryCanceledError detects if the error was an request canceled error and
|
||||
// returns if so.
|
||||
type NoRetryCanceledError struct{}
|
||||
|
||||
// IsErrorRetryable returns the error is not retryable if the request was
|
||||
// canceled.
|
||||
func (NoRetryCanceledError) IsErrorRetryable(err error) aws.Ternary {
|
||||
var v interface{ CanceledError() bool }
|
||||
|
||||
if !errors.As(err, &v) {
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
|
||||
if v.CanceledError() {
|
||||
return aws.FalseTernary
|
||||
}
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
|
||||
// RetryableConnectionError determines if the underlying error is an HTTP
|
||||
// connection and returns if it should be retried.
|
||||
//
|
||||
// Includes errors such as connection reset, connection refused, net dial,
|
||||
// temporary, and timeout errors.
|
||||
type RetryableConnectionError struct{}
|
||||
|
||||
// IsErrorRetryable returns if the error is caused by and HTTP connection
|
||||
// error, and should be retried.
|
||||
func (r RetryableConnectionError) IsErrorRetryable(err error) aws.Ternary {
|
||||
if err == nil {
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
var retryable bool
|
||||
|
||||
var conErr interface{ ConnectionError() bool }
|
||||
var tempErr interface{ Temporary() bool }
|
||||
var timeoutErr interface{ Timeout() bool }
|
||||
var urlErr *url.Error
|
||||
var netOpErr *net.OpError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &conErr) && conErr.ConnectionError():
|
||||
retryable = true
|
||||
|
||||
case strings.Contains(err.Error(), "connection reset"):
|
||||
retryable = true
|
||||
|
||||
case errors.As(err, &urlErr):
|
||||
// Refused connections should be retried as the service may not yet be
|
||||
// running on the port. Go TCP dial considers refused connections as
|
||||
// not temporary.
|
||||
if strings.Contains(urlErr.Error(), "connection refused") {
|
||||
retryable = true
|
||||
} else {
|
||||
return r.IsErrorRetryable(errors.Unwrap(urlErr))
|
||||
}
|
||||
|
||||
case errors.As(err, &netOpErr):
|
||||
// Network dial, or temporary network errors are always retryable.
|
||||
if strings.EqualFold(netOpErr.Op, "dial") || netOpErr.Temporary() {
|
||||
retryable = true
|
||||
} else {
|
||||
return r.IsErrorRetryable(errors.Unwrap(netOpErr))
|
||||
}
|
||||
|
||||
case errors.As(err, &tempErr) && tempErr.Temporary():
|
||||
// Fallback to the generic temporary check, with temporary errors
|
||||
// retryable.
|
||||
retryable = true
|
||||
|
||||
case errors.As(err, &timeoutErr) && timeoutErr.Timeout():
|
||||
// Fallback to the generic timeout check, with timeout errors
|
||||
// retryable.
|
||||
retryable = true
|
||||
|
||||
default:
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
|
||||
return aws.BoolTernary(retryable)
|
||||
|
||||
}
|
||||
|
||||
// RetryableHTTPStatusCode provides a IsErrorRetryable based on HTTP status
|
||||
// codes.
|
||||
type RetryableHTTPStatusCode struct {
|
||||
Codes map[int]struct{}
|
||||
}
|
||||
|
||||
// IsErrorRetryable return if the passed in error is retryable based on the
|
||||
// HTTP status code.
|
||||
func (r RetryableHTTPStatusCode) IsErrorRetryable(err error) aws.Ternary {
|
||||
var v interface{ HTTPStatusCode() int }
|
||||
|
||||
if !errors.As(err, &v) {
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
|
||||
_, ok := r.Codes[v.HTTPStatusCode()]
|
||||
if !ok {
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
|
||||
return aws.TrueTernary
|
||||
}
|
||||
|
||||
// RetryableErrorCode determines if an attempt should be retried based on the
|
||||
// API error code.
|
||||
type RetryableErrorCode struct {
|
||||
Codes map[string]struct{}
|
||||
}
|
||||
|
||||
// IsErrorRetryable return if the error is retryable based on the error codes.
|
||||
// Returns unknown if the error doesn't have a code or it is unknown.
|
||||
func (r RetryableErrorCode) IsErrorRetryable(err error) aws.Ternary {
|
||||
var v interface{ ErrorCode() string }
|
||||
|
||||
if !errors.As(err, &v) {
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
|
||||
_, ok := r.Codes[v.ErrorCode()]
|
||||
if !ok {
|
||||
return aws.UnknownTernary
|
||||
}
|
||||
|
||||
return aws.TrueTernary
|
||||
}
|
@ -0,0 +1,258 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws/ratelimit"
|
||||
)
|
||||
|
||||
// BackoffDelayer provides the interface for determining the delay to before
|
||||
// another request attempt, that previously failed.
|
||||
type BackoffDelayer interface {
|
||||
BackoffDelay(attempt int, err error) (time.Duration, error)
|
||||
}
|
||||
|
||||
// BackoffDelayerFunc provides a wrapper around a function to determine the
|
||||
// backoff delay of an attempt retry.
|
||||
type BackoffDelayerFunc func(int, error) (time.Duration, error)
|
||||
|
||||
// BackoffDelay returns the delay before attempt to retry a request.
|
||||
func (fn BackoffDelayerFunc) BackoffDelay(attempt int, err error) (time.Duration, error) {
|
||||
return fn(attempt, err)
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultMaxAttempts is the maximum of attempts for an API request
|
||||
DefaultMaxAttempts int = 3
|
||||
|
||||
// DefaultMaxBackoff is the maximum back off delay between attempts
|
||||
DefaultMaxBackoff time.Duration = 20 * time.Second
|
||||
)
|
||||
|
||||
// Default retry token quota values.
|
||||
const (
|
||||
DefaultRetryRateTokens uint = 500
|
||||
DefaultRetryCost uint = 5
|
||||
DefaultRetryTimeoutCost uint = 10
|
||||
DefaultNoRetryIncrement uint = 1
|
||||
)
|
||||
|
||||
// DefaultRetryableHTTPStatusCodes is the default set of HTTP status codes the SDK
|
||||
// should consider as retryable errors.
|
||||
var DefaultRetryableHTTPStatusCodes = map[int]struct{}{
|
||||
500: {},
|
||||
502: {},
|
||||
503: {},
|
||||
504: {},
|
||||
}
|
||||
|
||||
// DefaultRetryableErrorCodes provides the set of API error codes that should
|
||||
// be retried.
|
||||
var DefaultRetryableErrorCodes = map[string]struct{}{
|
||||
"RequestTimeout": {},
|
||||
"RequestTimeoutException": {},
|
||||
}
|
||||
|
||||
// DefaultThrottleErrorCodes provides the set of API error codes that are
|
||||
// considered throttle errors.
|
||||
var DefaultThrottleErrorCodes = map[string]struct{}{
|
||||
"Throttling": {},
|
||||
"ThrottlingException": {},
|
||||
"ThrottledException": {},
|
||||
"RequestThrottledException": {},
|
||||
"TooManyRequestsException": {},
|
||||
"ProvisionedThroughputExceededException": {},
|
||||
"TransactionInProgressException": {},
|
||||
"RequestLimitExceeded": {},
|
||||
"BandwidthLimitExceeded": {},
|
||||
"LimitExceededException": {},
|
||||
"RequestThrottled": {},
|
||||
"SlowDown": {},
|
||||
"PriorRequestNotComplete": {},
|
||||
"EC2ThrottledException": {},
|
||||
}
|
||||
|
||||
// DefaultRetryables provides the set of retryable checks that are used by
|
||||
// default.
|
||||
var DefaultRetryables = []IsErrorRetryable{
|
||||
NoRetryCanceledError{},
|
||||
RetryableError{},
|
||||
RetryableConnectionError{},
|
||||
RetryableHTTPStatusCode{
|
||||
Codes: DefaultRetryableHTTPStatusCodes,
|
||||
},
|
||||
RetryableErrorCode{
|
||||
Codes: DefaultRetryableErrorCodes,
|
||||
},
|
||||
RetryableErrorCode{
|
||||
Codes: DefaultThrottleErrorCodes,
|
||||
},
|
||||
}
|
||||
|
||||
// DefaultTimeouts provides the set of timeout checks that are used by default.
|
||||
var DefaultTimeouts = []IsErrorTimeout{
|
||||
TimeouterError{},
|
||||
}
|
||||
|
||||
// StandardOptions provides the functional options for configuring the standard
|
||||
// retryable, and delay behavior.
|
||||
type StandardOptions struct {
|
||||
// Maximum number of attempts that should be made.
|
||||
MaxAttempts int
|
||||
|
||||
// MaxBackoff duration between retried attempts.
|
||||
MaxBackoff time.Duration
|
||||
|
||||
// Provides the backoff strategy the retryer will use to determine the
|
||||
// delay between retry attempts.
|
||||
Backoff BackoffDelayer
|
||||
|
||||
// Set of strategies to determine if the attempt should be retried based on
|
||||
// the error response received.
|
||||
//
|
||||
// It is safe to append to this list in NewStandard's functional options.
|
||||
Retryables []IsErrorRetryable
|
||||
|
||||
// Set of strategies to determine if the attempt failed due to a timeout
|
||||
// error.
|
||||
//
|
||||
// It is safe to append to this list in NewStandard's functional options.
|
||||
Timeouts []IsErrorTimeout
|
||||
|
||||
// Provides the rate limiting strategy for rate limiting attempt retries
|
||||
// across all attempts the retryer is being used with.
|
||||
RateLimiter RateLimiter
|
||||
|
||||
// The cost to deduct from the RateLimiter's token bucket per retry.
|
||||
RetryCost uint
|
||||
|
||||
// The cost to deduct from the RateLimiter's token bucket per retry caused
|
||||
// by timeout error.
|
||||
RetryTimeoutCost uint
|
||||
|
||||
// The cost to payback to the RateLimiter's token bucket for successful
|
||||
// attempts.
|
||||
NoRetryIncrement uint
|
||||
}
|
||||
|
||||
// RateLimiter provides the interface for limiting the rate of attempt retries
|
||||
// allowed by the retryer.
|
||||
type RateLimiter interface {
|
||||
GetToken(ctx context.Context, cost uint) (releaseToken func() error, err error)
|
||||
AddTokens(uint) error
|
||||
}
|
||||
|
||||
// Standard is the standard retry pattern for the SDK. It uses a set of
|
||||
// retryable checks to determine of the failed attempt should be retried, and
|
||||
// what retry delay should be used.
|
||||
type Standard struct {
|
||||
options StandardOptions
|
||||
|
||||
timeout IsErrorTimeout
|
||||
retryable IsErrorRetryable
|
||||
backoff BackoffDelayer
|
||||
}
|
||||
|
||||
// NewStandard initializes a standard retry behavior with defaults that can be
|
||||
// overridden via functional options.
|
||||
func NewStandard(fnOpts ...func(*StandardOptions)) *Standard {
|
||||
o := StandardOptions{
|
||||
MaxAttempts: DefaultMaxAttempts,
|
||||
MaxBackoff: DefaultMaxBackoff,
|
||||
Retryables: append([]IsErrorRetryable{}, DefaultRetryables...),
|
||||
Timeouts: append([]IsErrorTimeout{}, DefaultTimeouts...),
|
||||
|
||||
RateLimiter: ratelimit.NewTokenRateLimit(DefaultRetryRateTokens),
|
||||
RetryCost: DefaultRetryCost,
|
||||
RetryTimeoutCost: DefaultRetryTimeoutCost,
|
||||
NoRetryIncrement: DefaultNoRetryIncrement,
|
||||
}
|
||||
for _, fn := range fnOpts {
|
||||
fn(&o)
|
||||
}
|
||||
if o.MaxAttempts <= 0 {
|
||||
o.MaxAttempts = DefaultMaxAttempts
|
||||
}
|
||||
|
||||
backoff := o.Backoff
|
||||
if backoff == nil {
|
||||
backoff = NewExponentialJitterBackoff(o.MaxBackoff)
|
||||
}
|
||||
|
||||
return &Standard{
|
||||
options: o,
|
||||
backoff: backoff,
|
||||
retryable: IsErrorRetryables(o.Retryables),
|
||||
timeout: IsErrorTimeouts(o.Timeouts),
|
||||
}
|
||||
}
|
||||
|
||||
// MaxAttempts returns the maximum number of attempts that can be made for a
|
||||
// request before failing.
|
||||
func (s *Standard) MaxAttempts() int {
|
||||
return s.options.MaxAttempts
|
||||
}
|
||||
|
||||
// IsErrorRetryable returns if the error is can be retried or not. Should not
|
||||
// consider the number of attempts made.
|
||||
func (s *Standard) IsErrorRetryable(err error) bool {
|
||||
return s.retryable.IsErrorRetryable(err).Bool()
|
||||
}
|
||||
|
||||
// RetryDelay returns the delay to use before another request attempt is made.
|
||||
func (s *Standard) RetryDelay(attempt int, err error) (time.Duration, error) {
|
||||
return s.backoff.BackoffDelay(attempt, err)
|
||||
}
|
||||
|
||||
// GetAttemptToken returns the token to be released after then attempt completes.
|
||||
// The release token will add NoRetryIncrement to the RateLimiter token pool if
|
||||
// the attempt was successful. If the attempt failed, nothing will be done.
|
||||
func (s *Standard) GetAttemptToken(context.Context) (func(error) error, error) {
|
||||
return s.GetInitialToken(), nil
|
||||
}
|
||||
|
||||
// GetInitialToken returns a token for adding the NoRetryIncrement to the
|
||||
// RateLimiter token if the attempt completed successfully without error.
|
||||
//
|
||||
// InitialToken applies to result of the each attempt, including the first.
|
||||
// Whereas the RetryToken applies to the result of subsequent attempts.
|
||||
//
|
||||
// Deprecated: use GetAttemptToken instead.
|
||||
func (s *Standard) GetInitialToken() func(error) error {
|
||||
return releaseToken(s.noRetryIncrement).release
|
||||
}
|
||||
|
||||
func (s *Standard) noRetryIncrement() error {
|
||||
return s.options.RateLimiter.AddTokens(s.options.NoRetryIncrement)
|
||||
}
|
||||
|
||||
// GetRetryToken attempts to deduct the retry cost from the retry token pool.
|
||||
// Returning the token release function, or error.
|
||||
func (s *Standard) GetRetryToken(ctx context.Context, opErr error) (func(error) error, error) {
|
||||
cost := s.options.RetryCost
|
||||
|
||||
if s.timeout.IsErrorTimeout(opErr).Bool() {
|
||||
cost = s.options.RetryTimeoutCost
|
||||
}
|
||||
|
||||
fn, err := s.options.RateLimiter.GetToken(ctx, cost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rate limit token, %w", err)
|
||||
}
|
||||
|
||||
return releaseToken(fn).release, nil
|
||||
}
|
||||
|
||||
func nopRelease(error) error { return nil }
|
||||
|
||||
type releaseToken func() error
|
||||
|
||||
func (f releaseToken) release(err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return f()
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue