- update vendor

master
李光春 1 year ago
parent 4aedb95281
commit f509f0087f

@ -1,34 +0,0 @@
# idea ignore
.idea/
*.ipr
*.iml
*.iws
.vscode/
*.swp
# temp ignore
*.log
*.cache
*.diff
*.exe
*.exe~
*.patch
*.tmp
*debug.test
debug.test
go.sum
# system ignore
.DS_Store
Thumbs.db
# project
*.cert
*.key
.test
iprepo.txt
_output

@ -1,286 +0,0 @@
v1.6.6
- 感谢 @encircles 的pr,使用FuncWrapFieldTagName函数自定义Tag列名
- 简化查询逻辑,统一reBindSQL,不覆盖finder参数值,提升性能
- 修复获取自增主键异常
- 完善文档,注释
v1.6.5
- TDengineInsertsColumnName TDengine批量insert语句中是否有列名.默认false没有列名,插入值和数据库列顺序保持一致,减少语句长度
- 调整FuncGlobalTransaction函数返回值,支持seata-go
- 完善文档,注释
v1.6.4
- 感谢@haifengat 的场景反馈,完善NUMBER类型的数据接收
- RegisterCustomDriverValueConver函数的 dialectColumnType 参数修改为 Dialect.字段类型 ,例如:dm.TEXT
- 增加FuncDecimalValue函数,设置decimal类型接收值,复写函数自定义decimal实现
- NewSelectFinder方法参数strs取值第一个字符串
- 感谢@soldier_of_love 的场景反馈,error日志记录执行的sql和参数值
- 清理无效的代码和注释
- 完善文档,注释
v1.6.3
- 感谢@rebens 的场景反馈,增加InsertEntityMapSlice函数,批量保存EntityMap
- 感谢@haifengat 的场景反馈,ICustomDriverValueConver增加structFieldType *reflect.Type入参
- 感谢@zhou-a-xing 调整匿名结构体字段顺序
- 感谢@rebens 反馈的问题,避免IEntityMap默认实现IEntityStruct接口
- 感谢@cucuy 对www.zorm.cn官网的修改
- 完善文档,注释
v1.6.2
- 捕获panic,赋值给err,避免程序崩溃
- 增加sqlserver 和 oracle 分页默认order by
- 录制视频教程:https://www.bilibili.com/video/BV1L24y1976U/
- 完善文档,注释
v1.6.1
- 使用RegisterCustomDriverValueConver函数替代CustomDriverValueMap变量,将 ```zorm.CustomDriverValueMap["*dm.DmClob"] = CustomDMText{}```修改为```zorm.RegisterCustomDriverValueConver("TEXT", CustomDMText{})```,达梦数据库重新复制示例代码,重新复制!!重新复制!!!
- 重写sqlRowsValues函数,支持查询单个字段,Struct类型接收
- 简化自增序列的实现,使用string代替map[string]string
- 使用OverrideFunc重写zorm的函数,暴露WrapUpdateStructFinder函数
- 去掉kingbase列的大写转换,修改字符串拼接方式,提升性能
- BindContextDisableTransaction 用在不使用事务更新数据库的场景,强烈建议不要使用这个方法,更新数据库必须有事务!!!
- 增加查询没有返回列的判断,特殊情况可以使用Query执行更新语句,绕过事务检查(不建议)
- 更新官网 https://zorm.cn
- 完善文档,注释
v1.6.0
- 更新漂亮的logo
- 增加db2数据支持,依赖Limit分页语法
- DBType即将废弃,更名为Dialect,方便gorm和xorm迁移
- FuncReadWriteStrategy和GetGTXID函数增加error返回值
- 修改日志格式,统一加上 -> 符号
- 曾经偷的懒还是还上吧,类型转换加上err返回值.去掉无用的日期格式转换,驱动获取的并不是[]byte
- 修复Finder.Append和GetSQL为nil的bug
- 完善文档,注释
v1.5.9
- hptx已合并@小口天的pr, [hptx代理模式zorm使用示例](https://github.com/CECTC/hptx-samples/tree/main/http_proxy_zorm) 和 [zorm事务托管hptx示例](https://github.com/CECTC/hptx-samples/tree/main/http_zorm)
- 增加IsInTransaction(ctx)函数,检查是否在事务内
- 扩展函数统一加上ctx入参,方便场景自定义扩展
- 取消PrintSQL参数,使用SlowSQLMillis控制输出慢sql语句
- 完善文档,注释
v1.5.8
更新内容:
- 感谢 @zhou-a-xing 编写TDengine的测试用例,不允许手动拼接 '?' 单引号,强制使用?,书写统一
- 感谢 @小口天 反馈的bug和编写hptx测试用例,修改全局事务接口方法名,避免和gtx方法名一致造成递归调用
- 取消自动开启全局事务,必须手动zorm.BindContextEnableGlobalTransaction(ctx)开启全局事务
- 重构 reBindSQL 函数,在SQL最后执行前统一处理
- 吐槽很久的switch代替if else
- 完善文档,注释
v1.5.7
更新内容:
- 感谢 @小口天 的辛苦付出,https://gitee.com/wuxiangege/zorm-examples 测试用例已经非常完善.
- 按照反射获取的Struct属性顺序,生成insert语句和update语句
- 支持TDengine数据库,因TDengine驱动不支持事务,需要设置DisableTransaction=true
- 增加hptx和dbpack分布式事务的支持,细粒度控制是否使用全局事务
- DisableTransaction用于全局禁用数据库事务,用于不支持事务的数据库驱动.
- 完善文档,注释
v1.5.6
更新内容:
- 感谢@无泪发现Transaction方法返回值为nil的bug,已修复
- 感谢社区贡献,https://zorm.cn 官网上线,很丑的logo上线 :).
- 支持已经存在的数据库连接
- 修改panic的异常记录和主键零值判断,用于支持基础类型扩展的主键
- 完善文档,注释
v1.5.5
更新内容:
- 增加CloseDB函数,关闭数据库连接池
- 完善文档,注释
v1.5.4
更新内容:
- QueryRow如果查询一个字段,而且这个字段数据库为null,会有异常,没有赋为默认值
- reflect.Type 类型的参数,修改为 *reflect.Type 指针,包括CustomDriverValueConver接口的参数
- 完善文档,注释
v1.5.3
更新内容:
- 感谢@Howard.TSE的建议,判断配置是否为空
- 感谢@haming123反馈性能问题.zorm 1.2.x 版本实现了基础功能,读性能比gorm和xorm快一倍.随着功能持续增加,造成性能下降,目前读性能只快了50%.
- 性能优化,去掉不必要的反射
- 完善文档,注释
v1.5.2
更新内容:
- 感谢奔跑(@zeqjone)提供的正则,排除不在括号内的from,已经满足绝大部分场景
- 感谢奔跑(@zeqjone) pr,修复 金仓数据库模型定义中tag数据库列标签与数据库内置关键词冲突时,加双引号处理
- 升级 decimal 到1.3.1
- 完善文档,注释
v1.5.1
更新内容:
- 完善文档,注释
- 注释未使用的代码
- 先判断error,再执行defer rows.Close()
- 增加微信社区支持(负责人是八块腹肌的单身小伙 @zhou-a-xing)
v1.5.0
更新内容:
- 完善文档,注释
- 支持clickhouse,更新,删除语句使用SQL92标准语法
- ID默认使用时间戳+随机数,代替UUID实现
- 优化SQL提取的正则表达式
- 集成seata-golang,支持全局托管,不修改业务代码,零侵入分布式事务
v1.4.9
更新内容:
- 完善文档,注释
- 摊牌了,不装了,就是修改注释,刷刷版本活跃度
v1.4.8
更新内容:
- 完善文档,注释
- 数据库字段和实体类额外映射时,支持 _ 下划线转驼峰
v1.4.7
更新内容:
- 情人节版本,返回map时,如果无法正常转换值类型,就返回原值,而不是nil
v1.4.6
更新内容:
- 完善文档,注释
- 千行代码,胜他十万,牛气冲天,zorm零依赖.(uuid和decimal这两个工具包竟然有1700行代码)
- 在涉密内网开发环境中,零依赖能减少很多麻烦,做不到请不要说没必要......
v1.4.5
更新内容:
- 增强自定义类型转换的功能
- 完善文档,注释
- 非常感谢 @anxuanzi 完善代码生成器
- 非常感谢 @chien_tung 增加changelog,以后版本发布都会记录changelog
v1.4.4
更新内容:
- 如果查询的字段在column tag中没有找到,就会根据名称(不区分大小写)映射到struct的属性上
- 给QueryRow方法增加 has 的返回值,标识数据库是有一条记录的,各位已经使用的大佬,升级时注意修改代码,非常抱歉*3
v1.4.3
更新内容:
- 正式支持南大通用(gbase)数据库,完成国产四库的适配
- 增加设置全局事务隔离级别和单个事务的隔离级别
- 修复触发器自增主键的逻辑bug
- 文档完善和细节调整
v1.4.2
更新内容:
- 正式支持神州通用(shentong)数据库
- 完善pgsql和kingbase的自增主键返回值支持
- 七家公司的同学建议查询和golang sql方法命名保持统一.做了一个艰难的决定,修改zorm的部分方法名.全局依次替换字符串即可.
zorm.Query( 替换为 zorm.QueryRow(
zorm.QuerySlice( 替换为 zorm.Query(
zorm.QueryMap( 替换为 zorm.QueryRowMap(
zorm.QueryMapSlice( 替换为 zorm.QueryMap(
v1.4.1
更新内容:
- 支持自定义扩展字段映射逻辑
v1.4.0
更新内容:
- 修改多条数据的判断逻辑
v1.3.9
更新内容:
- 支持自定义数据类型,包括json/jsonb
- 非常感谢 @chien_tung 同学反馈的问题, QuerySlice方法支持*[]*struct类型,简化从xorm迁移
- 其他代码细节优化.
v1.3.7
更新内容:
- 非常感谢 @zhou- a- xing 同学(八块腹肌的单身少年)的英文翻译,zorm的核心代码注释已经是中英双语了.
- 非常感谢 @chien_tung 同学反馈的问题,修复主键自增int和init64类型的兼容性.
- 其他代码细节优化.
v1.3.6
更新内容:
- 完善注释文档
- 修复Delete方法的参数类型错误
- 其他代码细节优化.
v1.3.5
更新内容:
- 完善注释文档
- 兼容处理数据库为null时,基本类型取默认值,感谢@fastabler的pr
- 修复批量保存方法的一个bug:如果slice的长度为1,在pgsql和oracle会出现异常
- 其他代码细节优化.
v1.3.4
更新内容:
- 完善注释文档
- 取消分页语句必须有order by的限制
- 支持人大金仓数据库
- 人大金仓驱动说明: https://help.kingbase.com.cn/doc- view- 8108.html
- 人大金仓kingbase 8核心是基于postgresql 9.6,可以使用 https://github.com/lib/pq 进行测试,生产环境建议使用官方驱动
v1.3.3
更新内容:
- 完善注释文档
- 增加批量保存Struct对象方法
- 正式支持达梦数据库
- 基于达梦官方驱动,发布go mod项目 https://gitee.com/chunanyong/dm
v1.3.2
更新内容:
- 增加达梦数据的分页适配
- 完善调整代码注释
- 增加存储过程和函数的调用示例
v1.3.1
更新内容:
- 修改方法名称,和gorm和xorm保持相似,降低迁移和学习成本
- 更新测试用例文档
v1.3.0
更新内容:
- 去掉zap日志依赖,通过复写 FuncLogError FuncLogPanic FuncPrintSQL 实现自定义日志
- golang版本依赖调整为v1.13
- 迁移测试用到readygo,zorm项目不依赖任何数据库驱动包
v1.2.9
更新内容:
- IEntityMap支持主键自增或主键序列
- 更新方法返回影响的行数affected
- 修复 查询IEntityMap时数据库无记录出现异常的bug
- 测试用例即文档 https://gitee.com/chunanyong/readygo/blob/master/test/testzorm/BaseDao_test.go
v1.2.8
更新内容:
- 暴露FuncGenerateStringID函数,方便自定义扩展字符串主键ID
- Finder.Append 默认加一个空格,避免手误出现语法错误
- 缓存字段信息时,使用map代替sync.Map,提高性能
- 第三方性能压测结果
v1.2.6
更新内容:
- DataSourceConfig 配置区分 DriverName 和 DBType兼容一种数据库的多个驱动包
- 不再显示依赖数据库驱动,由使用者确定依赖的数据库驱动包
v1.2.5
更新内容:
- 分页语句必须有明确的order by,避免数据库迁移时出现分页语法不兼容.
- 修复列表查询时,page对象为nil的bug
v1.2.3
更新内容:
- 完善数据库支持,目前支持MySQL,SQLServer,Oracle,PostgreSQL,SQLite3
- 简化数据库读写分离实现,暴露zorm.FuncReadWriteBaseDao函数属性,用于自定义读写分离策略
- 精简zorm.DataSourceConfig属性,增加PrintSQL属性
v1.2.2
更新内容:
- 修改NewPage()返回Page对象指针,传递时少写一个 & 符号
- 取消GetDBConnection()方法,使用BindContextConnection()方法进行多个数据库库绑定
- 隐藏DBConnection对象,不再对外暴露数据库对象,避免手动初始化造成的异常
v1.1.8
更新内容:
- 修复UUID支持
- 数据库连接和事务隐藏到context.Context为统一参数,符合golang规范,更好的性能
- 封装logger实现,方便更换log包
- 增加zorm.UpdateStructNotZeroValue 方法,只更新不为零值的字段
- 完善测试用例

File diff suppressed because it is too large Load Diff

@ -1,184 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
import (
"errors"
"strings"
)
// Finder 查询数据库的载体,所有的sql语句都要通过Finder执行.
// Finder To query the database carrier, all SQL statements must be executed through Finder
type Finder struct {
// 拼接SQL
// Splicing SQL.
sqlBuilder strings.Builder
// SQL的参数值
// SQL parameter values.
values []interface{}
// 注入检查,默认true 不允许SQL注入的 ' 单引号
// Injection check, default true does not allow SQL injection single quote
InjectionCheck bool
// CountFinder 自定义的查询总条数'Finder',使用指针默认为nil.主要是为了在'group by'等复杂情况下,为了性能,手动编写总条数语句
// CountFinder The total number of custom queries is'Finder', and the pointer is nil by default. It is mainly used to manually write the total number of statements for performance in complex situations such as'group by'
CountFinder *Finder
// 是否自动查询总条数,默认true.同时需要Page不为nil,才查询总条数
// Whether to automatically query the total number of entries, the default is true. At the same time, the Page is not nil to query the total number of entries
SelectTotalCount bool
// SQL语句
// SQL statement
sqlstr string
}
// NewFinder 初始化一个Finder,生成一个空的Finder
// NewFinder Initialize a Finder and generate an empty Finder
func NewFinder() *Finder {
finder := Finder{}
finder.sqlBuilder.Grow(stringBuilderGrowLen)
finder.SelectTotalCount = true
finder.InjectionCheck = true
// slice扩容会生成新的slice,最后要值复制接收.问:为什么cap是3?答:经验
finder.values = make([]interface{}, 0, 3)
return &finder
}
// NewSelectFinder 根据表名初始化查询的Finder,strs 只取第一个字符串,用数组类型是为了可以不传入,默认为 * | Finder that initializes the query based on the table name
// NewSelectFinder("tableName") SELECT * FROM tableName
// NewSelectFinder("tableName", "id,name") SELECT id,name FROM tableName
func NewSelectFinder(tableName string, strs ...string) *Finder {
strsLen := len(strs)
if strsLen > 1 { // 不支持多个参数
return nil
}
finder := NewFinder()
finder.sqlBuilder.WriteString("SELECT ")
if strsLen == 1 { // 只取值第一个字符串
finder.sqlBuilder.WriteString(strs[0])
} else {
finder.sqlBuilder.WriteByte('*')
}
finder.sqlBuilder.WriteString(" FROM ")
finder.sqlBuilder.WriteString(tableName)
return finder
}
// NewUpdateFinder 根据表名初始化更新的Finder, UPDATE tableName SET
// NewUpdateFinder Initialize the updated Finder according to the table name, UPDATE tableName SET
func NewUpdateFinder(tableName string) *Finder {
finder := NewFinder()
finder.sqlBuilder.WriteString("UPDATE ")
finder.sqlBuilder.WriteString(tableName)
finder.sqlBuilder.WriteString(" SET ")
return finder
}
// NewDeleteFinder 根据表名初始化删除的'Finder', DELETE FROM tableName
// NewDeleteFinder Finder for initial deletion based on table name. DELETE FROM tableName
func NewDeleteFinder(tableName string) *Finder {
finder := NewFinder()
finder.sqlBuilder.WriteString("DELETE FROM ")
finder.sqlBuilder.WriteString(tableName)
// 所有的 WHERE 都不加,规则统一,好记
// No WHERE is added, the rules are unified, easy to remember
// finder.sqlBuilder.WriteString(" WHERE ")
return finder
}
// Append 添加SQL和参数的值,第一个参数是语句,后面的参数[可选]是参数的值,顺序要正确
// 例如: finder.Append(" and id=? and name=? ",23123,"abc")
// 只拼接SQL,例如: finder.Append(" and name=123 ")
// Append:Add SQL and parameter values, the first parameter is the statement, and the following parameter (optional) is the value of the parameter, in the correct order
// E.g: finder.Append(" and id=? and name=? ",23123,"abc")
// Only splice SQL, E.g: finder.Append(" and name=123 ")
func (finder *Finder) Append(s string, values ...interface{}) *Finder {
// 不要自己构建finder,使用NewFinder()方法
// Don't build finder by yourself, use NewFinder() method
if finder == nil || finder.values == nil {
return nil
}
if s != "" {
if finder.sqlstr != "" {
finder.sqlstr = ""
}
// 默认加一个空格,避免手误两个字符串连接再一起
// A space is added by default to avoid hand mistakes when connecting two strings together
finder.sqlBuilder.WriteByte(' ')
finder.sqlBuilder.WriteString(s)
}
if values == nil || len(values) < 1 {
return finder
}
finder.values = append(finder.values, values...)
return finder
}
// AppendFinder 添加另一个Finder finder.AppendFinder(f)
// AppendFinder Add another Finder . finder.AppendFinder(f)
func (finder *Finder) AppendFinder(f *Finder) (*Finder, error) {
if finder == nil {
return finder, errors.New("->finder-->AppendFinder()finder对象为nil")
}
if f == nil {
return finder, errors.New("->finder-->AppendFinder()参数是nil")
}
// 不要自己构建finder,使用NewFinder()方法
// Don't build finder by yourself, use NewFinder() method
if finder.values == nil {
return finder, errors.New("->finder-->AppendFinder()不要自己构建finder,使用NewFinder()方法")
}
// 添加f的SQL
// SQL to add f
sqlstr, err := f.GetSQL()
if err != nil {
return finder, err
}
finder.sqlstr = ""
finder.sqlBuilder.WriteString(sqlstr)
// 添加f的值
// Add the value of f
finder.values = append(finder.values, f.values...)
return finder, nil
}
// GetSQL 返回Finder封装的SQL语句
// GetSQL Return the SQL statement encapsulated by the Finder
func (finder *Finder) GetSQL() (string, error) {
// 不要自己构建finder,使用NewFinder方法
// Don't build finder by yourself, use NewFinder method
if finder == nil || finder.values == nil {
return "", errors.New("->finder-->GetSQL()不要自己构建finder,使用NewFinder()方法")
}
if len(finder.sqlstr) > 0 {
return finder.sqlstr, nil
}
sqlstr := finder.sqlBuilder.String()
// 包含单引号,属于非法字符串
// Contains single quotes, which are illegal strings
if finder.InjectionCheck && (strings.Contains(sqlstr, "'")) {
return "", errors.New(`->finder-->GetSQL()SQL语句请不要直接拼接字符串参数,容易注入!!!请使用问号占位符,例如 finder.Append("and id=?","stringId"),如果必须拼接字符串,请设置 finder.InjectionCheck = false `)
}
finder.sqlstr = sqlstr
return sqlstr, nil
}

@ -1,141 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"strings"
)
// customDriverValueMap 用于配置数据库字段类型的处理关系,key是 Dialect.字段类型,例如 dm.TEXT
var customDriverValueMap = make(map[string]ICustomDriverValueConver)
// iscdvm 是否有自定义的DriverValueMap
var iscdvm bool
// ICustomDriverValueConver 自定义类型转化接口,用于解决 类似达梦 text --> dm.DmClob --> string类型接收的问题
type ICustomDriverValueConver interface {
// GetDriverValue 根据数据库列类型,返回driver.Value的实例,struct属性类型
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
GetDriverValue(ctx context.Context, columnType *sql.ColumnType, structFieldType *reflect.Type) (driver.Value, error)
// ConverDriverValue 数据库列类型,GetDriverValue返回的driver.Value的临时接收值,struct属性类型
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
// 返回符合接收类型值的指针,指针,指针!!!!
ConverDriverValue(ctx context.Context, columnType *sql.ColumnType, tempDriverValue driver.Value, structFieldType *reflect.Type) (interface{}, error)
}
// RegisterCustomDriverValueConver 注册自定义的字段处理逻辑,用于驱动无法直接转换的场景,例如达梦的 TEXT 无法直接转化成 string
// dialectColumnType 值是 Dialect.字段类型,例如: dm.TEXT
// 一般是放到init方法里进行注册
func RegisterCustomDriverValueConver(dialectColumnType string, customDriverValueConver ICustomDriverValueConver) error {
if len(dialectColumnType) < 1 {
return errors.New("->RegisterCustomDriverValueConver-->dialectColumnType为空")
}
dialectColumnTypes := strings.Split(dialectColumnType, ".")
if len(dialectColumnTypes) < 2 {
customDriverValueMap[strings.ToUpper(dialectColumnType)] = customDriverValueConver
err := errors.New("->RegisterCustomDriverValueConver-->dialectColumnType 值是 Dialect.字段类型,例如: dm.TEXT ,本次正常运行,请尽快修改")
FuncLogError(nil, err)
} else {
customDriverValueMap[strings.ToLower(dialectColumnTypes[0])+"."+strings.ToUpper(dialectColumnTypes[1])] = customDriverValueConver
}
iscdvm = true
return nil
}
type driverValueInfo struct {
customDriverValueConver ICustomDriverValueConver
columnType *sql.ColumnType
tempDriverValue interface{}
structFieldType *reflect.Type
}
/**
import (
// 00.引入数据库驱动
"gitee.com/chunanyong/dm"
"io"
)
// CustomDMText 实现ICustomDriverValueConver接口,扩展自定义类型,例如 达梦数据库TEXT类型,映射出来的是dm.DmClob类型,无法使用string类型直接接收
type CustomDMText struct{}
// GetDriverValue 根据数据库列类型,返回driver.Value的实例,struct属性类型
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
func (dmtext CustomDMText) GetDriverValue(ctx context.Context, columnType *sql.ColumnType, structFieldType *reflect.Type) (driver.Value, error) {
// 如果需要使用structFieldType,需要先判断是否为nil
// if structFieldType != nil {
// }
return &dm.DmClob{}, nil
}
// ConverDriverValue 数据库列类型,GetDriverValue返回的driver.Value的临时接收值,struct属性类型
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
// 返回符合接收类型值的指针,指针,指针!!!!
func (dmtext CustomDMText) ConverDriverValue(ctx context.Context, columnType *sql.ColumnType, tempDriverValue driver.Value, structFieldType *reflect.Type) (interface{}, error) {
// 如果需要使用structFieldType,需要先判断是否为nil
// if structFieldType != nil {
// }
// 类型转换
dmClob, isok := tempDriverValue.(*dm.DmClob)
if !isok {
return tempDriverValue, errors.New("->ConverDriverValue-->转换至*dm.DmClob类型失败")
}
if dmClob == nil || !dmClob.Valid {
return new(string), nil
}
// 获取长度
dmlen, errLength := dmClob.GetLength()
if errLength != nil {
return dmClob, errLength
}
// int64转成int类型
strInt64 := strconv.FormatInt(dmlen, 10)
dmlenInt, errAtoi := strconv.Atoi(strInt64)
if errAtoi != nil {
return dmClob, errAtoi
}
// 读取字符串
str, errReadString := dmClob.ReadString(1, dmlenInt)
// 处理空字符串或NULL造成的EOF错误
if errReadString == io.EOF {
return new(string), nil
}
return &str, errReadString
}
// RegisterCustomDriverValueConver 注册自定义的字段处理逻辑,用于驱动无法直接转换的场景,例如达梦的 TEXT 无法直接转化成 string
// 一般是放到init方法里进行注册
func init() {
// dialectColumnType 值是 Dialect.字段类型 ,例如 dm.TEXT
zorm.RegisterCustomDriverValueConver("dm.TEXT", CustomDMText{})
}
**/

@ -1,160 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
// IEntityStruct "struct"实体类的接口,所有的struct实体类都要实现这个接口
// IEntityStruct The interface of the "struct" entity class, all struct entity classes must implement this interface
type IEntityStruct interface {
// 获取表名称
// Get the table name.
GetTableName() string
// 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称
// Get the primary key field name of the database table. Because it is compatible with Map, it can only be the field name of the database
GetPKColumnName() string
// GetPkSequence 主键序列
// GetPkSequence Primary key sequence
GetPkSequence() string
}
// IEntityMap 使用Map保存数据,用于不方便使用struct的场景,如果主键是自增或者序列,不要"entityMap.Set"主键的值
// IEntityMap Use Map to save data for scenarios where it is not convenient to use struct
// If the primary key is auto-increment or sequence, do not "entity Map.Set" the value of the primary key
type IEntityMap interface {
// 获取表名称
// Get the table name
GetTableName() string
// 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称.
// Get the primary key field name of the database table. Because it is compatible with Map, it can only be the field name of the database.
GetPKColumnName() string
// GetEntityMapPkSequence 主键序列,不能使用GetPkSequence方法名,避免默认实现了IEntityStruct接口
// GetEntityMapPkSequence primary key sequence, you cannot use the GetPkSequence method name, to avoid the default implementation of IEntityStruct interface
GetEntityMapPkSequence() string
// GetDBFieldMap 针对Map类型,记录数据库字段
// GetDBFieldMap For Map type, record database fields.
GetDBFieldMap() map[string]interface{}
// GetDBFieldMapKey 按照Set的先后顺序记录key值,也就是数据库字段,用于SQL排序
// GetDBFieldMapKey records the key value, that is, the database field, in the order of the Set, which is used for SQL sorting
GetDBFieldMapKey() []string
// 设置数据库字段的值
// Set the value of a database field.
Set(key string, value interface{}) map[string]interface{}
}
// EntityStruct "IBaseEntity" 的基础实现,所有的实体类都匿名注入.这样就类似实现继承了,如果接口增加方法,调整这个默认实现即可
// EntityStruct The basic implementation of "IBaseEntity", all entity classes are injected anonymously
// This is similar to implementation inheritance. If the interface adds methods, adjust the default implementation
type EntityStruct struct{}
// 默认数据库的主键列名
// Primary key column name of the default database
const defaultPkName = "id"
//GetTableName 获取表名称,必须有具体的Struct实现,类似java的抽象方法,避免手误忘记写表名.如果有扩展需求,建议使用接口进行扩展,不要默认实现GetTableName
/*
func (entity *EntityStruct) GetTableName() string {
return ""
}
*/
// GetPKColumnName 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称
// GetPKColumnName Get the primary key field name of the database table
// Because it is compatible with Map, it can only be the field name of the database
func (entity *EntityStruct) GetPKColumnName() string {
return defaultPkName
}
// var defaultPkSequence = make(map[string]string, 0)
// GetPkSequence 主键序列
// GetPkSequence Primary key sequence
func (entity *EntityStruct) GetPkSequence() string {
return ""
}
//-------------------------------------------------------------------------//
// EntityMap IEntityMap的基础实现,可以直接使用或者匿名注入
type EntityMap struct {
// 表名
tableName string
// 主键列名
PkColumnName string
// 主键序列,如果有值,优先级最高
PkSequence string
// 数据库字段,不暴露外部
dbFieldMap map[string]interface{}
// 列名,记录顺序
dbFieldMapKey []string
}
// NewEntityMap 初始化Map,必须传入表名称
func NewEntityMap(tbName string) *EntityMap {
entityMap := EntityMap{}
entityMap.dbFieldMap = map[string]interface{}{}
entityMap.tableName = tbName
entityMap.PkColumnName = defaultPkName
entityMap.dbFieldMapKey = make([]string, 0)
return &entityMap
}
// GetTableName 获取表名称
func (entity *EntityMap) GetTableName() string {
return entity.tableName
}
// GetPKColumnName 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称
func (entity *EntityMap) GetPKColumnName() string {
return entity.PkColumnName
}
// GetEntityMapPkSequence 主键序列,不能使用GetPkSequence方法名,避免默认实现了IEntityStruct接口
// GetEntityMapPkSequence primary key sequence, you cannot use the GetPkSequence method name, to avoid the default implementation of IEntityStruct interface
func (entity *EntityMap) GetEntityMapPkSequence() string {
return entity.PkSequence
}
// GetDBFieldMap 针对Map类型,记录数据库字段
// GetDBFieldMap For Map type, record database fields
func (entity *EntityMap) GetDBFieldMap() map[string]interface{} {
return entity.dbFieldMap
}
// GetDBFieldMapKey 按照Set的先后顺序记录key值,也就是数据库字段,用于SQL排序
// GetDBFieldMapKey records the key value, that is, the database field, in the order of the Set, which is used for SQL sorting
func (entity *EntityMap) GetDBFieldMapKey() []string {
return entity.dbFieldMapKey
}
// Set 设置数据库字段
// Set Set database fields
func (entity *EntityMap) Set(key string, value interface{}) map[string]interface{} {
_, ok := entity.dbFieldMap[key]
if !ok { // 如果不存在
entity.dbFieldMapKey = append(entity.dbFieldMapKey, key)
}
entity.dbFieldMap[key] = value
return entity.dbFieldMap
}

@ -1,41 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
import "context"
// IGlobalTransaction 托管全局分布式事务接口
type IGlobalTransaction interface {
// BeginGTX 开启全局分布式事务
BeginGTX(ctx context.Context, globalRootContext context.Context) error
// CommitGTX 提交全局分布式事务.不能命名为 Commit,不然就和gtx的Commit一致了,就递归调用自己了.......
CommitGTX(ctx context.Context, globalRootContext context.Context) error
// RollbackGTX 回滚全局分布式事务
RollbackGTX(ctx context.Context, globalRootContext context.Context) error
// GetGTXID 获取全局分布式事务的XID
GetGTXID(ctx context.Context, globalRootContext context.Context) (string, error)
// 重新包装为 seata/hptx 的context.RootContext
// context.RootContext 如果后续使用了 context.WithValue,类型就是context.valueCtx 就会造成无法再类型断言为 context.RootContext
// 所以DBDao里使用了 globalRootContext变量,区分业务的ctx和分布式事务的RootContext
// NewRootContext(ctx context.Context) context.Context
}

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

@ -1,63 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
import (
"context"
"fmt"
"log"
)
func init() {
// 设置默认的日志显示信息,显示文件和行号
// Set the default log display information, display file and line number.
log.SetFlags(log.Llongfile | log.LstdFlags)
}
// LogCallDepth 记录日志调用层级,用于定位到业务层代码
// Log Call Depth Record the log call level, used to locate the business layer code
var LogCallDepth = 4
// FuncLogError 记录error日志.NewDBDao方法里的异常,ctx为nil,扩展时请注意
// FuncLogError Record error log
var FuncLogError func(ctx context.Context, err error) = defaultLogError
// FuncLogPanic 记录panic日志,默认使用"defaultLogError"实现
// FuncLogPanic Record panic log, using "defaultLogError" by default
var FuncLogPanic func(ctx context.Context, err error) = defaultLogPanic
// FuncPrintSQL 打印sql语句,参数和执行时间,小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
// FuncPrintSQL Print sql statement and parameters
var FuncPrintSQL func(ctx context.Context, sqlstr string, args []interface{}, execSQLMillis int64) = defaultPrintSQL
func defaultLogError(ctx context.Context, err error) {
log.Output(LogCallDepth, fmt.Sprintln(err))
}
func defaultLogPanic(ctx context.Context, err error) {
defaultLogError(ctx, err)
}
func defaultPrintSQL(ctx context.Context, sqlstr string, args []interface{}, execSQLMillis int64) {
if args != nil {
log.Output(LogCallDepth, fmt.Sprintln("sql:", sqlstr, ",args:", args, ",execSQLMillis:", execSQLMillis))
} else {
log.Output(LogCallDepth, fmt.Sprintln("sql:", sqlstr, ",args: [] ", ",execSQLMillis:", execSQLMillis))
}
}

@ -1,81 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
// Page 分页对象
// Page Pagination object
type Page struct {
// 当前页码,从1开始
// Current page number, starting from 1
PageNo int
// 每页多少条,默认20条
// How many items per page, 20 items by default
PageSize int
// 数据总条数
// Total number of data
TotalCount int
// 共多少页
// How many pages
PageCount int
// 是否是第一页
// Is it the first page
FirstPage bool
// 是否有上一页
// Whether there is a previous page
HasPrev bool
// 是否有下一页
// Is there a next page
HasNext bool
// 是否是最后一页
// Is it the last page
LastPage bool
}
// NewPage 创建Page对象
// NewPage Create Page object
func NewPage() *Page {
page := Page{}
page.PageNo = 1
page.PageSize = 20
return &page
}
// setTotalCount 设置总条数,计算其他值
// setTotalCount Set the total number of bars, calculate other values
func (page *Page) setTotalCount(total int) {
page.TotalCount = total
page.PageCount = (page.TotalCount + page.PageSize - 1) / page.PageSize
if page.PageNo >= page.PageCount {
page.LastPage = true
} else {
page.HasNext = true
}
if page.PageNo > 1 {
page.HasPrev = true
} else {
page.FirstPage = true
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -1,318 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
)
// dataSorce对象,隔离sql原生对象
// dataSorce Isolate sql native objects
type dataSource struct {
*sql.DB
// config *DataSourceConfig
}
// newDataSource 创建一个新的datasource,内部调用,避免外部直接使用datasource
// newDAtaSource Create a new datasource and call it internally to avoid direct external use of the datasource
func newDataSource(config *DataSourceConfig) (*dataSource, error) {
if config == nil {
return nil, errors.New("->newDataSource-->config cannot be nil")
}
if config.DriverName == "" {
return nil, errors.New("->newDataSource-->DriverName cannot be empty")
}
// 兼容处理,DBType即将废弃,请使用Dialect属性
if config.DBType != "" && config.Dialect == "" {
FuncLogError(nil, errors.New("->newDataSource-->DataSourceConfig的DBType即将废弃,请使用Dialect属性"))
config.Dialect = config.DBType
}
if config.Dialect == "" {
return nil, errors.New("->newDataSource-->Dialect cannot be empty")
}
var db *sql.DB
var errSQLOpen error
if config.SQLDB == nil { // 没有已经存在的数据库连接,使用DSN初始化
if config.DSN == "" {
return nil, errors.New("->newDataSource-->DSN cannot be empty")
}
db, errSQLOpen = sql.Open(config.DriverName, config.DSN)
if errSQLOpen != nil {
errSQLOpen = fmt.Errorf("->newDataSource-->open数据库打开失败:%w", errSQLOpen)
FuncLogError(nil, errSQLOpen)
return nil, errSQLOpen
}
} else { // 使用已经存在的数据库连接
db = config.SQLDB
}
if config.MaxOpenConns == 0 {
config.MaxOpenConns = 50
}
if config.MaxIdleConns == 0 {
config.MaxIdleConns = 50
}
if config.ConnMaxLifetimeSecond == 0 {
config.ConnMaxLifetimeSecond = 600
}
// 设置数据库最大连接数
// Set the maximum number of database connections
db.SetMaxOpenConns(config.MaxOpenConns)
// 设置数据库最大空闲连接数
// Set the maximum number of free connections to the database
db.SetMaxIdleConns(config.MaxIdleConns)
//连接存活秒时间. 默认600(10分钟)后连接被销毁重建.避免数据库主动断开连接,造成死连接.MySQL默认wait_timeout 28800秒(8小时)
//(Connection survival time in seconds) Destroy and rebuild the connection after the default 600 seconds (10 minutes)
//Prevent the database from actively disconnecting and causing dead connections. MySQL Default wait_timeout 28800 seconds
db.SetConnMaxLifetime(time.Second * time.Duration(config.ConnMaxLifetimeSecond))
// 验证连接
if pingerr := db.Ping(); pingerr != nil {
pingerr = fmt.Errorf("->newDataSource-->ping数据库失败:%w", pingerr)
FuncLogError(nil, pingerr)
db.Close()
return nil, pingerr
}
return &dataSource{db}, nil
}
// 事务参照:https://www.jianshu.com/p/2a144332c3db
// Transaction reference: https://www.jianshu.com/p/2a144332c3db
// dataBaseConnection 数据库dbConnection会话,可以原生查询或者事务
// dataBaseConnection Database session, native query or transaction.
type dataBaseConnection struct {
// 原生db
// native db
db *sql.DB
// 原生事务
// native transaction
tx *sql.Tx
// 数据库配置
config *DataSourceConfig
}
// beginTx 开启事务
// beginTx Open transaction
func (dbConnection *dataBaseConnection) beginTx(ctx context.Context) error {
if dbConnection.tx != nil {
return nil
}
// 设置事务配置,主要是隔离级别
var txOptions *sql.TxOptions
contextTxOptions := ctx.Value(contextTxOptionsKey)
if contextTxOptions != nil {
txOptions, _ = contextTxOptions.(*sql.TxOptions)
} else {
txOptions = dbConnection.config.DefaultTxOptions
}
tx, err := dbConnection.db.BeginTx(ctx, txOptions)
if err != nil {
err = fmt.Errorf("->beginTx事务开启失败:%w", err)
return err
}
dbConnection.tx = tx
return nil
}
// rollback 回滚事务
// rollback Rollback transaction
func (dbConnection *dataBaseConnection) rollback() error {
if dbConnection.tx == nil {
return nil
}
err := dbConnection.tx.Rollback()
if err != nil {
err = fmt.Errorf("->rollback事务回滚失败:%w", err)
return err
}
dbConnection.tx = nil
return nil
}
// commit 提交事务
// commit Commit transaction
func (dbConnection *dataBaseConnection) commit() error {
if dbConnection.tx == nil {
return errors.New("->dbConnection.commit()事务为空")
}
err := dbConnection.tx.Commit()
if err != nil {
err = fmt.Errorf("->dbConnection.commit()事务提交失败:%w", err)
return err
}
dbConnection.tx = nil
return nil
}
// execContext 执行sql语句,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
// execContext Execute sql statement,If the transaction has been opened,it will be executed in transaction mode, if the transaction is not opened,it will be executed in non-transactional mode
func (dbConnection *dataBaseConnection) execContext(ctx context.Context, sqlstr *string, argsValues *[]interface{}) (*sql.Result, error) {
// reBindSQL 重新处理参数代入方式
execsql, args, err := reBindSQL(dbConnection.config.Dialect, sqlstr, argsValues)
if err != nil {
return nil, err
}
// 更新语句处理ClickHouse特殊语法
err = reUpdateSQL(dbConnection.config.Dialect, execsql)
if err != nil {
return nil, err
}
// 执行前加入 hint
err = wrapSQLHint(ctx, execsql)
if err != nil {
return nil, err
}
var start *time.Time
var res sql.Result
// 小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
slowSQLMillis := dbConnection.config.SlowSQLMillis
if slowSQLMillis == 0 {
FuncPrintSQL(ctx, *execsql, *args, 0)
} else if slowSQLMillis > 0 {
now := time.Now() // 获取当前时间
start = &now
}
if dbConnection.tx != nil {
res, err = dbConnection.tx.ExecContext(ctx, *execsql, *args...)
} else {
res, err = dbConnection.db.ExecContext(ctx, *execsql, *args...)
}
if slowSQLMillis > 0 {
slow := time.Since(*start).Milliseconds()
if slow-int64(slowSQLMillis) >= 0 {
FuncPrintSQL(ctx, *execsql, *args, slow)
}
}
if err != nil {
err = fmt.Errorf("->execContext执行错误:%w,-->zormErrorExecSQL:%s,-->zormErrorSQLValues:%v", err, *execsql, *args)
}
return &res, err
}
// queryRowContext 如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
func (dbConnection *dataBaseConnection) queryRowContext(ctx context.Context, sqlstr *string, argsValues *[]interface{}) (*sql.Row, error) {
// reBindSQL 重新处理参数代入方式
query, args, err := reBindSQL(dbConnection.config.Dialect, sqlstr, argsValues)
if err != nil {
return nil, err
}
// 执行前加入 hint
err = wrapSQLHint(ctx, query)
if err != nil {
return nil, err
}
var start *time.Time
var row *sql.Row
// 小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
slowSQLMillis := dbConnection.config.SlowSQLMillis
if slowSQLMillis == 0 {
FuncPrintSQL(ctx, *query, *args, 0)
} else if slowSQLMillis > 0 {
now := time.Now() // 获取当前时间
start = &now
}
if dbConnection.tx != nil {
row = dbConnection.tx.QueryRowContext(ctx, *query, *args...)
} else {
row = dbConnection.db.QueryRowContext(ctx, *query, *args...)
}
if slowSQLMillis > 0 {
slow := time.Since(*start).Milliseconds()
if slow-int64(slowSQLMillis) >= 0 {
FuncPrintSQL(ctx, *query, *args, slow)
}
}
return row, nil
}
// queryContext 查询数据,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
// queryRowContext Execute sql row statement,If the transaction has been opened,it will be executed in transaction mode, if the transaction is not opened,it will be executed in non-transactional mode
func (dbConnection *dataBaseConnection) queryContext(ctx context.Context, sqlstr *string, argsValues *[]interface{}) (*sql.Rows, error) {
// reBindSQL 重新处理参数代入方式
query, args, err := reBindSQL(dbConnection.config.Dialect, sqlstr, argsValues)
if err != nil {
return nil, err
}
// 执行前加入 hint
err = wrapSQLHint(ctx, query)
if err != nil {
return nil, err
}
var start *time.Time
var rows *sql.Rows
// 小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
slowSQLMillis := dbConnection.config.SlowSQLMillis
if slowSQLMillis == 0 {
FuncPrintSQL(ctx, *query, *args, 0)
} else if slowSQLMillis > 0 {
now := time.Now() // 获取当前时间
start = &now
}
if dbConnection.tx != nil {
rows, err = dbConnection.tx.QueryContext(ctx, *query, *args...)
} else {
rows, err = dbConnection.db.QueryContext(ctx, *query, *args...)
}
if slowSQLMillis > 0 {
slow := time.Since(*start).Milliseconds()
if slow-int64(slowSQLMillis) >= 0 {
FuncPrintSQL(ctx, *query, *args, slow)
}
}
if err != nil {
err = fmt.Errorf("->queryContext执行错误:%w,-->zormErrorExecSQL:%s,-->zormErrorSQLValues:%v", err, *query, *args)
}
return rows, err
}
/*
// prepareContext 预执行,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
// prepareContext Pre-execution,If the transaction has been opened,it will be executed in transaction mode,if the transaction is not opened,it will be executed in non-transactional mode
func (dbConnection *dataBaseConnection) prepareContext(ctx context.Context, query *string) (*sql.Stmt, error) {
//打印SQL
//print SQL
if dbConnection.config.PrintSQL {
//logger.Info("printSQL", logger.String("sql", query))
FuncPrintSQL(ctx,*query, nil)
}
if dbConnection.tx != nil {
return dbConnection.tx.PrepareContext(ctx, *query)
}
return dbConnection.db.PrepareContext(ctx, *query)
}
*/

@ -1,415 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Multiprecision decimal numbers.
// For floating-point formatting only; not general purpose.
// Only operations are assign and (binary) left/right shift.
// Can do binary floating point in multiprecision decimal precisely
// because 2 divides 10; cannot do decimal floating point
// in multiprecision binary precisely.
package decimal
type decimal struct {
d [800]byte // digits, big-endian representation
nd int // number of digits used
dp int // decimal point
neg bool // negative flag
trunc bool // discarded nonzero digits beyond d[:nd]
}
func (a *decimal) String() string {
n := 10 + a.nd
if a.dp > 0 {
n += a.dp
}
if a.dp < 0 {
n += -a.dp
}
buf := make([]byte, n)
w := 0
switch {
case a.nd == 0:
return "0"
case a.dp <= 0:
// zeros fill space between decimal point and digits
buf[w] = '0'
w++
buf[w] = '.'
w++
w += digitZero(buf[w : w+-a.dp])
w += copy(buf[w:], a.d[0:a.nd])
case a.dp < a.nd:
// decimal point in middle of digits
w += copy(buf[w:], a.d[0:a.dp])
buf[w] = '.'
w++
w += copy(buf[w:], a.d[a.dp:a.nd])
default:
// zeros fill space between digits and decimal point
w += copy(buf[w:], a.d[0:a.nd])
w += digitZero(buf[w : w+a.dp-a.nd])
}
return string(buf[0:w])
}
func digitZero(dst []byte) int {
for i := range dst {
dst[i] = '0'
}
return len(dst)
}
// trim trailing zeros from number.
// (They are meaningless; the decimal point is tracked
// independent of the number of digits.)
func trim(a *decimal) {
for a.nd > 0 && a.d[a.nd-1] == '0' {
a.nd--
}
if a.nd == 0 {
a.dp = 0
}
}
// Assign v to a.
func (a *decimal) Assign(v uint64) {
var buf [24]byte
// Write reversed decimal in buf.
n := 0
for v > 0 {
v1 := v / 10
v -= 10 * v1
buf[n] = byte(v + '0')
n++
v = v1
}
// Reverse again to produce forward decimal in a.d.
a.nd = 0
for n--; n >= 0; n-- {
a.d[a.nd] = buf[n]
a.nd++
}
a.dp = a.nd
trim(a)
}
// Maximum shift that we can do in one pass without overflow.
// A uint has 32 or 64 bits, and we have to be able to accommodate 9<<k.
const uintSize = 32 << (^uint(0) >> 63)
const maxShift = uintSize - 4
// Binary shift right (/ 2) by k bits. k <= maxShift to avoid overflow.
func rightShift(a *decimal, k uint) {
r := 0 // read pointer
w := 0 // write pointer
// Pick up enough leading digits to cover first shift.
var n uint
for ; n>>k == 0; r++ {
if r >= a.nd {
if n == 0 {
// a == 0; shouldn't get here, but handle anyway.
a.nd = 0
return
}
for n>>k == 0 {
n = n * 10
r++
}
break
}
c := uint(a.d[r])
n = n*10 + c - '0'
}
a.dp -= r - 1
var mask uint = (1 << k) - 1
// Pick up a digit, put down a digit.
for ; r < a.nd; r++ {
c := uint(a.d[r])
dig := n >> k
n &= mask
a.d[w] = byte(dig + '0')
w++
n = n*10 + c - '0'
}
// Put down extra digits.
for n > 0 {
dig := n >> k
n &= mask
if w < len(a.d) {
a.d[w] = byte(dig + '0')
w++
} else if dig > 0 {
a.trunc = true
}
n = n * 10
}
a.nd = w
trim(a)
}
// Cheat sheet for left shift: table indexed by shift count giving
// number of new digits that will be introduced by that shift.
//
// For example, leftcheats[4] = {2, "625"}. That means that
// if we are shifting by 4 (multiplying by 16), it will add 2 digits
// when the string prefix is "625" through "999", and one fewer digit
// if the string prefix is "000" through "624".
//
// Credit for this trick goes to Ken.
type leftCheat struct {
delta int // number of new digits
cutoff string // minus one digit if original < a.
}
var leftcheats = []leftCheat{
// Leading digits of 1/2^i = 5^i.
// 5^23 is not an exact 64-bit floating point number,
// so have to use bc for the math.
// Go up to 60 to be large enough for 32bit and 64bit platforms.
/*
seq 60 | sed 's/^/5^/' | bc |
awk 'BEGIN{ print "\t{ 0, \"\" }," }
{
log2 = log(2)/log(10)
printf("\t{ %d, \"%s\" },\t// * %d\n",
int(log2*NR+1), $0, 2**NR)
}'
*/
{0, ""},
{1, "5"}, // * 2
{1, "25"}, // * 4
{1, "125"}, // * 8
{2, "625"}, // * 16
{2, "3125"}, // * 32
{2, "15625"}, // * 64
{3, "78125"}, // * 128
{3, "390625"}, // * 256
{3, "1953125"}, // * 512
{4, "9765625"}, // * 1024
{4, "48828125"}, // * 2048
{4, "244140625"}, // * 4096
{4, "1220703125"}, // * 8192
{5, "6103515625"}, // * 16384
{5, "30517578125"}, // * 32768
{5, "152587890625"}, // * 65536
{6, "762939453125"}, // * 131072
{6, "3814697265625"}, // * 262144
{6, "19073486328125"}, // * 524288
{7, "95367431640625"}, // * 1048576
{7, "476837158203125"}, // * 2097152
{7, "2384185791015625"}, // * 4194304
{7, "11920928955078125"}, // * 8388608
{8, "59604644775390625"}, // * 16777216
{8, "298023223876953125"}, // * 33554432
{8, "1490116119384765625"}, // * 67108864
{9, "7450580596923828125"}, // * 134217728
{9, "37252902984619140625"}, // * 268435456
{9, "186264514923095703125"}, // * 536870912
{10, "931322574615478515625"}, // * 1073741824
{10, "4656612873077392578125"}, // * 2147483648
{10, "23283064365386962890625"}, // * 4294967296
{10, "116415321826934814453125"}, // * 8589934592
{11, "582076609134674072265625"}, // * 17179869184
{11, "2910383045673370361328125"}, // * 34359738368
{11, "14551915228366851806640625"}, // * 68719476736
{12, "72759576141834259033203125"}, // * 137438953472
{12, "363797880709171295166015625"}, // * 274877906944
{12, "1818989403545856475830078125"}, // * 549755813888
{13, "9094947017729282379150390625"}, // * 1099511627776
{13, "45474735088646411895751953125"}, // * 2199023255552
{13, "227373675443232059478759765625"}, // * 4398046511104
{13, "1136868377216160297393798828125"}, // * 8796093022208
{14, "5684341886080801486968994140625"}, // * 17592186044416
{14, "28421709430404007434844970703125"}, // * 35184372088832
{14, "142108547152020037174224853515625"}, // * 70368744177664
{15, "710542735760100185871124267578125"}, // * 140737488355328
{15, "3552713678800500929355621337890625"}, // * 281474976710656
{15, "17763568394002504646778106689453125"}, // * 562949953421312
{16, "88817841970012523233890533447265625"}, // * 1125899906842624
{16, "444089209850062616169452667236328125"}, // * 2251799813685248
{16, "2220446049250313080847263336181640625"}, // * 4503599627370496
{16, "11102230246251565404236316680908203125"}, // * 9007199254740992
{17, "55511151231257827021181583404541015625"}, // * 18014398509481984
{17, "277555756156289135105907917022705078125"}, // * 36028797018963968
{17, "1387778780781445675529539585113525390625"}, // * 72057594037927936
{18, "6938893903907228377647697925567626953125"}, // * 144115188075855872
{18, "34694469519536141888238489627838134765625"}, // * 288230376151711744
{18, "173472347597680709441192448139190673828125"}, // * 576460752303423488
{19, "867361737988403547205962240695953369140625"}, // * 1152921504606846976
}
// Is the leading prefix of b lexicographically less than s?
func prefixIsLessThan(b []byte, s string) bool {
for i := 0; i < len(s); i++ {
if i >= len(b) {
return true
}
if b[i] != s[i] {
return b[i] < s[i]
}
}
return false
}
// Binary shift left (* 2) by k bits. k <= maxShift to avoid overflow.
func leftShift(a *decimal, k uint) {
delta := leftcheats[k].delta
if prefixIsLessThan(a.d[0:a.nd], leftcheats[k].cutoff) {
delta--
}
r := a.nd // read index
w := a.nd + delta // write index
// Pick up a digit, put down a digit.
var n uint
for r--; r >= 0; r-- {
n += (uint(a.d[r]) - '0') << k
quo := n / 10
rem := n - 10*quo
w--
if w < len(a.d) {
a.d[w] = byte(rem + '0')
} else if rem != 0 {
a.trunc = true
}
n = quo
}
// Put down extra digits.
for n > 0 {
quo := n / 10
rem := n - 10*quo
w--
if w < len(a.d) {
a.d[w] = byte(rem + '0')
} else if rem != 0 {
a.trunc = true
}
n = quo
}
a.nd += delta
if a.nd >= len(a.d) {
a.nd = len(a.d)
}
a.dp += delta
trim(a)
}
// Binary shift left (k > 0) or right (k < 0).
func (a *decimal) Shift(k int) {
switch {
case a.nd == 0:
// nothing to do: a == 0
case k > 0:
for k > maxShift {
leftShift(a, maxShift)
k -= maxShift
}
leftShift(a, uint(k))
case k < 0:
for k < -maxShift {
rightShift(a, maxShift)
k += maxShift
}
rightShift(a, uint(-k))
}
}
// If we chop a at nd digits, should we round up?
func shouldRoundUp(a *decimal, nd int) bool {
if nd < 0 || nd >= a.nd {
return false
}
if a.d[nd] == '5' && nd+1 == a.nd { // exactly halfway - round to even
// if we truncated, a little higher than what's recorded - always round up
if a.trunc {
return true
}
return nd > 0 && (a.d[nd-1]-'0')%2 != 0
}
// not halfway - digit tells all
return a.d[nd] >= '5'
}
// Round a to nd digits (or fewer).
// If nd is zero, it means we're rounding
// just to the left of the digits, as in
// 0.09 -> 0.1.
func (a *decimal) Round(nd int) {
if nd < 0 || nd >= a.nd {
return
}
if shouldRoundUp(a, nd) {
a.RoundUp(nd)
} else {
a.RoundDown(nd)
}
}
// Round a down to nd digits (or fewer).
func (a *decimal) RoundDown(nd int) {
if nd < 0 || nd >= a.nd {
return
}
a.nd = nd
trim(a)
}
// Round a up to nd digits (or fewer).
func (a *decimal) RoundUp(nd int) {
if nd < 0 || nd >= a.nd {
return
}
// round up
for i := nd - 1; i >= 0; i-- {
c := a.d[i]
if c < '9' { // can stop after this digit
a.d[i]++
a.nd = i + 1
return
}
}
// Number is all 9s.
// Change to single 1 with adjusted decimal point.
a.d[0] = '1'
a.nd = 1
a.dp++
}
// Extract integer part, rounded appropriately.
// No guarantees about overflow.
func (a *decimal) RoundedInteger() uint64 {
if a.dp > 20 {
return 0xFFFFFFFFFFFFFFFF
}
var i int
n := uint64(0)
for i = 0; i < a.dp && i < a.nd; i++ {
n = n*10 + uint64(a.d[i]-'0')
}
for ; i < a.dp; i++ {
n *= 10
}
if shouldRoundUp(a, a.dp) {
n++
}
return n
}

File diff suppressed because it is too large Load Diff

@ -1,160 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Multiprecision decimal numbers.
// For floating-point formatting only; not general purpose.
// Only operations are assign and (binary) left/right shift.
// Can do binary floating point in multiprecision decimal precisely
// because 2 divides 10; cannot do decimal floating point
// in multiprecision binary precisely.
package decimal
type floatInfo struct {
mantbits uint
expbits uint
bias int
}
var float32info = floatInfo{23, 8, -127}
var float64info = floatInfo{52, 11, -1023}
// roundShortest rounds d (= mant * 2^exp) to the shortest number of digits
// that will let the original floating point value be precisely reconstructed.
func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) {
// If mantissa is zero, the number is zero; stop now.
if mant == 0 {
d.nd = 0
return
}
// Compute upper and lower such that any decimal number
// between upper and lower (possibly inclusive)
// will round to the original floating point number.
// We may see at once that the number is already shortest.
//
// Suppose d is not denormal, so that 2^exp <= d < 10^dp.
// The closest shorter number is at least 10^(dp-nd) away.
// The lower/upper bounds computed below are at distance
// at most 2^(exp-mantbits).
//
// So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits),
// or equivalently log2(10)*(dp-nd) > exp-mantbits.
// It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32).
minexp := flt.bias + 1 // minimum possible exponent
if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) {
// The number is already shortest.
return
}
// d = mant << (exp - mantbits)
// Next highest floating point number is mant+1 << exp-mantbits.
// Our upper bound is halfway between, mant*2+1 << exp-mantbits-1.
upper := new(decimal)
upper.Assign(mant*2 + 1)
upper.Shift(exp - int(flt.mantbits) - 1)
// d = mant << (exp - mantbits)
// Next lowest floating point number is mant-1 << exp-mantbits,
// unless mant-1 drops the significant bit and exp is not the minimum exp,
// in which case the next lowest is mant*2-1 << exp-mantbits-1.
// Either way, call it mantlo << explo-mantbits.
// Our lower bound is halfway between, mantlo*2+1 << explo-mantbits-1.
var mantlo uint64
var explo int
if mant > 1<<flt.mantbits || exp == minexp {
mantlo = mant - 1
explo = exp
} else {
mantlo = mant*2 - 1
explo = exp - 1
}
lower := new(decimal)
lower.Assign(mantlo*2 + 1)
lower.Shift(explo - int(flt.mantbits) - 1)
// The upper and lower bounds are possible outputs only if
// the original mantissa is even, so that IEEE round-to-even
// would round to the original mantissa and not the neighbors.
inclusive := mant%2 == 0
// As we walk the digits we want to know whether rounding up would fall
// within the upper bound. This is tracked by upperdelta:
//
// If upperdelta == 0, the digits of d and upper are the same so far.
//
// If upperdelta == 1, we saw a difference of 1 between d and upper on a
// previous digit and subsequently only 9s for d and 0s for upper.
// (Thus rounding up may fall outside the bound, if it is exclusive.)
//
// If upperdelta == 2, then the difference is greater than 1
// and we know that rounding up falls within the bound.
var upperdelta uint8
// Now we can figure out the minimum number of digits required.
// Walk along until d has distinguished itself from upper and lower.
for ui := 0; ; ui++ {
// lower, d, and upper may have the decimal points at different
// places. In this case upper is the longest, so we iterate from
// ui==0 and start li and mi at (possibly) -1.
mi := ui - upper.dp + d.dp
if mi >= d.nd {
break
}
li := ui - upper.dp + lower.dp
l := byte('0') // lower digit
if li >= 0 && li < lower.nd {
l = lower.d[li]
}
m := byte('0') // middle digit
if mi >= 0 {
m = d.d[mi]
}
u := byte('0') // upper digit
if ui < upper.nd {
u = upper.d[ui]
}
// Okay to round down (truncate) if lower has a different digit
// or if lower is inclusive and is exactly the result of rounding
// down (i.e., and we have reached the final digit of lower).
okdown := l != m || inclusive && li+1 == lower.nd
switch {
case upperdelta == 0 && m+1 < u:
// Example:
// m = 12345xxx
// u = 12347xxx
upperdelta = 2
case upperdelta == 0 && m != u:
// Example:
// m = 12345xxx
// u = 12346xxx
upperdelta = 1
case upperdelta == 1 && (m != '9' || u != '0'):
// Example:
// m = 1234598x
// u = 1234600x
upperdelta = 2
}
// Okay to round up if upper has a different digit and either upper
// is inclusive or upper is bigger than the result of rounding up.
okup := upperdelta > 0 && (inclusive || upperdelta > 1 || ui+1 < upper.nd)
// If it's okay to do either, then round to the nearest one.
// If it's okay to do only one, do it.
switch {
case okdown && okup:
d.Round(mi + 1)
return
case okdown:
d.RoundDown(mi + 1)
return
case okup:
d.RoundUp(mi + 1)
return
}
}
}

File diff suppressed because it is too large Load Diff

@ -1,564 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
import (
"context"
"database/sql"
"errors"
"fmt"
"go/ast"
"reflect"
"strings"
"sync"
)
const (
// tag标签的名称
tagColumnName = "column"
// 输出字段 缓存的前缀
exportPrefix = "_exportStructFields_"
// 私有字段 缓存的前缀
privatePrefix = "_privateStructFields_"
// 数据库列名 缓存的前缀
dbColumnNamePrefix = "_dbColumnName_"
// 数据库所有列名,经过排序 缓存的前缀
dbColumnNameSlicePrefix = "_dbColumnNameSlice_"
// field对应的column的tag值 缓存的前缀
// structFieldTagPrefix = "_structFieldTag_"
// 数据库主键 缓存的前缀
// dbPKNamePrefix = "_dbPKName_"
)
// cacheStructFieldInfoMap 用于缓存反射的信息,sync.Map内部处理了并发锁
var cacheStructFieldInfoMap *sync.Map = &sync.Map{}
// var cacheStructFieldInfoMap = make(map[string]map[string]reflect.StructField)
// 用于缓存field对应的column的tag值
// var cacheStructFieldTagInfoMap = make(map[string]map[string]string)
// structFieldInfo 获取StructField的信息.只对struct或者*struct判断,如果是指针,返回指针下实际的struct类型
// 第一个返回值是可以输出的字段(首字母大写),第二个是不能输出的字段(首字母小写)
func structFieldInfo(typeOf *reflect.Type) error {
if typeOf == nil {
return errors.New("->structFieldInfo数据为空")
}
entityName := (*typeOf).String()
// 缓存的key
// 所有输出的属性,包含数据库字段,key是struct属性的名称,不区分大小写
exportCacheKey := exportPrefix + entityName
// 所有私有变量的属性,key是struct属性的名称,不区分大小写
privateCacheKey := privatePrefix + entityName
// 所有数据库的属性,key是数据库的字段名称,不区分大小写
dbColumnCacheKey := dbColumnNamePrefix + entityName
// 所有数据库字段名称的slice,经过排序,不区分大小写
dbColumnNameSliceCacheKey := dbColumnNameSlicePrefix + entityName
// structFieldTagCacheKey := structFieldTagPrefix + entityName
// dbPKNameCacheKey := dbPKNamePrefix + entityName
// 缓存的数据库主键值
_, exportOk := cacheStructFieldInfoMap.Load(exportCacheKey)
//_, exportOk := cacheStructFieldInfoMap[exportCacheKey]
//如果存在值,认为缓存中有所有的信息,不再处理
if exportOk {
return nil
}
// 获取字段长度
fieldNum := (*typeOf).NumField()
// 如果没有字段
if fieldNum < 1 {
return errors.New("->structFieldInfo-->NumField entity没有属性")
}
// 声明所有字段的载体
var allFieldMap *sync.Map = &sync.Map{}
// anonymous := make([]reflect.StructField, 0)
// 缓存的数据
exportStructFieldMap := make(map[string]reflect.StructField)
privateStructFieldMap := make(map[string]reflect.StructField)
dbColumnFieldMap := make(map[string]reflect.StructField)
// structFieldTagMap := make(map[string]string)
dbColumnFieldNameSlice := make([]string, 0)
// 遍历sync.Map,要求输入一个func作为参数
// 这个函数的入参、出参的类型都已经固定,不能修改
// 可以在函数体内编写自己的代码,调用map中的k,v
// var funcMapKV func(k, v interface{}) bool
funcMapKV := func(k, v interface{}) bool {
field := v.(reflect.StructField)
fieldName := field.Name
if ast.IsExported(fieldName) { // 如果是可以输出的,不区分大小写
exportStructFieldMap[strings.ToLower(fieldName)] = field
// 如果是数据库字段
tagColumnValue := field.Tag.Get(tagColumnName)
if len(tagColumnValue) > 0 {
// dbColumnFieldMap[tagColumnValue] = field
// 使用数据库字段的小写,处理oracle和达梦数据库的sql返回值大写
tagColumnValueLower := strings.ToLower(tagColumnValue)
dbColumnFieldMap[tagColumnValueLower] = field
dbColumnFieldNameSlice = append(dbColumnFieldNameSlice, tagColumnValueLower)
// structFieldTagMap[fieldName] = tagColumnValue
}
} else { // 私有属性
privateStructFieldMap[strings.ToLower(fieldName)] = field
}
return true
}
// 并发锁,用于处理slice并发append
var lock sync.Mutex
// funcRecursiveAnonymous 递归调用struct的匿名属性,就近覆盖属性
var funcRecursiveAnonymous func(allFieldMap *sync.Map, anonymous *reflect.StructField)
funcRecursiveAnonymous = func(allFieldMap *sync.Map, anonymous *reflect.StructField) {
// 字段类型
anonymousTypeOf := anonymous.Type
if anonymousTypeOf.Kind() == reflect.Ptr {
// 获取指针下的Struct类型
anonymousTypeOf = anonymousTypeOf.Elem()
}
// 只处理Struct类型
if anonymousTypeOf.Kind() != reflect.Struct {
return
}
// 获取字段长度
fieldNum := anonymousTypeOf.NumField()
// 如果没有字段
if fieldNum < 1 {
return
}
// 遍历所有字段
for i := 0; i < fieldNum; i++ {
anonymousField := anonymousTypeOf.Field(i)
if anonymousField.Anonymous { // 匿名struct里自身又有匿名struct
funcRecursiveAnonymous(allFieldMap, &anonymousField)
} else if _, ok := allFieldMap.Load(anonymousField.Name); !ok { // 普通命名字段,而且没有记录过
allFieldMap.Store(anonymousField.Name, anonymousField)
lock.Lock()
funcMapKV(anonymousField.Name, anonymousField)
lock.Unlock()
}
}
}
// 遍历所有字段,记录匿名属性
for i := 0; i < fieldNum; i++ {
field := (*typeOf).Field(i)
if field.Anonymous { // 如果是匿名的
funcRecursiveAnonymous(allFieldMap, &field)
} else if _, ok := allFieldMap.Load(field.Name); !ok { // 普通命名字段,而且没有记录过
allFieldMap.Store(field.Name, field)
lock.Lock()
funcMapKV(field.Name, field)
lock.Unlock()
}
}
// allFieldMap.Range(f)
// 加入缓存
cacheStructFieldInfoMap.Store(exportCacheKey, exportStructFieldMap)
cacheStructFieldInfoMap.Store(privateCacheKey, privateStructFieldMap)
cacheStructFieldInfoMap.Store(dbColumnCacheKey, dbColumnFieldMap)
// cacheStructFieldInfoMap[exportCacheKey] = exportStructFieldMap
// cacheStructFieldInfoMap[privateCacheKey] = privateStructFieldMap
// cacheStructFieldInfoMap[dbColumnCacheKey] = dbColumnFieldMap
// cacheStructFieldTagInfoMap[structFieldTagCacheKey] = structFieldTagMap
// 不按照字母顺序,按照反射获取的Struct属性顺序,生成insert语句和update语句
// sort.Strings(dbColumnFieldNameSlice)
cacheStructFieldInfoMap.Store(dbColumnNameSliceCacheKey, dbColumnFieldNameSlice)
return nil
}
// setFieldValueByColumnName 根据数据库的字段名,找到struct映射的字段,并赋值
func setFieldValueByColumnName(entity interface{}, columnName string, value interface{}) error {
// 先从本地缓存中查找
typeOf := reflect.TypeOf(entity)
valueOf := reflect.ValueOf(entity)
if typeOf.Kind() == reflect.Ptr { // 如果是指针
typeOf = typeOf.Elem()
valueOf = valueOf.Elem()
}
dbMap, err := getDBColumnFieldMap(&typeOf)
if err != nil {
return err
}
f, ok := dbMap[strings.ToLower(columnName)]
if ok { // 给主键赋值
valueOf.FieldByName(f.Name).Set(reflect.ValueOf(value))
}
return nil
}
// structFieldValue 获取指定字段的值
func structFieldValue(s interface{}, fieldName string) (interface{}, error) {
if s == nil || len(fieldName) < 1 {
return nil, errors.New("->structFieldValue数据为空")
}
// entity的s类型
valueOf := reflect.ValueOf(s)
kind := valueOf.Kind()
if !(kind == reflect.Ptr || kind == reflect.Struct) {
return nil, errors.New("->structFieldValue必须是Struct或者*Struct类型")
}
if kind == reflect.Ptr {
// 获取指针下的Struct类型
valueOf = valueOf.Elem()
if valueOf.Kind() != reflect.Struct {
return nil, errors.New("->structFieldValue必须是Struct或者*Struct类型")
}
}
// FieldByName方法返回的是reflect.Value类型,调用Interface()方法,返回原始类型的数据值
value := valueOf.FieldByName(fieldName).Interface()
return value, nil
}
// getDBColumnExportFieldMap 获取实体类的数据库字段,key是数据库的字段名称.同时返回所有的字段属性的map,key是实体类的属性.不区分大小写
func getDBColumnExportFieldMap(typeOf *reflect.Type) (map[string]reflect.StructField, map[string]reflect.StructField, error) {
dbColumnFieldMap, err := getCacheStructFieldInfoMap(typeOf, dbColumnNamePrefix)
if err != nil {
return nil, nil, err
}
exportFieldMap, err := getCacheStructFieldInfoMap(typeOf, exportPrefix)
return dbColumnFieldMap, exportFieldMap, err
}
// getDBColumnFieldMap 获取实体类的数据库字段,key是数据库的字段名称.不区分大小写
func getDBColumnFieldMap(typeOf *reflect.Type) (map[string]reflect.StructField, error) {
return getCacheStructFieldInfoMap(typeOf, dbColumnNamePrefix)
}
// getDBColumnFieldNameSlice 获取实体类的数据库字段,经过排序,key是数据库的字段名称.不区分大小写,
func getDBColumnFieldNameSlice(typeOf *reflect.Type) ([]string, error) {
dbColumnFieldSlice, dbmapErr := getCacheStructFieldInfo(typeOf, dbColumnNameSlicePrefix)
if dbmapErr != nil {
return nil, fmt.Errorf("->getDBColumnFieldNameSlice-->getCacheStructFieldInfo()取值错误:%w", dbmapErr)
}
dbcfSlice, efOK := dbColumnFieldSlice.([]string)
if !efOK {
return dbcfSlice, errors.New("->getDBColumnFieldNameSlice-->dbColumnFieldSlice取值转[]string类型异常")
}
return dbcfSlice, nil
}
// getCacheStructFieldInfo 根据类型和key,获取缓存的数据字段信息slice,已经排序
func getCacheStructFieldInfo(typeOf *reflect.Type, keyPrefix string) (interface{}, error) {
if typeOf == nil {
return nil, errors.New("->getCacheStructFieldInfo-->typeOf不能为空")
}
key := keyPrefix + (*typeOf).String()
dbColumnFieldMap, dbOk := cacheStructFieldInfoMap.Load(key)
// dbColumnFieldMap, dbOk := cacheStructFieldInfoMap[key]
if !dbOk { // 缓存不存在
// 获取实体类的输出字段和私有 字段
err := structFieldInfo(typeOf)
if err != nil {
return nil, err
}
dbColumnFieldMap, dbOk = cacheStructFieldInfoMap.Load(key)
// dbColumnFieldMap, dbOk = cacheStructFieldInfoMap[key]
if !dbOk {
return nil, errors.New("->getCacheStructFieldInfo-->cacheStructFieldInfoMap.Load()获取数据库字段dbColumnFieldMap异常")
}
}
return dbColumnFieldMap, nil
// return dbColumnFieldMap, nil
}
// getCacheStructFieldInfoMap 根据类型和key,获取缓存的字段信息
func getCacheStructFieldInfoMap(typeOf *reflect.Type, keyPrefix string) (map[string]reflect.StructField, error) {
dbColumnFieldMap, dbmapErr := getCacheStructFieldInfo(typeOf, keyPrefix)
if dbmapErr != nil {
return nil, fmt.Errorf("->getCacheStructFieldInfoMap-->getCacheStructFieldInfo()取值错误:%w", dbmapErr)
}
dbcfMap, efOK := dbColumnFieldMap.(map[string]reflect.StructField)
if !efOK {
return dbcfMap, errors.New("->getCacheStructFieldInfoMap-->dbColumnFieldMap取值转map[string]reflect.StructField类型异常")
}
return dbcfMap, nil
// return dbColumnFieldMap, nil
}
// columnAndValue 根据保存的对象,返回插入的语句,需要插入的字段,字段的值
func columnAndValue(entity interface{}) (reflect.Type, []reflect.StructField, []interface{}, error) {
typeOf, checkerr := checkEntityKind(entity)
if checkerr != nil {
return typeOf, nil, nil, checkerr
}
// 获取实体类的反射,指针下的struct
valueOf := reflect.ValueOf(entity).Elem()
// reflect.Indirect
// 先从本地缓存中查找
// typeOf := reflect.TypeOf(entity).Elem()
dbMap, err := getDBColumnFieldMap(&typeOf)
if err != nil {
return typeOf, nil, nil, err
}
dbSlice, err := getDBColumnFieldNameSlice(&typeOf)
if err != nil {
return typeOf, nil, nil, err
}
// 实体类公开字段的长度
fLen := len(dbMap)
// 长度不一致
if fLen-len(dbSlice) != 0 {
return typeOf, nil, nil, errors.New("->columnAndValue-->缓存的数据库字段和实体类字段不对应")
}
// 接收列的数组,这里是做一个副本,避免外部更改掉原始的列信息
columns := make([]reflect.StructField, 0, fLen)
// 接收值的数组
values := make([]interface{}, 0, fLen)
// 遍历所有数据库属性
for _, fieldName := range dbSlice {
//获取字段类型的Kind
// fieldKind := field.Type.Kind()
//if !allowTypeMap[fieldKind] { //不允许的类型
// continue
//}
field := dbMap[fieldName]
columns = append(columns, field)
// FieldByName方法返回的是reflect.Value类型,调用Interface()方法,返回原始类型的数据值.字段不会重名,不使用FieldByIndex()函数
value := valueOf.FieldByName(field.Name).Interface()
// 添加到记录值的数组
values = append(values, value)
}
// 缓存数据库的列
return typeOf, columns, values, nil
}
// entityPKFieldName 获取实体类主键属性名称
func entityPKFieldName(entity IEntityStruct, typeOf *reflect.Type) (string, error) {
//检查是否是指针对象
//typeOf, checkerr := checkEntityKind(entity)
//if checkerr != nil {
// return "", checkerr
//}
// 缓存的key,TypeOf和ValueOf的String()方法,返回值不一样
// typeOf := reflect.TypeOf(entity).Elem()
dbMap, err := getDBColumnFieldMap(typeOf)
if err != nil {
return "", err
}
field := dbMap[strings.ToLower(entity.GetPKColumnName())]
return field.Name, nil
}
// checkEntityKind 检查entity类型必须是*struct类型或者基础类型的指针
func checkEntityKind(entity interface{}) (reflect.Type, error) {
if entity == nil {
return nil, errors.New("->checkEntityKind参数不能为空,必须是*struct类型或者基础类型的指针")
}
typeOf := reflect.TypeOf(entity)
if typeOf.Kind() != reflect.Ptr { // 如果不是指针
return nil, errors.New("->checkEntityKind必须是*struct类型或者基础类型的指针")
}
typeOf = typeOf.Elem()
//if !(typeOf.Kind() == reflect.Struct || allowBaseTypeMap[typeOf.Kind()]) { //如果不是指针
// return nil, errors.New("checkEntityKind必须是*struct类型或者基础类型的指针")
//}
return typeOf, nil
}
// sqlRowsValues 包装接收sqlRows的Values数组,反射rows屏蔽数据库null值,兼容单个字段查询和Struct映射
// fix:converting NULL to int is unsupported
// 当读取数据库的值为NULL时,由于基本类型不支持为NULL,通过反射将未知driver.Value改为interface{},不再映射到struct实体类
// 感谢@fastabler提交的pr
// oneColumnScanner 只有一个字段,而且可以直接Scan,例如string或者[]string,不需要反射StructType进行处理
func sqlRowsValues(ctx context.Context, dialect string, valueOf *reflect.Value, typeOf *reflect.Type, rows *sql.Rows, driverValue *reflect.Value, columnTypes []*sql.ColumnType, entity interface{}, dbColumnFieldMap, exportFieldMap *map[string]reflect.StructField) error {
if entity == nil && valueOf == nil {
return errors.New("->sqlRowsValues-->valueOfElem为nil")
}
var valueOfElem reflect.Value
if entity == nil && valueOf != nil {
valueOfElem = valueOf.Elem()
}
ctLen := len(columnTypes)
// 声明载体数组,用于存放struct的属性指针
// Declare a carrier array to store the attribute pointer of the struct
values := make([]interface{}, ctLen)
// 记录需要类型转换的字段信息
var fieldTempDriverValueMap map[*sql.ColumnType]*driverValueInfo
if iscdvm {
fieldTempDriverValueMap = make(map[*sql.ColumnType]*driverValueInfo)
}
var err error
var customDriverValueConver ICustomDriverValueConver
var converOK bool
for i, columnType := range columnTypes {
if iscdvm {
databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName())
// 根据接收的类型,获取到类型转换的接口实现,优先匹配指定的数据库类型
customDriverValueConver, converOK = customDriverValueMap[dialect+"."+databaseTypeName]
if !converOK {
customDriverValueConver, converOK = customDriverValueMap[databaseTypeName]
}
}
dv := driverValue.Index(i)
if dv.IsValid() && dv.InterfaceData()[0] == 0 { // 该字段的数据库值是null,取默认值
values[i] = new(interface{})
continue
} else if converOK { // 如果是需要转换的字段
// 获取字段类型
var structFieldType *reflect.Type
if entity != nil { // 查询一个字段,并且可以直接接收
structFieldType = typeOf
} else { // 如果是struct类型
field, err := getStructFieldByColumnType(columnType, dbColumnFieldMap, exportFieldMap)
if err != nil {
return err
}
if field != nil { // 存在这个字段
vtype := field.Type
structFieldType = &vtype
}
}
tempDriverValue, err := customDriverValueConver.GetDriverValue(ctx, columnType, structFieldType)
if err != nil {
return err
}
if tempDriverValue == nil {
return errors.New("->sqlRowsValues-->customDriverValueConver.GetDriverValue返回的driver.Value不能为nil")
}
values[i] = tempDriverValue
// 如果需要类型转换
dvinfo := driverValueInfo{}
dvinfo.customDriverValueConver = customDriverValueConver
// dvinfo.columnType = columnType
dvinfo.structFieldType = structFieldType
dvinfo.tempDriverValue = tempDriverValue
fieldTempDriverValueMap[columnType] = &dvinfo
continue
} else if entity != nil { // 查询一个字段,并且可以直接接收
values[i] = entity
continue
} else {
field, err := getStructFieldByColumnType(columnType, dbColumnFieldMap, exportFieldMap)
if err != nil {
return err
}
if field == nil { // 如果不存在这个字段
values[i] = new(interface{})
} else {
// fieldType := refPV.FieldByName(field.Name).Type()
// v := reflect.New(field.Type).Interface()
// 字段的反射值
fieldValue := valueOfElem.FieldByName(field.Name)
v := fieldValue.Addr().Interface()
// v := new(interface{})
values[i] = v
}
}
}
err = rows.Scan(values...)
if err != nil {
return err
}
if len(fieldTempDriverValueMap) < 1 {
return err
}
// 循环需要替换的值
for columnType, driverValueInfo := range fieldTempDriverValueMap {
// 根据列名,字段类型,新值 返回符合接收类型值的指针,返回值是个指针,指针,指针!!!!
// typeOf := fieldValue.Type()
rightValue, errConverDriverValue := driverValueInfo.customDriverValueConver.ConverDriverValue(ctx, columnType, driverValueInfo.tempDriverValue, driverValueInfo.structFieldType)
if errConverDriverValue != nil {
errConverDriverValue = fmt.Errorf("->sqlRowsValues-->customDriverValueConver.ConverDriverValue错误:%w", errConverDriverValue)
FuncLogError(ctx, errConverDriverValue)
return errConverDriverValue
}
if entity != nil { // 查询一个字段,并且可以直接接收
// entity = rightValue
// valueOfElem.Set(reflect.ValueOf(rightValue).Elem())
reflect.ValueOf(entity).Elem().Set(reflect.ValueOf(rightValue).Elem())
continue
} else { // 如果是Struct类型接收
field, err := getStructFieldByColumnType(columnType, dbColumnFieldMap, exportFieldMap)
if err != nil {
return err
}
if field != nil { // 如果存在这个字段
// 字段的反射值
fieldValue := valueOfElem.FieldByName(field.Name)
// 给字段赋值
fieldValue.Set(reflect.ValueOf(rightValue).Elem())
}
}
}
return err
}
// getStructFieldByColumnType 根据ColumnType获取StructField对象,兼容驼峰
func getStructFieldByColumnType(columnType *sql.ColumnType, dbColumnFieldMap *map[string]reflect.StructField, exportFieldMap *map[string]reflect.StructField) (*reflect.StructField, error) {
columnName := strings.ToLower(columnType.Name())
// columnName := "test"
// 从缓存中获取列名的field字段
// Get the field field of the column name from the cache
field, fok := (*dbColumnFieldMap)[columnName]
if !fok {
field, fok = (*exportFieldMap)[columnName]
if !fok {
// 尝试驼峰
cname := strings.ReplaceAll(columnName, "_", "")
field, fok = (*exportFieldMap)[cname]
}
}
if fok {
return &field, nil
}
return nil, nil
}

@ -1,611 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
import (
"context"
"errors"
"strconv"
"gitee.com/chunanyong/zorm/decimal"
)
// FuncDecimalValue 设置decimal类型接收值,复写函数自定义decimal实现,例如github.com/shopspring/decimal,返回的是指针
var FuncDecimalValue = func(ctx context.Context, dialect string) interface{} {
return &decimal.Decimal{}
}
// OverrideFunc 重写ZORM的函数,用于风险监控,只要查看这个函数的调用,就知道哪些地方重写了函数,避免项目混乱.当你使用这个函数时,你必须知道自己在做什么
// funcName 是需要重写的方法命,funcObject是对应的函数. 返回值bool是否重写成功,interface{}是重写前的函数
// 一般是在init里调用重写
func OverrideFunc(funcName string, funcObject interface{}) (bool, interface{}, error) {
if funcName == "" {
return false, nil, errors.New("->OverrideFunc-->funcName不能为空")
}
// oldFunc 老的函数
var oldFunc interface{} = nil
switch funcName {
case "Transaction":
newFunc, ok := funcObject.(func(ctx context.Context, doTransaction func(ctx context.Context) (interface{}, error)) (interface{}, error))
if ok {
oldFunc = transaction
transaction = newFunc
}
case "QueryRow":
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder, entity interface{}) (bool, error))
if ok {
oldFunc = queryRow
queryRow = newFunc
}
case "Query":
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder, rowsSlicePtr interface{}, page *Page) error)
if ok {
oldFunc = query
query = newFunc
}
case "QueryRowMap":
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder) (map[string]interface{}, error))
if ok {
oldFunc = queryRowMap
queryRowMap = newFunc
}
case "QueryMap":
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder, page *Page) ([]map[string]interface{}, error))
if ok {
oldFunc = queryMap
queryMap = newFunc
}
case "UpdateFinder":
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder) (int, error))
if ok {
oldFunc = updateFinder
updateFinder = newFunc
}
case "Insert":
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
if ok {
oldFunc = insert
insert = newFunc
}
case "InsertSlice":
newFunc, ok := funcObject.(func(ctx context.Context, entityStructSlice []IEntityStruct) (int, error))
if ok {
oldFunc = insertSlice
insertSlice = newFunc
}
case "Update":
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
if ok {
oldFunc = update
update = newFunc
}
case "UpdateNotZeroValue":
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
if ok {
oldFunc = updateNotZeroValue
updateNotZeroValue = newFunc
}
case "Delete":
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
if ok {
oldFunc = delete
delete = newFunc
}
case "InsertEntityMap":
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityMap) (int, error))
if ok {
oldFunc = insertEntityMap
insertEntityMap = newFunc
}
case "InsertEntityMapSlice":
newFunc, ok := funcObject.(func(ctx context.Context, entity []IEntityMap) (int, error))
if ok {
oldFunc = insertEntityMapSlice
insertEntityMapSlice = newFunc
}
case "UpdateEntityMap":
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityMap) (int, error))
if ok {
oldFunc = updateEntityMap
updateEntityMap = newFunc
}
default:
return false, oldFunc, errors.New("->OverrideFunc-->函数" + funcName + "暂不支持重写或不存在")
}
if oldFunc == nil {
return false, oldFunc, errors.New("->OverrideFunc-->请检查传入的" + funcName + "函数实现,断言转换失败.")
}
return true, oldFunc, nil
}
// typeConvertInt64toInt int64 转 int
func typeConvertInt64toInt(from int64) (int, error) {
strInt64 := strconv.FormatInt(from, 10)
return strconv.Atoi(strInt64)
}
/*
func typeConvertFloat32(i interface{}) (float32, error) {
if i == nil {
return 0, nil
}
if v, ok := i.(float32); ok {
return v, nil
}
v, err := typeConvertString(i)
if err != nil {
return 0, err
}
vf, err := strconv.ParseFloat(v, 32)
return float32(vf), err
}
func typeConvertFloat64(i interface{}) (float64, error) {
if i == nil {
return 0, nil
}
if v, ok := i.(float64); ok {
return v, nil
}
v, err := typeConvertString(i)
if err != nil {
return 0, err
}
return strconv.ParseFloat(v, 64)
}
func typeConvertDecimal(i interface{}) (decimal.Decimal, error) {
if i == nil {
return decimal.Zero, nil
}
if v, ok := i.(decimal.Decimal); ok {
return v, nil
}
v, err := typeConvertString(i)
if err != nil {
return decimal.Zero, err
}
return decimal.NewFromString(v)
}
func typeConvertInt64(i interface{}) (int64, error) {
if i == nil {
return 0, nil
}
if v, ok := i.(int64); ok {
return v, nil
}
v, err := typeConvertInt(i)
if err != nil {
return 0, err
}
return int64(v), err
}
func typeConvertString(i interface{}) (string, error) {
if i == nil {
return "", nil
}
switch value := i.(type) {
case int:
return strconv.Itoa(value), nil
case int8:
return strconv.Itoa(int(value)), nil
case int16:
return strconv.Itoa(int(value)), nil
case int32:
return strconv.Itoa(int(value)), nil
case int64:
return strconv.Itoa(int(value)), nil
case uint:
return strconv.FormatUint(uint64(value), 10), nil
case uint8:
return strconv.FormatUint(uint64(value), 10), nil
case uint16:
return strconv.FormatUint(uint64(value), 10), nil
case uint32:
return strconv.FormatUint(uint64(value), 10), nil
case uint64:
return strconv.FormatUint(uint64(value), 10), nil
case float32:
return strconv.FormatFloat(float64(value), 'f', -1, 32), nil
case float64:
return strconv.FormatFloat(value, 'f', -1, 64), nil
case bool:
return strconv.FormatBool(value), nil
case string:
return value, nil
case []byte:
return string(value), nil
default:
return fmt.Sprintf("%v", value), nil
}
}
//false: "", 0, false, off
func typeConvertBool(i interface{}) (bool, error) {
if i == nil {
return false, nil
}
if v, ok := i.(bool); ok {
return v, nil
}
s, err := typeConvertString(i)
if err != nil {
return false, err
}
if s != "" && s != "0" && s != "false" && s != "off" {
return true, err
}
return false, err
}
func typeConvertInt(i interface{}) (int, error) {
if i == nil {
return 0, nil
}
switch value := i.(type) {
case int:
return value, nil
case int8:
return int(value), nil
case int16:
return int(value), nil
case int32:
return int(value), nil
case int64:
return int(value), nil
case uint:
return int(value), nil
case uint8:
return int(value), nil
case uint16:
return int(value), nil
case uint32:
return int(value), nil
case uint64:
return int(value), nil
case float32:
return int(value), nil
case float64:
return int(value), nil
case bool:
if value {
return 1, nil
}
return 0, nil
default:
v, err := typeConvertString(value)
if err != nil {
return 0, err
}
return strconv.Atoi(v)
}
}
func typeConvertTime(i interface{}, format string, TZLocation ...*time.Location) (time.Time, error) {
s, err := typeConvertString(i)
if err != nil {
return time.Time{}, err
}
return typeConvertStrToTime(s, format, TZLocation...)
}
func typeConvertStrToTime(str string, format string, TZLocation ...*time.Location) (time.Time, error) {
if len(TZLocation) > 0 {
return time.ParseInLocation(format, str, TZLocation[0])
}
return time.ParseInLocation(format, str, time.Local)
}
func encodeString(s string) []byte {
return []byte(s)
}
func decodeToString(b []byte) string {
return string(b)
}
func encodeBool(b bool) []byte {
if b {
return []byte{1}
}
return []byte{0}
}
func encodeInt(i int) []byte {
if i <= math.MaxInt8 {
return encodeInt8(int8(i))
} else if i <= math.MaxInt16 {
return encodeInt16(int16(i))
} else if i <= math.MaxInt32 {
return encodeInt32(int32(i))
} else {
return encodeInt64(int64(i))
}
}
func encodeUint(i uint) []byte {
if i <= math.MaxUint8 {
return encodeUint8(uint8(i))
} else if i <= math.MaxUint16 {
return encodeUint16(uint16(i))
} else if i <= math.MaxUint32 {
return encodeUint32(uint32(i))
} else {
return encodeUint64(uint64(i))
}
}
func encodeInt8(i int8) []byte {
return []byte{byte(i)}
}
func encodeUint8(i uint8) []byte {
return []byte{byte(i)}
}
func encodeInt16(i int16) []byte {
bytes := make([]byte, 2)
binary.LittleEndian.PutUint16(bytes, uint16(i))
return bytes
}
func encodeUint16(i uint16) []byte {
bytes := make([]byte, 2)
binary.LittleEndian.PutUint16(bytes, i)
return bytes
}
func encodeInt32(i int32) []byte {
bytes := make([]byte, 4)
binary.LittleEndian.PutUint32(bytes, uint32(i))
return bytes
}
func encodeUint32(i uint32) []byte {
bytes := make([]byte, 4)
binary.LittleEndian.PutUint32(bytes, i)
return bytes
}
func encodeInt64(i int64) []byte {
bytes := make([]byte, 8)
binary.LittleEndian.PutUint64(bytes, uint64(i))
return bytes
}
func encodeUint64(i uint64) []byte {
bytes := make([]byte, 8)
binary.LittleEndian.PutUint64(bytes, i)
return bytes
}
func encodeFloat32(f float32) []byte {
bits := math.Float32bits(f)
bytes := make([]byte, 4)
binary.LittleEndian.PutUint32(bytes, bits)
return bytes
}
func encodeFloat64(f float64) []byte {
bits := math.Float64bits(f)
bytes := make([]byte, 8)
binary.LittleEndian.PutUint64(bytes, bits)
return bytes
}
func encode(vs ...interface{}) []byte {
buf := new(bytes.Buffer)
for i := 0; i < len(vs); i++ {
switch value := vs[i].(type) {
case int:
buf.Write(encodeInt(value))
case int8:
buf.Write(encodeInt8(value))
case int16:
buf.Write(encodeInt16(value))
case int32:
buf.Write(encodeInt32(value))
case int64:
buf.Write(encodeInt64(value))
case uint:
buf.Write(encodeUint(value))
case uint8:
buf.Write(encodeUint8(value))
case uint16:
buf.Write(encodeUint16(value))
case uint32:
buf.Write(encodeUint32(value))
case uint64:
buf.Write(encodeUint64(value))
case bool:
buf.Write(encodeBool(value))
case string:
buf.Write(encodeString(value))
case []byte:
buf.Write(value)
case float32:
buf.Write(encodeFloat32(value))
case float64:
buf.Write(encodeFloat64(value))
default:
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
buf.Write(encodeString(fmt.Sprintf("%v", value)))
}
}
}
return buf.Bytes()
}
func isNumeric(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < byte('0') || s[i] > byte('9') {
return false
}
}
return true
}
func typeConvertTimeDuration(i interface{}) time.Duration {
return time.Duration(typeConvertInt64(i))
}
func typeConvertBytes(i interface{}) []byte {
if i == nil {
return nil
}
if r, ok := i.([]byte); ok {
return r
}
return encode(i)
}
func typeConvertStrings(i interface{}) []string {
if i == nil {
return nil
}
if r, ok := i.([]string); ok {
return r
} else if r, ok := i.([]interface{}); ok {
strs := make([]string, len(r))
for k, v := range r {
strs[k] = typeConvertString(v)
}
return strs
}
return []string{fmt.Sprintf("%v", i)}
}
func typeConvertInt8(i interface{}) int8 {
if i == nil {
return 0
}
if v, ok := i.(int8); ok {
return v
}
return int8(typeConvertInt(i))
}
func typeConvertInt16(i interface{}) int16 {
if i == nil {
return 0
}
if v, ok := i.(int16); ok {
return v
}
return int16(typeConvertInt(i))
}
func typeConvertInt32(i interface{}) int32 {
if i == nil {
return 0
}
if v, ok := i.(int32); ok {
return v
}
return int32(typeConvertInt(i))
}
func typeConvertUint(i interface{}) uint {
if i == nil {
return 0
}
switch value := i.(type) {
case int:
return uint(value)
case int8:
return uint(value)
case int16:
return uint(value)
case int32:
return uint(value)
case int64:
return uint(value)
case uint:
return value
case uint8:
return uint(value)
case uint16:
return uint(value)
case uint32:
return uint(value)
case uint64:
return uint(value)
case float32:
return uint(value)
case float64:
return uint(value)
case bool:
if value {
return 1
}
return 0
default:
v, _ := strconv.ParseUint(typeConvertString(value), 10, 64)
return uint(v)
}
}
func typeConvertUint8(i interface{}) uint8 {
if i == nil {
return 0
}
if v, ok := i.(uint8); ok {
return v
}
return uint8(typeConvertUint(i))
}
func typeConvertUint16(i interface{}) uint16 {
if i == nil {
return 0
}
if v, ok := i.(uint16); ok {
return v
}
return uint16(typeConvertUint(i))
}
func typeConvertUint32(i interface{}) uint32 {
if i == nil {
return 0
}
if v, ok := i.(uint32); ok {
return v
}
return uint32(typeConvertUint(i))
}
func typeConvertUint64(i interface{}) uint64 {
if i == nil {
return 0
}
if v, ok := i.(uint64); ok {
return v
}
return uint64(typeConvertUint(i))
}
*/

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.8 KiB

