- update mongo

master v1.0.34
李光春 2 years ago
parent 7c835284e1
commit f420131a37

@ -1,3 +1,3 @@
package dorm package dorm
const Version = "1.0.33" const Version = "1.0.34"

@ -22,7 +22,7 @@ require (
go.mongodb.org/mongo-driver v1.10.2 go.mongodb.org/mongo-driver v1.10.2
gorm.io/driver/mysql v1.3.6 gorm.io/driver/mysql v1.3.6
gorm.io/driver/postgres v1.3.9 gorm.io/driver/postgres v1.3.9
gorm.io/gorm v1.23.8 gorm.io/gorm v1.23.9
xorm.io/xorm v1.3.2 xorm.io/xorm v1.3.2
) )
@ -48,7 +48,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.15.9 // indirect github.com/klauspost/compress v1.15.10 // indirect
github.com/leodido/go-urn v1.2.1 // indirect github.com/leodido/go-urn v1.2.1 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
@ -72,7 +72,7 @@ require (
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
golang.org/x/net v0.0.0-20220909164309-bea034e7d591 // indirect golang.org/x/net v0.0.0-20220909164309-bea034e7d591 // indirect
golang.org/x/sync v0.0.0-20220907140024-f12130a52804 // indirect golang.org/x/sync v0.0.0-20220907140024-f12130a52804 // indirect
golang.org/x/sys v0.0.0-20220913175220-63ea55921009 // indirect golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41 // indirect
golang.org/x/text v0.3.7 // indirect golang.org/x/text v0.3.7 // indirect
golang.org/x/tools v0.1.11 // indirect golang.org/x/tools v0.1.11 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

@ -297,8 +297,8 @@ github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvW
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.9.5/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.9.5/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY= github.com/klauspost/compress v1.15.10 h1:Ai8UzuomSCDw90e1qNMtb15msBXsNpH6gzkkENQNcJo=
github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= github.com/klauspost/compress v1.15.10/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
@ -654,8 +654,8 @@ golang.org/x/sys v0.0.0-20210902050250-f475640dd07b/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220913175220-63ea55921009 h1:PuvuRMeLWqsf/ZdT1UUZz0syhioyv1mzuFZsXs4fvhw= golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41 h1:ohgcoMbSofXygzo6AD2I1kz3BFmW1QArPYTtwEM3UXc=
golang.org/x/sys v0.0.0-20220913175220-63ea55921009/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220915200043-7b5979e65e41/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -745,8 +745,9 @@ gorm.io/driver/mysql v1.3.6/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10
gorm.io/driver/postgres v1.3.9 h1:lWGiVt5CijhQAg0PWB7Od1RNcBw/jS4d2cAScBcSDXg= gorm.io/driver/postgres v1.3.9 h1:lWGiVt5CijhQAg0PWB7Od1RNcBw/jS4d2cAScBcSDXg=
gorm.io/driver/postgres v1.3.9/go.mod h1:qw/FeqjxmYqW5dBcYNBsnhQULIApQdk7YuuDPktVi1U= gorm.io/driver/postgres v1.3.9/go.mod h1:qw/FeqjxmYqW5dBcYNBsnhQULIApQdk7YuuDPktVi1U=
gorm.io/gorm v1.23.7/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.23.7/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.23.8 h1:h8sGJ+biDgBA1AD1Ha9gFCx7h8npU7AsLdlkX0n2TpE=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.23.9 h1:NSHG021i+MCznokeXR3udGaNyFyBQJW8MbjrJMVCfGw=
gorm.io/gorm v1.23.9/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

@ -160,7 +160,7 @@ func NewGormPostgresqlClient(config *ConfigGormClient) (*GormClient, error) {
// 设置了连接可复用的最大时间 // 设置了连接可复用的最大时间
if c.config.ConnSetConnMaxLifetime == 0 { if c.config.ConnSetConnMaxLifetime == 0 {
sqlDB.SetConnMaxLifetime(time.Second * 600) sqlDB.SetConnMaxLifetime(time.Hour)
} else { } else {
sqlDB.SetConnMaxLifetime(time.Duration(c.config.ConnSetConnMaxLifetime)) sqlDB.SetConnMaxLifetime(time.Duration(c.config.ConnSetConnMaxLifetime))
} }

@ -1 +0,0 @@
package dorm

@ -0,0 +1,56 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type MongoCollectionOptions struct {
dbCollection *mongo.Collection
}
// Collection 选择集合
func (cd *MongoDatabaseOptions) Collection(name string, opts ...*options.CollectionOptions) *MongoCollectionOptions {
return &MongoCollectionOptions{
dbCollection: cd.dbDatabase.Collection(name, opts...),
}
}
// CreateOneIndexes 创建一个索引
func (cc *MongoCollectionOptions) CreateOneIndexes(ctx context.Context, key string, value string) (string, error) {
return cc.dbCollection.Indexes().CreateOne(ctx, mongo.IndexModel{
Keys: bson.D{{
Key: key,
Value: value,
}},
})
}
// CreateOneUniqueIndexes 创建一个唯一索引
func (cc *MongoCollectionOptions) CreateOneUniqueIndexes(ctx context.Context, key string, value string) (string, error) {
return cc.dbCollection.Indexes().CreateOne(ctx, mongo.IndexModel{
Keys: bson.D{{
Key: key,
Value: value,
}},
Options: options.Index().SetUnique(true),
})
}
// CreateOneUniqueIndexesOpts 创建一个索引
func (cc *MongoCollectionOptions) CreateOneUniqueIndexesOpts(ctx context.Context, key string, value string, opts *options.IndexOptions) (string, error) {
return cc.dbCollection.Indexes().CreateOne(ctx, mongo.IndexModel{
Keys: bson.D{{
Key: key,
Value: value,
}},
Options: opts,
})
}
// CreateManyIndexes 创建多个索引
func (cc *MongoCollectionOptions) CreateManyIndexes(ctx context.Context, models []mongo.IndexModel) ([]string, error) {
return cc.dbCollection.Indexes().CreateMany(ctx, models)
}

@ -0,0 +1,47 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// InsertOne 插入一个文档
func (cc *MongoCollectionOptions) InsertOne(ctx context.Context, document interface{}, opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) {
return cc.dbCollection.InsertOne(ctx, document, opts...)
}
// InsertMany 插入多个文档
func (cc *MongoCollectionOptions) InsertMany(ctx context.Context, document []interface{}, opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) {
return cc.dbCollection.InsertMany(ctx, document, opts...)
}
// DeleteOne 删除一个文档
func (cc *MongoCollectionOptions) DeleteOne(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) {
return cc.dbCollection.DeleteOne(ctx, filter, opts...)
}
// DeleteMany 删除多个文档
func (cc *MongoCollectionOptions) DeleteMany(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) {
return cc.dbCollection.DeleteMany(ctx, filter, opts...)
}
// UpdateOne 更新一个文档
func (cc *MongoCollectionOptions) UpdateOne(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
return cc.dbCollection.UpdateOne(ctx, filter, update, opts...)
}
// UpdateMany 更新多个文档
func (cc *MongoCollectionOptions) UpdateMany(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) {
return cc.dbCollection.UpdateMany(ctx, filter, update, opts...)
}
// FindOne 查询一个文档
func (cc *MongoCollectionOptions) FindOne(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult {
return cc.dbCollection.FindOne(ctx, filter, opts...)
}
// Find 查询多个文档
func (cc *MongoCollectionOptions) Find(ctx context.Context, filter interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) {
return cc.dbCollection.Find(ctx, filter, opts...)
}

@ -1,185 +0,0 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
// InsertOne 插入一个文档
func (c *MongoClient) InsertOne(document interface{}) (result *mongo.InsertOneResult, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.InsertOne(context.TODO(), document)
return
}
// InsertMany 插入多个文档
func (c *MongoClient) InsertMany(documents []interface{}) (result *mongo.InsertManyResult, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.InsertMany(context.TODO(), documents)
return
}
// Delete 删除文档
func (c *MongoClient) Delete(filter interface{}) (err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
_, err = collection.DeleteOne(context.TODO(), filter)
return
}
// DeleteId 删除文档
func (c *MongoClient) DeleteId(id interface{}) (err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
_, err = collection.DeleteOne(context.TODO(), bson.M{"_id": id})
return
}
// DeleteMany 删除多个文档
func (c *MongoClient) DeleteMany(filter interface{}) (result *mongo.DeleteResult, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.DeleteMany(context.TODO(), filter)
return
}
// UpdateOne 更新单个文档
// 修改字段的值($set)
// 字段增加值 inc($inc)
// 从数组中增加一个元素 push($push)
// 从数组中删除一个元素 pull($pull)
func (c *MongoClient) UpdateOne(filter interface{}, update interface{}) (err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
_, err = collection.UpdateOne(context.TODO(), filter, update)
return
}
// UpdateId 更新单个文档
// 修改字段的值($set)
// 字段增加值 inc($inc)
// 从数组中增加一个元素 push($push)
// 从数组中删除一个元素 pull($pull)
func (c *MongoClient) UpdateId(id interface{}, update interface{}) (err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
_, err = collection.UpdateOne(context.TODO(), bson.M{"_id": id}, update)
return
}
// UpdateMany 更新多个文档
// 修改字段的值($set)
// 字段增加值 inc($inc)
// 从数组中增加一个元素 push($push)
// 从数组中删除一个元素 pull($pull)
func (c *MongoClient) UpdateMany(filter interface{}, update interface{}) (result *mongo.UpdateResult, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.UpdateMany(context.TODO(), filter, update)
return
}
type FindResultI interface {
Many(result interface{}) error
}
// Find 查询
func (c *MongoClient) Find(filter interface{}) (*mongo.Cursor, error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
return collection.Find(context.TODO(), filter)
}
type FindOneResultI interface {
One(result interface{}) error
}
// FindOne 查询单个文档
func (c *MongoClient) FindOne(filter interface{}) *mongo.SingleResult {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
return collection.FindOne(context.TODO(), filter)
}
type FindManyResultI interface {
Many(result interface{}) error
}
// FindMany 查询多个文档
func (c *MongoClient) FindMany(filter interface{}) (*mongo.Cursor, error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
return collection.Find(context.TODO(), filter)
}
// FindManyByFilters 多条件查询
func (c *MongoClient) FindManyByFilters(filter interface{}) (result *mongo.Cursor, err error) {
collection, err := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName).Clone()
result, err = collection.Find(context.TODO(), bson.M{"$and": filter})
return result, err
}
// FindManyByFiltersSort 多条件查询支持排序
func (c *MongoClient) FindManyByFiltersSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) {
collection, err := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName).Clone()
findOptions := options.Find()
findOptions.SetSort(Sort)
result, err = collection.Find(context.TODO(), filter, findOptions)
return result, err
}
// FindCollection 查询集合文档
func (c *MongoClient) FindCollection(Limit int64) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
findOptions := options.Find()
findOptions.SetLimit(Limit)
result, err = collection.Find(context.TODO(), bson.D{{}}, findOptions)
return result, err
}
// FindCollectionSort 查询集合文档支持排序
func (c *MongoClient) FindCollectionSort(Sort interface{}, Limit int64) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
findOptions := options.Find()
findOptions.SetSort(Sort)
findOptions.SetLimit(Limit)
result, err = collection.Find(context.TODO(), bson.D{{}}, findOptions)
return result, err
}
// FindManyCollectionSort 查询集合文档支持排序支持条件
func (c *MongoClient) FindManyCollectionSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
findOptions := options.Find()
findOptions.SetSort(Sort)
result, err = collection.Find(context.TODO(), filter, findOptions)
return result, err
}
// CollectionCount 查询集合里有多少数据
func (c *MongoClient) CollectionCount() (name string, size int64) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
name = collection.Name()
size, _ = collection.EstimatedDocumentCount(context.TODO())
return name, size
}
// CollectionDocuments 按选项查询集合
// Skip 跳过
// Limit 读取数量
// sort 1 -1 . 1 为升序 -1 为降序
func (c *MongoClient) CollectionDocuments(Skip, Limit int64, sort int, key string, value interface{}) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
SORT := bson.D{{"_id", sort}}
filter := bson.D{{key, value}}
findOptions := options.Find().SetSort(SORT).SetLimit(Limit).SetSkip(Skip)
result, err = collection.Find(context.TODO(), filter, findOptions)
return result, err
}
// AggregateByFiltersSort 统计分析
func (c *MongoClient) AggregateByFiltersSort(pipeline interface{}) (result *mongo.Cursor, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
result, err = collection.Aggregate(context.TODO(), pipeline)
return result, err
}
// CountDocumentsByFilters 统计数量
func (c *MongoClient) CountDocumentsByFilters(filter interface{}) (count int64, err error) {
collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName)
count, err = collection.CountDocuments(context.TODO(), filter)
return count, err
}

