增加websocket支持

This commit is contained in:
2025-09-08 13:47:13 +08:00
parent 9caedd697d
commit 7e0fd53dd3
336 changed files with 9020 additions and 20356 deletions

View File

@@ -8,12 +8,10 @@ import (
"github.com/jackc/pgx/v5/pgconn"
)
// The error codes to map PostgreSQL errors to gorm errors, here is the PostgreSQL error codes reference https://www.postgresql.org/docs/current/errcodes-appendix.html.
var errCodes = map[string]error{
"23505": gorm.ErrDuplicatedKey,
"23503": gorm.ErrForeignKeyViolated,
"42703": gorm.ErrInvalidField,
"23514": gorm.ErrCheckConstraintViolated,
}
type ErrMessage struct {

View File

@@ -38,7 +38,6 @@ WHERE
`
var typeAliasMap = map[string][]string{
"int": {"integer"},
"int2": {"smallint"},
"int4": {"integer"},
"int8": {"bigint"},
@@ -49,35 +48,14 @@ var typeAliasMap = map[string][]string{
"numeric": {"decimal"},
"timestamptz": {"timestamp with time zone"},
"timestamp with time zone": {"timestamptz"},
"bool": {"boolean"},
"boolean": {"bool"},
"serial2": {"smallserial"},
"serial4": {"serial"},
"serial8": {"bigserial"},
"varbit": {"bit varying"},
"char": {"character"},
"varchar": {"character varying"},
"float4": {"real"},
"float8": {"double precision"},
"timetz": {"time with time zone"},
}
type Migrator struct {
migrator.Migrator
}
// select querys ignore dryrun
func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) {
queryTx := m.DB
if m.DB.DryRun {
queryTx = m.DB.Session(&gorm.Session{})
queryTx.DryRun = false
}
return queryTx.Raw(sql, values...)
}
func (m Migrator) CurrentDatabase() (name string) {
m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name)
m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name)
return
}
@@ -109,7 +87,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
}
}
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.queryRaw(
return m.DB.Raw(
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema,
).Scan(&count).Error
})
@@ -177,7 +155,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
func (m Migrator) GetTables() (tableList []string, err error) {
currentSchema, _ := m.CurrentSchema(m.DB.Statement, "")
return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
}
func (m Migrator) CreateTable(values ...interface{}) (err error) {
@@ -211,7 +189,7 @@ func (m Migrator) HasTable(value interface{}) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
})
return count > 0
}
@@ -263,7 +241,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
}
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.queryRaw(
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentSchema, curTable, name,
).Scan(&count).Error
@@ -288,7 +266,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) "
checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = "
checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))"
m.queryRaw(checkSQL, values...).Scan(&description)
m.DB.Raw(checkSQL, values...).Scan(&description)
comment := strings.Trim(field.Comment, "'")
comment = strings.Trim(comment, `"`)
@@ -322,7 +300,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
// check for typeName and SQL name
isSameType := true
if !strings.EqualFold(fieldColumnType.DatabaseTypeName(), fileType.SQL) {
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
isSameType = false
// if different, also check for aliases
aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName())
@@ -436,7 +414,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
}
currentSchema, curTable := m.CurrentSchema(stmt, table)
return m.queryRaw(
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?",
currentSchema, curTable, name,
).Scan(&count).Error
@@ -451,7 +429,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
var (
currentDatabase = m.DB.Migrator().CurrentDatabase()
currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
columns, err = m.queryRaw(
columns, err = m.DB.Raw(
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
currentDatabase, currentSchema, table).Rows()
)
@@ -525,7 +503,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
// check primary, unique field
{
columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
if err != nil {
return err
}
@@ -537,7 +515,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
}
columnTypeRows.Close()
columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
if err != nil {
return err
}
@@ -564,7 +542,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
// check column type
{
dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
WHERE a.attnum > 0 -- hide internal columns
AND NOT a.attisdropped -- hide deleted columns
@@ -722,7 +700,7 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
result := make([]*Index, 0)
scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error
scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error
if scanErr != nil {
return scanErr
}

View File

@@ -48,25 +48,19 @@ func (dialector Dialector) Name() string {
}
func (dialector Dialector) Apply(config *gorm.Config) error {
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{
IdentifierMaxLength: defaultIdentifierLength,
}
return nil
}
var namingStartegy *schema.NamingStrategy
switch v := config.NamingStrategy.(type) {
case *schema.NamingStrategy:
if v.IdentifierMaxLength <= 0 {
v.IdentifierMaxLength = defaultIdentifierLength
}
namingStartegy = v
case schema.NamingStrategy:
if v.IdentifierMaxLength <= 0 {
v.IdentifierMaxLength = defaultIdentifierLength
config.NamingStrategy = v
}
namingStartegy = &v
case nil:
namingStartegy = &schema.NamingStrategy{}
}
if namingStartegy.IdentifierMaxLength <= 0 {
namingStartegy.IdentifierMaxLength = defaultIdentifierLength
}
config.NamingStrategy = namingStartegy
return nil
}

2
vendor/gorm.io/gorm/README.md generated vendored
View File

@@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly.
© Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)

