You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
281 lines
6.6 KiB
281 lines
6.6 KiB
package datatypes
|
|
|
|
import (
|
|
"context"
|
|
"database/sql/driver"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
"gorm.io/gorm/schema"
|
|
)
|
|
|
|
// JSON defined JSON data type, need to implements driver.Valuer, sql.Scanner interface
|
|
type JSON json.RawMessage
|
|
|
|
// Value return json value, implement driver.Valuer interface
|
|
func (j JSON) Value() (driver.Value, error) {
|
|
if len(j) == 0 {
|
|
return nil, nil
|
|
}
|
|
return string(j), nil
|
|
}
|
|
|
|
// Scan scan value into Jsonb, implements sql.Scanner interface
|
|
func (j *JSON) Scan(value interface{}) error {
|
|
if value == nil {
|
|
*j = JSON("null")
|
|
return nil
|
|
}
|
|
var bytes []byte
|
|
switch v := value.(type) {
|
|
case []byte:
|
|
if len(v) > 0 {
|
|
bytes = make([]byte, len(v))
|
|
copy(bytes, v)
|
|
}
|
|
case string:
|
|
bytes = []byte(v)
|
|
default:
|
|
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
|
|
}
|
|
|
|
result := json.RawMessage(bytes)
|
|
*j = JSON(result)
|
|
return nil
|
|
}
|
|
|
|
// MarshalJSON to output non base64 encoded []byte
|
|
func (j JSON) MarshalJSON() ([]byte, error) {
|
|
return json.RawMessage(j).MarshalJSON()
|
|
}
|
|
|
|
// UnmarshalJSON to deserialize []byte
|
|
func (j *JSON) UnmarshalJSON(b []byte) error {
|
|
result := json.RawMessage{}
|
|
err := result.UnmarshalJSON(b)
|
|
*j = JSON(result)
|
|
return err
|
|
}
|
|
|
|
func (j JSON) String() string {
|
|
return string(j)
|
|
}
|
|
|
|
// GormDataType gorm common data type
|
|
func (JSON) GormDataType() string {
|
|
return "json"
|
|
}
|
|
|
|
// GormDBDataType gorm db data type
|
|
func (JSON) GormDBDataType(db *gorm.DB, field *schema.Field) string {
|
|
switch db.Dialector.Name() {
|
|
case "sqlite":
|
|
return "JSON"
|
|
case "mysql":
|
|
return "JSON"
|
|
case "postgres":
|
|
return "JSONB"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (js JSON) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
|
|
if len(js) == 0 {
|
|
return gorm.Expr("NULL")
|
|
}
|
|
|
|
data, _ := js.MarshalJSON()
|
|
|
|
switch db.Dialector.Name() {
|
|
case "mysql":
|
|
if v, ok := db.Dialector.(*mysql.Dialector); ok && !strings.Contains(v.ServerVersion, "MariaDB") {
|
|
return gorm.Expr("CAST(? AS JSON)", string(data))
|
|
}
|
|
}
|
|
|
|
return gorm.Expr("?", string(data))
|
|
}
|
|
|
|
// JSONQueryExpression json query expression, implements clause.Expression interface to use as querier
|
|
type JSONQueryExpression struct {
|
|
column string
|
|
keys []string
|
|
hasKeys bool
|
|
equals bool
|
|
equalsValue interface{}
|
|
extract bool
|
|
path string
|
|
}
|
|
|
|
// JSONQuery query column as json
|
|
func JSONQuery(column string) *JSONQueryExpression {
|
|
return &JSONQueryExpression{column: column}
|
|
}
|
|
|
|
// Extract extract json with path
|
|
func (jsonQuery *JSONQueryExpression) Extract(path string) *JSONQueryExpression {
|
|
jsonQuery.extract = true
|
|
jsonQuery.path = path
|
|
return jsonQuery
|
|
}
|
|
|
|
// HasKey returns clause.Expression
|
|
func (jsonQuery *JSONQueryExpression) HasKey(keys ...string) *JSONQueryExpression {
|
|
jsonQuery.keys = keys
|
|
jsonQuery.hasKeys = true
|
|
return jsonQuery
|
|
}
|
|
|
|
// Keys returns clause.Expression
|
|
func (jsonQuery *JSONQueryExpression) Equals(value interface{}, keys ...string) *JSONQueryExpression {
|
|
jsonQuery.keys = keys
|
|
jsonQuery.equals = true
|
|
jsonQuery.equalsValue = value
|
|
return jsonQuery
|
|
}
|
|
|
|
// Build implements clause.Expression
|
|
func (jsonQuery *JSONQueryExpression) Build(builder clause.Builder) {
|
|
if stmt, ok := builder.(*gorm.Statement); ok {
|
|
switch stmt.Dialector.Name() {
|
|
case "mysql", "sqlite":
|
|
switch {
|
|
case jsonQuery.extract:
|
|
builder.WriteString("JSON_EXTRACT(")
|
|
builder.WriteQuoted(jsonQuery.column)
|
|
builder.WriteByte(',')
|
|
builder.AddVar(stmt, jsonQuery.path)
|
|
builder.WriteString(")")
|
|
case jsonQuery.hasKeys:
|
|
if len(jsonQuery.keys) > 0 {
|
|
builder.WriteString("JSON_EXTRACT(")
|
|
builder.WriteQuoted(jsonQuery.column)
|
|
builder.WriteByte(',')
|
|
builder.AddVar(stmt, jsonQueryJoin(jsonQuery.keys))
|
|
builder.WriteString(") IS NOT NULL")
|
|
}
|
|
case jsonQuery.equals:
|
|
if len(jsonQuery.keys) > 0 {
|
|
builder.WriteString("JSON_EXTRACT(")
|
|
builder.WriteQuoted(jsonQuery.column)
|
|
builder.WriteByte(',')
|
|
builder.AddVar(stmt, jsonQueryJoin(jsonQuery.keys))
|
|
builder.WriteString(") = ")
|
|
if value, ok := jsonQuery.equalsValue.(bool); ok {
|
|
builder.WriteString(strconv.FormatBool(value))
|
|
} else {
|
|
stmt.AddVar(builder, jsonQuery.equalsValue)
|
|
}
|
|
}
|
|
}
|
|
case "postgres":
|
|
switch {
|
|
case jsonQuery.hasKeys:
|
|
if len(jsonQuery.keys) > 0 {
|
|
stmt.WriteQuoted(jsonQuery.column)
|
|
stmt.WriteString("::jsonb")
|
|
for _, key := range jsonQuery.keys[0 : len(jsonQuery.keys)-1] {
|
|
stmt.WriteString(" -> ")
|
|
stmt.AddVar(builder, key)
|
|
}
|
|
|
|
stmt.WriteString(" ? ")
|
|
stmt.AddVar(builder, jsonQuery.keys[len(jsonQuery.keys)-1])
|
|
}
|
|
case jsonQuery.equals:
|
|
if len(jsonQuery.keys) > 0 {
|
|
builder.WriteString(fmt.Sprintf("json_extract_path_text(%v::json,", stmt.Quote(jsonQuery.column)))
|
|
|
|
for idx, key := range jsonQuery.keys {
|
|
if idx > 0 {
|
|
builder.WriteByte(',')
|
|
}
|
|
stmt.AddVar(builder, key)
|
|
}
|
|
builder.WriteString(") = ")
|
|
|
|
if _, ok := jsonQuery.equalsValue.(string); ok {
|
|
stmt.AddVar(builder, jsonQuery.equalsValue)
|
|
} else {
|
|
stmt.AddVar(builder, fmt.Sprint(jsonQuery.equalsValue))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// JSONOverlapsExpression JSON_OVERLAPS expression, implements clause.Expression interface to use as querier
|
|
type JSONOverlapsExpression struct {
|
|
column clause.Expression
|
|
val string
|
|
}
|
|
|
|
// JSONOverlaps query column as json
|
|
func JSONOverlaps(column clause.Expression, value string) *JSONOverlapsExpression {
|
|
return &JSONOverlapsExpression{
|
|
column: column,
|
|
val: value,
|
|
}
|
|
}
|
|
|
|
// Build implements clause.Expression
|
|
// only mysql support JSON_OVERLAPS
|
|
func (json *JSONOverlapsExpression) Build(builder clause.Builder) {
|
|
if stmt, ok := builder.(*gorm.Statement); ok {
|
|
switch stmt.Dialector.Name() {
|
|
case "mysql":
|
|
builder.WriteString("JSON_OVERLAPS(")
|
|
json.column.Build(builder)
|
|
builder.WriteString(",")
|
|
builder.AddVar(stmt, json.val)
|
|
builder.WriteString(")")
|
|
}
|
|
}
|
|
}
|
|
|
|
type columnExpression string
|
|
|
|
func Column(col string) columnExpression {
|
|
return columnExpression(col)
|
|
}
|
|
|
|
func (col columnExpression) Build(builder clause.Builder) {
|
|
if stmt, ok := builder.(*gorm.Statement); ok {
|
|
switch stmt.Dialector.Name() {
|
|
case "mysql", "sqlite", "postgres":
|
|
builder.WriteString(stmt.Quote(string(col)))
|
|
}
|
|
}
|
|
}
|
|
|
|
const prefix = "$."
|
|
|
|
func jsonQueryJoin(keys []string) string {
|
|
if len(keys) == 1 {
|
|
return prefix + keys[0]
|
|
}
|
|
|
|
n := len(prefix)
|
|
n += len(keys) - 1
|
|
for i := 0; i < len(keys); i++ {
|
|
n += len(keys[i])
|
|
}
|
|
|
|
var b strings.Builder
|
|
b.Grow(n)
|
|
b.WriteString(prefix)
|
|
b.WriteString(keys[0])
|
|
for _, key := range keys[1:] {
|
|
b.WriteString(".")
|
|
b.WriteString(key)
|
|
}
|
|
return b.String()
|
|
}
|