You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gomongo/curd.go

154 lines
4.2 KiB

2 years ago
package gomongo
import (
"context"
2 years ago
"errors"
2 years ago
"go.mongodb.org/mongo-driver/bson"
2 years ago
"log"
2 years ago
"reflect"
)
// Database 设置库名
2 years ago
func (c *Client) Database(database string) *Client {
c.DatabaseName = database
return c
2 years ago
}
// Collection 设置表名
2 years ago
func (c *Client) Collection(collection string) *Client {
c.collectionName = collection
return c
2 years ago
}
// Model 传入模型自动获取库名和表名
2 years ago
func (c *Client) Model(value interface{}) *Client {
2 years ago
// https://studygolang.com/articles/896
val := reflect.ValueOf(value)
if methodValue := val.MethodByName("Database"); methodValue.IsValid() {
2 years ago
c.DatabaseName = methodValue.Call(nil)[0].String()
2 years ago
}
if methodValue := val.MethodByName("TableName"); methodValue.IsValid() {
2 years ago
c.collectionName = methodValue.Call(nil)[0].String()
2 years ago
}
2 years ago
return c
2 years ago
}
2 years ago
// CreateResult 返回查询结果
type CreateResult struct {
InsertedID interface{} // 创建一条记录的ID
InsertedIDs []interface{} // 创建多条记录的ID
}
// Create 创建数据
func (c *Client) Create(values ...interface{}) (CreateResult, error) {
collection := c.db.Database(c.DatabaseName).Collection(c.collectionName)
const (
insertTypeOne = "one"
insertTypeMany = "many"
)
var (
insertType string
insertDataOne interface{}
insertDataMany []interface{}
)
for _, value := range values {
switch v := value.(type) {
case map[string]interface{}:
case []map[string]interface{}:
case map[string]string:
case []map[string]string:
default:
sliceValue := reflect.Indirect(reflect.ValueOf(value))
if sliceValue.Kind() == reflect.Slice {
insertType = insertTypeMany
size := sliceValue.Len()
for i := 0; i < size; i++ {
sv := sliceValue.Index(i) // 取出第i个元素
elemValue := sv.Interface() // 原始数据
insertDataMany = append(insertDataMany, elemValue) // 加入到数组中
}
} else {
insertType = insertTypeOne
insertDataOne = v
}
}
}
2 years ago
2 years ago
if insertType == insertTypeOne {
result, err := collection.InsertOne(context.TODO(), insertDataOne)
return CreateResult{InsertedID: result.InsertedID}, err
} else if insertType == insertTypeMany {
result, err := collection.InsertMany(context.TODO(), insertDataMany)
return CreateResult{InsertedIDs: result.InsertedIDs}, err
} else {
return CreateResult{}, errors.New("values is empty")
}
2 years ago
}
2 years ago
// 查询条件
type queryFilter struct {
Key string
Value interface{}
2 years ago
}
2 years ago
// Where 条件
func (c *Client) Where(key string, value interface{}) *Client {
log.Println("key", key)
log.Println("value", value)
c.filterArr = append(c.filterArr, queryFilter{key, value})
c.filter = bson.D{{key, value}}
return c
2 years ago
}
2 years ago
// QueryResult 返回查询结果
type QueryResult struct {
RowsAffected int // 返回找到的记录数
Error error // 错误信息
2 years ago
}
2 years ago
// First 获取第一条记录(主键升序)
func (c *Client) First() *QueryResult {
return &QueryResult{}
2 years ago
}
2 years ago
// Take 获取一条记录,没有指定排序字段
func (c *Client) Take(v interface{}) *QueryResult {
collection := c.db.Database(c.DatabaseName).Collection(c.collectionName)
//log.Printf("c.filterArr%s\n", c.filterArr)
//log.Printf("c.filterArr%v\n", c.filterArr)
//log.Printf("c.filterArr%+v\n", c.filterArr)
//log.Printf("c.filter%s\n", c.filter)
//log.Printf("c.filter%v\n", c.filter)
//log.Printf("c.filter%+v\n", c.filter)
err := collection.FindOne(context.TODO(), c.filter).Decode(v)
return &QueryResult{1, err}
2 years ago
}
2 years ago
// Last 获取最后一条记录(主键降序)
func (c *Client) Last() *QueryResult {
return &QueryResult{}
2 years ago
}
2 years ago
// Find 获取多条记录
func (c *Client) Find(v interface{}) *QueryResult {
collection := c.db.Database(c.DatabaseName).Collection(c.collectionName)
log.Printf("c.filterArr%s\n", c.filterArr)
log.Printf("c.filterArr%v\n", c.filterArr)
log.Printf("c.filterArr%+v\n", c.filterArr)
log.Printf("c.filter%s\n", c.filter)
log.Printf("c.filter%v\n", c.filter)
log.Printf("c.filter%+v\n", c.filter)
cursor, err := collection.Find(context.TODO(), c.filter)
if err != nil {
return &QueryResult{0, err}
}
2 years ago
2 years ago
// 结果遍历和赋值
err = cursor.All(context.TODO(), v)
2 years ago
2 years ago
return &QueryResult{cursor.RemainingBatchLength(), err}
2 years ago
}