Compare commits

...

31 Commits

3
.gitignore vendored

@ -7,4 +7,5 @@
*_test.go
*.txt
*.text
gomod.sh
go_mod.sh
/vendor/

@ -1,8 +1,8 @@
<h1>
<a href="https://www.dtapp.net/">Golang Dorm</a>
<a href="https://www.dtapp.net/">Golang</a>
</h1>
📦 Golang Dorm
📦 Golang 数据库
[comment]: <> (go)
[![godoc](https://pkg.go.dev/badge/go.dtapp.net/dorm?status.svg)](https://pkg.go.dev/go.dtapp.net/dorm)
@ -13,5 +13,5 @@
#### 安装
```shell
go get -v -u go.dtapp.net/dorm
go get -v -u go.dtapp.net/dorm@v1.0.59
```

@ -1,9 +0,0 @@
package dorm
import (
"github.com/beego/beego/v2/client/orm"
)
type BeegoClient struct {
Db orm.Ormer // 驱动
}

@ -1,19 +0,0 @@
package dorm
import (
"github.com/beego/beego/v2/client/orm"
_ "github.com/go-sql-driver/mysql"
)
func NewBeegoMysqlClient(dns string) (*BeegoClient, error) {
c := &BeegoClient{}
err := orm.RegisterDataBase("default", "mysql", dns)
if err != nil {
return nil, err
}
c.Db = orm.NewOrm()
return c, err
}

@ -1,19 +0,0 @@
package dorm
import (
"github.com/beego/beego/v2/client/orm"
_ "github.com/lib/pq"
)
func NewBeegoPostgresClient(dns string) (*BeegoClient, error) {
c := &BeegoClient{}
err := orm.RegisterDataBase("default", "postgres", dns)
if err != nil {
return nil, err
}
c.Db = orm.NewOrm()
return c, err
}

@ -1,16 +0,0 @@
package dorm
import (
"github.com/uptrace/bun"
)
type ConfigBunClient struct {
Dns string // 地址
}
// BunClient
// https://bun.uptrace.dev/
type BunClient struct {
Db *bun.DB // 驱动
config *ConfigBunClient // 配置
}

@ -1,10 +0,0 @@
package dorm
import (
"github.com/uptrace/bun"
)
// GetDb 获取驱动
func (c *BunClient) GetDb() *bun.DB {
return c.Db
}

@ -1,25 +0,0 @@
package dorm
import (
"database/sql"
"errors"
"fmt"
_ "github.com/go-sql-driver/mysql"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/mysqldialect"
)
func NewBunMysqlClient(config *ConfigBunClient) (*BunClient, error) {
var err error
c := &BunClient{config: config}
sqlDb, err := sql.Open("mysql", c.config.Dns)
if err != nil {
return nil, errors.New(fmt.Sprintf("加载驱动失败:%v", err))
}
c.Db = bun.NewDB(sqlDb, mysqldialect.New())
return c, nil
}

@ -1,19 +0,0 @@
package dorm
import (
"database/sql"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)
func NewBunPgsqlClient(config *ConfigBunClient) (*BunClient, error) {
c := &BunClient{config: config}
sqlDb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(c.config.Dns)))
c.Db = bun.NewDB(sqlDb, pgdialect.New())
return c, nil
}

@ -1,3 +0,0 @@
package dorm
const Version = "1.0.33"

@ -1 +0,0 @@
package dorm

@ -1 +0,0 @@
package dorm

110
go.mod

@ -1,87 +1,57 @@
module go.dtapp.net/dorm
go 1.19
go 1.22.1
require (
github.com/basgys/goxml2json v1.1.0
github.com/beego/beego/v2 v2.0.5
github.com/go-redis/redis/v9 v9.0.0-beta.2
github.com/go-sql-driver/mysql v1.6.0
github.com/kamva/mgm/v3 v3.4.1
github.com/lib/pq v1.10.7
github.com/qiniu/qmgo v1.1.2
github.com/rs/xid v1.4.0
github.com/sirupsen/logrus v1.9.0
github.com/upper/db/v4 v4.6.0
github.com/uptrace/bun v1.1.8
github.com/uptrace/bun/dialect/mysqldialect v1.1.8
github.com/uptrace/bun/dialect/pgdialect v1.1.8
github.com/uptrace/bun/driver/pgdriver v1.1.8
go.dtapp.net/gotime v1.0.5
go.dtapp.net/goxml v1.0.1
go.mongodb.org/mongo-driver v1.10.2
gorm.io/driver/mysql v1.3.6
gorm.io/driver/postgres v1.3.9
gorm.io/gorm v1.23.8
xorm.io/xorm v1.3.2
github.com/redis/go-redis/v9 v9.5.1
github.com/sirupsen/logrus v1.9.3
go.dtapp.net/gojson v1.0.4
go.dtapp.net/gotime v1.0.8
go.mongodb.org/mongo-driver v1.15.0
gorm.io/driver/mysql v1.5.6
gorm.io/driver/postgres v1.5.7
gorm.io/gen v0.3.26
gorm.io/gorm v1.25.10
)
require (
github.com/bitly/go-simplejson v0.5.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/basgys/goxml2json v1.1.0 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/go-playground/locales v0.14.0 // indirect
github.com/go-playground/universal-translator v0.18.0 // indirect
github.com/go-playground/validator/v10 v10.11.0 // indirect
github.com/goccy/go-json v0.9.11 // indirect
github.com/go-sql-driver/mysql v1.8.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.13.0 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.3.1 // indirect
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b // indirect
github.com/jackc/pgtype v1.12.0 // indirect
github.com/jackc/pgx/v4 v4.17.2 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.15.9 // indirect
github.com/leodido/go-urn v1.2.1 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/klauspost/compress v1.17.8 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/montanaflynn/stats v0.6.6 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/segmentio/fasthash v1.0.3 // indirect
github.com/shiena/ansicolor v0.0.0-20200904210342-c7312218db18 // indirect
github.com/stretchr/testify v1.8.0 // indirect
github.com/syndtr/goleveldb v1.0.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/montanaflynn/stats v0.7.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.1 // indirect
github.com/xdg-go/stringprep v1.0.3 // indirect
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a // indirect
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/net v0.0.0-20220909164309-bea034e7d591 // indirect
golang.org/x/sync v0.0.0-20220907140024-f12130a52804 // indirect
golang.org/x/sys v0.0.0-20220913175220-63ea55921009 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/tools v0.1.11 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
lukechampine.com/uint128 v1.2.0 // indirect
mellium.im/sasl v0.3.0 // indirect
modernc.org/ccgo/v3 v3.16.7 // indirect
modernc.org/libc v1.16.14 // indirect
modernc.org/opt v0.1.3 // indirect
modernc.org/sqlite v1.17.3 // indirect
modernc.org/strutil v1.1.2 // indirect
xorm.io/builder v0.3.12 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20240424034433-3c2c7870ae76 // indirect
golang.org/x/arch v0.7.0 // indirect
golang.org/x/crypto v0.22.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.24.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.20.0 // indirect
gorm.io/datatypes v1.2.0 // indirect
gorm.io/hints v1.1.2 // indirect
gorm.io/plugin/dbresolver v1.5.1 // indirect
)

966
go.sum

File diff suppressed because it is too large Load Diff

@ -1 +0,0 @@
package dorm

@ -0,0 +1,18 @@
package dorm
import "database/sql"
// Ping ping
func (c *GormClient) Ping() error {
return c.sqlDd.Ping()
}
// Close 关闭
func (c *GormClient) Close() error {
return c.sqlDd.Close()
}
// Stats 返回数据库统计信息
func (c *GormClient) Stats() sql.DBStats {
return c.sqlDd.Stats()
}

