349 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			349 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package gorm
 | 
						|
 | 
						|
import (
 | 
						|
	"database/sql"
 | 
						|
	"database/sql/driver"
 | 
						|
	"reflect"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"gorm.io/gorm/schema"
 | 
						|
	"gorm.io/gorm/utils"
 | 
						|
)
 | 
						|
 | 
						|
// prepareValues prepare values slice
 | 
						|
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
 | 
						|
	if db.Statement.Schema != nil {
 | 
						|
		for idx, name := range columns {
 | 
						|
			if field := db.Statement.Schema.LookUpField(name); field != nil {
 | 
						|
				values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
 | 
						|
				continue
 | 
						|
			}
 | 
						|
			values[idx] = new(interface{})
 | 
						|
		}
 | 
						|
	} else if len(columnTypes) > 0 {
 | 
						|
		for idx, columnType := range columnTypes {
 | 
						|
			if columnType.ScanType() != nil {
 | 
						|
				values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
 | 
						|
			} else {
 | 
						|
				values[idx] = new(interface{})
 | 
						|
			}
 | 
						|
		}
 | 
						|
	} else {
 | 
						|
		for idx := range columns {
 | 
						|
			values[idx] = new(interface{})
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
 | 
						|
	for idx, column := range columns {
 | 
						|
		if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
 | 
						|
			mapValue[column] = reflectValue.Interface()
 | 
						|
			if valuer, ok := mapValue[column].(driver.Valuer); ok {
 | 
						|
				mapValue[column], _ = valuer.Value()
 | 
						|
			} else if b, ok := mapValue[column].(sql.RawBytes); ok {
 | 
						|
				mapValue[column] = string(b)
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			mapValue[column] = nil
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
 | 
						|
	for idx, field := range fields {
 | 
						|
		if field != nil {
 | 
						|
			values[idx] = field.NewValuePool.Get()
 | 
						|
		} else if len(fields) == 1 {
 | 
						|
			if reflectValue.CanAddr() {
 | 
						|
				values[idx] = reflectValue.Addr().Interface()
 | 
						|
			} else {
 | 
						|
				values[idx] = reflectValue.Interface()
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	db.RowsAffected++
 | 
						|
	db.AddError(rows.Scan(values...))
 | 
						|
	joinedNestedSchemaMap := make(map[string]interface{})
 | 
						|
	for idx, field := range fields {
 | 
						|
		if field == nil {
 | 
						|
			continue
 | 
						|
		}
 | 
						|
 | 
						|
		if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
 | 
						|
			db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
 | 
						|
		} else { // joinFields count is larger than 2 when using join
 | 
						|
			var isNilPtrValue bool
 | 
						|
			var relValue reflect.Value
 | 
						|
			// does not contain raw dbname
 | 
						|
			nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
 | 
						|
			// current reflect value
 | 
						|
			currentReflectValue := reflectValue
 | 
						|
			fullRels := make([]string, 0, len(nestedJoinSchemas))
 | 
						|
			for _, joinSchema := range nestedJoinSchemas {
 | 
						|
				fullRels = append(fullRels, joinSchema.Name)
 | 
						|
				relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
 | 
						|
				if relValue.Kind() == reflect.Ptr {
 | 
						|
					fullRelsName := utils.JoinNestedRelationNames(fullRels)
 | 
						|
					// same nested structure
 | 
						|
					if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
 | 
						|
						if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
 | 
						|
							isNilPtrValue = true
 | 
						|
							break
 | 
						|
						}
 | 
						|
 | 
						|
						relValue.Set(reflect.New(relValue.Type().Elem()))
 | 
						|
						joinedNestedSchemaMap[fullRelsName] = nil
 | 
						|
					}
 | 
						|
				}
 | 
						|
				currentReflectValue = relValue
 | 
						|
			}
 | 
						|
 | 
						|
			if !isNilPtrValue { // ignore if value is nil
 | 
						|
				f := joinFields[idx][len(joinFields[idx])-1]
 | 
						|
				db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		// release data to pool
 | 
						|
		field.NewValuePool.Put(values[idx])
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
// ScanMode scan data mode
 | 
						|
type ScanMode uint8
 | 
						|
 | 
						|
// scan modes
 | 
						|
const (
 | 
						|
	ScanInitialized         ScanMode = 1 << 0 // 1
 | 
						|
	ScanUpdate              ScanMode = 1 << 1 // 2
 | 
						|
	ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
 | 
						|
)
 | 
						|
 | 
						|
// Scan scan rows into db statement
 | 
						|
func Scan(rows Rows, db *DB, mode ScanMode) {
 | 
						|
	var (
 | 
						|
		columns, _          = rows.Columns()
 | 
						|
		values              = make([]interface{}, len(columns))
 | 
						|
		initialized         = mode&ScanInitialized != 0
 | 
						|
		update              = mode&ScanUpdate != 0
 | 
						|
		onConflictDonothing = mode&ScanOnConflictDoNothing != 0
 | 
						|
	)
 | 
						|
 | 
						|
	db.RowsAffected = 0
 | 
						|
 | 
						|
	switch dest := db.Statement.Dest.(type) {
 | 
						|
	case map[string]interface{}, *map[string]interface{}:
 | 
						|
		if initialized || rows.Next() {
 | 
						|
			columnTypes, _ := rows.ColumnTypes()
 | 
						|
			prepareValues(values, db, columnTypes, columns)
 | 
						|
 | 
						|
			db.RowsAffected++
 | 
						|
			db.AddError(rows.Scan(values...))
 | 
						|
 | 
						|
			mapValue, ok := dest.(map[string]interface{})
 | 
						|
			if !ok {
 | 
						|
				if v, ok := dest.(*map[string]interface{}); ok {
 | 
						|
					if *v == nil {
 | 
						|
						*v = map[string]interface{}{}
 | 
						|
					}
 | 
						|
					mapValue = *v
 | 
						|
				}
 | 
						|
			}
 | 
						|
			scanIntoMap(mapValue, values, columns)
 | 
						|
		}
 | 
						|
	case *[]map[string]interface{}:
 | 
						|
		columnTypes, _ := rows.ColumnTypes()
 | 
						|
		for initialized || rows.Next() {
 | 
						|
			prepareValues(values, db, columnTypes, columns)
 | 
						|
 | 
						|
			initialized = false
 | 
						|
			db.RowsAffected++
 | 
						|
			db.AddError(rows.Scan(values...))
 | 
						|
 | 
						|
			mapValue := map[string]interface{}{}
 | 
						|
			scanIntoMap(mapValue, values, columns)
 | 
						|
			*dest = append(*dest, mapValue)
 | 
						|
		}
 | 
						|
	case *int, *int8, *int16, *int32, *int64,
 | 
						|
		*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
 | 
						|
		*float32, *float64,
 | 
						|
		*bool, *string, *time.Time,
 | 
						|
		*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
 | 
						|
		*sql.NullBool, *sql.NullString, *sql.NullTime:
 | 
						|
		for initialized || rows.Next() {
 | 
						|
			initialized = false
 | 
						|
			db.RowsAffected++
 | 
						|
			db.AddError(rows.Scan(dest))
 | 
						|
		}
 | 
						|
	default:
 | 
						|
		var (
 | 
						|
			fields       = make([]*schema.Field, len(columns))
 | 
						|
			joinFields   [][]*schema.Field
 | 
						|
			sch          = db.Statement.Schema
 | 
						|
			reflectValue = db.Statement.ReflectValue
 | 
						|
		)
 | 
						|
 | 
						|
		if reflectValue.Kind() == reflect.Interface {
 | 
						|
			reflectValue = reflectValue.Elem()
 | 
						|
		}
 | 
						|
 | 
						|
		reflectValueType := reflectValue.Type()
 | 
						|
		switch reflectValueType.Kind() {
 | 
						|
		case reflect.Array, reflect.Slice:
 | 
						|
			reflectValueType = reflectValueType.Elem()
 | 
						|
		}
 | 
						|
		isPtr := reflectValueType.Kind() == reflect.Ptr
 | 
						|
		if isPtr {
 | 
						|
			reflectValueType = reflectValueType.Elem()
 | 
						|
		}
 | 
						|
 | 
						|
		if sch != nil {
 | 
						|
			if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
 | 
						|
				sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
 | 
						|
			}
 | 
						|
 | 
						|
			if len(columns) == 1 {
 | 
						|
				// Is Pluck
 | 
						|
				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
 | 
						|
					reflectValueType.Kind() != reflect.Struct || // is not struct
 | 
						|
					sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
 | 
						|
					sch = nil
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			// Not Pluck
 | 
						|
			if sch != nil {
 | 
						|
				matchedFieldCount := make(map[string]int, len(columns))
 | 
						|
				for idx, column := range columns {
 | 
						|
					if field := sch.LookUpField(column); field != nil && field.Readable {
 | 
						|
						fields[idx] = field
 | 
						|
						if count, ok := matchedFieldCount[column]; ok {
 | 
						|
							// handle duplicate fields
 | 
						|
							for _, selectField := range sch.Fields {
 | 
						|
								if selectField.DBName == column && selectField.Readable {
 | 
						|
									if count == 0 {
 | 
						|
										matchedFieldCount[column]++
 | 
						|
										fields[idx] = selectField
 | 
						|
										break
 | 
						|
									}
 | 
						|
									count--
 | 
						|
								}
 | 
						|
							}
 | 
						|
						} else {
 | 
						|
							matchedFieldCount[column] = 1
 | 
						|
						}
 | 
						|
					} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
 | 
						|
						if rel, ok := sch.Relationships.Relations[names[0]]; ok {
 | 
						|
							subNameCount := len(names)
 | 
						|
							// nested relation fields
 | 
						|
							relFields := make([]*schema.Field, 0, subNameCount-1)
 | 
						|
							relFields = append(relFields, rel.Field)
 | 
						|
							for _, name := range names[1 : subNameCount-1] {
 | 
						|
								rel = rel.FieldSchema.Relationships.Relations[name]
 | 
						|
								relFields = append(relFields, rel.Field)
 | 
						|
							}
 | 
						|
							// lastest name is raw dbname
 | 
						|
							dbName := names[subNameCount-1]
 | 
						|
							if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
 | 
						|
								fields[idx] = field
 | 
						|
 | 
						|
								if len(joinFields) == 0 {
 | 
						|
									joinFields = make([][]*schema.Field, len(columns))
 | 
						|
								}
 | 
						|
								relFields = append(relFields, field)
 | 
						|
								joinFields[idx] = relFields
 | 
						|
								continue
 | 
						|
							}
 | 
						|
						}
 | 
						|
						var val interface{}
 | 
						|
						values[idx] = &val
 | 
						|
					} else {
 | 
						|
						var val interface{}
 | 
						|
						values[idx] = &val
 | 
						|
					}
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		switch reflectValue.Kind() {
 | 
						|
		case reflect.Slice, reflect.Array:
 | 
						|
			var (
 | 
						|
				elem        reflect.Value
 | 
						|
				isArrayKind = reflectValue.Kind() == reflect.Array
 | 
						|
			)
 | 
						|
 | 
						|
			if !update || reflectValue.Len() == 0 {
 | 
						|
				update = false
 | 
						|
				if isArrayKind {
 | 
						|
					db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
 | 
						|
				} else {
 | 
						|
					// if the slice cap is externally initialized, the externally initialized slice is directly used here
 | 
						|
					if reflectValue.Cap() == 0 {
 | 
						|
						db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
 | 
						|
					} else {
 | 
						|
						reflectValue.SetLen(0)
 | 
						|
						db.Statement.ReflectValue.Set(reflectValue)
 | 
						|
					}
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			for initialized || rows.Next() {
 | 
						|
			BEGIN:
 | 
						|
				initialized = false
 | 
						|
 | 
						|
				if update {
 | 
						|
					if int(db.RowsAffected) >= reflectValue.Len() {
 | 
						|
						return
 | 
						|
					}
 | 
						|
					elem = reflectValue.Index(int(db.RowsAffected))
 | 
						|
					if onConflictDonothing {
 | 
						|
						for _, field := range fields {
 | 
						|
							if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
 | 
						|
								db.RowsAffected++
 | 
						|
								goto BEGIN
 | 
						|
							}
 | 
						|
						}
 | 
						|
					}
 | 
						|
				} else {
 | 
						|
					elem = reflect.New(reflectValueType)
 | 
						|
				}
 | 
						|
 | 
						|
				db.scanIntoStruct(rows, elem, values, fields, joinFields)
 | 
						|
 | 
						|
				if !update {
 | 
						|
					if !isPtr {
 | 
						|
						elem = elem.Elem()
 | 
						|
					}
 | 
						|
					if isArrayKind {
 | 
						|
						if reflectValue.Len() >= int(db.RowsAffected) {
 | 
						|
							reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
 | 
						|
						}
 | 
						|
					} else {
 | 
						|
						reflectValue = reflect.Append(reflectValue, elem)
 | 
						|
					}
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			if !update {
 | 
						|
				db.Statement.ReflectValue.Set(reflectValue)
 | 
						|
			}
 | 
						|
		case reflect.Struct, reflect.Ptr:
 | 
						|
			if initialized || rows.Next() {
 | 
						|
				db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
 | 
						|
			}
 | 
						|
		default:
 | 
						|
			db.AddError(rows.Scan(dest))
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if err := rows.Err(); err != nil && err != db.Error {
 | 
						|
		db.AddError(err)
 | 
						|
	}
 | 
						|
 | 
						|
	if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
 | 
						|
		db.AddError(ErrRecordNotFound)
 | 
						|
	}
 | 
						|
}
 |