424 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			424 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package schema
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"go/ast"
 | |
| 	"reflect"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 
 | |
| 	"gorm.io/gorm/clause"
 | |
| 	"gorm.io/gorm/logger"
 | |
| )
 | |
| 
 | |
| type callbackType string
 | |
| 
 | |
| const (
 | |
| 	callbackTypeBeforeCreate callbackType = "BeforeCreate"
 | |
| 	callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
 | |
| 	callbackTypeAfterCreate  callbackType = "AfterCreate"
 | |
| 	callbackTypeAfterUpdate  callbackType = "AfterUpdate"
 | |
| 	callbackTypeBeforeSave   callbackType = "BeforeSave"
 | |
| 	callbackTypeAfterSave    callbackType = "AfterSave"
 | |
| 	callbackTypeBeforeDelete callbackType = "BeforeDelete"
 | |
| 	callbackTypeAfterDelete  callbackType = "AfterDelete"
 | |
| 	callbackTypeAfterFind    callbackType = "AfterFind"
 | |
| )
 | |
| 
 | |
| // ErrUnsupportedDataType unsupported data type
 | |
| var ErrUnsupportedDataType = errors.New("unsupported data type")
 | |
| 
 | |
| type Schema struct {
 | |
| 	Name                      string
 | |
| 	ModelType                 reflect.Type
 | |
| 	Table                     string
 | |
| 	PrioritizedPrimaryField   *Field
 | |
| 	DBNames                   []string
 | |
| 	PrimaryFields             []*Field
 | |
| 	PrimaryFieldDBNames       []string
 | |
| 	Fields                    []*Field
 | |
| 	FieldsByName              map[string]*Field
 | |
| 	FieldsByBindName          map[string]*Field // embedded fields is 'Embed.Field'
 | |
| 	FieldsByDBName            map[string]*Field
 | |
| 	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
 | |
| 	Relationships             Relationships
 | |
| 	CreateClauses             []clause.Interface
 | |
| 	QueryClauses              []clause.Interface
 | |
| 	UpdateClauses             []clause.Interface
 | |
| 	DeleteClauses             []clause.Interface
 | |
| 	BeforeCreate, AfterCreate bool
 | |
| 	BeforeUpdate, AfterUpdate bool
 | |
| 	BeforeDelete, AfterDelete bool
 | |
| 	BeforeSave, AfterSave     bool
 | |
| 	AfterFind                 bool
 | |
| 	err                       error
 | |
| 	initialized               chan struct{}
 | |
| 	namer                     Namer
 | |
| 	cacheStore                *sync.Map
 | |
| }
 | |
| 
 | |
| func (schema Schema) String() string {
 | |
| 	if schema.ModelType.Name() == "" {
 | |
| 		return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
 | |
| 	}
 | |
| 	return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
 | |
| }
 | |
| 
 | |
| func (schema Schema) MakeSlice() reflect.Value {
 | |
| 	slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
 | |
| 	results := reflect.New(slice.Type())
 | |
| 	results.Elem().Set(slice)
 | |
| 	return results
 | |
| }
 | |
| 
 | |