19
vendor/gorm.io/gorm/callbacks.go generated vendored
View File

@@ -187,18 +187,10 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
func (p *processor) compile() (err error) {
var callbacks []*callback
removedMap := map[string]bool{}
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
if callback.remove {
removedMap[callback.name] = true
}
}
if len(removedMap) > 0 {
callbacks = removeCallbacks(callbacks, removedMap)
}
p.callbacks = callbacks
@@ -347,14 +339,3 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
return
}
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
callbacks := make([]*callback, 0, len(cs))
for _, callback := range cs {
if nameMap[callback.name] {
continue
}
callbacks = append(callbacks, callback)
}
return callbacks
}

View File

@@ -103,62 +103,13 @@ func Create(config *Config) func(db *gorm.DB) {
}
db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected == 0 {
return
}
var (
pkField *schema.Field
pkFieldName = "@id"
)
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
if !supportReturning {
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
db.Statement.Schema.PrioritizedPrimaryField != nil &&
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
}
return
}
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
if config.LastInsertIDReversed {
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return
}
@@ -171,10 +122,10 @@ func Create(config *Config) func(db *gorm.DB) {
break
}
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
} else {
@@ -184,16 +135,16 @@ func Create(config *Config) func(db *gorm.DB) {
break
}
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
@@ -302,15 +253,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
for field, vs := range defaultValueFieldsHavingValue {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
@@ -333,7 +282,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue)
@@ -362,7 +311,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixMilli()
assignment.Value = curTime.UnixNano() / 1e6
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}

View File

@@ -3,7 +3,6 @@ package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
@@ -75,7 +74,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
@@ -83,103 +82,27 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
return names
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
if relationships == nil {
return nil
}
sort.Strings(preloadNames)
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
joinNames := strings.SplitN(join, ".", 2)
if len(joinNames) == 2 {
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
joined = true
nestedJoins = append(nestedJoins, joinNames[1])
}
}
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
preloadMap := parsePreloadMap(s, preloads)
for name := range preloadMap {
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
reflectValue := rel.FieldSchema.MakeSlice().Elem()
reflectValue.SetLen(rv.Len())
for i := 0; i < rv.Len(); i++ {
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
if frv.Kind() != reflect.Ptr {
reflectValue.Index(i).Set(frv.Addr())
} else {
reflectValue.Index(i).Set(frv)
}
}
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
return err
}
} else {
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
}
}
return nil
}
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var (
reflectValue = tx.Statement.ReflectValue

View File

@@ -3,6 +3,7 @@ package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
@@ -253,6 +254,7 @@ func BuildQuerySQL(db *gorm.DB) {
}
db.Statement.AddClause(fromClause)
db.Statement.Joins = nil
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
@@ -270,23 +272,38 @@ func Preload(db *gorm.DB) {
return
}
joins := make([]string, 0, len(db.Statement.Joins))
for _, join := range db.Statement.Joins {
joins = append(joins, join.Name)
preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
if tx.Error != nil {
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
preloadDB.Statement.Settings.Store(k, v)
return true
})
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
return
}
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
preloadDB.Statement.Unscoped = db.Statement.Unscoped
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
for _, name := range preloadNames {
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
} else {
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
}
}
}
}
func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
db.Statement.Joins = nil
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {

View File

@@ -234,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
} else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
@@ -268,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixMilli()
value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix()
} else {

38
vendor/gorm.io/gorm/chainable_api.go generated vendored
View File

@@ -367,12 +367,33 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
}
func (db *DB) executeScopes() (tx *DB) {
tx = db.getInstance()
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
if len(scopes) == 0 {
return tx
}
return db
tx.Statement.scopes = nil
conditions := make([]clause.Interface, 0, 4)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
for _, scope := range scopes {
tx = scope(tx)
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
conditions = append(conditions, cs.Expression.(clause.Interface))
cs.Expression = nil
tx.Statement.Clauses["WHERE"] = cs
}
}
for _, condition := range conditions {
tx.Statement.AddClause(condition)
}
return tx
}
// Preload preload associations with given conditions
@@ -429,15 +450,6 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
return
}
// Unscoped disables the global scope of soft deletion in a query.
// By default, GORM uses soft deletion, marking records as "deleted"
// by setting a timestamp on a specific field (e.g., `deleted_at`).
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance()
tx.Statement.Unscoped = true

