261 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			261 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package gorm
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"database/sql"
 | 
						|
	"database/sql/driver"
 | 
						|
	"errors"
 | 
						|
	"reflect"
 | 
						|
	"sync"
 | 
						|
)
 | 
						|
 | 
						|
type Stmt struct {
 | 
						|
	*sql.Stmt
 | 
						|
	Transaction bool
 | 
						|
	prepared    chan struct{}
 | 
						|
	prepareErr  error
 | 
						|
}
 | 
						|
 | 
						|
type PreparedStmtDB struct {
 | 
						|
	Stmts       map[string]*Stmt
 | 
						|
	PreparedSQL []string
 | 
						|
	Mux         *sync.RWMutex
 | 
						|
	ConnPool
 | 
						|
}
 | 
						|
 | 
						|
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
 | 
						|
	return &PreparedStmtDB{
 | 
						|
		ConnPool:    connPool,
 | 
						|
		Stmts:       make(map[string]*Stmt),
 | 
						|
		Mux:         &sync.RWMutex{},
 | 
						|
		PreparedSQL: make([]string, 0, 100),
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
 | 
						|
	if sqldb, ok := db.ConnPool.(*sql.DB); ok {
 | 
						|
		return sqldb, nil
 | 
						|
	}
 | 
						|
 | 
						|
	if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
 | 
						|
		return dbConnector.GetDBConn()
 | 
						|
	}
 | 
						|
 | 
						|
	return nil, ErrInvalidDB
 | 
						|
}
 | 
						|
 | 
						|
func (db *PreparedStmtDB) Close() {
 | 
						|
	db.Mux.Lock()
 | 
						|
	defer db.Mux.Unlock()
 | 
						|
 | 
						|
	for _, query := range db.PreparedSQL {
 | 
						|
		if stmt, ok := db.Stmts[query]; ok {
 | 
						|
			delete(db.Stmts, query)
 | 
						|
			go stmt.Close()
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func (sdb *PreparedStmtDB) Reset() {
 | 
						|
	sdb.Mux.Lock()
 | 
						|
	defer sdb.Mux.Unlock()
 | 
						|
 | 
						|
	for _, stmt := range sdb.Stmts {
 | 
						|
		go stmt.Close()
 | 
						|
	}
 | 
						|
	sdb.PreparedSQL = make([]string, 0, 100)
 | 
						|
	sdb.Stmts = make(map[string]*Stmt)
 | 
						|
}
 | 
						|
 | 
						|
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
 | 
						|
	db.Mux.RLock()
 | 
						|
	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
 | 
						|
		db.Mux.RUnlock()
 | 
						|
		// wait for other goroutines prepared
 | 
						|
		<-stmt.prepared
 | 
						|
		if stmt.prepareErr != nil {
 | 
						|
			return Stmt{}, stmt.prepareErr
 | 
						|
		}
 | 
						|
 | 
						|
		return *stmt, nil
 | 
						|
	}
 | 
						|
	db.Mux.RUnlock()
 | 
						|
 | 
						|
	db.Mux.Lock()
 | 
						|
	// double check
 | 
						|
	if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
 | 
						|
		db.Mux.Unlock()
 | 
						|
		// wait for other goroutines prepared
 | 
						|
		<-stmt.prepared
 | 
						|
		if stmt.prepareErr != nil {
 | 
						|
			return Stmt{}, stmt.prepareErr
 | 
						|
		}
 | 
						|
 | 
						|
		return *stmt, nil
 | 
						|
	}
 | 
						|
 | 
						|
	// cache preparing stmt first
 | 
						|
	cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
 | 
						|
	db.Stmts[query] = &cacheStmt
 | 
						|
	db.Mux.Unlock()
 | 
						|
 | 
						|
	// prepare completed
 | 
						|
	defer close(cacheStmt.prepared)
 | 
						|
 | 
						|
	// Reason why cannot lock conn.PrepareContext
 | 
						|
	// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
 | 
						|
	// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
 | 
						|
	// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
 | 
						|
	// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
 | 
						|
	stmt, err := conn.PrepareContext(ctx, query)
 | 
						|
	if err != nil {
 | 
						|
		cacheStmt.prepareErr = err
 | 
						|
		db.Mux.Lock()
 | 
						|
		delete(db.Stmts, query)
 | 
						|
		db.Mux.Unlock()
 | 
						|
		return Stmt{}, err
 | 
						|
	}
 | 
						|
 | 
						|
	db.Mux.Lock()
 | 
						|
	cacheStmt.Stmt = stmt
 | 
						|
	db.PreparedSQL = append(db.PreparedSQL, query)
 | 
						|
	db.Mux.Unlock()
 | 
						|
 | 
						|
	return cacheStmt, nil
 | 
						|
}
 | 
						|
 | 
						|
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
 | 
						|
	if beginner, ok := db.ConnPool.(TxBeginner); ok {
 | 
						|
		tx, err := beginner.BeginTx(ctx, opt)
 | 
						|
		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
 | 
						|
	}
 | 
						|
 | 
						|
	beginner, ok := db.ConnPool.(ConnPoolBeginner)
 | 
						|
	if !ok {
 | 
						|
		return nil, ErrInvalidTransaction
 | 
						|
	}
 | 
						|
 | 
						|
	connPool, err := beginner.BeginTx(ctx, opt)
 | 
						|
	if err != nil {
 | 
						|
		return nil, err
 | 
						|
	}
 | 
						|
	if tx, ok := connPool.(Tx); ok {
 | 
						|
		return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
 | 
						|
	}
 | 
						|
	return nil, ErrInvalidTransaction
 | 
						|
}
 | 
						|
 | 
						|
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
 | 
						|
	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
						|
	if err == nil {
 | 
						|
		result, err = stmt.ExecContext(ctx, args...)
 | 
						|
		if errors.Is(err, driver.ErrBadConn) {
 | 
						|
			db.Mux.Lock()
 | 
						|
			defer db.Mux.Unlock()
 | 
						|
			go stmt.Close()
 | 
						|
			delete(db.Stmts, query)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return result, err
 | 
						|
}
 | 
						|
 | 
						|
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
						|
	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
						|
	if err == nil {
 | 
						|
		rows, err = stmt.QueryContext(ctx, args...)
 | 
						|
		if errors.Is(err, driver.ErrBadConn) {
 | 
						|
			db.Mux.Lock()
 | 
						|
			defer db.Mux.Unlock()
 | 
						|
 | 
						|
			go stmt.Close()
 | 
						|
			delete(db.Stmts, query)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return rows, err
 | 
						|
}
 | 
						|
 | 
						|
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
						|
	stmt, err := db.prepare(ctx, db.ConnPool, false, query)
 | 
						|
	if err == nil {
 | 
						|
		return stmt.QueryRowContext(ctx, args...)
 | 
						|
	}
 | 
						|
	return &sql.Row{}
 | 
						|
}
 | 
						|
 | 
						|
func (db *PreparedStmtDB) Ping() error {
 | 
						|
	conn, err := db.GetDBConn()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	return conn.Ping()
 | 
						|
}
 | 
						|
 | 
						|
type PreparedStmtTX struct {
 | 
						|
	Tx
 | 
						|
	PreparedStmtDB *PreparedStmtDB
 | 
						|
}
 | 
						|
 | 
						|
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
 | 
						|
	return db.PreparedStmtDB.GetDBConn()
 | 
						|
}
 | 
						|
 | 
						|
