You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
4.8 KiB
250 lines
4.8 KiB
package pgdriver
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"io"
|
|
|
|
"github.com/uptrace/bun"
|
|
)
|
|
|
|
// CopyFrom copies data from the reader to the query destination.
|
|
func CopyFrom(
|
|
ctx context.Context, conn bun.Conn, r io.Reader, query string, args ...interface{},
|
|
) (res sql.Result, err error) {
|
|
query, err = formatQueryArgs(query, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := conn.Raw(func(driverConn interface{}) error {
|
|
cn := driverConn.(*Conn)
|
|
|
|
if err := writeQuery(ctx, cn, query); err != nil {
|
|
return err
|
|
}
|
|
if err := readCopyIn(ctx, cn); err != nil {
|
|
return err
|
|
}
|
|
if err := writeCopyData(ctx, cn, r); err != nil {
|
|
return err
|
|
}
|
|
if err := writeCopyDone(ctx, cn); err != nil {
|
|
return err
|
|
}
|
|
|
|
res, err = readQuery(ctx, cn)
|
|
return err
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func readCopyIn(ctx context.Context, cn *Conn) error {
|
|
rd := cn.reader(ctx, -1)
|
|
var firstErr error
|
|
for {
|
|
c, msgLen, err := readMessageType(rd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch c {
|
|
case errorResponseMsg:
|
|
e, err := readError(rd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if firstErr == nil {
|
|
firstErr = e
|
|
}
|
|
case readyForQueryMsg:
|
|
if err := rd.Discard(msgLen); err != nil {
|
|
return err
|
|
}
|
|
return firstErr
|
|
case copyInResponseMsg:
|
|
if err := rd.Discard(msgLen); err != nil {
|
|
return err
|
|
}
|
|
return firstErr
|
|
case noticeResponseMsg, parameterStatusMsg:
|
|
if err := rd.Discard(msgLen); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return fmt.Errorf("pgdriver: readCopyIn: unexpected message %q", c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func writeCopyData(ctx context.Context, cn *Conn, r io.Reader) error {
|
|
wb := getWriteBuffer()
|
|
defer putWriteBuffer(wb)
|
|
|
|
for {
|
|
wb.StartMessage(copyDataMsg)
|
|
if _, err := wb.ReadFrom(r); err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return err
|
|
}
|
|
wb.FinishMessage()
|
|
|
|
if err := cn.write(ctx, wb); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func writeCopyDone(ctx context.Context, cn *Conn) error {
|
|
wb := getWriteBuffer()
|
|
defer putWriteBuffer(wb)
|
|
|
|
wb.StartMessage(copyDoneMsg)
|
|
wb.FinishMessage()
|
|
|
|
return cn.write(ctx, wb)
|
|
}
|
|
|
|
//------------------------------------------------------------------------------
|
|
|
|
// CopyTo copies data from the query source to the writer.
|
|
func CopyTo(
|
|
ctx context.Context, conn bun.Conn, w io.Writer, query string, args ...interface{},
|
|
) (res sql.Result, err error) {
|
|
query, err = formatQueryArgs(query, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := conn.Raw(func(driverConn interface{}) error {
|
|
cn := driverConn.(*Conn)
|
|
|
|
if err := writeQuery(ctx, cn, query); err != nil {
|
|
return err
|
|
}
|
|
if err := readCopyOut(ctx, cn); err != nil {
|
|
return err
|
|
}
|
|
|
|
res, err = readCopyData(ctx, cn, w)
|
|
return err
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func readCopyOut(ctx context.Context, cn *Conn) error {
|
|
rd := cn.reader(ctx, -1)
|
|
var firstErr error
|
|
for {
|
|
c, msgLen, err := readMessageType(rd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch c {
|
|
case errorResponseMsg:
|
|
e, err := readError(rd)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if firstErr == nil {
|
|
firstErr = e
|
|
}
|
|
case readyForQueryMsg:
|
|
if err := rd.Discard(msgLen); err != nil {
|
|
return err
|
|
}
|
|
return firstErr
|
|
case copyOutResponseMsg:
|
|
if err := rd.Discard(msgLen); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
case noticeResponseMsg, parameterStatusMsg:
|
|
if err := rd.Discard(msgLen); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return fmt.Errorf("pgdriver: readCopyOut: unexpected message %q", c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func readCopyData(ctx context.Context, cn *Conn, w io.Writer) (res sql.Result, err error) {
|
|
rd := cn.reader(ctx, -1)
|
|
var firstErr error
|
|
for {
|
|
c, msgLen, err := readMessageType(rd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch c {
|
|
case errorResponseMsg:
|
|
e, err := readError(rd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if firstErr == nil {
|
|
firstErr = e
|
|
}
|
|
case copyDataMsg:
|
|
for msgLen > 0 {
|
|
b, err := rd.ReadTemp(msgLen)
|
|
if err != nil && err != bufio.ErrBufferFull {
|
|
return nil, err
|
|
}
|
|
|
|
if _, err := w.Write(b); err != nil {
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
break
|
|
}
|
|
|
|
msgLen -= len(b)
|
|
}
|
|
case copyDoneMsg:
|
|
if err := rd.Discard(msgLen); err != nil {
|
|
return nil, err
|
|
}
|
|
case commandCompleteMsg:
|
|
tmp, err := rd.ReadTemp(msgLen)
|
|
if err != nil {
|
|
firstErr = err
|
|
break
|
|
}
|
|
|
|
r, err := parseResult(tmp)
|
|
if err != nil {
|
|
firstErr = err
|
|
} else {
|
|
res = r
|
|
}
|
|
case readyForQueryMsg:
|
|
if err := rd.Discard(msgLen); err != nil {
|
|
return nil, err
|
|
}
|
|
return res, firstErr
|
|
case noticeResponseMsg, parameterStatusMsg:
|
|
if err := rd.Discard(msgLen); err != nil {
|
|
return nil, err
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("pgdriver: readCopyData: unexpected message %q", c)
|
|
}
|
|
}
|
|
}
|