@ -1,4 +0,0 @@
.db
*.test
*~
*.swp

@ -1,7 +0,0 @@
Copyright (C) 2011 Blake Mizerany
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -1,99 +0,0 @@
# pq - A pure Go postgres driver for Go's database/sql package
**This package is now deprecated. The up to date version is at
[github.com/lib/pq](https://github.com/lib/pq).**
## Install
go get github.com/bmizerany/pq
## Docs
<http://go.pkgdoc.org/github.com/bmizerany/pq>
## Use
package main
import (
_ "github.com/bmizerany/pq"
"database/sql"
)
func main() {
db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full")
// ...
}
**Connection String Parameters**
These are a subset of the libpq connection parameters. In addition, a
number of the [environment
variables](http://www.postgresql.org/docs/9.1/static/libpq-envars.html)
supported by libpq are also supported. Just like libpq, these have
lower precedence than explicitly provided connection parameters.
See http://www.postgresql.org/docs/9.1/static/libpq-connect.html.
* `dbname` - The name of the database to connect to
* `user` - The user to sign in as
* `password` - The user's password
* `host` - The host to connect to. Values that start with `/` are for unix domain sockets. (default is `localhost`)
* `port` - The port to bind to. (default is `5432`)
* `sslmode` - Whether or not to use SSL (default is `require`, this is not the default for libpq)
Valid values are:
* `disable` - No SSL
* `require` - Always SSL (skip verification)
* `verify-full` - Always SSL (require verification)
See http://golang.org/pkg/database/sql to learn how to use with `pq` through the `database/sql` package.
## Tests
`go test` is used for testing. A running PostgreSQL server is
required, with the ability to log in. The default database to connect
to test with is "pqgotest," but it can be overridden using environment
variables.
Example:
PGHOST=/var/run/postgresql go test pq
## Features
* SSL
* Handles bad connections for `database/sql`
* Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`)
* Scan binary blobs correctly (i.e. `bytea`)
* pq.ParseURL for converting urls to connection strings for sql.Open.
* Many libpq compatible environment variables
* Unix socket support
## Future / Things you can help with
* Notifications: `LISTEN`/`NOTIFY`
* `hstore` sugar (i.e. handling hstore in `rows.Scan`)
## Thank you (alphabetical)
Some of these contributors are from the original library `bmizerany/pq.go` whose
code still exists in here.
* Andy Balholm (andybalholm)
* Ben Berkert (benburkert)
* Bjørn Madsen (aeons)
* Blake Gentry (bgentry)
* Brad Fitzpatrick (bradfitz)
* Daniel Farina (fdr)
* Everyone at The Go Team
* Federico Romero (federomero)
* Heroku (heroku)
* John Gallagher (jgallagher)
* Kamil Kisiel (kisielk)
* Keith Rarick (kr)
* Marc Brinkmann (mbr)
* Martin Olsen (martinolsen)
* Mike Lewis (mikelikespie)
* Ryan Smith (ryandotsmith)
* Samuel Stauffer (samuel)
* notedit (notedit)

@ -1,80 +0,0 @@
package pq
import (
"bytes"
"encoding/binary"
)
type readBuf []byte
func (b *readBuf) int32() (n int) {
n = int(int32(binary.BigEndian.Uint32(*b)))
*b = (*b)[4:]
return
}
func (b *readBuf) oid() (n oid) {
n = oid(binary.BigEndian.Uint32(*b))
*b = (*b)[4:]
return
}
func (b *readBuf) int16() (n int) {
n = int(binary.BigEndian.Uint16(*b))
*b = (*b)[2:]
return
}
var stringTerm = []byte{0}
func (b *readBuf) string() string {
i := bytes.Index(*b, stringTerm)
if i < 0 {
errorf("invalid message format; expected string terminator")
}
s := (*b)[:i]
*b = (*b)[i+1:]
return string(s)
}
func (b *readBuf) next(n int) (v []byte) {
v = (*b)[:n]
*b = (*b)[n:]
return
}
func (b *readBuf) byte() byte {
return b.next(1)[0]
}
type writeBuf []byte
func newWriteBuf(c byte) *writeBuf {
b := make(writeBuf, 5)
b[0] = c
return &b
}
func (b *writeBuf) int32(n int) {
x := make([]byte, 4)
binary.BigEndian.PutUint32(x, uint32(n))
*b = append(*b, x...)
}
func (b *writeBuf) int16(n int) {
x := make([]byte, 2)
binary.BigEndian.PutUint16(x, uint16(n))
*b = append(*b, x...)
}
func (b *writeBuf) string(s string) {
*b = append(*b, (s + "\000")...)
}
func (b *writeBuf) byte(c byte) {
*b = append(*b, c)
}
func (b *writeBuf) bytes(v []byte) {
*b = append(*b, v...)
}

@ -1,678 +0,0 @@
package pq
import (
"bufio"
"crypto/md5"
"crypto/tls"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"os/user"
"path"
"strconv"
"strings"
)
var (
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
ErrNotSupported = errors.New("pq: invalid command")
)
type drv struct{}
func (d *drv) Open(name string) (driver.Conn, error) {
return Open(name)
}
func init() {
sql.Register("postgres", &drv{})
}
type conn struct {
c net.Conn
buf *bufio.Reader
namei int
}
func Open(name string) (_ driver.Conn, err error) {
defer errRecover(&err)
defer errRecoverWithPGReason(&err)
o := make(Values)
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o.Set("host", "localhost")
o.Set("port", "5432")
// Default the username, but ignore errors, because a user
// passed in via environment variable or connection string
// would be okay. This can result in connections failing
// *sometimes* if the client relies on being able to determine
// the current username and there are intermittent problems.
u, err := user.Current()
if err == nil {
o.Set("user", u.Username)
}
for k, v := range parseEnviron(os.Environ()) {
o.Set(k, v)
}
parseOpts(name, o)
c, err := net.Dial(network(o))
if err != nil {
return nil, err
}
cn := &conn{c: c}
cn.ssl(o)
cn.buf = bufio.NewReader(cn.c)
cn.startup(o)
return cn, nil
}
func network(o Values) (string, string) {
host := o.Get("host")
if strings.HasPrefix(host, "/") {
sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
return "unix", sockPath
}
return "tcp", host + ":" + o.Get("port")
}
type Values map[string]string
func (vs Values) Set(k, v string) {
vs[k] = v
}
func (vs Values) Get(k string) (v string) {
v, _ = vs[k]
return
}
func parseOpts(name string, o Values) {
if len(name) == 0 {
return
}
ps := strings.Split(name, " ")
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
errorf("invalid option: %q", p)
}
o.Set(kv[0], kv[1])
}
}
func (cn *conn) Begin() (driver.Tx, error) {
_, err := cn.Exec("BEGIN", nil)
if err != nil {
return nil, err
}
return cn, err
}
func (cn *conn) Commit() error {
_, err := cn.Exec("COMMIT", nil)
return err
}
func (cn *conn) Rollback() error {
_, err := cn.Exec("ROLLBACK", nil)
return err
}
func (cn *conn) gname() string {
cn.namei++
return strconv.FormatInt(int64(cn.namei), 10)
}
func (cn *conn) simpleQuery(q string) (res driver.Result, err error) {
defer errRecover(&err)
b := newWriteBuf('Q')
b.string(q)
cn.send(b)
for {
t, r := cn.recv1()
switch t {
case 'C':
res = parseComplete(r.string())
case 'Z':
// done
return
case 'E':
err = parseError(r)
case 'T', 'N', 'S':
// ignore
default:
errorf("unknown response for simple query: %q", t)
}
}
panic("not reached")
}
func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) {
defer errRecover(&err)
st := &stmt{cn: cn, name: stmtName, query: q}
b := newWriteBuf('P')
b.string(st.name)
b.string(q)
b.int16(0)
cn.send(b)
b = newWriteBuf('D')
b.byte('S')
b.string(st.name)
cn.send(b)
cn.send(newWriteBuf('S'))
for {
t, r := cn.recv1()
switch t {
case '1', '2', 'N':
case 't':
st.nparams = int(r.int16())
st.paramTyps = make([]oid, st.nparams, st.nparams)
for i := 0; i < st.nparams; i += 1 {
st.paramTyps[i] = r.oid()
}
case 'T':
n := r.int16()
st.cols = make([]string, n)
st.rowTyps = make([]oid, n)
for i := range st.cols {
st.cols[i] = r.string()
r.next(6)
st.rowTyps[i] = r.oid()
r.next(8)
}
case 'n':
// no data
case 'Z':
return st, err
case 'E':
err = parseError(r)
default:
errorf("unexpected describe rows response: %q", t)
}
}
panic("not reached")
}
func (cn *conn) Prepare(q string) (driver.Stmt, error) {
return cn.prepareTo(q, cn.gname())
}
func (cn *conn) Close() (err error) {
defer errRecover(&err)
cn.send(newWriteBuf('X'))
return cn.c.Close()
}
// Implement the optional "Execer" interface for one-shot queries
func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) {
defer errRecover(&err)
// Check to see if we can use the "simpleQuery" interface, which is
// *much* faster than going through prepare/exec
if len(args) == 0 {
return cn.simpleQuery(query)
}
// Use the unnamed statement to defer planning until bind
// time, or else value-based selectivity estimates cannot be
// used.
st, err := cn.prepareTo(query, "")
if err != nil {
panic(err)
}
r, err := st.Exec(args)
if err != nil {
panic(err)
}
return r, err
}
// Assumes len(*m) is > 5
func (cn *conn) send(m *writeBuf) {
b := (*m)[1:]
binary.BigEndian.PutUint32(b, uint32(len(b)))
if (*m)[0] == 0 {
*m = b
}
_, err := cn.c.Write(*m)
if err != nil {
panic(err)
}
}
func (cn *conn) recv() (t byte, r *readBuf) {
for {
t, r = cn.recv1()
switch t {
case 'E':
panic(parseError(r))
case 'N':
// ignore
default:
return
}
}
panic("not reached")
}
func (cn *conn) recv1() (byte, *readBuf) {
x := make([]byte, 5)
_, err := io.ReadFull(cn.buf, x)
if err != nil {
panic(err)
}
b := readBuf(x[1:])
y := make([]byte, b.int32()-4)
_, err = io.ReadFull(cn.buf, y)
if err != nil {
panic(err)
}
return x[0], (*readBuf)(&y)
}
func (cn *conn) ssl(o Values) {
tlsConf := tls.Config{}
switch mode := o.Get("sslmode"); mode {
case "require", "":
tlsConf.InsecureSkipVerify = true
case "verify-full":
// fall out
case "disable":
return
default:
errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode)
}
w := newWriteBuf(0)
w.int32(80877103)
cn.send(w)
b := make([]byte, 1)
_, err := io.ReadFull(cn.c, b)
if err != nil {
panic(err)
}
if b[0] != 'S' {
panic(ErrSSLNotSupported)
}
cn.c = tls.Client(cn.c, &tlsConf)
}
func (cn *conn) startup(o Values) {
w := newWriteBuf(0)
w.int32(196608)
w.string("user")
w.string(o.Get("user"))
w.string("database")
w.string(o.Get("dbname"))
w.string("")
cn.send(w)
for {
t, r := cn.recv()
switch t {
case 'K', 'S':
case 'R':
cn.auth(r, o)
case 'Z':
return
default:
errorf("unknown response for startup: %q", t)
}
}
}
func (cn *conn) auth(r *readBuf, o Values) {
switch code := r.int32(); code {
case 0:
// OK
case 3:
w := newWriteBuf('p')
w.string(o.Get("password"))
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
}
case 5:
s := string(r.next(4))
w := newWriteBuf('p')
w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 0 {
errorf("unexpected authentication resoonse: %q", t)
}
default:
errorf("unknown authentication response: %d", code)
}
}
type stmt struct {
cn *conn
name string
query string
cols []string
nparams int
rowTyps []oid
paramTyps []oid
closed bool
}
func (st *stmt) Close() (err error) {
if st.closed {
return nil
}
defer errRecover(&err)
w := newWriteBuf('C')
w.byte('S')
w.string(st.name)
st.cn.send(w)
st.cn.send(newWriteBuf('S'))
t, _ := st.cn.recv()
if t != '3' {
errorf("unexpected close response: %q", t)
}
st.closed = true
t, _ = st.cn.recv()
if t != 'Z' {
errorf("expected ready for query, but got: %q", t)
}
return nil
}
func (st *stmt) Query(v []driver.Value) (_ driver.Rows, err error) {
defer errRecover(&err)
st.exec(v)
return &rows{st: st}, nil
}
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
defer errRecover(&err)
if len(v) == 0 {
return st.cn.simpleQuery(st.query)
}
st.exec(v)
for {
t, r := st.cn.recv1()
switch t {
case 'E':
err = parseError(r)
case 'C':
res = parseComplete(r.string())
case 'Z':
// done
return
case 'D':
errorf("unexpected data row returned in Exec; check your query")
case 'S', 'N':
// Ignore
default:
errorf("unknown exec response: %q", t)
}
}
panic("not reached")
}
func (st *stmt) exec(v []driver.Value) {
w := newWriteBuf('B')
w.string("")
w.string(st.name)
w.int16(0)
w.int16(len(v))
for i, x := range v {
if x == nil {
w.int32(-1)
} else {
b := encode(x, st.paramTyps[i])
w.int32(len(b))
w.bytes(b)
}
}
w.int16(0)
st.cn.send(w)
w = newWriteBuf('E')
w.string("")
w.int32(0)
st.cn.send(w)
st.cn.send(newWriteBuf('S'))
var err error
for {
t, r := st.cn.recv1()
switch t {
case 'E':
err = parseError(r)
case '2':
if err != nil {
panic(err)
}
return
case 'Z':
if err != nil {
panic(err)
}
return
case 'N':
// ignore
default:
errorf("unexpected bind response: %q", t)
}
}
}
func (st *stmt) NumInput() int {
return st.nparams
}
type result int64
func (i result) RowsAffected() (int64, error) {
return int64(i), nil
}
func (i result) LastInsertId() (int64, error) {
return 0, ErrNotSupported
}
func parseComplete(s string) driver.Result {
parts := strings.Split(s, " ")
n, _ := strconv.ParseInt(parts[len(parts)-1], 10, 64)
return result(n)
}
type rows struct {
st *stmt
done bool
}
func (rs *rows) Close() error {
for {
err := rs.Next(nil)
switch err {
case nil:
case io.EOF:
return nil
default:
return err
}
}
panic("not reached")
}
func (rs *rows) Columns() []string {
return rs.st.cols
}
func (rs *rows) Next(dest []driver.Value) (err error) {
if rs.done {
return io.EOF
}
defer errRecover(&err)
for {
t, r := rs.st.cn.recv1()
switch t {
case 'E':
err = parseError(r)
case 'C', 'S', 'N':
continue
case 'Z':
rs.done = true
if err != nil {
return err
}
return io.EOF
case 'D':
n := r.int16()
for i := 0; i < len(dest) && i < n; i++ {
l := r.int32()
if l == -1 {
dest[i] = nil
continue
}
dest[i] = decode(r.next(l), rs.st.rowTyps[i])
}
return
default:
errorf("unexpected message after execute: %q", t)
}
}
panic("not reached")
}
func md5s(s string) string {
h := md5.New()
h.Write([]byte(s))
return fmt.Sprintf("%x", h.Sum(nil))
}
// parseEnviron tries to mimic some of libpq's environment handling
//
// To ease testing, it does not directly reference os.Environ, but is
// designed to accept its output.
//
// Environment-set connection information is intended to have a higher
// precedence than a library default but lower than any explicitly
// passed information (such as in the URL or connection string).
func parseEnviron(env []string) (out map[string]string) {
out = make(map[string]string)
for _, v := range env {
parts := strings.SplitN(v, "=", 2)
accrue := func(keyname string) {
out[keyname] = parts[1]
}
// The order of these is the same as is seen in the
// PostgreSQL 9.1 manual, with omissions briefly
// noted.
switch parts[0] {
case "PGHOST":
accrue("host")
case "PGHOSTADDR":
accrue("hostaddr")
case "PGPORT":
accrue("port")
case "PGDATABASE":
accrue("dbname")
case "PGUSER":
accrue("user")
case "PGPASSWORD":
accrue("password")
// skip PGPASSFILE, PGSERVICE, PGSERVICEFILE,
// PGREALM
case "PGOPTIONS":
accrue("options")
case "PGAPPNAME":
accrue("application_name")
case "PGSSLMODE":
accrue("sslmode")
case "PGREQUIRESSL":
accrue("requiressl")
case "PGSSLCERT":
accrue("sslcert")
case "PGSSLKEY":
accrue("sslkey")
case "PGSSLROOTCERT":
accrue("sslrootcert")
case "PGSSLCRL":
accrue("sslcrl")
case "PGREQUIREPEER":
accrue("requirepeer")
case "PGKRBSRVNAME":
accrue("krbsrvname")
case "PGGSSLIB":
accrue("gsslib")
case "PGCONNECT_TIMEOUT":
accrue("connect_timeout")
case "PGCLIENTENCODING":
accrue("client_encoding")
// skip PGDATESTYLE, PGTZ, PGGEQO, PGSYSCONFDIR,
// PGLOCALEDIR
}
}
return out
}

@ -1,121 +0,0 @@
package pq
import (
"database/sql/driver"
"encoding/hex"
"fmt"
"strconv"
"time"
)
func encode(x interface{}, pgtypoid oid) []byte {
switch v := x.(type) {
case int64:
return []byte(fmt.Sprintf("%d", v))
case float32, float64:
return []byte(fmt.Sprintf("%f", v))
case []byte:
if pgtypoid == t_bytea {
return []byte(fmt.Sprintf("\\x%x", v))
}
return v
case string:
if pgtypoid == t_bytea {
return []byte(fmt.Sprintf("\\x%x", v))
}
return []byte(v)
case bool:
return []byte(fmt.Sprintf("%t", v))
case time.Time:
return []byte(v.Format(time.RFC3339Nano))
default:
errorf("encode: unknown type for %T", v)
}
panic("not reached")
}
func decode(s []byte, typ oid) interface{} {
switch typ {
case t_bytea:
s = s[2:] // trim off "\\x"
d := make([]byte, hex.DecodedLen(len(s)))
_, err := hex.Decode(d, s)
if err != nil {
errorf("%s", err)
}
return d
case t_timestamptz:
return mustParse("2006-01-02 15:04:05-07", typ, s)
case t_timestamp:
return mustParse("2006-01-02 15:04:05", typ, s)
case t_time:
return mustParse("15:04:05", typ, s)
case t_timetz:
return mustParse("15:04:05-07", typ, s)
case t_date:
return mustParse("2006-01-02", typ, s)
case t_bool:
return s[0] == 't'
case t_int8, t_int2, t_int4:
i, err := strconv.ParseInt(string(s), 10, 64)
if err != nil {
errorf("%s", err)
}
return i
case t_float4, t_float8:
bits := 64
if typ == t_float4 {
bits = 32
}
f, err := strconv.ParseFloat(string(s), bits)
if err != nil {
errorf("%s", err)
}
return f
}
return s
}
func mustParse(f string, typ oid, s []byte) time.Time {
str := string(s)
// Special case until time.Parse bug is fixed:
// http://code.google.com/p/go/issues/detail?id=3487
if str[len(str)-2] == '.' {
str += "0"
}
// check for a 30-minute-offset timezone
if (typ == t_timestamptz || typ == t_timetz) &&
str[len(str)-3] == ':' {
f += ":00"
}
t, err := time.Parse(f, str)
if err != nil {
errorf("decode: %s", err)
}
return t
}
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) error {
nt.Time, nt.Valid = value.(time.Time)
return nil
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}

@ -1,108 +0,0 @@
package pq
import (
"database/sql/driver"
"fmt"
"io"
"net"
"runtime"
)
const (
Efatal = "FATAL"
Epanic = "PANIC"
Ewarning = "WARNING"
Enotice = "NOTICE"
Edebug = "DEBUG"
Einfo = "INFO"
Elog = "LOG"
)
type Error error
type PGError interface {
Error() string
Fatal() bool
Get(k byte) (v string)
}
type pgError struct {
c map[byte]string
}
func parseError(r *readBuf) *pgError {
err := &pgError{make(map[byte]string)}
for t := r.byte(); t != 0; t = r.byte() {
err.c[t] = r.string()
}
return err
}
func (err *pgError) Get(k byte) (v string) {
v, _ = err.c[k]
return
}
func (err *pgError) Fatal() bool {
return err.Get('S') == Efatal
}
func (err *pgError) Error() string {
var s string
for k, v := range err.c {
s += fmt.Sprintf(" %c:%q", k, v)
}
return "pq: " + s[1:]
}
func errorf(s string, args ...interface{}) {
panic(Error(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))))
}
type SimplePGError struct {
pgError
}
func (err *SimplePGError) Error() string {
return "pq: " + err.Get('M')
}
func errRecoverWithPGReason(err *error) {
e := recover()
switch v := e.(type) {
case nil:
// Do nothing
case *pgError:
// Return a SimplePGError in place
*err = &SimplePGError{*v}
default:
// Otherwise re-panic
panic(e)
}
}
func errRecover(err *error) {
e := recover()
switch v := e.(type) {
case nil:
// Do nothing
case runtime.Error:
panic(v)
case *pgError:
if v.Fatal() {
*err = driver.ErrBadConn
} else {
*err = v
}
case *net.OpError:
*err = driver.ErrBadConn
case error:
if v == io.EOF || v.(error).Error() == "remote error: handshake failure" {
*err = driver.ErrBadConn
} else {
*err = v
}
default:
panic(fmt.Sprintf("unknown error: %#v", e))
}
}

@ -1,319 +0,0 @@
package pq
type oid uint32
const (
t_bool oid = 16
t_bytea = 17
t_char = 18
t_name = 19
t_int8 = 20
t_int2 = 21
t_int2vector = 22
t_int4 = 23
t_regproc = 24
t_text = 25
t_oid = 26
t_tid = 27
t_xid = 28
t_cid = 29
t_oidvector = 30
t_pg_type = 71
t_pg_attribute = 75
t_pg_proc = 81
t_pg_class = 83
t_xml = 142
t__xml = 143
t_pg_node_tree = 194
t_smgr = 210
t_point = 600
t_lseg = 601
t_path = 602
t_box = 603
t_polygon = 604
t_line = 628
t__line = 629
t_float4 = 700
t_float8 = 701
t_abstime = 702
t_reltime = 703
t_tinterval = 704
t_unknown = 705
t_circle = 718
t__circle = 719
t_money = 790
t__money = 791
t_macaddr = 829
t_inet = 869
t_cidr = 650
t__bool = 1000
t__bytea = 1001
t__char = 1002
t__name = 1003
t__int2 = 1005
t__int2vector = 1006
t__int4 = 1007
t__regproc = 1008
t__text = 1009
t__oid = 1028
t__tid = 1010
t__xid = 1011
t__cid = 1012
t__oidvector = 1013
t__bpchar = 1014
t__varchar = 1015
t__int8 = 1016
t__point = 1017
t__lseg = 1018
t__path = 1019
t__box = 1020
t__float4 = 1021
t__float8 = 1022
t__abstime = 1023
t__reltime = 1024
t__tinterval = 1025
t__polygon = 1027
t_aclitem = 1033
t__aclitem = 1034
t__macaddr = 1040
t__inet = 1041
t__cidr = 651
t__cstring = 1263
t_bpchar = 1042
t_varchar = 1043
t_date = 1082
t_time = 1083
t_timestamp = 1114
t__timestamp = 1115
t__date = 1182
t__time = 1183
t_timestamptz = 1184
t__timestamptz = 1185
t_interval = 1186
t__interval = 1187
t__numeric = 1231
t_timetz = 1266
t__timetz = 1270
t_bit = 1560
t__bit = 1561
t_varbit = 1562
t__varbit = 1563
t_numeric = 1700
t_refcursor = 1790
t__refcursor = 2201
t_regprocedure = 2202
t_regoper = 2203
t_regoperator = 2204
t_regclass = 2205
t_regtype = 2206
t__regprocedure = 2207
t__regoper = 2208
t__regoperator = 2209
t__regclass = 2210
t__regtype = 2211
t_uuid = 2950
t__uuid = 2951
t_tsvector = 3614
t_gtsvector = 3642
t_tsquery = 3615
t_regconfig = 3734
t_regdictionary = 3769
t__tsvector = 3643
t__gtsvector = 3644
t__tsquery = 3645
t__regconfig = 3735
t__regdictionary = 3770
t_txid_snapshot = 2970
t__txid_snapshot = 2949
t_record = 2249
t__record = 2287
t_cstring = 2275
t_any = 2276
t_anyarray = 2277
t_void = 2278
t_trigger = 2279
t_language_handler = 2280
t_internal = 2281
t_opaque = 2282
t_anyelement = 2283
t_anynonarray = 2776
t_anyenum = 3500
t_fdw_handler = 3115
t_pg_attrdef = 10000
t_pg_constraint = 10001
t_pg_inherits = 10002
t_pg_index = 10003
t_pg_operator = 10004
t_pg_opfamily = 10005
t_pg_opclass = 10006
t_pg_am = 10117
t_pg_amop = 10118
t_pg_amproc = 10478
t_pg_language = 10731
t_pg_largeobject_metadata = 10732
t_pg_largeobject = 10733
t_pg_aggregate = 10734
t_pg_statistic = 10735
t_pg_rewrite = 10736
t_pg_trigger = 10737
t_pg_description = 10738
t_pg_cast = 10739
t_pg_enum = 10936
t_pg_namespace = 10937
t_pg_conversion = 10938
t_pg_depend = 10939
t_pg_database = 1248
t_pg_db_role_setting = 10940
t_pg_tablespace = 10941
t_pg_pltemplate = 10942
t_pg_authid = 2842
t_pg_auth_members = 2843
t_pg_shdepend = 10943
t_pg_shdescription = 10944
t_pg_ts_config = 10945
t_pg_ts_config_map = 10946
t_pg_ts_dict = 10947
t_pg_ts_parser = 10948
t_pg_ts_template = 10949
t_pg_extension = 10950
t_pg_foreign_data_wrapper = 10951
t_pg_foreign_server = 10952
t_pg_user_mapping = 10953
t_pg_foreign_table = 10954
t_pg_default_acl = 10955
t_pg_seclabel = 10956
t_pg_collation = 10957
t_pg_toast_2604 = 10958
t_pg_toast_2606 = 10959
t_pg_toast_2609 = 10960
t_pg_toast_1255 = 10961
t_pg_toast_2618 = 10962
t_pg_toast_3596 = 10963
t_pg_toast_2619 = 10964
t_pg_toast_2620 = 10965
t_pg_toast_1262 = 10966
t_pg_toast_2396 = 10967
t_pg_toast_2964 = 10968
t_pg_roles = 10970
t_pg_shadow = 10973
t_pg_group = 10976
t_pg_user = 10979
t_pg_rules = 10982
t_pg_views = 10986
t_pg_tables = 10989
t_pg_indexes = 10993
t_pg_stats = 10997
t_pg_locks = 11001
t_pg_cursors = 11004
t_pg_available_extensions = 11007
t_pg_available_extension_versions = 11010
t_pg_prepared_xacts = 11013
t_pg_prepared_statements = 11017
t_pg_seclabels = 11020
t_pg_settings = 11024
t_pg_timezone_abbrevs = 11029
t_pg_timezone_names = 11032
t_pg_stat_all_tables = 11035
t_pg_stat_xact_all_tables = 11039
t_pg_stat_sys_tables = 11043
t_pg_stat_xact_sys_tables = 11047
t_pg_stat_user_tables = 11050
t_pg_stat_xact_user_tables = 11054
t_pg_statio_all_tables = 11057
t_pg_statio_sys_tables = 11061
t_pg_statio_user_tables = 11064
t_pg_stat_all_indexes = 11067
t_pg_stat_sys_indexes = 11071
t_pg_stat_user_indexes = 11074
t_pg_statio_all_indexes = 11077
t_pg_statio_sys_indexes = 11081
t_pg_statio_user_indexes = 11084
t_pg_statio_all_sequences = 11087
t_pg_statio_sys_sequences = 11090
t_pg_statio_user_sequences = 11093
t_pg_stat_activity = 11096
t_pg_stat_replication = 11099
t_pg_stat_database = 11102
t_pg_stat_database_conflicts = 11105
t_pg_stat_user_functions = 11108
t_pg_stat_xact_user_functions = 11112
t_pg_stat_bgwriter = 11116
t_pg_user_mappings = 11119
t_cardinal_number = 11669
t_character_data = 11671
t_sql_identifier = 11672
t_information_schema_catalog_name = 11674
t_time_stamp = 11676
t_yes_or_no = 11677
t_applicable_roles = 11680
t_administrable_role_authorizations = 11684
t_attributes = 11687
t_character_sets = 11691
t_check_constraint_routine_usage = 11695
t_check_constraints = 11699
t_collations = 11703
t_collation_character_set_applicability = 11706
t_column_domain_usage = 11709
t_column_privileges = 11713
t_column_udt_usage = 11717
t_columns = 11721
t_constraint_column_usage = 11725
t_constraint_table_usage = 11729
t_domain_constraints = 11733
t_domain_udt_usage = 11737
t_domains = 11740
t_enabled_roles = 11744
t_key_column_usage = 11747
t_parameters = 11751
t_referential_constraints = 11755
t_role_column_grants = 11759
t_routine_privileges = 11762
t_role_routine_grants = 11766
t_routines = 11769
t_schemata = 11773
t_sequences = 11776
t_sql_features = 11780
t_pg_toast_11779 = 11782
t_sql_implementation_info = 11785
t_pg_toast_11784 = 11787
t_sql_languages = 11790
t_pg_toast_11789 = 11792
t_sql_packages = 11795
t_pg_toast_11794 = 11797
t_sql_parts = 11800
t_pg_toast_11799 = 11802
t_sql_sizing = 11805
t_pg_toast_11804 = 11807
t_sql_sizing_profiles = 11810
t_pg_toast_11809 = 11812
t_table_constraints = 11815
t_table_privileges = 11819
t_role_table_grants = 11823
t_tables = 11826
t_triggered_update_columns = 11830
t_triggers = 11834
t_usage_privileges = 11838
t_role_usage_grants = 11842
t_view_column_usage = 11845
t_view_routine_usage = 11849
t_view_table_usage = 11853
t_views = 11857
t_data_type_privileges = 11861
t_element_types = 11865
t__pg_foreign_data_wrappers = 11869
t_foreign_data_wrapper_options = 11872
t_foreign_data_wrappers = 11875
t__pg_foreign_servers = 11878
t_foreign_server_options = 11882
t_foreign_servers = 11885
t__pg_foreign_tables = 11888
t_foreign_table_options = 11892
t_foreign_tables = 11895
t__pg_user_mappings = 11898
t_user_mapping_options = 11901
t_user_mappings = 11905
t_t = 16806
t__t = 16805
t_temp = 16810
t__temp = 16809
)

@ -1,68 +0,0 @@
package pq
import (
"fmt"
nurl "net/url"
"sort"
"strings"
)
// ParseURL converts url to a connection string for driver.Open.
// Example:
//
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
//
// converts to:
//
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
//
// A minimal example:
//
// "postgres://"
//
// This will be blank, causing driver.Open to use all of the defaults
func ParseURL(url string) (string, error) {
u, err := nurl.Parse(url)
if err != nil {
return "", err
}
if u.Scheme != "postgres" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+v)
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
i := strings.Index(u.Host, ":")
if i < 0 {
accrue("host", u.Host)
} else {
accrue("host", u.Host[:i])
accrue("port", u.Host[i+1:])
}
if u.Path != "" {
accrue("dbname", u.Path[1:])
}
q := u.Query()
for k, _ := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
}

@ -1,9 +0,0 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

@ -1,22 +0,0 @@
Copyright (c) 2019 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -1,8 +0,0 @@
[![](https://godoc.org/github.com/jackc/chunkreader?status.svg)](https://godoc.org/github.com/jackc/chunkreader)
[![Build Status](https://travis-ci.org/jackc/chunkreader.svg)](https://travis-ci.org/jackc/chunkreader)
# chunkreader
Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations.
Extracted from original implementation in https://github.com/jackc/pgx.

@ -1,104 +0,0 @@
// Package chunkreader provides an io.Reader wrapper that minimizes IO reads and memory allocations.
package chunkreader
import (
"io"
)
// ChunkReader is a io.Reader wrapper that minimizes IO reads and memory allocations. It allocates memory in chunks and
// will read as much as will fit in the current buffer in a single call regardless of how large a read is actually
// requested. The memory returned via Next is owned by the caller. This avoids the need for an additional copy.
//
// The downside of this approach is that a large buffer can be pinned in memory even if only a small slice is
// referenced. For example, an entire 4096 byte block could be pinned in memory by even a 1 byte slice. In these rare
// cases it would be advantageous to copy the bytes to another slice.
type ChunkReader struct {
r io.Reader
buf []byte
rp, wp int // buf read position and write position
config Config
}
// Config contains configuration parameters for ChunkReader.
type Config struct {
MinBufLen int // Minimum buffer length
}
// New creates and returns a new ChunkReader for r with default configuration.
func New(r io.Reader) *ChunkReader {
cr, err := NewConfig(r, Config{})
if err != nil {
panic("default config can't be bad")
}
return cr
}
// NewConfig creates and a new ChunkReader for r configured by config.
func NewConfig(r io.Reader, config Config) (*ChunkReader, error) {
if config.MinBufLen == 0 {
// By historical reasons Postgres currently has 8KB send buffer inside,
// so here we want to have at least the same size buffer.
// @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134
// @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru
config.MinBufLen = 8192
}
return &ChunkReader{
r: r,
buf: make([]byte, config.MinBufLen),
config: config,
}, nil
}
// Next returns buf filled with the next n bytes. The caller gains ownership of buf. It is not necessary to make a copy
// of buf. If an error occurs, buf will be nil.
func (r *ChunkReader) Next(n int) (buf []byte, err error) {
// n bytes already in buf
if (r.wp - r.rp) >= n {
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, err
}
// available space in buf is less than n
if len(r.buf) < n {
r.copyBufContents(r.newBuf(n))
}
// buf is large enough, but need to shift filled area to start to make enough contiguous space
minReadCount := n - (r.wp - r.rp)
if (len(r.buf) - r.wp) < minReadCount {
newBuf := r.newBuf(n)
r.copyBufContents(newBuf)
}
if err := r.appendAtLeast(minReadCount); err != nil {
return nil, err
}
buf = r.buf[r.rp : r.rp+n]
r.rp += n
return buf, nil
}
func (r *ChunkReader) appendAtLeast(fillLen int) error {
n, err := io.ReadAtLeast(r.r, r.buf[r.wp:], fillLen)
r.wp += n
return err
}
func (r *ChunkReader) newBuf(size int) []byte {
if size < r.config.MinBufLen {
size = r.config.MinBufLen
}
return make([]byte, size)
}
func (r *ChunkReader) copyBufContents(dest []byte) {
r.wp = copy(dest, r.buf[r.rp:r.wp])
r.rp = 0
r.buf = dest
}

@ -1,3 +0,0 @@
.envrc
vendor/
.vscode

@ -1,161 +0,0 @@
# 1.14.0 (February 11, 2023)
* Fix: each connection attempt to new node gets own timeout (Nathan Giardina)
* Set SNI for SSL connections (Stas Kelvich)
* Fix: CopyFrom I/O race (Tommy Reilly)
* Minor dependency upgrades
# 1.13.0 (August 6, 2022)
* Add sslpassword support (Eric McCormack and yun.xu)
* Add prefer-standby target_session_attrs support (sergey.bashilov)
* Fix GSS ErrorResponse handling (Oliver Tan)
# 1.12.1 (May 7, 2022)
* Fix: setting krbspn and krbsrvname in connection string (sireax)
* Add support for Unix sockets on Windows (Eno Compton)
* Stop ignoring ErrorResponse during SCRAM auth (Rafi Shamim)
# 1.12.0 (April 21, 2022)
* Add pluggable GSSAPI support (Oliver Tan)
* Fix: Consider any "0A000" error a possible cached plan changed error due to locale
* Better match psql fallback behavior with multiple hosts
# 1.11.0 (February 7, 2022)
* Support port in ip from LookupFunc to override config (James Hartig)
* Fix TLS connection timeout (Blake Embrey)
* Add support for read-only, primary, standby, prefer-standby target_session_attributes (Oscar)
* Fix connect when receiving NoticeResponse
# 1.10.1 (November 20, 2021)
* Close without waiting for response (Kei Kamikawa)
* Save waiting for network round-trip in CopyFrom (Rueian)
* Fix concurrency issue with ContextWatcher
* LRU.Get always checks context for cancellation / expiration (Georges Varouchas)
# 1.10.0 (July 24, 2021)
* net.Timeout errors are no longer returned when a query is canceled via context. A wrapped context error is returned.
# 1.9.0 (July 10, 2021)
* pgconn.Timeout only is true for errors originating in pgconn (Michael Darr)
* Add defaults for sslcert, sslkey, and sslrootcert (Joshua Brindle)
* Solve issue with 'sslmode=verify-full' when there are multiple hosts (mgoddard)
* Fix default host when parsing URL without host but with port
* Allow dbname query parameter in URL conn string
* Update underlying dependencies
# 1.8.1 (March 25, 2021)
* Better connection string sanitization (ip.novikov)
* Use proper pgpass location on Windows (Moshe Katz)
* Use errors instead of golang.org/x/xerrors
* Resume fallback on server error in Connect (Andrey Borodin)
# 1.8.0 (December 3, 2020)
* Add StatementErrored method to stmtcache.Cache. This allows the cache to purge invalidated prepared statements. (Ethan Pailes)
# 1.7.2 (November 3, 2020)
* Fix data value slices into work buffer with capacities larger than length.
# 1.7.1 (October 31, 2020)
* Do not asyncClose after receiving FATAL error from PostgreSQL server
# 1.7.0 (September 26, 2020)
* Exec(Params|Prepared) return ResultReader with FieldDescriptions loaded
* Add ReceiveResults (Sebastiaan Mannem)
* Fix parsing DSN connection with bad backslash
* Add PgConn.CleanupDone so connection pools can determine when async close is complete
# 1.6.4 (July 29, 2020)
* Fix deadlock on error after CommandComplete but before ReadyForQuery
* Fix panic on parsing DSN with trailing '='
# 1.6.3 (July 22, 2020)
* Fix error message after AppendCertsFromPEM failure (vahid-sohrabloo)
# 1.6.2 (July 14, 2020)
* Update pgservicefile library
# 1.6.1 (June 27, 2020)
* Update golang.org/x/crypto to latest
* Update golang.org/x/text to 0.3.3
* Fix error handling for bad PGSERVICE definition
* Redact passwords in ParseConfig errors (Lukas Vogel)
# 1.6.0 (June 6, 2020)
* Fix panic when closing conn during cancellable query
* Fix behavior of sslmode=require with sslrootcert present (Petr Jediný)
* Fix field descriptions available after command concluded (Tobias Salzmann)
* Support connect_timeout (georgysavva)
* Handle IPv6 in connection URLs (Lukas Vogel)
* Fix ValidateConnect with cancelable context
* Improve CopyFrom performance
* Add Config.Copy (georgysavva)
# 1.5.0 (March 30, 2020)
* Update golang.org/x/crypto for security fix
* Implement "verify-ca" SSL mode (Greg Curtis)
# 1.4.0 (March 7, 2020)
* Fix ExecParams and ExecPrepared handling of empty query.
* Support reading config from PostgreSQL service files.
# 1.3.2 (February 14, 2020)
* Update chunkreader to v2.0.1 for optimized default buffer size.
# 1.3.1 (February 5, 2020)
* Fix CopyFrom deadlock when multiple NoticeResponse received during copy
# 1.3.0 (January 23, 2020)
* Add Hijack and Construct.
* Update pgproto3 to v2.0.1.
# 1.2.1 (January 13, 2020)
* Fix data race in context cancellation introduced in v1.2.0.
# 1.2.0 (January 11, 2020)
## Features
* Add Insert(), Update(), Delete(), and Select() statement type query methods to CommandTag.
* Add PgError.SQLState method. This could be used for compatibility with other drivers and databases.
## Performance
* Improve performance when context.Background() is used. (bakape)
* CommandTag.RowsAffected is faster and does not allocate.
## Fixes
* Try to cancel any in-progress query when a conn is closed by ctx cancel.
* Handle NoticeResponse during CopyFrom.
* Ignore errors sending Terminate message while closing connection. This mimics the behavior of libpq PGfinish.
# 1.1.0 (October 12, 2019)
* Add PgConn.IsBusy() method.
# 1.0.1 (September 19, 2019)
* Fix statement cache not properly cleaning discarded statements.

@ -1,22 +0,0 @@
Copyright (c) 2019-2021 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -1,62 +0,0 @@
[![](https://godoc.org/github.com/jackc/pgconn?status.svg)](https://godoc.org/github.com/jackc/pgconn)
![CI](https://github.com/jackc/pgconn/workflows/CI/badge.svg)
---
This version is used with pgx `v4`. In pgx `v5` it is part of the https://github.com/jackc/pgx repository.
---
# pgconn
Package pgconn is a low-level PostgreSQL database driver. It operates at nearly the same level as the C library libpq.
It is primarily intended to serve as the foundation for higher level libraries such as https://github.com/jackc/pgx.
Applications should handle normal queries with a higher level library and only use pgconn directly when required for
low-level access to PostgreSQL functionality.
## Example Usage
```go
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("DATABASE_URL"))
if err != nil {
log.Fatalln("pgconn failed to connect:", err)
}
defer pgConn.Close(context.Background())
result := pgConn.ExecParams(context.Background(), "SELECT email FROM users WHERE id=$1", [][]byte{[]byte("123")}, nil, nil, nil)
for result.NextRow() {
fmt.Println("User 123 has email:", string(result.Values()[0]))
}
_, err = result.Close()
if err != nil {
log.Fatalln("failed reading result:", err)
}
```
## Testing
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING`
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
environment variable handling.
### Example Test Environment
Connect to your PostgreSQL server and run:
```
create database pgx_test;
```
Now you can run the tests:
```bash
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
```
### Connection and Authentication Tests
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
authentication code.

@ -1,270 +0,0 @@
// SCRAM-SHA-256 authentication
//
// Resources:
// https://tools.ietf.org/html/rfc5802
// https://tools.ietf.org/html/rfc8265
// https://www.postgresql.org/docs/current/sasl-authentication.html
//
// Inspiration drawn from other implementations:
// https://github.com/lib/pq/pull/608
// https://github.com/lib/pq/pull/788
// https://github.com/lib/pq/pull/833
package pgconn
import (
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"strconv"
"github.com/jackc/pgproto3/v2"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/text/secure/precis"
)
const clientNonceLen = 18
// Perform SCRAM authentication.
func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
sc, err := newScramClient(serverAuthMechanisms, c.config.Password)
if err != nil {
return err
}
// Send client-first-message in a SASLInitialResponse
saslInitialResponse := &pgproto3.SASLInitialResponse{
AuthMechanism: "SCRAM-SHA-256",
Data: sc.clientFirstMessage(),
}
_, err = c.conn.Write(saslInitialResponse.Encode(nil))
if err != nil {
return err
}
// Receive server-first-message payload in a AuthenticationSASLContinue.
saslContinue, err := c.rxSASLContinue()
if err != nil {
return err
}
err = sc.recvServerFirstMessage(saslContinue.Data)
if err != nil {
return err
}
// Send client-final-message in a SASLResponse
saslResponse := &pgproto3.SASLResponse{
Data: []byte(sc.clientFinalMessage()),
}
_, err = c.conn.Write(saslResponse.Encode(nil))
if err != nil {
return err
}
// Receive server-final-message payload in a AuthenticationSASLFinal.
saslFinal, err := c.rxSASLFinal()
if err != nil {
return err
}
return sc.recvServerFinalMessage(saslFinal.Data)
}
func (c *PgConn) rxSASLContinue() (*pgproto3.AuthenticationSASLContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationSASLContinue message but received unexpected message %T", msg)
}
func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationSASLFinal:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationSASLFinal message but received unexpected message %T", msg)
}
type scramClient struct {
serverAuthMechanisms []string
password []byte
clientNonce []byte
clientFirstMessageBare []byte
serverFirstMessage []byte
clientAndServerNonce []byte
salt []byte
iterations int
saltedPassword []byte
authMessage []byte
}
func newScramClient(serverAuthMechanisms []string, password string) (*scramClient, error) {
sc := &scramClient{
serverAuthMechanisms: serverAuthMechanisms,
}
// Ensure server supports SCRAM-SHA-256
hasScramSHA256 := false
for _, mech := range sc.serverAuthMechanisms {
if mech == "SCRAM-SHA-256" {
hasScramSHA256 = true
break
}
}
if !hasScramSHA256 {
return nil, errors.New("server does not support SCRAM-SHA-256")
}
// precis.OpaqueString is equivalent to SASLprep for password.
var err error
sc.password, err = precis.OpaqueString.Bytes([]byte(password))
if err != nil {
// PostgreSQL allows passwords invalid according to SCRAM / SASLprep.
sc.password = []byte(password)
}
buf := make([]byte, clientNonceLen)
_, err = rand.Read(buf)
if err != nil {
return nil, err
}
sc.clientNonce = make([]byte, base64.RawStdEncoding.EncodedLen(len(buf)))
base64.RawStdEncoding.Encode(sc.clientNonce, buf)
return sc, nil
}
func (sc *scramClient) clientFirstMessage() []byte {
sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce))
return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare))
}
func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error {
sc.serverFirstMessage = serverFirstMessage
buf := serverFirstMessage
if !bytes.HasPrefix(buf, []byte("r=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include r=")
}
buf = buf[2:]
idx := bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
sc.clientAndServerNonce = buf[:idx]
buf = buf[idx+1:]
if !bytes.HasPrefix(buf, []byte("s=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include s=")
}
buf = buf[2:]
idx = bytes.IndexByte(buf, ',')
if idx == -1 {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
saltStr := buf[:idx]
buf = buf[idx+1:]
if !bytes.HasPrefix(buf, []byte("i=")) {
return errors.New("invalid SCRAM server-first-message received from server: did not include i=")
}
buf = buf[2:]
iterationsStr := buf
var err error
sc.salt, err = base64.StdEncoding.DecodeString(string(saltStr))
if err != nil {
return fmt.Errorf("invalid SCRAM salt received from server: %w", err)
}
sc.iterations, err = strconv.Atoi(string(iterationsStr))
if err != nil || sc.iterations <= 0 {
return fmt.Errorf("invalid SCRAM iteration count received from server: %w", err)
}
if !bytes.HasPrefix(sc.clientAndServerNonce, sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not start with client nonce")
}
if len(sc.clientAndServerNonce) <= len(sc.clientNonce) {
return errors.New("invalid SCRAM nonce: did not include server nonce")
}
return nil
}
func (sc *scramClient) clientFinalMessage() string {
clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce))
sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New)
sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(","))
clientProof := computeClientProof(sc.saltedPassword, sc.authMessage)
return fmt.Sprintf("%s,p=%s", clientFinalMessageWithoutProof, clientProof)
}
func (sc *scramClient) recvServerFinalMessage(serverFinalMessage []byte) error {
if !bytes.HasPrefix(serverFinalMessage, []byte("v=")) {
return errors.New("invalid SCRAM server-final-message received from server")
}
serverSignature := serverFinalMessage[2:]
if !hmac.Equal(serverSignature, computeServerSignature(sc.saltedPassword, sc.authMessage)) {
return errors.New("invalid SCRAM ServerSignature received from server")
}
return nil
}
func computeHMAC(key, msg []byte) []byte {
mac := hmac.New(sha256.New, key)
mac.Write(msg)
return mac.Sum(nil)
}
func computeClientProof(saltedPassword, authMessage []byte) []byte {
clientKey := computeHMAC(saltedPassword, []byte("Client Key"))
storedKey := sha256.Sum256(clientKey)
clientSignature := computeHMAC(storedKey[:], authMessage)
clientProof := make([]byte, len(clientSignature))
for i := 0; i < len(clientSignature); i++ {
clientProof[i] = clientKey[i] ^ clientSignature[i]
}
buf := make([]byte, base64.StdEncoding.EncodedLen(len(clientProof)))
base64.StdEncoding.Encode(buf, clientProof)
return buf
}
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
serverSignature := computeHMAC(serverKey, authMessage)
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
base64.StdEncoding.Encode(buf, serverSignature)
return buf
}