View File

@@ -1,5 +1,7 @@
package clause
import "strconv"
// Limit limit clause
type Limit struct {
Limit *int
@@ -15,14 +17,14 @@ func (limit Limit) Name() string {
func (limit Limit) Build(builder Builder) {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ")
builder.AddVar(builder, *limit.Limit)
builder.WriteString(strconv.Itoa(*limit.Limit))
}
if limit.Offset > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ')
}
builder.WriteString("OFFSET ")
builder.AddVar(builder, limit.Offset)
builder.WriteString(strconv.Itoa(limit.Offset))
}
}

View File

@@ -1,12 +1,5 @@
package clause
const (
LockingStrengthUpdate = "UPDATE"
LockingStrengthShare = "SHARE"
LockingOptionsSkipLocked = "SKIP LOCKED"
LockingOptionsNoWait = "NOWAIT"
)
type Locking struct {
Strength string
Table Table

79
vendor/gorm.io/gorm/clause/where.go generated vendored
View File

@@ -21,12 +21,6 @@ func (where Where) Name() string {
// Build build where clause
func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
// Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
@@ -153,11 +147,6 @@ func Not(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
if len(exprs) == 1 {
if andCondition, ok := exprs[0].(AndConditions); ok {
exprs = andCondition.Exprs
}
}
return NotConditions{Exprs: exprs}
}
@@ -166,63 +155,19 @@ type NotConditions struct {
}
func (not NotConditions) Build(builder Builder) {
anyNegationBuilder := false
for _, c := range not.Exprs {
if _, ok := c.(NegationExpressionBuilder); ok {
anyNegationBuilder = true
break
}
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
if anyNegationBuilder {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
} else {
builder.WriteString("NOT ")
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
switch c.(type) {
case OrConditions:
builder.WriteString(OrWithSpace)
default:
builder.WriteString(AndWithSpace)
}
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
@@ -237,9 +182,9 @@ func (not NotConditions) Build(builder Builder) {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
}

2
vendor/gorm.io/gorm/errors.go generated vendored
View File

@@ -49,6 +49,4 @@ var (
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
// ErrCheckConstraintViolated occurs when there is a check constraint violation
ErrCheckConstraintViolated = errors.New("violates check constraint")
)

View File

@@ -376,12 +376,8 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
assigns := map[string]interface{}{}
for i := 0; i < len(exprs); i++ {
expr := exprs[i]
if eq, ok := expr.(clause.AndConditions); ok {
exprs = append(exprs, eq.Exprs...)
} else if eq, ok := expr.(clause.Eq); ok {
for _, expr := range exprs {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value

22
vendor/gorm.io/gorm/logger/logger.go generated vendored
View File

@@ -69,7 +69,7 @@ type Interface interface {
}
var (
// Discard logger will print any log to io.Discard
// Discard Discard logger will print any log to io.Discard
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
@@ -78,7 +78,7 @@ var (
IgnoreRecordNotFoundError: false,
Colorful: true,
})
// Recorder logger records running SQL into a recorder instance
// Recorder Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
)
@@ -129,30 +129,28 @@ func (l *logger) LogMode(level LogLevel) Interface {
}
// Info print info
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Warn print warn messages
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Error print error messages
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Trace print sql message
//
//nolint:cyclop
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent {
return
}
@@ -184,8 +182,8 @@ func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string,
}
}
// ParamsFilter filter params
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
// Trace print sql message
func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries {
return sql, nil
}
@@ -200,8 +198,8 @@ type traceRecorder struct {
Err error
}
// New trace recorder
func (l *traceRecorder) New() *traceRecorder {
// New new trace recorder
func (l traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
}

29
vendor/gorm.io/gorm/logger/sql.go generated vendored
View File

@@ -34,19 +34,6 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
func isNumeric(k reflect.Kind) bool {
switch k {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return true
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
case reflect.Float32, reflect.Float64:
return true
default:
return false
}
}
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var (
@@ -92,17 +79,17 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
} else {
vars[idx] = nullStr
}
}
case []byte:
if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
} else {
vars[idx] = escaper + "<binary>" + escaper
}
@@ -113,7 +100,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
case string:
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
default:
rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
@@ -123,12 +110,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else if isNumeric(rv.Kind()) {
if rv.CanInt() || rv.CanUint() {
vars[idx] = fmt.Sprintf("%d", rv.Interface())
} else {
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
}
} else {
for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) {
@@ -136,7 +117,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
return
}
}
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
}
}
}