@ -1,6 +1,7 @@
package dorm
import (
"database/sql"
"fmt"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
@ -9,37 +10,45 @@ import (
"time"
)
type ConfigGormClient struct {
// GormClientFun *GormClient 驱动
type GormClientFun func() *GormClient
// GormClientTableFun
// *GormClient 驱动
// string 表名
type GormClientTableFun func() (*GormClient, string)
// GormClientConfig 配置
type GormClientConfig struct {
Dns string // 地址
LogStatus bool // 日志 - 状态
LogPath string // 日志 - 路径
LogSlow int64 // 日志 - 慢SQL阈值
LogLevel string // 日志 - 级别
LogNotFoundError bool // 日志 - 忽略ErrRecordNotFound记录未找到错误
LogColorful bool // 日志 - 禁用彩色打印
ConnSetMaxIdle int // 连接 - 设置空闲连接池中连接的最大数量
ConnSetMaxOpen int // 连接 - 设置打开数据库连接的最大数量
ConnSetConnMaxLifetime int64 // 连接 - 设置了连接可复用的最大时间
}
// GormClient
// https://gorm.io/
// https://gorm.io/zh_CN/docs/index.html
type GormClient struct {
Db *gorm.DB // 驱动
config *ConfigGormClient // 配置
db *gorm.DB // 驱动
config *GormClientConfig // 配置
sqlDd *sql.DB
}
type writer struct{}
// 日志路径
var logsUrl = ""
var logsURL = ""
func (w writer) Printf(format string, args ...interface{}) {
now := time.Now()
logFilePath := ""
if dir, err := os.Getwd(); err == nil {
logFilePath = dir + logsUrl
logFilePath = dir + logsURL
}
if err := os.MkdirAll(logFilePath, 0777); err != nil {
fmt.Println(err.Error())

@ -0,0 +1,29 @@
package dorm
import (
"gorm.io/gen"
"gorm.io/gorm"
)
// GormGenClientFun *GormClient 驱动
type GormGenClientFun func() *GormGenClient
// GormGenClientTableFun
// *GormClient 驱动
// string 表名
type GormGenClientTableFun func() (*GormGenClient, string)
// GormGenClientConfig 配置
type GormGenClientConfig struct {
Dns string // dns地址
Db *gorm.DB // db驱动
Config gen.Config // gen配置
}
// GormGenClient
// https://gorm.io/zh_CN/gen/index.html
type GormGenClient struct {
db *gorm.DB // 驱动
generator *gen.Generator // 驱动
config *GormGenClientConfig // 配置
}

@ -0,0 +1,16 @@
package dorm
import (
"gorm.io/gen"
"gorm.io/gorm"
)
// GetDb 获取驱动
func (c *GormGenClient) GetDb() *gorm.DB {
return c.db
}
// GetGenerator 获取驱动
func (c *GormGenClient) GetGenerator() *gen.Generator {
return c.generator
}

@ -0,0 +1,25 @@
package dorm
import (
"context"
"gorm.io/driver/mysql"
"gorm.io/gen"
"gorm.io/gorm"
)
// NewGormGenMysqlClient 创建GormGenClient实例 mysql
func NewGormGenMysqlClient(ctx context.Context, config *GormGenClientConfig) (*GormGenClient, error) {
c := &GormGenClient{config: config}
c.generator = gen.NewGenerator(config.Config)
if c.config.Dns != "" {
c.db, _ = gorm.Open(mysql.Open(c.config.Dns), &gorm.Config{})
c.generator.UseDB(c.db)
} else {
c.generator.UseDB(c.config.Db)
}
return c, nil
}

@ -0,0 +1,25 @@
package dorm
import (
"context"
"gorm.io/driver/postgres"
"gorm.io/gen"
"gorm.io/gorm"
)
// NewGormGenPostgresClient 创建GormGenClient实例 postgres
func NewGormGenPostgresClient(ctx context.Context, config *GormGenClientConfig) (*GormGenClient, error) {
c := &GormGenClient{config: config}
c.generator = gen.NewGenerator(config.Config)
if c.config.Dns != "" {
c.db, _ = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{})
c.generator.UseDB(c.db)
} else {
c.generator.UseDB(c.config.Db)
}
return c, nil
}

@ -0,0 +1,25 @@
package dorm
import (
"context"
"gorm.io/driver/postgres"
"gorm.io/gen"
"gorm.io/gorm"
)
// NewGormGenPostgresqlClient 创建GormGenClient实例 postgresql
func NewGormGenPostgresqlClient(ctx context.Context, config *GormGenClientConfig) (*GormGenClient, error) {
c := &GormGenClient{config: config}
c.generator = gen.NewGenerator(config.Config)
if c.config.Dns != "" {
c.db, _ = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{})
c.generator.UseDB(c.db)
} else {
c.generator.UseDB(c.config.Db)
}
return c, nil
}

@ -1,8 +1,16 @@
package dorm
import "gorm.io/gorm"
import (
"database/sql"
"gorm.io/gorm"
)
// GetDb 获取驱动
func (c *GormClient) GetDb() *gorm.DB {
return c.Db
return c.db
}
// GetSqlDb 获取驱动
func (c *GormClient) GetSqlDb() *sql.DB {
return c.sqlDd
}

@ -1,6 +1,7 @@
package dorm
import (
"context"
"errors"
"fmt"
"go.dtapp.net/gotime"
@ -10,16 +11,17 @@ import (
"time"
)
func NewGormMysqlClient(config *ConfigGormClient) (*GormClient, error) {
// NewGormMysqlClient 创建GormClient实例 mysql
func NewGormMysqlClient(ctx context.Context, config *GormClientConfig) (*GormClient, error) {
var err error
c := &GormClient{config: config}
// 判断路径
if c.config.LogPath == "" {
logsUrl = "/logs/mysql"
logsURL = "/logs/mysql"
} else {
logsUrl = c.config.LogPath
logsURL = c.config.LogPath
}
if c.config.LogStatus {
@ -37,14 +39,14 @@ func NewGormMysqlClient(config *ConfigGormClient) (*GormClient, error) {
} else {
logLevel = logger.Info
}
c.Db, err = gorm.Open(mysql.Open(c.config.Dns), &gorm.Config{
c.db, err = gorm.Open(mysql.Open(c.config.Dns), &gorm.Config{
Logger: logger.New(
writer{},
logger.Config{
SlowThreshold: slowThreshold, // 慢SQL阈值
LogLevel: logLevel, // 日志级别
IgnoreRecordNotFoundError: c.config.LogNotFoundError, // 忽略ErrRecordNotFound记录未找到错误
Colorful: c.config.LogColorful, // 禁用彩色打印
SlowThreshold: slowThreshold, // 慢SQL阈值
LogLevel: logLevel, // 日志级别
IgnoreRecordNotFoundError: true, // 忽略ErrRecordNotFound记录未找到错误
Colorful: true, // 禁用彩色打印
},
),
NowFunc: func() time.Time {
@ -52,37 +54,37 @@ func NewGormMysqlClient(config *ConfigGormClient) (*GormClient, error) {
},
})
} else {
c.Db, err = gorm.Open(mysql.Open(c.config.Dns), &gorm.Config{})
c.db, err = gorm.Open(mysql.Open(c.config.Dns), &gorm.Config{})
}
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}
sqlDB, err := c.Db.DB()
c.sqlDd, err = c.db.DB()
if err != nil {
return nil, errors.New(fmt.Sprintf("检查连接失败:%v", err))
return nil, errors.New(fmt.Sprintf("获取通用数据库对象失败:%v", err))
}
// 设置空闲连接池中连接的最大数量
if c.config.ConnSetMaxIdle == 0 {
sqlDB.SetMaxIdleConns(10)
c.sqlDd.SetMaxIdleConns(10)
} else {
sqlDB.SetMaxIdleConns(c.config.ConnSetMaxIdle)
c.sqlDd.SetMaxIdleConns(c.config.ConnSetMaxIdle)
}
// 设置打开数据库连接的最大数量
if c.config.ConnSetMaxOpen == 0 {
sqlDB.SetMaxOpenConns(100)
c.sqlDd.SetMaxOpenConns(100)
} else {
sqlDB.SetMaxOpenConns(c.config.ConnSetMaxOpen)
c.sqlDd.SetMaxOpenConns(c.config.ConnSetMaxOpen)
}
// 设置了连接可复用的最大时间
if c.config.ConnSetConnMaxLifetime == 0 {
sqlDB.SetConnMaxLifetime(time.Second * 600)
c.sqlDd.SetConnMaxLifetime(time.Second * 600)
} else {
sqlDB.SetConnMaxLifetime(time.Duration(c.config.ConnSetConnMaxLifetime))
c.sqlDd.SetConnMaxLifetime(time.Duration(c.config.ConnSetConnMaxLifetime))
}
return c, nil

@ -0,0 +1,92 @@
package dorm
import (
"context"
"errors"
"fmt"
"go.dtapp.net/gotime"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"time"
)
// NewGormPostgresClient 创建GormClient实例 postgres
func NewGormPostgresClient(ctx context.Context, config *GormClientConfig) (*GormClient, error) {
var err error
c := &GormClient{}
c.config = config
// 判断路径
if c.config.LogPath == "" {
logsURL = "/logs/postgresql"
} else {
logsURL = c.config.LogPath
}
if c.config.LogStatus {
var slowThreshold time.Duration
var logLevel logger.LogLevel
if c.config.LogSlow == 0 {
slowThreshold = 100 * time.Millisecond
} else {
slowThreshold = time.Duration(c.config.LogSlow)
}
if c.config.LogLevel == "Error" {
logLevel = logger.Error
} else if c.config.LogLevel == "Warn" {
logLevel = logger.Warn
} else {
logLevel = logger.Info
}
c.db, err = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{
Logger: logger.New(
writer{},
logger.Config{
SlowThreshold: slowThreshold, // 慢SQL阈值
LogLevel: logLevel, // 日志级别
IgnoreRecordNotFoundError: true, // 忽略ErrRecordNotFound记录未找到错误
Colorful: true, // 禁用彩色打印
},
),
NowFunc: func() time.Time {
return gotime.Current().Now().Local()
},
})
} else {
c.db, err = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{})
}
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}
c.sqlDd, err = c.db.DB()
if err != nil {
return nil, errors.New(fmt.Sprintf("获取通用数据库对象失败:%v", err))
}
// 设置空闲连接池中连接的最大数量
if c.config.ConnSetMaxIdle == 0 {
c.sqlDd.SetMaxIdleConns(10)
} else {
c.sqlDd.SetMaxIdleConns(c.config.ConnSetMaxIdle)
}
// 设置打开数据库连接的最大数量
if c.config.ConnSetMaxOpen == 0 {
c.sqlDd.SetMaxOpenConns(100)
} else {
c.sqlDd.SetMaxOpenConns(c.config.ConnSetMaxOpen)
}
// 设置了连接可复用的最大时间
if c.config.ConnSetConnMaxLifetime == 0 {
c.sqlDd.SetConnMaxLifetime(time.Hour)
} else {
c.sqlDd.SetConnMaxLifetime(time.Duration(c.config.ConnSetConnMaxLifetime))
}
return c, nil
}

@ -1,6 +1,7 @@
package dorm
import (
"context"
"errors"
"fmt"
"go.dtapp.net/gotime"
@ -10,7 +11,8 @@ import (
"time"
)
func NewGormPostgresClient(config *ConfigGormClient) (*GormClient, error) {
// NewGormPostgresqlClient 创建GormClient实例 postgresql
func NewGormPostgresqlClient(ctx context.Context, config *GormClientConfig) (*GormClient, error) {
var err error
c := &GormClient{}
@ -18,9 +20,9 @@ func NewGormPostgresClient(config *ConfigGormClient) (*GormClient, error) {
// 判断路径
if c.config.LogPath == "" {
logsUrl = "/logs/postgresql"
logsURL = "/logs/postgresql"
} else {
logsUrl = c.config.LogPath
logsURL = c.config.LogPath
}
if c.config.LogStatus {
@ -38,14 +40,14 @@ func NewGormPostgresClient(config *ConfigGormClient) (*GormClient, error) {
} else {
logLevel = logger.Info
}
c.Db, err = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{
c.db, err = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{
Logger: logger.New(
writer{},
logger.Config{
SlowThreshold: slowThreshold, // 慢SQL阈值
LogLevel: logLevel, // 日志级别
IgnoreRecordNotFoundError: c.config.LogNotFoundError, // 忽略ErrRecordNotFound记录未找到错误
Colorful: c.config.LogColorful, // 禁用彩色打印
SlowThreshold: slowThreshold, // 慢SQL阈值
LogLevel: logLevel, // 日志级别
IgnoreRecordNotFoundError: true, // 忽略ErrRecordNotFound记录未找到错误
Colorful: true, // 禁用彩色打印
},
),
NowFunc: func() time.Time {
@ -53,116 +55,37 @@ func NewGormPostgresClient(config *ConfigGormClient) (*GormClient, error) {
},
})
} else {
c.Db, err = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{})
c.db, err = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{})
}
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}
sqlDB, err := c.Db.DB()
c.sqlDd, err = c.db.DB()
if err != nil {
return nil, errors.New(fmt.Sprintf("检查连接失败:%v", err))
return nil, errors.New(fmt.Sprintf("获取通用数据库对象失败:%v", err))
}
// 设置空闲连接池中连接的最大数量
if c.config.ConnSetMaxIdle == 0 {
sqlDB.SetMaxIdleConns(10)
c.sqlDd.SetMaxIdleConns(10)
} else {
sqlDB.SetMaxIdleConns(c.config.ConnSetMaxIdle)
c.sqlDd.SetMaxIdleConns(c.config.ConnSetMaxIdle)
}
// 设置打开数据库连接的最大数量
if c.config.ConnSetMaxOpen == 0 {
sqlDB.SetMaxOpenConns(100)
c.sqlDd.SetMaxOpenConns(100)
} else {
sqlDB.SetMaxOpenConns(c.config.ConnSetMaxOpen)
c.sqlDd.SetMaxOpenConns(c.config.ConnSetMaxOpen)
}
// 设置了连接可复用的最大时间
if c.config.ConnSetConnMaxLifetime == 0 {
sqlDB.SetConnMaxLifetime(time.Hour)
c.sqlDd.SetConnMaxLifetime(time.Hour)
} else {
sqlDB.SetConnMaxLifetime(time.Duration(c.config.ConnSetConnMaxLifetime))
}
return c, nil
}
func NewGormPostgresqlClient(config *ConfigGormClient) (*GormClient, error) {
var err error
c := &GormClient{}
c.config = config
// 判断路径
if c.config.LogPath == "" {
logsUrl = "/logs/postgresql"
} else {
logsUrl = c.config.LogPath
}
if c.config.LogStatus {
var slowThreshold time.Duration
var logLevel logger.LogLevel
if c.config.LogSlow == 0 {
slowThreshold = 100 * time.Millisecond
} else {
slowThreshold = time.Duration(c.config.LogSlow)
}
if c.config.LogLevel == "Error" {
logLevel = logger.Error
} else if c.config.LogLevel == "Warn" {
logLevel = logger.Warn
} else {
logLevel = logger.Info
}
c.Db, err = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{
Logger: logger.New(
writer{},
logger.Config{
SlowThreshold: slowThreshold, // 慢SQL阈值
LogLevel: logLevel, // 日志级别
IgnoreRecordNotFoundError: c.config.LogNotFoundError, // 忽略ErrRecordNotFound记录未找到错误
Colorful: c.config.LogColorful, // 禁用彩色打印
},
),
NowFunc: func() time.Time {
return gotime.Current().Now().Local()
},
})
} else {
c.Db, err = gorm.Open(postgres.Open(c.config.Dns), &gorm.Config{})
}
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}
sqlDB, err := c.Db.DB()
if err != nil {
return nil, errors.New(fmt.Sprintf("检查连接失败:%v", err))
}
// 设置空闲连接池中连接的最大数量
if c.config.ConnSetMaxIdle == 0 {
sqlDB.SetMaxIdleConns(10)
} else {
sqlDB.SetMaxIdleConns(c.config.ConnSetMaxIdle)
}
// 设置打开数据库连接的最大数量
if c.config.ConnSetMaxOpen == 0 {
sqlDB.SetMaxOpenConns(100)
} else {
sqlDB.SetMaxOpenConns(c.config.ConnSetMaxOpen)
}
// 设置了连接可复用的最大时间
if c.config.ConnSetConnMaxLifetime == 0 {
sqlDB.SetConnMaxLifetime(time.Second * 600)
} else {
sqlDB.SetConnMaxLifetime(time.Duration(c.config.ConnSetConnMaxLifetime))
c.sqlDd.SetConnMaxLifetime(time.Duration(c.config.ConnSetConnMaxLifetime))
}
return c, nil

@ -4,15 +4,15 @@ import "gorm.io/gorm"
// Begin 开始事务,不需要创建 Session 对象
func (c *GormClient) Begin() *gorm.DB {
return c.Db.Begin()
return c.db.Begin()
}
// Rollback 回滚事务
func (c *GormClient) Rollback() *gorm.DB {
return c.Db.Rollback()
return c.db.Rollback()
}
// Commit 提交事务
func (c *GormClient) Commit() *gorm.DB {
return c.Db.Commit()
return c.db.Commit()
}

@ -1,16 +0,0 @@
package dorm
import "encoding/json"
// JsonDecodeNoError json字符串转结构体不报错
func JsonDecodeNoError(b []byte) map[string]interface{} {
var data map[string]interface{}
_ = json.Unmarshal(b, &data)
return data
}
// JsonEncodeNoError 结构体转json字符串不报错
func JsonEncodeNoError(data interface{}) string {
jsons, _ := json.Marshal(data)
return string(jsons)
}

@ -4,48 +4,56 @@ import (
"context"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type ConfigMongoClient struct {
// MongoClientFun *MongoClient 驱动
// string 库名
type MongoClientFun func() (*MongoClient, string)
// MongoClientCollectionFun
// *MongoClient 驱动
// string 库名
// string 集合
type MongoClientCollectionFun func() (*MongoClient, string, string)
// MongoClientConfig 实例配置
type MongoClientConfig struct {
Dns string // 地址
Opts *options.ClientOptions
DatabaseName string // 库名
}
// MongoClient 实例
type MongoClient struct {
Db *mongo.Client // 驱动
config *ConfigMongoClient // 配置
databaseName string // 库名
collectionName string // 表名
//filterArr []queryFilter // 查询条件数组
filter bson.D // 查询条件
db *mongo.Client // 驱动
configDatabaseName string // 库名
}
func NewMongoClient(config *ConfigMongoClient) (*MongoClient, error) {
// NewMongoClient 创建实例
func NewMongoClient(ctx context.Context, config *MongoClientConfig) (*MongoClient, error) {
var err error
c := &MongoClient{config: config}
c := &MongoClient{}
c.databaseName = c.config.DatabaseName
c.configDatabaseName = config.DatabaseName
// 连接到MongoDB
if c.config.Dns != "" {
c.Db, err = mongo.Connect(context.Background(), options.Client().ApplyURI(c.config.Dns))
if config.Dns != "" {
c.db, err = mongo.Connect(ctx, options.Client().ApplyURI(config.Dns))
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}
} else {
c.Db, err = mongo.Connect(context.Background(), c.config.Opts)
c.db, err = mongo.Connect(ctx, config.Opts)
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}
}
// 检查连接
err = c.Db.Ping(context.TODO(), nil)
err = c.db.Ping(ctx, nil)
if err != nil {
return nil, errors.New(fmt.Sprintf("检查连接失败:%v", err))
}
@ -54,10 +62,6 @@ func NewMongoClient(config *ConfigMongoClient) (*MongoClient, error) {
}
// Close 关闭
func (c *MongoClient) Close() error {
err := c.Db.Disconnect(context.TODO())
if err != nil {
return errors.New(fmt.Sprintf("关闭失败:%v", err))
}
return nil
func (c *MongoClient) Close(ctx context.Context) error {
return c.db.Disconnect(ctx)
}

@ -1 +0,0 @@
package dorm

@ -0,0 +1,60 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type MongoCollectionOptions struct {
db *mongo.Client // 驱动
configDatabaseName string // 库名
dbCollection *mongo.Collection // 集合
}
// Collection 选择集合
func (cd *MongoDatabaseOptions) Collection(name string, opts ...*options.CollectionOptions) *MongoCollectionOptions {
return &MongoCollectionOptions{
db: cd.db, // 驱动
configDatabaseName: cd.configDatabaseName, // 库名
dbCollection: cd.dbDatabase.Collection(name, opts...), // 集合
}
}
// CreateOneIndexes 创建一个索引
func (cc *MongoCollectionOptions) CreateOneIndexes(ctx context.Context, key string, value string) (string, error) {
return cc.dbCollection.Indexes().CreateOne(ctx, mongo.IndexModel{
Keys: bson.D{{
Key: key,
Value: value,
}},
})
}
// CreateOneUniqueIndexes 创建一个唯一索引
func (cc *MongoCollectionOptions) CreateOneUniqueIndexes(ctx context.Context, key string, value string) (string, error) {
return cc.dbCollection.Indexes().CreateOne(ctx, mongo.IndexModel{
Keys: bson.D{{
Key: key,
Value: value,
}},
Options: options.Index().SetUnique(true),
})
}
// CreateOneUniqueIndexesOpts 创建一个索引
func (cc *MongoCollectionOptions) CreateOneUniqueIndexesOpts(ctx context.Context, key string, value string, opts *options.IndexOptions) (string, error) {
return cc.dbCollection.Indexes().CreateOne(ctx, mongo.IndexModel{
Keys: bson.D{{
Key: key,
Value: value,
}},
Options: opts,
})
}
// CreateManyIndexes 创建多个索引
func (cc *MongoCollectionOptions) CreateManyIndexes(ctx context.Context, models []mongo.IndexModel) ([]string, error) {
return cc.dbCollection.Indexes().CreateMany(ctx, models)
}

@ -0,0 +1,100 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// InsertOne 插入一个文档
func (cc *MongoCollectionOptions) InsertOne(ctx context.Context, document interface{}, opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) {
return cc.dbCollection.InsertOne(ctx, document, opts...)
}
// InsertMany 插入多个文档
func (cc *MongoCollectionOptions) InsertMany(ctx context.Context, document []interface{}, opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) {
return cc.dbCollection.InsertMany(ctx, document, opts...)
}
// DeleteOne 删除一个文档
func (cc *MongoCollectionOptions) DeleteOne(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) {
return cc.dbCollection.DeleteOne(ctx, filter, opts...)
}
// DeleteMany 删除多个文档
func (cc *MongoCollectionOptions) DeleteMany(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) {
return cc.dbCollection.DeleteMany(ctx, filter, opts...)
}
// UpdateByID 按ID更新
func (cc *MongoCollectionOptions) UpdateByID(ctx context.Context, id interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
return cc.dbCollection.UpdateByID(ctx, id, update, opts...)
}
// UpdateOne 更新一个文档
func (cc *MongoCollectionOptions) UpdateOne(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
return cc.dbCollection.UpdateOne(ctx, filter, update, opts...)
}
// UpdateMany 更新多个文档
func (cc *MongoCollectionOptions) UpdateMany(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
return cc.dbCollection.UpdateMany(ctx, filter, update, opts...)
}
// ReplaceOne 替换一个文档
func (cc *MongoCollectionOptions) ReplaceOne(ctx context.Context, filter interface{}, update interface{}, opts ...*options.ReplaceOptions) (*mongo.UpdateResult, error) {
return cc.dbCollection.ReplaceOne(ctx, filter, update, opts...)
}
// Aggregate 统计分析
func (cc *MongoCollectionOptions) Aggregate(ctx context.Context, pipeline interface{}, opts ...*options.AggregateOptions) (*mongo.Cursor, error) {
return cc.dbCollection.Aggregate(ctx, pipeline, opts...)
}
// CountDocuments 计数文档
func (cc *MongoCollectionOptions) CountDocuments(ctx context.Context, filter interface{}, opts ...*options.CountOptions) (int64, error) {
return cc.dbCollection.CountDocuments(ctx, filter, opts...)
}
// EstimatedDocumentCount 估计文档计数
func (cc *MongoCollectionOptions) EstimatedDocumentCount(ctx context.Context, opts ...*options.EstimatedDocumentCountOptions) (int64, error) {
return cc.dbCollection.EstimatedDocumentCount(ctx, opts...)
}
func (cc *MongoCollectionOptions) Distinct(ctx context.Context, fieldName string, filter interface{}, opts ...*options.DistinctOptions) ([]interface{}, error) {
return cc.dbCollection.Distinct(ctx, fieldName, filter, opts...)
}
// Find 查询多个文档
func (cc *MongoCollectionOptions) Find(ctx context.Context, filter interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) {
return cc.dbCollection.Find(ctx, filter, opts...)
}
// FindOne 查询一个文档
func (cc *MongoCollectionOptions) FindOne(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult {
return cc.dbCollection.FindOne(ctx, filter, opts...)
}
func (cc *MongoCollectionOptions) FindOneAndDelete(ctx context.Context, filter interface{}, opts ...*options.FindOneAndDeleteOptions) *mongo.SingleResult {
return cc.dbCollection.FindOneAndDelete(ctx, filter, opts...)
}
func (cc *MongoCollectionOptions) FindOneAndReplace(ctx context.Context, filter interface{}, replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *mongo.SingleResult {
return cc.dbCollection.FindOneAndReplace(ctx, filter, replacement, opts...)
}
func (cc *MongoCollectionOptions) FindOneAndUpdate(ctx context.Context, filter interface{}, replacement interface{}, opts ...*options.FindOneAndUpdateOptions) *mongo.SingleResult {
return cc.dbCollection.FindOneAndUpdate(ctx, filter, replacement, opts...)
}
func (cc *MongoCollectionOptions) Watch(ctx context.Context, pipeline interface{}, opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) {
return cc.dbCollection.Watch(ctx, pipeline, opts...)
}
func (cc *MongoCollectionOptions) Indexes(ctx context.Context) mongo.IndexView {
return cc.dbCollection.Indexes()
}
func (cc *MongoCollectionOptions) Drop(ctx context.Context) error {
return cc.dbCollection.Drop(ctx)
}

@ -1,185 +0,0 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// InsertOne 插入一个文档
func (c *MongoClient) InsertOne(document interface{}) (result *mongo.InsertOneResult, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.InsertOne(context.TODO(), document)
return
}
// InsertMany 插入多个文档
func (c *MongoClient) InsertMany(documents []interface{}) (result *mongo.InsertManyResult, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.InsertMany(context.TODO(), documents)
return
}
// Delete 删除文档
func (c *MongoClient) Delete(filter interface{}) (err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
_, err = collection.DeleteOne(context.TODO(), filter)
return
}
// DeleteId 删除文档
func (c *MongoClient) DeleteId(id interface{}) (err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
_, err = collection.DeleteOne(context.TODO(), bson.M{"_id": id})
return
}
// DeleteMany 删除多个文档
func (c *MongoClient) DeleteMany(filter interface{}) (result *mongo.DeleteResult, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.DeleteMany(context.TODO(), filter)
return
}
// UpdateOne 更新单个文档
// 修改字段的值($set)
// 字段增加值 inc($inc)
// 从数组中增加一个元素 push($push)
// 从数组中删除一个元素 pull($pull)
func (c *MongoClient) UpdateOne(filter interface{}, update interface{}) (err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
_, err = collection.UpdateOne(context.TODO(), filter, update)
return
}
// UpdateId 更新单个文档
// 修改字段的值($set)
// 字段增加值 inc($inc)
// 从数组中增加一个元素 push($push)
// 从数组中删除一个元素 pull($pull)
func (c *MongoClient) UpdateId(id interface{}, update interface{}) (err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
_, err = collection.UpdateOne(context.TODO(), bson.M{"_id": id}, update)
return
}
// UpdateMany 更新多个文档
// 修改字段的值($set)
// 字段增加值 inc($inc)
// 从数组中增加一个元素 push($push)
// 从数组中删除一个元素 pull($pull)
func (c *MongoClient) UpdateMany(filter interface{}, update interface{}) (result *mongo.UpdateResult, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.UpdateMany(context.TODO(), filter, update)
return
}
type FindResultI interface {
Many(result interface{}) error
}
// Find 查询
func (c *MongoClient) Find(filter interface{}) (*mongo.Cursor, error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
return collection.Find(context.TODO(), filter)
}
type FindOneResultI interface {
One(result interface{}) error
}
// FindOne 查询单个文档
func (c *MongoClient) FindOne(filter interface{}) *mongo.SingleResult {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
return collection.FindOne(context.TODO(), filter)
}
type FindManyResultI interface {
Many(result interface{}) error
}
// FindMany 查询多个文档
func (c *MongoClient) FindMany(filter interface{}) (*mongo.Cursor, error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
return collection.Find(context.TODO(), filter)
}
// FindManyByFilters 多条件查询
func (c *MongoClient) FindManyByFilters(filter interface{}) (result *mongo.Cursor, err error) {
collection, err := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName).Clone()
result, err = collection.Find(context.TODO(), bson.M{"$and": filter})
return result, err
}
// FindManyByFiltersSort 多条件查询支持排序
func (c *MongoClient) FindManyByFiltersSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) {
collection, err := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName).Clone()
findOptions := options.Find()
findOptions.SetSort(Sort)
result, err = collection.Find(context.TODO(), filter, findOptions)
return result, err
}
// FindCollection 查询集合文档
func (c *MongoClient) FindCollection(Limit int64) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
findOptions := options.Find()
findOptions.SetLimit(Limit)
result, err = collection.Find(context.TODO(), bson.D{{}}, findOptions)
return result, err
}
// FindCollectionSort 查询集合文档支持排序
func (c *MongoClient) FindCollectionSort(Sort interface{}, Limit int64) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
findOptions := options.Find()
findOptions.SetSort(Sort)
findOptions.SetLimit(Limit)
result, err = collection.Find(context.TODO(), bson.D{{}}, findOptions)
return result, err
}
// FindManyCollectionSort 查询集合文档支持排序支持条件
func (c *MongoClient) FindManyCollectionSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
findOptions := options.Find()
findOptions.SetSort(Sort)
result, err = collection.Find(context.TODO(), filter, findOptions)
return result, err
}
// CollectionCount 查询集合里有多少数据
func (c *MongoClient) CollectionCount() (name string, size int64) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
name = collection.Name()
size, _ = collection.EstimatedDocumentCount(context.TODO())
return name, size
}
// CollectionDocuments 按选项查询集合
// Skip 跳过
// Limit 读取数量
// sort 1 -1 . 1 为升序 -1 为降序
func (c *MongoClient) CollectionDocuments(Skip, Limit int64, sort int, key string, value interface{}) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
SORT := bson.D{{"_id", sort}}
filter := bson.D{{key, value}}
findOptions := options.Find().SetSort(SORT).SetLimit(Limit).SetSkip(Skip)
result, err = collection.Find(context.TODO(), filter, findOptions)
return result, err
}
// AggregateByFiltersSort 统计分析
func (c *MongoClient) AggregateByFiltersSort(pipeline interface{}) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.Aggregate(context.TODO(), pipeline)
return result, err
}
// CountDocumentsByFilters 统计数量
func (c *MongoClient) CountDocumentsByFilters(filter interface{}) (count int64, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
count, err = collection.CountDocuments(context.TODO(), filter)
return count, err
}

@ -1,15 +0,0 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
)
type FindResult struct {
cursor *mongo.Cursor
err error
}
func (f *FindResult) Many(result interface{}) error {
return f.cursor.All(context.TODO(), result)
}

