743 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			743 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package gorm
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"database/sql"
 | 
						|
	"database/sql/driver"
 | 
						|
	"fmt"
 | 
						|
	"reflect"
 | 
						|
	"regexp"
 | 
						|
	"sort"
 | 
						|
	"strconv"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
 | 
						|
	"gorm.io/gorm/clause"
 | 
						|
	"gorm.io/gorm/logger"
 | 
						|
	"gorm.io/gorm/schema"
 | 
						|
	"gorm.io/gorm/utils"
 | 
						|
)
 | 
						|
 | 
						|
// Statement statement
 | 
						|
type Statement struct {
 | 
						|
	*DB
 | 
						|
	TableExpr            *clause.Expr
 | 
						|
	Table                string
 | 
						|
	Model                interface{}
 | 
						|
	Unscoped             bool
 | 
						|
	Dest                 interface{}
 | 
						|
	ReflectValue         reflect.Value
 | 
						|
	Clauses              map[string]clause.Clause
 | 
						|
	BuildClauses         []string
 | 
						|
	Distinct             bool
 | 
						|
	Selects              []string // selected columns
 | 
						|
	Omits                []string // omit columns
 | 
						|
	Joins                []join
 | 
						|
	Preloads             map[string][]interface{}
 | 
						|
	Settings             sync.Map
 | 
						|
	ConnPool             ConnPool
 | 
						|
	Schema               *schema.Schema
 | 
						|
	Context              context.Context
 | 
						|
	RaiseErrorOnNotFound bool
 | 
						|
	SkipHooks            bool
 | 
						|
	SQL                  strings.Builder
 | 
						|
	Vars                 []interface{}
 | 
						|
	CurDestIndex         int
 | 
						|
	attrs                []interface{}
 | 
						|
	assigns              []interface{}
 | 
						|
	scopes               []func(*DB) *DB
 | 
						|
}
 | 
						|
 | 
						|
type join struct {
 | 
						|
	Name     string
 | 
						|
	Conds    []interface{}
 | 
						|
	On       *clause.Where
 | 
						|
	Selects  []string
 | 
						|
	Omits    []string
 | 
						|
	JoinType clause.JoinType
 | 
						|
}
 | 
						|
 | 
						|
// StatementModifier statement modifier interface
 | 
						|
type StatementModifier interface {
 | 
						|
	ModifyStatement(*Statement)
 | 
						|
}
 | 
						|
 | 
						|
// WriteString write string
 | 
						|
func (stmt *Statement) WriteString(str string) (int, error) {
 | 
						|
	return stmt.SQL.WriteString(str)
 | 
						|
}
 | 
						|
 | 
						|
// WriteByte write byte
 | 
						|
func (stmt *Statement) WriteByte(c byte) error {
 | 
						|
	return stmt.SQL.WriteByte(c)
 | 
						|
}
 | 
						|
 | 
						|
// WriteQuoted write quoted value
 | 
						|
func (stmt *Statement) WriteQuoted(value interface{}) {
 | 
						|
	stmt.QuoteTo(&stmt.SQL, value)
 | 
						|
}
 | 
						|
 | 
						|