@ -1,905 +0,0 @@
package pgconn
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/jackc/chunkreader/v2"
"github.com/jackc/pgpassfile"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgservicefile"
)
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
type GetSSLPasswordFunc func(ctx context.Context) string
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A
// manually initialized Config will cause ConnectConfig to panic.
type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
Port uint16
Database string
User string
Password string
TLSConfig *tls.Config // nil disables TLS
ConnectTimeout time.Duration
DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
KerberosSrvName string
KerberosSpn string
Fallbacks []*FallbackConfig
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
ValidateConnect ValidateConnectFunc
// AfterConnect is called after ValidateConnect. It can be used to set up the connection (e.g. Set session variables
// or prepare statements). If this returns an error the connection attempt fails.
AfterConnect AfterConnectFunc
// OnNotice is a callback function called when a notice response is received.
OnNotice NoticeHandler
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
type ParseConfigOptions struct {
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function
// PQsetSSLKeyPassHook_OpenSSL.
GetSSLPassword GetSSLPasswordFunc
}
// Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the TLSConfig field:
// according to the tls.Config docs it must not be modified after creation.
func (c *Config) Copy() *Config {
newConf := new(Config)
*newConf = *c
if newConf.TLSConfig != nil {
newConf.TLSConfig = c.TLSConfig.Clone()
}
if newConf.RuntimeParams != nil {
newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams))
for k, v := range c.RuntimeParams {
newConf.RuntimeParams[k] = v
}
}
if newConf.Fallbacks != nil {
newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks))
for i, fallback := range c.Fallbacks {
newFallback := new(FallbackConfig)
*newFallback = *fallback
if newFallback.TLSConfig != nil {
newFallback.TLSConfig = fallback.TLSConfig.Clone()
}
newConf.Fallbacks[i] = newFallback
}
}
return newConf
}
// FallbackConfig is additional settings to attempt a connection with when the primary Config fails to establish a
// network connection. It is used for TLS fallback such as sslmode=prefer and high availability (HA) connections.
type FallbackConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16
TLSConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
func isAbsolutePath(path string) bool {
isWindowsPath := func(p string) bool {
if len(p) < 3 {
return false
}
drive := p[0]
colon := p[1]
backslash := p[2]
if drive >= 'A' && drive <= 'Z' && colon == ':' && backslash == '\\' {
return true
}
return false
}
return strings.HasPrefix(path, "/") || isWindowsPath(path)
}
// NetworkAddress converts a PostgreSQL host and port into network and address suitable for use with
// net.Dial.
func NetworkAddress(host string, port uint16) (network, address string) {
if isAbsolutePath(host) {
network = "unix"
address = filepath.Join(host, ".s.PGSQL.") + strconv.FormatInt(int64(port), 10)
} else {
network = "tcp"
address = net.JoinHostPort(host, strconv.Itoa(int(port)))
}
return network, address
}
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style).
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
//
// # Example DSN
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
//
// # Example URL
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
//
// The returned *Config may be modified. However, it is strongly recommended that any configuration that can be done
// through the connection string be done there. In particular the fields Host, Port, TLSConfig, and Fallbacks can be
// interdependent (e.g. TLSConfig needs knowledge of the host to validate the server certificate). These fields should
// not be modified individually. They should all be modified or all left unchanged.
//
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
// values that will be tried in order. This can be used as part of a high availability system. See
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
//
// # Example URL
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
//
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or DSN:
//
// PGHOST
// PGPORT
// PGDATABASE
// PGUSER
// PGPASSWORD
// PGPASSFILE
// PGSERVICE
// PGSERVICEFILE
// PGSSLMODE
// PGSSLCERT
// PGSSLKEY
// PGSSLROOTCERT
// PGSSLPASSWORD
// PGAPPNAME
// PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS
//
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
//
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
// usually but not always the environment variable name downcased and without the "PG" prefix.
//
// Important Security Notes:
//
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
// not set.
//
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
// security each sslmode provides.
//
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
// the Config struct. If TLSConfig is manually changed it will not affect the fallbacks. For example, in the case of
// sslmode "prefer" this means it will first try the main Config settings which use TLS, then it will try the fallback
// which does not use TLS. This can lead to an unexpected unencrypted connection if the main TLS config is manually
// changed later but the unencrypted fallback is present. Ensure there are no stale fallbacks when manually setting
// TLSConfig.
//
// Other known differences with libpq:
//
// When multiple hosts are specified, libpq allows them to have different passwords set via the .pgpass file. pgconn
// does not.
//
// In addition, ParseConfig accepts the following options:
//
// min_read_buffer_size
// The minimum size of the internal read buffer. Default 8192.
// servicefile
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string.
func ParseConfig(connString string) (*Config, error) {
var parseConfigOptions ParseConfigOptions
return ParseConfigWithOptions(connString, parseConfigOptions)
}
// ParseConfigWithOptions builds a *Config from connString and options with similar behavior to the PostgreSQL standard
// C library libpq. options contains settings that cannot be specified in a connString such as providing a function to
// get the SSL password.
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Config, error) {
defaultSettings := defaultSettings()
envSettings := parseEnvSettings()
connStringSettings := make(map[string]string)
if connString != "" {
var err error
// connString may be a database URL or a DSN
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err}
}
} else {
connStringSettings, err = parseDSNSettings(connString)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err}
}
}
}
settings := mergeSettings(defaultSettings, envSettings, connStringSettings)
if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err}
}
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
}
minReadBufferSize, err := strconv.ParseInt(settings["min_read_buffer_size"], 10, 32)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "cannot parse min_read_buffer_size", err: err}
}
config := &Config{
createdByParseConfig: true,
Database: settings["database"],
User: settings["user"],
Password: settings["password"],
RuntimeParams: make(map[string]string),
BuildFrontend: makeDefaultBuildFrontendFunc(int(minReadBufferSize)),
}
if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err}
}
config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
} else {
defaultDialer := makeDefaultDialer()
config.DialFunc = defaultDialer.DialContext
}
config.LookupFunc = makeDefaultResolver().LookupHost
notRuntimeParams := map[string]struct{}{
"host": {},
"port": {},
"database": {},
"user": {},
"password": {},
"passfile": {},
"connect_timeout": {},
"sslmode": {},
"sslkey": {},
"sslcert": {},
"sslrootcert": {},
"sslpassword": {},
"sslsni": {},
"krbspn": {},
"krbsrvname": {},
"target_session_attrs": {},
"min_read_buffer_size": {},
"service": {},
"servicefile": {},
}
// Adding kerberos configuration
if _, present := settings["krbsrvname"]; present {
config.KerberosSrvName = settings["krbsrvname"]
}
if _, present := settings["krbspn"]; present {
config.KerberosSpn = settings["krbspn"]
}
for k, v := range settings {
if _, present := notRuntimeParams[k]; present {
continue
}
config.RuntimeParams[k] = v
}
fallbacks := []*FallbackConfig{}
hosts := strings.Split(settings["host"], ",")
ports := strings.Split(settings["port"], ",")
for i, host := range hosts {
var portStr string
if i < len(ports) {
portStr = ports[i]
} else {
portStr = ports[0]
}
port, err := parsePort(portStr)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err}
}
var tlsConfigs []*tls.Config
// Ignore TLS settings if Unix domain socket like libpq
if network, _ := NetworkAddress(host, port); network == "unix" {
tlsConfigs = append(tlsConfigs, nil)
} else {
var err error
tlsConfigs, err = configTLS(settings, host, options)
if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err}
}
}
for _, tlsConfig := range tlsConfigs {
fallbacks = append(fallbacks, &FallbackConfig{
Host: host,
Port: port,
TLSConfig: tlsConfig,
})
}
}
config.Host = fallbacks[0].Host
config.Port = fallbacks[0].Port
config.TLSConfig = fallbacks[0].TLSConfig
config.Fallbacks = fallbacks[1:]
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
if err == nil {
if config.Password == "" {
host := config.Host
if network, _ := NetworkAddress(config.Host, config.Port); network == "unix" {
host = "localhost"
}
config.Password = passfile.FindPassword(host, strconv.Itoa(int(config.Port)), config.Database, config.User)
}
}
switch tsa := settings["target_session_attrs"]; tsa {
case "read-write":
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
case "read-only":
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
case "primary":
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
case "standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
case "prefer-standby":
config.ValidateConnect = ValidateConnectTargetSessionAttrsPreferStandby
case "any":
// do nothing
default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
}
return config, nil
}
func mergeSettings(settingSets ...map[string]string) map[string]string {
settings := make(map[string]string)
for _, s2 := range settingSets {
for k, v := range s2 {
settings[k] = v
}
}
return settings
}
func parseEnvSettings() map[string]string {
settings := make(map[string]string)
nameMap := map[string]string{
"PGHOST": "host",
"PGPORT": "port",
"PGDATABASE": "database",
"PGUSER": "user",
"PGPASSWORD": "password",
"PGPASSFILE": "passfile",
"PGAPPNAME": "application_name",
"PGCONNECT_TIMEOUT": "connect_timeout",
"PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert",
"PGSSLSNI": "sslsni",
"PGSSLROOTCERT": "sslrootcert",
"PGSSLPASSWORD": "sslpassword",
"PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service",
"PGSERVICEFILE": "servicefile",
}
for envname, realname := range nameMap {
value := os.Getenv(envname)
if value != "" {
settings[realname] = value
}
}
return settings
}
func parseURLSettings(connString string) (map[string]string, error) {
settings := make(map[string]string)
url, err := url.Parse(connString)
if err != nil {
return nil, err
}
if url.User != nil {
settings["user"] = url.User.Username()
if password, present := url.User.Password(); present {
settings["password"] = password
}
}
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
var hosts []string
var ports []string
for _, host := range strings.Split(url.Host, ",") {
if host == "" {
continue
}
if isIPOnly(host) {
hosts = append(hosts, strings.Trim(host, "[]"))
continue
}
h, p, err := net.SplitHostPort(host)
if err != nil {
return nil, fmt.Errorf("failed to split host:port in '%s', err: %w", host, err)
}
if h != "" {
hosts = append(hosts, h)
}
if p != "" {
ports = append(ports, p)
}
}
if len(hosts) > 0 {
settings["host"] = strings.Join(hosts, ",")
}
if len(ports) > 0 {
settings["port"] = strings.Join(ports, ",")
}
database := strings.TrimLeft(url.Path, "/")
if database != "" {
settings["database"] = database
}
nameMap := map[string]string{
"dbname": "database",
}
for k, v := range url.Query() {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v[0]
}
return settings, nil
}
func isIPOnly(host string) bool {
return net.ParseIP(strings.Trim(host, "[]")) != nil || !strings.Contains(host, ":")
}
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseDSNSettings(s string) (map[string]string, error) {
settings := make(map[string]string)
nameMap := map[string]string{
"dbname": "database",
}
for len(s) > 0 {
var key, val string
eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 {
return nil, errors.New("invalid dsn")
}
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
s = strings.TrimLeft(s[eqIdx+1:], " \t\n\r\v\f")
if len(s) == 0 {
} else if s[0] != '\'' {
end := 0
for ; end < len(s); end++ {
if asciiSpace[s[end]] == 1 {
break
}
if s[end] == '\\' {
end++
if end == len(s) {
return nil, errors.New("invalid backslash")
}
}
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
} else { // quoted string
s = s[1:]
end := 0
for ; end < len(s); end++ {
if s[end] == '\'' {
break
}
if s[end] == '\\' {
end++
}
}
if end == len(s) {
return nil, errors.New("unterminated quoted string in connection info string")
}
val = strings.Replace(strings.Replace(s[:end], "\\\\", "\\", -1), "\\'", "'", -1)
if end == len(s) {
s = ""
} else {
s = s[end+1:]
}
}
if k, ok := nameMap[key]; ok {
key = k
}
if key == "" {
return nil, errors.New("invalid dsn")
}
settings[key] = val
}
return settings, nil
}
func parseServiceSettings(servicefilePath, serviceName string) (map[string]string, error) {
servicefile, err := pgservicefile.ReadServicefile(servicefilePath)
if err != nil {
return nil, fmt.Errorf("failed to read service file: %v", servicefilePath)
}
service, err := servicefile.GetService(serviceName)
if err != nil {
return nil, fmt.Errorf("unable to find service: %v", serviceName)
}
nameMap := map[string]string{
"dbname": "database",
}
settings := make(map[string]string, len(service.Settings))
for k, v := range service.Settings {
if k2, present := nameMap[k]; present {
k = k2
}
settings[k] = v
}
return settings, nil
}
// configTLS uses libpq's TLS parameters to construct []*tls.Config. It is
// necessary to allow returning multiple TLS configs as sslmode "allow" and
// "prefer" allow fallback.
func configTLS(settings map[string]string, thisHost string, parseConfigOptions ParseConfigOptions) ([]*tls.Config, error) {
host := thisHost
sslmode := settings["sslmode"]
sslrootcert := settings["sslrootcert"]
sslcert := settings["sslcert"]
sslkey := settings["sslkey"]
sslpassword := settings["sslpassword"]
sslsni := settings["sslsni"]
// Match libpq default behavior
if sslmode == "" {
sslmode = "prefer"
}
if sslsni == "" {
sslsni = "1"
}
tlsConfig := &tls.Config{}
switch sslmode {
case "disable":
return []*tls.Config{nil}, nil
case "allow", "prefer":
tlsConfig.InsecureSkipVerify = true
case "require":
// According to PostgreSQL documentation, if a root CA file exists,
// the behavior of sslmode=require should be the same as that of verify-ca
//
// See https://www.postgresql.org/docs/12/libpq-ssl.html
if sslrootcert != "" {
goto nextCase
}
tlsConfig.InsecureSkipVerify = true
break
nextCase:
fallthrough
case "verify-ca":
// Don't perform the default certificate verification because it
// will verify the hostname. Instead, verify the server's
// certificate chain ourselves in VerifyPeerCertificate and
// ignore the server name. This emulates libpq's verify-ca
// behavior.
//
// See https://github.com/golang/go/issues/21971#issuecomment-332693931
// and https://pkg.go.dev/crypto/tls?tab=doc#example-Config-VerifyPeerCertificate
// for more info.
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyPeerCertificate = func(certificates [][]byte, _ [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, len(certificates))
for i, asn1Data := range certificates {
cert, err := x509.ParseCertificate(asn1Data)
if err != nil {
return errors.New("failed to parse certificate from server: " + err.Error())
}
certs[i] = cert
}
// Leave DNSName empty to skip hostname verification.
opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
Intermediates: x509.NewCertPool(),
}
// Skip the first cert because it's the leaf. All others
// are intermediates.
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)
return err
}
case "verify-full":
tlsConfig.ServerName = host
default:
return nil, errors.New("sslmode is invalid")
}
if sslrootcert != "" {
caCertPool := x509.NewCertPool()
caPath := sslrootcert
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return nil, errors.New(`both "sslcert" and "sslkey" are required`)
}
if sslcert != "" && sslkey != "" {
buf, err := ioutil.ReadFile(sslkey)
if err != nil {
return nil, fmt.Errorf("unable to read sslkey: %w", err)
}
block, _ := pem.Decode(buf)
var pemKey []byte
var decryptedKey []byte
var decryptedError error
// If PEM is encrypted, attempt to decrypt using pass phrase
if x509.IsEncryptedPEMBlock(block) {
// Attempt decryption with pass phrase
// NOTE: only supports RSA (PKCS#1)
if sslpassword != "" {
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
}
//if sslpassword not provided or has decryption error when use it
//try to find sslpassword with callback function
if sslpassword == "" || decryptedError != nil {
if parseConfigOptions.GetSSLPassword != nil {
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
}
if sslpassword == "" {
return nil, fmt.Errorf("unable to find sslpassword")
}
}
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
// Should we also provide warning for PKCS#1 needed?
if decryptedError != nil {
return nil, fmt.Errorf("unable to decrypt key: %w", err)
}
pemBytes := pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: decryptedKey,
}
pemKey = pem.EncodeToMemory(&pemBytes)
} else {
pemKey = pem.EncodeToMemory(block)
}
certfile, err := ioutil.ReadFile(sslcert)
if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err)
}
cert, err := tls.X509KeyPair(certfile, pemKey)
if err != nil {
return nil, fmt.Errorf("unable to load cert: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
// Set Server Name Indication (SNI), if enabled by connection parameters.
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
// or IPv6).
if sslsni == "1" && net.ParseIP(host) == nil {
tlsConfig.ServerName = host
}
switch sslmode {
case "allow":
return []*tls.Config{nil, tlsConfig}, nil
case "prefer":
return []*tls.Config{tlsConfig, nil}, nil
case "require", "verify-ca", "verify-full":
return []*tls.Config{tlsConfig}, nil
default:
panic("BUG: bad sslmode should already have been caught")
}
}
func parsePort(s string) (uint16, error) {
port, err := strconv.ParseUint(s, 10, 16)
if err != nil {
return 0, err
}
if port < 1 || port > math.MaxUint16 {
return 0, errors.New("outside range")
}
return uint16(port), nil
}
func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute}
}
func makeDefaultResolver() *net.Resolver {
return net.DefaultResolver
}
func makeDefaultBuildFrontendFunc(minBufferLen int) BuildFrontendFunc {
return func(r io.Reader, w io.Writer) Frontend {
cr, err := chunkreader.NewConfig(r, chunkreader.Config{MinBufLen: minBufferLen})
if err != nil {
panic(fmt.Sprintf("BUG: chunkreader.NewConfig failed: %v", err))
}
frontend := pgproto3.NewFrontend(cr, w)
return frontend
}
}
func parseConnectTimeoutSetting(s string) (time.Duration, error) {
timeout, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, err
}
if timeout < 0 {
return 0, errors.New("negative timeout")
}
return time.Duration(timeout) * time.Second, nil
}
func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
d := makeDefaultDialer()
d.Timeout = timeout
return d.DialContext
}
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write.
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) == "on" {
return errors.New("read only connection")
}
return nil
}
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-only.
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "on" {
return errors.New("connection is not read only")
}
return nil
}
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=standby.
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "t" {
return errors.New("server is not in hot standby mode")
}
return nil
}
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=primary.
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) == "t" {
return errors.New("server is in standby mode")
}
return nil
}
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible
// target_session_attrs=prefer-standby.
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
if result.Err != nil {
return result.Err
}
if string(result.Rows[0][0]) != "t" {
return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
}
return nil
}

