package pgtype import ( "bytes" "encoding/binary" "fmt" "io" "reflect" "strconv" "strings" "unicode" "github.com/jackc/pgx/v5/internal/pgio" ) // Information on the internals of PostgreSQL arrays can be found in // src/include/utils/array.h and src/backend/utils/adt/arrayfuncs.c. Of // particular interest is the array_send function. type arrayHeader struct { ContainsNull bool ElementOID uint32 Dimensions []ArrayDimension } type ArrayDimension struct { Length int32 LowerBound int32 } // cardinality returns the number of elements in an array of dimensions size. func cardinality(dimensions []ArrayDimension) int { if len(dimensions) == 0 { return 0 } elementCount := int(dimensions[0].Length) for _, d := range dimensions[1:] { elementCount *= int(d.Length) } return elementCount } func (dst *arrayHeader) DecodeBinary(m *Map, src []byte) (int, error) { if len(src) < 12 { return 0, fmt.Errorf("array header too short: %d", len(src)) } rp := 0 numDims := int(binary.BigEndian.Uint32(src[rp:])) rp += 4 dst.ContainsNull = binary.BigEndian.Uint32(src[rp:]) == 1 rp += 4 dst.ElementOID = binary.BigEndian.Uint32(src[rp:]) rp += 4 dst.Dimensions = make([]ArrayDimension, numDims) if len(src) < 12+numDims*8 { return 0, fmt.Errorf("array header too short for %d dimensions: %d", numDims, len(src)) } for i := range dst.Dimensions { dst.Dimensions[i].Length = int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 dst.Dimensions[i].LowerBound = int32(binary.BigEndian.Uint32(src[rp:])) rp += 4 } return rp, nil } func (src arrayHeader) EncodeBinary(buf []byte) []byte { buf = pgio.AppendInt32(buf, int32(len(src.Dimensions))) var containsNull int32 if src.ContainsNull { containsNull = 1 } buf = pgio.AppendInt32(buf, containsNull) buf = pgio.AppendUint32(buf, src.ElementOID) for i := range src.Dimensions { buf = pgio.AppendInt32(buf, src.Dimensions[i].Length) buf = pgio.AppendInt32(buf, src.Dimensions[i].LowerBound) } return buf } type untypedTextArray struct { Elements []string Quoted []bool Dimensions []ArrayDimension } func parseUntypedTextArray(src string) (*untypedTextArray, error) { dst := &untypedTextArray{ Elements: []string{}, Quoted: []bool{}, Dimensions: []ArrayDimension{}, } buf := bytes.NewBufferString(src) skipWhitespace(buf) r, _, err := buf.ReadRune() if err != nil { return nil, fmt.Errorf("invalid array: %v", err) } var explicitDimensions []ArrayDimension // Array has explicit dimensions if r == '[' { buf.UnreadRune() for { r, _, err = buf.ReadRune() if err != nil { return nil, fmt.Errorf("invalid array: %v", err) } if r == '=' { break } else if r != '[' { return nil, fmt.Errorf("invalid array, expected '[' or '=' got %v", r) } lower, err := arrayParseInteger(buf) if err != nil { return nil, fmt.Errorf("invalid array: %v", err) } r, _, err = buf.ReadRune() if err != nil { return nil, fmt.Errorf("invalid array: %v", err) } if r != ':' { return nil, fmt.Errorf("invalid array, expected ':' got %v", r) } upper, err := arrayParseInteger(buf) if err != nil { return nil, fmt.Errorf("invalid array: %v", err) } r, _, err = buf.ReadRune() if err != nil { return nil, fmt.Errorf("invalid array: %v", err) } if r != ']' { return nil, fmt.Errorf("invalid array, expected ']' got %v", r) } explicitDimensions = append(explicitDimensions, ArrayDimension{LowerBound: lower, Length: upper - lower + 1}) } r, _, err = buf.ReadRune() if err != nil { return nil, fmt.Errorf("invalid array: %v", err) } } if r != '{' { return nil, fmt.Errorf("invalid array, expected '{': %v", err) } implicitDimensions := []ArrayDimension{{LowerBound: 1, Length: 0}} // Consume all initial opening brackets. This provides number of dimensions. for { r, _, err = buf.ReadRune() if err != nil { return nil, fmt.Errorf("invalid array: %v", err) } if r == '{' { implicitDimensions[len(implicitDimensions)-1].Length = 1 implicitDimensions = append(implicitDimensions, ArrayDimension{LowerBound: 1}) } else { buf.UnreadRune() break } } currentDim := len(implicitDimensions) - 1 counterDim := currentDim for { r, _, err = buf.ReadRune() if err != nil { return nil, fmt.Errorf("invalid array: %v", err) } switch r { case '{': if currentDim == counterDim { implicitDimensions[currentDim].Length++ } currentDim++ case ',': case '}': currentDim-- if currentDim < counterDim { counterDim = currentDim } default: buf.UnreadRune() value, quoted, err := arrayParseValue(buf) if err != nil { return nil, fmt.Errorf("invalid array value: %v", err) } if currentDim == counterDim { implicitDimensions[currentDim].Length++ } dst.Quoted = append(dst.Quoted, quoted) dst.Elements = append(dst.Elements, value) } if currentDim < 0 { break } } skipWhitespace(buf) if buf.Len() > 0 { return nil, fmt.Errorf("unexpected trailing data: %v", buf.String()) } if len(dst.Elements) == 0 { } else if len(explicitDimensions) > 0 { dst.Dimensions = explicitDimensions } else { dst.Dimensions = implicitDimensions } return dst, nil } func skipWhitespace(buf *bytes.Buffer) { var r rune var err error for r, _, _ = buf.ReadRune(); unicode.IsSpace(r); r, _, _ = buf.ReadRune() { } if err != io.EOF { buf.UnreadRune() } } func arrayParseValue(buf *bytes.Buffer) (string, bool, error) { r, _, err := buf.ReadRune() if err != nil { return "", false, err } if r == '"' { return arrayParseQuotedValue(buf) } buf.UnreadRune() s := &bytes.Buffer{} for { r, _, err := buf.ReadRune() if err != nil { return "", false, err } switch r { case ',', '}': buf.UnreadRune() return s.String(), false, nil } s.WriteRune(r) } } func arrayParseQuotedValue(buf *bytes.Buffer) (string, bool, error) { s := &bytes.Buffer{} for { r, _, err := buf.ReadRune() if err != nil { return "", false, err } switch r { case '\\': r, _, err = buf.ReadRune() if err != nil { return "", false, err } case '"': r, _, err = buf.ReadRune() if err != nil { return "", false, err } buf.UnreadRune() return s.String(), true, nil } s.WriteRune(r) } } func arrayParseInteger(buf *bytes.Buffer) (int32, error) { s := &bytes.Buffer{} for { r, _, err := buf.ReadRune() if err != nil { return 0, err } if ('0' <= r && r <= '9') || r == '-' { s.WriteRune(r) } else { buf.UnreadRune() n, err := strconv.ParseInt(s.String(), 10, 32) if err != nil { return 0, err } return int32(n), nil } } } func encodeTextArrayDimensions(buf []byte, dimensions []ArrayDimension) []byte { var customDimensions bool for _, dim := range dimensions { if dim.LowerBound != 1 { customDimensions = true } } if !customDimensions { return buf } for _, dim := range dimensions { buf = append(buf, '[') buf = append(buf, strconv.FormatInt(int64(dim.LowerBound), 10)...) buf = append(buf, ':') buf = append(buf, strconv.FormatInt(int64(dim.LowerBound+dim.Length-1), 10)...) buf = append(buf, ']') } return append(buf, '=') } var quoteArrayReplacer = strings.NewReplacer(`\`, `\\`, `"`, `\"`) func quoteArrayElement(src string) string { return `"` + quoteArrayReplacer.Replace(src) + `"` } func isSpace(ch byte) bool { // see array_isspace: // https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c return ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' || ch == '\v' || ch == '\f' } func quoteArrayElementIfNeeded(src string) string { if src == "" || (len(src) == 4 && strings.EqualFold(src, "null")) || isSpace(src[0]) || isSpace(src[len(src)-1]) || strings.ContainsAny(src, `{},"\`) { return quoteArrayElement(src) } return src } func findDimensionsFromValue(value reflect.Value, dimensions []ArrayDimension, elementsLength int) ([]ArrayDimension, int, bool) { switch value.Kind() { case reflect.Array: fallthrough case reflect.Slice: length := value.Len() if 0 == elementsLength { elementsLength = length } else { elementsLength *= length } dimensions = append(dimensions, ArrayDimension{Length: int32(length), LowerBound: 1}) for i := 0; i < length; i++ { if d, l, ok := findDimensionsFromValue(value.Index(i), dimensions, elementsLength); ok { return d, l, true } } } return dimensions, elementsLength, true } // Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves // PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. type Array[T any] struct { Elements []T Dims []ArrayDimension Valid bool } func (a Array[T]) Dimensions() []ArrayDimension { return a.Dims } func (a Array[T]) Index(i int) any { return a.Elements[i] } func (a Array[T]) IndexType() any { var el T return el } func (a *Array[T]) SetDimensions(dimensions []ArrayDimension) error { if dimensions == nil { *a = Array[T]{} return nil } elementCount := cardinality(dimensions) *a = Array[T]{ Elements: make([]T, elementCount), Dims: dimensions, Valid: true, } return nil } func (a Array[T]) ScanIndex(i int) any { return &a.Elements[i] } func (a Array[T]) ScanIndexType() any { return new(T) } // FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions // and custom lower bounds. Use Array to preserve these. type FlatArray[T any] []T func (a FlatArray[T]) Dimensions() []ArrayDimension { if a == nil { return nil } return []ArrayDimension{{Length: int32(len(a)), LowerBound: 1}} } func (a FlatArray[T]) Index(i int) any { return a[i] } func (a FlatArray[T]) IndexType() any { var el T return el } func (a *FlatArray[T]) SetDimensions(dimensions []ArrayDimension) error { if dimensions == nil { *a = nil return nil } elementCount := cardinality(dimensions) *a = make(FlatArray[T], elementCount) return nil } func (a FlatArray[T]) ScanIndex(i int) any { return &a[i] } func (a FlatArray[T]) ScanIndexType() any { return new(T) }