507 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			507 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package gorm
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"database/sql"
 | 
						|
	"fmt"
 | 
						|
	"reflect"
 | 
						|
	"sort"
 | 
						|
	"sync"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"gorm.io/gorm/clause"
 | 
						|
	"gorm.io/gorm/logger"
 | 
						|
	"gorm.io/gorm/schema"
 | 
						|
)
 | 
						|
 | 
						|
// for Config.cacheStore store PreparedStmtDB key
 | 
						|
const preparedStmtDBKey = "preparedStmt"
 | 
						|
 | 
						|
// Config GORM config
 | 
						|
type Config struct {
 | 
						|
	// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
 | 
						|
	// You can disable it by setting `SkipDefaultTransaction` to true
 | 
						|
	SkipDefaultTransaction bool
 | 
						|
	// NamingStrategy tables, columns naming strategy
 | 
						|
	NamingStrategy schema.Namer
 | 
						|
	// FullSaveAssociations full save associations
 | 
						|
	FullSaveAssociations bool
 | 
						|
	// Logger
 | 
						|
	Logger logger.Interface
 | 
						|
	// NowFunc the function to be used when creating a new timestamp
 | 
						|
	NowFunc func() time.Time
 | 
						|
	// DryRun generate sql without execute
 | 
						|
	DryRun bool
 | 
						|
	// PrepareStmt executes the given query in cached statement
 | 
						|
	PrepareStmt bool
 | 
						|
	// DisableAutomaticPing
 | 
						|
	DisableAutomaticPing bool
 | 
						|
	// DisableForeignKeyConstraintWhenMigrating
 | 
						|
	DisableForeignKeyConstraintWhenMigrating bool
 | 
						|
	// IgnoreRelationshipsWhenMigrating
 | 
						|
	IgnoreRelationshipsWhenMigrating bool
 | 
						|
	// DisableNestedTransaction disable nested transaction
 | 
						|
	DisableNestedTransaction bool
 | 
						|
	// AllowGlobalUpdate allow global update
 | 
						|
	AllowGlobalUpdate bool
 | 
						|
	// QueryFields executes the SQL query with all fields of the table
 | 
						|
	QueryFields bool
 | 
						|
	// CreateBatchSize default create batch size
 | 
						|
	CreateBatchSize int
 | 
						|
	// TranslateError enabling error translation
 | 
						|
	TranslateError bool
 | 
						|
 | 
						|
	// ClauseBuilders clause builder
 | 
						|
	ClauseBuilders map[string]clause.ClauseBuilder
 | 
						|
	// ConnPool db conn pool
 | 
						|
	ConnPool ConnPool
 | 
						|
	// Dialector database dialector
 | 
						|
	Dialector
 | 
						|
	// Plugins registered plugins
 | 
						|
	Plugins map[string]Plugin
 | 
						|
 | 
						|
	callbacks  *callbacks
 | 
						|
	cacheStore *sync.Map
 | 
						|
}
 | 
						|
 | 
						|
// Apply update config to new config
 | 
						|
