300 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			300 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package callbacks
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 
 | |
| 	"gorm.io/gorm"
 | |
| 	"gorm.io/gorm/clause"
 | |
| 	"gorm.io/gorm/schema"
 | |
| 	"gorm.io/gorm/utils"
 | |
| )
 | |
| 
 | |
| func Query(db *gorm.DB) {
 | |
| 	if db.Error == nil {
 | |
| 		BuildQuerySQL(db)
 | |
| 
 | |
| 		if !db.DryRun && db.Error == nil {
 | |
| 			rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
 | |
| 			if err != nil {
 | |
| 				db.AddError(err)
 | |
| 				return
 | |
| 			}
 | |
| 			defer func() {
 | |
| 				db.AddError(rows.Close())
 | |
| 			}()
 | |
| 			gorm.Scan(rows, db, 0)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func BuildQuerySQL(db *gorm.DB) {
 | |
| 	if db.Statement.Schema != nil {
 | |
| 		for _, c := range db.Statement.Schema.QueryClauses {
 | |
| 			db.Statement.AddClause(c)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if db.Statement.SQL.Len() == 0 {
 | |
| 		db.Statement.SQL.Grow(100)
 | |
| 		clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
 | |
| 
 | |
| 		if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
 | |
| 			var conds []clause.Expression
 | |
| 			for _, primaryField := range db.Statement.Schema.PrimaryFields {
 | |
| 				if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
 | |
| 					conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			if len(conds) > 0 {
 | |
| 				db.Statement.AddClause(clause.Where{Exprs: conds})
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if len(db.Statement.Selects) > 0 {
 | |
| 			clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
 | |
| 			for idx, name := range db.Statement.Selects {
 | |
| 				if db.Statement.Schema == nil {
 | |
| 					clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
 | |
| 				} else if f := db.Statement.Schema.LookUpField(name); f != nil {
 | |
| 					clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
 | |
| 				} else {
 | |
| 					clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
 | |
| 				}
 | |
| 			}
 | |
| 		} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
 | |
| 			selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
 | |
| 			clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
 | |
| 			for _, dbName := range db.Statement.Schema.DBNames {
 | |
| 				if v, ok := selectColumns[dbName]; (ok && v) || !ok {
 | |
| 					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
 | |
| 				}
 | |
| 			}
 | |
| 		} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
 | |
| 			queryFields := db.QueryFields
 | |
| 			if !queryFields {
 | |
| 				switch db.Statement.ReflectValue.Kind() {
 | |
| 				case reflect.Struct:
 | |
| 					queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
 | |
| 				case reflect.Slice:
 | |
| 					queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			if queryFields {
 | |
| 				stmt := gorm.Statement{DB: db}
 | |
| 				// smaller struct
 | |
| 				if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
 | |
| 					clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
 | |
| 
 | |
| 					for idx, dbName := range stmt.Schema.DBNames {
 | |
| 						clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// inline joins
 | |
| 		fromClause := clause.From{}
 | |
| 		if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
 | |
| 			fromClause = v
 | |
| 		}
 | |
| 
 | |
| 		if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
 | |
| 			if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
 | |
| 				clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
 | |
| 				for idx, dbName := range db.Statement.Schema.DBNames {
 | |
| 					clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			specifiedRelationsName := make(map[string]interface{})
 | |
| 			for _, join := range db.Statement.Joins {
 | |
| 				if db.Statement.Schema != nil {
 | |
| 					var isRelations bool // is relations or raw sql
 | |
| 					var relations []*schema.Relationship
 | |
| 					relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
 | |
| 					if ok {
 | |
| 						isRelations = true
 | |
| 						relations = append(relations, relation)
 | |
| 					} else {
 | |
| 						// handle nested join like "Manager.Company"
 | |
| 						nestedJoinNames := strings.Split(join.Name, ".")
 | |
| 						if len(nestedJoinNames) > 1 {
 | |
| 							isNestedJoin := true
 | |
| 							gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
 | |
| 							currentRelations := db.Statement.Schema.Relationships.Relations
 | |
| 							for _, relname := range nestedJoinNames {
 | |
| 								// incomplete match, only treated as raw sql
 | |
| 								if relation, ok = currentRelations[relname]; ok {
 | |
| 									gussNestedRelations = append(gussNestedRelations, relation)
 | |
| 									currentRelations = relation.FieldSchema.Relationships.Relations
 | |
| 								} else {
 | |
| 									isNestedJoin = false
 | |
| 									break
 | |
| 								}
 | |
| 							}
 | |
| 
 | |
| 							if isNestedJoin {
 | |
| 								isRelations = true
 | |
| 								relations = gussNestedRelations
 | |
| 							}
 | |
| 						}
 | |
| 					}
 | |
| 
 | |
| 					if isRelations {
 | |
| 						genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
 | |
| 							tableAliasName := relation.Name
 | |
| 							if parentTableName != clause.CurrentTable {
 | |
| 								tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
 | |
| 							}
 | |
| 
 | |
| 							columnStmt := gorm.Statement{
 | |
| 								Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
 | |
| 								Selects: join.Selects, Omits: join.Omits,
 | |
| 							}
 | |
| 
 | |
| 							selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
 | |
| 							for _, s := range relation.FieldSchema.DBNames {
 | |
| 								if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
 | |
| 									clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
 | |
| 										Table: tableAliasName,
 | |
| 										Name:  s,
 | |
| 										Alias: utils.NestedRelationName(tableAliasName, s),
 | |
| 									})
 | |
| 								}
 | |
| 							}
 | |
| 
 | |
| 							exprs := make([]clause.Expression, len(relation.References))
 | |
| 							for idx, ref := range relation.References {
 | |
| 								if ref.OwnPrimaryKey {
 | |
| 									exprs[idx] = clause.Eq{
 | |
| 										Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
 | |
| 										Value:  clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
 | |
| 									}
 | |
| 								} else {
 | |
| 									if ref.PrimaryValue == "" {
 | |
| 										exprs[idx] = clause.Eq{
 | |
| 											Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
 | |
| 											Value:  clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
 | |
| 										}
 | |
| 									} else {
 | |
| 										exprs[idx] = clause.Eq{
 | |
| 											Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
 | |
| 											Value:  ref.PrimaryValue,
 | |
| 										}
 | |
| 									}
 | |
| 								}
 | |
| 							}
 | |
| 
 | |
| 							{
 | |
| 								onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
 | |
| 								for _, c := range relation.FieldSchema.QueryClauses {
 | |
| 									onStmt.AddClause(c)
 | |
| 								}
 | |
| 
 | |
| 								if join.On != nil {
 | |
| 									onStmt.AddClause(join.On)
 | |
| 								}
 | |
| 
 | |
| 								if cs, ok := onStmt.Clauses["WHERE"]; ok {
 | |
| 									if where, ok := cs.Expression.(clause.Where); ok {
 | |
| 										where.Build(&onStmt)
 | |
| 
 | |
| 										if onSQL := onStmt.SQL.String(); onSQL != "" {
 | |
| 											vars := onStmt.Vars
 | |
| 											for idx, v := range vars {
 | |
| 												bindvar := strings.Builder{}
 | |
| 												onStmt.Vars = vars[0 : idx+1]
 | |
| 												db.Dialector.BindVarTo(&bindvar, &onStmt, v)
 | |
| 												onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
 | |
| 											}
 | |
| 
 | |
| 											exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
 | |
| 										}
 | |
| 									}
 | |
| 								}
 | |
| 							}
 | |
| 
 | |
| 							return clause.Join{
 | |
| 								Type:  joinType,
 | |
| 								Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
 | |
| 								ON:    clause.Where{Exprs: exprs},
 | |
| 							}
 | |
| 						}
 | |
| 
 | |
| 						parentTableName := clause.CurrentTable
 | |
| 						for _, rel := range relations {
 | |
| 							// joins table alias like "Manager, Company, Manager__Company"
 | |
| 							nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
 | |
| 							if _, ok := specifiedRelationsName[nestedAlias]; !ok {
 | |
| 								fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
 | |
| 								specifiedRelationsName[nestedAlias] = nil
 | |
| 							}
 | |
| 
 | |
| 							if parentTableName != clause.CurrentTable {
 | |
| 								parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
 | |
| 							} else {
 | |
| 								parentTableName = rel.Name
 | |
| 							}
 | |
| 						}
 | |
| 					} else {
 | |
| 						fromClause.Joins = append(fromClause.Joins, clause.Join{
 | |
| 							Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
 | |
| 						})
 | |
| 					}
 | |
| 				} else {
 | |
| 					fromClause.Joins = append(fromClause.Joins, clause.Join{
 | |
| 						Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
 | |
| 					})
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			db.Statement.AddClause(fromClause)
 | |
| 		} else {
 | |
| 			db.Statement.AddClauseIfNotExists(clause.From{})
 | |
| 		}
 | |
| 
 | |
| 		db.Statement.AddClauseIfNotExists(clauseSelect)
 | |
| 
 | |
| 		db.Statement.Build(db.Statement.BuildClauses...)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func Preload(db *gorm.DB) {
 | |
| 	if db.Error == nil && len(db.Statement.Preloads) > 0 {
 | |
| 		if db.Statement.Schema == nil {
 | |
| 			db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		joins := make([]string, 0, len(db.Statement.Joins))
 | |
| 		for _, join := range db.Statement.Joins {
 | |
| 			joins = append(joins, join.Name)
 | |
| 		}
 | |
| 
 | |
| 		tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
 | |
| 		if tx.Error != nil {
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func AfterQuery(db *gorm.DB) {
 | |
| 	// clear the joins after query because preload need it
 | |
| 	db.Statement.Joins = nil
 | |
| 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
 | |
| 		callMethod(db, func(value interface{}, tx *gorm.DB) bool {
 | |
| 			if i, ok := value.(AfterFindInterface); ok {
 | |
| 				db.AddError(i.AfterFind(tx))
 | |
| 				return true
 | |
| 			}
 | |
| 			return false
 | |
| 		})
 | |
| 	}
 | |
| }
 |