@ -1,65 +0,0 @@
//go:build !windows
// +build !windows
package pgconn
import (
"os"
"os/user"
"path/filepath"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
if err == nil {
settings["user"] = user.Username
settings["passfile"] = filepath.Join(user.HomeDir, ".pgpass")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
sslkey := filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(user.HomeDir, ".postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
settings["min_read_buffer_size"] = "8192"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
candidatePaths := []string{
"/var/run/postgresql", // Debian
"/private/tmp", // OSX - homebrew
"/tmp", // standard PostgreSQL
}
for _, path := range candidatePaths {
if _, err := os.Stat(path); err == nil {
return path
}
}
return "localhost"
}

@ -1,59 +0,0 @@
package pgconn
import (
"os"
"os/user"
"path/filepath"
"strings"
)
func defaultSettings() map[string]string {
settings := make(map[string]string)
settings["host"] = defaultHost()
settings["port"] = "5432"
// Default to the OS user name. Purposely ignoring err getting user name from
// OS. The client application will simply have to specify the user in that
// case (which they typically will be doing anyway).
user, err := user.Current()
appData := os.Getenv("APPDATA")
if err == nil {
// Windows gives us the username here as `DOMAIN\user` or `LOCALPCNAME\user`,
// but the libpq default is just the `user` portion, so we strip off the first part.
username := user.Username
if strings.Contains(username, "\\") {
username = username[strings.LastIndex(username, "\\")+1:]
}
settings["user"] = username
settings["passfile"] = filepath.Join(appData, "postgresql", "pgpass.conf")
settings["servicefile"] = filepath.Join(user.HomeDir, ".pg_service.conf")
sslcert := filepath.Join(appData, "postgresql", "postgresql.crt")
sslkey := filepath.Join(appData, "postgresql", "postgresql.key")
if _, err := os.Stat(sslcert); err == nil {
if _, err := os.Stat(sslkey); err == nil {
// Both the cert and key must be present to use them, or do not use either
settings["sslcert"] = sslcert
settings["sslkey"] = sslkey
}
}
sslrootcert := filepath.Join(appData, "postgresql", "root.crt")
if _, err := os.Stat(sslrootcert); err == nil {
settings["sslrootcert"] = sslrootcert
}
}
settings["target_session_attrs"] = "any"
settings["min_read_buffer_size"] = "8192"
return settings
}
// defaultHost attempts to mimic libpq's default host. libpq uses the default unix socket location on *nix and localhost
// on Windows. The default socket location is compiled into libpq. Since pgx does not have access to that default it
// checks the existence of common locations.
func defaultHost() string {
return "localhost"
}

@ -1,29 +0,0 @@
// Package pgconn is a low-level PostgreSQL database driver.
/*
pgconn provides lower level access to a PostgreSQL connection than a database/sql or pgx connection. It operates at
nearly the same level is the C library libpq.
Establishing a Connection
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for
libpq style environment variables.
Executing a Query
ExecParams and ExecPrepared execute a single query. They return readers that iterate over each row. The Read method
reads all rows into memory.
Executing Multiple Queries in a Single Round Trip
Exec and ExecBatch can execute multiple queries in a single round trip. They return readers that iterate over each query
result. The ReadAll method reads all query results into memory.
Context Support
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
method immediately returns. In most circumstances, this will close the underlying connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.
*/
package pgconn

@ -1,238 +0,0 @@
package pgconn
import (
"context"
"errors"
"fmt"
"net"
"net/url"
"regexp"
"strings"
)
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool {
if e, ok := err.(interface{ SafeToRetry() bool }); ok {
return e.SafeToRetry()
}
return false
}
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool {
var timeoutErr *errTimeout
return errors.As(err, &timeoutErr)
}
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe *PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// SQLState returns the SQLState of the error.
func (pe *PgError) SQLState() string {
return pe.Code
}
type connectError struct {
config *Config
msg string
err error
}
func (e *connectError) Error() string {
sb := &strings.Builder{}
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg)
if e.err != nil {
fmt.Fprintf(sb, " (%s)", e.err.Error())
}
return sb.String()
}
func (e *connectError) Unwrap() error {
return e.err
}
type connLockError struct {
status string
}
func (e *connLockError) SafeToRetry() bool {
return true // a lock failure by definition happens before the connection is used.
}
func (e *connLockError) Error() string {
return e.status
}
type parseConfigError struct {
connString string
msg string
err error
}
func (e *parseConfigError) Error() string {
connString := redactPW(e.connString)
if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
}
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
}
func (e *parseConfigError) Unwrap() error {
return e.err
}
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() ==
// true. Otherwise returns err.
func preferContextOverNetTimeoutError(ctx context.Context, err error) error {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil {
return &errTimeout{err: ctx.Err()}
}
return err
}
type pgconnError struct {
msg string
err error
safeToRetry bool
}
func (e *pgconnError) Error() string {
if e.msg == "" {
return e.err.Error()
}
if e.err == nil {
return e.msg
}
return fmt.Sprintf("%s: %s", e.msg, e.err.Error())
}
func (e *pgconnError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *pgconnError) Unwrap() error {
return e.err
}
// errTimeout occurs when an error was caused by a timeout. Specifically, it wraps an error which is
// context.Canceled, context.DeadlineExceeded, or an implementer of net.Error where Timeout() is true.
type errTimeout struct {
err error
}
func (e *errTimeout) Error() string {
return fmt.Sprintf("timeout: %s", e.err.Error())
}
func (e *errTimeout) SafeToRetry() bool {
return SafeToRetry(e.err)
}
func (e *errTimeout) Unwrap() error {
return e.err
}
type contextAlreadyDoneError struct {
err error
}
func (e *contextAlreadyDoneError) Error() string {
return fmt.Sprintf("context already done: %s", e.err.Error())
}
func (e *contextAlreadyDoneError) SafeToRetry() bool {
return true
}
func (e *contextAlreadyDoneError) Unwrap() error {
return e.err
}
// newContextAlreadyDoneError double-wraps a context error in `contextAlreadyDoneError` and `errTimeout`.
func newContextAlreadyDoneError(ctx context.Context) (err error) {
return &errTimeout{&contextAlreadyDoneError{err: ctx.Err()}}
}
type writeError struct {
err error
safeToRetry bool
}
func (e *writeError) Error() string {
return fmt.Sprintf("write failed: %s", e.err.Error())
}
func (e *writeError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *writeError) Unwrap() error {
return e.err
}
func redactPW(connString string) string {
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
if u, err := url.Parse(connString); err == nil {
return redactURL(u)
}
}
quotedDSN := regexp.MustCompile(`password='[^']*'`)
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
plainDSN := regexp.MustCompile(`password=[^ ]*`)
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx")
brokenURL := regexp.MustCompile(`:[^:@]+?@`)
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
return connString
}
func redactURL(u *url.URL) string {
if u == nil {
return ""
}
if _, pwSet := u.User.Password(); pwSet {
u.User = url.UserPassword(u.User.Username(), "xxxxx")
}
return u.String()
}
type NotPreferredError struct {
err error
safeToRetry bool
}
func (e *NotPreferredError) Error() string {
return fmt.Sprintf("standby server not found: %s", e.err.Error())
}
func (e *NotPreferredError) SafeToRetry() bool {
return e.safeToRetry
}
func (e *NotPreferredError) Unwrap() error {
return e.err
}

