// Copyright 2016 The Mellium Contributors. // Use of this source code is governed by the BSD 2-clause license that can be // found in the LICENSE file. package sasl import ( "bytes" "crypto/hmac" "encoding/base64" "errors" "hash" "strconv" "strings" "golang.org/x/crypto/pbkdf2" ) const ( gs2HeaderCBSupport = "p=tls-unique," gs2HeaderNoServerCBSupport = "y," gs2HeaderNoCBSupport = "n," ) var ( clientKeyInput = []byte("Client Key") serverKeyInput = []byte("Server Key") ) // The number of random bytes to generate for a nonce. const noncerandlen = 16 func getGS2Header(name string, n *Negotiator) (gs2Header []byte) { _, _, identity := n.Credentials() switch { case n.TLSState() == nil || !strings.HasSuffix(name, "-PLUS"): // We do not support channel binding gs2Header = []byte(gs2HeaderNoCBSupport) case n.State()&RemoteCB == RemoteCB: // We support channel binding and the server does too gs2Header = []byte(gs2HeaderCBSupport) case n.State()&RemoteCB != RemoteCB: // We support channel binding but the server does not gs2Header = []byte(gs2HeaderNoServerCBSupport) } if len(identity) > 0 { gs2Header = append(gs2Header, []byte(`a=`)...) gs2Header = append(gs2Header, identity...) } gs2Header = append(gs2Header, ',') return } func scram(name string, fn func() hash.Hash) Mechanism { // BUG(ssw): We need a way to cache the SCRAM client and server key // calculations. return Mechanism{ Name: name, Start: func(m *Negotiator) (bool, []byte, interface{}, error) { user, _, _ := m.Credentials() // Escape "=" and ",". This is mostly the same as bytes.Replace but // faster because we can do both replacements in a single pass. n := bytes.Count(user, []byte{'='}) + bytes.Count(user, []byte{','}) username := make([]byte, len(user)+(n*2)) w := 0 start := 0 for i := 0; i < n; i++ { j := start j += bytes.IndexAny(user[start:], "=,") w += copy(username[w:], user[start:j]) switch user[j] { case '=': w += copy(username[w:], "=3D") case ',': w += copy(username[w:], "=2C") } start = j + 1 } copy(username[w:], user[start:]) clientFirstMessage := make([]byte, 5+len(m.Nonce())+len(username)) copy(clientFirstMessage, "n=") copy(clientFirstMessage[2:], username) copy(clientFirstMessage[2+len(username):], ",r=") copy(clientFirstMessage[5+len(username):], m.Nonce()) return true, append(getGS2Header(name, m), clientFirstMessage...), clientFirstMessage, nil }, Next: func(m *Negotiator, challenge []byte, data interface{}) (more bool, resp []byte, cache interface{}, err error) { if challenge == nil || len(challenge) == 0 { return more, resp, cache, ErrInvalidChallenge } if m.State()&Receiving == Receiving { panic("not yet implemented") } return scramClientNext(name, fn, m, challenge, data) }, } } func scramClientNext(name string, fn func() hash.Hash, m *Negotiator, challenge []byte, data interface{}) (more bool, resp []byte, cache interface{}, err error) { _, password, _ := m.Credentials() state := m.State() switch state & StepMask { case AuthTextSent: iter := -1 var salt, nonce []byte for _, field := range bytes.Split(challenge, []byte{','}) { if len(field) < 3 || (len(field) >= 2 && field[1] != '=') { continue } switch field[0] { case 'i': ival := string(bytes.TrimRight(field[2:], "\x00")) if iter, err = strconv.Atoi(ival); err != nil { return } case 's': salt = make([]byte, base64.StdEncoding.DecodedLen(len(field)-2)) var n int n, err = base64.StdEncoding.Decode(salt, field[2:]) salt = salt[:n] if err != nil { return } case 'r': nonce = field[2:] case 'm': // RFC 5802: // m: This attribute is reserved for future extensibility. In this // version of SCRAM, its presence in a client or a server message // MUST cause authentication failure when the attribute is parsed by // the other end. err = errors.New("Server sent reserved attribute `m'") return } } switch { case iter < 0: err = errors.New("Iteration count is missing") return case iter < 0: err = errors.New("Iteration count is invalid") return case nonce == nil || !bytes.HasPrefix(nonce, m.Nonce()): err = errors.New("Server nonce does not match client nonce") return case salt == nil: err = errors.New("Server sent empty salt") return } gs2Header := getGS2Header(name, m) tlsState := m.TLSState() var channelBinding []byte if tlsState != nil && strings.HasSuffix(name, "-PLUS") { channelBinding = make( []byte, 2+base64.StdEncoding.EncodedLen(len(gs2Header)+len(tlsState.TLSUnique)), ) base64.StdEncoding.Encode(channelBinding[2:], append(gs2Header, tlsState.TLSUnique...)) channelBinding[0] = 'c' channelBinding[1] = '=' } else { channelBinding = make( []byte, 2+base64.StdEncoding.EncodedLen(len(gs2Header)), ) base64.StdEncoding.Encode(channelBinding[2:], gs2Header) channelBinding[0] = 'c' channelBinding[1] = '=' } clientFinalMessageWithoutProof := append(channelBinding, []byte(",r=")...) clientFinalMessageWithoutProof = append(clientFinalMessageWithoutProof, nonce...) clientFirstMessage := data.([]byte) authMessage := append(clientFirstMessage, ',') authMessage = append(authMessage, challenge...) authMessage = append(authMessage, ',') authMessage = append(authMessage, clientFinalMessageWithoutProof...) saltedPassword := pbkdf2.Key(password, salt, iter, fn().Size(), fn) h := hmac.New(fn, saltedPassword) h.Write(serverKeyInput) serverKey := h.Sum(nil) h.Reset() h.Write(clientKeyInput) clientKey := h.Sum(nil) h = hmac.New(fn, serverKey) h.Write(authMessage) serverSignature := h.Sum(nil) h = fn() h.Write(clientKey) storedKey := h.Sum(nil) h = hmac.New(fn, storedKey) h.Write(authMessage) clientSignature := h.Sum(nil) clientProof := make([]byte, len(clientKey)) xorBytes(clientProof, clientKey, clientSignature) encodedClientProof := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof))) base64.StdEncoding.Encode(encodedClientProof, clientProof) clientFinalMessage := append(clientFinalMessageWithoutProof, []byte(",p=")...) clientFinalMessage = append(clientFinalMessage, encodedClientProof...) return true, clientFinalMessage, serverSignature, nil case ResponseSent: clientCalculatedServerFinalMessage := "v=" + base64.StdEncoding.EncodeToString(data.([]byte)) if clientCalculatedServerFinalMessage != string(challenge) { err = ErrAuthn return } // Success! return false, nil, nil, nil } err = ErrInvalidState return }