@ -1,15 +0,0 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
)
type FindResult struct {
cursor *mongo.Cursor
err error
}
func (f *FindResult) Many(result interface{}) error {
return f.cursor.All(context.TODO(), result)
}

@ -1,15 +0,0 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
)
type FindManyResult struct {
cursor *mongo.Cursor
err error
}
func (f *FindManyResult) Many(result interface{}) error {
return f.cursor.All(context.TODO(), result)
}

@ -1,11 +0,0 @@
package dorm
import "go.mongodb.org/mongo-driver/mongo"
type FindOneResult struct {
singleResult *mongo.SingleResult
}
func (f *FindOneResult) One(result interface{}) error {
return f.singleResult.Decode(result)
}

@ -0,0 +1,28 @@
package dorm
import (
"context"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type MongoDatabaseOptions struct {
dbDatabase *mongo.Database
}
// Database 选择数据库
func (c *MongoClient) Database(name string, opts ...*options.DatabaseOptions) *MongoDatabaseOptions {
return &MongoDatabaseOptions{
dbDatabase: c.Db.Database(name, opts...),
}
}
// CreateCollection 创建集合
func (cd *MongoDatabaseOptions) CreateCollection(ctx context.Context, name string, opts ...*options.CreateCollectionOptions) error {
return cd.dbDatabase.CreateCollection(ctx, name, opts...)
}
// CreateTimeSeriesCollection 创建时间序列集合
func (cd *MongoDatabaseOptions) CreateTimeSeriesCollection(ctx context.Context, name string, timeField string) error {
return cd.dbDatabase.CreateCollection(ctx, name, options.CreateCollection().SetTimeSeriesOptions(options.TimeSeries().SetTimeField(timeField)))
}

@ -1 +0,0 @@
package dorm

@ -1,28 +0,0 @@
package dorm
import "reflect"
// Database 设置库名
func (c *MongoClient) Database(databaseName string) *MongoClient {
c.databaseName = databaseName
return c
}
// Collection 设置表名
func (c *MongoClient) Collection(collectionName string) *MongoClient {
c.collectionName = collectionName
return c
}
// Model 传入模型自动获取库名和表名
func (c *MongoClient) Model(value interface{}) *MongoClient {
// https://studygolang.com/articles/896
val := reflect.ValueOf(value)
if methodValue := val.MethodByName("Database"); methodValue.IsValid() {
c.databaseName = methodValue.Call(nil)[0].String()
}
if methodValue := val.MethodByName("TableName"); methodValue.IsValid() {
c.collectionName = methodValue.Call(nil)[0].String()
}
return c
}

@ -17,6 +17,12 @@ This package provides various compression algorithms.
# changelog # changelog
* July 21, 2022 (v1.15.9)
* zstd: Fix decoder crash on amd64 (no BMI) on invalid input https://github.com/klauspost/compress/pull/645
* zstd: Disable decoder extended memory copies (amd64) due to possible crashes https://github.com/klauspost/compress/pull/644
* zstd: Allow single segments up to "max decoded size" by @klauspost in https://github.com/klauspost/compress/pull/643
* July 13, 2022 (v1.15.8) * July 13, 2022 (v1.15.8)
* gzip: fix stack exhaustion bug in Reader.Read https://github.com/klauspost/compress/pull/641 * gzip: fix stack exhaustion bug in Reader.Read https://github.com/klauspost/compress/pull/641

@ -763,17 +763,20 @@ func (d *Decoder) decompress4X8bit(dst, src []byte) ([]byte, error) {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 1") return nil, errors.New("corruption detected: stream overrun 1")
} }
copy(out, buf[0][:])
copy(out[dstEvery:], buf[1][:])
copy(out[dstEvery*2:], buf[2][:])
copy(out[dstEvery*3:], buf[3][:])
out = out[bufoff:]
decoded += bufoff * 4
// There must at least be 3 buffers left. // There must at least be 3 buffers left.
if len(out) < dstEvery*3 { if len(out)-bufoff < dstEvery*3 {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 2") return nil, errors.New("corruption detected: stream overrun 2")
} }
//copy(out, buf[0][:])
//copy(out[dstEvery:], buf[1][:])
//copy(out[dstEvery*2:], buf[2][:])
*(*[bufoff]byte)(out) = buf[0]
*(*[bufoff]byte)(out[dstEvery:]) = buf[1]
*(*[bufoff]byte)(out[dstEvery*2:]) = buf[2]
*(*[bufoff]byte)(out[dstEvery*3:]) = buf[3]
out = out[bufoff:]
decoded += bufoff * 4
} }
} }
if off > 0 { if off > 0 {
@ -997,17 +1000,22 @@ func (d *Decoder) decompress4X8bitExactly(dst, src []byte) ([]byte, error) {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 1") return nil, errors.New("corruption detected: stream overrun 1")
} }
copy(out, buf[0][:])
copy(out[dstEvery:], buf[1][:])
copy(out[dstEvery*2:], buf[2][:])
copy(out[dstEvery*3:], buf[3][:])
out = out[bufoff:]
decoded += bufoff * 4
// There must at least be 3 buffers left. // There must at least be 3 buffers left.
if len(out) < dstEvery*3 { if len(out)-bufoff < dstEvery*3 {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 2") return nil, errors.New("corruption detected: stream overrun 2")
} }
//copy(out, buf[0][:])
//copy(out[dstEvery:], buf[1][:])
//copy(out[dstEvery*2:], buf[2][:])
// copy(out[dstEvery*3:], buf[3][:])
*(*[bufoff]byte)(out) = buf[0]
*(*[bufoff]byte)(out[dstEvery:]) = buf[1]
*(*[bufoff]byte)(out[dstEvery*2:]) = buf[2]
*(*[bufoff]byte)(out[dstEvery*3:]) = buf[3]
out = out[bufoff:]
decoded += bufoff * 4
} }
} }
if off > 0 { if off > 0 {

@ -14,12 +14,14 @@ import (
// decompress4x_main_loop_x86 is an x86 assembler implementation // decompress4x_main_loop_x86 is an x86 assembler implementation
// of Decompress4X when tablelog > 8. // of Decompress4X when tablelog > 8.
//
//go:noescape //go:noescape
func decompress4x_main_loop_amd64(ctx *decompress4xContext) func decompress4x_main_loop_amd64(ctx *decompress4xContext)
// decompress4x_8b_loop_x86 is an x86 assembler implementation // decompress4x_8b_loop_x86 is an x86 assembler implementation
// of Decompress4X when tablelog <= 8 which decodes 4 entries // of Decompress4X when tablelog <= 8 which decodes 4 entries
// per loop. // per loop.
//
//go:noescape //go:noescape
func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext) func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext)
@ -145,11 +147,13 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
// decompress4x_main_loop_x86 is an x86 assembler implementation // decompress4x_main_loop_x86 is an x86 assembler implementation
// of Decompress1X when tablelog > 8. // of Decompress1X when tablelog > 8.
//
//go:noescape //go:noescape
func decompress1x_main_loop_amd64(ctx *decompress1xContext) func decompress1x_main_loop_amd64(ctx *decompress1xContext)
// decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation // decompress4x_main_loop_x86 is an x86 with BMI2 assembler implementation
// of Decompress1X when tablelog > 8. // of Decompress1X when tablelog > 8.
//
//go:noescape //go:noescape
func decompress1x_main_loop_bmi2(ctx *decompress1xContext) func decompress1x_main_loop_bmi2(ctx *decompress1xContext)

@ -1,7 +1,6 @@
// Code generated by command: go run gen.go -out ../decompress_amd64.s -pkg=huff0. DO NOT EDIT. // Code generated by command: go run gen.go -out ../decompress_amd64.s -pkg=huff0. DO NOT EDIT.
//go:build amd64 && !appengine && !noasm && gc //go:build amd64 && !appengine && !noasm && gc
// +build amd64,!appengine,!noasm,gc
// func decompress4x_main_loop_amd64(ctx *decompress4xContext) // func decompress4x_main_loop_amd64(ctx *decompress4xContext)
TEXT ·decompress4x_main_loop_amd64(SB), $0-8 TEXT ·decompress4x_main_loop_amd64(SB), $0-8

@ -122,17 +122,21 @@ func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 1") return nil, errors.New("corruption detected: stream overrun 1")
} }
copy(out, buf[0][:])
copy(out[dstEvery:], buf[1][:])
copy(out[dstEvery*2:], buf[2][:])
copy(out[dstEvery*3:], buf[3][:])
out = out[bufoff:]
decoded += bufoff * 4
// There must at least be 3 buffers left. // There must at least be 3 buffers left.
if len(out) < dstEvery*3 { if len(out)-bufoff < dstEvery*3 {
d.bufs.Put(buf) d.bufs.Put(buf)
return nil, errors.New("corruption detected: stream overrun 2") return nil, errors.New("corruption detected: stream overrun 2")
} }
//copy(out, buf[0][:])
//copy(out[dstEvery:], buf[1][:])
//copy(out[dstEvery*2:], buf[2][:])
//copy(out[dstEvery*3:], buf[3][:])
*(*[bufoff]byte)(out) = buf[0]
*(*[bufoff]byte)(out[dstEvery:]) = buf[1]
*(*[bufoff]byte)(out[dstEvery*2:]) = buf[2]
*(*[bufoff]byte)(out[dstEvery*3:]) = buf[3]
out = out[bufoff:]
decoded += bufoff * 4
} }
} }
if off > 0 { if off > 0 {

@ -18,6 +18,7 @@ func load64(b []byte, i int) uint64 {
// emitLiteral writes a literal chunk and returns the number of bytes written. // emitLiteral writes a literal chunk and returns the number of bytes written.
// //
// It assumes that: // It assumes that:
//
// dst is long enough to hold the encoded bytes // dst is long enough to hold the encoded bytes
// 1 <= len(lit) && len(lit) <= 65536 // 1 <= len(lit) && len(lit) <= 65536
func emitLiteral(dst, lit []byte) int { func emitLiteral(dst, lit []byte) int {
@ -42,6 +43,7 @@ func emitLiteral(dst, lit []byte) int {
// emitCopy writes a copy chunk and returns the number of bytes written. // emitCopy writes a copy chunk and returns the number of bytes written.
// //
// It assumes that: // It assumes that:
//
// dst is long enough to hold the encoded bytes // dst is long enough to hold the encoded bytes
// 1 <= offset && offset <= 65535 // 1 <= offset && offset <= 65535
// 4 <= length && length <= 65535 // 4 <= length && length <= 65535
@ -89,6 +91,7 @@ func emitCopy(dst []byte, offset, length int) int {
// src[i:i+k-j] and src[j:k] have the same contents. // src[i:i+k-j] and src[j:k] have the same contents.
// //
// It assumes that: // It assumes that:
//
// 0 <= i && i < j && j <= len(src) // 0 <= i && i < j && j <= len(src)
func extendMatch(src []byte, i, j int) int { func extendMatch(src []byte, i, j int) int {
for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 { for ; j < len(src) && src[i] == src[j]; i, j = i+1, j+1 {
@ -105,8 +108,9 @@ func hash(u, shift uint32) uint32 {
// been written. // been written.
// //
// It also assumes that: // It also assumes that:
//
// len(dst) >= MaxEncodedLen(len(src)) && // len(dst) >= MaxEncodedLen(len(src)) &&
// minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize // minNonLiteralBlockSize <= len(src) && len(src) <= maxBlockSize
func encodeBlock(dst, src []byte) (d int) { func encodeBlock(dst, src []byte) (d int) {
// Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive. // Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive.
// The table element type is uint16, as s < sLimit and sLimit < len(src) // The table element type is uint16, as s < sLimit and sLimit < len(src)

@ -12,6 +12,8 @@ The `zstd` package is provided as open source software using a Go standard licen
Currently the package is heavily optimized for 64 bit processors and will be significantly slower on 32 bit processors. Currently the package is heavily optimized for 64 bit processors and will be significantly slower on 32 bit processors.
For seekable zstd streams, see [this excellent package](https://github.com/SaveTheRbtz/zstd-seekable-format-go).
## Installation ## Installation
Install using `go get -u github.com/klauspost/compress`. The package is located in `github.com/klauspost/compress/zstd`. Install using `go get -u github.com/klauspost/compress`. The package is located in `github.com/klauspost/compress/zstd`.

@ -10,7 +10,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
@ -651,7 +650,7 @@ func (b *blockDec) prepareSequences(in []byte, hist *history) (err error) {
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.matchLengths.fse)) fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.matchLengths.fse))
fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.offsets.fse)) fatalErr(binary.Write(&buf, binary.LittleEndian, hist.decoders.offsets.fse))
buf.Write(in) buf.Write(in)
ioutil.WriteFile(filepath.Join("testdata", "seqs", fn), buf.Bytes(), os.ModePerm) os.WriteFile(filepath.Join("testdata", "seqs", fn), buf.Bytes(), os.ModePerm)
} }
return nil return nil

@ -7,7 +7,6 @@ package zstd
import ( import (
"fmt" "fmt"
"io" "io"
"io/ioutil"
) )
type byteBuffer interface { type byteBuffer interface {
@ -124,7 +123,7 @@ func (r *readerWrapper) readByte() (byte, error) {
} }
func (r *readerWrapper) skipN(n int64) error { func (r *readerWrapper) skipN(n int64) error {
n2, err := io.CopyN(ioutil.Discard, r.r, n) n2, err := io.CopyN(io.Discard, r.r, n)
if n2 != n { if n2 != n {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF
} }

@ -312,6 +312,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
// Grab a block decoder and frame decoder. // Grab a block decoder and frame decoder.
block := <-d.decoders block := <-d.decoders
frame := block.localFrame frame := block.localFrame
initialSize := len(dst)
defer func() { defer func() {
if debugDecoder { if debugDecoder {
printf("re-adding decoder: %p", block) printf("re-adding decoder: %p", block)
@ -354,7 +355,16 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
return dst, ErrWindowSizeExceeded return dst, ErrWindowSizeExceeded
} }
if frame.FrameContentSize != fcsUnknown { if frame.FrameContentSize != fcsUnknown {
if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) { if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)-initialSize) {
if debugDecoder {
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> mcs:", d.o.maxDecodedSize-uint64(len(dst)-initialSize), "len:", len(dst))
}
return dst, ErrDecoderSizeExceeded
}
if d.o.limitToCap && frame.FrameContentSize > uint64(cap(dst)-len(dst)) {
if debugDecoder {
println("decoder size exceeded; fcs:", frame.FrameContentSize, "> (cap-len)", cap(dst)-len(dst))
}
return dst, ErrDecoderSizeExceeded return dst, ErrDecoderSizeExceeded
} }
if cap(dst)-len(dst) < int(frame.FrameContentSize) { if cap(dst)-len(dst) < int(frame.FrameContentSize) {
@ -364,7 +374,7 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
} }
} }
if cap(dst) == 0 { if cap(dst) == 0 && !d.o.limitToCap {
// Allocate len(input) * 2 by default if nothing is provided // Allocate len(input) * 2 by default if nothing is provided
// and we didn't get frame content size. // and we didn't get frame content size.
size := len(input) * 2 size := len(input) * 2
@ -382,6 +392,9 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
if err != nil { if err != nil {
return dst, err return dst, err
} }
if uint64(len(dst)-initialSize) > d.o.maxDecodedSize {
return dst, ErrDecoderSizeExceeded
}
if len(frame.bBuf) == 0 { if len(frame.bBuf) == 0 {
if debugDecoder { if debugDecoder {
println("frame dbuf empty") println("frame dbuf empty")
@ -852,6 +865,10 @@ decodeStream:
} }
} }
if err == nil && d.frame.WindowSize > d.o.maxWindowSize { if err == nil && d.frame.WindowSize > d.o.maxWindowSize {
if debugDecoder {
println("decoder size exceeded, fws:", d.frame.WindowSize, "> mws:", d.o.maxWindowSize)
}
err = ErrDecoderSizeExceeded err = ErrDecoderSizeExceeded
} }
if err != nil { if err != nil {

@ -20,6 +20,7 @@ type decoderOptions struct {
maxWindowSize uint64 maxWindowSize uint64
dicts []dict dicts []dict
ignoreChecksum bool ignoreChecksum bool
limitToCap bool
} }
func (o *decoderOptions) setDefault() { func (o *decoderOptions) setDefault() {
@ -114,6 +115,17 @@ func WithDecoderMaxWindow(size uint64) DOption {
} }
} }
// WithDecodeAllCapLimit will limit DecodeAll to decoding cap(dst)-len(dst) bytes,
// or any size set in WithDecoderMaxMemory.
// This can be used to limit decoding to a specific maximum output size.
// Disabled by default.
func WithDecodeAllCapLimit(b bool) DOption {
return func(o *decoderOptions) error {
o.limitToCap = b
return nil
}
}
// IgnoreChecksum allows to forcibly ignore checksum checking. // IgnoreChecksum allows to forcibly ignore checksum checking.
func IgnoreChecksum(b bool) DOption { func IgnoreChecksum(b bool) DOption {
return func(o *decoderOptions) error { return func(o *decoderOptions) error {

@ -416,15 +416,23 @@ encodeLoop:
// Try to find a better match by searching for a long match at the end of the current best match // Try to find a better match by searching for a long match at the end of the current best match
if s+matched < sLimit { if s+matched < sLimit {
// Allow some bytes at the beginning to mismatch.
// Sweet spot is around 3 bytes, but depends on input.
// The skipped bytes are tested in Extend backwards,
// and still picked up as part of the match if they do.
const skipBeginning = 3
nextHashL := hashLen(load6432(src, s+matched), betterLongTableBits, betterLongLen) nextHashL := hashLen(load6432(src, s+matched), betterLongTableBits, betterLongLen)
cv := load3232(src, s) s2 := s + skipBeginning
cv := load3232(src, s2)
candidateL := e.longTable[nextHashL] candidateL := e.longTable[nextHashL]
coffsetL := candidateL.offset - e.cur - matched coffsetL := candidateL.offset - e.cur - matched + skipBeginning
if coffsetL >= 0 && coffsetL < s && s-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) { if coffsetL >= 0 && coffsetL < s2 && s2-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) {
// Found a long match, at least 4 bytes. // Found a long match, at least 4 bytes.
matchedNext := e.matchlen(s+4, coffsetL+4, src) + 4 matchedNext := e.matchlen(s2+4, coffsetL+4, src) + 4
if matchedNext > matched { if matchedNext > matched {
t = coffsetL t = coffsetL
s = s2
matched = matchedNext matched = matchedNext
if debugMatches { if debugMatches {
println("long match at end-of-match") println("long match at end-of-match")
@ -434,12 +442,13 @@ encodeLoop:
// Check prev long... // Check prev long...
if true { if true {
coffsetL = candidateL.prev - e.cur - matched coffsetL = candidateL.prev - e.cur - matched + skipBeginning
if coffsetL >= 0 && coffsetL < s && s-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) { if coffsetL >= 0 && coffsetL < s2 && s2-coffsetL < e.maxMatchOff && cv == load3232(src, coffsetL) {
// Found a long match, at least 4 bytes. // Found a long match, at least 4 bytes.
matchedNext := e.matchlen(s+4, coffsetL+4, src) + 4 matchedNext := e.matchlen(s2+4, coffsetL+4, src) + 4
if matchedNext > matched { if matchedNext > matched {
t = coffsetL t = coffsetL
s = s2
matched = matchedNext matched = matchedNext
if debugMatches { if debugMatches {
println("prev long match at end-of-match") println("prev long match at end-of-match")

@ -1103,7 +1103,8 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) {
} }
if allDirty || dirtyShardCnt > dLongTableShardCnt/2 { if allDirty || dirtyShardCnt > dLongTableShardCnt/2 {
copy(e.longTable[:], e.dictLongTable) //copy(e.longTable[:], e.dictLongTable)
e.longTable = *(*[dFastLongTableSize]tableEntry)(e.dictLongTable)
for i := range e.longTableShardDirty { for i := range e.longTableShardDirty {
e.longTableShardDirty[i] = false e.longTableShardDirty[i] = false
} }
@ -1114,7 +1115,9 @@ func (e *doubleFastEncoderDict) Reset(d *dict, singleBlock bool) {
continue continue
} }
copy(e.longTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize], e.dictLongTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize]) // copy(e.longTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize], e.dictLongTable[i*dLongTableShardSize:(i+1)*dLongTableShardSize])
*(*[dLongTableShardSize]tableEntry)(e.longTable[i*dLongTableShardSize:]) = *(*[dLongTableShardSize]tableEntry)(e.dictLongTable[i*dLongTableShardSize:])
e.longTableShardDirty[i] = false e.longTableShardDirty[i] = false
} }
} }

@ -871,7 +871,8 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) {
const shardCnt = tableShardCnt const shardCnt = tableShardCnt
const shardSize = tableShardSize const shardSize = tableShardSize
if e.allDirty || dirtyShardCnt > shardCnt*4/6 { if e.allDirty || dirtyShardCnt > shardCnt*4/6 {
copy(e.table[:], e.dictTable) //copy(e.table[:], e.dictTable)
e.table = *(*[tableSize]tableEntry)(e.dictTable)
for i := range e.tableShardDirty { for i := range e.tableShardDirty {
e.tableShardDirty[i] = false e.tableShardDirty[i] = false
} }
@ -883,7 +884,8 @@ func (e *fastEncoderDict) Reset(d *dict, singleBlock bool) {
continue continue
} }
copy(e.table[i*shardSize:(i+1)*shardSize], e.dictTable[i*shardSize:(i+1)*shardSize]) //copy(e.table[i*shardSize:(i+1)*shardSize], e.dictTable[i*shardSize:(i+1)*shardSize])
*(*[shardSize]tableEntry)(e.table[i*shardSize:]) = *(*[shardSize]tableEntry)(e.dictTable[i*shardSize:])
e.tableShardDirty[i] = false e.tableShardDirty[i] = false
} }
e.allDirty = false e.allDirty = false

@ -353,12 +353,23 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
// Store input length, so we only check new data. // Store input length, so we only check new data.
crcStart := len(dst) crcStart := len(dst)
d.history.decoders.maxSyncLen = 0 d.history.decoders.maxSyncLen = 0
if d.o.limitToCap {
d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst))
}
if d.FrameContentSize != fcsUnknown { if d.FrameContentSize != fcsUnknown {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst)) if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen {
d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
}
if d.history.decoders.maxSyncLen > d.o.maxDecodedSize { if d.history.decoders.maxSyncLen > d.o.maxDecodedSize {
if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize)
}
return dst, ErrDecoderSizeExceeded return dst, ErrDecoderSizeExceeded
} }
if uint64(cap(dst)) < d.history.decoders.maxSyncLen { if debugDecoder {
println("maxSyncLen:", d.history.decoders.maxSyncLen)
}
if !d.o.limitToCap && uint64(cap(dst)-len(dst)) < d.history.decoders.maxSyncLen {
// Alloc for output // Alloc for output
dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc) dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
copy(dst2, dst) copy(dst2, dst)
@ -378,7 +389,13 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
if err != nil { if err != nil {
break break
} }
if uint64(len(d.history.b)) > d.o.maxDecodedSize { if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize {
println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize)
err = ErrDecoderSizeExceeded
break
}
if d.o.limitToCap && len(d.history.b) > cap(dst) {
println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst))
err = ErrDecoderSizeExceeded err = ErrDecoderSizeExceeded
break break
} }

@ -1,7 +1,6 @@
// Code generated by command: go run gen_fse.go -out ../fse_decoder_amd64.s -pkg=zstd. DO NOT EDIT. // Code generated by command: go run gen_fse.go -out ../fse_decoder_amd64.s -pkg=zstd. DO NOT EDIT.
//go:build !appengine && !noasm && gc && !noasm //go:build !appengine && !noasm && gc && !noasm
// +build !appengine,!noasm,gc,!noasm
// func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int // func buildDtable_asm(s *fseDecoder, ctx *buildDtableAsmContext) int
TEXT ·buildDtable_asm(SB), $0-24 TEXT ·buildDtable_asm(SB), $0-24

@ -32,18 +32,22 @@ type decodeSyncAsmContext struct {
// sequenceDecs_decodeSync_amd64 implements the main loop of sequenceDecs.decodeSync in x86 asm. // sequenceDecs_decodeSync_amd64 implements the main loop of sequenceDecs.decodeSync in x86 asm.
// //
// Please refer to seqdec_generic.go for the reference implementation. // Please refer to seqdec_generic.go for the reference implementation.
//
//go:noescape //go:noescape
func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
// sequenceDecs_decodeSync_bmi2 implements the main loop of sequenceDecs.decodeSync in x86 asm with BMI2 extensions. // sequenceDecs_decodeSync_bmi2 implements the main loop of sequenceDecs.decodeSync in x86 asm with BMI2 extensions.
//
//go:noescape //go:noescape
func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
// sequenceDecs_decodeSync_safe_amd64 does the same as above, but does not write more than output buffer. // sequenceDecs_decodeSync_safe_amd64 does the same as above, but does not write more than output buffer.
//
//go:noescape //go:noescape
func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
// sequenceDecs_decodeSync_safe_bmi2 does the same as above, but does not write more than output buffer. // sequenceDecs_decodeSync_safe_bmi2 does the same as above, but does not write more than output buffer.
//
//go:noescape //go:noescape
func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
@ -201,20 +205,24 @@ const errorNotEnoughSpace = 5
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
// //
// Please refer to seqdec_generic.go for the reference implementation. // Please refer to seqdec_generic.go for the reference implementation.
//
//go:noescape //go:noescape
func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
// //
// Please refer to seqdec_generic.go for the reference implementation. // Please refer to seqdec_generic.go for the reference implementation.
//
//go:noescape //go:noescape
func sequenceDecs_decode_56_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int func sequenceDecs_decode_56_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
//
//go:noescape //go:noescape
func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions. // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
//
//go:noescape //go:noescape
func sequenceDecs_decode_56_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int func sequenceDecs_decode_56_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
@ -308,10 +316,12 @@ type executeAsmContext struct {
// Returns false if a match offset is too big. // Returns false if a match offset is too big.
// //
// Please refer to seqdec_generic.go for the reference implementation. // Please refer to seqdec_generic.go for the reference implementation.
//
//go:noescape //go:noescape
func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool
// Same as above, but with safe memcopies // Same as above, but with safe memcopies
//
//go:noescape //go:noescape
func sequenceDecs_executeSimple_safe_amd64(ctx *executeAsmContext) bool func sequenceDecs_executeSimple_safe_amd64(ctx *executeAsmContext) bool

@ -1,7 +1,6 @@
// Code generated by command: go run gen.go -out ../seqdec_amd64.s -pkg=zstd. DO NOT EDIT. // Code generated by command: go run gen.go -out ../seqdec_amd64.s -pkg=zstd. DO NOT EDIT.
//go:build !appengine && !noasm && gc && !noasm //go:build !appengine && !noasm && gc && !noasm
// +build !appengine,!noasm,gc,!noasm
// func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int // func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
// Requires: CMOV // Requires: CMOV

@ -214,11 +214,6 @@ esac
if [ "$GOOSARCH" == "aix_ppc64" ]; then if [ "$GOOSARCH" == "aix_ppc64" ]; then
# aix/ppc64 script generates files instead of writing to stdin. # aix/ppc64 script generates files instead of writing to stdin.
echo "$mksyscall -tags $GOOS,$GOARCH $syscall_goos $GOOSARCH_in && gofmt -w zsyscall_$GOOSARCH.go && gofmt -w zsyscall_"$GOOSARCH"_gccgo.go && gofmt -w zsyscall_"$GOOSARCH"_gc.go " ; echo "$mksyscall -tags $GOOS,$GOARCH $syscall_goos $GOOSARCH_in && gofmt -w zsyscall_$GOOSARCH.go && gofmt -w zsyscall_"$GOOSARCH"_gccgo.go && gofmt -w zsyscall_"$GOOSARCH"_gc.go " ;
elif [ "$GOOS" == "darwin" ]; then
# 1.12 and later, syscalls via libSystem
echo "$mksyscall -tags $GOOS,$GOARCH,go1.12 $syscall_goos $GOOSARCH_in |gofmt >zsyscall_$GOOSARCH.go";
# 1.13 and later, syscalls via libSystem (including syscallPtr)
echo "$mksyscall -tags $GOOS,$GOARCH,go1.13 syscall_darwin.1_13.go |gofmt >zsyscall_$GOOSARCH.1_13.go";
elif [ "$GOOS" == "illumos" ]; then elif [ "$GOOS" == "illumos" ]; then
# illumos code generation requires a --illumos switch # illumos code generation requires a --illumos switch
echo "$mksyscall -illumos -tags illumos,$GOARCH syscall_illumos.go |gofmt > zsyscall_illumos_$GOARCH.go"; echo "$mksyscall -illumos -tags illumos,$GOARCH syscall_illumos.go |gofmt > zsyscall_illumos_$GOARCH.go";

@ -1,32 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && go1.12 && !go1.13
// +build darwin,go1.12,!go1.13
package unix
import (
"unsafe"
)
const _SYS_GETDIRENTRIES64 = 344
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// To implement this using libSystem we'd need syscall_syscallPtr for
// fdopendir. However, syscallPtr was only added in Go 1.13, so we fall
// back to raw syscalls for this func on Go 1.12.
var p unsafe.Pointer
if len(buf) > 0 {
p = unsafe.Pointer(&buf[0])
} else {
p = unsafe.Pointer(&_zero)
}
r0, _, e1 := Syscall6(_SYS_GETDIRENTRIES64, uintptr(fd), uintptr(p), uintptr(len(buf)), uintptr(unsafe.Pointer(basep)), 0, 0)
n = int(r0)
if e1 != 0 {
return n, errnoErr(e1)
}
return n, nil
}

@ -1,100 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build darwin && go1.13
// +build darwin,go1.13
package unix
import "unsafe"
//sys closedir(dir uintptr) (err error)
//sys readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno)
func fdopendir(fd int) (dir uintptr, err error) {
r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0)
dir = uintptr(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_fdopendir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib"
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// Simulate Getdirentries using fdopendir/readdir_r/closedir.
// We store the number of entries to skip in the seek
// offset of fd. See issue #31368.
// It's not the full required semantics, but should handle the case
// of calling Getdirentries or ReadDirent repeatedly.
// It won't handle assigning the results of lseek to *basep, or handle
// the directory being edited underfoot.
skip, err := Seek(fd, 0, 1 /* SEEK_CUR */)
if err != nil {
return 0, err
}
// We need to duplicate the incoming file descriptor
// because the caller expects to retain control of it, but
// fdopendir expects to take control of its argument.
// Just Dup'ing the file descriptor is not enough, as the
// result shares underlying state. Use Openat to make a really
// new file descriptor referring to the same directory.
fd2, err := Openat(fd, ".", O_RDONLY, 0)
if err != nil {
return 0, err
}
d, err := fdopendir(fd2)
if err != nil {
Close(fd2)
return 0, err
}
defer closedir(d)
var cnt int64
for {
var entry Dirent
var entryp *Dirent
e := readdir_r(d, &entry, &entryp)
if e != 0 {
return n, errnoErr(e)
}
if entryp == nil {
break
}
if skip > 0 {
skip--
cnt++
continue
}
reclen := int(entry.Reclen)
if reclen > len(buf) {
// Not enough room. Return for now.
// The counter will let us know where we should start up again.
// Note: this strategy for suspending in the middle and
// restarting is O(n^2) in the length of the directory. Oh well.
break
}
// Copy entry into return buffer.
s := unsafe.Slice((*byte)(unsafe.Pointer(&entry)), reclen)
copy(buf, s)
buf = buf[reclen:]
n += reclen
cnt++
}
// Set the seek offset of the input fd to record
// how many files we've already returned.
_, err = Seek(fd, cnt, 0 /* SEEK_SET */)
if err != nil {
return n, err
}
return n, nil
}

@ -19,6 +19,96 @@ import (
"unsafe" "unsafe"
) )
//sys closedir(dir uintptr) (err error)
//sys readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno)
func fdopendir(fd int) (dir uintptr, err error) {
r0, _, e1 := syscall_syscallPtr(libc_fdopendir_trampoline_addr, uintptr(fd), 0, 0)
dir = uintptr(r0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_fdopendir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_fdopendir fdopendir "/usr/lib/libSystem.B.dylib"
func Getdirentries(fd int, buf []byte, basep *uintptr) (n int, err error) {
// Simulate Getdirentries using fdopendir/readdir_r/closedir.
// We store the number of entries to skip in the seek
// offset of fd. See issue #31368.
// It's not the full required semantics, but should handle the case
// of calling Getdirentries or ReadDirent repeatedly.
// It won't handle assigning the results of lseek to *basep, or handle
// the directory being edited underfoot.
skip, err := Seek(fd, 0, 1 /* SEEK_CUR */)
if err != nil {
return 0, err
}
// We need to duplicate the incoming file descriptor
// because the caller expects to retain control of it, but
// fdopendir expects to take control of its argument.
// Just Dup'ing the file descriptor is not enough, as the
// result shares underlying state. Use Openat to make a really
// new file descriptor referring to the same directory.
fd2, err := Openat(fd, ".", O_RDONLY, 0)
if err != nil {
return 0, err
}
d, err := fdopendir(fd2)
if err != nil {
Close(fd2)
return 0, err
}
defer closedir(d)
var cnt int64
for {
var entry Dirent
var entryp *Dirent
e := readdir_r(d, &entry, &entryp)
if e != 0 {
return n, errnoErr(e)
}
if entryp == nil {
break
}
if skip > 0 {
skip--
cnt++
continue
}
reclen := int(entry.Reclen)
if reclen > len(buf) {
// Not enough room. Return for now.
// The counter will let us know where we should start up again.
// Note: this strategy for suspending in the middle and
// restarting is O(n^2) in the length of the directory. Oh well.
break
}
// Copy entry into return buffer.
s := unsafe.Slice((*byte)(unsafe.Pointer(&entry)), reclen)
copy(buf, s)
buf = buf[reclen:]
n += reclen
cnt++
}
// Set the seek offset of the input fd to record
// how many files we've already returned.
_, err = Seek(fd, cnt, 0 /* SEEK_SET */)
if err != nil {
return n, err
}
return n, nil
}
// SockaddrDatalink implements the Sockaddr interface for AF_LINK type sockets. // SockaddrDatalink implements the Sockaddr interface for AF_LINK type sockets.
type SockaddrDatalink struct { type SockaddrDatalink struct {
Len uint8 Len uint8

@ -1,40 +0,0 @@
// go run mksyscall.go -tags darwin,amd64,go1.13 syscall_darwin.1_13.go
// Code generated by the command above; see README.md. DO NOT EDIT.
//go:build darwin && amd64 && go1.13
// +build darwin,amd64,go1.13
package unix
import (
"syscall"
"unsafe"
)
var _ syscall.Errno
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func closedir(dir uintptr) (err error) {
_, _, e1 := syscall_syscall(libc_closedir_trampoline_addr, uintptr(dir), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_closedir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_closedir closedir "/usr/lib/libSystem.B.dylib"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno) {
r0, _, _ := syscall_syscall(libc_readdir_r_trampoline_addr, uintptr(dir), uintptr(unsafe.Pointer(entry)), uintptr(unsafe.Pointer(result)))
res = Errno(r0)
return
}
var libc_readdir_r_trampoline_addr uintptr
//go:cgo_import_dynamic libc_readdir_r readdir_r "/usr/lib/libSystem.B.dylib"

@ -1,25 +0,0 @@
// go run mkasm.go darwin amd64
// Code generated by the command above; DO NOT EDIT.
//go:build go1.13
// +build go1.13
#include "textflag.h"
TEXT libc_fdopendir_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_fdopendir(SB)
GLOBL ·libc_fdopendir_trampoline_addr(SB), RODATA, $8
DATA ·libc_fdopendir_trampoline_addr(SB)/8, $libc_fdopendir_trampoline<>(SB)
TEXT libc_closedir_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_closedir(SB)
GLOBL ·libc_closedir_trampoline_addr(SB), RODATA, $8
DATA ·libc_closedir_trampoline_addr(SB)/8, $libc_closedir_trampoline<>(SB)
TEXT libc_readdir_r_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_readdir_r(SB)
GLOBL ·libc_readdir_r_trampoline_addr(SB), RODATA, $8
DATA ·libc_readdir_r_trampoline_addr(SB)/8, $libc_readdir_r_trampoline<>(SB)

@ -1,8 +1,8 @@
// go run mksyscall.go -tags darwin,amd64,go1.12 syscall_bsd.go syscall_darwin.go syscall_darwin_amd64.go // go run mksyscall.go -tags darwin,amd64 syscall_bsd.go syscall_darwin.go syscall_darwin_amd64.go
// Code generated by the command above; see README.md. DO NOT EDIT. // Code generated by the command above; see README.md. DO NOT EDIT.
//go:build darwin && amd64 && go1.12 //go:build darwin && amd64
// +build darwin,amd64,go1.12 // +build darwin,amd64
package unix package unix
@ -463,6 +463,32 @@ var libc_munlockall_trampoline_addr uintptr
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func closedir(dir uintptr) (err error) {
_, _, e1 := syscall_syscall(libc_closedir_trampoline_addr, uintptr(dir), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_closedir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_closedir closedir "/usr/lib/libSystem.B.dylib"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno) {
r0, _, _ := syscall_syscall(libc_readdir_r_trampoline_addr, uintptr(dir), uintptr(unsafe.Pointer(entry)), uintptr(unsafe.Pointer(result)))
res = Errno(r0)
return
}
var libc_readdir_r_trampoline_addr uintptr
//go:cgo_import_dynamic libc_readdir_r readdir_r "/usr/lib/libSystem.B.dylib"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func pipe(p *[2]int32) (err error) { func pipe(p *[2]int32) (err error) {
_, _, e1 := syscall_rawSyscall(libc_pipe_trampoline_addr, uintptr(unsafe.Pointer(p)), 0, 0) _, _, e1 := syscall_rawSyscall(libc_pipe_trampoline_addr, uintptr(unsafe.Pointer(p)), 0, 0)
if e1 != 0 { if e1 != 0 {

@ -1,11 +1,14 @@
// go run mkasm.go darwin amd64 // go run mkasm.go darwin amd64
// Code generated by the command above; DO NOT EDIT. // Code generated by the command above; DO NOT EDIT.
//go:build go1.12
// +build go1.12
#include "textflag.h" #include "textflag.h"
TEXT libc_fdopendir_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_fdopendir(SB)
GLOBL ·libc_fdopendir_trampoline_addr(SB), RODATA, $8
DATA ·libc_fdopendir_trampoline_addr(SB)/8, $libc_fdopendir_trampoline<>(SB)
TEXT libc_getgroups_trampoline<>(SB),NOSPLIT,$0-0 TEXT libc_getgroups_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_getgroups(SB) JMP libc_getgroups(SB)
@ -174,6 +177,18 @@ TEXT libc_munlockall_trampoline<>(SB),NOSPLIT,$0-0
GLOBL ·libc_munlockall_trampoline_addr(SB), RODATA, $8 GLOBL ·libc_munlockall_trampoline_addr(SB), RODATA, $8
DATA ·libc_munlockall_trampoline_addr(SB)/8, $libc_munlockall_trampoline<>(SB) DATA ·libc_munlockall_trampoline_addr(SB)/8, $libc_munlockall_trampoline<>(SB)
TEXT libc_closedir_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_closedir(SB)
GLOBL ·libc_closedir_trampoline_addr(SB), RODATA, $8
DATA ·libc_closedir_trampoline_addr(SB)/8, $libc_closedir_trampoline<>(SB)
TEXT libc_readdir_r_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_readdir_r(SB)
GLOBL ·libc_readdir_r_trampoline_addr(SB), RODATA, $8
DATA ·libc_readdir_r_trampoline_addr(SB)/8, $libc_readdir_r_trampoline<>(SB)
TEXT libc_pipe_trampoline<>(SB),NOSPLIT,$0-0 TEXT libc_pipe_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_pipe(SB) JMP libc_pipe(SB)

@ -1,40 +0,0 @@
// go run mksyscall.go -tags darwin,arm64,go1.13 syscall_darwin.1_13.go
// Code generated by the command above; see README.md. DO NOT EDIT.
//go:build darwin && arm64 && go1.13
// +build darwin,arm64,go1.13
package unix
import (
"syscall"
"unsafe"
)
var _ syscall.Errno
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func closedir(dir uintptr) (err error) {
_, _, e1 := syscall_syscall(libc_closedir_trampoline_addr, uintptr(dir), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_closedir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_closedir closedir "/usr/lib/libSystem.B.dylib"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno) {
r0, _, _ := syscall_syscall(libc_readdir_r_trampoline_addr, uintptr(dir), uintptr(unsafe.Pointer(entry)), uintptr(unsafe.Pointer(result)))
res = Errno(r0)
return
}
var libc_readdir_r_trampoline_addr uintptr
//go:cgo_import_dynamic libc_readdir_r readdir_r "/usr/lib/libSystem.B.dylib"

@ -1,25 +0,0 @@
// go run mkasm.go darwin arm64
// Code generated by the command above; DO NOT EDIT.
//go:build go1.13
// +build go1.13
#include "textflag.h"
TEXT libc_fdopendir_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_fdopendir(SB)
GLOBL ·libc_fdopendir_trampoline_addr(SB), RODATA, $8
DATA ·libc_fdopendir_trampoline_addr(SB)/8, $libc_fdopendir_trampoline<>(SB)
TEXT libc_closedir_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_closedir(SB)
GLOBL ·libc_closedir_trampoline_addr(SB), RODATA, $8
DATA ·libc_closedir_trampoline_addr(SB)/8, $libc_closedir_trampoline<>(SB)
TEXT libc_readdir_r_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_readdir_r(SB)
GLOBL ·libc_readdir_r_trampoline_addr(SB), RODATA, $8
DATA ·libc_readdir_r_trampoline_addr(SB)/8, $libc_readdir_r_trampoline<>(SB)

@ -1,8 +1,8 @@
// go run mksyscall.go -tags darwin,arm64,go1.12 syscall_bsd.go syscall_darwin.go syscall_darwin_arm64.go // go run mksyscall.go -tags darwin,arm64 syscall_bsd.go syscall_darwin.go syscall_darwin_arm64.go
// Code generated by the command above; see README.md. DO NOT EDIT. // Code generated by the command above; see README.md. DO NOT EDIT.
//go:build darwin && arm64 && go1.12 //go:build darwin && arm64
// +build darwin,arm64,go1.12 // +build darwin,arm64
package unix package unix
@ -463,6 +463,32 @@ var libc_munlockall_trampoline_addr uintptr
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func closedir(dir uintptr) (err error) {
_, _, e1 := syscall_syscall(libc_closedir_trampoline_addr, uintptr(dir), 0, 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
var libc_closedir_trampoline_addr uintptr
//go:cgo_import_dynamic libc_closedir closedir "/usr/lib/libSystem.B.dylib"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func readdir_r(dir uintptr, entry *Dirent, result **Dirent) (res Errno) {
r0, _, _ := syscall_syscall(libc_readdir_r_trampoline_addr, uintptr(dir), uintptr(unsafe.Pointer(entry)), uintptr(unsafe.Pointer(result)))
res = Errno(r0)
return
}
var libc_readdir_r_trampoline_addr uintptr
//go:cgo_import_dynamic libc_readdir_r readdir_r "/usr/lib/libSystem.B.dylib"
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func pipe(p *[2]int32) (err error) { func pipe(p *[2]int32) (err error) {
_, _, e1 := syscall_rawSyscall(libc_pipe_trampoline_addr, uintptr(unsafe.Pointer(p)), 0, 0) _, _, e1 := syscall_rawSyscall(libc_pipe_trampoline_addr, uintptr(unsafe.Pointer(p)), 0, 0)
if e1 != 0 { if e1 != 0 {

@ -1,11 +1,14 @@
// go run mkasm.go darwin arm64 // go run mkasm.go darwin arm64
// Code generated by the command above; DO NOT EDIT. // Code generated by the command above; DO NOT EDIT.
//go:build go1.12
// +build go1.12
#include "textflag.h" #include "textflag.h"
TEXT libc_fdopendir_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_fdopendir(SB)
GLOBL ·libc_fdopendir_trampoline_addr(SB), RODATA, $8
DATA ·libc_fdopendir_trampoline_addr(SB)/8, $libc_fdopendir_trampoline<>(SB)
TEXT libc_getgroups_trampoline<>(SB),NOSPLIT,$0-0 TEXT libc_getgroups_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_getgroups(SB) JMP libc_getgroups(SB)
@ -174,6 +177,18 @@ TEXT libc_munlockall_trampoline<>(SB),NOSPLIT,$0-0
GLOBL ·libc_munlockall_trampoline_addr(SB), RODATA, $8 GLOBL ·libc_munlockall_trampoline_addr(SB), RODATA, $8
DATA ·libc_munlockall_trampoline_addr(SB)/8, $libc_munlockall_trampoline<>(SB) DATA ·libc_munlockall_trampoline_addr(SB)/8, $libc_munlockall_trampoline<>(SB)
TEXT libc_closedir_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_closedir(SB)
GLOBL ·libc_closedir_trampoline_addr(SB), RODATA, $8
DATA ·libc_closedir_trampoline_addr(SB)/8, $libc_closedir_trampoline<>(SB)
TEXT libc_readdir_r_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_readdir_r(SB)
GLOBL ·libc_readdir_r_trampoline_addr(SB), RODATA, $8
DATA ·libc_readdir_r_trampoline_addr(SB)/8, $libc_readdir_r_trampoline<>(SB)
TEXT libc_pipe_trampoline<>(SB),NOSPLIT,$0-0 TEXT libc_pipe_trampoline<>(SB),NOSPLIT,$0-0
JMP libc_pipe(SB) JMP libc_pipe(SB)

3
vendor/gorm.io/gorm/.gitignore generated vendored

@ -3,4 +3,5 @@ documents
coverage.txt coverage.txt
_book _book
.idea .idea
vendor vendor
.vscode

@ -507,7 +507,9 @@ func (association *Association) buildCondition() *DB {
joinStmt.AddClause(queryClause) joinStmt.AddClause(queryClause)
} }
joinStmt.Build("WHERE") joinStmt.Build("WHERE")
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
} }
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{

@ -206,7 +206,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
} }
} }
cacheKey := utils.ToStringKey(relPrimaryValues) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true identityMap[cacheKey] = true
if isPtr { if isPtr {
@ -292,7 +292,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
} }
} }
cacheKey := utils.ToStringKey(relPrimaryValues) cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
identityMap[cacheKey] = true identityMap[cacheKey] = true
distinctElems = reflect.Append(distinctElems, elem) distinctElems = reflect.Append(distinctElems, elem)

@ -70,10 +70,12 @@ func Update(config *Config) func(db *gorm.DB) {
if db.Statement.SQL.Len() == 0 { if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180) db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{}) db.Statement.AddClauseIfNotExists(clause.Update{})
if set := ConvertToAssignments(db.Statement); len(set) != 0 { if _, ok := db.Statement.Clauses["SET"]; !ok {
db.Statement.AddClause(set) if set := ConvertToAssignments(db.Statement); len(set) != 0 {
} else if _, ok := db.Statement.Clauses["SET"]; !ok { db.Statement.AddClause(set)
return } else {
return
}
} }
db.Statement.Build(db.Statement.BuildClauses...) db.Statement.Build(db.Statement.BuildClauses...)
@ -158,21 +160,21 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
switch stmt.ReflectValue.Kind() { switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 { if size := stmt.ReflectValue.Len(); size > 0 {
var primaryKeyExprs []clause.Expression var isZero bool
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
exprs := make([]clause.Expression, len(stmt.Schema.PrimaryFields)) for _, field := range stmt.Schema.PrimaryFields {
var notZero bool _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
for idx, field := range stmt.Schema.PrimaryFields { if !isZero {
value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) break
exprs[idx] = clause.Eq{Column: field.DBName, Value: value} }
notZero = notZero || !isZero
}
if notZero {
primaryKeyExprs = append(primaryKeyExprs, clause.And(exprs...))
} }
} }
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.And(clause.Or(primaryKeyExprs...))}}) if !isZero {
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
} }
case reflect.Struct: case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields { for _, field := range stmt.Schema.PrimaryFields {

@ -13,7 +13,7 @@ import (
"gorm.io/gorm/utils" "gorm.io/gorm/utils"
) )
// Create insert the value into database // Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) { func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 { if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize) return db.CreateInBatches(value, db.CreateBatchSize)
@ -24,7 +24,7 @@ func (db *DB) Create(value interface{}) (tx *DB) {
return tx.callbacks.Create().Execute(tx) return tx.callbacks.Create().Execute(tx)
} }
// CreateInBatches insert the value in batches into database // CreateInBatches inserts value in batches of batchSize
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value)) reflectValue := reflect.Indirect(reflect.ValueOf(value))
@ -68,7 +68,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
return return
} }
// Save update value in database, if the value doesn't have primary key, will insert it // Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) { func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = value tx.Statement.Dest = value
@ -114,7 +114,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
return return
} }
// First find first record that match given conditions, order by primary key // First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -129,7 +129,7 @@ func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Take return a record that match given conditions, the order will depend on the database implementation // Take finds the first record returned by the database in no specified order, matching given conditions conds
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1) tx = db.Limit(1)
if len(conds) > 0 { if len(conds) > 0 {
@ -142,7 +142,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Last find last record that match given conditions, order by primary key // Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{ tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -158,7 +158,7 @@ func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// Find find records that match given conditions // Find finds all records matching given conditions conds
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
@ -170,7 +170,7 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
return tx.callbacks.Query().Execute(tx) return tx.callbacks.Query().Execute(tx)
} }
// FindInBatches find records in batches // FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var ( var (
tx = db.Order(clause.OrderByColumn{ tx = db.Order(clause.OrderByColumn{
@ -202,7 +202,9 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
batch++ batch++
if result.Error == nil && result.RowsAffected != 0 { if result.Error == nil && result.RowsAffected != 0 {
tx.AddError(fc(result, batch)) fcTx := result.Session(&Session{NewDB: true})
fcTx.RowsAffected = result.RowsAffected
tx.AddError(fc(fcTx, batch))
} else if result.Error != nil { } else if result.Error != nil {
tx.AddError(result.Error) tx.AddError(result.Error)
} }
@ -284,7 +286,8 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
} }
} }
// FirstOrInit gets the first matched record or initialize a new instance with given conditions (only works with struct or map conditions) // FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{ queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
@ -310,7 +313,8 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
return return
} }
// FirstOrCreate gets the first matched record or create a new one with given conditions (only works with struct, map conditions) // FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
@ -358,14 +362,14 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
return tx return tx
} }
// Update update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields // Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value interface{}) (tx *DB) { func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.Dest = map[string]interface{}{column: value}
return tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
} }
// Updates update attributes with callbacks, refer: https://gorm.io/docs/update.html#Update-Changed-Fields // Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.Dest = values tx.Statement.Dest = values
@ -386,7 +390,9 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
return tx.callbacks.Update().Execute(tx) return tx.callbacks.Update().Execute(tx)
} }
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition // Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null.
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
if len(conds) > 0 { if len(conds) > 0 {
@ -480,7 +486,7 @@ func (db *DB) Rows() (*sql.Rows, error) {
return rows, tx.Error return rows, tx.Error
} }
// Scan scan value to a struct // Scan scans selected value to the struct dest
func (db *DB) Scan(dest interface{}) (tx *DB) { func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New() currentLogger, newLogger := config.Logger, logger.Recorder.New()
@ -505,7 +511,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
return return
} }
// Pluck used to query single column from a model as a map // Pluck queries a single column from a model, returning in the slice dest. E.g.:
// var ages []int64 // var ages []int64
// db.Model(&users).Pluck("age", &ages) // db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
@ -548,7 +554,8 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
return tx.Error return tx.Error
} }
// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed. // Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) { func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil { if db.Error != nil {
return db.Error return db.Error
@ -570,7 +577,9 @@ func (db *DB) Connection(fc func(tx *DB) error) (err error) {
return fc(tx) return fc(tx)
} }
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true panicked := true
@ -613,7 +622,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
return return
} }
// Begin begins a transaction // Begin begins a transaction with any transaction options opts
func (db *DB) Begin(opts ...*sql.TxOptions) *DB { func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var ( var (
// clone statement // clone statement
@ -642,7 +651,7 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
return tx return tx
} }
// Commit commit a transaction // Commit commits the changes in a transaction
func (db *DB) Commit() *DB { func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit()) db.AddError(committer.Commit())
@ -652,7 +661,7 @@ func (db *DB) Commit() *DB {
return db return db
} }
// Rollback rollback a transaction // Rollback rollbacks the changes in a transaction
func (db *DB) Rollback() *DB { func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() { if !reflect.ValueOf(committer).IsNil() {
@ -682,7 +691,7 @@ func (db *DB) RollbackTo(name string) *DB {
return db return db
} }
// Exec execute raw sql // Exec executes raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance() tx = db.getInstance()
tx.Statement.SQL = strings.Builder{} tx.Statement.SQL = strings.Builder{}

5
vendor/gorm.io/gorm/gorm.go generated vendored

@ -300,7 +300,8 @@ func (db *DB) WithContext(ctx context.Context) *DB {
// Debug start debug mode // Debug start debug mode
func (db *DB) Debug() (tx *DB) { func (db *DB) Debug() (tx *DB) {
return db.Session(&Session{ tx = db.getInstance()
return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info), Logger: db.Logger.LogMode(logger.Info),
}) })
} }
@ -412,7 +413,7 @@ func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interfac
relation, ok := modelSchema.Relationships.Relations[field] relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil isRelation := ok && relation.JoinTable != nil
if !isRelation { if !isRelation {
return fmt.Errorf("failed to found relation: %s", field) return fmt.Errorf("failed to find relation: %s", field)
} }
for _, ref := range relation.References { for _, ref := range relation.References {

@ -4,7 +4,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io"
"log" "log"
"os" "os"
"time" "time"
@ -68,8 +68,8 @@ type Interface interface {
} }
var ( var (
// Discard Discard logger will print any log to ioutil.Discard // Discard Discard logger will print any log to io.Discard
Discard = New(log.New(ioutil.Discard, "", log.LstdFlags), Config{}) Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger // Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond, SlowThreshold: 200 * time.Millisecond,

17
vendor/gorm.io/gorm/logger/sql.go generated vendored

@ -30,6 +30,8 @@ func isPrintable(s string) bool {
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var ( var (
@ -138,9 +140,18 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
sql = newSQL.String() sql = newSQL.String()
} else { } else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
for idx, v := range vars {
sql = strings.Replace(sql, "$"+strconv.Itoa(idx+1)+"$", v, 1) sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
} num := v[1 : len(v)-1]
n, _ := strconv.Atoi(num)
// position var start from 1 ($1, $2)
n -= 1
if n >= 0 && n <= len(vars)-1 {
return vars[n]
}
return v
})
} }
return sql return sql

@ -15,7 +15,7 @@ import (
) )
var ( var (
regFullDataType = regexp.MustCompile(`[^\d]*(\d+)[^\d]?`) regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
) )
// Migrator m struct // Migrator m struct
@ -135,12 +135,12 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
} }
} }
}
for _, chk := range stmt.Schema.ParseCheckConstraints() { for _, chk := range stmt.Schema.ParseCheckConstraints() {
if !tx.Migrator().HasConstraint(value, chk.Name) { if !tx.Migrator().HasConstraint(value, chk.Name) {
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err return err
}
} }
} }
} }