| func (schema Schema) LookUpField(name string) *Field {
 | |
| 	if field, ok := schema.FieldsByDBName[name]; ok {
 | |
| 		return field
 | |
| 	}
 | |
| 	if field, ok := schema.FieldsByName[name]; ok {
 | |
| 		return field
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // LookUpFieldByBindName looks for the closest field in the embedded struct.
 | |
| //
 | |
| //	type Struct struct {
 | |
| //		Embedded struct {
 | |
| //			ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
 | |
| //		}
 | |
| //		ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
 | |
| //	}
 | |
| func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
 | |
| 	if len(bindNames) == 0 {
 | |
| 		return nil
 | |
| 	}
 | |
| 	for i := len(bindNames) - 1; i >= 0; i-- {
 | |
| 		find := strings.Join(bindNames[:i], ".") + "." + name
 | |
| 		if field, ok := schema.FieldsByBindName[find]; ok {
 | |
| 			return field
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| type Tabler interface {
 | |
| 	TableName() string
 | |
| }
 | |
| 
 | |
| type TablerWithNamer interface {
 | |
| 	TableName(Namer) string
 | |
| }
 | |
| 
 | |
| // Parse get data type from dialector
 | |
| func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
 | |
| 	return ParseWithSpecialTableName(dest, cacheStore, namer, "")
 | |
| }
 | |
| 
 | |
| // ParseWithSpecialTableName get data type from dialector with extra schema table
 | |
| func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
 | |
| 	if dest == nil {
 | |
| 		return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 | |
| 	}
 | |
| 
 | |
| 	value := reflect.ValueOf(dest)
 | |
| 	if value.Kind() == reflect.Ptr && value.IsNil() {
 | |
| 		value = reflect.New(value.Type().Elem())
 | |
| 	}
 | |
| 	modelType := reflect.Indirect(value).Type()
 | |
| 
 | |
| 	if modelType.Kind() == reflect.Interface {
 | |
| 		modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
 | |
| 	}
 | |
| 
 | |
| 	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
 | |
| 		modelType = modelType.Elem()
 | |
| 	}
 | |
| 
 | |
| 	if modelType.Kind() != reflect.Struct {
 | |
| 		if modelType.PkgPath() == "" {
 | |
| 			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 | |
| 		}
 | |
| 		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 | |
| 	}
 | |
| 
 | |
| 	// Cache the Schema for performance,
 | |
| 	// Use the modelType or modelType + schemaTable (if it present) as cache key.
 | |
| 	var schemaCacheKey interface{}
 | |
| 	if specialTableName != "" {
 | |
| 		schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
 | |
| 	} else {
 | |
| 		schemaCacheKey = modelType
 | |
| 	}
 | |
| 
 | |
| 	// Load exist schema cache, return if exists
 | |
| 	if v, ok := cacheStore.Load(schemaCacheKey); ok {
 | |
| 		s := v.(*Schema)
 | |
| 		// Wait for the initialization of other goroutines to complete
 | |
| 		<-s.initialized
 | |
| 		return s, s.err
 | |
| 	}
 | |
| 
 | |
| 	modelValue := reflect.New(modelType)
 | |
| 	tableName := namer.TableName(modelType.Name())
 | |
| 	if tabler, ok := modelValue.Interface().(Tabler); ok {
 | |
| 		tableName = tabler.TableName()
 | |
| 	}
 | |
| 	if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
 | |
| 		tableName = tabler.TableName(namer)
 | |
| 	}
 | |
| 	if en, ok := namer.(embeddedNamer); ok {
 | |
| 		tableName = en.Table
 | |
| 	}
 | |
| 	if specialTableName != "" && specialTableName != tableName {
 | |
| 		tableName = specialTableName
 | |
| 	}
 | |
| 
 | |
| 	schema := &Schema{
 | |
| 		Name:             modelType.Name(),
 | |
| 		ModelType:        modelType,
 | |
| 		Table:            tableName,
 | |
| 		FieldsByName:     map[string]*Field{},
 | |
| 		FieldsByBindName: map[string]*Field{},
 | |
| 		FieldsByDBName:   map[string]*Field{},
 | |
| 		Relationships:    Relationships{Relations: map[string]*Relationship{}},
 | |
| 		cacheStore:       cacheStore,
 | |
| 		namer:            namer,
 | |
| 		initialized:      make(chan struct{}),
 | |
| 	}
 | |
| 	// When the schema initialization is completed, the channel will be closed
 | |
| 	defer close(schema.initialized)
 | |
| 
 | |
| 	// Load exist schema cache, return if exists
 | |
| 	if v, ok := cacheStore.Load(schemaCacheKey); ok {
 | |
| 		s := v.(*Schema)
 | |
| 		// Wait for the initialization of other goroutines to complete
 | |
| 		<-s.initialized
 | |
| 		return s, s.err
 | |
| 	}
 | |
| 
 | |
| 	for i := 0; i < modelType.NumField(); i++ {
 | |
| 		if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
 | |
| 			if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
 | |
| 				schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
 | |
| 			} else {
 | |
| 				schema.Fields = append(schema.Fields, field)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	for _, field := range schema.Fields {
 | |
| 		if field.DBName == "" && field.DataType != "" {
 | |
| 			field.DBName = namer.ColumnName(schema.Table, field.Name)
 | |
| 		}
 | |
| 
 | |
| 		bindName := field.BindName()
 | |
| 		if field.DBName != "" {
 | |
| 			// nonexistence or shortest path or first appear prioritized if has permission
 | |
| 			if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
 | |
| 				if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
 | |
| 					schema.DBNames = append(schema.DBNames, field.DBName)
 | |
| 				}
 | |
| 				schema.FieldsByDBName[field.DBName] = field
 | |
| 				schema.FieldsByName[field.Name] = field
 | |
| 				schema.FieldsByBindName[bindName] = field
 | |
| 
 | |
| 				if v != nil && v.PrimaryKey {
 | |
| 					for idx, f := range schema.PrimaryFields {
 | |
| 						if f == v {
 | |
| 							schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 
 | |
| 				if field.PrimaryKey {
 | |
| 					schema.PrimaryFields = append(schema.PrimaryFields, field)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
 | |
| 			schema.FieldsByName[field.Name] = field
 | |
| 		}
 | |
| 		if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
 | |
| 			schema.FieldsByBindName[bindName] = field
 | |
| 		}
 | |
| 
 | |
| 		field.setupValuerAndSetter()
 | |
| 	}
 | |
| 
 | |
| 	prioritizedPrimaryField := schema.LookUpField("id")
 | |
| 	if prioritizedPrimaryField == nil {
 | |
| 		prioritizedPrimaryField = schema.LookUpField("ID")
 | |
| 	}
 | |
| 
 | |
| 	if prioritizedPrimaryField != nil {
 | |
| 		if prioritizedPrimaryField.PrimaryKey {
 | |
| 			schema.PrioritizedPrimaryField = prioritizedPrimaryField
 | |
| 		} else if len(schema.PrimaryFields) == 0 {
 | |
| 			prioritizedPrimaryField.PrimaryKey = true
 | |
| 			schema.PrioritizedPrimaryField = prioritizedPrimaryField
 | |
| 			schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if schema.PrioritizedPrimaryField == nil {
 | |
| 		if len(schema.PrimaryFields) == 1 {
 | |
| 			schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
 | |
| 		} else if len(schema.PrimaryFields) > 1 {
 | |
| 			// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
 | |
| 			for _, field := range schema.PrimaryFields {
 | |
| 				if field.AutoIncrement {
 | |
| 					schema.PrioritizedPrimaryField = field
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	for _, field := range schema.PrimaryFields {
 | |
| 		schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
 | |
| 	}
 | |
| 
 | |
| 	for _, field := range schema.Fields {
 | |
| 		if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
 | |
| 			schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if field := schema.PrioritizedPrimaryField; field != nil {
 | |
| 		switch field.GORMDataType {
 | |
| 		case Int, Uint:
 | |
| 			if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
 | |
| 				if !field.HasDefaultValue || field.DefaultValueInterface != nil {
 | |
| 					schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
 | |
| 				}
 | |
| 
 | |
| 				field.HasDefaultValue = true
 | |
| 				field.AutoIncrement = true
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	callbackTypes := []callbackType{
 | |
| 		callbackTypeBeforeCreate, callbackTypeAfterCreate,
 | |
| 		callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
 | |
| 		callbackTypeBeforeSave, callbackTypeAfterSave,
 | |
| 		callbackTypeBeforeDelete, callbackTypeAfterDelete,
 | |
| 		callbackTypeAfterFind,
 | |
| 	}
 | |
| 	for _, cbName := range callbackTypes {
 | |
| 		if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
 | |
| 			switch methodValue.Type().String() {
 | |
| 			case "func(*gorm.DB) error": // TODO hack
 | |
| 				reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
 | |
| 			default:
 | |
| 				logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Cache the schema
 | |
| 	if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
 | |
| 		s := v.(*Schema)
 | |
| 		// Wait for the initialization of other goroutines to complete
 | |
| 		<-s.initialized
 | |
| 		return s, s.err
 | |
| 	}
 | |
| 
 | |
| 	defer func() {
 | |
| 		if schema.err != nil {
 | |
| 			logger.Default.Error(context.Background(), schema.err.Error())
 | |
| 			cacheStore.Delete(modelType)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
 | |
| 		for _, field := range schema.Fields {
 | |
| 			if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
 | |
| 				if schema.parseRelation(field); schema.err != nil {
 | |
| 					return schema, schema.err
 | |
| 				} else {
 | |
| 					schema.FieldsByName[field.Name] = field
 | |
| 					schema.FieldsByBindName[field.BindName()] = field
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			fieldValue := reflect.New(field.IndirectFieldType)
 | |
| 			fieldInterface := fieldValue.Interface()
 | |
| 			if fc, ok := fieldInterface.(CreateClausesInterface); ok {
 | |
| 				field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
 | |
| 			}
 | |
| 
 | |
| 			if fc, ok := fieldInterface.(QueryClausesInterface); ok {
 | |
| 				field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
 | |
| 			}
 | |
| 
 | |
| 			if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
 | |
| 				field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
 | |
| 			}
 | |
| 
 | |
| 			if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
 | |
| 				field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return schema, schema.err
 | |
| }
 | |
| 
 | |
| // This unrolling is needed to show to the compiler the exact set of methods
 | |
| // that can be used on the modelType.
 | |
| // Prior to go1.22 any use of MethodByName would cause the linker to
 | |
| // abandon dead code elimination for the entire binary.
 | |
| // As of go1.22 the compiler supports one special case of a string constant
 | |
| // being passed to MethodByName. For enterprise customers or those building
 | |
| // large binaries, this gives a significant reduction in binary size.
 | |
| // https://github.com/golang/go/issues/62257
 | |
| func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
 | |
| 	switch cbType {
 | |
| 	case callbackTypeBeforeCreate:
 | |
| 		return modelType.MethodByName(string(callbackTypeBeforeCreate))
 | |
| 	case callbackTypeAfterCreate:
 | |
| 		return modelType.MethodByName(string(callbackTypeAfterCreate))
 | |
| 	case callbackTypeBeforeUpdate:
 | |
| 		return modelType.MethodByName(string(callbackTypeBeforeUpdate))
 | |
| 	case callbackTypeAfterUpdate:
 | |
| 		return modelType.MethodByName(string(callbackTypeAfterUpdate))
 | |
| 	case callbackTypeBeforeSave:
 | |
| 		return modelType.MethodByName(string(callbackTypeBeforeSave))
 | |
| 	case callbackTypeAfterSave:
 | |
| 		return modelType.MethodByName(string(callbackTypeAfterSave))
 | |
| 	case callbackTypeBeforeDelete:
 | |
| 		return modelType.MethodByName(string(callbackTypeBeforeDelete))
 | |
| 	case callbackTypeAfterDelete:
 | |
| 		return modelType.MethodByName(string(callbackTypeAfterDelete))
 | |
| 	case callbackTypeAfterFind:
 | |
| 		return modelType.MethodByName(string(callbackTypeAfterFind))
 | |
| 	default:
 | |
| 		return reflect.ValueOf(nil)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
 | |
| 	modelType := reflect.ValueOf(dest).Type()
 | |
| 	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
 | |
| 		modelType = modelType.Elem()
 | |
| 	}
 | |
| 
 | |
| 	if modelType.Kind() != reflect.Struct {
 | |
| 		if modelType.PkgPath() == "" {
 | |
| 			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
 | |
| 		}
 | |
| 		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
 | |
| 	}
 | |
| 
 | |
| 	if v, ok := cacheStore.Load(modelType); ok {
 | |
| 		return v.(*Schema), nil
 | |
| 	}
 | |
| 
 | |
| 	return Parse(dest, cacheStore, namer)
 | |
| }
 |