You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go

843 lines
26 KiB

// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package topology
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
// Connection state constants.
const (
connDisconnected int64 = iota
connConnected
connInitialized
)
var globalConnectionID uint64 = 1
var (
defaultMaxMessageSize uint32 = 48000000
errResponseTooLarge = errors.New("length of read message too large")
errLoadBalancedStateMismatch = errors.New("driver attempted to initialize in load balancing mode, but the server does not support this mode")
)
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
type connection struct {
// state must be accessed using the atomic package and should be at the beginning of the struct.
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
state int64
id string
nc net.Conn // When nil, the connection is closed.
addr address.Address
idleTimeout time.Duration
idleDeadline atomic.Value // Stores a time.Time
readTimeout time.Duration
writeTimeout time.Duration
desc description.Server
helloRTT time.Duration
compressor wiremessage.CompressorID
zliblevel int
zstdLevel int
connectDone chan struct{}
config *connectionConfig
cancelConnectContext context.CancelFunc
connectContextMade chan struct{}
canStream bool
currentlyStreaming bool
connectContextMutex sync.Mutex
cancellationListener cancellationListener
serverConnectionID *int64 // the server's ID for this client's connection
// pool related fields
pool *pool
// TODO(GODRIVER-2824): change driverConnectionID type to int64.
driverConnectionID uint64
generation uint64
}
// newConnection handles the creation of a connection. It does not connect the connection.
func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
cfg := newConnectionConfig(opts...)
id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
c := &connection{
id: id,
addr: addr,
idleTimeout: cfg.idleTimeout,
readTimeout: cfg.readTimeout,
writeTimeout: cfg.writeTimeout,
connectDone: make(chan struct{}),
config: cfg,
connectContextMade: make(chan struct{}),
cancellationListener: internal.NewCancellationListener(),
}
// Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered
// at any point during connection establishment can be processed without the connection being considered stale.
if !c.config.loadBalanced {
c.setGenerationNumber()
}
atomic.StoreInt64(&c.state, connInitialized)
return c
}
// DriverConnectionID returns the driver connection ID.
// TODO(GODRIVER-2824): change return type to int64.
func (c *connection) DriverConnectionID() uint64 {
return c.driverConnectionID
}
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
// configuration.
func (c *connection) setGenerationNumber() {
if c.config.getGenerationFn != nil {
c.generation = c.config.getGenerationFn(c.desc.ServiceID)
}
}
// hasGenerationNumber returns true if the connection has set its generation number. If so, this indicates that the
// generationNumberFn provided via the connection options has been called exactly once.
func (c *connection) hasGenerationNumber() bool {
if !c.config.loadBalanced {
// The generation is known for all non-LB clusters once the connection object has been created.
return true
}
// For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection
// description has been updated to reflect that it's behind an LB.
return c.desc.LoadBalanced()
}
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
// handshakes. All errors returned by connect are considered "before the handshake completes" and
// must be handled by calling the appropriate SDAM handshake error handler.
func (c *connection) connect(ctx context.Context) (err error) {
if !atomic.CompareAndSwapInt64(&c.state, connInitialized, connConnected) {
return nil
}
defer close(c.connectDone)
// If connect returns an error, set the connection status as disconnected and close the
// underlying net.Conn if it was created.
defer func() {
if err != nil {
atomic.StoreInt64(&c.state, connDisconnected)
if c.nc != nil {
_ = c.nc.Close()
}
}
}()
// Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes.
//
// handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied
// to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no
// longer required. This is done in lock because it accesses the shared cancelConnectContext field.
//
// dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the
// cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket
// establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid
// holding the lock longer than necessary.
c.connectContextMutex.Lock()
var handshakeCtx context.Context
handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
c.connectContextMutex.Unlock()
dialCtx := handshakeCtx
var dialCancel context.CancelFunc
if c.config.connectTimeout != 0 {
dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
defer dialCancel()
}
defer func() {
var cancelFn context.CancelFunc
c.connectContextMutex.Lock()
cancelFn = c.cancelConnectContext
c.cancelConnectContext = nil
c.connectContextMutex.Unlock()
if cancelFn != nil {
cancelFn()
}
}()
close(c.connectContextMade)
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
tempNc, err := c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
if err != nil {
return ConnectionError{Wrapped: err, init: true}
}
c.nc = tempNc
if c.config.tlsConfig != nil {
tlsConfig := c.config.tlsConfig.Clone()
// store the result of configureTLS in a separate variable than c.nc to avoid overwriting c.nc with nil in
// error cases.
ocspOpts := &ocsp.VerifyOptions{
Cache: c.config.ocspCache,
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
HTTPClient: c.config.httpClient,
}
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
if err != nil {
return ConnectionError{Wrapped: err, init: true}
}
c.nc = tlsNc
}
// running hello and authentication is handled by a handshaker on the configuration instance.
handshaker := c.config.handshaker
if handshaker == nil {
return nil
}
var handshakeInfo driver.HandshakeInformation
handshakeStartTime := time.Now()
handshakeConn := initConnection{c}
handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn)
if err == nil {
// We only need to retain the Description field as the connection's description. The authentication-related
// fields in handshakeInfo are tracked by the handshaker if necessary.
c.desc = handshakeInfo.Description
c.serverConnectionID = handshakeInfo.ServerConnectionID
c.helloRTT = time.Since(handshakeStartTime)
// If the application has indicated that the cluster is load balanced, ensure the server has included serviceId
// in its handshake response to signal that it knows it's behind an LB as well.
if c.config.loadBalanced && c.desc.ServiceID == nil {
err = errLoadBalancedStateMismatch
}
}
if err == nil {
// For load-balanced connections, the generation number depends on the service ID, which isn't known until the
// initial MongoDB handshake is done. To account for this, we don't attempt to set the connection's generation
// number unless GetHandshakeInformation succeeds.
if c.config.loadBalanced {
c.setGenerationNumber()
}
// If we successfully finished the first part of the handshake and verified LB state, continue with the rest of
// the handshake.
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
}
// We have a failed handshake here
if err != nil {
return ConnectionError{Wrapped: err, init: true}
}
if len(c.desc.Compression) > 0 {
clientMethodLoop:
for _, method := range c.config.compressors {
for _, serverMethod := range c.desc.Compression {
if method != serverMethod {
continue
}
switch strings.ToLower(method) {
case "snappy":
c.compressor = wiremessage.CompressorSnappy
case "zlib":
c.compressor = wiremessage.CompressorZLib
c.zliblevel = wiremessage.DefaultZlibLevel
if c.config.zlibLevel != nil {
c.zliblevel = *c.config.zlibLevel
}
case "zstd":
c.compressor = wiremessage.CompressorZstd
c.zstdLevel = wiremessage.DefaultZstdLevel
if c.config.zstdLevel != nil {
c.zstdLevel = *c.config.zstdLevel
}
}
break clientMethodLoop
}
}
}
return nil
}
func (c *connection) wait() {
if c.connectDone != nil {
<-c.connectDone
}
}
func (c *connection) closeConnectContext() {
<-c.connectContextMade
var cancelFn context.CancelFunc
c.connectContextMutex.Lock()
cancelFn = c.cancelConnectContext
c.cancelConnectContext = nil
c.connectContextMutex.Unlock()
if cancelFn != nil {
cancelFn()
}
}
func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
if originalError == nil {
return nil
}
// If there was an error and the context was cancelled, we assume it happened due to the cancellation.
if ctx.Err() == context.Canceled {
return context.Canceled
}
// If there was a timeout error and the context deadline was used, we convert the error into
// context.DeadlineExceeded.
if !contextDeadlineUsed {
return originalError
}
if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() {
return context.DeadlineExceeded
}
return originalError
}
func (c *connection) cancellationListenerCallback() {
_ = c.close()
}
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
var err error
if atomic.LoadInt64(&c.state) != connConnected {
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
var deadline time.Time
if c.writeTimeout != 0 {
deadline = time.Now().Add(c.writeTimeout)
}
var contextDeadlineUsed bool
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
contextDeadlineUsed = true
deadline = dl
}
if err := c.nc.SetWriteDeadline(deadline); err != nil {
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"}
}
err = c.write(ctx, wm)
if err != nil {
c.close()
return ConnectionError{
ConnectionID: c.id,
Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
message: "unable to write wire message to network",
}
}
return nil
}
func (c *connection) write(ctx context.Context, wm []byte) (err error) {
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
defer func() {
// There is a race condition between Write and StopListening. If the context is cancelled after c.nc.Write
// succeeds, the cancellation listener could fire and close the connection. In this case, the connection has
// been invalidated but the error is nil. To account for this, overwrite the error to context.Cancelled if
// the abortedForCancellation flag was set.
if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
err = context.Canceled
}
}()
_, err = c.nc.Write(wm)
return err
}
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
if atomic.LoadInt64(&c.state) != connConnected {
return nil, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
}
var deadline time.Time
if c.readTimeout != 0 {
deadline = time.Now().Add(c.readTimeout)
}
var contextDeadlineUsed bool
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
contextDeadlineUsed = true
deadline = dl
}
if err := c.nc.SetReadDeadline(deadline); err != nil {
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"}
}
dst, errMsg, err := c.read(ctx)
if err != nil {
// We closeConnection the connection because we don't know if there are other bytes left to read.
c.close()
message := errMsg
if err == io.EOF {
message = "socket was unexpectedly closed"
}
return nil, ConnectionError{
ConnectionID: c.id,
Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
message: message,
}
}
return dst, nil
}
func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
defer func() {
// If the context is cancelled after we finish reading the server response, the cancellation listener could fire
// even though the socket reads succeed. To account for this, we overwrite err to be context.Canceled if the
// abortedForCancellation flag is set.
if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
errMsg = "unable to read server response"
err = context.Canceled
}
}()
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
// reslice dst once instead of twice.
var sizeBuf [4]byte
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
// because there might be more than one wire message waiting to be read, for example when
// reading messages from an exhaust cursor.
_, err = io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
return nil, "incomplete read of message header", err
}
// read the length as an int32
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return nil, errResponseTooLarge.Error(), errResponseTooLarge
}
dst := make([]byte, size)
copy(dst, sizeBuf[:])
_, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
return dst, "incomplete read of full message", err
}
return dst, "", nil
}
func (c *connection) close() error {
// Overwrite the connection state as the first step so only the first close call will execute.
if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) {
return nil
}
var err error
if c.nc != nil {
err = c.nc.Close()
}
return err
}
func (c *connection) closed() bool {
return atomic.LoadInt64(&c.state) == connDisconnected
}
func (c *connection) idleTimeoutExpired() bool {
now := time.Now()
if c.idleTimeout > 0 {
idleDeadline, ok := c.idleDeadline.Load().(time.Time)
if ok && now.After(idleDeadline) {
return true
}
}
return false
}
func (c *connection) bumpIdleDeadline() {
if c.idleTimeout > 0 {
c.idleDeadline.Store(time.Now().Add(c.idleTimeout))
}
}
func (c *connection) setCanStream(canStream bool) {
c.canStream = canStream
}
func (c initConnection) supportsStreaming() bool {
return c.canStream
}
func (c *connection) setStreaming(streaming bool) {
c.currentlyStreaming = streaming
}
func (c *connection) getCurrentlyStreaming() bool {
return c.currentlyStreaming
}
func (c *connection) setSocketTimeout(timeout time.Duration) {
c.readTimeout = timeout
c.writeTimeout = timeout
}
func (c *connection) ID() string {
return c.id
}
func (c *connection) ServerConnectionID() *int64 {
return c.serverConnectionID
}
// initConnection is an adapter used during connection initialization. It has the minimum
// functionality necessary to implement the driver.Connection interface, which is required to pass a
// *connection to a Handshaker.
type initConnection struct{ *connection }
var _ driver.Connection = initConnection{}
var _ driver.StreamerConnection = initConnection{}
func (c initConnection) Description() description.Server {
if c.connection == nil {
return description.Server{}
}
return c.connection.desc
}
func (c initConnection) Close() error { return nil }
func (c initConnection) ID() string { return c.id }
func (c initConnection) Address() address.Address { return c.addr }
func (c initConnection) Stale() bool { return false }
func (c initConnection) LocalAddress() address.Address {
if c.connection == nil || c.nc == nil {
return address.Address("0.0.0.0")
}
return address.Address(c.nc.LocalAddr().String())
}
func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
return c.writeWireMessage(ctx, wm)
}
func (c initConnection) ReadWireMessage(ctx context.Context) ([]byte, error) {
return c.readWireMessage(ctx)
}
func (c initConnection) SetStreaming(streaming bool) {
c.setStreaming(streaming)
}
func (c initConnection) CurrentlyStreaming() bool {
return c.getCurrentlyStreaming()
}
func (c initConnection) SupportsStreaming() bool {
return c.supportsStreaming()
}
// Connection implements the driver.Connection interface to allow reading and writing wire
// messages and the driver.Expirable interface to allow expiring. It wraps an underlying
// topology.connection to make it more goroutine-safe and nil-safe.
type Connection struct {
connection *connection
refCount int
cleanupPoolFn func()
// cleanupServerFn resets the server state when a connection is returned to the connection pool
// via Close() or expired via Expire().
cleanupServerFn func()
mu sync.RWMutex
}
var _ driver.Connection = (*Connection)(nil)
var _ driver.Expirable = (*Connection)(nil)
var _ driver.PinnedConnection = (*Connection)(nil)
// WriteWireMessage handles writing a wire message to the underlying connection.
func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return ErrConnectionClosed
}
return c.connection.writeWireMessage(ctx, wm)
}
// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
// will be overwritten with the new wire message.
func (c *Connection) ReadWireMessage(ctx context.Context) ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return nil, ErrConnectionClosed
}
return c.connection.readWireMessage(ctx)
}
// CompressWireMessage handles compressing the provided wire message using the underlying
// connection's compressor. The dst parameter will be overwritten with the new wire message. If
// there is no compressor set on the underlying connection, then no compression will be performed.
func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return dst, ErrConnectionClosed
}
if c.connection.compressor == wiremessage.CompressorNoOp {
return append(dst, src...), nil
}
_, reqid, respto, origcode, rem, ok := wiremessage.ReadHeader(src)
if !ok {
return dst, errors.New("wiremessage is too short to compress, less than 16 bytes")
}
idx, dst := wiremessage.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
dst = wiremessage.AppendCompressedOriginalOpCode(dst, origcode)
dst = wiremessage.AppendCompressedUncompressedSize(dst, int32(len(rem)))
dst = wiremessage.AppendCompressedCompressorID(dst, c.connection.compressor)
opts := driver.CompressionOpts{
Compressor: c.connection.compressor,
ZlibLevel: c.connection.zliblevel,
ZstdLevel: c.connection.zstdLevel,
}
compressed, err := driver.CompressPayload(rem, opts)
if err != nil {
return nil, err
}
dst = wiremessage.AppendCompressedCompressedMessage(dst, compressed)
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
}
// Description returns the server description of the server this connection is connected to.
func (c *Connection) Description() description.Server {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return description.Server{}
}
return c.connection.desc
}
// Close returns this connection to the connection pool. This method may not closeConnection the underlying
// socket.
func (c *Connection) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil || c.refCount > 0 {
return nil
}
return c.cleanupReferences()
}
// Expire closes this connection and will closeConnection the underlying socket.
func (c *Connection) Expire() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil {
return nil
}
_ = c.connection.close()
return c.cleanupReferences()
}
func (c *Connection) cleanupReferences() error {
err := c.connection.pool.checkIn(c.connection)
if c.cleanupPoolFn != nil {
c.cleanupPoolFn()
c.cleanupPoolFn = nil
}
if c.cleanupServerFn != nil {
c.cleanupServerFn()
c.cleanupServerFn = nil
}
c.connection = nil
return err
}
// Alive returns if the connection is still alive.
func (c *Connection) Alive() bool {
return c.connection != nil
}
// ID returns the ID of this connection.
func (c *Connection) ID() string {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return "<closed>"
}
return c.connection.id
}
// ServerConnectionID returns the server connection ID of this connection.
func (c *Connection) ServerConnectionID() *int64 {
if c.connection == nil {
return nil
}
return c.connection.serverConnectionID
}
// Stale returns if the connection is stale.
func (c *Connection) Stale() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.connection.pool.stale(c.connection)
}
// Address returns the address of this connection.
func (c *Connection) Address() address.Address {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil {
return address.Address("0.0.0.0")
}
return c.connection.addr
}
// LocalAddress returns the local address of the connection
func (c *Connection) LocalAddress() address.Address {
c.mu.RLock()
defer c.mu.RUnlock()
if c.connection == nil || c.connection.nc == nil {
return address.Address("0.0.0.0")
}
return address.Address(c.connection.nc.LocalAddr().String())
}
// PinToCursor updates this connection to reflect that it is pinned to a cursor.
func (c *Connection) PinToCursor() error {
return c.pin("cursor", c.connection.pool.pinConnectionToCursor, c.connection.pool.unpinConnectionFromCursor)
}
// PinToTransaction updates this connection to reflect that it is pinned to a transaction.
func (c *Connection) PinToTransaction() error {
return c.pin("transaction", c.connection.pool.pinConnectionToTransaction, c.connection.pool.unpinConnectionFromTransaction)
}
func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil {
return fmt.Errorf("attempted to pin a connection for a %s, but the connection has already been returned to the pool", reason)
}
// Only use the provided callbacks for the first reference to avoid double-counting pinned connection statistics
// in the pool.
if c.refCount == 0 {
updatePoolFn()
c.cleanupPoolFn = cleanupPoolFn
}
c.refCount++
return nil
}
// UnpinFromCursor updates this connection to reflect that it is no longer pinned to a cursor.
func (c *Connection) UnpinFromCursor() error {
return c.unpin("cursor")
}
// UnpinFromTransaction updates this connection to reflect that it is no longer pinned to a transaction.
func (c *Connection) UnpinFromTransaction() error {
return c.unpin("transaction")
}
func (c *Connection) unpin(reason string) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connection == nil {
// We don't error here because the resource could have been forcefully closed via Expire.
return nil
}
if c.refCount == 0 {
return fmt.Errorf("attempted to unpin a connection from a %s, but the connection is not pinned by any resources", reason)
}
c.refCount--
return nil
}
// DriverConnectionID returns the driver connection ID.
// TODO(GODRIVER-2824): change return type to int64.
func (c *Connection) DriverConnectionID() uint64 {
return c.connection.DriverConnectionID()
}
func configureTLS(ctx context.Context,
tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}
hostname = hostname[:colonPos]
config.ServerName = hostname
}
client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}
// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}