@ -1,15 +0,0 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
)
type FindManyResult struct {
cursor *mongo.Cursor
err error
}
func (f *FindManyResult) Many(result interface{}) error {
return f.cursor.All(context.TODO(), result)
}

@ -1,11 +0,0 @@
package dorm
import "go.mongodb.org/mongo-driver/mongo"
type FindOneResult struct {
singleResult *mongo.SingleResult
}
func (f *FindOneResult) One(result interface{}) error {
return f.singleResult.Decode(result)
}

@ -0,0 +1,69 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"reflect"
)
type MongoDatabaseOptions struct {
db *mongo.Client // 驱动
configDatabaseName string // 库名
dbDatabase *mongo.Database // 数据库
}
// Database 选择数据库
func (c *MongoClient) Database(name string, opts ...*options.DatabaseOptions) *MongoDatabaseOptions {
return &MongoDatabaseOptions{
db: c.db, // 驱动
configDatabaseName: c.configDatabaseName, // 库名
dbDatabase: c.db.Database(name, opts...), // 数据库
}
}
// CreateCollection 创建集合
func (cd *MongoDatabaseOptions) CreateCollection(ctx context.Context, name string, opts ...*options.CreateCollectionOptions) error {
return cd.dbDatabase.CreateCollection(ctx, name, opts...)
}
// CreateTimeSeriesCollection 创建时间序列集合
func (cd *MongoDatabaseOptions) CreateTimeSeriesCollection(ctx context.Context, name string, timeField string) error {
return cd.dbDatabase.CreateCollection(ctx, name, options.CreateCollection().SetTimeSeriesOptions(options.TimeSeries().SetTimeField(timeField)))
}
// Model 传入模型自动获取库名和表名
// https://studygolang.com/articles/896
// DatabaseName 库名
// CollectionName 集合名
func (c *MongoClient) Model(value interface{}) *MongoCollectionOptions {
var databaseOptions *MongoDatabaseOptions
var collectionOptions *MongoCollectionOptions
val := reflect.ValueOf(value)
methodDatabaseNameValue := val.MethodByName("DatabaseName")
if methodDatabaseNameValue.IsValid() {
databaseName := methodDatabaseNameValue.Call(nil)[0].String()
databaseOptions = c.Database(databaseName)
} else {
databaseOptions = c.Database(c.configDatabaseName)
}
methodCollectionNameValue := val.MethodByName("CollectionName")
if methodCollectionNameValue.IsValid() {
collectionName := methodCollectionNameValue.Call(nil)[0].String()
collectionOptions = databaseOptions.Collection(collectionName)
} else {
methodTableNameValue := val.MethodByName("TableName")
if methodTableNameValue.IsValid() {
collectionName := methodTableNameValue.Call(nil)[0].String()
collectionOptions = databaseOptions.Collection(collectionName)
} else {
panic(NoConfigCollectionName)
}
}
return collectionOptions
}

@ -0,0 +1,8 @@
package dorm
import "errors"
var (
NoConfigDatabaseName = errors.New("没有配置库名")
NoConfigCollectionName = errors.New("没有配置集合名")
)

@ -4,15 +4,5 @@ import "go.mongodb.org/mongo-driver/mongo"
// GetDb 获取驱动
func (c *MongoClient) GetDb() *mongo.Client {
return c.Db
}
// 获取库名
func (c *MongoClient) getDatabaseName() string {
return c.databaseName
}
// 获取表名
func (c *MongoClient) getCollectionName() string {
return c.collectionName
return c.db
}

@ -1 +0,0 @@
package dorm

@ -0,0 +1,59 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
"log"
)
type MongoSessionOptions struct {
db *mongo.Client // 驱动
configDatabaseName string // 库名
session mongo.Session // 会话
sessionContext mongo.SessionContext // 会话上下文
}
// Begin 开始事务,会同时创建开始会话需要在退出时关闭会话
func (c *MongoClient) Begin() *MongoSessionOptions {
var ctx = context.TODO()
var err error
ms := &MongoSessionOptions{}
ms.db = c.GetDb()
ms.configDatabaseName = c.configDatabaseName
// 开始会话
ms.session, err = ms.db.StartSession()
if err != nil {
log.Println("开始会话异常:", err)
}
// 会话上下文
ms.sessionContext = mongo.NewSessionContext(ctx, ms.session)
// 会话开启事务
err = ms.session.StartTransaction()
return ms
}
// Rollback 回滚事务
func (cs *MongoSessionOptions) Rollback() {
var ctx = context.TODO()
err := cs.session.AbortTransaction(ctx)
if err != nil {
log.Println("回滚事务异常:", err)
}
cs.session.EndSession(ctx)
}
// Commit 提交事务
func (cs *MongoSessionOptions) Commit() {
var ctx = context.TODO()
err := cs.session.CommitTransaction(ctx)
if err != nil {
log.Println("提交事务异常:", err)
}
cs.session.EndSession(ctx)
}

@ -0,0 +1,25 @@
package dorm
import (
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type MongoSessionCollectionOptions struct {
db *mongo.Client // 驱动
configDatabaseName string // 库名
session mongo.Session // 会话
sessionContext mongo.SessionContext // 会话上下文
dbCollection *mongo.Collection // 集合
}
// Collection 选择集合
func (csd *MongoSessionDatabaseOptions) Collection(name string, opts ...*options.CollectionOptions) *MongoSessionCollectionOptions {
return &MongoSessionCollectionOptions{
db: csd.db, // 驱动
configDatabaseName: csd.configDatabaseName, // 库名
session: csd.session, // 会话
sessionContext: csd.sessionContext, // 会话上下文
dbCollection: csd.dbDatabase.Collection(name, opts...), // 集合
}
}

@ -0,0 +1,99 @@
package dorm
import (
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// InsertOne 插入一个文档
func (csc *MongoSessionCollectionOptions) InsertOne(document interface{}, opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) {
return csc.dbCollection.InsertOne(csc.sessionContext, document, opts...)
}
// InsertMany 插入多个文档
func (csc *MongoSessionCollectionOptions) InsertMany(document []interface{}, opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) {
return csc.dbCollection.InsertMany(csc.sessionContext, document, opts...)
}
// DeleteOne 删除一个文档
func (csc *MongoSessionCollectionOptions) DeleteOne(filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) {
return csc.dbCollection.DeleteOne(csc.sessionContext, filter, opts...)
}
// DeleteMany 删除多个文档
func (csc *MongoSessionCollectionOptions) DeleteMany(filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) {
return csc.dbCollection.DeleteMany(csc.sessionContext, filter, opts...)
}
// UpdateByID 按ID更新
func (csc *MongoSessionCollectionOptions) UpdateByID(id interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
return csc.dbCollection.UpdateByID(csc.sessionContext, id, update, opts...)
}
// UpdateOne 更新一个文档
func (csc *MongoSessionCollectionOptions) UpdateOne(filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
return csc.dbCollection.UpdateOne(csc.sessionContext, filter, update, opts...)
}
// UpdateMany 更新多个文档
func (csc *MongoSessionCollectionOptions) UpdateMany(filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
return csc.dbCollection.UpdateMany(csc.sessionContext, filter, update, opts...)
}
// ReplaceOne 替换一个文档
func (csc *MongoSessionCollectionOptions) ReplaceOne(filter interface{}, update interface{}, opts ...*options.ReplaceOptions) (*mongo.UpdateResult, error) {
return csc.dbCollection.ReplaceOne(csc.sessionContext, filter, update, opts...)
}
// Aggregate 统计分析
func (csc *MongoSessionCollectionOptions) Aggregate(pipeline interface{}, opts ...*options.AggregateOptions) (*mongo.Cursor, error) {
return csc.dbCollection.Aggregate(csc.sessionContext, pipeline, opts...)
}
// CountDocuments 计数文档
func (csc *MongoSessionCollectionOptions) CountDocuments(filter interface{}, opts ...*options.CountOptions) (int64, error) {
return csc.dbCollection.CountDocuments(csc.sessionContext, filter, opts...)
}
// EstimatedDocumentCount 估计文档计数
func (csc *MongoSessionCollectionOptions) EstimatedDocumentCount(opts ...*options.EstimatedDocumentCountOptions) (int64, error) {
return csc.dbCollection.EstimatedDocumentCount(csc.sessionContext, opts...)
}
func (csc *MongoSessionCollectionOptions) Distinct(fieldName string, filter interface{}, opts ...*options.DistinctOptions) ([]interface{}, error) {
return csc.dbCollection.Distinct(csc.sessionContext, fieldName, filter, opts...)
}
// Find 查询多个文档
func (csc *MongoSessionCollectionOptions) Find(filter interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) {
return csc.dbCollection.Find(csc.sessionContext, filter, opts...)
}
// FindOne 查询一个文档
func (csc *MongoSessionCollectionOptions) FindOne(filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult {
return csc.dbCollection.FindOne(csc.sessionContext, filter, opts...)
}
func (csc *MongoSessionCollectionOptions) FindOneAndDelete(filter interface{}, opts ...*options.FindOneAndDeleteOptions) *mongo.SingleResult {
return csc.dbCollection.FindOneAndDelete(csc.sessionContext, filter, opts...)
}
func (csc *MongoSessionCollectionOptions) FindOneAndReplace(filter interface{}, replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *mongo.SingleResult {
return csc.dbCollection.FindOneAndReplace(csc.sessionContext, filter, replacement, opts...)
}
func (csc *MongoSessionCollectionOptions) FindOneAndUpdate(filter interface{}, replacement interface{}, opts ...*options.FindOneAndUpdateOptions) *mongo.SingleResult {
return csc.dbCollection.FindOneAndUpdate(csc.sessionContext, filter, replacement, opts...)
}
func (csc *MongoSessionCollectionOptions) Watch(pipeline interface{}, opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) {
return csc.dbCollection.Watch(csc.sessionContext, pipeline, opts...)
}
func (csc *MongoSessionCollectionOptions) Indexes() mongo.IndexView {
return csc.dbCollection.Indexes()
}
func (csc *MongoSessionCollectionOptions) Drop() error {
return csc.dbCollection.Drop(csc.sessionContext)
}

@ -0,0 +1,62 @@
package dorm
import (
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"reflect"
)
type MongoSessionDatabaseOptions struct {
db *mongo.Client // 驱动
configDatabaseName string // 库名
session mongo.Session // 会话
sessionContext mongo.SessionContext // 会话上下文
dbDatabase *mongo.Database // 数据库
}
// Database 选择数据库
func (cs *MongoSessionOptions) Database(name string, opts ...*options.DatabaseOptions) *MongoSessionDatabaseOptions {
return &MongoSessionDatabaseOptions{
db: cs.db, // 驱动
configDatabaseName: cs.configDatabaseName, // 库名
session: cs.session, // 会话
sessionContext: cs.sessionContext, // 会话上下文
dbDatabase: cs.db.Database(name, opts...), // 数据库
}
}
// Model 传入模型自动获取库名和表名
// https://studygolang.com/articles/896
// DatabaseName 库名
// CollectionName 集合名
func (cs *MongoSessionOptions) Model(value interface{}) *MongoSessionCollectionOptions {
var sessionDatabaseOptions *MongoSessionDatabaseOptions
var sessionCollectionOptions *MongoSessionCollectionOptions
val := reflect.ValueOf(value)
methodDatabaseNameValue := val.MethodByName("DatabaseName")
if methodDatabaseNameValue.IsValid() {
databaseName := methodDatabaseNameValue.Call(nil)[0].String()
sessionDatabaseOptions = cs.Database(databaseName)
} else {
sessionDatabaseOptions = cs.Database(cs.configDatabaseName)
}
methodCollectionNameValue := val.MethodByName("CollectionName")
if methodCollectionNameValue.IsValid() {
collectionName := methodCollectionNameValue.Call(nil)[0].String()
sessionCollectionOptions = sessionDatabaseOptions.Collection(collectionName)
} else {
methodTableNameValue := val.MethodByName("TableName")
if methodTableNameValue.IsValid() {
collectionName := methodTableNameValue.Call(nil)[0].String()
sessionCollectionOptions = sessionDatabaseOptions.Collection(collectionName)
} else {
panic(NoConfigCollectionName)
}
}
return sessionCollectionOptions
}

@ -0,0 +1,18 @@
package dorm
import "go.mongodb.org/mongo-driver/mongo"
// GetDb 获取驱动
func (cs *MongoSessionOptions) GetDb() *mongo.Client {
return cs.db
}
// GetSession 获取会话
func (cs *MongoSessionOptions) GetSession() mongo.Session {
return cs.session
}
// GetSessionContext 获取会话上下文
func (cs *MongoSessionOptions) GetSessionContext() mongo.SessionContext {
return cs.sessionContext
}

@ -1,28 +0,0 @@
package dorm
import "reflect"
// Database 设置库名
func (c *MongoClient) Database(databaseName string) *MongoClient {
c.databaseName = databaseName
return c
}
// Collection 设置表名
func (c *MongoClient) Collection(collectionName string) *MongoClient {
c.collectionName = collectionName
return c
}
// Model 传入模型自动获取库名和表名
func (c *MongoClient) Model(value interface{}) *MongoClient {
// https://studygolang.com/articles/896
val := reflect.ValueOf(value)
if methodValue := val.MethodByName("Database"); methodValue.IsValid() {
c.databaseName = methodValue.Call(nil)[0].String()
}
if methodValue := val.MethodByName("TableName"); methodValue.IsValid() {
c.collectionName = methodValue.Call(nil)[0].String()
}
return c
}

@ -7,44 +7,79 @@ import (
"time"
)
// BsonTime 类型
// BsonTime 时间类型
type BsonTime time.Time
// Value 时间类型
func (t BsonTime) Value() string {
return gotime.SetCurrent(time.Time(t)).Bson()
}
// MarshalJSON 实现json序列化
func (t BsonTime) MarshalJSON() ([]byte, error) {
//log.Println("MarshalJSON")
func (bt BsonTime) MarshalJSON() ([]byte, error) {
b := make([]byte, 0)
b = append(b, gotime.SetCurrent(time.Time(t)).Bson()...)
b = append(b, '"')
b = append(b, gotime.SetCurrent(time.Time(bt)).Bson()...)
b = append(b, '"')
return b, nil
}
// UnmarshalJSON 实现json反序列化
func (t *BsonTime) UnmarshalJSON(data []byte) (err error) {
//log.Println("UnmarshalJSON")
t1 := gotime.SetCurrentParse(string(data))
*t = BsonTime(t1.Time)
return
func (bt *BsonTime) UnmarshalJSON(data []byte) (err error) {
if string(data) == "null" {
return nil
}
bsonTime := gotime.SetCurrentParse(string(data))
*bt = BsonTime(bsonTime.Time)
return nil
}
// Time 转时间
func (bt BsonTime) Time() time.Time {
return gotime.SetCurrent(time.Time(bt)).Time
}
// Format 转时间字符串
func (bt BsonTime) Format() string {
return gotime.SetCurrent(time.Time(bt)).Format()
}
// TimePro 转时间操作
func (bt BsonTime) TimePro() gotime.Pro {
return gotime.SetCurrent(time.Time(bt))
}
// NewBsonTimeCurrent 创建当前时间
func NewBsonTimeCurrent() BsonTime {
return BsonTime(gotime.Current().Time)
}
// NewBsonTimeFromTime 创建某个时间
func NewBsonTimeFromTime(t time.Time) BsonTime {
return BsonTime(t)
}
// NewBsonTimeFromString 创建某个时间 字符串
func NewBsonTimeFromString(t string) BsonTime {
return BsonTime(gotime.SetCurrentParse(t).Time)
}
// Value 时间类型
func (bt BsonTime) Value() string {
return gotime.SetCurrent(time.Time(bt)).Bson()
}
// MarshalBSONValue 实现bson序列化
func (t BsonTime) MarshalBSONValue() (bsontype.Type, []byte, error) {
//log.Println("MarshalBSONValue")
targetTime := gotime.SetCurrent(time.Time(t)).Bson()
return bson.MarshalValue(targetTime)
func (bt BsonTime) MarshalBSONValue() (bsontype.Type, []byte, error) {
return bson.MarshalValue(gotime.SetCurrent(time.Time(bt)).Bson())
}
// UnmarshalBSONValue 实现bson反序列化
func (t *BsonTime) UnmarshalBSONValue(t2 bsontype.Type, data []byte) error {
//log.Println("UnmarshalBSONValue")
func (bt *BsonTime) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
t1 := gotime.SetCurrentParse(string(data))
//if string(data) == "" {
// return errors.New(fmt.Sprintf("%s, %s, %s", "读取数据失败:", t2, data))
// return errors.New(fmt.Sprintf("%s, %s, %s", "读取数据失败:", t, data))
//}
*t = BsonTime(t1.Time)
*bt = BsonTime(t1.Time)
return nil
}

@ -1,48 +0,0 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
)
type MongoTransaction struct {
startSession mongo.Session
Session mongo.SessionContext
db *mongo.Client // 驱动
databaseName string // 库名
collectionName string // 表名
}
// Begin 开始事务,会同时创建开始会话需要在退出时关闭会话
func (c *MongoClient) Begin() (ms MongoTransaction, err error) {
ms.db = c.Db
// 开始会话
ms.startSession, err = ms.db.StartSession()
if err != nil {
panic(err)
}
// 会话上下文
ms.Session = mongo.NewSessionContext(context.Background(), ms.startSession)
// 会话开启事务
err = ms.startSession.StartTransaction()
return ms, err
}
// Close 关闭会话
func (ms *MongoTransaction) Close() {
ms.startSession.EndSession(context.TODO())
}
// Rollback 回滚事务
func (ms *MongoTransaction) Rollback() error {
return ms.startSession.AbortTransaction(context.Background())
}
// Commit 提交事务
func (ms *MongoTransaction) Commit() error {
return ms.startSession.CommitTransaction(context.Background())
}

@ -1,174 +0,0 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// InsertOne 插入一个文档
func (ms *MongoTransaction) InsertOne(document interface{}) (result *mongo.InsertOneResult, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
result, err = collection.InsertOne(ms.Session, document)
return
}
// InsertMany 插入多个文档
func (ms *MongoTransaction) InsertMany(documents []interface{}) (result *mongo.InsertManyResult, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
result, err = collection.InsertMany(ms.Session, documents)
return
}
// Delete 删除文档
func (ms *MongoTransaction) Delete(filter interface{}) (err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
_, err = collection.DeleteOne(ms.Session, filter)
return
}
// DeleteId 删除文档
func (ms *MongoTransaction) DeleteId(id interface{}) (err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
_, err = collection.DeleteOne(ms.Session, bson.M{"_id": id})
return
}
// DeleteMany 删除多个文档
func (ms *MongoTransaction) DeleteMany(key string, value interface{}) (result *mongo.DeleteResult, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
filter := bson.D{{key, value}}
result, err = collection.DeleteMany(ms.Session, filter)
return
}
// UpdateOne 更新单个文档
// 修改字段的值($set)
// 字段增加值 inc($inc)
// 从数组中增加一个元素 push($push)
// 从数组中删除一个元素 pull($pull)
func (ms *MongoTransaction) UpdateOne(filter interface{}, update interface{}) (err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
_, err = collection.UpdateOne(ms.Session, filter, update)
return
}
// UpdateId 更新单个文档
// 修改字段的值($set)
// 字段增加值 inc($inc)
// 从数组中增加一个元素 push($push)
// 从数组中删除一个元素 pull($pull)
func (ms *MongoTransaction) UpdateId(id interface{}, update interface{}) (err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
_, err = collection.UpdateOne(context.TODO(), bson.M{"_id": id}, update)
return
}
// UpdateMany 更新多个文档
// 修改字段的值($set)
// 字段增加值 inc($inc)
// 从数组中增加一个元素 push($push)
// 从数组中删除一个元素 pull($pull)
func (ms *MongoTransaction) UpdateMany(filter interface{}, update interface{}) (result *mongo.UpdateResult, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
result, err = collection.UpdateMany(ms.Session, filter, update)
return
}
// Find 查询
func (ms *MongoTransaction) Find(filter interface{}) (*mongo.Cursor, error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
return collection.Find(ms.Session, filter)
}
// FindOne 查询单个文档
func (ms *MongoTransaction) FindOne(filter interface{}) *mongo.SingleResult {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
return collection.FindOne(ms.Session, filter)
}
// FindMany 查询多个文档
func (ms *MongoTransaction) FindMany(filter interface{}) (*mongo.Cursor, error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
return collection.Find(ms.Session, filter)
}
// FindManyByFilters 多条件查询
func (ms *MongoTransaction) FindManyByFilters(filter interface{}) (result *mongo.Cursor, err error) {
collection, err := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName).Clone()
result, err = collection.Find(ms.Session, bson.M{"$and": filter})
return result, err
}
// FindManyByFiltersSort 多条件查询支持排序
func (ms *MongoTransaction) FindManyByFiltersSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) {
collection, err := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName).Clone()
findOptions := options.Find()
findOptions.SetSort(Sort)
result, err = collection.Find(ms.Session, filter, findOptions)
return result, err
}
// FindCollection 查询集合文档
func (ms *MongoTransaction) FindCollection(Limit int64) (result *mongo.Cursor, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
findOptions := options.Find()
findOptions.SetLimit(Limit)
result, err = collection.Find(ms.Session, bson.D{{}}, findOptions)
return result, err
}
// FindCollectionSort 查询集合文档支持排序
func (ms *MongoTransaction) FindCollectionSort(Sort interface{}, Limit int64) (result *mongo.Cursor, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
findOptions := options.Find()
findOptions.SetSort(Sort)
findOptions.SetLimit(Limit)
result, err = collection.Find(ms.Session, bson.D{{}}, findOptions)
return result, err
}
// FindManyCollectionSort 查询集合文档支持排序支持条件
func (ms *MongoTransaction) FindManyCollectionSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
findOptions := options.Find()
findOptions.SetSort(Sort)
result, err = collection.Find(ms.Session, filter, findOptions)
return result, err
}
// CollectionCount 查询集合里有多少数据
func (ms *MongoTransaction) CollectionCount(ctx context.Context) (name string, size int64) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
name = collection.Name()
size, _ = collection.EstimatedDocumentCount(ctx)
return name, size
}
// CollectionDocuments 按选项查询集合
// Skip 跳过
// Limit 读取数量
// sort 1 -1 . 1 为升序 -1 为降序
func (ms *MongoTransaction) CollectionDocuments(Skip, Limit int64, sort int, key string, value interface{}) (result *mongo.Cursor, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
SORT := bson.D{{"_id", sort}}
filter := bson.D{{key, value}}
findOptions := options.Find().SetSort(SORT).SetLimit(Limit).SetSkip(Skip)
result, err = collection.Find(ms.Session, filter, findOptions)
return result, err
}
// AggregateByFiltersSort 统计分析
func (ms *MongoTransaction) AggregateByFiltersSort(pipeline interface{}) (result *mongo.Cursor, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
result, err = collection.Aggregate(ms.Session, pipeline)
return result, err
}
// CountDocumentsByFilters 统计数量
func (ms *MongoTransaction) CountDocumentsByFilters(filter interface{}) (count int64, err error) {
collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName)
count, err = collection.CountDocuments(ms.Session, filter)
return count, err
}

