153 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			153 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package callbacks
 | 
						|
 | 
						|
import (
 | 
						|
	"reflect"
 | 
						|
	"sort"
 | 
						|
 | 
						|
	"gorm.io/gorm"
 | 
						|
	"gorm.io/gorm/clause"
 | 
						|
)
 | 
						|
 | 
						|
// ConvertMapToValuesForCreate convert map to values
 | 
						|
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
 | 
						|
	values.Columns = make([]clause.Column, 0, len(mapValue))
 | 
						|
	selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
 | 
						|
 | 
						|
	keys := make([]string, 0, len(mapValue))
 | 
						|
	for k := range mapValue {
 | 
						|
		keys = append(keys, k)
 | 
						|
	}
 | 
						|
	sort.Strings(keys)
 | 
						|
 | 
						|
	for _, k := range keys {
 | 
						|
		value := mapValue[k]
 | 
						|
		if stmt.Schema != nil {
 | 
						|
			if field := stmt.Schema.LookUpField(k); field != nil {
 | 
						|
				k = field.DBName
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
 | 
						|
			values.Columns = append(values.Columns, clause.Column{Name: k})
 | 
						|
			if len(values.Values) == 0 {
 | 
						|
				values.Values = [][]interface{}{{}}
 | 
						|
			}
 | 
						|
 | 
						|
			values.Values[0] = append(values.Values[0], value)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
 | 
						|
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
 | 
						|
	columns := make([]string, 0, len(mapValues))
 | 
						|
 | 
						|
	// when the length of mapValues is zero,return directly here
 | 
						|
	// no need to call stmt.SelectAndOmitColumns method
 | 
						|
	if len(mapValues) == 0 {
 | 
						|
		stmt.AddError(gorm.ErrEmptySlice)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	var (
 | 
						|
		result                    = make(map[string][]interface{}, len(mapValues))
 | 
						|
		selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
 | 
						|
	)
 | 
						|
 | 
						|
	for idx, mapValue := range mapValues {
 | 
						|
		for k, v := range mapValue {
 | 
						|
			if stmt.Schema != nil {
 | 
						|
				if field := stmt.Schema.LookUpField(k); field != nil {
 | 
						|
					k = field.DBName
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			if _, ok := result[k]; !ok {
 | 
						|
				if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
 | 
						|
					result[k] = make([]interface{}, len(mapValues))
 | 
						|
					columns = append(columns, k)
 | 
						|
				} else {
 | 
						|
					continue
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			result[k][idx] = v
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	sort.Strings(columns)
 | 
						|
	values.Values = make([][]interface{}, len(mapValues))
 | 
						|
	values.Columns = make([]clause.Column, len(columns))
 | 
						|
	for idx, column := range columns {
 | 
						|
		values.Columns[idx] = clause.Column{Name: column}
 | 
						|
 | 
						|
		for i, v := range result[column] {
 | 
						|
			if len(values.Values[i]) == 0 {
 | 
						|
				values.Values[i] = make([]interface{}, len(columns))
 | 
						|
			}
 | 
						|
 | 
						|
			values.Values[i][idx] = v
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
 | 
						|
	if supportReturning {
 | 
						|
		if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
 | 
						|
			returning, _ := c.Expression.(clause.Returning)
 | 
						|
			if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
 | 
						|
				return true, 0
 | 
						|
			}
 | 
						|
			return true, gorm.ScanUpdate
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return false, 0
 | 
						|
}
 | 
						|
 | 
						|
func checkMissingWhereConditions(db *gorm.DB) {
 | 
						|
	if !db.AllowGlobalUpdate && db.Error == nil {
 | 
						|
		where, withCondition := db.Statement.Clauses["WHERE"]
 | 
						|
		if withCondition {
 | 
						|
			if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
 | 
						|
				whereClause, _ := where.Expression.(clause.Where)
 | 
						|
				withCondition = len(whereClause.Exprs) > 1
 | 
						|
			}
 | 
						|
		}
 | 
						|
		if !withCondition {
 | 
						|
			db.AddError(gorm.ErrMissingWhereClause)
 | 
						|
		}
 | 
						|
		return
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
type visitMap = map[reflect.Value]bool
 | 
						|
 | 
						|
// Check if circular values, return true if loaded
 | 
						|
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
 | 
						|
	if v.Kind() == reflect.Ptr {
 | 
						|
		v = v.Elem()
 | 
						|
	}
 | 
						|
 | 
						|
	switch v.Kind() {
 | 
						|
	case reflect.Slice, reflect.Array:
 | 
						|
		loaded = true
 | 
						|
		for i := 0; i < v.Len(); i++ {
 | 
						|
			if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
 | 
						|
				loaded = false
 | 
						|
			}
 | 
						|
		}
 | 
						|
	case reflect.Struct, reflect.Interface:
 | 
						|
		if v.CanAddr() {
 | 
						|
			p := v.Addr()
 | 
						|
			if _, ok := (*visitMap)[p]; ok {
 | 
						|
				return true
 | 
						|
			}
 | 
						|
			(*visitMap)[p] = true
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 |