424 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			424 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package schema
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"go/ast"
 | 
						|
	"reflect"
 | 
						|
	"strings"
 | 
						|
	"sync"
 | 
						|
 | 
						|
	"gorm.io/gorm/clause"
 | 
						|
	"gorm.io/gorm/logger"
 | 
						|
)
 | 
						|
 | 
						|
type callbackType string
 | 
						|
 | 
						|
const (
 | 
						|
	callbackTypeBeforeCreate callbackType = "BeforeCreate"
 | 
						|
	callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
 | 
						|
	callbackTypeAfterCreate  callbackType = "AfterCreate"
 | 
						|
	callbackTypeAfterUpdate  callbackType = "AfterUpdate"
 | 
						|
	callbackTypeBeforeSave   callbackType = "BeforeSave"
 | 
						|
	callbackTypeAfterSave    callbackType = "AfterSave"
 | 
						|
	callbackTypeBeforeDelete callbackType = "BeforeDelete"
 | 
						|
	callbackTypeAfterDelete  callbackType = "AfterDelete"
 | 
						|
	callbackTypeAfterFind    callbackType = "AfterFind"
 | 
						|
)
 | 
						|
 | 
						|
// ErrUnsupportedDataType unsupported data type
 | 
						|
var ErrUnsupportedDataType = errors.New("unsupported data type")
 | 
						|
 | 
						|
type Schema struct {
 | 
						|
	Name                      string
 | 
						|
	ModelType                 reflect.Type
 | 
						|
	Table                     string
 | 
						|
	PrioritizedPrimaryField   *Field
 | 
						|
	DBNames                   []string
 | 
						|
	PrimaryFields             []*Field
 | 
						|
	PrimaryFieldDBNames       []string
 | 
						|
	Fields                    []*Field
 | 
						|
	FieldsByName              map[string]*Field
 | 
						|
	FieldsByBindName          map[string]*Field // embedded fields is 'Embed.Field'
 | 
						|
	FieldsByDBName            map[string]*Field
 | 
						|
	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
 | 
						|
	Relationships             Relationships
 | 
						|
	CreateClauses             []clause.Interface
 | 
						|
	QueryClauses              []clause.Interface
 | 
						|
	UpdateClauses             []clause.Interface
 | 
						|
	DeleteClauses             []clause.Interface
 | 
						|
	BeforeCreate, AfterCreate bool
 | 
						|
	BeforeUpdate, AfterUpdate bool
 | 
						|
	BeforeDelete, AfterDelete bool
 | 
						|
	BeforeSave, AfterSave     bool
 | 
						|
	AfterFind                 bool
 | 
						|
	err                       error
 | 
						|
	initialized               chan struct{}
 | 
						|
	namer                     Namer
 | 
						|
	cacheStore                *sync.Map
 | 
						|
}
 | 
						|
 | 
						|
func (schema Schema) String() string {
 | 
						|
	if schema.ModelType.Name() == "" {
 | 
						|
		return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
 | 
						|
	}
 | 
						|
	return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
 | 
						|
}
 | 
						|
 | 
						|
func (schema Schema) MakeSlice() reflect.Value {
 | 
						|
	slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
 | 
						|
	results := reflect.New(slice.Type())
 | 
						|
	results.Elem().Set(slice)
 | 
						|
	return results
 | 
						|
}
 | 
						|
 | 
						|