@ -1,11 +0,0 @@
package dorm
// 获取库名
func (ms *MongoTransaction) getDatabaseName() string {
return ms.databaseName
}
// 获取表名
func (ms *MongoTransaction) getCollectionName() string {
return ms.collectionName
}

@ -1,28 +0,0 @@
package dorm
import "reflect"
// Database 设置库名
func (ms *MongoTransaction) Database(databaseName string) *MongoTransaction {
ms.databaseName = databaseName
return ms
}
// Collection 设置表名
func (ms *MongoTransaction) Collection(collectionName string) *MongoTransaction {
ms.collectionName = collectionName
return ms
}
// Model 传入模型自动获取库名和表名
func (ms *MongoTransaction) Model(value interface{}) *MongoTransaction {
// https://studygolang.com/articles/896
val := reflect.ValueOf(value)
if methodValue := val.MethodByName("Database"); methodValue.IsValid() {
ms.databaseName = methodValue.Call(nil)[0].String()
}
if methodValue := val.MethodByName("TableName"); methodValue.IsValid() {
ms.collectionName = methodValue.Call(nil)[0].String()
}
return ms
}

@ -1 +0,0 @@
package dorm

@ -1 +0,0 @@
package dorm

@ -1 +0,0 @@
package dorm

@ -4,11 +4,14 @@ import (
"context"
"errors"
"fmt"
"github.com/go-redis/redis/v9"
"github.com/redis/go-redis/v9"
"time"
)
type ConfigRedisClient struct {
// RedisClientFun *RedisClient 驱动
type RedisClientFun func() *RedisClient
type RedisClientConfig struct {
Addr string // 地址
Password string // 密码
DB int // 数据库
@ -19,30 +22,55 @@ type ConfigRedisClient struct {
// RedisClient
// https://redis.uptrace.dev/
type RedisClient struct {
Db *redis.Client // 驱动
config *ConfigRedisClient // 配置
db *redis.Client // 驱动
}
func NewRedisClient(config *ConfigRedisClient) (*RedisClient, error) {
// NewRedisClient 创建实例
func NewRedisClient(config *RedisClientConfig) (*RedisClient, error) {
c := &RedisClient{}
c.config = config
if c.config.PoolSize == 0 {
c.config.PoolSize = 100
if config.PoolSize == 0 {
config.PoolSize = 100
}
c.Db = redis.NewClient(&redis.Options{
Addr: c.config.Addr, // 地址
Password: c.config.Password, // 密码
DB: c.config.DB, // 数据库
PoolSize: c.config.PoolSize, // 连接池大小
ReadTimeout: c.config.ReadTimeout, // 读取超时
c.db = redis.NewClient(&redis.Options{
Addr: config.Addr, // 地址
Password: config.Password, // 密码
DB: config.DB, // 数据库
PoolSize: config.PoolSize, // 连接池大小
ReadTimeout: config.ReadTimeout, // 读取超时
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := c.Db.Ping(ctx).Result()
// 检测 Redis 连接是否正常连接
_, err := c.db.Ping(ctx).Result()
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}
return c, nil
}
// NewRedisClientURL 创建实例
func NewRedisClientURL(redisURL string) (*RedisClient, error) {
c := &RedisClient{}
opt, err := redis.ParseURL(redisURL)
if err != nil {
return c, nil
}
// 创建 Redis 客户端
c.db = redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 检测 Redis 连接是否正常连接
_, err = c.db.Ping(ctx).Result()
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}

@ -2,7 +2,7 @@ package dorm
import (
"context"
"encoding/json"
"go.dtapp.net/gojson"
"time"
)
@ -62,7 +62,7 @@ func (rc *RedisClientCache) GetString(ctx context.Context, key string) (ret stri
func (rc *RedisClientCache) GetInterface(ctx context.Context, key string, result interface{}) {
f := func() string {
marshal, _ := json.Marshal(rc.GetterInterface())
marshal, _ := gojson.Marshal(rc.GetterInterface())
return string(marshal)
}
@ -74,7 +74,7 @@ func (rc *RedisClientCache) GetInterface(ctx context.Context, key string, result
ret, _ = rc.operation.Get(ctx, key).Result()
}
err = json.Unmarshal([]byte(ret), result)
err = gojson.Unmarshal([]byte(ret), result)
return
}
@ -85,12 +85,12 @@ func (rc *RedisClientCache) GetInterfaceKey(ctx context.Context, key string, res
if err != nil {
return err
}
err = json.Unmarshal([]byte(ret), result)
err = gojson.Unmarshal([]byte(ret), result)
return nil
}
// SetInterfaceKey 设置key值
func (rc *RedisClientCache) SetInterfaceKey(ctx context.Context, key string, value interface{}) (string, error) {
marshal, _ := json.Marshal(value)
marshal, _ := gojson.Marshal(value)
return rc.operation.Set(ctx, key, marshal, rc.defaultExpiration).Result()
}

@ -2,69 +2,69 @@ package dorm
import (
"context"
"github.com/go-redis/redis/v9"
"github.com/redis/go-redis/v9"
"time"
)
// Set 设置一个key的值
func (r *RedisClient) Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd {
return r.Db.Set(ctx, key, value, expiration)
return r.db.Set(ctx, key, value, expiration)
}
// Get 查询key的值
func (r *RedisClient) Get(ctx context.Context, key string) *redis.StringCmd {
return r.Db.Get(ctx, key)
return r.db.Get(ctx, key)
}
// GetSet 设置一个key的值并返回这个key的旧值
func (r *RedisClient) GetSet(ctx context.Context, key string, value interface{}) *redis.StringCmd {
return r.Db.GetSet(ctx, key, value)
return r.db.GetSet(ctx, key, value)
}
// SetNX 如果key不存在则设置这个key的值
func (r *RedisClient) SetNX(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.BoolCmd {
return r.Db.SetNX(ctx, key, value, expiration)
return r.db.SetNX(ctx, key, value, expiration)
}
// MGet 批量查询key的值
func (r *RedisClient) MGet(ctx context.Context, keys ...string) *redis.SliceCmd {
return r.Db.MGet(ctx, keys...)
return r.db.MGet(ctx, keys...)
}
// MSet 批量设置key的值
// MSet(map[string]interface{}{"key1": "value1", "key2": "value2"})
func (r *RedisClient) MSet(ctx context.Context, values map[string]interface{}) *redis.StatusCmd {
return r.Db.MSet(ctx, values)
return r.db.MSet(ctx, values)
}
// Incr 针对一个key的数值进行递增操作
func (r *RedisClient) Incr(ctx context.Context, key string) *redis.IntCmd {
return r.Db.Incr(ctx, key)
return r.db.Incr(ctx, key)
}
// IncrBy 针对一个key的数值进行递增操作指定每次递增多少
func (r *RedisClient) IncrBy(ctx context.Context, key string, value int64) *redis.IntCmd {
return r.Db.IncrBy(ctx, key, value)
return r.db.IncrBy(ctx, key, value)
}
// Decr 针对一个key的数值进行递减操作
func (r *RedisClient) Decr(ctx context.Context, key string) *redis.IntCmd {
return r.Db.Decr(ctx, key)
return r.db.Decr(ctx, key)
}
// DecrBy 针对一个key的数值进行递减操作指定每次递减多少
func (r *RedisClient) DecrBy(ctx context.Context, key string, value int64) *redis.IntCmd {
return r.Db.DecrBy(ctx, key, value)
return r.db.DecrBy(ctx, key, value)
}
// Del 删除key操作支持批量删除
func (r *RedisClient) Del(ctx context.Context, keys ...string) *redis.IntCmd {
return r.Db.Del(ctx, keys...)
return r.db.Del(ctx, keys...)
}
// Keys 按前缀获取所有key名
func (r *RedisClient) Keys(ctx context.Context, prefix string) []string {
values, _ := r.Db.Keys(ctx, prefix).Result()
values, _ := r.db.Keys(ctx, prefix).Result()
keys := make([]string, 0, len(values))
if len(values) <= 0 {
return keys
@ -77,7 +77,7 @@ func (r *RedisClient) Keys(ctx context.Context, prefix string) []string {
// KeysValue 按前缀获取所有key值
func (r *RedisClient) KeysValue(ctx context.Context, prefix string) *redis.SliceCmd {
values, _ := r.Db.Keys(ctx, prefix).Result()
values, _ := r.db.Keys(ctx, prefix).Result()
if len(values) <= 0 {
return &redis.SliceCmd{}
}
@ -87,3 +87,72 @@ func (r *RedisClient) KeysValue(ctx context.Context, prefix string) *redis.Slice
}
return r.MGet(ctx, keys...)
}
// AddKeyToSet 将值添加到集合
func (r *RedisClient) AddKeyToSet(ctx context.Context, setKey string, value interface{}) error {
return r.db.SAdd(ctx, setKey, value).Err()
}
// GetAllKeysInSet 获取集合的所有元素
func (r *RedisClient) GetAllKeysInSet(ctx context.Context, setKey string) ([]string, error) {
return r.db.SMembers(ctx, setKey).Result()
}
// DoesKeyExistInSet 检查值是否存在于集合中
func (r *RedisClient) DoesKeyExistInSet(ctx context.Context, setKey string, targetKey interface{}) (bool, error) {
return r.db.SIsMember(ctx, setKey, targetKey).Result()
}
// RemoveKeyFromSet 从集合中删除指定的元素
func (r *RedisClient) RemoveKeyFromSet(ctx context.Context, setKey string, targetKey interface{}) error {
return r.db.SRem(ctx, setKey, targetKey).Err()
}
// DeleteKeysWithPrefix 根据前缀删除key
func (r *RedisClient) DeleteKeysWithPrefix(ctx context.Context, prefix string) error {
// 获取所有符合给定模式的键
keys, err := r.db.Keys(ctx, prefix+"*").Result()
if err != nil {
return err
}
// 删除所有匹配的键
if len(keys) > 0 {
_, err := r.db.Del(ctx, keys...).Result()
if err != nil {
return err
}
}
return nil
}
// DeleteScanWithPrefix 根据前缀删除key
func (r *RedisClient) DeleteScanWithPrefix(ctx context.Context, prefix string) error {
var cursor uint64
for {
// 使用 SCAN 迭代获取匹配模式的键
keys, nextCursor, err := r.db.Scan(ctx, cursor, prefix+"*", 10).Result()
if err != nil {
return err
}
// 删除当前迭代返回的键
if len(keys) > 0 {
_, err := r.db.Del(ctx, keys...).Result()
if err != nil {
return err
}
}
// 更新游标,继续迭代
cursor = nextCursor
// 如果迭代结束,退出循环
if cursor == 0 {
break
}
}
return nil
}

@ -2,30 +2,30 @@ package dorm
import (
"context"
"github.com/go-redis/redis/v9"
"github.com/redis/go-redis/v9"
)
// Subscribe 订阅channel
func (r *RedisClient) Subscribe(ctx context.Context, channels ...string) *redis.PubSub {
return r.Db.Subscribe(ctx, channels...)
return r.db.Subscribe(ctx, channels...)
}
// PSubscribe 订阅channel支持通配符匹配
func (r *RedisClient) PSubscribe(ctx context.Context, channels ...string) *redis.PubSub {
return r.Db.PSubscribe(ctx, channels...)
return r.db.PSubscribe(ctx, channels...)
}
// Publish 将信息发送到指定的channel
func (r *RedisClient) Publish(ctx context.Context, channel string, message interface{}) *redis.IntCmd {
return r.Db.Publish(ctx, channel, message)
return r.db.Publish(ctx, channel, message)
}
// PubSubChannels 查询活跃的channel
func (r *RedisClient) PubSubChannels(ctx context.Context, pattern string) *redis.StringSliceCmd {
return r.Db.PubSubChannels(ctx, pattern)
return r.db.PubSubChannels(ctx, pattern)
}
// PubSubNumSub 查询指定的channel有多少个订阅者
func (r *RedisClient) PubSubNumSub(ctx context.Context, channels ...string) *redis.StringIntMapCmd {
return r.Db.PubSubNumSub(ctx, channels...)
func (r *RedisClient) PubSubNumSub(ctx context.Context, channels ...string) *redis.MapStringIntCmd {
return r.db.PubSubNumSub(ctx, channels...)
}

@ -1,8 +1,8 @@
package dorm
import "github.com/go-redis/redis/v9"
import "github.com/redis/go-redis/v9"
// GetDb 获取驱动
func (r *RedisClient) GetDb() *redis.Client {
return r.Db
return r.db
}

@ -2,9 +2,10 @@ package dorm
import (
"context"
"github.com/go-redis/redis/v9"
"github.com/redis/go-redis/v9"
)
// HashOperation Hash类型数据操作
type HashOperation struct {
db *redis.Client
}

@ -2,9 +2,10 @@ package dorm
import (
"context"
"github.com/go-redis/redis/v9"
"github.com/redis/go-redis/v9"
)
// ListOperation 列表(list)类型数据操作
type ListOperation struct {
db *redis.Client
ctx context.Context
@ -12,7 +13,7 @@ type ListOperation struct {
// NewListOperation 列表(list)类型数据操作 https://www.tizi365.com/archives/299.html
func (r *RedisClient) NewListOperation() *ListOperation {
return &ListOperation{db: r.Db, ctx: context.Background()}
return &ListOperation{db: r.db, ctx: context.Background()}
}
// LPush 从列表左边插入数据

@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/go-redis/redis/v9"
"github.com/redis/go-redis/v9"
"time"
)

@ -2,7 +2,7 @@ package dorm
import (
"context"
"encoding/json"
"go.dtapp.net/gojson"
"time"
)
@ -43,7 +43,7 @@ func (c *SimpleCache) GetCache(ctx context.Context, key string) (ret interface{}
if c.Serializer == SerializerJson {
f := func() string {
obj := c.JsonGetter()
b, err := json.Marshal(obj)
b, err := gojson.Marshal(obj)
if err != nil {
return ""
}

@ -2,7 +2,7 @@ package dorm
import (
"context"
"encoding/json"
"go.dtapp.net/gojson"
"time"
)
@ -32,7 +32,7 @@ func (c *SimpleJsonCache) SetCache(ctx context.Context, key string, value interf
func (c *SimpleJsonCache) GetCache(ctx context.Context, key string) (ret interface{}) {
f := func() string {
obj := c.DBGetter()
b, err := json.Marshal(obj)
b, err := gojson.Marshal(obj)
if err != nil {
return ""
}

@ -2,7 +2,7 @@ package dorm
import (
"context"
"github.com/go-redis/redis/v9"
"github.com/redis/go-redis/v9"
"time"
)
@ -11,7 +11,7 @@ type SimpleOperation struct {
}
func (r *RedisClient) NewSimpleOperation() *SimpleOperation {
return &SimpleOperation{db: r.Db}
return &SimpleOperation{db: r.db}
}
// Set 设置

@ -2,7 +2,7 @@ package dorm
import (
"context"
"github.com/go-redis/redis/v9"
"github.com/redis/go-redis/v9"
"time"
)
@ -11,7 +11,7 @@ type StringOperation struct {
}
func (r *RedisClient) NewStringOperation() *StringOperation {
return &StringOperation{db: r.Db}
return &StringOperation{db: r.db}
}
// Set 设置

@ -1 +0,0 @@
package dorm

@ -1 +0,0 @@
package dorm

@ -1 +0,0 @@
package dorm

@ -1 +0,0 @@
package dorm

@ -1 +0,0 @@
package dorm

@ -0,0 +1,76 @@
package dorm
// Bool 复制 bool 对象,并返回复制体
func Bool(b bool) bool {
return b
}
// Uint8 复制 int8 对象,并返回复制体
func Uint8(ui uint8) uint8 {
return ui
}
// Uint16 复制 uint16 对象,并返回复制体
func Uint16(ui uint16) uint16 {
return ui
}
// Uint32 复制 uint32 对象,并返回复制体
func Uint32(ui uint32) uint32 {
return ui
}
// Uint64 复制 uint64 对象,并返回复制体
func Uint64(ui uint64) uint64 {
return ui
}
// Int8 复制 int8 对象,并返回复制体
func Int8(i int8) int8 {
return i
}
// Int16 复制 int16 对象,并返回复制体
func Int16(i int16) int16 {
return i
}
// Int32 复制 int64 对象,并返回复制体
func Int32(i int32) int32 {
return i
}
// Int64 复制 int64 对象,并返回复制体
func Int64(i int64) int64 {
return i
}
// Float32 复制 float32 对象,并返回复制体
func Float32(f float32) float32 {
return f
}
// Float64 复制 float64 对象,并返回复制体
func Float64(f float64) float64 {
return f
}
// String 复制 string 对象,并返回复制体
func String(s string) string {
return s
}
// Int 复制 int 对象,并返回复制体
func Int(i int) int {
return i
}
// Uint 复制 uint 对象,并返回复制体
func Uint(ui uint) uint {
return ui
}
// Any 复制 any 对象,并返回复制体
func Any(a any) any {
return a
}

@ -1,9 +0,0 @@
package dorm
import "github.com/upper/db/v4"
// UpperClient
// https://upper.io/
type UpperClient struct {
Db *db.Session // 驱动
}

@ -1,10 +0,0 @@
package dorm
import (
"github.com/upper/db/v4"
)
// GetDb 获取驱动
func (c *UpperClient) GetDb() *db.Session {
return c.Db
}

@ -1,23 +0,0 @@
package dorm
import (
"errors"
"fmt"
"github.com/upper/db/v4/adapter/mysql"
)
func NewUpperMysqlClient(settings mysql.ConnectionURL) (*UpperClient, error) {
var err error
c := &UpperClient{}
sess, err := mysql.Open(settings)
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}
defer sess.Close()
c.Db = &sess
return c, nil
}

@ -1,23 +0,0 @@
package dorm
import (
"errors"
"fmt"
"github.com/upper/db/v4/adapter/postgresql"
)
func NewUpperPostgresqlClient(settings postgresql.ConnectionURL) (*UpperClient, error) {
var err error
c := &UpperClient{}
sess, err := postgresql.Open(settings)
if err != nil {
return nil, errors.New(fmt.Sprintf("连接失败:%v", err))
}
defer sess.Close()
c.Db = &sess
return c, nil
}

@ -1,25 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
/.tags

@ -1,21 +0,0 @@
The MIT License (MIT)
Copyright (c) 2016 Bastien Gysler
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -1,82 +0,0 @@
# goxml2json [![CircleCI](https://circleci.com/gh/basgys/goxml2json.svg?style=svg)](https://circleci.com/gh/basgys/goxml2json)
Go package that converts XML to JSON
### Install
go get -u github.com/basgys/goxml2json
### Importing
import github.com/basgys/goxml2json
### Usage
**Code example**
```go
package main
import (
"fmt"
"strings"
xj "github.com/basgys/goxml2json"
)
func main() {
// xml is an io.Reader
xml := strings.NewReader(`<?xml version="1.0" encoding="UTF-8"?><hello>world</hello>`)
json, err := xj.Convert(xml)
if err != nil {
panic("That's embarrassing...")
}
fmt.Println(json.String())
// {"hello": "world"}
}
```
**Input**
```xml
<?xml version="1.0" encoding="UTF-8"?>
<osm version="0.6" generator="CGImap 0.0.2">
<bounds minlat="54.0889580" minlon="12.2487570" maxlat="54.0913900" maxlon="12.2524800"/>
<foo>bar</foo>
</osm>
```
**Output**
```json
{
"osm": {
"-version": "0.6",
"-generator": "CGImap 0.0.2",
"bounds": {
"-minlat": "54.0889580",
"-minlon": "12.2487570",
"-maxlat": "54.0913900",
"-maxlon": "12.2524800"
},
"foo": "bar"
}
}
```
### Contributing
Feel free to contribute to this project if you want to fix/extend/improve it.
### Contributors
- [DirectX](https://github.com/directx)
- [samuelhug](https://github.com/samuelhug)
### TODO
* Extract data types in JSON (numbers, boolean, ...)
* Categorise errors
* Option to prettify the JSON output
* Benchmark

@ -1,25 +0,0 @@
package xml2json
import (
"bytes"
"io"
)
// Convert converts the given XML document to JSON
func Convert(r io.Reader) (*bytes.Buffer, error) {
// Decode XML document
root := &Node{}
err := NewDecoder(r).Decode(root)
if err != nil {
return nil, err
}
// Then encode it in JSON
buf := new(bytes.Buffer)
err = NewEncoder(buf).Encode(root)
if err != nil {
return nil, err
}
return buf, nil
}

@ -1,140 +0,0 @@
package xml2json
import (
"encoding/xml"
"io"
"unicode"
"golang.org/x/net/html/charset"
)
const (
attrPrefix = "-"
contentPrefix = "#"
)
// A Decoder reads and decodes XML objects from an input stream.
type Decoder struct {
r io.Reader
err error
attributePrefix string
contentPrefix string
}
type element struct {
parent *element
n *Node
label string
}
func (dec *Decoder) SetAttributePrefix(prefix string) {
dec.attributePrefix = prefix
}
func (dec *Decoder) SetContentPrefix(prefix string) {
dec.contentPrefix = prefix
}
func (dec *Decoder) DecodeWithCustomPrefixes(root *Node, contentPrefix string, attributePrefix string) error {
dec.contentPrefix = contentPrefix
dec.attributePrefix = attributePrefix
return dec.Decode(root)
}
// NewDecoder returns a new decoder that reads from r.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// Decode reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v.
func (dec *Decoder) Decode(root *Node) error {
if dec.contentPrefix == "" {
dec.contentPrefix = contentPrefix
}
if dec.attributePrefix == "" {
dec.attributePrefix = attrPrefix
}
xmlDec := xml.NewDecoder(dec.r)
// That will convert the charset if the provided XML is non-UTF-8
xmlDec.CharsetReader = charset.NewReaderLabel
// Create first element from the root node
elem := &element{
parent: nil,
n: root,
}
for {
t, _ := xmlDec.Token()
if t == nil {
break
}
switch se := t.(type) {
case xml.StartElement:
// Build new a new current element and link it to its parent
elem = &element{
parent: elem,
n: &Node{},
label: se.Name.Local,
}
// Extract attributes as children
for _, a := range se.Attr {
elem.n.AddChild(dec.attributePrefix+a.Name.Local, &Node{Data: a.Value})
}
case xml.CharData:
// Extract XML data (if any)
elem.n.Data = trimNonGraphic(string(xml.CharData(se)))
case xml.EndElement:
// And add it to its parent list
if elem.parent != nil {
elem.parent.n.AddChild(elem.label, elem.n)
}
// Then change the current element to its parent
elem = elem.parent
}
}
return nil
}
// trimNonGraphic returns a slice of the string s, with all leading and trailing
// non graphic characters and spaces removed.
//
// Graphic characters include letters, marks, numbers, punctuation, symbols,
// and spaces, from categories L, M, N, P, S, Zs.
// Spacing characters are set by category Z and property Pattern_White_Space.
func trimNonGraphic(s string) string {
if s == "" {
return s
}
var first *int
var last int
for i, r := range []rune(s) {
if !unicode.IsGraphic(r) || unicode.IsSpace(r) {
continue
}
if first == nil {
f := i // copy i
first = &f
last = i
} else {
last = i
}
}
// If first is nil, it means there are no graphic characters
if first == nil {
return ""
}
return string([]rune(s)[*first : last+1])
}

@ -1,2 +0,0 @@
// Package xml2json is an XML to JSON converter
package xml2json

@ -1,197 +0,0 @@
package xml2json
import (
"bytes"
"io"
"unicode/utf8"
)
// An Encoder writes JSON objects to an output stream.
type Encoder struct {
w io.Writer
err error
contentPrefix string
attributePrefix string
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w}
}
func (enc *Encoder) SetAttributePrefix(prefix string) {
enc.attributePrefix = prefix
}
func (enc *Encoder) SetContentPrefix(prefix string) {
enc.contentPrefix = prefix
}
func (enc *Encoder) EncodeWithCustomPrefixes(root *Node, contentPrefix string, attributePrefix string) error {
enc.contentPrefix = contentPrefix
enc.attributePrefix = attributePrefix
return enc.Encode(root)
}
// Encode writes the JSON encoding of v to the stream
func (enc *Encoder) Encode(root *Node) error {
if enc.err != nil {
return enc.err
}
if root == nil {
return nil
}
if enc.contentPrefix == "" {
enc.contentPrefix = contentPrefix
}
if enc.attributePrefix == "" {
enc.attributePrefix = attrPrefix
}
enc.err = enc.format(root, 0)
// Terminate each value with a newline.
// This makes the output look a little nicer
// when debugging, and some kind of space
// is required if the encoded value was a number,
// so that the reader knows there aren't more
// digits coming.
enc.write("\n")
return enc.err
}
func (enc *Encoder) format(n *Node, lvl int) error {
if n.IsComplex() {
enc.write("{")
// Add data as an additional attibute (if any)
if len(n.Data) > 0 {
enc.write("\"")
enc.write(enc.contentPrefix)
enc.write("content")
enc.write("\": ")
enc.write(sanitiseString(n.Data))
enc.write(", ")
}
i := 0
tot := len(n.Children)
for label, children := range n.Children {
enc.write("\"")
enc.write(label)
enc.write("\": ")
if len(children) > 1 {
// Array
enc.write("[")
for j, c := range children {
enc.format(c, lvl+1)
if j < len(children)-1 {
enc.write(", ")
}
}
enc.write("]")
} else {
// Map
enc.format(children[0], lvl+1)
}
if i < tot-1 {
enc.write(", ")
}
i++
}
enc.write("}")
} else {
// TODO : Extract data type
enc.write(sanitiseString(n.Data))
}
return nil
}
func (enc *Encoder) write(s string) {
enc.w.Write([]byte(s))
}
// https://golang.org/src/encoding/json/encode.go?s=5584:5627#L788
var hex = "0123456789abcdef"
func sanitiseString(s string) string {
var buf bytes.Buffer
buf.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
i++
continue
}
if start < i {
buf.WriteString(s[start:i])
}
switch b {
case '\\', '"':
buf.WriteByte('\\')
buf.WriteByte(b)
case '\n':
buf.WriteByte('\\')
buf.WriteByte('n')
case '\r':
buf.WriteByte('\\')
buf.WriteByte('r')
case '\t':
buf.WriteByte('\\')
buf.WriteByte('t')
default:
// This encodes bytes < 0x20 except for \n and \r,
// as well as <, > and &. The latter are escaped because they
// can lead to security holes when user-controlled strings
// are rendered into JSON and served to some browsers.
buf.WriteString(`\u00`)
buf.WriteByte(hex[b>>4])
buf.WriteByte(hex[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\u202`)
buf.WriteByte(hex[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
buf.WriteString(s[start:])
}
buf.WriteByte('"')
return buf.String()
}

@ -1,25 +0,0 @@
package xml2json
// Node is a data element on a tree
type Node struct {
Children map[string]Nodes
Data string
}
// Nodes is a list of nodes
type Nodes []*Node
// AddChild appends a node to the list of children
func (n *Node) AddChild(s string, c *Node) {
// Lazy lazy
if n.Children == nil {
n.Children = map[string]Nodes{}
}
n.Children[s] = append(n.Children[s], c)
}
// IsComplex returns whether it is a complex type (has children)
func (n *Node) IsComplex() bool {
return len(n.Children) > 0
}

@ -1,13 +0,0 @@
Copyright 2014 astaxie
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

@ -1,159 +0,0 @@
# beego orm
[![Build Status](https://drone.io/github.com/beego/beego/v2/status.png)](https://drone.io/github.com/beego/beego/v2/latest)
A powerful orm framework for go.
It is heavily influenced by Django ORM, SQLAlchemy.
**Support Database:**
* MySQL: [github.com/go-sql-driver/mysql](https://github.com/go-sql-driver/mysql)
* PostgreSQL: [github.com/lib/pq](https://github.com/lib/pq)
* Sqlite3: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3)
Passed all test, but need more feedback.
**Features:**
* full go type support
* easy for usage, simple CRUD operation
* auto join with relation table
* cross DataBase compatible query
* Raw SQL query / mapper without orm model
* full test keep stable and strong
more features please read the docs
**Install:**
go get github.com/beego/beego/v2/client/orm
## Changelog
* 2013-08-19: support table auto create
* 2013-08-13: update test for database types
* 2013-08-13: go type support, such as int8, uint8, byte, rune
* 2013-08-13: date / datetime timezone support very well
## Quick Start
#### Simple Usage
```go
package main
import (
"fmt"
"github.com/beego/beego/v2/client/orm"
_ "github.com/go-sql-driver/mysql" // import your used driver
)
// Model Struct
type User struct {
Id int `orm:"auto"`
Name string `orm:"size(100)"`
}
func init() {
// register model
orm.RegisterModel(new(User))
// set default database
orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
// create table
orm.RunSyncdb("default", false, true)
}
func main() {
o := orm.NewOrm()
user := User{Name: "slene"}
// insert
id, err := o.Insert(&user)
// update
user.Name = "astaxie"
num, err := o.Update(&user)
// read one
u := User{Id: user.Id}
err = o.Read(&u)
// delete
num, err = o.Delete(&u)
}
```
#### Next with relation
```go
type Post struct {
Id int `orm:"auto"`
Title string `orm:"size(100)"`
User *User `orm:"rel(fk)"`
}
var posts []*Post
qs := o.QueryTable("post")
num, err := qs.Filter("User__Name", "slene").All(&posts)
```
#### Use Raw sql
If you don't like ORMuse Raw SQL to query / mapping without ORM setting
```go
var maps []Params
num, err := o.Raw("SELECT id FROM user WHERE name = ?", "slene").Values(&maps)
if num > 0 {
fmt.Println(maps[0]["id"])
}
```
#### Transaction
```go
o.Begin()
...
user := User{Name: "slene"}
id, err := o.Insert(&user)
if err == nil {
o.Commit()
} else {
o.Rollback()
}
```
#### Debug Log Queries
In development env, you can simple use
```go
func main() {
orm.Debug = true
...
```
enable log queries.
output include all queries, such as exec / prepare / transaction.
like this:
```go
[ORM] - 2013-08-09 13:18:16 - [Queries/default] - [ db.Exec / 0.4ms] - [INSERT INTO `user` (`name`) VALUES (?)] - `slene`
...
```
note: not recommend use this in product env.
## Docs
more details and examples in docs and test
[documents](http://beego.vip/docs/mvc/model/overview.md)

@ -1,6 +0,0 @@
package clauses
const (
ExprSep = "__"
ExprDot = "."
)

@ -1,104 +0,0 @@
package order_clause
import (
"strings"
"github.com/beego/beego/v2/client/orm/clauses"
)
type Sort int8
const (
None Sort = 0
Ascending Sort = 1
Descending Sort = 2
)
type Option func(order *Order)
type Order struct {
column string
sort Sort
isRaw bool
}
func Clause(options ...Option) *Order {
o := &Order{}
for _, option := range options {
option(o)
}
return o
}
func (o *Order) GetColumn() string {
return o.column
}
func (o *Order) GetSort() Sort {
return o.sort
}
func (o *Order) SortString() string {
switch o.GetSort() {
case Ascending:
return "ASC"
case Descending:
return "DESC"
}
return ``
}
func (o *Order) IsRaw() bool {
return o.isRaw
}
func ParseOrder(expressions ...string) []*Order {
var orders []*Order
for _, expression := range expressions {
sort := Ascending
column := strings.ReplaceAll(expression, clauses.ExprSep, clauses.ExprDot)
if column[0] == '-' {
sort = Descending
column = column[1:]
}
orders = append(orders, &Order{
column: column,
sort: sort,
})
}
return orders
}
func Column(column string) Option {
return func(order *Order) {
order.column = strings.ReplaceAll(column, clauses.ExprSep, clauses.ExprDot)
}
}
func sort(sort Sort) Option {
return func(order *Order) {
order.sort = sort
}
}
func SortAscending() Option {
return sort(Ascending)
}
func SortDescending() Option {
return sort(Descending)
}
func SortNone() Option {
return sort(None)
}
func Raw() Option {
return func(order *Order) {
order.isRaw = true
}
}

@ -1,299 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"flag"
"fmt"
"os"
"strings"
)
type commander interface {
Parse([]string)
Run() error
}
var commands = make(map[string]commander)
// print help.
func printHelp(errs ...string) {
content := `orm command usage:
syncdb - auto create tables
sqlall - print sql of create tables
help - print this help
`
if len(errs) > 0 {
fmt.Println(errs[0])
}
fmt.Println(content)
os.Exit(2)
}
// RunCommand listens for orm command and runs if command arguments have been passed.
func RunCommand() {
if len(os.Args) < 2 || os.Args[1] != "orm" {
return
}
BootStrap()
args := argString(os.Args[2:])
name := args.Get(0)
if name == "help" {
printHelp()
}
if cmd, ok := commands[name]; ok {
cmd.Parse(os.Args[3:])
cmd.Run()
os.Exit(0)
} else {
if name == "" {
printHelp()
} else {
printHelp(fmt.Sprintf("unknown command %s", name))
}
}
}
// sync database struct command interface.
type commandSyncDb struct {
al *alias
force bool
verbose bool
noInfo bool
rtOnError bool
}
// Parse the orm command line arguments.
func (d *commandSyncDb) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: syncdb", flag.ExitOnError)
flagSet.StringVar(&name, "db", "default", "DataBase alias name")
flagSet.BoolVar(&d.force, "force", false, "drop tables before create")
flagSet.BoolVar(&d.verbose, "v", false, "verbose info")
flagSet.Parse(args)
d.al = getDbAlias(name)
}
// Run orm line command.
func (d *commandSyncDb) Run() error {
var drops []string
var err error
if d.force {
drops, err = defaultModelCache.getDbDropSQL(d.al)
if err != nil {
return err
}
}
db := d.al.DB
if d.force && len(drops) > 0 {
for i, mi := range defaultModelCache.allOrdered() {
query := drops[i]
if !d.noInfo {
fmt.Printf("drop table `%s`\n", mi.table)
}
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
}
createQueries, indexes, err := defaultModelCache.getDbCreateSQL(d.al)
if err != nil {
return err
}
tables, err := d.al.DbBaser.GetTables(db)
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
ctx := context.Background()
for i, mi := range defaultModelCache.allOrdered() {
if !isApplicableTableForDB(mi.addrField, d.al.Name) {
fmt.Printf("table `%s` is not applicable to database '%s'\n", mi.table, d.al.Name)
continue
}
if tables[mi.table] {
if !d.noInfo {
fmt.Printf("table `%s` already exists, skip\n", mi.table)
}
var fields []*fieldInfo
columns, err := d.al.DbBaser.GetColumns(ctx, db, mi.table)
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
for _, fi := range mi.fields.fieldsDB {
if _, ok := columns[fi.column]; !ok {
fields = append(fields, fi)
}
}
for _, fi := range fields {
query := getColumnAddQuery(d.al, fi)
if !d.noInfo {
fmt.Printf("add column `%s` for table `%s`\n", fi.fullName, mi.table)
}
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
for _, idx := range indexes[mi.table] {
if !d.al.DbBaser.IndexExists(ctx, db, idx.Table, idx.Name) {
if !d.noInfo {
fmt.Printf("create index `%s` for table `%s`\n", idx.Name, idx.Table)
}
query := idx.SQL
_, err := db.Exec(query)
if d.verbose {
fmt.Printf(" %s\n", query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
}
continue
}
if !d.noInfo {
fmt.Printf("create table `%s` \n", mi.table)
}
queries := []string{createQueries[i]}
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.SQL)
}
for _, query := range queries {
_, err := db.Exec(query)
if d.verbose {
query = " " + strings.Join(strings.Split(query, "\n"), "\n ")
fmt.Println(query)
}
if err != nil {
if d.rtOnError {
return err
}
fmt.Printf(" %s\n", err.Error())
}
}
if d.verbose {
fmt.Println("")
}
}
return nil
}
// database creation commander interface implement.
type commandSQLAll struct {
al *alias
}
// Parse orm command line arguments.
func (d *commandSQLAll) Parse(args []string) {
var name string
flagSet := flag.NewFlagSet("orm command: sqlall", flag.ExitOnError)
flagSet.StringVar(&name, "db", "default", "DataBase alias name")
flagSet.Parse(args)
d.al = getDbAlias(name)
}
// Run orm line command.
func (d *commandSQLAll) Run() error {
createQueries, indexes, err := defaultModelCache.getDbCreateSQL(d.al)
if err != nil {
return err
}
var all []string
for i, mi := range defaultModelCache.allOrdered() {
queries := []string{createQueries[i]}
for _, idx := range indexes[mi.table] {
queries = append(queries, idx.SQL)
}
sql := strings.Join(queries, "\n")
all = append(all, sql)
}
fmt.Println(strings.Join(all, "\n\n"))
return nil
}
func init() {
commands["syncdb"] = new(commandSyncDb)
commands["sqlall"] = new(commandSQLAll)
}
// RunSyncdb run syncdb command line.
// name: Table's alias name (default is "default")
// force: Run the next sql command even if the current gave an error
// verbose: Print all information, useful for debugging
func RunSyncdb(name string, force bool, verbose bool) error {
BootStrap()
al := getDbAlias(name)
cmd := new(commandSyncDb)
cmd.al = al
cmd.force = force
cmd.noInfo = !verbose
cmd.verbose = verbose
cmd.rtOnError = true
return cmd.Run()
}

@ -1,169 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strings"
)
type dbIndex struct {
Table string
Name string
SQL string
}
// get database column type string.
func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
T := al.DbBaser.DbTypes()
fieldType := fi.fieldType
fieldSize := fi.size
checkColumn:
switch fieldType {
case TypeBooleanField:
col = T["bool"]
case TypeVarCharField:
if al.Driver == DRPostgres && fi.toText {
col = T["string-text"]
} else {
col = fmt.Sprintf(T["string"], fieldSize)
}
case TypeCharField:
col = fmt.Sprintf(T["string-char"], fieldSize)
case TypeTextField:
col = T["string-text"]
case TypeTimeField:
col = T["time.Time-clock"]
case TypeDateField:
col = T["time.Time-date"]
case TypeDateTimeField:
// the precision of sqlite is not implemented
if al.Driver == 2 || fi.timePrecision == nil {
col = T["time.Time"]
} else {
s := T["time.Time-precision"]
col = fmt.Sprintf(s, *fi.timePrecision)
}
case TypeBitField:
col = T["int8"]
case TypeSmallIntegerField:
col = T["int16"]
case TypeIntegerField:
col = T["int32"]
case TypeBigIntegerField:
if al.Driver == DRSqlite {
fieldType = TypeIntegerField
goto checkColumn
}
col = T["int64"]
case TypePositiveBitField:
col = T["uint8"]
case TypePositiveSmallIntegerField:
col = T["uint16"]
case TypePositiveIntegerField:
col = T["uint32"]
case TypePositiveBigIntegerField:
col = T["uint64"]
case TypeFloatField:
col = T["float64"]
case TypeDecimalField:
s := T["float64-decimal"]
if !strings.Contains(s, "%d") {
col = s
} else {
col = fmt.Sprintf(s, fi.digits, fi.decimals)
}
case TypeJSONField:
if al.Driver != DRPostgres {
fieldType = TypeVarCharField
goto checkColumn
}
col = T["json"]
case TypeJsonbField:
if al.Driver != DRPostgres {
fieldType = TypeVarCharField
goto checkColumn
}
col = T["jsonb"]
case RelForeignKey, RelOneToOne:
fieldType = fi.relModelInfo.fields.pk.fieldType
fieldSize = fi.relModelInfo.fields.pk.size
goto checkColumn
}
return
}
// create alter sql string.
func getColumnAddQuery(al *alias, fi *fieldInfo) string {
Q := al.DbBaser.TableQuote()
typ := getColumnTyp(al, fi)
if !fi.null {
typ += " " + "NOT NULL"
}
return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
Q, fi.mi.table, Q,
Q, fi.column, Q,
typ, getColumnDefault(fi),
)
}
// Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
func getColumnDefault(fi *fieldInfo) string {
var v, t, d string
// Skip default attribute if field is in relations
if fi.rel || fi.reverse {
return v
}
t = " DEFAULT '%s' "
// These defaults will be useful if there no config value orm:"default" and NOT NULL is on
switch fi.fieldType {
case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
return v
case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
TypeDecimalField:
t = " DEFAULT %s "
d = "0"
case TypeBooleanField:
t = " DEFAULT %s "
d = "FALSE"
case TypeJSONField, TypeJsonbField:
d = "{}"
}
if fi.colDefault {
if !fi.initial.Exist() {
v = fmt.Sprintf(t, "")
} else {
v = fmt.Sprintf(t, fi.initial.String())
}
} else {
if !fi.null {
v = fmt.Sprintf(t, d)
}
}
return v
}

File diff suppressed because it is too large Load Diff

@ -1,599 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"database/sql"
"fmt"
"sync"
"time"
lru "github.com/hashicorp/golang-lru"
)
// DriverType database driver constant int.
type DriverType int
// Enum the Database driver
const (
_ DriverType = iota // int enum type
DRMySQL // mysql
DRSqlite // sqlite
DROracle // oracle
DRPostgres // pgsql
DRTiDB // TiDB
)
// database driver string.
type driver string
// get type constant int of current driver..
func (d driver) Type() DriverType {
a, _ := dataBaseCache.get(string(d))
return a.Driver
}
// get name of current driver
func (d driver) Name() string {
return string(d)
}
// check driver iis implemented Driver interface or not.
var _ Driver = new(driver)
var (
dataBaseCache = &_dbCache{cache: make(map[string]*alias)}
drivers = map[string]DriverType{
"mysql": DRMySQL,
"postgres": DRPostgres,
"sqlite3": DRSqlite,
"tidb": DRTiDB,
"oracle": DROracle,
"oci8": DROracle, // github.com/mattn/go-oci8
"ora": DROracle, // https://github.com/rana/ora
}
dbBasers = map[DriverType]dbBaser{
DRMySQL: newdbBaseMysql(),
DRSqlite: newdbBaseSqlite(),
DROracle: newdbBaseOracle(),
DRPostgres: newdbBasePostgres(),
DRTiDB: newdbBaseTidb(),
}
)
// database alias cacher.
type _dbCache struct {
mux sync.RWMutex
cache map[string]*alias
}
// add database alias with original name.
func (ac *_dbCache) add(name string, al *alias) (added bool) {
ac.mux.Lock()
defer ac.mux.Unlock()
if _, ok := ac.cache[name]; !ok {
ac.cache[name] = al
added = true
}
return
}
// get database alias if cached.
func (ac *_dbCache) get(name string) (al *alias, ok bool) {
ac.mux.RLock()
defer ac.mux.RUnlock()
al, ok = ac.cache[name]
return
}
// get default alias.
func (ac *_dbCache) getDefault() (al *alias) {
al, _ = ac.get("default")
return
}
type DB struct {
*sync.RWMutex
DB *sql.DB
stmtDecorators *lru.Cache
stmtDecoratorsLimit int
}
var (
_ dbQuerier = new(DB)
_ txer = new(DB)
)
func (d *DB) Begin() (*sql.Tx, error) {
return d.DB.Begin()
}
func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return d.DB.BeginTx(ctx, opts)
}
// su must call release to release *sql.Stmt after using
func (d *DB) getStmtDecorator(query string) (*stmtDecorator, error) {
d.RLock()
c, ok := d.stmtDecorators.Get(query)
if ok {
c.(*stmtDecorator).acquire()
d.RUnlock()
return c.(*stmtDecorator), nil
}
d.RUnlock()
d.Lock()
c, ok = d.stmtDecorators.Get(query)
if ok {
c.(*stmtDecorator).acquire()
d.Unlock()
return c.(*stmtDecorator), nil
}
stmt, err := d.Prepare(query)
if err != nil {
d.Unlock()
return nil, err
}
sd := newStmtDecorator(stmt)
sd.acquire()
d.stmtDecorators.Add(query, sd)
d.Unlock()
return sd, nil
}
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
return d.DB.Prepare(query)
}
func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return d.DB.PrepareContext(ctx, query)
}
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
return d.ExecContext(context.Background(), query, args...)
}
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if d.stmtDecorators == nil {
return d.DB.ExecContext(ctx, query, args...)
}
sd, err := d.getStmtDecorator(query)
if err != nil {
return nil, err
}
stmt := sd.getStmt()
defer sd.release()
return stmt.ExecContext(ctx, args...)
}
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
return d.QueryContext(context.Background(), query, args...)
}
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
if d.stmtDecorators == nil {
return d.DB.QueryContext(ctx, query, args...)
}
sd, err := d.getStmtDecorator(query)
if err != nil {
return nil, err
}
stmt := sd.getStmt()
defer sd.release()
return stmt.QueryContext(ctx, args...)
}
func (d *DB) QueryRow(query string, args ...interface{}) *sql.Row {
return d.QueryRowContext(context.Background(), query, args...)
}
func (d *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
if d.stmtDecorators == nil {
return d.DB.QueryRowContext(ctx, query, args...)
}
sd, err := d.getStmtDecorator(query)
if err != nil {
panic(err)
}
stmt := sd.getStmt()
defer sd.release()
return stmt.QueryRowContext(ctx, args...)
}
type TxDB struct {
tx *sql.Tx
}
var (
_ dbQuerier = new(TxDB)
_ txEnder = new(TxDB)
)
func (t *TxDB) Commit() error {
return t.tx.Commit()
}
func (t *TxDB) Rollback() error {
return t.tx.Rollback()
}
func (t *TxDB) RollbackUnlessCommit() error {
err := t.tx.Rollback()
if err != sql.ErrTxDone {
return err
}
return nil
}
var (
_ dbQuerier = new(TxDB)
_ txEnder = new(TxDB)
)
func (t *TxDB) Prepare(query string) (*sql.Stmt, error) {
return t.PrepareContext(context.Background(), query)
}
func (t *TxDB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
return t.tx.PrepareContext(ctx, query)
}
func (t *TxDB) Exec(query string, args ...interface{}) (sql.Result, error) {
return t.ExecContext(context.Background(), query, args...)
}
func (t *TxDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return t.tx.ExecContext(ctx, query, args...)
}
func (t *TxDB) Query(query string, args ...interface{}) (*sql.Rows, error) {
return t.QueryContext(context.Background(), query, args...)
}
func (t *TxDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return t.tx.QueryContext(ctx, query, args...)
}
func (t *TxDB) QueryRow(query string, args ...interface{}) *sql.Row {
return t.QueryRowContext(context.Background(), query, args...)
}
func (t *TxDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
return t.tx.QueryRowContext(ctx, query, args...)
}
type alias struct {
Name string
Driver DriverType
DriverName string
DataSource string
MaxIdleConns int
MaxOpenConns int
ConnMaxLifetime time.Duration
StmtCacheSize int
DB *DB
DbBaser dbBaser
TZ *time.Location
Engine string
}
func detectTZ(al *alias) {
// orm timezone system match database
// default use Local
al.TZ = DefaultTimeLoc
if al.DriverName == "sphinx" {
return
}
switch al.Driver {
case DRMySQL:
row := al.DB.QueryRow("SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP)")
var tz string
row.Scan(&tz)
if len(tz) >= 8 {
if tz[0] != '-' {
tz = "+" + tz
}
t, err := time.Parse("-07:00:00", tz)
if err == nil {
if t.Location().String() != "" {
al.TZ = t.Location()
}
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
// get default engine from current database
row = al.DB.QueryRow("SELECT ENGINE, TRANSACTIONS FROM information_schema.engines WHERE SUPPORT = 'DEFAULT'")
var engine string
var tx bool
row.Scan(&engine, &tx)
if engine != "" {
al.Engine = engine
} else {
al.Engine = "INNODB"
}
case DRSqlite, DROracle:
al.TZ = time.UTC
case DRPostgres:
row := al.DB.QueryRow("SELECT current_setting('TIMEZONE')")
var tz string
row.Scan(&tz)
loc, err := time.LoadLocation(tz)
if err == nil {
al.TZ = loc
} else {
DebugLog.Printf("Detect DB timezone: %s %s\n", tz, err.Error())
}
}
}
func addAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) {
existErr := fmt.Errorf("DataBase alias name `%s` already registered, cannot reuse", aliasName)
if _, ok := dataBaseCache.get(aliasName); ok {
return nil, existErr
}
al, err := newAliasWithDb(aliasName, driverName, db, params...)
if err != nil {
return nil, err
}
if !dataBaseCache.add(aliasName, al) {
return nil, existErr
}
return al, nil
}
func newAliasWithDb(aliasName, driverName string, db *sql.DB, params ...DBOption) (*alias, error) {
al := &alias{}
al.DB = &DB{
RWMutex: new(sync.RWMutex),
DB: db,
}
for _, p := range params {
p(al)
}
var stmtCache *lru.Cache
var stmtCacheSize int
if al.StmtCacheSize > 0 {
_stmtCache, errC := newStmtDecoratorLruWithEvict(al.StmtCacheSize)
if errC != nil {
return nil, errC
} else {
stmtCache = _stmtCache
stmtCacheSize = al.StmtCacheSize
}
}
al.Name = aliasName
al.DriverName = driverName
al.DB.stmtDecorators = stmtCache
al.DB.stmtDecoratorsLimit = stmtCacheSize
if dr, ok := drivers[driverName]; ok {
al.DbBaser = dbBasers[dr]
al.Driver = dr
} else {
return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
}
err := db.Ping()
if err != nil {
return nil, fmt.Errorf("register db Ping `%s`, %s", aliasName, err.Error())
}
detectTZ(al)
return al, nil
}
// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
// Deprecated you should not use this, we will remove it in the future
func SetMaxIdleConns(aliasName string, maxIdleConns int) {
al := getDbAlias(aliasName)
al.SetMaxIdleConns(maxIdleConns)
}
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
// Deprecated you should not use this, we will remove it in the future
func SetMaxOpenConns(aliasName string, maxOpenConns int) {
al := getDbAlias(aliasName)
al.SetMaxOpenConns(maxOpenConns)
}
// SetMaxIdleConns Change the max idle conns for *sql.DB, use specify database alias name
func (al *alias) SetMaxIdleConns(maxIdleConns int) {
al.MaxIdleConns = maxIdleConns
al.DB.DB.SetMaxIdleConns(maxIdleConns)
}
// SetMaxOpenConns Change the max open conns for *sql.DB, use specify database alias name
func (al *alias) SetMaxOpenConns(maxOpenConns int) {
al.MaxOpenConns = maxOpenConns
al.DB.DB.SetMaxOpenConns(maxOpenConns)
}
func (al *alias) SetConnMaxLifetime(lifeTime time.Duration) {
al.ConnMaxLifetime = lifeTime
al.DB.DB.SetConnMaxLifetime(lifeTime)
}
// AddAliasWthDB add a aliasName for the drivename
func AddAliasWthDB(aliasName, driverName string, db *sql.DB, params ...DBOption) error {
_, err := addAliasWthDB(aliasName, driverName, db, params...)
return err
}
// RegisterDataBase Setting the database connect params. Use the database driver self dataSource args.
func RegisterDataBase(aliasName, driverName, dataSource string, params ...DBOption) error {
var (
err error
db *sql.DB
al *alias
)
db, err = sql.Open(driverName, dataSource)
if err != nil {
err = fmt.Errorf("register db `%s`, %s", aliasName, err.Error())
goto end
}
al, err = addAliasWthDB(aliasName, driverName, db, params...)
if err != nil {
goto end
}
al.DataSource = dataSource
end:
if err != nil {
if db != nil {
db.Close()
}
DebugLog.Println(err.Error())
}
return err
}
// RegisterDriver Register a database driver use specify driver name, this can be definition the driver is which database type.
func RegisterDriver(driverName string, typ DriverType) error {
if t, ok := drivers[driverName]; !ok {
drivers[driverName] = typ
} else {
if t != typ {
return fmt.Errorf("driverName `%s` db driver already registered and is other type", driverName)
}
}
return nil
}
// SetDataBaseTZ Change the database default used timezone
func SetDataBaseTZ(aliasName string, tz *time.Location) error {
if al, ok := dataBaseCache.get(aliasName); ok {
al.TZ = tz
} else {
return fmt.Errorf("DataBase alias name `%s` not registered", aliasName)
}
return nil
}
// GetDB Get *sql.DB from registered database by db alias name.
// Use "default" as alias name if you not set.
func GetDB(aliasNames ...string) (*sql.DB, error) {
var name string
if len(aliasNames) > 0 {
name = aliasNames[0]
} else {
name = "default"
}
al, ok := dataBaseCache.get(name)
if ok {
return al.DB.DB, nil
}
return nil, fmt.Errorf("DataBase of alias name `%s` not found", name)
}
type stmtDecorator struct {
wg sync.WaitGroup
stmt *sql.Stmt
}
func (s *stmtDecorator) getStmt() *sql.Stmt {
return s.stmt
}
// acquire will add one
// since this method will be used inside read lock scope,
// so we can not do more things here
// we should think about refactor this
func (s *stmtDecorator) acquire() {
s.wg.Add(1)
}
func (s *stmtDecorator) release() {
s.wg.Done()
}
// garbage recycle for stmt
func (s *stmtDecorator) destroy() {
go func() {
s.wg.Wait()
_ = s.stmt.Close()
}()
}
func newStmtDecorator(sqlStmt *sql.Stmt) *stmtDecorator {
return &stmtDecorator{
stmt: sqlStmt,
}
}
func newStmtDecoratorLruWithEvict(cacheSize int) (*lru.Cache, error) {
cache, err := lru.NewWithEvict(cacheSize, func(key interface{}, value interface{}) {
value.(*stmtDecorator).destroy()
})
if err != nil {
return nil, err
}
return cache, nil
}
type DBOption func(al *alias)
// MaxIdleConnections return a hint about MaxIdleConnections
func MaxIdleConnections(maxIdleConn int) DBOption {
return func(al *alias) {
al.SetMaxIdleConns(maxIdleConn)
}
}
// MaxOpenConnections return a hint about MaxOpenConnections
func MaxOpenConnections(maxOpenConn int) DBOption {
return func(al *alias) {
al.SetMaxOpenConns(maxOpenConn)
}
}
// ConnMaxLifetime return a hint about ConnMaxLifetime
func ConnMaxLifetime(v time.Duration) DBOption {
return func(al *alias) {
al.SetConnMaxLifetime(v)
}
}
// MaxStmtCacheSize return a hint about MaxStmtCacheSize
func MaxStmtCacheSize(v int) DBOption {
return func(al *alias) {
al.StmtCacheSize = v
}
}

@ -1,192 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"fmt"
"reflect"
"strings"
)
// mysql operators.
var mysqlOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ?",
"strictexact": "= BINARY ?",
"contains": "LIKE BINARY ?",
"icontains": "LIKE ?",
// "regex": "REGEXP BINARY ?",
// "iregex": "REGEXP ?",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"eq": "= ?",
"ne": "!= ?",
"startswith": "LIKE BINARY ?",
"endswith": "LIKE BINARY ?",
"istartswith": "LIKE ?",
"iendswith": "LIKE ?",
}
// mysql column field types.
var mysqlTypes = map[string]string{
"auto": "AUTO_INCREMENT NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "longtext",
"time.Time-date": "date",
"time.Time": "datetime",
"int8": "tinyint",
"int16": "smallint",
"int32": "integer",
"int64": "bigint",
"uint8": "tinyint unsigned",
"uint16": "smallint unsigned",
"uint32": "integer unsigned",
"uint64": "bigint unsigned",
"float64": "double precision",
"float64-decimal": "numeric(%d, %d)",
"time.Time-precision": "datetime(%d)",
}
// mysql dbBaser implementation.
type dbBaseMysql struct {
dbBase
}
var _ dbBaser = new(dbBaseMysql)
// get mysql operator.
func (d *dbBaseMysql) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}
// get mysql table field types.
func (d *dbBaseMysql) DbTypes() map[string]string {
return mysqlTypes
}
// show table sql for mysql.
func (d *dbBaseMysql) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
func (d *dbBaseMysql) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
// execute sql to check index exist.
func (d *dbBaseMysql) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
// InsertOrUpdate a row
// If your primary key or unique column conflict will update
// If no will insert
// Add "`" for mysql sql building
func (d *dbBaseMysql) InsertOrUpdate(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, a *alias, args ...string) (int64, error) {
var iouStr string
argsMap := map[string]string{}
iouStr = "ON DUPLICATE KEY UPDATE"
// Get on the key-value pairs
for _, v := range args {
kv := strings.Split(v, "=")
if len(kv) == 2 {
argsMap[strings.ToLower(kv[0])] = kv[1]
}
}
isMulti := false
names := make([]string, 0, len(mi.fields.dbcols)-1)
Q := d.ins.TableQuote()
values, _, err := d.collectValues(mi, ind, mi.fields.dbcols, true, true, &names, a.TZ)
if err != nil {
return 0, err
}
marks := make([]string, len(names))
updateValues := make([]interface{}, 0)
updates := make([]string, len(names))
for i, v := range names {
marks[i] = "?"
valueStr := argsMap[strings.ToLower(v)]
if valueStr != "" {
updates[i] = "`" + v + "`" + "=" + valueStr
} else {
updates[i] = "`" + v + "`" + "=?"
updateValues = append(updateValues, values[i])
}
}
values = append(values, updateValues...)
sep := fmt.Sprintf("%s, %s", Q, Q)
qmarks := strings.Join(marks, ", ")
qupdates := strings.Join(updates, ", ")
columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
// conflitValue maybe is a int,can`t use fmt.Sprintf
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s) %s "+qupdates, Q, mi.table, Q, Q, columns, Q, qmarks, iouStr)
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
}
lastInsertId, err := res.LastInsertId()
if err != nil {
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
return lastInsertId, ErrLastInsertIdUnavailable
} else {
return lastInsertId, nil
}
}
return 0, err
}
row := q.QueryRowContext(ctx, query, values...)
var id int64
err = row.Scan(&id)
return id, err
}
// create new mysql dbBaser.
func newdbBaseMysql() dbBaser {
b := new(dbBaseMysql)
b.ins = b
return b
}

@ -1,171 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"fmt"
"strings"
"github.com/beego/beego/v2/client/orm/hints"
)
// oracle operators.
var oracleOperators = map[string]string{
"exact": "= ?",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"//iendswith": "LIKE ?",
}
// oracle column field types.
var oracleTypes = map[string]string{
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "VARCHAR2(%d)",
"string-char": "CHAR(%d)",
"string-text": "VARCHAR2(%d)",
"time.Time-date": "DATE",
"time.Time": "TIMESTAMP",
"int8": "INTEGER",
"int16": "INTEGER",
"int32": "INTEGER",
"int64": "INTEGER",
"uint8": "INTEGER",
"uint16": "INTEGER",
"uint32": "INTEGER",
"uint64": "INTEGER",
"float64": "NUMBER",
"float64-decimal": "NUMBER(%d, %d)",
"time.Time-precision": "TIMESTAMP(%d)",
}
// oracle dbBaser
type dbBaseOracle struct {
dbBase
}
var _ dbBaser = new(dbBaseOracle)
// create oracle dbBaser.
func newdbBaseOracle() dbBaser {
b := new(dbBaseOracle)
b.ins = b
return b
}
// OperatorSQL get oracle operator.
func (d *dbBaseOracle) OperatorSQL(operator string) string {
return oracleOperators[operator]
}
// DbTypes get oracle table field types.
func (d *dbBaseOracle) DbTypes() map[string]string {
return oracleTypes
}
// ShowTablesQuery show all the tables in database
func (d *dbBaseOracle) ShowTablesQuery() string {
return "SELECT TABLE_NAME FROM USER_TABLES"
}
// Oracle
func (d *dbBaseOracle) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME FROM ALL_TAB_COLUMNS "+
"WHERE TABLE_NAME ='%s'", strings.ToUpper(table))
}
// check index is exist
func (d *dbBaseOracle) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM USER_IND_COLUMNS, USER_INDEXES "+
"WHERE USER_IND_COLUMNS.INDEX_NAME = USER_INDEXES.INDEX_NAME "+
"AND USER_IND_COLUMNS.TABLE_NAME = ? AND USER_IND_COLUMNS.INDEX_NAME = ?", strings.ToUpper(table), strings.ToUpper(name))
var cnt int
row.Scan(&cnt)
return cnt > 0
}
func (d *dbBaseOracle) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
var s []string
Q := d.TableQuote()
for _, index := range indexes {
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
s = append(s, tmp)
}
var hint string
switch useIndex {
case hints.KeyUseIndex, hints.KeyForceIndex:
hint = `INDEX`
case hints.KeyIgnoreIndex:
hint = `NO_INDEX`
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
return fmt.Sprintf(` /*+ %s(%s %s)*/ `, hint, tableName, strings.Join(s, `,`))
}
// execute insert sql with given struct and given values.
// insert the given values, not the field values in struct.
func (d *dbBaseOracle) InsertValue(ctx context.Context, q dbQuerier, mi *modelInfo, isMulti bool, names []string, values []interface{}) (int64, error) {
Q := d.ins.TableQuote()
marks := make([]string, len(names))
for i := range marks {
marks[i] = ":" + names[i]
}
sep := fmt.Sprintf("%s, %s", Q, Q)
qmarks := strings.Join(marks, ", ")
columns := strings.Join(names, sep)
multi := len(values) / len(names)
if isMulti {
qmarks = strings.Repeat(qmarks+"), (", multi-1) + qmarks
}
query := fmt.Sprintf("INSERT INTO %s%s%s (%s%s%s) VALUES (%s)", Q, mi.table, Q, Q, columns, Q, qmarks)
d.ins.ReplaceMarks(&query)
if isMulti || !d.ins.HasReturningID(mi, &query) {
res, err := q.ExecContext(ctx, query, values...)
if err == nil {
if isMulti {
return res.RowsAffected()
}
lastInsertId, err := res.LastInsertId()
if err != nil {
DebugLog.Println(ErrLastInsertIdUnavailable, ':', err)
return lastInsertId, ErrLastInsertIdUnavailable
} else {
return lastInsertId, nil
}
}
return 0, err
}
row := q.QueryRowContext(ctx, query, values...)
var id int64
err := row.Scan(&id)
return id, err
}

@ -1,197 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"fmt"
"strconv"
)
// postgresql operators.
var postgresOperators = map[string]string{
"exact": "= ?",
"iexact": "= UPPER(?)",
"contains": "LIKE ?",
"icontains": "LIKE UPPER(?)",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"eq": "= ?",
"ne": "!= ?",
"startswith": "LIKE ?",
"endswith": "LIKE ?",
"istartswith": "LIKE UPPER(?)",
"iendswith": "LIKE UPPER(?)",
}
// postgresql column field types.
var postgresTypes = map[string]string{
"auto": "serial NOT NULL PRIMARY KEY",
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-char": "char(%d)",
"string-text": "text",
"time.Time-date": "date",
"time.Time": "timestamp with time zone",
"int8": `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`,
"int16": "smallint",
"int32": "integer",
"int64": "bigint",
"uint8": `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`,
"uint16": `integer CHECK("%COL%" >= 0)`,
"uint32": `bigint CHECK("%COL%" >= 0)`,
"uint64": `bigint CHECK("%COL%" >= 0)`,
"float64": "double precision",
"float64-decimal": "numeric(%d, %d)",
"json": "json",
"jsonb": "jsonb",
"time.Time-precision": "timestamp(%d) with time zone",
}
// postgresql dbBaser.
type dbBasePostgres struct {
dbBase
}
var _ dbBaser = new(dbBasePostgres)
// get postgresql operator.
func (d *dbBasePostgres) OperatorSQL(operator string) string {
return postgresOperators[operator]
}
// generate functioned sql string, such as contains(text).
func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
switch operator {
case "contains", "startswith", "endswith":
*leftCol = fmt.Sprintf("%s::text", *leftCol)
case "iexact", "icontains", "istartswith", "iendswith":
*leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol)
}
}
// postgresql unsupports updating joined record.
func (d *dbBasePostgres) SupportUpdateJoin() bool {
return false
}
func (d *dbBasePostgres) MaxLimit() uint64 {
return 0
}
// postgresql quote is ".
func (d *dbBasePostgres) TableQuote() string {
return `"`
}
// postgresql value placeholder is $n.
// replace default ? to $n.
func (d *dbBasePostgres) ReplaceMarks(query *string) {
q := *query
num := 0
for _, c := range q {
if c == '?' {
num++
}
}
if num == 0 {
return
}
data := make([]byte, 0, len(q)+num)
num = 1
for i := 0; i < len(q); i++ {
c := q[i]
if c == '?' {
data = append(data, '$')
data = append(data, []byte(strconv.Itoa(num))...)
num++
} else {
data = append(data, c)
}
}
*query = string(data)
}
// make returning sql support for postgresql.
func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
fi := mi.fields.pk
if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 {
return false
}
if query != nil {
*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column)
}
return true
}
// sync auto key
func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
if len(autoFields) == 0 {
return nil
}
Q := d.ins.TableQuote()
for _, name := range autoFields {
query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
mi.table, name,
Q, name, Q,
Q, mi.table, Q)
if _, err := db.ExecContext(ctx, query); err != nil {
return err
}
}
return nil
}
// show table sql for postgresql.
func (d *dbBasePostgres) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
}
// show table columns sql for postgresql.
func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
}
// get column types of postgresql.
func (d *dbBasePostgres) DbTypes() map[string]string {
return postgresTypes
}
// check index exist in postgresql.
func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
row := db.QueryRowContext(ctx, query)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
// GenerateSpecifyIndex return a specifying index clause
func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored")
return ``
}
// create new postgresql dbBaser.
func newdbBasePostgres() dbBaser {
b := new(dbBasePostgres)
b.ins = b
return b
}

@ -1,184 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"github.com/beego/beego/v2/client/orm/hints"
)
// sqlite operators.
var sqliteOperators = map[string]string{
"exact": "= ?",
"iexact": "LIKE ? ESCAPE '\\'",
"contains": "LIKE ? ESCAPE '\\'",
"icontains": "LIKE ? ESCAPE '\\'",
"gt": "> ?",
"gte": ">= ?",
"lt": "< ?",
"lte": "<= ?",
"eq": "= ?",
"ne": "!= ?",
"startswith": "LIKE ? ESCAPE '\\'",
"endswith": "LIKE ? ESCAPE '\\'",
"istartswith": "LIKE ? ESCAPE '\\'",
"iendswith": "LIKE ? ESCAPE '\\'",
}
// sqlite column types.
var sqliteTypes = map[string]string{
"auto": "integer NOT NULL PRIMARY KEY AUTOINCREMENT",
"pk": "NOT NULL PRIMARY KEY",
"bool": "bool",
"string": "varchar(%d)",
"string-char": "character(%d)",
"string-text": "text",
"time.Time-date": "date",
"time.Time": "datetime",
"time.Time-precision": "datetime(%d)",
"int8": "tinyint",
"int16": "smallint",
"int32": "integer",
"int64": "bigint",
"uint8": "tinyint unsigned",
"uint16": "smallint unsigned",
"uint32": "integer unsigned",
"uint64": "bigint unsigned",
"float64": "real",
"float64-decimal": "decimal",
}
// sqlite dbBaser.
type dbBaseSqlite struct {
dbBase
}
var _ dbBaser = new(dbBaseSqlite)
// override base db read for update behavior as SQlite does not support syntax
func (d *dbBaseSqlite) Read(ctx context.Context, q dbQuerier, mi *modelInfo, ind reflect.Value, tz *time.Location, cols []string, isForUpdate bool) error {
if isForUpdate {
DebugLog.Println("[WARN] SQLite does not support SELECT FOR UPDATE query, isForUpdate param is ignored and always as false to do the work")
}
return d.dbBase.Read(ctx, q, mi, ind, tz, cols, false)
}
// get sqlite operator.
func (d *dbBaseSqlite) OperatorSQL(operator string) string {
return sqliteOperators[operator]
}
// generate functioned sql for sqlite.
// only support DATE(text).
func (d *dbBaseSqlite) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
if fi.fieldType == TypeDateField {
*leftCol = fmt.Sprintf("DATE(%s)", *leftCol)
}
}
// unable updating joined record in sqlite.
func (d *dbBaseSqlite) SupportUpdateJoin() bool {
return false
}
// max int in sqlite.
func (d *dbBaseSqlite) MaxLimit() uint64 {
return 9223372036854775807
}
// get column types in sqlite.
func (d *dbBaseSqlite) DbTypes() map[string]string {
return sqliteTypes
}
// get show tables sql in sqlite.
func (d *dbBaseSqlite) ShowTablesQuery() string {
return "SELECT name FROM sqlite_master WHERE type = 'table'"
}
// get columns in sqlite.
func (d *dbBaseSqlite) GetColumns(ctx context.Context, db dbQuerier, table string) (map[string][3]string, error) {
query := d.ins.ShowColumnsQuery(table)
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
columns := make(map[string][3]string)
for rows.Next() {
var tmp, name, typ, null sql.NullString
err := rows.Scan(&tmp, &name, &typ, &null, &tmp, &tmp)
if err != nil {
return nil, err
}
columns[name.String] = [3]string{name.String, typ.String, null.String}
}
return columns, nil
}
// get show columns sql in sqlite.
func (d *dbBaseSqlite) ShowColumnsQuery(table string) string {
return fmt.Sprintf("pragma table_info('%s')", table)
}
// check index exist in sqlite.
func (d *dbBaseSqlite) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
query := fmt.Sprintf("PRAGMA index_list('%s')", table)
rows, err := db.QueryContext(ctx, query)
if err != nil {
panic(err)
}
defer rows.Close()
for rows.Next() {
var tmp, index sql.NullString
rows.Scan(&tmp, &index, &tmp, &tmp, &tmp)
if name == index.String {
return true
}
}
return false
}
// GenerateSpecifyIndex return a specifying index clause
func (d *dbBaseSqlite) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
var s []string
Q := d.TableQuote()
for _, index := range indexes {
tmp := fmt.Sprintf(`%s%s%s`, Q, index, Q)
s = append(s, tmp)
}
switch useIndex {
case hints.KeyUseIndex, hints.KeyForceIndex:
return fmt.Sprintf(` INDEXED BY %s `, strings.Join(s, `,`))
default:
DebugLog.Println("[WARN] Not a valid specifying action, so that action is ignored")
return ``
}
}
// create new sqlite dbBaser.
func newdbBaseSqlite() dbBaser {
b := new(dbBaseSqlite)
b.ins = b
return b
}

@ -1,499 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"strings"
"time"
"github.com/beego/beego/v2/client/orm/clauses"
"github.com/beego/beego/v2/client/orm/clauses/order_clause"
)
// table info struct.
type dbTable struct {
id int
index string
name string
names []string
sel bool
inner bool
mi *modelInfo
fi *fieldInfo
jtl *dbTable
}
// tables collection struct, contains some tables.
type dbTables struct {
tablesM map[string]*dbTable
tables []*dbTable
mi *modelInfo
base dbBaser
skipEnd bool
}
// set table info to collection.
// if not exist, create new.
func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
name := strings.Join(names, ExprSep)
if j, ok := t.tablesM[name]; ok {
j.name = name
j.mi = mi
j.fi = fi
j.inner = inner
} else {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
}
return t.tablesM[name]
}
// add table info to collection.
func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
name := strings.Join(names, ExprSep)
if _, ok := t.tablesM[name]; !ok {
i := len(t.tables) + 1
jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
t.tablesM[name] = jt
t.tables = append(t.tables, jt)
return jt, true
}
return t.tablesM[name], false
}
// get table info in collection.
func (t *dbTables) get(name string) (*dbTable, bool) {
j, ok := t.tablesM[name]
return j, ok
}
// get related fields info in recursive depth loop.
// loop once, depth decreases one.
func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
if depth < 0 || fi.fieldType == RelManyToMany {
return related
}
if prefix == "" {
prefix = fi.name
} else {
prefix = prefix + ExprSep + fi.name
}
related = append(related, prefix)
depth--
for _, fi := range fi.relModelInfo.fields.fieldsRel {
related = t.loopDepth(depth, prefix, fi, related)
}
return related
}
// parse related fields.
func (t *dbTables) parseRelated(rels []string, depth int) {
relsNum := len(rels)
related := make([]string, relsNum)
copy(related, rels)
relDepth := depth
if relsNum != 0 {
relDepth = 0
}
relDepth--
for _, fi := range t.mi.fields.fieldsRel {
related = t.loopDepth(relDepth, "", fi, related)
}
for i, s := range related {
var (
exs = strings.Split(s, ExprSep)
names = make([]string, 0, len(exs))
mmi = t.mi
cancel = true
jtl *dbTable
)
inner := true
for _, ex := range exs {
if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
names = append(names, fi.name)
mmi = fi.relModelInfo
if fi.null || t.skipEnd {
inner = false
}
jt := t.set(names, mmi, fi, inner)
jt.jtl = jtl
if fi.reverse {
cancel = false
}
if cancel {
jt.sel = depth > 0
if i < relsNum {
jt.sel = true
}
}
jtl = jt
} else {
panic(fmt.Errorf("unknown model/table name `%s`", ex))
}
}
}
}
// generate join string.
func (t *dbTables) getJoinSQL() (join string) {
Q := t.base.TableQuote()
for _, jt := range t.tables {
if jt.inner {
join += "INNER JOIN "
} else {
join += "LEFT OUTER JOIN "
}
var (
table string
t1, t2 string
c1, c2 string
)
t1 = "T0"
if jt.jtl != nil {
t1 = jt.jtl.index
}
t2 = jt.index
table = jt.mi.table
switch {
case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
c1 = jt.fi.mi.fields.pk.column
for _, ffi := range jt.mi.fields.fieldsRel {
if jt.fi.mi == ffi.relModelInfo {
c2 = ffi.column
break
}
}
default:
c1 = jt.fi.column
c2 = jt.fi.relModelInfo.fields.pk.column
if jt.fi.reverse {
c1 = jt.mi.fields.pk.column
c2 = jt.fi.reverseFieldInfo.column
}
}
join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
t2, Q, c2, Q, t1, Q, c1, Q)
}
return
}
// parse orm model struct field tag expression.
func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
var (
jtl *dbTable
fi *fieldInfo
fiN *fieldInfo
mmi = mi
)
num := len(exprs) - 1
var names []string
inner := true
loopFor:
for i, ex := range exprs {
var ok, okN bool
if fiN != nil {
fi = fiN
ok = true
fiN = nil
}
if i == 0 {
fi, ok = mmi.fields.GetByAny(ex)
}
_ = okN
if ok {
isRel := fi.rel || fi.reverse
names = append(names, fi.name)
switch {
case fi.rel:
mmi = fi.relModelInfo
if fi.fieldType == RelManyToMany {
mmi = fi.relThroughModelInfo
}
case fi.reverse:
mmi = fi.reverseFieldInfo.mi
}
if i < num {
fiN, okN = mmi.fields.GetByAny(exprs[i+1])
}
if isRel && (!fi.mi.isThrough || num != i) {
if fi.null || t.skipEnd {
inner = false
}
if t.skipEnd && okN || !t.skipEnd {
if t.skipEnd && okN && fiN.pk {
goto loopEnd
}
jt, _ := t.add(names, mmi, fi, inner)
jt.jtl = jtl
jtl = jt
}
}
if num != i {
continue
}
loopEnd:
if i == 0 || jtl == nil {
index = "T0"
} else {
index = jtl.index
}
info = fi
if jtl == nil {
name = fi.name
} else {
name = jtl.name + ExprSep + fi.name
}
switch {
case fi.rel:
case fi.reverse:
switch fi.reverseFieldInfo.fieldType {
case RelOneToOne, RelForeignKey:
index = jtl.index
info = fi.reverseFieldInfo.mi.fields.pk
name = info.name
}
}
break loopFor
} else {
index = ""
name = ""
info = nil
success = false
return
}
}
success = index != "" && info != nil
return
}
// generate condition sql.
func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
if cond == nil || cond.IsEmpty() {
return
}
Q := t.base.TableQuote()
mi := t.mi
for i, p := range cond.params {
if i > 0 {
if p.isOr {
where += "OR "
} else {
where += "AND "
}
}
if p.isNot {
where += "NOT "
}
if p.isCond {
w, ps := t.getCondSQL(p.cond, true, tz)
if w != "" {
w = fmt.Sprintf("( %s) ", w)
}
where += w
params = append(params, ps...)
} else {
exprs := p.exprs
num := len(exprs) - 1
operator := ""
if operators[exprs[num]] {
operator = exprs[num]
exprs = exprs[:num]
}
index, _, fi, suc := t.parseExprs(mi, exprs)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
}
if operator == "" {
operator = "exact"
}
var operSQL string
var args []interface{}
if p.isRaw {
operSQL = p.sql
} else {
operSQL, args = t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
}
leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
where += fmt.Sprintf("%s %s ", leftCol, operSQL)
params = append(params, args...)
}
}
if !sub && where != "" {
where = "WHERE " + where
}
return
}
// generate group sql.
func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
if len(groups) == 0 {
return
}
Q := t.base.TableQuote()
groupSqls := make([]string, 0, len(groups))
for _, group := range groups {
exprs := strings.Split(group, ExprSep)
index, _, fi, suc := t.parseExprs(t.mi, exprs)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
}
groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
}
groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
return
}
// generate order sql.
func (t *dbTables) getOrderSQL(orders []*order_clause.Order) (orderSQL string) {
if len(orders) == 0 {
return
}
Q := t.base.TableQuote()
orderSqls := make([]string, 0, len(orders))
for _, order := range orders {
column := order.GetColumn()
clause := strings.Split(column, clauses.ExprDot)
if order.IsRaw() {
if len(clause) == 2 {
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", clause[0], Q, clause[1], Q, order.SortString()))
} else if len(clause) == 1 {
orderSqls = append(orderSqls, fmt.Sprintf("%s%s%s %s", Q, clause[0], Q, order.SortString()))
} else {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
}
} else {
index, _, fi, suc := t.parseExprs(t.mi, clause)
if !suc {
panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(clause, ExprSep)))
}
orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, order.SortString()))
}
}
orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
return
}
// generate limit sql.
func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
if limit == 0 {
limit = int64(DefaultRowsLimit)
}
if limit < 0 {
// no limit
if offset > 0 {
maxLimit := t.base.MaxLimit()
if maxLimit == 0 {
limits = fmt.Sprintf("OFFSET %d", offset)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
}
}
} else if offset <= 0 {
limits = fmt.Sprintf("LIMIT %d", limit)
} else {
limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
}
return
}
// getIndexSql generate index sql.
func (t *dbTables) getIndexSql(tableName string, useIndex int, indexes []string) (clause string) {
if len(indexes) == 0 {
return
}
return t.base.GenerateSpecifyIndex(tableName, useIndex, indexes)
}
// crete new tables collection.
func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
tables := &dbTables{}
tables.tablesM = make(map[string]*dbTable)
tables.mi = mi
tables.base = base
return tables
}