38
vendor/gorm.io/gorm/scan.go generated vendored

@ -66,30 +66,32 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int
db.RowsAffected++ db.RowsAffected++
db.AddError(rows.Scan(values...)) db.AddError(rows.Scan(values...))
joinedSchemaMap := make(map[*schema.Field]interface{}, 0) joinedSchemaMap := make(map[*schema.Field]interface{})
for idx, field := range fields { for idx, field := range fields {
if field != nil { if field == nil {
if len(joinFields) == 0 || joinFields[idx][0] == nil { continue
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) }
} else {
joinSchema := joinFields[idx][0]
relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr {
if _, ok := joinedSchemaMap[joinSchema]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
continue
}
relValue.Set(reflect.New(relValue.Type().Elem())) if len(joinFields) == 0 || joinFields[idx][0] == nil {
joinedSchemaMap[joinSchema] = nil db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else {
joinSchema := joinFields[idx][0]
relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue)
if relValue.Kind() == reflect.Ptr {
if _, ok := joinedSchemaMap[joinSchema]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
continue
} }
relValue.Set(reflect.New(relValue.Type().Elem()))
joinedSchemaMap[joinSchema] = nil
} }
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
} }
db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx]))
// release data to pool
field.NewValuePool.Put(values[idx])
} }
// release data to pool
field.NewValuePool.Put(values[idx])
} }
} }