2
vendor/gorm.io/gorm/migrator.go generated vendored
View File

@@ -87,8 +87,6 @@ type Migrator interface {
DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]ColumnType, error)

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
@@ -28,8 +27,6 @@ var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
// TODO:? Create const vars for raw sql queries ?
var _ gorm.Migrator = (*Migrator)(nil)
// Migrator m struct
type Migrator struct {
Config
@@ -94,6 +91,10 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL += " NOT NULL"
}
if field.Unique {
expr.SQL += " UNIQUE"
}
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
@@ -107,31 +108,21 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
return
}
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
queryTx = m.DB.Session(&gorm.Session{})
execTx = queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
return queryTx, execTx
}
// AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) {
queryTx, execTx := m.GetQueryAndExecTx()
queryTx := m.DB.Session(&gorm.Session{})
execTx := queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
if !queryTx.Migrator().HasTable(value) {
if err := execTx.Migrator().CreateTable(value); err != nil {
return err
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil {
return err
@@ -216,11 +207,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
var (
createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)}
@@ -280,7 +266,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
}
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := constraint.Build()
sql, vars := buildConstraint(constraint)
createTableSQL += sql + ","
values = append(values, vars...)
}
@@ -288,11 +274,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
}
}
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
@@ -373,9 +354,6 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
f := stmt.Schema.LookUpField(name)
if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name)
@@ -395,10 +373,8 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
// DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
return m.DB.Exec(
@@ -410,15 +386,13 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
fileType := m.FullDataTypeOf(field)
return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
).Error
if field := stmt.Schema.LookUpField(field); field != nil {
fileType := m.FullDataTypeOf(field)
return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
).Error
}
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
@@ -430,10 +404,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
return m.DB.Raw(
@@ -448,14 +420,12 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
// RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(oldName); field != nil {
oldName = field.DBName
}
if field := stmt.Schema.LookUpField(oldName); field != nil {
oldName = field.DBName
}
if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName
}
if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName
}
return m.DB.Exec(
@@ -467,10 +437,6 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
if field.IgnoreMigration {
return nil
}
// found, smart migrate
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName())
@@ -530,6 +496,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
}
}
// check unique
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
// check default value
if !field.PrimaryKey {
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
@@ -540,18 +514,12 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
} else if !dvNotNull && currentDefaultNotNull {
// null -> default value
alterColumn = true
} else if currentDefaultNotNull || dvNotNull {
switch field.GORMDataType {
case schema.Time:
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
alterColumn = true
}
case schema.Bool:
v1, _ := strconv.ParseBool(dv)
v2, _ := strconv.ParseBool(field.DefaultValue)
alterColumn = v1 != v2
default:
alterColumn = dv != field.DefaultValue
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
// default value not equal
// not both null
if currentDefaultNotNull || dvNotNull {
alterColumn = true
}
}
}
@@ -564,39 +532,13 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
}
}
if alterColumn {
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
return err
}
}
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
return err
if alterColumn && !field.IgnoreMigration {
return m.DB.Migrator().AlterColumn(value, field.DBName)
}
return nil
}
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
unique, ok := columnType.Unique()
if !ok || field.PrimaryKey {
return nil // skip primary key
}
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// We're currently only receiving boolean values on `Unique` tag,
// so the UniqueConstraint name is fixed
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
if unique && !field.Unique {
return m.DB.Migrator().DropConstraint(value, constraint)
}
if !unique && field.Unique {
return m.DB.Migrator().CreateConstraint(value, constraint)
}
return nil
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
@@ -666,36 +608,37 @@ func (m Migrator) DropView(name string) error {
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
}
// GuessConstraintAndTable guess statement's constraint and it's table based on name
//
// Deprecated: use GuessConstraintInterfaceAndTable instead.
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
switch c := constraint.(type) {
case *schema.Constraint:
return c, nil, table
case *schema.CheckConstraint:
return nil, c, table
default:
return nil, nil, table
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
// nolint:cyclop
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
// GuessConstraintAndTable guess statement's constraint and it's table based on name
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
if stmt.Schema == nil {
return nil, stmt.Table
return nil, nil, stmt.Table
}
checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok {
return &chk, stmt.Table
}
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
if uni, ok := uniqueConstraints[name]; ok {
return &uni, stmt.Table
return nil, &chk, stmt.Table
}
getTable := func(rel *schema.Relationship) string {
@@ -710,7 +653,7 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
return constraint, getTable(rel)
return constraint, nil, getTable(rel)
}
}
@@ -718,39 +661,40 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
for k := range checkConstraints {
if checkConstraints[k].Field == field {
v := checkConstraints[k]
return &v, stmt.Table
}
}
for k := range uniqueConstraints {
if uniqueConstraints[k].Field == field {
v := uniqueConstraints[k]
return &v, stmt.Table
return nil, &v, stmt.Table
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
return constraint, getTable(rel)
return constraint, nil, getTable(rel)
}
}
}
return nil, stmt.Schema.Table
return nil, nil, stmt.Schema.Table
}
// CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if chk != nil {
return m.DB.Exec(
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
).Error
}
if constraint != nil {
vars := []interface{}{clause.Table{Name: table}}
if stmt.TableExpr != nil {
vars[0] = stmt.TableExpr
}
sql, values := constraint.Build()
sql, values := buildConstraint(constraint)
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
}
return nil
})
}
@@ -758,9 +702,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
// DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
})
@@ -771,9 +717,11 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
return m.DB.Raw(
@@ -815,9 +763,6 @@ type BuildIndexOptionsInterface interface {
// CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema == nil {
return errors.New("failed to get schema")
}
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
@@ -850,10 +795,8 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
// DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
@@ -865,10 +808,8 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Raw(

26
vendor/gorm.io/gorm/prepare_stmt.go generated vendored
View File

@@ -3,8 +3,6 @@ package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"sync"
)
@@ -149,7 +147,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
if err != nil {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
@@ -163,7 +161,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
if err != nil {
db.Mux.Lock()
defer db.Mux.Unlock()
@@ -182,14 +180,6 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
return &sql.Row{}
}
func (db *PreparedStmtDB) Ping() error {
conn, err := db.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}
type PreparedStmtTX struct {
Tx
PreparedStmtDB *PreparedStmtDB
@@ -217,7 +207,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
@@ -232,7 +222,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if errors.Is(err, driver.ErrBadConn) {
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
@@ -250,11 +240,3 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
}
return &sql.Row{}
}
func (tx *PreparedStmtTX) Ping() error {
conn, err := tx.GetDBConn()
if err != nil {
return err
}
return conn.Ping()
}

22
vendor/gorm.io/gorm/scan.go generated vendored
View File

@@ -257,11 +257,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
continue
}
}
var val interface{}
values[idx] = &val
values[idx] = &sql.RawBytes{}
} else {
var val interface{}
values[idx] = &val
values[idx] = &sql.RawBytes{}
}
}
}
@@ -276,16 +274,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
if !update || reflectValue.Len() == 0 {
update = false
if isArrayKind {
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
} else {
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else if !isArrayKind {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
}

View File

@@ -1,66 +0,0 @@
package schema
import (
"regexp"
"strings"
"gorm.io/gorm/clause"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
func (chk *CheckConstraint) GetName() string { return chk.Name }
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}
type UniqueConstraint struct {
Name string
Field *Field
}
func (uni *UniqueConstraint) GetName() string { return uni.Name }
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}
// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}

21
vendor/gorm.io/gorm/schema/field.go generated vendored
View File

@@ -49,14 +49,11 @@ const (
Bytes DataType = "bytes"
)
const DefaultAutoIncrementIncrement int64 = 1
// Field is the representation of model schema's field
type Field struct {
Name string
DBName string
BindNames []string
EmbeddedBindNames []string
DataType DataType
GORMDataType DataType
PrimaryKey bool
@@ -90,12 +87,6 @@ type Field struct {
Set func(context.Context, reflect.Value, interface{}) error
Serializer SerializerInterface
NewValuePool FieldNewValuePool
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
// It causes field unnecessarily migration.
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
UniqueIndex string
}
func (field *Field) BindName() string {
@@ -113,7 +104,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name},
EmbeddedBindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct,
@@ -129,7 +119,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
AutoIncrementIncrement: 1,
}
for field.IndirectFieldType.Kind() == reflect.Ptr {
@@ -405,9 +395,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous {
ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...)
}
// index is negative means is pointer
if field.FieldType.Kind() == reflect.Struct {
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
@@ -669,7 +656,7 @@ func (field *Field) setupValuerAndSetter() {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
@@ -678,7 +665,7 @@ func (field *Field) setupValuerAndSetter() {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
@@ -743,7 +730,7 @@ func (field *Field) setupValuerAndSetter() {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
} else {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
}

View File

@@ -13,8 +13,8 @@ type Index struct {
Type string // btree, hash, gist, spgist, gin, and brin
Where string
Comment string
Option string // WITH PARSER parser_name
Fields []IndexOption // Note: IndexOption's Field maybe the same
Option string // WITH PARSER parser_name
Fields []IndexOption
}
type IndexOption struct {
@@ -67,7 +67,7 @@ func (schema *Schema) ParseIndexes() map[string]Index {
}
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.UniqueIndex = index.Name
index.Fields[0].Field.Unique = true
}
}
return indexes

View File

@@ -4,12 +4,6 @@ import (
"gorm.io/gorm/clause"
)
// ConstraintInterface database constraint interface
type ConstraintInterface interface {
GetName() string
Build() (sql string, vars []interface{})
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDataType() string

View File

@@ -19,7 +19,6 @@ type Namer interface {
RelationshipFKName(Relationship) string
CheckerName(table, column string) string
IndexName(table, column string) string
UniqueName(table, column string) string
}
// Replacer replacer interface like strings.Replacer
@@ -27,8 +26,6 @@ type Replacer interface {
Replace(name string) string
}
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy
type NamingStrategy struct {
TablePrefix string
@@ -88,11 +85,6 @@ func (ns NamingStrategy) IndexName(table, column string) string {
return ns.formatName("idx", table, ns.toDBName(column))
}
// UniqueName generate unique constraint name
func (ns NamingStrategy) UniqueName(table, column string) string {
return ns.formatName("uni", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name,

View File

@@ -76,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return nil
}
if hasPolymorphicRelation(field.TagSettings) {
schema.buildPolymorphicRelation(relation, field)
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
schema.buildPolymorphicRelation(relation, field, polymorphic)
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
@@ -89,8 +89,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
case reflect.Slice:
schema.guessRelation(relation, field, guessHas)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
field.Name)
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
}
}
@@ -125,20 +124,6 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
return relation
}
// hasPolymorphicRelation check if has polymorphic relation
// 1. `POLYMORPHIC` tag
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
func hasPolymorphicRelation(tagSettings map[string]string) bool {
if _, ok := tagSettings["POLYMORPHIC"]; ok {
return true
}
_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]
return hasType && hasId
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
@@ -150,12 +135,12 @@ func (schema *Schema) setRelation(relation *Relationship) {
}
// set embedded relation
if len(relation.Field.EmbeddedBindNames) <= 1 {
if len(relation.Field.BindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.EmbeddedBindNames {
if i < len(relation.Field.EmbeddedBindNames)-1 {
for i, name := range relation.Field.BindNames {
if i < len(relation.Field.BindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
@@ -184,41 +169,23 @@ func (schema *Schema) setRelation(relation *Relationship) {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
Value: schema.Table,
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
}
var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value)
}
if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
}
if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
}
if schema.err == nil {
@@ -230,14 +197,12 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
schema, field.Name)
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
}
}
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
return
}
@@ -352,8 +317,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
@@ -472,8 +436,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
schema, field.Name)
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
}
}
@@ -529,9 +492,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
}
for _, name := range lookUpNames {
@@ -605,7 +566,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
}
}
// Constraint is ForeignKey Constraint
type Constraint struct {
Name string
Field *Field
@@ -617,31 +577,6 @@ type Constraint struct {
OnUpdate string
}
func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
references := make([]interface{}, 0, len(constraint.References))
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" {

View File

@@ -126,12 +126,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
rv := reflect.ValueOf(fieldValue)
switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16:
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
result = time.Unix(reflect.Indirect(rv).Int(), 0)
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
if rv.IsZero() {
return nil, nil
}
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
result = time.Unix(reflect.Indirect(rv).Int(), 0)
default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
}