@ -1,64 +0,0 @@
// Copyright 2015 TiDB Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"fmt"
)
// mysql dbBaser implementation.
type dbBaseTidb struct {
dbBase
}
var _ dbBaser = new(dbBaseTidb)
// get mysql operator.
func (d *dbBaseTidb) OperatorSQL(operator string) string {
return mysqlOperators[operator]
}
// get mysql table field types.
func (d *dbBaseTidb) DbTypes() map[string]string {
return mysqlTypes
}
// show table sql for mysql.
func (d *dbBaseTidb) ShowTablesQuery() string {
return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = DATABASE()"
}
// show columns sql of table for mysql.
func (d *dbBaseTidb) ShowColumnsQuery(table string) string {
return fmt.Sprintf("SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE FROM information_schema.columns "+
"WHERE table_schema = DATABASE() AND table_name = '%s'", table)
}
// execute sql to check index exist.
func (d *dbBaseTidb) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
row := db.QueryRowContext(ctx, "SELECT count(*) FROM information_schema.statistics "+
"WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", table, name)
var cnt int
row.Scan(&cnt)
return cnt > 0
}
// create new mysql dbBaser.
func newdbBaseTidb() dbBaser {
b := new(dbBaseTidb)
b.ins = b
return b
}

@ -1,175 +0,0 @@
// Copyright 2014 beego Author. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"fmt"
"reflect"
"time"
)
// get table alias.
func getDbAlias(name string) *alias {
if al, ok := dataBaseCache.get(name); ok {
return al
}
panic(fmt.Errorf("unknown DataBase alias name %s", name))
}
// get pk column info.
func getExistPk(mi *modelInfo, ind reflect.Value) (column string, value interface{}, exist bool) {
fi := mi.fields.pk
v := ind.FieldByIndex(fi.fieldIndex)
if fi.fieldType&IsPositiveIntegerField > 0 {
vu := v.Uint()
exist = vu > 0
value = vu
} else if fi.fieldType&IsIntegerField > 0 {
vu := v.Int()
exist = true
value = vu
} else if fi.fieldType&IsRelField > 0 {
_, value, exist = getExistPk(fi.relModelInfo, reflect.Indirect(v))
} else {
vu := v.String()
exist = vu != ""
value = vu
}
column = fi.column
return
}
// get fields description as flatted string.
func getFlatParams(fi *fieldInfo, args []interface{}, tz *time.Location) (params []interface{}) {
outFor:
for _, arg := range args {
if arg == nil {
params = append(params, arg)
continue
}
val := reflect.ValueOf(arg)
kind := val.Kind()
if kind == reflect.Ptr {
val = val.Elem()
kind = val.Kind()
arg = val.Interface()
}
switch kind {
case reflect.String:
v := val.String()
if fi != nil {
if fi.fieldType == TypeTimeField || fi.fieldType == TypeDateField || fi.fieldType == TypeDateTimeField {
var t time.Time
var err error
if len(v) >= 19 {
s := v[:19]
t, err = time.ParseInLocation(formatDateTime, s, DefaultTimeLoc)
} else if len(v) >= 10 {
s := v
if len(v) > 10 {
s = v[:10]
}
t, err = time.ParseInLocation(formatDate, s, tz)
} else {
s := v
if len(s) > 8 {
s = v[:8]
}
t, err = time.ParseInLocation(formatTime, s, tz)
}
if err == nil {
if fi.fieldType == TypeDateField {
v = t.In(tz).Format(formatDate)
} else if fi.fieldType == TypeDateTimeField {
v = t.In(tz).Format(formatDateTime)
} else {
v = t.In(tz).Format(formatTime)
}
}
}
}
arg = v
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
arg = val.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
arg = val.Uint()
case reflect.Float32:
arg, _ = StrTo(ToStr(arg)).Float64()
case reflect.Float64:
arg = val.Float()
case reflect.Bool:
arg = val.Bool()
case reflect.Slice, reflect.Array:
if _, ok := arg.([]byte); ok {
continue outFor
}
var args []interface{}
for i := 0; i < val.Len(); i++ {
v := val.Index(i)
var vu interface{}
if v.CanInterface() {
vu = v.Interface()
}
if vu == nil {
continue
}
args = append(args, vu)
}
if len(args) > 0 {
p := getFlatParams(fi, args, tz)
params = append(params, p...)
}
continue outFor
case reflect.Struct:
if v, ok := arg.(time.Time); ok {
if fi != nil && fi.fieldType == TypeDateField {
arg = v.In(tz).Format(formatDate)
} else if fi != nil && fi.fieldType == TypeDateTimeField {
arg = v.In(tz).Format(formatDateTime)
} else if fi != nil && fi.fieldType == TypeTimeField {
arg = v.In(tz).Format(formatTime)
} else {
arg = v.In(tz).Format(formatDateTime)
}
} else {
typ := val.Type()
name := getFullName(typ)
var value interface{}
if mmi, ok := defaultModelCache.getByFullName(name); ok {
if _, vu, exist := getExistPk(mmi, val); exist {
value = vu
}
}
arg = value
if arg == nil {
panic(fmt.Errorf("need a valid args value, unknown table or value `%s`", name))
}
}
}
params = append(params, arg)
}
return
}

