290 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			290 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package postgres
 | |
| 
 | |
| import (
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"regexp"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 
 | |
| 	"github.com/jackc/pgx/v5"
 | |
| 	"github.com/jackc/pgx/v5/stdlib"
 | |
| 	"gorm.io/gorm"
 | |
| 	"gorm.io/gorm/callbacks"
 | |
| 	"gorm.io/gorm/clause"
 | |
| 	"gorm.io/gorm/logger"
 | |
| 	"gorm.io/gorm/migrator"
 | |
| 	"gorm.io/gorm/schema"
 | |
| )
 | |
| 
 | |
| type Dialector struct {
 | |
| 	*Config
 | |
| }
 | |
| 
 | |
| type Config struct {
 | |
| 	DriverName           string
 | |
| 	DSN                  string
 | |
| 	WithoutQuotingCheck  bool
 | |
| 	PreferSimpleProtocol bool
 | |
| 	WithoutReturning     bool
 | |
| 	Conn                 gorm.ConnPool
 | |
| }
 | |
| 
 | |
| var (
 | |
| 	timeZoneMatcher         = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )")
 | |
| 	defaultIdentifierLength = 63 //maximum identifier length for postgres
 | |
| )
 | |
| 
 | |
| func Open(dsn string) gorm.Dialector {
 | |
| 	return &Dialector{&Config{DSN: dsn}}
 | |
| }
 | |
| 
 | |