@ -1,73 +0,0 @@
package ctxwatch
import (
"context"
"sync"
)
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time.
type ContextWatcher struct {
onCancel func()
onUnwatchAfterCancel func()
unwatchChan chan struct{}
lock sync.Mutex
watchInProgress bool
onCancelWasCalled bool
}
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called.
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher {
cw := &ContextWatcher{
onCancel: onCancel,
onUnwatchAfterCancel: onUnwatchAfterCancel,
unwatchChan: make(chan struct{}),
}
return cw
}
// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called.
func (cw *ContextWatcher) Watch(ctx context.Context) {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
panic("Watch already in progress")
}
cw.onCancelWasCalled = false
if ctx.Done() != nil {
cw.watchInProgress = true
go func() {
select {
case <-ctx.Done():
cw.onCancel()
cw.onCancelWasCalled = true
<-cw.unwatchChan
case <-cw.unwatchChan:
}
}()
} else {
cw.watchInProgress = false
}
}
// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was
// called then onUnwatchAfterCancel will also be called.
func (cw *ContextWatcher) Unwatch() {
cw.lock.Lock()
defer cw.lock.Unlock()
if cw.watchInProgress {
cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled {
cw.onUnwatchAfterCancel()
}
cw.watchInProgress = false
}
}

@ -1,99 +0,0 @@
package pgconn
import (
"errors"
"fmt"
"github.com/jackc/pgproto3/v2"
)
// NewGSSFunc creates a GSS authentication provider, for use with
// RegisterGSSProvider.
type NewGSSFunc func() (GSS, error)
var newGSS NewGSSFunc
// RegisterGSSProvider registers a GSS authentication provider. For example, if
// you need to use Kerberos to authenticate with your server, add this to your
// main package:
//
// import "github.com/otan/gopgkrb5"
//
// func init() {
// pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() })
// }
func RegisterGSSProvider(newGSSArg NewGSSFunc) {
newGSS = newGSSArg
}
// GSS provides GSSAPI authentication (e.g., Kerberos).
type GSS interface {
GetInitToken(host string, service string) ([]byte, error)
GetInitTokenFromSPN(spn string) ([]byte, error)
Continue(inToken []byte) (done bool, outToken []byte, err error)
}
func (c *PgConn) gssAuth() error {
if newGSS == nil {
return errors.New("kerberos error: no GSSAPI provider registered, see https://github.com/otan/gopgkrb5")
}
cli, err := newGSS()
if err != nil {
return err
}
var nextData []byte
if c.config.KerberosSpn != "" {
// Use the supplied SPN if provided.
nextData, err = cli.GetInitTokenFromSPN(c.config.KerberosSpn)
} else {
// Allow the kerberos service name to be overridden
service := "postgres"
if c.config.KerberosSrvName != "" {
service = c.config.KerberosSrvName
}
nextData, err = cli.GetInitToken(c.config.Host, service)
}
if err != nil {
return err
}
for {
gssResponse := &pgproto3.GSSResponse{
Data: nextData,
}
_, err = c.conn.Write(gssResponse.Encode(nil))
if err != nil {
return err
}
resp, err := c.rxGSSContinue()
if err != nil {
return err
}
var done bool
done, nextData, err = cli.Continue(resp.Data)
if err != nil {
return err
}
if done {
break
}
}
return nil
}
func (c *PgConn) rxGSSContinue() (*pgproto3.AuthenticationGSSContinue, error) {
msg, err := c.receiveMessage()
if err != nil {
return nil, err
}
switch m := msg.(type) {
case *pgproto3.AuthenticationGSSContinue:
return m, nil
case *pgproto3.ErrorResponse:
return nil, ErrorResponseToPgError(m)
}
return nil, fmt.Errorf("expected AuthenticationGSSContinue message but received unexpected message %T", msg)
}

