You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
843 lines
26 KiB
843 lines
26 KiB
// Copyright (C) MongoDB, Inc. 2017-present.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
// not use this file except in compliance with the License. You may obtain
|
|
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
package topology
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"go.mongodb.org/mongo-driver/internal"
|
|
"go.mongodb.org/mongo-driver/mongo/address"
|
|
"go.mongodb.org/mongo-driver/mongo/description"
|
|
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
|
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
|
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
|
|
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
|
|
)
|
|
|
|
// Connection state constants.
|
|
const (
|
|
connDisconnected int64 = iota
|
|
connConnected
|
|
connInitialized
|
|
)
|
|
|
|
var globalConnectionID uint64 = 1
|
|
|
|
var (
|
|
defaultMaxMessageSize uint32 = 48000000
|
|
errResponseTooLarge = errors.New("length of read message too large")
|
|
errLoadBalancedStateMismatch = errors.New("driver attempted to initialize in load balancing mode, but the server does not support this mode")
|
|
)
|
|
|
|
func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
|
|
|
|
type connection struct {
|
|
// state must be accessed using the atomic package and should be at the beginning of the struct.
|
|
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
|
|
// - suggested layout: https://go101.org/article/memory-layout.html
|
|
state int64
|
|
|
|
id string
|
|
nc net.Conn // When nil, the connection is closed.
|
|
addr address.Address
|
|
idleTimeout time.Duration
|
|
idleDeadline atomic.Value // Stores a time.Time
|
|
readTimeout time.Duration
|
|
writeTimeout time.Duration
|
|
desc description.Server
|
|
helloRTT time.Duration
|
|
compressor wiremessage.CompressorID
|
|
zliblevel int
|
|
zstdLevel int
|
|
connectDone chan struct{}
|
|
config *connectionConfig
|
|
cancelConnectContext context.CancelFunc
|
|
connectContextMade chan struct{}
|
|
canStream bool
|
|
currentlyStreaming bool
|
|
connectContextMutex sync.Mutex
|
|
cancellationListener cancellationListener
|
|
serverConnectionID *int64 // the server's ID for this client's connection
|
|
|
|
// pool related fields
|
|
pool *pool
|
|
|
|
// TODO(GODRIVER-2824): change driverConnectionID type to int64.
|
|
driverConnectionID uint64
|
|
generation uint64
|
|
}
|
|
|
|
// newConnection handles the creation of a connection. It does not connect the connection.
|
|
func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
|
|
cfg := newConnectionConfig(opts...)
|
|
|
|
id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
|
|
|
|
c := &connection{
|
|
id: id,
|
|
addr: addr,
|
|
idleTimeout: cfg.idleTimeout,
|
|
readTimeout: cfg.readTimeout,
|
|
writeTimeout: cfg.writeTimeout,
|
|
connectDone: make(chan struct{}),
|
|
config: cfg,
|
|
connectContextMade: make(chan struct{}),
|
|
cancellationListener: internal.NewCancellationListener(),
|
|
}
|
|
// Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered
|
|
// at any point during connection establishment can be processed without the connection being considered stale.
|
|
if !c.config.loadBalanced {
|
|
c.setGenerationNumber()
|
|
}
|
|
atomic.StoreInt64(&c.state, connInitialized)
|
|
|
|
return c
|
|
}
|
|
|
|
// DriverConnectionID returns the driver connection ID.
|
|
// TODO(GODRIVER-2824): change return type to int64.
|
|
func (c *connection) DriverConnectionID() uint64 {
|
|
return c.driverConnectionID
|
|
}
|
|
|
|
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
|
|
// configuration.
|
|
func (c *connection) setGenerationNumber() {
|
|
if c.config.getGenerationFn != nil {
|
|
c.generation = c.config.getGenerationFn(c.desc.ServiceID)
|
|
}
|
|
}
|
|
|
|
// hasGenerationNumber returns true if the connection has set its generation number. If so, this indicates that the
|
|
// generationNumberFn provided via the connection options has been called exactly once.
|
|
func (c *connection) hasGenerationNumber() bool {
|
|
if !c.config.loadBalanced {
|
|
// The generation is known for all non-LB clusters once the connection object has been created.
|
|
return true
|
|
}
|
|
|
|
// For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection
|
|
// description has been updated to reflect that it's behind an LB.
|
|
return c.desc.LoadBalanced()
|
|
}
|
|
|
|
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
|
|
// handshakes. All errors returned by connect are considered "before the handshake completes" and
|
|
// must be handled by calling the appropriate SDAM handshake error handler.
|
|
func (c *connection) connect(ctx context.Context) (err error) {
|
|
if !atomic.CompareAndSwapInt64(&c.state, connInitialized, connConnected) {
|
|
return nil
|
|
}
|
|
|
|
defer close(c.connectDone)
|
|
|
|
// If connect returns an error, set the connection status as disconnected and close the
|
|
// underlying net.Conn if it was created.
|
|
defer func() {
|
|
if err != nil {
|
|
atomic.StoreInt64(&c.state, connDisconnected)
|
|
|
|
if c.nc != nil {
|
|
_ = c.nc.Close()
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes.
|
|
//
|
|
// handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied
|
|
// to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no
|
|
// longer required. This is done in lock because it accesses the shared cancelConnectContext field.
|
|
//
|
|
// dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the
|
|
// cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket
|
|
// establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid
|
|
// holding the lock longer than necessary.
|
|
c.connectContextMutex.Lock()
|
|
var handshakeCtx context.Context
|
|
handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
|
|
c.connectContextMutex.Unlock()
|
|
|
|
dialCtx := handshakeCtx
|
|
var dialCancel context.CancelFunc
|
|
if c.config.connectTimeout != 0 {
|
|
dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
|
|
defer dialCancel()
|
|
}
|
|
|
|
defer func() {
|
|
var cancelFn context.CancelFunc
|
|
|
|
c.connectContextMutex.Lock()
|
|
cancelFn = c.cancelConnectContext
|
|
c.cancelConnectContext = nil
|
|
c.connectContextMutex.Unlock()
|
|
|
|
if cancelFn != nil {
|
|
cancelFn()
|
|
}
|
|
}()
|
|
|
|
close(c.connectContextMade)
|
|
|
|
// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
|
|
tempNc, err := c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
|
|
if err != nil {
|
|
return ConnectionError{Wrapped: err, init: true}
|
|
}
|
|
c.nc = tempNc
|
|
|
|
if c.config.tlsConfig != nil {
|
|
tlsConfig := c.config.tlsConfig.Clone()
|
|
|
|
// store the result of configureTLS in a separate variable than c.nc to avoid overwriting c.nc with nil in
|
|
// error cases.
|
|
ocspOpts := &ocsp.VerifyOptions{
|
|
Cache: c.config.ocspCache,
|
|
DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
|
|
HTTPClient: c.config.httpClient,
|
|
}
|
|
tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
|
|
if err != nil {
|
|
return ConnectionError{Wrapped: err, init: true}
|
|
}
|
|
c.nc = tlsNc
|
|
}
|
|
|
|
// running hello and authentication is handled by a handshaker on the configuration instance.
|
|
handshaker := c.config.handshaker
|
|
if handshaker == nil {
|
|
return nil
|
|
}
|
|
|
|
var handshakeInfo driver.HandshakeInformation
|
|
handshakeStartTime := time.Now()
|
|
handshakeConn := initConnection{c}
|
|
handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn)
|
|
if err == nil {
|
|
// We only need to retain the Description field as the connection's description. The authentication-related
|
|
// fields in handshakeInfo are tracked by the handshaker if necessary.
|
|
c.desc = handshakeInfo.Description
|
|
c.serverConnectionID = handshakeInfo.ServerConnectionID
|
|
c.helloRTT = time.Since(handshakeStartTime)
|
|
|
|
// If the application has indicated that the cluster is load balanced, ensure the server has included serviceId
|
|
// in its handshake response to signal that it knows it's behind an LB as well.
|
|
if c.config.loadBalanced && c.desc.ServiceID == nil {
|
|
err = errLoadBalancedStateMismatch
|
|
}
|
|
}
|
|
if err == nil {
|
|
// For load-balanced connections, the generation number depends on the service ID, which isn't known until the
|
|
// initial MongoDB handshake is done. To account for this, we don't attempt to set the connection's generation
|
|
// number unless GetHandshakeInformation succeeds.
|
|
if c.config.loadBalanced {
|
|
c.setGenerationNumber()
|
|
}
|
|
|
|
// If we successfully finished the first part of the handshake and verified LB state, continue with the rest of
|
|
// the handshake.
|
|
err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
|
|
}
|
|
|
|
// We have a failed handshake here
|
|
if err != nil {
|
|
return ConnectionError{Wrapped: err, init: true}
|
|
}
|
|
|
|
if len(c.desc.Compression) > 0 {
|
|
clientMethodLoop:
|
|
for _, method := range c.config.compressors {
|
|
for _, serverMethod := range c.desc.Compression {
|
|
if method != serverMethod {
|
|
continue
|
|
}
|
|
|
|
switch strings.ToLower(method) {
|
|
case "snappy":
|
|
c.compressor = wiremessage.CompressorSnappy
|
|
case "zlib":
|
|
c.compressor = wiremessage.CompressorZLib
|
|
c.zliblevel = wiremessage.DefaultZlibLevel
|
|
if c.config.zlibLevel != nil {
|
|
c.zliblevel = *c.config.zlibLevel
|
|
}
|
|
case "zstd":
|
|
c.compressor = wiremessage.CompressorZstd
|
|
c.zstdLevel = wiremessage.DefaultZstdLevel
|
|
if c.config.zstdLevel != nil {
|
|
c.zstdLevel = *c.config.zstdLevel
|
|
}
|
|
}
|
|
break clientMethodLoop
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *connection) wait() {
|
|
if c.connectDone != nil {
|
|
<-c.connectDone
|
|
}
|
|
}
|
|
|
|
func (c *connection) closeConnectContext() {
|
|
<-c.connectContextMade
|
|
var cancelFn context.CancelFunc
|
|
|
|
c.connectContextMutex.Lock()
|
|
cancelFn = c.cancelConnectContext
|
|
c.cancelConnectContext = nil
|
|
c.connectContextMutex.Unlock()
|
|
|
|
if cancelFn != nil {
|
|
cancelFn()
|
|
}
|
|
}
|
|
|
|
func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
|
|
if originalError == nil {
|
|
return nil
|
|
}
|
|
|
|
// If there was an error and the context was cancelled, we assume it happened due to the cancellation.
|
|
if ctx.Err() == context.Canceled {
|
|
return context.Canceled
|
|
}
|
|
|
|
// If there was a timeout error and the context deadline was used, we convert the error into
|
|
// context.DeadlineExceeded.
|
|
if !contextDeadlineUsed {
|
|
return originalError
|
|
}
|
|
if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() {
|
|
return context.DeadlineExceeded
|
|
}
|
|
|
|
return originalError
|
|
}
|
|
|
|
func (c *connection) cancellationListenerCallback() {
|
|
_ = c.close()
|
|
}
|
|
|
|
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
|
|
var err error
|
|
if atomic.LoadInt64(&c.state) != connConnected {
|
|
return ConnectionError{ConnectionID: c.id, message: "connection is closed"}
|
|
}
|
|
|
|
var deadline time.Time
|
|
if c.writeTimeout != 0 {
|
|
deadline = time.Now().Add(c.writeTimeout)
|
|
}
|
|
|
|
var contextDeadlineUsed bool
|
|
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
|
|
contextDeadlineUsed = true
|
|
deadline = dl
|
|
}
|
|
|
|
if err := c.nc.SetWriteDeadline(deadline); err != nil {
|
|
return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"}
|
|
}
|
|
|
|
err = c.write(ctx, wm)
|
|
if err != nil {
|
|
c.close()
|
|
return ConnectionError{
|
|
ConnectionID: c.id,
|
|
Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
|
|
message: "unable to write wire message to network",
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *connection) write(ctx context.Context, wm []byte) (err error) {
|
|
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
|
|
defer func() {
|
|
// There is a race condition between Write and StopListening. If the context is cancelled after c.nc.Write
|
|
// succeeds, the cancellation listener could fire and close the connection. In this case, the connection has
|
|
// been invalidated but the error is nil. To account for this, overwrite the error to context.Cancelled if
|
|
// the abortedForCancellation flag was set.
|
|
|
|
if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
|
|
err = context.Canceled
|
|
}
|
|
}()
|
|
|
|
_, err = c.nc.Write(wm)
|
|
return err
|
|
}
|
|
|
|
// readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
|
|
func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
|
|
if atomic.LoadInt64(&c.state) != connConnected {
|
|
return nil, ConnectionError{ConnectionID: c.id, message: "connection is closed"}
|
|
}
|
|
|
|
var deadline time.Time
|
|
if c.readTimeout != 0 {
|
|
deadline = time.Now().Add(c.readTimeout)
|
|
}
|
|
|
|
var contextDeadlineUsed bool
|
|
if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
|
|
contextDeadlineUsed = true
|
|
deadline = dl
|
|
}
|
|
|
|
if err := c.nc.SetReadDeadline(deadline); err != nil {
|
|
return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"}
|
|
}
|
|
|
|
dst, errMsg, err := c.read(ctx)
|
|
if err != nil {
|
|
// We closeConnection the connection because we don't know if there are other bytes left to read.
|
|
c.close()
|
|
message := errMsg
|
|
if err == io.EOF {
|
|
message = "socket was unexpectedly closed"
|
|
}
|
|
return nil, ConnectionError{
|
|
ConnectionID: c.id,
|
|
Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
|
|
message: message,
|
|
}
|
|
}
|
|
|
|
return dst, nil
|
|
}
|
|
|
|
func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
|
|
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
|
|
defer func() {
|
|
// If the context is cancelled after we finish reading the server response, the cancellation listener could fire
|
|
// even though the socket reads succeed. To account for this, we overwrite err to be context.Canceled if the
|
|
// abortedForCancellation flag is set.
|
|
|
|
if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
|
|
errMsg = "unable to read server response"
|
|
err = context.Canceled
|
|
}
|
|
}()
|
|
|
|
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
|
|
// reslice dst once instead of twice.
|
|
var sizeBuf [4]byte
|
|
|
|
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
|
|
// because there might be more than one wire message waiting to be read, for example when
|
|
// reading messages from an exhaust cursor.
|
|
_, err = io.ReadFull(c.nc, sizeBuf[:])
|
|
if err != nil {
|
|
return nil, "incomplete read of message header", err
|
|
}
|
|
|
|
// read the length as an int32
|
|
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
|
|
|
|
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
|
|
// defaultMaxMessageSize instead.
|
|
maxMessageSize := c.desc.MaxMessageSize
|
|
if maxMessageSize == 0 {
|
|
maxMessageSize = defaultMaxMessageSize
|
|
}
|
|
if uint32(size) > maxMessageSize {
|
|
return nil, errResponseTooLarge.Error(), errResponseTooLarge
|
|
}
|
|
|
|
dst := make([]byte, size)
|
|
copy(dst, sizeBuf[:])
|
|
|
|
_, err = io.ReadFull(c.nc, dst[4:])
|
|
if err != nil {
|
|
return dst, "incomplete read of full message", err
|
|
}
|
|
|
|
return dst, "", nil
|
|
}
|
|
|
|
func (c *connection) close() error {
|
|
// Overwrite the connection state as the first step so only the first close call will execute.
|
|
if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) {
|
|
return nil
|
|
}
|
|
|
|
var err error
|
|
if c.nc != nil {
|
|
err = c.nc.Close()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (c *connection) closed() bool {
|
|
return atomic.LoadInt64(&c.state) == connDisconnected
|
|
}
|
|
|
|
func (c *connection) idleTimeoutExpired() bool {
|
|
now := time.Now()
|
|
if c.idleTimeout > 0 {
|
|
idleDeadline, ok := c.idleDeadline.Load().(time.Time)
|
|
if ok && now.After(idleDeadline) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (c *connection) bumpIdleDeadline() {
|
|
if c.idleTimeout > 0 {
|
|
c.idleDeadline.Store(time.Now().Add(c.idleTimeout))
|
|
}
|
|
}
|
|
|
|
func (c *connection) setCanStream(canStream bool) {
|
|
c.canStream = canStream
|
|
}
|
|
|
|
func (c initConnection) supportsStreaming() bool {
|
|
return c.canStream
|
|
}
|
|
|
|
func (c *connection) setStreaming(streaming bool) {
|
|
c.currentlyStreaming = streaming
|
|
}
|
|
|
|
func (c *connection) getCurrentlyStreaming() bool {
|
|
return c.currentlyStreaming
|
|
}
|
|
|
|
func (c *connection) setSocketTimeout(timeout time.Duration) {
|
|
c.readTimeout = timeout
|
|
c.writeTimeout = timeout
|
|
}
|
|
|
|
func (c *connection) ID() string {
|
|
return c.id
|
|
}
|
|
|
|
func (c *connection) ServerConnectionID() *int64 {
|
|
return c.serverConnectionID
|
|
}
|
|
|
|
// initConnection is an adapter used during connection initialization. It has the minimum
|
|
// functionality necessary to implement the driver.Connection interface, which is required to pass a
|
|
// *connection to a Handshaker.
|
|
type initConnection struct{ *connection }
|
|
|
|
var _ driver.Connection = initConnection{}
|
|
var _ driver.StreamerConnection = initConnection{}
|
|
|
|
func (c initConnection) Description() description.Server {
|
|
if c.connection == nil {
|
|
return description.Server{}
|
|
}
|
|
return c.connection.desc
|
|
}
|
|
func (c initConnection) Close() error { return nil }
|
|
func (c initConnection) ID() string { return c.id }
|
|
func (c initConnection) Address() address.Address { return c.addr }
|
|
func (c initConnection) Stale() bool { return false }
|
|
func (c initConnection) LocalAddress() address.Address {
|
|
if c.connection == nil || c.nc == nil {
|
|
return address.Address("0.0.0.0")
|
|
}
|
|
return address.Address(c.nc.LocalAddr().String())
|
|
}
|
|
func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
|
|
return c.writeWireMessage(ctx, wm)
|
|
}
|
|
func (c initConnection) ReadWireMessage(ctx context.Context) ([]byte, error) {
|
|
return c.readWireMessage(ctx)
|
|
}
|
|
func (c initConnection) SetStreaming(streaming bool) {
|
|
c.setStreaming(streaming)
|
|
}
|
|
func (c initConnection) CurrentlyStreaming() bool {
|
|
return c.getCurrentlyStreaming()
|
|
}
|
|
func (c initConnection) SupportsStreaming() bool {
|
|
return c.supportsStreaming()
|
|
}
|
|
|
|
// Connection implements the driver.Connection interface to allow reading and writing wire
|
|
// messages and the driver.Expirable interface to allow expiring. It wraps an underlying
|
|
// topology.connection to make it more goroutine-safe and nil-safe.
|
|
type Connection struct {
|
|
connection *connection
|
|
refCount int
|
|
cleanupPoolFn func()
|
|
|
|
// cleanupServerFn resets the server state when a connection is returned to the connection pool
|
|
// via Close() or expired via Expire().
|
|
cleanupServerFn func()
|
|
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
var _ driver.Connection = (*Connection)(nil)
|
|
var _ driver.Expirable = (*Connection)(nil)
|
|
var _ driver.PinnedConnection = (*Connection)(nil)
|
|
|
|
// WriteWireMessage handles writing a wire message to the underlying connection.
|
|
func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
if c.connection == nil {
|
|
return ErrConnectionClosed
|
|
}
|
|
return c.connection.writeWireMessage(ctx, wm)
|
|
}
|
|
|
|
// ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
|
|
// will be overwritten with the new wire message.
|
|
func (c *Connection) ReadWireMessage(ctx context.Context) ([]byte, error) {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
if c.connection == nil {
|
|
return nil, ErrConnectionClosed
|
|
}
|
|
return c.connection.readWireMessage(ctx)
|
|
}
|
|
|
|
// CompressWireMessage handles compressing the provided wire message using the underlying
|
|
// connection's compressor. The dst parameter will be overwritten with the new wire message. If
|
|
// there is no compressor set on the underlying connection, then no compression will be performed.
|
|
func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
if c.connection == nil {
|
|
return dst, ErrConnectionClosed
|
|
}
|
|
if c.connection.compressor == wiremessage.CompressorNoOp {
|
|
return append(dst, src...), nil
|
|
}
|
|
_, reqid, respto, origcode, rem, ok := wiremessage.ReadHeader(src)
|
|
if !ok {
|
|
return dst, errors.New("wiremessage is too short to compress, less than 16 bytes")
|
|
}
|
|
idx, dst := wiremessage.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
|
|
dst = wiremessage.AppendCompressedOriginalOpCode(dst, origcode)
|
|
dst = wiremessage.AppendCompressedUncompressedSize(dst, int32(len(rem)))
|
|
dst = wiremessage.AppendCompressedCompressorID(dst, c.connection.compressor)
|
|
opts := driver.CompressionOpts{
|
|
Compressor: c.connection.compressor,
|
|
ZlibLevel: c.connection.zliblevel,
|
|
ZstdLevel: c.connection.zstdLevel,
|
|
}
|
|
compressed, err := driver.CompressPayload(rem, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dst = wiremessage.AppendCompressedCompressedMessage(dst, compressed)
|
|
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
|
|
}
|
|
|
|
// Description returns the server description of the server this connection is connected to.
|
|
func (c *Connection) Description() description.Server {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
if c.connection == nil {
|
|
return description.Server{}
|
|
}
|
|
return c.connection.desc
|
|
}
|
|
|
|
// Close returns this connection to the connection pool. This method may not closeConnection the underlying
|
|
// socket.
|
|
func (c *Connection) Close() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.connection == nil || c.refCount > 0 {
|
|
return nil
|
|
}
|
|
|
|
return c.cleanupReferences()
|
|
}
|
|
|
|
// Expire closes this connection and will closeConnection the underlying socket.
|
|
func (c *Connection) Expire() error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.connection == nil {
|
|
return nil
|
|
}
|
|
|
|
_ = c.connection.close()
|
|
return c.cleanupReferences()
|
|
}
|
|
|
|
func (c *Connection) cleanupReferences() error {
|
|
err := c.connection.pool.checkIn(c.connection)
|
|
if c.cleanupPoolFn != nil {
|
|
c.cleanupPoolFn()
|
|
c.cleanupPoolFn = nil
|
|
}
|
|
if c.cleanupServerFn != nil {
|
|
c.cleanupServerFn()
|
|
c.cleanupServerFn = nil
|
|
}
|
|
c.connection = nil
|
|
return err
|
|
}
|
|
|
|
// Alive returns if the connection is still alive.
|
|
func (c *Connection) Alive() bool {
|
|
return c.connection != nil
|
|
}
|
|
|
|
// ID returns the ID of this connection.
|
|
func (c *Connection) ID() string {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
if c.connection == nil {
|
|
return "<closed>"
|
|
}
|
|
return c.connection.id
|
|
}
|
|
|
|
// ServerConnectionID returns the server connection ID of this connection.
|
|
func (c *Connection) ServerConnectionID() *int64 {
|
|
if c.connection == nil {
|
|
return nil
|
|
}
|
|
return c.connection.serverConnectionID
|
|
}
|
|
|
|
// Stale returns if the connection is stale.
|
|
func (c *Connection) Stale() bool {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
return c.connection.pool.stale(c.connection)
|
|
}
|
|
|
|
// Address returns the address of this connection.
|
|
func (c *Connection) Address() address.Address {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
if c.connection == nil {
|
|
return address.Address("0.0.0.0")
|
|
}
|
|
return c.connection.addr
|
|
}
|
|
|
|
// LocalAddress returns the local address of the connection
|
|
func (c *Connection) LocalAddress() address.Address {
|
|
c.mu.RLock()
|
|
defer c.mu.RUnlock()
|
|
if c.connection == nil || c.connection.nc == nil {
|
|
return address.Address("0.0.0.0")
|
|
}
|
|
return address.Address(c.connection.nc.LocalAddr().String())
|
|
}
|
|
|
|
// PinToCursor updates this connection to reflect that it is pinned to a cursor.
|
|
func (c *Connection) PinToCursor() error {
|
|
return c.pin("cursor", c.connection.pool.pinConnectionToCursor, c.connection.pool.unpinConnectionFromCursor)
|
|
}
|
|
|
|
// PinToTransaction updates this connection to reflect that it is pinned to a transaction.
|
|
func (c *Connection) PinToTransaction() error {
|
|
return c.pin("transaction", c.connection.pool.pinConnectionToTransaction, c.connection.pool.unpinConnectionFromTransaction)
|
|
}
|
|
|
|
func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.connection == nil {
|
|
return fmt.Errorf("attempted to pin a connection for a %s, but the connection has already been returned to the pool", reason)
|
|
}
|
|
|
|
// Only use the provided callbacks for the first reference to avoid double-counting pinned connection statistics
|
|
// in the pool.
|
|
if c.refCount == 0 {
|
|
updatePoolFn()
|
|
c.cleanupPoolFn = cleanupPoolFn
|
|
}
|
|
c.refCount++
|
|
return nil
|
|
}
|
|
|
|
// UnpinFromCursor updates this connection to reflect that it is no longer pinned to a cursor.
|
|
func (c *Connection) UnpinFromCursor() error {
|
|
return c.unpin("cursor")
|
|
}
|
|
|
|
// UnpinFromTransaction updates this connection to reflect that it is no longer pinned to a transaction.
|
|
func (c *Connection) UnpinFromTransaction() error {
|
|
return c.unpin("transaction")
|
|
}
|
|
|
|
func (c *Connection) unpin(reason string) error {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.connection == nil {
|
|
// We don't error here because the resource could have been forcefully closed via Expire.
|
|
return nil
|
|
}
|
|
if c.refCount == 0 {
|
|
return fmt.Errorf("attempted to unpin a connection from a %s, but the connection is not pinned by any resources", reason)
|
|
}
|
|
|
|
c.refCount--
|
|
return nil
|
|
}
|
|
|
|
// DriverConnectionID returns the driver connection ID.
|
|
// TODO(GODRIVER-2824): change return type to int64.
|
|
func (c *Connection) DriverConnectionID() uint64 {
|
|
return c.connection.DriverConnectionID()
|
|
}
|
|
|
|
func configureTLS(ctx context.Context,
|
|
tlsConnSource tlsConnectionSource,
|
|
nc net.Conn,
|
|
addr address.Address,
|
|
config *tls.Config,
|
|
ocspOpts *ocsp.VerifyOptions,
|
|
) (net.Conn, error) {
|
|
// Ensure config.ServerName is always set for SNI.
|
|
if config.ServerName == "" {
|
|
hostname := addr.String()
|
|
colonPos := strings.LastIndex(hostname, ":")
|
|
if colonPos == -1 {
|
|
colonPos = len(hostname)
|
|
}
|
|
|
|
hostname = hostname[:colonPos]
|
|
config.ServerName = hostname
|
|
}
|
|
|
|
client := tlsConnSource.Client(nc, config)
|
|
if err := clientHandshake(ctx, client); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Only do OCSP verification if TLS verification is requested.
|
|
if !config.InsecureSkipVerify {
|
|
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
|
|
return nil, ocspErr
|
|
}
|
|
}
|
|
return client, nil
|
|
}
|