@ -1,181 +0,0 @@
// Copyright 2020 beego
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"database/sql"
"github.com/beego/beego/v2/core/utils"
)
// DoNothingOrm won't do anything, usually you use this to custom your mock Ormer implementation
// I think golang mocking interface is hard to use
// this may help you to integrate with Ormer
var _ Ormer = new(DoNothingOrm)
type DoNothingOrm struct{}
func (d *DoNothingOrm) Read(md interface{}, cols ...string) error {
return nil
}
func (d *DoNothingOrm) ReadWithCtx(ctx context.Context, md interface{}, cols ...string) error {
return nil
}
func (d *DoNothingOrm) ReadForUpdate(md interface{}, cols ...string) error {
return nil
}
func (d *DoNothingOrm) ReadForUpdateWithCtx(ctx context.Context, md interface{}, cols ...string) error {
return nil
}
func (d *DoNothingOrm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
return false, 0, nil
}
func (d *DoNothingOrm) ReadOrCreateWithCtx(ctx context.Context, md interface{}, col1 string, cols ...string) (bool, int64, error) {
return false, 0, nil
}
func (d *DoNothingOrm) LoadRelated(md interface{}, name string, args ...utils.KV) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) LoadRelatedWithCtx(ctx context.Context, md interface{}, name string, args ...utils.KV) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) QueryM2M(md interface{}, name string) QueryM2Mer {
return nil
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (d *DoNothingOrm) QueryM2MWithCtx(ctx context.Context, md interface{}, name string) QueryM2Mer {
return nil
}
func (d *DoNothingOrm) QueryTable(ptrStructOrTableName interface{}) QuerySeter {
return nil
}
// NOTE: this method is deprecated, context parameter will not take effect.
func (d *DoNothingOrm) QueryTableWithCtx(ctx context.Context, ptrStructOrTableName interface{}) QuerySeter {
return nil
}
func (d *DoNothingOrm) DBStats() *sql.DBStats {
return nil
}
func (d *DoNothingOrm) Insert(md interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertWithCtx(ctx context.Context, md interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertOrUpdateWithCtx(ctx context.Context, md interface{}, colConflitAndArgs ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertMulti(bulk int, mds interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) InsertMultiWithCtx(ctx context.Context, bulk int, mds interface{}) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) Update(md interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) UpdateWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) Delete(md interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) DeleteWithCtx(ctx context.Context, md interface{}, cols ...string) (int64, error) {
return 0, nil
}
func (d *DoNothingOrm) Raw(query string, args ...interface{}) RawSeter {
return nil
}
func (d *DoNothingOrm) RawWithCtx(ctx context.Context, query string, args ...interface{}) RawSeter {
return nil
}
func (d *DoNothingOrm) Driver() Driver {
return nil
}
func (d *DoNothingOrm) Begin() (TxOrmer, error) {
return nil, nil
}
func (d *DoNothingOrm) BeginWithCtx(ctx context.Context) (TxOrmer, error) {
return nil, nil
}
func (d *DoNothingOrm) BeginWithOpts(opts *sql.TxOptions) (TxOrmer, error) {
return nil, nil
}
func (d *DoNothingOrm) BeginWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions) (TxOrmer, error) {
return nil, nil
}
func (d *DoNothingOrm) DoTx(task func(ctx context.Context, txOrm TxOrmer) error) error {
return nil
}
func (d *DoNothingOrm) DoTxWithCtx(ctx context.Context, task func(ctx context.Context, txOrm TxOrmer) error) error {
return nil
}
func (d *DoNothingOrm) DoTxWithOpts(opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
return nil
}
func (d *DoNothingOrm) DoTxWithCtxAndOpts(ctx context.Context, opts *sql.TxOptions, task func(ctx context.Context, txOrm TxOrmer) error) error {
return nil
}
// DoNothingTxOrm is similar with DoNothingOrm, usually you use it to test
type DoNothingTxOrm struct {
DoNothingOrm
}
func (d *DoNothingTxOrm) Commit() error {
return nil
}
func (d *DoNothingTxOrm) Rollback() error {
return nil
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save