File diff suppressed because it is too large Load Diff

@ -1,169 +0,0 @@
package stmtcache
import (
"container/list"
"context"
"fmt"
"sync/atomic"
"github.com/jackc/pgconn"
)
var lruCount uint64
// LRU implements Cache with a Least Recently Used (LRU) cache.
type LRU struct {
conn *pgconn.PgConn
mode int
cap int
prepareCount int
m map[string]*list.Element
l *list.List
psNamePrefix string
stmtsToClear []string
}
// NewLRU creates a new LRU. mode is either ModePrepare or ModeDescribe. cap is the maximum size of the cache.
func NewLRU(conn *pgconn.PgConn, mode int, cap int) *LRU {
mustBeValidMode(mode)
mustBeValidCap(cap)
n := atomic.AddUint64(&lruCount, 1)
return &LRU{
conn: conn,
mode: mode,
cap: cap,
m: make(map[string]*list.Element),
l: list.New(),
psNamePrefix: fmt.Sprintf("lrupsc_%d", n),
}
}
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
func (c *LRU) Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
if ctx != context.Background() {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// flush an outstanding bad statements
txStatus := c.conn.TxStatus()
if (txStatus == 'I' || txStatus == 'T') && len(c.stmtsToClear) > 0 {
for _, stmt := range c.stmtsToClear {
err := c.clearStmt(ctx, stmt)
if err != nil {
return nil, err
}
}
}
if el, ok := c.m[sql]; ok {
c.l.MoveToFront(el)
return el.Value.(*pgconn.StatementDescription), nil
}
if c.l.Len() == c.cap {
err := c.removeOldest(ctx)
if err != nil {
return nil, err
}
}
psd, err := c.prepare(ctx, sql)
if err != nil {
return nil, err
}
el := c.l.PushFront(psd)
c.m[sql] = el
return psd, nil
}
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
func (c *LRU) Clear(ctx context.Context) error {
for c.l.Len() > 0 {
err := c.removeOldest(ctx)
if err != nil {
return err
}
}
return nil
}
func (c *LRU) StatementErrored(sql string, err error) {
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return
}
// https://github.com/jackc/pgx/issues/1162
//
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
if possibleInvalidCachedPlanError {
c.stmtsToClear = append(c.stmtsToClear, sql)
}
}
func (c *LRU) clearStmt(ctx context.Context, sql string) error {
elem, inMap := c.m[sql]
if !inMap {
// The statement probably fell off the back of the list. In that case, we've
// ensured that it isn't in the cache, so we can declare victory.
return nil
}
c.l.Remove(elem)
psd := elem.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}
// Len returns the number of cached prepared statement descriptions.
func (c *LRU) Len() int {
return c.l.Len()
}
// Cap returns the maximum number of cached prepared statement descriptions.
func (c *LRU) Cap() int {
return c.cap
}
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
func (c *LRU) Mode() int {
return c.mode
}
func (c *LRU) prepare(ctx context.Context, sql string) (*pgconn.StatementDescription, error) {
var name string
if c.mode == ModePrepare {
name = fmt.Sprintf("%s_%d", c.psNamePrefix, c.prepareCount)
c.prepareCount += 1
}
return c.conn.Prepare(ctx, name, sql, nil)
}
func (c *LRU) removeOldest(ctx context.Context) error {
oldest := c.l.Back()
c.l.Remove(oldest)
psd := oldest.Value.(*pgconn.StatementDescription)
delete(c.m, psd.SQL)
if c.mode == ModePrepare {
return c.conn.Exec(ctx, fmt.Sprintf("deallocate %s", psd.Name)).Close()
}
return nil
}