38
vendor/gorm.io/gorm/statement.go generated vendored
View File

@@ -326,7 +326,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
case *DB:
v.executeScopes()
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
@@ -334,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
}
}
conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
} else {
conds = append(conds, cs.Expression)
}
if v.Statement == stmt {
cs.Expression = nil
stmt.Statement.Clauses["WHERE"] = cs
}
}
case map[interface{}]interface{}:
for i, j := range v {
@@ -447,9 +451,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
return []clause.Expression{clause.And(conds...)}
}
return nil
return conds
}
}
@@ -458,10 +461,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
}
}
if len(conds) > 0 {
return []clause.Expression{clause.And(conds...)}
}
return nil
return conds
}
// Build build sql with clauses names
@@ -665,21 +665,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
return false
}
var matchName = func() func(tableColumn string) (table, column string) {
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
return func(tableColumn string) (table, column string) {
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
table = matches[1]
star := matches[2]
columnName := matches[3]
if star != "" {
return table, star
}
return table, columnName
}
return "", ""
}
}()
var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
@@ -700,13 +686,13 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = result
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
if col == "*" {
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
if matches[2] == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else {
results[col] = result
results[matches[2]] = result
}
} else {
results[column] = result

21
vendor/gorm.io/gorm/utils/utils.go generated vendored
View File

@@ -32,16 +32,11 @@ func sourceDir(file string) string {
// FileWithLineNum return the file name and line number of the current file
func FileWithLineNum() string {
pcs := [13]uintptr{}
// the third caller usually from gorm internal
len := runtime.Callers(3, pcs[:])
frames := runtime.CallersFrames(pcs[:len])
for i := 0; i < len; i++ {
// second return value is "more", not "ok"
frame, _ := frames.Next()
if (!strings.HasPrefix(frame.File, gormSourceDir) ||
strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") {
return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10))
// the second caller usually from gorm internal, so set i start from 2
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) {
return file + ":" + strconv.FormatInt(int64(line), 10)
}
}
@@ -78,11 +73,7 @@ func ToStringKey(values ...interface{}) string {
case uint:
results[idx] = strconv.FormatUint(uint64(v), 10)
default:
results[idx] = "nil"
vv := reflect.ValueOf(v)
if vv.IsValid() && !vv.IsZero() {
results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface())
}
results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface())
}
}