diff --git a/utils/dorm/mongo_official.go b/utils/dorm/mongo_official.go index 6fde4feb..e1c211d5 100644 --- a/utils/dorm/mongo_official.go +++ b/utils/dorm/mongo_official.go @@ -8,31 +8,31 @@ import ( ) // InsertOne 插入单个文档 -func (c *MongoClient) InsertOne(ctx context.Context, document interface{}, opts ...*options.InsertOneOptions) (result *mongo.InsertOneResult, err error) { +func (c *MongoClient) InsertOne(document interface{}) (result *mongo.InsertOneResult, err error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) - result, err = collection.InsertOne(ctx, document, opts...) + result, err = collection.InsertOne(context.TODO(), document) return result, err } // InsertMany 插入多个文档 -func (c *MongoClient) InsertMany(ctx context.Context, documents []interface{}, opts ...*options.InsertManyOptions) (result *mongo.InsertManyResult, err error) { +func (c *MongoClient) InsertMany(documents []interface{}) (result *mongo.InsertManyResult, err error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) - result, err = collection.InsertMany(ctx, documents, opts...) + result, err = collection.InsertMany(context.TODO(), documents) return result, err } // Delete 删除文档 -func (c *MongoClient) Delete(ctx context.Context, filter interface{}, opts ...*options.DeleteOptions) (int64, error) { +func (c *MongoClient) Delete(filter interface{}) (int64, error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) - count, err := collection.DeleteOne(ctx, filter, opts...) + count, err := collection.DeleteOne(context.TODO(), filter) return count.DeletedCount, err } // DeleteMany 删除多个文档 -func (c *MongoClient) DeleteMany(ctx context.Context, key string, value interface{}) (int64, error) { +func (c *MongoClient) DeleteMany(key string, value interface{}) (int64, error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) filter := bson.D{{key, value}} - count, err := collection.DeleteMany(ctx, filter) + count, err := collection.DeleteMany(context.TODO(), filter) return count.DeletedCount, err } @@ -41,9 +41,9 @@ func (c *MongoClient) DeleteMany(ctx context.Context, key string, value interfac // 字段增加值 inc($inc) // 从数组中增加一个元素 push($push) // 从数组中删除一个元素 pull($pull) -func (c *MongoClient) UpdateOne(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (int64, error) { +func (c *MongoClient) UpdateOne(filter interface{}, update interface{}) (int64, error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) - result, err := collection.UpdateOne(ctx, filter, update, opts...) + result, err := collection.UpdateOne(context.TODO(), filter, update) return result.UpsertedCount, err } @@ -52,82 +52,82 @@ func (c *MongoClient) UpdateOne(ctx context.Context, filter interface{}, update // 字段增加值 inc($inc) // 从数组中增加一个元素 push($push) // 从数组中删除一个元素 pull($pull) -func (c *MongoClient) UpdateMany(ctx context.Context, filter interface{}, update interface{}, opts ...*options.UpdateOptions) (int64, error) { +func (c *MongoClient) UpdateMany(filter interface{}, update interface{}) (int64, error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) - result, err := collection.UpdateMany(ctx, filter, update, opts...) + result, err := collection.UpdateMany(context.TODO(), filter, update) return result.UpsertedCount, err } // Find 查询 -//func (c *MongoClient) Find(filter interface{}, opts ...*options.FindOptions) (result *mongo.Cursor, err error) { +//func (c *MongoClient) Find(filter interface{}) (result *mongo.Cursor, err error) { // collection := c.Db.Database(c.DatabaseName).Collection(c.collectionName) -// result, err = collection.Find(context.TODO(), filter, opts...) +// result, err = collection.Find(context.TODO(), filter) // return result, err //} // FindOne 查询单个文档 -func (c *MongoClient) FindOne(ctx context.Context, filter interface{}, opts ...*options.FindOneOptions) (result *mongo.SingleResult) { +func (c *MongoClient) FindOne(filter interface{}) (result *mongo.SingleResult) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) - result = collection.FindOne(ctx, filter, opts...) + result = collection.FindOne(context.TODO(), filter) return result } // FindMany 查询多个文档 -func (c *MongoClient) FindMany(ctx context.Context, filter interface{}, opts ...*options.FindOptions) (result *mongo.Cursor, err error) { +func (c *MongoClient) FindMany(filter interface{}) (result *mongo.Cursor, err error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) - result, err = collection.Find(ctx, filter, opts...) + result, err = collection.Find(context.TODO(), filter) return result, err } // FindManyByFilters 多条件查询 -func (c *MongoClient) FindManyByFilters(ctx context.Context, filter interface{}) (result *mongo.Cursor, err error) { +func (c *MongoClient) FindManyByFilters(filter interface{}) (result *mongo.Cursor, err error) { collection, err := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName).Clone() - result, err = collection.Find(ctx, bson.M{"$and": filter}) + result, err = collection.Find(context.TODO(), bson.M{"$and": filter}) return result, err } // FindManyByFiltersSort 多条件查询支持排序 -func (c *MongoClient) FindManyByFiltersSort(ctx context.Context, filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) { +func (c *MongoClient) FindManyByFiltersSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) { collection, err := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName).Clone() findOptions := options.Find() findOptions.SetSort(Sort) - result, err = collection.Find(ctx, filter, findOptions) + result, err = collection.Find(context.TODO(), filter, findOptions) return result, err } // FindCollection 查询集合文档 -func (c *MongoClient) FindCollection(ctx context.Context, Limit int64) (result *mongo.Cursor, err error) { +func (c *MongoClient) FindCollection(Limit int64) (result *mongo.Cursor, err error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) findOptions := options.Find() findOptions.SetLimit(Limit) - result, err = collection.Find(ctx, bson.D{{}}, findOptions) + result, err = collection.Find(context.TODO(), bson.D{{}}, findOptions) return result, err } // FindCollectionSort 查询集合文档支持排序 -func (c *MongoClient) FindCollectionSort(ctx context.Context, Sort interface{}, Limit int64) (result *mongo.Cursor, err error) { +func (c *MongoClient) FindCollectionSort(Sort interface{}, Limit int64) (result *mongo.Cursor, err error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) findOptions := options.Find() findOptions.SetSort(Sort) findOptions.SetLimit(Limit) - result, err = collection.Find(ctx, bson.D{{}}, findOptions) + result, err = collection.Find(context.TODO(), bson.D{{}}, findOptions) return result, err } // FindManyCollectionSort 查询集合文档支持排序支持条件 -func (c *MongoClient) FindManyCollectionSort(ctx context.Context, filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) { +func (c *MongoClient) FindManyCollectionSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) findOptions := options.Find() findOptions.SetSort(Sort) - result, err = collection.Find(ctx, filter, findOptions) + result, err = collection.Find(context.TODO(), filter, findOptions) return result, err } // CollectionCount 查询集合里有多少数据 -func (c *MongoClient) CollectionCount(ctx context.Context) (name string, size int64) { +func (c *MongoClient) CollectionCount() (name string, size int64) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) name = collection.Name() - size, _ = collection.EstimatedDocumentCount(ctx) + size, _ = collection.EstimatedDocumentCount(context.TODO()) return name, size } @@ -135,25 +135,25 @@ func (c *MongoClient) CollectionCount(ctx context.Context) (name string, size in // Skip 跳过 // Limit 读取数量 // sort 1 ,-1 . 1 为升序 , -1 为降序 -func (c *MongoClient) CollectionDocuments(ctx context.Context, Skip, Limit int64, sort int, key string, value interface{}) (result *mongo.Cursor, err error) { +func (c *MongoClient) CollectionDocuments(Skip, Limit int64, sort int, key string, value interface{}) (result *mongo.Cursor, err error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) SORT := bson.D{{"_id", sort}} filter := bson.D{{key, value}} findOptions := options.Find().SetSort(SORT).SetLimit(Limit).SetSkip(Skip) - result, err = collection.Find(ctx, filter, findOptions) + result, err = collection.Find(context.TODO(), filter, findOptions) return result, err } // AggregateByFiltersSort 统计分析 -func (c *MongoClient) AggregateByFiltersSort(ctx context.Context, pipeline interface{}, opts ...*options.AggregateOptions) (result *mongo.Cursor, err error) { +func (c *MongoClient) AggregateByFiltersSort(pipeline interface{}) (result *mongo.Cursor, err error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) - result, err = collection.Aggregate(ctx, pipeline, opts...) + result, err = collection.Aggregate(context.TODO(), pipeline) return result, err } // CountDocumentsByFilters 统计数量 -func (c *MongoClient) CountDocumentsByFilters(ctx context.Context, filter interface{}) (count int64, err error) { +func (c *MongoClient) CountDocumentsByFilters(filter interface{}) (count int64, err error) { collection := c.Db.Database(c.getDatabaseName()).Collection(c.collectionName) - count, err = collection.CountDocuments(ctx, filter) + count, err = collection.CountDocuments(context.TODO(), filter) return count, err } diff --git a/utils/dorm/mongo_transaction.go b/utils/dorm/mongo_transaction.go index d59198bb..c5f9d185 100644 --- a/utils/dorm/mongo_transaction.go +++ b/utils/dorm/mongo_transaction.go @@ -2,19 +2,27 @@ package dorm import ( "context" + "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "reflect" ) type MongoTransaction struct { - startSession mongo.Session - Session mongo.SessionContext + startSession mongo.Session + Session mongo.SessionContext + db *mongo.Client // 驱动 + databaseName string // 库名 + collectionName string // 表名 } // Begin 开始事务,会同时创建开始会话需要在退出时关闭会话 func (c *MongoClient) Begin() (ms MongoTransaction, err error) { + ms.db = c.Db + // 开始会话 - ms.startSession, err = c.Db.StartSession() + ms.startSession, err = ms.db.StartSession() if err != nil { panic(err) } @@ -41,3 +49,187 @@ func (ms *MongoTransaction) Rollback() error { func (ms *MongoTransaction) Commit() error { return ms.startSession.CommitTransaction(context.Background()) } + +// Database 设置库名 +func (ms *MongoTransaction) Database(databaseName string) *MongoTransaction { + ms.databaseName = databaseName + return ms +} + +// Collection 设置表名 +func (ms *MongoTransaction) Collection(collectionName string) *MongoTransaction { + ms.collectionName = collectionName + return ms +} + +// Model 传入模型自动获取库名和表名 +func (ms *MongoTransaction) Model(value interface{}) *MongoTransaction { + // https://studygolang.com/articles/896 + val := reflect.ValueOf(value) + if methodValue := val.MethodByName("Database"); methodValue.IsValid() { + ms.databaseName = methodValue.Call(nil)[0].String() + } + if methodValue := val.MethodByName("TableName"); methodValue.IsValid() { + ms.collectionName = methodValue.Call(nil)[0].String() + } + return ms +} + +func (ms *MongoTransaction) getDatabaseName() string { + return ms.databaseName +} + +func (ms *MongoTransaction) getCollectionName() string { + return ms.collectionName +} + +// InsertOne 插入单个文档 +func (ms *MongoTransaction) InsertOne(document interface{}) (result *mongo.InsertOneResult, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + result, err = collection.InsertOne(ms.Session, document) + return result, err +} + +// InsertMany 插入多个文档 +func (ms *MongoTransaction) InsertMany(documents []interface{}) (result *mongo.InsertManyResult, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + result, err = collection.InsertMany(ms.Session, documents) + return result, err +} + +// Delete 删除文档 +func (ms *MongoTransaction) Delete(filter interface{}) (int64, error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + count, err := collection.DeleteOne(ms.Session, filter) + return count.DeletedCount, err +} + +// DeleteMany 删除多个文档 +func (ms *MongoTransaction) DeleteMany(key string, value interface{}) (int64, error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + filter := bson.D{{key, value}} + count, err := collection.DeleteMany(ms.Session, filter) + return count.DeletedCount, err +} + +// UpdateOne 更新单个文档 +// 修改字段的值($set) +// 字段增加值 inc($inc) +// 从数组中增加一个元素 push($push) +// 从数组中删除一个元素 pull($pull) +func (ms *MongoTransaction) UpdateOne(filter interface{}, update interface{}) (int64, error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + result, err := collection.UpdateOne(ms.Session, filter, update) + return result.UpsertedCount, err +} + +// UpdateMany 更新多个文档 +// 修改字段的值($set) +// 字段增加值 inc($inc) +// 从数组中增加一个元素 push($push) +// 从数组中删除一个元素 pull($pull) +func (ms *MongoTransaction) UpdateMany(filter interface{}, update interface{}) (int64, error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + result, err := collection.UpdateMany(ms.Session, filter, update) + return result.UpsertedCount, err +} + +// Find 查询 +func (ms *MongoTransaction) Find(filter interface{}) (result *mongo.Cursor, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + result, err = collection.Find(ms.Session, filter) + return result, err +} + +// FindOne 查询单个文档 +func (ms *MongoTransaction) FindOne(filter interface{}) (result *mongo.SingleResult) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + result = collection.FindOne(ms.Session, filter) + return result +} + +// FindMany 查询多个文档 +func (ms *MongoTransaction) FindMany(filter interface{}) (result *mongo.Cursor, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + result, err = collection.Find(ms.Session, filter) + return result, err +} + +// FindManyByFilters 多条件查询 +func (ms *MongoTransaction) FindManyByFilters(filter interface{}) (result *mongo.Cursor, err error) { + collection, err := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName).Clone() + result, err = collection.Find(ms.Session, bson.M{"$and": filter}) + return result, err +} + +// FindManyByFiltersSort 多条件查询支持排序 +func (ms *MongoTransaction) FindManyByFiltersSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) { + collection, err := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName).Clone() + findOptions := options.Find() + findOptions.SetSort(Sort) + result, err = collection.Find(ms.Session, filter, findOptions) + return result, err +} + +// FindCollection 查询集合文档 +func (ms *MongoTransaction) FindCollection(Limit int64) (result *mongo.Cursor, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + findOptions := options.Find() + findOptions.SetLimit(Limit) + result, err = collection.Find(ms.Session, bson.D{{}}, findOptions) + return result, err +} + +// FindCollectionSort 查询集合文档支持排序 +func (ms *MongoTransaction) FindCollectionSort(Sort interface{}, Limit int64) (result *mongo.Cursor, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + findOptions := options.Find() + findOptions.SetSort(Sort) + findOptions.SetLimit(Limit) + result, err = collection.Find(ms.Session, bson.D{{}}, findOptions) + return result, err +} + +// FindManyCollectionSort 查询集合文档支持排序支持条件 +func (ms *MongoTransaction) FindManyCollectionSort(filter interface{}, Sort interface{}) (result *mongo.Cursor, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + findOptions := options.Find() + findOptions.SetSort(Sort) + result, err = collection.Find(ms.Session, filter, findOptions) + return result, err +} + +// CollectionCount 查询集合里有多少数据 +func (ms *MongoTransaction) CollectionCount(ctx context.Context) (name string, size int64) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + name = collection.Name() + size, _ = collection.EstimatedDocumentCount(ctx) + return name, size +} + +// CollectionDocuments 按选项查询集合 +// Skip 跳过 +// Limit 读取数量 +// sort 1 ,-1 . 1 为升序 , -1 为降序 +func (ms *MongoTransaction) CollectionDocuments(Skip, Limit int64, sort int, key string, value interface{}) (result *mongo.Cursor, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + SORT := bson.D{{"_id", sort}} + filter := bson.D{{key, value}} + findOptions := options.Find().SetSort(SORT).SetLimit(Limit).SetSkip(Skip) + result, err = collection.Find(ms.Session, filter, findOptions) + return result, err +} + +// AggregateByFiltersSort 统计分析 +func (ms *MongoTransaction) AggregateByFiltersSort(pipeline interface{}) (result *mongo.Cursor, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + result, err = collection.Aggregate(ms.Session, pipeline) + return result, err +} + +// CountDocumentsByFilters 统计数量 +func (ms *MongoTransaction) CountDocumentsByFilters(filter interface{}) (count int64, err error) { + collection := ms.db.Database(ms.getDatabaseName()).Collection(ms.collectionName) + count, err = collection.CountDocuments(ms.Session, filter) + return count, err +}