增加websocket支持
This commit is contained in:
2
vendor/gorm.io/driver/postgres/error_translator.go
generated
vendored
2
vendor/gorm.io/driver/postgres/error_translator.go
generated
vendored
@@ -8,12 +8,10 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
// The error codes to map PostgreSQL errors to gorm errors, here is the PostgreSQL error codes reference https://www.postgresql.org/docs/current/errcodes-appendix.html.
|
||||
var errCodes = map[string]error{
|
||||
"23505": gorm.ErrDuplicatedKey,
|
||||
"23503": gorm.ErrForeignKeyViolated,
|
||||
"42703": gorm.ErrInvalidField,
|
||||
"23514": gorm.ErrCheckConstraintViolated,
|
||||
}
|
||||
|
||||
type ErrMessage struct {
|
||||
|
||||
48
vendor/gorm.io/driver/postgres/migrator.go
generated
vendored
48
vendor/gorm.io/driver/postgres/migrator.go
generated
vendored
@@ -38,7 +38,6 @@ WHERE
|
||||
`
|
||||
|
||||
var typeAliasMap = map[string][]string{
|
||||
"int": {"integer"},
|
||||
"int2": {"smallint"},
|
||||
"int4": {"integer"},
|
||||
"int8": {"bigint"},
|
||||
@@ -49,35 +48,14 @@ var typeAliasMap = map[string][]string{
|
||||
"numeric": {"decimal"},
|
||||
"timestamptz": {"timestamp with time zone"},
|
||||
"timestamp with time zone": {"timestamptz"},
|
||||
"bool": {"boolean"},
|
||||
"boolean": {"bool"},
|
||||
"serial2": {"smallserial"},
|
||||
"serial4": {"serial"},
|
||||
"serial8": {"bigserial"},
|
||||
"varbit": {"bit varying"},
|
||||
"char": {"character"},
|
||||
"varchar": {"character varying"},
|
||||
"float4": {"real"},
|
||||
"float8": {"double precision"},
|
||||
"timetz": {"time with time zone"},
|
||||
}
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
// select querys ignore dryrun
|
||||
func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) {
|
||||
queryTx := m.DB
|
||||
if m.DB.DryRun {
|
||||
queryTx = m.DB.Session(&gorm.Session{})
|
||||
queryTx.DryRun = false
|
||||
}
|
||||
return queryTx.Raw(sql, values...)
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name)
|
||||
m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,7 +87,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
}
|
||||
}
|
||||
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
|
||||
return m.queryRaw(
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema,
|
||||
).Scan(&count).Error
|
||||
})
|
||||
@@ -177,7 +155,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
currentSchema, _ := m.CurrentSchema(m.DB.Statement, "")
|
||||
return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
|
||||
return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
|
||||
}
|
||||
|
||||
func (m Migrator) CreateTable(values ...interface{}) (err error) {
|
||||
@@ -211,7 +189,7 @@ func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
|
||||
return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
|
||||
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
@@ -263,7 +241,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
}
|
||||
|
||||
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
|
||||
return m.queryRaw(
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
|
||||
currentSchema, curTable, name,
|
||||
).Scan(&count).Error
|
||||
@@ -288,7 +266,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) "
|
||||
checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = "
|
||||
checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))"
|
||||
m.queryRaw(checkSQL, values...).Scan(&description)
|
||||
m.DB.Raw(checkSQL, values...).Scan(&description)
|
||||
|
||||
comment := strings.Trim(field.Comment, "'")
|
||||
comment = strings.Trim(comment, `"`)
|
||||
@@ -322,7 +300,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
|
||||
// check for typeName and SQL name
|
||||
isSameType := true
|
||||
if !strings.EqualFold(fieldColumnType.DatabaseTypeName(), fileType.SQL) {
|
||||
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
|
||||
isSameType = false
|
||||
// if different, also check for aliases
|
||||
aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName())
|
||||
@@ -436,7 +414,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
}
|
||||
currentSchema, curTable := m.CurrentSchema(stmt, table)
|
||||
|
||||
return m.queryRaw(
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?",
|
||||
currentSchema, curTable, name,
|
||||
).Scan(&count).Error
|
||||
@@ -451,7 +429,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
|
||||
var (
|
||||
currentDatabase = m.DB.Migrator().CurrentDatabase()
|
||||
currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
|
||||
columns, err = m.queryRaw(
|
||||
columns, err = m.DB.Raw(
|
||||
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
|
||||
currentDatabase, currentSchema, table).Rows()
|
||||
)
|
||||
@@ -525,7 +503,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
|
||||
|
||||
// check primary, unique field
|
||||
{
|
||||
columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
|
||||
columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -537,7 +515,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
|
||||
}
|
||||
columnTypeRows.Close()
|
||||
|
||||
columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
|
||||
columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -564,7 +542,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
|
||||
|
||||
// check column type
|
||||
{
|
||||
dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
|
||||
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
|
||||
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
|
||||
WHERE a.attnum > 0 -- hide internal columns
|
||||
AND NOT a.attisdropped -- hide deleted columns
|
||||
@@ -722,7 +700,7 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
|
||||
|
||||
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
result := make([]*Index, 0)
|
||||
scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error
|
||||
scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error
|
||||
if scanErr != nil {
|
||||
return scanErr
|
||||
}
|
||||
|
||||
24
vendor/gorm.io/driver/postgres/postgres.go
generated
vendored
24
vendor/gorm.io/driver/postgres/postgres.go
generated
vendored
@@ -48,25 +48,19 @@ func (dialector Dialector) Name() string {
|
||||
}
|
||||
|
||||
func (dialector Dialector) Apply(config *gorm.Config) error {
|
||||
if config.NamingStrategy == nil {
|
||||
config.NamingStrategy = schema.NamingStrategy{
|
||||
IdentifierMaxLength: defaultIdentifierLength,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var namingStartegy *schema.NamingStrategy
|
||||
switch v := config.NamingStrategy.(type) {
|
||||
case *schema.NamingStrategy:
|
||||
if v.IdentifierMaxLength <= 0 {
|
||||
v.IdentifierMaxLength = defaultIdentifierLength
|
||||
}
|
||||
namingStartegy = v
|
||||
case schema.NamingStrategy:
|
||||
if v.IdentifierMaxLength <= 0 {
|
||||
v.IdentifierMaxLength = defaultIdentifierLength
|
||||
config.NamingStrategy = v
|
||||
}
|
||||
namingStartegy = &v
|
||||
case nil:
|
||||
namingStartegy = &schema.NamingStrategy{}
|
||||
}
|
||||
|
||||
if namingStartegy.IdentifierMaxLength <= 0 {
|
||||
namingStartegy.IdentifierMaxLength = defaultIdentifierLength
|
||||
}
|
||||
config.NamingStrategy = namingStartegy
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
2
vendor/gorm.io/gorm/README.md
generated
vendored
2
vendor/gorm.io/gorm/README.md
generated
vendored
@@ -41,4 +41,4 @@ The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
|
||||
© Jinzhu, 2013~time.Now
|
||||
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/License)
|
||||
|
||||
19
vendor/gorm.io/gorm/callbacks.go
generated
vendored
19
vendor/gorm.io/gorm/callbacks.go
generated
vendored
@@ -187,18 +187,10 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
|
||||
|
||||
func (p *processor) compile() (err error) {
|
||||
var callbacks []*callback
|
||||
removedMap := map[string]bool{}
|
||||
for _, callback := range p.callbacks {
|
||||
if callback.match == nil || callback.match(p.db) {
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
if callback.remove {
|
||||
removedMap[callback.name] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(removedMap) > 0 {
|
||||
callbacks = removeCallbacks(callbacks, removedMap)
|
||||
}
|
||||
p.callbacks = callbacks
|
||||
|
||||
@@ -347,14 +339,3 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback {
|
||||
callbacks := make([]*callback, 0, len(cs))
|
||||
for _, callback := range cs {
|
||||
if nameMap[callback.name] {
|
||||
continue
|
||||
}
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
return callbacks
|
||||
}
|
||||
|
||||
97
vendor/gorm.io/gorm/callbacks/create.go
generated
vendored
97
vendor/gorm.io/gorm/callbacks/create.go
generated
vendored
@@ -103,62 +103,13 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
if db.RowsAffected == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
pkField *schema.Field
|
||||
pkFieldName = "@id"
|
||||
)
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
|
||||
if !insertOk {
|
||||
if !supportReturning {
|
||||
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
|
||||
db.Statement.Schema.PrioritizedPrimaryField != nil &&
|
||||
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
if !insertOk {
|
||||
db.AddError(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
return
|
||||
}
|
||||
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
|
||||
}
|
||||
|
||||
// append @id column with value for auto-increment primary key
|
||||
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
|
||||
switch values := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values[pkFieldName] = insertID
|
||||
case *map[string]interface{}:
|
||||
(*values)[pkFieldName] = insertID
|
||||
case []map[string]interface{}, *[]map[string]interface{}:
|
||||
mapValues, ok := values.([]map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := values.(*[]map[string]interface{}); ok {
|
||||
if *v != nil {
|
||||
mapValues = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.LastInsertIDReversed {
|
||||
insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
|
||||
for _, mapValue := range mapValues {
|
||||
if mapValue != nil {
|
||||
mapValue[pkFieldName] = insertID
|
||||
}
|
||||
insertID += schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
default:
|
||||
if pkField == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -171,10 +122,10 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= pkField.AutoIncrementIncrement
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -184,16 +135,16 @@ func Create(config *Config) func(db *gorm.DB) {
|
||||
break
|
||||
}
|
||||
|
||||
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += pkField.AutoIncrementIncrement
|
||||
if _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += db.Statement.Schema.PrioritizedPrimaryField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -302,15 +253,13 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if vs, ok := defaultValueFieldsHavingValue[field]; ok {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
for field, vs := range defaultValueFieldsHavingValue {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -333,7 +282,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[0] = append(values.Values[0], rvOfvalue)
|
||||
@@ -362,7 +311,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
case schema.UnixNanosecond:
|
||||
assignment.Value = curTime.UnixNano()
|
||||
case schema.UnixMillisecond:
|
||||
assignment.Value = curTime.UnixMilli()
|
||||
assignment.Value = curTime.UnixNano() / 1e6
|
||||
case schema.UnixSecond:
|
||||
assignment.Value = curTime.Unix()
|
||||
}
|
||||
|
||||
99
vendor/gorm.io/gorm/callbacks/preload.go
generated
vendored
99
vendor/gorm.io/gorm/callbacks/preload.go
generated
vendored
@@ -3,7 +3,6 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -75,7 +74,7 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||
for _, relation := range embeddedRelations.Relations {
|
||||
// skip first struct name
|
||||
names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
|
||||
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
|
||||
}
|
||||
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||
names = append(names, embeddedValues(relations)...)
|
||||
@@ -83,103 +82,27 @@ func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
return names
|
||||
}
|
||||
|
||||
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||
// If the current relationship is embedded or joined, current query will be ignored.
|
||||
//
|
||||
//nolint:cyclop
|
||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||
|
||||
// avoid random traversal of the map
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error {
|
||||
if relationships == nil {
|
||||
return nil
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||
for _, join := range joins {
|
||||
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||
joined = true
|
||||
continue
|
||||
}
|
||||
joinNames := strings.SplitN(join, ".", 2)
|
||||
if len(joinNames) == 2 {
|
||||
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
|
||||
joined = true
|
||||
nestedJoins = append(nestedJoins, joinNames[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
return joined, nestedJoins
|
||||
}
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||
preloadMap := parsePreloadMap(s, preloads)
|
||||
for name := range preloadMap {
|
||||
if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil {
|
||||
if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if rel := relationships.Relations[name]; rel != nil {
|
||||
if joined, nestedJoins := isJoined(name); joined {
|
||||
switch rv := db.Statement.ReflectValue; rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() > 0 {
|
||||
reflectValue := rel.FieldSchema.MakeSlice().Elem()
|
||||
reflectValue.SetLen(rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
|
||||
if frv.Kind() != reflect.Ptr {
|
||||
reflectValue.Index(i).Set(frv.Addr())
|
||||
} else {
|
||||
reflectValue.Index(i).Set(frv)
|
||||
}
|
||||
}
|
||||
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return gorm.ErrInvalidData
|
||||
}
|
||||
} else {
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
|
||||
return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
|
||||
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := tx.Statement.Parse(dest); err != nil {
|
||||
tx.AddError(err)
|
||||
return tx
|
||||
}
|
||||
tx.Statement.ReflectValue = reflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
return tx
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
|
||||
33
vendor/gorm.io/gorm/callbacks/query.go
generated
vendored
33
vendor/gorm.io/gorm/callbacks/query.go
generated
vendored
@@ -3,6 +3,7 @@ package callbacks
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -253,6 +254,7 @@ func BuildQuerySQL(db *gorm.DB) {
|
||||
}
|
||||
|
||||
db.Statement.AddClause(fromClause)
|
||||
db.Statement.Joins = nil
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
@@ -270,23 +272,38 @@ func Preload(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
joins := make([]string, 0, len(db.Statement.Joins))
|
||||
for _, join := range db.Statement.Joins {
|
||||
joins = append(joins, join.Name)
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads)
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
|
||||
if tx.Error != nil {
|
||||
preloadDB := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
preloadDB.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := preloadDB.Statement.Parse(db.Statement.Dest); err != nil {
|
||||
return
|
||||
}
|
||||
preloadDB.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
preloadDB.Statement.Unscoped = db.Statement.Unscoped
|
||||
|
||||
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||
for _, name := range preloadNames {
|
||||
if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil {
|
||||
db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations]))
|
||||
} else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil {
|
||||
db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]))
|
||||
} else {
|
||||
db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// clear the joins after query because preload need it
|
||||
db.Statement.Joins = nil
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
|
||||
4
vendor/gorm.io/gorm/callbacks/update.go
generated
vendored
4
vendor/gorm.io/gorm/callbacks/update.go
generated
vendored
@@ -234,7 +234,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||
} else {
|
||||
@@ -268,7 +268,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
value = stmt.DB.NowFunc().UnixNano()
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
value = stmt.DB.NowFunc().UnixMilli()
|
||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
value = stmt.DB.NowFunc().Unix()
|
||||
} else {
|
||||
|
||||
38
vendor/gorm.io/gorm/chainable_api.go
generated
vendored
38
vendor/gorm.io/gorm/chainable_api.go
generated
vendored
@@ -367,12 +367,33 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
}
|
||||
|
||||
func (db *DB) executeScopes() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
scopes := db.Statement.scopes
|
||||
db.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
db = scope(db)
|
||||
if len(scopes) == 0 {
|
||||
return tx
|
||||
}
|
||||
return db
|
||||
tx.Statement.scopes = nil
|
||||
|
||||
conditions := make([]clause.Interface, 0, 4)
|
||||
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
|
||||
conditions = append(conditions, cs.Expression.(clause.Interface))
|
||||
cs.Expression = nil
|
||||
tx.Statement.Clauses["WHERE"] = cs
|
||||
}
|
||||
|
||||
for _, scope := range scopes {
|
||||
tx = scope(tx)
|
||||
if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
|
||||
conditions = append(conditions, cs.Expression.(clause.Interface))
|
||||
cs.Expression = nil
|
||||
tx.Statement.Clauses["WHERE"] = cs
|
||||
}
|
||||
}
|
||||
|
||||
for _, condition := range conditions {
|
||||
tx.Statement.AddClause(condition)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
@@ -429,15 +450,6 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
return
|
||||
}
|
||||
|
||||
// Unscoped disables the global scope of soft deletion in a query.
|
||||
// By default, GORM uses soft deletion, marking records as "deleted"
|
||||
// by setting a timestamp on a specific field (e.g., `deleted_at`).
|
||||
// Unscoped allows queries to include records marked as deleted,
|
||||
// overriding the soft deletion behavior.
|
||||
// Example:
|
||||
// var users []User
|
||||
// db.Unscoped().Find(&users)
|
||||
// // Retrieves all users, including deleted ones.
|
||||
func (db *DB) Unscoped() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Unscoped = true
|
||||
|
||||
6
vendor/gorm.io/gorm/clause/limit.go
generated
vendored
6
vendor/gorm.io/gorm/clause/limit.go
generated
vendored
@@ -1,5 +1,7 @@
|
||||
package clause
|
||||
|
||||
import "strconv"
|
||||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
Limit *int
|
||||
@@ -15,14 +17,14 @@ func (limit Limit) Name() string {
|
||||
func (limit Limit) Build(builder Builder) {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.AddVar(builder, *limit.Limit)
|
||||
builder.WriteString(strconv.Itoa(*limit.Limit))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
builder.WriteString("OFFSET ")
|
||||
builder.AddVar(builder, limit.Offset)
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
7
vendor/gorm.io/gorm/clause/locking.go
generated
vendored
7
vendor/gorm.io/gorm/clause/locking.go
generated
vendored
@@ -1,12 +1,5 @@
|
||||
package clause
|
||||
|
||||
const (
|
||||
LockingStrengthUpdate = "UPDATE"
|
||||
LockingStrengthShare = "SHARE"
|
||||
LockingOptionsSkipLocked = "SKIP LOCKED"
|
||||
LockingOptionsNoWait = "NOWAIT"
|
||||
)
|
||||
|
||||
type Locking struct {
|
||||
Strength string
|
||||
Table Table
|
||||
|
||||
79
vendor/gorm.io/gorm/clause/where.go
generated
vendored
79
vendor/gorm.io/gorm/clause/where.go
generated
vendored
@@ -21,12 +21,6 @@ func (where Where) Name() string {
|
||||
|
||||
// Build build where clause
|
||||
func (where Where) Build(builder Builder) {
|
||||
if len(where.Exprs) == 1 {
|
||||
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
|
||||
where.Exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
|
||||
// Switch position if the first query expression is a single Or condition
|
||||
for idx, expr := range where.Exprs {
|
||||
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
|
||||
@@ -153,11 +147,6 @@ func Not(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(exprs) == 1 {
|
||||
if andCondition, ok := exprs[0].(AndConditions); ok {
|
||||
exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
return NotConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
@@ -166,63 +155,19 @@ type NotConditions struct {
|
||||
}
|
||||
|
||||
func (not NotConditions) Build(builder Builder) {
|
||||
anyNegationBuilder := false
|
||||
for _, c := range not.Exprs {
|
||||
if _, ok := c.(NegationExpressionBuilder); ok {
|
||||
anyNegationBuilder = true
|
||||
break
|
||||
}
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
if anyNegationBuilder {
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
switch c.(type) {
|
||||
case OrConditions:
|
||||
builder.WriteString(OrWithSpace)
|
||||
default:
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
@@ -237,9 +182,9 @@ func (not NotConditions) Build(builder Builder) {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
2
vendor/gorm.io/gorm/errors.go
generated
vendored
2
vendor/gorm.io/gorm/errors.go
generated
vendored
@@ -49,6 +49,4 @@ var (
|
||||
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
|
||||
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
|
||||
// ErrCheckConstraintViolated occurs when there is a check constraint violation
|
||||
ErrCheckConstraintViolated = errors.New("violates check constraint")
|
||||
)
|
||||
|
||||
8
vendor/gorm.io/gorm/finisher_api.go
generated
vendored
8
vendor/gorm.io/gorm/finisher_api.go
generated
vendored
@@ -376,12 +376,8 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
} else if len(db.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for i := 0; i < len(exprs); i++ {
|
||||
expr := exprs[i]
|
||||
|
||||
if eq, ok := expr.(clause.AndConditions); ok {
|
||||
exprs = append(exprs, eq.Exprs...)
|
||||
} else if eq, ok := expr.(clause.Eq); ok {
|
||||
for _, expr := range exprs {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
assigns[column] = eq.Value
|
||||
|
||||
22
vendor/gorm.io/gorm/logger/logger.go
generated
vendored
22
vendor/gorm.io/gorm/logger/logger.go
generated
vendored
@@ -69,7 +69,7 @@ type Interface interface {
|
||||
}
|
||||
|
||||
var (
|
||||
// Discard logger will print any log to io.Discard
|
||||
// Discard Discard logger will print any log to io.Discard
|
||||
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
|
||||
// Default Default logger
|
||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||
@@ -78,7 +78,7 @@ var (
|
||||
IgnoreRecordNotFoundError: false,
|
||||
Colorful: true,
|
||||
})
|
||||
// Recorder logger records running SQL into a recorder instance
|
||||
// Recorder Recorder logger records running SQL into a recorder instance
|
||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||
)
|
||||
|
||||
@@ -129,30 +129,28 @@ func (l *logger) LogMode(level LogLevel) Interface {
|
||||
}
|
||||
|
||||
// Info print info
|
||||
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
func (l logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Info {
|
||||
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Warn print warn messages
|
||||
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
func (l logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Warn {
|
||||
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error print error messages
|
||||
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
func (l logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Error {
|
||||
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Trace print sql message
|
||||
//
|
||||
//nolint:cyclop
|
||||
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
func (l logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
if l.LogLevel <= Silent {
|
||||
return
|
||||
}
|
||||
@@ -184,8 +182,8 @@ func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string,
|
||||
}
|
||||
}
|
||||
|
||||
// ParamsFilter filter params
|
||||
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
// Trace print sql message
|
||||
func (l logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
if l.Config.ParameterizedQueries {
|
||||
return sql, nil
|
||||
}
|
||||
@@ -200,8 +198,8 @@ type traceRecorder struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// New trace recorder
|
||||
func (l *traceRecorder) New() *traceRecorder {
|
||||
// New new trace recorder
|
||||
func (l traceRecorder) New() *traceRecorder {
|
||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||
}
|
||||
|
||||
|
||||
29
vendor/gorm.io/gorm/logger/sql.go
generated
vendored
29
vendor/gorm.io/gorm/logger/sql.go
generated
vendored
@@ -34,19 +34,6 @@ var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeO
|
||||
// RegEx matches only numeric values
|
||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||
|
||||
func isNumeric(k reflect.Kind) bool {
|
||||
switch k {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return true
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return true
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||
var (
|
||||
@@ -92,17 +79,17 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
case reflect.Bool:
|
||||
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||
case reflect.String:
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
default:
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, "\\"+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
}
|
||||
case []byte:
|
||||
if s := string(v); isPrintable(s) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, "\\"+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + "<binary>" + escaper
|
||||
}
|
||||
@@ -113,7 +100,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
case float64:
|
||||
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case string:
|
||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper
|
||||
default:
|
||||
rv := reflect.ValueOf(v)
|
||||
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
@@ -123,12 +110,6 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
convertParams(v, idx)
|
||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||
} else if isNumeric(rv.Kind()) {
|
||||
if rv.CanInt() || rv.CanUint() {
|
||||
vars[idx] = fmt.Sprintf("%d", rv.Interface())
|
||||
} else {
|
||||
vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
|
||||
}
|
||||
} else {
|
||||
for _, t := range convertibleTypes {
|
||||
if rv.Type().ConvertibleTo(t) {
|
||||
@@ -136,7 +117,7 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a
|
||||
return
|
||||
}
|
||||
}
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, "\\"+escaper) + escaper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
2
vendor/gorm.io/gorm/migrator.go
generated
vendored
2
vendor/gorm.io/gorm/migrator.go
generated
vendored
@@ -87,8 +87,6 @@ type Migrator interface {
|
||||
DropColumn(dst interface{}, field string) error
|
||||
AlterColumn(dst interface{}, field string) error
|
||||
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
|
||||
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
HasColumn(dst interface{}, field string) bool
|
||||
RenameColumn(dst interface{}, oldName, field string) error
|
||||
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
||||
|
||||
241
vendor/gorm.io/gorm/migrator/migrator.go
generated
vendored
241
vendor/gorm.io/gorm/migrator/migrator.go
generated
vendored
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -28,8 +27,6 @@ var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||
|
||||
// TODO:? Create const vars for raw sql queries ?
|
||||
|
||||
var _ gorm.Migrator = (*Migrator)(nil)
|
||||
|
||||
// Migrator m struct
|
||||
type Migrator struct {
|
||||
Config
|
||||
@@ -94,6 +91,10 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||
expr.SQL += " NOT NULL"
|
||||
}
|
||||
|
||||
if field.Unique {
|
||||
expr.SQL += " UNIQUE"
|
||||
}
|
||||
|
||||
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
|
||||
if field.DefaultValueInterface != nil {
|
||||
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
||||
@@ -107,31 +108,21 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
|
||||
queryTx = m.DB.Session(&gorm.Session{})
|
||||
execTx = queryTx
|
||||
if m.DB.DryRun {
|
||||
queryTx.DryRun = false
|
||||
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
|
||||
}
|
||||
return queryTx, execTx
|
||||
}
|
||||
|
||||
// AutoMigrate auto migrate values
|
||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, true) {
|
||||
queryTx, execTx := m.GetQueryAndExecTx()
|
||||
queryTx := m.DB.Session(&gorm.Session{})
|
||||
execTx := queryTx
|
||||
if m.DB.DryRun {
|
||||
queryTx.DryRun = false
|
||||
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
|
||||
}
|
||||
if !queryTx.Migrator().HasTable(value) {
|
||||
if err := execTx.Migrator().CreateTable(value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
|
||||
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -216,11 +207,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, false) {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
|
||||
var (
|
||||
createTableSQL = "CREATE TABLE ? ("
|
||||
values = []interface{}{m.CurrentTable(stmt)}
|
||||
@@ -280,7 +266,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
}
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
sql, vars := constraint.Build()
|
||||
sql, vars := buildConstraint(constraint)
|
||||
createTableSQL += sql + ","
|
||||
values = append(values, vars...)
|
||||
}
|
||||
@@ -288,11 +274,6 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
|
||||
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
||||
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
||||
@@ -373,9 +354,6 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
|
||||
func (m Migrator) AddColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// avoid using the same name field
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
f := stmt.Schema.LookUpField(name)
|
||||
if f == nil {
|
||||
return fmt.Errorf("failed to look up field with name: %s", name)
|
||||
@@ -395,10 +373,8 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
|
||||
// DropColumn drop value's `name` column
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
@@ -410,15 +386,13 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
// AlterColumn alter value's `field` column' type based on schema definition
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
fileType := m.FullDataTypeOf(field)
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
|
||||
).Error
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
fileType := m.FullDataTypeOf(field)
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
|
||||
).Error
|
||||
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
@@ -430,10 +404,8 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
name := field
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
@@ -448,14 +420,12 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
// RenameColumn rename value's field name from oldName to newName
|
||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||
oldName = field.DBName
|
||||
}
|
||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||
oldName = field.DBName
|
||||
}
|
||||
|
||||
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||
newName = field.DBName
|
||||
}
|
||||
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||
newName = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
@@ -467,10 +437,6 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
||||
|
||||
// MigrateColumn migrate column
|
||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||
if field.IgnoreMigration {
|
||||
return nil
|
||||
}
|
||||
|
||||
// found, smart migrate
|
||||
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
@@ -530,6 +496,14 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
}
|
||||
}
|
||||
|
||||
// check unique
|
||||
if unique, ok := columnType.Unique(); ok && unique != field.Unique {
|
||||
// not primary key
|
||||
if !field.PrimaryKey {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
// check default value
|
||||
if !field.PrimaryKey {
|
||||
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
|
||||
@@ -540,18 +514,12 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
} else if !dvNotNull && currentDefaultNotNull {
|
||||
// null -> default value
|
||||
alterColumn = true
|
||||
} else if currentDefaultNotNull || dvNotNull {
|
||||
switch field.GORMDataType {
|
||||
case schema.Time:
|
||||
if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) {
|
||||
alterColumn = true
|
||||
}
|
||||
case schema.Bool:
|
||||
v1, _ := strconv.ParseBool(dv)
|
||||
v2, _ := strconv.ParseBool(field.DefaultValue)
|
||||
alterColumn = v1 != v2
|
||||
default:
|
||||
alterColumn = dv != field.DefaultValue
|
||||
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
|
||||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
|
||||
// default value not equal
|
||||
// not both null
|
||||
if currentDefaultNotNull || dvNotNull {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -564,39 +532,13 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
||||
}
|
||||
}
|
||||
|
||||
if alterColumn {
|
||||
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
|
||||
return err
|
||||
if alterColumn && !field.IgnoreMigration {
|
||||
return m.DB.Migrator().AlterColumn(value, field.DBName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||
unique, ok := columnType.Unique()
|
||||
if !ok || field.PrimaryKey {
|
||||
return nil // skip primary key
|
||||
}
|
||||
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// We're currently only receiving boolean values on `Unique` tag,
|
||||
// so the UniqueConstraint name is fixed
|
||||
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
|
||||
if unique && !field.Unique {
|
||||
return m.DB.Migrator().DropConstraint(value, constraint)
|
||||
}
|
||||
if !unique && field.Unique {
|
||||
return m.DB.Migrator().CreateConstraint(value, constraint)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
@@ -666,36 +608,37 @@ func (m Migrator) DropView(name string) error {
|
||||
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
|
||||
}
|
||||
|
||||
// GuessConstraintAndTable guess statement's constraint and it's table based on name
|
||||
//
|
||||
// Deprecated: use GuessConstraintInterfaceAndTable instead.
|
||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
switch c := constraint.(type) {
|
||||
case *schema.Constraint:
|
||||
return c, nil, table
|
||||
case *schema.CheckConstraint:
|
||||
return nil, c, table
|
||||
default:
|
||||
return nil, nil, table
|
||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
var foreignKeys, references []interface{}
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
|
||||
// nolint:cyclop
|
||||
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
|
||||
// GuessConstraintAndTable guess statement's constraint and it's table based on name
|
||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
|
||||
if stmt.Schema == nil {
|
||||
return nil, stmt.Table
|
||||
return nil, nil, stmt.Table
|
||||
}
|
||||
|
||||
checkConstraints := stmt.Schema.ParseCheckConstraints()
|
||||
if chk, ok := checkConstraints[name]; ok {
|
||||
return &chk, stmt.Table
|
||||
}
|
||||
|
||||
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
|
||||
if uni, ok := uniqueConstraints[name]; ok {
|
||||
return &uni, stmt.Table
|
||||
return nil, &chk, stmt.Table
|
||||
}
|
||||
|
||||
getTable := func(rel *schema.Relationship) string {
|
||||
@@ -710,7 +653,7 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
|
||||
return constraint, getTable(rel)
|
||||
return constraint, nil, getTable(rel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -718,39 +661,40 @@ func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name st
|
||||
for k := range checkConstraints {
|
||||
if checkConstraints[k].Field == field {
|
||||
v := checkConstraints[k]
|
||||
return &v, stmt.Table
|
||||
}
|
||||
}
|
||||
|
||||
for k := range uniqueConstraints {
|
||||
if uniqueConstraints[k].Field == field {
|
||||
v := uniqueConstraints[k]
|
||||
return &v, stmt.Table
|
||||
return nil, &v, stmt.Table
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
|
||||
return constraint, getTable(rel)
|
||||
return constraint, nil, getTable(rel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, stmt.Schema.Table
|
||||
return nil, nil, stmt.Schema.Table
|
||||
}
|
||||
|
||||
// CreateConstraint create constraint
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if chk != nil {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ADD CONSTRAINT ? CHECK (?)",
|
||||
m.CurrentTable(stmt), clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint},
|
||||
).Error
|
||||
}
|
||||
|
||||
if constraint != nil {
|
||||
vars := []interface{}{clause.Table{Name: table}}
|
||||
if stmt.TableExpr != nil {
|
||||
vars[0] = stmt.TableExpr
|
||||
}
|
||||
sql, values := constraint.Build()
|
||||
sql, values := buildConstraint(constraint)
|
||||
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -758,9 +702,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
// DropConstraint drop constraint
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.GetName()
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
|
||||
})
|
||||
@@ -771,9 +717,11 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.GetName()
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
@@ -815,9 +763,6 @@ type BuildIndexOptionsInterface interface {
|
||||
// CreateIndex create index `name`
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema == nil {
|
||||
return errors.New("failed to get schema")
|
||||
}
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
|
||||
@@ -850,10 +795,8 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
// DropIndex drop index `name`
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
|
||||
@@ -865,10 +808,8 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
|
||||
26
vendor/gorm.io/gorm/prepare_stmt.go
generated
vendored
26
vendor/gorm.io/gorm/prepare_stmt.go
generated
vendored
@@ -3,8 +3,6 @@ package gorm
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
@@ -149,7 +147,7 @@ func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ..
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
if err != nil {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
go stmt.Close()
|
||||
@@ -163,7 +161,7 @@ func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args .
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
if err != nil {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
@@ -182,14 +180,6 @@ func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, arg
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Ping() error {
|
||||
conn, err := db.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
|
||||
type PreparedStmtTX struct {
|
||||
Tx
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
@@ -217,7 +207,7 @@ func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ..
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
@@ -232,7 +222,7 @@ func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args .
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||
if errors.Is(err, driver.ErrBadConn) {
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
@@ -250,11 +240,3 @@ func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, arg
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Ping() error {
|
||||
conn, err := tx.GetDBConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Ping()
|
||||
}
|
||||
|
||||
22
vendor/gorm.io/gorm/scan.go
generated
vendored
22
vendor/gorm.io/gorm/scan.go
generated
vendored
@@ -257,11 +257,9 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
var val interface{}
|
||||
values[idx] = &val
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else {
|
||||
var val interface{}
|
||||
values[idx] = &val
|
||||
values[idx] = &sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -276,16 +274,12 @@ func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
|
||||
if !update || reflectValue.Len() == 0 {
|
||||
update = false
|
||||
if isArrayKind {
|
||||
db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
|
||||
} else {
|
||||
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
||||
if reflectValue.Cap() == 0 {
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
} else {
|
||||
reflectValue.SetLen(0)
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
||||
if reflectValue.Cap() == 0 {
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
} else if !isArrayKind {
|
||||
reflectValue.SetLen(0)
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
66
vendor/gorm.io/gorm/schema/constraint.go
generated
vendored
66
vendor/gorm.io/gorm/schema/constraint.go
generated
vendored
@@ -1,66 +0,0 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// reg match english letters and midline
|
||||
var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`)
|
||||
|
||||
type CheckConstraint struct {
|
||||
Name string
|
||||
Constraint string // length(phone) >= 10
|
||||
*Field
|
||||
}
|
||||
|
||||
func (chk *CheckConstraint) GetName() string { return chk.Name }
|
||||
|
||||
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
}
|
||||
|
||||
// ParseCheckConstraints parse schema check constraints
|
||||
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
|
||||
checks := map[string]CheckConstraint{}
|
||||
for _, field := range schema.FieldsByDBName {
|
||||
if chk := field.TagSettings["CHECK"]; chk != "" {
|
||||
names := strings.Split(chk, ",")
|
||||
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
|
||||
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
||||
} else {
|
||||
if names[0] == "" {
|
||||
chk = strings.Join(names[1:], ",")
|
||||
}
|
||||
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
||||
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
|
||||
}
|
||||
}
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
type UniqueConstraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
}
|
||||
|
||||
func (uni *UniqueConstraint) GetName() string { return uni.Name }
|
||||
|
||||
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
|
||||
}
|
||||
|
||||
// ParseUniqueConstraints parse schema unique constraints
|
||||
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
|
||||
uniques := make(map[string]UniqueConstraint)
|
||||
for _, field := range schema.Fields {
|
||||
if field.Unique {
|
||||
name := schema.namer.UniqueName(schema.Table, field.DBName)
|
||||
uniques[name] = UniqueConstraint{Name: name, Field: field}
|
||||
}
|
||||
}
|
||||
return uniques
|
||||
}
|
||||
21
vendor/gorm.io/gorm/schema/field.go
generated
vendored
21
vendor/gorm.io/gorm/schema/field.go
generated
vendored
@@ -49,14 +49,11 @@ const (
|
||||
Bytes DataType = "bytes"
|
||||
)
|
||||
|
||||
const DefaultAutoIncrementIncrement int64 = 1
|
||||
|
||||
// Field is the representation of model schema's field
|
||||
type Field struct {
|
||||
Name string
|
||||
DBName string
|
||||
BindNames []string
|
||||
EmbeddedBindNames []string
|
||||
DataType DataType
|
||||
GORMDataType DataType
|
||||
PrimaryKey bool
|
||||
@@ -90,12 +87,6 @@ type Field struct {
|
||||
Set func(context.Context, reflect.Value, interface{}) error
|
||||
Serializer SerializerInterface
|
||||
NewValuePool FieldNewValuePool
|
||||
|
||||
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
|
||||
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
|
||||
// It causes field unnecessarily migration.
|
||||
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
|
||||
UniqueIndex string
|
||||
}
|
||||
|
||||
func (field *Field) BindName() string {
|
||||
@@ -113,7 +104,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
Name: fieldStruct.Name,
|
||||
DBName: tagSetting["COLUMN"],
|
||||
BindNames: []string{fieldStruct.Name},
|
||||
EmbeddedBindNames: []string{fieldStruct.Name},
|
||||
FieldType: fieldStruct.Type,
|
||||
IndirectFieldType: fieldStruct.Type,
|
||||
StructField: fieldStruct,
|
||||
@@ -129,7 +119,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
||||
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
||||
Comment: tagSetting["COMMENT"],
|
||||
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
|
||||
AutoIncrementIncrement: 1,
|
||||
}
|
||||
|
||||
for field.IndirectFieldType.Kind() == reflect.Ptr {
|
||||
@@ -405,9 +395,6 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
ef.Schema = schema
|
||||
ef.OwnerSchema = field.EmbeddedSchema
|
||||
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
|
||||
if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous {
|
||||
ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...)
|
||||
}
|
||||
// index is negative means is pointer
|
||||
if field.FieldType.Kind() == reflect.Struct {
|
||||
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
|
||||
@@ -669,7 +656,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
@@ -678,7 +665,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli())
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
@@ -743,7 +730,7 @@ func (field *Field) setupValuerAndSetter() {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli()))
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
|
||||
}
|
||||
|
||||
6
vendor/gorm.io/gorm/schema/index.go
generated
vendored
6
vendor/gorm.io/gorm/schema/index.go
generated
vendored
@@ -13,8 +13,8 @@ type Index struct {
|
||||
Type string // btree, hash, gist, spgist, gin, and brin
|
||||
Where string
|
||||
Comment string
|
||||
Option string // WITH PARSER parser_name
|
||||
Fields []IndexOption // Note: IndexOption's Field maybe the same
|
||||
Option string // WITH PARSER parser_name
|
||||
Fields []IndexOption
|
||||
}
|
||||
|
||||
type IndexOption struct {
|
||||
@@ -67,7 +67,7 @@ func (schema *Schema) ParseIndexes() map[string]Index {
|
||||
}
|
||||
for _, index := range indexes {
|
||||
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
|
||||
index.Fields[0].Field.UniqueIndex = index.Name
|
||||
index.Fields[0].Field.Unique = true
|
||||
}
|
||||
}
|
||||
return indexes
|
||||
|
||||
6
vendor/gorm.io/gorm/schema/interfaces.go
generated
vendored
6
vendor/gorm.io/gorm/schema/interfaces.go
generated
vendored
@@ -4,12 +4,6 @@ import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ConstraintInterface database constraint interface
|
||||
type ConstraintInterface interface {
|
||||
GetName() string
|
||||
Build() (sql string, vars []interface{})
|
||||
}
|
||||
|
||||
// GormDataTypeInterface gorm data type interface
|
||||
type GormDataTypeInterface interface {
|
||||
GormDataType() string
|
||||
|
||||
8
vendor/gorm.io/gorm/schema/naming.go
generated
vendored
8
vendor/gorm.io/gorm/schema/naming.go
generated
vendored
@@ -19,7 +19,6 @@ type Namer interface {
|
||||
RelationshipFKName(Relationship) string
|
||||
CheckerName(table, column string) string
|
||||
IndexName(table, column string) string
|
||||
UniqueName(table, column string) string
|
||||
}
|
||||
|
||||
// Replacer replacer interface like strings.Replacer
|
||||
@@ -27,8 +26,6 @@ type Replacer interface {
|
||||
Replace(name string) string
|
||||
}
|
||||
|
||||
var _ Namer = (*NamingStrategy)(nil)
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
type NamingStrategy struct {
|
||||
TablePrefix string
|
||||
@@ -88,11 +85,6 @@ func (ns NamingStrategy) IndexName(table, column string) string {
|
||||
return ns.formatName("idx", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
// UniqueName generate unique constraint name
|
||||
func (ns NamingStrategy) UniqueName(table, column string) string {
|
||||
return ns.formatName("uni", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||
formattedName := strings.ReplaceAll(strings.Join([]string{
|
||||
prefix, table, name,
|
||||
|
||||
99
vendor/gorm.io/gorm/schema/relationship.go
generated
vendored
99
vendor/gorm.io/gorm/schema/relationship.go
generated
vendored
@@ -76,8 +76,8 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
return nil
|
||||
}
|
||||
|
||||
if hasPolymorphicRelation(field.TagSettings) {
|
||||
schema.buildPolymorphicRelation(relation, field)
|
||||
if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
|
||||
schema.buildPolymorphicRelation(relation, field, polymorphic)
|
||||
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||
schema.buildMany2ManyRelation(relation, field, many2many)
|
||||
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
|
||||
@@ -89,8 +89,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
case reflect.Slice:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
default:
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
|
||||
field.Name)
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,20 +124,6 @@ func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
return relation
|
||||
}
|
||||
|
||||
// hasPolymorphicRelation check if has polymorphic relation
|
||||
// 1. `POLYMORPHIC` tag
|
||||
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
|
||||
func hasPolymorphicRelation(tagSettings map[string]string) bool {
|
||||
if _, ok := tagSettings["POLYMORPHIC"]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
_, hasType := tagSettings["POLYMORPHICTYPE"]
|
||||
_, hasId := tagSettings["POLYMORPHICID"]
|
||||
|
||||
return hasType && hasId
|
||||
}
|
||||
|
||||
func (schema *Schema) setRelation(relation *Relationship) {
|
||||
// set non-embedded relation
|
||||
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||
@@ -150,12 +135,12 @@ func (schema *Schema) setRelation(relation *Relationship) {
|
||||
}
|
||||
|
||||
// set embedded relation
|
||||
if len(relation.Field.EmbeddedBindNames) <= 1 {
|
||||
if len(relation.Field.BindNames) <= 1 {
|
||||
return
|
||||
}
|
||||
relationships := &schema.Relationships
|
||||
for i, name := range relation.Field.EmbeddedBindNames {
|
||||
if i < len(relation.Field.EmbeddedBindNames)-1 {
|
||||
for i, name := range relation.Field.BindNames {
|
||||
if i < len(relation.Field.BindNames)-1 {
|
||||
if relationships.EmbeddedRelations == nil {
|
||||
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||
}
|
||||
@@ -184,41 +169,23 @@ func (schema *Schema) setRelation(relation *Relationship) {
|
||||
// OwnerID int
|
||||
// OwnerType string
|
||||
// }
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
|
||||
polymorphic := field.TagSettings["POLYMORPHIC"]
|
||||
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) {
|
||||
relation.Polymorphic = &Polymorphic{
|
||||
Value: schema.Table,
|
||||
Value: schema.Table,
|
||||
PolymorphicType: relation.FieldSchema.FieldsByName[polymorphic+"Type"],
|
||||
PolymorphicID: relation.FieldSchema.FieldsByName[polymorphic+"ID"],
|
||||
}
|
||||
|
||||
var (
|
||||
typeName = polymorphic + "Type"
|
||||
typeId = polymorphic + "ID"
|
||||
)
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
|
||||
typeName = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
|
||||
typeId = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
|
||||
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
|
||||
relation.Polymorphic.Value = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicType == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicID == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
@@ -230,14 +197,12 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi
|
||||
primaryKeyField := schema.PrioritizedPrimaryField
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
|
||||
schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if primaryKeyField == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
|
||||
relation.FieldSchema, schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -352,8 +317,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
|
||||
Tag: `gorm:"-"`,
|
||||
})
|
||||
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
|
||||
schema.namer); err != nil {
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
@@ -472,8 +436,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
// case guessEmbeddedHas:
|
||||
default:
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
|
||||
schema, field.Name)
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -529,9 +492,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
|
||||
lookUpNames := []string{lookUpName}
|
||||
if len(primaryFields) == 1 {
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
}
|
||||
|
||||
for _, name := range lookUpNames {
|
||||
@@ -605,7 +566,6 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint is ForeignKey Constraint
|
||||
type Constraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
@@ -617,31 +577,6 @@ type Constraint struct {
|
||||
OnUpdate string
|
||||
}
|
||||
|
||||
func (constraint *Constraint) GetName() string { return constraint.Name }
|
||||
|
||||
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
references := make([]interface{}, 0, len(constraint.References))
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
func (rel *Relationship) ParseConstraint() *Constraint {
|
||||
str := rel.Field.TagSettings["CONSTRAINT"]
|
||||
if str == "-" {
|
||||
|
||||
4
vendor/gorm.io/gorm/schema/serializer.go
generated
vendored
4
vendor/gorm.io/gorm/schema/serializer.go
generated
vendored
@@ -126,12 +126,12 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect
|
||||
rv := reflect.ValueOf(fieldValue)
|
||||
switch v := fieldValue.(type) {
|
||||
case int64, int, uint, uint64, int32, uint32, int16, uint16:
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||
if rv.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||
default:
|
||||
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
||||
}
|
||||
|
||||
38
vendor/gorm.io/gorm/statement.go
generated
vendored
38
vendor/gorm.io/gorm/statement.go
generated
vendored
@@ -326,7 +326,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
case *DB:
|
||||
v.executeScopes()
|
||||
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
if len(where.Exprs) == 1 {
|
||||
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
||||
@@ -334,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
}
|
||||
conds = append(conds, clause.And(where.Exprs...))
|
||||
} else if cs.Expression != nil {
|
||||
} else {
|
||||
conds = append(conds, cs.Expression)
|
||||
}
|
||||
if v.Statement == stmt {
|
||||
cs.Expression = nil
|
||||
stmt.Statement.Clauses["WHERE"] = cs
|
||||
}
|
||||
}
|
||||
case map[interface{}]interface{}:
|
||||
for i, j := range v {
|
||||
@@ -447,9 +451,8 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
|
||||
if len(values) > 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
return conds
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,10 +461,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []
|
||||
}
|
||||
}
|
||||
|
||||
if len(conds) > 0 {
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
return conds
|
||||
}
|
||||
|
||||
// Build build sql with clauses names
|
||||
@@ -665,21 +665,7 @@ func (stmt *Statement) Changed(fields ...string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
var matchName = func() func(tableColumn string) (table, column string) {
|
||||
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
|
||||
return func(tableColumn string) (table, column string) {
|
||||
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
|
||||
table = matches[1]
|
||||
star := matches[2]
|
||||
columnName := matches[3]
|
||||
if star != "" {
|
||||
return table, star
|
||||
}
|
||||
return table, columnName
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
}()
|
||||
var nameMatcher = regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?\W?(\w+?)\W?$`)
|
||||
|
||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||
@@ -700,13 +686,13 @@ func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = result
|
||||
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
|
||||
if col == "*" {
|
||||
} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 3 && (matches[1] == stmt.Table || matches[1] == "") {
|
||||
if matches[2] == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
}
|
||||
} else {
|
||||
results[col] = result
|
||||
results[matches[2]] = result
|
||||
}
|
||||
} else {
|
||||
results[column] = result
|
||||
|
||||
21
vendor/gorm.io/gorm/utils/utils.go
generated
vendored
21
vendor/gorm.io/gorm/utils/utils.go
generated
vendored
@@ -32,16 +32,11 @@ func sourceDir(file string) string {
|
||||
|
||||
// FileWithLineNum return the file name and line number of the current file
|
||||
func FileWithLineNum() string {
|
||||
pcs := [13]uintptr{}
|
||||
// the third caller usually from gorm internal
|
||||
len := runtime.Callers(3, pcs[:])
|
||||
frames := runtime.CallersFrames(pcs[:len])
|
||||
for i := 0; i < len; i++ {
|
||||
// second return value is "more", not "ok"
|
||||
frame, _ := frames.Next()
|
||||
if (!strings.HasPrefix(frame.File, gormSourceDir) ||
|
||||
strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") {
|
||||
return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10))
|
||||
// the second caller usually from gorm internal, so set i start from 2
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) {
|
||||
return file + ":" + strconv.FormatInt(int64(line), 10)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,11 +73,7 @@ func ToStringKey(values ...interface{}) string {
|
||||
case uint:
|
||||
results[idx] = strconv.FormatUint(uint64(v), 10)
|
||||
default:
|
||||
results[idx] = "nil"
|
||||
vv := reflect.ValueOf(v)
|
||||
if vv.IsValid() && !vv.IsZero() {
|
||||
results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface())
|
||||
}
|
||||
results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user