parent
bcb246b16f
commit
cc3c269200
@ -0,0 +1,17 @@
|
||||
package dorm
|
||||
|
||||
import (
|
||||
"gitee.com/chunanyong/zorm"
|
||||
)
|
||||
|
||||
type ConfigZormClient struct {
|
||||
Dns string // 地址
|
||||
}
|
||||
|
||||
// ZormClient
|
||||
// https://zorm.cn/
|
||||
// https://www.yuque.com/u27016943/nrgi00
|
||||
type ZormClient struct {
|
||||
Db *zorm.DBDao // 驱动
|
||||
config *ConfigZormClient // 配置
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
package dorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitee.com/chunanyong/zorm"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
func NewZormMysqlClient(config *ConfigZormClient) (*ZormClient, error) {
|
||||
|
||||
var err error
|
||||
c := &ZormClient{config: config}
|
||||
|
||||
c.Db, err = zorm.NewDBDao(&zorm.DataSourceConfig{
|
||||
DSN: c.config.Dns,
|
||||
DriverName: "mysql", // 数据库驱动名称
|
||||
DBType: "mysql", // 数据库类型
|
||||
PrintSQL: true, // 是否打印sql
|
||||
MaxOpenConns: 0, // 数据库最大连接数,默认50
|
||||
MaxIdleConns: 0, // 数据库最大空闲连接数,默认50
|
||||
ConnMaxLifetimeSecond: 0, // 连接存活秒时间. 默认600(10分钟)后连接被销毁重建.
|
||||
// 避免数据库主动断开连接,造成死连接.MySQL默认wait_timeout 28800秒(8小时)
|
||||
DefaultTxOptions: nil, // 事务隔离级别的默认配置,默认为nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
package dorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gitee.com/chunanyong/zorm"
|
||||
_ "github.com/bmizerany/pq"
|
||||
)
|
||||
|
||||
func NewZormPostgresqlClient(config *ConfigZormClient) (*ZormClient, error) {
|
||||
|
||||
var err error
|
||||
c := &ZormClient{config: config}
|
||||
|
||||
c.Db, err = zorm.NewDBDao(&zorm.DataSourceConfig{
|
||||
DSN: c.config.Dns,
|
||||
DriverName: "postgres", // 数据库驱动名称
|
||||
DBType: "postgresql", // 数据库类型
|
||||
PrintSQL: true, // 是否打印sql
|
||||
MaxOpenConns: 0, // 数据库最大连接数,默认50
|
||||
MaxIdleConns: 0, // 数据库最大空闲连接数,默认50
|
||||
ConnMaxLifetimeSecond: 0, // 连接存活秒时间. 默认600(10分钟)后连接被销毁重建.
|
||||
// 避免数据库主动断开连接,造成死连接.MySQL默认wait_timeout 28800秒(8小时)
|
||||
DefaultTxOptions: nil, // 事务隔离级别的默认配置,默认为nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
# idea ignore
|
||||
.idea/
|
||||
*.ipr
|
||||
*.iml
|
||||
*.iws
|
||||
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
# temp ignore
|
||||
*.log
|
||||
*.cache
|
||||
*.diff
|
||||
*.exe
|
||||
*.exe~
|
||||
*.patch
|
||||
*.tmp
|
||||
*debug.test
|
||||
debug.test
|
||||
go.sum
|
||||
|
||||
# system ignore
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# project
|
||||
*.cert
|
||||
*.key
|
||||
.test
|
||||
iprepo.txt
|
||||
|
||||
|
||||
_output
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,184 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Finder 查询数据库的载体,所有的sql语句都要通过Finder执行.
|
||||
// Finder To query the database carrier, all SQL statements must be executed through Finder
|
||||
type Finder struct {
|
||||
// 拼接SQL
|
||||
// Splicing SQL.
|
||||
sqlBuilder strings.Builder
|
||||
// SQL的参数值
|
||||
// SQL parameter values.
|
||||
values []interface{}
|
||||
// 注入检查,默认true 不允许SQL注入的 ' 单引号
|
||||
// Injection check, default true does not allow SQL injection single quote
|
||||
InjectionCheck bool
|
||||
// CountFinder 自定义的查询总条数'Finder',使用指针默认为nil.主要是为了在'group by'等复杂情况下,为了性能,手动编写总条数语句
|
||||
// CountFinder The total number of custom queries is'Finder', and the pointer is nil by default. It is mainly used to manually write the total number of statements for performance in complex situations such as'group by'
|
||||
CountFinder *Finder
|
||||
// 是否自动查询总条数,默认true.同时需要Page不为nil,才查询总条数
|
||||
// Whether to automatically query the total number of entries, the default is true. At the same time, the Page is not nil to query the total number of entries
|
||||
SelectTotalCount bool
|
||||
// SQL语句
|
||||
// SQL statement
|
||||
sqlstr string
|
||||
}
|
||||
|
||||
// NewFinder 初始化一个Finder,生成一个空的Finder
|
||||
// NewFinder Initialize a Finder and generate an empty Finder
|
||||
func NewFinder() *Finder {
|
||||
finder := Finder{}
|
||||
finder.sqlBuilder.Grow(stringBuilderGrowLen)
|
||||
finder.SelectTotalCount = true
|
||||
finder.InjectionCheck = true
|
||||
// slice扩容会生成新的slice,最后要值复制接收.问:为什么cap是3?答:经验
|
||||
finder.values = make([]interface{}, 0, 3)
|
||||
return &finder
|
||||
}
|
||||
|
||||
// NewSelectFinder 根据表名初始化查询的Finder,strs 只取第一个字符串,用数组类型是为了可以不传入,默认为 * | Finder that initializes the query based on the table name
|
||||
// NewSelectFinder("tableName") SELECT * FROM tableName
|
||||
// NewSelectFinder("tableName", "id,name") SELECT id,name FROM tableName
|
||||
func NewSelectFinder(tableName string, strs ...string) *Finder {
|
||||
strsLen := len(strs)
|
||||
if strsLen > 1 { // 不支持多个参数
|
||||
return nil
|
||||
}
|
||||
finder := NewFinder()
|
||||
finder.sqlBuilder.WriteString("SELECT ")
|
||||
if strsLen == 1 { // 只取值第一个字符串
|
||||
finder.sqlBuilder.WriteString(strs[0])
|
||||
} else {
|
||||
finder.sqlBuilder.WriteByte('*')
|
||||
}
|
||||
finder.sqlBuilder.WriteString(" FROM ")
|
||||
finder.sqlBuilder.WriteString(tableName)
|
||||
return finder
|
||||
}
|
||||
|
||||
// NewUpdateFinder 根据表名初始化更新的Finder, UPDATE tableName SET
|
||||
// NewUpdateFinder Initialize the updated Finder according to the table name, UPDATE tableName SET
|
||||
func NewUpdateFinder(tableName string) *Finder {
|
||||
finder := NewFinder()
|
||||
finder.sqlBuilder.WriteString("UPDATE ")
|
||||
finder.sqlBuilder.WriteString(tableName)
|
||||
finder.sqlBuilder.WriteString(" SET ")
|
||||
return finder
|
||||
}
|
||||
|
||||
// NewDeleteFinder 根据表名初始化删除的'Finder', DELETE FROM tableName
|
||||
// NewDeleteFinder Finder for initial deletion based on table name. DELETE FROM tableName
|
||||
func NewDeleteFinder(tableName string) *Finder {
|
||||
finder := NewFinder()
|
||||
finder.sqlBuilder.WriteString("DELETE FROM ")
|
||||
finder.sqlBuilder.WriteString(tableName)
|
||||
// 所有的 WHERE 都不加,规则统一,好记
|
||||
// No WHERE is added, the rules are unified, easy to remember
|
||||
// finder.sqlBuilder.WriteString(" WHERE ")
|
||||
return finder
|
||||
}
|
||||
|
||||
// Append 添加SQL和参数的值,第一个参数是语句,后面的参数[可选]是参数的值,顺序要正确
|
||||
// 例如: finder.Append(" and id=? and name=? ",23123,"abc")
|
||||
// 只拼接SQL,例如: finder.Append(" and name=123 ")
|
||||
// Append:Add SQL and parameter values, the first parameter is the statement, and the following parameter (optional) is the value of the parameter, in the correct order
|
||||
// E.g: finder.Append(" and id=? and name=? ",23123,"abc")
|
||||
// Only splice SQL, E.g: finder.Append(" and name=123 ")
|
||||
func (finder *Finder) Append(s string, values ...interface{}) *Finder {
|
||||
// 不要自己构建finder,使用NewFinder()方法
|
||||
// Don't build finder by yourself, use NewFinder() method
|
||||
if finder == nil || finder.values == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s != "" {
|
||||
if finder.sqlstr != "" {
|
||||
finder.sqlstr = ""
|
||||
}
|
||||
// 默认加一个空格,避免手误两个字符串连接再一起
|
||||
// A space is added by default to avoid hand mistakes when connecting two strings together
|
||||
finder.sqlBuilder.WriteByte(' ')
|
||||
|
||||
finder.sqlBuilder.WriteString(s)
|
||||
|
||||
}
|
||||
if values == nil || len(values) < 1 {
|
||||
return finder
|
||||
}
|
||||
|
||||
finder.values = append(finder.values, values...)
|
||||
return finder
|
||||
}
|
||||
|
||||
// AppendFinder 添加另一个Finder finder.AppendFinder(f)
|
||||
// AppendFinder Add another Finder . finder.AppendFinder(f)
|
||||
func (finder *Finder) AppendFinder(f *Finder) (*Finder, error) {
|
||||
if finder == nil {
|
||||
return finder, errors.New("->finder-->AppendFinder()finder对象为nil")
|
||||
}
|
||||
if f == nil {
|
||||
return finder, errors.New("->finder-->AppendFinder()参数是nil")
|
||||
}
|
||||
|
||||
// 不要自己构建finder,使用NewFinder()方法
|
||||
// Don't build finder by yourself, use NewFinder() method
|
||||
if finder.values == nil {
|
||||
return finder, errors.New("->finder-->AppendFinder()不要自己构建finder,使用NewFinder()方法")
|
||||
}
|
||||
|
||||
// 添加f的SQL
|
||||
// SQL to add f
|
||||
sqlstr, err := f.GetSQL()
|
||||
if err != nil {
|
||||
return finder, err
|
||||
}
|
||||
finder.sqlstr = ""
|
||||
finder.sqlBuilder.WriteString(sqlstr)
|
||||
// 添加f的值
|
||||
// Add the value of f
|
||||
finder.values = append(finder.values, f.values...)
|
||||
return finder, nil
|
||||
}
|
||||
|
||||
// GetSQL 返回Finder封装的SQL语句
|
||||
// GetSQL Return the SQL statement encapsulated by the Finder
|
||||
func (finder *Finder) GetSQL() (string, error) {
|
||||
// 不要自己构建finder,使用NewFinder方法
|
||||
// Don't build finder by yourself, use NewFinder method
|
||||
if finder == nil || finder.values == nil {
|
||||
return "", errors.New("->finder-->GetSQL()不要自己构建finder,使用NewFinder()方法")
|
||||
}
|
||||
if len(finder.sqlstr) > 0 {
|
||||
return finder.sqlstr, nil
|
||||
}
|
||||
sqlstr := finder.sqlBuilder.String()
|
||||
// 包含单引号,属于非法字符串
|
||||
// Contains single quotes, which are illegal strings
|
||||
if finder.InjectionCheck && (strings.Contains(sqlstr, "'")) {
|
||||
return "", errors.New(`->finder-->GetSQL()SQL语句请不要直接拼接字符串参数,容易注入!!!请使用问号占位符,例如 finder.Append("and id=?","stringId"),如果必须拼接字符串,请设置 finder.InjectionCheck = false `)
|
||||
}
|
||||
finder.sqlstr = sqlstr
|
||||
return sqlstr, nil
|
||||
}
|
@ -0,0 +1,141 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// customDriverValueMap 用于配置数据库字段类型的处理关系,key是 Dialect.字段类型,例如 dm.TEXT
|
||||
var customDriverValueMap = make(map[string]ICustomDriverValueConver)
|
||||
|
||||
// iscdvm 是否有自定义的DriverValueMap
|
||||
var iscdvm bool
|
||||
|
||||
// ICustomDriverValueConver 自定义类型转化接口,用于解决 类似达梦 text --> dm.DmClob --> string类型接收的问题
|
||||
type ICustomDriverValueConver interface {
|
||||
// GetDriverValue 根据数据库列类型,返回driver.Value的实例,struct属性类型
|
||||
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
|
||||
GetDriverValue(ctx context.Context, columnType *sql.ColumnType, structFieldType *reflect.Type) (driver.Value, error)
|
||||
|
||||
// ConverDriverValue 数据库列类型,GetDriverValue返回的driver.Value的临时接收值,struct属性类型
|
||||
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
|
||||
// 返回符合接收类型值的指针,指针,指针!!!!
|
||||
ConverDriverValue(ctx context.Context, columnType *sql.ColumnType, tempDriverValue driver.Value, structFieldType *reflect.Type) (interface{}, error)
|
||||
}
|
||||
|
||||
// RegisterCustomDriverValueConver 注册自定义的字段处理逻辑,用于驱动无法直接转换的场景,例如达梦的 TEXT 无法直接转化成 string
|
||||
// dialectColumnType 值是 Dialect.字段类型,例如: dm.TEXT
|
||||
// 一般是放到init方法里进行注册
|
||||
func RegisterCustomDriverValueConver(dialectColumnType string, customDriverValueConver ICustomDriverValueConver) error {
|
||||
if len(dialectColumnType) < 1 {
|
||||
return errors.New("->RegisterCustomDriverValueConver-->dialectColumnType为空")
|
||||
}
|
||||
dialectColumnTypes := strings.Split(dialectColumnType, ".")
|
||||
if len(dialectColumnTypes) < 2 {
|
||||
customDriverValueMap[strings.ToUpper(dialectColumnType)] = customDriverValueConver
|
||||
err := errors.New("->RegisterCustomDriverValueConver-->dialectColumnType 值是 Dialect.字段类型,例如: dm.TEXT ,本次正常运行,请尽快修改")
|
||||
FuncLogError(nil, err)
|
||||
} else {
|
||||
customDriverValueMap[strings.ToLower(dialectColumnTypes[0])+"."+strings.ToUpper(dialectColumnTypes[1])] = customDriverValueConver
|
||||
}
|
||||
iscdvm = true
|
||||
return nil
|
||||
}
|
||||
|
||||
type driverValueInfo struct {
|
||||
customDriverValueConver ICustomDriverValueConver
|
||||
columnType *sql.ColumnType
|
||||
tempDriverValue interface{}
|
||||
structFieldType *reflect.Type
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
import (
|
||||
// 00.引入数据库驱动
|
||||
"gitee.com/chunanyong/dm"
|
||||
"io"
|
||||
)
|
||||
|
||||
// CustomDMText 实现ICustomDriverValueConver接口,扩展自定义类型,例如 达梦数据库TEXT类型,映射出来的是dm.DmClob类型,无法使用string类型直接接收
|
||||
type CustomDMText struct{}
|
||||
|
||||
// GetDriverValue 根据数据库列类型,返回driver.Value的实例,struct属性类型
|
||||
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
|
||||
func (dmtext CustomDMText) GetDriverValue(ctx context.Context, columnType *sql.ColumnType, structFieldType *reflect.Type) (driver.Value, error) {
|
||||
// 如果需要使用structFieldType,需要先判断是否为nil
|
||||
// if structFieldType != nil {
|
||||
// }
|
||||
|
||||
return &dm.DmClob{}, nil
|
||||
}
|
||||
|
||||
// ConverDriverValue 数据库列类型,GetDriverValue返回的driver.Value的临时接收值,struct属性类型
|
||||
// map接收或者字段不存在,无法获取到structFieldType,会传入nil
|
||||
// 返回符合接收类型值的指针,指针,指针!!!!
|
||||
func (dmtext CustomDMText) ConverDriverValue(ctx context.Context, columnType *sql.ColumnType, tempDriverValue driver.Value, structFieldType *reflect.Type) (interface{}, error) {
|
||||
// 如果需要使用structFieldType,需要先判断是否为nil
|
||||
// if structFieldType != nil {
|
||||
// }
|
||||
|
||||
// 类型转换
|
||||
dmClob, isok := tempDriverValue.(*dm.DmClob)
|
||||
if !isok {
|
||||
return tempDriverValue, errors.New("->ConverDriverValue-->转换至*dm.DmClob类型失败")
|
||||
}
|
||||
if dmClob == nil || !dmClob.Valid {
|
||||
return new(string), nil
|
||||
}
|
||||
// 获取长度
|
||||
dmlen, errLength := dmClob.GetLength()
|
||||
if errLength != nil {
|
||||
return dmClob, errLength
|
||||
}
|
||||
|
||||
// int64转成int类型
|
||||
strInt64 := strconv.FormatInt(dmlen, 10)
|
||||
dmlenInt, errAtoi := strconv.Atoi(strInt64)
|
||||
if errAtoi != nil {
|
||||
return dmClob, errAtoi
|
||||
}
|
||||
|
||||
// 读取字符串
|
||||
str, errReadString := dmClob.ReadString(1, dmlenInt)
|
||||
|
||||
// 处理空字符串或NULL造成的EOF错误
|
||||
if errReadString == io.EOF {
|
||||
return new(string), nil
|
||||
}
|
||||
|
||||
return &str, errReadString
|
||||
}
|
||||
// RegisterCustomDriverValueConver 注册自定义的字段处理逻辑,用于驱动无法直接转换的场景,例如达梦的 TEXT 无法直接转化成 string
|
||||
// 一般是放到init方法里进行注册
|
||||
func init() {
|
||||
// dialectColumnType 值是 Dialect.字段类型 ,例如 dm.TEXT
|
||||
zorm.RegisterCustomDriverValueConver("dm.TEXT", CustomDMText{})
|
||||
}
|
||||
|
||||
**/
|
@ -0,0 +1,160 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
// IEntityStruct "struct"实体类的接口,所有的struct实体类都要实现这个接口
|
||||
// IEntityStruct The interface of the "struct" entity class, all struct entity classes must implement this interface
|
||||
type IEntityStruct interface {
|
||||
// 获取表名称
|
||||
// Get the table name.
|
||||
GetTableName() string
|
||||
|
||||
// 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称
|
||||
// Get the primary key field name of the database table. Because it is compatible with Map, it can only be the field name of the database
|
||||
GetPKColumnName() string
|
||||
|
||||
// GetPkSequence 主键序列
|
||||
// GetPkSequence Primary key sequence
|
||||
GetPkSequence() string
|
||||
}
|
||||
|
||||
// IEntityMap 使用Map保存数据,用于不方便使用struct的场景,如果主键是自增或者序列,不要"entityMap.Set"主键的值
|
||||
// IEntityMap Use Map to save data for scenarios where it is not convenient to use struct
|
||||
// If the primary key is auto-increment or sequence, do not "entity Map.Set" the value of the primary key
|
||||
type IEntityMap interface {
|
||||
// 获取表名称
|
||||
// Get the table name
|
||||
GetTableName() string
|
||||
|
||||
// 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称.
|
||||
// Get the primary key field name of the database table. Because it is compatible with Map, it can only be the field name of the database.
|
||||
GetPKColumnName() string
|
||||
|
||||
// GetEntityMapPkSequence 主键序列,不能使用GetPkSequence方法名,避免默认实现了IEntityStruct接口
|
||||
// GetEntityMapPkSequence primary key sequence, you cannot use the GetPkSequence method name, to avoid the default implementation of IEntityStruct interface
|
||||
GetEntityMapPkSequence() string
|
||||
|
||||
// GetDBFieldMap 针对Map类型,记录数据库字段
|
||||
// GetDBFieldMap For Map type, record database fields.
|
||||
GetDBFieldMap() map[string]interface{}
|
||||
|
||||
// GetDBFieldMapKey 按照Set的先后顺序记录key值,也就是数据库字段,用于SQL排序
|
||||
// GetDBFieldMapKey records the key value, that is, the database field, in the order of the Set, which is used for SQL sorting
|
||||
GetDBFieldMapKey() []string
|
||||
// 设置数据库字段的值
|
||||
// Set the value of a database field.
|
||||
Set(key string, value interface{}) map[string]interface{}
|
||||
}
|
||||
|
||||
// EntityStruct "IBaseEntity" 的基础实现,所有的实体类都匿名注入.这样就类似实现继承了,如果接口增加方法,调整这个默认实现即可
|
||||
// EntityStruct The basic implementation of "IBaseEntity", all entity classes are injected anonymously
|
||||
// This is similar to implementation inheritance. If the interface adds methods, adjust the default implementation
|
||||
type EntityStruct struct{}
|
||||
|
||||
// 默认数据库的主键列名
|
||||
// Primary key column name of the default database
|
||||
const defaultPkName = "id"
|
||||
|
||||
//GetTableName 获取表名称,必须有具体的Struct实现,类似java的抽象方法,避免手误忘记写表名.如果有扩展需求,建议使用接口进行扩展,不要默认实现GetTableName
|
||||
/*
|
||||
func (entity *EntityStruct) GetTableName() string {
|
||||
return ""
|
||||
}
|
||||
*/
|
||||
|
||||
// GetPKColumnName 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称
|
||||
// GetPKColumnName Get the primary key field name of the database table
|
||||
// Because it is compatible with Map, it can only be the field name of the database
|
||||
func (entity *EntityStruct) GetPKColumnName() string {
|
||||
return defaultPkName
|
||||
}
|
||||
|
||||
// var defaultPkSequence = make(map[string]string, 0)
|
||||
|
||||
// GetPkSequence 主键序列
|
||||
// GetPkSequence Primary key sequence
|
||||
func (entity *EntityStruct) GetPkSequence() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
//-------------------------------------------------------------------------//
|
||||
|
||||
// EntityMap IEntityMap的基础实现,可以直接使用或者匿名注入
|
||||
type EntityMap struct {
|
||||
// 表名
|
||||
tableName string
|
||||
// 主键列名
|
||||
PkColumnName string
|
||||
// 主键序列,如果有值,优先级最高
|
||||
PkSequence string
|
||||
// 数据库字段,不暴露外部
|
||||
dbFieldMap map[string]interface{}
|
||||
// 列名,记录顺序
|
||||
dbFieldMapKey []string
|
||||
}
|
||||
|
||||
// NewEntityMap 初始化Map,必须传入表名称
|
||||
func NewEntityMap(tbName string) *EntityMap {
|
||||
entityMap := EntityMap{}
|
||||
entityMap.dbFieldMap = map[string]interface{}{}
|
||||
entityMap.tableName = tbName
|
||||
entityMap.PkColumnName = defaultPkName
|
||||
entityMap.dbFieldMapKey = make([]string, 0)
|
||||
return &entityMap
|
||||
}
|
||||
|
||||
// GetTableName 获取表名称
|
||||
func (entity *EntityMap) GetTableName() string {
|
||||
return entity.tableName
|
||||
}
|
||||
|
||||
// GetPKColumnName 获取数据库表的主键字段名称.因为要兼容Map,只能是数据库的字段名称
|
||||
func (entity *EntityMap) GetPKColumnName() string {
|
||||
return entity.PkColumnName
|
||||
}
|
||||
|
||||
// GetEntityMapPkSequence 主键序列,不能使用GetPkSequence方法名,避免默认实现了IEntityStruct接口
|
||||
// GetEntityMapPkSequence primary key sequence, you cannot use the GetPkSequence method name, to avoid the default implementation of IEntityStruct interface
|
||||
func (entity *EntityMap) GetEntityMapPkSequence() string {
|
||||
return entity.PkSequence
|
||||
}
|
||||
|
||||
// GetDBFieldMap 针对Map类型,记录数据库字段
|
||||
// GetDBFieldMap For Map type, record database fields
|
||||
func (entity *EntityMap) GetDBFieldMap() map[string]interface{} {
|
||||
return entity.dbFieldMap
|
||||
}
|
||||
|
||||
// GetDBFieldMapKey 按照Set的先后顺序记录key值,也就是数据库字段,用于SQL排序
|
||||
// GetDBFieldMapKey records the key value, that is, the database field, in the order of the Set, which is used for SQL sorting
|
||||
func (entity *EntityMap) GetDBFieldMapKey() []string {
|
||||
return entity.dbFieldMapKey
|
||||
}
|
||||
|
||||
// Set 设置数据库字段
|
||||
// Set Set database fields
|
||||
func (entity *EntityMap) Set(key string, value interface{}) map[string]interface{} {
|
||||
_, ok := entity.dbFieldMap[key]
|
||||
if !ok { // 如果不存在
|
||||
entity.dbFieldMapKey = append(entity.dbFieldMapKey, key)
|
||||
}
|
||||
entity.dbFieldMap[key] = value
|
||||
|
||||
return entity.dbFieldMap
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import "context"
|
||||
|
||||
// IGlobalTransaction 托管全局分布式事务接口
|
||||
type IGlobalTransaction interface {
|
||||
// BeginGTX 开启全局分布式事务
|
||||
BeginGTX(ctx context.Context, globalRootContext context.Context) error
|
||||
|
||||
// CommitGTX 提交全局分布式事务.不能命名为 Commit,不然就和gtx的Commit一致了,就递归调用自己了.......
|
||||
CommitGTX(ctx context.Context, globalRootContext context.Context) error
|
||||
|
||||
// RollbackGTX 回滚全局分布式事务
|
||||
RollbackGTX(ctx context.Context, globalRootContext context.Context) error
|
||||
|
||||
// GetGTXID 获取全局分布式事务的XID
|
||||
GetGTXID(ctx context.Context, globalRootContext context.Context) (string, error)
|
||||
|
||||
// 重新包装为 seata/hptx 的context.RootContext
|
||||
// context.RootContext 如果后续使用了 context.WithValue,类型就是context.valueCtx 就会造成无法再类型断言为 context.RootContext
|
||||
// 所以DBDao里使用了 globalRootContext变量,区分业务的ctx和分布式事务的RootContext
|
||||
// NewRootContext(ctx context.Context) context.Context
|
||||
}
|
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// 设置默认的日志显示信息,显示文件和行号
|
||||
// Set the default log display information, display file and line number.
|
||||
log.SetFlags(log.Llongfile | log.LstdFlags)
|
||||
}
|
||||
|
||||
// LogCallDepth 记录日志调用层级,用于定位到业务层代码
|
||||
// Log Call Depth Record the log call level, used to locate the business layer code
|
||||
var LogCallDepth = 4
|
||||
|
||||
// FuncLogError 记录error日志.NewDBDao方法里的异常,ctx为nil,扩展时请注意
|
||||
// FuncLogError Record error log
|
||||
var FuncLogError func(ctx context.Context, err error) = defaultLogError
|
||||
|
||||
// FuncLogPanic 记录panic日志,默认使用"defaultLogError"实现
|
||||
// FuncLogPanic Record panic log, using "defaultLogError" by default
|
||||
var FuncLogPanic func(ctx context.Context, err error) = defaultLogPanic
|
||||
|
||||
// FuncPrintSQL 打印sql语句,参数和执行时间,小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
|
||||
// FuncPrintSQL Print sql statement and parameters
|
||||
var FuncPrintSQL func(ctx context.Context, sqlstr string, args []interface{}, execSQLMillis int64) = defaultPrintSQL
|
||||
|
||||
func defaultLogError(ctx context.Context, err error) {
|
||||
log.Output(LogCallDepth, fmt.Sprintln(err))
|
||||
}
|
||||
|
||||
func defaultLogPanic(ctx context.Context, err error) {
|
||||
defaultLogError(ctx, err)
|
||||
}
|
||||
|
||||
func defaultPrintSQL(ctx context.Context, sqlstr string, args []interface{}, execSQLMillis int64) {
|
||||
if args != nil {
|
||||
log.Output(LogCallDepth, fmt.Sprintln("sql:", sqlstr, ",args:", args, ",execSQLMillis:", execSQLMillis))
|
||||
} else {
|
||||
log.Output(LogCallDepth, fmt.Sprintln("sql:", sqlstr, ",args: [] ", ",execSQLMillis:", execSQLMillis))
|
||||
}
|
||||
}
|
@ -0,0 +1,81 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
// Page 分页对象
|
||||
// Page Pagination object
|
||||
type Page struct {
|
||||
// 当前页码,从1开始
|
||||
// Current page number, starting from 1
|
||||
PageNo int
|
||||
|
||||
// 每页多少条,默认20条
|
||||
// How many items per page, 20 items by default
|
||||
PageSize int
|
||||
|
||||
// 数据总条数
|
||||
// Total number of data
|
||||
TotalCount int
|
||||
|
||||
// 共多少页
|
||||
// How many pages
|
||||
PageCount int
|
||||
|
||||
// 是否是第一页
|
||||
// Is it the first page
|
||||
FirstPage bool
|
||||
|
||||
// 是否有上一页
|
||||
// Whether there is a previous page
|
||||
HasPrev bool
|
||||
|
||||
// 是否有下一页
|
||||
// Is there a next page
|
||||
HasNext bool
|
||||
|
||||
// 是否是最后一页
|
||||
// Is it the last page
|
||||
LastPage bool
|
||||
}
|
||||
|
||||
// NewPage 创建Page对象
|
||||
// NewPage Create Page object
|
||||
func NewPage() *Page {
|
||||
page := Page{}
|
||||
page.PageNo = 1
|
||||
page.PageSize = 20
|
||||
return &page
|
||||
}
|
||||
|
||||
// setTotalCount 设置总条数,计算其他值
|
||||
// setTotalCount Set the total number of bars, calculate other values
|
||||
func (page *Page) setTotalCount(total int) {
|
||||
page.TotalCount = total
|
||||
page.PageCount = (page.TotalCount + page.PageSize - 1) / page.PageSize
|
||||
if page.PageNo >= page.PageCount {
|
||||
page.LastPage = true
|
||||
} else {
|
||||
page.HasNext = true
|
||||
}
|
||||
if page.PageNo > 1 {
|
||||
page.HasPrev = true
|
||||
} else {
|
||||
page.FirstPage = true
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,318 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// dataSorce对象,隔离sql原生对象
|
||||
// dataSorce Isolate sql native objects
|
||||
type dataSource struct {
|
||||
*sql.DB
|
||||
// config *DataSourceConfig
|
||||
}
|
||||
|
||||
// newDataSource 创建一个新的datasource,内部调用,避免外部直接使用datasource
|
||||
// newDAtaSource Create a new datasource and call it internally to avoid direct external use of the datasource
|
||||
func newDataSource(config *DataSourceConfig) (*dataSource, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("->newDataSource-->config cannot be nil")
|
||||
}
|
||||
|
||||
if config.DriverName == "" {
|
||||
return nil, errors.New("->newDataSource-->DriverName cannot be empty")
|
||||
}
|
||||
// 兼容处理,DBType即将废弃,请使用Dialect属性
|
||||
if config.DBType != "" && config.Dialect == "" {
|
||||
FuncLogError(nil, errors.New("->newDataSource-->DataSourceConfig的DBType即将废弃,请使用Dialect属性"))
|
||||
config.Dialect = config.DBType
|
||||
}
|
||||
if config.Dialect == "" {
|
||||
return nil, errors.New("->newDataSource-->Dialect cannot be empty")
|
||||
}
|
||||
var db *sql.DB
|
||||
var errSQLOpen error
|
||||
|
||||
if config.SQLDB == nil { // 没有已经存在的数据库连接,使用DSN初始化
|
||||
if config.DSN == "" {
|
||||
return nil, errors.New("->newDataSource-->DSN cannot be empty")
|
||||
}
|
||||
db, errSQLOpen = sql.Open(config.DriverName, config.DSN)
|
||||
if errSQLOpen != nil {
|
||||
errSQLOpen = fmt.Errorf("->newDataSource-->open数据库打开失败:%w", errSQLOpen)
|
||||
FuncLogError(nil, errSQLOpen)
|
||||
return nil, errSQLOpen
|
||||
}
|
||||
} else { // 使用已经存在的数据库连接
|
||||
db = config.SQLDB
|
||||
}
|
||||
|
||||
if config.MaxOpenConns == 0 {
|
||||
config.MaxOpenConns = 50
|
||||
}
|
||||
if config.MaxIdleConns == 0 {
|
||||
config.MaxIdleConns = 50
|
||||
}
|
||||
|
||||
if config.ConnMaxLifetimeSecond == 0 {
|
||||
config.ConnMaxLifetimeSecond = 600
|
||||
}
|
||||
|
||||
// 设置数据库最大连接数
|
||||
// Set the maximum number of database connections
|
||||
db.SetMaxOpenConns(config.MaxOpenConns)
|
||||
// 设置数据库最大空闲连接数
|
||||
// Set the maximum number of free connections to the database
|
||||
db.SetMaxIdleConns(config.MaxIdleConns)
|
||||
//连接存活秒时间. 默认600(10分钟)后连接被销毁重建.避免数据库主动断开连接,造成死连接.MySQL默认wait_timeout 28800秒(8小时)
|
||||
//(Connection survival time in seconds) Destroy and rebuild the connection after the default 600 seconds (10 minutes)
|
||||
//Prevent the database from actively disconnecting and causing dead connections. MySQL Default wait_timeout 28800 seconds
|
||||
db.SetConnMaxLifetime(time.Second * time.Duration(config.ConnMaxLifetimeSecond))
|
||||
|
||||
// 验证连接
|
||||
if pingerr := db.Ping(); pingerr != nil {
|
||||
pingerr = fmt.Errorf("->newDataSource-->ping数据库失败:%w", pingerr)
|
||||
FuncLogError(nil, pingerr)
|
||||
db.Close()
|
||||
return nil, pingerr
|
||||
}
|
||||
|
||||
return &dataSource{db}, nil
|
||||
}
|
||||
|
||||
// 事务参照:https://www.jianshu.com/p/2a144332c3db
|
||||
// Transaction reference: https://www.jianshu.com/p/2a144332c3db
|
||||
|
||||
// dataBaseConnection 数据库dbConnection会话,可以原生查询或者事务
|
||||
// dataBaseConnection Database session, native query or transaction.
|
||||
type dataBaseConnection struct {
|
||||
// 原生db
|
||||
// native db
|
||||
db *sql.DB
|
||||
|
||||
// 原生事务
|
||||
// native transaction
|
||||
tx *sql.Tx
|
||||
|
||||
// 数据库配置
|
||||
config *DataSourceConfig
|
||||
}
|
||||
|
||||
// beginTx 开启事务
|
||||
// beginTx Open transaction
|
||||
func (dbConnection *dataBaseConnection) beginTx(ctx context.Context) error {
|
||||
if dbConnection.tx != nil {
|
||||
return nil
|
||||
}
|
||||
// 设置事务配置,主要是隔离级别
|
||||
var txOptions *sql.TxOptions
|
||||
contextTxOptions := ctx.Value(contextTxOptionsKey)
|
||||
if contextTxOptions != nil {
|
||||
txOptions, _ = contextTxOptions.(*sql.TxOptions)
|
||||
} else {
|
||||
txOptions = dbConnection.config.DefaultTxOptions
|
||||
}
|
||||
|
||||
tx, err := dbConnection.db.BeginTx(ctx, txOptions)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->beginTx事务开启失败:%w", err)
|
||||
return err
|
||||
}
|
||||
dbConnection.tx = tx
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollback 回滚事务
|
||||
// rollback Rollback transaction
|
||||
func (dbConnection *dataBaseConnection) rollback() error {
|
||||
if dbConnection.tx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := dbConnection.tx.Rollback()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->rollback事务回滚失败:%w", err)
|
||||
return err
|
||||
}
|
||||
dbConnection.tx = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// commit 提交事务
|
||||
// commit Commit transaction
|
||||
func (dbConnection *dataBaseConnection) commit() error {
|
||||
if dbConnection.tx == nil {
|
||||
return errors.New("->dbConnection.commit()事务为空")
|
||||
}
|
||||
|
||||
err := dbConnection.tx.Commit()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->dbConnection.commit()事务提交失败:%w", err)
|
||||
return err
|
||||
}
|
||||
dbConnection.tx = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// execContext 执行sql语句,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
|
||||
// execContext Execute sql statement,If the transaction has been opened,it will be executed in transaction mode, if the transaction is not opened,it will be executed in non-transactional mode
|
||||
func (dbConnection *dataBaseConnection) execContext(ctx context.Context, sqlstr *string, argsValues *[]interface{}) (*sql.Result, error) {
|
||||
// reBindSQL 重新处理参数代入方式
|
||||
execsql, args, err := reBindSQL(dbConnection.config.Dialect, sqlstr, argsValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 更新语句处理ClickHouse特殊语法
|
||||
err = reUpdateSQL(dbConnection.config.Dialect, execsql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 执行前加入 hint
|
||||
err = wrapSQLHint(ctx, execsql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var start *time.Time
|
||||
var res sql.Result
|
||||
// 小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
|
||||
slowSQLMillis := dbConnection.config.SlowSQLMillis
|
||||
if slowSQLMillis == 0 {
|
||||
FuncPrintSQL(ctx, *execsql, *args, 0)
|
||||
} else if slowSQLMillis > 0 {
|
||||
now := time.Now() // 获取当前时间
|
||||
start = &now
|
||||
}
|
||||
if dbConnection.tx != nil {
|
||||
res, err = dbConnection.tx.ExecContext(ctx, *execsql, *args...)
|
||||
} else {
|
||||
res, err = dbConnection.db.ExecContext(ctx, *execsql, *args...)
|
||||
}
|
||||
if slowSQLMillis > 0 {
|
||||
slow := time.Since(*start).Milliseconds()
|
||||
if slow-int64(slowSQLMillis) >= 0 {
|
||||
FuncPrintSQL(ctx, *execsql, *args, slow)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->execContext执行错误:%w,-->zormErrorExecSQL:%s,-->zormErrorSQLValues:%v", err, *execsql, *args)
|
||||
}
|
||||
return &res, err
|
||||
}
|
||||
|
||||
// queryRowContext 如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
|
||||
func (dbConnection *dataBaseConnection) queryRowContext(ctx context.Context, sqlstr *string, argsValues *[]interface{}) (*sql.Row, error) {
|
||||
// reBindSQL 重新处理参数代入方式
|
||||
query, args, err := reBindSQL(dbConnection.config.Dialect, sqlstr, argsValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 执行前加入 hint
|
||||
err = wrapSQLHint(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var start *time.Time
|
||||
var row *sql.Row
|
||||
// 小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
|
||||
slowSQLMillis := dbConnection.config.SlowSQLMillis
|
||||
if slowSQLMillis == 0 {
|
||||
FuncPrintSQL(ctx, *query, *args, 0)
|
||||
} else if slowSQLMillis > 0 {
|
||||
now := time.Now() // 获取当前时间
|
||||
start = &now
|
||||
}
|
||||
|
||||
if dbConnection.tx != nil {
|
||||
row = dbConnection.tx.QueryRowContext(ctx, *query, *args...)
|
||||
} else {
|
||||
row = dbConnection.db.QueryRowContext(ctx, *query, *args...)
|
||||
}
|
||||
if slowSQLMillis > 0 {
|
||||
slow := time.Since(*start).Milliseconds()
|
||||
if slow-int64(slowSQLMillis) >= 0 {
|
||||
FuncPrintSQL(ctx, *query, *args, slow)
|
||||
}
|
||||
}
|
||||
return row, nil
|
||||
}
|
||||
|
||||
// queryContext 查询数据,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
|
||||
// queryRowContext Execute sql row statement,If the transaction has been opened,it will be executed in transaction mode, if the transaction is not opened,it will be executed in non-transactional mode
|
||||
func (dbConnection *dataBaseConnection) queryContext(ctx context.Context, sqlstr *string, argsValues *[]interface{}) (*sql.Rows, error) {
|
||||
// reBindSQL 重新处理参数代入方式
|
||||
query, args, err := reBindSQL(dbConnection.config.Dialect, sqlstr, argsValues)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 执行前加入 hint
|
||||
err = wrapSQLHint(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var start *time.Time
|
||||
var rows *sql.Rows
|
||||
// 小于0是禁用日志输出;等于0是只输出日志,不计算SQ执行时间;大于0是计算执行时间,并且大于指定值
|
||||
slowSQLMillis := dbConnection.config.SlowSQLMillis
|
||||
if slowSQLMillis == 0 {
|
||||
FuncPrintSQL(ctx, *query, *args, 0)
|
||||
} else if slowSQLMillis > 0 {
|
||||
now := time.Now() // 获取当前时间
|
||||
start = &now
|
||||
}
|
||||
|
||||
if dbConnection.tx != nil {
|
||||
rows, err = dbConnection.tx.QueryContext(ctx, *query, *args...)
|
||||
} else {
|
||||
rows, err = dbConnection.db.QueryContext(ctx, *query, *args...)
|
||||
}
|
||||
if slowSQLMillis > 0 {
|
||||
slow := time.Since(*start).Milliseconds()
|
||||
if slow-int64(slowSQLMillis) >= 0 {
|
||||
FuncPrintSQL(ctx, *query, *args, slow)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
err = fmt.Errorf("->queryContext执行错误:%w,-->zormErrorExecSQL:%s,-->zormErrorSQLValues:%v", err, *query, *args)
|
||||
}
|
||||
return rows, err
|
||||
}
|
||||
|
||||
/*
|
||||
// prepareContext 预执行,如果已经开启事务,就以事务方式执行,如果没有开启事务,就以非事务方式执行
|
||||
// prepareContext Pre-execution,If the transaction has been opened,it will be executed in transaction mode,if the transaction is not opened,it will be executed in non-transactional mode
|
||||
func (dbConnection *dataBaseConnection) prepareContext(ctx context.Context, query *string) (*sql.Stmt, error) {
|
||||
//打印SQL
|
||||
//print SQL
|
||||
if dbConnection.config.PrintSQL {
|
||||
//logger.Info("printSQL", logger.String("sql", query))
|
||||
FuncPrintSQL(ctx,*query, nil)
|
||||
}
|
||||
|
||||
if dbConnection.tx != nil {
|
||||
return dbConnection.tx.PrepareContext(ctx, *query)
|
||||
}
|
||||
|
||||
return dbConnection.db.PrepareContext(ctx, *query)
|
||||
}
|
||||
*/
|
@ -0,0 +1,415 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Multiprecision decimal numbers.
|
||||
// For floating-point formatting only; not general purpose.
|
||||
// Only operations are assign and (binary) left/right shift.
|
||||
// Can do binary floating point in multiprecision decimal precisely
|
||||
// because 2 divides 10; cannot do decimal floating point
|
||||
// in multiprecision binary precisely.
|
||||
|
||||
package decimal
|
||||
|
||||
type decimal struct {
|
||||
d [800]byte // digits, big-endian representation
|
||||
nd int // number of digits used
|
||||
dp int // decimal point
|
||||
neg bool // negative flag
|
||||
trunc bool // discarded nonzero digits beyond d[:nd]
|
||||
}
|
||||
|
||||
func (a *decimal) String() string {
|
||||
n := 10 + a.nd
|
||||
if a.dp > 0 {
|
||||
n += a.dp
|
||||
}
|
||||
if a.dp < 0 {
|
||||
n += -a.dp
|
||||
}
|
||||
|
||||
buf := make([]byte, n)
|
||||
w := 0
|
||||
switch {
|
||||
case a.nd == 0:
|
||||
return "0"
|
||||
|
||||
case a.dp <= 0:
|
||||
// zeros fill space between decimal point and digits
|
||||
buf[w] = '0'
|
||||
w++
|
||||
buf[w] = '.'
|
||||
w++
|
||||
w += digitZero(buf[w : w+-a.dp])
|
||||
w += copy(buf[w:], a.d[0:a.nd])
|
||||
|
||||
case a.dp < a.nd:
|
||||
// decimal point in middle of digits
|
||||
w += copy(buf[w:], a.d[0:a.dp])
|
||||
buf[w] = '.'
|
||||
w++
|
||||
w += copy(buf[w:], a.d[a.dp:a.nd])
|
||||
|
||||
default:
|
||||
// zeros fill space between digits and decimal point
|
||||
w += copy(buf[w:], a.d[0:a.nd])
|
||||
w += digitZero(buf[w : w+a.dp-a.nd])
|
||||
}
|
||||
return string(buf[0:w])
|
||||
}
|
||||
|
||||
func digitZero(dst []byte) int {
|
||||
for i := range dst {
|
||||
dst[i] = '0'
|
||||
}
|
||||
return len(dst)
|
||||
}
|
||||
|
||||
// trim trailing zeros from number.
|
||||
// (They are meaningless; the decimal point is tracked
|
||||
// independent of the number of digits.)
|
||||
func trim(a *decimal) {
|
||||
for a.nd > 0 && a.d[a.nd-1] == '0' {
|
||||
a.nd--
|
||||
}
|
||||
if a.nd == 0 {
|
||||
a.dp = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Assign v to a.
|
||||
func (a *decimal) Assign(v uint64) {
|
||||
var buf [24]byte
|
||||
|
||||
// Write reversed decimal in buf.
|
||||
n := 0
|
||||
for v > 0 {
|
||||
v1 := v / 10
|
||||
v -= 10 * v1
|
||||
buf[n] = byte(v + '0')
|
||||
n++
|
||||
v = v1
|
||||
}
|
||||
|
||||
// Reverse again to produce forward decimal in a.d.
|
||||
a.nd = 0
|
||||
for n--; n >= 0; n-- {
|
||||
a.d[a.nd] = buf[n]
|
||||
a.nd++
|
||||
}
|
||||
a.dp = a.nd
|
||||
trim(a)
|
||||
}
|
||||
|
||||
// Maximum shift that we can do in one pass without overflow.
|
||||
// A uint has 32 or 64 bits, and we have to be able to accommodate 9<<k.
|
||||
const uintSize = 32 << (^uint(0) >> 63)
|
||||
const maxShift = uintSize - 4
|
||||
|
||||
// Binary shift right (/ 2) by k bits. k <= maxShift to avoid overflow.
|
||||
func rightShift(a *decimal, k uint) {
|
||||
r := 0 // read pointer
|
||||
w := 0 // write pointer
|
||||
|
||||
// Pick up enough leading digits to cover first shift.
|
||||
var n uint
|
||||
for ; n>>k == 0; r++ {
|
||||
if r >= a.nd {
|
||||
if n == 0 {
|
||||
// a == 0; shouldn't get here, but handle anyway.
|
||||
a.nd = 0
|
||||
return
|
||||
}
|
||||
for n>>k == 0 {
|
||||
n = n * 10
|
||||
r++
|
||||
}
|
||||
break
|
||||
}
|
||||
c := uint(a.d[r])
|
||||
n = n*10 + c - '0'
|
||||
}
|
||||
a.dp -= r - 1
|
||||
|
||||
var mask uint = (1 << k) - 1
|
||||
|
||||
// Pick up a digit, put down a digit.
|
||||
for ; r < a.nd; r++ {
|
||||
c := uint(a.d[r])
|
||||
dig := n >> k
|
||||
n &= mask
|
||||
a.d[w] = byte(dig + '0')
|
||||
w++
|
||||
n = n*10 + c - '0'
|
||||
}
|
||||
|
||||
// Put down extra digits.
|
||||
for n > 0 {
|
||||
dig := n >> k
|
||||
n &= mask
|
||||
if w < len(a.d) {
|
||||
a.d[w] = byte(dig + '0')
|
||||
w++
|
||||
} else if dig > 0 {
|
||||
a.trunc = true
|
||||
}
|
||||
n = n * 10
|
||||
}
|
||||
|
||||
a.nd = w
|
||||
trim(a)
|
||||
}
|
||||
|
||||
// Cheat sheet for left shift: table indexed by shift count giving
|
||||
// number of new digits that will be introduced by that shift.
|
||||
//
|
||||
// For example, leftcheats[4] = {2, "625"}. That means that
|
||||
// if we are shifting by 4 (multiplying by 16), it will add 2 digits
|
||||
// when the string prefix is "625" through "999", and one fewer digit
|
||||
// if the string prefix is "000" through "624".
|
||||
//
|
||||
// Credit for this trick goes to Ken.
|
||||
|
||||
type leftCheat struct {
|
||||
delta int // number of new digits
|
||||
cutoff string // minus one digit if original < a.
|
||||
}
|
||||
|
||||
var leftcheats = []leftCheat{
|
||||
// Leading digits of 1/2^i = 5^i.
|
||||
// 5^23 is not an exact 64-bit floating point number,
|
||||
// so have to use bc for the math.
|
||||
// Go up to 60 to be large enough for 32bit and 64bit platforms.
|
||||
/*
|
||||
seq 60 | sed 's/^/5^/' | bc |
|
||||
awk 'BEGIN{ print "\t{ 0, \"\" }," }
|
||||
{
|
||||
log2 = log(2)/log(10)
|
||||
printf("\t{ %d, \"%s\" },\t// * %d\n",
|
||||
int(log2*NR+1), $0, 2**NR)
|
||||
}'
|
||||
*/
|
||||
{0, ""},
|
||||
{1, "5"}, // * 2
|
||||
{1, "25"}, // * 4
|
||||
{1, "125"}, // * 8
|
||||
{2, "625"}, // * 16
|
||||
{2, "3125"}, // * 32
|
||||
{2, "15625"}, // * 64
|
||||
{3, "78125"}, // * 128
|
||||
{3, "390625"}, // * 256
|
||||
{3, "1953125"}, // * 512
|
||||
{4, "9765625"}, // * 1024
|
||||
{4, "48828125"}, // * 2048
|
||||
{4, "244140625"}, // * 4096
|
||||
{4, "1220703125"}, // * 8192
|
||||
{5, "6103515625"}, // * 16384
|
||||
{5, "30517578125"}, // * 32768
|
||||
{5, "152587890625"}, // * 65536
|
||||
{6, "762939453125"}, // * 131072
|
||||
{6, "3814697265625"}, // * 262144
|
||||
{6, "19073486328125"}, // * 524288
|
||||
{7, "95367431640625"}, // * 1048576
|
||||
{7, "476837158203125"}, // * 2097152
|
||||
{7, "2384185791015625"}, // * 4194304
|
||||
{7, "11920928955078125"}, // * 8388608
|
||||
{8, "59604644775390625"}, // * 16777216
|
||||
{8, "298023223876953125"}, // * 33554432
|
||||
{8, "1490116119384765625"}, // * 67108864
|
||||
{9, "7450580596923828125"}, // * 134217728
|
||||
{9, "37252902984619140625"}, // * 268435456
|
||||
{9, "186264514923095703125"}, // * 536870912
|
||||
{10, "931322574615478515625"}, // * 1073741824
|
||||
{10, "4656612873077392578125"}, // * 2147483648
|
||||
{10, "23283064365386962890625"}, // * 4294967296
|
||||
{10, "116415321826934814453125"}, // * 8589934592
|
||||
{11, "582076609134674072265625"}, // * 17179869184
|
||||
{11, "2910383045673370361328125"}, // * 34359738368
|
||||
{11, "14551915228366851806640625"}, // * 68719476736
|
||||
{12, "72759576141834259033203125"}, // * 137438953472
|
||||
{12, "363797880709171295166015625"}, // * 274877906944
|
||||
{12, "1818989403545856475830078125"}, // * 549755813888
|
||||
{13, "9094947017729282379150390625"}, // * 1099511627776
|
||||
{13, "45474735088646411895751953125"}, // * 2199023255552
|
||||
{13, "227373675443232059478759765625"}, // * 4398046511104
|
||||
{13, "1136868377216160297393798828125"}, // * 8796093022208
|
||||
{14, "5684341886080801486968994140625"}, // * 17592186044416
|
||||
{14, "28421709430404007434844970703125"}, // * 35184372088832
|
||||
{14, "142108547152020037174224853515625"}, // * 70368744177664
|
||||
{15, "710542735760100185871124267578125"}, // * 140737488355328
|
||||
{15, "3552713678800500929355621337890625"}, // * 281474976710656
|
||||
{15, "17763568394002504646778106689453125"}, // * 562949953421312
|
||||
{16, "88817841970012523233890533447265625"}, // * 1125899906842624
|
||||
{16, "444089209850062616169452667236328125"}, // * 2251799813685248
|
||||
{16, "2220446049250313080847263336181640625"}, // * 4503599627370496
|
||||
{16, "11102230246251565404236316680908203125"}, // * 9007199254740992
|
||||
{17, "55511151231257827021181583404541015625"}, // * 18014398509481984
|
||||
{17, "277555756156289135105907917022705078125"}, // * 36028797018963968
|
||||
{17, "1387778780781445675529539585113525390625"}, // * 72057594037927936
|
||||
{18, "6938893903907228377647697925567626953125"}, // * 144115188075855872
|
||||
{18, "34694469519536141888238489627838134765625"}, // * 288230376151711744
|
||||
{18, "173472347597680709441192448139190673828125"}, // * 576460752303423488
|
||||
{19, "867361737988403547205962240695953369140625"}, // * 1152921504606846976
|
||||
}
|
||||
|
||||
// Is the leading prefix of b lexicographically less than s?
|
||||
func prefixIsLessThan(b []byte, s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if i >= len(b) {
|
||||
return true
|
||||
}
|
||||
if b[i] != s[i] {
|
||||
return b[i] < s[i]
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Binary shift left (* 2) by k bits. k <= maxShift to avoid overflow.
|
||||
func leftShift(a *decimal, k uint) {
|
||||
delta := leftcheats[k].delta
|
||||
if prefixIsLessThan(a.d[0:a.nd], leftcheats[k].cutoff) {
|
||||
delta--
|
||||
}
|
||||
|
||||
r := a.nd // read index
|
||||
w := a.nd + delta // write index
|
||||
|
||||
// Pick up a digit, put down a digit.
|
||||
var n uint
|
||||
for r--; r >= 0; r-- {
|
||||
n += (uint(a.d[r]) - '0') << k
|
||||
quo := n / 10
|
||||
rem := n - 10*quo
|
||||
w--
|
||||
if w < len(a.d) {
|
||||
a.d[w] = byte(rem + '0')
|
||||
} else if rem != 0 {
|
||||
a.trunc = true
|
||||
}
|
||||
n = quo
|
||||
}
|
||||
|
||||
// Put down extra digits.
|
||||
for n > 0 {
|
||||
quo := n / 10
|
||||
rem := n - 10*quo
|
||||
w--
|
||||
if w < len(a.d) {
|
||||
a.d[w] = byte(rem + '0')
|
||||
} else if rem != 0 {
|
||||
a.trunc = true
|
||||
}
|
||||
n = quo
|
||||
}
|
||||
|
||||
a.nd += delta
|
||||
if a.nd >= len(a.d) {
|
||||
a.nd = len(a.d)
|
||||
}
|
||||
a.dp += delta
|
||||
trim(a)
|
||||
}
|
||||
|
||||
// Binary shift left (k > 0) or right (k < 0).
|
||||
func (a *decimal) Shift(k int) {
|
||||
switch {
|
||||
case a.nd == 0:
|
||||
// nothing to do: a == 0
|
||||
case k > 0:
|
||||
for k > maxShift {
|
||||
leftShift(a, maxShift)
|
||||
k -= maxShift
|
||||
}
|
||||
leftShift(a, uint(k))
|
||||
case k < 0:
|
||||
for k < -maxShift {
|
||||
rightShift(a, maxShift)
|
||||
k += maxShift
|
||||
}
|
||||
rightShift(a, uint(-k))
|
||||
}
|
||||
}
|
||||
|
||||
// If we chop a at nd digits, should we round up?
|
||||
func shouldRoundUp(a *decimal, nd int) bool {
|
||||
if nd < 0 || nd >= a.nd {
|
||||
return false
|
||||
}
|
||||
if a.d[nd] == '5' && nd+1 == a.nd { // exactly halfway - round to even
|
||||
// if we truncated, a little higher than what's recorded - always round up
|
||||
if a.trunc {
|
||||
return true
|
||||
}
|
||||
return nd > 0 && (a.d[nd-1]-'0')%2 != 0
|
||||
}
|
||||
// not halfway - digit tells all
|
||||
return a.d[nd] >= '5'
|
||||
}
|
||||
|
||||
// Round a to nd digits (or fewer).
|
||||
// If nd is zero, it means we're rounding
|
||||
// just to the left of the digits, as in
|
||||
// 0.09 -> 0.1.
|
||||
func (a *decimal) Round(nd int) {
|
||||
if nd < 0 || nd >= a.nd {
|
||||
return
|
||||
}
|
||||
if shouldRoundUp(a, nd) {
|
||||
a.RoundUp(nd)
|
||||
} else {
|
||||
a.RoundDown(nd)
|
||||
}
|
||||
}
|
||||
|
||||
// Round a down to nd digits (or fewer).
|
||||
func (a *decimal) RoundDown(nd int) {
|
||||
if nd < 0 || nd >= a.nd {
|
||||
return
|
||||
}
|
||||
a.nd = nd
|
||||
trim(a)
|
||||
}
|
||||
|
||||
// Round a up to nd digits (or fewer).
|
||||
func (a *decimal) RoundUp(nd int) {
|
||||
if nd < 0 || nd >= a.nd {
|
||||
return
|
||||
}
|
||||
|
||||
// round up
|
||||
for i := nd - 1; i >= 0; i-- {
|
||||
c := a.d[i]
|
||||
if c < '9' { // can stop after this digit
|
||||
a.d[i]++
|
||||
a.nd = i + 1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Number is all 9s.
|
||||
// Change to single 1 with adjusted decimal point.
|
||||
a.d[0] = '1'
|
||||
a.nd = 1
|
||||
a.dp++
|
||||
}
|
||||
|
||||
// Extract integer part, rounded appropriately.
|
||||
// No guarantees about overflow.
|
||||
func (a *decimal) RoundedInteger() uint64 {
|
||||
if a.dp > 20 {
|
||||
return 0xFFFFFFFFFFFFFFFF
|
||||
}
|
||||
var i int
|
||||
n := uint64(0)
|
||||
for i = 0; i < a.dp && i < a.nd; i++ {
|
||||
n = n*10 + uint64(a.d[i]-'0')
|
||||
}
|
||||
for ; i < a.dp; i++ {
|
||||
n *= 10
|
||||
}
|
||||
if shouldRoundUp(a, a.dp) {
|
||||
n++
|
||||
}
|
||||
return n
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,160 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Multiprecision decimal numbers.
|
||||
// For floating-point formatting only; not general purpose.
|
||||
// Only operations are assign and (binary) left/right shift.
|
||||
// Can do binary floating point in multiprecision decimal precisely
|
||||
// because 2 divides 10; cannot do decimal floating point
|
||||
// in multiprecision binary precisely.
|
||||
|
||||
package decimal
|
||||
|
||||
type floatInfo struct {
|
||||
mantbits uint
|
||||
expbits uint
|
||||
bias int
|
||||
}
|
||||
|
||||
var float32info = floatInfo{23, 8, -127}
|
||||
var float64info = floatInfo{52, 11, -1023}
|
||||
|
||||
// roundShortest rounds d (= mant * 2^exp) to the shortest number of digits
|
||||
// that will let the original floating point value be precisely reconstructed.
|
||||
func roundShortest(d *decimal, mant uint64, exp int, flt *floatInfo) {
|
||||
// If mantissa is zero, the number is zero; stop now.
|
||||
if mant == 0 {
|
||||
d.nd = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Compute upper and lower such that any decimal number
|
||||
// between upper and lower (possibly inclusive)
|
||||
// will round to the original floating point number.
|
||||
|
||||
// We may see at once that the number is already shortest.
|
||||
//
|
||||
// Suppose d is not denormal, so that 2^exp <= d < 10^dp.
|
||||
// The closest shorter number is at least 10^(dp-nd) away.
|
||||
// The lower/upper bounds computed below are at distance
|
||||
// at most 2^(exp-mantbits).
|
||||
//
|
||||
// So the number is already shortest if 10^(dp-nd) > 2^(exp-mantbits),
|
||||
// or equivalently log2(10)*(dp-nd) > exp-mantbits.
|
||||
// It is true if 332/100*(dp-nd) >= exp-mantbits (log2(10) > 3.32).
|
||||
minexp := flt.bias + 1 // minimum possible exponent
|
||||
if exp > minexp && 332*(d.dp-d.nd) >= 100*(exp-int(flt.mantbits)) {
|
||||
// The number is already shortest.
|
||||
return
|
||||
}
|
||||
|
||||
// d = mant << (exp - mantbits)
|
||||
// Next highest floating point number is mant+1 << exp-mantbits.
|
||||
// Our upper bound is halfway between, mant*2+1 << exp-mantbits-1.
|
||||
upper := new(decimal)
|
||||
upper.Assign(mant*2 + 1)
|
||||
upper.Shift(exp - int(flt.mantbits) - 1)
|
||||
|
||||
// d = mant << (exp - mantbits)
|
||||
// Next lowest floating point number is mant-1 << exp-mantbits,
|
||||
// unless mant-1 drops the significant bit and exp is not the minimum exp,
|
||||
// in which case the next lowest is mant*2-1 << exp-mantbits-1.
|
||||
// Either way, call it mantlo << explo-mantbits.
|
||||
// Our lower bound is halfway between, mantlo*2+1 << explo-mantbits-1.
|
||||
var mantlo uint64
|
||||
var explo int
|
||||
if mant > 1<<flt.mantbits || exp == minexp {
|
||||
mantlo = mant - 1
|
||||
explo = exp
|
||||
} else {
|
||||
mantlo = mant*2 - 1
|
||||
explo = exp - 1
|
||||
}
|
||||
lower := new(decimal)
|
||||
lower.Assign(mantlo*2 + 1)
|
||||
lower.Shift(explo - int(flt.mantbits) - 1)
|
||||
|
||||
// The upper and lower bounds are possible outputs only if
|
||||
// the original mantissa is even, so that IEEE round-to-even
|
||||
// would round to the original mantissa and not the neighbors.
|
||||
inclusive := mant%2 == 0
|
||||
|
||||
// As we walk the digits we want to know whether rounding up would fall
|
||||
// within the upper bound. This is tracked by upperdelta:
|
||||
//
|
||||
// If upperdelta == 0, the digits of d and upper are the same so far.
|
||||
//
|
||||
// If upperdelta == 1, we saw a difference of 1 between d and upper on a
|
||||
// previous digit and subsequently only 9s for d and 0s for upper.
|
||||
// (Thus rounding up may fall outside the bound, if it is exclusive.)
|
||||
//
|
||||
// If upperdelta == 2, then the difference is greater than 1
|
||||
// and we know that rounding up falls within the bound.
|
||||
var upperdelta uint8
|
||||
|
||||
// Now we can figure out the minimum number of digits required.
|
||||
// Walk along until d has distinguished itself from upper and lower.
|
||||
for ui := 0; ; ui++ {
|
||||
// lower, d, and upper may have the decimal points at different
|
||||
// places. In this case upper is the longest, so we iterate from
|
||||
// ui==0 and start li and mi at (possibly) -1.
|
||||
mi := ui - upper.dp + d.dp
|
||||
if mi >= d.nd {
|
||||
break
|
||||
}
|
||||
li := ui - upper.dp + lower.dp
|
||||
l := byte('0') // lower digit
|
||||
if li >= 0 && li < lower.nd {
|
||||
l = lower.d[li]
|
||||
}
|
||||
m := byte('0') // middle digit
|
||||
if mi >= 0 {
|
||||
m = d.d[mi]
|
||||
}
|
||||
u := byte('0') // upper digit
|
||||
if ui < upper.nd {
|
||||
u = upper.d[ui]
|
||||
}
|
||||
|
||||
// Okay to round down (truncate) if lower has a different digit
|
||||
// or if lower is inclusive and is exactly the result of rounding
|
||||
// down (i.e., and we have reached the final digit of lower).
|
||||
okdown := l != m || inclusive && li+1 == lower.nd
|
||||
|
||||
switch {
|
||||
case upperdelta == 0 && m+1 < u:
|
||||
// Example:
|
||||
// m = 12345xxx
|
||||
// u = 12347xxx
|
||||
upperdelta = 2
|
||||
case upperdelta == 0 && m != u:
|
||||
// Example:
|
||||
// m = 12345xxx
|
||||
// u = 12346xxx
|
||||
upperdelta = 1
|
||||
case upperdelta == 1 && (m != '9' || u != '0'):
|
||||
// Example:
|
||||
// m = 1234598x
|
||||
// u = 1234600x
|
||||
upperdelta = 2
|
||||
}
|
||||
// Okay to round up if upper has a different digit and either upper
|
||||
// is inclusive or upper is bigger than the result of rounding up.
|
||||
okup := upperdelta > 0 && (inclusive || upperdelta > 1 || ui+1 < upper.nd)
|
||||
|
||||
// If it's okay to do either, then round to the nearest one.
|
||||
// If it's okay to do only one, do it.
|
||||
switch {
|
||||
case okdown && okup:
|
||||
d.Round(mi + 1)
|
||||
return
|
||||
case okdown:
|
||||
d.RoundDown(mi + 1)
|
||||
return
|
||||
case okup:
|
||||
d.RoundUp(mi + 1)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,564 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// tag标签的名称
|
||||
tagColumnName = "column"
|
||||
|
||||
// 输出字段 缓存的前缀
|
||||
exportPrefix = "_exportStructFields_"
|
||||
// 私有字段 缓存的前缀
|
||||
privatePrefix = "_privateStructFields_"
|
||||
// 数据库列名 缓存的前缀
|
||||
dbColumnNamePrefix = "_dbColumnName_"
|
||||
|
||||
// 数据库所有列名,经过排序 缓存的前缀
|
||||
dbColumnNameSlicePrefix = "_dbColumnNameSlice_"
|
||||
|
||||
// field对应的column的tag值 缓存的前缀
|
||||
// structFieldTagPrefix = "_structFieldTag_"
|
||||
// 数据库主键 缓存的前缀
|
||||
// dbPKNamePrefix = "_dbPKName_"
|
||||
)
|
||||
|
||||
// cacheStructFieldInfoMap 用于缓存反射的信息,sync.Map内部处理了并发锁
|
||||
var cacheStructFieldInfoMap *sync.Map = &sync.Map{}
|
||||
|
||||
// var cacheStructFieldInfoMap = make(map[string]map[string]reflect.StructField)
|
||||
|
||||
// 用于缓存field对应的column的tag值
|
||||
// var cacheStructFieldTagInfoMap = make(map[string]map[string]string)
|
||||
|
||||
// structFieldInfo 获取StructField的信息.只对struct或者*struct判断,如果是指针,返回指针下实际的struct类型
|
||||
// 第一个返回值是可以输出的字段(首字母大写),第二个是不能输出的字段(首字母小写)
|
||||
func structFieldInfo(typeOf *reflect.Type) error {
|
||||
if typeOf == nil {
|
||||
return errors.New("->structFieldInfo数据为空")
|
||||
}
|
||||
|
||||
entityName := (*typeOf).String()
|
||||
|
||||
// 缓存的key
|
||||
// 所有输出的属性,包含数据库字段,key是struct属性的名称,不区分大小写
|
||||
exportCacheKey := exportPrefix + entityName
|
||||
// 所有私有变量的属性,key是struct属性的名称,不区分大小写
|
||||
privateCacheKey := privatePrefix + entityName
|
||||
// 所有数据库的属性,key是数据库的字段名称,不区分大小写
|
||||
dbColumnCacheKey := dbColumnNamePrefix + entityName
|
||||
// 所有数据库字段名称的slice,经过排序,不区分大小写
|
||||
dbColumnNameSliceCacheKey := dbColumnNameSlicePrefix + entityName
|
||||
|
||||
// structFieldTagCacheKey := structFieldTagPrefix + entityName
|
||||
// dbPKNameCacheKey := dbPKNamePrefix + entityName
|
||||
// 缓存的数据库主键值
|
||||
_, exportOk := cacheStructFieldInfoMap.Load(exportCacheKey)
|
||||
//_, exportOk := cacheStructFieldInfoMap[exportCacheKey]
|
||||
//如果存在值,认为缓存中有所有的信息,不再处理
|
||||
if exportOk {
|
||||
return nil
|
||||
}
|
||||
// 获取字段长度
|
||||
fieldNum := (*typeOf).NumField()
|
||||
// 如果没有字段
|
||||
if fieldNum < 1 {
|
||||
return errors.New("->structFieldInfo-->NumField entity没有属性")
|
||||
}
|
||||
|
||||
// 声明所有字段的载体
|
||||
var allFieldMap *sync.Map = &sync.Map{}
|
||||
// anonymous := make([]reflect.StructField, 0)
|
||||
|
||||
// 缓存的数据
|
||||
exportStructFieldMap := make(map[string]reflect.StructField)
|
||||
privateStructFieldMap := make(map[string]reflect.StructField)
|
||||
dbColumnFieldMap := make(map[string]reflect.StructField)
|
||||
|
||||
// structFieldTagMap := make(map[string]string)
|
||||
dbColumnFieldNameSlice := make([]string, 0)
|
||||
|
||||
// 遍历sync.Map,要求输入一个func作为参数
|
||||
// 这个函数的入参、出参的类型都已经固定,不能修改
|
||||
// 可以在函数体内编写自己的代码,调用map中的k,v
|
||||
// var funcMapKV func(k, v interface{}) bool
|
||||
funcMapKV := func(k, v interface{}) bool {
|
||||
field := v.(reflect.StructField)
|
||||
fieldName := field.Name
|
||||
if ast.IsExported(fieldName) { // 如果是可以输出的,不区分大小写
|
||||
exportStructFieldMap[strings.ToLower(fieldName)] = field
|
||||
// 如果是数据库字段
|
||||
tagColumnValue := field.Tag.Get(tagColumnName)
|
||||
if len(tagColumnValue) > 0 {
|
||||
// dbColumnFieldMap[tagColumnValue] = field
|
||||
// 使用数据库字段的小写,处理oracle和达梦数据库的sql返回值大写
|
||||
tagColumnValueLower := strings.ToLower(tagColumnValue)
|
||||
dbColumnFieldMap[tagColumnValueLower] = field
|
||||
dbColumnFieldNameSlice = append(dbColumnFieldNameSlice, tagColumnValueLower)
|
||||
// structFieldTagMap[fieldName] = tagColumnValue
|
||||
}
|
||||
|
||||
} else { // 私有属性
|
||||
privateStructFieldMap[strings.ToLower(fieldName)] = field
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
// 并发锁,用于处理slice并发append
|
||||
var lock sync.Mutex
|
||||
// funcRecursiveAnonymous 递归调用struct的匿名属性,就近覆盖属性
|
||||
var funcRecursiveAnonymous func(allFieldMap *sync.Map, anonymous *reflect.StructField)
|
||||
funcRecursiveAnonymous = func(allFieldMap *sync.Map, anonymous *reflect.StructField) {
|
||||
// 字段类型
|
||||
anonymousTypeOf := anonymous.Type
|
||||
if anonymousTypeOf.Kind() == reflect.Ptr {
|
||||
// 获取指针下的Struct类型
|
||||
anonymousTypeOf = anonymousTypeOf.Elem()
|
||||
}
|
||||
|
||||
// 只处理Struct类型
|
||||
if anonymousTypeOf.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
|
||||
// 获取字段长度
|
||||
fieldNum := anonymousTypeOf.NumField()
|
||||
// 如果没有字段
|
||||
if fieldNum < 1 {
|
||||
return
|
||||
}
|
||||
// 遍历所有字段
|
||||
for i := 0; i < fieldNum; i++ {
|
||||
anonymousField := anonymousTypeOf.Field(i)
|
||||
if anonymousField.Anonymous { // 匿名struct里自身又有匿名struct
|
||||
funcRecursiveAnonymous(allFieldMap, &anonymousField)
|
||||
} else if _, ok := allFieldMap.Load(anonymousField.Name); !ok { // 普通命名字段,而且没有记录过
|
||||
allFieldMap.Store(anonymousField.Name, anonymousField)
|
||||
lock.Lock()
|
||||
funcMapKV(anonymousField.Name, anonymousField)
|
||||
lock.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 遍历所有字段,记录匿名属性
|
||||
for i := 0; i < fieldNum; i++ {
|
||||
field := (*typeOf).Field(i)
|
||||
if field.Anonymous { // 如果是匿名的
|
||||
funcRecursiveAnonymous(allFieldMap, &field)
|
||||
} else if _, ok := allFieldMap.Load(field.Name); !ok { // 普通命名字段,而且没有记录过
|
||||
allFieldMap.Store(field.Name, field)
|
||||
lock.Lock()
|
||||
funcMapKV(field.Name, field)
|
||||
lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// allFieldMap.Range(f)
|
||||
|
||||
// 加入缓存
|
||||
cacheStructFieldInfoMap.Store(exportCacheKey, exportStructFieldMap)
|
||||
cacheStructFieldInfoMap.Store(privateCacheKey, privateStructFieldMap)
|
||||
cacheStructFieldInfoMap.Store(dbColumnCacheKey, dbColumnFieldMap)
|
||||
// cacheStructFieldInfoMap[exportCacheKey] = exportStructFieldMap
|
||||
// cacheStructFieldInfoMap[privateCacheKey] = privateStructFieldMap
|
||||
// cacheStructFieldInfoMap[dbColumnCacheKey] = dbColumnFieldMap
|
||||
|
||||
// cacheStructFieldTagInfoMap[structFieldTagCacheKey] = structFieldTagMap
|
||||
|
||||
// 不按照字母顺序,按照反射获取的Struct属性顺序,生成insert语句和update语句
|
||||
// sort.Strings(dbColumnFieldNameSlice)
|
||||
cacheStructFieldInfoMap.Store(dbColumnNameSliceCacheKey, dbColumnFieldNameSlice)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setFieldValueByColumnName 根据数据库的字段名,找到struct映射的字段,并赋值
|
||||
func setFieldValueByColumnName(entity interface{}, columnName string, value interface{}) error {
|
||||
// 先从本地缓存中查找
|
||||
typeOf := reflect.TypeOf(entity)
|
||||
valueOf := reflect.ValueOf(entity)
|
||||
if typeOf.Kind() == reflect.Ptr { // 如果是指针
|
||||
typeOf = typeOf.Elem()
|
||||
valueOf = valueOf.Elem()
|
||||
}
|
||||
|
||||
dbMap, err := getDBColumnFieldMap(&typeOf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f, ok := dbMap[strings.ToLower(columnName)]
|
||||
if ok { // 给主键赋值
|
||||
valueOf.FieldByName(f.Name).Set(reflect.ValueOf(value))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// structFieldValue 获取指定字段的值
|
||||
func structFieldValue(s interface{}, fieldName string) (interface{}, error) {
|
||||
if s == nil || len(fieldName) < 1 {
|
||||
return nil, errors.New("->structFieldValue数据为空")
|
||||
}
|
||||
// entity的s类型
|
||||
valueOf := reflect.ValueOf(s)
|
||||
|
||||
kind := valueOf.Kind()
|
||||
if !(kind == reflect.Ptr || kind == reflect.Struct) {
|
||||
return nil, errors.New("->structFieldValue必须是Struct或者*Struct类型")
|
||||
}
|
||||
|
||||
if kind == reflect.Ptr {
|
||||
// 获取指针下的Struct类型
|
||||
valueOf = valueOf.Elem()
|
||||
if valueOf.Kind() != reflect.Struct {
|
||||
return nil, errors.New("->structFieldValue必须是Struct或者*Struct类型")
|
||||
}
|
||||
}
|
||||
|
||||
// FieldByName方法返回的是reflect.Value类型,调用Interface()方法,返回原始类型的数据值
|
||||
value := valueOf.FieldByName(fieldName).Interface()
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// getDBColumnExportFieldMap 获取实体类的数据库字段,key是数据库的字段名称.同时返回所有的字段属性的map,key是实体类的属性.不区分大小写
|
||||
func getDBColumnExportFieldMap(typeOf *reflect.Type) (map[string]reflect.StructField, map[string]reflect.StructField, error) {
|
||||
dbColumnFieldMap, err := getCacheStructFieldInfoMap(typeOf, dbColumnNamePrefix)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
exportFieldMap, err := getCacheStructFieldInfoMap(typeOf, exportPrefix)
|
||||
return dbColumnFieldMap, exportFieldMap, err
|
||||
}
|
||||
|
||||
// getDBColumnFieldMap 获取实体类的数据库字段,key是数据库的字段名称.不区分大小写
|
||||
func getDBColumnFieldMap(typeOf *reflect.Type) (map[string]reflect.StructField, error) {
|
||||
return getCacheStructFieldInfoMap(typeOf, dbColumnNamePrefix)
|
||||
}
|
||||
|
||||
// getDBColumnFieldNameSlice 获取实体类的数据库字段,经过排序,key是数据库的字段名称.不区分大小写,
|
||||
func getDBColumnFieldNameSlice(typeOf *reflect.Type) ([]string, error) {
|
||||
dbColumnFieldSlice, dbmapErr := getCacheStructFieldInfo(typeOf, dbColumnNameSlicePrefix)
|
||||
if dbmapErr != nil {
|
||||
return nil, fmt.Errorf("->getDBColumnFieldNameSlice-->getCacheStructFieldInfo()取值错误:%w", dbmapErr)
|
||||
}
|
||||
dbcfSlice, efOK := dbColumnFieldSlice.([]string)
|
||||
if !efOK {
|
||||
return dbcfSlice, errors.New("->getDBColumnFieldNameSlice-->dbColumnFieldSlice取值转[]string类型异常")
|
||||
}
|
||||
return dbcfSlice, nil
|
||||
}
|
||||
|
||||
// getCacheStructFieldInfo 根据类型和key,获取缓存的数据字段信息slice,已经排序
|
||||
func getCacheStructFieldInfo(typeOf *reflect.Type, keyPrefix string) (interface{}, error) {
|
||||
if typeOf == nil {
|
||||
return nil, errors.New("->getCacheStructFieldInfo-->typeOf不能为空")
|
||||
}
|
||||
key := keyPrefix + (*typeOf).String()
|
||||
dbColumnFieldMap, dbOk := cacheStructFieldInfoMap.Load(key)
|
||||
// dbColumnFieldMap, dbOk := cacheStructFieldInfoMap[key]
|
||||
if !dbOk { // 缓存不存在
|
||||
// 获取实体类的输出字段和私有 字段
|
||||
err := structFieldInfo(typeOf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbColumnFieldMap, dbOk = cacheStructFieldInfoMap.Load(key)
|
||||
// dbColumnFieldMap, dbOk = cacheStructFieldInfoMap[key]
|
||||
if !dbOk {
|
||||
return nil, errors.New("->getCacheStructFieldInfo-->cacheStructFieldInfoMap.Load()获取数据库字段dbColumnFieldMap异常")
|
||||
}
|
||||
}
|
||||
|
||||
return dbColumnFieldMap, nil
|
||||
|
||||
// return dbColumnFieldMap, nil
|
||||
}
|
||||
|
||||
// getCacheStructFieldInfoMap 根据类型和key,获取缓存的字段信息
|
||||
func getCacheStructFieldInfoMap(typeOf *reflect.Type, keyPrefix string) (map[string]reflect.StructField, error) {
|
||||
dbColumnFieldMap, dbmapErr := getCacheStructFieldInfo(typeOf, keyPrefix)
|
||||
if dbmapErr != nil {
|
||||
return nil, fmt.Errorf("->getCacheStructFieldInfoMap-->getCacheStructFieldInfo()取值错误:%w", dbmapErr)
|
||||
}
|
||||
dbcfMap, efOK := dbColumnFieldMap.(map[string]reflect.StructField)
|
||||
if !efOK {
|
||||
return dbcfMap, errors.New("->getCacheStructFieldInfoMap-->dbColumnFieldMap取值转map[string]reflect.StructField类型异常")
|
||||
}
|
||||
return dbcfMap, nil
|
||||
|
||||
// return dbColumnFieldMap, nil
|
||||
}
|
||||
|
||||
// columnAndValue 根据保存的对象,返回插入的语句,需要插入的字段,字段的值
|
||||
func columnAndValue(entity interface{}) (reflect.Type, []reflect.StructField, []interface{}, error) {
|
||||
typeOf, checkerr := checkEntityKind(entity)
|
||||
if checkerr != nil {
|
||||
return typeOf, nil, nil, checkerr
|
||||
}
|
||||
// 获取实体类的反射,指针下的struct
|
||||
valueOf := reflect.ValueOf(entity).Elem()
|
||||
// reflect.Indirect
|
||||
|
||||
// 先从本地缓存中查找
|
||||
// typeOf := reflect.TypeOf(entity).Elem()
|
||||
|
||||
dbMap, err := getDBColumnFieldMap(&typeOf)
|
||||
if err != nil {
|
||||
return typeOf, nil, nil, err
|
||||
}
|
||||
dbSlice, err := getDBColumnFieldNameSlice(&typeOf)
|
||||
if err != nil {
|
||||
return typeOf, nil, nil, err
|
||||
}
|
||||
// 实体类公开字段的长度
|
||||
fLen := len(dbMap)
|
||||
// 长度不一致
|
||||
if fLen-len(dbSlice) != 0 {
|
||||
return typeOf, nil, nil, errors.New("->columnAndValue-->缓存的数据库字段和实体类字段不对应")
|
||||
}
|
||||
// 接收列的数组,这里是做一个副本,避免外部更改掉原始的列信息
|
||||
columns := make([]reflect.StructField, 0, fLen)
|
||||
// 接收值的数组
|
||||
values := make([]interface{}, 0, fLen)
|
||||
|
||||
// 遍历所有数据库属性
|
||||
for _, fieldName := range dbSlice {
|
||||
//获取字段类型的Kind
|
||||
// fieldKind := field.Type.Kind()
|
||||
//if !allowTypeMap[fieldKind] { //不允许的类型
|
||||
// continue
|
||||
//}
|
||||
field := dbMap[fieldName]
|
||||
columns = append(columns, field)
|
||||
// FieldByName方法返回的是reflect.Value类型,调用Interface()方法,返回原始类型的数据值.字段不会重名,不使用FieldByIndex()函数
|
||||
value := valueOf.FieldByName(field.Name).Interface()
|
||||
// 添加到记录值的数组
|
||||
values = append(values, value)
|
||||
|
||||
}
|
||||
|
||||
// 缓存数据库的列
|
||||
return typeOf, columns, values, nil
|
||||
}
|
||||
|
||||
// entityPKFieldName 获取实体类主键属性名称
|
||||
func entityPKFieldName(entity IEntityStruct, typeOf *reflect.Type) (string, error) {
|
||||
//检查是否是指针对象
|
||||
//typeOf, checkerr := checkEntityKind(entity)
|
||||
//if checkerr != nil {
|
||||
// return "", checkerr
|
||||
//}
|
||||
|
||||
// 缓存的key,TypeOf和ValueOf的String()方法,返回值不一样
|
||||
// typeOf := reflect.TypeOf(entity).Elem()
|
||||
|
||||
dbMap, err := getDBColumnFieldMap(typeOf)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
field := dbMap[strings.ToLower(entity.GetPKColumnName())]
|
||||
return field.Name, nil
|
||||
}
|
||||
|
||||
// checkEntityKind 检查entity类型必须是*struct类型或者基础类型的指针
|
||||
func checkEntityKind(entity interface{}) (reflect.Type, error) {
|
||||
if entity == nil {
|
||||
return nil, errors.New("->checkEntityKind参数不能为空,必须是*struct类型或者基础类型的指针")
|
||||
}
|
||||
typeOf := reflect.TypeOf(entity)
|
||||
if typeOf.Kind() != reflect.Ptr { // 如果不是指针
|
||||
return nil, errors.New("->checkEntityKind必须是*struct类型或者基础类型的指针")
|
||||
}
|
||||
typeOf = typeOf.Elem()
|
||||
//if !(typeOf.Kind() == reflect.Struct || allowBaseTypeMap[typeOf.Kind()]) { //如果不是指针
|
||||
// return nil, errors.New("checkEntityKind必须是*struct类型或者基础类型的指针")
|
||||
//}
|
||||
return typeOf, nil
|
||||
}
|
||||
|
||||
// sqlRowsValues 包装接收sqlRows的Values数组,反射rows屏蔽数据库null值,兼容单个字段查询和Struct映射
|
||||
// fix:converting NULL to int is unsupported
|
||||
// 当读取数据库的值为NULL时,由于基本类型不支持为NULL,通过反射将未知driver.Value改为interface{},不再映射到struct实体类
|
||||
// 感谢@fastabler提交的pr
|
||||
// oneColumnScanner 只有一个字段,而且可以直接Scan,例如string或者[]string,不需要反射StructType进行处理
|
||||
func sqlRowsValues(ctx context.Context, dialect string, valueOf *reflect.Value, typeOf *reflect.Type, rows *sql.Rows, driverValue *reflect.Value, columnTypes []*sql.ColumnType, entity interface{}, dbColumnFieldMap, exportFieldMap *map[string]reflect.StructField) error {
|
||||
if entity == nil && valueOf == nil {
|
||||
return errors.New("->sqlRowsValues-->valueOfElem为nil")
|
||||
}
|
||||
|
||||
var valueOfElem reflect.Value
|
||||
if entity == nil && valueOf != nil {
|
||||
valueOfElem = valueOf.Elem()
|
||||
}
|
||||
|
||||
ctLen := len(columnTypes)
|
||||
// 声明载体数组,用于存放struct的属性指针
|
||||
// Declare a carrier array to store the attribute pointer of the struct
|
||||
values := make([]interface{}, ctLen)
|
||||
// 记录需要类型转换的字段信息
|
||||
var fieldTempDriverValueMap map[*sql.ColumnType]*driverValueInfo
|
||||
if iscdvm {
|
||||
fieldTempDriverValueMap = make(map[*sql.ColumnType]*driverValueInfo)
|
||||
}
|
||||
var err error
|
||||
var customDriverValueConver ICustomDriverValueConver
|
||||
var converOK bool
|
||||
|
||||
for i, columnType := range columnTypes {
|
||||
if iscdvm {
|
||||
databaseTypeName := strings.ToUpper(columnType.DatabaseTypeName())
|
||||
// 根据接收的类型,获取到类型转换的接口实现,优先匹配指定的数据库类型
|
||||
customDriverValueConver, converOK = customDriverValueMap[dialect+"."+databaseTypeName]
|
||||
if !converOK {
|
||||
customDriverValueConver, converOK = customDriverValueMap[databaseTypeName]
|
||||
}
|
||||
}
|
||||
dv := driverValue.Index(i)
|
||||
if dv.IsValid() && dv.InterfaceData()[0] == 0 { // 该字段的数据库值是null,取默认值
|
||||
values[i] = new(interface{})
|
||||
continue
|
||||
} else if converOK { // 如果是需要转换的字段
|
||||
// 获取字段类型
|
||||
var structFieldType *reflect.Type
|
||||
if entity != nil { // 查询一个字段,并且可以直接接收
|
||||
structFieldType = typeOf
|
||||
} else { // 如果是struct类型
|
||||
field, err := getStructFieldByColumnType(columnType, dbColumnFieldMap, exportFieldMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if field != nil { // 存在这个字段
|
||||
vtype := field.Type
|
||||
structFieldType = &vtype
|
||||
}
|
||||
}
|
||||
tempDriverValue, err := customDriverValueConver.GetDriverValue(ctx, columnType, structFieldType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tempDriverValue == nil {
|
||||
return errors.New("->sqlRowsValues-->customDriverValueConver.GetDriverValue返回的driver.Value不能为nil")
|
||||
}
|
||||
values[i] = tempDriverValue
|
||||
|
||||
// 如果需要类型转换
|
||||
dvinfo := driverValueInfo{}
|
||||
dvinfo.customDriverValueConver = customDriverValueConver
|
||||
// dvinfo.columnType = columnType
|
||||
dvinfo.structFieldType = structFieldType
|
||||
dvinfo.tempDriverValue = tempDriverValue
|
||||
fieldTempDriverValueMap[columnType] = &dvinfo
|
||||
continue
|
||||
|
||||
} else if entity != nil { // 查询一个字段,并且可以直接接收
|
||||
values[i] = entity
|
||||
continue
|
||||
} else {
|
||||
field, err := getStructFieldByColumnType(columnType, dbColumnFieldMap, exportFieldMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if field == nil { // 如果不存在这个字段
|
||||
values[i] = new(interface{})
|
||||
} else {
|
||||
// fieldType := refPV.FieldByName(field.Name).Type()
|
||||
// v := reflect.New(field.Type).Interface()
|
||||
// 字段的反射值
|
||||
fieldValue := valueOfElem.FieldByName(field.Name)
|
||||
v := fieldValue.Addr().Interface()
|
||||
// v := new(interface{})
|
||||
values[i] = v
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
err = rows.Scan(values...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(fieldTempDriverValueMap) < 1 {
|
||||
return err
|
||||
}
|
||||
|
||||
// 循环需要替换的值
|
||||
for columnType, driverValueInfo := range fieldTempDriverValueMap {
|
||||
// 根据列名,字段类型,新值 返回符合接收类型值的指针,返回值是个指针,指针,指针!!!!
|
||||
// typeOf := fieldValue.Type()
|
||||
rightValue, errConverDriverValue := driverValueInfo.customDriverValueConver.ConverDriverValue(ctx, columnType, driverValueInfo.tempDriverValue, driverValueInfo.structFieldType)
|
||||
if errConverDriverValue != nil {
|
||||
errConverDriverValue = fmt.Errorf("->sqlRowsValues-->customDriverValueConver.ConverDriverValue错误:%w", errConverDriverValue)
|
||||
FuncLogError(ctx, errConverDriverValue)
|
||||
return errConverDriverValue
|
||||
}
|
||||
if entity != nil { // 查询一个字段,并且可以直接接收
|
||||
// entity = rightValue
|
||||
// valueOfElem.Set(reflect.ValueOf(rightValue).Elem())
|
||||
reflect.ValueOf(entity).Elem().Set(reflect.ValueOf(rightValue).Elem())
|
||||
continue
|
||||
} else { // 如果是Struct类型接收
|
||||
field, err := getStructFieldByColumnType(columnType, dbColumnFieldMap, exportFieldMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if field != nil { // 如果存在这个字段
|
||||
// 字段的反射值
|
||||
fieldValue := valueOfElem.FieldByName(field.Name)
|
||||
// 给字段赋值
|
||||
fieldValue.Set(reflect.ValueOf(rightValue).Elem())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// getStructFieldByColumnType 根据ColumnType获取StructField对象,兼容驼峰
|
||||
func getStructFieldByColumnType(columnType *sql.ColumnType, dbColumnFieldMap *map[string]reflect.StructField, exportFieldMap *map[string]reflect.StructField) (*reflect.StructField, error) {
|
||||
columnName := strings.ToLower(columnType.Name())
|
||||
// columnName := "test"
|
||||
// 从缓存中获取列名的field字段
|
||||
// Get the field field of the column name from the cache
|
||||
field, fok := (*dbColumnFieldMap)[columnName]
|
||||
if !fok {
|
||||
field, fok = (*exportFieldMap)[columnName]
|
||||
if !fok {
|
||||
// 尝试驼峰
|
||||
cname := strings.ReplaceAll(columnName, "_", "")
|
||||
field, fok = (*exportFieldMap)[cname]
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
if fok {
|
||||
return &field, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
@ -0,0 +1,611 @@
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package zorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
|
||||
"gitee.com/chunanyong/zorm/decimal"
|
||||
)
|
||||
|
||||
// FuncDecimalValue 设置decimal类型接收值,复写函数自定义decimal实现,例如github.com/shopspring/decimal,返回的是指针
|
||||
var FuncDecimalValue = func(ctx context.Context, dialect string) interface{} {
|
||||
return &decimal.Decimal{}
|
||||
}
|
||||
|
||||
// OverrideFunc 重写ZORM的函数,用于风险监控,只要查看这个函数的调用,就知道哪些地方重写了函数,避免项目混乱.当你使用这个函数时,你必须知道自己在做什么
|
||||
// funcName 是需要重写的方法命,funcObject是对应的函数. 返回值bool是否重写成功,interface{}是重写前的函数
|
||||
// 一般是在init里调用重写
|
||||
func OverrideFunc(funcName string, funcObject interface{}) (bool, interface{}, error) {
|
||||
if funcName == "" {
|
||||
return false, nil, errors.New("->OverrideFunc-->funcName不能为空")
|
||||
}
|
||||
|
||||
// oldFunc 老的函数
|
||||
var oldFunc interface{} = nil
|
||||
switch funcName {
|
||||
case "Transaction":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, doTransaction func(ctx context.Context) (interface{}, error)) (interface{}, error))
|
||||
if ok {
|
||||
oldFunc = transaction
|
||||
transaction = newFunc
|
||||
}
|
||||
case "QueryRow":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder, entity interface{}) (bool, error))
|
||||
if ok {
|
||||
oldFunc = queryRow
|
||||
queryRow = newFunc
|
||||
}
|
||||
case "Query":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder, rowsSlicePtr interface{}, page *Page) error)
|
||||
if ok {
|
||||
oldFunc = query
|
||||
query = newFunc
|
||||
}
|
||||
|
||||
case "QueryRowMap":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder) (map[string]interface{}, error))
|
||||
if ok {
|
||||
oldFunc = queryRowMap
|
||||
queryRowMap = newFunc
|
||||
}
|
||||
case "QueryMap":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder, page *Page) ([]map[string]interface{}, error))
|
||||
if ok {
|
||||
oldFunc = queryMap
|
||||
queryMap = newFunc
|
||||
}
|
||||
case "UpdateFinder":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, finder *Finder) (int, error))
|
||||
if ok {
|
||||
oldFunc = updateFinder
|
||||
updateFinder = newFunc
|
||||
}
|
||||
case "Insert":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = insert
|
||||
insert = newFunc
|
||||
}
|
||||
case "InsertSlice":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entityStructSlice []IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = insertSlice
|
||||
insertSlice = newFunc
|
||||
}
|
||||
case "Update":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = update
|
||||
update = newFunc
|
||||
}
|
||||
case "UpdateNotZeroValue":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = updateNotZeroValue
|
||||
updateNotZeroValue = newFunc
|
||||
}
|
||||
case "Delete":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityStruct) (int, error))
|
||||
if ok {
|
||||
oldFunc = delete
|
||||
delete = newFunc
|
||||
}
|
||||
|
||||
case "InsertEntityMap":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityMap) (int, error))
|
||||
if ok {
|
||||
oldFunc = insertEntityMap
|
||||
insertEntityMap = newFunc
|
||||
}
|
||||
case "InsertEntityMapSlice":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity []IEntityMap) (int, error))
|
||||
if ok {
|
||||
oldFunc = insertEntityMapSlice
|
||||
insertEntityMapSlice = newFunc
|
||||
}
|
||||
case "UpdateEntityMap":
|
||||
newFunc, ok := funcObject.(func(ctx context.Context, entity IEntityMap) (int, error))
|
||||
if ok {
|
||||
oldFunc = updateEntityMap
|
||||
updateEntityMap = newFunc
|
||||
}
|
||||
default:
|
||||
return false, oldFunc, errors.New("->OverrideFunc-->函数" + funcName + "暂不支持重写或不存在")
|
||||
}
|
||||
if oldFunc == nil {
|
||||
return false, oldFunc, errors.New("->OverrideFunc-->请检查传入的" + funcName + "函数实现,断言转换失败.")
|
||||
}
|
||||
return true, oldFunc, nil
|
||||
}
|
||||
|
||||
// typeConvertInt64toInt int64 转 int
|
||||
func typeConvertInt64toInt(from int64) (int, error) {
|
||||
strInt64 := strconv.FormatInt(from, 10)
|
||||
return strconv.Atoi(strInt64)
|
||||
}
|
||||
|
||||
/*
|
||||
func typeConvertFloat32(i interface{}) (float32, error) {
|
||||
if i == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if v, ok := i.(float32); ok {
|
||||
return v, nil
|
||||
}
|
||||
v, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
vf, err := strconv.ParseFloat(v, 32)
|
||||
return float32(vf), err
|
||||
}
|
||||
|
||||
func typeConvertFloat64(i interface{}) (float64, error) {
|
||||
if i == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if v, ok := i.(float64); ok {
|
||||
return v, nil
|
||||
}
|
||||
v, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(v, 64)
|
||||
}
|
||||
|
||||
func typeConvertDecimal(i interface{}) (decimal.Decimal, error) {
|
||||
if i == nil {
|
||||
return decimal.Zero, nil
|
||||
}
|
||||
if v, ok := i.(decimal.Decimal); ok {
|
||||
return v, nil
|
||||
}
|
||||
v, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return decimal.Zero, err
|
||||
}
|
||||
return decimal.NewFromString(v)
|
||||
}
|
||||
|
||||
func typeConvertInt64(i interface{}) (int64, error) {
|
||||
if i == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if v, ok := i.(int64); ok {
|
||||
return v, nil
|
||||
}
|
||||
v, err := typeConvertInt(i)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(v), err
|
||||
}
|
||||
|
||||
func typeConvertString(i interface{}) (string, error) {
|
||||
if i == nil {
|
||||
return "", nil
|
||||
}
|
||||
switch value := i.(type) {
|
||||
case int:
|
||||
return strconv.Itoa(value), nil
|
||||
case int8:
|
||||
return strconv.Itoa(int(value)), nil
|
||||
case int16:
|
||||
return strconv.Itoa(int(value)), nil
|
||||
case int32:
|
||||
return strconv.Itoa(int(value)), nil
|
||||
case int64:
|
||||
return strconv.Itoa(int(value)), nil
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(value), 10), nil
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(value), 'f', -1, 32), nil
|
||||
case float64:
|
||||
return strconv.FormatFloat(value, 'f', -1, 64), nil
|
||||
case bool:
|
||||
return strconv.FormatBool(value), nil
|
||||
case string:
|
||||
return value, nil
|
||||
case []byte:
|
||||
return string(value), nil
|
||||
default:
|
||||
return fmt.Sprintf("%v", value), nil
|
||||
}
|
||||
}
|
||||
|
||||
//false: "", 0, false, off
|
||||
func typeConvertBool(i interface{}) (bool, error) {
|
||||
if i == nil {
|
||||
return false, nil
|
||||
}
|
||||
if v, ok := i.(bool); ok {
|
||||
return v, nil
|
||||
}
|
||||
s, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if s != "" && s != "0" && s != "false" && s != "off" {
|
||||
return true, err
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func typeConvertInt(i interface{}) (int, error) {
|
||||
if i == nil {
|
||||
return 0, nil
|
||||
}
|
||||
switch value := i.(type) {
|
||||
case int:
|
||||
return value, nil
|
||||
case int8:
|
||||
return int(value), nil
|
||||
case int16:
|
||||
return int(value), nil
|
||||
case int32:
|
||||
return int(value), nil
|
||||
case int64:
|
||||
return int(value), nil
|
||||
case uint:
|
||||
return int(value), nil
|
||||
case uint8:
|
||||
return int(value), nil
|
||||
case uint16:
|
||||
return int(value), nil
|
||||
case uint32:
|
||||
return int(value), nil
|
||||
case uint64:
|
||||
return int(value), nil
|
||||
case float32:
|
||||
return int(value), nil
|
||||
case float64:
|
||||
return int(value), nil
|
||||
case bool:
|
||||
if value {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, nil
|
||||
default:
|
||||
v, err := typeConvertString(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.Atoi(v)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
func typeConvertTime(i interface{}, format string, TZLocation ...*time.Location) (time.Time, error) {
|
||||
s, err := typeConvertString(i)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return typeConvertStrToTime(s, format, TZLocation...)
|
||||
}
|
||||
|
||||
func typeConvertStrToTime(str string, format string, TZLocation ...*time.Location) (time.Time, error) {
|
||||
if len(TZLocation) > 0 {
|
||||
return time.ParseInLocation(format, str, TZLocation[0])
|
||||
}
|
||||
return time.ParseInLocation(format, str, time.Local)
|
||||
}
|
||||
|
||||
func encodeString(s string) []byte {
|
||||
return []byte(s)
|
||||
}
|
||||
|
||||
func decodeToString(b []byte) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func encodeBool(b bool) []byte {
|
||||
if b {
|
||||
return []byte{1}
|
||||
}
|
||||
return []byte{0}
|
||||
|
||||
}
|
||||
|
||||
func encodeInt(i int) []byte {
|
||||
if i <= math.MaxInt8 {
|
||||
return encodeInt8(int8(i))
|
||||
} else if i <= math.MaxInt16 {
|
||||
return encodeInt16(int16(i))
|
||||
} else if i <= math.MaxInt32 {
|
||||
return encodeInt32(int32(i))
|
||||
} else {
|
||||
return encodeInt64(int64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func encodeUint(i uint) []byte {
|
||||
if i <= math.MaxUint8 {
|
||||
return encodeUint8(uint8(i))
|
||||
} else if i <= math.MaxUint16 {
|
||||
return encodeUint16(uint16(i))
|
||||
} else if i <= math.MaxUint32 {
|
||||
return encodeUint32(uint32(i))
|
||||
} else {
|
||||
return encodeUint64(uint64(i))
|
||||
}
|
||||
}
|
||||
|
||||
func encodeInt8(i int8) []byte {
|
||||
return []byte{byte(i)}
|
||||
}
|
||||
|
||||
func encodeUint8(i uint8) []byte {
|
||||
return []byte{byte(i)}
|
||||
}
|
||||
|
||||
func encodeInt16(i int16) []byte {
|
||||
bytes := make([]byte, 2)
|
||||
binary.LittleEndian.PutUint16(bytes, uint16(i))
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeUint16(i uint16) []byte {
|
||||
bytes := make([]byte, 2)
|
||||
binary.LittleEndian.PutUint16(bytes, i)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeInt32(i int32) []byte {
|
||||
bytes := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bytes, uint32(i))
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeUint32(i uint32) []byte {
|
||||
bytes := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bytes, i)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeInt64(i int64) []byte {
|
||||
bytes := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(bytes, uint64(i))
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeUint64(i uint64) []byte {
|
||||
bytes := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(bytes, i)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeFloat32(f float32) []byte {
|
||||
bits := math.Float32bits(f)
|
||||
bytes := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(bytes, bits)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encodeFloat64(f float64) []byte {
|
||||
bits := math.Float64bits(f)
|
||||
bytes := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(bytes, bits)
|
||||
return bytes
|
||||
}
|
||||
|
||||
func encode(vs ...interface{}) []byte {
|
||||
buf := new(bytes.Buffer)
|
||||
for i := 0; i < len(vs); i++ {
|
||||
switch value := vs[i].(type) {
|
||||
case int:
|
||||
buf.Write(encodeInt(value))
|
||||
case int8:
|
||||
buf.Write(encodeInt8(value))
|
||||
case int16:
|
||||
buf.Write(encodeInt16(value))
|
||||
case int32:
|
||||
buf.Write(encodeInt32(value))
|
||||
case int64:
|
||||
buf.Write(encodeInt64(value))
|
||||
case uint:
|
||||
buf.Write(encodeUint(value))
|
||||
case uint8:
|
||||
buf.Write(encodeUint8(value))
|
||||
case uint16:
|
||||
buf.Write(encodeUint16(value))
|
||||
case uint32:
|
||||
buf.Write(encodeUint32(value))
|
||||
case uint64:
|
||||
buf.Write(encodeUint64(value))
|
||||
case bool:
|
||||
buf.Write(encodeBool(value))
|
||||
case string:
|
||||
buf.Write(encodeString(value))
|
||||
case []byte:
|
||||
buf.Write(value)
|
||||
case float32:
|
||||
buf.Write(encodeFloat32(value))
|
||||
case float64:
|
||||
buf.Write(encodeFloat64(value))
|
||||
default:
|
||||
if err := binary.Write(buf, binary.LittleEndian, value); err != nil {
|
||||
buf.Write(encodeString(fmt.Sprintf("%v", value)))
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func isNumeric(s string) bool {
|
||||
for i := 0; i < len(s); i++ {
|
||||
if s[i] < byte('0') || s[i] > byte('9') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
func typeConvertTimeDuration(i interface{}) time.Duration {
|
||||
return time.Duration(typeConvertInt64(i))
|
||||
}
|
||||
|
||||
func typeConvertBytes(i interface{}) []byte {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
if r, ok := i.([]byte); ok {
|
||||
return r
|
||||
}
|
||||
return encode(i)
|
||||
|
||||
}
|
||||
|
||||
func typeConvertStrings(i interface{}) []string {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
if r, ok := i.([]string); ok {
|
||||
return r
|
||||
} else if r, ok := i.([]interface{}); ok {
|
||||
strs := make([]string, len(r))
|
||||
for k, v := range r {
|
||||
strs[k] = typeConvertString(v)
|
||||
}
|
||||
return strs
|
||||
}
|
||||
return []string{fmt.Sprintf("%v", i)}
|
||||
}
|
||||
|
||||
func typeConvertInt8(i interface{}) int8 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(int8); ok {
|
||||
return v
|
||||
}
|
||||
return int8(typeConvertInt(i))
|
||||
}
|
||||
|
||||
func typeConvertInt16(i interface{}) int16 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(int16); ok {
|
||||
return v
|
||||
}
|
||||
return int16(typeConvertInt(i))
|
||||
}
|
||||
|
||||
func typeConvertInt32(i interface{}) int32 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(int32); ok {
|
||||
return v
|
||||
}
|
||||
return int32(typeConvertInt(i))
|
||||
}
|
||||
|
||||
func typeConvertUint(i interface{}) uint {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
switch value := i.(type) {
|
||||
case int:
|
||||
return uint(value)
|
||||
case int8:
|
||||
return uint(value)
|
||||
case int16:
|
||||
return uint(value)
|
||||
case int32:
|
||||
return uint(value)
|
||||
case int64:
|
||||
return uint(value)
|
||||
case uint:
|
||||
return value
|
||||
case uint8:
|
||||
return uint(value)
|
||||
case uint16:
|
||||
return uint(value)
|
||||
case uint32:
|
||||
return uint(value)
|
||||
case uint64:
|
||||
return uint(value)
|
||||
case float32:
|
||||
return uint(value)
|
||||
case float64:
|
||||
return uint(value)
|
||||
case bool:
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
v, _ := strconv.ParseUint(typeConvertString(value), 10, 64)
|
||||
return uint(v)
|
||||
}
|
||||
}
|
||||
|
||||
func typeConvertUint8(i interface{}) uint8 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(uint8); ok {
|
||||
return v
|
||||
}
|
||||
return uint8(typeConvertUint(i))
|
||||
}
|
||||
|
||||
func typeConvertUint16(i interface{}) uint16 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(uint16); ok {
|
||||
return v
|
||||
}
|
||||
return uint16(typeConvertUint(i))
|
||||
}
|
||||
|
||||
func typeConvertUint32(i interface{}) uint32 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(uint32); ok {
|
||||
return v
|
||||
}
|
||||
return uint32(typeConvertUint(i))
|
||||
}
|
||||
|
||||
func typeConvertUint64(i interface{}) uint64 {
|
||||
if i == nil {
|
||||
return 0
|
||||
}
|
||||
if v, ok := i.(uint64); ok {
|
||||
return v
|
||||
}
|
||||
return uint64(typeConvertUint(i))
|
||||
}
|
||||
*/
|
After Width: | Height: | Size: 7.8 KiB |
@ -0,0 +1,4 @@
|
||||
.db
|
||||
*.test
|
||||
*~
|
||||
*.swp
|
@ -0,0 +1,7 @@
|
||||
Copyright (C) 2011 Blake Mizerany
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -0,0 +1,99 @@
|
||||
# pq - A pure Go postgres driver for Go's database/sql package
|
||||
|
||||
**This package is now deprecated. The up to date version is at
|
||||
[github.com/lib/pq](https://github.com/lib/pq).**
|
||||
|
||||
## Install
|
||||
|
||||
go get github.com/bmizerany/pq
|
||||
|
||||
## Docs
|
||||
|
||||
<http://go.pkgdoc.org/github.com/bmizerany/pq>
|
||||
|
||||
## Use
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "github.com/bmizerany/pq"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full")
|
||||
// ...
|
||||
}
|
||||
|
||||
**Connection String Parameters**
|
||||
|
||||
These are a subset of the libpq connection parameters. In addition, a
|
||||
number of the [environment
|
||||
variables](http://www.postgresql.org/docs/9.1/static/libpq-envars.html)
|
||||
supported by libpq are also supported. Just like libpq, these have
|
||||
lower precedence than explicitly provided connection parameters.
|
||||
|
||||
See http://www.postgresql.org/docs/9.1/static/libpq-connect.html.
|
||||
|
||||
* `dbname` - The name of the database to connect to
|
||||
* `user` - The user to sign in as
|
||||
* `password` - The user's password
|
||||
* `host` - The host to connect to. Values that start with `/` are for unix domain sockets. (default is `localhost`)
|
||||
* `port` - The port to bind to. (default is `5432`)
|
||||
* `sslmode` - Whether or not to use SSL (default is `require`, this is not the default for libpq)
|
||||
Valid values are:
|
||||
* `disable` - No SSL
|
||||
* `require` - Always SSL (skip verification)
|
||||
* `verify-full` - Always SSL (require verification)
|
||||
|
||||
See http://golang.org/pkg/database/sql to learn how to use with `pq` through the `database/sql` package.
|
||||
|
||||
## Tests
|
||||
|
||||
`go test` is used for testing. A running PostgreSQL server is
|
||||
required, with the ability to log in. The default database to connect
|
||||
to test with is "pqgotest," but it can be overridden using environment
|
||||
variables.
|
||||
|
||||
Example:
|
||||
|
||||
PGHOST=/var/run/postgresql go test pq
|
||||
|
||||
## Features
|
||||
|
||||
* SSL
|
||||
* Handles bad connections for `database/sql`
|
||||
* Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`)
|
||||
* Scan binary blobs correctly (i.e. `bytea`)
|
||||
* pq.ParseURL for converting urls to connection strings for sql.Open.
|
||||
* Many libpq compatible environment variables
|
||||
* Unix socket support
|
||||
|
||||
## Future / Things you can help with
|
||||
|
||||
* Notifications: `LISTEN`/`NOTIFY`
|
||||
* `hstore` sugar (i.e. handling hstore in `rows.Scan`)
|
||||
|
||||
## Thank you (alphabetical)
|
||||
|
||||
Some of these contributors are from the original library `bmizerany/pq.go` whose
|
||||
code still exists in here.
|
||||
|
||||
* Andy Balholm (andybalholm)
|
||||
* Ben Berkert (benburkert)
|
||||
* Bjørn Madsen (aeons)
|
||||
* Blake Gentry (bgentry)
|
||||
* Brad Fitzpatrick (bradfitz)
|
||||
* Daniel Farina (fdr)
|
||||
* Everyone at The Go Team
|
||||
* Federico Romero (federomero)
|
||||
* Heroku (heroku)
|
||||
* John Gallagher (jgallagher)
|
||||
* Kamil Kisiel (kisielk)
|
||||
* Keith Rarick (kr)
|
||||
* Marc Brinkmann (mbr)
|
||||
* Martin Olsen (martinolsen)
|
||||
* Mike Lewis (mikelikespie)
|
||||
* Ryan Smith (ryandotsmith)
|
||||
* Samuel Stauffer (samuel)
|
||||
* notedit (notedit)
|
@ -0,0 +1,80 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type readBuf []byte
|
||||
|
||||
func (b *readBuf) int32() (n int) {
|
||||
n = int(int32(binary.BigEndian.Uint32(*b)))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) oid() (n oid) {
|
||||
n = oid(binary.BigEndian.Uint32(*b))
|
||||
*b = (*b)[4:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) int16() (n int) {
|
||||
n = int(binary.BigEndian.Uint16(*b))
|
||||
*b = (*b)[2:]
|
||||
return
|
||||
}
|
||||
|
||||
var stringTerm = []byte{0}
|
||||
|
||||
func (b *readBuf) string() string {
|
||||
i := bytes.Index(*b, stringTerm)
|
||||
if i < 0 {
|
||||
errorf("invalid message format; expected string terminator")
|
||||
}
|
||||
s := (*b)[:i]
|
||||
*b = (*b)[i+1:]
|
||||
return string(s)
|
||||
}
|
||||
|
||||
func (b *readBuf) next(n int) (v []byte) {
|
||||
v = (*b)[:n]
|
||||
*b = (*b)[n:]
|
||||
return
|
||||
}
|
||||
|
||||
func (b *readBuf) byte() byte {
|
||||
return b.next(1)[0]
|
||||
}
|
||||
|
||||
type writeBuf []byte
|
||||
|
||||
func newWriteBuf(c byte) *writeBuf {
|
||||
b := make(writeBuf, 5)
|
||||
b[0] = c
|
||||
return &b
|
||||
}
|
||||
|
||||
func (b *writeBuf) int32(n int) {
|
||||
x := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(x, uint32(n))
|
||||
*b = append(*b, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) int16(n int) {
|
||||
x := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(x, uint16(n))
|
||||
*b = append(*b, x...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) string(s string) {
|
||||
*b = append(*b, (s + "\000")...)
|
||||
}
|
||||
|
||||
func (b *writeBuf) byte(c byte) {
|
||||
*b = append(*b, c)
|
||||
}
|
||||
|
||||
func (b *writeBuf) bytes(v []byte) {
|
||||
*b = append(*b, v...)
|
||||
}
|
@ -0,0 +1,678 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
|
||||
ErrNotSupported = errors.New("pq: invalid command")
|
||||
)
|
||||
|
||||
type drv struct{}
|
||||
|
||||
func (d *drv) Open(name string) (driver.Conn, error) {
|
||||
return Open(name)
|
||||
}
|
||||
|
||||
func init() {
|
||||
sql.Register("postgres", &drv{})
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
c net.Conn
|
||||
buf *bufio.Reader
|
||||
namei int
|
||||
}
|
||||
|
||||
func Open(name string) (_ driver.Conn, err error) {
|
||||
defer errRecover(&err)
|
||||
defer errRecoverWithPGReason(&err)
|
||||
|
||||
o := make(Values)
|
||||
|
||||
// A number of defaults are applied here, in this order:
|
||||
//
|
||||
// * Very low precedence defaults applied in every situation
|
||||
// * Environment variables
|
||||
// * Explicitly passed connection information
|
||||
o.Set("host", "localhost")
|
||||
o.Set("port", "5432")
|
||||
|
||||
// Default the username, but ignore errors, because a user
|
||||
// passed in via environment variable or connection string
|
||||
// would be okay. This can result in connections failing
|
||||
// *sometimes* if the client relies on being able to determine
|
||||
// the current username and there are intermittent problems.
|
||||
u, err := user.Current()
|
||||
if err == nil {
|
||||
o.Set("user", u.Username)
|
||||
}
|
||||
|
||||
for k, v := range parseEnviron(os.Environ()) {
|
||||
o.Set(k, v)
|
||||
}
|
||||
|
||||
parseOpts(name, o)
|
||||
|
||||
c, err := net.Dial(network(o))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cn := &conn{c: c}
|
||||
cn.ssl(o)
|
||||
cn.buf = bufio.NewReader(cn.c)
|
||||
cn.startup(o)
|
||||
return cn, nil
|
||||
}
|
||||
|
||||
func network(o Values) (string, string) {
|
||||
host := o.Get("host")
|
||||
|
||||
if strings.HasPrefix(host, "/") {
|
||||
sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
|
||||
return "unix", sockPath
|
||||
}
|
||||
|
||||
return "tcp", host + ":" + o.Get("port")
|
||||
}
|
||||
|
||||
type Values map[string]string
|
||||
|
||||
func (vs Values) Set(k, v string) {
|
||||
vs[k] = v
|
||||
}
|
||||
|
||||
func (vs Values) Get(k string) (v string) {
|
||||
v, _ = vs[k]
|
||||
return
|
||||
}
|
||||
|
||||
func parseOpts(name string, o Values) {
|
||||
if len(name) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ps := strings.Split(name, " ")
|
||||
for _, p := range ps {
|
||||
kv := strings.Split(p, "=")
|
||||
if len(kv) < 2 {
|
||||
errorf("invalid option: %q", p)
|
||||
}
|
||||
o.Set(kv[0], kv[1])
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) Begin() (driver.Tx, error) {
|
||||
_, err := cn.Exec("BEGIN", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cn, err
|
||||
}
|
||||
|
||||
func (cn *conn) Commit() error {
|
||||
_, err := cn.Exec("COMMIT", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (cn *conn) Rollback() error {
|
||||
_, err := cn.Exec("ROLLBACK", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (cn *conn) gname() string {
|
||||
cn.namei++
|
||||
return strconv.FormatInt(int64(cn.namei), 10)
|
||||
}
|
||||
|
||||
func (cn *conn) simpleQuery(q string) (res driver.Result, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
b := newWriteBuf('Q')
|
||||
b.string(q)
|
||||
cn.send(b)
|
||||
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case 'C':
|
||||
res = parseComplete(r.string())
|
||||
case 'Z':
|
||||
// done
|
||||
return
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'T', 'N', 'S':
|
||||
// ignore
|
||||
default:
|
||||
errorf("unknown response for simple query: %q", t)
|
||||
}
|
||||
}
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
st := &stmt{cn: cn, name: stmtName, query: q}
|
||||
|
||||
b := newWriteBuf('P')
|
||||
b.string(st.name)
|
||||
b.string(q)
|
||||
b.int16(0)
|
||||
cn.send(b)
|
||||
|
||||
b = newWriteBuf('D')
|
||||
b.byte('S')
|
||||
b.string(st.name)
|
||||
cn.send(b)
|
||||
|
||||
cn.send(newWriteBuf('S'))
|
||||
|
||||
for {
|
||||
t, r := cn.recv1()
|
||||
switch t {
|
||||
case '1', '2', 'N':
|
||||
case 't':
|
||||
st.nparams = int(r.int16())
|
||||
st.paramTyps = make([]oid, st.nparams, st.nparams)
|
||||
|
||||
for i := 0; i < st.nparams; i += 1 {
|
||||
st.paramTyps[i] = r.oid()
|
||||
}
|
||||
case 'T':
|
||||
n := r.int16()
|
||||
st.cols = make([]string, n)
|
||||
st.rowTyps = make([]oid, n)
|
||||
for i := range st.cols {
|
||||
st.cols[i] = r.string()
|
||||
r.next(6)
|
||||
st.rowTyps[i] = r.oid()
|
||||
r.next(8)
|
||||
}
|
||||
case 'n':
|
||||
// no data
|
||||
case 'Z':
|
||||
return st, err
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
default:
|
||||
errorf("unexpected describe rows response: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (cn *conn) Prepare(q string) (driver.Stmt, error) {
|
||||
return cn.prepareTo(q, cn.gname())
|
||||
}
|
||||
|
||||
func (cn *conn) Close() (err error) {
|
||||
defer errRecover(&err)
|
||||
cn.send(newWriteBuf('X'))
|
||||
|
||||
return cn.c.Close()
|
||||
}
|
||||
|
||||
// Implement the optional "Execer" interface for one-shot queries
|
||||
func (cn *conn) Exec(query string, args []driver.Value) (_ driver.Result, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
// Check to see if we can use the "simpleQuery" interface, which is
|
||||
// *much* faster than going through prepare/exec
|
||||
if len(args) == 0 {
|
||||
return cn.simpleQuery(query)
|
||||
}
|
||||
|
||||
// Use the unnamed statement to defer planning until bind
|
||||
// time, or else value-based selectivity estimates cannot be
|
||||
// used.
|
||||
st, err := cn.prepareTo(query, "")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
r, err := st.Exec(args)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return r, err
|
||||
}
|
||||
|
||||
// Assumes len(*m) is > 5
|
||||
func (cn *conn) send(m *writeBuf) {
|
||||
b := (*m)[1:]
|
||||
binary.BigEndian.PutUint32(b, uint32(len(b)))
|
||||
|
||||
if (*m)[0] == 0 {
|
||||
*m = b
|
||||
}
|
||||
|
||||
_, err := cn.c.Write(*m)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) recv() (t byte, r *readBuf) {
|
||||
for {
|
||||
t, r = cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
panic(parseError(r))
|
||||
case 'N':
|
||||
// ignore
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (cn *conn) recv1() (byte, *readBuf) {
|
||||
x := make([]byte, 5)
|
||||
_, err := io.ReadFull(cn.buf, x)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
b := readBuf(x[1:])
|
||||
y := make([]byte, b.int32()-4)
|
||||
_, err = io.ReadFull(cn.buf, y)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return x[0], (*readBuf)(&y)
|
||||
}
|
||||
|
||||
func (cn *conn) ssl(o Values) {
|
||||
tlsConf := tls.Config{}
|
||||
switch mode := o.Get("sslmode"); mode {
|
||||
case "require", "":
|
||||
tlsConf.InsecureSkipVerify = true
|
||||
case "verify-full":
|
||||
// fall out
|
||||
case "disable":
|
||||
return
|
||||
default:
|
||||
errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode)
|
||||
}
|
||||
|
||||
w := newWriteBuf(0)
|
||||
w.int32(80877103)
|
||||
cn.send(w)
|
||||
|
||||
b := make([]byte, 1)
|
||||
_, err := io.ReadFull(cn.c, b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if b[0] != 'S' {
|
||||
panic(ErrSSLNotSupported)
|
||||
}
|
||||
|
||||
cn.c = tls.Client(cn.c, &tlsConf)
|
||||
}
|
||||
|
||||
func (cn *conn) startup(o Values) {
|
||||
w := newWriteBuf(0)
|
||||
w.int32(196608)
|
||||
w.string("user")
|
||||
w.string(o.Get("user"))
|
||||
w.string("database")
|
||||
w.string(o.Get("dbname"))
|
||||
w.string("")
|
||||
cn.send(w)
|
||||
|
||||
for {
|
||||
t, r := cn.recv()
|
||||
switch t {
|
||||
case 'K', 'S':
|
||||
case 'R':
|
||||
cn.auth(r, o)
|
||||
case 'Z':
|
||||
return
|
||||
default:
|
||||
errorf("unknown response for startup: %q", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *conn) auth(r *readBuf, o Values) {
|
||||
switch code := r.int32(); code {
|
||||
case 0:
|
||||
// OK
|
||||
case 3:
|
||||
w := newWriteBuf('p')
|
||||
w.string(o.Get("password"))
|
||||
cn.send(w)
|
||||
|
||||
t, r := cn.recv()
|
||||
if t != 'R' {
|
||||
errorf("unexpected password response: %q", t)
|
||||
}
|
||||
|
||||
if r.int32() != 0 {
|
||||
errorf("unexpected authentication response: %q", t)
|
||||
}
|
||||
case 5:
|
||||
s := string(r.next(4))
|
||||
w := newWriteBuf('p')
|
||||
w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
|
||||
cn.send(w)
|
||||
|
||||
t, r := cn.recv()
|
||||
if t != 'R' {
|
||||
errorf("unexpected password response: %q", t)
|
||||
}
|
||||
|
||||
if r.int32() != 0 {
|
||||
errorf("unexpected authentication resoonse: %q", t)
|
||||
}
|
||||
default:
|
||||
errorf("unknown authentication response: %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
cn *conn
|
||||
name string
|
||||
query string
|
||||
cols []string
|
||||
nparams int
|
||||
rowTyps []oid
|
||||
paramTyps []oid
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (st *stmt) Close() (err error) {
|
||||
if st.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer errRecover(&err)
|
||||
|
||||
w := newWriteBuf('C')
|
||||
w.byte('S')
|
||||
w.string(st.name)
|
||||
st.cn.send(w)
|
||||
|
||||
st.cn.send(newWriteBuf('S'))
|
||||
|
||||
t, _ := st.cn.recv()
|
||||
if t != '3' {
|
||||
errorf("unexpected close response: %q", t)
|
||||
}
|
||||
st.closed = true
|
||||
|
||||
t, _ = st.cn.recv()
|
||||
if t != 'Z' {
|
||||
errorf("expected ready for query, but got: %q", t)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *stmt) Query(v []driver.Value) (_ driver.Rows, err error) {
|
||||
defer errRecover(&err)
|
||||
st.exec(v)
|
||||
return &rows{st: st}, nil
|
||||
}
|
||||
|
||||
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
|
||||
defer errRecover(&err)
|
||||
|
||||
if len(v) == 0 {
|
||||
return st.cn.simpleQuery(st.query)
|
||||
}
|
||||
st.exec(v)
|
||||
|
||||
for {
|
||||
t, r := st.cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'C':
|
||||
res = parseComplete(r.string())
|
||||
case 'Z':
|
||||
// done
|
||||
return
|
||||
case 'D':
|
||||
errorf("unexpected data row returned in Exec; check your query")
|
||||
case 'S', 'N':
|
||||
// Ignore
|
||||
default:
|
||||
errorf("unknown exec response: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (st *stmt) exec(v []driver.Value) {
|
||||
w := newWriteBuf('B')
|
||||
w.string("")
|
||||
w.string(st.name)
|
||||
w.int16(0)
|
||||
w.int16(len(v))
|
||||
for i, x := range v {
|
||||
if x == nil {
|
||||
w.int32(-1)
|
||||
} else {
|
||||
b := encode(x, st.paramTyps[i])
|
||||
w.int32(len(b))
|
||||
w.bytes(b)
|
||||
}
|
||||
}
|
||||
w.int16(0)
|
||||
st.cn.send(w)
|
||||
|
||||
w = newWriteBuf('E')
|
||||
w.string("")
|
||||
w.int32(0)
|
||||
st.cn.send(w)
|
||||
|
||||
st.cn.send(newWriteBuf('S'))
|
||||
|
||||
var err error
|
||||
for {
|
||||
t, r := st.cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case '2':
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
case 'Z':
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return
|
||||
case 'N':
|
||||
// ignore
|
||||
default:
|
||||
errorf("unexpected bind response: %q", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (st *stmt) NumInput() int {
|
||||
return st.nparams
|
||||
}
|
||||
|
||||
type result int64
|
||||
|
||||
func (i result) RowsAffected() (int64, error) {
|
||||
return int64(i), nil
|
||||
}
|
||||
|
||||
func (i result) LastInsertId() (int64, error) {
|
||||
return 0, ErrNotSupported
|
||||
}
|
||||
|
||||
func parseComplete(s string) driver.Result {
|
||||
parts := strings.Split(s, " ")
|
||||
n, _ := strconv.ParseInt(parts[len(parts)-1], 10, 64)
|
||||
return result(n)
|
||||
}
|
||||
|
||||
type rows struct {
|
||||
st *stmt
|
||||
done bool
|
||||
}
|
||||
|
||||
func (rs *rows) Close() error {
|
||||
for {
|
||||
err := rs.Next(nil)
|
||||
switch err {
|
||||
case nil:
|
||||
case io.EOF:
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func (rs *rows) Columns() []string {
|
||||
return rs.st.cols
|
||||
}
|
||||
|
||||
func (rs *rows) Next(dest []driver.Value) (err error) {
|
||||
if rs.done {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
defer errRecover(&err)
|
||||
|
||||
for {
|
||||
t, r := rs.st.cn.recv1()
|
||||
switch t {
|
||||
case 'E':
|
||||
err = parseError(r)
|
||||
case 'C', 'S', 'N':
|
||||
continue
|
||||
case 'Z':
|
||||
rs.done = true
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return io.EOF
|
||||
case 'D':
|
||||
n := r.int16()
|
||||
for i := 0; i < len(dest) && i < n; i++ {
|
||||
l := r.int32()
|
||||
if l == -1 {
|
||||
dest[i] = nil
|
||||
continue
|
||||
}
|
||||
dest[i] = decode(r.next(l), rs.st.rowTyps[i])
|
||||
}
|
||||
return
|
||||
default:
|
||||
errorf("unexpected message after execute: %q", t)
|
||||
}
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func md5s(s string) string {
|
||||
h := md5.New()
|
||||
h.Write([]byte(s))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
|
||||
// parseEnviron tries to mimic some of libpq's environment handling
|
||||
//
|
||||
// To ease testing, it does not directly reference os.Environ, but is
|
||||
// designed to accept its output.
|
||||
//
|
||||
// Environment-set connection information is intended to have a higher
|
||||
// precedence than a library default but lower than any explicitly
|
||||
// passed information (such as in the URL or connection string).
|
||||
func parseEnviron(env []string) (out map[string]string) {
|
||||
out = make(map[string]string)
|
||||
|
||||
for _, v := range env {
|
||||
parts := strings.SplitN(v, "=", 2)
|
||||
|
||||
accrue := func(keyname string) {
|
||||
out[keyname] = parts[1]
|
||||
}
|
||||
|
||||
// The order of these is the same as is seen in the
|
||||
// PostgreSQL 9.1 manual, with omissions briefly
|
||||
// noted.
|
||||
switch parts[0] {
|
||||
case "PGHOST":
|
||||
accrue("host")
|
||||
case "PGHOSTADDR":
|
||||
accrue("hostaddr")
|
||||
case "PGPORT":
|
||||
accrue("port")
|
||||
case "PGDATABASE":
|
||||
accrue("dbname")
|
||||
case "PGUSER":
|
||||
accrue("user")
|
||||
case "PGPASSWORD":
|
||||
accrue("password")
|
||||
// skip PGPASSFILE, PGSERVICE, PGSERVICEFILE,
|
||||
// PGREALM
|
||||
case "PGOPTIONS":
|
||||
accrue("options")
|
||||
case "PGAPPNAME":
|
||||
accrue("application_name")
|
||||
case "PGSSLMODE":
|
||||
accrue("sslmode")
|
||||
case "PGREQUIRESSL":
|
||||
accrue("requiressl")
|
||||
case "PGSSLCERT":
|
||||
accrue("sslcert")
|
||||
case "PGSSLKEY":
|
||||
accrue("sslkey")
|
||||
case "PGSSLROOTCERT":
|
||||
accrue("sslrootcert")
|
||||
case "PGSSLCRL":
|
||||
accrue("sslcrl")
|
||||
case "PGREQUIREPEER":
|
||||
accrue("requirepeer")
|
||||
case "PGKRBSRVNAME":
|
||||
accrue("krbsrvname")
|
||||
case "PGGSSLIB":
|
||||
accrue("gsslib")
|
||||
case "PGCONNECT_TIMEOUT":
|
||||
accrue("connect_timeout")
|
||||
case "PGCLIENTENCODING":
|
||||
accrue("client_encoding")
|
||||
// skip PGDATESTYLE, PGTZ, PGGEQO, PGSYSCONFDIR,
|
||||
// PGLOCALEDIR
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
@ -0,0 +1,121 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func encode(x interface{}, pgtypoid oid) []byte {
|
||||
switch v := x.(type) {
|
||||
case int64:
|
||||
return []byte(fmt.Sprintf("%d", v))
|
||||
case float32, float64:
|
||||
return []byte(fmt.Sprintf("%f", v))
|
||||
case []byte:
|
||||
if pgtypoid == t_bytea {
|
||||
return []byte(fmt.Sprintf("\\x%x", v))
|
||||
}
|
||||
|
||||
return v
|
||||
case string:
|
||||
if pgtypoid == t_bytea {
|
||||
return []byte(fmt.Sprintf("\\x%x", v))
|
||||
}
|
||||
|
||||
return []byte(v)
|
||||
case bool:
|
||||
return []byte(fmt.Sprintf("%t", v))
|
||||
case time.Time:
|
||||
return []byte(v.Format(time.RFC3339Nano))
|
||||
default:
|
||||
errorf("encode: unknown type for %T", v)
|
||||
}
|
||||
|
||||
panic("not reached")
|
||||
}
|
||||
|
||||
func decode(s []byte, typ oid) interface{} {
|
||||
switch typ {
|
||||
case t_bytea:
|
||||
s = s[2:] // trim off "\\x"
|
||||
d := make([]byte, hex.DecodedLen(len(s)))
|
||||
_, err := hex.Decode(d, s)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return d
|
||||
case t_timestamptz:
|
||||
return mustParse("2006-01-02 15:04:05-07", typ, s)
|
||||
case t_timestamp:
|
||||
return mustParse("2006-01-02 15:04:05", typ, s)
|
||||
case t_time:
|
||||
return mustParse("15:04:05", typ, s)
|
||||
case t_timetz:
|
||||
return mustParse("15:04:05-07", typ, s)
|
||||
case t_date:
|
||||
return mustParse("2006-01-02", typ, s)
|
||||
case t_bool:
|
||||
return s[0] == 't'
|
||||
case t_int8, t_int2, t_int4:
|
||||
i, err := strconv.ParseInt(string(s), 10, 64)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return i
|
||||
case t_float4, t_float8:
|
||||
bits := 64
|
||||
if typ == t_float4 {
|
||||
bits = 32
|
||||
}
|
||||
f, err := strconv.ParseFloat(string(s), bits)
|
||||
if err != nil {
|
||||
errorf("%s", err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func mustParse(f string, typ oid, s []byte) time.Time {
|
||||
str := string(s)
|
||||
|
||||
// Special case until time.Parse bug is fixed:
|
||||
// http://code.google.com/p/go/issues/detail?id=3487
|
||||
if str[len(str)-2] == '.' {
|
||||
str += "0"
|
||||
}
|
||||
|
||||
// check for a 30-minute-offset timezone
|
||||
if (typ == t_timestamptz || typ == t_timetz) &&
|
||||
str[len(str)-3] == ':' {
|
||||
f += ":00"
|
||||
}
|
||||
t, err := time.Parse(f, str)
|
||||
if err != nil {
|
||||
errorf("decode: %s", err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
nt.Time, nt.Valid = value.(time.Time)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
@ -0,0 +1,108 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
Efatal = "FATAL"
|
||||
Epanic = "PANIC"
|
||||
Ewarning = "WARNING"
|
||||
Enotice = "NOTICE"
|
||||
Edebug = "DEBUG"
|
||||
Einfo = "INFO"
|
||||
Elog = "LOG"
|
||||
)
|
||||
|
||||
type Error error
|
||||
|
||||
type PGError interface {
|
||||
Error() string
|
||||
Fatal() bool
|
||||
Get(k byte) (v string)
|
||||
}
|
||||
type pgError struct {
|
||||
c map[byte]string
|
||||
}
|
||||
|
||||
func parseError(r *readBuf) *pgError {
|
||||
err := &pgError{make(map[byte]string)}
|
||||
for t := r.byte(); t != 0; t = r.byte() {
|
||||
err.c[t] = r.string()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (err *pgError) Get(k byte) (v string) {
|
||||
v, _ = err.c[k]
|
||||
return
|
||||
}
|
||||
|
||||
func (err *pgError) Fatal() bool {
|
||||
return err.Get('S') == Efatal
|
||||
}
|
||||
|
||||
func (err *pgError) Error() string {
|
||||
var s string
|
||||
for k, v := range err.c {
|
||||
s += fmt.Sprintf(" %c:%q", k, v)
|
||||
}
|
||||
return "pq: " + s[1:]
|
||||
}
|
||||
|
||||
func errorf(s string, args ...interface{}) {
|
||||
panic(Error(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))))
|
||||
}
|
||||
|
||||
type SimplePGError struct {
|
||||
pgError
|
||||
}
|
||||
|
||||
func (err *SimplePGError) Error() string {
|
||||
return "pq: " + err.Get('M')
|
||||
}
|
||||
|
||||
func errRecoverWithPGReason(err *error) {
|
||||
e := recover()
|
||||
switch v := e.(type) {
|
||||
case nil:
|
||||
// Do nothing
|
||||
case *pgError:
|
||||
// Return a SimplePGError in place
|
||||
*err = &SimplePGError{*v}
|
||||
default:
|
||||
// Otherwise re-panic
|
||||
panic(e)
|
||||
}
|
||||
}
|
||||
|
||||
func errRecover(err *error) {
|
||||
e := recover()
|
||||
switch v := e.(type) {
|
||||
case nil:
|
||||
// Do nothing
|
||||
case runtime.Error:
|
||||
panic(v)
|
||||
case *pgError:
|
||||
if v.Fatal() {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
case *net.OpError:
|
||||
*err = driver.ErrBadConn
|
||||
case error:
|
||||
if v == io.EOF || v.(error).Error() == "remote error: handshake failure" {
|
||||
*err = driver.ErrBadConn
|
||||
} else {
|
||||
*err = v
|
||||
}
|
||||
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown error: %#v", e))
|
||||
}
|
||||
}
|
@ -0,0 +1,319 @@
|
||||
package pq
|
||||
|
||||
type oid uint32
|
||||
|
||||
const (
|
||||
t_bool oid = 16
|
||||
t_bytea = 17
|
||||
t_char = 18
|
||||
t_name = 19
|
||||
t_int8 = 20
|
||||
t_int2 = 21
|
||||
t_int2vector = 22
|
||||
t_int4 = 23
|
||||
t_regproc = 24
|
||||
t_text = 25
|
||||
t_oid = 26
|
||||
t_tid = 27
|
||||
t_xid = 28
|
||||
t_cid = 29
|
||||
t_oidvector = 30
|
||||
t_pg_type = 71
|
||||
t_pg_attribute = 75
|
||||
t_pg_proc = 81
|
||||
t_pg_class = 83
|
||||
t_xml = 142
|
||||
t__xml = 143
|
||||
t_pg_node_tree = 194
|
||||
t_smgr = 210
|
||||
t_point = 600
|
||||
t_lseg = 601
|
||||
t_path = 602
|
||||
t_box = 603
|
||||
t_polygon = 604
|
||||
t_line = 628
|
||||
t__line = 629
|
||||
t_float4 = 700
|
||||
t_float8 = 701
|
||||
t_abstime = 702
|
||||
t_reltime = 703
|
||||
t_tinterval = 704
|
||||
t_unknown = 705
|
||||
t_circle = 718
|
||||
t__circle = 719
|
||||
t_money = 790
|
||||
t__money = 791
|
||||
t_macaddr = 829
|
||||
t_inet = 869
|
||||
t_cidr = 650
|
||||
t__bool = 1000
|
||||
t__bytea = 1001
|
||||
t__char = 1002
|
||||
t__name = 1003
|
||||
t__int2 = 1005
|
||||
t__int2vector = 1006
|
||||
t__int4 = 1007
|
||||
t__regproc = 1008
|
||||
t__text = 1009
|
||||
t__oid = 1028
|
||||
t__tid = 1010
|
||||
t__xid = 1011
|
||||
t__cid = 1012
|
||||
t__oidvector = 1013
|
||||
t__bpchar = 1014
|
||||
t__varchar = 1015
|
||||
t__int8 = 1016
|
||||
t__point = 1017
|
||||
t__lseg = 1018
|
||||
t__path = 1019
|
||||
t__box = 1020
|
||||
t__float4 = 1021
|
||||
t__float8 = 1022
|
||||
t__abstime = 1023
|
||||
t__reltime = 1024
|
||||
t__tinterval = 1025
|
||||
t__polygon = 1027
|
||||
t_aclitem = 1033
|
||||
t__aclitem = 1034
|
||||
t__macaddr = 1040
|
||||
t__inet = 1041
|
||||
t__cidr = 651
|
||||
t__cstring = 1263
|
||||
t_bpchar = 1042
|
||||
t_varchar = 1043
|
||||
t_date = 1082
|
||||
t_time = 1083
|
||||
t_timestamp = 1114
|
||||
t__timestamp = 1115
|
||||
t__date = 1182
|
||||
t__time = 1183
|
||||
t_timestamptz = 1184
|
||||
t__timestamptz = 1185
|
||||
t_interval = 1186
|
||||
t__interval = 1187
|
||||
t__numeric = 1231
|
||||
t_timetz = 1266
|
||||
t__timetz = 1270
|
||||
t_bit = 1560
|
||||
t__bit = 1561
|
||||
t_varbit = 1562
|
||||
t__varbit = 1563
|
||||
t_numeric = 1700
|
||||
t_refcursor = 1790
|
||||
t__refcursor = 2201
|
||||
t_regprocedure = 2202
|
||||
t_regoper = 2203
|
||||
t_regoperator = 2204
|
||||
t_regclass = 2205
|
||||
t_regtype = 2206
|
||||
t__regprocedure = 2207
|
||||
t__regoper = 2208
|
||||
t__regoperator = 2209
|
||||
t__regclass = 2210
|
||||
t__regtype = 2211
|
||||
t_uuid = 2950
|
||||
t__uuid = 2951
|
||||
t_tsvector = 3614
|
||||
t_gtsvector = 3642
|
||||
t_tsquery = 3615
|
||||
t_regconfig = 3734
|
||||
t_regdictionary = 3769
|
||||
t__tsvector = 3643
|
||||
t__gtsvector = 3644
|
||||
t__tsquery = 3645
|
||||
t__regconfig = 3735
|
||||
t__regdictionary = 3770
|
||||
t_txid_snapshot = 2970
|
||||
t__txid_snapshot = 2949
|
||||
t_record = 2249
|
||||
t__record = 2287
|
||||
t_cstring = 2275
|
||||
t_any = 2276
|
||||
t_anyarray = 2277
|
||||
t_void = 2278
|
||||
t_trigger = 2279
|
||||
t_language_handler = 2280
|
||||
t_internal = 2281
|
||||
t_opaque = 2282
|
||||
t_anyelement = 2283
|
||||
t_anynonarray = 2776
|
||||
t_anyenum = 3500
|
||||
t_fdw_handler = 3115
|
||||
t_pg_attrdef = 10000
|
||||
t_pg_constraint = 10001
|
||||
t_pg_inherits = 10002
|
||||
t_pg_index = 10003
|
||||
t_pg_operator = 10004
|
||||
t_pg_opfamily = 10005
|
||||
t_pg_opclass = 10006
|
||||
t_pg_am = 10117
|
||||
t_pg_amop = 10118
|
||||
t_pg_amproc = 10478
|
||||
t_pg_language = 10731
|
||||
t_pg_largeobject_metadata = 10732
|
||||
t_pg_largeobject = 10733
|
||||
t_pg_aggregate = 10734
|
||||
t_pg_statistic = 10735
|
||||
t_pg_rewrite = 10736
|
||||
t_pg_trigger = 10737
|
||||
t_pg_description = 10738
|
||||
t_pg_cast = 10739
|
||||
t_pg_enum = 10936
|
||||
t_pg_namespace = 10937
|
||||
t_pg_conversion = 10938
|
||||
t_pg_depend = 10939
|
||||
t_pg_database = 1248
|
||||
t_pg_db_role_setting = 10940
|
||||
t_pg_tablespace = 10941
|
||||
t_pg_pltemplate = 10942
|
||||
t_pg_authid = 2842
|
||||
t_pg_auth_members = 2843
|
||||
t_pg_shdepend = 10943
|
||||
t_pg_shdescription = 10944
|
||||
t_pg_ts_config = 10945
|
||||
t_pg_ts_config_map = 10946
|
||||
t_pg_ts_dict = 10947
|
||||
t_pg_ts_parser = 10948
|
||||
t_pg_ts_template = 10949
|
||||
t_pg_extension = 10950
|
||||
t_pg_foreign_data_wrapper = 10951
|
||||
t_pg_foreign_server = 10952
|
||||
t_pg_user_mapping = 10953
|
||||
t_pg_foreign_table = 10954
|
||||
t_pg_default_acl = 10955
|
||||
t_pg_seclabel = 10956
|
||||
t_pg_collation = 10957
|
||||
t_pg_toast_2604 = 10958
|
||||
t_pg_toast_2606 = 10959
|
||||
t_pg_toast_2609 = 10960
|
||||
t_pg_toast_1255 = 10961
|
||||
t_pg_toast_2618 = 10962
|
||||
t_pg_toast_3596 = 10963
|
||||
t_pg_toast_2619 = 10964
|
||||
t_pg_toast_2620 = 10965
|
||||
t_pg_toast_1262 = 10966
|
||||
t_pg_toast_2396 = 10967
|
||||
t_pg_toast_2964 = 10968
|
||||
t_pg_roles = 10970
|
||||
t_pg_shadow = 10973
|
||||
t_pg_group = 10976
|
||||
t_pg_user = 10979
|
||||
t_pg_rules = 10982
|
||||
t_pg_views = 10986
|
||||
t_pg_tables = 10989
|
||||
t_pg_indexes = 10993
|
||||
t_pg_stats = 10997
|
||||
t_pg_locks = 11001
|
||||
t_pg_cursors = 11004
|
||||
t_pg_available_extensions = 11007
|
||||
t_pg_available_extension_versions = 11010
|
||||
t_pg_prepared_xacts = 11013
|
||||
t_pg_prepared_statements = 11017
|
||||
t_pg_seclabels = 11020
|
||||
t_pg_settings = 11024
|
||||
t_pg_timezone_abbrevs = 11029
|
||||
t_pg_timezone_names = 11032
|
||||
t_pg_stat_all_tables = 11035
|
||||
t_pg_stat_xact_all_tables = 11039
|
||||
t_pg_stat_sys_tables = 11043
|
||||
t_pg_stat_xact_sys_tables = 11047
|
||||
t_pg_stat_user_tables = 11050
|
||||
t_pg_stat_xact_user_tables = 11054
|
||||
t_pg_statio_all_tables = 11057
|
||||
t_pg_statio_sys_tables = 11061
|
||||
t_pg_statio_user_tables = 11064
|
||||
t_pg_stat_all_indexes = 11067
|
||||
t_pg_stat_sys_indexes = 11071
|
||||
t_pg_stat_user_indexes = 11074
|
||||
t_pg_statio_all_indexes = 11077
|
||||
t_pg_statio_sys_indexes = 11081
|
||||
t_pg_statio_user_indexes = 11084
|
||||
t_pg_statio_all_sequences = 11087
|
||||
t_pg_statio_sys_sequences = 11090
|
||||
t_pg_statio_user_sequences = 11093
|
||||
t_pg_stat_activity = 11096
|
||||
t_pg_stat_replication = 11099
|
||||
t_pg_stat_database = 11102
|
||||
t_pg_stat_database_conflicts = 11105
|
||||
t_pg_stat_user_functions = 11108
|
||||
t_pg_stat_xact_user_functions = 11112
|
||||
t_pg_stat_bgwriter = 11116
|
||||
t_pg_user_mappings = 11119
|
||||
t_cardinal_number = 11669
|
||||
t_character_data = 11671
|
||||
t_sql_identifier = 11672
|
||||
t_information_schema_catalog_name = 11674
|
||||
t_time_stamp = 11676
|
||||
t_yes_or_no = 11677
|
||||
t_applicable_roles = 11680
|
||||
t_administrable_role_authorizations = 11684
|
||||
t_attributes = 11687
|
||||
t_character_sets = 11691
|
||||
t_check_constraint_routine_usage = 11695
|
||||
t_check_constraints = 11699
|
||||
t_collations = 11703
|
||||
t_collation_character_set_applicability = 11706
|
||||
t_column_domain_usage = 11709
|
||||
t_column_privileges = 11713
|
||||
t_column_udt_usage = 11717
|
||||
t_columns = 11721
|
||||
t_constraint_column_usage = 11725
|
||||
t_constraint_table_usage = 11729
|
||||
t_domain_constraints = 11733
|
||||
t_domain_udt_usage = 11737
|
||||
t_domains = 11740
|
||||
t_enabled_roles = 11744
|
||||
t_key_column_usage = 11747
|
||||
t_parameters = 11751
|
||||
t_referential_constraints = 11755
|
||||
t_role_column_grants = 11759
|
||||
t_routine_privileges = 11762
|
||||
t_role_routine_grants = 11766
|
||||
t_routines = 11769
|
||||
t_schemata = 11773
|
||||
t_sequences = 11776
|
||||
t_sql_features = 11780
|
||||
t_pg_toast_11779 = 11782
|
||||
t_sql_implementation_info = 11785
|
||||
t_pg_toast_11784 = 11787
|
||||
t_sql_languages = 11790
|
||||
t_pg_toast_11789 = 11792
|
||||
t_sql_packages = 11795
|
||||
t_pg_toast_11794 = 11797
|
||||
t_sql_parts = 11800
|
||||
t_pg_toast_11799 = 11802
|
||||
t_sql_sizing = 11805
|
||||
t_pg_toast_11804 = 11807
|
||||
t_sql_sizing_profiles = 11810
|
||||
t_pg_toast_11809 = 11812
|
||||
t_table_constraints = 11815
|
||||
t_table_privileges = 11819
|
||||
t_role_table_grants = 11823
|
||||
t_tables = 11826
|
||||
t_triggered_update_columns = 11830
|
||||
t_triggers = 11834
|
||||
t_usage_privileges = 11838
|
||||
t_role_usage_grants = 11842
|
||||
t_view_column_usage = 11845
|
||||
t_view_routine_usage = 11849
|
||||
t_view_table_usage = 11853
|
||||
t_views = 11857
|
||||
t_data_type_privileges = 11861
|
||||
t_element_types = 11865
|
||||
t__pg_foreign_data_wrappers = 11869
|
||||
t_foreign_data_wrapper_options = 11872
|
||||
t_foreign_data_wrappers = 11875
|
||||
t__pg_foreign_servers = 11878
|
||||
t_foreign_server_options = 11882
|
||||
t_foreign_servers = 11885
|
||||
t__pg_foreign_tables = 11888
|
||||
t_foreign_table_options = 11892
|
||||
t_foreign_tables = 11895
|
||||
t__pg_user_mappings = 11898
|
||||
t_user_mapping_options = 11901
|
||||
t_user_mappings = 11905
|
||||
t_t = 16806
|
||||
t__t = 16805
|
||||
t_temp = 16810
|
||||
t__temp = 16809
|
||||
)
|
@ -0,0 +1,68 @@
|
||||
package pq
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
nurl "net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseURL converts url to a connection string for driver.Open.
|
||||
// Example:
|
||||
//
|
||||
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
|
||||
//
|
||||
// converts to:
|
||||
//
|
||||
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
|
||||
//
|
||||
// A minimal example:
|
||||
//
|
||||
// "postgres://"
|
||||
//
|
||||
// This will be blank, causing driver.Open to use all of the defaults
|
||||
func ParseURL(url string) (string, error) {
|
||||
u, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if u.Scheme != "postgres" {
|
||||
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
|
||||
}
|
||||
|
||||
var kvs []string
|
||||
accrue := func(k, v string) {
|
||||
if v != "" {
|
||||
kvs = append(kvs, k+"="+v)
|
||||
}
|
||||
}
|
||||
|
||||
if u.User != nil {
|
||||
v := u.User.Username()
|
||||
accrue("user", v)
|
||||
|
||||
v, _ = u.User.Password()
|
||||
accrue("password", v)
|
||||
}
|
||||
|
||||
i := strings.Index(u.Host, ":")
|
||||
if i < 0 {
|
||||
accrue("host", u.Host)
|
||||
} else {
|
||||
accrue("host", u.Host[:i])
|
||||
accrue("port", u.Host[i+1:])
|
||||
}
|
||||
|
||||
if u.Path != "" {
|
||||
accrue("dbname", u.Path[1:])
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
for k, _ := range q {
|
||||
accrue(k, q.Get(k))
|
||||
}
|
||||
|
||||
sort.Strings(kvs) // Makes testing easier (not a performance concern)
|
||||
return strings.Join(kvs, " "), nil
|
||||
}
|
Loading…
Reference in new issue