You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/github.com/uptrace/bun/driver/pgdriver/listener.go

381 lines
7.4 KiB

package pgdriver
import (
"context"
"errors"
"strconv"
"sync"
"time"
"github.com/uptrace/bun"
)
const pingChannel = "bun:ping"
var (
errListenerClosed = errors.New("bun: listener is closed")
errPingTimeout = errors.New("bun: ping timeout")
)
// Notify sends a notification on the channel using `NOTIFY` command.
func Notify(ctx context.Context, db *bun.DB, channel, payload string) error {
_, err := db.ExecContext(ctx, "NOTIFY ?, ?", bun.Ident(channel), payload)
return err
}
type Listener struct {
db *bun.DB
driver *Connector
channels []string
mu sync.Mutex
cn *Conn
closed bool
exit chan struct{}
}
func NewListener(db *bun.DB) *Listener {
return &Listener{
db: db,
driver: db.Driver().(Driver).connector,
exit: make(chan struct{}),
}
}
// Close closes the listener, releasing any open resources.
func (ln *Listener) Close() error {
return ln.withLock(func() error {
if ln.closed {
return errListenerClosed
}
ln.closed = true
close(ln.exit)
return ln.closeConn(errListenerClosed)
})
}
func (ln *Listener) withLock(fn func() error) error {
ln.mu.Lock()
defer ln.mu.Unlock()
return fn()
}
func (ln *Listener) conn(ctx context.Context) (*Conn, error) {
if ln.closed {
return nil, errListenerClosed
}
if ln.cn != nil {
return ln.cn, nil
}
cn, err := ln._conn(ctx)
if err != nil {
return nil, err
}
ln.cn = cn
return cn, nil
}
func (ln *Listener) _conn(ctx context.Context) (*Conn, error) {
driverConn, err := ln.driver.Connect(ctx)
if err != nil {
return nil, err
}
cn := driverConn.(*Conn)
if len(ln.channels) > 0 {
err := ln.listen(ctx, cn, ln.channels...)
if err != nil {
_ = cn.Close()
return nil, err
}
}
return cn, nil
}
func (ln *Listener) checkConn(ctx context.Context, cn *Conn, err error, allowTimeout bool) {
_ = ln.withLock(func() error {
if ln.closed || ln.cn != cn {
return nil
}
if isBadConn(err, allowTimeout) {
ln.reconnect(ctx, err)
}
return nil
})
}
func (ln *Listener) reconnect(ctx context.Context, reason error) {
if ln.cn != nil {
Logger.Printf(ctx, "bun: discarding bad listener connection: %s", reason)
_ = ln.closeConn(reason)
}
_, _ = ln.conn(ctx)
}
func (ln *Listener) closeConn(reason error) error {
if ln.cn == nil {
return nil
}
err := ln.cn.Close()
ln.cn = nil
return err
}
// Listen starts listening for notifications on channels.
func (ln *Listener) Listen(ctx context.Context, channels ...string) error {
var cn *Conn
if err := ln.withLock(func() error {
ln.channels = appendIfNotExists(ln.channels, channels...)
var err error
cn, err = ln.conn(ctx)
return err
}); err != nil {
return err
}
if err := ln.listen(ctx, cn, channels...); err != nil {
ln.checkConn(ctx, cn, err, false)
return err
}
return nil
}
func (ln *Listener) listen(ctx context.Context, cn *Conn, channels ...string) error {
for _, channel := range channels {
if err := writeQuery(ctx, cn, "LISTEN "+strconv.Quote(channel)); err != nil {
return err
}
}
return nil
}
// Unlisten stops listening for notifications on channels.
func (ln *Listener) Unlisten(ctx context.Context, channels ...string) error {
var cn *Conn
if err := ln.withLock(func() error {
ln.channels = removeIfExists(ln.channels, channels...)
var err error
cn, err = ln.conn(ctx)
return err
}); err != nil {
return err
}
if err := ln.unlisten(ctx, cn, channels...); err != nil {
ln.checkConn(ctx, cn, err, false)
return err
}
return nil
}
func (ln *Listener) unlisten(ctx context.Context, cn *Conn, channels ...string) error {
for _, channel := range channels {
if err := writeQuery(ctx, cn, "UNLISTEN "+strconv.Quote(channel)); err != nil {
return err
}
}
return nil
}
// Receive indefinitely waits for a notification. This is low-level API
// and in most cases Channel should be used instead.
func (ln *Listener) Receive(ctx context.Context) (channel string, payload string, err error) {
return ln.ReceiveTimeout(ctx, 0)
}
// ReceiveTimeout waits for a notification until timeout is reached.
// This is low-level API and in most cases Channel should be used instead.
func (ln *Listener) ReceiveTimeout(
ctx context.Context, timeout time.Duration,
) (channel, payload string, err error) {
var cn *Conn
if err := ln.withLock(func() error {
var err error
cn, err = ln.conn(ctx)
return err
}); err != nil {
return "", "", err
}
rd := cn.reader(ctx, timeout)
channel, payload, err = readNotification(ctx, rd)
if err != nil {
ln.checkConn(ctx, cn, err, timeout > 0)
return "", "", err
}
return channel, payload, nil
}
// Channel returns a channel for concurrently receiving notifications.
// It periodically sends Ping notification to test connection health.
//
// The channel is closed with Listener. Receive* APIs can not be used
// after channel is created.
func (ln *Listener) Channel(opts ...ChannelOption) <-chan Notification {
return newChannel(ln, opts).ch
}
//------------------------------------------------------------------------------
// Notification received with LISTEN command.
type Notification struct {
Channel string
Payload string
}
type ChannelOption func(c *channel)
func WithChannelSize(size int) ChannelOption {
return func(c *channel) {
c.size = size
}
}
type channel struct {
ctx context.Context
ln *Listener
size int
pingTimeout time.Duration
ch chan Notification
pingCh chan struct{}
}
func newChannel(ln *Listener, opts []ChannelOption) *channel {
c := &channel{
ctx: context.TODO(),
ln: ln,
size: 1000,
pingTimeout: 5 * time.Second,
}
for _, opt := range opts {
opt(c)
}
c.ch = make(chan Notification, c.size)
c.pingCh = make(chan struct{}, 1)
_ = c.ln.Listen(c.ctx, pingChannel)
go c.startReceive()
go c.startPing()
return c
}
func (c *channel) startReceive() {
var errCount int
for {
channel, payload, err := c.ln.Receive(c.ctx)
if err != nil {
if err == errListenerClosed {
close(c.ch)
return
}
if errCount > 0 {
time.Sleep(500 * time.Millisecond)
}
errCount++
continue
}
errCount = 0
// Any notification is as good as a ping.
select {
case c.pingCh <- struct{}{}:
default:
}
switch channel {
case pingChannel:
// ignore
default:
select {
case c.ch <- Notification{channel, payload}:
default:
Logger.Printf(c.ctx, "pgdriver: Listener buffer is full (message is dropped)")
}
}
}
}
func (c *channel) startPing() {
timer := time.NewTimer(time.Minute)
timer.Stop()
healthy := true
for {
timer.Reset(c.pingTimeout)
select {
case <-c.pingCh:
healthy = true
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
pingErr := c.ping(c.ctx)
if healthy {
healthy = false
} else {
if pingErr == nil {
pingErr = errPingTimeout
}
_ = c.ln.withLock(func() error {
c.ln.reconnect(c.ctx, pingErr)
return nil
})
}
case <-c.ln.exit:
return
}
}
}
func (c *channel) ping(ctx context.Context) error {
_, err := c.ln.db.ExecContext(ctx, "NOTIFY "+strconv.Quote(pingChannel))
return err
}
func appendIfNotExists(ss []string, es ...string) []string {
loop:
for _, e := range es {
for _, s := range ss {
if s == e {
continue loop
}
}
ss = append(ss, e)
}
return ss
}
func removeIfExists(ss []string, es ...string) []string {
for _, e := range es {
for i, s := range ss {
if s == e {
last := len(ss) - 1
ss[i] = ss[last]
ss = ss[:last]
break
}
}
}
return ss
}