| func New(config Config) gorm.Dialector {
 | |
| 	return &Dialector{Config: &config}
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) Name() string {
 | |
| 	return "postgres"
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) Apply(config *gorm.Config) error {
 | |
| 	if config.NamingStrategy == nil {
 | |
| 		config.NamingStrategy = schema.NamingStrategy{
 | |
| 			IdentifierMaxLength: defaultIdentifierLength,
 | |
| 		}
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	switch v := config.NamingStrategy.(type) {
 | |
| 	case *schema.NamingStrategy:
 | |
| 		if v.IdentifierMaxLength <= 0 {
 | |
| 			v.IdentifierMaxLength = defaultIdentifierLength
 | |
| 		}
 | |
| 	case schema.NamingStrategy:
 | |
| 		if v.IdentifierMaxLength <= 0 {
 | |
| 			v.IdentifierMaxLength = defaultIdentifierLength
 | |
| 			config.NamingStrategy = v
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
 | |
| 	callbackConfig := &callbacks.Config{
 | |
| 		CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"},
 | |
| 		UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"},
 | |
| 		DeleteClauses: []string{"DELETE", "FROM", "WHERE"},
 | |
| 	}
 | |
| 	// register callbacks
 | |
| 	if !dialector.WithoutReturning {
 | |
| 		callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
 | |
| 		callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
 | |
| 		callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
 | |
| 	}
 | |
| 	callbacks.RegisterDefaultCallbacks(db, callbackConfig)
 | |
| 
 | |
| 	if dialector.Conn != nil {
 | |
| 		db.ConnPool = dialector.Conn
 | |
| 	} else if dialector.DriverName != "" {
 | |
| 		db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN)
 | |
| 	} else {
 | |
| 		var config *pgx.ConnConfig
 | |
| 
 | |
| 		config, err = pgx.ParseConfig(dialector.Config.DSN)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		if dialector.Config.PreferSimpleProtocol {
 | |
| 			config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
 | |
| 		}
 | |
| 		result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN)
 | |
| 		if len(result) > 2 {
 | |
| 			config.RuntimeParams["timezone"] = result[2]
 | |
| 		}
 | |
| 		db.ConnPool = stdlib.OpenDB(*config)
 | |
| 	}
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
 | |
| 	return Migrator{migrator.Migrator{Config: migrator.Config{
 | |
| 		DB:                          db,
 | |
| 		Dialector:                   dialector,
 | |
| 		CreateIndexAfterCreateTable: true,
 | |
| 	}}}
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
 | |
| 	return clause.Expr{SQL: "DEFAULT"}
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
 | |
| 	writer.WriteByte('$')
 | |
| 	index := 0
 | |
| 	varLen := len(stmt.Vars)
 | |
| 	if varLen > 0 {
 | |
| 		switch stmt.Vars[0].(type) {
 | |
| 		case pgx.QueryExecMode:
 | |
| 			index++
 | |
| 		}
 | |
| 	}
 | |
| 	writer.WriteString(strconv.Itoa(varLen - index))
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
 | |
| 	if dialector.WithoutQuotingCheck {
 | |
| 		writer.WriteString(str)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	var (
 | |
| 		underQuoted, selfQuoted bool
 | |
| 		continuousBacktick      int8
 | |
| 		shiftDelimiter          int8
 | |
| 	)
 | |
| 
 | |
| 	for _, v := range []byte(str) {
 | |
| 		switch v {
 | |
| 		case '"':
 | |
| 			continuousBacktick++
 | |
| 			if continuousBacktick == 2 {
 | |
| 				writer.WriteString(`""`)
 | |
| 				continuousBacktick = 0
 | |
| 			}
 | |
| 		case '.':
 | |
| 			if continuousBacktick > 0 || !selfQuoted {
 | |
| 				shiftDelimiter = 0
 | |
| 				underQuoted = false
 | |
| 				continuousBacktick = 0
 | |
| 				writer.WriteByte('"')
 | |
| 			}
 | |
| 			writer.WriteByte(v)
 | |
| 			continue
 | |
| 		default:
 | |
| 			if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
 | |
| 				writer.WriteByte('"')
 | |
| 				underQuoted = true
 | |
| 				if selfQuoted = continuousBacktick > 0; selfQuoted {
 | |
| 					continuousBacktick -= 1
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			for ; continuousBacktick > 0; continuousBacktick -= 1 {
 | |
| 				writer.WriteString(`""`)
 | |
| 			}
 | |
| 
 | |
| 			writer.WriteByte(v)
 | |
| 		}
 | |
| 		shiftDelimiter++
 | |
| 	}
 | |
| 
 | |
| 	if continuousBacktick > 0 && !selfQuoted {
 | |
| 		writer.WriteString(`""`)
 | |
| 	}
 | |
| 	writer.WriteByte('"')
 | |
| }
 | |
| 
 | |
| var numericPlaceholder = regexp.MustCompile(`\$(\d+)`)
 | |
| 
 | |
| func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
 | |
| 	return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) DataTypeOf(field *schema.Field) string {
 | |
| 	switch field.DataType {
 | |
| 	case schema.Bool:
 | |
| 		return "boolean"
 | |
| 	case schema.Int, schema.Uint:
 | |
| 		size := field.Size
 | |
| 		if field.DataType == schema.Uint {
 | |
| 			size++
 | |
| 		}
 | |
| 		if field.AutoIncrement {
 | |
| 			switch {
 | |
| 			case size <= 16:
 | |
| 				return "smallserial"
 | |
| 			case size <= 32:
 | |
| 				return "serial"
 | |
| 			default:
 | |
| 				return "bigserial"
 | |
| 			}
 | |
| 		} else {
 | |
| 			switch {
 | |
| 			case size <= 16:
 | |
| 				return "smallint"
 | |
| 			case size <= 32:
 | |
| 				return "integer"
 | |
| 			default:
 | |
| 				return "bigint"
 | |
| 			}
 | |
| 		}
 | |
| 	case schema.Float:
 | |
| 		if field.Precision > 0 {
 | |
| 			if field.Scale > 0 {
 | |
| 				return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale)
 | |
| 			}
 | |
| 			return fmt.Sprintf("numeric(%d)", field.Precision)
 | |
| 		}
 | |
| 		return "decimal"
 | |
| 	case schema.String:
 | |
| 		if field.Size > 0 {
 | |
| 			return fmt.Sprintf("varchar(%d)", field.Size)
 | |
| 		}
 | |
| 		return "text"
 | |
| 	case schema.Time:
 | |
| 		if field.Precision > 0 {
 | |
| 			return fmt.Sprintf("timestamptz(%d)", field.Precision)
 | |
| 		}
 | |
| 		return "timestamptz"
 | |
| 	case schema.Bytes:
 | |
| 		return "bytea"
 | |
| 	default:
 | |
| 		return dialector.getSchemaCustomType(field)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
 | |
| 	sqlType := string(field.DataType)
 | |
| 
 | |
| 	if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") {
 | |
| 		size := field.Size
 | |
| 		if field.GORMDataType == schema.Uint {
 | |
| 			size++
 | |
| 		}
 | |
| 		switch {
 | |
| 		case size <= 16:
 | |
| 			sqlType = "smallserial"
 | |
| 		case size <= 32:
 | |
| 			sqlType = "serial"
 | |
| 		default:
 | |
| 			sqlType = "bigserial"
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return sqlType
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
 | |
| 	tx.Exec("SAVEPOINT " + name)
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
 | |
| 	tx.Exec("ROLLBACK TO SAVEPOINT " + name)
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func getSerialDatabaseType(s string) (dbType string, ok bool) {
 | |
| 	switch s {
 | |
| 	case "smallserial":
 | |
| 		return "smallint", true
 | |
| 	case "serial":
 | |
| 		return "integer", true
 | |
| 	case "bigserial":
 | |
| 		return "bigint", true
 | |
| 	default:
 | |
| 		return "", false
 | |
| 	}
 | |
| }
 |