parent
4aedb95281
commit
f509f0087f
@ -1,34 +0,0 @@
|
||||
# idea ignore
|
||||
.idea/
|
||||
*.ipr
|
||||
*.iml
|
||||
*.iws
|
||||
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
# temp ignore
|
||||
*.log
|
||||
*.cache
|
||||
*.diff
|
||||
*.exe
|
||||
*.exe~
|
||||
*.patch
|
||||
*.tmp
|
||||
*debug.test
|
||||
debug.test
|
||||
go.sum
|
||||
|
||||
# system ignore
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# project
|
||||
*.cert
|
||||
*.key
|
||||
.test
|
||||
iprepo.txt
|
||||
|
||||
|
||||
_output
|
File diff suppressed because it is too large
Load Diff
@ -1,184 +0,0 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Finder 查询数据库的载体,所有的sql语句都要通过Finder执行.
|
||||
// Finder To query the database carrier, all SQL statements must be executed through Finder
|
||||
type Finder struct {
|
||||
// 拼接SQL
|
||||
// Splicing SQL.
|
||||
sqlBuilder strings.Builder
|
||||
// SQL的参数值
|
||||
// SQL parameter values.
|
||||
values []interface{}
|
||||
// 注入检查,默认true 不允许SQL注入的 ' 单引号
|
||||
// Injection check, default true does not allow SQL injection single quote
|
||||
InjectionCheck bool
|
||||
// CountFinder 自定义的查询总条数'Finder',使用指针默认为nil.主要是为了在'group by'等复杂情况下,为了性能,手动编写总条数语句
|
||||
// CountFinder The total number of custom queries is'Finder', and the pointer is nil by default. It is mainly used to manually write the total number of statements for performance in complex situations such as'group by'
|
||||
CountFinder *Finder
|
||||
// 是否自动查询总条数,默认true.同时需要Page不为nil,才查询总条数
|
||||
// Whether to automatically query the total number of entries, the default is true. At the same time, the Page is not nil to query the total number of entries
|
||||
SelectTotalCount bool
|
||||
// SQL语句
|
||||
// SQL statement
|
||||
sqlstr string
|
||||
}
|
||||
|
||||
// NewFinder 初始化一个Finder,生成一个空的Finder
|
||||
// NewFinder Initialize a Finder and generate an empty Finder
|
||||
func NewFinder() *Finder {
|
||||
finder := Finder{}
|
||||
finder.sqlBuilder.Grow(stringBuilderGrowLen)
|
||||
finder.SelectTotalCount = true
|
||||
finder.InjectionCheck = true
|
||||
// slice扩容会生成新的slice,最后要值复制接收.问:为什么cap是3?答:经验
|
||||
finder.values = make([]interface{}, 0, 3)
|
||||
return &finder
|
||||
}
|
||||
|
||||
// NewSelectFinder 根据表名初始化查询的Finder,strs 只取第一个字符串,用数组类型是为了可以不传入,默认为 * | Finder that initializes the query based on the table name
|
||||
// NewSelectFinder("tableName") SELECT * FROM tableName
|
||||
// NewSelectFinder("tableName", "id,name") SELECT id,name FROM tableName
|
||||
func NewSelectFinder(tableName string, strs ...string) *Finder {
|
||||
strsLen := len(strs)
|
||||
if strsLen > 1 { // 不支持多个参数
|
||||
return nil
|
||||
}
|
||||
finder := NewFinder()
|
||||
finder.sqlBuilder.WriteString("SELECT ")
|
||||
if strsLen == 1 { // 只取值第一个字符串
|
||||
finder.sqlBuilder.WriteString(strs[0])
|
||||
} else {
|
||||
finder.sqlBuilder.WriteByte('*')
|
||||
}
|
||||
finder.sqlBuilder.WriteString(" FROM ")
|
||||
finder.sqlBuilder.WriteString(tableName)
|
||||
return finder
|
||||
}
|
||||
|
||||
// NewUpdateFinder 根据表名初始化更新的Finder, UPDATE tableName SET
|
||||
// NewUpdateFinder Initialize the updated Finder according to the table name, UPDATE tableName SET
|
||||
func NewUpdateFinder(tableName string) *Finder {
|
||||
finder := NewFinder()
|
||||
finder.sqlBuilder.WriteString("UPDATE ")
|
||||
finder.sqlBuilder.WriteString(tableName)
|
||||
finder.sqlBuilder.WriteString(" SET ")
|
||||
return finder
|
||||
}
|
||||
|
||||
// NewDeleteFinder 根据表名初始化删除的'Finder', DELETE FROM tableName
|
||||
// NewDeleteFinder Finder for initial deletion based on table name. DELETE FROM tableName
|
||||
func NewDeleteFinder(tableName string) *Finder {
|
||||
finder := NewFinder()
|
||||
finder.sqlBuilder.WriteString("DELETE FROM ")
|
||||
finder.sqlBuilder.WriteString(tableName)
|
||||
// 所有的 WHERE 都不加,规则统一,好记
|
||||
// No WHERE is added, the rules are unified, easy to remember
|
||||
// finder.sqlBuilder.WriteString(" WHERE ")
|
||||
return finder
|
||||
}
|
||||
|
||||
// Append 添加SQL和参数的值,第一个参数是语句,后面的参数[可选]是参数的值,顺序要正确
|
||||
// 例如: finder.Append(" and id=? and name=? ",23123,"abc")
|
||||
// 只拼接SQL,例如: finder.Append(" and name=123 ")
|
||||
// Append:Add SQL and parameter values, the first parameter is the statement, and the following parameter (optional) is the value of the parameter, in the correct order
|
||||
// E.g: finder.Append(" and id=? and name=? ",23123,"abc")
|
||||
// Only splice SQL, E.g: finder.Append(" and name=123 ")
|
||||
func (finder *Finder) Append(s string, values ...interface{}) *Finder {
|
||||
// 不要自己构建finder,使用NewFinder()方法
|
||||
// Don't build finder by yourself, use NewFinder() method
|
||||
if finder == nil || finder.values == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s != "" {
|
||||
if finder.sqlstr != "" {
|
||||
finder.sqlstr = ""
|
||||
}
|
||||
// 默认加一个空格,避免手误两个字符串连接再一起
|
||||
// A space is added by default to avoid hand mistakes when connecting two strings together
|
||||
finder.sqlBuilder.WriteByte(' ')
|
||||
|
||||
finder.sqlBuilder.WriteString(s)
|
||||
|
||||
}
|
||||
if values == nil || len(values) < 1 {
|
||||
return finder
|
||||
}
|
||||
|
||||
finder.values = append(finder.values, values...)
|
||||
return finder
|
||||
}
|
||||
|
||||
// AppendFinder 添加另一个Finder finder.AppendFinder(f)
|
||||
// AppendFinder Add another Finder . finder.AppendFinder(f)
|
||||
func (finder *Finder) AppendFinder(f *Finder) (*Finder, error) {
|
||||
if finder == nil {
|
||||
return finder, errors.New("->finder-->AppendFinder()finder对象为nil")
|
||||
}
|
||||
if f == nil {
|
||||
return finder, errors.New("->finder-->AppendFinder()参数是nil")
|
||||
}
|
||||
|
||||
// 不要自己构建finder,使用NewFinder()方法
|
||||
// Don't build finder by yourself, use NewFinder() method
|
||||
if finder.values == nil {
|
||||
return finder, errors.New("->finder-->AppendFinder()不要自己构建finder,使用NewFinder()方法")
|
||||
}
|
||||
|
||||
// 添加f的SQL
|
||||
// SQL to add f
|
||||
sqlstr, err := f.GetSQL()
|
||||
if err != nil {
|
||||
return finder, err
|
||||
}
|
||||
finder.sqlstr = ""
|
||||
finder.sqlBuilder.WriteString(sqlstr)
|
||||
// 添加f的值
|
||||
// Add the value of f
|
||||
finder.values = append(finder.values, f.values...)
|
||||
return finder, nil
|
||||
}
|
||||
|
||||
// GetSQL 返回Finder封装的SQL语句
|
||||
// GetSQL Return the SQL statement encapsulated by the Finder
|
||||
func (finder *Finder) GetSQL() (string, error) {
|
||||
// 不要自己构建finder,使用NewFinder方法
|
||||
// Don't build finder by yourself, use NewFinder method
|
||||
if finder == nil || finder.values == nil {
|
||||
return "", errors.New("->finder-->GetSQL()不要自己构建finder,使用NewFinder()方法")
|
||||
}
|
||||
if len(finder.sqlstr) > 0 {
|
||||
return finder.sqlstr, nil
|
||||
}
|
||||
sqlstr := finder.sqlBuilder.String()
|
||||
// 包含单引号,属于非法字符串
|
||||
// Contains single quotes, which are illegal strings
|
||||
if finder.InjectionCheck && (strings.Contains(sqlstr, "'")) {
|
||||
return "", errors.New(`->finder-->GetSQL()SQL语句请不要直接拼接字符串参数,容易注入!!!请使用问号占位符,例如 finder.Append("and id=?","stringId"),如果必须拼接字符串,请设置 finder.InjectionCheck = false `)
|
||||
}
|
||||
finder.sqlstr = sqlstr
|
||||
return sqlstr, nil
|
||||
}
|
@ -1,141 +0,0 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// customDriverValueMap 用于配置数据库字段类型的处理关系,key是 Dialect.字段类型,例如 dm.TEXT
|
||||
var customDriverValueMap = make(map[string]ICustomDriverValueConver)
|
||||
|
||||
// iscdvm 是否有自定义的DriverValueMap
|
||||
var iscdvm bool
|
||||
|
||||
// ICustomDriverValueConver 自定义类型转化接口,用于解决 类似达梦 text --> dm.DmClob --> string类型接收的问题
|
||||
type ICustomDriverValueConver interface {
|
||||
// GetDriverValue 根据数据库列类型,返回driver.Value的实例,struct属性类型
|
||||
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
|
||||
GetDriverValue(ctx context.Context, columnType *sql.ColumnType, structFieldType *reflect.Type) (driver.Value, error)
|
||||
|
||||
// ConverDriverValue 数据库列类型,GetDriverValue返回的driver.Value的临时接收值,struct属性类型
|
||||
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
|
||||
// 返回符合接收类型值的指针,指针,指针!!!!
|
||||
ConverDriverValue(ctx context.Context, columnType *sql.ColumnType, tempDriverValue driver.Value, structFieldType *reflect.Type) (interface{}, error)
|
||||
}
|
||||
|
||||
// RegisterCustomDriverValueConver 注册自定义的字段处理逻辑,用于驱动无法直接转换的场景,例如达梦的 TEXT 无法直接转化成 string
|
||||
// dialectColumnType 值是 Dialect.字段类型,例如: dm.TEXT
|
||||
// 一般是放到init方法里进行注册
|
||||
func RegisterCustomDriverValueConver(dialectColumnType string, customDriverValueConver ICustomDriverValueConver) error {
|
||||
if len(dialectColumnType) < 1 {
|
||||
return errors.New("->RegisterCustomDriverValueConver-->dialectColumnType为空")
|
||||
}
|
||||
dialectColumnTypes := strings.Split(dialectColumnType, ".")
|
||||
if len(dialectColumnTypes) < 2 {
|
||||
customDriverValueMap[strings.ToUpper(dialectColumnType)] = customDriverValueConver
|
||||
err := errors.New("->RegisterCustomDriverValueConver-->dialectColumnType 值是 Dialect.字段类型,例如: dm.TEXT ,本次正常运行,请尽快修改")
|
||||
FuncLogError(nil, err)
|
||||
} else {
|
||||
customDriverValueMap[strings.ToLower(dialectColumnTypes[0])+"."+strings.ToUpper(dialectColumnTypes[1])] = customDriverValueConver
|
||||
}
|
||||
iscdvm = true
|
||||
return nil
|
||||
}
|
||||
|
||||
type driverValueInfo struct {
|
||||
customDriverValueConver ICustomDriverValueConver
|
||||
columnType *sql.ColumnType
|
||||
tempDriverValue interface{}
|
||||
structFieldType *reflect.Type
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
import (
|
||||
// 00.引入数据库驱动
|
||||
"gitee.com/chunanyong/dm"
|
||||
"io"
|
||||
)
|
||||
|
||||
// CustomDMText 实现ICustomDriverValueConver接口,扩展自定义类型,例如 达梦数据库TEXT类型,映射出来的是dm.DmClob类型,无法使用string类型直接接收
|
||||
type CustomDMText struct{}
|
||||
|
||||
// GetDriverValue 根据数据库列类型,返回driver.Value的实例,struct属性类型
|
||||
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
|
||||
func (dmtext CustomDMText) GetDriverValue(ctx context.Context, columnType *sql.ColumnType, structFieldType *reflect.Type) (driver.Value, error) {
|
||||
// 如果需要使用structFieldType,需要先判断是否为nil
|
||||
// if structFieldType != nil {
|
||||
// }
|
||||
|
||||
return &dm.DmClob{}, nil
|
||||
}
|
||||
|
||||
// ConverDriverValue 数据库列类型,GetDriverValue返回的driver.Value的临时接收值,struct属性类型
|
||||
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
|
||||
// 返回符合接收类型值的指针,指针,指针!!!!
|
||||
func (dmtext CustomDMText) ConverDriverValue(ctx context.Context, columnType *sql.ColumnType, tempDriverValue driver.Value, structFieldType *reflect.Type) (interface{}, error) {
|
||||
// 如果需要使用structFieldType,需要先判断是否为nil
|
||||
// if structFieldType != nil {
|
||||
// }
|
||||
|
||||
// 类型转换
|
||||
dmClob, isok := tempDriverValue.(*dm.DmClob)
|
||||
if !isok {
|
||||
return tempDriverValue, errors.New("->ConverDriverValue-->转换至*dm.DmClob类型失败")
|
||||
}
|
||||
if dmClob == nil || !dmClob.Valid {
|
||||
return new(string), nil
|
||||
}
|
||||
// 获取长度
|
||||
dmlen, errLength := dmClob.GetLength()
|
||||
if errLength != nil {
|
||||
return dmClob, errLength
|
||||
}
|
||||
|
||||
// int64转成int类型
|
||||
strInt64 := strconv.FormatInt(dmlen, 10)
|
||||
dmlenInt, errAtoi := strconv.Atoi(strInt64)
|
||||
if errAtoi != nil {
|
||||
return dmClob, errAtoi
|
||||
}
|
||||
|
||||
// 读取字符串
|
||||
str, errReadString := dmClob.ReadString(1, dmlenInt)
|
||||
|
||||
// 处理空字符串或NULL造成的EOF错误
|
||||
if errReadString == io.EOF {
|
||||
return new(string), nil
|
||||
}
|
||||
|
||||
return &str, errReadString
|
||||
}
|
||||
// RegisterCustomDriverValueConver 注册自定义的字段处理逻辑,用于驱动无法直接转换的场景,例如达梦的 TEXT 无法直接转化成 string
|
||||
// 一般是放到init方法里进行注册
|
||||
func init() {
|
||||
// dialectColumnType 值是 Dialect.字段类型 ,例如 dm.TEXT
|
||||
zorm.RegisterCustomDriverValueConver("dm.TEXT", CustomDMText{})
|
||||
}
|
||||
|
||||
**/
|
@ -1,160 +0,0 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
// IEntityStruct "struct"实体类的接口,所有的struct实体类都要实现这个接口
|
||||
// IEntityStruct The interface of the "struct" entity class, all struct entity classes must implement this interface
|
||||
type IEntityStruct interface {
|
||||
// 获取表名称
|
||||
// Get the table name.
|
||||
GetTableName() string
|
||||
|
||||
// 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称
|
||||
// Get the primary key field name of the database table. Because it is compatible with Map, it can only be the field name of the database
|
||||
GetPKColumnName() string
|
||||
|
||||
// GetPkSequence 主键序列
|
||||
// GetPkSequence Primary key sequence
|
||||
GetPkSequence() string
|
||||
}
|
||||
|
||||
// IEntityMap 使用Map保存数据,用于不方便使用struct的场景,如果主键是自增或者序列,不要"entityMap.Set"主键的值
|
||||
// IEntityMap Use Map to save data for scenarios where it is not convenient to use struct
|
||||
// If the primary key is auto-increment or sequence, do not "entity Map.Set" the value of the primary key
|
||||
type IEntityMap interface {
|
||||
// 获取表名称
|
||||
// Get the table name
|
||||
GetTableName() string
|
||||
|
||||
// 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称.
|
||||
// Get the primary key field name of the database table. Because it is compatible with Map, it can only be the field name of the database.
|
||||
GetPKColumnName() string
|
||||
|
||||
// GetEntityMapPkSequence 主键序列,不能使用GetPkSequence方法名,避免默认实现了IEntityStruct接口
|
||||
// GetEntityMapPkSequence primary key sequence, you cannot use the GetPkSequence method name, to avoid the default implementation of IEntityStruct interface
|
||||
GetEntityMapPkSequence() string
|
||||
|
||||
// GetDBFieldMap 针对Map类型,记录数据库字段
|
||||
// GetDBFieldMap For Map type, record database fields.
|
||||
GetDBFieldMap() map[string]interface{}
|
||||
|
||||
// GetDBFieldMapKey 按照Set的先后顺序记录key值,也就是数据库字段,用于SQL排序
|
||||
// GetDBFieldMapKey records the key value, that is, the database field, in the order of the Set, which is used for SQL sorting
|
||||
GetDBFieldMapKey() []string
|
||||
// 设置数据库字段的值
|
||||
// Set the value of a database field.
|
||||
Set(key string, value interface{}) map[string]interface{}
|
||||
}
|
||||
|
||||
// EntityStruct "IBaseEntity" 的基础实现,所有的实体类都匿名注入.这样就类似实现继承了,如果接口增加方法,调整这个默认实现即可
|
||||
// EntityStruct The basic implementation of "IBaseEntity", all entity classes are injected anonymously
|
||||
// This is similar to implementation inheritance. If the interface adds methods, adjust the default implementation
|
||||
type EntityStruct struct{}
|
||||
|
||||
// 默认数据库的主键列名
|
||||
// Primary key column name of the default database
|
||||
const defaultPkName = "id"
|
||||
|
||||
//GetTableName 获取表名称,必须有具体的Struct实现,类似java的抽象方法,避免手误忘记写表名.如果有扩展需求,建议使用接口进行扩展,不要默认实现GetTableName
|
||||
/*
|
||||
func (entity *EntityStruct) GetTableName() string {
|
||||
return ""
|
||||
}
|
||||
*/
|
||||
|
||||
// GetPKColumnName 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称
|
||||
// GetPKColumnName Get the primary key field name of the database table
|
||||
// Because it is compatible with Map, it can only be the field name of the database
|
||||
func (entity *EntityStruct) GetPKColumnName() string {
|
||||
return defaultPkName
|
||||
}
|
||||
|
||||
// var defaultPkSequence = make(map[string]string, 0)
|
||||
|
||||
// GetPkSequence 主键序列
|
||||
// GetPkSequence Primary key sequence
|
||||
func (entity *EntityStruct) GetPkSequence() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------//
|
||||
|
||||
// EntityMap IEntityMap的基础实现,可以直接使用或者匿名注入
|
||||
type EntityMap struct {
|
||||
// 表名
|
||||
tableName string
|
||||
// 主键列名
|
||||
PkColumnName string
|
||||
// 主键序列,如果有值,优先级最高
|
||||
PkSequence string
|
||||
// 数据库字段,不暴露外部
|
||||
dbFieldMap map[string]interface{}
|
||||
// 列名,记录顺序
|
||||
dbFieldMapKey []string
|
||||
}
|
||||
|
||||
// NewEntityMap 初始化Map,必须传入表名称
|
||||
func NewEntityMap(tbName string) *EntityMap {
|
||||
entityMap := EntityMap{}
|
||||
entityMap.dbFieldMap = map[string]interface{}{}
|
||||
entityMap.tableName = tbName
|
||||
entityMap.PkColumnName = defaultPkName
|
||||
entityMap.dbFieldMapKey = make([]string, 0)
|
||||
return &entityMap
|
||||
}
|
||||
|
||||
// GetTableName 获取表名称
|
||||
func (entity *EntityMap) GetTableName() string {
|
||||
return entity.tableName
|
||||
}
|
||||
|
||||
// GetPKColumnName 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称
|
||||
func (entity *EntityMap) GetPKColumnName() string {
|
||||
return entity.PkColumnName
|
||||
}
|
||||
|
||||
// GetEntityMapPkSequence 主键序列,不能使用GetPkSequence方法名,避免默认实现了IEntityStruct接口
|
||||
// GetEntityMapPkSequence primary key sequence, you cannot use the GetPkSequence method name, to avoid the default implementation of IEntityStruct interface
|
||||
func (entity *EntityMap) GetEntityMapPkSequence() string {
|
||||
return entity.PkSequence
|
||||
}
|
||||
|
||||
// GetDBFieldMap 针对Map类型,记录数据库字段
|
||||
// GetDBFieldMap For Map type, record database fields
|
||||
func (entity *EntityMap) GetDBFieldMap() map[string]interface{} {
|
||||
return entity.dbFieldMap
|
||||
}
|
||||
|
||||
// GetDBFieldMapKey 按照Set的先后顺序记录key值,也就是数据库字段,用于SQL排序
|
||||
// GetDBFieldMapKey records the key value, that is, the database field, in the order of the Set, which is used for SQL sorting
|
||||
func (entity *EntityMap) GetDBFieldMapKey() []string {
|
||||
return entity.dbFieldMapKey
|
||||
}
|
||||
|
||||
// Set 设置数据库字段
|
||||
// Set Set database fields
|
||||
func (entity *EntityMap) Set(key string, value interface{}) map[string]interface{} {
|
||||
_, ok := entity.dbFieldMap[key]
|
||||
if !ok { // 如果不存在
|
||||
entity.dbFieldMapKey = append(entity.dbFieldMapKey, key)
|
||||
}
|
||||
entity.dbFieldMap[key] = value
|
||||
|
||||
return entity.dbFieldMap
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import "context"
|
||||
|
||||
// IGlobalTransaction 托管全局分布式事务接口
|
||||
type IGlobalTransaction interface {
|
||||
// BeginGTX 开启全局分布式事务
|
||||
BeginGTX(ctx context.Context, globalRootContext context.Context) error
|
||||
|
||||
// CommitGTX 提交全局分布式事务.不能命名为 Commit,不然就和gtx的Commit一致了,就递归调用自己了.......
|
||||
CommitGTX(ctx context.Context, globalRootContext context.Context) error
|
||||
|
||||
// RollbackGTX 回滚全局分布式事务
|
||||
RollbackGTX(ctx context.Context, globalRootContext context.Context) error
|
||||
|
||||
// GetGTXID 获取全局分布式事务的XID
|
||||
GetGTXID(ctx context.Context, globalRootContext context.Context) (string, error)
|
||||
|
||||
// 重新包装为 seata/hptx 的context.RootContext
|
||||
// context.RootContext 如果后续使用了 context.WithValue,类型就是context.valueCtx 就会造成无法再类型断言为 context.RootContext
|
||||
// 所以DBDao里使用了 globalRootContext变量,区分业务的ctx和分布式事务的RootContext
|
||||
// NewRootContext(ctx context.Context) context.Context
|
||||
}
|
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,63 +0,0 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 设置默认的日志显示信息,显示文件和行号
|
||||
// Set the default log display information, display file and line number.
|
||||
log.SetFlags(log.Llongfile | log.LstdFlags)
|
||||
}
|
||||
|
||||
// LogCallDepth 记录日志调用层级,用于定位到业务层代码
|
||||
// Log Call Depth Record the log call level, used to locate the business layer code
|
||||
var LogCallDepth = 4
|
||||
|
||||
// FuncLogError 记录error日志.NewDBDao方法里的异常,ctx为nil,扩展时请注意
|
||||
// FuncLogError Record error log
|
||||
var FuncLogError func(ctx context.Context, err error) = defaultLogError
|
||||
|
||||
// FuncLogPanic 记录panic日志,默认使用"defaultLogError"实现
|
||||
// FuncLogPanic Record panic log, using "defaultLogError" by default
|
||||
var FuncLogPanic func(ctx context.Context, err error) = defaultLogPanic
|
||||
|
||||
// FuncPrintSQL 打印sql语句,参数和执行时间,小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
|
||||
// FuncPrintSQL Print sql statement and parameters
|
||||
var FuncPrintSQL func(ctx context.Context, sqlstr string, args []interface{}, execSQLMillis int64) = defaultPrintSQL
|
||||
|
||||
func defaultLogError(ctx context.Context, err error) {
|
||||
log.Output(LogCallDepth, fmt.Sprintln(err))
|
||||
}
|
||||
|
||||
func defaultLogPanic(ctx context.Context, err error) {
|
||||
defaultLogError(ctx, err)
|
||||
}
|
||||
|
||||
func defaultPrintSQL(ctx context.Context, sqlstr string, args []interface{}, execSQLMillis int64) {
|
||||
if args != nil {
|
||||
log.Output(LogCallDepth, fmt.Sprintln("sql:", sqlstr, ",args:", args, ",execSQLMillis:", execSQLMillis))
|
||||
} else {
|
||||
log.Output(LogCallDepth, fmt.Sprintln("sql:", sqlstr, ",args: [] ", ",execSQLMillis:", execSQLMillis))
|
||||
}
|
||||
}
|
@ -1,81 +0,0 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
// Page 分页对象
|
||||
// Page Pagination object
|
||||
type Page struct {
|
||||
// 当前页码,从1开始
|
||||
// Current page number, starting from 1
|
||||
PageNo int
|
||||
|
||||
// 每页多少条,默认20条
|
||||
// How many items per page, 20 items by default
|
||||
PageSize int
|
||||
|
||||
// 数据总条数
|
||||
// Total number of data
|
||||
TotalCount int
|
||||
|
||||
// 共多少页
|
||||
// How many pages
|
||||
PageCount int
|
||||
|
||||
// 是否是第一页
|
||||
// Is it the first page
|
||||
FirstPage bool
|
||||
|
||||
// 是否有上一页
|
||||
// Whether there is a previous page
|
||||
HasPrev bool
|
||||
|
||||
// 是否有下一页
|
||||
// Is there a next page
|
||||
HasNext bool
|
||||
|
||||
// 是否是最后一页
|
||||
// Is it the last page
|
||||
LastPage bool
|
||||
}
|
||||
|
||||
// NewPage 创建Page对象
|
||||
// NewPage Create Page object
|
||||
func NewPage() *Page {
|
||||
page := Page{}
|
||||
page.PageNo = 1
|
||||
page.PageSize = 20
|
||||
return &page
|
||||
}
|
||||
|
||||
// setTotalCount 设置总条数,计算其他值
|
||||
// setTotalCount Set the total number of bars, calculate other values
|
||||
func (page *Page) setTotalCount(total int) {
|
||||
page.TotalCount = total
|
||||
page.PageCount = (page.TotalCount + page.PageSize - 1) / page.PageSize
|
||||
if page.PageNo >= page.PageCount {
|
||||
page.LastPage = true
|
||||
} else {
|
||||
page.HasNext = true
|
||||
}
|
||||
if page.PageNo > 1 {
|
||||
page.HasPrev = true
|
||||
} else {
|
||||
page.FirstPage = true
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,318 +0,0 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// dataSorce对象,隔离sql原生对象
|
||||
// dataSorce Isolate sql native objects
|
||||
type dataSource struct {
|
||||
*sql.DB
|
||||
// config *DataSourceConfig
|
||||
}
|
||||
|
||||
// newDataSource 创建一个新的datasource,内部调用,避免外部直接使用datasource
|
||||
// newDAtaSource Create a new datasource and call it internally to avoid direct external use of the datasource
|
||||
func newDataSource(config *DataSourceConfig) (*dataSource, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("->newDataSource-->config cannot be nil")
|
||||
}
|
||||
|
||||
if config.DriverName == "" {
|
||||
return nil, errors.New("->newDataSource-->DriverName cannot be empty")
|
||||
}
|
||||
// 兼容处理,DBType即将废弃,请使用Dialect属性
|
||||
if config.DBType != "" && config.Dialect == "" {
|
||||
FuncLogError(nil, errors.New("->newDataSource-->DataSourceConfig的DBType即将废弃,请使用Dialect属性"))
|
||||
config.Dialect = config.DBType
|
||||
}
|
||||
if config.Dialect == "" {
|
||||
return nil, errors.New("->newDataSource-->Dialect cannot be empty")
|
||||
}
|
||||
var db *sql.DB
|
||||
var errSQLOpen error
|
||||
|
||||
if config.SQLDB == nil { // 没有已经存在的数据库连接,使用DSN初始化
|
||||
if config.DSN == "" {
|
||||
return nil, errors.New("->newDataSource-->DSN cannot be empty")
|
||||
}
|
||||
db, errSQLOpen = sql.Open(config.DriverName, config.DSN)
|
||||
if errSQLOpen != nil {
|
||||
errSQLOpen = fmt.Errorf("->newDataSource-->open数据库打开失败:%w", errSQLOpen)
|
||||
FuncLogError(nil, errSQLOpen)
|
||||
return nil, errSQLOpen
|
||||
}
|
||||
} else { // 使用已经存在的数据库连接
|
||||
db = config.SQLDB
|
||||
}
|
||||
|
||||
if config.MaxOpenConns == 0 {
|
||||
config.MaxOpenConns = 50
|
||||
}
|
||||
if config.MaxIdleConns == 0 {
|
||||
config.MaxIdleConns = 50
|
||||
}
|
||||
|
||||
if config.ConnMaxLifetimeSecond == 0 {
|
||||
config.ConnMaxLifetimeSecond = 600
|
||||
}
|
||||
|
||||
// 设置数据库最大连接数
|
||||
// Set the maximum number of database connections
|
||||
db.SetMaxOpenConns(config.MaxOpenConns)
|
||||
// 设置数据库最大空闲连接数
|
||||
// Set the maximum number of free connections to the database
|
||||
db.SetMaxIdleConns(config.MaxIdleConns)
|
||||
//连接存活秒时间. 默认600(10分钟)后连接被销毁重建.避免数据库主动断开连接,造成死连接.MySQL默认wait_timeout 28800秒(8小时)
|
||||
//(Connection survival time in seconds) Destroy and rebuild the connection after the default 600 seconds (10 minutes)
|
||||
//Prevent the database from actively disconnecting and causing dead connections. MySQL Default wait_timeout 28800 seconds
|
||||
db.SetConnMaxLifetime(time.Second * time.Duration(config.ConnMaxLifetimeSecond))
|
||||
|
||||
// 验证连接
|
||||
if pingerr := db.Ping(); pingerr != nil {
|
||||
pingerr = fmt.Errorf("->newDataSource-->ping数据库失败:%w", pingerr)
|
||||
FuncLogError(nil, pingerr)
|
||||
db.Close()
|
||||
return nil, pingerr
|
||||
}
|
||||
|
||||
return &dataSource{db}, nil
|
||||
}
|
||||
|
||||
// 事务参照:https://www.jianshu.com/p/2a144332c3db
|
||||
// Transaction reference: https://www.jianshu.com/p/2a144332c3db
|
||||
|
||||
// dataBaseConnection 数据库dbConnection会话,可以原生查询或者事务
|
||||
// dataBaseConnection Database session, native query or transaction.
|
||||
type dataBaseConnection struct {
|
||||
// 原生db
|
||||
// native db
|
||||
db *sql.DB
|
||||
|
||||
// 原生事务
|
||||
// native transaction
|
||||
tx *sql.Tx
|
||||
|
||||
// 数据库配置
|
||||
config *DataSourceConfig
|
||||
}
|
||||
|
||||
// beginTx 开启事务
|
||||
// beginTx Open transaction
|
||||
func (dbConnection *dataBaseConnection) beginTx(ctx context.Context) error {
|
||||
if dbConnection.tx != nil {
|
||||
return nil
|
||||
}
|
||||
// 设置事务配置,主要是隔离级别
|
||||
var txOptions *sql.TxOptions
|
||||
contextTxOptions := ctx.Value(contextTxOptionsKey)
|
||||
if contextTxOptions != nil {
|
||||
txOptions, _ = contextTxOptions.(*sql.TxOptions)
|
||||
} else {
|
||||
txOptions = dbConnection.config.DefaultTxOptions
|
||||
}
|
||||
|
||||
tx, err := dbConnection.db.BeginTx(ctx, txOptions)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->beginTx事务开启失败:%w", err)
|
||||
return err
|
||||
}
|
||||
dbConnection.tx = tx
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollback 回滚事务
|
||||
// rollback Rollback transaction
|
||||
func (dbConnection *dataBaseConnection) rollback() error {
|
||||
if dbConnection.tx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := dbConnection.tx.Rollback()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->rollback事务回滚失败:%w", err)
|
||||
return err
|
||||
}
|
||||
dbConnection.tx = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// commit 提交事务
|
||||
// commit Commit transaction
|
||||
func (dbConnection *dataBaseConnection) commit() error {
|
||||
if dbConnection.tx == nil {
|
||||
return errors.New("->dbConnection.commit()事务为空")
|
||||
}
|
||||
|
||||
err := dbConnection.tx.Commit()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->dbConnection.commit()事务提交失败:%w", err)
|
||||
return err
|
||||
}
|
||||
dbConnection.tx = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// execContext 执行sql语句,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
|
||||
// execContext Execute sql statement,If the transaction has been opened,it will be executed in transaction mode, if the transaction is not opened,it will be executed in non-transactional mode
|
||||
func (dbConnection *dataBaseConnection) execContext(ctx context.Context, sqlstr *string, argsValues *[]interface{}) (*sql.Result, error) {
|
||||
// reBindSQL 重新处理参数代入方式
|
||||
execsql, args, err := reBindSQL(dbConnection.config.Dialect, sqlstr, argsValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 更新语句处理ClickHouse特殊语法
|
||||
err = reUpdateSQL(dbConnection.config.Dialect, execsql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 执行前加入 hint
|
||||
err = wrapSQLHint(ctx, execsql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var start *time.Time
|
||||
var res sql.Result
|
||||
// 小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
|
||||
slowSQLMillis := dbConnection.config.SlowSQLMillis
|
||||
if slowSQLMillis == 0 {
|
||||
FuncPrintSQL(ctx, *execsql, *args, 0)
|
||||
} else if slowSQLMillis > 0 {
|
||||
now := time.Now() // 获取当前时间
|
||||
start = &now
|
||||
}
|
||||
if dbConnection.tx != nil {
|
||||
res, err = dbConnection.tx.ExecContext(ctx, *execsql, *args...)
|
||||
} else {
|
||||
res, err = dbConnection.db.ExecContext(ctx, *execsql, *args...)
|
||||
}
|
||||
if slowSQLMillis > 0 {
|
||||
slow := time.Since(*start).Milliseconds()
|
||||
if slow-int64(slowSQLMillis) >= 0 {
|
||||
FuncPrintSQL(ctx, *execsql, *args, slow)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->execContext执行错误:%w,-->zormErrorExecSQL:%s,-->zormErrorSQLValues:%v", err, *execsql, *args)
|
||||
}
|
||||
return &res, err
|
||||
}
|
||||
|
||||
// queryRowContext 如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
|
||||
func (dbConnection *dataBaseConnection) queryRowContext(ctx context.Context, sqlstr *string, argsValues *[]interface{}) (*sql.Row, error) {
|
||||
// reBindSQL 重新处理参数代入方式
|
||||
query, args, err := reBindSQL(dbConnection.config.Dialect, sqlstr, argsValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 执行前加入 hint
|
||||
err = wrapSQLHint(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var start *time.Time
|
||||
var row *sql.Row
|
||||
// 小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
|
||||
slowSQLMillis := dbConnection.config.SlowSQLMillis
|
||||
if slowSQLMillis == 0 {
|
||||
FuncPrintSQL(ctx, *query, *args, 0)
|
||||
} else if slowSQLMillis > 0 {
|
||||
now := time.Now() // 获取当前时间
|
||||
start = &now
|
||||
}
|
||||
|
||||
if dbConnection.tx != nil {
|
||||
row = dbConnection.tx.QueryRowContext(ctx, *query, *args...)
|
||||
} else {
|
||||
row = dbConnection.db.QueryRowContext(ctx, *query, *args...)
|
||||
}
|
||||
if slowSQLMillis > 0 {
|
||||
slow := time.Since(*start).Milliseconds()
|
||||
if slow-int64(slowSQLMillis) >= 0 {
|
||||
FuncPrintSQL(ctx, *query, *args, slow)
|
||||
}
|
||||
}
|
||||
return row, nil
|
||||
}
|
||||
|
||||
// queryContext 查询数据,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
|
||||
// queryRowContext Execute sql row statement,If the transaction has been opened,it will be executed in transaction mode, if the transaction is not opened,it will be executed in non-transactional mode
|
||||
func (dbConnection *dataBaseConnection) queryContext(ctx context.Context, sqlstr *string, argsValues *[]interface{}) (*sql.Rows, error) {
|
||||
// reBindSQL 重新处理参数代入方式
|
||||
query, args, err := reBindSQL(dbConnection.config.Dialect, sqlstr, argsValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 执行前加入 hint
|
||||
err = wrapSQLHint(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var start *time.Time
|
||||
var rows *sql.Rows
|
||||
// 小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
|
||||
slowSQLMillis := dbConnection.config.SlowSQLMillis
|
||||
if slowSQLMillis == 0 {
|
||||
FuncPrintSQL(ctx, *query, *args, 0)
|
||||
} else if slowSQLMillis > 0 {
|
||||
now := time.Now() // 获取当前时间
|
||||
start = &now
|
||||
}
|
||||
|
||||
if dbConnection.tx != nil {
|
||||
rows, err = dbConnection.tx.QueryContext(ctx, *query, *args...)
|
||||
} else {
|
||||
rows, err = dbConnection.db.QueryContext(ctx, *query, *args...)
|
||||
}
|
||||
if slowSQLMillis > 0 {
|
||||
slow := time.Since(*start).Milliseconds()
|
||||
if slow-int64(slowSQLMillis) >= 0 {
|
||||
FuncPrintSQL(ctx, *query, *args, slow)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->queryContext执行错误:%w,-->zormErrorExecSQL:%s,-->zormErrorSQLValues:%v", err, *query, *args)
|
||||
}
|
||||
return rows, err
|
||||
}
|
||||
|
||||
/*
|
||||
// prepareContext 预执行,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
|
||||
// prepareContext Pre-execution,If the transaction has been opened,it will be executed in transaction mode,if the transaction is not opened,it will be executed in non-transactional mode
|
||||
func (dbConnection *dataBaseConnection) prepareContext(ctx context.Context, query *string) (*sql.Stmt, error) {
|
||||
//打印SQL
|
||||
//print SQL
|
||||
if dbConnection.config.PrintSQL {
|
||||
//logger.Info("printSQL", logger.String("sql", query))
|
||||
FuncPrintSQL(ctx,*query, nil)
|
||||
}
|
||||
|
||||
if dbConnection.tx != nil {
|
||||
return dbConnection.tx.PrepareContext(ctx, *query)
|
||||
}
|
||||
|
||||
return dbConnection.db.PrepareContext(ctx, *query)
|
||||
}
|
||||
*/
|
@ -1,415 +0,0 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Multiprecision decimal numbers.
|
||||
// For floating-point formatting only; not general purpose.
|
||||
// Only operations are assign and (binary) left/right shift.
|
||||
// Can do binary floating point in multiprecision decimal precisely
|
||||
// because 2 divides 10; cannot do decimal floating point
|
||||
// in multiprecision binary precisely.
|
||||
|
||||
package decimal
|
||||
|
||||
type decimal struct {
|
||||
d [800]byte // digits, big-endian representation
|
||||
nd int // number of digits used
|
||||
dp int // decimal point
|
||||
neg bool // negative flag
|
||||
trunc bool // discarded nonzero digits beyond d[:nd]
|
||||
}
|
||||
|
||||
func (a *decimal) String() string {
|
||||
n := 10 + a.nd
|
||||
if a.dp > 0 {
|
||||
n += a.dp
|
||||
}
|
||||
if a.dp < 0 {
|
||||
n += -a.dp
|
||||
}
|
||||
|
||||
buf := make([]byte, n)
|
||||
w := 0
|
||||
switch {
|
||||
case a.nd == 0:
|
||||
return "0"
|
||||
|
||||
case a.dp <= 0:
|
||||
// zeros fill space between decimal point and digits
|
||||
buf[w] = '0'
|
||||
w++
|
||||
buf[w] = '.'
|
||||
w++
|
||||
w += digitZero(buf[w : w+-a.dp])
|
||||
w += copy(buf[w:], a.d[0:a.nd])
|
||||
|
||||
case a.dp < a.nd:
|
||||
// decimal point in middle of digits
|
||||
w += copy(buf[w:], a.d[0:a.dp])
|
||||
buf[w] = '.'
|
||||
w++
|
||||
w += copy(buf[w:], a.d[a.dp:a.nd])
|
||||
|
||||
default:
|
||||
// zeros fill space between digits and decimal point
|
||||
w += copy(buf[w:], a.d[0:a.nd])
|
||||
w += digitZero(buf[w : w+a.dp-a.nd])
|
||||
}
|
||||
return string(buf[0:w])
|
||||
}
|
||||
|
||||
func digitZero(dst []byte) int {
|
||||
for i := range dst {
|
||||
dst[i] = '0'
|
||||
}
|
||||
return len(dst)
|
||||
}
|
||||
|
||||
// trim trailing zeros from number.
|
||||
// (They are meaningless; the decimal point is tracked
|
||||
// independent of the number of digits.)
|
||||
func trim(a *decimal) {
|
||||
for a.nd > 0 && a.d[a.nd-1] == '0' {
|
||||
a.nd--
|
||||
}
|
||||
if a.nd == 0 {
|
||||
a.dp = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Assign v to a.
|
||||
func (a *decimal) Assign(v uint64) {
|
||||
var buf [24]byte
|
||||
|
||||
// Write reversed decimal in buf.
|
||||
n := 0
|
||||
for v > 0 {
|
||||
v1 := v / 10
|
||||
v -= 10 * v1
|
||||
buf[n] = byte(v + '0')
|
||||
n++
|
||||
v = v1
|
||||
}
|
||||
|
||||
// Reverse again to produce forward decimal in a.d.
|
||||
a.nd = 0
|
||||
for n--; n >= 0; n-- {
|
||||
a.d[a.nd] = buf[n]
|
||||
a.nd++
|
||||
}
|
||||
a.dp = a.nd
|
||||
trim(a)
|
||||
}
|
||||
|
||||
// Maximum shift that we can do in one pass without overflow.
|
||||
// A uint has 32 or 64 bits, and we have to be able to accommodate 9<<k.
|
||||
const uintSize = 32 << (^uint(0) >> 63)
|
||||
const maxShift = uintSize - 4
|
||||
|
||||
// Binary shift right (/ 2) by k bits. k <= maxShift to avoid overflow.
|
||||
func rightShift(a *decimal, k uint) {
|
||||
r := 0 // read pointer
|
||||
w := 0 // write pointer
|
||||
|
||||
// Pick up enough leading digits to cover first shift.
|
||||
var n uint
|
||||
for ; n>>k == 0; r++ {
|
||||
if r >= a.nd {
|
||||
if n == 0 {
|
||||
// a == 0; shouldn't get here, but handle anyway.
|
||||
a.nd = 0
|
||||
return
|
||||
}
|
||||
for n>>k == 0 {
|
||||
n = n * 10
|
||||
r++
|
||||
}
|
||||
break
|
||||
}
|
||||
c := uint(a.d[r])
|
||||
n = n*10 + c - '0'
|
||||
}
|
||||
a.dp -= r - 1
|
||||
|
||||
var mask uint = (1 << k) - 1
|
||||
|
||||
// Pick up a digit, put down a digit.
|
||||
for ; r < a.nd; r++ {
|
||||
c := uint(a.d[r])
|
||||
dig := n >> k
|
||||
n &= mask
|
||||
a.d[w] = byte(dig + '0')
|
||||
w++
|
||||
n = n*10 + c - '0'
|
||||
}
|
||||
|
||||
// Put down extra digits.
|
||||
for n > 0 {
|
||||
dig := n >> k
|
||||
n &= mask
|
||||
if w < len(a.d) {
|
||||
a.d[w] = byte(dig + '0')
|
||||
w++
|
||||
} else if dig > 0 {
|
||||
a.trunc = true
|
||||
}
|
||||
n = n * 10
|
||||
}
|
||||
|
||||
a.nd = w
|
||||
trim(a)
|
||||
}
|
||||
|
||||
// Cheat sheet for left shift: table indexed by shift count giving
|
||||
// number of new digits that will be introduced by that shift.
|
||||
//
|
||||
// For example, leftcheats[4] = {2, "625"}. That means that
|
||||
// if we are shifting by 4 (multiplying by 16), it will add 2 digits
|
||||
// when the string prefix is "625" through "999", and one fewer digit
|
||||
// if the string prefix is "000" through "624".
|
||||
//
|
||||
// Credit for this trick goes to Ken.
|
||||
|
||||
type leftCheat struct {
|
||||
delta int // number of new digits
|
||||
cutoff string // minus one digit if original < a.
|
||||
}
|
||||
|
||||
var leftcheats = []leftCheat{
|
||||
// Leading digits of 1/2^i = 5^i.
|
||||
// 5^23 is not an exact 64-bit floating point number,
|
||||
// so have to use bc for the math.
|
||||
// Go up to 60 to be large enough for 32bit and 64bit platforms.
|
||||
/*
|
||||
seq 60 | sed 's/^/5^/' | bc |
|
||||
awk 'BEGIN{ print "\t{ 0, \"\" }," }
|
||||
{
|
||||
log2 = log(2)/log(10)
|
||||
printf("\t{ %d, \"%s\" },\t// * %d\n",
|
||||
int(log2*NR+1), $0, 2**NR)
|
||||
}'
|
||||
*/
|
||||
{0, ""},
|
||||
{1, "5"}, // * 2
|
||||
{1, "25"}, // * 4
|
||||
{1, "125"}, // * 8
|
||||
{2, "625"}, // * 16
|
||||
{2, "3125"}, // * 32
|
||||
{2, "15625"}, // * 64
|
||||
{3, "78125"}, // * 128
|
||||
{3, "390625"}, // * 256
|
||||
{3, "1953125"}, // * 512
|
||||
{4, "9765625"}, // * 1024
|
||||
{4, "48828125"}, // * 2048
|
||||
{4, "244140625"}, // * 4096
|
||||
{4, "1220703125"}, // * 8192
|
||||
{5, "6103515625"}, // * 16384
|
||||
{5, "30517578125"}, // * 32768
|
||||
{5, "152587890625"}, // * 65536
|
||||
{6, "762939453125"}, // * 131072
|
||||
{6, "3814697265625"}, // * 262144
|
||||
{6, "19073486328125"}, // * 524288
|
||||
{7, "95367431640625"}, // * 1048576
|
||||
{7, "476837158203125"}, // * 2097152
|
||||
{7, "2384185791015625"}, // * 4194304
|
||||
{7, "11920928955078125"}, // * 8388608
|
||||
{8, "59604644775390625"}, // * 16777216
|
||||
{8, "298023223876953125"}, // * 33554432
|
||||
{8, "1490116119384765625"}, // * 67108864
|
||||
{9, "7450580596923828125"}, // * 134217728
|
||||
{9, "37252902984619140625"}, // * 268435456
|
||||
{9, "186264514923095703125"}, // * 536870912
|
||||
{10, "931322574615478515625"}, // * 1073741824
|
||||
{10, "4656612873077392578125"}, // * 2147483648
|
||||
{10, "23283064365386962890625"}, // * 4294967296
|
||||
{10, "116415321826934814453125"}, // * 8589934592
|
||||
{11, "582076609134674072265625"}, // * 17179869184
|
||||
{11, "2910383045673370361328125"}, // * 34359738368
|
||||
{11, "14551915228366851806640625"}, // * 68719476736
|
||||
{12, "72759576141834259033203125"}, // * 137438953472
|
||||
{12, "363797880709171295166015625"}, // * 274877906944
|
||||
{12, "1818989403545856475830078125"}, // * 549755813888
|
||||
{13, "9094947017729282379150390625"}, // * 1099511627776
|
||||
{13, "45474735088646411895751953125"}, // * 2199023255552
|
||||
{13, "227373675443232059478759765625"}, // * 4398046511104
|
||||
{13, "1136868377216160297393798828125"}, // * 8796093022208
|
||||
{14, "5684341886080801486968994140625"}, // * 17592186044416
|
||||
{14, "28421709430404007434844970703125"}, // * 35184372088832
|
||||
{14, "142108547152020037174224853515625"}, // * 70368744177664
|
||||
{15, "710542735760100185871124267578125"}, // * 140737488355328
|
||||
{15, "3552713678800500929355621337890625"}, // * 281474976710656
|
||||
{15, "17763568394002504646778106689453125"}, // * 562949953421312
|
||||
{16, "88817841970012523233890533447265625"}, // * 1125899906842624
|
||||
{16, "444089209850062616169452667236328125"}, // * 2251799813685248
|
||||
{16, "2220446049250313080847263336181640625"}, // * 4503599627370496
|
||||
{16, "11102230246251565404236316680908203125"}, // * 9007199254740992
|
||||
{17, "55511151231257827021181583404541015625"}, // * 18014398509481984
|
||||
{17, "277555756156289135105907917022705078125"}, // * 36028797018963968
|
||||
{17, "1387778780781445675529539585113525390625"}, // * 72057594037927936
|
||||
{18, "6938893903907228377647697925567626953125"}, // * 144115188075855872
|
||||
{18, "34694469519536141888238489627838134765625"}, // * 288230376151711744
|
||||
{18, "173472347597680709441192448139190673828125"}, // * 576460752303423488
|
||||
{19, "867361737988403547205962240695953369140625"}, // * 1152921504606846976
|
||||
}
|
||||
|
||||
// Is the leading prefix of b lexicographically less than s?
|
||||
func prefixIsLessThan(b []byte, s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if i >= len(b) {
|
||||
return true
|
||||
}
|
||||
if b[i] != s[i] {
|
||||
return b[i] < s[i]
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Binary shift left (* 2) by k bits. k <= maxShift to avoid overflow.
|
||||
func leftShift(a *decimal, k uint) {
|
||||
delta := leftcheats[k].delta
|
||||
if prefixIsLessThan(a.d[0:a.nd], leftcheats[k].cutoff) {
|
||||
delta--
|
||||
}
|
||||
|
||||
r := a.nd // read index
|
||||
w := a.nd + delta // write index
|
||||
|
||||
// Pick up a digit, put down a digit.
|
||||
var n uint
|
||||
for r--; r >= 0; r-- {
|
||||
n += (uint(a.d[r]) - '0') << k
|
||||
quo := n / 10
|
||||
rem := n - 10*quo
|
||||
w--
|
||||
if w < len(a.d) {
|
||||
a.d[w] = byte(rem + '0')
|
||||
} else if rem != 0 {
|
||||
a.trunc = true
|
||||
}
|
||||
n = quo
|
||||
}
|
||||
|
||||
// Put down extra digits.
|
||||
for n > 0 {
|
||||
quo := n / 10
|
||||
rem := n - 10*quo
|
||||
w--
|
||||
if w < len(a.d) {
|
||||
a.d[w] = byte(rem + '0')
|
||||
} else if rem != 0 {
|
||||
a.trunc = true
|
||||
}
|
||||
n = quo
|
||||
}
|
||||
|
||||
a.nd += delta
|
||||
if a.nd >= len(a.d) {
|
||||
a.nd = len(a.d)
|
||||
}
|
||||
a.dp += delta
|
||||
trim(a)
|
||||
}
|
||||
|
||||
// Binary shift left (k > 0) or right (k < 0).
|
||||
func (a *decimal) Shift(k int) {
|
||||
switch {
|
||||
case a.nd == 0:
|
||||
// nothing to do: a == 0
|
||||
case k > 0:
|
||||
for k > maxShift {
|
||||
leftShift(a, maxShift)
|
||||
k -= maxShift
|
||||
}
|
||||
leftShift(a, uint(k))
|
||||
case k < 0:
|
||||
for k < -maxShift {
|
||||
rightShift(a, maxShift)
|
||||
k += maxShift
|
||||
}
|
||||
rightShift(a, uint(-k))
|
||||
}
|
||||
}
|
||||
|
||||
// If we chop a at nd digits, should we round up?
|
||||
func shouldRoundUp(a *decimal, nd int) bool {
|
||||
if nd < 0 || nd >= a.nd {
|
||||
return false
|
||||
}
|
||||
if a.d[nd] == '5' && nd+1 == a.nd { // exactly halfway - round to even
|
||||
// if we truncated, a little higher than what's recorded - always round up
|
||||
if a.trunc {
|
||||
return true
|
||||
}
|
||||
return nd > 0 && (a.d[nd-1]-'0')%2 != 0
|
||||
}
|
||||
// not halfway - digit tells all
|
||||
return a.d[nd] >= '5'
|
||||
}
|
||||
|
||||
// Round a to nd digits (or fewer).
|
||||
// If nd is zero, it means we're rounding
|
||||
// just to the left of the digits, as in
|
||||
// 0.09 -> 0.1.
|
||||
func (a *decimal) Round(nd int) {
|
||||
if nd < 0 || nd >= a.nd {
|
||||
return
|
||||
}
|
||||
if shouldRoundUp(a, nd) {
|
||||
a.RoundUp(nd)
|
||||
} else {
|
||||
a.RoundDown(nd)
|
||||
}
|
||||
}
|
||||
|
||||
// Round a down to nd digits (or fewer).
|
||||
func (a *decimal) RoundDown(nd int) {
|
||||
if nd < 0 || nd >= a.nd {
|
||||
return
|
||||
}
|
||||
a.nd = nd
|
||||
trim(a)
|
||||
}
|
||||
|
||||
// Round a up to nd digits (or fewer).
|
||||
func (a *decimal) RoundUp(nd int) {
|
||||
if nd < 0 || nd >= a.nd {
|
||||
return
|
||||
}
|
||||
|
||||
// round up
|
||||
for i := nd - 1; i >= 0; i-- {
|
||||
c := a.d[i]
|
||||
if c < '9' { // can stop after this digit
|
||||
a.d[i]++
|
||||
a.nd = i + 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Number is all 9s.
|
||||
// Change to single 1 with adjusted decimal point.
|
||||
a.d[0] = '1'
|
||||
a.nd = 1
|
||||
a.dp++
|
||||
}
|
||||
|
||||
// Extract integer part, rounded appropriately.
|
||||
// No guarantees about overflow.
|
||||
func (a *decimal) RoundedInteger() uint64 {
|
||||
if a.dp > 20 {
|
||||
return 0xFFFFFFFFFFFFFFFF
|
||||
}
|
||||
var i int
|
||||
n := uint64(0)
|
||||
for i = 0; i < a.dp && i < a.nd; i++ {
|
||||
n = n*10 + uint64(a.d[i]-'0')
|
||||
}
|
||||
for ; i < a.dp; i++ {
|
||||
n *= 10
|
||||
}
|
||||
if shouldRoundUp(a, a.dp) {
|
||||
n++
|
||||
}
|
||||
return n
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,160 +0,0 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Multiprecision decimal numbers.
|
||||
// For floating-point formatting only; not general purpose.
|
||||
// Only operations are assign and (binary) left/right shift.
|
||||
// Can do binary floating point in multiprecision decimal precisely
|
||||
// because 2 divides 10; cannot do decimal floating point
|
||||
// in multiprecision binary precisely.
|
||||
|
||||
package decimal
|
||||
|
||||
type floatInfo struct {
|
||||
mantbits uint
|
||||
expbits uint
|
||||
bias int
|
||||
}
|
||||
|
||||
var float32info = floatInfo{23, 8, -127}
|
||||
var float64info = floatInfo{52, 11, -1023}
|
||||
|
||||
// roundShortest rounds d (= mant * 2^exp) to the shortest number of digits
|
||||
// that will let the original floating point value be precisely reconstructed.
|
||||
func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) {
|
||||
// If mantissa is zero, the number is zero; stop now.
|
||||
if mant == 0 {
|
||||
d.nd = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Compute upper and lower such that any decimal number
|
||||
// between upper and lower (possibly inclusive)
|
||||
// will round to the original floating point number.
|
||||
|
||||
// We may see at once that the number is already shortest.
|
||||
//
|
||||
// Suppose d is not denormal, so that 2^exp <= d < 10^dp.
|
||||
// The closest shorter number is at least 10^(dp-nd) away.
|
||||
// The lower/upper bounds computed below are at distance
|
||||
// at most 2^(exp-mantbits).
|
||||
//
|
||||
// So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits),
|
||||
// or equivalently log2(10)*(dp-nd) > exp-mantbits.
|
||||
// It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32).
|
||||
minexp := flt.bias + 1 // minimum possible exponent
|
||||
if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) {
|
||||
// The number is already shortest.
|
||||
return
|
||||
}
|
||||
|
||||
// d = mant << (exp - mantbits)
|
||||
// Next highest floating point number is mant+1 << exp-mantbits.
|
||||
// Our upper bound is halfway between, mant*2+1 << exp-mantbits-1.
|
||||
upper := new(decimal)
|
||||
upper.Assign(mant*2 + 1)
|
||||
upper.Shift(exp - int(flt.mantbits) - 1)
|
||||
|
||||
// d = mant << (exp - mantbits)
|
||||
// Next lowest floating point number is mant-1 << exp-mantbits,
|
||||
// unless mant-1 drops the significant bit and exp is not the minimum exp,
|
||||
// in which case the next lowest is mant*2-1 << exp-mantbits-1.
|
||||
// Either way, call it mantlo << explo-mantbits.
|
||||
// Our lower bound is halfway between, mantlo*2+1 << explo-mantbits-1.
|
||||
var mantlo uint64
|
||||
var explo int
|
||||
if mant > 1<<flt.mantbits || exp == minexp {
|
||||
mantlo = mant - 1
|
||||
explo = exp
|
||||
} else {
|
||||
mantlo = mant*2 - 1
|
||||
explo = exp - 1
|
||||
}
|
||||
lower := new(decimal)
|
||||
lower.Assign(mantlo*2 + 1)
|
||||
lower.Shift(explo - int(flt.mantbits) - 1)
|
||||
|
||||
// The upper and lower bounds are possible outputs only if
|
||||
// the original mantissa is even, so that IEEE round-to-even
|
||||
// would round to the original mantissa and not the neighbors.
|
||||
inclusive := mant%2 == 0
|
||||
|
||||
// As we walk the digits we want to know whether rounding up would fall
|
||||
// within the upper bound. This is tracked by upperdelta:
|
||||
//
|
||||
// If upperdelta == 0, the digits of d and upper are the same so far.
|
||||
//
|
||||
// If upperdelta == 1, we saw a difference of 1 between d and upper on a
|
||||
// previous digit and subsequently only 9s for d and 0s for upper.
|
||||
// (Thus rounding up may fall outside the bound, if it is exclusive.)
|
||||
//
|
||||
// If upperdelta == 2, then the difference is greater than 1
|
||||
// and we know that rounding up falls within the bound.
|
||||
var upperdelta uint8
|
||||
|
||||
// Now we can figure out the minimum number of digits required.
|
||||
// Walk along until d has distinguished itself from upper and lower.
|
||||
for ui := 0; ; ui++ {
|
||||
// lower, d, and upper may have the decimal points at different
|
||||
// places. In this case upper is the longest, so we iterate from
|
||||
// ui==0 and start li and mi at (possibly) -1.
|
||||
mi := ui - upper.dp + d.dp
|
||||
if mi >= d.nd {
|
||||
break
|
||||
}
|
||||
li := ui - upper.dp + lower.dp
|
||||
l := byte('0') // lower digit
|
||||
if li >= 0 && li < lower.nd {
|
||||
l = lower.d[li]
|
||||
}
|
||||
m := byte('0') // middle digit
|
||||
if mi >= 0 {
|
||||
m = d.d[mi]
|
||||
}
|
||||
u := byte('0') // upper digit
|
||||
if ui < upper.nd {
|
||||
u = upper.d[ui]
|
||||
}
|
||||
|
||||
// Okay to round down (truncate) if lower has a different digit
|
||||
// or if lower is inclusive and is exactly the result of rounding
|
||||
// down (i.e., and we have reached the final digit of lower).
|
||||
okdown := l != m || inclusive && li+1 == lower.nd
|
||||
|
||||
switch {
|
||||
case upperdelta == 0 && m+1 < u:
|
||||
// Example:
|
||||
// m = 12345xxx
|
||||
// u = 12347xxx
|
||||
upperdelta = 2
|
||||
case upperdelta == 0 && m != u:
|
||||
// Example:
|
||||
// m = 12345xxx
|
||||
// u = 12346xxx
|
||||
upperdelta = 1
|
||||
case upperdelta == 1 && (m != '9' || u != '0'):
|
||||
// Example:
|
||||
// m = 1234598x
|
||||
// u = 1234600x
|
||||
upperdelta = 2
|
||||
}
|
||||
// Okay to round up if upper has a different digit and either upper
|
||||
// is inclusive or upper is bigger than the result of rounding up.
|
||||
okup := upperdelta > 0 && (inclusive || upperdelta > 1 || ui+1 < upper.nd)
|
||||
|
||||
// If it's okay to do either, then round to the nearest one.
|
||||
// If it's okay to do only one, do it.
|
||||
switch {
|
||||
case okdown && okup:
|
||||
d.Round(mi + 1)
|
||||
return
|
||||
case okdown:
|
||||
d.RoundDown(mi + 1)
|
||||
return
|
||||
case okup:
|
||||
d.RoundUp(mi + 1)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,564 +0,0 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// tag标签的名称
|
||||
tagColumnName = "column"
|
||||
|
||||
// 输出字段 缓存的前缀
|
||||
exportPrefix = "_exportStructFields_"
|
||||
// 私有字段 缓存的前缀
|
||||
privatePrefix = "_privateStructFields_"
|
||||
// 数据库列名 缓存的前缀
|
||||
dbColumnNamePrefix = "_dbColumnName_"
|
||||
|
||||
// 数据库所有列名,经过排序 缓存的前缀
|
||||
dbColumnNameSlicePrefix = "_dbColumnNameSlice_"
|
||||
|
||||
// field对应的column的tag值 缓存的前缀
|
||||
// structFieldTagPrefix = "_structFieldTag_"
|
||||
// 数据库主键 缓存的前缀
|
||||
// dbPKNamePrefix = "_dbPKName_"
|
||||
)
|
||||
|
||||
// cacheStructFieldInfoMap 用于缓存反射的信息,sync.Map内部处理了并发锁
|
||||
var cacheStructFieldInfoMap *sync.Map = &sync.Map{}
|
||||
|
||||
// var cacheStructFieldInfoMap = make(map[string]map[string]reflect.StructField)
|
||||
|
||||
// 用于缓存field对应的column的tag值
|
||||
// var cacheStructFieldTagInfoMap = make(map[string]map[string]string)
|
||||
|
||||
// structFieldInfo 获取StructField的信息.只对struct或者*struct判断,如果是指针,返回指针下实际的struct类型
|
||||
// 第一个返回值是可以输出的字段(首字母大写),第二个是不能输出的字段(首字母小写)
|
||||
func structFieldInfo(typeOf *reflect.Type) error {
|
||||
if typeOf == nil {
|
||||
return errors.New("->structFieldInfo数据为空")
|
||||
}
|
||||
|
||||
entityName := (*typeOf).String()
|
||||
|
||||
// 缓存的key
|
||||
// 所有输出的属性,包含数据库字段,key是struct属性的名称,不区分大小写
|
||||
exportCacheKey := exportPrefix + entityName
|
||||
// 所有私有变量的属性,key是struct属性的名称,不区分大小写
|
||||
privateCacheKey := privatePrefix + entityName
|
||||
// 所有数据库的属性,key是数据库的字段名称,不区分大小写
|
||||
dbColumnCacheKey := dbColumnNamePrefix + entityName
|
||||
// 所有数据库字段名称的slice,经过排序,不区分大小写
|
||||
dbColumnNameSliceCacheKey := dbColumnNameSlicePrefix + entityName
|
||||
|
||||
// structFieldTagCacheKey := structFieldTagPrefix + entityName
|
||||
// dbPKNameCacheKey := dbPKNamePrefix + entityName
|
||||
// 缓存的数据库主键值
|
||||
_, exportOk := cacheStructFieldInfoMap.Load(exportCacheKey)
|
||||
//_, exportOk := cacheStructFieldInfoMap[exportCacheKey]
|
||||
//如果存在值,认为缓存中有所有的信息,不再处理
|
||||
if exportOk {
|
||||
return nil
|
||||
}
|
||||
// 获取字段长度
|
||||
fieldNum := (*typeOf).NumField()
|
||||
// 如果没有字段
|
||||
if fieldNum < 1 {
|
||||
return errors.New("->structFieldInfo-->NumField entity没有属性")
|
||||
}
|
||||
|
||||
// 声明所有字段的载体
|
||||
var allFieldMap *sync.Map = &sync.Map{}
|
||||
// anonymous := make([]reflect.StructField, 0)
|
||||
|
||||
// 缓存的数据
|
||||
exportStructFieldMap := make(map[string]reflect.StructField)
|
||||
privateStructFieldMap := make(map[string]reflect.StructField)
|
||||
dbColumnFieldMap := make(map[string]reflect.StructField)
|
||||
|
||||
// structFieldTagMap := make(map[string]string)
|
||||
dbColumnFieldNameSlice := make([]string, 0)
|
||||
|
||||
// 遍历sync.Map,要求输入一个func作为参数
|
||||
// 这个函数的入参、出参的类型都已经固定,不能修改
|
||||
// 可以在函数体内编写自己的代码,调用map中的k,v
|
||||
// var funcMapKV func(k, v interface{}) bool
|
||||
funcMapKV := func(k, v interface{}) bool {
|
||||
field := v.(reflect.StructField)
|
||||
fieldName := field.Name
|
||||
if ast.IsExported(fieldName) { // 如果是可以输出的,不区分大小写
|
||||
exportStructFieldMap[strings.ToLower(fieldName)] = field
|
||||
// 如果是数据库字段
|
||||
tagColumnValue := field.Tag.Get(tagColumnName)
|
||||
if len(tagColumnValue) > 0 {
|
||||
// dbColumnFieldMap[tagColumnValue] = field
|
||||
// 使用数据库字段的小写,处理oracle和达梦数据库的sql返回值大写
|
||||
tagColumnValueLower := strings.ToLower(tagColumnValue)
|
||||
dbColumnFieldMap[tagColumnValueLower] = field
|
||||
dbColumnFieldNameSlice = append(dbColumnFieldNameSlice, tagColumnValueLower)
|
||||
// structFieldTagMap[fieldName] = tagColumnValue
|
||||
}
|
||||
|
||||
} else { // 私有属性
|
||||
privateStructFieldMap[strings.ToLower(fieldName)] = field
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
// 并发锁,用于处理slice并发append
|
||||
var lock sync.Mutex
|
||||
// funcRecursiveAnonymous 递归调用struct的匿名属性,就近覆盖属性
|
||||
var funcRecursiveAnonymous func(allFieldMap *sync.Map, anonymous *reflect.StructField)
|
||||
funcRecursiveAnonymous = func(allFieldMap *sync.Map, anonymous *reflect.StructField) {
|
||||
// 字段类型
|
||||
anonymousTypeOf := anonymous.Type
|
||||
if anonymousTypeOf.Kind() == reflect.Ptr {
|
||||
// 获取指针下的Struct类型
|
||||
anonymousTypeOf = anonymousTypeOf.Elem()
|
||||
}
|
||||
|
||||
// 只处理Struct类型
|
||||
if anonymousTypeOf.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取字段长度
|
||||
fieldNum := anonymousTypeOf.NumField()
|
||||
// 如果没有字段
|
||||
if fieldNum < 1 {
|
||||
return
|
||||
}
|
||||
// 遍历所有字段
|
||||
for i := 0; i < fieldNum; i++ {
|
||||
anonymousField := anonymousTypeOf.Field(i)
|
||||
if anonymousField.Anonymous { // 匿名struct里自身又有匿名struct
|
||||
funcRecursiveAnonymous(allFieldMap, &anonymousField)
|
||||
} else if _, ok := allFieldMap.Load(anonymousField.Name); !ok { // 普通命名字段,而且没有记录过
|
||||
allFieldMap.Store(anonymousField.Name, anonymousField)
|
||||
lock.Lock()
|
||||
funcMapKV(anonymousField.Name, anonymousField)
|
||||
lock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 遍历所有字段,记录匿名属性
|
||||
for i := 0; i < fieldNum; i++ {
|
||||
field := (*typeOf).Field(i)
|
||||
if field.Anonymous { // 如果是匿名的
|
||||
funcRecursiveAnonymous(allFieldMap, &field)
|
||||
} else if _, ok := allFieldMap.Load(field.Name); !ok { // 普通命名字段,而且没有记录过
|
||||
allFieldMap.Store(field.Name, field)
|
||||
lock.Lock()
|
||||
funcMapKV(field.Name, field)
|
||||
lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// allFieldMap.Range(f)
|
||||
|
||||
// 加入缓存
|
||||
cacheStructFieldInfoMap.Store(exportCacheKey, exportStructFieldMap)
|
||||
cacheStructFieldInfoMap.Store(privateCacheKey, privateStructFieldMap)
|
||||
cacheStructFieldInfoMap.Store(dbColumnCacheKey, dbColumnFieldMap)
|
||||
// cacheStructFieldInfoMap[exportCacheKey] = exportStructFieldMap
|
||||
// cacheStructFieldInfoMap[privateCacheKey] = privateStructFieldMap
|
||||
// cacheStructFieldInfoMap[dbColumnCacheKey] = dbColumnFieldMap
|
||||
|
||||
// cacheStructFieldTagInfoMap[structFieldTagCacheKey] = structFieldTagMap
|
||||
|
||||
// 不按照字母顺序,按照反射获取的Struct属性顺序,生成insert语句和update语句
|
||||
// sort.Strings(dbColumnFieldNameSlice)
|
||||
cacheStructFieldInfoMap.Store(dbColumnNameSliceCacheKey, dbColumnFieldNameSlice)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setFieldValueByColumnName 根据数据库的字段名,找到struct映射的字段,并赋值
|
||||
func setFieldValueByColumnName(entity interface{}, columnName string, value interface{}) error {
|
||||
// 先从本地缓存中查找
|
||||
typeOf := reflect.TypeOf(entity)
|
||||
valueOf := reflect.ValueOf(entity)
|
||||
if typeOf.Kind() == reflect.Ptr { // 如果是指针
|
||||
typeOf = typeOf.Elem()
|
||||
valueOf = valueOf.Elem()
|
||||
}
|
||||
|
||||
dbMap, err := getDBColumnFieldMap(&typeOf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, ok := dbMap[strings.ToLower(columnName)]
|
||||
if ok { // 给主键赋值
|
||||
valueOf.FieldByName(f.Name).Set(reflect.ValueOf(value))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// structFieldValue 获取指定字段的值
|
||||
func structFieldValue(s interface{}, fieldName string) (interface{}, error) {
|
||||
if s == nil || len(fieldName) < 1 {
|
||||
return nil, errors.New("->structFieldValue数据为空")
|
||||
}
|
||||
// entity的s类型
|
||||
valueOf := reflect.ValueOf(s)
|
||||
|
||||
kind := valueOf.Kind()
|
||||
if !(kind == reflect.Ptr || kind == reflect.Struct) {
|
||||
return nil, errors.New("->structFieldValue必须是Struct或者*Struct类型")
|
||||
}
|
||||
|
||||
if kind == reflect.Ptr {
|
||||
// 获取指针下的Struct类型
|
||||
valueOf = valueOf.Elem()
|
||||
if valueOf.Kind() != reflect.Struct {
|
||||
return nil, errors.New("->structFieldValue必须是Struct或者*Struct类型")
|
||||
}
|
||||
}
|
||||
|
||||
// FieldByName方法返回的是reflect.Value类型,调用Interface()方法,返回原始类型的数据值
|
||||
value := valueOf.FieldByName(fieldName).Interface()
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// getDBColumnExportFieldMap 获取实体类的数据库字段,key是数据库的字段名称.同时返回所有的字段属性的map,key是实体类的属性.不区分大小写
|
||||
func getDBColumnExportFieldMap(typeOf *reflect.Type) (map[string]reflect.StructField, map[string]reflect.StructField, error) {
|
||||
dbColumnFieldMap, err := getCacheStructFieldInfoMap(typeOf, dbColumnNamePrefix)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
exportFieldMap, err := getCacheStructFieldInfoMap(typeOf, exportPrefix)
|
||||
return dbColumnFieldMap, exportFieldMap, err
|
||||
}
|
||||
|
||||
// getDBColumnFieldMap 获取实体类的数据库字段,key是数据库的字段名称.不区分大小写
|
||||
func getDBColumnFieldMap(typeOf *reflect.Type) (map[string]reflect.StructField, error) {
|
||||
return getCacheStructFieldInfoMap(typeOf, dbColumnNamePrefix)
|
||||
}
|
||||
|
||||
// getDBColumnFieldNameSlice 获取实体类的数据库字段,经过排序,key是数据库的字段名称.不区分大小写,
|
||||
func getDBColumnFieldNameSlice(typeOf *reflect.Type) ([]string, error) {
|
||||
dbColumnFieldSlice, dbmapErr := getCacheStructFieldInfo(typeOf, dbColumnNameSlicePrefix)
|
||||
if dbmapErr != nil {
|
||||
return nil, fmt.Errorf("->getDBColumnFieldNameSlice-->getCacheStructFieldInfo()取值错误:%w", dbmapErr)
|
||||
}
|
||||
dbcfSlice, efOK := dbColumnFieldSlice.([]string)
|
||||
if !efOK {
|
||||
return dbcfSlice, errors.New("->getDBColumnFieldNameSlice-->dbColumnFieldSlice取值转[]string类型异常")
|
||||
}
|
||||
return dbcfSlice, nil
|
||||
}
|
||||
|
||||
// getCacheStructFieldInfo 根据类型和key,获取缓存的数据字段信息slice,已经排序
|
||||
func getCacheStructFieldInfo(typeOf *reflect.Type, keyPrefix string) (interface{}, error) {
|
||||
if typeOf == nil {
|
||||
return nil, errors.New("->getCacheStructFieldInfo-->typeOf不能为空")
|
||||
}
|
||||
key := keyPrefix + (*typeOf).String()
|
||||
dbColumnFieldMap, dbOk := cacheStructFieldInfoMap.Load(key)
|
||||
// dbColumnFieldMap, dbOk := cacheStructFieldInfoMap[key]
|
||||
if !dbOk { // 缓存不存在
|
||||
// 获取实体类的输出字段和私有 字段
|
||||
err := structFieldInfo(typeOf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbColumnFieldMap, dbOk = cacheStructFieldInfoMap.Load(key)
|
||||
// dbColumnFieldMap, dbOk = cacheStructFieldInfoMap[key]
|
||||
if !dbOk {
|
||||
return nil, errors.New("->getCacheStructFieldInfo-->cacheStructFieldInfoMap.Load()获取数据库字段dbColumnFieldMap异常")
|
||||
}
|
||||
}
|
||||
|
||||
return dbColumnFieldMap, nil
|
||||
|
||||
// return dbColumnFieldMap, nil
|
||||
}
|
||||
|
||||
// getCacheStructFieldInfoMap 根据类型和key,获取缓存的字段信息
|
||||
func getCacheStructFieldInfoMap(typeOf *reflect.Type, keyPrefix string) (map[string]reflect.StructField, error) {
|
||||
dbColumnFieldMap, dbmapErr := getCacheStructFieldInfo(typeOf, keyPrefix)
|
||||
if dbmapErr != nil {
|
||||
return nil, fmt.Errorf("->getCacheStructFieldInfoMap-->getCacheStructFieldInfo()取值错误:%w", dbmapErr)
|
||||
}
|
||||
dbcfMap, efOK := dbColumnFieldMap.(map[string]reflect.StructField)
|
||||
if !efOK {
|
||||
return dbcfMap, errors.New("->getCacheStructFieldInfoMap-->dbColumnFieldMap取值转map[string]reflect.StructField类型异常")
|
||||
}
|
||||
return dbcfMap, nil
|
||||
|
||||
// return dbColumnFieldMap, nil
|
||||
}
|
||||
|
||||
// columnAndValue 根据保存的对象,返回插入的语句,需要插入的字段,字段的值
|
||||
func columnAndValue(entity interface{}) (reflect.Type, []reflect.StructField, []interface{}, error) {
|
||||
typeOf, checkerr := checkEntityKind(entity)
|
||||
if checkerr != nil {
|
||||
return typeOf, nil, nil, checkerr
|
||||
}
|
||||
// 获取实体类的反射,指针下的struct
|
||||
valueOf := reflect.ValueOf(entity).Elem()
|
||||
// reflect.Indirect
|
||||
|
||||
// 先从本地缓存中查找
|
||||
// typeOf := reflect.TypeOf(entity).Elem()
|
||||
|
||||
dbMap, err := getDBColumnFieldMap(&typeOf)
|
||||
if err != nil {
|
||||
return typeOf, nil, nil, err
|
||||
}
|
||||
dbSlice, err := getDBColumnFieldNameSlice(&typeOf)
|
||||
if err != nil {
|
||||
return typeOf, nil, nil, err
|
||||
}
|
||||
// 实体类公开字段的长度
|
||||
fLen := len(dbMap)
|
||||
// 长度不一致
|
||||
if fLen-len(dbSlice) != 0 {
|
||||
return typeOf, nil, nil, errors.New("->columnAndValue-->缓存的数据库字段和实体类字段不对应")
|
||||
}
|
||||
// 接收列的数组,这里是做一个副本,避免外部更改掉原始的列信息
|
||||
columns := make([]reflect.StructField, 0, fLen)
|
||||
// 接收值的数组
|
||||
values := make([]interface{}, 0, fLen)
|
||||
|
||||
// 遍历所有数据库属性
|
||||
for _, fieldName := range dbSlice {
|
||||
//获取字段类型的Kind
|
||||
// fieldKind := field.Type.Kind()
|
||||
//if !allowTypeMap[fieldKind] { //不允许的类型
|
||||
// continue
|
||||
//}
|
||||
field := dbMap[fieldName]
|
||||
columns = append(columns, field)
|
||||
// FieldByName方法返回的是reflect.Value类型,调用Interface()方法,返回原始类型的数据值.字段不会重名,不使用FieldByIndex()函数
|
||||
value := valueOf.FieldByName(field.Name).Interface()
|
||||
// 添加到记录值的数组
|
||||
values = append(values, value)
|
||||
|
||||
}
|
||||
|
||||
// 缓存数据库的列
|
||||
return typeOf, columns, values, nil
|
||||
}
|
||||
|
||||
// entityPKFieldName 获取实体类主键属性名称
|
||||
func entityPKFieldName(entity IEntityStruct, typeOf *reflect.Type) (string, error) {
|
||||
//检查是否是指针对象
|
||||
//typeOf, checkerr := checkEntityKind(entity)
|
||||
//if checkerr != nil {
|
||||
// return "", checkerr
|
||||
//}
|
||||
|
||||
// 缓存的key,TypeOf和ValueOf的String()方法,返回值不一样
|
||||
// typeOf := reflect.TypeOf(entity).Elem()
|
||||
|
||||
dbMap, err := getDBColumnFieldMap(typeOf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
field := dbMap[strings.ToLower(entity.GetPKColumnName())]
|
||||
return field.Name, nil
|
||||
}
|
||||
|
||||
// checkEntityKind 检查entity类型必须是*struct类型或者基础类型的指针
|
||||
func checkEntityKind(entity interface{}) (reflect.Type, error) {
|
||||
if entity == nil {
|
||||
return nil, errors.New("->checkEntityKind参数不能为空,必须是*struct类型或者基础类型的指针")
|
||||
}
|
||||
typeOf := reflect.TypeOf(entity)
|
||||
if typeOf.Kind() != reflect.Ptr { // 如果不是指针
|
||||
return nil, errors.New("->checkEntityKind必须是*struct类型或者基础类型的指针")
|
||||
}
|
||||
typeOf = typeOf.Elem()
|
||||
//if !(typeOf.Kind() == reflect.Struct || allowBaseTypeMap[typeOf.Kind()]) { //如果不是指针
|
||||
// return nil, errors.New("checkEntityKind必须是*struct类型或者基础类型的指针")
|
||||
//}
|
||||
return typeOf, nil
|
||||
}
|
||||
|
||||
// sqlRowsValues 包装接收sqlRows的Values数组,反射rows屏蔽数据库null值,兼容单个字段查询和Struct映射
|
||||
// fix:converting NULL to int is unsupported
|
||||
// 当读取数据库的值为NULL时,由于基本类型不支持为NULL,通过反射将未知driver.Value改为interface{},不再映射到struct实体类
|
||||
// 感谢@fastabler提交的pr
|
||||
// oneColumnScanner 只有一个字段,而且可以直接Scan,例如string或者[]string,不需要反射StructType进行处理
|
||||
func sqlRowsValues(ctx context.Context, dialect string, valueOf *reflect.Value, typeOf *reflect.Type, rows *sql.Rows, driverValue *reflect.Value, columnTypes []*sql.ColumnType, entity interface{}, dbColumnFieldMap, exportFieldMap *map[string]reflect.StructField) error {
|
||||
if entity == nil && valueOf == nil {
|
||||
return errors.New("->sqlRowsValues-->valueOfElem为nil")
|
||||
}
|
||||
|
||||
var valueOfElem reflect.Value
|
||||
if entity == nil && valueOf != nil {
|
||||
valueOfElem = valueOf.Elem()
|
||||
}
|
||||
|
||||
ctLen := len(columnTypes)
|
||||
// 声明载体数组,用于存放struct的属性指针
|
||||
// Declare a carrier array to store the attribute pointer of the struct
|
||||
values := make([]interface{}, ctLen)
|
||||
// 记录需要类型转换的字段信息
|
||||
var fieldTempDriverValueMap map[*sql.ColumnType]*driverValueInfo
|
||||
if iscdvm {
|
||||
fieldTempDriverValueMap = make(map[*sql.ColumnType]*driverValueInfo)
|
||||
}
|
||||
var err error
|
||||
var customDriverValueConver ICustomDriverValueConver
|
||||
var converOK bool
|
||||
|
||||
for i, columnType := range columnTypes {
|
||||
if iscdvm {
|
||||
databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName())
|
||||
// 根据接收的类型,获取到类型转换的接口实现,优先匹配指定的数据库类型
|
||||
customDriverValueConver, converOK = customDriverValueMap[dialect+"."+databaseTypeName]
|
||||
if !converOK {
|
||||
customDriverValueConver, converOK = customDriverValueMap[databaseTypeName]
|
||||
}
|
||||
}
|
||||
dv := driverValue.Index(i)
|
||||
if dv.IsValid() && dv.InterfaceData()[0] == 0 { // 该字段的数据库值是null,取默认值
|
||||
values[i] = new(interface{})
|
||||
continue
|
||||
} else if converOK { // 如果是需要转换的字段
|
||||
// 获取字段类型
|
||||
var structFieldType *reflect.Type
|
||||
if entity != nil { // 查询一个字段,并且可以直接接收
|
||||
structFieldType = typeOf
|
||||
} else { // 如果是struct类型
|
||||
field, err := getStructFieldByColumnType(columnType, dbColumnFieldMap, exportFieldMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if field != nil { // 存在这个字段
|
||||
vtype := field.Type
|
||||
structFieldType = &vtype
|
||||
}
|
||||
}
|
||||
tempDriverValue, err := customDriverValueConver.GetDriverValue(ctx, columnType, structFieldType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tempDriverValue == nil {
|
||||
return errors.New("->sqlRowsValues-->customDriverValueConver.GetDriverValue返回的driver.Value不能为nil")
|
||||
}
|
||||
values[i] = tempDriverValue
|
||||
|
||||
// 如果需要类型转换
|
||||
dvinfo := driverValueInfo{}
|
||||
dvinfo.customDriverValueConver = customDriverValueConver
|
||||
// dvinfo.columnType = columnType
|
||||
dvinfo.structFieldType = structFieldType
|
||||
dvinfo.tempDriverValue = tempDriverValue
|
||||
fieldTempDriverValueMap[columnType] = &dvinfo
|
||||
continue
|
||||
|
||||
} else if entity != nil { // 查询一个字段,并且可以直接接收
|
||||
values[i] = entity
|
||||
continue
|
||||
} else {
|
||||
field, err := getStructFieldByColumnType(columnType, dbColumnFieldMap, exportFieldMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if field == nil { // 如果不存在这个字段
|
||||
values[i] = new(interface{})
|
||||
} else {
|
||||
// fieldType := refPV.FieldByName(field.Name).Type()
|
||||
// v := reflect.New(field.Type).Interface()
|
||||
// 字段的反射值
|
||||
fieldValue := valueOfElem.FieldByName(field.Name)
|
||||
v := fieldValue.Addr().Interface()
|
||||
// v := new(interface{})
|
||||
values[i] = v
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
err = rows.Scan(values...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(fieldTempDriverValueMap) < 1 {
|
||||
return err
|
||||
}
|
||||
|
||||
// 循环需要替换的值
|
||||
for columnType, driverValueInfo := range fieldTempDriverValueMap {
|
||||
// 根据列名,字段类型,新值 返回符合接收类型值的指针,返回值是个指针,指针,指针!!!!
|
||||
// typeOf := fieldValue.Type()
|
||||
rightValue, errConverDriverValue := driverValueInfo.customDriverValueConver.ConverDriverValue(ctx, columnType, driverValueInfo.tempDriverValue, driverValueInfo.structFieldType)
|
||||
if errConverDriverValue != nil {
|
||||
errConverDriverValue = fmt.Errorf("->sqlRowsValues-->customDriverValueConver.ConverDriverValue错误:%w", errConverDriverValue)
|
||||
FuncLogError(ctx, errConverDriverValue)
|
||||
return errConverDriverValue
|
||||
}
|
||||
if entity != nil { // 查询一个字段,并且可以直接接收
|
||||
// entity = rightValue
|
||||
// valueOfElem.Set(reflect.ValueOf(rightValue).Elem())
|
||||
reflect.ValueOf(entity).Elem().Set(reflect.ValueOf(rightValue).Elem())
|
||||
continue
|
||||
} else { // 如果是Struct类型接收
|
||||
field, err := getStructFieldByColumnType(columnType, dbColumnFieldMap, exportFieldMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if field != nil { // 如果存在这个字段
|
||||
// 字段的反射值
|
||||
fieldValue := valueOfElem.FieldByName(field.Name)
|
||||
// 给字段赋值
|
||||
fieldValue.Set(reflect.ValueOf(rightValue).Elem())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// getStructFieldByColumnType 根据ColumnType获取StructField对象,兼容驼峰
|
||||
func getStructFieldByColumnType(columnType *sql.ColumnType, dbColumnFieldMap *map[string]reflect.StructField, exportFieldMap *map[string]reflect.StructField) (*reflect.StructField, error) {
|
||||
columnName := strings.ToLower(columnType.Name())
|
||||
// columnName := "test"
|
||||
// 从缓存中获取列名的field字段
|
||||
// Get the field field of the column name from the cache
|
||||
field, fok := (*dbColumnFieldMap)[columnName]
|
||||
if !fok {
|
||||
field, fok = (*exportFieldMap)[columnName]
|
||||
if !fok {
|
||||
// 尝试驼峰
|
||||
cname := strings.ReplaceAll(columnName, "_", "")
|
||||
field, fok = (*exportFieldMap)[cname]
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
if fok {
|
||||
return &field, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
@ -1,611 +0,0 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"gitee.com/chunanyong/zorm/decimal"
|
||||
)
|
||||
|
||||
// FuncDecimalValue 设置decimal类型接收值,复写函数自定义decimal实现,例如github.com/shopspring/decimal,返回的是指针
|
||||
var FuncDecimalValue = func(ctx context.Context, dialect string) interface{} {
|
||||
return &decimal.Decimal{}
|
||||
}
|
||||
|
||||
// OverrideFunc 重写ZORM的函数,用于风险监控,只要查看这个函数的调用,就知道哪些地方重写了函数,避免项目混乱.当你使用这个函数时,你必须知道自己在做什么
|
||||
// funcName 是需要重写的方法命,funcObject是对应的函数. 返回值bool是否重写成功,interface{}是重写前的函数
|
||||
// 一般是在init里调用重写
|
||||
func OverrideFunc(funcName string, funcObject interface{}) (bool, interface{}, error) {
|
||||
if funcName == "" {
|
||||
return false, nil, errors.New("->OverrideFunc-->funcName不能为空")
|
||||
}
|
||||
|
||||
// oldFunc 老的函数
|
||||
var oldFunc interface{} = nil
|
||||
switch funcName {
|
||||
case "Transaction":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, doTransaction func(ctx context.Context) (interface{}, error)) (interface{}, error))
|
||||
if ok {
|
||||
oldFunc = transaction
|
||||
transaction = newFunc
|
||||
}
|
||||
case "QueryRow":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder, entity interface{}) (bool, error))
|
||||
if ok {
|
||||
oldFunc = queryRow
|
||||
queryRow = newFunc
|
||||
}
|
||||
case "Query":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder, rowsSlicePtr interface{}, page *Page) error)
|
||||
if ok {
|
||||
oldFunc = query
|
||||
query = newFunc
|
||||
}
|
||||
|
||||
case "QueryRowMap":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder) (map[string]interface{}, error))
|
||||
if ok {
|
||||
oldFunc = queryRowMap
|
||||
queryRowMap = newFunc
|
||||
}
|
||||
case "QueryMap":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder, page *Page) ([]map[string]interface{}, error))
|
||||
if ok {
|
||||
oldFunc = queryMap
|
||||
queryMap = newFunc
|
||||
}
|
||||
case "UpdateFinder":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder) (int, error))
|
||||
if ok {
|
||||
oldFunc = updateFinder
|
||||
updateFinder = newFunc
|
||||
}
|
||||
case "Insert":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = insert
|
||||
insert = newFunc
|
||||
}
|
||||
case "InsertSlice":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entityStructSlice []IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = insertSlice
|
||||
insertSlice = newFunc
|
||||
}
|
||||
case "Update":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = update
|
||||
update = newFunc
|
||||
}
|
||||
case "UpdateNotZeroValue":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = updateNotZeroValue
|
||||
updateNotZeroValue = newFunc
|
||||
}
|
||||
case "Delete":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = delete
|
||||
delete = newFunc
|
||||
}
|
||||
|
||||
case "InsertEntityMap":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityMap) (int, error))
|
||||
if ok {
|
||||
oldFunc = insertEntityMap
|
||||
insertEntityMap = newFunc
|
||||
}
|
||||
case "InsertEntityMapSlice":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity []IEntityMap) (int, error))
|
||||
if ok {
|
||||
oldFunc = insertEntityMapSlice
|
||||
insertEntityMapSlice = newFunc
|
||||
}
|
||||
case "UpdateEntityMap":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityMap) (int, error))
|
||||
if ok {
|
||||
oldFunc = updateEntityMap
|
||||
updateEntityMap = newFunc
|
||||
}
|
||||
default:
|
||||
return false, oldFunc, errors.New("->OverrideFunc-->函数" + funcName + "暂不支持重写或不存在")
|
||||
}
|
||||
if oldFunc == nil {
|
||||
return false, oldFunc, errors.New("->OverrideFunc-->请检查传入的" + funcName + "函数实现,断言转换失败.")
|
||||
}
|
||||
return true, oldFunc, nil
|
||||
}
|
||||
|
||||
// typeConvertInt64toInt int64 转 int
|
||||
func typeConvertInt64toInt(from int64) (int, error) {
|
||||
strInt64 := strconv.FormatInt(from, 10)
|
||||
return strconv.Atoi(strInt64)
|
||||
}
|
||||
|
||||
/*
|
||||
func typeConvertFloat32(i interface{}) (float32, error) {
|
||||
if i == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if v, ok := i.(float32); ok {
|
||||
return v, nil
|
||||
}
|
||||
v, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
vf, err := strconv.ParseFloat(v, 32)
|
||||
return float32(vf), err
|
||||
}
|
||||
|
||||
func typeConvertFloat64(i interface{}) (float64, error) {
|
||||
if i == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if v, ok := i.(float64); ok {
|
||||
return v, nil
|
||||
}
|
||||
v, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(v, 64)
|
||||
}
|
||||
|
||||
func typeConvertDecimal(i interface{}) (decimal.Decimal, error) {
|
||||
if i == nil {
|
||||
return decimal.Zero, nil
|
||||
}
|
||||
if v, ok := i.(decimal.Decimal); ok {
|
||||
return v, nil
|
||||
}
|
||||
v, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return decimal.Zero, err
|
||||
}
|
||||
return decimal.NewFromString(v)
|
||||
}
|
||||
|
||||
func typeConvertInt64(i interface{}) (int64, error) {
|
||||
if i == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if v, ok := i.(int64); ok {
|
||||
return v, nil
|
||||
}
|
||||
v, err := typeConvertInt(i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(v), err
|
||||
}
|
||||
|
||||
func typeConvertString(i interface{}) (string, error) {
|
||||
if i == nil {
|
||||
return "", nil
|
||||
}
|
||||
switch value := i.(type) {
|
||||
case int:
|
||||
return strconv.Itoa(value), nil
|
||||
case int8:
|
||||
return strconv.Itoa(int(value)), nil
|
||||
case int16:
|
||||
return strconv.Itoa(int(value)), nil
|
||||
case int32:
|
||||
return strconv.Itoa(int(value)), nil
|
||||
case int64:
|
||||
return strconv.Itoa(int(value)), nil
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(value), 'f', -1, 32), nil
|
||||
case float64:
|
||||
return strconv.FormatFloat(value, 'f', -1, 64), nil
|
||||
case bool:
|
||||
return strconv.FormatBool(value), nil
|
||||
case string:
|
||||
return value, nil
|
||||
case []byte:
|
||||
return string(value), nil
|
||||
default:
|
||||
return fmt.Sprintf("%v", value), nil
|
||||
}
|
||||
}
|
||||
|
||||
//false: "", 0, false, off
|
||||
func typeConvertBool(i interface{}) (bool, error) {
|
||||
if i == nil {
|
||||
return false, nil
|
||||
}
|
||||
if v, ok := i.(bool); ok {
|
||||
return v, nil
|
||||
}
|
||||
s, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if s != "" && s != "0" && s != "false" && s != "off" {
|
||||
return true, err
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func typeConvertInt(i interface{}) (int, error) {
|
||||
if i == nil {
|
||||
return 0, nil
|
||||
}
|
||||
switch value := i.(type) {
|
||||
case int:
|
||||
return value, nil
|
||||
case int8:
|
||||
return int(value), nil
|
||||
case int16:
|
||||
return int(value), nil
|
||||
case int32:
|
||||
return int(value), nil
|
||||
case int64:
|
||||
return int(value), nil
|
||||
case uint:
|
||||
return int(value), nil
|
||||
case uint8:
|
||||
return int(value), nil
|
||||
case uint16:
|
||||
return int(value), nil
|
||||
case uint32:
|
||||
return int(value), nil
|
||||
case uint64:
|
||||
return int(value), nil
|
||||
case float32:
|
||||
return int(value), nil
|
||||
case float64:
|
||||
return int(value), nil
|
||||
case bool:
|
||||
if value {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, nil
|
||||
default:
|
||||
v, err := typeConvertString(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.Atoi(v)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
func typeConvertTime(i interface{}, format string, TZLocation ...*time.Location) (time.Time, error) {
|
||||
s, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return typeConvertStrToTime(s, format, TZLocation...)
|
||||
}
|
||||
|
||||
func typeConvertStrToTime(str string, format string, TZLocation ...*time.Location) (time.Time, error) {
|
||||
if len(TZLocation) > 0 {
|
||||
return time.ParseInLocation(format, str, TZLocation[0])
|
||||
}
|
||||
return time.ParseInLocation(format, str, time.Local)
|
||||
}
|
||||
|
||||
func encodeString(s string) []byte {
|
||||
return []byte(s)
|
||||
}
|
||||
|
||||
func decodeToString(b []byte) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func encodeBool(b bool) []byte {
|
||||
if b {
|
||||
return []byte{1}
|
||||
}
|
||||
return []byte{0}
|
||||
|
||||
}
|
||||
|
||||
func encodeInt(i int) []byte {
|
||||
if i <= math.MaxInt8 {
|
||||
return encodeInt8(int8(i))
|
||||
} else if i <= math.MaxInt16 {
|
||||
return encodeInt16(int16(i))
|
||||
} else if i <= math.MaxInt32 {
|
||||
return encodeInt32(int32(i))
|
||||
} else {
|
||||
return encodeInt64(int64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func encodeUint(i uint) []byte {
|
||||
if i <= math.MaxUint8 {
|
||||
return encodeUint8(uint8(i))
|
||||
} else if i <= math.MaxUint16 {
|
||||
return encodeUint16(uint16(i))
|
||||
} else if i <= math.MaxUint32 {
|
||||
return encodeUint32(uint32(i))
|
||||
} else {
|
||||
return encodeUint64(uint64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func encodeInt8(i int8) []byte {
|
||||
return []byte{byte(i)}
|
||||
}
|
||||
|
||||
func encodeUint8(i uint8) []byte {
|
||||
return []byte{byte(i)}
|
||||
}
|
||||
|
||||
func encodeInt16(i int16) []byte {
|
||||
bytes := make([]byte, 2)
|
||||
binary.LittleEndian.PutUint16(bytes, uint16(i))
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeUint16(i uint16) []byte {
|
||||
bytes := make([]byte, 2)
|
||||
binary.LittleEndian.PutUint16(bytes, i)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeInt32(i int32) []byte {
|
||||
bytes := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bytes, uint32(i))
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeUint32(i uint32) []byte {
|
||||
bytes := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bytes, i)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeInt64(i int64) []byte {
|
||||
bytes := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(bytes, uint64(i))
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeUint64(i uint64) []byte {
|
||||
bytes := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(bytes, i)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeFloat32(f float32) []byte {
|
||||
bits := math.Float32bits(f)
|
||||
bytes := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bytes, bits)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeFloat64(f float64) []byte {
|
||||
bits := math.Float64bits(f)
|
||||
bytes := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(bytes, bits)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encode(vs ...interface{}) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
for i := 0; i < len(vs); i++ {
|
||||
switch value := vs[i].(type) {
|
||||
case int:
|
||||
buf.Write(encodeInt(value))
|
||||
case int8:
|
||||
buf.Write(encodeInt8(value))
|
||||
case int16:
|
||||
buf.Write(encodeInt16(value))
|
||||
case int32:
|
||||
buf.Write(encodeInt32(value))
|
||||
case int64:
|
||||
buf.Write(encodeInt64(value))
|
||||
case uint:
|
||||
buf.Write(encodeUint(value))
|
||||
case uint8:
|
||||
buf.Write(encodeUint8(value))
|
||||
case uint16:
|
||||
buf.Write(encodeUint16(value))
|
||||
case uint32:
|
||||
buf.Write(encodeUint32(value))
|
||||
case uint64:
|
||||
buf.Write(encodeUint64(value))
|
||||
case bool:
|
||||
buf.Write(encodeBool(value))
|
||||
case string:
|
||||
buf.Write(encodeString(value))
|
||||
case []byte:
|
||||
buf.Write(value)
|
||||
case float32:
|
||||
buf.Write(encodeFloat32(value))
|
||||
case float64:
|
||||
buf.Write(encodeFloat64(value))
|
||||
default:
|
||||
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
|
||||
buf.Write(encodeString(fmt.Sprintf("%v", value)))
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func isNumeric(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] < byte('0') || s[i] > byte('9') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
func typeConvertTimeDuration(i interface{}) time.Duration {
|
||||
return time.Duration(typeConvertInt64(i))
|
||||
}
|
||||
|
||||
func typeConvertBytes(i interface{}) []byte {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
if r, ok := i.([]byte); ok {
|
||||
return r
|
||||
}
|
||||
return encode(i)
|
||||
|
||||
}
|
||||
|
||||
func typeConvertStrings(i interface{}) []string {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
if r, ok := i.([]string); ok {
|
||||
return r
|
||||
} else if r, ok := i.([]interface{}); ok {
|
||||
strs := make([]string, len(r))
|
||||
for k, v := range r {
|
||||
strs[k] = typeConvertString(v)
|
||||
}
|
||||
return strs
|
||||
}
|
||||
return []string{fmt.Sprintf("%v", i)}
|
||||
}
|
||||
|
||||
func typeConvertInt8(i interface{}) int8 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(int8); ok {
|
||||
return v
|
||||
}
|
||||
return int8(typeConvertInt(i))
|
||||
}
|
||||
|
||||
func typeConvertInt16(i interface{}) int16 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(int16); ok {
|
||||
return v
|
||||
}
|
||||
return int16(typeConvertInt(i))
|
||||
}
|
||||
|
||||
func typeConvertInt32(i interface{}) int32 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(int32); ok {
|
||||
return v
|
||||
}
|
||||
return int32(typeConvertInt(i))
|
||||
}
|
||||
|
||||
func typeConvertUint(i interface{}) uint {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
switch value := i.(type) {
|
||||
case int:
|
||||
return uint(value)
|
||||
case int8:
|
||||
return uint(value)
|
||||
case int16:
|
||||
return uint(value)
|
||||
case int32:
|
||||
return uint(value)
|
||||
case int64:
|
||||
return uint(value)
|
||||
case uint:
|
||||
return value
|
||||
case uint8:
|
||||
return uint(value)
|
||||
case uint16:
|
||||
return uint(value)
|
||||
case uint32:
|
||||
return uint(value)
|
||||
case uint64:
|
||||
return uint(value)
|
||||
case float32:
|
||||
return uint(value)
|
||||
case float64:
|
||||
return uint(value)
|
||||
case bool:
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
v, _ := strconv.ParseUint(typeConvertString(value), 10, 64)
|
||||
return uint(v)
|
||||
}
|
||||
}
|
||||
|
||||
func typeConvertUint8(i interface{}) uint8 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(uint8); ok {
|
||||
return v
|
||||
}
|
||||
return uint8(typeConvertUint(i))
|
||||
}
|
||||
|
||||
func typeConvertUint16(i interface{}) uint16 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(uint16); ok {
|
||||
return v
|
||||
}
|
||||
return uint16(typeConvertUint(i))
|
||||
}
|
||||
|
||||
func typeConvertUint32(i interface{}) uint32 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(uint32); ok {
|
||||
return v
|
||||
}
|
||||
return uint32(typeConvertUint(i))
|
||||
}
|
||||
|
||||
func typeConvertUint64(i interface{}) uint64 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(uint64); ok {
|
||||
return v
|
||||
}
|
||||
return uint64(typeConvertUint(i))
|
||||
}
|
||||
*/
|
Before Width: | Height: | Size: 7.8 KiB |
@ -1,4 +0,0 @@
|
||||
.db
|
||||
*.test
|
||||
*~
|
||||
*.swp
|
@ -1,7 +0,0 @@
|
||||
Copyright (C) 2011 Blake Mizerany
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -1,99 +0,0 @@
|
||||
# pq - A pure Go postgres driver for Go's database/sql package
|
||||
|
||||
**This package is now deprecated. The up to date version is at
|
||||
[github.com/lib/pq](https://github.com/lib/pq).**
|
||||
|
||||
## Install
|
||||
|
||||
go get github.com/bmizerany/pq
|
||||
|
||||
## Docs
|
||||
|
||||
<http://go.pkgdoc.org/github.com/bmizerany/pq>
|
||||
|
||||
## Use
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "github.com/bmizerany/pq"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full")
|
||||
// ...
|
||||
}
|
||||
|
||||
**Connection String Parameters**
|
||||
|
||||
These are a subset of the libpq connection parameters. In addition, a
|
||||
number of the [environment
|
||||
variables](http://www.postgresql.org/docs/9.1/static/libpq-envars.html)
|
||||
supported by libpq are also supported. Just like libpq, these have
|
||||
lower precedence than explicitly provided connection parameters.
|
||||
|
||||
See http://www.postgresql.org/docs/9.1/static/libpq-connect.html.
|
||||
|
||||
* `dbname` - The name of the database to connect to
|
||||
* `user` - The user to sign in as
|
||||
* `password` - The user's password
|
||||
* `host` - The host to connect to. Values that start with `/` are for unix domain sockets. (default is `localhost`)
|
||||
* `port` - The port to bind to. (default is `5432`)
|
||||
* `sslmode` - Whether or not to use SSL (default is `require`, this is not the default for libpq)
|
||||
Valid values are:
|
||||
* `disable` - No SSL
|
||||
* `require` - Always SSL (skip verification)
|
||||
* `verify-full` - Always SSL (require verification)
|
||||
|
||||
See http://golang.org/pkg/database/sql to learn how to use with `pq` through the `database/sql` package.
|
||||
|
||||
## Tests
|
||||
|
||||
`go test` is used for testing. A running PostgreSQL server is
|
||||
required, with the ability to log in. The default database to connect
|
||||
to test with is "pqgotest," but it can be overridden using environment
|
||||
variables.
|
||||
|
||||
Example:
|
||||
|
||||
PGHOST=/var/run/postgresql go test pq
|
||||
|
||||
## Features
|
||||
|
||||
* SSL
|
||||
* Handles bad connections for `database/sql`
|
||||
* Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`)
|
||||
* Scan binary blobs correctly (i.e. `bytea`)
|
||||
* pq.ParseURL for converting urls to connection strings for sql.Open.
|
||||
* Many libpq compatible environment variables
|
||||
* Unix socket support
|
||||
|
||||
## Future / Things you can help with
|
||||
|
||||
* Notifications: `LISTEN`/`NOTIFY`
|
||||
* `hstore` sugar (i.e. handling hstore in `rows.Scan`)
|
||||
|
||||
## Thank you (alphabetical)
|
||||
|
||||
Some of these contributors are from the original library `bmizerany/pq.go` whose
|
||||
code still exists in here.
|
||||
|
||||
* Andy Balholm (andybalholm)
|
||||
* Ben Berkert (benburkert)
|
||||
* Bjørn Madsen (aeons)
|
||||
* Blake Gentry (bgentry)
|
||||
* Brad Fitzpatrick (bradfitz)
|
||||
* Daniel Farina (fdr)
|
||||
* Everyone at The Go Team
|
||||
* Federico Romero (federomero)
|
||||
* Heroku (heroku)
|
||||
* John Gallagher (jgallagher)
|
||||
* Kamil Kisiel (kisielk)
|
||||
* Keith Rarick (kr)
|
||||
* Marc Brinkmann (mbr)
|
||||
* Martin Olsen (martinolsen)
|
||||
* Mike Lewis (mikelikespie)
|
||||
* Ryan Smith (ryandotsmith)
|
||||
* Samuel Stauffer (samuel)
|
||||
* notedit (notedit)
|
@ -1,80 +0,0 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type readBuf []byte
|
||||
|
||||
func (b *readBuf) int32() (n int) {
|
||||
n = int(int32(binary.BigEndian.Uint32(*b)))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) oid() (n oid) {
|
||||
n = oid(binary.BigEndian.Uint32(*b))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) int16() (n int) {
|
||||
n = int(binary.BigEndian.Uint16(*b))
|
||||
*b = (*b)[2:]
|
||||
return
|
||||
}
|
||||
|
||||
var stringTerm = []byte{0}
|
||||
|
||||
func (b *readBuf) string() string {
|
||||
i := bytes.Index(*b, stringTerm)
|
||||
if i < 0 {
|
||||
errorf("invalid message format; expected string terminator")
|
||||
}
|
||||
s := (*b)[:i]
|
||||
*b = (*b)[i+1:]
|
||||
return string(s)
|
||||
}
|
||||
|
||||
func (b *readBuf) next(n int) (v []byte) {
|
||||
v = (*b)[:n]
|
||||
*b = (*b)[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) byte() byte {
|
||||
return b.next(1)[0]
|
||||
}
|
||||
|
||||
type writeBuf []byte
|
||||
|
||||
func newWriteBuf(c byte) *writeBuf {
|
||||
b := make(writeBuf, 5)
|
||||
b[0] = c
|
||||
return &b
|
||||
}
|
||||
|
||||
func (b *writeBuf) int32(n int) {
|
||||
x := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(x, uint32(n))
|
||||
*b = append(*b, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) int16(n int) {
|
||||
x := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(x, uint16(n))
|
||||
*b = append(*b, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) string(s string) {
|
||||
*b = append(*b, (s + "\000")...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) byte(c byte) {
|
||||
*b = append(*b, c)
|
||||
}
|
||||
|
||||
func (b *writeBuf) bytes(v []byte) {
|
||||
*b = append(*b, v...)
|
||||
}
|
@ -1,678 +0,0 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
|
||||
ErrNotSupported = errors.New("pq: invalid command")
|
||||
)
|
||||
|
||||
type drv struct{}
|
||||
|
||||
func (d *drv) Open(name string) (driver.Conn, error) {
|
||||
return Open(name)
|
||||
}
|
||||
|
||||
func init() {
|
||||
sql.Register("postgres", &drv{})
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
c net.Conn
|
||||
buf *bufio.Reader
|
||||
namei int
|
||||
}
|
||||
|
||||
func Open(name string) (_ driver.Conn, err error) {
|
||||
defer errRecover(&err)
|
||||
defer errRecoverWithPGReason(&err)
|
||||
|
||||
o := make(Values)
|
||||
|
||||
// A number of defaults are applied here, in this order:
|
||||
//
|
||||
// * Very low precedence defaults applied in every situation
|
||||
// * Environment variables
|
||||
// * Explicitly passed connection information
|
||||
o.Set("host", "localhost")
|
||||
o.Set("port", "5432")
|
||||
|
||||
// Default the username, but ignore errors, because a user
|
||||
// passed in via environment variable or connection string
|
||||
// would be okay. This can result in connections failing
|
||||
// *sometimes* if the client relies on being able to determine
|
||||
// the current username and there are intermittent problems.
|
||||
u, err := user.Current()
|
||||
if err == nil {
|
||||
o.Set("user", u.Username)
|
||||
}
|
||||
|
||||
for k, v := range parseEnviron(os.Environ()) {
|
||||
o.Set(k, v)
|
||||
}
|
||||
|
||||
parseOpts(name, o)
|
||||
|
||||
c, err := net.Dial(network(o))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cn := &conn{c: c}
|
||||
cn.ssl(o)
|
||||
cn.buf = bufio.NewReader(cn.c)
|
||||
cn.startup(o)
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func network(o Values) (string, string) {
|
||||
host := o.Get("host")
|
||||
|
||||
if strings.HasPrefix(host, "/") {
|
||||
sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
|
||||
return "unix", sockPath
|
||||
}
|
||||
|
||||
return "tcp", host + ":" + o.Get("port")
|
||||
}
|
||||
|
||||
type Values map[string]string
|
||||
|
||||
func (vs Values) Set(k, v string) {
|
||||
vs[k] = v
|
||||
}
|
||||
|
||||
func (vs Values) Get(k string) (v string) {
|
||||
v, _ = vs[k]
|
||||
return
|
||||
}
|
||||
|
||||
func parseOpts(name string, o Values) {
|
||||
if len(name) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ps := strings.Split(name, " ")
|
||||
for _, p := range ps {
|
||||
kv := strings.Split(p, "=")
|
||||
if len(kv) < 2 {
|
||||
errorf("invalid option: %q", p)
|
||||
}
|
||||
o.Set(kv[0], kv[1])
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) Begin() (driver.Tx, error) {
|
||||
_, err := cn.Exec("BEGIN", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cn, err
|
||||
}
|
||||
|
||||
func (cn *conn) Commit() error {
|
||||
_, err := cn.Exec("COMMIT", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (cn *conn) Rollback() error {
|
||||
_, err := cn.Exec("ROLLBACK", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (cn *conn) gname() string {
|
||||
cn.namei++
|
||||
return strconv.FormatInt(int64(cn.namei), 10)
|
||||
}
|
||||
|
||||
func (cn *conn) simpleQuery(q string) (res driver.Result, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
b := newWriteBuf('Q')
|
||||
b.string(q)
|
||||
cn.send(b)
|
||||
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case 'C':
|
||||
res = parseComplete(r.string())
|
||||
case 'Z':
|
||||
// done
|
||||
return
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'T', 'N', 'S':
|
||||
// ignore
|
||||
default:
|
||||
errorf("unknown response for simple query: %q", t)
|
||||
}
|
||||
}
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
st := &stmt{cn: cn, name: stmtName, query: q}
|
||||
|
||||
b := newWriteBuf('P')
|
||||
b.string(st.name)
|
||||
b.string(q)
|
||||
b.int16(0)
|
||||
cn.send(b)
|
||||
|
||||
b = newWriteBuf('D')
|
||||
b.byte('S')
|
||||
b.string(st.name)
|
||||
cn.send(b)
|
||||
|
||||
cn.send(newWriteBuf('S'))
|
||||
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case '1', '2', 'N':
|
||||
case 't':
|
||||
st.nparams = int(r.int16())
|
||||
st.paramTyps = make([]oid, st.nparams, st.nparams)
|
||||
|
||||
for i := 0; i < st.nparams; i += 1 {
|
||||
st.paramTyps[i] = r.oid()
|
||||
}
|
||||
case 'T':
|
||||
n := r.int16()
|
||||
st.cols = make([]string, n)
|
||||
st.rowTyps = make([]oid, n)
|
||||
for i := range st.cols {
|
||||
st.cols[i] = r.string()
|
||||
r.next(6)
|
||||
st.rowTyps[i] = r.oid()
|
||||
r.next(8)
|
||||
}
|
||||
case 'n':
|
||||
// no data
|
||||
case 'Z':
|
||||
return st, err
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
default:
|
||||
errorf("unexpected describe rows response: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (cn *conn) Prepare(q string) (driver.Stmt, error) {
|
||||
return cn.prepareTo(q, cn.gname())
|
||||
}
|
||||
|
||||
func (cn *conn) Close() (err error) {
|
||||
defer errRecover(&err)
|
||||
cn.send(newWriteBuf('X'))
|
||||
|
||||
return cn.c.Close()
|
||||
}
|
||||
|
||||
// Implement the optional "Execer" interface for one-shot queries
|
||||
func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
// Check to see if we can use the "simpleQuery" interface, which is
|
||||
// *much* faster than going through prepare/exec
|
||||
if len(args) == 0 {
|
||||
return cn.simpleQuery(query)
|
||||
}
|
||||
|
||||
// Use the unnamed statement to defer planning until bind
|
||||
// time, or else value-based selectivity estimates cannot be
|
||||
// used.
|
||||
st, err := cn.prepareTo(query, "")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r, err := st.Exec(args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return r, err
|
||||
}
|
||||
|
||||
// Assumes len(*m) is > 5
|
||||
func (cn *conn) send(m *writeBuf) {
|
||||
b := (*m)[1:]
|
||||
binary.BigEndian.PutUint32(b, uint32(len(b)))
|
||||
|
||||
if (*m)[0] == 0 {
|
||||
*m = b
|
||||
}
|
||||
|
||||
_, err := cn.c.Write(*m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) recv() (t byte, r *readBuf) {
|
||||
for {
|
||||
t, r = cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
panic(parseError(r))
|
||||
case 'N':
|
||||
// ignore
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (cn *conn) recv1() (byte, *readBuf) {
|
||||
x := make([]byte, 5)
|
||||
_, err := io.ReadFull(cn.buf, x)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
b := readBuf(x[1:])
|
||||
y := make([]byte, b.int32()-4)
|
||||
_, err = io.ReadFull(cn.buf, y)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return x[0], (*readBuf)(&y)
|
||||
}
|
||||
|
||||
func (cn *conn) ssl(o Values) {
|
||||
tlsConf := tls.Config{}
|
||||
switch mode := o.Get("sslmode"); mode {
|
||||
case "require", "":
|
||||
tlsConf.InsecureSkipVerify = true
|
||||
case "verify-full":
|
||||
// fall out
|
||||
case "disable":
|
||||
return
|
||||
default:
|
||||
errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode)
|
||||
}
|
||||
|
||||
w := newWriteBuf(0)
|
||||
w.int32(80877103)
|
||||
cn.send(w)
|
||||
|
||||
b := make([]byte, 1)
|
||||
_, err := io.ReadFull(cn.c, b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if b[0] != 'S' {
|
||||
panic(ErrSSLNotSupported)
|
||||
}
|
||||
|
||||
cn.c = tls.Client(cn.c, &tlsConf)
|
||||
}
|
||||
|
||||
func (cn *conn) startup(o Values) {
|
||||
w := newWriteBuf(0)
|
||||
w.int32(196608)
|
||||
w.string("user")
|
||||
w.string(o.Get("user"))
|
||||
w.string("database")
|
||||
w.string(o.Get("dbname"))
|
||||
w.string("")
|
||||
cn.send(w)
|
||||
|
||||
for {
|
||||
t, r := cn.recv()
|
||||
switch t {
|
||||
case 'K', 'S':
|
||||
case 'R':
|
||||
cn.auth(r, o)
|
||||
case 'Z':
|
||||
return
|
||||
default:
|
||||
errorf("unknown response for startup: %q", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) auth(r *readBuf, o Values) {
|
||||
switch code := r.int32(); code {
|
||||
case 0:
|
||||
// OK
|
||||
case 3:
|
||||
w := newWriteBuf('p')
|
||||
w.string(o.Get("password"))
|
||||
cn.send(w)
|
||||
|
||||
t, r := cn.recv()
|
||||
if t != 'R' {
|
||||
errorf("unexpected password response: %q", t)
|
||||
}
|
||||
|
||||
if r.int32() != 0 {
|
||||
errorf("unexpected authentication response: %q", t)
|
||||
}
|
||||
case 5:
|
||||
s := string(r.next(4))
|
||||
w := newWriteBuf('p')
|
||||
w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
|
||||
cn.send(w)
|
||||
|
||||
t, r := cn.recv()
|
||||
if t != 'R' {
|
||||
errorf("unexpected password response: %q", t)
|
||||
}
|
||||
|
||||
if r.int32() != 0 {
|
||||
errorf("unexpected authentication resoonse: %q", t)
|
||||
}
|
||||
default:
|
||||
errorf("unknown authentication response: %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
cn *conn
|
||||
name string
|
||||
query string
|
||||
cols []string
|
||||
nparams int
|
||||
rowTyps []oid
|
||||
paramTyps []oid
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (st *stmt) Close() (err error) {
|
||||
if st.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer errRecover(&err)
|
||||
|
||||
w := newWriteBuf('C')
|
||||
w.byte('S')
|
||||
w.string(st.name)
|
||||
st.cn.send(w)
|
||||
|
||||
st.cn.send(newWriteBuf('S'))
|
||||
|
||||
t, _ := st.cn.recv()
|
||||
if t != '3' {
|
||||
errorf("unexpected close response: %q", t)
|
||||
}
|
||||
st.closed = true
|
||||
|
||||
t, _ = st.cn.recv()
|
||||
if t != 'Z' {
|
||||
errorf("expected ready for query, but got: %q", t)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *stmt) Query(v []driver.Value) (_ driver.Rows, err error) {
|
||||
defer errRecover(&err)
|
||||
st.exec(v)
|
||||
return &rows{st: st}, nil
|
||||
}
|
||||
|
||||
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
if len(v) == 0 {
|
||||
return st.cn.simpleQuery(st.query)
|
||||
}
|
||||
st.exec(v)
|
||||
|
||||
for {
|
||||
t, r := st.cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'C':
|
||||
res = parseComplete(r.string())
|
||||
case 'Z':
|
||||
// done
|
||||
return
|
||||
case 'D':
|
||||
errorf("unexpected data row returned in Exec; check your query")
|
||||
case 'S', 'N':
|
||||
// Ignore
|
||||
default:
|
||||
errorf("unknown exec response: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (st *stmt) exec(v []driver.Value) {
|
||||
w := newWriteBuf('B')
|
||||
w.string("")
|
||||
w.string(st.name)
|
||||
w.int16(0)
|
||||
w.int16(len(v))
|
||||
for i, x := range v {
|
||||
if x == nil {
|
||||
w.int32(-1)
|
||||
} else {
|
||||
b := encode(x, st.paramTyps[i])
|
||||
w.int32(len(b))
|
||||
w.bytes(b)
|
||||
}
|
||||
}
|
||||
w.int16(0)
|
||||
st.cn.send(w)
|
||||
|
||||
w = newWriteBuf('E')
|
||||
w.string("")
|
||||
w.int32(0)
|
||||
st.cn.send(w)
|
||||
|
||||
st.cn.send(newWriteBuf('S'))
|
||||
|
||||
var err error
|
||||
for {
|
||||
t, r := st.cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case '2':
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
case 'Z':
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
case 'N':
|
||||
// ignore
|
||||
default:
|
||||
errorf("unexpected bind response: %q", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (st *stmt) NumInput() int {
|
||||
return st.nparams
|
||||
}
|
||||
|
||||
type result int64
|
||||
|
||||
func (i result) RowsAffected() (int64, error) {
|
||||
return int64(i), nil
|
||||
}
|
||||
|
||||
func (i result) LastInsertId() (int64, error) {
|
||||
return 0, ErrNotSupported
|
||||
}
|
||||
|
||||
func parseComplete(s string) driver.Result {
|
||||
parts := strings.Split(s, " ")
|
||||
n, _ := strconv.ParseInt(parts[len(parts)-1], 10, 64)
|
||||
return result(n)
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
st *stmt
|
||||
done bool
|
||||
}
|
||||
|
||||
func (rs *rows) Close() error {
|
||||
for {
|
||||
err := rs.Next(nil)
|
||||
switch err {
|
||||
case nil:
|
||||
case io.EOF:
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (rs *rows) Columns() []string {
|
||||
return rs.st.cols
|
||||
}
|
||||
|
||||
func (rs *rows) Next(dest []driver.Value) (err error) {
|
||||
if rs.done {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
defer errRecover(&err)
|
||||
|
||||
for {
|
||||
t, r := rs.st.cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'C', 'S', 'N':
|
||||
continue
|
||||
case 'Z':
|
||||
rs.done = true
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return io.EOF
|
||||
case 'D':
|
||||
n := r.int16()
|
||||
for i := 0; i < len(dest) && i < n; i++ {
|
||||
l := r.int32()
|
||||
if l == -1 {
|
||||
dest[i] = nil
|
||||
continue
|
||||
}
|
||||
dest[i] = decode(r.next(l), rs.st.rowTyps[i])
|
||||
}
|
||||
return
|
||||
default:
|
||||
errorf("unexpected message after execute: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func md5s(s string) string {
|
||||
h := md5.New()
|
||||
h.Write([]byte(s))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// parseEnviron tries to mimic some of libpq's environment handling
|
||||
//
|
||||
// To ease testing, it does not directly reference os.Environ, but is
|
||||
// designed to accept its output.
|
||||
//
|
||||
// Environment-set connection information is intended to have a higher
|
||||
// precedence than a library default but lower than any explicitly
|
||||
// passed information (such as in the URL or connection string).
|
||||
func parseEnviron(env []string) (out map[string]string) {
|
||||
out = make(map[string]string)
|
||||
|
||||
for _, v := range env {
|
||||
parts := strings.SplitN(v, "=", 2)
|
||||
|
||||
accrue := func(keyname string) {
|
||||
out[keyname] = parts[1]
|
||||
}
|
||||
|
||||
// The order of these is the same as is seen in the
|
||||
// PostgreSQL 9.1 manual, with omissions briefly
|
||||
// noted.
|
||||
switch parts[0] {
|
||||
case "PGHOST":
|
||||
accrue("host")
|
||||
case "PGHOSTADDR":
|
||||
accrue("hostaddr")
|
||||
case "PGPORT":
|
||||
accrue("port")
|
||||
case "PGDATABASE":
|
||||
accrue("dbname")
|
||||
case "PGUSER":
|
||||
accrue("user")
|
||||
case "PGPASSWORD":
|
||||
accrue("password")
|
||||
// skip PGPASSFILE, PGSERVICE, PGSERVICEFILE,
|
||||
// PGREALM
|
||||
case "PGOPTIONS":
|
||||
accrue("options")
|
||||
case "PGAPPNAME":
|
||||
accrue("application_name")
|
||||
case "PGSSLMODE":
|
||||
accrue("sslmode")
|
||||
case "PGREQUIRESSL":
|
||||
accrue("requiressl")
|
||||
case "PGSSLCERT":
|
||||
accrue("sslcert")
|
||||
case "PGSSLKEY":
|
||||
accrue("sslkey")
|
||||
case "PGSSLROOTCERT":
|
||||
accrue("sslrootcert")
|
||||
case "PGSSLCRL":
|
||||
accrue("sslcrl")
|
||||
case "PGREQUIREPEER":
|
||||
accrue("requirepeer")
|
||||
case "PGKRBSRVNAME":
|
||||
accrue("krbsrvname")
|
||||
case "PGGSSLIB":
|
||||
accrue("gsslib")
|
||||
case "PGCONNECT_TIMEOUT":
|
||||
accrue("connect_timeout")
|
||||
case "PGCLIENTENCODING":
|
||||
accrue("client_encoding")
|
||||
// skip PGDATESTYLE, PGTZ, PGGEQO, PGSYSCONFDIR,
|
||||
// PGLOCALEDIR
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
@ -1,121 +0,0 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func encode(x interface{}, pgtypoid oid) []byte {
|
||||
switch v := x.(type) {
|
||||
case int64:
|
||||
return []byte(fmt.Sprintf("%d", v))
|
||||
case float32, float64:
|
||||
return []byte(fmt.Sprintf("%f", v))
|
||||
case []byte:
|
||||
if pgtypoid == t_bytea {
|
||||
return []byte(fmt.Sprintf("\\x%x", v))
|
||||
}
|
||||
|
||||
return v
|
||||
case string:
|
||||
if pgtypoid == t_bytea {
|
||||
return []byte(fmt.Sprintf("\\x%x", v))
|
||||
}
|
||||
|
||||
return []byte(v)
|
||||
case bool:
|
||||
return []byte(fmt.Sprintf("%t", v))
|
||||
case time.Time:
|
||||
return []byte(v.Format(time.RFC3339Nano))
|
||||
default:
|
||||
errorf("encode: unknown type for %T", v)
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func decode(s []byte, typ oid) interface{} {
|
||||
switch typ {
|
||||
case t_bytea:
|
||||
s = s[2:] // trim off "\\x"
|
||||
d := make([]byte, hex.DecodedLen(len(s)))
|
||||
_, err := hex.Decode(d, s)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return d
|
||||
case t_timestamptz:
|
||||
return mustParse("2006-01-02 15:04:05-07", typ, s)
|
||||
case t_timestamp:
|
||||
return mustParse("2006-01-02 15:04:05", typ, s)
|
||||
case t_time:
|
||||
return mustParse("15:04:05", typ, s)
|
||||
case t_timetz:
|
||||
return mustParse("15:04:05-07", typ, s)
|
||||
case t_date:
|
||||
return mustParse("2006-01-02", typ, s)
|
||||
case t_bool:
|
||||
return s[0] == 't'
|
||||
case t_int8, t_int2, t_int4:
|
||||
i, err := strconv.ParseInt(string(s), 10, 64)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return i
|
||||
case t_float4, t_float8:
|
||||
bits := 64
|
||||
if typ == t_float4 {
|
||||
bits = 32
|
||||
}
|
||||
f, err := strconv.ParseFloat(string(s), bits)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func mustParse(f string, typ oid, s []byte) time.Time {
|
||||
str := string(s)
|
||||
|
||||
// Special case until time.Parse bug is fixed:
|
||||
// http://code.google.com/p/go/issues/detail?id=3487
|
||||
if str[len(str)-2] == '.' {
|
||||
str += "0"
|
||||
}
|
||||
|
||||
// check for a 30-minute-offset timezone
|
||||
if (typ == t_timestamptz || typ == t_timetz) &&
|
||||
str[len(str)-3] == ':' {
|
||||
f += ":00"
|
||||
}
|
||||
t, err := time.Parse(f, str)
|
||||
if err != nil {
|
||||
errorf("decode: %s", err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
nt.Time, nt.Valid = value.(time.Time)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
@ -1,108 +0,0 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
Efatal = "FATAL"
|
||||
Epanic = "PANIC"
|
||||
Ewarning = "WARNING"
|
||||
Enotice = "NOTICE"
|
||||
Edebug = "DEBUG"
|
||||
Einfo = "INFO"
|
||||
Elog = "LOG"
|
||||
)
|
||||
|
||||
type Error error
|
||||
|
||||
type PGError interface {
|
||||
Error() string
|
||||
Fatal() bool
|
||||
Get(k byte) (v string)
|
||||
}
|
||||
type pgError struct {
|
||||
c map[byte]string
|
||||
}
|
||||
|
||||
func parseError(r *readBuf) *pgError {
|
||||
err := &pgError{make(map[byte]string)}
|
||||
for t := r.byte(); t != 0; t = r.byte() {
|
||||
err.c[t] = r.string()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (err *pgError) Get(k byte) (v string) {
|
||||
v, _ = err.c[k]
|
||||
return
|
||||
}
|
||||
|
||||
func (err *pgError) Fatal() bool {
|
||||
return err.Get('S') == Efatal
|
||||
}
|
||||
|
||||
func (err *pgError) Error() string {
|
||||
var s string
|
||||
for k, v := range err.c {
|
||||
s += fmt.Sprintf(" %c:%q", k, v)
|
||||
}
|
||||
return "pq: " + s[1:]
|
||||
}
|
||||
|
||||
func errorf(s string, args ...interface{}) {
|
||||
panic(Error(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))))
|
||||
}
|
||||
|
||||
type SimplePGError struct {
|
||||
pgError
|
||||
}
|
||||
|
||||
func (err *SimplePGError) Error() string {
|
||||
return "pq: " + err.Get('M')
|
||||
}
|
||||
|
||||
func errRecoverWithPGReason(err *error) {
|
||||
e := recover()
|
||||
switch v := e.(type) {
|
||||
case nil:
|
||||
// Do nothing
|
||||
case *pgError:
|
||||
// Return a SimplePGError in place
|
||||
*err = &SimplePGError{*v}
|
||||
default:
|
||||
// Otherwise re-panic
|
||||
panic(e)
|
||||
}
|
||||
}
|
||||
|
||||
func errRecover(err *error) {
|
||||
e := recover()
|
||||
switch v := e.(type) {
|
||||
case nil:
|
||||
// Do nothing
|
||||
case runtime.Error:
|
||||
panic(v)
|
||||
case *pgError:
|
||||
if v.Fatal() {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
case *net.OpError:
|
||||
*err = driver.ErrBadConn
|
||||
case error:
|
||||
if v == io.EOF || v.(error).Error() == "remote error: handshake failure" {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown error: %#v", e))
|
||||
}
|
||||
}
|
@ -1,319 +0,0 @@
|
||||
package pq
|
||||
|
||||
type oid uint32
|
||||
|
||||
const (
|
||||
t_bool oid = 16
|
||||
t_bytea = 17
|
||||
t_char = 18
|
||||
t_name = 19
|
||||
t_int8 = 20
|
||||
t_int2 = 21
|
||||
t_int2vector = 22
|
||||
t_int4 = 23
|
||||
t_regproc = 24
|
||||
t_text = 25
|
||||
t_oid = 26
|
||||
t_tid = 27
|
||||
t_xid = 28
|
||||
t_cid = 29
|
||||
t_oidvector = 30
|
||||
t_pg_type = 71
|
||||
t_pg_attribute = 75
|
||||
t_pg_proc = 81
|
||||
t_pg_class = 83
|
||||
t_xml = 142
|
||||
t__xml = 143
|
||||
t_pg_node_tree = 194
|
||||
t_smgr = 210
|
||||
t_point = 600
|
||||
t_lseg = 601
|
||||
t_path = 602
|
||||
t_box = 603
|
||||
t_polygon = 604
|
||||
t_line = 628
|
||||
t__line = 629
|
||||
t_float4 = 700
|
||||
t_float8 = 701
|
||||
t_abstime = 702
|
||||
t_reltime = 703
|
||||
t_tinterval = 704
|
||||
t_unknown = 705
|
||||
t_circle = 718
|
||||
t__circle = 719
|
||||
t_money = 790
|
||||
t__money = 791
|
||||
t_macaddr = 829
|
||||
t_inet = 869
|
||||
t_cidr = 650
|
||||
t__bool = 1000
|
||||
t__bytea = 1001
|
||||
t__char = 1002
|
||||
t__name = 1003
|
||||
t__int2 = 1005
|
||||
t__int2vector = 1006
|
||||
t__int4 = 1007
|
||||
t__regproc = 1008
|
||||
t__text = 1009
|
||||
t__oid = 1028
|
||||
t__tid = 1010
|
||||
t__xid = 1011
|
||||
t__cid = 1012
|
||||
t__oidvector = 1013
|
||||
t__bpchar = 1014
|
||||
t__varchar = 1015
|
||||
t__int8 = 1016
|
||||
t__point = 1017
|
||||
t__lseg = 1018
|
||||
t__path = 1019
|
||||
t__box = 1020
|
||||
t__float4 = 1021
|
||||
t__float8 = 1022
|
||||
t__abstime = 1023
|
||||
t__reltime = 1024
|
||||
t__tinterval = 1025
|
||||
t__polygon = 1027
|
||||
t_aclitem = 1033
|
||||
t__aclitem = 1034
|
||||
t__macaddr = 1040
|
||||
t__inet = 1041
|
||||
t__cidr = 651
|
||||
t__cstring = 1263
|
||||
t_bpchar = 1042
|
||||
t_varchar = 1043
|
||||
t_date = 1082
|
||||
t_time = 1083
|
||||
t_timestamp = 1114
|
||||
t__timestamp = 1115
|
||||
t__date = 1182
|
||||
t__time = 1183
|
||||
t_timestamptz = 1184
|
||||
t__timestamptz = 1185
|
||||
t_interval = 1186
|
||||
t__interval = 1187
|
||||
t__numeric = 1231
|
||||
t_timetz = 1266
|
||||
t__timetz = 1270
|
||||
t_bit = 1560
|
||||
t__bit = 1561
|
||||
t_varbit = 1562
|
||||
t__varbit = 1563
|
||||
t_numeric = 1700
|
||||
t_refcursor = 1790
|
||||
t__refcursor = 2201
|
||||
t_regprocedure = 2202
|
||||
t_regoper = 2203
|
||||
t_regoperator = 2204
|
||||
t_regclass = 2205
|
||||
t_regtype = 2206
|
||||
t__regprocedure = 2207
|
||||
t__regoper = 2208
|
||||
t__regoperator = 2209
|
||||
t__regclass = 2210
|
||||
t__regtype = 2211
|
||||
t_uuid = 2950
|
||||
t__uuid = 2951
|
||||
t_tsvector = 3614
|
||||
t_gtsvector = 3642
|
||||
t_tsquery = 3615
|
||||
t_regconfig = 3734
|
||||
t_regdictionary = 3769
|
||||
t__tsvector = 3643
|
||||
t__gtsvector = 3644
|
||||
t__tsquery = 3645
|
||||
t__regconfig = 3735
|
||||
t__regdictionary = 3770
|
||||
t_txid_snapshot = 2970
|
||||
t__txid_snapshot = 2949
|
||||
t_record = 2249
|
||||
t__record = 2287
|
||||
t_cstring = 2275
|
||||
t_any = 2276
|
||||
t_anyarray = 2277
|
||||
t_void = 2278
|
||||
t_trigger = 2279
|
||||
t_language_handler = 2280
|
||||
t_internal = 2281
|
||||
t_opaque = 2282
|
||||
t_anyelement = 2283
|
||||
t_anynonarray = 2776
|
||||
t_anyenum = 3500
|
||||
t_fdw_handler = 3115
|
||||
t_pg_attrdef = 10000
|
||||
t_pg_constraint = 10001
|
||||
t_pg_inherits = 10002
|
||||
t_pg_index = 10003
|
||||
t_pg_operator = 10004
|
||||
t_pg_opfamily = 10005
|
||||
t_pg_opclass = 10006
|
||||
t_pg_am = 10117
|
||||
t_pg_amop = 10118
|
||||
t_pg_amproc = 10478
|
||||
t_pg_language = 10731
|
||||
t_pg_largeobject_metadata = 10732
|
||||
t_pg_largeobject = 10733
|
||||
t_pg_aggregate = 10734
|
||||
t_pg_statistic = 10735
|
||||
t_pg_rewrite = 10736
|
||||
t_pg_trigger = 10737
|
||||
t_pg_description = 10738
|
||||
t_pg_cast = 10739
|
||||
t_pg_enum = 10936
|
||||
t_pg_namespace = 10937
|
||||
t_pg_conversion = 10938
|
||||
t_pg_depend = 10939
|
||||
t_pg_database = 1248
|
||||
t_pg_db_role_setting = 10940
|
||||
t_pg_tablespace = 10941
|
||||
t_pg_pltemplate = 10942
|
||||
t_pg_authid = 2842
|
||||
t_pg_auth_members = 2843
|
||||
t_pg_shdepend = 10943
|
||||
t_pg_shdescription = 10944
|
||||
t_pg_ts_config = 10945
|
||||
t_pg_ts_config_map = 10946
|
||||
t_pg_ts_dict = 10947
|
||||
t_pg_ts_parser = 10948
|
||||
t_pg_ts_template = 10949
|
||||
t_pg_extension = 10950
|
||||
t_pg_foreign_data_wrapper = 10951
|
||||
t_pg_foreign_server = 10952
|
||||
t_pg_user_mapping = 10953
|
||||
t_pg_foreign_table = 10954
|
||||
t_pg_default_acl = 10955
|
||||
t_pg_seclabel = 10956
|
||||
t_pg_collation = 10957
|
||||
t_pg_toast_2604 = 10958
|
||||
t_pg_toast_2606 = 10959
|
||||
t_pg_toast_2609 = 10960
|
||||
t_pg_toast_1255 = 10961
|
||||
t_pg_toast_2618 = 10962
|
||||
t_pg_toast_3596 = 10963
|
||||
t_pg_toast_2619 = 10964
|
||||
t_pg_toast_2620 = 10965
|
||||
t_pg_toast_1262 = 10966
|
||||
t_pg_toast_2396 = 10967
|
||||
t_pg_toast_2964 = 10968
|
||||
t_pg_roles = 10970
|
||||
t_pg_shadow = 10973
|
||||
t_pg_group = 10976
|
||||
t_pg_user = 10979
|
||||
t_pg_rules = 10982
|
||||
t_pg_views = 10986
|
||||
t_pg_tables = 10989
|
||||
t_pg_indexes = 10993
|
||||
t_pg_stats = 10997
|
||||
t_pg_locks = 11001
|
||||
t_pg_cursors = 11004
|
||||
t_pg_available_extensions = 11007
|
||||
t_pg_available_extension_versions = 11010
|
||||
t_pg_prepared_xacts = 11013
|
||||
t_pg_prepared_statements = 11017
|
||||
t_pg_seclabels = 11020
|
||||
t_pg_settings = 11024
|
||||
t_pg_timezone_abbrevs = 11029
|
||||
t_pg_timezone_names = 11032
|
||||
t_pg_stat_all_tables = 11035
|
||||
t_pg_stat_xact_all_tables = 11039
|
||||
t_pg_stat_sys_tables = 11043
|
||||
t_pg_stat_xact_sys_tables = 11047
|
||||
t_pg_stat_user_tables = 11050
|
||||
t_pg_stat_xact_user_tables = 11054
|
||||
t_pg_statio_all_tables = 11057
|
||||
t_pg_statio_sys_tables = 11061
|
||||
t_pg_statio_user_tables = 11064
|
||||
t_pg_stat_all_indexes = 11067
|
||||
t_pg_stat_sys_indexes = 11071
|
||||
t_pg_stat_user_indexes = 11074
|
||||
t_pg_statio_all_indexes = 11077
|
||||
t_pg_statio_sys_indexes = 11081
|
||||
t_pg_statio_user_indexes = 11084
|
||||
t_pg_statio_all_sequences = 11087
|
||||
t_pg_statio_sys_sequences = 11090
|
||||
t_pg_statio_user_sequences = 11093
|
||||
t_pg_stat_activity = 11096
|
||||
t_pg_stat_replication = 11099
|
||||
t_pg_stat_database = 11102
|
||||
t_pg_stat_database_conflicts = 11105
|
||||
t_pg_stat_user_functions = 11108
|
||||
t_pg_stat_xact_user_functions = 11112
|
||||
t_pg_stat_bgwriter = 11116
|
||||
t_pg_user_mappings = 11119
|
||||
t_cardinal_number = 11669
|
||||
t_character_data = 11671
|
||||
t_sql_identifier = 11672
|
||||
t_information_schema_catalog_name = 11674
|
||||
t_time_stamp = 11676
|
||||
t_yes_or_no = 11677
|
||||
t_applicable_roles = 11680
|
||||
t_administrable_role_authorizations = 11684
|
||||
t_attributes = 11687
|
||||
t_character_sets = 11691
|
||||
t_check_constraint_routine_usage = 11695
|
||||
t_check_constraints = 11699
|
||||
t_collations = 11703
|
||||
t_collation_character_set_applicability = 11706
|
||||
t_column_domain_usage = 11709
|
||||
t_column_privileges = 11713
|
||||
t_column_udt_usage = 11717
|
||||
t_columns = 11721
|
||||
t_constraint_column_usage = 11725
|
||||
t_constraint_table_usage = 11729
|
||||
t_domain_constraints = 11733
|
||||
t_domain_udt_usage = 11737
|
||||
t_domains = 11740
|
||||
t_enabled_roles = 11744
|
||||
t_key_column_usage = 11747
|
||||
t_parameters = 11751
|
||||
t_referential_constraints = 11755
|
||||
t_role_column_grants = 11759
|
||||
t_routine_privileges = 11762
|
||||
t_role_routine_grants = 11766
|
||||
t_routines = 11769
|
||||
t_schemata = 11773
|
||||
t_sequences = 11776
|
||||
t_sql_features = 11780
|
||||
t_pg_toast_11779 = 11782
|
||||
t_sql_implementation_info = 11785
|
||||
t_pg_toast_11784 = 11787
|
||||
t_sql_languages = 11790
|
||||
t_pg_toast_11789 = 11792
|
||||
t_sql_packages = 11795
|
||||
t_pg_toast_11794 = 11797
|
||||
t_sql_parts = 11800
|
||||
t_pg_toast_11799 = 11802
|
||||
t_sql_sizing = 11805
|
||||
t_pg_toast_11804 = 11807
|
||||
t_sql_sizing_profiles = 11810
|
||||
t_pg_toast_11809 = 11812
|
||||
t_table_constraints = 11815
|
||||
t_table_privileges = 11819
|
||||
t_role_table_grants = 11823
|
||||
t_tables = 11826
|
||||
t_triggered_update_columns = 11830
|
||||
t_triggers = 11834
|
||||
t_usage_privileges = 11838
|
||||
t_role_usage_grants = 11842
|
||||
t_view_column_usage = 11845
|
||||
t_view_routine_usage = 11849
|
||||
t_view_table_usage = 11853
|
||||
t_views = 11857
|
||||
t_data_type_privileges = 11861
|
||||
t_element_types = 11865
|
||||
t__pg_foreign_data_wrappers = 11869
|
||||
t_foreign_data_wrapper_options = 11872
|
||||
t_foreign_data_wrappers = 11875
|
||||
t__pg_foreign_servers = 11878
|
||||
t_foreign_server_options = 11882
|
||||
t_foreign_servers = 11885
|
||||
t__pg_foreign_tables = 11888
|
||||
t_foreign_table_options = 11892
|
||||
t_foreign_tables = 11895
|
||||
t__pg_user_mappings = 11898
|
||||
t_user_mapping_options = 11901
|
||||
t_user_mappings = 11905
|
||||
t_t = 16806
|
||||
t__t = 16805
|
||||
t_temp = 16810
|
||||
t__temp = 16809
|
||||
)
|
@ -1,68 +0,0 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
nurl "net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseURL converts url to a connection string for driver.Open.
|
||||
// Example:
|
||||
//
|
||||
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
|
||||
//
|
||||
// converts to:
|
||||
//
|
||||
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
|
||||
//
|
||||
// A minimal example:
|
||||
//
|
||||
// "postgres://"
|
||||
//
|
||||
// This will be blank, causing driver.Open to use all of the defaults
|
||||
func ParseURL(url string) (string, error) {
|
||||
u, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if u.Scheme != "postgres" {
|
||||
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
|
||||
}
|
||||
|
||||
var kvs []string
|
||||
accrue := func(k, v string) {
|
||||
if v != "" {
|
||||
kvs = append(kvs, k+"="+v)
|
||||
}
|
||||
}
|
||||
|
||||
if u.User != nil {
|
||||
v := u.User.Username()
|
||||
accrue("user", v)
|
||||
|
||||
v, _ = u.User.Password()
|
||||
accrue("password", v)
|
||||
}
|
||||
|
||||
i := strings.Index(u.Host, ":")
|
||||
if i < 0 {
|
||||
accrue("host", u.Host)
|
||||
} else {
|
||||
accrue("host", u.Host[:i])
|
||||
accrue("port", u.Host[i+1:])
|
||||
}
|
||||
|
||||
if u.Path != "" {
|
||||
accrue("dbname", u.Path[1:])
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
for k, _ := range q {
|
||||
accrue(k, q.Get(k))
|
||||
}
|
||||
|
||||
sort.Strings(kvs) // Makes testing easier (not a performance concern)
|
||||
return strings.Join(kvs, " "), nil
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.x
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
@ -1,22 +0,0 @@
|
||||
Copyright (c) 2019 Jack Christensen
|
||||
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -1,8 +0,0 @@
|
||||
[![](https://godoc.org/github.com/jackc/chunkreader?status.svg)](https://godoc.org/github.com/jackc/chunkreader)
|
||||
[![Build Status](https://travis-ci.org/jackc/chunkreader.svg)](https://travis-ci.org/jackc/chunkreader)
|
||||
|
||||
# chunkreader
|
||||
|
||||
Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations.
|
||||
|
||||
Extracted from original implementation in https://github.com/jackc/pgx.
|
@ -1,104 +0,0 @@
|
||||
// Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations.
|
||||
package chunkreader
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// ChunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and
|
||||
// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually
|
||||
// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy.
|
||||
//
|
||||
// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is
|
||||
// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare
|
||||
// cases it would be advantageous to copy the bytes to another slice.
|
||||
type ChunkReader struct {
|
||||
r io.Reader
|
||||
|
||||
buf []byte
|
||||
rp, wp int // buf read position and write position
|
||||
|
||||
config Config
|
||||
}
|
||||
|
||||
// Config contains configuration parameters for ChunkReader.
|
||||
type Config struct {
|
||||
MinBufLen int // Minimum buffer length
|
||||
}
|
||||
|
||||
// New creates and returns a new ChunkReader for r with default configuration.
|
||||
func New(r io.Reader) *ChunkReader {
|
||||
cr, err := NewConfig(r, Config{})
|
||||
if err != nil {
|
||||
panic("default config can't be bad")
|
||||
}
|
||||
|
||||
return cr
|
||||
}
|
||||
|
||||
// NewConfig creates and a new ChunkReader for r configured by config.
|
||||
func NewConfig(r io.Reader, config Config) (*ChunkReader, error) {
|
||||
if config.MinBufLen == 0 {
|
||||
// By historical reasons Postgres currently has 8KB send buffer inside,
|
||||
// so here we want to have at least the same size buffer.
|
||||
// @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134
|
||||
// @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru
|
||||
config.MinBufLen = 8192
|
||||
}
|
||||
|
||||
return &ChunkReader{
|
||||
r: r,
|
||||
buf: make([]byte, config.MinBufLen),
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy
|
||||
// of buf. If an error occurs, buf will be nil.
|
||||
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
|
||||
// n bytes already in buf
|
||||
if (r.wp - r.rp) >= n {
|
||||
buf = r.buf[r.rp : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, err
|
||||
}
|
||||
|
||||
// available space in buf is less than n
|
||||
if len(r.buf) < n {
|
||||
r.copyBufContents(r.newBuf(n))
|
||||
}
|
||||
|
||||
// buf is large enough, but need to shift filled area to start to make enough contiguous space
|
||||
minReadCount := n - (r.wp - r.rp)
|
||||
if (len(r.buf) - r.wp) < minReadCount {
|
||||
newBuf := r.newBuf(n)
|
||||
r.copyBufContents(newBuf)
|
||||
}
|
||||
|
||||
if err := r.appendAtLeast(minReadCount); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf = r.buf[r.rp : r.rp+n]
|
||||
r.rp += n
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func (r *ChunkReader) appendAtLeast(fillLen int) error {
|
||||
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
|
||||
r.wp += n
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *ChunkReader) newBuf(size int) []byte {
|
||||
if size < r.config.MinBufLen {
|
||||
size = r.config.MinBufLen
|
||||
}
|
||||
return make([]byte, size)
|
||||
}
|
||||
|
||||
func (r *ChunkReader) copyBufContents(dest []byte) {
|
||||
r.wp = copy(dest, r.buf[r.rp:r.wp])
|
||||
r.rp = 0
|
||||
r.buf = dest
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
.envrc
|
||||
vendor/
|
||||
.vscode
|
@ -1,161 +0,0 @@
|
||||
# 1.14.0 (February 11, 2023)
|
||||
|
||||
* Fix: each connection attempt to new node gets own timeout (Nathan Giardina)
|
||||
* Set SNI for SSL connections (Stas Kelvich)
|
||||
* Fix: CopyFrom I/O race (Tommy Reilly)
|
||||
* Minor dependency upgrades
|
||||
|
||||
# 1.13.0 (August 6, 2022)
|
||||
|
||||
* Add sslpassword support (Eric McCormack and yun.xu)
|
||||
* Add prefer-standby target_session_attrs support (sergey.bashilov)
|
||||
* Fix GSS ErrorResponse handling (Oliver Tan)
|
||||
|
||||
# 1.12.1 (May 7, 2022)
|
||||
|
||||
* Fix: setting krbspn and krbsrvname in connection string (sireax)
|
||||
* Add support for Unix sockets on Windows (Eno Compton)
|
||||
* Stop ignoring ErrorResponse during SCRAM auth (Rafi Shamim)
|
||||
|
||||
# 1.12.0 (April 21, 2022)
|
||||
|
||||
* Add pluggable GSSAPI support (Oliver Tan)
|
||||
* Fix: Consider any "0A000" error a possible cached plan changed error due to locale
|
||||
* Better match psql fallback behavior with multiple hosts
|
||||
|
||||
# 1.11.0 (February 7, 2022)
|
||||
|
||||
* Support port in ip from LookupFunc to override config (James Hartig)
|
||||
* Fix TLS connection timeout (Blake Embrey)
|
||||
* Add support for read-only, primary, standby, prefer-standby target_session_attributes (Oscar)
|
||||
* Fix connect when receiving NoticeResponse
|
||||
|
||||
# 1.10.1 (November 20, 2021)
|
||||
|
||||
* Close without waiting for response (Kei Kamikawa)
|
||||
* Save waiting for network round-trip in CopyFrom (Rueian)
|
||||
* Fix concurrency issue with ContextWatcher
|
||||
* LRU.Get always checks context for cancellation / expiration (Georges Varouchas)
|
||||
|
||||
# 1.10.0 (July 24, 2021)
|
||||
|
||||
* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned.
|
||||
|
||||
# 1.9.0 (July 10, 2021)
|
||||
|
||||
* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr)
|
||||
* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle)
|
||||
* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard)
|
||||
* Fix default host when parsing URL without host but with port
|
||||
* Allow dbname query parameter in URL conn string
|
||||
* Update underlying dependencies
|
||||
|
||||
# 1.8.1 (March 25, 2021)
|
||||
|
||||
* Better connection string sanitization (ip.novikov)
|
||||
* Use proper pgpass location on Windows (Moshe Katz)
|
||||
* Use errors instead of golang.org/x/xerrors
|
||||
* Resume fallback on server error in Connect (Andrey Borodin)
|
||||
|
||||
# 1.8.0 (December 3, 2020)
|
||||
|
||||
* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes)
|
||||
|
||||
# 1.7.2 (November 3, 2020)
|
||||
|
||||
* Fix data value slices into work buffer with capacities larger than length.
|
||||
|
||||
# 1.7.1 (October 31, 2020)
|
||||
|
||||
* Do not asyncClose after receiving FATAL error from PostgreSQL server
|
||||
|
||||
# 1.7.0 (September 26, 2020)
|
||||
|
||||
* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded
|
||||
* Add ReceiveResults (Sebastiaan Mannem)
|
||||
* Fix parsing DSN connection with bad backslash
|
||||
* Add PgConn.CleanupDone so connection pools can determine when async close is complete
|
||||
|
||||
# 1.6.4 (July 29, 2020)
|
||||
|
||||
* Fix deadlock on error after CommandComplete but before ReadyForQuery
|
||||
* Fix panic on parsing DSN with trailing '='
|
||||
|
||||
# 1.6.3 (July 22, 2020)
|
||||
|
||||
* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo)
|
||||
|
||||
# 1.6.2 (July 14, 2020)
|
||||
|
||||
* Update pgservicefile library
|
||||
|
||||
# 1.6.1 (June 27, 2020)
|
||||
|
||||
* Update golang.org/x/crypto to latest
|
||||
* Update golang.org/x/text to 0.3.3
|
||||
* Fix error handling for bad PGSERVICE definition
|
||||
* Redact passwords in ParseConfig errors (Lukas Vogel)
|
||||
|
||||
# 1.6.0 (June 6, 2020)
|
||||
|
||||
* Fix panic when closing conn during cancellable query
|
||||
* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný)
|
||||
* Fix field descriptions available after command concluded (Tobias Salzmann)
|
||||
* Support connect_timeout (georgysavva)
|
||||
* Handle IPv6 in connection URLs (Lukas Vogel)
|
||||
* Fix ValidateConnect with cancelable context
|
||||
* Improve CopyFrom performance
|
||||
* Add Config.Copy (georgysavva)
|
||||
|
||||
# 1.5.0 (March 30, 2020)
|
||||
|
||||
* Update golang.org/x/crypto for security fix
|
||||
* Implement "verify-ca" SSL mode (Greg Curtis)
|
||||
|
||||
# 1.4.0 (March 7, 2020)
|
||||
|
||||
* Fix ExecParams and ExecPrepared handling of empty query.
|
||||
* Support reading config from PostgreSQL service files.
|
||||
|
||||
# 1.3.2 (February 14, 2020)
|
||||
|
||||
* Update chunkreader to v2.0.1 for optimized default buffer size.
|
||||
|
||||
# 1.3.1 (February 5, 2020)
|
||||
|
||||
* Fix CopyFrom deadlock when multiple NoticeResponse received during copy
|
||||
|
||||
# 1.3.0 (January 23, 2020)
|
||||
|
||||
* Add Hijack and Construct.
|
||||
* Update pgproto3 to v2.0.1.
|
||||
|
||||
# 1.2.1 (January 13, 2020)
|
||||
|
||||
* Fix data race in context cancellation introduced in v1.2.0.
|
||||
|
||||
# 1.2.0 (January 11, 2020)
|
||||
|
||||
## Features
|
||||
|
||||
* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag.
|
||||
* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases.
|
||||
|
||||
## Performance
|
||||
|
||||
* Improve performance when context.Background() is used. (bakape)
|
||||
* CommandTag.RowsAffected is faster and does not allocate.
|
||||
|
||||
## Fixes
|
||||
|
||||
* Try to cancel any in-progress query when a conn is closed by ctx cancel.
|
||||
* Handle NoticeResponse during CopyFrom.
|
||||
* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish.
|
||||
|
||||
# 1.1.0 (October 12, 2019)
|
||||
|
||||
* Add PgConn.IsBusy() method.
|
||||
|
||||
# 1.0.1 (September 19, 2019)
|
||||
|
||||
* Fix statement cache not properly cleaning discarded statements.
|
@ -1,22 +0,0 @@
|
||||
Copyright (c) 2019-2021 Jack Christensen
|
||||
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -1,62 +0,0 @@
|
||||
[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn)
|
||||
![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg)
|
||||
|
||||
---
|
||||
|
||||
This version is used with pgx `v4`. In pgx `v5` it is part of the https://github.com/jackc/pgx repository.
|
||||
|
||||
---
|
||||
|
||||
# pgconn
|
||||
|
||||
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
|
||||
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
|
||||
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
|
||||
low-level access to PostgreSQL functionality.
|
||||
|
||||
## Example Usage
|
||||
|
||||
```go
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
|
||||
if err != nil {
|
||||
log.Fatalln("pgconn failed to connect:", err)
|
||||
}
|
||||
defer pgConn.Close(context.Background())
|
||||
|
||||
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
|
||||
for result.NextRow() {
|
||||
fmt.Println("User 123 has email:", string(result.Values()[0]))
|
||||
}
|
||||
_, err = result.Close()
|
||||
if err != nil {
|
||||
log.Fatalln("failed reading result:", err)
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
|
||||
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
|
||||
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
|
||||
environment variable handling.
|
||||
|
||||
### Example Test Environment
|
||||
|
||||
Connect to your PostgreSQL server and run:
|
||||
|
||||
```
|
||||
create database pgx_test;
|
||||
```
|
||||
|
||||
Now you can run the tests:
|
||||
|
||||
```bash
|
||||
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
|
||||
```
|
||||
|
||||
### Connection and Authentication Tests
|
||||
|
||||
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
|
||||
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
|
||||
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
|
||||
authentication code.
|
@ -1,270 +0,0 @@
|
||||
// SCRAM-SHA-256 authentication
|
||||
//
|
||||
// Resources:
|
||||
// https://tools.ietf.org/html/rfc5802
|
||||
// https://tools.ietf.org/html/rfc8265
|
||||
// https://www.postgresql.org/docs/current/sasl-authentication.html
|
||||
//
|
||||
// Inspiration drawn from other implementations:
|
||||
// https://github.com/lib/pq/pull/608
|
||||
// https://github.com/lib/pq/pull/788
|
||||
// https://github.com/lib/pq/pull/833
|
||||
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/text/secure/precis"
|
||||
)
|
||||
|
||||
const clientNonceLen = 18
|
||||
|
||||
// Perform SCRAM authentication.
|
||||
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
|
||||
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send client-first-message in a SASLInitialResponse
|
||||
saslInitialResponse := &pgproto3.SASLInitialResponse{
|
||||
AuthMechanism: "SCRAM-SHA-256",
|
||||
Data: sc.clientFirstMessage(),
|
||||
}
|
||||
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-first-message payload in a AuthenticationSASLContinue.
|
||||
saslContinue, err := c.rxSASLContinue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = sc.recvServerFirstMessage(saslContinue.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send client-final-message in a SASLResponse
|
||||
saslResponse := &pgproto3.SASLResponse{
|
||||
Data: []byte(sc.clientFinalMessage()),
|
||||
}
|
||||
_, err = c.conn.Write(saslResponse.Encode(nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive server-final-message payload in a AuthenticationSASLFinal.
|
||||
saslFinal, err := c.rxSASLFinal()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sc.recvServerFinalMessage(saslFinal.Data)
|
||||
}
|
||||
|
||||
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationSASLContinue:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
|
||||
}
|
||||
|
||||
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationSASLFinal:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
|
||||
}
|
||||
|
||||
type scramClient struct {
|
||||
serverAuthMechanisms []string
|
||||
password []byte
|
||||
clientNonce []byte
|
||||
|
||||
clientFirstMessageBare []byte
|
||||
|
||||
serverFirstMessage []byte
|
||||
clientAndServerNonce []byte
|
||||
salt []byte
|
||||
iterations int
|
||||
|
||||
saltedPassword []byte
|
||||
authMessage []byte
|
||||
}
|
||||
|
||||
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
|
||||
sc := &scramClient{
|
||||
serverAuthMechanisms: serverAuthMechanisms,
|
||||
}
|
||||
|
||||
// Ensure server supports SCRAM-SHA-256
|
||||
hasScramSHA256 := false
|
||||
for _, mech := range sc.serverAuthMechanisms {
|
||||
if mech == "SCRAM-SHA-256" {
|
||||
hasScramSHA256 = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasScramSHA256 {
|
||||
return nil, errors.New("server does not support SCRAM-SHA-256")
|
||||
}
|
||||
|
||||
// precis.OpaqueString is equivalent to SASLprep for password.
|
||||
var err error
|
||||
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
|
||||
if err != nil {
|
||||
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
|
||||
sc.password = []byte(password)
|
||||
}
|
||||
|
||||
buf := make([]byte, clientNonceLen)
|
||||
_, err = rand.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
|
||||
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
|
||||
|
||||
return sc, nil
|
||||
}
|
||||
|
||||
func (sc *scramClient) clientFirstMessage() []byte {
|
||||
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
|
||||
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
|
||||
}
|
||||
|
||||
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
|
||||
sc.serverFirstMessage = serverFirstMessage
|
||||
buf := serverFirstMessage
|
||||
if !bytes.HasPrefix(buf, []byte("r=")) {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
|
||||
}
|
||||
buf = buf[2:]
|
||||
|
||||
idx := bytes.IndexByte(buf, ',')
|
||||
if idx == -1 {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||
}
|
||||
sc.clientAndServerNonce = buf[:idx]
|
||||
buf = buf[idx+1:]
|
||||
|
||||
if !bytes.HasPrefix(buf, []byte("s=")) {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
|
||||
}
|
||||
buf = buf[2:]
|
||||
|
||||
idx = bytes.IndexByte(buf, ',')
|
||||
if idx == -1 {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||
}
|
||||
saltStr := buf[:idx]
|
||||
buf = buf[idx+1:]
|
||||
|
||||
if !bytes.HasPrefix(buf, []byte("i=")) {
|
||||
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
|
||||
}
|
||||
buf = buf[2:]
|
||||
iterationsStr := buf
|
||||
|
||||
var err error
|
||||
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
|
||||
}
|
||||
|
||||
sc.iterations, err = strconv.Atoi(string(iterationsStr))
|
||||
if err != nil || sc.iterations <= 0 {
|
||||
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
|
||||
return errors.New("invalid SCRAM nonce: did not start with client nonce")
|
||||
}
|
||||
|
||||
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
|
||||
return errors.New("invalid SCRAM nonce: did not include server nonce")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sc *scramClient) clientFinalMessage() string {
|
||||
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
|
||||
|
||||
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
|
||||
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
|
||||
|
||||
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
|
||||
|
||||
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
|
||||
}
|
||||
|
||||
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
|
||||
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
|
||||
return errors.New("invalid SCRAM server-final-message received from server")
|
||||
}
|
||||
|
||||
serverSignature := serverFinalMessage[2:]
|
||||
|
||||
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
|
||||
return errors.New("invalid SCRAM ServerSignature received from server")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeHMAC(key, msg []byte) []byte {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write(msg)
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
||||
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
|
||||
storedKey := sha256.Sum256(clientKey)
|
||||
clientSignature := computeHMAC(storedKey[:], authMessage)
|
||||
|
||||
clientProof := make([]byte, len(clientSignature))
|
||||
for i := 0; i < len(clientSignature); i++ {
|
||||
clientProof[i] = clientKey[i] ^ clientSignature[i]
|
||||
}
|
||||
|
||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
|
||||
base64.StdEncoding.Encode(buf, clientProof)
|
||||
return buf
|
||||
}
|
||||
|
||||
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||
serverSignature := computeHMAC(serverKey, authMessage)
|
||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||
base64.StdEncoding.Encode(buf, serverSignature)
|
||||
return buf
|
||||
}
|
@ -1,905 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/chunkreader/v2"
|
||||
"github.com/jackc/pgpassfile"
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
"github.com/jackc/pgservicefile"
|
||||
)
|
||||
|
||||
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||
type GetSSLPasswordFunc func(ctx context.Context) string
|
||||
|
||||
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
|
||||
// manually initialized Config will cause ConnectConfig to panic.
|
||||
type Config struct {
|
||||
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
ConnectTimeout time.Duration
|
||||
DialFunc DialFunc // e.g. net.Dialer.DialContext
|
||||
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
|
||||
BuildFrontend BuildFrontendFunc
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
|
||||
KerberosSrvName string
|
||||
KerberosSpn string
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
ValidateConnect ValidateConnectFunc
|
||||
|
||||
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
|
||||
// or prepare statements). If this returns an error the connection attempt fails.
|
||||
AfterConnect AfterConnectFunc
|
||||
|
||||
// OnNotice is a callback function called when a notice response is received.
|
||||
OnNotice NoticeHandler
|
||||
|
||||
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
|
||||
OnNotification NotificationHandler
|
||||
|
||||
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
||||
}
|
||||
|
||||
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
|
||||
type ParseConfigOptions struct {
|
||||
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
|
||||
// PQsetSSLKeyPassHook_OpenSSL.
|
||||
GetSSLPassword GetSSLPasswordFunc
|
||||
}
|
||||
|
||||
// Copy returns a deep copy of the config that is safe to use and modify.
|
||||
// The only exception is the TLSConfig field:
|
||||
// according to the tls.Config docs it must not be modified after creation.
|
||||
func (c *Config) Copy() *Config {
|
||||
newConf := new(Config)
|
||||
*newConf = *c
|
||||
if newConf.TLSConfig != nil {
|
||||
newConf.TLSConfig = c.TLSConfig.Clone()
|
||||
}
|
||||
if newConf.RuntimeParams != nil {
|
||||
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
|
||||
for k, v := range c.RuntimeParams {
|
||||
newConf.RuntimeParams[k] = v
|
||||
}
|
||||
}
|
||||
if newConf.Fallbacks != nil {
|
||||
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
|
||||
for i, fallback := range c.Fallbacks {
|
||||
newFallback := new(FallbackConfig)
|
||||
*newFallback = *fallback
|
||||
if newFallback.TLSConfig != nil {
|
||||
newFallback.TLSConfig = fallback.TLSConfig.Clone()
|
||||
}
|
||||
newConf.Fallbacks[i] = newFallback
|
||||
}
|
||||
}
|
||||
return newConf
|
||||
}
|
||||
|
||||
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
|
||||
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
|
||||
type FallbackConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16
|
||||
TLSConfig *tls.Config // nil disables TLS
|
||||
}
|
||||
|
||||
// isAbsolutePath checks if the provided value is an absolute path either
|
||||
// beginning with a forward slash (as on Linux-based systems) or with a capital
|
||||
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
|
||||
func isAbsolutePath(path string) bool {
|
||||
isWindowsPath := func(p string) bool {
|
||||
if len(p) < 3 {
|
||||
return false
|
||||
}
|
||||
drive := p[0]
|
||||
colon := p[1]
|
||||
backslash := p[2]
|
||||
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(path, "/") || isWindowsPath(path)
|
||||
}
|
||||
|
||||
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
|
||||
// net.Dial.
|
||||
func NetworkAddress(host string, port uint16) (network, address string) {
|
||||
if isAbsolutePath(host) {
|
||||
network = "unix"
|
||||
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
|
||||
} else {
|
||||
network = "tcp"
|
||||
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
}
|
||||
return network, address
|
||||
}
|
||||
|
||||
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
|
||||
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
|
||||
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
|
||||
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
|
||||
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
|
||||
//
|
||||
// # Example DSN
|
||||
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||
//
|
||||
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
|
||||
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
|
||||
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
|
||||
// not be modified individually. They should all be modified or all left unchanged.
|
||||
//
|
||||
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
||||
// values that will be tried in order. This can be used as part of a high availability system. See
|
||||
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||
//
|
||||
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
|
||||
// via database URL or DSN:
|
||||
//
|
||||
// PGHOST
|
||||
// PGPORT
|
||||
// PGDATABASE
|
||||
// PGUSER
|
||||
// PGPASSWORD
|
||||
// PGPASSFILE
|
||||
// PGSERVICE
|
||||
// PGSERVICEFILE
|
||||
// PGSSLMODE
|
||||
// PGSSLCERT
|
||||
// PGSSLKEY
|
||||
// PGSSLROOTCERT
|
||||
// PGSSLPASSWORD
|
||||
// PGAPPNAME
|
||||
// PGCONNECT_TIMEOUT
|
||||
// PGTARGETSESSIONATTRS
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
||||
//
|
||||
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
||||
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
||||
//
|
||||
// Important Security Notes:
|
||||
//
|
||||
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
||||
// not set.
|
||||
//
|
||||
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
||||
// security each sslmode provides.
|
||||
//
|
||||
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
|
||||
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
|
||||
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
|
||||
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
|
||||
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
|
||||
// TLSConfig.
|
||||
//
|
||||
// Other known differences with libpq:
|
||||
//
|
||||
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
|
||||
// does not.
|
||||
//
|
||||
// In addition, ParseConfig accepts the following options:
|
||||
//
|
||||
// min_read_buffer_size
|
||||
// The minimum size of the internal read buffer. Default 8192.
|
||||
// servicefile
|
||||
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
|
||||
// part of the connection string.
|
||||
func ParseConfig(connString string) (*Config, error) {
|
||||
var parseConfigOptions ParseConfigOptions
|
||||
return ParseConfigWithOptions(connString, parseConfigOptions)
|
||||
}
|
||||
|
||||
// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
|
||||
// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
|
||||
// get the SSL password.
|
||||
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
|
||||
defaultSettings := defaultSettings()
|
||||
envSettings := parseEnvSettings()
|
||||
|
||||
connStringSettings := make(map[string]string)
|
||||
if connString != "" {
|
||||
var err error
|
||||
// connString may be a database URL or a DSN
|
||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||
connStringSettings, err = parseURLSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
|
||||
}
|
||||
} else {
|
||||
connStringSettings, err = parseDSNSettings(connString)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
|
||||
if service, present := settings["service"]; present {
|
||||
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
|
||||
}
|
||||
|
||||
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
|
||||
}
|
||||
|
||||
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
createdByParseConfig: true,
|
||||
Database: settings["database"],
|
||||
User: settings["user"],
|
||||
Password: settings["password"],
|
||||
RuntimeParams: make(map[string]string),
|
||||
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
|
||||
}
|
||||
|
||||
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
|
||||
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
|
||||
}
|
||||
config.ConnectTimeout = connectTimeout
|
||||
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
|
||||
} else {
|
||||
defaultDialer := makeDefaultDialer()
|
||||
config.DialFunc = defaultDialer.DialContext
|
||||
}
|
||||
|
||||
config.LookupFunc = makeDefaultResolver().LookupHost
|
||||
|
||||
notRuntimeParams := map[string]struct{}{
|
||||
"host": {},
|
||||
"port": {},
|
||||
"database": {},
|
||||
"user": {},
|
||||
"password": {},
|
||||
"passfile": {},
|
||||
"connect_timeout": {},
|
||||
"sslmode": {},
|
||||
"sslkey": {},
|
||||
"sslcert": {},
|
||||
"sslrootcert": {},
|
||||
"sslpassword": {},
|
||||
"sslsni": {},
|
||||
"krbspn": {},
|
||||
"krbsrvname": {},
|
||||
"target_session_attrs": {},
|
||||
"min_read_buffer_size": {},
|
||||
"service": {},
|
||||
"servicefile": {},
|
||||
}
|
||||
|
||||
// Adding kerberos configuration
|
||||
if _, present := settings["krbsrvname"]; present {
|
||||
config.KerberosSrvName = settings["krbsrvname"]
|
||||
}
|
||||
if _, present := settings["krbspn"]; present {
|
||||
config.KerberosSpn = settings["krbspn"]
|
||||
}
|
||||
|
||||
for k, v := range settings {
|
||||
if _, present := notRuntimeParams[k]; present {
|
||||
continue
|
||||
}
|
||||
config.RuntimeParams[k] = v
|
||||
}
|
||||
|
||||
fallbacks := []*FallbackConfig{}
|
||||
|
||||
hosts := strings.Split(settings["host"], ",")
|
||||
ports := strings.Split(settings["port"], ",")
|
||||
|
||||
for i, host := range hosts {
|
||||
var portStr string
|
||||
if i < len(ports) {
|
||||
portStr = ports[i]
|
||||
} else {
|
||||
portStr = ports[0]
|
||||
}
|
||||
|
||||
port, err := parsePort(portStr)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
|
||||
}
|
||||
|
||||
var tlsConfigs []*tls.Config
|
||||
|
||||
// Ignore TLS settings if Unix domain socket like libpq
|
||||
if network, _ := NetworkAddress(host, port); network == "unix" {
|
||||
tlsConfigs = append(tlsConfigs, nil)
|
||||
} else {
|
||||
var err error
|
||||
tlsConfigs, err = configTLS(settings, host, options)
|
||||
if err != nil {
|
||||
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
|
||||
}
|
||||
}
|
||||
|
||||
for _, tlsConfig := range tlsConfigs {
|
||||
fallbacks = append(fallbacks, &FallbackConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
TLSConfig: tlsConfig,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
config.Host = fallbacks[0].Host
|
||||
config.Port = fallbacks[0].Port
|
||||
config.TLSConfig = fallbacks[0].TLSConfig
|
||||
config.Fallbacks = fallbacks[1:]
|
||||
|
||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||
if err == nil {
|
||||
if config.Password == "" {
|
||||
host := config.Host
|
||||
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
|
||||
host = "localhost"
|
||||
}
|
||||
|
||||
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
|
||||
}
|
||||
}
|
||||
|
||||
switch tsa := settings["target_session_attrs"]; tsa {
|
||||
case "read-write":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||
case "read-only":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
|
||||
case "primary":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
|
||||
case "standby":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
|
||||
case "prefer-standby":
|
||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
|
||||
case "any":
|
||||
// do nothing
|
||||
default:
|
||||
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func mergeSettings(settingSets ...map[string]string) map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
for _, s2 := range settingSets {
|
||||
for k, v := range s2 {
|
||||
settings[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func parseEnvSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
nameMap := map[string]string{
|
||||
"PGHOST": "host",
|
||||
"PGPORT": "port",
|
||||
"PGDATABASE": "database",
|
||||
"PGUSER": "user",
|
||||
"PGPASSWORD": "password",
|
||||
"PGPASSFILE": "passfile",
|
||||
"PGAPPNAME": "application_name",
|
||||
"PGCONNECT_TIMEOUT": "connect_timeout",
|
||||
"PGSSLMODE": "sslmode",
|
||||
"PGSSLKEY": "sslkey",
|
||||
"PGSSLCERT": "sslcert",
|
||||
"PGSSLSNI": "sslsni",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
"PGSSLPASSWORD": "sslpassword",
|
||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||
"PGSERVICE": "service",
|
||||
"PGSERVICEFILE": "servicefile",
|
||||
}
|
||||
|
||||
for envname, realname := range nameMap {
|
||||
value := os.Getenv(envname)
|
||||
if value != "" {
|
||||
settings[realname] = value
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func parseURLSettings(connString string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
url, err := url.Parse(connString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if url.User != nil {
|
||||
settings["user"] = url.User.Username()
|
||||
if password, present := url.User.Password(); present {
|
||||
settings["password"] = password
|
||||
}
|
||||
}
|
||||
|
||||
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
|
||||
var hosts []string
|
||||
var ports []string
|
||||
for _, host := range strings.Split(url.Host, ",") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
if isIPOnly(host) {
|
||||
hosts = append(hosts, strings.Trim(host, "[]"))
|
||||
continue
|
||||
}
|
||||
h, p, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
|
||||
}
|
||||
if h != "" {
|
||||
hosts = append(hosts, h)
|
||||
}
|
||||
if p != "" {
|
||||
ports = append(ports, p)
|
||||
}
|
||||
}
|
||||
if len(hosts) > 0 {
|
||||
settings["host"] = strings.Join(hosts, ",")
|
||||
}
|
||||
if len(ports) > 0 {
|
||||
settings["port"] = strings.Join(ports, ",")
|
||||
}
|
||||
|
||||
database := strings.TrimLeft(url.Path, "/")
|
||||
if database != "" {
|
||||
settings["database"] = database
|
||||
}
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
for k, v := range url.Query() {
|
||||
if k2, present := nameMap[k]; present {
|
||||
k = k2
|
||||
}
|
||||
|
||||
settings[k] = v[0]
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func isIPOnly(host string) bool {
|
||||
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
|
||||
}
|
||||
|
||||
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
|
||||
|
||||
func parseDSNSettings(s string) (map[string]string, error) {
|
||||
settings := make(map[string]string)
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
for len(s) > 0 {
|
||||
var key, val string
|
||||
eqIdx := strings.IndexRune(s, '=')
|
||||
if eqIdx < 0 {
|
||||
return nil, errors.New("invalid dsn")
|
||||
}
|
||||
|
||||
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
|
||||
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
|
||||
if len(s) == 0 {
|
||||
} else if s[0] != '\'' {
|
||||
end := 0
|
||||
for ; end < len(s); end++ {
|
||||
if asciiSpace[s[end]] == 1 {
|
||||
break
|
||||
}
|
||||
if s[end] == '\\' {
|
||||
end++
|
||||
if end == len(s) {
|
||||
return nil, errors.New("invalid backslash")
|
||||
}
|
||||
}
|
||||
}
|
||||
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||
if end == len(s) {
|
||||
s = ""
|
||||
} else {
|
||||
s = s[end+1:]
|
||||
}
|
||||
} else { // quoted string
|
||||
s = s[1:]
|
||||
end := 0
|
||||
for ; end < len(s); end++ {
|
||||
if s[end] == '\'' {
|
||||
break
|
||||
}
|
||||
if s[end] == '\\' {
|
||||
end++
|
||||
}
|
||||
}
|
||||
if end == len(s) {
|
||||
return nil, errors.New("unterminated quoted string in connection info string")
|
||||
}
|
||||
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
|
||||
if end == len(s) {
|
||||
s = ""
|
||||
} else {
|
||||
s = s[end+1:]
|
||||
}
|
||||
}
|
||||
|
||||
if k, ok := nameMap[key]; ok {
|
||||
key = k
|
||||
}
|
||||
|
||||
if key == "" {
|
||||
return nil, errors.New("invalid dsn")
|
||||
}
|
||||
|
||||
settings[key] = val
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
|
||||
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
|
||||
}
|
||||
|
||||
service, err := servicefile.GetService(serviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to find service: %v", serviceName)
|
||||
}
|
||||
|
||||
nameMap := map[string]string{
|
||||
"dbname": "database",
|
||||
}
|
||||
|
||||
settings := make(map[string]string, len(service.Settings))
|
||||
for k, v := range service.Settings {
|
||||
if k2, present := nameMap[k]; present {
|
||||
k = k2
|
||||
}
|
||||
settings[k] = v
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
|
||||
// necessary to allow returning multiple TLS configs as sslmode "allow" and
|
||||
// "prefer" allow fallback.
|
||||
func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
|
||||
host := thisHost
|
||||
sslmode := settings["sslmode"]
|
||||
sslrootcert := settings["sslrootcert"]
|
||||
sslcert := settings["sslcert"]
|
||||
sslkey := settings["sslkey"]
|
||||
sslpassword := settings["sslpassword"]
|
||||
sslsni := settings["sslsni"]
|
||||
|
||||
// Match libpq default behavior
|
||||
if sslmode == "" {
|
||||
sslmode = "prefer"
|
||||
}
|
||||
if sslsni == "" {
|
||||
sslsni = "1"
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
switch sslmode {
|
||||
case "disable":
|
||||
return []*tls.Config{nil}, nil
|
||||
case "allow", "prefer":
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
case "require":
|
||||
// According to PostgreSQL documentation, if a root CA file exists,
|
||||
// the behavior of sslmode=require should be the same as that of verify-ca
|
||||
//
|
||||
// See https://www.postgresql.org/docs/12/libpq-ssl.html
|
||||
if sslrootcert != "" {
|
||||
goto nextCase
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
break
|
||||
nextCase:
|
||||
fallthrough
|
||||
case "verify-ca":
|
||||
// Don't perform the default certificate verification because it
|
||||
// will verify the hostname. Instead, verify the server's
|
||||
// certificate chain ourselves in VerifyPeerCertificate and
|
||||
// ignore the server name. This emulates libpq's verify-ca
|
||||
// behavior.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
|
||||
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
|
||||
// for more info.
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
|
||||
certs := make([]*x509.Certificate, len(certificates))
|
||||
for i, asn1Data := range certificates {
|
||||
cert, err := x509.ParseCertificate(asn1Data)
|
||||
if err != nil {
|
||||
return errors.New("failed to parse certificate from server: " + err.Error())
|
||||
}
|
||||
certs[i] = cert
|
||||
}
|
||||
|
||||
// Leave DNSName empty to skip hostname verification.
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: tlsConfig.RootCAs,
|
||||
Intermediates: x509.NewCertPool(),
|
||||
}
|
||||
// Skip the first cert because it's the leaf. All others
|
||||
// are intermediates.
|
||||
for _, cert := range certs[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
_, err := certs[0].Verify(opts)
|
||||
return err
|
||||
}
|
||||
case "verify-full":
|
||||
tlsConfig.ServerName = host
|
||||
default:
|
||||
return nil, errors.New("sslmode is invalid")
|
||||
}
|
||||
|
||||
if sslrootcert != "" {
|
||||
caCertPool := x509.NewCertPool()
|
||||
|
||||
caPath := sslrootcert
|
||||
caCert, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read CA file: %w", err)
|
||||
}
|
||||
|
||||
if !caCertPool.AppendCertsFromPEM(caCert) {
|
||||
return nil, errors.New("unable to add CA to cert pool")
|
||||
}
|
||||
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
|
||||
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
|
||||
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
|
||||
}
|
||||
|
||||
if sslcert != "" && sslkey != "" {
|
||||
buf, err := ioutil.ReadFile(sslkey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read sslkey: %w", err)
|
||||
}
|
||||
block, _ := pem.Decode(buf)
|
||||
var pemKey []byte
|
||||
var decryptedKey []byte
|
||||
var decryptedError error
|
||||
// If PEM is encrypted, attempt to decrypt using pass phrase
|
||||
if x509.IsEncryptedPEMBlock(block) {
|
||||
// Attempt decryption with pass phrase
|
||||
// NOTE: only supports RSA (PKCS#1)
|
||||
if sslpassword != "" {
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
}
|
||||
//if sslpassword not provided or has decryption error when use it
|
||||
//try to find sslpassword with callback function
|
||||
if sslpassword == "" || decryptedError != nil {
|
||||
if parseConfigOptions.GetSSLPassword != nil {
|
||||
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
||||
}
|
||||
if sslpassword == "" {
|
||||
return nil, fmt.Errorf("unable to find sslpassword")
|
||||
}
|
||||
}
|
||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||
// Should we also provide warning for PKCS#1 needed?
|
||||
if decryptedError != nil {
|
||||
return nil, fmt.Errorf("unable to decrypt key: %w", err)
|
||||
}
|
||||
|
||||
pemBytes := pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: decryptedKey,
|
||||
}
|
||||
pemKey = pem.EncodeToMemory(&pemBytes)
|
||||
} else {
|
||||
pemKey = pem.EncodeToMemory(block)
|
||||
}
|
||||
certfile, err := ioutil.ReadFile(sslcert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read cert: %w", err)
|
||||
}
|
||||
cert, err := tls.X509KeyPair(certfile, pemKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load cert: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
// Set Server Name Indication (SNI), if enabled by connection parameters.
|
||||
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
|
||||
// or IPv6).
|
||||
if sslsni == "1" && net.ParseIP(host) == nil {
|
||||
tlsConfig.ServerName = host
|
||||
}
|
||||
|
||||
switch sslmode {
|
||||
case "allow":
|
||||
return []*tls.Config{nil, tlsConfig}, nil
|
||||
case "prefer":
|
||||
return []*tls.Config{tlsConfig, nil}, nil
|
||||
case "require", "verify-ca", "verify-full":
|
||||
return []*tls.Config{tlsConfig}, nil
|
||||
default:
|
||||
panic("BUG: bad sslmode should already have been caught")
|
||||
}
|
||||
}
|
||||
|
||||
func parsePort(s string) (uint16, error) {
|
||||
port, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if port < 1 || port > math.MaxUint16 {
|
||||
return 0, errors.New("outside range")
|
||||
}
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
func makeDefaultDialer() *net.Dialer {
|
||||
return &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
}
|
||||
|
||||
func makeDefaultResolver() *net.Resolver {
|
||||
return net.DefaultResolver
|
||||
}
|
||||
|
||||
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
|
||||
return func(r io.Reader, w io.Writer) Frontend {
|
||||
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
|
||||
}
|
||||
frontend := pgproto3.NewFrontend(cr, w)
|
||||
|
||||
return frontend
|
||||
}
|
||||
}
|
||||
|
||||
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
|
||||
timeout, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if timeout < 0 {
|
||||
return 0, errors.New("negative timeout")
|
||||
}
|
||||
return time.Duration(timeout) * time.Second, nil
|
||||
}
|
||||
|
||||
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
|
||||
d := makeDefaultDialer()
|
||||
d.Timeout = timeout
|
||||
return d.DialContext
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-write.
|
||||
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) == "on" {
|
||||
return errors.New("read only connection")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=read-only.
|
||||
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "on" {
|
||||
return errors.New("connection is not read only")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=standby.
|
||||
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "t" {
|
||||
return errors.New("server is not in hot standby mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=primary.
|
||||
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) == "t" {
|
||||
return errors.New("server is in standby mode")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
|
||||
// target_session_attrs=prefer-standby.
|
||||
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
|
||||
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||
if result.Err != nil {
|
||||
return result.Err
|
||||
}
|
||||
|
||||
if string(result.Rows[0][0]) != "t" {
|
||||
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,65 +0,0 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func defaultSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
settings["host"] = defaultHost()
|
||||
settings["port"] = "5432"
|
||||
|
||||
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||
// OS. The client application will simply have to specify the user in that
|
||||
// case (which they typically will be doing anyway).
|
||||
user, err := user.Current()
|
||||
if err == nil {
|
||||
settings["user"] = user.Username
|
||||
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
|
||||
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
|
||||
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
|
||||
if _, err := os.Stat(sslcert); err == nil {
|
||||
if _, err := os.Stat(sslkey); err == nil {
|
||||
// Both the cert and key must be present to use them, or do not use either
|
||||
settings["sslcert"] = sslcert
|
||||
settings["sslkey"] = sslkey
|
||||
}
|
||||
}
|
||||
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
|
||||
if _, err := os.Stat(sslrootcert); err == nil {
|
||||
settings["sslrootcert"] = sslrootcert
|
||||
}
|
||||
}
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
|
||||
settings["min_read_buffer_size"] = "8192"
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||
// checks the existence of common locations.
|
||||
func defaultHost() string {
|
||||
candidatePaths := []string{
|
||||
"/var/run/postgresql", // Debian
|
||||
"/private/tmp", // OSX - homebrew
|
||||
"/tmp", // standard PostgreSQL
|
||||
}
|
||||
|
||||
for _, path := range candidatePaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return path
|
||||
}
|
||||
}
|
||||
|
||||
return "localhost"
|
||||
}
|
@ -1,59 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func defaultSettings() map[string]string {
|
||||
settings := make(map[string]string)
|
||||
|
||||
settings["host"] = defaultHost()
|
||||
settings["port"] = "5432"
|
||||
|
||||
// Default to the OS user name. Purposely ignoring err getting user name from
|
||||
// OS. The client application will simply have to specify the user in that
|
||||
// case (which they typically will be doing anyway).
|
||||
user, err := user.Current()
|
||||
appData := os.Getenv("APPDATA")
|
||||
if err == nil {
|
||||
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
|
||||
// but the libpq default is just the `user` portion, so we strip off the first part.
|
||||
username := user.Username
|
||||
if strings.Contains(username, "\\") {
|
||||
username = username[strings.LastIndex(username, "\\")+1:]
|
||||
}
|
||||
|
||||
settings["user"] = username
|
||||
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
|
||||
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
|
||||
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
|
||||
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
|
||||
if _, err := os.Stat(sslcert); err == nil {
|
||||
if _, err := os.Stat(sslkey); err == nil {
|
||||
// Both the cert and key must be present to use them, or do not use either
|
||||
settings["sslcert"] = sslcert
|
||||
settings["sslkey"] = sslkey
|
||||
}
|
||||
}
|
||||
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
|
||||
if _, err := os.Stat(sslrootcert); err == nil {
|
||||
settings["sslrootcert"] = sslrootcert
|
||||
}
|
||||
}
|
||||
|
||||
settings["target_session_attrs"] = "any"
|
||||
|
||||
settings["min_read_buffer_size"] = "8192"
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
|
||||
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
|
||||
// checks the existence of common locations.
|
||||
func defaultHost() string {
|
||||
return "localhost"
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
// Package pgconn is a low-level PostgreSQL database driver.
|
||||
/*
|
||||
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
|
||||
nearly the same level is the C library libpq.
|
||||
|
||||
Establishing a Connection
|
||||
|
||||
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
|
||||
libpq style environment variables.
|
||||
|
||||
Executing a Query
|
||||
|
||||
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
|
||||
reads all rows into memory.
|
||||
|
||||
Executing Multiple Queries in a Single Round Trip
|
||||
|
||||
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
|
||||
result. The ReadAll method reads all query results into memory.
|
||||
|
||||
Context Support
|
||||
|
||||
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
|
||||
method immediately returns. In most circumstances, this will close the underlying connection.
|
||||
|
||||
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
|
||||
client to abort.
|
||||
*/
|
||||
package pgconn
|
@ -1,238 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
|
||||
func SafeToRetry(err error) bool {
|
||||
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
|
||||
return e.SafeToRetry()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
|
||||
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
|
||||
func Timeout(err error) bool {
|
||||
var timeoutErr *errTimeout
|
||||
return errors.As(err, &timeoutErr)
|
||||
}
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
||||
// detailed field description.
|
||||
type PgError struct {
|
||||
Severity string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
}
|
||||
|
||||
func (pe *PgError) Error() string {
|
||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
// SQLState returns the SQLState of the error.
|
||||
func (pe *PgError) SQLState() string {
|
||||
return pe.Code
|
||||
}
|
||||
|
||||
type connectError struct {
|
||||
config *Config
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *connectError) Error() string {
|
||||
sb := &strings.Builder{}
|
||||
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
|
||||
if e.err != nil {
|
||||
fmt.Fprintf(sb, " (%s)", e.err.Error())
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (e *connectError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type connLockError struct {
|
||||
status string
|
||||
}
|
||||
|
||||
func (e *connLockError) SafeToRetry() bool {
|
||||
return true // a lock failure by definition happens before the connection is used.
|
||||
}
|
||||
|
||||
func (e *connLockError) Error() string {
|
||||
return e.status
|
||||
}
|
||||
|
||||
type parseConfigError struct {
|
||||
connString string
|
||||
msg string
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Error() string {
|
||||
connString := redactPW(e.connString)
|
||||
if e.err == nil {
|
||||
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
|
||||
}
|
||||
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *parseConfigError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
|
||||
// true. Otherwise returns err.
|
||||
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
|
||||
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
|
||||
return &errTimeout{err: ctx.Err()}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type pgconnError struct {
|
||||
msg string
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *pgconnError) Error() string {
|
||||
if e.msg == "" {
|
||||
return e.err.Error()
|
||||
}
|
||||
if e.err == nil {
|
||||
return e.msg
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
|
||||
}
|
||||
|
||||
func (e *pgconnError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *pgconnError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
|
||||
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
|
||||
type errTimeout struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *errTimeout) Error() string {
|
||||
return fmt.Sprintf("timeout: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *errTimeout) SafeToRetry() bool {
|
||||
return SafeToRetry(e.err)
|
||||
}
|
||||
|
||||
func (e *errTimeout) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
type contextAlreadyDoneError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Error() string {
|
||||
return fmt.Sprintf("context already done: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) SafeToRetry() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *contextAlreadyDoneError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
|
||||
func newContextAlreadyDoneError(ctx context.Context) (err error) {
|
||||
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
|
||||
}
|
||||
|
||||
type writeError struct {
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *writeError) Error() string {
|
||||
return fmt.Sprintf("write failed: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *writeError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *writeError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func redactPW(connString string) string {
|
||||
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
|
||||
if u, err := url.Parse(connString); err == nil {
|
||||
return redactURL(u)
|
||||
}
|
||||
}
|
||||
quotedDSN := regexp.MustCompile(`password='[^']*'`)
|
||||
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
plainDSN := regexp.MustCompile(`password=[^ ]*`)
|
||||
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
|
||||
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
|
||||
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
|
||||
return connString
|
||||
}
|
||||
|
||||
func redactURL(u *url.URL) string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
if _, pwSet := u.User.Password(); pwSet {
|
||||
u.User = url.UserPassword(u.User.Username(), "xxxxx")
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
type NotPreferredError struct {
|
||||
err error
|
||||
safeToRetry bool
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) Error() string {
|
||||
return fmt.Sprintf("standby server not found: %s", e.err.Error())
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) SafeToRetry() bool {
|
||||
return e.safeToRetry
|
||||
}
|
||||
|
||||
func (e *NotPreferredError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
package ctxwatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
|
||||
// time.
|
||||
type ContextWatcher struct {
|
||||
onCancel func()
|
||||
onUnwatchAfterCancel func()
|
||||
unwatchChan chan struct{}
|
||||
|
||||
lock sync.Mutex
|
||||
watchInProgress bool
|
||||
onCancelWasCalled bool
|
||||
}
|
||||
|
||||
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
|
||||
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
|
||||
// onCancel called.
|
||||
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
|
||||
cw := &ContextWatcher{
|
||||
onCancel: onCancel,
|
||||
onUnwatchAfterCancel: onUnwatchAfterCancel,
|
||||
unwatchChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
return cw
|
||||
}
|
||||
|
||||
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
|
||||
func (cw *ContextWatcher) Watch(ctx context.Context) {
|
||||
cw.lock.Lock()
|
||||
defer cw.lock.Unlock()
|
||||
|
||||
if cw.watchInProgress {
|
||||
panic("Watch already in progress")
|
||||
}
|
||||
|
||||
cw.onCancelWasCalled = false
|
||||
|
||||
if ctx.Done() != nil {
|
||||
cw.watchInProgress = true
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cw.onCancel()
|
||||
cw.onCancelWasCalled = true
|
||||
<-cw.unwatchChan
|
||||
case <-cw.unwatchChan:
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
cw.watchInProgress = false
|
||||
}
|
||||
}
|
||||
|
||||
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
|
||||
// called then onUnwatchAfterCancel will also be called.
|
||||
func (cw *ContextWatcher) Unwatch() {
|
||||
cw.lock.Lock()
|
||||
defer cw.lock.Unlock()
|
||||
|
||||
if cw.watchInProgress {
|
||||
cw.unwatchChan <- struct{}{}
|
||||
if cw.onCancelWasCalled {
|
||||
cw.onUnwatchAfterCancel()
|
||||
}
|
||||
cw.watchInProgress = false
|
||||
}
|
||||
}
|
@ -1,99 +0,0 @@
|
||||
package pgconn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
)
|
||||
|
||||
// NewGSSFunc creates a GSS authentication provider, for use with
|
||||
// RegisterGSSProvider.
|
||||
type NewGSSFunc func() (GSS, error)
|
||||
|
||||
var newGSS NewGSSFunc
|
||||
|
||||
// RegisterGSSProvider registers a GSS authentication provider. For example, if
|
||||
// you need to use Kerberos to authenticate with your server, add this to your
|
||||
// main package:
|
||||
//
|
||||
// import "github.com/otan/gopgkrb5"
|
||||
//
|
||||
// func init() {
|
||||
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
|
||||
// }
|
||||
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
|
||||
newGSS = newGSSArg
|
||||
}
|
||||
|
||||
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
||||
type GSS interface {
|
||||
GetInitToken(host string, service string) ([]byte, error)
|
||||
GetInitTokenFromSPN(spn string) ([]byte, error)
|
||||
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
||||
}
|
||||
|
||||
func (c *PgConn) gssAuth() error {
|
||||
if newGSS == nil {
|
||||
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
|
||||
}
|
||||
cli, err := newGSS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var nextData []byte
|
||||
if c.config.KerberosSpn != "" {
|
||||
// Use the supplied SPN if provided.
|
||||
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
|
||||
} else {
|
||||
// Allow the kerberos service name to be overridden
|
||||
service := "postgres"
|
||||
if c.config.KerberosSrvName != "" {
|
||||
service = c.config.KerberosSrvName
|
||||
}
|
||||
nextData, err = cli.GetInitToken(c.config.Host, service)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
gssResponse := &pgproto3.GSSResponse{
|
||||
Data: nextData,
|
||||
}
|
||||
_, err = c.conn.Write(gssResponse.Encode(nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.rxGSSContinue()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var done bool
|
||||
done, nextData, err = cli.Continue(resp.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
|
||||
msg, err := c.receiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch m := msg.(type) {
|
||||
case *pgproto3.AuthenticationGSSContinue:
|
||||
return m, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, ErrorResponseToPgError(m)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,169 +0,0 @@
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
)
|
||||
|
||||
var lruCount uint64
|
||||
|
||||
// LRU implements Cache with a Least Recently Used (LRU) cache.
|
||||
type LRU struct {
|
||||
conn *pgconn.PgConn
|
||||
mode int
|
||||
cap int
|
||||
prepareCount int
|
||||
m map[string]*list.Element
|
||||
l *list.List
|
||||
psNamePrefix string
|
||||
stmtsToClear []string
|
||||
}
|
||||
|
||||
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
|
||||
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
|
||||
mustBeValidMode(mode)
|
||||
mustBeValidCap(cap)
|
||||
|
||||
n := atomic.AddUint64(&lruCount, 1)
|
||||
|
||||
return &LRU{
|
||||
conn: conn,
|
||||
mode: mode,
|
||||
cap: cap,
|
||||
m: make(map[string]*list.Element),
|
||||
l: list.New(),
|
||||
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||
if ctx != context.Background() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// flush an outstanding bad statements
|
||||
txStatus := c.conn.TxStatus()
|
||||
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
|
||||
for _, stmt := range c.stmtsToClear {
|
||||
err := c.clearStmt(ctx, stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if el, ok := c.m[sql]; ok {
|
||||
c.l.MoveToFront(el)
|
||||
return el.Value.(*pgconn.StatementDescription), nil
|
||||
}
|
||||
|
||||
if c.l.Len() == c.cap {
|
||||
err := c.removeOldest(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
psd, err := c.prepare(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
el := c.l.PushFront(psd)
|
||||
c.m[sql] = el
|
||||
|
||||
return psd, nil
|
||||
}
|
||||
|
||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||
func (c *LRU) Clear(ctx context.Context) error {
|
||||
for c.l.Len() > 0 {
|
||||
err := c.removeOldest(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LRU) StatementErrored(sql string, err error) {
|
||||
pgErr, ok := err.(*pgconn.PgError)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1162
|
||||
//
|
||||
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
|
||||
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
|
||||
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
|
||||
// have so it should be safe.
|
||||
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
|
||||
if possibleInvalidCachedPlanError {
|
||||
c.stmtsToClear = append(c.stmtsToClear, sql)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
|
||||
elem, inMap := c.m[sql]
|
||||
if !inMap {
|
||||
// The statement probably fell off the back of the list. In that case, we've
|
||||
// ensured that it isn't in the cache, so we can declare victory.
|
||||
return nil
|
||||
}
|
||||
|
||||
c.l.Remove(elem)
|
||||
|
||||
psd := elem.Value.(*pgconn.StatementDescription)
|
||||
delete(c.m, psd.SQL)
|
||||
if c.mode == ModePrepare {
|
||||
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
func (c *LRU) Len() int {
|
||||
return c.l.Len()
|
||||
}
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
func (c *LRU) Cap() int {
|
||||
return c.cap
|
||||
}
|
||||
|
||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||
func (c *LRU) Mode() int {
|
||||
return c.mode
|
||||
}
|
||||
|
||||
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
|
||||
var name string
|
||||
if c.mode == ModePrepare {
|
||||
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
|
||||
c.prepareCount += 1
|
||||
}
|
||||
|
||||
return c.conn.Prepare(ctx, name, sql, nil)
|
||||
}
|
||||
|
||||
func (c *LRU) removeOldest(ctx context.Context) error {
|
||||
oldest := c.l.Back()
|
||||
c.l.Remove(oldest)
|
||||
psd := oldest.Value.(*pgconn.StatementDescription)
|
||||
delete(c.m, psd.SQL)
|
||||
if c.mode == ModePrepare {
|
||||
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
|
||||
package stmtcache
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
)
|
||||
|
||||
const (
|
||||
ModePrepare = iota // Cache should prepare named statements.
|
||||
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
|
||||
)
|
||||
|
||||
// Cache prepares and caches prepared statement descriptions.
|
||||
type Cache interface {
|
||||
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
|
||||
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
|
||||
|
||||
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// StatementErrored informs the cache that the given statement resulted in an error when it
|
||||
// was last used against the database. In some cases, this will cause the cache to maer that
|
||||
// statement as bad. The bad statement will instead be flushed during the next call to Get
|
||||
// that occurs outside of a failed transaction.
|
||||
StatementErrored(sql string, err error)
|
||||
|
||||
// Len returns the number of cached prepared statement descriptions.
|
||||
Len() int
|
||||
|
||||
// Cap returns the maximum number of cached prepared statement descriptions.
|
||||
Cap() int
|
||||
|
||||
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
|
||||
Mode() int
|
||||
}
|
||||
|
||||
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
|
||||
// the maximum size of the cache.
|
||||
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
|
||||
mustBeValidMode(mode)
|
||||
mustBeValidCap(cap)
|
||||
|
||||
return NewLRU(conn, mode, cap)
|
||||
}
|
||||
|
||||
func mustBeValidMode(mode int) {
|
||||
if mode != ModePrepare && mode != ModeDescribe {
|
||||
panic("mode must be ModePrepare or ModeDescribe")
|
||||
}
|
||||
}
|
||||
|
||||
func mustBeValidCap(cap int) {
|
||||
if cap < 1 {
|
||||
panic("cache must have cap of >= 1")
|
||||
}
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.x
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
@ -1,22 +0,0 @@
|
||||
Copyright (c) 2019 Jack Christensen
|
||||
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -1,11 +0,0 @@
|
||||
[![](https://godoc.org/github.com/jackc/pgio?status.svg)](https://godoc.org/github.com/jackc/pgio)
|
||||
[![Build Status](https://travis-ci.org/jackc/pgio.svg)](https://travis-ci.org/jackc/pgio)
|
||||
|
||||
# pgio
|
||||
|
||||
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||
|
||||
pgio provides functions for appending integers to a []byte while doing byte
|
||||
order conversion.
|
||||
|
||||
Extracted from original implementation in https://github.com/jackc/pgx.
|
@ -1,6 +0,0 @@
|
||||
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
|
||||
/*
|
||||
pgio provides functions for appending integers to a []byte while doing byte
|
||||
order conversion.
|
||||
*/
|
||||
package pgio
|
@ -1,40 +0,0 @@
|
||||
package pgio
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func AppendUint16(buf []byte, n uint16) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0)
|
||||
binary.BigEndian.PutUint16(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendUint32(buf []byte, n uint32) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0)
|
||||
binary.BigEndian.PutUint32(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendUint64(buf []byte, n uint64) []byte {
|
||||
wp := len(buf)
|
||||
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
|
||||
binary.BigEndian.PutUint64(buf[wp:], n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func AppendInt16(buf []byte, n int16) []byte {
|
||||
return AppendUint16(buf, uint16(n))
|
||||
}
|
||||
|
||||
func AppendInt32(buf []byte, n int32) []byte {
|
||||
return AppendUint32(buf, uint32(n))
|
||||
}
|
||||
|
||||
func AppendInt64(buf []byte, n int64) []byte {
|
||||
return AppendUint64(buf, uint64(n))
|
||||
}
|
||||
|
||||
func SetInt32(buf []byte, n int32) {
|
||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.x
|
||||
- tip
|
||||
|
||||
matrix:
|
||||
allow_failures:
|
||||
- go: tip
|
@ -1,22 +0,0 @@
|
||||
Copyright (c) 2019 Jack Christensen
|
||||
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -1,18 +0,0 @@
|
||||
[![](https://godoc.org/github.com/jackc/pgproto3?status.svg)](https://godoc.org/github.com/jackc/pgproto3)
|
||||
[![Build Status](https://travis-ci.org/jackc/pgproto3.svg)](https://travis-ci.org/jackc/pgproto3)
|
||||
|
||||
---
|
||||
|
||||
This version is used with pgx `v4`. In pgx `v5` it is part of the https://github.com/jackc/pgx repository.
|
||||
|
||||
---
|
||||
|
||||
# pgproto3
|
||||
|
||||
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||
|
||||
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.
|
||||
|
||||
See example/pgfortune for a playful example of a fake PostgreSQL server.
|
||||
|
||||
Extracted from original implementation in https://github.com/jackc/pgx.
|
@ -1,52 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
|
||||
type AuthenticationCleartextPassword struct {
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationCleartextPassword) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
|
||||
if len(src) != 4 {
|
||||
return errors.New("bad authentication message size")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeCleartextPassword {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "AuthenticationCleartextPassword",
|
||||
})
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type AuthenticationGSS struct{}
|
||||
|
||||
func (a *AuthenticationGSS) Backend() {}
|
||||
|
||||
func (a *AuthenticationGSS) AuthenticationResponse() {}
|
||||
|
||||
func (a *AuthenticationGSS) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeGSS {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 4)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSS)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}{
|
||||
Type: "AuthenticationGSS",
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Type string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,67 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type AuthenticationGSSContinue struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Backend() {}
|
||||
|
||||
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeGSSCont {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
a.Data = src[4:]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
|
||||
dst = append(dst, a.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}{
|
||||
Type: "AuthenticationGSSContinue",
|
||||
Data: a.Data,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a.Data = msg.Data
|
||||
return nil
|
||||
}
|
@ -1,77 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
|
||||
type AuthenticationMD5Password struct {
|
||||
Salt [4]byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationMD5Password) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationMD5Password) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
|
||||
if len(src) != 8 {
|
||||
return errors.New("bad authentication message size")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeMD5Password {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
copy(dst.Salt[:], src[4:8])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 12)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
|
||||
dst = append(dst, src.Salt[:]...)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Salt [4]byte
|
||||
}{
|
||||
Type: "AuthenticationMD5Password",
|
||||
Salt: src.Salt,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Type string
|
||||
Salt [4]byte
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Salt = msg.Salt
|
||||
return nil
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
|
||||
type AuthenticationOk struct {
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationOk) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationOk) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationOk) Decode(src []byte) error {
|
||||
if len(src) != 4 {
|
||||
return errors.New("bad authentication message size")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeOk {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationOk) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeOk)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationOk) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "AuthenticationOK",
|
||||
})
|
||||
}
|
@ -1,75 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
|
||||
type AuthenticationSASL struct {
|
||||
AuthMechanisms []string
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationSASL) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationSASL) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationSASL) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeSASL {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
authMechanisms := src[4:]
|
||||
for len(authMechanisms) > 1 {
|
||||
idx := bytes.IndexByte(authMechanisms, 0)
|
||||
if idx > 0 {
|
||||
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
|
||||
authMechanisms = authMechanisms[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASL)
|
||||
|
||||
for _, s := range src.AuthMechanisms {
|
||||
dst = append(dst, []byte(s)...)
|
||||
dst = append(dst, 0)
|
||||
}
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationSASL) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
AuthMechanisms []string
|
||||
}{
|
||||
Type: "AuthenticationSASL",
|
||||
AuthMechanisms: src.AuthMechanisms,
|
||||
})
|
||||
}
|
@ -1,81 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
|
||||
type AuthenticationSASLContinue struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationSASLContinue) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeSASLContinue {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
dst.Data = src[4:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "AuthenticationSASLContinue",
|
||||
Data: string(src.Data),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Data = []byte(msg.Data)
|
||||
return nil
|
||||
}
|
@ -1,81 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
|
||||
type AuthenticationSASLFinal struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*AuthenticationSASLFinal) Backend() {}
|
||||
|
||||
// Backend identifies this message as an authentication response.
|
||||
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("authentication message too short")
|
||||
}
|
||||
|
||||
authType := binary.BigEndian.Uint32(src)
|
||||
|
||||
if authType != AuthTypeSASLFinal {
|
||||
return errors.New("bad auth type")
|
||||
}
|
||||
|
||||
dst.Data = src[4:]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'R')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
|
||||
|
||||
dst = append(dst, src.Data...)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "AuthenticationSASLFinal",
|
||||
Data: string(src.Data),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Data = []byte(msg.Data)
|
||||
return nil
|
||||
}
|
@ -1,213 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Backend acts as a server for the PostgreSQL wire protocol version 3.
|
||||
type Backend struct {
|
||||
cr ChunkReader
|
||||
w io.Writer
|
||||
|
||||
// Frontend message flyweights
|
||||
bind Bind
|
||||
cancelRequest CancelRequest
|
||||
_close Close
|
||||
copyFail CopyFail
|
||||
copyData CopyData
|
||||
copyDone CopyDone
|
||||
describe Describe
|
||||
execute Execute
|
||||
flush Flush
|
||||
functionCall FunctionCall
|
||||
gssEncRequest GSSEncRequest
|
||||
parse Parse
|
||||
query Query
|
||||
sslRequest SSLRequest
|
||||
startupMessage StartupMessage
|
||||
sync Sync
|
||||
terminate Terminate
|
||||
|
||||
bodyLen int
|
||||
msgType byte
|
||||
partialMsg bool
|
||||
authType uint32
|
||||
}
|
||||
|
||||
const (
|
||||
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
|
||||
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
|
||||
)
|
||||
|
||||
// NewBackend creates a new Backend.
|
||||
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
|
||||
return &Backend{cr: cr, w: w}
|
||||
}
|
||||
|
||||
// Send sends a message to the frontend.
|
||||
func (b *Backend) Send(msg BackendMessage) error {
|
||||
_, err := b.w.Write(msg.Encode(nil))
|
||||
return err
|
||||
}
|
||||
|
||||
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
|
||||
// because the initial connection message is "special" and does not include the message type as the first byte. This
|
||||
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
|
||||
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
|
||||
buf, err := b.cr.Next(4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
|
||||
|
||||
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
|
||||
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
|
||||
}
|
||||
|
||||
buf, err = b.cr.Next(msgSize)
|
||||
if err != nil {
|
||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||
}
|
||||
|
||||
code := binary.BigEndian.Uint32(buf)
|
||||
|
||||
switch code {
|
||||
case ProtocolVersionNumber:
|
||||
err = b.startupMessage.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.startupMessage, nil
|
||||
case sslRequestNumber:
|
||||
err = b.sslRequest.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.sslRequest, nil
|
||||
case cancelRequestCode:
|
||||
err = b.cancelRequest.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.cancelRequest, nil
|
||||
case gssEncReqNumber:
|
||||
err = b.gssEncRequest.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &b.gssEncRequest, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown startup message code: %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
|
||||
func (b *Backend) Receive() (FrontendMessage, error) {
|
||||
if !b.partialMsg {
|
||||
header, err := b.cr.Next(5)
|
||||
if err != nil {
|
||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||
}
|
||||
|
||||
b.msgType = header[0]
|
||||
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
b.partialMsg = true
|
||||
if b.bodyLen < 0 {
|
||||
return nil, errors.New("invalid message with negative body length received")
|
||||
}
|
||||
}
|
||||
|
||||
var msg FrontendMessage
|
||||
switch b.msgType {
|
||||
case 'B':
|
||||
msg = &b.bind
|
||||
case 'C':
|
||||
msg = &b._close
|
||||
case 'D':
|
||||
msg = &b.describe
|
||||
case 'E':
|
||||
msg = &b.execute
|
||||
case 'F':
|
||||
msg = &b.functionCall
|
||||
case 'f':
|
||||
msg = &b.copyFail
|
||||
case 'd':
|
||||
msg = &b.copyData
|
||||
case 'c':
|
||||
msg = &b.copyDone
|
||||
case 'H':
|
||||
msg = &b.flush
|
||||
case 'P':
|
||||
msg = &b.parse
|
||||
case 'p':
|
||||
switch b.authType {
|
||||
case AuthTypeSASL:
|
||||
msg = &SASLInitialResponse{}
|
||||
case AuthTypeSASLContinue:
|
||||
msg = &SASLResponse{}
|
||||
case AuthTypeSASLFinal:
|
||||
msg = &SASLResponse{}
|
||||
case AuthTypeGSS, AuthTypeGSSCont:
|
||||
msg = &GSSResponse{}
|
||||
case AuthTypeCleartextPassword, AuthTypeMD5Password:
|
||||
fallthrough
|
||||
default:
|
||||
// to maintain backwards compatability
|
||||
msg = &PasswordMessage{}
|
||||
}
|
||||
case 'Q':
|
||||
msg = &b.query
|
||||
case 'S':
|
||||
msg = &b.sync
|
||||
case 'X':
|
||||
msg = &b.terminate
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
|
||||
}
|
||||
|
||||
msgBody, err := b.cr.Next(b.bodyLen)
|
||||
if err != nil {
|
||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||
}
|
||||
|
||||
b.partialMsg = false
|
||||
|
||||
err = msg.Decode(msgBody)
|
||||
return msg, err
|
||||
}
|
||||
|
||||
// SetAuthType sets the authentication type in the backend.
|
||||
// Since multiple message types can start with 'p', SetAuthType allows
|
||||
// contextual identification of FrontendMessages. For example, in the
|
||||
// PG message flow documentation for PasswordMessage:
|
||||
//
|
||||
// Byte1('p')
|
||||
//
|
||||
// Identifies the message as a password response. Note that this is also used for
|
||||
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
|
||||
// the context.
|
||||
//
|
||||
// Since the Frontend does not know about the state of a backend, it is important
|
||||
// to call SetAuthType() after an authentication request is received by the Frontend.
|
||||
func (b *Backend) SetAuthType(authType uint32) error {
|
||||
switch authType {
|
||||
case AuthTypeOk,
|
||||
AuthTypeCleartextPassword,
|
||||
AuthTypeMD5Password,
|
||||
AuthTypeSCMCreds,
|
||||
AuthTypeGSS,
|
||||
AuthTypeGSSCont,
|
||||
AuthTypeSSPI,
|
||||
AuthTypeSASL,
|
||||
AuthTypeSASLContinue,
|
||||
AuthTypeSASLFinal:
|
||||
b.authType = authType
|
||||
default:
|
||||
return fmt.Errorf("authType not recognized: %d", authType)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,51 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type BackendKeyData struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*BackendKeyData) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *BackendKeyData) Decode(src []byte) error {
|
||||
if len(src) != 8 {
|
||||
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BackendKeyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'K')
|
||||
dst = pgio.AppendUint32(dst, 12)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src BackendKeyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}{
|
||||
Type: "BackendKeyData",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
})
|
||||
}
|
@ -1,37 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type BigEndianBuf [8]byte
|
||||
|
||||
func (b BigEndianBuf) Int16(n int16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, uint16(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint16(n uint16) []byte {
|
||||
buf := b[0:2]
|
||||
binary.BigEndian.PutUint16(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int32(n int32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, uint32(n))
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Uint32(n uint32) []byte {
|
||||
buf := b[0:4]
|
||||
binary.BigEndian.PutUint32(buf, n)
|
||||
return buf
|
||||
}
|
||||
|
||||
func (b BigEndianBuf) Int64(n int64) []byte {
|
||||
buf := b[0:8]
|
||||
binary.BigEndian.PutUint64(buf, uint64(n))
|
||||
return buf
|
||||
}
|
@ -1,216 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type Bind struct {
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters [][]byte
|
||||
ResultFormatCodes []int16
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*Bind) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *Bind) Decode(src []byte) error {
|
||||
*dst = Bind{}
|
||||
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
dst.DestinationPortal = string(src[:idx])
|
||||
rp := idx + 1
|
||||
|
||||
idx = bytes.IndexByte(src[rp:], 0)
|
||||
if idx < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
dst.PreparedStatement = string(src[rp : rp+idx])
|
||||
rp += idx + 1
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
if parameterFormatCodeCount > 0 {
|
||||
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
|
||||
|
||||
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < parameterFormatCodeCount; i++ {
|
||||
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
}
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
if parameterCount > 0 {
|
||||
dst.Parameters = make([][]byte, parameterCount)
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(src[rp:]) < msgSize {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
|
||||
dst.Parameters[i] = src[rp : rp+msgSize]
|
||||
rp += msgSize
|
||||
}
|
||||
}
|
||||
|
||||
if len(src[rp:]) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
|
||||
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
|
||||
return &invalidMessageFormatErr{messageType: "Bind"}
|
||||
}
|
||||
for i := 0; i < resultFormatCodeCount; i++ {
|
||||
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Bind) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'B')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.DestinationPortal...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.PreparedStatement...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
|
||||
for _, fc := range src.ParameterFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
|
||||
for _, p := range src.Parameters {
|
||||
if p == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
dst = pgio.AppendInt32(dst, int32(len(p)))
|
||||
dst = append(dst, p...)
|
||||
}
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
|
||||
for _, fc := range src.ResultFormatCodes {
|
||||
dst = pgio.AppendInt16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src Bind) MarshalJSON() ([]byte, error) {
|
||||
formattedParameters := make([]map[string]string, len(src.Parameters))
|
||||
for i, p := range src.Parameters {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
textFormat := true
|
||||
if len(src.ParameterFormatCodes) == 1 {
|
||||
textFormat = src.ParameterFormatCodes[0] == 0
|
||||
} else if len(src.ParameterFormatCodes) > 1 {
|
||||
textFormat = src.ParameterFormatCodes[i] == 0
|
||||
}
|
||||
|
||||
if textFormat {
|
||||
formattedParameters[i] = map[string]string{"text": string(p)}
|
||||
} else {
|
||||
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters []map[string]string
|
||||
ResultFormatCodes []int16
|
||||
}{
|
||||
Type: "Bind",
|
||||
DestinationPortal: src.DestinationPortal,
|
||||
PreparedStatement: src.PreparedStatement,
|
||||
ParameterFormatCodes: src.ParameterFormatCodes,
|
||||
Parameters: formattedParameters,
|
||||
ResultFormatCodes: src.ResultFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *Bind) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
DestinationPortal string
|
||||
PreparedStatement string
|
||||
ParameterFormatCodes []int16
|
||||
Parameters []map[string]string
|
||||
ResultFormatCodes []int16
|
||||
}
|
||||
err := json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.DestinationPortal = msg.DestinationPortal
|
||||
dst.PreparedStatement = msg.PreparedStatement
|
||||
dst.ParameterFormatCodes = msg.ParameterFormatCodes
|
||||
dst.Parameters = make([][]byte, len(msg.Parameters))
|
||||
dst.ResultFormatCodes = msg.ResultFormatCodes
|
||||
for n, parameter := range msg.Parameters {
|
||||
dst.Parameters[n], err = getValueFromJSON(parameter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot get param %d: %w", n, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type BindComplete struct{}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*BindComplete) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *BindComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *BindComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '2', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src BindComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "BindComplete",
|
||||
})
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
const cancelRequestCode = 80877102
|
||||
|
||||
type CancelRequest struct {
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*CancelRequest) Frontend() {}
|
||||
|
||||
func (dst *CancelRequest) Decode(src []byte) error {
|
||||
if len(src) != 12 {
|
||||
return errors.New("bad cancel request size")
|
||||
}
|
||||
|
||||
requestCode := binary.BigEndian.Uint32(src)
|
||||
|
||||
if requestCode != cancelRequestCode {
|
||||
return errors.New("bad cancel request code")
|
||||
}
|
||||
|
||||
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
|
||||
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *CancelRequest) Encode(dst []byte) []byte {
|
||||
dst = pgio.AppendInt32(dst, 16)
|
||||
dst = pgio.AppendInt32(dst, cancelRequestCode)
|
||||
dst = pgio.AppendUint32(dst, src.ProcessID)
|
||||
dst = pgio.AppendUint32(dst, src.SecretKey)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CancelRequest) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProcessID uint32
|
||||
SecretKey uint32
|
||||
}{
|
||||
Type: "CancelRequest",
|
||||
ProcessID: src.ProcessID,
|
||||
SecretKey: src.SecretKey,
|
||||
})
|
||||
}
|
@ -1,19 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/jackc/chunkreader/v2"
|
||||
)
|
||||
|
||||
// ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package.
|
||||
type ChunkReader interface {
|
||||
// Next returns buf filled with the next n bytes. If an error (including a partial read) occurs,
|
||||
// buf must be nil. Next must preserve any partially read data. Next must not reuse buf.
|
||||
Next(n int) (buf []byte, err error)
|
||||
}
|
||||
|
||||
// NewChunkReader creates and returns a new default ChunkReader.
|
||||
func NewChunkReader(r io.Reader) ChunkReader {
|
||||
return chunkreader.New(r)
|
||||
}
|
@ -1,89 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type Close struct {
|
||||
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||
Name string
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*Close) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *Close) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Close"}
|
||||
}
|
||||
|
||||
dst.ObjectType = src[0]
|
||||
rp := 1
|
||||
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx != len(src[rp:])-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Close"}
|
||||
}
|
||||
|
||||
dst.Name = string(src[rp : len(src)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Close) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src Close) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ObjectType string
|
||||
Name string
|
||||
}{
|
||||
Type: "Close",
|
||||
ObjectType: string(src.ObjectType),
|
||||
Name: src.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *Close) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
ObjectType string
|
||||
Name string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.ObjectType) != 1 {
|
||||
return errors.New("invalid length for Close.ObjectType")
|
||||
}
|
||||
|
||||
dst.ObjectType = byte(msg.ObjectType[0])
|
||||
dst.Name = msg.Name
|
||||
return nil
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CloseComplete struct{}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CloseComplete) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CloseComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CloseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '3', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CloseComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "CloseComplete",
|
||||
})
|
||||
}
|
@ -1,71 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type CommandComplete struct {
|
||||
CommandTag []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CommandComplete) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CommandComplete) Decode(src []byte) error {
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "CommandComplete"}
|
||||
}
|
||||
|
||||
dst.CommandTag = src[:idx]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CommandComplete) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'C')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.CommandTag...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CommandComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
CommandTag string
|
||||
}{
|
||||
Type: "CommandComplete",
|
||||
CommandTag: string(src.CommandTag),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
CommandTag string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.CommandTag = []byte(msg.CommandTag)
|
||||
return nil
|
||||
}
|
@ -1,95 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type CopyBothResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CopyBothResponse) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CopyBothResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyBothResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'W')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
dst = append(dst, src.OverallFormat)
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyBothResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
OverallFormat string
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.OverallFormat) != 1 {
|
||||
return errors.New("invalid length for CopyBothResponse.OverallFormat")
|
||||
}
|
||||
|
||||
dst.OverallFormat = msg.OverallFormat[0]
|
||||
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||
return nil
|
||||
}
|
@ -1,62 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type CopyData struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CopyData) Backend() {}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*CopyData) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CopyData) Decode(src []byte) error {
|
||||
dst.Data = src
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyData) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'd')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
|
||||
dst = append(dst, src.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CopyData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data string
|
||||
}{
|
||||
Type: "CopyData",
|
||||
Data: hex.EncodeToString(src.Data),
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyData) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Data string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Data = []byte(msg.Data)
|
||||
return nil
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type CopyDone struct {
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CopyDone) Backend() {}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*CopyDone) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CopyDone) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyDone) Encode(dst []byte) []byte {
|
||||
return append(dst, 'c', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CopyDone) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "CopyDone",
|
||||
})
|
||||
}
|
@ -1,53 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type CopyFail struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*CopyFail) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CopyFail) Decode(src []byte) error {
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyFail"}
|
||||
}
|
||||
|
||||
dst.Message = string(src[:idx])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyFail) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'f')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Message...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CopyFail) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Message string
|
||||
}{
|
||||
Type: "CopyFail",
|
||||
Message: src.Message,
|
||||
})
|
||||
}
|
@ -1,96 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type CopyInResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*CopyInResponse) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CopyInResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyInResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'G')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.OverallFormat)
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CopyInResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyInResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyInResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
OverallFormat string
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.OverallFormat) != 1 {
|
||||
return errors.New("invalid length for CopyInResponse.OverallFormat")
|
||||
}
|
||||
|
||||
dst.OverallFormat = msg.OverallFormat[0]
|
||||
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||
return nil
|
||||
}
|
@ -1,96 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type CopyOutResponse struct {
|
||||
OverallFormat byte
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
|
||||
func (*CopyOutResponse) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *CopyOutResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 3 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||
}
|
||||
|
||||
overallFormat := buf.Next(1)[0]
|
||||
|
||||
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
if buf.Len() != columnCount*2 {
|
||||
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
|
||||
}
|
||||
|
||||
columnFormatCodes := make([]uint16, columnCount)
|
||||
for i := 0; i < columnCount; i++ {
|
||||
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
|
||||
}
|
||||
|
||||
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *CopyOutResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'H')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.OverallFormat)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
|
||||
for _, fc := range src.ColumnFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, fc)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ColumnFormatCodes []uint16
|
||||
}{
|
||||
Type: "CopyOutResponse",
|
||||
ColumnFormatCodes: src.ColumnFormatCodes,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
OverallFormat string
|
||||
ColumnFormatCodes []uint16
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msg.OverallFormat) != 1 {
|
||||
return errors.New("invalid length for CopyOutResponse.OverallFormat")
|
||||
}
|
||||
|
||||
dst.OverallFormat = msg.OverallFormat[0]
|
||||
dst.ColumnFormatCodes = msg.ColumnFormatCodes
|
||||
return nil
|
||||
}
|
@ -1,142 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type DataRow struct {
|
||||
Values [][]byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*DataRow) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *DataRow) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
rp := 0
|
||||
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
|
||||
// If the capacity of the values slice is too small OR substantially too
|
||||
// large reallocate. This is too avoid one row with many columns from
|
||||
// permanently allocating memory.
|
||||
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
|
||||
newCap := 32
|
||||
if newCap < fieldCount {
|
||||
newCap = fieldCount
|
||||
}
|
||||
dst.Values = make([][]byte, fieldCount, newCap)
|
||||
} else {
|
||||
dst.Values = dst.Values[:fieldCount]
|
||||
}
|
||||
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
if len(src[rp:]) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
dst.Values[i] = nil
|
||||
} else {
|
||||
if len(src[rp:]) < msgSize {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
dst.Values[i] = src[rp : rp+msgSize : rp+msgSize]
|
||||
rp += msgSize
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *DataRow) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
|
||||
for _, v := range src.Values {
|
||||
if v == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
continue
|
||||
}
|
||||
|
||||
dst = pgio.AppendInt32(dst, int32(len(v)))
|
||||
dst = append(dst, v...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src DataRow) MarshalJSON() ([]byte, error) {
|
||||
formattedValues := make([]map[string]string, len(src.Values))
|
||||
for i, v := range src.Values {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var hasNonPrintable bool
|
||||
for _, b := range v {
|
||||
if b < 32 {
|
||||
hasNonPrintable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasNonPrintable {
|
||||
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
|
||||
} else {
|
||||
formattedValues[i] = map[string]string{"text": string(v)}
|
||||
}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Values []map[string]string
|
||||
}{
|
||||
Type: "DataRow",
|
||||
Values: formattedValues,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *DataRow) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Values []map[string]string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Values = make([][]byte, len(msg.Values))
|
||||
for n, parameter := range msg.Values {
|
||||
var err error
|
||||
dst.Values[n], err = getValueFromJSON(parameter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,88 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type Describe struct {
|
||||
ObjectType byte // 'S' = prepared statement, 'P' = portal
|
||||
Name string
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*Describe) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *Describe) Decode(src []byte) error {
|
||||
if len(src) < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||
}
|
||||
|
||||
dst.ObjectType = src[0]
|
||||
rp := 1
|
||||
|
||||
idx := bytes.IndexByte(src[rp:], 0)
|
||||
if idx != len(src[rp:])-1 {
|
||||
return &invalidMessageFormatErr{messageType: "Describe"}
|
||||
}
|
||||
|
||||
dst.Name = string(src[rp : len(src)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Describe) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'D')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.ObjectType)
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src Describe) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ObjectType string
|
||||
Name string
|
||||
}{
|
||||
Type: "Describe",
|
||||
ObjectType: string(src.ObjectType),
|
||||
Name: src.Name,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *Describe) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
ObjectType string
|
||||
Name string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(msg.ObjectType) != 1 {
|
||||
return errors.New("invalid length for Describe.ObjectType")
|
||||
}
|
||||
|
||||
dst.ObjectType = byte(msg.ObjectType[0])
|
||||
dst.Name = msg.Name
|
||||
return nil
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
|
||||
//
|
||||
// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages.
|
||||
package pgproto3
|
@ -1,34 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type EmptyQueryResponse struct{}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*EmptyQueryResponse) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *EmptyQueryResponse) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, 'I', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "EmptyQueryResponse",
|
||||
})
|
||||
}
|
@ -1,334 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
Severity string
|
||||
SeverityUnlocalized string // only in 9.6 and greater
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
|
||||
UnknownFields map[byte]string
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*ErrorResponse) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *ErrorResponse) Decode(src []byte) error {
|
||||
*dst = ErrorResponse{}
|
||||
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
for {
|
||||
k, err := buf.ReadByte()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if k == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
vb, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
v := string(vb[:len(vb)-1])
|
||||
|
||||
switch k {
|
||||
case 'S':
|
||||
dst.Severity = v
|
||||
case 'V':
|
||||
dst.SeverityUnlocalized = v
|
||||
case 'C':
|
||||
dst.Code = v
|
||||
case 'M':
|
||||
dst.Message = v
|
||||
case 'D':
|
||||
dst.Detail = v
|
||||
case 'H':
|
||||
dst.Hint = v
|
||||
case 'P':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.Position = int32(n)
|
||||
case 'p':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.InternalPosition = int32(n)
|
||||
case 'q':
|
||||
dst.InternalQuery = v
|
||||
case 'W':
|
||||
dst.Where = v
|
||||
case 's':
|
||||
dst.SchemaName = v
|
||||
case 't':
|
||||
dst.TableName = v
|
||||
case 'c':
|
||||
dst.ColumnName = v
|
||||
case 'd':
|
||||
dst.DataTypeName = v
|
||||
case 'n':
|
||||
dst.ConstraintName = v
|
||||
case 'F':
|
||||
dst.File = v
|
||||
case 'L':
|
||||
s := v
|
||||
n, _ := strconv.ParseInt(s, 10, 32)
|
||||
dst.Line = int32(n)
|
||||
case 'R':
|
||||
dst.Routine = v
|
||||
|
||||
default:
|
||||
if dst.UnknownFields == nil {
|
||||
dst.UnknownFields = make(map[byte]string)
|
||||
}
|
||||
dst.UnknownFields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ErrorResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, src.marshalBinary('E')...)
|
||||
}
|
||||
|
||||
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
|
||||
var bigEndian BigEndianBuf
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte(typeByte)
|
||||
buf.Write(bigEndian.Uint32(0))
|
||||
|
||||
if src.Severity != "" {
|
||||
buf.WriteByte('S')
|
||||
buf.WriteString(src.Severity)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.SeverityUnlocalized != "" {
|
||||
buf.WriteByte('V')
|
||||
buf.WriteString(src.SeverityUnlocalized)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Code != "" {
|
||||
buf.WriteByte('C')
|
||||
buf.WriteString(src.Code)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Message != "" {
|
||||
buf.WriteByte('M')
|
||||
buf.WriteString(src.Message)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Detail != "" {
|
||||
buf.WriteByte('D')
|
||||
buf.WriteString(src.Detail)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Hint != "" {
|
||||
buf.WriteByte('H')
|
||||
buf.WriteString(src.Hint)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Position != 0 {
|
||||
buf.WriteByte('P')
|
||||
buf.WriteString(strconv.Itoa(int(src.Position)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.InternalPosition != 0 {
|
||||
buf.WriteByte('p')
|
||||
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.InternalQuery != "" {
|
||||
buf.WriteByte('q')
|
||||
buf.WriteString(src.InternalQuery)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Where != "" {
|
||||
buf.WriteByte('W')
|
||||
buf.WriteString(src.Where)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.SchemaName != "" {
|
||||
buf.WriteByte('s')
|
||||
buf.WriteString(src.SchemaName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.TableName != "" {
|
||||
buf.WriteByte('t')
|
||||
buf.WriteString(src.TableName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.ColumnName != "" {
|
||||
buf.WriteByte('c')
|
||||
buf.WriteString(src.ColumnName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.DataTypeName != "" {
|
||||
buf.WriteByte('d')
|
||||
buf.WriteString(src.DataTypeName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.ConstraintName != "" {
|
||||
buf.WriteByte('n')
|
||||
buf.WriteString(src.ConstraintName)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.File != "" {
|
||||
buf.WriteByte('F')
|
||||
buf.WriteString(src.File)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Line != 0 {
|
||||
buf.WriteByte('L')
|
||||
buf.WriteString(strconv.Itoa(int(src.Line)))
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
if src.Routine != "" {
|
||||
buf.WriteByte('R')
|
||||
buf.WriteString(src.Routine)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
|
||||
for k, v := range src.UnknownFields {
|
||||
buf.WriteByte(k)
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(v)
|
||||
buf.WriteByte(0)
|
||||
}
|
||||
|
||||
buf.WriteByte(0)
|
||||
|
||||
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src ErrorResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Severity string
|
||||
SeverityUnlocalized string // only in 9.6 and greater
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
|
||||
UnknownFields map[byte]string
|
||||
}{
|
||||
Type: "ErrorResponse",
|
||||
Severity: src.Severity,
|
||||
SeverityUnlocalized: src.SeverityUnlocalized,
|
||||
Code: src.Code,
|
||||
Message: src.Message,
|
||||
Detail: src.Detail,
|
||||
Hint: src.Hint,
|
||||
Position: src.Position,
|
||||
InternalPosition: src.InternalPosition,
|
||||
InternalQuery: src.InternalQuery,
|
||||
Where: src.Where,
|
||||
SchemaName: src.SchemaName,
|
||||
TableName: src.TableName,
|
||||
ColumnName: src.ColumnName,
|
||||
DataTypeName: src.DataTypeName,
|
||||
ConstraintName: src.ConstraintName,
|
||||
File: src.File,
|
||||
Line: src.Line,
|
||||
Routine: src.Routine,
|
||||
UnknownFields: src.UnknownFields,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Type string
|
||||
Severity string
|
||||
SeverityUnlocalized string // only in 9.6 and greater
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
|
||||
UnknownFields map[byte]string
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dst.Severity = msg.Severity
|
||||
dst.SeverityUnlocalized = msg.SeverityUnlocalized
|
||||
dst.Code = msg.Code
|
||||
dst.Message = msg.Message
|
||||
dst.Detail = msg.Detail
|
||||
dst.Hint = msg.Hint
|
||||
dst.Position = msg.Position
|
||||
dst.InternalPosition = msg.InternalPosition
|
||||
dst.InternalQuery = msg.InternalQuery
|
||||
dst.Where = msg.Where
|
||||
dst.SchemaName = msg.SchemaName
|
||||
dst.TableName = msg.TableName
|
||||
dst.ColumnName = msg.ColumnName
|
||||
dst.DataTypeName = msg.DataTypeName
|
||||
dst.ConstraintName = msg.ConstraintName
|
||||
dst.File = msg.File
|
||||
dst.Line = msg.Line
|
||||
dst.Routine = msg.Routine
|
||||
|
||||
dst.UnknownFields = msg.UnknownFields
|
||||
|
||||
return nil
|
||||
}
|
@ -1,65 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type Execute struct {
|
||||
Portal string
|
||||
MaxRows uint32
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*Execute) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *Execute) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Portal = string(b[:len(b)-1])
|
||||
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Execute"}
|
||||
}
|
||||
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Execute) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'E')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Portal...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint32(dst, src.MaxRows)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src Execute) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Portal string
|
||||
MaxRows uint32
|
||||
}{
|
||||
Type: "Execute",
|
||||
Portal: src.Portal,
|
||||
MaxRows: src.MaxRows,
|
||||
})
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Flush struct{}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*Flush) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *Flush) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Flush) Encode(dst []byte) []byte {
|
||||
return append(dst, 'H', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src Flush) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "Flush",
|
||||
})
|
||||
}
|
@ -1,206 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
|
||||
type Frontend struct {
|
||||
cr ChunkReader
|
||||
w io.Writer
|
||||
|
||||
// Backend message flyweights
|
||||
authenticationOk AuthenticationOk
|
||||
authenticationCleartextPassword AuthenticationCleartextPassword
|
||||
authenticationMD5Password AuthenticationMD5Password
|
||||
authenticationGSS AuthenticationGSS
|
||||
authenticationGSSContinue AuthenticationGSSContinue
|
||||
authenticationSASL AuthenticationSASL
|
||||
authenticationSASLContinue AuthenticationSASLContinue
|
||||
authenticationSASLFinal AuthenticationSASLFinal
|
||||
backendKeyData BackendKeyData
|
||||
bindComplete BindComplete
|
||||
closeComplete CloseComplete
|
||||
commandComplete CommandComplete
|
||||
copyBothResponse CopyBothResponse
|
||||
copyData CopyData
|
||||
copyInResponse CopyInResponse
|
||||
copyOutResponse CopyOutResponse
|
||||
copyDone CopyDone
|
||||
dataRow DataRow
|
||||
emptyQueryResponse EmptyQueryResponse
|
||||
errorResponse ErrorResponse
|
||||
functionCallResponse FunctionCallResponse
|
||||
noData NoData
|
||||
noticeResponse NoticeResponse
|
||||
notificationResponse NotificationResponse
|
||||
parameterDescription ParameterDescription
|
||||
parameterStatus ParameterStatus
|
||||
parseComplete ParseComplete
|
||||
readyForQuery ReadyForQuery
|
||||
rowDescription RowDescription
|
||||
portalSuspended PortalSuspended
|
||||
|
||||
bodyLen int
|
||||
msgType byte
|
||||
partialMsg bool
|
||||
authType uint32
|
||||
}
|
||||
|
||||
// NewFrontend creates a new Frontend.
|
||||
func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
|
||||
return &Frontend{cr: cr, w: w}
|
||||
}
|
||||
|
||||
// Send sends a message to the backend.
|
||||
func (f *Frontend) Send(msg FrontendMessage) error {
|
||||
_, err := f.w.Write(msg.Encode(nil))
|
||||
return err
|
||||
}
|
||||
|
||||
func translateEOFtoErrUnexpectedEOF(err error) error {
|
||||
if err == io.EOF {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
|
||||
func (f *Frontend) Receive() (BackendMessage, error) {
|
||||
if !f.partialMsg {
|
||||
header, err := f.cr.Next(5)
|
||||
if err != nil {
|
||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||
}
|
||||
|
||||
f.msgType = header[0]
|
||||
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
|
||||
f.partialMsg = true
|
||||
if f.bodyLen < 0 {
|
||||
return nil, errors.New("invalid message with negative body length received")
|
||||
}
|
||||
}
|
||||
|
||||
msgBody, err := f.cr.Next(f.bodyLen)
|
||||
if err != nil {
|
||||
return nil, translateEOFtoErrUnexpectedEOF(err)
|
||||
}
|
||||
|
||||
f.partialMsg = false
|
||||
|
||||
var msg BackendMessage
|
||||
switch f.msgType {
|
||||
case '1':
|
||||
msg = &f.parseComplete
|
||||
case '2':
|
||||
msg = &f.bindComplete
|
||||
case '3':
|
||||
msg = &f.closeComplete
|
||||
case 'A':
|
||||
msg = &f.notificationResponse
|
||||
case 'c':
|
||||
msg = &f.copyDone
|
||||
case 'C':
|
||||
msg = &f.commandComplete
|
||||
case 'd':
|
||||
msg = &f.copyData
|
||||
case 'D':
|
||||
msg = &f.dataRow
|
||||
case 'E':
|
||||
msg = &f.errorResponse
|
||||
case 'G':
|
||||
msg = &f.copyInResponse
|
||||
case 'H':
|
||||
msg = &f.copyOutResponse
|
||||
case 'I':
|
||||
msg = &f.emptyQueryResponse
|
||||
case 'K':
|
||||
msg = &f.backendKeyData
|
||||
case 'n':
|
||||
msg = &f.noData
|
||||
case 'N':
|
||||
msg = &f.noticeResponse
|
||||
case 'R':
|
||||
var err error
|
||||
msg, err = f.findAuthenticationMessageType(msgBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case 's':
|
||||
msg = &f.portalSuspended
|
||||
case 'S':
|
||||
msg = &f.parameterStatus
|
||||
case 't':
|
||||
msg = &f.parameterDescription
|
||||
case 'T':
|
||||
msg = &f.rowDescription
|
||||
case 'V':
|
||||
msg = &f.functionCallResponse
|
||||
case 'W':
|
||||
msg = &f.copyBothResponse
|
||||
case 'Z':
|
||||
msg = &f.readyForQuery
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
|
||||
}
|
||||
|
||||
err = msg.Decode(msgBody)
|
||||
return msg, err
|
||||
}
|
||||
|
||||
// Authentication message type constants.
|
||||
// See src/include/libpq/pqcomm.h for all
|
||||
// constants.
|
||||
const (
|
||||
AuthTypeOk = 0
|
||||
AuthTypeCleartextPassword = 3
|
||||
AuthTypeMD5Password = 5
|
||||
AuthTypeSCMCreds = 6
|
||||
AuthTypeGSS = 7
|
||||
AuthTypeGSSCont = 8
|
||||
AuthTypeSSPI = 9
|
||||
AuthTypeSASL = 10
|
||||
AuthTypeSASLContinue = 11
|
||||
AuthTypeSASLFinal = 12
|
||||
)
|
||||
|
||||
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
|
||||
if len(src) < 4 {
|
||||
return nil, errors.New("authentication message too short")
|
||||
}
|
||||
f.authType = binary.BigEndian.Uint32(src[:4])
|
||||
|
||||
switch f.authType {
|
||||
case AuthTypeOk:
|
||||
return &f.authenticationOk, nil
|
||||
case AuthTypeCleartextPassword:
|
||||
return &f.authenticationCleartextPassword, nil
|
||||
case AuthTypeMD5Password:
|
||||
return &f.authenticationMD5Password, nil
|
||||
case AuthTypeSCMCreds:
|
||||
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
|
||||
case AuthTypeGSS:
|
||||
return &f.authenticationGSS, nil
|
||||
case AuthTypeGSSCont:
|
||||
return &f.authenticationGSSContinue, nil
|
||||
case AuthTypeSSPI:
|
||||
return nil, errors.New("AuthTypeSSPI is unimplemented")
|
||||
case AuthTypeSASL:
|
||||
return &f.authenticationSASL, nil
|
||||
case AuthTypeSASLContinue:
|
||||
return &f.authenticationSASLContinue, nil
|
||||
case AuthTypeSASLFinal:
|
||||
return &f.authenticationSASLFinal, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthType returns the authType used in the current state of the frontend.
|
||||
// See SetAuthType for more information.
|
||||
func (f *Frontend) GetAuthType() uint32 {
|
||||
return f.authType
|
||||
}
|
@ -1,94 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type FunctionCall struct {
|
||||
Function uint32
|
||||
ArgFormatCodes []uint16
|
||||
Arguments [][]byte
|
||||
ResultFormatCode uint16
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*FunctionCall) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *FunctionCall) Decode(src []byte) error {
|
||||
*dst = FunctionCall{}
|
||||
rp := 0
|
||||
// Specifies the object ID of the function to call.
|
||||
dst.Function = binary.BigEndian.Uint32(src[rp:])
|
||||
rp += 4
|
||||
// The number of argument format codes that follow (denoted C below).
|
||||
// This can be zero to indicate that there are no arguments or that the arguments all use the default format (text);
|
||||
// or one, in which case the specified format code is applied to all arguments;
|
||||
// or it can equal the actual number of arguments.
|
||||
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
argumentCodes := make([]uint16, nArgumentCodes)
|
||||
for i := 0; i < nArgumentCodes; i++ {
|
||||
// The argument format codes. Each must presently be zero (text) or one (binary).
|
||||
ac := binary.BigEndian.Uint16(src[rp:])
|
||||
if ac != 0 && ac != 1 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||
}
|
||||
argumentCodes[i] = ac
|
||||
rp += 2
|
||||
}
|
||||
dst.ArgFormatCodes = argumentCodes
|
||||
|
||||
// Specifies the number of arguments being supplied to the function.
|
||||
nArguments := int(binary.BigEndian.Uint16(src[rp:]))
|
||||
rp += 2
|
||||
arguments := make([][]byte, nArguments)
|
||||
for i := 0; i < nArguments; i++ {
|
||||
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
|
||||
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
|
||||
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
if argumentLength == -1 {
|
||||
arguments[i] = nil
|
||||
} else {
|
||||
// The value of the argument, in the format indicated by the associated format code. n is the above length.
|
||||
argumentValue := src[rp : rp+argumentLength]
|
||||
rp += argumentLength
|
||||
arguments[i] = argumentValue
|
||||
}
|
||||
}
|
||||
dst.Arguments = arguments
|
||||
// The format code for the function result. Must presently be zero (text) or one (binary).
|
||||
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
|
||||
if resultFormatCode != 0 && resultFormatCode != 1 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCall"}
|
||||
}
|
||||
dst.ResultFormatCode = resultFormatCode
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *FunctionCall) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'F')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
|
||||
dst = pgio.AppendUint32(dst, src.Function)
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
|
||||
for _, argFormatCode := range src.ArgFormatCodes {
|
||||
dst = pgio.AppendUint16(dst, argFormatCode)
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
|
||||
for _, argument := range src.Arguments {
|
||||
if argument == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
} else {
|
||||
dst = pgio.AppendInt32(dst, int32(len(argument)))
|
||||
dst = append(dst, argument...)
|
||||
}
|
||||
}
|
||||
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
return dst
|
||||
}
|
@ -1,101 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type FunctionCallResponse struct {
|
||||
Result []byte
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*FunctionCallResponse) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *FunctionCallResponse) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
rp := 0
|
||||
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
|
||||
rp += 4
|
||||
|
||||
if resultSize == -1 {
|
||||
dst.Result = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(src[rp:]) != resultSize {
|
||||
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
|
||||
}
|
||||
|
||||
dst.Result = src[rp:]
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'V')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
if src.Result == nil {
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
} else {
|
||||
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
|
||||
dst = append(dst, src.Result...)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src FunctionCallResponse) MarshalJSON() ([]byte, error) {
|
||||
var formattedValue map[string]string
|
||||
var hasNonPrintable bool
|
||||
for _, b := range src.Result {
|
||||
if b < 32 {
|
||||
hasNonPrintable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if hasNonPrintable {
|
||||
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
|
||||
} else {
|
||||
formattedValue = map[string]string{"text": string(src.Result)}
|
||||
}
|
||||
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Result map[string]string
|
||||
}{
|
||||
Type: "FunctionCallResponse",
|
||||
Result: formattedValue,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error {
|
||||
// Ignore null, like in the main JSON package.
|
||||
if string(data) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var msg struct {
|
||||
Result map[string]string
|
||||
}
|
||||
err := json.Unmarshal(data, &msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Result, err = getValueFromJSON(msg.Result)
|
||||
return err
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
const gssEncReqNumber = 80877104
|
||||
|
||||
type GSSEncRequest struct {
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*GSSEncRequest) Frontend() {}
|
||||
|
||||
func (dst *GSSEncRequest) Decode(src []byte) error {
|
||||
if len(src) < 4 {
|
||||
return errors.New("gss encoding request too short")
|
||||
}
|
||||
|
||||
requestCode := binary.BigEndian.Uint32(src)
|
||||
|
||||
if requestCode != gssEncReqNumber {
|
||||
return errors.New("bad gss encoding request code")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 4 byte message length.
|
||||
func (src *GSSEncRequest) Encode(dst []byte) []byte {
|
||||
dst = pgio.AppendInt32(dst, 8)
|
||||
dst = pgio.AppendInt32(dst, gssEncReqNumber)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src GSSEncRequest) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ProtocolVersion uint32
|
||||
Parameters map[string]string
|
||||
}{
|
||||
Type: "GSSEncRequest",
|
||||
})
|
||||
}
|
@ -1,48 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type GSSResponse struct {
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (g *GSSResponse) Frontend() {}
|
||||
|
||||
func (g *GSSResponse) Decode(data []byte) error {
|
||||
g.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GSSResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
|
||||
dst = append(dst, g.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Data []byte
|
||||
}{
|
||||
Type: "GSSResponse",
|
||||
Data: g.Data,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements encoding/json.Unmarshaler.
|
||||
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
|
||||
var msg struct {
|
||||
Data []byte
|
||||
}
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
g.Data = msg.Data
|
||||
return nil
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type NoData struct{}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*NoData) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *NoData) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NoData) Encode(dst []byte) []byte {
|
||||
return append(dst, 'n', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src NoData) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "NoData",
|
||||
})
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
type NoticeResponse ErrorResponse
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*NoticeResponse) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *NoticeResponse) Decode(src []byte) error {
|
||||
return (*ErrorResponse)(dst).Decode(src)
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NoticeResponse) Encode(dst []byte) []byte {
|
||||
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type NotificationResponse struct {
|
||||
PID uint32
|
||||
Channel string
|
||||
Payload string
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*NotificationResponse) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *NotificationResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
pid := binary.BigEndian.Uint32(buf.Next(4))
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
channel := string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := string(b[:len(b)-1])
|
||||
|
||||
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *NotificationResponse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'A')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint32(dst, src.PID)
|
||||
dst = append(dst, src.Channel...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Payload...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src NotificationResponse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
PID uint32
|
||||
Channel string
|
||||
Payload string
|
||||
}{
|
||||
Type: "NotificationResponse",
|
||||
PID: src.PID,
|
||||
Channel: src.Channel,
|
||||
Payload: src.Payload,
|
||||
})
|
||||
}
|
@ -1,66 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type ParameterDescription struct {
|
||||
ParameterOIDs []uint32
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*ParameterDescription) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *ParameterDescription) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
|
||||
}
|
||||
|
||||
// Reported parameter count will be incorrect when number of args is greater than uint16
|
||||
buf.Next(2)
|
||||
// Instead infer parameter count by remaining size of message
|
||||
parameterCount := buf.Len() / 4
|
||||
|
||||
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
|
||||
|
||||
for i := 0; i < parameterCount; i++ {
|
||||
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ParameterDescription) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 't')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src ParameterDescription) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
ParameterOIDs []uint32
|
||||
}{
|
||||
Type: "ParameterDescription",
|
||||
ParameterOIDs: src.ParameterOIDs,
|
||||
})
|
||||
}
|
@ -1,66 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type ParameterStatus struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*ParameterStatus) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *ParameterStatus) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
name := string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
value := string(b[:len(b)-1])
|
||||
|
||||
*dst = ParameterStatus{Name: name, Value: value}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ParameterStatus) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'S')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Value...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (ps ParameterStatus) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Name string
|
||||
Value string
|
||||
}{
|
||||
Type: "ParameterStatus",
|
||||
Name: ps.Name,
|
||||
Value: ps.Value,
|
||||
})
|
||||
}
|
@ -1,88 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type Parse struct {
|
||||
Name string
|
||||
Query string
|
||||
ParameterOIDs []uint32
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*Parse) Frontend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *Parse) Decode(src []byte) error {
|
||||
*dst = Parse{}
|
||||
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Name = string(b[:len(b)-1])
|
||||
|
||||
b, err = buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Query = string(b[:len(b)-1])
|
||||
|
||||
if buf.Len() < 2 {
|
||||
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||
}
|
||||
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
|
||||
|
||||
for i := 0; i < parameterOIDCount; i++ {
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "Parse"}
|
||||
}
|
||||
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *Parse) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'P')
|
||||
sp := len(dst)
|
||||
dst = pgio.AppendInt32(dst, -1)
|
||||
|
||||
dst = append(dst, src.Name...)
|
||||
dst = append(dst, 0)
|
||||
dst = append(dst, src.Query...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
|
||||
for _, oid := range src.ParameterOIDs {
|
||||
dst = pgio.AppendUint32(dst, oid)
|
||||
}
|
||||
|
||||
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src Parse) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Name string
|
||||
Query string
|
||||
ParameterOIDs []uint32
|
||||
}{
|
||||
Type: "Parse",
|
||||
Name: src.Name,
|
||||
Query: src.Query,
|
||||
ParameterOIDs: src.ParameterOIDs,
|
||||
})
|
||||
}
|
@ -1,34 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type ParseComplete struct{}
|
||||
|
||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||
func (*ParseComplete) Backend() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *ParseComplete) Decode(src []byte) error {
|
||||
if len(src) != 0 {
|
||||
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *ParseComplete) Encode(dst []byte) []byte {
|
||||
return append(dst, '1', 0, 0, 0, 4)
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src ParseComplete) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
}{
|
||||
Type: "ParseComplete",
|
||||
})
|
||||
}
|
@ -1,54 +0,0 @@
|
||||
package pgproto3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/jackc/pgio"
|
||||
)
|
||||
|
||||
type PasswordMessage struct {
|
||||
Password string
|
||||
}
|
||||
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*PasswordMessage) Frontend() {}
|
||||
|
||||
// Frontend identifies this message as an authentication response.
|
||||
func (*PasswordMessage) InitialResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
// type identifier and 4 byte message length.
|
||||
func (dst *PasswordMessage) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dst.Password = string(b[:len(b)-1])
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
|
||||
func (src *PasswordMessage) Encode(dst []byte) []byte {
|
||||
dst = append(dst, 'p')
|
||||
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
|
||||
|
||||
dst = append(dst, src.Password...)
|
||||
dst = append(dst, 0)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
// MarshalJSON implements encoding/json.Marshaler.
|
||||
func (src PasswordMessage) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
Type string
|
||||
Password string
|
||||
}{
|
||||
Type: "PasswordMessage",
|
||||
Password: src.Password,
|
||||
})
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue