You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/github.com/uptrace/bun/driver/pgdriver/copy.go

250 lines
4.8 KiB

package pgdriver
import (
"bufio"
"context"
"database/sql"
"fmt"
"io"
"github.com/uptrace/bun"
)
// CopyFrom copies data from the reader to the query destination.
func CopyFrom(
ctx context.Context, conn bun.Conn, r io.Reader, query string, args ...interface{},
) (res sql.Result, err error) {
query, err = formatQueryArgs(query, args)
if err != nil {
return nil, err
}
if err := conn.Raw(func(driverConn interface{}) error {
cn := driverConn.(*Conn)
if err := writeQuery(ctx, cn, query); err != nil {
return err
}
if err := readCopyIn(ctx, cn); err != nil {
return err
}
if err := writeCopyData(ctx, cn, r); err != nil {
return err
}
if err := writeCopyDone(ctx, cn); err != nil {
return err
}
res, err = readQuery(ctx, cn)
return err
}); err != nil {
return nil, err
}
return res, nil
}
func readCopyIn(ctx context.Context, cn *Conn) error {
rd := cn.reader(ctx, -1)
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
if err != nil {
return err
}
switch c {
case errorResponseMsg:
e, err := readError(rd)
if err != nil {
return err
}
if firstErr == nil {
firstErr = e
}
case readyForQueryMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
return firstErr
case copyInResponseMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
return firstErr
case noticeResponseMsg, parameterStatusMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
default:
return fmt.Errorf("pgdriver: readCopyIn: unexpected message %q", c)
}
}
}
func writeCopyData(ctx context.Context, cn *Conn, r io.Reader) error {
wb := getWriteBuffer()
defer putWriteBuffer(wb)
for {
wb.StartMessage(copyDataMsg)
if _, err := wb.ReadFrom(r); err != nil {
if err == io.EOF {
break
}
return err
}
wb.FinishMessage()
if err := cn.write(ctx, wb); err != nil {
return err
}
}
return nil
}
func writeCopyDone(ctx context.Context, cn *Conn) error {
wb := getWriteBuffer()
defer putWriteBuffer(wb)
wb.StartMessage(copyDoneMsg)
wb.FinishMessage()
return cn.write(ctx, wb)
}
//------------------------------------------------------------------------------
// CopyTo copies data from the query source to the writer.
func CopyTo(
ctx context.Context, conn bun.Conn, w io.Writer, query string, args ...interface{},
) (res sql.Result, err error) {
query, err = formatQueryArgs(query, args)
if err != nil {
return nil, err
}
if err := conn.Raw(func(driverConn interface{}) error {
cn := driverConn.(*Conn)
if err := writeQuery(ctx, cn, query); err != nil {
return err
}
if err := readCopyOut(ctx, cn); err != nil {
return err
}
res, err = readCopyData(ctx, cn, w)
return err
}); err != nil {
return nil, err
}
return res, nil
}
func readCopyOut(ctx context.Context, cn *Conn) error {
rd := cn.reader(ctx, -1)
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
if err != nil {
return err
}
switch c {
case errorResponseMsg:
e, err := readError(rd)
if err != nil {
return err
}
if firstErr == nil {
firstErr = e
}
case readyForQueryMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
return firstErr
case copyOutResponseMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
return nil
case noticeResponseMsg, parameterStatusMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
default:
return fmt.Errorf("pgdriver: readCopyOut: unexpected message %q", c)
}
}
}
func readCopyData(ctx context.Context, cn *Conn, w io.Writer) (res sql.Result, err error) {
rd := cn.reader(ctx, -1)
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
if err != nil {
return nil, err
}
switch c {
case errorResponseMsg:
e, err := readError(rd)
if err != nil {
return nil, err
}
if firstErr == nil {
firstErr = e
}
case copyDataMsg:
for msgLen > 0 {
b, err := rd.ReadTemp(msgLen)
if err != nil && err != bufio.ErrBufferFull {
return nil, err
}
if _, err := w.Write(b); err != nil {
if firstErr == nil {
firstErr = err
}
break
}
msgLen -= len(b)
}
case copyDoneMsg:
if err := rd.Discard(msgLen); err != nil {
return nil, err
}
case commandCompleteMsg:
tmp, err := rd.ReadTemp(msgLen)
if err != nil {
firstErr = err
break
}
r, err := parseResult(tmp)
if err != nil {
firstErr = err
} else {
res = r
}
case readyForQueryMsg:
if err := rd.Discard(msgLen); err != nil {
return nil, err
}
return res, firstErr
case noticeResponseMsg, parameterStatusMsg:
if err := rd.Discard(msgLen); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("pgdriver: readCopyData: unexpected message %q", c)
}
}
}