You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-library/vendor/gorm.io/gen/internal/generate/section.go

677 lines
15 KiB

package generate
import (
"fmt"
"strconv"
"strings"
"gorm.io/gen/internal/model"
)
// NewSection create and initialize Sections
func NewSection() *Section {
return &Section{
ClauseTotal: map[model.Status]int{
model.WHERE: 0,
model.SET: 0,
},
}
}
// Section split sql into chunks
type Section struct {
members []section
Tmpls []string
currentIndex int
ClauseTotal map[model.Status]int
forValue []ForRange
}
// next return next section and increase index by 1
func (s *Section) next() section {
if s.currentIndex < len(s.members)-1 {
s.currentIndex++
return s.members[s.currentIndex]
}
return section{Type: model.END}
}
// SubIndex take index one step back
func (s *Section) SubIndex() {
s.currentIndex--
}
// HasMore is has more section
func (s *Section) HasMore() bool {
return s.currentIndex < len(s.members)-1
}
// IsNull whether section is empty
func (s *Section) IsNull() bool {
return len(s.members) == 0
}
// current return current section
func (s *Section) current() section {
return s.members[s.currentIndex]
}
func (s *Section) appendTmpl(value string) {
s.Tmpls = append(s.Tmpls, value)
}
func (s *Section) hasSameName(value string) bool {
for _, p := range s.members {
if p.Type == model.FOR && p.ForRange.value == value {
return true
}
}
return false
}
// BuildSQL sql sections and append to tmpl, return a Clause array
func (s *Section) BuildSQL() ([]Clause, error) {
if s.IsNull() {
return nil, fmt.Errorf("sql is null")
}
name := "generateSQL"
res := make([]Clause, 0, len(s.members))
for {
c := s.current()
switch c.Type {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(name)
res = append(res, sqlClause)
s.appendTmpl(sqlClause.Finish())
case model.IF:
ifClause, err := s.parseIF(name)
if err != nil {
return nil, err
}
res = append(res, ifClause)
s.appendTmpl(ifClause.Finish())
case model.WHERE:
whereClause, err := s.parseWhere()
if err != nil {
return nil, err
}
res = append(res, whereClause)
s.appendTmpl(whereClause.Finish(name))
case model.SET:
setClause, err := s.parseSet()
if err != nil {
return nil, err
}
res = append(res, setClause)
s.appendTmpl(setClause.Finish(name))
case model.TRIM:
trimClause, err := s.parseTrim()
if err != nil {
return nil, err
}
res = append(res, trimClause)
s.appendTmpl(trimClause.Finish(name))
case model.FOR:
forClause, err := s.parseFor(name)
_, _ = forClause, err
if err != nil {
return nil, err
}
res = append(res, forClause)
s.appendTmpl(forClause.Finish())
case model.END:
default:
return nil, fmt.Errorf("unknow clause:%s", c.Value)
}
if !s.HasMore() {
break
}
c = s.next()
}
return res, nil
}
// parseIF parse if clause
func (s *Section) parseIF(name string) (res IfClause, err error) {
c := s.current()
res.slice = c
s.appendTmpl(res.Create())
if !s.HasMore() {
return
}
c = s.next()
for {
switch c.Type {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(name)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finish())
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(name)
if err != nil {
return
}
res.Value = append(res.Value, ifClause)
s.appendTmpl(ifClause.Finish())
case model.WHERE:
var whereClause WhereClause
whereClause, err = s.parseWhere()
if err != nil {
return
}
res.Value = append(res.Value, whereClause)
s.appendTmpl(whereClause.Finish(name))
case model.SET:
var setClause SetClause
setClause, err = s.parseSet()
if err != nil {
return
}
res.Value = append(res.Value, setClause)
s.appendTmpl(setClause.Finish(name))
case model.ELSE:
var elseClause ElseClause
elseClause, err = s.parseElSE(name)
if err != nil {
return
}
res.Value = append(res.Value, elseClause)
case model.FOR:
var forClause ForClause
forClause, err = s.parseFor(name)
if err != nil {
return
}
res.Value = append(res.Value, forClause)
s.appendTmpl(res.Finish())
case model.TRIM:
var trimClause TrimClause
trimClause, err = s.parseTrim()
if err != nil {
return
}
res.Value = append(res.Value, trimClause)
s.appendTmpl(trimClause.Finish(name))
case model.END:
return
default:
err = fmt.Errorf("unknow clause : %s", c.Value)
return
}
if !s.HasMore() {
break
}
c = s.next()
}
if c.isEnd() {
err = fmt.Errorf("incomplete SQL,if not end")
}
return
}
// parseElSE parse else clause, the clause' type must be one of if, where, set, SQL condition
func (s *Section) parseElSE(name string) (res ElseClause, err error) {
res.slice = s.current()
s.appendTmpl(res.Create())
if !s.HasMore() {
return
}
c := s.next()
for {
switch c.Type {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(name)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Create())
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(name)
if err != nil {
return
}
res.Value = append(res.Value, ifClause)
s.appendTmpl(ifClause.Finish())
case model.WHERE:
var whereClause WhereClause
whereClause, err = s.parseWhere()
if err != nil {
return
}
res.Value = append(res.Value, whereClause)
s.appendTmpl(whereClause.Finish(name))
case model.SET:
var setClause SetClause
setClause, err = s.parseSet()
if err != nil {
return
}
res.Value = append(res.Value, setClause)
s.appendTmpl(setClause.Finish(name))
case model.ELSE:
var elseClause ElseClause
elseClause, err = s.parseElSE(name)
if err != nil {
return
}
res.Value = append(res.Value, elseClause)
case model.FOR:
var forClause ForClause
forClause, err = s.parseFor(name)
if err != nil {
return
}
res.Value = append(res.Value, forClause)
s.appendTmpl(forClause.Finish())
case model.TRIM:
var trimClause TrimClause
trimClause, err = s.parseTrim()
if err != nil {
return
}
res.Value = append(res.Value, trimClause)
s.appendTmpl(trimClause.Finish(name))
default:
s.SubIndex()
return
}
if !s.HasMore() {
break
}
c = s.next()
}
return
}
// parseWhere parse where clause, the clause' type must be one of if, SQL condition
func (s *Section) parseWhere() (res WhereClause, err error) {
c := s.current()
res.VarName = s.GetName(c.Type)
s.appendTmpl(res.Create())
res.Type = c.Type
if !s.HasMore() {
return
}
c = s.next()
for {
switch c.Type {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(res.VarName)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finish())
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(res.VarName)
if err != nil {
return
}
res.Value = append(res.Value, ifClause)
s.appendTmpl(ifClause.Finish())
case model.FOR:
var forClause ForClause
forClause, err = s.parseFor(res.VarName)
if err != nil {
return
}
res.Value = append(res.Value, forClause)
s.appendTmpl(forClause.Finish())
case model.TRIM:
var trimClause TrimClause
trimClause, err = s.parseTrim()
if err != nil {
return
}
res.Value = append(res.Value, trimClause)
s.appendTmpl(trimClause.Finish(res.VarName))
case model.END:
return
default:
err = fmt.Errorf("unknow clause : %s", c.Value)
return
}
if !s.HasMore() {
break
}
c = s.next()
}
if c.isEnd() {
return
}
err = fmt.Errorf("incomplete SQL,where not end")
return
}
// parseSet parse set clause, the clause' type must be one of if, SQL condition
func (s *Section) parseSet() (res SetClause, err error) {
c := s.current()
res.VarName = s.GetName(c.Type)
s.appendTmpl(res.Create())
if !s.HasMore() {
return
}
c = s.next()
res.Type = c.Type
for {
switch c.Type {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(res.VarName)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finish())
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(res.VarName)
if err != nil {
return
}
res.Value = append(res.Value, ifClause)
s.appendTmpl(ifClause.Finish())
case model.FOR:
var forClause ForClause
forClause, err = s.parseFor(res.VarName)
if err != nil {
return
}
res.Value = append(res.Value, forClause)
s.appendTmpl(forClause.Finish())
case model.TRIM:
var trimClause TrimClause
trimClause, err = s.parseTrim()
if err != nil {
return
}
res.Value = append(res.Value, trimClause)
s.appendTmpl(trimClause.Finish(res.VarName))
case model.END:
return
default:
err = fmt.Errorf("unknow clause : %s", c.Value)
return
}
if !s.HasMore() {
break
}
c = s.next()
}
if c.isEnd() {
err = fmt.Errorf("incomplete SQL,set not end")
}
return
}
// parseTrim parse set clause, the clause' type must be one of if, SQL condition
func (s *Section) parseTrim() (res TrimClause, err error) {
c := s.current()
res.VarName = s.GetName(c.Type)
s.appendTmpl(res.Create())
if !s.HasMore() {
return
}
c = s.next()
res.Type = c.Type
for {
switch c.Type {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(res.VarName)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finish())
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(res.VarName)
if err != nil {
return
}
res.Value = append(res.Value, ifClause)
s.appendTmpl(ifClause.Finish())
case model.FOR:
var forClause ForClause
forClause, err = s.parseFor(res.VarName)
if err != nil {
return
}
res.Value = append(res.Value, forClause)
s.appendTmpl(forClause.Finish())
case model.END:
return
default:
err = fmt.Errorf("unknow clause : %s", c.Value)
return
}
if !s.HasMore() {
break
}
c = s.next()
}
if c.isEnd() {
err = fmt.Errorf("incomplete SQL,set not end")
}
return
}
func (s *Section) parseFor(name string) (res ForClause, err error) {
c := s.current()
res.forSlice = c
s.appendTmpl(res.Create())
s.forValue = append(s.forValue, res.forSlice.ForRange)
if !s.HasMore() {
return
}
c = s.next()
for {
switch c.Type {
case model.SQL, model.DATA, model.VARIABLE:
strClause := s.parseSQL(name)
res.Value = append(res.Value, strClause)
s.appendTmpl(fmt.Sprintf("%s.WriteString(%s)", name, strClause.String()))
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(name)
if err != nil {
return
}
res.Value = append(res.Value, ifClause)
s.appendTmpl(ifClause.Finish())
case model.FOR:
var forClause ForClause
forClause, err = s.parseFor(name)
if err != nil {
return
}
res.Value = append(res.Value, forClause)
s.appendTmpl(forClause.Finish())
case model.TRIM:
var trimClause TrimClause
trimClause, err = s.parseTrim()
if err != nil {
return
}
res.Value = append(res.Value, trimClause)
s.appendTmpl(trimClause.Finish(name))
case model.END:
s.forValue = s.forValue[:len(s.forValue)-1]
return
default:
err = fmt.Errorf("unknow clause : %s", c.Value)
return
}
if !s.HasMore() {
break
}
c = s.next()
}
if c.isEnd() {
err = fmt.Errorf("incomplete SQL,set not end")
}
return
}
// parseSQL parse sql condition, the clause' type must be one of SQL condition, VARIABLE, Data
func (s *Section) parseSQL(name string) (res SQLClause) {
res.VarName = name
res.Type = model.SQL
for {
c := s.current()
switch c.Type {
case model.SQL:
res.Value = append(res.Value, c.Value)
case model.VARIABLE:
res.Value = append(res.Value, c.Value)
case model.DATA:
s.appendTmpl(fmt.Sprintf("params = append(params,%s)", c.Value))
res.Value = append(res.Value, "\"?\"")
default:
s.SubIndex()
return
}
if !s.HasMore() {
return
}
c = s.next()
}
}
// checkSQLVar check sql variable by for loops value and external params
func (s *Section) checkSQLVar(param string, status model.Status, method *InterfaceMethod) (result section, err error) {
if status == model.VARIABLE && param == "table" {
result = section{
Type: model.SQL,
Value: strconv.Quote(method.Table),
}
return
}
if status == model.DATA {
method.HasForParams = true
}
if status == model.VARIABLE {
param = fmt.Sprintf("%s.Quote(%s)", method.S, param)
}
result = section{
Type: status,
Value: param,
}
return
}
// GetName ...
func (s *Section) GetName(status model.Status) string {
switch status {
case model.WHERE:
defer func() { s.ClauseTotal[model.WHERE]++ }()
return fmt.Sprintf("whereSQL%d", s.ClauseTotal[model.WHERE])
case model.SET:
defer func() { s.ClauseTotal[model.SET]++ }()
return fmt.Sprintf("setSQL%d", s.ClauseTotal[model.SET])
case model.TRIM:
defer func() { s.ClauseTotal[model.TRIM]++ }()
return fmt.Sprintf("trimSQL%d", s.ClauseTotal[model.TRIM])
default:
return "generateSQL"
}
}
// checkTemplate check sql template's syntax (if/else/where/set/for)
func (s *Section) checkTemplate(tmpl string) (part section, err error) {
part.Value = tmpl
part.SQLSlice = s
part.splitTemplate()
err = part.checkTemplate()
return
}
type section struct {
Type model.Status
Value string
ForRange ForRange
SQLSlice *Section
splitList []string
}
func (s *section) isEnd() bool {
return s.Type == model.END
}
func (s *section) String() string {
if s.Type == model.FOR {
return s.ForRange.String()
}
return s.Value
}
func (s *section) splitTemplate() {
s.splitList = strings.FieldsFunc(strings.TrimSpace(s.Value), func(r rune) bool {
return r == ':' || r == ' ' || r == '=' || r == ','
})
}
func (s *section) checkTemplate() error {
if len(s.splitList) == 0 {
return fmt.Errorf("template is null")
}
if model.GenKeywords.Contain(s.Value) {
return fmt.Errorf("template can not use gen keywords")
}
err := s.sectionType(s.splitList[0])
if err != nil {
return err
}
if s.Type == model.FOR {
if len(s.splitList) != 5 {
return fmt.Errorf("for range syntax error: %s", s.Value)
}
if s.SQLSlice.hasSameName(s.splitList[2]) {
return fmt.Errorf("cannot use the same value name in different for loops")
}
s.ForRange.index = s.splitList[1]
s.ForRange.value = s.splitList[2]
s.ForRange.rangeList = s.splitList[4]
}
return nil
}
func (s *section) sectionType(str string) error {
switch str {
case "if":
s.Type = model.IF
case "else":
s.Type = model.ELSE
case "for":
s.Type = model.FOR
case "where":
s.Type = model.WHERE
case "set":
s.Type = model.SET
case "end":
s.Type = model.END
case "trim":
s.Type = model.TRIM
default:
return fmt.Errorf("unknown syntax: %s", str)
}
return nil
}
func (s *section) SQLParamName() string {
return strings.Replace(s.Value, ".", "", -1)
}
// ForRange for range clause for diy method
type ForRange struct {
index string
value string
suffix string
rangeList string
}
func (f *ForRange) String() string {
return fmt.Sprintf("for %s, %s := range %s", f.index, f.value, f.rangeList)
}