@ -1,58 +0,0 @@
// Package stmtcache is a cache that can be used to implement lazy prepared statements.
package stmtcache
import (
"context"
"github.com/jackc/pgconn"
)
const (
ModePrepare = iota // Cache should prepare named statements.
ModeDescribe // Cache should prepare the anonymous prepared statement to only fetch the description of the statement.
)
// Cache prepares and caches prepared statement descriptions.
type Cache interface {
// Get returns the prepared statement description for sql preparing or describing the sql on the server as needed.
Get(ctx context.Context, sql string) (*pgconn.StatementDescription, error)
// Clear removes all entries in the cache. Any prepared statements will be deallocated from the PostgreSQL session.
Clear(ctx context.Context) error
// StatementErrored informs the cache that the given statement resulted in an error when it
// was last used against the database. In some cases, this will cause the cache to maer that
// statement as bad. The bad statement will instead be flushed during the next call to Get
// that occurs outside of a failed transaction.
StatementErrored(sql string, err error)
// Len returns the number of cached prepared statement descriptions.
Len() int
// Cap returns the maximum number of cached prepared statement descriptions.
Cap() int
// Mode returns the mode of the cache (ModePrepare or ModeDescribe)
Mode() int
}
// New returns the preferred cache implementation for mode and cap. mode is either ModePrepare or ModeDescribe. cap is
// the maximum size of the cache.
func New(conn *pgconn.PgConn, mode int, cap int) Cache {
mustBeValidMode(mode)
mustBeValidCap(cap)
return NewLRU(conn, mode, cap)
}
func mustBeValidMode(mode int) {
if mode != ModePrepare && mode != ModeDescribe {
panic("mode must be ModePrepare or ModeDescribe")
}
}
func mustBeValidCap(cap int) {
if cap < 1 {
panic("cache must have cap of >= 1")
}
}

@ -1,9 +0,0 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

@ -1,22 +0,0 @@
Copyright (c) 2019 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -1,11 +0,0 @@
[![](https://godoc.org/github.com/jackc/pgio?status.svg)](https://godoc.org/github.com/jackc/pgio)
[![Build Status](https://travis-ci.org/jackc/pgio.svg)](https://travis-ci.org/jackc/pgio)
# pgio
Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
pgio provides functions for appending integers to a []byte while doing byte
order conversion.
Extracted from original implementation in https://github.com/jackc/pgx.

@ -1,6 +0,0 @@
// Package pgio is a low-level toolkit building messages in the PostgreSQL wire protocol.
/*
pgio provides functions for appending integers to a []byte while doing byte
order conversion.
*/
package pgio

@ -1,40 +0,0 @@
package pgio
import "encoding/binary"
func AppendUint16(buf []byte, n uint16) []byte {
wp := len(buf)
buf = append(buf, 0, 0)
binary.BigEndian.PutUint16(buf[wp:], n)
return buf
}
func AppendUint32(buf []byte, n uint32) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0)
binary.BigEndian.PutUint32(buf[wp:], n)
return buf
}
func AppendUint64(buf []byte, n uint64) []byte {
wp := len(buf)
buf = append(buf, 0, 0, 0, 0, 0, 0, 0, 0)
binary.BigEndian.PutUint64(buf[wp:], n)
return buf
}
func AppendInt16(buf []byte, n int16) []byte {
return AppendUint16(buf, uint16(n))
}
func AppendInt32(buf []byte, n int32) []byte {
return AppendUint32(buf, uint32(n))
}
func AppendInt64(buf []byte, n int64) []byte {
return AppendUint64(buf, uint64(n))
}
func SetInt32(buf []byte, n int32) {
binary.BigEndian.PutUint32(buf, uint32(n))
}

@ -1,9 +0,0 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

@ -1,22 +0,0 @@
Copyright (c) 2019 Jack Christensen
MIT License
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -1,18 +0,0 @@
[![](https://godoc.org/github.com/jackc/pgproto3?status.svg)](https://godoc.org/github.com/jackc/pgproto3)
[![Build Status](https://travis-ci.org/jackc/pgproto3.svg)](https://travis-ci.org/jackc/pgproto3)
---
This version is used with pgx `v4`. In pgx `v5` it is part of the https://github.com/jackc/pgx repository.
---
# pgproto3
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.
See example/pgfortune for a playful example of a fake PostgreSQL server.
Extracted from original implementation in https://github.com/jackc/pgx.

@ -1,52 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
type AuthenticationCleartextPassword struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationCleartextPassword) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationCleartextPassword) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeCleartextPassword {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationCleartextPassword) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "AuthenticationCleartextPassword",
})
}

@ -1,58 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type AuthenticationGSS struct{}
func (a *AuthenticationGSS) Backend() {}
func (a *AuthenticationGSS) AuthenticationResponse() {}
func (a *AuthenticationGSS) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSS {
return errors.New("bad auth type")
}
return nil
}
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst
}
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSS",
})
}
func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
return nil
}

@ -1,67 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type AuthenticationGSSContinue struct {
Data []byte
}
func (a *AuthenticationGSSContinue) Backend() {}
func (a *AuthenticationGSSContinue) AuthenticationResponse() {}
func (a *AuthenticationGSSContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeGSSCont {
return errors.New("bad auth type")
}
a.Data = src[4:]
return nil
}
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return dst
}
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSSContinue",
Data: a.Data,
})
}
func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
a.Data = msg.Data
return nil
}

@ -1,77 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationMD5Password is a message sent from the backend indicating that an MD5 hashed password is required.
type AuthenticationMD5Password struct {
Salt [4]byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationMD5Password) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationMD5Password) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationMD5Password) Decode(src []byte) error {
if len(src) != 8 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeMD5Password {
return errors.New("bad auth type")
}
copy(dst.Salt[:], src[4:8])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationMD5Password) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Salt [4]byte
}{
Type: "AuthenticationMD5Password",
Salt: src.Salt,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationMD5Password) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Salt [4]byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Salt = msg.Salt
return nil
}

@ -1,52 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
type AuthenticationOk struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationOk) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationOk) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationOk) Decode(src []byte) error {
if len(src) != 4 {
return errors.New("bad authentication message size")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeOk {
return errors.New("bad auth type")
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationOk) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "AuthenticationOK",
})
}

@ -1,75 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASL is a message sent from the backend indicating that SASL authentication is required.
type AuthenticationSASL struct {
AuthMechanisms []string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASL) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASL) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASL) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASL {
return errors.New("bad auth type")
}
authMechanisms := src[4:]
for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 {
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms {
dst = append(dst, []byte(s)...)
dst = append(dst, 0)
}
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationSASL) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
AuthMechanisms []string
}{
Type: "AuthenticationSASL",
AuthMechanisms: src.AuthMechanisms,
})
}

@ -1,81 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLContinue is a message sent from the backend containing a SASL challenge.
type AuthenticationSASLContinue struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLContinue) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLContinue) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLContinue {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src AuthenticationSASLContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "AuthenticationSASLContinue",
Data: string(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

@ -1,81 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
// AuthenticationSASLFinal is a message sent from the backend indicating a SASL authentication has completed.
type AuthenticationSASLFinal struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationSASLFinal) Backend() {}
// Backend identifies this message as an authentication response.
func (*AuthenticationSASLFinal) AuthenticationResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}
authType := binary.BigEndian.Uint32(src)
if authType != AuthTypeSASLFinal {
return errors.New("bad auth type")
}
dst.Data = src[4:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
dst = append(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = append(dst, src.Data...)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Unmarshaler.
func (src AuthenticationSASLFinal) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "AuthenticationSASLFinal",
Data: string(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *AuthenticationSASLFinal) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

@ -1,213 +0,0 @@
package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
// Backend acts as a server for the PostgreSQL wire protocol version 3.
type Backend struct {
cr ChunkReader
w io.Writer
// Frontend message flyweights
bind Bind
cancelRequest CancelRequest
_close Close
copyFail CopyFail
copyData CopyData
copyDone CopyDone
describe Describe
execute Execute
flush Flush
functionCall FunctionCall
gssEncRequest GSSEncRequest
parse Parse
query Query
sslRequest SSLRequest
startupMessage StartupMessage
sync Sync
terminate Terminate
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
const (
minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code.
maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source.
)
// NewBackend creates a new Backend.
func NewBackend(cr ChunkReader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w}
}
// Send sends a message to the frontend.
func (b *Backend) Send(msg BackendMessage) error {
_, err := b.w.Write(msg.Encode(nil))
return err
}
// ReceiveStartupMessage receives the initial connection message. This method is used of the normal Receive method
// because the initial connection message is "special" and does not include the message type as the first byte. This
// will return either a StartupMessage, SSLRequest, GSSEncRequest, or CancelRequest.
func (b *Backend) ReceiveStartupMessage() (FrontendMessage, error) {
buf, err := b.cr.Next(4)
if err != nil {
return nil, err
}
msgSize := int(binary.BigEndian.Uint32(buf) - 4)
if msgSize < minStartupPacketLen || msgSize > maxStartupPacketLen {
return nil, fmt.Errorf("invalid length of startup packet: %d", msgSize)
}
buf, err = b.cr.Next(msgSize)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
code := binary.BigEndian.Uint32(buf)
switch code {
case ProtocolVersionNumber:
err = b.startupMessage.Decode(buf)
if err != nil {
return nil, err
}
return &b.startupMessage, nil
case sslRequestNumber:
err = b.sslRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.sslRequest, nil
case cancelRequestCode:
err = b.cancelRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.cancelRequest, nil
case gssEncReqNumber:
err = b.gssEncRequest.Decode(buf)
if err != nil {
return nil, err
}
return &b.gssEncRequest, nil
default:
return nil, fmt.Errorf("unknown startup message code: %d", code)
}
}
// Receive receives a message from the frontend. The returned message is only valid until the next call to Receive.
func (b *Backend) Receive() (FrontendMessage, error) {
if !b.partialMsg {
header, err := b.cr.Next(5)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
b.partialMsg = true
if b.bodyLen < 0 {
return nil, errors.New("invalid message with negative body length received")
}
}
var msg FrontendMessage
switch b.msgType {
case 'B':
msg = &b.bind
case 'C':
msg = &b._close
case 'D':
msg = &b.describe
case 'E':
msg = &b.execute
case 'F':
msg = &b.functionCall
case 'f':
msg = &b.copyFail
case 'd':
msg = &b.copyData
case 'c':
msg = &b.copyDone
case 'H':
msg = &b.flush
case 'P':
msg = &b.parse
case 'p':
switch b.authType {
case AuthTypeSASL:
msg = &SASLInitialResponse{}
case AuthTypeSASLContinue:
msg = &SASLResponse{}
case AuthTypeSASLFinal:
msg = &SASLResponse{}
case AuthTypeGSS, AuthTypeGSSCont:
msg = &GSSResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:
// to maintain backwards compatability
msg = &PasswordMessage{}
}
case 'Q':
msg = &b.query
case 'S':
msg = &b.sync
case 'X':
msg = &b.terminate
default:
return nil, fmt.Errorf("unknown message type: %c", b.msgType)
}
msgBody, err := b.cr.Next(b.bodyLen)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
b.partialMsg = false
err = msg.Decode(msgBody)
return msg, err
}
// SetAuthType sets the authentication type in the backend.
// Since multiple message types can start with 'p', SetAuthType allows
// contextual identification of FrontendMessages. For example, in the
// PG message flow documentation for PasswordMessage:
//
// Byte1('p')
//
// Identifies the message as a password response. Note that this is also used for
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
// the context.
//
// Since the Frontend does not know about the state of a backend, it is important
// to call SetAuthType() after an authentication request is received by the Frontend.
func (b *Backend) SetAuthType(authType uint32) error {
switch authType {
case AuthTypeOk,
AuthTypeCleartextPassword,
AuthTypeMD5Password,
AuthTypeSCMCreds,
AuthTypeGSS,
AuthTypeGSSCont,
AuthTypeSSPI,
AuthTypeSASL,
AuthTypeSASLContinue,
AuthTypeSASLFinal:
b.authType = authType
default:
return fmt.Errorf("authType not recognized: %d", authType)
}
return nil
}

@ -1,51 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type BackendKeyData struct {
ProcessID uint32
SecretKey uint32
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*BackendKeyData) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *BackendKeyData) Decode(src []byte) error {
if len(src) != 8 {
return &invalidMessageLenErr{messageType: "BackendKeyData", expectedLen: 8, actualLen: len(src)}
}
dst.ProcessID = binary.BigEndian.Uint32(src[:4])
dst.SecretKey = binary.BigEndian.Uint32(src[4:])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) []byte {
dst = append(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src BackendKeyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "BackendKeyData",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

@ -1,37 +0,0 @@
package pgproto3
import (
"encoding/binary"
)
type BigEndianBuf [8]byte
func (b BigEndianBuf) Int16(n int16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, uint16(n))
return buf
}
func (b BigEndianBuf) Uint16(n uint16) []byte {
buf := b[0:2]
binary.BigEndian.PutUint16(buf, n)
return buf
}
func (b BigEndianBuf) Int32(n int32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, uint32(n))
return buf
}
func (b BigEndianBuf) Uint32(n uint32) []byte {
buf := b[0:4]
binary.BigEndian.PutUint32(buf, n)
return buf
}
func (b BigEndianBuf) Int64(n int64) []byte {
buf := b[0:8]
binary.BigEndian.PutUint64(buf, uint64(n))
return buf
}

@ -1,216 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/jackc/pgio"
)
type Bind struct {
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters [][]byte
ResultFormatCodes []int16
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Bind) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Bind) Decode(src []byte) error {
*dst = Bind{}
idx := bytes.IndexByte(src, 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.DestinationPortal = string(src[:idx])
rp := idx + 1
idx = bytes.IndexByte(src[rp:], 0)
if idx < 0 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.PreparedStatement = string(src[rp : rp+idx])
rp += idx + 1
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
parameterFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if parameterFormatCodeCount > 0 {
dst.ParameterFormatCodes = make([]int16, parameterFormatCodeCount)
if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
for i := 0; i < parameterFormatCodeCount; i++ {
dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
}
}
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
parameterCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
if parameterCount > 0 {
dst.Parameters = make([][]byte, parameterCount)
for i := 0; i < parameterCount; i++ {
if len(src[rp:]) < 4 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
// null
if msgSize == -1 {
continue
}
if len(src[rp:]) < msgSize {
return &invalidMessageFormatErr{messageType: "Bind"}
}
dst.Parameters[i] = src[rp : rp+msgSize]
rp += msgSize
}
}
if len(src[rp:]) < 2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
resultFormatCodeCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
dst.ResultFormatCodes = make([]int16, resultFormatCodeCount)
if len(src[rp:]) < len(dst.ResultFormatCodes)*2 {
return &invalidMessageFormatErr{messageType: "Bind"}
}
for i := 0; i < resultFormatCodeCount; i++ {
dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:]))
rp += 2
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0)
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters {
if p == nil {
dst = pgio.AppendInt32(dst, -1)
continue
}
dst = pgio.AppendInt32(dst, int32(len(p)))
dst = append(dst, p...)
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Bind) MarshalJSON() ([]byte, error) {
formattedParameters := make([]map[string]string, len(src.Parameters))
for i, p := range src.Parameters {
if p == nil {
continue
}
textFormat := true
if len(src.ParameterFormatCodes) == 1 {
textFormat = src.ParameterFormatCodes[0] == 0
} else if len(src.ParameterFormatCodes) > 1 {
textFormat = src.ParameterFormatCodes[i] == 0
}
if textFormat {
formattedParameters[i] = map[string]string{"text": string(p)}
} else {
formattedParameters[i] = map[string]string{"binary": hex.EncodeToString(p)}
}
}
return json.Marshal(struct {
Type string
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters []map[string]string
ResultFormatCodes []int16
}{
Type: "Bind",
DestinationPortal: src.DestinationPortal,
PreparedStatement: src.PreparedStatement,
ParameterFormatCodes: src.ParameterFormatCodes,
Parameters: formattedParameters,
ResultFormatCodes: src.ResultFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Bind) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
DestinationPortal string
PreparedStatement string
ParameterFormatCodes []int16
Parameters []map[string]string
ResultFormatCodes []int16
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
dst.DestinationPortal = msg.DestinationPortal
dst.PreparedStatement = msg.PreparedStatement
dst.ParameterFormatCodes = msg.ParameterFormatCodes
dst.Parameters = make([][]byte, len(msg.Parameters))
dst.ResultFormatCodes = msg.ResultFormatCodes
for n, parameter := range msg.Parameters {
dst.Parameters[n], err = getValueFromJSON(parameter)
if err != nil {
return fmt.Errorf("cannot get param %d: %w", n, err)
}
}
return nil
}

@ -1,34 +0,0 @@
package pgproto3
import (
"encoding/json"
)
type BindComplete struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*BindComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *BindComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "BindComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) []byte {
return append(dst, '2', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src BindComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "BindComplete",
})
}

@ -1,58 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
const cancelRequestCode = 80877102
type CancelRequest struct {
ProcessID uint32
SecretKey uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CancelRequest) Frontend() {}
func (dst *CancelRequest) Decode(src []byte) error {
if len(src) != 12 {
return errors.New("bad cancel request size")
}
requestCode := binary.BigEndian.Uint32(src)
if requestCode != cancelRequestCode {
return errors.New("bad cancel request code")
}
dst.ProcessID = binary.BigEndian.Uint32(src[4:])
dst.SecretKey = binary.BigEndian.Uint32(src[8:])
return nil
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CancelRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProcessID uint32
SecretKey uint32
}{
Type: "CancelRequest",
ProcessID: src.ProcessID,
SecretKey: src.SecretKey,
})
}

@ -1,19 +0,0 @@
package pgproto3
import (
"io"
"github.com/jackc/chunkreader/v2"
)
// ChunkReader is an interface to decouple github.com/jackc/chunkreader from this package.
type ChunkReader interface {
// Next returns buf filled with the next n bytes. If an error (including a partial read) occurs,
// buf must be nil. Next must preserve any partially read data. Next must not reuse buf.
Next(n int) (buf []byte, err error)
}
// NewChunkReader creates and returns a new default ChunkReader.
func NewChunkReader(r io.Reader) ChunkReader {
return chunkreader.New(r)
}

@ -1,89 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type Close struct {
ObjectType byte // 'S' = prepared statement, 'P' = portal
Name string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Close) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Close) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "Close"}
}
dst.ObjectType = src[0]
rp := 1
idx := bytes.IndexByte(src[rp:], 0)
if idx != len(src[rp:])-1 {
return &invalidMessageFormatErr{messageType: "Close"}
}
dst.Name = string(src[rp : len(src)-1])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Close) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Close) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ObjectType string
Name string
}{
Type: "Close",
ObjectType: string(src.ObjectType),
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Close) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Close.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}

@ -1,34 +0,0 @@
package pgproto3
import (
"encoding/json"
)
type CloseComplete struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CloseComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CloseComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "CloseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CloseComplete) Encode(dst []byte) []byte {
return append(dst, '3', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CloseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "CloseComplete",
})
}

@ -1,71 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgio"
)
type CommandComplete struct {
CommandTag []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CommandComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CommandComplete) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0)
if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CommandComplete"}
}
dst.CommandTag = src[:idx]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CommandComplete) Encode(dst []byte) []byte {
dst = append(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.CommandTag...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CommandComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
CommandTag string
}{
Type: "CommandComplete",
CommandTag: string(src.CommandTag),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CommandComplete) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
CommandTag string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.CommandTag = []byte(msg.CommandTag)
return nil
}