// QuoteTo write quoted value to writer
 | 
						|
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
 | 
						|
	write := func(raw bool, str string) {
 | 
						|
		if raw {
 | 
						|
			writer.WriteString(str)
 | 
						|
		} else {
 | 
						|
			stmt.DB.Dialector.QuoteTo(writer, str)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	switch v := field.(type) {
 | 
						|
	case clause.Table:
 | 
						|
		if v.Name == clause.CurrentTable {
 | 
						|
			if stmt.TableExpr != nil {
 | 
						|
				stmt.TableExpr.Build(stmt)
 | 
						|
			} else {
 | 
						|
				write(v.Raw, stmt.Table)
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			write(v.Raw, v.Name)
 | 
						|
		}
 | 
						|
 | 
						|
		if v.Alias != "" {
 | 
						|
			writer.WriteByte(' ')
 | 
						|
			write(v.Raw, v.Alias)
 | 
						|
		}
 | 
						|
	case clause.Column:
 | 
						|
		if v.Table != "" {
 | 
						|
			if v.Table == clause.CurrentTable {
 | 
						|
				write(v.Raw, stmt.Table)
 | 
						|
			} else {
 | 
						|
				write(v.Raw, v.Table)
 | 
						|
			}
 | 
						|
			writer.WriteByte('.')
 | 
						|
		}
 | 
						|
 | 
						|
		if v.Name == clause.PrimaryKey {
 | 
						|
			if stmt.Schema == nil {
 | 
						|
				stmt.DB.AddError(ErrModelValueRequired)
 | 
						|
			} else if stmt.Schema.PrioritizedPrimaryField != nil {
 | 
						|
				write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
 | 
						|
			} else if len(stmt.Schema.DBNames) > 0 {
 | 
						|
				write(v.Raw, stmt.Schema.DBNames[0])
 | 
						|
			} else {
 | 
						|
				stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			write(v.Raw, v.Name)
 | 
						|
		}
 | 
						|
 | 
						|
		if v.Alias != "" {
 | 
						|
			writer.WriteString(" AS ")
 | 
						|
			write(v.Raw, v.Alias)
 | 
						|
		}
 | 
						|
	case []clause.Column:
 | 
						|
		writer.WriteByte('(')
 | 
						|
		for idx, d := range v {
 | 
						|
			if idx > 0 {
 | 
						|
				writer.WriteByte(',')
 | 
						|
			}
 | 
						|
			stmt.QuoteTo(writer, d)
 | 
						|
		}
 | 
						|
		writer.WriteByte(')')
 | 
						|
	case clause.Expr:
 | 
						|
		v.Build(stmt)
 | 
						|
	case string:
 | 
						|
		stmt.DB.Dialector.QuoteTo(writer, v)
 | 
						|
	case []string:
 | 
						|
		writer.WriteByte('(')
 | 
						|
		for idx, d := range v {
 | 
						|
			if idx > 0 {
 | 
						|
				writer.WriteByte(',')
 | 
						|
			}
 | 
						|
			stmt.DB.Dialector.QuoteTo(writer, d)
 | 
						|
		}
 | 
						|
		writer.WriteByte(')')
 | 
						|
	default:
 | 
						|
		stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Quote returns quoted value
 | 
						|
func (stmt *Statement) Quote(field interface{}) string {
 | 
						|
	var builder strings.Builder
 | 
						|
	stmt.QuoteTo(&builder, field)
 | 
						|
	return builder.String()
 | 
						|
}
 | 
						|
 | 
						|
// AddVar add var
 | 
						|
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
 | 
						|
	for idx, v := range vars {
 | 
						|
		if idx > 0 {
 | 
						|
			writer.WriteByte(',')
 | 
						|
		}
 | 
						|
 | 
						|
		switch v := v.(type) {
 | 
						|
		case sql.NamedArg:
 | 
						|
			stmt.Vars = append(stmt.Vars, v.Value)
 | 
						|
		case clause.Column, clause.Table:
 | 
						|
			stmt.QuoteTo(writer, v)
 | 
						|
		case Valuer:
 | 
						|
			reflectValue := reflect.ValueOf(v)
 | 
						|
			if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
 | 
						|
				stmt.AddVar(writer, nil)
 | 
						|
			} else {
 | 
						|
				stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
 | 
						|
			}
 | 
						|
		case clause.Interface:
 | 
						|
			c := clause.Clause{Name: v.Name()}
 | 
						|
			v.MergeClause(&c)
 | 
						|
			c.Build(stmt)
 | 
						|
		case clause.Expression:
 | 
						|
			v.Build(stmt)
 | 
						|
		case driver.Valuer:
 | 
						|
			stmt.Vars = append(stmt.Vars, v)
 | 
						|
			stmt.DB.Dialector.BindVarTo(writer, stmt, v)
 | 
						|
		case []byte:
 | 
						|
			stmt.Vars = append(stmt.Vars, v)
 | 
						|
			stmt.DB.Dialector.BindVarTo(writer, stmt, v)
 | 
						|
		case []interface{}:
 | 
						|
			if len(v) > 0 {
 | 
						|
				writer.WriteByte('(')
 | 
						|
				stmt.AddVar(writer, v...)
 | 
						|
				writer.WriteByte(')')
 | 
						|
			} else {
 | 
						|
				writer.WriteString("(NULL)")
 | 
						|
			}
 | 
						|
		case *DB:
 | 
						|
			subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
 | 
						|
			if v.Statement.SQL.Len() > 0 {
 | 
						|
				var (
 | 
						|
					vars = subdb.Statement.Vars
 | 
						|
					sql  = v.Statement.SQL.String()
 | 
						|
				)
 | 
						|
 | 
						|
				subdb.Statement.Vars = make([]interface{}, 0, len(vars))
 | 
						|
				for _, vv := range vars {
 | 
						|
					subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
 | 
						|
					bindvar := strings.Builder{}
 | 
						|
					v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
 | 
						|
					sql = strings.Replace(sql, bindvar.String(), "?", 1)
 | 
						|
				}
 | 
						|
 | 
						|
				subdb.Statement.SQL.Reset()
 | 
						|
				subdb.Statement.Vars = stmt.Vars
 | 
						|
				if strings.Contains(sql, "@") {
 | 
						|
					clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
 | 
						|
				} else {
 | 
						|
					clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
 | 
						|
				}
 | 
						|
			} else {
 | 
						|
				subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
 | 
						|
				subdb.callbacks.Query().Execute(subdb)
 | 
						|
			}
 | 
						|
 | 
						|
			writer.WriteString(subdb.Statement.SQL.String())
 | 
						|
			stmt.Vars = subdb.Statement.Vars
 | 
						|
		default:
 | 
						|
			switch rv := reflect.ValueOf(v); rv.Kind() {
 | 
						|
			case reflect.Slice, reflect.Array:
 | 
						|
				if rv.Len() == 0 {
 | 
						|
					writer.WriteString("(NULL)")
 | 
						|
				} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
 | 
						|
					stmt.Vars = append(stmt.Vars, v)
 | 
						|
					stmt.DB.Dialector.BindVarTo(writer, stmt, v)
 | 
						|
				} else {
 | 
						|
					writer.WriteByte('(')
 | 
						|
					for i := 0; i < rv.Len(); i++ {
 | 
						|
						if i > 0 {
 | 
						|
							writer.WriteByte(',')
 | 
						|
						}
 | 
						|
						stmt.AddVar(writer, rv.Index(i).Interface())
 | 
						|
					}
 | 
						|
					writer.WriteByte(')')
 | 
						|
				}
 | 
						|
			default:
 | 
						|
				stmt.Vars = append(stmt.Vars, v)
 | 
						|
				stmt.DB.Dialector.BindVarTo(writer, stmt, v)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// AddClause add clause
 | 
						|
func (stmt *Statement) AddClause(v clause.Interface) {
 | 
						|
	if optimizer, ok := v.(StatementModifier); ok {
 | 
						|
		optimizer.ModifyStatement(stmt)
 | 
						|
	} else {
 | 
						|
		name := v.Name()
 | 
						|
		c := stmt.Clauses[name]
 | 
						|
		c.Name = name
 | 
						|
		v.MergeClause(&c)
 | 
						|
		stmt.Clauses[name] = c
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// AddClauseIfNotExists add clause if not exists
 | 
						|
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
 | 
						|
	if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
 | 
						|
		stmt.AddClause(v)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// BuildCondition build condition
 | 
						|
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
 | 
						|
	if s, ok := query.(string); ok {
 | 
						|
		// if it is a number, then treats it as primary key
 | 
						|
		if _, err := strconv.Atoi(s); err != nil {
 | 
						|
			if s == "" && len(args) == 0 {
 | 
						|
				return nil
 | 
						|
			}
 | 
						|
 | 
						|
			if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
 | 
						|
				// looks like a where condition
 | 
						|
				return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
 | 
						|
			}
 | 
						|
 | 
						|
			if len(args) > 0 && strings.Contains(s, "@") {
 | 
						|
				// looks like a named query
 | 
						|
				return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
 | 
						|
			}
 | 
						|
 | 
						|
			if strings.Contains(strings.TrimSpace(s), " ") {
 | 
						|
				// looks like a where condition
 | 
						|
				return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
 | 
						|
			}
 | 
						|
 | 
						|
			if len(args) == 1 {
 | 
						|
				return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	conds := make([]clause.Expression, 0, 4)
 | 
						|
	args = append([]interface{}{query}, args...)
 | 
						|
	for idx, arg := range args {
 | 
						|
		if arg == nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
		if valuer, ok := arg.(driver.Valuer); ok {
 | 
						|
			arg, _ = valuer.Value()
 | 
						|
		}
 | 
						|
 | 
						|
		switch v := arg.(type) {
 | 
						|
		case clause.Expression:
 | 
						|
			conds = append(conds, v)
 | 
						|
		case *DB:
 | 
						|
			v.executeScopes()
 | 
						|
 | 
						|
			if cs, ok := v.Statement.Clauses["WHERE"]; ok {
 | 
						|
				if where, ok := cs.Expression.(clause.Where); ok {
 | 
						|
					if len(where.Exprs) == 1 {
 | 
						|
						if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
 | 
						|
							where.Exprs[0] = clause.AndConditions(orConds)
 | 
						|
						}
 | 
						|
					}
 | 
						|
					conds = append(conds, clause.And(where.Exprs...))
 | 
						|
				} else if cs.Expression != nil {
 | 
						|
					conds = append(conds, cs.Expression)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		case map[interface{}]interface{}:
 | 
						|
			for i, j := range v {
 | 
						|
				conds = append(conds, clause.Eq{Column: i, Value: j})
 | 
						|
			}
 | 
						|
		case map[string]string:
 | 
						|
			keys := make([]string, 0, len(v))
 | 
						|
			for i := range v {
 | 
						|
				keys = append(keys, i)
 | 
						|
			}
 | 
						|
			sort.Strings(keys)
 | 
						|
 | 
						|
			for _, key := range keys {
 | 
						|
				conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
						|
			}
 | 
						|
		case map[string]interface{}:
 | 
						|
			keys := make([]string, 0, len(v))
 | 
						|
			for i := range v {
 | 
						|
				keys = append(keys, i)
 | 
						|
			}
 | 
						|
			sort.Strings(keys)
 | 
						|
 | 
						|
			for _, key := range keys {
 | 
						|
				reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
 | 
						|
				switch reflectValue.Kind() {
 | 
						|
				case reflect.Slice, reflect.Array:
 | 
						|
					if _, ok := v[key].(driver.Valuer); ok {
 | 
						|
						conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
						|
					} else if _, ok := v[key].(Valuer); ok {
 | 
						|
						conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
						|
					} else {
 | 
						|
						// optimize reflect value length
 | 
						|
						valueLen := reflectValue.Len()
 | 
						|
						values := make([]interface{}, valueLen)
 | 
						|
						for i := 0; i < valueLen; i++ {
 | 
						|
							values[i] = reflectValue.Index(i).Interface()
 | 
						|
						}
 | 
						|
 | 
						|
						conds = append(conds, clause.IN{Column: key, Values: values})
 | 
						|
					}
 | 
						|
				default:
 | 
						|
					conds = append(conds, clause.Eq{Column: key, Value: v[key]})
 | 
						|
				}
 | 
						|
			}
 | 
						|
		default:
 | 
						|
			reflectValue := reflect.Indirect(reflect.ValueOf(arg))
 | 
						|
			for reflectValue.Kind() == reflect.Ptr {
 | 
						|
				reflectValue = reflectValue.Elem()
 | 
						|
			}
 | 
						|
 | 
						|
			if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
 | 
						|
				selectedColumns := map[string]bool{}
 | 
						|
				if idx == 0 {
 | 
						|
					for _, v := range args[1:] {
 | 
						|
						if vs, ok := v.(string); ok {
 | 
						|
							selectedColumns[vs] = true
 | 
						|
						}
 | 
						|
					}
 | 
						|
				}
 | 
						|
				restricted := len(selectedColumns) != 0
 | 
						|
 | 
						|
				switch reflectValue.Kind() {
 | 
						|
				case reflect.Struct:
 | 
						|
					for _, field := range s.Fields {
 | 
						|
						selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
 | 
						|
						if selected || (!restricted && field.Readable) {
 | 
						|
							if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
 | 
						|
								if field.DBName != "" {
 | 
						|
									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
 | 
						|
								} else if field.DataType != "" {
 | 
						|
									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
 | 
						|
								}
 | 
						|
							}
 | 
						|
						}
 | 
						|
					}
 | 
						|
				case reflect.Slice, reflect.Array:
 | 
						|
					for i := 0; i < reflectValue.Len(); i++ {
 | 
						|
						for _, field := range s.Fields {
 | 
						|
							selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
 | 
						|
							if selected || (!restricted && field.Readable) {
 | 
						|
								if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
 | 
						|
									if field.DBName != "" {
 | 
						|
										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
 | 
						|
									} else if field.DataType != "" {
 | 
						|
										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
 | 
						|
									}
 | 
						|
								}
 | 
						|
							}
 | 
						|
						}
 | 
						|
					}
 | 
						|
				}
 | 
						|
 | 
						|
				if restricted {
 | 
						|
					break
 | 
						|
				}
 | 
						|
			} else if !reflectValue.IsValid() {
 | 
						|
				stmt.AddError(ErrInvalidData)
 | 
						|
			} else if len(conds) == 0 {
 | 
						|
				if len(args) == 1 {
 | 
						|
					switch reflectValue.Kind() {
 | 
						|
					case reflect.Slice, reflect.Array:
 | 
						|
						// optimize reflect value length
 | 
						|
						valueLen := reflectValue.Len()
 | 
						|
						values := make([]interface{}, valueLen)
 | 
						|
						for i := 0; i < valueLen; i++ {
 | 
						|
							values[i] = reflectValue.Index(i).Interface()
 | 
						|
						}
 | 
						|
 | 
						|
						if len(values) > 0 {
 | 
						|
							conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
 | 
						|
							return []clause.Expression{clause.And(conds...)}
 | 
						|
						}
 | 
						|
						return nil
 | 
						|
					}
 | 
						|
				}
 | 
						|
 | 
						|
				conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if len(conds) > 0 {
 | 
						|
		return []clause.Expression{clause.And(conds...)}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Build build sql with clauses names
 | 
						|
func (stmt *Statement) Build(clauses ...string) {
 | 
						|
	var firstClauseWritten bool
 | 
						|
 | 
						|
	for _, name := range clauses {
 | 
						|
		if c, ok := stmt.Clauses[name]; ok {
 | 
						|
			if firstClauseWritten {
 | 
						|
				stmt.WriteByte(' ')
 | 
						|
			}
 | 
						|
 | 
						|
			firstClauseWritten = true
 | 
						|
			if b, ok := stmt.DB.ClauseBuilders[name]; ok {
 | 
						|
				b(c, stmt)
 | 
						|
			} else {
 | 
						|
				c.Build(stmt)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (stmt *Statement) Parse(value interface{}) (err error) {
 | 
						|
	return stmt.ParseWithSpecialTableName(value, "")
 | 
						|
}
 | 
						|
 | 
						|
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
 | 
						|
	if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
 | 
						|
		if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
 | 
						|
			stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
 | 
						|
			stmt.Table = tables[1]
 | 
						|
			return
 | 
						|
		}
 | 
						|
 | 
						|
		stmt.Table = stmt.Schema.Table
 | 
						|
	}
 | 
						|
	return err
 | 
						|
}
 | 
						|
 | 
						|
func (stmt *Statement) clone() *Statement {
 | 
						|
	newStmt := &Statement{
 | 
						|
		TableExpr:            stmt.TableExpr,
 | 
						|
		Table:                stmt.Table,
 | 
						|
		Model:                stmt.Model,
 | 
						|
		Unscoped:             stmt.Unscoped,
 | 
						|
		Dest:                 stmt.Dest,
 | 
						|
		ReflectValue:         stmt.ReflectValue,
 | 
						|
		Clauses:              map[string]clause.Clause{},
 | 
						|
		Distinct:             stmt.Distinct,
 | 
						|
		Selects:              stmt.Selects,
 | 
						|
		Omits:                stmt.Omits,
 | 
						|
		Preloads:             map[string][]interface{}{},
 | 
						|
		ConnPool:             stmt.ConnPool,
 | 
						|
		Schema:               stmt.Schema,
 | 
						|
		Context:              stmt.Context,
 | 
						|
		RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
 | 
						|
		SkipHooks:            stmt.SkipHooks,
 | 
						|
	}
 | 
						|
 | 
						|
	if stmt.SQL.Len() > 0 {
 | 
						|
		newStmt.SQL.WriteString(stmt.SQL.String())
 | 
						|
		newStmt.Vars = make([]interface{}, 0, len(stmt.Vars))
 | 
						|
		newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
 | 
						|
	}
 | 
						|
 | 
						|
	for k, c := range stmt.Clauses {
 | 
						|
		newStmt.Clauses[k] = c
 | 
						|
	}
 | 
						|
 | 
						|
	for k, p := range stmt.Preloads {
 | 
						|
		newStmt.Preloads[k] = p
 | 
						|
	}
 | 
						|
 | 
						|
	if len(stmt.Joins) > 0 {
 | 
						|
		newStmt.Joins = make([]join, len(stmt.Joins))
 | 
						|
		copy(newStmt.Joins, stmt.Joins)
 | 
						|
	}
 | 
						|
 | 
						|
	if len(stmt.scopes) > 0 {
 | 
						|
		newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
 | 
						|
		copy(newStmt.scopes, stmt.scopes)
 | 
						|
	}
 | 
						|
 | 
						|
	stmt.Settings.Range(func(k, v interface{}) bool {
 | 
						|
		newStmt.Settings.Store(k, v)
 | 
						|
		return true
 | 
						|
	})
 | 
						|
 | 
						|
	return newStmt
 | 
						|
}
 | 
						|
 | 
						|
// SetColumn set column's value
 | 
						|
//
 | 
						|
//	stmt.SetColumn("Name", "jinzhu") // Hooks Method
 | 
						|
//	stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
 | 
						|
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
 | 
						|
	if v, ok := stmt.Dest.(map[string]interface{}); ok {
 | 
						|
		v[name] = value
 | 
						|
	} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
 | 
						|
		for _, m := range v {
 | 
						|
			m[name] = value
 | 
						|
		}
 | 
						|
	} else if stmt.Schema != nil {
 | 
						|
		if field := stmt.Schema.LookUpField(name); field != nil {
 | 
						|
			destValue := reflect.ValueOf(stmt.Dest)
 | 
						|
			for destValue.Kind() == reflect.Ptr {
 | 
						|
				destValue = destValue.Elem()
 | 
						|
			}
 | 
						|
 | 
						|
			if stmt.ReflectValue != destValue {
 | 
						|
				if !destValue.CanAddr() {
 | 
						|
					destValueCanAddr := reflect.New(destValue.Type())
 | 
						|
					destValueCanAddr.Elem().Set(destValue)
 | 
						|
					stmt.Dest = destValueCanAddr.Interface()
 | 
						|
					destValue = destValueCanAddr.Elem()
 | 
						|
				}
 | 
						|
 | 
						|
				switch destValue.Kind() {
 | 
						|
				case reflect.Struct:
 | 
						|
					stmt.AddError(field.Set(stmt.Context, destValue, value))
 | 
						|
				default:
 | 
						|
					stmt.AddError(ErrInvalidData)
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			switch stmt.ReflectValue.Kind() {
 | 
						|
			case reflect.Slice, reflect.Array:
 | 
						|
				if len(fromCallbacks) > 0 {
 | 
						|
					for i := 0; i < stmt.ReflectValue.Len(); i++ {
 | 
						|
						stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
 | 
						|
					}
 | 
						|
				} else {
 | 
						|
					stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
 | 
						|
				}
 | 
						|
			case reflect.Struct:
 | 
						|
				if !stmt.ReflectValue.CanAddr() {
 | 
						|
					stmt.AddError(ErrInvalidValue)
 | 
						|
					return
 | 
						|
				}
 | 
						|
 | 
						|
				stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			stmt.AddError(ErrInvalidField)
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		stmt.AddError(ErrInvalidField)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// Changed check model changed or not when updating
 | 
						|
func (stmt *Statement) Changed(fields ...string) bool {
 | 
						|
	modelValue := stmt.ReflectValue
 | 
						|
	switch modelValue.Kind() {
 | 
						|
	case reflect.Slice, reflect.Array:
 | 
						|
		modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
 | 
						|
	}
 | 
						|
 | 
						|
	selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
 | 
						|
	changed := func(field *schema.Field) bool {
 | 
						|
		fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
 | 
						|
		if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
 | 
						|
			if mv, mok := stmt.Dest.(map[string]interface{}); mok {
 | 
						|
				if fv, ok := mv[field.Name]; ok {
 | 
						|
					return !utils.AssertEqual(fv, fieldValue)
 | 
						|
				} else if fv, ok := mv[field.DBName]; ok {
 | 
						|
					return !utils.AssertEqual(fv, fieldValue)
 | 
						|
				}
 | 
						|
			} else {
 | 
						|
				destValue := reflect.ValueOf(stmt.Dest)
 | 
						|
				for destValue.Kind() == reflect.Ptr {
 | 
						|
					destValue = destValue.Elem()
 | 
						|
				}
 | 
						|
 | 
						|
				changedValue, zero := field.ValueOf(stmt.Context, destValue)
 | 
						|
				if v {
 | 
						|
					return !utils.AssertEqual(changedValue, fieldValue)
 | 
						|
				}
 | 
						|
				return !zero && !utils.AssertEqual(changedValue, fieldValue)
 | 
						|
			}
 | 
						|
		}
 | 
						|
		return false
 | 
						|
	}
 | 
						|
 | 
						|
	if len(fields) == 0 {
 | 
						|
		for _, field := range stmt.Schema.FieldsByDBName {
 | 
						|
			if changed(field) {
 | 
						|
				return true
 | 
						|
			}
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		for _, name := range fields {
 | 
						|
			if field := stmt.Schema.LookUpField(name); field != nil {
 | 
						|
				if changed(field) {
 | 
						|
					return true
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return false
 | 
						|
}
 | 
						|
 | 
						|
var matchName = func() func(tableColumn string) (table, column string) {
 | 
						|
	nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
 | 
						|
	return func(tableColumn string) (table, column string) {
 | 
						|
		if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
 | 
						|
			table = matches[1]
 | 
						|
			star := matches[2]
 | 
						|
			columnName := matches[3]
 | 
						|
			if star != "" {
 | 
						|
				return table, star
 | 
						|
			}
 | 
						|
			return table, columnName
 | 
						|
		}
 | 
						|
		return "", ""
 | 
						|
	}
 | 
						|
}()
 | 
						|
 | 
						|
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
 | 
						|
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
 | 
						|
	results := map[string]bool{}
 | 
						|
	notRestricted := false
 | 
						|
 | 
						|
	processColumn := func(column string, result bool) {
 | 
						|
		if stmt.Schema == nil {
 | 
						|
			results[column] = result
 | 
						|
		} else if column == "*" {
 | 
						|
			notRestricted = result
 | 
						|
			for _, dbName := range stmt.Schema.DBNames {
 | 
						|
				results[dbName] = result
 | 
						|
			}
 | 
						|
		} else if column == clause.Associations {
 | 
						|
			for _, rel := range stmt.Schema.Relationships.Relations {
 | 
						|
				results[rel.Name] = result
 | 
						|
			}
 | 
						|
		} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
 | 
						|
			results[field.DBName] = result
 | 
						|
		} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
 | 
						|
			if col == "*" {
 | 
						|
				for _, dbName := range stmt.Schema.DBNames {
 | 
						|
					results[dbName] = result
 | 
						|
				}
 | 
						|
			} else {
 | 
						|
				results[col] = result
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			results[column] = result
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// select columns
 | 
						|
	for _, column := range stmt.Selects {
 | 
						|
		processColumn(column, true)
 | 
						|
	}
 | 
						|
 | 
						|
	// omit columns
 | 
						|
	for _, column := range stmt.Omits {
 | 
						|
		processColumn(column, false)
 | 
						|
	}
 | 
						|
 | 
						|
	if stmt.Schema != nil {
 | 
						|
		for _, field := range stmt.Schema.FieldsByName {
 | 
						|
			name := field.DBName
 | 
						|
			if name == "" {
 | 
						|
				name = field.Name
 | 
						|
			}
 | 
						|
 | 
						|
			if requireCreate && !field.Creatable {
 | 
						|
				results[name] = false
 | 
						|
			} else if requireUpdate && !field.Updatable {
 | 
						|
				results[name] = false
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return results, !notRestricted && len(stmt.Selects) > 0
 | 
						|
}
 |