func (c *Config) Apply(config *Config) error {
 | 
						|
	if config != c {
 | 
						|
		*config = *c
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// AfterInitialize initialize plugins after db connected
 | 
						|
func (c *Config) AfterInitialize(db *DB) error {
 | 
						|
	if db != nil {
 | 
						|
		for _, plugin := range c.Plugins {
 | 
						|
			if err := plugin.Initialize(db); err != nil {
 | 
						|
				return err
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Option gorm option interface
 | 
						|
type Option interface {
 | 
						|
	Apply(*Config) error
 | 
						|
	AfterInitialize(*DB) error
 | 
						|
}
 | 
						|
 | 
						|
// DB GORM DB definition
 | 
						|
type DB struct {
 | 
						|
	*Config
 | 
						|
	Error        error
 | 
						|
	RowsAffected int64
 | 
						|
	Statement    *Statement
 | 
						|
	clone        int
 | 
						|
}
 | 
						|
 | 
						|
// Session session config when create session with Session() method
 | 
						|
type Session struct {
 | 
						|
	DryRun                   bool
 | 
						|
	PrepareStmt              bool
 | 
						|
	NewDB                    bool
 | 
						|
	Initialized              bool
 | 
						|
	SkipHooks                bool
 | 
						|
	SkipDefaultTransaction   bool
 | 
						|
	DisableNestedTransaction bool
 | 
						|
	AllowGlobalUpdate        bool
 | 
						|
	FullSaveAssociations     bool
 | 
						|
	QueryFields              bool
 | 
						|
	Context                  context.Context
 | 
						|
	Logger                   logger.Interface
 | 
						|
	NowFunc                  func() time.Time
 | 
						|
	CreateBatchSize          int
 | 
						|
}
 | 
						|
 | 
						|
// Open initialize db session based on dialector
 | 
						|
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
 | 
						|
	config := &Config{}
 | 
						|
 | 
						|
	sort.Slice(opts, func(i, j int) bool {
 | 
						|
		_, isConfig := opts[i].(*Config)
 | 
						|
		_, isConfig2 := opts[j].(*Config)
 | 
						|
		return isConfig && !isConfig2
 | 
						|
	})
 | 
						|
 | 
						|
	for _, opt := range opts {
 | 
						|
		if opt != nil {
 | 
						|
			if applyErr := opt.Apply(config); applyErr != nil {
 | 
						|
				return nil, applyErr
 | 
						|
			}
 | 
						|
			defer func(opt Option) {
 | 
						|
				if errr := opt.AfterInitialize(db); errr != nil {
 | 
						|
					err = errr
 | 
						|
				}
 | 
						|
			}(opt)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
 | 
						|
		if err = d.Apply(config); err != nil {
 | 
						|
			return
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if config.NamingStrategy == nil {
 | 
						|
		config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
 | 
						|
	}
 | 
						|
 | 
						|
	if config.Logger == nil {
 | 
						|
		config.Logger = logger.Default
 | 
						|
	}
 | 
						|
 | 
						|
	if config.NowFunc == nil {
 | 
						|
		config.NowFunc = func() time.Time { return time.Now().Local() }
 | 
						|
	}
 | 
						|
 | 
						|
	if dialector != nil {
 | 
						|
		config.Dialector = dialector
 | 
						|
	}
 | 
						|
 | 
						|
	if config.Plugins == nil {
 | 
						|
		config.Plugins = map[string]Plugin{}
 | 
						|
	}
 | 
						|
 | 
						|
	if config.cacheStore == nil {
 | 
						|
		config.cacheStore = &sync.Map{}
 | 
						|
	}
 | 
						|
 | 
						|
	db = &DB{Config: config, clone: 1}
 | 
						|
 | 
						|
	db.callbacks = initializeCallbacks(db)
 | 
						|
 | 
						|
	if config.ClauseBuilders == nil {
 | 
						|
		config.ClauseBuilders = map[string]clause.ClauseBuilder{}
 | 
						|
	}
 | 
						|
 | 
						|
	if config.Dialector != nil {
 | 
						|
		err = config.Dialector.Initialize(db)
 | 
						|
 | 
						|
		if err != nil {
 | 
						|
			if db, _ := db.DB(); db != nil {
 | 
						|
				_ = db.Close()
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if config.PrepareStmt {
 | 
						|
		preparedStmt := NewPreparedStmtDB(db.ConnPool)
 | 
						|
		db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
 | 
						|
		db.ConnPool = preparedStmt
 | 
						|
	}
 | 
						|
 | 
						|
	db.Statement = &Statement{
 | 
						|
		DB:       db,
 | 
						|
		ConnPool: db.ConnPool,
 | 
						|
		Context:  context.Background(),
 | 
						|
		Clauses:  map[string]clause.Clause{},
 | 
						|
	}
 | 
						|
 | 
						|
	if err == nil && !config.DisableAutomaticPing {
 | 
						|
		if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
 | 
						|
			err = pinger.Ping()
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	return
 | 
						|
}
 | 
						|
 | 
						|
// Session create new db session
 | 
						|
func (db *DB) Session(config *Session) *DB {
 | 
						|
	var (
 | 
						|
		txConfig = *db.Config
 | 
						|
		tx       = &DB{
 | 
						|
			Config:    &txConfig,
 | 
						|
			Statement: db.Statement,
 | 
						|
			Error:     db.Error,
 | 
						|
			clone:     1,
 | 
						|
		}
 | 
						|
	)
 | 
						|
	if config.CreateBatchSize > 0 {
 | 
						|
		tx.Config.CreateBatchSize = config.CreateBatchSize
 | 
						|
	}
 | 
						|
 | 
						|
	if config.SkipDefaultTransaction {
 | 
						|
		tx.Config.SkipDefaultTransaction = true
 | 
						|
	}
 | 
						|
 | 
						|
	if config.AllowGlobalUpdate {
 | 
						|
		txConfig.AllowGlobalUpdate = true
 | 
						|
	}
 | 
						|
 | 
						|
	if config.FullSaveAssociations {
 | 
						|
		txConfig.FullSaveAssociations = true
 | 
						|
	}
 | 
						|
 | 
						|
	if config.Context != nil || config.PrepareStmt || config.SkipHooks {
 | 
						|
		tx.Statement = tx.Statement.clone()
 | 
						|
		tx.Statement.DB = tx
 | 
						|
	}
 | 
						|
 | 
						|
	if config.Context != nil {
 | 
						|
		tx.Statement.Context = config.Context
 | 
						|
	}
 | 
						|
 | 
						|
	if config.PrepareStmt {
 | 
						|
		var preparedStmt *PreparedStmtDB
 | 
						|
 | 
						|
		if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
 | 
						|
			preparedStmt = v.(*PreparedStmtDB)
 | 
						|
		} else {
 | 
						|
			preparedStmt = NewPreparedStmtDB(db.ConnPool)
 | 
						|
			db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
 | 
						|
		}
 | 
						|
 | 
						|
		switch t := tx.Statement.ConnPool.(type) {
 | 
						|
		case Tx:
 | 
						|
			tx.Statement.ConnPool = &PreparedStmtTX{
 | 
						|
				Tx:             t,
 | 
						|
				PreparedStmtDB: preparedStmt,
 | 
						|
			}
 | 
						|
		default:
 | 
						|
			tx.Statement.ConnPool = &PreparedStmtDB{
 | 
						|
				ConnPool: db.Config.ConnPool,
 | 
						|
				Mux:      preparedStmt.Mux,
 | 
						|
				Stmts:    preparedStmt.Stmts,
 | 
						|
			}
 | 
						|
		}
 | 
						|
		txConfig.ConnPool = tx.Statement.ConnPool
 | 
						|
		txConfig.PrepareStmt = true
 | 
						|
	}
 | 
						|
 | 
						|
	if config.SkipHooks {
 | 
						|
		tx.Statement.SkipHooks = true
 | 
						|
	}
 | 
						|
 | 
						|
	if config.DisableNestedTransaction {
 | 
						|
		txConfig.DisableNestedTransaction = true
 | 
						|
	}
 | 
						|
 | 
						|
	if !config.NewDB {
 | 
						|
		tx.clone = 2
 | 
						|
	}
 | 
						|
 | 
						|
	if config.DryRun {
 | 
						|
		tx.Config.DryRun = true
 | 
						|
	}
 | 
						|
 | 
						|
	if config.QueryFields {
 | 
						|
		tx.Config.QueryFields = true
 | 
						|
	}
 | 
						|
 | 
						|
	if config.Logger != nil {
 | 
						|
		tx.Config.Logger = config.Logger
 | 
						|
	}
 | 
						|
 | 
						|
	if config.NowFunc != nil {
 | 
						|
		tx.Config.NowFunc = config.NowFunc
 | 
						|
	}
 | 
						|
 | 
						|
	if config.Initialized {
 | 
						|
		tx = tx.getInstance()
 | 
						|
	}
 | 
						|
 | 
						|
	return tx
 | 
						|
}
 | 
						|
 | 
						|
// WithContext change current instance db's context to ctx
 | 
						|
func (db *DB) WithContext(ctx context.Context) *DB {
 | 
						|
	return db.Session(&Session{Context: ctx})
 | 
						|
}
 | 
						|
 | 
						|
// Debug start debug mode
 | 
						|
func (db *DB) Debug() (tx *DB) {
 | 
						|
	tx = db.getInstance()
 | 
						|
	return tx.Session(&Session{
 | 
						|
		Logger: db.Logger.LogMode(logger.Info),
 | 
						|
	})
 | 
						|
}
 | 
						|
 | 
						|
// Set store value with key into current db instance's context
 | 
						|
func (db *DB) Set(key string, value interface{}) *DB {
 | 
						|
	tx := db.getInstance()
 | 
						|
	tx.Statement.Settings.Store(key, value)
 | 
						|
	return tx
 | 
						|
}
 | 
						|
 | 
						|
// Get get value with key from current db instance's context
 | 
						|
func (db *DB) Get(key string) (interface{}, bool) {
 | 
						|
	return db.Statement.Settings.Load(key)
 | 
						|
}
 | 
						|
 | 
						|
// InstanceSet store value with key into current db instance's context
 | 
						|
func (db *DB) InstanceSet(key string, value interface{}) *DB {
 | 
						|
	tx := db.getInstance()
 | 
						|
	tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
 | 
						|
	return tx
 | 
						|
}
 | 
						|
 | 
						|
// InstanceGet get value with key from current db instance's context
 | 
						|
func (db *DB) InstanceGet(key string) (interface{}, bool) {
 | 
						|
	return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
 | 
						|
}
 | 
						|
 | 
						|
// Callback returns callback manager
 | 
						|
func (db *DB) Callback() *callbacks {
 | 
						|
	return db.callbacks
 | 
						|
}
 | 
						|
 | 
						|
// AddError add error to db
 | 
						|
func (db *DB) AddError(err error) error {
 | 
						|
	if err != nil {
 | 
						|
		if db.Config.TranslateError {
 | 
						|
			if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
 | 
						|
				err = errTranslator.Translate(err)
 | 
						|
			}
 | 
						|
		}
 | 
						|
 | 
						|
		if db.Error == nil {
 | 
						|
			db.Error = err
 | 
						|
		} else {
 | 
						|
			db.Error = fmt.Errorf("%v; %w", db.Error, err)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return db.Error
 | 
						|
}
 | 
						|
 | 
						|
// DB returns `*sql.DB`
 | 
						|
func (db *DB) DB() (*sql.DB, error) {
 | 
						|
	connPool := db.ConnPool
 | 
						|
	if db.Statement != nil && db.Statement.ConnPool != nil {
 | 
						|
		connPool = db.Statement.ConnPool
 | 
						|
	}
 | 
						|
	if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
 | 
						|
		return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
 | 
						|
	}
 | 
						|
 | 
						|
	if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
 | 
						|
		if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
 | 
						|
			return sqldb, err
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
 | 
						|
		return sqldb, nil
 | 
						|
	}
 | 
						|
 | 
						|
	return nil, ErrInvalidDB
 | 
						|
}
 | 
						|
 | 
						|
func (db *DB) getInstance() *DB {
 | 
						|
	if db.clone > 0 {
 | 
						|
		tx := &DB{Config: db.Config, Error: db.Error}
 | 
						|
 | 
						|
		if db.clone == 1 {
 | 
						|
			// clone with new statement
 | 
						|
			tx.Statement = &Statement{
 | 
						|
				DB:        tx,
 | 
						|
				ConnPool:  db.Statement.ConnPool,
 | 
						|
				Context:   db.Statement.Context,
 | 
						|
				Clauses:   map[string]clause.Clause{},
 | 
						|
				Vars:      make([]interface{}, 0, 8),
 | 
						|
				SkipHooks: db.Statement.SkipHooks,
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			// with clone statement
 | 
						|
			tx.Statement = db.Statement.clone()
 | 
						|
			tx.Statement.DB = tx
 | 
						|
		}
 | 
						|
 | 
						|
		return tx
 | 
						|
	}
 | 
						|
 | 
						|
	return db
 | 
						|
}
 | 
						|
 | 
						|
// Expr returns clause.Expr, which can be used to pass SQL expression as params
 | 
						|
func Expr(expr string, args ...interface{}) clause.Expr {
 | 
						|
	return clause.Expr{SQL: expr, Vars: args}
 | 
						|
}
 | 
						|
 | 
						|
// SetupJoinTable setup join table schema
 | 
						|
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
 | 
						|
	var (
 | 
						|
		tx                      = db.getInstance()
 | 
						|
		stmt                    = tx.Statement
 | 
						|
		modelSchema, joinSchema *schema.Schema
 | 
						|
	)
 | 
						|
 | 
						|
	err := stmt.Parse(model)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	modelSchema = stmt.Schema
 | 
						|
 | 
						|
	err = stmt.Parse(joinTable)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	joinSchema = stmt.Schema
 | 
						|
 | 
						|
	relation, ok := modelSchema.Relationships.Relations[field]
 | 
						|
	isRelation := ok && relation.JoinTable != nil
 | 
						|
	if !isRelation {
 | 
						|
		return fmt.Errorf("failed to find relation: %s", field)
 | 
						|
	}
 | 
						|
 | 
						|
	for _, ref := range relation.References {
 | 
						|
		f := joinSchema.LookUpField(ref.ForeignKey.DBName)
 | 
						|
		if f == nil {
 | 
						|
			return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
 | 
						|
		}
 | 
						|
 | 
						|
		f.DataType = ref.ForeignKey.DataType
 | 
						|
		f.GORMDataType = ref.ForeignKey.GORMDataType
 | 
						|
		if f.Size == 0 {
 | 
						|
			f.Size = ref.ForeignKey.Size
 | 
						|
		}
 | 
						|
		ref.ForeignKey = f
 | 
						|
	}
 | 
						|
 | 
						|
	for name, rel := range relation.JoinTable.Relationships.Relations {
 | 
						|
		if _, ok := joinSchema.Relationships.Relations[name]; !ok {
 | 
						|
			rel.Schema = joinSchema
 | 
						|
			joinSchema.Relationships.Relations[name] = rel
 | 
						|
		}
 | 
						|
	}
 | 
						|
	relation.JoinTable = joinSchema
 | 
						|
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Use use plugin
 | 
						|
func (db *DB) Use(plugin Plugin) error {
 | 
						|
	name := plugin.Name()
 | 
						|
	if _, ok := db.Plugins[name]; ok {
 | 
						|
		return ErrRegistered
 | 
						|
	}
 | 
						|
	if err := plugin.Initialize(db); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	db.Plugins[name] = plugin
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// ToSQL for generate SQL string.
 | 
						|
//
 | 
						|
//	db.ToSQL(func(tx *gorm.DB) *gorm.DB {
 | 
						|
//			return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
 | 
						|
//				.Limit(10).Offset(5)
 | 
						|
//				.Order("name ASC")
 | 
						|
//				.First(&User{})
 | 
						|
//	})
 | 
						|
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
 | 
						|
	tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
 | 
						|
	stmt := tx.Statement
 | 
						|
 | 
						|
	return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
 | 
						|
}
 |