454 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			454 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package callbacks
 | |
| 
 | |
| import (
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 
 | |
| 	"gorm.io/gorm"
 | |
| 	"gorm.io/gorm/clause"
 | |
| 	"gorm.io/gorm/schema"
 | |
| 	"gorm.io/gorm/utils"
 | |
| )
 | |
| 
 | |
| func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
 | |
| 	return func(db *gorm.DB) {
 | |
| 		if db.Error == nil && db.Statement.Schema != nil {
 | |
| 			selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
 | |
| 
 | |
| 			// Save Belongs To associations
 | |
| 			for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
 | |
| 				if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				setupReferences := func(obj reflect.Value, elem reflect.Value) {
 | |
| 					for _, ref := range rel.References {
 | |
| 						if !ref.OwnPrimaryKey {
 | |
| 							pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
 | |
| 							db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
 | |
| 
 | |
| 							if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
 | |
| 								dest[ref.ForeignKey.DBName] = pv
 | |
| 								if _, ok := dest[rel.Name]; ok {
 | |
| 									dest[rel.Name] = elem.Interface()
 | |
| 								}
 | |
| 							}
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				switch db.Statement.ReflectValue.Kind() {
 | |
| 				case reflect.Slice, reflect.Array:
 | |
| 					var (
 | |
| 						rValLen   = db.Statement.ReflectValue.Len()
 | |
| 						objs      = make([]reflect.Value, 0, rValLen)
 | |
| 						fieldType = rel.Field.FieldType
 | |
| 						isPtr     = fieldType.Kind() == reflect.Ptr
 | |
| 					)
 | |
| 
 | |
| 					if !isPtr {
 | |
| 						fieldType = reflect.PtrTo(fieldType)
 | |
| 					}
 | |
| 
 | |
| 					elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | |
| 					distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | |
| 					identityMap := map[string]bool{}
 | |
| 					for i := 0; i < rValLen; i++ {
 | |
| 						obj := db.Statement.ReflectValue.Index(i)
 | |
| 						if reflect.Indirect(obj).Kind() != reflect.Struct {
 | |
| 							break
 | |
| 						}
 | |
| 						if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
 | |
| 							rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
 | |
| 							if !isPtr {
 | |
| 								rv = rv.Addr()
 | |
| 							}
 | |
| 							objs = append(objs, obj)
 | |
| 							elems = reflect.Append(elems, rv)
 | |
| 
 | |
| 							relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
 | |
| 							for _, pf := range rel.FieldSchema.PrimaryFields {
 | |
| 								if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
 | |
| 									relPrimaryValues = append(relPrimaryValues, pfv)
 | |
| 								}
 | |
| 							}
 | |
| 							cacheKey := utils.ToStringKey(relPrimaryValues...)
 | |
| 							if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
 | |
| 								if cacheKey != "" { // has primary fields
 | |
| 									identityMap[cacheKey] = true
 | |
| 								}
 | |
| 
 | |
| 								distinctElems = reflect.Append(distinctElems, rv)
 | |
| 							}
 | |
| 						}
 | |
| 					}
 | |
| 
 | |
| 					if elems.Len() > 0 {
 | |
| 						if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
 | |
| 							for i := 0; i < elems.Len(); i++ {
 | |
| 								setupReferences(objs[i], elems.Index(i))
 | |
| 							}
 | |
| 						}
 | |
| 					}
 | |
| 				case reflect.Struct:
 | |
| 					if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
 | |
| 						rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
 | |
| 						if rv.Kind() != reflect.Ptr {
 | |
| 							rv = rv.Addr()
 | |
| 						}
 | |
| 
 | |
| 						if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
 | |
| 							setupReferences(db.Statement.ReflectValue, rv)
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func SaveAfterAssociations(create bool) func(db *gorm.DB) {
 | |
| 	return func(db *gorm.DB) {
 | |
| 		if db.Error == nil && db.Statement.Schema != nil {
 | |
| 			selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
 | |
| 
 | |
| 			// Save Has One associations
 | |
| 			for _, rel := range db.Statement.Schema.Relationships.HasOne {
 | |
| 				if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				switch db.Statement.ReflectValue.Kind() {
 | |
| 				case reflect.Slice, reflect.Array:
 | |
| 					var (
 | |
| 						fieldType = rel.Field.FieldType
 | |
| 						isPtr     = fieldType.Kind() == reflect.Ptr
 | |
| 					)
 | |
| 
 | |
| 					if !isPtr {
 | |
| 						fieldType = reflect.PtrTo(fieldType)
 | |
| 					}
 | |
| 
 | |
| 					elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | |
| 
 | |
| 					for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
 | |
| 						obj := db.Statement.ReflectValue.Index(i)
 | |
| 
 | |
| 						if reflect.Indirect(obj).Kind() == reflect.Struct {
 | |
| 							if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
 | |
| 								rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
 | |
| 								if rv.Kind() != reflect.Ptr {
 | |
| 									rv = rv.Addr()
 | |
| 								}
 | |
| 
 | |
| 								for _, ref := range rel.References {
 | |
| 									if ref.OwnPrimaryKey {
 | |
| 										fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
 | |
| 										db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
 | |
| 									} else if ref.PrimaryValue != "" {
 | |
| 										db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
 | |
| 									}
 | |
| 								}
 | |
| 
 | |
| 								elems = reflect.Append(elems, rv)
 | |
| 							}
 | |
| 						}
 | |
| 					}
 | |
| 
 | |
| 					if elems.Len() > 0 {
 | |
| 						assignmentColumns := make([]string, 0, len(rel.References))
 | |
| 						for _, ref := range rel.References {
 | |
| 							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
 | |
| 						}
 | |
| 
 | |
| 						saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
 | |
| 					}
 | |
| 				case reflect.Struct:
 | |
| 					if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
 | |
| 						f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
 | |
| 						if f.Kind() != reflect.Ptr {
 | |
| 							f = f.Addr()
 | |
| 						}
 | |
| 
 | |
| 						assignmentColumns := make([]string, 0, len(rel.References))
 | |
| 						for _, ref := range rel.References {
 | |
| 							if ref.OwnPrimaryKey {
 | |
| 								fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
 | |
| 								db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
 | |
| 							} else if ref.PrimaryValue != "" {
 | |
| 								db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
 | |
| 							}
 | |
| 							assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
 | |
| 						}
 | |
| 
 | |
| 						saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			// Save Has Many associations
 | |
| 			for _, rel := range db.Statement.Schema.Relationships.HasMany {
 | |
| 				if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				fieldType := rel.Field.IndirectFieldType.Elem()
 | |
| 				isPtr := fieldType.Kind() == reflect.Ptr
 | |
| 				if !isPtr {
 | |
| 					fieldType = reflect.PtrTo(fieldType)
 | |
| 				}
 | |
| 				elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | |
| 				identityMap := map[string]bool{}
 | |
| 				appendToElems := func(v reflect.Value) {
 | |
| 					if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
 | |
| 						f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
 | |
| 
 | |
| 						for i := 0; i < f.Len(); i++ {
 | |
| 							elem := f.Index(i)
 | |
| 							for _, ref := range rel.References {
 | |
| 								if ref.OwnPrimaryKey {
 | |
| 									pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
 | |
| 									db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
 | |
| 								} else if ref.PrimaryValue != "" {
 | |
| 									db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
 | |
| 								}
 | |
| 							}
 | |
| 
 | |
| 							relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
 | |
| 							for _, pf := range rel.FieldSchema.PrimaryFields {
 | |
| 								if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
 | |
| 									relPrimaryValues = append(relPrimaryValues, pfv)
 | |
| 								}
 | |
| 							}
 | |
| 
 | |
| 							cacheKey := utils.ToStringKey(relPrimaryValues...)
 | |
| 							if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
 | |
| 								if cacheKey != "" { // has primary fields
 | |
| 									identityMap[cacheKey] = true
 | |
| 								}
 | |
| 
 | |
| 								if isPtr {
 | |
| 									elems = reflect.Append(elems, elem)
 | |
| 								} else {
 | |
| 									elems = reflect.Append(elems, elem.Addr())
 | |
| 								}
 | |
| 							}
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				switch db.Statement.ReflectValue.Kind() {
 | |
| 				case reflect.Slice, reflect.Array:
 | |
| 					for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
 | |
| 						obj := db.Statement.ReflectValue.Index(i)
 | |
| 						if reflect.Indirect(obj).Kind() == reflect.Struct {
 | |
| 							appendToElems(obj)
 | |
| 						}
 | |
| 					}
 | |
| 				case reflect.Struct:
 | |
| 					appendToElems(db.Statement.ReflectValue)
 | |
| 				}
 | |
| 
 | |
| 				if elems.Len() > 0 {
 | |
| 					assignmentColumns := make([]string, 0, len(rel.References))
 | |
| 					for _, ref := range rel.References {
 | |
| 						assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
 | |
| 					}
 | |
| 
 | |
| 					saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			// Save Many2Many associations
 | |
| 			for _, rel := range db.Statement.Schema.Relationships.Many2Many {
 | |
| 				if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				fieldType := rel.Field.IndirectFieldType.Elem()
 | |
| 				isPtr := fieldType.Kind() == reflect.Ptr
 | |
| 				if !isPtr {
 | |
| 					fieldType = reflect.PtrTo(fieldType)
 | |
| 				}
 | |
| 				elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | |
| 				distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
 | |
| 				joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
 | |
| 				objs := []reflect.Value{}
 | |
| 
 | |
| 				appendToJoins := func(obj reflect.Value, elem reflect.Value) {
 | |
| 					joinValue := reflect.New(rel.JoinTable.ModelType)
 | |
| 					for _, ref := range rel.References {
 | |
| 						if ref.OwnPrimaryKey {
 | |
| 							fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
 | |
| 							db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
 | |
| 						} else if ref.PrimaryValue != "" {
 | |
| 							db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
 | |
| 						} else {
 | |
| 							fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
 | |
| 							db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
 | |
| 						}
 | |
| 					}
 | |
| 					joins = reflect.Append(joins, joinValue)
 | |
| 				}
 | |
| 
 | |
| 				identityMap := map[string]bool{}
 | |
| 				appendToElems := func(v reflect.Value) {
 | |
| 					if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
 | |
| 						f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
 | |
| 						for i := 0; i < f.Len(); i++ {
 | |
| 							elem := f.Index(i)
 | |
| 							if !isPtr {
 | |
| 								elem = elem.Addr()
 | |
| 							}
 | |
| 							objs = append(objs, v)
 | |
| 							elems = reflect.Append(elems, elem)
 | |
| 
 | |
| 							relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
 | |
| 							for _, pf := range rel.FieldSchema.PrimaryFields {
 | |
| 								if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
 | |
| 									relPrimaryValues = append(relPrimaryValues, pfv)
 | |
| 								}
 | |
| 							}
 | |
| 
 | |
| 							cacheKey := utils.ToStringKey(relPrimaryValues...)
 | |
| 							if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
 | |
| 								if cacheKey != "" { // has primary fields
 | |
| 									identityMap[cacheKey] = true
 | |
| 								}
 | |
| 
 | |
| 								distinctElems = reflect.Append(distinctElems, elem)
 | |
| 							}
 | |
| 
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				switch db.Statement.ReflectValue.Kind() {
 | |
| 				case reflect.Slice, reflect.Array:
 | |
| 					for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
 | |
| 						obj := db.Statement.ReflectValue.Index(i)
 | |
| 						if reflect.Indirect(obj).Kind() == reflect.Struct {
 | |
| 							appendToElems(obj)
 | |
| 						}
 | |
| 					}
 | |
| 				case reflect.Struct:
 | |
| 					appendToElems(db.Statement.ReflectValue)
 | |
| 				}
 | |
| 
 | |
| 				// optimize elems of reflect value length
 | |
| 				if elemLen := elems.Len(); elemLen > 0 {
 | |
| 					if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
 | |
| 						saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
 | |
| 					}
 | |
| 
 | |
| 					for i := 0; i < elemLen; i++ {
 | |
| 						appendToJoins(objs[i], elems.Index(i))
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				if joins.Len() > 0 {
 | |
| 					db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
 | |
| 						SkipHooks:                db.Statement.SkipHooks,
 | |
| 						DisableNestedTransaction: true,
 | |
| 					}).Create(joins.Interface()).Error)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
 | |
| 	if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
 | |
| 		onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
 | |
| 		for _, dbName := range s.PrimaryFieldDBNames {
 | |
| 			onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
 | |
| 		}
 | |
| 
 | |
| 		onConflict.UpdateAll = stmt.DB.FullSaveAssociations
 | |
| 		if !onConflict.UpdateAll {
 | |
| 			onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
 | |
| 		}
 | |
| 	} else {
 | |
| 		onConflict.DoNothing = true
 | |
| 	}
 | |
| 
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
 | |
| 	// stop save association loop
 | |
| 	if checkAssociationsSaved(db, rValues) {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	var (
 | |
| 		selects, omits []string
 | |
| 		onConflict     = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
 | |
| 		refName        = rel.Name + "."
 | |
| 		values         = rValues.Interface()
 | |
| 	)
 | |
| 
 | |
| 	for name, ok := range selectColumns {
 | |
| 		columnName := ""
 | |
| 		if strings.HasPrefix(name, refName) {
 | |
| 			columnName = strings.TrimPrefix(name, refName)
 | |
| 		}
 | |
| 
 | |
| 		if columnName != "" {
 | |
| 			if ok {
 | |
| 				selects = append(selects, columnName)
 | |
| 			} else {
 | |
| 				omits = append(omits, columnName)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
 | |
| 		FullSaveAssociations:     db.FullSaveAssociations,
 | |
| 		SkipHooks:                db.Statement.SkipHooks,
 | |
| 		DisableNestedTransaction: true,
 | |
| 	})
 | |
| 
 | |
| 	db.Statement.Settings.Range(func(k, v interface{}) bool {
 | |
| 		tx.Statement.Settings.Store(k, v)
 | |
| 		return true
 | |
| 	})
 | |
| 
 | |
| 	if tx.Statement.FullSaveAssociations {
 | |
| 		tx = tx.Set("gorm:update_track_time", true)
 | |
| 	}
 | |
| 
 | |
| 	if len(selects) > 0 {
 | |
| 		tx = tx.Select(selects)
 | |
| 	} else if restricted && len(omits) == 0 {
 | |
| 		tx = tx.Omit(clause.Associations)
 | |
| 	}
 | |
| 
 | |
| 	if len(omits) > 0 {
 | |
| 		tx = tx.Omit(omits...)
 | |
| 	}
 | |
| 
 | |
| 	return db.AddError(tx.Create(values).Error)
 | |
| }
 | |
| 
 | |
| // check association values has been saved
 | |
| // if values kind is Struct, check it has been saved
 | |
| // if values kind is Slice/Array, check all items have been saved
 | |
| var visitMapStoreKey = "gorm:saved_association_map"
 | |
| 
 | |
| func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
 | |
| 	if visit, ok := db.Get(visitMapStoreKey); ok {
 | |
| 		if v, ok := visit.(*visitMap); ok {
 | |
| 			if loadOrStoreVisitMap(v, values) {
 | |
| 				return true
 | |
| 			}
 | |
| 		}
 | |
| 	} else {
 | |
| 		vistMap := make(visitMap)
 | |
| 		loadOrStoreVisitMap(&vistMap, values)
 | |
| 		db.Set(visitMapStoreKey, &vistMap)
 | |
| 	}
 | |
| 
 | |
| 	return false
 | |
| }
 |