@ -1,95 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type CopyBothResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyBothResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyBothResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyBothResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyBothResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyBothResponse) Encode(dst []byte) []byte {
dst = append(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyBothResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyBothResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyBothResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyBothResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

@ -1,62 +0,0 @@
package pgproto3
import (
"encoding/hex"
"encoding/json"
"github.com/jackc/pgio"
)
type CopyData struct {
Data []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyData) Backend() {}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CopyData) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyData) Decode(src []byte) error {
dst.Data = src
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyData) Encode(dst []byte) []byte {
dst = append(dst, 'd')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
dst = append(dst, src.Data...)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data string
}{
Type: "CopyData",
Data: hex.EncodeToString(src.Data),
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyData) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Data string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Data = []byte(msg.Data)
return nil
}

@ -1,38 +0,0 @@
package pgproto3
import (
"encoding/json"
)
type CopyDone struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyDone) Backend() {}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CopyDone) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyDone) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "CopyDone", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyDone) Encode(dst []byte) []byte {
return append(dst, 'c', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyDone) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "CopyDone",
})
}

@ -1,53 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgio"
)
type CopyFail struct {
Message string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*CopyFail) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyFail) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0)
if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CopyFail"}
}
dst.Message = string(src[:idx])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyFail) Encode(dst []byte) []byte {
dst = append(dst, 'f')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Message...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyFail) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Message string
}{
Type: "CopyFail",
Message: src.Message,
})
}

@ -1,96 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type CopyInResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyInResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyInResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyInResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyInResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyInResponse) Encode(dst []byte) []byte {
dst = append(dst, 'G')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyInResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyInResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyInResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyInResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

@ -1,96 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type CopyOutResponse struct {
OverallFormat byte
ColumnFormatCodes []uint16
}
func (*CopyOutResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *CopyOutResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 3 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
overallFormat := buf.Next(1)[0]
columnCount := int(binary.BigEndian.Uint16(buf.Next(2)))
if buf.Len() != columnCount*2 {
return &invalidMessageFormatErr{messageType: "CopyOutResponse"}
}
columnFormatCodes := make([]uint16, columnCount)
for i := 0; i < columnCount; i++ {
columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2))
}
*dst = CopyOutResponse{OverallFormat: overallFormat, ColumnFormatCodes: columnFormatCodes}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyOutResponse) Encode(dst []byte) []byte {
dst = append(dst, 'H')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat)
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src CopyOutResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ColumnFormatCodes []uint16
}{
Type: "CopyOutResponse",
ColumnFormatCodes: src.ColumnFormatCodes,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *CopyOutResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
OverallFormat string
ColumnFormatCodes []uint16
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.OverallFormat) != 1 {
return errors.New("invalid length for CopyOutResponse.OverallFormat")
}
dst.OverallFormat = msg.OverallFormat[0]
dst.ColumnFormatCodes = msg.ColumnFormatCodes
return nil
}

@ -1,142 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"github.com/jackc/pgio"
)
type DataRow struct {
Values [][]byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*DataRow) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *DataRow) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
rp := 0
fieldCount := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
// If the capacity of the values slice is too small OR substantially too
// large reallocate. This is too avoid one row with many columns from
// permanently allocating memory.
if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 {
newCap := 32
if newCap < fieldCount {
newCap = fieldCount
}
dst.Values = make([][]byte, fieldCount, newCap)
} else {
dst.Values = dst.Values[:fieldCount]
}
for i := 0; i < fieldCount; i++ {
if len(src[rp:]) < 4 {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4
// null
if msgSize == -1 {
dst.Values[i] = nil
} else {
if len(src[rp:]) < msgSize {
return &invalidMessageFormatErr{messageType: "DataRow"}
}
dst.Values[i] = src[rp : rp+msgSize : rp+msgSize]
rp += msgSize
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values {
if v == nil {
dst = pgio.AppendInt32(dst, -1)
continue
}
dst = pgio.AppendInt32(dst, int32(len(v)))
dst = append(dst, v...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src DataRow) MarshalJSON() ([]byte, error) {
formattedValues := make([]map[string]string, len(src.Values))
for i, v := range src.Values {
if v == nil {
continue
}
var hasNonPrintable bool
for _, b := range v {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValues[i] = map[string]string{"binary": hex.EncodeToString(v)}
} else {
formattedValues[i] = map[string]string{"text": string(v)}
}
}
return json.Marshal(struct {
Type string
Values []map[string]string
}{
Type: "DataRow",
Values: formattedValues,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *DataRow) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Values []map[string]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Values = make([][]byte, len(msg.Values))
for n, parameter := range msg.Values {
var err error
dst.Values[n], err = getValueFromJSON(parameter)
if err != nil {
return err
}
}
return nil
}

@ -1,88 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
type Describe struct {
ObjectType byte // 'S' = prepared statement, 'P' = portal
Name string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Describe) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Describe) Decode(src []byte) error {
if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "Describe"}
}
dst.ObjectType = src[0]
rp := 1
idx := bytes.IndexByte(src[rp:], 0)
if idx != len(src[rp:])-1 {
return &invalidMessageFormatErr{messageType: "Describe"}
}
dst.Name = string(src[rp : len(src)-1])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Describe) Encode(dst []byte) []byte {
dst = append(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Describe) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ObjectType string
Name string
}{
Type: "Describe",
ObjectType: string(src.ObjectType),
Name: src.Name,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *Describe) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
ObjectType string
Name string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
if len(msg.ObjectType) != 1 {
return errors.New("invalid length for Describe.ObjectType")
}
dst.ObjectType = byte(msg.ObjectType[0])
dst.Name = msg.Name
return nil
}

@ -1,4 +0,0 @@
// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3.
//
// See https://www.postgresql.org/docs/current/protocol-message-formats.html for meanings of the different messages.
package pgproto3

@ -1,34 +0,0 @@
package pgproto3
import (
"encoding/json"
)
type EmptyQueryResponse struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*EmptyQueryResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *EmptyQueryResponse) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "EmptyQueryResponse", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *EmptyQueryResponse) Encode(dst []byte) []byte {
return append(dst, 'I', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src EmptyQueryResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "EmptyQueryResponse",
})
}

@ -1,334 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"strconv"
)
type ErrorResponse struct {
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ErrorResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ErrorResponse) Decode(src []byte) error {
*dst = ErrorResponse{}
buf := bytes.NewBuffer(src)
for {
k, err := buf.ReadByte()
if err != nil {
return err
}
if k == 0 {
break
}
vb, err := buf.ReadBytes(0)
if err != nil {
return err
}
v := string(vb[:len(vb)-1])
switch k {
case 'S':
dst.Severity = v
case 'V':
dst.SeverityUnlocalized = v
case 'C':
dst.Code = v
case 'M':
dst.Message = v
case 'D':
dst.Detail = v
case 'H':
dst.Hint = v
case 'P':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Position = int32(n)
case 'p':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.InternalPosition = int32(n)
case 'q':
dst.InternalQuery = v
case 'W':
dst.Where = v
case 's':
dst.SchemaName = v
case 't':
dst.TableName = v
case 'c':
dst.ColumnName = v
case 'd':
dst.DataTypeName = v
case 'n':
dst.ConstraintName = v
case 'F':
dst.File = v
case 'L':
s := v
n, _ := strconv.ParseInt(s, 10, 32)
dst.Line = int32(n)
case 'R':
dst.Routine = v
default:
if dst.UnknownFields == nil {
dst.UnknownFields = make(map[byte]string)
}
dst.UnknownFields[k] = v
}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ErrorResponse) Encode(dst []byte) []byte {
return append(dst, src.marshalBinary('E')...)
}
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
if src.Severity != "" {
buf.WriteByte('S')
buf.WriteString(src.Severity)
buf.WriteByte(0)
}
if src.SeverityUnlocalized != "" {
buf.WriteByte('V')
buf.WriteString(src.SeverityUnlocalized)
buf.WriteByte(0)
}
if src.Code != "" {
buf.WriteByte('C')
buf.WriteString(src.Code)
buf.WriteByte(0)
}
if src.Message != "" {
buf.WriteByte('M')
buf.WriteString(src.Message)
buf.WriteByte(0)
}
if src.Detail != "" {
buf.WriteByte('D')
buf.WriteString(src.Detail)
buf.WriteByte(0)
}
if src.Hint != "" {
buf.WriteByte('H')
buf.WriteString(src.Hint)
buf.WriteByte(0)
}
if src.Position != 0 {
buf.WriteByte('P')
buf.WriteString(strconv.Itoa(int(src.Position)))
buf.WriteByte(0)
}
if src.InternalPosition != 0 {
buf.WriteByte('p')
buf.WriteString(strconv.Itoa(int(src.InternalPosition)))
buf.WriteByte(0)
}
if src.InternalQuery != "" {
buf.WriteByte('q')
buf.WriteString(src.InternalQuery)
buf.WriteByte(0)
}
if src.Where != "" {
buf.WriteByte('W')
buf.WriteString(src.Where)
buf.WriteByte(0)
}
if src.SchemaName != "" {
buf.WriteByte('s')
buf.WriteString(src.SchemaName)
buf.WriteByte(0)
}
if src.TableName != "" {
buf.WriteByte('t')
buf.WriteString(src.TableName)
buf.WriteByte(0)
}
if src.ColumnName != "" {
buf.WriteByte('c')
buf.WriteString(src.ColumnName)
buf.WriteByte(0)
}
if src.DataTypeName != "" {
buf.WriteByte('d')
buf.WriteString(src.DataTypeName)
buf.WriteByte(0)
}
if src.ConstraintName != "" {
buf.WriteByte('n')
buf.WriteString(src.ConstraintName)
buf.WriteByte(0)
}
if src.File != "" {
buf.WriteByte('F')
buf.WriteString(src.File)
buf.WriteByte(0)
}
if src.Line != 0 {
buf.WriteByte('L')
buf.WriteString(strconv.Itoa(int(src.Line)))
buf.WriteByte(0)
}
if src.Routine != "" {
buf.WriteByte('R')
buf.WriteString(src.Routine)
buf.WriteByte(0)
}
for k, v := range src.UnknownFields {
buf.WriteByte(k)
buf.WriteByte(0)
buf.WriteString(v)
buf.WriteByte(0)
}
buf.WriteByte(0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1))
return buf.Bytes()
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ErrorResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}{
Type: "ErrorResponse",
Severity: src.Severity,
SeverityUnlocalized: src.SeverityUnlocalized,
Code: src.Code,
Message: src.Message,
Detail: src.Detail,
Hint: src.Hint,
Position: src.Position,
InternalPosition: src.InternalPosition,
InternalQuery: src.InternalQuery,
Where: src.Where,
SchemaName: src.SchemaName,
TableName: src.TableName,
ColumnName: src.ColumnName,
DataTypeName: src.DataTypeName,
ConstraintName: src.ConstraintName,
File: src.File,
Line: src.Line,
Routine: src.Routine,
UnknownFields: src.UnknownFields,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *ErrorResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Type string
Severity string
SeverityUnlocalized string // only in 9.6 and greater
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
UnknownFields map[byte]string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
dst.Severity = msg.Severity
dst.SeverityUnlocalized = msg.SeverityUnlocalized
dst.Code = msg.Code
dst.Message = msg.Message
dst.Detail = msg.Detail
dst.Hint = msg.Hint
dst.Position = msg.Position
dst.InternalPosition = msg.InternalPosition
dst.InternalQuery = msg.InternalQuery
dst.Where = msg.Where
dst.SchemaName = msg.SchemaName
dst.TableName = msg.TableName
dst.ColumnName = msg.ColumnName
dst.DataTypeName = msg.DataTypeName
dst.ConstraintName = msg.ConstraintName
dst.File = msg.File
dst.Line = msg.Line
dst.Routine = msg.Routine
dst.UnknownFields = msg.UnknownFields
return nil
}

@ -1,65 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type Execute struct {
Portal string
MaxRows uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Execute) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Execute) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Portal = string(b[:len(b)-1])
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "Execute"}
}
dst.MaxRows = binary.BigEndian.Uint32(buf.Next(4))
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Execute) Encode(dst []byte) []byte {
dst = append(dst, 'E')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Portal...)
dst = append(dst, 0)
dst = pgio.AppendUint32(dst, src.MaxRows)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Execute) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Portal string
MaxRows uint32
}{
Type: "Execute",
Portal: src.Portal,
MaxRows: src.MaxRows,
})
}

@ -1,34 +0,0 @@
package pgproto3
import (
"encoding/json"
)
type Flush struct{}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Flush) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Flush) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "Flush", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Flush) Encode(dst []byte) []byte {
return append(dst, 'H', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Flush) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "Flush",
})
}

@ -1,206 +0,0 @@
package pgproto3
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
// Frontend acts as a client for the PostgreSQL wire protocol version 3.
type Frontend struct {
cr ChunkReader
w io.Writer
// Backend message flyweights
authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
authenticationMD5Password AuthenticationMD5Password
authenticationGSS AuthenticationGSS
authenticationGSSContinue AuthenticationGSSContinue
authenticationSASL AuthenticationSASL
authenticationSASLContinue AuthenticationSASLContinue
authenticationSASLFinal AuthenticationSASLFinal
backendKeyData BackendKeyData
bindComplete BindComplete
closeComplete CloseComplete
commandComplete CommandComplete
copyBothResponse CopyBothResponse
copyData CopyData
copyInResponse CopyInResponse
copyOutResponse CopyOutResponse
copyDone CopyDone
dataRow DataRow
emptyQueryResponse EmptyQueryResponse
errorResponse ErrorResponse
functionCallResponse FunctionCallResponse
noData NoData
noticeResponse NoticeResponse
notificationResponse NotificationResponse
parameterDescription ParameterDescription
parameterStatus ParameterStatus
parseComplete ParseComplete
readyForQuery ReadyForQuery
rowDescription RowDescription
portalSuspended PortalSuspended
bodyLen int
msgType byte
partialMsg bool
authType uint32
}
// NewFrontend creates a new Frontend.
func NewFrontend(cr ChunkReader, w io.Writer) *Frontend {
return &Frontend{cr: cr, w: w}
}
// Send sends a message to the backend.
func (f *Frontend) Send(msg FrontendMessage) error {
_, err := f.w.Write(msg.Encode(nil))
return err
}
func translateEOFtoErrUnexpectedEOF(err error) error {
if err == io.EOF {
return io.ErrUnexpectedEOF
}
return err
}
// Receive receives a message from the backend. The returned message is only valid until the next call to Receive.
func (f *Frontend) Receive() (BackendMessage, error) {
if !f.partialMsg {
header, err := f.cr.Next(5)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
f.msgType = header[0]
f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
f.partialMsg = true
if f.bodyLen < 0 {
return nil, errors.New("invalid message with negative body length received")
}
}
msgBody, err := f.cr.Next(f.bodyLen)
if err != nil {
return nil, translateEOFtoErrUnexpectedEOF(err)
}
f.partialMsg = false
var msg BackendMessage
switch f.msgType {
case '1':
msg = &f.parseComplete
case '2':
msg = &f.bindComplete
case '3':
msg = &f.closeComplete
case 'A':
msg = &f.notificationResponse
case 'c':
msg = &f.copyDone
case 'C':
msg = &f.commandComplete
case 'd':
msg = &f.copyData
case 'D':
msg = &f.dataRow
case 'E':
msg = &f.errorResponse
case 'G':
msg = &f.copyInResponse
case 'H':
msg = &f.copyOutResponse
case 'I':
msg = &f.emptyQueryResponse
case 'K':
msg = &f.backendKeyData
case 'n':
msg = &f.noData
case 'N':
msg = &f.noticeResponse
case 'R':
var err error
msg, err = f.findAuthenticationMessageType(msgBody)
if err != nil {
return nil, err
}
case 's':
msg = &f.portalSuspended
case 'S':
msg = &f.parameterStatus
case 't':
msg = &f.parameterDescription
case 'T':
msg = &f.rowDescription
case 'V':
msg = &f.functionCallResponse
case 'W':
msg = &f.copyBothResponse
case 'Z':
msg = &f.readyForQuery
default:
return nil, fmt.Errorf("unknown message type: %c", f.msgType)
}
err = msg.Decode(msgBody)
return msg, err
}
// Authentication message type constants.
// See src/include/libpq/pqcomm.h for all
// constants.
const (
AuthTypeOk = 0
AuthTypeCleartextPassword = 3
AuthTypeMD5Password = 5
AuthTypeSCMCreds = 6
AuthTypeGSS = 7
AuthTypeGSSCont = 8
AuthTypeSSPI = 9
AuthTypeSASL = 10
AuthTypeSASLContinue = 11
AuthTypeSASLFinal = 12
)
func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) {
if len(src) < 4 {
return nil, errors.New("authentication message too short")
}
f.authType = binary.BigEndian.Uint32(src[:4])
switch f.authType {
case AuthTypeOk:
return &f.authenticationOk, nil
case AuthTypeCleartextPassword:
return &f.authenticationCleartextPassword, nil
case AuthTypeMD5Password:
return &f.authenticationMD5Password, nil
case AuthTypeSCMCreds:
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
case AuthTypeGSS:
return &f.authenticationGSS, nil
case AuthTypeGSSCont:
return &f.authenticationGSSContinue, nil
case AuthTypeSSPI:
return nil, errors.New("AuthTypeSSPI is unimplemented")
case AuthTypeSASL:
return &f.authenticationSASL, nil
case AuthTypeSASLContinue:
return &f.authenticationSASLContinue, nil
case AuthTypeSASLFinal:
return &f.authenticationSASLFinal, nil
default:
return nil, fmt.Errorf("unknown authentication type: %d", f.authType)
}
}
// GetAuthType returns the authType used in the current state of the frontend.
// See SetAuthType for more information.
func (f *Frontend) GetAuthType() uint32 {
return f.authType
}

@ -1,94 +0,0 @@
package pgproto3
import (
"encoding/binary"
"github.com/jackc/pgio"
)
type FunctionCall struct {
Function uint32
ArgFormatCodes []uint16
Arguments [][]byte
ResultFormatCode uint16
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*FunctionCall) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *FunctionCall) Decode(src []byte) error {
*dst = FunctionCall{}
rp := 0
// Specifies the object ID of the function to call.
dst.Function = binary.BigEndian.Uint32(src[rp:])
rp += 4
// The number of argument format codes that follow (denoted C below).
// This can be zero to indicate that there are no arguments or that the arguments all use the default format (text);
// or one, in which case the specified format code is applied to all arguments;
// or it can equal the actual number of arguments.
nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
argumentCodes := make([]uint16, nArgumentCodes)
for i := 0; i < nArgumentCodes; i++ {
// The argument format codes. Each must presently be zero (text) or one (binary).
ac := binary.BigEndian.Uint16(src[rp:])
if ac != 0 && ac != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
argumentCodes[i] = ac
rp += 2
}
dst.ArgFormatCodes = argumentCodes
// Specifies the number of arguments being supplied to the function.
nArguments := int(binary.BigEndian.Uint16(src[rp:]))
rp += 2
arguments := make([][]byte, nArguments)
for i := 0; i < nArguments; i++ {
// The length of the argument value, in bytes (this count does not include itself). Can be zero.
// As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case.
argumentLength := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if argumentLength == -1 {
arguments[i] = nil
} else {
// The value of the argument, in the format indicated by the associated format code. n is the above length.
argumentValue := src[rp : rp+argumentLength]
rp += argumentLength
arguments[i] = argumentValue
}
}
dst.Arguments = arguments
// The format code for the function result. Must presently be zero (text) or one (binary).
resultFormatCode := binary.BigEndian.Uint16(src[rp:])
if resultFormatCode != 0 && resultFormatCode != 1 {
return &invalidMessageFormatErr{messageType: "FunctionCall"}
}
dst.ResultFormatCode = resultFormatCode
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCall) Encode(dst []byte) []byte {
dst = append(dst, 'F')
sp := len(dst)
dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end
dst = pgio.AppendUint32(dst, src.Function)
dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes)))
for _, argFormatCode := range src.ArgFormatCodes {
dst = pgio.AppendUint16(dst, argFormatCode)
}
dst = pgio.AppendUint16(dst, uint16(len(src.Arguments)))
for _, argument := range src.Arguments {
if argument == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(argument)))
dst = append(dst, argument...)
}
}
dst = pgio.AppendUint16(dst, src.ResultFormatCode)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}

@ -1,101 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"github.com/jackc/pgio"
)
type FunctionCallResponse struct {
Result []byte
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*FunctionCallResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *FunctionCallResponse) Decode(src []byte) error {
if len(src) < 4 {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
rp := 0
resultSize := int(binary.BigEndian.Uint32(src[rp:]))
rp += 4
if resultSize == -1 {
dst.Result = nil
return nil
}
if len(src[rp:]) != resultSize {
return &invalidMessageFormatErr{messageType: "FunctionCallResponse"}
}
dst.Result = src[rp:]
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *FunctionCallResponse) Encode(dst []byte) []byte {
dst = append(dst, 'V')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
if src.Result == nil {
dst = pgio.AppendInt32(dst, -1)
} else {
dst = pgio.AppendInt32(dst, int32(len(src.Result)))
dst = append(dst, src.Result...)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src FunctionCallResponse) MarshalJSON() ([]byte, error) {
var formattedValue map[string]string
var hasNonPrintable bool
for _, b := range src.Result {
if b < 32 {
hasNonPrintable = true
break
}
}
if hasNonPrintable {
formattedValue = map[string]string{"binary": hex.EncodeToString(src.Result)}
} else {
formattedValue = map[string]string{"text": string(src.Result)}
}
return json.Marshal(struct {
Type string
Result map[string]string
}{
Type: "FunctionCallResponse",
Result: formattedValue,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (dst *FunctionCallResponse) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}
var msg struct {
Result map[string]string
}
err := json.Unmarshal(data, &msg)
if err != nil {
return err
}
dst.Result, err = getValueFromJSON(msg.Result)
return err
}

@ -1,49 +0,0 @@
package pgproto3
import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)
const gssEncReqNumber = 80877104
type GSSEncRequest struct {
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*GSSEncRequest) Frontend() {}
func (dst *GSSEncRequest) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("gss encoding request too short")
}
requestCode := binary.BigEndian.Uint32(src)
if requestCode != gssEncReqNumber {
return errors.New("bad gss encoding request code")
}
return nil
}
// Encode encodes src into dst. dst will include the 4 byte message length.
func (src *GSSEncRequest) Encode(dst []byte) []byte {
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendInt32(dst, gssEncReqNumber)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src GSSEncRequest) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ProtocolVersion uint32
Parameters map[string]string
}{
Type: "GSSEncRequest",
})
}

@ -1,48 +0,0 @@
package pgproto3
import (
"encoding/json"
"github.com/jackc/pgio"
)
type GSSResponse struct {
Data []byte
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (g *GSSResponse) Frontend() {}
func (g *GSSResponse) Decode(data []byte) error {
g.Data = data
return nil
}
func (g *GSSResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
dst = append(dst, g.Data...)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "GSSResponse",
Data: g.Data,
})
}
// UnmarshalJSON implements encoding/json.Unmarshaler.
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
var msg struct {
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
g.Data = msg.Data
return nil
}

@ -1,34 +0,0 @@
package pgproto3
import (
"encoding/json"
)
type NoData struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NoData) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NoData) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "NoData", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoData) Encode(dst []byte) []byte {
return append(dst, 'n', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src NoData) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "NoData",
})
}

@ -1,17 +0,0 @@
package pgproto3
type NoticeResponse ErrorResponse
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NoticeResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NoticeResponse) Decode(src []byte) error {
return (*ErrorResponse)(dst).Decode(src)
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NoticeResponse) Encode(dst []byte) []byte {
return append(dst, (*ErrorResponse)(src).marshalBinary('N')...)
}

@ -1,73 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type NotificationResponse struct {
PID uint32
Channel string
Payload string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*NotificationResponse) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *NotificationResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
pid := binary.BigEndian.Uint32(buf.Next(4))
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
channel := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
payload := string(b[:len(b)-1])
*dst = NotificationResponse{PID: pid, Channel: channel, Payload: payload}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *NotificationResponse) Encode(dst []byte) []byte {
dst = append(dst, 'A')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, src.PID)
dst = append(dst, src.Channel...)
dst = append(dst, 0)
dst = append(dst, src.Payload...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src NotificationResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
PID uint32
Channel string
Payload string
}{
Type: "NotificationResponse",
PID: src.PID,
Channel: src.Channel,
Payload: src.Payload,
})
}

@ -1,66 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type ParameterDescription struct {
ParameterOIDs []uint32
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ParameterDescription) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ParameterDescription) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "ParameterDescription"}
}
// Reported parameter count will be incorrect when number of args is greater than uint16
buf.Next(2)
// Instead infer parameter count by remaining size of message
parameterCount := buf.Len() / 4
*dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)}
for i := 0; i < parameterCount; i++ {
dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4))
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterDescription) Encode(dst []byte) []byte {
dst = append(dst, 't')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ParameterDescription) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
ParameterOIDs []uint32
}{
Type: "ParameterDescription",
ParameterOIDs: src.ParameterOIDs,
})
}

@ -1,66 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgio"
)
type ParameterStatus struct {
Name string
Value string
}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ParameterStatus) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ParameterStatus) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
name := string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
value := string(b[:len(b)-1])
*dst = ParameterStatus{Name: name, Value: value}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParameterStatus) Encode(dst []byte) []byte {
dst = append(dst, 'S')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Value...)
dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (ps ParameterStatus) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Value string
}{
Type: "ParameterStatus",
Name: ps.Name,
Value: ps.Value,
})
}

@ -1,88 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/binary"
"encoding/json"
"github.com/jackc/pgio"
)
type Parse struct {
Name string
Query string
ParameterOIDs []uint32
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*Parse) Frontend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *Parse) Decode(src []byte) error {
*dst = Parse{}
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Name = string(b[:len(b)-1])
b, err = buf.ReadBytes(0)
if err != nil {
return err
}
dst.Query = string(b[:len(b)-1])
if buf.Len() < 2 {
return &invalidMessageFormatErr{messageType: "Parse"}
}
parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2)))
for i := 0; i < parameterOIDCount; i++ {
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "Parse"}
}
dst.ParameterOIDs = append(dst.ParameterOIDs, binary.BigEndian.Uint32(buf.Next(4)))
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Parse) Encode(dst []byte) []byte {
dst = append(dst, 'P')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Name...)
dst = append(dst, 0)
dst = append(dst, src.Query...)
dst = append(dst, 0)
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs)))
for _, oid := range src.ParameterOIDs {
dst = pgio.AppendUint32(dst, oid)
}
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src Parse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Name string
Query string
ParameterOIDs []uint32
}{
Type: "Parse",
Name: src.Name,
Query: src.Query,
ParameterOIDs: src.ParameterOIDs,
})
}

@ -1,34 +0,0 @@
package pgproto3
import (
"encoding/json"
)
type ParseComplete struct{}
// Backend identifies this message as sendable by the PostgreSQL backend.
func (*ParseComplete) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *ParseComplete) Decode(src []byte) error {
if len(src) != 0 {
return &invalidMessageLenErr{messageType: "ParseComplete", expectedLen: 0, actualLen: len(src)}
}
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ParseComplete) Encode(dst []byte) []byte {
return append(dst, '1', 0, 0, 0, 4)
}
// MarshalJSON implements encoding/json.Marshaler.
func (src ParseComplete) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
}{
Type: "ParseComplete",
})
}

@ -1,54 +0,0 @@
package pgproto3
import (
"bytes"
"encoding/json"
"github.com/jackc/pgio"
)
type PasswordMessage struct {
Password string
}
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*PasswordMessage) Frontend() {}
// Frontend identifies this message as an authentication response.
func (*PasswordMessage) InitialResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length.
func (dst *PasswordMessage) Decode(src []byte) error {
buf := bytes.NewBuffer(src)
b, err := buf.ReadBytes(0)
if err != nil {
return err
}
dst.Password = string(b[:len(b)-1])
return nil
}
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *PasswordMessage) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1))
dst = append(dst, src.Password...)
dst = append(dst, 0)
return dst
}
// MarshalJSON implements encoding/json.Marshaler.
func (src PasswordMessage) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Password string
}{
Type: "PasswordMessage",
Password: src.Password,
})
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save