@ -403,18 +403,14 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
} }
if ef.PrimaryKey { if ef.PrimaryKey {
if val, ok := ef.TagSettings["PRIMARYKEY"]; ok && utils.CheckTruth(val) { if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) {
ef.PrimaryKey = true
} else if val, ok := ef.TagSettings["PRIMARY_KEY"]; ok && utils.CheckTruth(val) {
ef.PrimaryKey = true
} else {
ef.PrimaryKey = false ef.PrimaryKey = false
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
ef.AutoIncrement = false ef.AutoIncrement = false
} }
if ef.DefaultValue == "" { if !ef.AutoIncrement && ef.DefaultValue == "" {
ef.HasDefaultValue = false ef.HasDefaultValue = false
} }
} }
@ -472,9 +468,6 @@ func (field *Field) setupValuerAndSetter() {
oldValuerOf := field.ValueOf oldValuerOf := field.ValueOf
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
value, zero := oldValuerOf(ctx, v) value, zero := oldValuerOf(ctx, v)
if zero {
return value, zero
}
s, ok := value.(SerializerValuerInterface) s, ok := value.(SerializerValuerInterface)
if !ok { if !ok {
@ -487,7 +480,7 @@ func (field *Field) setupValuerAndSetter() {
Destination: v, Destination: v,
Context: ctx, Context: ctx,
fieldValue: value, fieldValue: value,
}, false }, zero
} }
} }

