You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/github.com/tencentyun/cos-go-sdk-v5/auth.go

547 lines
15 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package cos
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/json"
"fmt"
"hash"
"io/ioutil"
math_rand "math/rand"
"net"
"net/http"
"net/url"
"regexp"
"sort"
"strings"
"sync"
"time"
)
const (
sha1SignAlgorithm = "sha1"
privateHeaderPrefix = "x-cos-"
defaultAuthExpire = time.Hour
)
var (
defaultCVMAuthExpire = int64(600)
defaultCVMSchema = "http"
defaultCVMMetaHost = "metadata.tencentyun.com"
defaultCVMCredURI = "latest/meta-data/cam/security-credentials"
internalHost = regexp.MustCompile(`^.*cos-internal\.[a-z-1]+\.tencentcos\.cn$`)
)
var DNSScatterDialContext = DNSScatterDialContextFunc
var DNSScatterTransport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: DNSScatterDialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
func init() {
math_rand.Seed(time.Now().UnixNano())
}
func DNSScatterDialContextFunc(ctx context.Context, network string, addr string) (conn net.Conn, err error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}
// DNS 打散
start := math_rand.Intn(len(ips))
for i := start; i < len(ips); i++ {
conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(ips[i].IP.String(), port))
if err == nil {
return
}
}
for i := 0; i < start; i++ {
conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(ips[i].IP.String(), port))
if err == nil {
break
}
}
return
}
// 需要校验的 Headers 列表
var NeedSignHeaders = map[string]bool{
"host": true,
"range": true,
"x-cos-acl": true,
"x-cos-grant-read": true,
"x-cos-grant-write": true,
"x-cos-grant-full-control": true,
"cache-control": true,
"content-disposition": true,
"content-encoding": true,
"content-type": true,
"content-length": true,
"content-md5": true,
"transfer-encoding": true,
"expect": true,
"expires": true,
"x-cos-content-sha1": true,
"x-cos-storage-class": true,
"if-match": true,
"if-modified-since": true,
"if-none-match": true,
"if-unmodified-since": true,
"origin": true,
"access-control-request-method": true,
"access-control-request-headers": true,
"x-cos-object-type": true,
}
// 非线程安全只能在进程初始化而不是Client初始化时做设置
func SetNeedSignHeaders(key string, val bool) {
NeedSignHeaders[key] = val
}
func safeURLEncode(s string) string {
s = encodeURIComponent(s)
s = strings.Replace(s, "!", "%21", -1)
s = strings.Replace(s, "'", "%27", -1)
s = strings.Replace(s, "(", "%28", -1)
s = strings.Replace(s, ")", "%29", -1)
s = strings.Replace(s, "*", "%2A", -1)
return s
}
type valuesSignMap map[string][]string
func (vs valuesSignMap) Add(key, value string) {
key = strings.ToLower(safeURLEncode(key))
vs[key] = append(vs[key], value)
}
func (vs valuesSignMap) Encode() string {
var keys []string
for k := range vs {
keys = append(keys, k)
}
sort.Strings(keys)
var pairs []string
for _, k := range keys {
items := vs[k]
sort.Strings(items)
for _, val := range items {
pairs = append(
pairs,
fmt.Sprintf("%s=%s", k, safeURLEncode(val)))
}
}
return strings.Join(pairs, "&")
}
// AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
type AuthTime struct {
SignStartTime time.Time
SignEndTime time.Time
KeyStartTime time.Time
KeyEndTime time.Time
}
// NewAuthTime 生成 AuthTime 的便捷函数
//
// expire: 从现在开始多久过期.
func NewAuthTime(expire time.Duration) *AuthTime {
signStartTime := time.Now()
keyStartTime := signStartTime
signEndTime := signStartTime.Add(expire)
keyEndTime := signEndTime
return &AuthTime{
SignStartTime: signStartTime,
SignEndTime: signEndTime,
KeyStartTime: keyStartTime,
KeyEndTime: keyEndTime,
}
}
// signString return q-sign-time string
func (a *AuthTime) signString() string {
return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix())
}
// keyString return q-key-time string
func (a *AuthTime) keyString() string {
return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix())
}
// newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime, signHost bool) string {
signTime := authTime.signString()
keyTime := authTime.keyString()
signKey := calSignKey(secretKey, keyTime)
if signHost {
req.Header.Set("Host", req.Host)
}
formatHeaders := *new(string)
signedHeaderList := *new([]string)
formatHeaders, signedHeaderList = genFormatHeaders(req.Header)
formatParameters, signedParameterList := genFormatParameters(req.URL.Query())
formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString)
signature := calSignature(signKey, stringToSign)
return genAuthorization(
secretID, signTime, keyTime, signature, signedHeaderList,
signedParameterList,
)
}
// AddAuthorizationHeader 给 req 增加签名信息
func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) {
if secretID == "" {
return
}
auth := newAuthorization(secretID, secretKey, req,
authTime, true,
)
if len(sessionToken) > 0 {
req.Header.Set("x-cos-security-token", sessionToken)
}
req.Header.Set("Authorization", auth)
}
// calSignKey 计算 SignKey
func calSignKey(secretKey, keyTime string) string {
digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm)
return fmt.Sprintf("%x", digest)
}
// calStringToSign 计算 StringToSign
func calStringToSign(signAlgorithm, signTime, formatString string) string {
h := sha1.New()
h.Write([]byte(formatString))
return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil))
}
// calSignature 计算 Signature
func calSignature(signKey, stringToSign string) string {
digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm)
return fmt.Sprintf("%x", digest)
}
// genAuthorization 生成 Authorization
func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string {
return strings.Join([]string{
"q-sign-algorithm=" + sha1SignAlgorithm,
"q-ak=" + secretID,
"q-sign-time=" + signTime,
"q-key-time=" + keyTime,
"q-header-list=" + strings.Join(signedHeaderList, ";"),
"q-url-param-list=" + strings.Join(signedParameterList, ";"),
"q-signature=" + signature,
}, "&")
}
// genFormatString 生成 FormatString
func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string {
formatMethod := strings.ToLower(method)
formatURI := uri.Path
return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI,
formatParameters, formatHeaders,
)
}
// genFormatParameters 生成 FormatParameters 和 SignedParameterList
// instead of the url.Values{}
func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) {
ps := valuesSignMap{}
for key, values := range parameters {
for _, value := range values {
ps.Add(key, value)
signedParameterList = append(signedParameterList, strings.ToLower(safeURLEncode(key)))
}
}
//formatParameters = strings.ToLower(ps.Encode())
formatParameters = ps.Encode()
sort.Strings(signedParameterList)
return
}
// genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) {
hs := valuesSignMap{}
for key, values := range headers {
if isSignHeader(strings.ToLower(key)) {
for _, value := range values {
hs.Add(key, value)
signedHeaderList = append(signedHeaderList, strings.ToLower(safeURLEncode(key)))
}
}
}
formatHeaders = hs.Encode()
sort.Strings(signedHeaderList)
return
}
// HMAC 签名
func calHMACDigest(key, msg, signMethod string) []byte {
var hashFunc func() hash.Hash
switch signMethod {
case "sha1":
hashFunc = sha1.New
default:
hashFunc = sha1.New
}
h := hmac.New(hashFunc, []byte(key))
h.Write([]byte(msg))
return h.Sum(nil)
}
func isSignHeader(key string) bool {
for k, v := range NeedSignHeaders {
if key == k && v {
return true
}
}
return strings.HasPrefix(key, privateHeaderPrefix)
}
// AuthorizationTransport 给请求增加 Authorization header
type AuthorizationTransport struct {
SecretID string
SecretKey string
SessionToken string
rwLocker sync.RWMutex
// 签名多久过期
Expire time.Duration
Transport http.RoundTripper
}
// SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
func (t *AuthorizationTransport) SetCredential(ak, sk, token string) {
t.rwLocker.Lock()
defer t.rwLocker.Unlock()
t.SecretID = ak
t.SecretKey = sk
t.SessionToken = token
}
// GetCredential get the ak, sk, token
func (t *AuthorizationTransport) GetCredential() (string, string, string) {
t.rwLocker.RLock()
defer t.rwLocker.RUnlock()
return t.SecretID, t.SecretKey, t.SessionToken
}
// RoundTrip implements the RoundTripper interface.
func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req) // per RoundTrip contract
ak, sk, token := t.GetCredential()
if strings.HasPrefix(ak, " ") || strings.HasSuffix(ak, " ") {
return nil, fmt.Errorf("SecretID is invalid")
}
if strings.HasPrefix(sk, " ") || strings.HasSuffix(sk, " ") {
return nil, fmt.Errorf("SecretKey is invalid")
}
// 增加 Authorization header
authTime := NewAuthTime(defaultAuthExpire)
AddAuthorizationHeader(ak, sk, token, req, authTime)
resp, err := t.transport(req).RoundTrip(req)
return resp, err
}
func (t *AuthorizationTransport) transport(req *http.Request) http.RoundTripper {
if t.Transport != nil {
return t.Transport
}
// 内部域名默认使用DNS打散
if rc := internalHost.MatchString(req.URL.Hostname()); rc {
return DNSScatterTransport
}
return http.DefaultTransport
}
type CVMSecurityCredentials struct {
TmpSecretId string `json:",omitempty"`
TmpSecretKey string `json:",omitempty"`
ExpiredTime int64 `json:",omitempty"`
Expiration string `json:",omitempty"`
Token string `json:",omitempty"`
Code string `json:",omitempty"`
}
type CVMCredentialTransport struct {
RoleName string
Transport http.RoundTripper
secretID string
secretKey string
sessionToken string
expiredTime int64
rwLocker sync.RWMutex
}
func (t *CVMCredentialTransport) GetRoles() ([]string, error) {
urlname := fmt.Sprintf("%s://%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI)
resp, err := http.Get(urlname)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
bs, _ := ioutil.ReadAll(resp.Body)
return nil, fmt.Errorf("get cvm security-credentials role failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs))
}
bs, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
roles := strings.Split(strings.TrimSpace(string(bs)), "\n")
if len(roles) == 0 {
return nil, fmt.Errorf("get cvm security-credentials role failed, No valid cam role was found")
}
return roles, nil
}
// https://cloud.tencent.com/document/product/213/4934
func (t *CVMCredentialTransport) UpdateCredential(now int64) (string, string, string, error) {
t.rwLocker.Lock()
defer t.rwLocker.Unlock()
if t.expiredTime > now+defaultCVMAuthExpire {
return t.secretID, t.secretKey, t.sessionToken, nil
}
roleName := t.RoleName
if roleName == "" {
roles, err := t.GetRoles()
if err != nil {
return t.secretID, t.secretKey, t.sessionToken, err
}
roleName = roles[0]
}
urlname := fmt.Sprintf("%s://%s/%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI, roleName)
resp, err := http.Get(urlname)
if err != nil {
return t.secretID, t.secretKey, t.sessionToken, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
bs, _ := ioutil.ReadAll(resp.Body)
return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs))
}
var cred CVMSecurityCredentials
err = json.NewDecoder(resp.Body).Decode(&cred)
if err != nil {
return t.secretID, t.secretKey, t.sessionToken, err
}
if cred.Code != "Success" {
return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, Code:%v", cred.Code)
}
t.secretID, t.secretKey, t.sessionToken, t.expiredTime = cred.TmpSecretId, cred.TmpSecretKey, cred.Token, cred.ExpiredTime
return t.secretID, t.secretKey, t.sessionToken, nil
}
func (t *CVMCredentialTransport) GetCredential() (string, string, string, error) {
now := time.Now().Unix()
t.rwLocker.RLock()
// 提前 defaultCVMAuthExpire 获取重新获取临时密钥
if t.expiredTime <= now+defaultCVMAuthExpire {
expiredTime := t.expiredTime
t.rwLocker.RUnlock()
secretID, secretKey, secretToken, err := t.UpdateCredential(now)
// 获取临时密钥失败但密钥未过期
if err != nil && now < expiredTime {
err = nil
}
return secretID, secretKey, secretToken, err
}
defer t.rwLocker.RUnlock()
return t.secretID, t.secretKey, t.sessionToken, nil
}
func (t *CVMCredentialTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ak, sk, token, err := t.GetCredential()
if err != nil {
return nil, err
}
req = cloneRequest(req)
// 增加 Authorization header
authTime := NewAuthTime(defaultAuthExpire)
AddAuthorizationHeader(ak, sk, token, req, authTime)
resp, err := t.transport().RoundTrip(req)
return resp, err
}
func (t *CVMCredentialTransport) transport() http.RoundTripper {
if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
}
type CredentialTransport struct {
Transport http.RoundTripper
Credential CredentialIface
}
func (t *CredentialTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ak, sk, token := t.Credential.GetSecretId(), t.Credential.GetSecretKey(), t.Credential.GetToken()
req = cloneRequest(req)
// 增加 Authorization header
authTime := NewAuthTime(defaultAuthExpire)
AddAuthorizationHeader(ak, sk, token, req, authTime)
resp, err := t.transport().RoundTrip(req)
return resp, err
}
func (t *CredentialTransport) transport() http.RoundTripper {
if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
}
type CredentialIface interface {
GetSecretId() string
GetToken() string
GetSecretKey() string
}
func NewTokenCredential(secretId, secretKey, token string) *Credential {
return &Credential{
SecretID: secretId,
SecretKey: secretKey,
SessionToken: token,
}
}
func (c *Credential) GetSecretKey() string {
return c.SecretKey
}
func (c *Credential) GetSecretId() string {
return c.SecretID
}
func (c *Credential) GetToken() string {
return c.SessionToken
}