func (tx *PreparedStmtTX) Commit() error {
 | 
						|
	if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
 | 
						|
		return tx.Tx.Commit()
 | 
						|
	}
 | 
						|
	return ErrInvalidTransaction
 | 
						|
}
 | 
						|
 | 
						|
func (tx *PreparedStmtTX) Rollback() error {
 | 
						|
	if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
 | 
						|
		return tx.Tx.Rollback()
 | 
						|
	}
 | 
						|
	return ErrInvalidTransaction
 | 
						|
}
 | 
						|
 | 
						|
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
 | 
						|
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
						|
	if err == nil {
 | 
						|
		result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
 | 
						|
		if errors.Is(err, driver.ErrBadConn) {
 | 
						|
			tx.PreparedStmtDB.Mux.Lock()
 | 
						|
			defer tx.PreparedStmtDB.Mux.Unlock()
 | 
						|
 | 
						|
			go stmt.Close()
 | 
						|
			delete(tx.PreparedStmtDB.Stmts, query)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return result, err
 | 
						|
}
 | 
						|
 | 
						|
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
 | 
						|
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
						|
	if err == nil {
 | 
						|
		rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
 | 
						|
		if errors.Is(err, driver.ErrBadConn) {
 | 
						|
			tx.PreparedStmtDB.Mux.Lock()
 | 
						|
			defer tx.PreparedStmtDB.Mux.Unlock()
 | 
						|
 | 
						|
			go stmt.Close()
 | 
						|
			delete(tx.PreparedStmtDB.Stmts, query)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return rows, err
 | 
						|
}
 | 
						|
 | 
						|
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
 | 
						|
	stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
 | 
						|
	if err == nil {
 | 
						|
		return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
 | 
						|
	}
 | 
						|
	return &sql.Row{}
 | 
						|
}
 | 
						|
 | 
						|
func (tx *PreparedStmtTX) Ping() error {
 | 
						|
	conn, err := tx.GetDBConn()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	return conn.Ping()
 | 
						|
}
 |