func (schema Schema) LookUpField(name string) *Field {
 | 
						|
	if field, ok := schema.FieldsByDBName[name]; ok {
 | 
						|
		return field
 | 
						|
	}
 | 
						|
	if field, ok := schema.FieldsByName[name]; ok {
 | 
						|
		return field
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// LookUpFieldByBindName looks for the closest field in the embedded struct.
 | 
						|
//
 | 
						|
//	type Struct struct {
 | 
						|
//		Embedded struct {
 | 
						|
//			ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
 | 
						|
//		}
 | 
						|
//		ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
 | 
						|
//	}
 | 
						|
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
 | 
						|
	if len(bindNames) == 0 {
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	for i := len(bindNames) - 1; i >= 0; i-- {
 | 
						|
		find := strings.Join(bindNames[:i], ".") + "." + name
 | 
						|
		if field, ok := schema.FieldsByBindName[find]; ok {
 | 
						|
			return field
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
type Tabler interface {
 | 
						|
	TableName() string
 | 
						|
}
 | 
						|
 | 
						|
type TablerWithNamer interface {
 | 
						|
	TableName(Namer) string
 | 
						|
}
 | 
						|
 | 
						|
// Parse get data type from dialector
 | 
						|
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
 | 
						|
	return ParseWithSpecialTableName(dest, cacheStore, namer, "")
 | 
						|
}
 | 
						|
 | 
						|
// ParseWithSpecialTableName get data type from dialector with extra schema table
 | 
						|
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
 | 
						|
	if dest == nil {
 | 
						|
		return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 | 
						|
	}
 | 
						|
 | 
						|
	value := reflect.ValueOf(dest)
 | 
						|
	if value.Kind() == reflect.Ptr && value.IsNil() {
 | 
						|
		value = reflect.New(value.Type().Elem())
 | 
						|
	}
 | 
						|
	modelType := reflect.Indirect(value).Type()
 | 
						|
 | 
						|
	if modelType.Kind() == reflect.Interface {
 | 
						|
		modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
 | 
						|
	}
 | 
						|
 | 
						|
	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
 | 
						|
		modelType = modelType.Elem()
 | 
						|
	}
 | 
						|
 | 
						|
	if modelType.Kind() != reflect.Struct {
 | 
						|
		if modelType.PkgPath() == "" {
 | 
						|
			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 | 
						|
		}
 | 
						|
		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 | 
						|
	}
 | 
						|
 | 
						|
	// Cache the Schema for performance,
 | 
						|
	// Use the modelType or modelType + schemaTable (if it present) as cache key.
 | 
						|
	var schemaCacheKey interface{}
 | 
						|
	if specialTableName != "" {
 | 
						|
		schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
 | 
						|
	} else {
 | 
						|
		schemaCacheKey = modelType
 | 
						|
	}
 | 
						|
 | 
						|
	// Load exist schema cache, return if exists
 | 
						|
	if v, ok := cacheStore.Load(schemaCacheKey); ok {
 | 
						|
		s := v.(*Schema)
 | 
						|
		// Wait for the initialization of other goroutines to complete
 | 
						|
		<-s.initialized
 | 
						|
		return s, s.err
 | 
						|
	}
 | 
						|
 | 
						|
	modelValue := reflect.New(modelType)
 | 
						|
	tableName := namer.TableName(modelType.Name())
 | 
						|
	if tabler, ok := modelValue.Interface().(Tabler); ok {
 | 
						|
		tableName = tabler.TableName()
 | 
						|
	}
 | 
						|
	if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
 | 
						|
		tableName = tabler.TableName(namer)
 | 
						|
	}
 | 
						|
	if en, ok := namer.(embeddedNamer); ok {
 | 
						|
		tableName = en.Table
 | 
						|
	}
 | 
						|
	if specialTableName != "" && specialTableName != tableName {
 | 
						|
		tableName = specialTableName
 | 
						|
	}
 | 
						|
 | 
						|
	schema := &Schema{
 | 
						|
		Name:             modelType.Name(),
 | 
						|
		ModelType:        modelType,
 | 
						|
		Table:            tableName,
 | 
						|
		FieldsByName:     map[string]*Field{},
 | 
						|
		FieldsByBindName: map[string]*Field{},
 | 
						|
		FieldsByDBName:   map[string]*Field{},
 | 
						|
		Relationships:    Relationships{Relations: map[string]*Relationship{}},
 | 
						|
		cacheStore:       cacheStore,
 | 
						|
		namer:            namer,
 | 
						|
		initialized:      make(chan struct{}),
 | 
						|
	}
 | 
						|
	// When the schema initialization is completed, the channel will be closed
 | 
						|
	defer close(schema.initialized)
 | 
						|
 | 
						|
	// Load exist schema cache, return if exists
 | 
						|
	if v, ok := cacheStore.Load(schemaCacheKey); ok {
 | 
						|
		s := v.(*Schema)
 | 
						|
		// Wait for the initialization of other goroutines to complete
 | 
						|
		<-s.initialized
 | 
						|
		return s, s.err
 | 
						|
	}
 | 
						|
 | 
						|
	for i := 0; i < modelType.NumField(); i++ {
 | 
						|
		if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
 | 
						|
			if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
 | 
						|
				schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
 | 
						|
			} else {
 | 
						|
				schema.Fields = append(schema.Fields, field)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	for _, field := range schema.Fields {
 | 
						|
		if field.DBName == "" && field.DataType != "" {
 | 
						|
			field.DBName = namer.ColumnName(schema.Table, field.Name)
 | 
						|
		}
 | 
						|
 | 
						|
		bindName := field.BindName()
 | 
						|
		if field.DBName != "" {
 | 
						|
			// nonexistence or shortest path or first appear prioritized if has permission
 | 
						|
			if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
 | 
						|
				if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
 | 
						|
					schema.DBNames = append(schema.DBNames, field.DBName)
 | 
						|
				}
 | 
						|
				schema.FieldsByDBName[field.DBName] = field
 | 
						|
				schema.FieldsByName[field.Name] = field
 | 
						|
				schema.FieldsByBindName[bindName] = field
 | 
						|
 | 
						|
				if v != nil && v.PrimaryKey {
 | 
						|
					for idx, f := range schema.PrimaryFields {
 | 
						|
						if f == v {
 | 
						|
							schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
 | 
						|
						}
 | 
						|
					}
 | 
						|
				}
 | 
						|
 | 
						|
				if field.PrimaryKey {
 | 
						|
					schema.PrimaryFields = append(schema.PrimaryFields, field)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
 | 
						|
			schema.FieldsByName[field.Name] = field
 | 
						|
		}
 | 
						|
		if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
 | 
						|
			schema.FieldsByBindName[bindName] = field
 | 
						|
		}
 | 
						|
 | 
						|
		field.setupValuerAndSetter()
 | 
						|
	}
 | 
						|
 | 
						|
	prioritizedPrimaryField := schema.LookUpField("id")
 | 
						|
	if prioritizedPrimaryField == nil {
 | 
						|
		prioritizedPrimaryField = schema.LookUpField("ID")
 | 
						|
	}
 | 
						|
 | 
						|
	if prioritizedPrimaryField != nil {
 | 
						|
		if prioritizedPrimaryField.PrimaryKey {
 | 
						|
			schema.PrioritizedPrimaryField = prioritizedPrimaryField
 | 
						|
		} else if len(schema.PrimaryFields) == 0 {
 | 
						|
			prioritizedPrimaryField.PrimaryKey = true
 | 
						|
			schema.PrioritizedPrimaryField = prioritizedPrimaryField
 | 
						|
			schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if schema.PrioritizedPrimaryField == nil {
 | 
						|
		if len(schema.PrimaryFields) == 1 {
 | 
						|
			schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
 | 
						|
		} else if len(schema.PrimaryFields) > 1 {
 | 
						|
			// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
 | 
						|
			for _, field := range schema.PrimaryFields {
 | 
						|
				if field.AutoIncrement {
 | 
						|
					schema.PrioritizedPrimaryField = field
 | 
						|
					break
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	for _, field := range schema.PrimaryFields {
 | 
						|
		schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
 | 
						|
	}
 | 
						|
 | 
						|
	for _, field := range schema.Fields {
 | 
						|
		if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
 | 
						|
			schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if field := schema.PrioritizedPrimaryField; field != nil {
 | 
						|
		switch field.GORMDataType {
 | 
						|
		case Int, Uint:
 | 
						|
			if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
 | 
						|
				if !field.HasDefaultValue || field.DefaultValueInterface != nil {
 | 
						|
					schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
 | 
						|
				}
 | 
						|
 | 
						|
				field.HasDefaultValue = true
 | 
						|
				field.AutoIncrement = true
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	callbackTypes := []callbackType{
 | 
						|
		callbackTypeBeforeCreate, callbackTypeAfterCreate,
 | 
						|
		callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
 | 
						|
		callbackTypeBeforeSave, callbackTypeAfterSave,
 | 
						|
		callbackTypeBeforeDelete, callbackTypeAfterDelete,
 | 
						|
		callbackTypeAfterFind,
 | 
						|
	}
 | 
						|
	for _, cbName := range callbackTypes {
 | 
						|
		if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
 | 
						|
			switch methodValue.Type().String() {
 | 
						|
			case "func(*gorm.DB) error": // TODO hack
 | 
						|
				reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
 | 
						|
			default:
 | 
						|
				logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// Cache the schema
 | 
						|
	if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
 | 
						|
		s := v.(*Schema)
 | 
						|
		// Wait for the initialization of other goroutines to complete
 | 
						|
		<-s.initialized
 | 
						|
		return s, s.err
 | 
						|
	}
 | 
						|
 | 
						|
	defer func() {
 | 
						|
		if schema.err != nil {
 | 
						|
			logger.Default.Error(context.Background(), schema.err.Error())
 | 
						|
			cacheStore.Delete(modelType)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
 | 
						|
		for _, field := range schema.Fields {
 | 
						|
			if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
 | 
						|
				if schema.parseRelation(field); schema.err != nil {
 | 
						|
					return schema, schema.err
 | 
						|
				} else {
 | 
						|
					schema.FieldsByName[field.Name] = field
 | 
						|
					schema.FieldsByBindName[field.BindName()] = field
 | 
						|
				}
 | 
						|
			}
 | 
						|
 | 
						|
			fieldValue := reflect.New(field.IndirectFieldType)
 | 
						|
			fieldInterface := fieldValue.Interface()
 | 
						|
			if fc, ok := fieldInterface.(CreateClausesInterface); ok {
 | 
						|
				field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
 | 
						|
			}
 | 
						|
 | 
						|
			if fc, ok := fieldInterface.(QueryClausesInterface); ok {
 | 
						|
				field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
 | 
						|
			}
 | 
						|
 | 
						|
			if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
 | 
						|
				field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
 | 
						|
			}
 | 
						|
 | 
						|
			if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
 | 
						|
				field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	return schema, schema.err
 | 
						|
}
 | 
						|
 | 
						|
// This unrolling is needed to show to the compiler the exact set of methods
 | 
						|
// that can be used on the modelType.
 | 
						|
// Prior to go1.22 any use of MethodByName would cause the linker to
 | 
						|
// abandon dead code elimination for the entire binary.
 | 
						|
// As of go1.22 the compiler supports one special case of a string constant
 | 
						|
// being passed to MethodByName. For enterprise customers or those building
 | 
						|
// large binaries, this gives a significant reduction in binary size.
 | 
						|
// https://github.com/golang/go/issues/62257
 | 
						|
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
 | 
						|
	switch cbType {
 | 
						|
	case callbackTypeBeforeCreate:
 | 
						|
		return modelType.MethodByName(string(callbackTypeBeforeCreate))
 | 
						|
	case callbackTypeAfterCreate:
 | 
						|
		return modelType.MethodByName(string(callbackTypeAfterCreate))
 | 
						|
	case callbackTypeBeforeUpdate:
 | 
						|
		return modelType.MethodByName(string(callbackTypeBeforeUpdate))
 | 
						|
	case callbackTypeAfterUpdate:
 | 
						|
		return modelType.MethodByName(string(callbackTypeAfterUpdate))
 | 
						|
	case callbackTypeBeforeSave:
 | 
						|
		return modelType.MethodByName(string(callbackTypeBeforeSave))
 | 
						|
	case callbackTypeAfterSave:
 | 
						|
		return modelType.MethodByName(string(callbackTypeAfterSave))
 | 
						|
	case callbackTypeBeforeDelete:
 | 
						|
		return modelType.MethodByName(string(callbackTypeBeforeDelete))
 | 
						|
	case callbackTypeAfterDelete:
 | 
						|
		return modelType.MethodByName(string(callbackTypeAfterDelete))
 | 
						|
	case callbackTypeAfterFind:
 | 
						|
		return modelType.MethodByName(string(callbackTypeAfterFind))
 | 
						|
	default:
 | 
						|
		return reflect.ValueOf(nil)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
 | 
						|
	modelType := reflect.ValueOf(dest).Type()
 | 
						|
	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
 | 
						|
		modelType = modelType.Elem()
 | 
						|
	}
 | 
						|
 | 
						|
	if modelType.Kind() != reflect.Struct {
 | 
						|
		if modelType.PkgPath() == "" {
 | 
						|
			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 | 
						|
		}
 | 
						|
		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 | 
						|
	}
 | 
						|
 | 
						|
	if v, ok := cacheStore.Load(modelType); ok {
 | 
						|
		return v.(*Schema), nil
 | 
						|
	}
 | 
						|
 | 
						|
	return Parse(dest, cacheStore, namer)
 | 
						|
}
 |