@ -112,7 +112,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
schemaCacheKey = modelType schemaCacheKey = modelType
} }
// Load exist schmema cache, return if exists // Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok { if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete
@ -146,7 +146,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam
// When the schema initialization is completed, the channel will be closed // When the schema initialization is completed, the channel will be closed
defer close(schema.initialized) defer close(schema.initialized)
// Load exist schmema cache, return if exists // Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok { if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema) s := v.(*Schema)
// Wait for the initialization of other goroutines to complete // Wait for the initialization of other goroutines to complete

@ -88,7 +88,9 @@ func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue) return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
} }
err = json.Unmarshal(bytes, fieldValue.Interface()) if len(bytes) > 0 {
err = json.Unmarshal(bytes, fieldValue.Interface())
}
} }
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
@ -117,9 +119,15 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.
// Value implements serializer interface // Value implements serializer interface
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
rv := reflect.ValueOf(fieldValue)
switch v := fieldValue.(type) { switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16, *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16: case int64, int, uint, uint64, int32, uint32, int16, uint16:
result = time.Unix(reflect.Indirect(reflect.ValueOf(v)).Int(), 0) result = time.Unix(reflect.Indirect(rv).Int(), 0)
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
if rv.IsZero() {
return nil, nil
}
result = time.Unix(reflect.Indirect(rv).Int(), 0)
default: default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v) err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
} }
@ -142,8 +150,10 @@ func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value,
default: default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
} }
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) if len(bytesValue) > 0 {
err = decoder.Decode(fieldValue.Interface()) decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
err = decoder.Decode(fieldValue.Interface())
}
} }
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return return

6
vendor/gorm.io/gorm/statement.go generated vendored

@ -650,7 +650,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false return false
} }
var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_0-9]+?)[\W]?\.[\W]?([a-z_0-9]+?)[\W]?$`) var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
@ -672,8 +672,8 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
} }
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = true results[field.DBName] = true
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 { } else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
results[matches[1]] = true results[matches[2]] = true
} else { } else {
results[column] = true results[column] = true
} }

Loading…
Cancel
Save