1. 实现配置文件解析

2. 实现数据库连接
This commit is contained in:
2025-09-07 17:50:04 +08:00
parent b8229846d1
commit cf00a74008
365 changed files with 181226 additions and 0 deletions

1
vendor/gorm.io/driver/postgres/.gitignore generated vendored Normal file
View File

@@ -0,0 +1 @@
.idea

21
vendor/gorm.io/driver/postgres/License generated vendored Normal file
View File

@@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

31
vendor/gorm.io/driver/postgres/README.md generated vendored Normal file
View File

@@ -0,0 +1,31 @@
# GORM PostgreSQL Driver
## Quick Start
```go
import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
// https://github.com/jackc/pgx
dsn := "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai"
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
```
## Configuration
```go
import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
db, err := gorm.Open(postgres.New(postgres.Config{
DSN: "host=localhost user=gorm password=gorm dbname=gorm port=9920 sslmode=disable TimeZone=Asia/Shanghai", // data source name, refer https://github.com/jackc/pgx
PreferSimpleProtocol: true, // disables implicit prepared statement usage. By default pgx automatically uses the extended protocol
}), &gorm.Config{})
```
Checkout [https://gorm.io](https://gorm.io) for details.

48
vendor/gorm.io/driver/postgres/error_translator.go generated vendored Normal file
View File

@@ -0,0 +1,48 @@
package postgres
import (
"encoding/json"
"gorm.io/gorm"
"github.com/jackc/pgx/v5/pgconn"
)
var errCodes = map[string]error{
"23505": gorm.ErrDuplicatedKey,
"23503": gorm.ErrForeignKeyViolated,
"42703": gorm.ErrInvalidField,
}
type ErrMessage struct {
Code string
Severity string
Message string
}
// Translate it will translate the error to native gorm errors.
// Since currently gorm supporting both pgx and pg drivers, only checking for pgx PgError types is not enough for translating errors, so we have additional error json marshal fallback.
func (dialector Dialector) Translate(err error) error {
if pgErr, ok := err.(*pgconn.PgError); ok {
if translatedErr, found := errCodes[pgErr.Code]; found {
return translatedErr
}
return err
}
parsedErr, marshalErr := json.Marshal(err)
if marshalErr != nil {
return err
}
var errMsg ErrMessage
unmarshalErr := json.Unmarshal(parsedErr, &errMsg)
if unmarshalErr != nil {
return err
}
if translatedErr, found := errCodes[errMsg.Code]; found {
return translatedErr
}
return err
}

794
vendor/gorm.io/driver/postgres/migrator.go generated vendored Normal file
View File

@@ -0,0 +1,794 @@
package postgres
import (
"database/sql"
"fmt"
"regexp"
"strings"
"github.com/jackc/pgx/v5"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
// See https://stackoverflow.com/questions/2204058/list-columns-with-indexes-in-postgresql
// Here are some changes:
// - use `LEFT JOIN` instead of `CROSS JOIN`
// - exclude indexes used to support constraints (they are auto-generated)
const indexSql = `
SELECT
ct.relname AS table_name,
ci.relname AS index_name,
i.indisunique AS non_unique,
i.indisprimary AS primary,
a.attname AS column_name
FROM
pg_index i
LEFT JOIN pg_class ct ON ct.oid = i.indrelid
LEFT JOIN pg_class ci ON ci.oid = i.indexrelid
LEFT JOIN pg_attribute a ON a.attrelid = ct.oid
LEFT JOIN pg_constraint con ON con.conindid = i.indexrelid
WHERE
a.attnum = ANY(i.indkey)
AND con.oid IS NULL
AND ct.relkind = 'r'
AND ct.relname = ?
`
var typeAliasMap = map[string][]string{
"int2": {"smallint"},
"int4": {"integer"},
"int8": {"bigint"},
"smallint": {"int2"},
"integer": {"int4"},
"bigint": {"int8"},
"decimal": {"numeric"},
"numeric": {"decimal"},
"timestamptz": {"timestamp with time zone"},
"timestamp with time zone": {"timestamptz"},
"bool": {"boolean"},
"boolean": {"bool"},
}
type Migrator struct {
migrator.Migrator
}
// select querys ignore dryrun
func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) {
queryTx := m.DB
if m.DB.DryRun {
queryTx = m.DB.Session(&gorm.Session{})
queryTx.DryRun = false
}
return queryTx.Raw(sql, values...)
}
func (m Migrator) CurrentDatabase() (name string) {
m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.queryRaw(
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema,
).Scan(&count).Error
})
return count > 0
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX "
if strings.TrimSpace(strings.ToUpper(idx.Option)) == "CONCURRENTLY" {
createIndexSQL += "CONCURRENTLY "
}
createIndexSQL += "IF NOT EXISTS ? ON ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type + "(?)"
} else {
createIndexSQL += " ?"
}
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"ALTER INDEX ? RENAME TO ?",
clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
})
}
func (m Migrator) GetTables() (tableList []string, err error) {
currentSchema, _ := m.CurrentSchema(m.DB.Statement, "")
return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
}
func (m Migrator) CreateTable(values ...interface{}) (err error) {
if err = m.Migrator.CreateTable(values...); err != nil {
return
}
for _, value := range m.ReorderModels(values, false) {
if err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
for _, fieldName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[fieldName]
if field.Comment != "" {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
}
}
}
}
return nil
}); err != nil {
return
}
}
return
}
func (m Migrator) HasTable(value interface{}) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
})
return count > 0
}
func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ? CASCADE", m.CurrentTable(stmt)).Error
}); err != nil {
return err
}
}
return nil
}
func (m Migrator) AddColumn(value interface{}, field string) error {
if err := m.Migrator.AddColumn(value, field); err != nil {
return err
}
m.resetPreparedStmts()
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
if field.Comment != "" {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
}
}
}
}
return nil
})
}
func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
name := field
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
}
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
return m.queryRaw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentSchema, curTable, name,
).Scan(&count).Error
})
return count > 0
}
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
// skip primary field
if !field.PrimaryKey {
if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil {
return err
}
}
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var description string
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
values := []interface{}{currentSchema, curTable, field.DBName, stmt.Table, currentSchema}
checkSQL := "SELECT description FROM pg_catalog.pg_description "
checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) "
checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = "
checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))"
m.queryRaw(checkSQL, values...).Scan(&description)
comment := strings.Trim(field.Comment, "'")
comment = strings.Trim(comment, `"`)
if field.Comment != "" && comment != description {
if err := m.DB.Exec(
"COMMENT ON COLUMN ?.? IS ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, gorm.Expr(m.Migrator.Dialector.Explain("$1", field.Comment)),
).Error; err != nil {
return err
}
}
return nil
})
}
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error {
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(field); field != nil {
var (
columnTypes, _ = m.DB.Migrator().ColumnTypes(value)
fieldColumnType *migrator.ColumnType
)
for _, columnType := range columnTypes {
if columnType.Name() == field.DBName {
fieldColumnType, _ = columnType.(*migrator.ColumnType)
}
}
fileType := clause.Expr{SQL: m.DataTypeOf(field)}
// check for typeName and SQL name
isSameType := true
if fieldColumnType.DatabaseTypeName() != fileType.SQL {
isSameType = false
// if different, also check for aliases
aliases := m.GetTypeAliases(fieldColumnType.DatabaseTypeName())
for _, alias := range aliases {
if strings.HasPrefix(fileType.SQL, alias) {
isSameType = true
break
}
}
}
// not same, migrate
if !isSameType {
filedColumnAutoIncrement, _ := fieldColumnType.AutoIncrement()
if field.AutoIncrement && filedColumnAutoIncrement { // update
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if t, _ := fieldColumnType.ColumnType(); t != serialDatabaseType {
if err := m.UpdateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
return err
}
}
} else if field.AutoIncrement && !filedColumnAutoIncrement { // create
serialDatabaseType, _ := getSerialDatabaseType(fileType.SQL)
if err := m.CreateSequence(m.DB, stmt, field, serialDatabaseType); err != nil {
return err
}
} else if !field.AutoIncrement && filedColumnAutoIncrement { // delete
if err := m.DeleteSequence(m.DB, stmt, field, fileType); err != nil {
return err
}
} else {
if err := m.modifyColumn(stmt, field, fileType, fieldColumnType); err != nil {
return err
}
}
}
if null, _ := fieldColumnType.Nullable(); null == field.NotNull {
if field.NotNull {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP NOT NULL", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
}
if v, ok := fieldColumnType.DefaultValue(); (field.DefaultValueInterface == nil && ok) || v != field.DefaultValue {
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)}).Error; err != nil {
return err
}
} else if field.DefaultValue != "(-)" {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
} else {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, clause.Expr{SQL: field.DefaultValue}).Error; err != nil {
return err
}
}
}
}
return nil
}
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
if err != nil {
return err
}
m.resetPreparedStmts()
return nil
}
func (m Migrator) modifyColumn(stmt *gorm.Statement, field *schema.Field, targetType clause.Expr, existingColumn *migrator.ColumnType) error {
alterSQL := "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::?"
isUncastableDefaultValue := false
if targetType.SQL == "boolean" {
switch existingColumn.DatabaseTypeName() {
case "int2", "int8", "numeric":
alterSQL = "ALTER TABLE ? ALTER COLUMN ? TYPE ? USING ?::int::?"
}
isUncastableDefaultValue = true
}
if dv, _ := existingColumn.DefaultValue(); dv != "" && isUncastableDefaultValue {
if err := m.DB.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT", m.CurrentTable(stmt), clause.Column{Name: field.DBName}).Error; err != nil {
return err
}
}
if err := m.DB.Exec(alterSQL, m.CurrentTable(stmt), clause.Column{Name: field.DBName}, targetType, clause.Column{Name: field.DBName}, targetType).Error; err != nil {
return err
}
return nil
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
}
currentSchema, curTable := m.CurrentSchema(stmt, table)
return m.queryRaw(
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?",
currentSchema, curTable, name,
).Scan(&count).Error
})
return count > 0
}
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) {
columnTypes = make([]gorm.ColumnType, 0)
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
var (
currentDatabase = m.DB.Migrator().CurrentDatabase()
currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
columns, err = m.queryRaw(
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
currentDatabase, currentSchema, table).Rows()
)
if err != nil {
return err
}
for columns.Next() {
var (
column = &migrator.ColumnType{
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
}
datetimePrecision sql.NullInt64
radixValue sql.NullInt64
typeLenValue sql.NullInt64
identityIncrement sql.NullString
)
err = columns.Scan(
&column.NameValue, &column.NullableValue, &column.DataTypeValue, &column.LengthValue, &column.DecimalSizeValue,
&radixValue, &column.ScaleValue, &datetimePrecision, &typeLenValue, &column.DefaultValueValue, &column.CommentValue, &identityIncrement,
)
if err != nil {
return err
}
if typeLenValue.Valid && typeLenValue.Int64 > 0 {
column.LengthValue = typeLenValue
}
if (strings.HasPrefix(column.DefaultValueValue.String, "nextval('") &&
strings.HasSuffix(column.DefaultValueValue.String, "seq'::regclass)")) || (identityIncrement.Valid && identityIncrement.String != "") {
column.AutoIncrementValue = sql.NullBool{Bool: true, Valid: true}
column.DefaultValueValue = sql.NullString{}
}
if column.DefaultValueValue.Valid {
column.DefaultValueValue.String = parseDefaultValueValue(column.DefaultValueValue.String)
}
if datetimePrecision.Valid {
column.DecimalSizeValue = datetimePrecision
}
columnTypes = append(columnTypes, column)
}
columns.Close()
// assign sql column type
{
rows, rowsErr := m.GetRows(currentSchema, table)
if rowsErr != nil {
return rowsErr
}
rawColumnTypes, err := rows.ColumnTypes()
if err != nil {
return err
}
for _, columnType := range columnTypes {
for _, c := range rawColumnTypes {
if c.Name() == columnType.Name() {
columnType.(*migrator.ColumnType).SQLColumnType = c
break
}
}
}
rows.Close()
}
// check primary, unique field
{
columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
if err != nil {
return err
}
uniqueContraints := map[string]int{}
for columnTypeRows.Next() {
var constraintName string
columnTypeRows.Scan(&constraintName)
uniqueContraints[constraintName]++
}
columnTypeRows.Close()
columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
if err != nil {
return err
}
for columnTypeRows.Next() {
var name, constraintName, columnType string
columnTypeRows.Scan(&name, &constraintName, &columnType)
for _, c := range columnTypes {
mc := c.(*migrator.ColumnType)
if mc.NameValue.String == name {
switch columnType {
case "PRIMARY KEY":
mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
case "UNIQUE":
if uniqueContraints[constraintName] == 1 {
mc.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
}
break
}
}
}
columnTypeRows.Close()
}
// check column type
{
dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
WHERE a.attnum > 0 -- hide internal columns
AND NOT a.attisdropped -- hide deleted columns
AND b.relname = ?`, currentSchema, table).Rows()
if err != nil {
return err
}
for dataTypeRows.Next() {
var name, dataType string
dataTypeRows.Scan(&name, &dataType)
for _, c := range columnTypes {
mc := c.(*migrator.ColumnType)
if mc.NameValue.String == name {
mc.ColumnTypeValue = sql.NullString{String: dataType, Valid: true}
// Handle array type: _text -> text[] , _int4 -> integer[]
// Not support array size limits and array size limits because:
// https://www.postgresql.org/docs/current/arrays.html#ARRAYS-DECLARATION
if strings.HasPrefix(mc.DataTypeValue.String, "_") {
mc.DataTypeValue = sql.NullString{String: dataType, Valid: true}
}
break
}
}
}
dataTypeRows.Close()
}
return err
})
return
}
func (m Migrator) GetRows(currentSchema interface{}, table interface{}) (*sql.Rows, error) {
name := table.(string)
if _, ok := currentSchema.(string); ok {
name = fmt.Sprintf("%v.%v", currentSchema, table)
}
return m.DB.Session(&gorm.Session{}).Table(name).Limit(1).Scopes(func(d *gorm.DB) *gorm.DB {
dialector, _ := m.Dialector.(Dialector)
// use simple protocol
if !m.DB.PrepareStmt && (dialector.Config != nil && (dialector.Config.DriverName == "" || dialector.Config.DriverName == "pgx")) {
d.Statement.Vars = append([]interface{}{pgx.QueryExecModeSimpleProtocol}, d.Statement.Vars...)
}
return d
}).Rows()
}
func (m Migrator) CurrentSchema(stmt *gorm.Statement, table string) (interface{}, interface{}) {
if strings.Contains(table, ".") {
if tables := strings.Split(table, `.`); len(tables) == 2 {
return tables[0], tables[1]
}
}
if stmt.TableExpr != nil {
if tables := strings.Split(stmt.TableExpr.SQL, `"."`); len(tables) == 2 {
return strings.TrimPrefix(tables[0], `"`), table
}
}
return clause.Expr{SQL: "CURRENT_SCHEMA()"}, table
}
func (m Migrator) CreateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
serialDatabaseType string) (err error) {
_, table := m.CurrentSchema(stmt, stmt.Table)
tableName := table.(string)
sequenceName := strings.Join([]string{tableName, field.DBName, "seq"}, "_")
if err = tx.Exec(`CREATE SEQUENCE IF NOT EXISTS ? AS ?`, clause.Expr{SQL: sequenceName},
clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
return err
}
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? SET DEFAULT nextval('?')",
clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}, clause.Expr{SQL: sequenceName}).Error; err != nil {
return err
}
if err := tx.Exec("ALTER SEQUENCE ? OWNED BY ?.?",
clause.Expr{SQL: sequenceName}, clause.Expr{SQL: tableName}, clause.Expr{SQL: field.DBName}).Error; err != nil {
return err
}
return
}
func (m Migrator) UpdateSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
serialDatabaseType string) (err error) {
sequenceName, err := m.getColumnSequenceName(tx, stmt, field)
if err != nil {
return err
}
if err = tx.Exec(`ALTER SEQUENCE IF EXISTS ? AS ?`, clause.Expr{SQL: sequenceName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
return err
}
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}, clause.Expr{SQL: serialDatabaseType}).Error; err != nil {
return err
}
return
}
func (m Migrator) DeleteSequence(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field,
fileType clause.Expr) (err error) {
sequenceName, err := m.getColumnSequenceName(tx, stmt, field)
if err != nil {
return err
}
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType).Error; err != nil {
return err
}
if err := tx.Exec("ALTER TABLE ? ALTER COLUMN ? DROP DEFAULT",
m.CurrentTable(stmt), clause.Expr{SQL: field.DBName}).Error; err != nil {
return err
}
if err = tx.Exec(`DROP SEQUENCE IF EXISTS ?`, clause.Expr{SQL: sequenceName}).Error; err != nil {
return err
}
return
}
func (m Migrator) getColumnSequenceName(tx *gorm.DB, stmt *gorm.Statement, field *schema.Field) (
sequenceName string, err error) {
_, table := m.CurrentSchema(stmt, stmt.Table)
// DefaultValueValue is reset by ColumnTypes, search again.
var columnDefault string
err = tx.Raw(
`SELECT column_default FROM information_schema.columns WHERE table_name = ? AND column_name = ?`,
table, field.DBName).Scan(&columnDefault).Error
if err != nil {
return
}
sequenceName = strings.TrimSuffix(
strings.TrimPrefix(columnDefault, `nextval('`),
`'::regclass)`,
)
return
}
func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
indexes := make([]gorm.Index, 0)
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
result := make([]*Index, 0)
scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error
if scanErr != nil {
return scanErr
}
indexMap := groupByIndexName(result)
for _, idx := range indexMap {
tempIdx := &migrator.Index{
TableName: idx[0].TableName,
NameValue: idx[0].IndexName,
PrimaryKeyValue: sql.NullBool{
Bool: idx[0].Primary,
Valid: true,
},
UniqueValue: sql.NullBool{
Bool: idx[0].NonUnique,
Valid: true,
},
}
for _, x := range idx {
tempIdx.ColumnList = append(tempIdx.ColumnList, x.ColumnName)
}
indexes = append(indexes, tempIdx)
}
return nil
})
return indexes, err
}
// Index table index info
type Index struct {
TableName string `gorm:"column:table_name"`
ColumnName string `gorm:"column:column_name"`
IndexName string `gorm:"column:index_name"`
NonUnique bool `gorm:"column:non_unique"`
Primary bool `gorm:"column:primary"`
}
func groupByIndexName(indexList []*Index) map[string][]*Index {
columnIndexMap := make(map[string][]*Index, len(indexList))
for _, idx := range indexList {
columnIndexMap[idx.IndexName] = append(columnIndexMap[idx.IndexName], idx)
}
return columnIndexMap
}
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return typeAliasMap[databaseTypeName]
}
// should reset prepared stmts when table changed
func (m Migrator) resetPreparedStmts() {
if m.DB.PrepareStmt {
if pdb, ok := m.DB.ConnPool.(*gorm.PreparedStmtDB); ok {
pdb.Reset()
}
}
}
func (m Migrator) DropColumn(dst interface{}, field string) error {
if err := m.Migrator.DropColumn(dst, field); err != nil {
return err
}
m.resetPreparedStmts()
return nil
}
func (m Migrator) RenameColumn(dst interface{}, oldName, field string) error {
if err := m.Migrator.RenameColumn(dst, oldName, field); err != nil {
return err
}
m.resetPreparedStmts()
return nil
}
func parseDefaultValueValue(defaultValue string) string {
value := regexp.MustCompile(`^(.*?)(?:::.*)?$`).ReplaceAllString(defaultValue, "$1")
return strings.Trim(value, "'")
}

289
vendor/gorm.io/driver/postgres/postgres.go generated vendored Normal file
View File

@@ -0,0 +1,289 @@
package postgres
import (
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/stdlib"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Dialector struct {
*Config
}
type Config struct {
DriverName string
DSN string
WithoutQuotingCheck bool
PreferSimpleProtocol bool
WithoutReturning bool
Conn gorm.ConnPool
}
var (
timeZoneMatcher = regexp.MustCompile("(time_zone|TimeZone)=(.*?)($|&| )")
defaultIdentifierLength = 63 //maximum identifier length for postgres
)
func Open(dsn string) gorm.Dialector {
return &Dialector{&Config{DSN: dsn}}
}
func New(config Config) gorm.Dialector {
return &Dialector{Config: &config}
}
func (dialector Dialector) Name() string {
return "postgres"
}
func (dialector Dialector) Apply(config *gorm.Config) error {
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{
IdentifierMaxLength: defaultIdentifierLength,
}
return nil
}
switch v := config.NamingStrategy.(type) {
case *schema.NamingStrategy:
if v.IdentifierMaxLength <= 0 {
v.IdentifierMaxLength = defaultIdentifierLength
}
case schema.NamingStrategy:
if v.IdentifierMaxLength <= 0 {
v.IdentifierMaxLength = defaultIdentifierLength
config.NamingStrategy = v
}
}
return nil
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
callbackConfig := &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"},
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE"},
}
// register callbacks
if !dialector.WithoutReturning {
callbackConfig.CreateClauses = append(callbackConfig.CreateClauses, "RETURNING")
callbackConfig.UpdateClauses = append(callbackConfig.UpdateClauses, "RETURNING")
callbackConfig.DeleteClauses = append(callbackConfig.DeleteClauses, "RETURNING")
}
callbacks.RegisterDefaultCallbacks(db, callbackConfig)
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else if dialector.DriverName != "" {
db.ConnPool, err = sql.Open(dialector.DriverName, dialector.Config.DSN)
} else {
var config *pgx.ConnConfig
config, err = pgx.ParseConfig(dialector.Config.DSN)
if err != nil {
return
}
if dialector.Config.PreferSimpleProtocol {
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
}
result := timeZoneMatcher.FindStringSubmatch(dialector.Config.DSN)
if len(result) > 2 {
config.RuntimeParams["timezone"] = result[2]
}
db.ConnPool = stdlib.OpenDB(*config)
}
return
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('$')
index := 0
varLen := len(stmt.Vars)
if varLen > 0 {
switch stmt.Vars[0].(type) {
case pgx.QueryExecMode:
index++
}
}
writer.WriteString(strconv.Itoa(varLen - index))
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
if dialector.WithoutQuotingCheck {
writer.WriteString(str)
return
}
var (
underQuoted, selfQuoted bool
continuousBacktick int8
shiftDelimiter int8
)
for _, v := range []byte(str) {
switch v {
case '"':
continuousBacktick++
if continuousBacktick == 2 {
writer.WriteString(`""`)
continuousBacktick = 0
}
case '.':
if continuousBacktick > 0 || !selfQuoted {
shiftDelimiter = 0
underQuoted = false
continuousBacktick = 0
writer.WriteByte('"')
}
writer.WriteByte(v)
continue
default:
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
writer.WriteByte('"')
underQuoted = true
if selfQuoted = continuousBacktick > 0; selfQuoted {
continuousBacktick -= 1
}
}
for ; continuousBacktick > 0; continuousBacktick -= 1 {
writer.WriteString(`""`)
}
writer.WriteByte(v)
}
shiftDelimiter++
}
if continuousBacktick > 0 && !selfQuoted {
writer.WriteString(`""`)
}
writer.WriteByte('"')
}
var numericPlaceholder = regexp.MustCompile(`\$(\d+)`)
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, numericPlaceholder, `'`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "boolean"
case schema.Int, schema.Uint:
size := field.Size
if field.DataType == schema.Uint {
size++
}
if field.AutoIncrement {
switch {
case size <= 16:
return "smallserial"
case size <= 32:
return "serial"
default:
return "bigserial"
}
} else {
switch {
case size <= 16:
return "smallint"
case size <= 32:
return "integer"
default:
return "bigint"
}
}
case schema.Float:
if field.Precision > 0 {
if field.Scale > 0 {
return fmt.Sprintf("numeric(%d, %d)", field.Precision, field.Scale)
}
return fmt.Sprintf("numeric(%d)", field.Precision)
}
return "decimal"
case schema.String:
if field.Size > 0 {
return fmt.Sprintf("varchar(%d)", field.Size)
}
return "text"
case schema.Time:
if field.Precision > 0 {
return fmt.Sprintf("timestamptz(%d)", field.Precision)
}
return "timestamptz"
case schema.Bytes:
return "bytea"
default:
return dialector.getSchemaCustomType(field)
}
}
func (dialector Dialector) getSchemaCustomType(field *schema.Field) string {
sqlType := string(field.DataType)
if field.AutoIncrement && !strings.Contains(strings.ToLower(sqlType), "serial") {
size := field.Size
if field.GORMDataType == schema.Uint {
size++
}
switch {
case size <= 16:
sqlType = "smallserial"
case size <= 32:
sqlType = "serial"
default:
sqlType = "bigserial"
}
}
return sqlType
}
func (dialector Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialector Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}
func getSerialDatabaseType(s string) (dbType string, ok bool) {
switch s {
case "smallserial":
return "smallint", true
case "serial":
return "integer", true
case "bigserial":
return "bigint", true
default:
return "", false
}
}

7
vendor/gorm.io/gorm/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,7 @@
TODO*
documents
coverage.txt
_book
.idea
vendor
.vscode

20
vendor/gorm.io/gorm/.golangci.yml generated vendored Normal file
View File

@@ -0,0 +1,20 @@
linters:
enable:
- cyclop
- exportloopref
- gocritic
- gosec
- ineffassign
- misspell
- prealloc
- unconvert
- unparam
- goimports
- whitespace
linters-settings:
whitespace:
multi-func: true
goimports:
local-prefixes: gorm.io/gorm

21
vendor/gorm.io/gorm/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

44
vendor/gorm.io/gorm/README.md generated vendored Normal file
View File

@@ -0,0 +1,44 @@
# GORM
The fantastic ORM library for Golang, aims to be developer friendly.
[![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm)
[![test status](https://github.com/go-gorm/gorm/workflows/tests/badge.svg?branch=master "test status")](https://github.com/go-gorm/gorm/actions)
[![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT)
[![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc)
## Overview
* Full-Featured ORM
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance)
* Hooks (Before/After Create/Save/Update/Delete/Find)
* Eager loading with `Preload`, `Joins`
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
* Context, Prepared Statement Mode, DryRun Mode
* Batch Insert, FindInBatches, Find To Map
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
* Composite Primary Key
* Auto Migrations
* Logger
* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
* Every feature comes with tests
* Developer Friendly
## Getting Started
* GORM Guides [https://gorm.io](https://gorm.io)
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
## Contributing
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
## Contributors
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
## License
© Jinzhu, 2013~time.Now
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)

579
vendor/gorm.io/gorm/association.go generated vendored Normal file
View File

@@ -0,0 +1,579 @@
package gorm
import (
"fmt"
"reflect"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// Association Mode contains some helper methods to handle relationship things easily.
type Association struct {
DB *DB
Relationship *schema.Relationship
Unscope bool
Error error
}
func (db *DB) Association(column string) *Association {
association := &Association{DB: db}
table := db.Statement.Table
if err := db.Statement.Parse(db.Statement.Model); err == nil {
db.Statement.Table = table
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
if association.Relationship == nil {
association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
}
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
}
} else {
association.Error = err
}
return association
}
func (association *Association) Unscoped() *Association {
return &Association{
DB: association.DB,
Relationship: association.Relationship,
Error: association.Error,
Unscope: true,
}
}
func (association *Association) Find(out interface{}, conds ...interface{}) error {
if association.Error == nil {
association.Error = association.buildCondition().Find(out, conds...).Error
}
return association.Error
}
func (association *Association) Append(values ...interface{}) error {
if association.Error == nil {
switch association.Relationship.Type {
case schema.HasOne, schema.BelongsTo:
if len(values) > 0 {
association.Error = association.Replace(values...)
}
default:
association.saveAssociation( /*clear*/ false, values...)
}
}
return association.Error
}
func (association *Association) Replace(values ...interface{}) error {
if association.Error == nil {
reflectValue := association.DB.Statement.ReflectValue
rel := association.Relationship
var oldBelongsToExpr clause.Expression
// we have to record the old BelongsTo value
if association.Unscope && rel.Type == schema.BelongsTo {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
oldBelongsToExpr = clause.IN{Column: column, Values: values}
}
}
// save associations
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
return association.Error
}
// set old associations's foreign key to null
switch rel.Type {
case schema.BelongsTo:
if len(values) == 0 {
updateMap := map[string]interface{}{}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
}
case reflect.Struct:
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
}
for _, ref := range rel.References {
updateMap[ref.ForeignKey.DBName] = nil
}
association.Error = association.DB.UpdateColumns(updateMap).Error
}
if association.Unscope && oldBelongsToExpr != nil {
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
case schema.HasOne, schema.HasMany:
var (
primaryFields []*schema.Field
foreignKeys []string
updateMap = map[string]interface{}{}
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values})
}
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateMap[ref.ForeignKey.DBName] = nil
} else if ref.PrimaryValue != "" {
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
if association.Unscope {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
} else {
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
}
case schema.Many2Many:
var (
primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}
} else {
tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}
association.Error = tx.Delete(modelValue).Error
}
}
return association.Error
}
func (association *Association) Delete(values ...interface{}) error {
if association.Error == nil {
var (
reflectValue = association.DB.Statement.ReflectValue
rel = association.Relationship
primaryFields []*schema.Field
foreignKeys []string
updateAttrs = map[string]interface{}{}
conds []clause.Expression
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
primaryFields = append(primaryFields, ref.PrimaryKey)
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
updateAttrs[ref.ForeignKey.DBName] = nil
} else {
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
switch rel.Type {
case schema.BelongsTo:
associationDB := association.DB.Session(&Session{})
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
if association.Unscope {
var foreignFields []*schema.Field
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
}
}
case schema.HasOne, schema.HasMany:
model := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := association.DB.Model(model)
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
if association.Unscope {
association.Error = tx.Clauses(conds...).Delete(model).Error
} else {
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
}
case schema.Many2Many:
var (
primaryFields, relPrimaryFields []*schema.Field
joinPrimaryKeys, joinRelPrimaryKeys []string
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
)
for _, ref := range rel.References {
if ref.PrimaryValue == "" {
if ref.OwnPrimaryKey {
primaryFields = append(primaryFields, ref.PrimaryKey)
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
} else {
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
}
} else {
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
}
}
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
} else {
return ErrPrimaryKeyRequired
}
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
}
if association.Error == nil {
// clean up deleted values's foreign key
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
cleanUpDeletedRelations := func(data reflect.Value) {
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
switch fieldValue.Kind() {
case reflect.Slice, reflect.Array:
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
for i := 0; i < fieldValue.Len(); i++ {
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
}
}
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
case reflect.Struct:
for idx, field := range rel.FieldSchema.PrimaryFields {
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
}
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
break
}
if rel.JoinTable == nil {
for _, ref := range rel.References {
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
} else {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
}
}
}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
}
case reflect.Struct:
cleanUpDeletedRelations(reflectValue)
}
}
}
return association.Error
}
func (association *Association) Clear() error {
return association.Replace()
}
func (association *Association) Count() (count int64) {
if association.Error == nil {
association.Error = association.buildCondition().Count(&count).Error
}
return
}
type assignBack struct {
Source reflect.Value
Index int
Dest reflect.Value
}
func (association *Association) saveAssociation(clear bool, values ...interface{}) {
var (
reflectValue = association.DB.Statement.ReflectValue
assignBacks []assignBack // assign association values back to arguments after save
)
appendToRelations := func(source, rv reflect.Value, clear bool) {
switch association.Relationship.Type {
case schema.HasOne, schema.BelongsTo:
switch rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() > 0 {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
}
}
case reflect.Struct:
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
}
}
case schema.HasMany, schema.Many2Many:
elemType := association.Relationship.Field.IndirectFieldType.Elem()
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
var fieldValue reflect.Value
if clear {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
} else {
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
reflect.Copy(fieldValue, oldFieldValue)
}
appendToFieldValues := func(ev reflect.Value) {
if ev.Type().AssignableTo(elemType) {
fieldValue = reflect.Append(fieldValue, ev)
} else if ev.Type().Elem().AssignableTo(elemType) {
fieldValue = reflect.Append(fieldValue, ev.Elem())
} else {
association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
}
if elemType.Kind() == reflect.Struct {
assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
}
}
switch rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
}
case reflect.Struct:
appendToFieldValues(rv.Addr())
}
if association.Error == nil {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
}
}
}
selectedSaveColumns := []string{association.Relationship.Name}
omitColumns := []string{}
selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, association.Relationship.Name) {
if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
columnName = name
}
} else if strings.HasPrefix(name, clause.Associations) {
columnName = name
}
if columnName != "" {
if ok {
selectedSaveColumns = append(selectedSaveColumns, columnName)
} else {
omitColumns = append(omitColumns, columnName)
}
}
}
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey {
selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
}
}
associationDB := association.DB.Session(&Session{}).Model(nil)
if !association.DB.FullSaveAssociations {
associationDB.Select(selectedSaveColumns)
}
if len(omitColumns) > 0 {
associationDB.Omit(omitColumns...)
}
associationDB = associationDB.Session(&Session{})
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(values) != reflectValue.Len() {
// clear old data
if clear && len(values) == 0 {
for i := 0; i < reflectValue.Len(); i++ {
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
association.Error = err
break
}
if association.Relationship.JoinTable == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
association.Error = err
break
}
}
}
}
}
break
}
association.Error = ErrInvalidValueOfLength
return
}
for i := 0; i < reflectValue.Len(); i++ {
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
// TODO support save slice data, sql with case?
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
}
case reflect.Struct:
// clear old data
if clear && len(values) == 0 {
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
if association.Relationship.JoinTable == nil && association.Error == nil {
for _, ref := range association.Relationship.References {
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
}
}
}
}
for idx, value := range values {
rv := reflect.Indirect(reflect.ValueOf(value))
appendToRelations(reflectValue, rv, clear && idx == 0)
}
if len(values) > 0 {
association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
}
}
for _, assignBack := range assignBacks {
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
if assignBack.Index > 0 {
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
} else {
reflect.Indirect(assignBack.Dest).Set(fieldValue)
}
}
}
func (association *Association) buildCondition() *DB {
var (
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
tx = association.DB.Model(modelValue)
)
if association.Relationship.JoinTable != nil {
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
joinStmt.AddClause(queryClause)
}
joinStmt.Build("WHERE")
if len(joinStmt.SQL.String()) > 0 {
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
}
}
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
ON: clause.Where{Exprs: queryConds},
}}})
} else {
tx.Clauses(clause.Where{Exprs: queryConds})
}
return tx
}

341
vendor/gorm.io/gorm/callbacks.go generated vendored Normal file
View File

@@ -0,0 +1,341 @@
package gorm
import (
"context"
"errors"
"fmt"
"reflect"
"sort"
"time"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": {db: db},
"query": {db: db},
"update": {db: db},
"delete": {db: db},
"row": {db: db},
"raw": {db: db},
},
}
}
// callbacks gorm callbacks manager
type callbacks struct {
processors map[string]*processor
}
type processor struct {
db *DB
Clauses []string
fns []func(*DB)
callbacks []*callback
}
type callback struct {
name string
before string
after string
remove bool
replace bool
match func(*DB) bool
handler func(*DB)
processor *processor
}
func (cs *callbacks) Create() *processor {
return cs.processors["create"]
}
func (cs *callbacks) Query() *processor {
return cs.processors["query"]
}
func (cs *callbacks) Update() *processor {
return cs.processors["update"]
}
func (cs *callbacks) Delete() *processor {
return cs.processors["delete"]
}
func (cs *callbacks) Row() *processor {
return cs.processors["row"]
}
func (cs *callbacks) Raw() *processor {
return cs.processors["raw"]
}
func (p *processor) Execute(db *DB) *DB {
// call scopes
for len(db.Statement.scopes) > 0 {
db = db.executeScopes()
}
var (
curTime = time.Now()
stmt = db.Statement
resetBuildClauses bool
)
if len(stmt.BuildClauses) == 0 {
stmt.BuildClauses = p.Clauses
resetBuildClauses = true
}
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
}
// assign model values
if stmt.Model == nil {
stmt.Model = stmt.Dest
} else if stmt.Dest == nil {
stmt.Dest = stmt.Model
}
// parse model values
if stmt.Model != nil {
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
} else {
db.AddError(err)
}
}
}
// assign stmt.ReflectValue
if stmt.Dest != nil {
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
for stmt.ReflectValue.Kind() == reflect.Ptr {
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
}
stmt.ReflectValue = stmt.ReflectValue.Elem()
}
if !stmt.ReflectValue.IsValid() {
db.AddError(ErrInvalidValue)
}
}
for _, f := range p.fns {
f(db)
}
if stmt.SQL.Len() > 0 {
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
sql, vars := stmt.SQL.String(), stmt.Vars
if filter, ok := db.Logger.(ParamsFilter); ok {
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
}
return db.Dialector.Explain(sql, vars...), db.RowsAffected
}, db.Error)
}
if !stmt.DB.DryRun {
stmt.SQL.Reset()
stmt.Vars = nil
}
if resetBuildClauses {
stmt.BuildClauses = nil
}
return db
}
func (p *processor) Get(name string) func(*DB) {
for i := len(p.callbacks) - 1; i >= 0; i-- {
if v := p.callbacks[i]; v.name == name && !v.remove {
return v.handler
}
}
return nil
}
func (p *processor) Before(name string) *callback {
return &callback{before: name, processor: p}
}
func (p *processor) After(name string) *callback {
return &callback{after: name, processor: p}
}
func (p *processor) Match(fc func(*DB) bool) *callback {
return &callback{match: fc, processor: p}
}
func (p *processor) Register(name string, fn func(*DB)) error {
return (&callback{processor: p}).Register(name, fn)
}
func (p *processor) Remove(name string) error {
return (&callback{processor: p}).Remove(name)
}
func (p *processor) Replace(name string, fn func(*DB)) error {
return (&callback{processor: p}).Replace(name, fn)
}
func (p *processor) compile() (err error) {
var callbacks []*callback
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
}
p.callbacks = callbacks
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
}
return
}
func (c *callback) Before(name string) *callback {
c.before = name
return c
}
func (c *callback) After(name string) *callback {
c.after = name
return c
}
func (c *callback) Register(name string, fn func(*DB)) error {
c.name = name
c.handler = fn
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
func (c *callback) Remove(name string) error {
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
func (c *callback) Replace(name string, fn func(*DB)) error {
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
c.name = name
c.handler = fn
c.replace = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile()
}
// getRIndex get right index from string slice
func getRIndex(strs []string, str string) int {
for i := len(strs) - 1; i >= 0; i-- {
if strs[i] == str {
return i
}
}
return -1
}
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
var (
names, sorted []string
sortCallback func(*callback) error
)
sort.SliceStable(cs, func(i, j int) bool {
if cs[j].before == "*" && cs[i].before != "*" {
return true
}
if cs[j].after == "*" && cs[i].after != "*" {
return true
}
return false
})
for _, c := range cs {
// show warning message the callback name already exists
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
}
names = append(names, c.name)
}
sortCallback = func(c *callback) error {
if c.before != "" { // if defined before callback
if c.before == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append([]string{c.name}, sorted...)
}
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if before callback already sorted, append current callback just after it
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
} else if curIdx > sortedIdx {
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
}
} else if idx := getRIndex(names, c.before); idx != -1 {
// if before callback exists
cs[idx].after = c.name
}
}
if c.after != "" { // if defined after callback
if c.after == "*" && len(sorted) > 0 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
sorted = append(sorted, c.name)
}
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
// if after callback sorted, append current callback to last
sorted = append(sorted, c.name)
} else if curIdx < sortedIdx {
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
}
} else if idx := getRIndex(names, c.after); idx != -1 {
// if after callback exists but haven't sorted
// set after callback's before callback to current callback
after := cs[idx]
if after.before == "" {
after.before = c.name
}
if err := sortCallback(after); err != nil {
return err
}
if err := sortCallback(c); err != nil {
return err
}
}
}
// if current callback haven't been sorted, append it to last
if getRIndex(sorted, c.name) == -1 {
sorted = append(sorted, c.name)
}
return nil
}
for _, c := range cs {
if err = sortCallback(c); err != nil {
return
}
}
for _, name := range sorted {
if idx := getRIndex(names, name); !cs[idx].remove {
fns = append(fns, cs[idx].handler)
}
}
return
}

453
vendor/gorm.io/gorm/callbacks/associations.go generated vendored Normal file
View File

@@ -0,0 +1,453 @@
package callbacks
import (
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Belongs To associations
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
setupReferences := func(obj reflect.Value, elem reflect.Value) {
for _, ref := range rel.References {
if !ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
dest[ref.ForeignKey.DBName] = pv
if _, ok := dest[rel.Name]; ok {
dest[rel.Name] = elem.Interface()
}
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
rValLen = db.Statement.ReflectValue.Len()
objs = make([]reflect.Value, 0, rValLen)
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
for i := 0; i < rValLen; i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() != reflect.Struct {
break
}
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
if !isPtr {
rv = rv.Addr()
}
objs = append(objs, obj)
elems = reflect.Append(elems, rv)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, rv)
}
}
}
if elems.Len() > 0 {
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
}
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
}
}
}
}
}
func SaveAfterAssociations(create bool) func(db *gorm.DB) {
return func(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
// Save Has One associations
for _, rel := range db.Statement.Schema.Relationships.HasOne {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
fieldType = rel.Field.FieldType
isPtr = fieldType.Kind() == reflect.Ptr
)
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
if rv.Kind() != reflect.Ptr {
rv = rv.Addr()
}
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
}
}
elems = reflect.Append(elems, rv)
}
}
}
if elems.Len() > 0 {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
if f.Kind() != reflect.Ptr {
f = f.Addr()
}
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
}
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
}
}
}
// Save Has Many associations
for _, rel := range db.Statement.Schema.Relationships.HasMany {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
}
}
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
if isPtr {
elems = reflect.Append(elems, elem)
} else {
elems = reflect.Append(elems, elem.Addr())
}
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
if elems.Len() > 0 {
assignmentColumns := make([]string, 0, len(rel.References))
for _, ref := range rel.References {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
}
}
// Save Many2Many associations
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
continue
}
fieldType := rel.Field.IndirectFieldType.Elem()
isPtr := fieldType.Kind() == reflect.Ptr
if !isPtr {
fieldType = reflect.PtrTo(fieldType)
}
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
objs := []reflect.Value{}
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
joinValue := reflect.New(rel.JoinTable.ModelType)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
} else if ref.PrimaryValue != "" {
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
} else {
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
}
}
joins = reflect.Append(joins, joinValue)
}
identityMap := map[string]bool{}
appendToElems := func(v reflect.Value) {
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
for i := 0; i < f.Len(); i++ {
elem := f.Index(i)
if !isPtr {
elem = elem.Addr()
}
objs = append(objs, v)
elems = reflect.Append(elems, elem)
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
for _, pf := range rel.FieldSchema.PrimaryFields {
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
relPrimaryValues = append(relPrimaryValues, pfv)
}
}
cacheKey := utils.ToStringKey(relPrimaryValues...)
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
if cacheKey != "" { // has primary fields
identityMap[cacheKey] = true
}
distinctElems = reflect.Append(distinctElems, elem)
}
}
}
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
obj := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(obj).Kind() == reflect.Struct {
appendToElems(obj)
}
}
case reflect.Struct:
appendToElems(db.Statement.ReflectValue)
}
// optimize elems of reflect value length
if elemLen := elems.Len(); elemLen > 0 {
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
}
for i := 0; i < elemLen; i++ {
appendToJoins(objs[i], elems.Index(i))
}
}
if joins.Len() > 0 {
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
}).Create(joins.Interface()).Error)
}
}
}
}
}
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
for _, dbName := range s.PrimaryFieldDBNames {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
}
onConflict.UpdateAll = stmt.DB.FullSaveAssociations
if !onConflict.UpdateAll {
onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
}
} else {
onConflict.DoNothing = true
}
return
}
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
// stop save association loop
if checkAssociationsSaved(db, rValues) {
return nil
}
var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
refName = rel.Name + "."
values = rValues.Interface()
)
for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, refName) {
columnName = strings.TrimPrefix(name, refName)
}
if columnName != "" {
if ok {
selects = append(selects, columnName)
} else {
omits = append(omits, columnName)
}
}
}
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
FullSaveAssociations: db.FullSaveAssociations,
SkipHooks: db.Statement.SkipHooks,
DisableNestedTransaction: true,
})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if tx.Statement.FullSaveAssociations {
tx = tx.Set("gorm:update_track_time", true)
}
if len(selects) > 0 {
tx = tx.Select(selects)
} else if restricted && len(omits) == 0 {
tx = tx.Omit(clause.Associations)
}
if len(omits) > 0 {
tx = tx.Omit(omits...)
}
return db.AddError(tx.Create(values).Error)
}
// check association values has been saved
// if values kind is Struct, check it has been saved
// if values kind is Slice/Array, check all items have been saved
var visitMapStoreKey = "gorm:saved_association_map"
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
if visit, ok := db.Get(visitMapStoreKey); ok {
if v, ok := visit.(*visitMap); ok {
if loadOrStoreVisitMap(v, values) {
return true
}
}
} else {
vistMap := make(visitMap)
loadOrStoreVisitMap(&vistMap, values)
db.Set(visitMapStoreKey, &vistMap)
}
return false
}

83
vendor/gorm.io/gorm/callbacks/callbacks.go generated vendored Normal file
View File

@@ -0,0 +1,83 @@
package callbacks
import (
"gorm.io/gorm"
)
var (
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
updateClauses = []string{"UPDATE", "SET", "WHERE"}
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
)
type Config struct {
LastInsertIDReversed bool
CreateClauses []string
QueryClauses []string
UpdateClauses []string
DeleteClauses []string
}
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
enableTransaction := func(db *gorm.DB) bool {
return !db.SkipDefaultTransaction
}
if len(config.CreateClauses) == 0 {
config.CreateClauses = createClauses
}
if len(config.QueryClauses) == 0 {
config.QueryClauses = queryClauses
}
if len(config.DeleteClauses) == 0 {
config.DeleteClauses = deleteClauses
}
if len(config.UpdateClauses) == 0 {
config.UpdateClauses = updateClauses
}
createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
createCallback.Register("gorm:create", Create(config))
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
createCallback.Clauses = config.CreateClauses
queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", Query)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)
queryCallback.Clauses = config.QueryClauses
deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
updateCallback.Register("gorm:update", Update(config))
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
updateCallback.Clauses = config.UpdateClauses
rowCallback := db.Callback().Row()
rowCallback.Register("gorm:row", RowQuery)
rowCallback.Clauses = config.QueryClauses
rawCallback := db.Callback().Raw()
rawCallback.Register("gorm:raw", RawExec)
rawCallback.Clauses = config.QueryClauses
}

32
vendor/gorm.io/gorm/callbacks/callmethod.go generated vendored Normal file
View File

@@ -0,0 +1,32 @@
package callbacks
import (
"reflect"
"gorm.io/gorm"
)
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
tx := db.Session(&gorm.Session{NewDB: true})
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
db.Statement.CurDestIndex = 0
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
fc(value.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
return
}
db.Statement.CurDestIndex++
}
case reflect.Struct:
if db.Statement.ReflectValue.CanAddr() {
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
} else {
db.AddError(gorm.ErrInvalidValue)
}
}
}
}

385
vendor/gorm.io/gorm/callbacks/create.go generated vendored Normal file
View File

@@ -0,0 +1,385 @@
package callbacks
import (
"fmt"
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// BeforeCreate before create hooks
func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeCreate {
if i, ok := value.(BeforeCreateInterface); ok {
called = true
db.AddError(i.BeforeCreate(tx))
}
}
return called
})
}
}
// Create create hook
func Create(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
if !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.CreateClauses {
db.Statement.AddClause(c)
}
}
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
}
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
}
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
isDryRun := !db.DryRun && db.Error == nil
if !isDryRun {
return
}
ok, mode := hasReturning(db, supportReturning)
if ok {
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
mode |= gorm.ScanOnConflictDoNothing
}
}
rows, err := db.Statement.ConnPool.QueryContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if db.AddError(err) == nil {
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)
}
return
}
result, err := db.Statement.ConnPool.ExecContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
if err != nil {
db.AddError(err)
return
}
db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected == 0 {
return
}
var (
pkField *schema.Field
pkFieldName = "@id"
)
if db.Statement.Schema != nil {
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
return
}
pkField = db.Statement.Schema.PrioritizedPrimaryField
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
}
insertID, err := result.LastInsertId()
insertOk := err == nil && insertID > 0
if !insertOk {
db.AddError(err)
return
}
// append @id column with value for auto-increment primary key
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
switch values := db.Statement.Dest.(type) {
case map[string]interface{}:
values[pkFieldName] = insertID
case *map[string]interface{}:
(*values)[pkFieldName] = insertID
case []map[string]interface{}, *[]map[string]interface{}:
mapValues, ok := values.([]map[string]interface{})
if !ok {
if v, ok := values.(*[]map[string]interface{}); ok {
if *v != nil {
mapValues = *v
}
}
}
for _, mapValue := range mapValues {
if mapValue != nil {
mapValue[pkFieldName] = insertID
}
insertID += schema.DefaultAutoIncrementIncrement
}
default:
if pkField == nil {
return
}
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if config.LastInsertIDReversed {
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID -= pkField.AutoIncrementIncrement
}
}
} else {
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
rv := db.Statement.ReflectValue.Index(i)
if reflect.Indirect(rv).Kind() != reflect.Struct {
break
}
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
insertID += pkField.AutoIncrementIncrement
}
}
}
case reflect.Struct:
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
}
}
// AfterCreate after create hooks
func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterCreate {
if i, ok := value.(AfterCreateInterface); ok {
called = true
db.AddError(i.AfterCreate(tx))
}
}
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
return called
})
}
}
// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
curTime := stmt.DB.NowFunc()
switch value := stmt.Dest.(type) {
case map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, value)
case *map[string]interface{}:
values = ConvertMapToValuesForCreate(stmt, *value)
case []map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
case *[]map[string]interface{}:
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
default:
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
_, updateTrackTime = stmt.Get("gorm:update_track_time")
isZero bool
)
stmt.Settings.Delete("gorm:update_track_time")
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
values.Columns = append(values.Columns, clause.Column{Name: db})
}
}
}
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
rValLen := stmt.ReflectValue.Len()
if rValLen == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
stmt.SQL.Grow(rValLen * 18)
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
values.Values = make([][]interface{}, rValLen)
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
for i := 0; i < rValLen; i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
if !rv.IsValid() {
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
return
}
values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
if field.DefaultValueInterface != nil {
values.Values[i][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, rv, curTime))
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
if len(defaultValueFieldsHavingValue[field]) == 0 {
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
}
defaultValueFieldsHavingValue[field][i] = rvOfvalue
}
}
}
}
for field, vs := range defaultValueFieldsHavingValue {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
for idx := range values.Values {
if vs[idx] == nil {
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
} else {
values.Values[idx] = append(values.Values[idx], vs[idx])
}
}
}
case reflect.Struct:
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
}
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
values.Values[0] = append(values.Values[0], rvOfvalue)
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) >= 1 {
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
if field.AutoUpdateTime > 0 {
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
switch field.AutoUpdateTime {
case schema.UnixNanosecond:
assignment.Value = curTime.UnixNano()
case schema.UnixMillisecond:
assignment.Value = curTime.UnixNano() / 1e6
case schema.UnixSecond:
assignment.Value = curTime.Unix()
}
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
} else {
columns = append(columns, column.Name)
}
}
}
}
}
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
if len(onConflict.DoUpdates) == 0 {
onConflict.DoNothing = true
}
// use primary fields as default OnConflict columns
if len(onConflict.Columns) == 0 {
for _, field := range stmt.Schema.PrimaryFields {
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
}
}
stmt.AddClause(onConflict)
}
}
}
return values
}

185
vendor/gorm.io/gorm/callbacks/delete.go generated vendored Normal file
View File

@@ -0,0 +1,185 @@
package callbacks
import (
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx))
return true
}
return false
})
}
}
func DeleteBeforeAssociations(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
if !restricted {
return
}
for column, v := range selectColumns {
if !v {
continue
}
rel, ok := db.Statement.Schema.Relationships.Relations[column]
if !ok {
continue
}
switch rel.Type {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false
if db.Statement.Unscoped {
tx = tx.Unscoped()
}
if len(db.Statement.Selects) > 0 {
selects := make([]string, 0, len(db.Statement.Selects))
for _, s := range db.Statement.Selects {
if s == clause.Associations {
selects = append(selects, s)
} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) {
selects = append(selects, strings.TrimPrefix(s, columnPrefix))
}
}
if len(selects) > 0 {
tx = tx.Select(selects)
}
}
for _, cond := range queryConds {
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
withoutConditions = true
break
}
}
if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
case schema.Many2Many:
var (
queryConds = make([]clause.Expression, 0, len(rel.References))
foreignFields = make([]*schema.Field, 0, len(rel.References))
relForeignKeys = make([]string, 0, len(rel.References))
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
queryConds = append(queryConds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
}
}
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
return
}
}
}
}
}
func Delete(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
db.Statement.AddClauseIfNotExists(clause.Delete{})
if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
}
db.Statement.AddClauseIfNotExists(clause.From{})
db.Statement.Build(db.Statement.BuildClauses...)
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
ok, mode := hasReturning(db, supportReturning)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
return
}
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
db.AddError(rows.Close())
}
}
}
}
func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx))
return true
}
return false
})
}
}

152
vendor/gorm.io/gorm/callbacks/helper.go generated vendored Normal file
View File

@@ -0,0 +1,152 @@
package callbacks
import (
"reflect"
"sort"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// ConvertMapToValuesForCreate convert map to values
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
values.Columns = make([]clause.Column, 0, len(mapValue))
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
keys := make([]string, 0, len(mapValue))
for k := range mapValue {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
value := mapValue[k]
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
values.Columns = append(values.Columns, clause.Column{Name: k})
if len(values.Values) == 0 {
values.Values = [][]interface{}{{}}
}
values.Values[0] = append(values.Values[0], value)
}
}
return
}
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
columns := make([]string, 0, len(mapValues))
// when the length of mapValues is zero,return directly here
// no need to call stmt.SelectAndOmitColumns method
if len(mapValues) == 0 {
stmt.AddError(gorm.ErrEmptySlice)
return
}
var (
result = make(map[string][]interface{}, len(mapValues))
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
)
for idx, mapValue := range mapValues {
for k, v := range mapValue {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
k = field.DBName
}
}
if _, ok := result[k]; !ok {
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
result[k] = make([]interface{}, len(mapValues))
columns = append(columns, k)
} else {
continue
}
}
result[k][idx] = v
}
}
sort.Strings(columns)
values.Values = make([][]interface{}, len(mapValues))
values.Columns = make([]clause.Column, len(columns))
for idx, column := range columns {
values.Columns[idx] = clause.Column{Name: column}
for i, v := range result[column] {
if len(values.Values[i]) == 0 {
values.Values[i] = make([]interface{}, len(columns))
}
values.Values[i][idx] = v
}
}
return
}
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
if supportReturning {
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
returning, _ := c.Expression.(clause.Returning)
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
return true, 0
}
return true, gorm.ScanUpdate
}
}
return false, 0
}
func checkMissingWhereConditions(db *gorm.DB) {
if !db.AllowGlobalUpdate && db.Error == nil {
where, withCondition := db.Statement.Clauses["WHERE"]
if withCondition {
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
whereClause, _ := where.Expression.(clause.Where)
withCondition = len(whereClause.Exprs) > 1
}
}
if !withCondition {
db.AddError(gorm.ErrMissingWhereClause)
}
return
}
}
type visitMap = map[reflect.Value]bool
// Check if circular values, return true if loaded
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
loaded = true
for i := 0; i < v.Len(); i++ {
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
loaded = false
}
}
case reflect.Struct, reflect.Interface:
if v.CanAddr() {
p := v.Addr()
if _, ok := (*visitMap)[p]; ok {
return true
}
(*visitMap)[p] = true
}
}
return
}

39
vendor/gorm.io/gorm/callbacks/interfaces.go generated vendored Normal file
View File

@@ -0,0 +1,39 @@
package callbacks
import "gorm.io/gorm"
type BeforeCreateInterface interface {
BeforeCreate(*gorm.DB) error
}
type AfterCreateInterface interface {
AfterCreate(*gorm.DB) error
}
type BeforeUpdateInterface interface {
BeforeUpdate(*gorm.DB) error
}
type AfterUpdateInterface interface {
AfterUpdate(*gorm.DB) error
}
type BeforeSaveInterface interface {
BeforeSave(*gorm.DB) error
}
type AfterSaveInterface interface {
AfterSave(*gorm.DB) error
}
type BeforeDeleteInterface interface {
BeforeDelete(*gorm.DB) error
}
type AfterDeleteInterface interface {
AfterDelete(*gorm.DB) error
}
type AfterFindInterface interface {
AfterFind(*gorm.DB) error
}

320
vendor/gorm.io/gorm/callbacks/preload.go generated vendored Normal file
View File

@@ -0,0 +1,320 @@
package callbacks
import (
"fmt"
"reflect"
"sort"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// parsePreloadMap extracts nested preloads. e.g.
//
// // schema has a "k0" relation and a "k7.k8" embedded relation
// parsePreloadMap(schema, map[string][]interface{}{
// clause.Associations: {"arg1"},
// "k1": {"arg2"},
// "k2.k3": {"arg3"},
// "k4.k5.k6": {"arg4"},
// })
// // preloadMap is
// map[string]map[string][]interface{}{
// "k0": {},
// "k7": {
// "k8": {},
// },
// "k1": {},
// "k2": {
// "k3": {"arg3"},
// },
// "k4": {
// "k5.k6": {"arg4"},
// },
// }
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
preloadMap := map[string]map[string][]interface{}{}
setPreloadMap := func(name, value string, args []interface{}) {
if _, ok := preloadMap[name]; !ok {
preloadMap[name] = map[string][]interface{}{}
}
if value != "" {
preloadMap[name][value] = args
}
}
for name, args := range preloads {
preloadFields := strings.Split(name, ".")
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
if preloadFields[0] == clause.Associations {
for _, relation := range s.Relationships.Relations {
if relation.Schema == s {
setPreloadMap(relation.Name, value, args)
}
}
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
for _, value := range embeddedValues(embeddedRelations) {
setPreloadMap(embedded, value, args)
}
}
} else {
setPreloadMap(preloadFields[0], value, args)
}
}
return preloadMap
}
func embeddedValues(embeddedRelations *schema.Relationships) []string {
if embeddedRelations == nil {
return nil
}
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
for _, relation := range embeddedRelations.Relations {
// skip first struct name
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
}
for _, relations := range embeddedRelations.EmbeddedRelations {
names = append(names, embeddedValues(relations)...)
}
return names
}
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
// If the current relationship is embedded or joined, current query will be ignored.
//
//nolint:cyclop
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
// avoid random traversal of the map
preloadNames := make([]string, 0, len(preloadMap))
for key := range preloadMap {
preloadNames = append(preloadNames, key)
}
sort.Strings(preloadNames)
isJoined := func(name string) (joined bool, nestedJoins []string) {
for _, join := range joins {
if _, ok := relationships.Relations[join]; ok && name == join {
joined = true
continue
}
joinNames := strings.SplitN(join, ".", 2)
if len(joinNames) == 2 {
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
joined = true
nestedJoins = append(nestedJoins, joinNames[1])
}
}
}
return joined, nestedJoins
}
for _, name := range preloadNames {
if relations := relationships.EmbeddedRelations[name]; relations != nil {
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
return err
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
tx.Statement.ReflectValue = db.Statement.ReflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
return err
}
}
} else {
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
}
}
return nil
}
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
db.Statement.Settings.Range(func(k, v interface{}) bool {
tx.Statement.Settings.Store(k, v)
return true
})
if err := tx.Statement.Parse(dest); err != nil {
tx.AddError(err)
return tx
}
tx.Statement.ReflectValue = reflectValue
tx.Statement.Unscoped = db.Statement.Unscoped
return tx
}
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
var (
reflectValue = tx.Statement.ReflectValue
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
foreignValues [][]interface{}
identityMap = map[string][]reflect.Value{}
inlineConds []interface{}
)
if rel.JoinTable != nil {
var (
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
joinForeignKeys = make([]string, 0, len(rel.References))
)
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
} else {
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
}
}
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(joinForeignValues) == 0 {
return nil
}
joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
return err
}
// convert join identity map to relation identity map
fieldValues := make([]interface{}, len(joinForeignFields))
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
for i := 0; i < joinResults.Len(); i++ {
joinIndexValue := joinResults.Index(i)
for idx, field := range joinForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
}
for idx, field := range joinRelForeignFields {
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
}
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
joinKey := utils.ToStringKey(joinFieldValues...)
identityMap[joinKey] = append(identityMap[joinKey], results...)
}
}
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
relForeignFields = append(relForeignFields, ref.ForeignKey)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
} else {
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
relForeignFields = append(relForeignFields, ref.PrimaryKey)
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
if len(foreignValues) == 0 {
return nil
}
}
// nested preload
for p, pvs := range preloads {
tx = tx.Preload(p, pvs...)
}
reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
if len(values) != 0 {
for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
} else {
inlineConds = append(inlineConds, cond)
}
}
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
return err
}
}
fieldValues := make([]interface{}, len(relForeignFields))
// clean up old values before preloading
switch reflectValue.Kind() {
case reflect.Struct:
switch rel.Type {
case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
switch rel.Type {
case schema.HasMany, schema.Many2Many:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
default:
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
}
}
}
for i := 0; i < reflectResults.Len(); i++ {
elem := reflectResults.Index(i)
for idx, field := range relForeignFields {
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
}
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
if !ok {
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
}
for _, data := range datas {
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
}
reflectFieldValue = reflect.Indirect(reflectFieldValue)
switch reflectFieldValue.Kind() {
case reflect.Struct:
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
case reflect.Slice, reflect.Array:
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
} else {
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
}
}
}
}
return tx.Error
}

299
vendor/gorm.io/gorm/callbacks/query.go generated vendored Normal file
View File

@@ -0,0 +1,299 @@
package callbacks
import (
"fmt"
"reflect"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func Query(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if !db.DryRun && db.Error == nil {
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
defer func() {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)
}
}
}
func BuildQuerySQL(db *gorm.DB) {
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.QueryClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(100)
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
var conds []clause.Expression
for _, primaryField := range db.Statement.Schema.PrimaryFields {
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
}
}
if len(conds) > 0 {
db.Statement.AddClause(clause.Where{Exprs: conds})
}
}
if len(db.Statement.Selects) > 0 {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
for idx, name := range db.Statement.Selects {
if db.Statement.Schema == nil {
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
} else {
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
}
}
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
for _, dbName := range db.Statement.Schema.DBNames {
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
}
}
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
queryFields := db.QueryFields
if !queryFields {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
case reflect.Slice:
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
}
}
if queryFields {
stmt := gorm.Statement{DB: db}
// smaller struct
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
for idx, dbName := range stmt.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
}
}
// inline joins
fromClause := clause.From{}
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
fromClause = v
}
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
for idx, dbName := range db.Statement.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
specifiedRelationsName := make(map[string]interface{})
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
var relations []*schema.Relationship
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
if ok {
isRelations = true
relations = append(relations, relation)
} else {
// handle nested join like "Manager.Company"
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
break
}
}
if isNestedJoin {
isRelations = true
relations = gussNestedRelations
}
}
}
if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
}
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
for _, s := range relation.FieldSchema.DBNames {
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
Table: tableAliasName,
Name: s,
Alias: utils.NestedRelationName(tableAliasName, s),
})
}
}
exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
}
} else {
if ref.PrimaryValue == "" {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
}
} else {
exprs[idx] = clause.Eq{
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
}
}
}
}
{
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
for _, c := range relation.FieldSchema.QueryClauses {
onStmt.AddClause(c)
}
if join.On != nil {
onStmt.AddClause(join.On)
}
if cs, ok := onStmt.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
where.Build(&onStmt)
if onSQL := onStmt.SQL.String(); onSQL != "" {
vars := onStmt.Vars
for idx, v := range vars {
bindvar := strings.Builder{}
onStmt.Vars = vars[0 : idx+1]
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
}
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
}
}
}
}
return clause.Join{
Type: joinType,
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
ON: clause.Where{Exprs: exprs},
}
}
parentTableName := clause.CurrentTable
for _, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
}
if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
}
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
})
}
}
db.Statement.AddClause(fromClause)
} else {
db.Statement.AddClauseIfNotExists(clause.From{})
}
db.Statement.AddClauseIfNotExists(clauseSelect)
db.Statement.Build(db.Statement.BuildClauses...)
}
}
func Preload(db *gorm.DB) {
if db.Error == nil && len(db.Statement.Preloads) > 0 {
if db.Statement.Schema == nil {
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
return
}
joins := make([]string, 0, len(db.Statement.Joins))
for _, join := range db.Statement.Joins {
joins = append(joins, join.Name)
}
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
if tx.Error != nil {
return
}
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
}
}
func AfterQuery(db *gorm.DB) {
// clear the joins after query because preload need it
db.Statement.Joins = nil
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx))
return true
}
return false
})
}
}

17
vendor/gorm.io/gorm/callbacks/raw.go generated vendored Normal file
View File

@@ -0,0 +1,17 @@
package callbacks
import (
"gorm.io/gorm"
)
func RawExec(db *gorm.DB) {
if db.Error == nil && !db.DryRun {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if err != nil {
db.AddError(err)
return
}
db.RowsAffected, _ = result.RowsAffected()
}
}

23
vendor/gorm.io/gorm/callbacks/row.go generated vendored Normal file
View File

@@ -0,0 +1,23 @@
package callbacks
import (
"gorm.io/gorm"
)
func RowQuery(db *gorm.DB) {
if db.Error == nil {
BuildQuerySQL(db)
if db.DryRun || db.Error != nil {
return
}
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
db.Statement.Settings.Delete("rows")
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
} else {
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
}
db.RowsAffected = -1
}
}

32
vendor/gorm.io/gorm/callbacks/transaction.go generated vendored Normal file
View File

@@ -0,0 +1,32 @@
package callbacks
import (
"gorm.io/gorm"
)
func BeginTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction && db.Error == nil {
if tx := db.Begin(); tx.Error == nil {
db.Statement.ConnPool = tx.Statement.ConnPool
db.InstanceSet("gorm:started_transaction", true)
} else if tx.Error == gorm.ErrInvalidTransaction {
tx.Error = nil
} else {
db.Error = tx.Error
}
}
}
func CommitOrRollbackTransaction(db *gorm.DB) {
if !db.Config.SkipDefaultTransaction {
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
if db.Error != nil {
db.Rollback()
} else {
db.Commit()
}
db.Statement.ConnPool = db.ConnPool
}
}
}

304
vendor/gorm.io/gorm/callbacks/update.go generated vendored Normal file
View File

@@ -0,0 +1,304 @@
package callbacks
import (
"reflect"
"sort"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
func SetupUpdateReflectValue(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil {
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
}
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
if _, ok := dest[rel.Name]; ok {
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
}
}
}
}
}
}
// BeforeUpdate before update hooks
func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
called = true
db.AddError(i.BeforeSave(tx))
}
}
if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(BeforeUpdateInterface); ok {
called = true
db.AddError(i.BeforeUpdate(tx))
}
}
return called
})
}
}
// Update update hook
func Update(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
return func(db *gorm.DB) {
if db.Error != nil {
return
}
if db.Statement.Schema != nil {
for _, c := range db.Statement.Schema.UpdateClauses {
db.Statement.AddClause(c)
}
}
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Update{})
if _, ok := db.Statement.Clauses["SET"]; !ok {
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
defer delete(db.Statement.Clauses, "SET")
db.Statement.AddClause(set)
} else {
return
}
}
db.Statement.Build(db.Statement.BuildClauses...)
}
checkMissingWhereConditions(db)
if !db.DryRun && db.Error == nil {
if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
dest := db.Statement.Dest
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
}
}
}
}
// AfterUpdate after update hooks
func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterUpdate {
if i, ok := value.(AfterUpdateInterface); ok {
called = true
db.AddError(i.AfterUpdate(tx))
}
}
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
called = true
db.AddError(i.AfterSave(tx))
}
}
return called
})
}
}
// ConvertToAssignments convert to update assignments
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
var (
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
assignValue func(field *schema.Field, value interface{})
)
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
assignValue = func(field *schema.Field, value interface{}) {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
}
}
}
case reflect.Struct:
assignValue = func(field *schema.Field, value interface{}) {
if stmt.ReflectValue.CanAddr() {
field.Set(stmt.Context, stmt.ReflectValue, value)
}
}
default:
assignValue = func(field *schema.Field, value interface{}) {
}
}
updatingValue := reflect.ValueOf(stmt.Dest)
for updatingValue.Kind() == reflect.Ptr {
updatingValue = updatingValue.Elem()
}
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if size := stmt.ReflectValue.Len(); size > 0 {
var isZero bool
for i := 0; i < size; i++ {
for _, field := range stmt.Schema.PrimaryFields {
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
if !isZero {
break
}
}
}
if !isZero {
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
case reflect.Struct:
for _, field := range stmt.Schema.PrimaryFields {
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
switch value := updatingValue.Interface().(type) {
case map[string]interface{}:
set = make([]clause.Assignment, 0, len(value))
keys := make([]string, 0, len(value))
for k := range value {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
kv := value[k]
if _, ok := kv.(*gorm.DB); ok {
kv = []interface{}{kv}
}
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(k); field != nil {
if field.DBName != "" {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
assignValue(field, value[k])
}
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
assignValue(field, value[k])
}
continue
}
}
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
}
}
if !stmt.SkipHooks && stmt.Schema != nil {
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
now := stmt.DB.NowFunc()
assignValue(field, now)
if field.AutoUpdateTime == schema.UnixNanosecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
} else if field.AutoUpdateTime == schema.UnixMillisecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
} else if field.AutoUpdateTime == schema.UnixSecond {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
} else {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
}
}
}
}
}
default:
updatingSchema := stmt.Schema
var isDiffSchema bool
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
// different schema
updatingStmt := &gorm.Statement{DB: stmt.DB}
if err := updatingStmt.Parse(stmt.Dest); err == nil {
updatingSchema = updatingStmt.Schema
isDiffSchema = true
}
}
switch updatingValue.Kind() {
case reflect.Struct:
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
for _, dbName := range stmt.Schema.DBNames {
if field := updatingSchema.LookUpField(dbName); field != nil {
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
value, isZero := field.ValueOf(stmt.Context, updatingValue)
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
} else if field.AutoUpdateTime == schema.UnixMillisecond {
value = stmt.DB.NowFunc().UnixNano() / 1e6
} else if field.AutoUpdateTime == schema.UnixSecond {
value = stmt.DB.NowFunc().Unix()
} else {
value = stmt.DB.NowFunc()
}
isZero = false
}
if (ok || !isZero) && field.Updatable {
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
assignField := field
if isDiffSchema {
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
assignField = originField
}
}
assignValue(assignField, value)
}
}
} else {
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
}
}
}
}
default:
stmt.AddError(gorm.ErrInvalidData)
}
}
return
}

448
vendor/gorm.io/gorm/chainable_api.go generated vendored Normal file
View File

@@ -0,0 +1,448 @@
package gorm
import (
"fmt"
"regexp"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
)
// Model specify the model you would like to run db operations
//
// // update all users's name to `hello`
// db.Model(&User{}).Update("name", "hello")
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
// db.Model(&user).Update("name", "hello")
func (db *DB) Model(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Model = value
return
}
// Clauses Add clauses
//
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
// advanced techniques like specifying lock strength and optimizer hints. See the
// [docs] for more depth.
//
// // add a simple limit clause
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
// // tell the optimizer to use the `idx_user_name` index
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
// // specify the lock strength to UPDATE
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
//
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
tx = db.getInstance()
var whereConds []interface{}
for _, cond := range conds {
if c, ok := cond.(clause.Interface); ok {
tx.Statement.AddClause(c)
} else if optimizer, ok := cond.(StatementModifier); ok {
optimizer.ModifyStatement(tx.Statement)
} else {
whereConds = append(whereConds, cond)
}
}
if len(whereConds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
}
return
}
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
// Table specify the table you would like to run db operations
//
// // Get a user
// db.Table("users").Take(&result)
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
if results[1] != "" {
tx.Statement.Table = results[1]
} else {
tx.Statement.Table = results[2]
}
}
} else if tables := strings.Split(name, "."); len(tables) == 2 {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = tables[1]
} else if name != "" {
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
tx.Statement.Table = name
} else {
tx.Statement.TableExpr = nil
tx.Statement.Table = ""
}
return
}
// Distinct specify distinct fields that you want querying
//
// // Select distinct names of users
// db.Distinct("name").Find(&results)
// // Select distinct name/age pairs from users
// db.Distinct("name", "age").Find(&results)
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Distinct = true
if len(args) > 0 {
tx = tx.Select(args[0], args[1:]...)
}
return
}
// Select specify fields that you want when querying, creating, updating
//
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
// Select accepts both string arguments and arrays.
//
// // Select name and age of user using multiple arguments
// db.Select("name", "age").Find(&users)
// // Select name and age of user using an array
// db.Select([]string{"name", "age"}).Find(&users)
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
switch v := query.(type) {
case []string:
tx.Statement.Selects = v
for _, arg := range args {
switch arg := arg.(type) {
case string:
tx.Statement.Selects = append(tx.Statement.Selects, arg)
case []string:
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
return
}
}
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
case string:
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
} else if strings.Count(v, "@") > 0 && len(args) > 0 {
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.NamedExpr{SQL: v, Vars: args},
})
} else {
tx.Statement.Selects = []string{v}
for _, arg := range args {
switch arg := arg.(type) {
case string:
tx.Statement.Selects = append(tx.Statement.Selects, arg)
case []string:
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
default:
tx.Statement.AddClause(clause.Select{
Distinct: db.Statement.Distinct,
Expression: clause.Expr{SQL: v, Vars: args},
})
return
}
}
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
clause.Expression = nil
tx.Statement.Clauses["SELECT"] = clause
}
}
default:
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
}
return
}
// Omit specify fields that you want to ignore when creating, updating and querying
func (db *DB) Omit(columns ...string) (tx *DB) {
tx = db.getInstance()
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
} else {
tx.Statement.Omits = columns
}
return
}
// Where add conditions
//
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
//
// // Find the first user with name jinzhu
// db.Where("name = ?", "jinzhu").First(&user)
// // Find the first user with name jinzhu and age 20
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
// // Find the first user with name jinzhu and age not equal to 20
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
//
// [docs]: https://gorm.io/docs/query.html#Conditions
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: conds})
}
return
}
// Not add NOT conditions
//
// Not works similarly to where, and has the same syntax.
//
// // Find the first user with name not equal to jinzhu
// db.Not("name = ?", "jinzhu").First(&user)
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
}
return
}
// Or add OR conditions
//
// Or is used to chain together queries with an OR.
//
// // Find the first user with name equal to jinzhu or john
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
}
return
}
// Joins specify Joins conditions
//
// db.Joins("Account").Find(&user)
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.LeftJoin, query, args...)
}
// InnerJoins specify inner joins conditions
// db.InnerJoins("Account").Find(&user)
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
return joins(db, clause.InnerJoin, query, args...)
}
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(args) == 1 {
if db, ok := args[0].(*DB); ok {
j := join{
Name: query, Conds: args, Selects: db.Statement.Selects,
Omits: db.Statement.Omits, JoinType: joinType,
}
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
j.On = &where
}
tx.Statement.Joins = append(tx.Statement.Joins, j)
return
}
}
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
return
}
// Group specify the group method on the find
//
// // Select the sum age of users with given names
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
func (db *DB) Group(name string) (tx *DB) {
tx = db.getInstance()
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
tx.Statement.AddClause(clause.GroupBy{
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
})
return
}
// Having specify HAVING conditions for GROUP BY
//
// // Select the sum age of users with name jinzhu
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.GroupBy{
Having: tx.Statement.BuildCondition(query, args...),
})
return
}
// Order specify order when retrieving records from database
//
// db.Order("name DESC")
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
func (db *DB) Order(value interface{}) (tx *DB) {
tx = db.getInstance()
switch v := value.(type) {
case clause.OrderByColumn:
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{v},
})
case string:
if v != "" {
tx.Statement.AddClause(clause.OrderBy{
Columns: []clause.OrderByColumn{{
Column: clause.Column{Name: v, Raw: true},
}},
})
}
}
return
}
// Limit specify the number of records to be retrieved
//
// Limit conditions can be cancelled by using `Limit(-1)`.
//
// // retrieve 3 users
// db.Limit(3).Find(&users)
// // retrieve 3 users into users1, and all users into users2
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
func (db *DB) Limit(limit int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Limit: &limit})
return
}
// Offset specify the number of records to skip before starting to return the records
//
// Offset conditions can be cancelled by using `Offset(-1)`.
//
// // select the third user
// db.Offset(2).First(&user)
// // select the first user by cancelling an earlier chained offset
// db.Offset(5).Offset(-1).First(&user)
func (db *DB) Offset(offset int) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Limit{Offset: offset})
return
}
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
//
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
// return db.Where("amount > ?", 1000)
// }
//
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
// return func (db *gorm.DB) *gorm.DB {
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
// }
// }
//
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
tx = db.getInstance()
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
return tx
}
func (db *DB) executeScopes() (tx *DB) {
scopes := db.Statement.scopes
db.Statement.scopes = nil
for _, scope := range scopes {
db = scope(db)
}
return db
}
// Preload preload associations with given conditions
//
// // get all users, and preload all non-cancelled orders
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Preloads == nil {
tx.Statement.Preloads = map[string][]interface{}{}
}
tx.Statement.Preloads[query] = args
return
}
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Attrs only adds attributes if the record is not found.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign an email if the record is not found, otherwise ignore provided email
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.attrs = attrs
return
}
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
//
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
// records will be updated even if they are found.
//
// // assign an email regardless of if the record is not found
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
//
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.assigns = attrs
return
}
func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance()
tx.Statement.Unscoped = true
return
}
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
return
}

89
vendor/gorm.io/gorm/clause/clause.go generated vendored Normal file
View File

@@ -0,0 +1,89 @@
package clause
// Interface clause interface
type Interface interface {
Name() string
Build(Builder)
MergeClause(*Clause)
}
// ClauseBuilder clause builder, allows to customize how to build clause
type ClauseBuilder func(Clause, Builder)
type Writer interface {
WriteByte(byte) error
WriteString(string) (int, error)
}
// Builder builder interface
type Builder interface {
Writer
WriteQuoted(field interface{})
AddVar(Writer, ...interface{})
AddError(error) error
}
// Clause
type Clause struct {
Name string // WHERE
BeforeExpression Expression
AfterNameExpression Expression
AfterExpression Expression
Expression Expression
Builder ClauseBuilder
}
// Build build clause
func (c Clause) Build(builder Builder) {
if c.Builder != nil {
c.Builder(c, builder)
} else if c.Expression != nil {
if c.BeforeExpression != nil {
c.BeforeExpression.Build(builder)
builder.WriteByte(' ')
}
if c.Name != "" {
builder.WriteString(c.Name)
builder.WriteByte(' ')
}
if c.AfterNameExpression != nil {
c.AfterNameExpression.Build(builder)
builder.WriteByte(' ')
}
c.Expression.Build(builder)
if c.AfterExpression != nil {
builder.WriteByte(' ')
c.AfterExpression.Build(builder)
}
}
}
const (
PrimaryKey string = "~~~py~~~" // primary key
CurrentTable string = "~~~ct~~~" // current table
Associations string = "~~~as~~~" // associations
)
var (
currentTable = Table{Name: CurrentTable}
PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey}
)
// Column quote with name
type Column struct {
Table string
Name string
Alias string
Raw bool
}
// Table quote with name
type Table struct {
Name string
Alias string
Raw bool
}

23
vendor/gorm.io/gorm/clause/delete.go generated vendored Normal file
View File

@@ -0,0 +1,23 @@
package clause
type Delete struct {
Modifier string
}
func (d Delete) Name() string {
return "DELETE"
}
func (d Delete) Build(builder Builder) {
builder.WriteString("DELETE")
if d.Modifier != "" {
builder.WriteByte(' ')
builder.WriteString(d.Modifier)
}
}
func (d Delete) MergeClause(clause *Clause) {
clause.Name = ""
clause.Expression = d
}

385
vendor/gorm.io/gorm/clause/expression.go generated vendored Normal file
View File

@@ -0,0 +1,385 @@
package clause
import (
"database/sql"
"database/sql/driver"
"go/ast"
"reflect"
)
// Expression expression interface
type Expression interface {
Build(builder Builder)
}
// NegationExpressionBuilder negation expression builder
type NegationExpressionBuilder interface {
NegationBuild(builder Builder)
}
// Expr raw expression
type Expr struct {
SQL string
Vars []interface{}
WithoutParentheses bool
}
// Build build raw expression
func (expr Expr) Build(builder Builder) {
var (
afterParenthesis bool
idx int
)
for _, v := range []byte(expr.SQL) {
if v == '?' && len(expr.Vars) > idx {
if afterParenthesis || expr.WithoutParentheses {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
idx++
} else {
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
builder.WriteByte(v)
}
}
if idx < len(expr.Vars) {
for _, v := range expr.Vars[idx:] {
builder.AddVar(builder, sql.NamedArg{Value: v})
}
}
}
// NamedExpr raw expression for named expr
type NamedExpr struct {
SQL string
Vars []interface{}
}
// Build build raw expression
func (expr NamedExpr) Build(builder Builder) {
var (
idx int
inName bool
afterParenthesis bool
namedMap = make(map[string]interface{}, len(expr.Vars))
)
for _, v := range expr.Vars {
switch value := v.(type) {
case sql.NamedArg:
namedMap[value.Name] = value.Value
case map[string]interface{}:
for k, v := range value {
namedMap[k] = v
}
default:
var appendFieldsToMap func(reflect.Value)
appendFieldsToMap = func(reflectValue reflect.Value) {
reflectValue = reflect.Indirect(reflectValue)
switch reflectValue.Kind() {
case reflect.Struct:
modelType := reflectValue.Type()
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
if fieldStruct.Anonymous {
appendFieldsToMap(reflectValue.Field(i))
}
}
}
}
}
appendFieldsToMap(reflect.ValueOf(value))
}
}
name := make([]byte, 0, 10)
for _, v := range []byte(expr.SQL) {
if v == '@' && !inName {
inName = true
name = name[:0]
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
inName = false
}
afterParenthesis = false
builder.WriteByte(v)
} else if v == '?' && len(expr.Vars) > idx {
if afterParenthesis {
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
builder.AddVar(builder, expr.Vars[idx])
} else {
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
builder.AddVar(builder, nil)
} else {
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
}
default:
builder.AddVar(builder, expr.Vars[idx])
}
}
} else {
builder.AddVar(builder, expr.Vars[idx])
}
idx++
} else if inName {
name = append(name, v)
} else {
if v == '(' {
afterParenthesis = true
} else {
afterParenthesis = false
}
builder.WriteByte(v)
}
}
if inName {
if nv, ok := namedMap[string(name)]; ok {
builder.AddVar(builder, nv)
} else {
builder.WriteByte('@')
builder.WriteString(string(name))
}
}
}
// IN Whether a value is within a set of values
type IN struct {
Column interface{}
Values []interface{}
}
func (in IN) Build(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.WriteString(" IN (NULL)")
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteString(" = ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
default:
builder.WriteString(" IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
}
}
func (in IN) NegationBuild(builder Builder) {
builder.WriteQuoted(in.Column)
switch len(in.Values) {
case 0:
builder.WriteString(" IS NOT NULL")
case 1:
if _, ok := in.Values[0].([]interface{}); !ok {
builder.WriteString(" <> ")
builder.AddVar(builder, in.Values[0])
break
}
fallthrough
default:
builder.WriteString(" NOT IN (")
builder.AddVar(builder, in.Values...)
builder.WriteByte(')')
}
}
// Eq equal to for where
type Eq struct {
Column interface{}
Value interface{}
}
func (eq Eq) Build(builder Builder) {
builder.WriteQuoted(eq.Column)
switch eq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
rv := reflect.ValueOf(eq.Value)
if rv.Len() == 0 {
builder.WriteString(" IN (NULL)")
} else {
builder.WriteString(" IN (")
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
}
default:
if eqNil(eq.Value) {
builder.WriteString(" IS NULL")
} else {
builder.WriteString(" = ")
builder.AddVar(builder, eq.Value)
}
}
}
func (eq Eq) NegationBuild(builder Builder) {
Neq(eq).Build(builder)
}
// Neq not equal to for where
type Neq Eq
func (neq Neq) Build(builder Builder) {
builder.WriteQuoted(neq.Column)
switch neq.Value.(type) {
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
builder.WriteString(" NOT IN (")
rv := reflect.ValueOf(neq.Value)
for i := 0; i < rv.Len(); i++ {
if i > 0 {
builder.WriteByte(',')
}
builder.AddVar(builder, rv.Index(i).Interface())
}
builder.WriteByte(')')
default:
if eqNil(neq.Value) {
builder.WriteString(" IS NOT NULL")
} else {
builder.WriteString(" <> ")
builder.AddVar(builder, neq.Value)
}
}
}
func (neq Neq) NegationBuild(builder Builder) {
Eq(neq).Build(builder)
}
// Gt greater than for where
type Gt Eq
func (gt Gt) Build(builder Builder) {
builder.WriteQuoted(gt.Column)
builder.WriteString(" > ")
builder.AddVar(builder, gt.Value)
}
func (gt Gt) NegationBuild(builder Builder) {
Lte(gt).Build(builder)
}
// Gte greater than or equal to for where
type Gte Eq
func (gte Gte) Build(builder Builder) {
builder.WriteQuoted(gte.Column)
builder.WriteString(" >= ")
builder.AddVar(builder, gte.Value)
}
func (gte Gte) NegationBuild(builder Builder) {
Lt(gte).Build(builder)
}
// Lt less than for where
type Lt Eq
func (lt Lt) Build(builder Builder) {
builder.WriteQuoted(lt.Column)
builder.WriteString(" < ")
builder.AddVar(builder, lt.Value)
}
func (lt Lt) NegationBuild(builder Builder) {
Gte(lt).Build(builder)
}
// Lte less than or equal to for where
type Lte Eq
func (lte Lte) Build(builder Builder) {
builder.WriteQuoted(lte.Column)
builder.WriteString(" <= ")
builder.AddVar(builder, lte.Value)
}
func (lte Lte) NegationBuild(builder Builder) {
Gt(lte).Build(builder)
}
// Like whether string matches regular expression
type Like Eq
func (like Like) Build(builder Builder) {
builder.WriteQuoted(like.Column)
builder.WriteString(" LIKE ")
builder.AddVar(builder, like.Value)
}
func (like Like) NegationBuild(builder Builder) {
builder.WriteQuoted(like.Column)
builder.WriteString(" NOT LIKE ")
builder.AddVar(builder, like.Value)
}
func eqNil(value interface{}) bool {
if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) {
value, _ = valuer.Value()
}
return value == nil || eqNilReflect(value)
}
func eqNilReflect(value interface{}) bool {
reflectValue := reflect.ValueOf(value)
return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
}

37
vendor/gorm.io/gorm/clause/from.go generated vendored Normal file
View File

@@ -0,0 +1,37 @@
package clause
// From from clause
type From struct {
Tables []Table
Joins []Join
}
// Name from clause name
func (from From) Name() string {
return "FROM"
}
// Build build from clause
func (from From) Build(builder Builder) {
if len(from.Tables) > 0 {
for idx, table := range from.Tables {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(table)
}
} else {
builder.WriteQuoted(currentTable)
}
for _, join := range from.Joins {
builder.WriteByte(' ')
join.Build(builder)
}
}
// MergeClause merge from clause
func (from From) MergeClause(clause *Clause) {
clause.Expression = from
}

48
vendor/gorm.io/gorm/clause/group_by.go generated vendored Normal file
View File

@@ -0,0 +1,48 @@
package clause
// GroupBy group by clause
type GroupBy struct {
Columns []Column
Having []Expression
}
// Name from clause name
func (groupBy GroupBy) Name() string {
return "GROUP BY"
}
// Build build group by clause
func (groupBy GroupBy) Build(builder Builder) {
for idx, column := range groupBy.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
if len(groupBy.Having) > 0 {
builder.WriteString(" HAVING ")
Where{Exprs: groupBy.Having}.Build(builder)
}
}
// MergeClause merge group by clause
func (groupBy GroupBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(GroupBy); ok {
copiedColumns := make([]Column, len(v.Columns))
copy(copiedColumns, v.Columns)
groupBy.Columns = append(copiedColumns, groupBy.Columns...)
copiedHaving := make([]Expression, len(v.Having))
copy(copiedHaving, v.Having)
groupBy.Having = append(copiedHaving, groupBy.Having...)
}
clause.Expression = groupBy
if len(groupBy.Columns) == 0 {
clause.Name = ""
} else {
clause.Name = groupBy.Name()
}
}

39
vendor/gorm.io/gorm/clause/insert.go generated vendored Normal file
View File

@@ -0,0 +1,39 @@
package clause
type Insert struct {
Table Table
Modifier string
}
// Name insert clause name
func (insert Insert) Name() string {
return "INSERT"
}
// Build build insert clause
func (insert Insert) Build(builder Builder) {
if insert.Modifier != "" {
builder.WriteString(insert.Modifier)
builder.WriteByte(' ')
}
builder.WriteString("INTO ")
if insert.Table.Name == "" {
builder.WriteQuoted(currentTable)
} else {
builder.WriteQuoted(insert.Table)
}
}
// MergeClause merge insert clause
func (insert Insert) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Insert); ok {
if insert.Modifier == "" {
insert.Modifier = v.Modifier
}
if insert.Table.Name == "" {
insert.Table = v.Table
}
}
clause.Expression = insert
}

47
vendor/gorm.io/gorm/clause/joins.go generated vendored Normal file
View File

@@ -0,0 +1,47 @@
package clause
type JoinType string
const (
CrossJoin JoinType = "CROSS"
InnerJoin JoinType = "INNER"
LeftJoin JoinType = "LEFT"
RightJoin JoinType = "RIGHT"
)
// Join clause for from
type Join struct {
Type JoinType
Table Table
ON Where
Using []string
Expression Expression
}
func (join Join) Build(builder Builder) {
if join.Expression != nil {
join.Expression.Build(builder)
} else {
if join.Type != "" {
builder.WriteString(string(join.Type))
builder.WriteByte(' ')
}
builder.WriteString("JOIN ")
builder.WriteQuoted(join.Table)
if len(join.ON.Exprs) > 0 {
builder.WriteString(" ON ")
join.ON.Build(builder)
} else if len(join.Using) > 0 {
builder.WriteString(" USING (")
for idx, c := range join.Using {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(c)
}
builder.WriteByte(')')
}
}
}

46
vendor/gorm.io/gorm/clause/limit.go generated vendored Normal file
View File

@@ -0,0 +1,46 @@
package clause
// Limit limit clause
type Limit struct {
Limit *int
Offset int
}
// Name where clause name
func (limit Limit) Name() string {
return "LIMIT"
}
// Build build where clause
func (limit Limit) Build(builder Builder) {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteString("LIMIT ")
builder.AddVar(builder, *limit.Limit)
}
if limit.Offset > 0 {
if limit.Limit != nil && *limit.Limit >= 0 {
builder.WriteByte(' ')
}
builder.WriteString("OFFSET ")
builder.AddVar(builder, limit.Offset)
}
}
// MergeClause merge order by clauses
func (limit Limit) MergeClause(clause *Clause) {
clause.Name = ""
if v, ok := clause.Expression.(Limit); ok {
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
limit.Limit = v.Limit
}
if limit.Offset == 0 && v.Offset > 0 {
limit.Offset = v.Offset
} else if limit.Offset < 0 {
limit.Offset = 0
}
}
clause.Expression = limit
}

38
vendor/gorm.io/gorm/clause/locking.go generated vendored Normal file
View File

@@ -0,0 +1,38 @@
package clause
const (
LockingStrengthUpdate = "UPDATE"
LockingStrengthShare = "SHARE"
LockingOptionsSkipLocked = "SKIP LOCKED"
LockingOptionsNoWait = "NOWAIT"
)
type Locking struct {
Strength string
Table Table
Options string
}
// Name where clause name
func (locking Locking) Name() string {
return "FOR"
}
// Build build where clause
func (locking Locking) Build(builder Builder) {
builder.WriteString(locking.Strength)
if locking.Table.Name != "" {
builder.WriteString(" OF ")
builder.WriteQuoted(locking.Table)
}
if locking.Options != "" {
builder.WriteByte(' ')
builder.WriteString(locking.Options)
}
}
// MergeClause merge order by clauses
func (locking Locking) MergeClause(clause *Clause) {
clause.Expression = locking
}

59
vendor/gorm.io/gorm/clause/on_conflict.go generated vendored Normal file
View File

@@ -0,0 +1,59 @@
package clause
type OnConflict struct {
Columns []Column
Where Where
TargetWhere Where
OnConstraint string
DoNothing bool
DoUpdates Set
UpdateAll bool
}
func (OnConflict) Name() string {
return "ON CONFLICT"
}
// Build build onConflict clause
func (onConflict OnConflict) Build(builder Builder) {
if onConflict.OnConstraint != "" {
builder.WriteString("ON CONSTRAINT ")
builder.WriteString(onConflict.OnConstraint)
builder.WriteByte(' ')
} else {
if len(onConflict.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range onConflict.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteString(`) `)
}
if len(onConflict.TargetWhere.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.TargetWhere.Build(builder)
builder.WriteByte(' ')
}
}
if onConflict.DoNothing {
builder.WriteString("DO NOTHING")
} else {
builder.WriteString("DO UPDATE SET ")
onConflict.DoUpdates.Build(builder)
}
if len(onConflict.Where.Exprs) > 0 {
builder.WriteString(" WHERE ")
onConflict.Where.Build(builder)
builder.WriteByte(' ')
}
}
// MergeClause merge onConflict clauses
func (onConflict OnConflict) MergeClause(clause *Clause) {
clause.Expression = onConflict
}

54
vendor/gorm.io/gorm/clause/order_by.go generated vendored Normal file
View File

@@ -0,0 +1,54 @@
package clause
type OrderByColumn struct {
Column Column
Desc bool
Reorder bool
}
type OrderBy struct {
Columns []OrderByColumn
Expression Expression
}
// Name where clause name
func (orderBy OrderBy) Name() string {
return "ORDER BY"
}
// Build build where clause
func (orderBy OrderBy) Build(builder Builder) {
if orderBy.Expression != nil {
orderBy.Expression.Build(builder)
} else {
for idx, column := range orderBy.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column.Column)
if column.Desc {
builder.WriteString(" DESC")
}
}
}
}
// MergeClause merge order by clauses
func (orderBy OrderBy) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(OrderBy); ok {
for i := len(orderBy.Columns) - 1; i >= 0; i-- {
if orderBy.Columns[i].Reorder {
orderBy.Columns = orderBy.Columns[i:]
clause.Expression = orderBy
return
}
}
copiedColumns := make([]OrderByColumn, len(v.Columns))
copy(copiedColumns, v.Columns)
orderBy.Columns = append(copiedColumns, orderBy.Columns...)
}
clause.Expression = orderBy
}

34
vendor/gorm.io/gorm/clause/returning.go generated vendored Normal file
View File

@@ -0,0 +1,34 @@
package clause
type Returning struct {
Columns []Column
}
// Name where clause name
func (returning Returning) Name() string {
return "RETURNING"
}
// Build build where clause
func (returning Returning) Build(builder Builder) {
if len(returning.Columns) > 0 {
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
} else {
builder.WriteByte('*')
}
}
// MergeClause merge order by clauses
func (returning Returning) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Returning); ok {
returning.Columns = append(v.Columns, returning.Columns...)
}
clause.Expression = returning
}

59
vendor/gorm.io/gorm/clause/select.go generated vendored Normal file
View File

@@ -0,0 +1,59 @@
package clause
// Select select attrs when querying, updating, creating
type Select struct {
Distinct bool
Columns []Column
Expression Expression
}
func (s Select) Name() string {
return "SELECT"
}
func (s Select) Build(builder Builder) {
if len(s.Columns) > 0 {
if s.Distinct {
builder.WriteString("DISTINCT ")
}
for idx, column := range s.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
} else {
builder.WriteByte('*')
}
}
func (s Select) MergeClause(clause *Clause) {
if s.Expression != nil {
if s.Distinct {
if expr, ok := s.Expression.(Expr); ok {
expr.SQL = "DISTINCT " + expr.SQL
clause.Expression = expr
return
}
}
clause.Expression = s.Expression
} else {
clause.Expression = s
}
}
// CommaExpression represents a group of expressions separated by commas.
type CommaExpression struct {
Exprs []Expression
}
func (comma CommaExpression) Build(builder Builder) {
for idx, expr := range comma.Exprs {
if idx > 0 {
_, _ = builder.WriteString(", ")
}
expr.Build(builder)
}
}

60
vendor/gorm.io/gorm/clause/set.go generated vendored Normal file
View File

@@ -0,0 +1,60 @@
package clause
import "sort"
type Set []Assignment
type Assignment struct {
Column Column
Value interface{}
}
func (set Set) Name() string {
return "SET"
}
func (set Set) Build(builder Builder) {
if len(set) > 0 {
for idx, assignment := range set {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(assignment.Column)
builder.WriteByte('=')
builder.AddVar(builder, assignment.Value)
}
} else {
builder.WriteQuoted(Column{Name: PrimaryKey})
builder.WriteByte('=')
builder.WriteQuoted(Column{Name: PrimaryKey})
}
}
// MergeClause merge assignments clauses
func (set Set) MergeClause(clause *Clause) {
copiedAssignments := make([]Assignment, len(set))
copy(copiedAssignments, set)
clause.Expression = Set(copiedAssignments)
}
func Assignments(values map[string]interface{}) Set {
keys := make([]string, 0, len(values))
for key := range values {
keys = append(keys, key)
}
sort.Strings(keys)
assignments := make([]Assignment, len(keys))
for idx, key := range keys {
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
}
return assignments
}
func AssignmentColumns(values []string) Set {
assignments := make([]Assignment, len(values))
for idx, value := range values {
assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}}
}
return assignments
}

38
vendor/gorm.io/gorm/clause/update.go generated vendored Normal file
View File

@@ -0,0 +1,38 @@
package clause
type Update struct {
Modifier string
Table Table
}
// Name update clause name
func (update Update) Name() string {
return "UPDATE"
}
// Build build update clause
func (update Update) Build(builder Builder) {
if update.Modifier != "" {
builder.WriteString(update.Modifier)
builder.WriteByte(' ')
}
if update.Table.Name == "" {
builder.WriteQuoted(currentTable)
} else {
builder.WriteQuoted(update.Table)
}
}
// MergeClause merge update clause
func (update Update) MergeClause(clause *Clause) {
if v, ok := clause.Expression.(Update); ok {
if update.Modifier == "" {
update.Modifier = v.Modifier
}
if update.Table.Name == "" {
update.Table = v.Table
}
}
clause.Expression = update
}

45
vendor/gorm.io/gorm/clause/values.go generated vendored Normal file
View File

@@ -0,0 +1,45 @@
package clause
type Values struct {
Columns []Column
Values [][]interface{}
}
// Name from clause name
func (Values) Name() string {
return "VALUES"
}
// Build build from clause
func (values Values) Build(builder Builder) {
if len(values.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range values.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteByte(')')
builder.WriteString(" VALUES ")
for idx, value := range values.Values {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteByte('(')
builder.AddVar(builder, value...)
builder.WriteByte(')')
}
} else {
builder.WriteString("DEFAULT VALUES")
}
}
// MergeClause merge values clauses
func (values Values) MergeClause(clause *Clause) {
clause.Name = ""
clause.Expression = values
}

201
vendor/gorm.io/gorm/clause/where.go generated vendored Normal file
View File

@@ -0,0 +1,201 @@
package clause
import (
"strings"
)
const (
AndWithSpace = " AND "
OrWithSpace = " OR "
)
// Where where clause
type Where struct {
Exprs []Expression
}
// Name where clause name
func (where Where) Name() string {
return "WHERE"
}
// Build build where clause
func (where Where) Build(builder Builder) {
if len(where.Exprs) == 1 {
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
where.Exprs = andCondition.Exprs
}
}
// Switch position if the first query expression is a single Or condition
for idx, expr := range where.Exprs {
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
if idx != 0 {
where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
}
break
}
}
buildExprs(where.Exprs, builder, AndWithSpace)
}
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
wrapInParentheses := false
for idx, expr := range exprs {
if idx > 0 {
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
builder.WriteString(OrWithSpace)
} else {
builder.WriteString(joinCond)
}
}
if len(exprs) > 1 {
switch v := expr.(type) {
case OrConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case AndConditions:
if len(v.Exprs) == 1 {
if e, ok := v.Exprs[0].(Expr); ok {
sql := strings.ToUpper(e.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
case Expr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
case NamedExpr:
sql := strings.ToUpper(v.SQL)
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
}
}
if wrapInParentheses {
builder.WriteByte('(')
expr.Build(builder)
builder.WriteByte(')')
wrapInParentheses = false
} else {
expr.Build(builder)
}
}
}
// MergeClause merge where clauses
func (where Where) MergeClause(clause *Clause) {
if w, ok := clause.Expression.(Where); ok {
exprs := make([]Expression, len(w.Exprs)+len(where.Exprs))
copy(exprs, w.Exprs)
copy(exprs[len(w.Exprs):], where.Exprs)
where.Exprs = exprs
}
clause.Expression = where
}
func And(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
if len(exprs) == 1 {
if _, ok := exprs[0].(OrConditions); !ok {
return exprs[0]
}
}
return AndConditions{Exprs: exprs}
}
type AndConditions struct {
Exprs []Expression
}
func (and AndConditions) Build(builder Builder) {
if len(and.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(and.Exprs, builder, AndWithSpace)
builder.WriteByte(')')
} else {
buildExprs(and.Exprs, builder, AndWithSpace)
}
}
func Or(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
return OrConditions{Exprs: exprs}
}
type OrConditions struct {
Exprs []Expression
}
func (or OrConditions) Build(builder Builder) {
if len(or.Exprs) > 1 {
builder.WriteByte('(')
buildExprs(or.Exprs, builder, OrWithSpace)
builder.WriteByte(')')
} else {
buildExprs(or.Exprs, builder, OrWithSpace)
}
}
func Not(exprs ...Expression) Expression {
if len(exprs) == 0 {
return nil
}
if len(exprs) == 1 {
if andCondition, ok := exprs[0].(AndConditions); ok {
exprs = andCondition.Exprs
}
}
return NotConditions{Exprs: exprs}
}
type NotConditions struct {
Exprs []Expression
}
func (not NotConditions) Build(builder Builder) {
if len(not.Exprs) > 1 {
builder.WriteByte('(')
}
for idx, c := range not.Exprs {
if idx > 0 {
builder.WriteString(AndWithSpace)
}
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
negationBuilder.NegationBuild(builder)
} else {
builder.WriteString("NOT ")
e, wrapInParentheses := c.(Expr)
if wrapInParentheses {
sql := strings.ToUpper(e.SQL)
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
builder.WriteByte('(')
}
}
c.Build(builder)
if wrapInParentheses {
builder.WriteByte(')')
}
}
}
if len(not.Exprs) > 1 {
builder.WriteByte(')')
}
}

3
vendor/gorm.io/gorm/clause/with.go generated vendored Normal file
View File

@@ -0,0 +1,3 @@
package clause
type With struct{}

52
vendor/gorm.io/gorm/errors.go generated vendored Normal file
View File

@@ -0,0 +1,52 @@
package gorm
import (
"errors"
"gorm.io/gorm/logger"
)
var (
// ErrRecordNotFound record not found error
ErrRecordNotFound = logger.ErrRecordNotFound
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
ErrInvalidTransaction = errors.New("invalid transaction")
// ErrNotImplemented not implemented
ErrNotImplemented = errors.New("not implemented")
// ErrMissingWhereClause missing where clause
ErrMissingWhereClause = errors.New("WHERE conditions required")
// ErrUnsupportedRelation unsupported relations
ErrUnsupportedRelation = errors.New("unsupported relations")
// ErrPrimaryKeyRequired primary keys required
ErrPrimaryKeyRequired = errors.New("primary key required")
// ErrModelValueRequired model value required
ErrModelValueRequired = errors.New("model value required")
// ErrModelAccessibleFieldsRequired model accessible fields required
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
// ErrSubQueryRequired sub query required
ErrSubQueryRequired = errors.New("sub query required")
// ErrInvalidData unsupported data
ErrInvalidData = errors.New("unsupported data")
// ErrUnsupportedDriver unsupported driver
ErrUnsupportedDriver = errors.New("unsupported driver")
// ErrRegistered registered
ErrRegistered = errors.New("registered")
// ErrInvalidField invalid field
ErrInvalidField = errors.New("invalid field")
// ErrEmptySlice empty slice found
ErrEmptySlice = errors.New("empty slice found")
// ErrDryRunModeUnsupported dry run mode unsupported
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
// ErrInvalidDB invalid db
ErrInvalidDB = errors.New("invalid db")
// ErrInvalidValue invalid value
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
// ErrInvalidValueOfLength invalid values do not match length
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
// ErrPreloadNotAllowed preload is not allowed when count is used
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
// ErrDuplicatedKey occurs when there is a unique key constraint violation
ErrDuplicatedKey = errors.New("duplicated key not allowed")
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
)

770
vendor/gorm.io/gorm/finisher_api.go generated vendored Normal file
View File

@@ -0,0 +1,770 @@
package gorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}
tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
}
// CreateInBatches inserts value in batches of batchSize
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var rowsAffected int64
tx = db.getInstance()
// the reflection length judgment of the optimized value
reflectLen := reflectValue.Len()
callFc := func(tx *DB) error {
for i := 0; i < reflectLen; i += batchSize {
ends := i + batchSize
if ends > reflectLen {
ends = reflectLen
}
subtx := tx.getInstance()
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
subtx.callbacks.Create().Execute(subtx)
if subtx.Error != nil {
return subtx.Error
}
rowsAffected += subtx.RowsAffected
}
return nil
}
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
tx.AddError(callFc(tx.Session(&Session{})))
} else {
tx.AddError(tx.Transaction(callFc))
}
tx.RowsAffected = rowsAffected
default:
tx = db.getInstance()
tx.Statement.Dest = value
tx = tx.callbacks.Create().Execute(tx)
}
return
}
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
reflectValue := reflect.Indirect(reflect.ValueOf(value))
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
reflectValue = reflect.Indirect(reflectValue)
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
}
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
case reflect.Struct:
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
for _, pf := range tx.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
return tx.callbacks.Create().Execute(tx)
}
}
}
fallthrough
default:
selectedUpdate := len(tx.Statement.Selects) != 0
// when updating, use all fields including those zero-value fields
if !selectedUpdate {
tx.Statement.Selects = append(tx.Statement.Selects, "*")
}
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
}
return updateTx
}
return
}
// First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
// Take finds the first record returned by the database in no specified order, matching given conditions conds
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1)
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
// Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Desc: true,
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.RaiseErrorOnNotFound = true
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
// Find finds all records matching given conditions conds
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
// FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
var (
tx = db.Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
}).Session(&Session{})
queryDB = tx
rowsAffected int64
batch int
)
// user specified offset or limit
var totalSize int
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
if limit, ok := c.Expression.(clause.Limit); ok {
if limit.Limit != nil {
totalSize = *limit.Limit
}
if totalSize > 0 && batchSize > totalSize {
batchSize = totalSize
}
// reset to offset to 0 in next batch
tx = tx.Offset(-1).Session(&Session{})
}
}
for {
result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected
batch++
if result.Error == nil && result.RowsAffected != 0 {
fcTx := result.Session(&Session{NewDB: true})
fcTx.RowsAffected = result.RowsAffected
tx.AddError(fc(fcTx, batch))
} else if result.Error != nil {
tx.AddError(result.Error)
}
if tx.Error != nil || int(result.RowsAffected) < batchSize {
break
}
if totalSize > 0 {
if totalSize <= int(rowsAffected) {
break
}
if totalSize/batchSize == batch {
batchSize = totalSize % batchSize
}
}
// Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
tx.AddError(ErrPrimaryKeyRequired)
break
}
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
tx.RowsAffected = rowsAffected
return tx
}
func (db *DB) assignInterfacesToValue(values ...interface{}) {
for _, value := range values {
switch v := value.(type) {
case []clause.Expression:
for _, expr := range v {
if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
if field := db.Statement.Schema.LookUpField(column); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
case clause.Column:
if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
}
}
} else if andCond, ok := expr.(clause.AndConditions); ok {
db.assignInterfacesToValue(andCond.Exprs)
}
}
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
default:
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
reflectValue := reflect.Indirect(reflect.ValueOf(value))
switch reflectValue.Kind() {
case reflect.Struct:
for _, f := range s.Fields {
if f.Readable {
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
}
}
}
}
}
} else if len(values) > 0 {
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
db.assignInterfacesToValue(exprs)
}
return
}
}
}
}
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
// Each conds must be a struct or map.
//
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
//
// // assign an email if the record is not found
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
//
// // assign email regardless of if record is found
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
tx.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(tx.Statement.attrs) > 0 {
tx.assignInterfacesToValue(tx.Statement.attrs...)
}
}
// initialize with attrs, conds
if len(tx.Statement.assigns) > 0 {
tx.assignInterfacesToValue(tx.Statement.assigns...)
}
return
}
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
// Each conds must be a struct or map.
//
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
//
// // assign an email if the record is not found
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
// // result.RowsAffected -> 1
//
// // assign email regardless of if record is found
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
// // result.RowsAffected -> 1
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
})
result := queryTx.Find(dest, conds...)
if result.Error != nil {
tx.Error = result.Error
return tx
}
if result.RowsAffected == 0 {
if c, ok := result.Statement.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok {
result.assignInterfacesToValue(where.Exprs)
}
}
// initialize with attrs, conds
if len(db.Statement.attrs) > 0 {
result.assignInterfacesToValue(db.Statement.attrs...)
}
// initialize with attrs, conds
if len(db.Statement.assigns) > 0 {
result.assignInterfacesToValue(db.Statement.assigns...)
}
return tx.Create(dest)
} else if len(db.Statement.assigns) > 0 {
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
assigns := map[string]interface{}{}
for i := 0; i < len(exprs); i++ {
expr := exprs[i]
if eq, ok := expr.(clause.AndConditions); ok {
exprs = append(exprs, eq.Exprs...)
} else if eq, ok := expr.(clause.Eq); ok {
switch column := eq.Column.(type) {
case string:
assigns[column] = eq.Value
case clause.Column:
assigns[column.Name] = eq.Value
}
}
}
return tx.Model(dest).Updates(assigns)
}
return tx
}
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Update(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value}
return tx.callbacks.Update().Execute(tx)
}
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
func (db *DB) Updates(values interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
return tx.callbacks.Update().Execute(tx)
}
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = map[string]interface{}{column: value}
tx.Statement.SkipHooks = true
return tx.callbacks.Update().Execute(tx)
}
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = values
tx.Statement.SkipHooks = true
return tx.callbacks.Update().Execute(tx)
}
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
// time if null.
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
tx.Statement.AddClause(clause.Where{Exprs: exprs})
}
}
tx.Statement.Dest = value
return tx.callbacks.Delete().Execute(tx)
}
func (db *DB) Count(count *int64) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model == nil {
tx.Statement.Model = tx.Statement.Dest
defer func() {
tx.Statement.Model = nil
}()
}
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
defer func() {
tx.Statement.Clauses["SELECT"] = selectClause
}()
} else {
defer delete(tx.Statement.Clauses, "SELECT")
}
if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
expr := clause.Expr{SQL: "count(*)"}
if len(tx.Statement.Selects) == 1 {
dbName := tx.Statement.Selects[0]
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
dbName = f.DBName
}
}
if tx.Statement.Distinct {
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
} else if dbName != "*" {
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
}
}
}
tx.Statement.AddClause(clause.Select{Expression: expr})
}
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
delete(tx.Statement.Clauses, "ORDER BY")
defer func() {
tx.Statement.Clauses["ORDER BY"] = orderByClause
}()
}
}
tx.Statement.Dest = count
tx = tx.callbacks.Query().Execute(tx)
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
*count = tx.RowsAffected
}
return
}
func (db *DB) Row() *sql.Row {
tx := db.getInstance().Set("rows", false)
tx = tx.callbacks.Row().Execute(tx)
row, ok := tx.Statement.Dest.(*sql.Row)
if !ok && tx.DryRun {
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
}
return row
}
func (db *DB) Rows() (*sql.Rows, error) {
tx := db.getInstance().Set("rows", true)
tx = tx.callbacks.Row().Execute(tx)
rows, ok := tx.Statement.Dest.(*sql.Rows)
if !ok && tx.DryRun && tx.Error == nil {
tx.Error = ErrDryRunModeUnsupported
}
return rows, tx.Error
}
// Scan scans selected value to the struct dest
func (db *DB) Scan(dest interface{}) (tx *DB) {
config := *db.Config
currentLogger, newLogger := config.Logger, logger.Recorder.New()
config.Logger = newLogger
tx = db.getInstance()
tx.Config = &config
if rows, err := tx.Rows(); err == nil {
if rows.Next() {
tx.ScanRows(rows, dest)
} else {
tx.RowsAffected = 0
tx.AddError(rows.Err())
}
tx.AddError(rows.Close())
}
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
return newLogger.SQL, tx.RowsAffected
}, tx.Error)
tx.Logger = currentLogger
return
}
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
//
// var ages []int64
// db.Model(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance()
if tx.Statement.Model != nil {
if tx.Statement.Parse(tx.Statement.Model) == nil {
if f := tx.Statement.Schema.LookUpField(column); f != nil {
column = f.DBName
}
}
}
if len(tx.Statement.Selects) != 1 {
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
tx.Statement.AddClauseIfNotExists(clause.Select{
Distinct: tx.Statement.Distinct,
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
})
}
tx.Statement.Dest = dest
return tx.callbacks.Query().Execute(tx)
}
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
tx := db.getInstance()
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
tx.AddError(err)
}
tx.Statement.Dest = dest
tx.Statement.ReflectValue = reflect.ValueOf(dest)
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
elem := tx.Statement.ReflectValue.Elem()
if !elem.IsValid() {
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
tx.Statement.ReflectValue.Set(elem)
}
tx.Statement.ReflectValue = elem
}
Scan(rows, tx, ScanInitialized)
return tx.Error
}
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
// returned to the connection pool.
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil {
return db.Error
}
tx := db.getInstance()
sqlDB, err := tx.DB()
if err != nil {
return
}
conn, err := sqlDB.Conn(tx.Statement.Context)
if err != nil {
return
}
defer conn.Close()
tx.Statement.ConnPool = conn
return fc(tx)
}
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
// they are rolled back.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
if !db.DisableNestedTransaction {
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
if err != nil {
return
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
}
}()
}
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
} else {
tx := db.Begin(opts...)
if tx.Error != nil {
return tx.Error
}
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()
if err = fc(tx); err == nil {
panicked = false
return tx.Commit().Error
}
}
panicked = false
return
}
// Begin begins a transaction with any transaction options opts
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
var (
// clone statement
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
opt *sql.TxOptions
err error
)
if len(opts) > 0 {
opt = opts[0]
}
switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
default:
err = ErrInvalidTransaction
}
if err != nil {
tx.AddError(err)
}
return tx
}
// Commit commits the changes in a transaction
func (db *DB) Commit() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Commit())
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
// Rollback rollbacks the changes in a transaction
func (db *DB) Rollback() *DB {
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
if !reflect.ValueOf(committer).IsNil() {
db.AddError(committer.Rollback())
}
} else {
db.AddError(ErrInvalidTransaction)
}
return db
}
func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because SavePoint not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.SavePoint(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
// close prepared statement, because RollbackTo not support prepared statement.
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
var (
preparedStmtTx *PreparedStmtTX
isPreparedStmtTx bool
)
// close prepared statement, because SavePoint not support prepared statement.
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx.Tx
}
db.AddError(savePointer.RollbackTo(db, name))
// restore prepared statement
if isPreparedStmtTx {
db.Statement.ConnPool = preparedStmtTx
}
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}
// Exec executes raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.SQL = strings.Builder{}
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
} else {
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
}
return tx.callbacks.Raw().Execute(tx)
}

506
vendor/gorm.io/gorm/gorm.go generated vendored Normal file
View File

@@ -0,0 +1,506 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"reflect"
"sort"
"sync"
"time"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
// for Config.cacheStore store PreparedStmtDB key
const preparedStmtDBKey = "preparedStmt"
// Config GORM config
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
// DryRun generate sql without execute
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table
QueryFields bool
// CreateBatchSize default create batch size
CreateBatchSize int
// TranslateError enabling error translation
TranslateError bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool
ConnPool ConnPool
// Dialector database dialector
Dialector
// Plugins registered plugins
Plugins map[string]Plugin
callbacks *callbacks
cacheStore *sync.Map
}
// Apply update config to new config
func (c *Config) Apply(config *Config) error {
if config != c {
*config = *c
}
return nil
}
// AfterInitialize initialize plugins after db connected
func (c *Config) AfterInitialize(db *DB) error {
if db != nil {
for _, plugin := range c.Plugins {
if err := plugin.Initialize(db); err != nil {
return err
}
}
}
return nil
}
// Option gorm option interface
type Option interface {
Apply(*Config) error
AfterInitialize(*DB) error
}
// DB GORM DB definition
type DB struct {
*Config
Error error
RowsAffected int64
Statement *Statement
clone int
}
// Session session config when create session with Session() method
type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
}
// Open initialize db session based on dialector
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
sort.Slice(opts, func(i, j int) bool {
_, isConfig := opts[i].(*Config)
_, isConfig2 := opts[j].(*Config)
return isConfig && !isConfig2
})
for _, opt := range opts {
if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
}
defer func(opt Option) {
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
}
}(opt)
}
}
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
if err = d.Apply(config); err != nil {
return
}
}
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
}
if config.Logger == nil {
config.Logger = logger.Default
}
if config.NowFunc == nil {
config.NowFunc = func() time.Time { return time.Now().Local() }
}
if dialector != nil {
config.Dialector = dialector
}
if config.Plugins == nil {
config.Plugins = map[string]Plugin{}
}
if config.cacheStore == nil {
config.cacheStore = &sync.Map{}
}
db = &DB{Config: config, clone: 1}
db.callbacks = initializeCallbacks(db)
if config.ClauseBuilders == nil {
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
}
if config.Dialector != nil {
err = config.Dialector.Initialize(db)
if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
}
}
}
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
}
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
}
if err == nil && !config.DisableAutomaticPing {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
}
}
if err != nil {
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
}
return
}
// Session create new db session
func (db *DB) Session(config *Session) *DB {
var (
txConfig = *db.Config
tx = &DB{
Config: &txConfig,
Statement: db.Statement,
Error: db.Error,
clone: 1,
}
)
if config.CreateBatchSize > 0 {
tx.Config.CreateBatchSize = config.CreateBatchSize
}
if config.SkipDefaultTransaction {
tx.Config.SkipDefaultTransaction = true
}
if config.AllowGlobalUpdate {
txConfig.AllowGlobalUpdate = true
}
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
}
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
}
if config.Context != nil {
tx.Statement.Context = config.Context
}
if config.PrepareStmt {
var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
}
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
}
default:
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
}
}
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
}
if config.SkipHooks {
tx.Statement.SkipHooks = true
}
if config.DisableNestedTransaction {
txConfig.DisableNestedTransaction = true
}
if !config.NewDB {
tx.clone = 2
}
if config.DryRun {
tx.Config.DryRun = true
}
if config.QueryFields {
tx.Config.QueryFields = true
}
if config.Logger != nil {
tx.Config.Logger = config.Logger
}
if config.NowFunc != nil {
tx.Config.NowFunc = config.NowFunc
}
if config.Initialized {
tx = tx.getInstance()
}
return tx
}
// WithContext change current instance db's context to ctx
func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{Context: ctx})
}
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
tx = db.getInstance()
return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info),
})
}
// Set store value with key into current db instance's context
func (db *DB) Set(key string, value interface{}) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(key, value)
return tx
}
// Get get value with key from current db instance's context
func (db *DB) Get(key string) (interface{}, bool) {
return db.Statement.Settings.Load(key)
}
// InstanceSet store value with key into current db instance's context
func (db *DB) InstanceSet(key string, value interface{}) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
return tx
}
// InstanceGet get value with key from current db instance's context
func (db *DB) InstanceGet(key string) (interface{}, bool) {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
}
// Callback returns callback manager
func (db *DB) Callback() *callbacks {
return db.callbacks
}
// AddError add error to db
func (db *DB) AddError(err error) error {
if err != nil {
if db.Config.TranslateError {
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
err = errTranslator.Translate(err)
}
}
if db.Error == nil {
db.Error = err
} else {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
}
}
return db.Error
}
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if db.Statement != nil && db.Statement.ConnPool != nil {
connPool = db.Statement.ConnPool
}
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
}
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
return sqldb, err
}
}
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
return sqldb, nil
}
return nil, ErrInvalidDB
}
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
}
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
}
return tx
}
return db
}
// Expr returns clause.Expr, which can be used to pass SQL expression as params
func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args}
}
// SetupJoinTable setup join table schema
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
)
err := stmt.Parse(model)
if err != nil {
return err
}
modelSchema = stmt.Schema
err = stmt.Parse(joinTable)
if err != nil {
return err
}
joinSchema = stmt.Schema
relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil
if !isRelation {
return fmt.Errorf("failed to find relation: %s", field)
}
for _, ref := range relation.References {
f := joinSchema.LookUpField(ref.ForeignKey.DBName)
if f == nil {
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
}
f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
}
ref.ForeignKey = f
}
for name, rel := range relation.JoinTable.Relationships.Relations {
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
rel.Schema = joinSchema
joinSchema.Relationships.Relations[name] = rel
}
}
relation.JoinTable = joinSchema
return nil
}
// Use use plugin
func (db *DB) Use(plugin Plugin) error {
name := plugin.Name()
if _, ok := db.Plugins[name]; ok {
return ErrRegistered
}
if err := plugin.Initialize(db); err != nil {
return err
}
db.Plugins[name] = plugin
return nil
}
// ToSQL for generate SQL string.
//
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
// .Limit(10).Offset(5)
// .Order("name ASC")
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
}

92
vendor/gorm.io/gorm/interfaces.go generated vendored Normal file
View File

@@ -0,0 +1,92 @@
package gorm
import (
"context"
"database/sql"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
// Dialector GORM database dialector
type Dialector interface {
Name() string
Initialize(*DB) error
Migrator(db *DB) Migrator
DataTypeOf(*schema.Field) string
DefaultValueOf(*schema.Field) clause.Expression
BindVarTo(writer clause.Writer, stmt *Statement, v interface{})
QuoteTo(clause.Writer, string)
Explain(sql string, vars ...interface{}) string
}
// Plugin GORM plugin interface
type Plugin interface {
Name() string
Initialize(*DB) error
}
type ParamsFilter interface {
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
}
// ConnPool db conns pool interface
type ConnPool interface {
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
// SavePointerDialectorInterface save pointer interface
type SavePointerDialectorInterface interface {
SavePoint(tx *DB, name string) error
RollbackTo(tx *DB, name string) error
}
// TxBeginner tx beginner
type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
// ConnPoolBeginner conn pool beginner
type ConnPoolBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
}
// TxCommitter tx committer
type TxCommitter interface {
Commit() error
Rollback() error
}
// Tx sql.Tx interface
type Tx interface {
ConnPool
TxCommitter
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
}
// Valuer gorm valuer interface
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
}
// GetDBConnector SQL db connector
type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}
// Rows rows interface
type Rows interface {
Columns() ([]string, error)
ColumnTypes() ([]*sql.ColumnType, error)
Next() bool
Scan(dest ...interface{}) error
Err() error
Close() error
}
type ErrorTranslator interface {
Translate(err error) error
}

213
vendor/gorm.io/gorm/logger/logger.go generated vendored Normal file
View File

@@ -0,0 +1,213 @@
package logger
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"time"
"gorm.io/gorm/utils"
)
// ErrRecordNotFound record not found error
var ErrRecordNotFound = errors.New("record not found")
// Colors
const (
Reset = "\033[0m"
Red = "\033[31m"
Green = "\033[32m"
Yellow = "\033[33m"
Blue = "\033[34m"
Magenta = "\033[35m"
Cyan = "\033[36m"
White = "\033[37m"
BlueBold = "\033[34;1m"
MagentaBold = "\033[35;1m"
RedBold = "\033[31;1m"
YellowBold = "\033[33;1m"
)
// LogLevel log level
type LogLevel int
const (
// Silent silent log level
Silent LogLevel = iota + 1
// Error error log level
Error
// Warn warn log level
Warn
// Info info log level
Info
)
// Writer log writer interface
type Writer interface {
Printf(string, ...interface{})
}
// Config logger config
type Config struct {
SlowThreshold time.Duration
Colorful bool
IgnoreRecordNotFoundError bool
ParameterizedQueries bool
LogLevel LogLevel
}
// Interface logger interface
type Interface interface {
LogMode(LogLevel) Interface
Info(context.Context, string, ...interface{})
Warn(context.Context, string, ...interface{})
Error(context.Context, string, ...interface{})
Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
}
var (
// Discard logger will print any log to io.Discard
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
// Default Default logger
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
SlowThreshold: 200 * time.Millisecond,
LogLevel: Warn,
IgnoreRecordNotFoundError: false,
Colorful: true,
})
// Recorder logger records running SQL into a recorder instance
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
)
// New initialize logger
func New(writer Writer, config Config) Interface {
var (
infoStr = "%s\n[info] "
warnStr = "%s\n[warn] "
errStr = "%s\n[error] "
traceStr = "%s\n[%.3fms] [rows:%v] %s"
traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s"
traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
)
if config.Colorful {
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
}
return &logger{
Writer: writer,
Config: config,
infoStr: infoStr,
warnStr: warnStr,
errStr: errStr,
traceStr: traceStr,
traceWarnStr: traceWarnStr,
traceErrStr: traceErrStr,
}
}
type logger struct {
Writer
Config
infoStr, warnStr, errStr string
traceStr, traceErrStr, traceWarnStr string
}
// LogMode log mode
func (l *logger) LogMode(level LogLevel) Interface {
newlogger := *l
newlogger.LogLevel = level
return &newlogger
}
// Info print info
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Info {
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Warn print warn messages
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Warn {
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Error print error messages
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.LogLevel >= Error {
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}
// Trace print sql message
//
//nolint:cyclop
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.LogLevel <= Silent {
return
}
elapsed := time.Since(begin)
switch {
case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
sql, rows := fc()
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
if rows == -1 {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
case l.LogLevel == Info:
sql, rows := fc()
if rows == -1 {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
} else {
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
}
}
// ParamsFilter filter params
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
if l.Config.ParameterizedQueries {
return sql, nil
}
return sql, params
}
type traceRecorder struct {
Interface
BeginAt time.Time
SQL string
RowsAffected int64
Err error
}
// New trace recorder
func (l *traceRecorder) New() *traceRecorder {
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
}
// Trace implement logger interface
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
l.BeginAt = begin
l.SQL, l.RowsAffected = fc()
l.Err = err
}

162
vendor/gorm.io/gorm/logger/sql.go generated vendored Normal file
View File

@@ -0,0 +1,162 @@
package logger
import (
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"unicode"
"gorm.io/gorm/utils"
)
const (
tmFmtWithMS = "2006-01-02 15:04:05.999"
tmFmtZero = "0000-00-00 00:00:00"
nullStr = "NULL"
)
func isPrintable(s string) bool {
for _, r := range s {
if !unicode.IsPrint(r) {
return false
}
}
return true
}
// A list of Go types that should be converted to SQL primitives
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
// RegEx matches only numeric values
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
var (
convertParams func(interface{}, int)
vars = make([]string, len(avars))
)
convertParams = func(v interface{}, idx int) {
switch v := v.(type) {
case bool:
vars[idx] = strconv.FormatBool(v)
case time.Time:
if v.IsZero() {
vars[idx] = escaper + tmFmtZero + escaper
} else {
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
}
case *time.Time:
if v != nil {
if v.IsZero() {
vars[idx] = escaper + tmFmtZero + escaper
} else {
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
}
} else {
vars[idx] = nullStr
}
case driver.Valuer:
reflectValue := reflect.ValueOf(v)
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
r, _ := v.Value()
convertParams(r, idx)
} else {
vars[idx] = nullStr
}
case fmt.Stringer:
reflectValue := reflect.ValueOf(v)
switch reflectValue.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
case reflect.Float32, reflect.Float64:
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
case reflect.Bool:
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
case reflect.String:
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
default:
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
} else {
vars[idx] = nullStr
}
}
case []byte:
if s := string(v); isPrintable(s) {
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
} else {
vars[idx] = escaper + "<binary>" + escaper
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
vars[idx] = utils.ToString(v)
case float32:
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
case string:
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
default:
rv := reflect.ValueOf(v)
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
vars[idx] = nullStr
} else if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
convertParams(v, idx)
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
convertParams(reflect.Indirect(rv).Interface(), idx)
} else {
for _, t := range convertibleTypes {
if rv.Type().ConvertibleTo(t) {
convertParams(rv.Convert(t).Interface(), idx)
return
}
}
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
}
}
}
for idx, v := range avars {
convertParams(v, idx)
}
if numericPlaceholder == nil {
var idx int
var newSQL strings.Builder
for _, v := range []byte(sql) {
if v == '?' {
if len(vars) > idx {
newSQL.WriteString(vars[idx])
idx++
continue
}
}
newSQL.WriteByte(v)
}
sql = newSQL.String()
} else {
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
num := v[1 : len(v)-1]
n, _ := strconv.Atoi(num)
// position var start from 1 ($1, $2)
n -= 1
if n >= 0 && n <= len(vars)-1 {
return vars[n]
}
return v
})
}
return sql
}

111
vendor/gorm.io/gorm/migrator.go generated vendored Normal file
View File

@@ -0,0 +1,111 @@
package gorm
import (
"reflect"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
// Migrator returns migrator
func (db *DB) Migrator() Migrator {
tx := db.getInstance()
// apply scopes to migrator
for len(tx.Statement.scopes) > 0 {
tx = tx.executeScopes()
}
return tx.Dialector.Migrator(tx.Session(&Session{}))
}
// AutoMigrate run auto migration for given models
func (db *DB) AutoMigrate(dst ...interface{}) error {
return db.Migrator().AutoMigrate(dst...)
}
// ViewOption view option
type ViewOption struct {
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
Query *DB // required subquery.
}
// ColumnType column type interface
type ColumnType interface {
Name() string
DatabaseTypeName() string // varchar
ColumnType() (columnType string, ok bool) // varchar(64)
PrimaryKey() (isPrimaryKey bool, ok bool)
AutoIncrement() (isAutoIncrement bool, ok bool)
Length() (length int64, ok bool)
DecimalSize() (precision int64, scale int64, ok bool)
Nullable() (nullable bool, ok bool)
Unique() (unique bool, ok bool)
ScanType() reflect.Type
Comment() (value string, ok bool)
DefaultValue() (value string, ok bool)
}
type Index interface {
Table() string
Name() string
Columns() []string
PrimaryKey() (isPrimaryKey bool, ok bool)
Unique() (unique bool, ok bool)
Option() string
}
// TableType table type interface
type TableType interface {
Schema() string
Name() string
Type() string
Comment() (comment string, ok bool)
}
// Migrator migrator interface
type Migrator interface {
// AutoMigrate
AutoMigrate(dst ...interface{}) error
// Database
CurrentDatabase() string
FullDataTypeOf(*schema.Field) clause.Expr
GetTypeAliases(databaseTypeName string) []string
// Tables
CreateTable(dst ...interface{}) error
DropTable(dst ...interface{}) error
HasTable(dst interface{}) bool
RenameTable(oldName, newName interface{}) error
GetTables() (tableList []string, err error)
TableType(dst interface{}) (TableType, error)
// Columns
AddColumn(dst interface{}, field string) error
DropColumn(dst interface{}, field string) error
AlterColumn(dst interface{}, field string) error
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
HasColumn(dst interface{}, field string) bool
RenameColumn(dst interface{}, oldName, field string) error
ColumnTypes(dst interface{}) ([]ColumnType, error)
// Views
CreateView(name string, option ViewOption) error
DropView(name string) error
// Constraints
CreateConstraint(dst interface{}, name string) error
DropConstraint(dst interface{}, name string) error
HasConstraint(dst interface{}, name string) bool
// Indexes
CreateIndex(dst interface{}, name string) error
DropIndex(dst interface{}, name string) error
HasIndex(dst interface{}, name string) bool
RenameIndex(dst interface{}, oldName, newName string) error
GetIndexes(dst interface{}) ([]Index, error)
}

107
vendor/gorm.io/gorm/migrator/column_type.go generated vendored Normal file
View File

@@ -0,0 +1,107 @@
package migrator
import (
"database/sql"
"reflect"
)
// ColumnType column type implements ColumnType interface
type ColumnType struct {
SQLColumnType *sql.ColumnType
NameValue sql.NullString
DataTypeValue sql.NullString
ColumnTypeValue sql.NullString
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
AutoIncrementValue sql.NullBool
LengthValue sql.NullInt64
DecimalSizeValue sql.NullInt64
ScaleValue sql.NullInt64
NullableValue sql.NullBool
ScanTypeValue reflect.Type
CommentValue sql.NullString
DefaultValueValue sql.NullString
}
// Name returns the name or alias of the column.
func (ct ColumnType) Name() string {
if ct.NameValue.Valid {
return ct.NameValue.String
}
return ct.SQLColumnType.Name()
}
// DatabaseTypeName returns the database system name of the column type. If an empty
// string is returned, then the driver type name is not supported.
// Consult your driver documentation for a list of driver data types. Length specifiers
// are not included.
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
// "INT", and "BIGINT".
func (ct ColumnType) DatabaseTypeName() string {
if ct.DataTypeValue.Valid {
return ct.DataTypeValue.String
}
return ct.SQLColumnType.DatabaseTypeName()
}
// ColumnType returns the database type of the column. like `varchar(16)`
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
}
// PrimaryKey returns the column is primary key or not.
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
}
// AutoIncrement returns the column is auto increment or not.
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
}
// Length returns the column type length for variable length column types
func (ct ColumnType) Length() (length int64, ok bool) {
if ct.LengthValue.Valid {
return ct.LengthValue.Int64, true
}
return ct.SQLColumnType.Length()
}
// DecimalSize returns the scale and precision of a decimal type.
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
if ct.DecimalSizeValue.Valid {
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
}
return ct.SQLColumnType.DecimalSize()
}
// Nullable reports whether the column may be null.
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
if ct.NullableValue.Valid {
return ct.NullableValue.Bool, true
}
return ct.SQLColumnType.Nullable()
}
// Unique reports whether the column may be unique.
func (ct ColumnType) Unique() (unique bool, ok bool) {
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
}
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
func (ct ColumnType) ScanType() reflect.Type {
if ct.ScanTypeValue != nil {
return ct.ScanTypeValue
}
return ct.SQLColumnType.ScanType()
}
// Comment returns the comment of current column.
func (ct ColumnType) Comment() (value string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}
// DefaultValue returns the default value of current column.
func (ct ColumnType) DefaultValue() (value string, ok bool) {
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
}

43
vendor/gorm.io/gorm/migrator/index.go generated vendored Normal file
View File

@@ -0,0 +1,43 @@
package migrator
import "database/sql"
// Index implements gorm.Index interface
type Index struct {
TableName string
NameValue string
ColumnList []string
PrimaryKeyValue sql.NullBool
UniqueValue sql.NullBool
OptionValue string
}
// Table return the table name of the index.
func (idx Index) Table() string {
return idx.TableName
}
// Name return the name of the index.
func (idx Index) Name() string {
return idx.NameValue
}
// Columns return the columns of the index
func (idx Index) Columns() []string {
return idx.ColumnList
}
// PrimaryKey returns the index is primary key or not.
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
}
// Unique returns whether the index is unique or not.
func (idx Index) Unique() (unique bool, ok bool) {
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
}
// Option return the optional attribute of the index
func (idx Index) Option() string {
return idx.OptionValue
}

989
vendor/gorm.io/gorm/migrator/migrator.go generated vendored Normal file
View File

@@ -0,0 +1,989 @@
package migrator
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
// with a possible trailing non-digit character (\D?).
// For example, values that can pass this regular expression are:
// - "123"
// - "abc456"
// -"%$#@789"
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
// TODO:? Create const vars for raw sql queries ?
var _ gorm.Migrator = (*Migrator)(nil)
// Migrator m struct
type Migrator struct {
Config
}
// Config schema config
type Config struct {
CreateIndexAfterCreateTable bool
DB *gorm.DB
gorm.Dialector
}
type printSQLLogger struct {
logger.Interface
}
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
sql, _ := fc()
fmt.Println(sql + ";")
l.Interface.Trace(ctx, begin, fc, err)
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDBDataType(*gorm.DB, *schema.Field) string
}
// RunWithValue run migration with statement value
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
stmt := &gorm.Statement{DB: m.DB}
if m.DB.Statement != nil {
stmt.Table = m.DB.Statement.Table
stmt.TableExpr = m.DB.Statement.TableExpr
}
if table, ok := value.(string); ok {
stmt.Table = table
} else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil {
return err
}
return fc(stmt)
}
// DataTypeOf return field's db data type
func (m Migrator) DataTypeOf(field *schema.Field) string {
fieldValue := reflect.New(field.IndirectFieldType)
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" {
return dataType
}
}
return m.Dialector.DataTypeOf(field)
}
// FullDataTypeOf returns field's db full data type
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL = m.DataTypeOf(field)
if field.NotNull {
expr.SQL += " NOT NULL"
}
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
} else if field.DefaultValue != "(-)" {
expr.SQL += " DEFAULT " + field.DefaultValue
}
}
return
}
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
queryTx = m.DB.Session(&gorm.Session{})
execTx = queryTx
if m.DB.DryRun {
queryTx.DryRun = false
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
}
return queryTx, execTx
}
// AutoMigrate auto migrate values
func (m Migrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) {
queryTx, execTx := m.GetQueryAndExecTx()
if !queryTx.Migrator().HasTable(value) {
if err := execTx.Migrator().CreateTable(value); err != nil {
return err
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
if err != nil {
return err
}
var (
parseIndexes = stmt.Schema.ParseIndexes()
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
)
for _, dbName := range stmt.Schema.DBNames {
var foundColumn gorm.ColumnType
for _, columnType := range columnTypes {
if columnType.Name() == dbName {
foundColumn = columnType
break
}
}
if foundColumn == nil {
// not found, add column
if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
return err
}
} else {
// found, smartly migrate
field := stmt.Schema.FieldsByDBName[dbName]
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
return err
}
}
}
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range stmt.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil &&
constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
return err
}
}
}
}
for _, chk := range parseCheckConstraints {
if !queryTx.Migrator().HasConstraint(value, chk.Name) {
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err
}
}
}
for _, idx := range parseIndexes {
if !queryTx.Migrator().HasIndex(value, idx.Name) {
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
return nil
}); err != nil {
return err
}
}
}
return nil
}
// GetTables returns tables
func (m Migrator) GetTables() (tableList []string, err error) {
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
Scan(&tableList).Error
return
}
// CreateTable create table in database for values
func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)}
hasPrimaryKeyInDataType bool
)
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
if !field.IgnoreMigration {
createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += ","
}
}
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
}
values = append(values, primaryKeys)
}
for _, idx := range stmt.Schema.ParseIndexes() {
if m.CreateIndexAfterCreateTable {
defer func(value interface{}, name string) {
if err == nil {
err = tx.Migrator().CreateIndex(value, name)
}
}(value, idx.Name)
} else {
if idx.Class != "" {
createTableSQL += idx.Class + " "
}
createTableSQL += "INDEX ? ?"
if idx.Comment != "" {
createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
if idx.Option != "" {
createTableSQL += " " + idx.Option
}
createTableSQL += ","
values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
}
}
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range stmt.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := constraint.Build()
createTableSQL += sql + ","
values = append(values, vars...)
}
}
}
}
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
}
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
createTableSQL += ")"
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
createTableSQL += fmt.Sprint(tableOption)
}
err = tx.Exec(createTableSQL, values...).Error
return err
}); err != nil {
return err
}
}
return nil
}
// DropTable drop table for values
func (m Migrator) DropTable(values ...interface{}) error {
values = m.ReorderModels(values, false)
for i := len(values) - 1; i >= 0; i-- {
tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error
}); err != nil {
return err
}
}
return nil
}
// HasTable returns table exists or not for value, value could be a struct or string
func (m Migrator) HasTable(value interface{}) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
})
return count > 0
}
// RenameTable rename table from oldName to newName
func (m Migrator) RenameTable(oldName, newName interface{}) error {
var oldTable, newTable interface{}
if v, ok := oldName.(string); ok {
oldTable = clause.Table{Name: v}
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(oldName); err == nil {
oldTable = m.CurrentTable(stmt)
} else {
return err
}
}
if v, ok := newName.(string); ok {
newTable = clause.Table{Name: v}
} else {
stmt := &gorm.Statement{DB: m.DB}
if err := stmt.Parse(newName); err == nil {
newTable = m.CurrentTable(stmt)
} else {
return err
}
}
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
}
// AddColumn create `name` column for value
func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field
f := stmt.Schema.LookUpField(name)
if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name)
}
if !f.IgnoreMigration {
return m.DB.Exec(
"ALTER TABLE ? ADD ? ?",
m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f),
).Error
}
return nil
})
}
// DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
return m.DB.Exec(
"ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name},
).Error
})
}
// AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil {
fileType := m.FullDataTypeOf(field)
return m.DB.Exec(
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
).Error
}
return fmt.Errorf("failed to look up field with name: %s", field)
})
}
// HasColumn check has column `field` for value or not
func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field
if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
currentDatabase, stmt.Table, name,
).Row().Scan(&count)
})
return count > 0
}
// RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil {
oldName = field.DBName
}
if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName
}
return m.DB.Exec(
"ALTER TABLE ? RENAME COLUMN ? TO ?",
m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}
// MigrateColumn migrate column
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
if field.IgnoreMigration {
return nil
}
// found, smart migrate
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
realDataType := strings.ToLower(columnType.DatabaseTypeName())
var (
alterColumn bool
isSameType = fullDataType == realDataType
)
if !field.PrimaryKey {
// check type
if !strings.HasPrefix(fullDataType, realDataType) {
// check type aliases
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
for _, alias := range aliases {
if strings.HasPrefix(fullDataType, alias) {
isSameType = true
break
}
}
if !isSameType {
alterColumn = true
}
}
}
if !isSameType {
// check size
if length, ok := columnType.Length(); length != int64(field.Size) {
if length > 0 && field.Size > 0 {
alterColumn = true
} else {
// has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if !field.PrimaryKey &&
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
alterColumn = true
}
}
}
// check precision
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
alterColumn = true
}
}
}
// check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
// not primary key & database is nullable
if !field.PrimaryKey && nullable {
alterColumn = true
}
}
// check default value
if !field.PrimaryKey {
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
dv, dvNotNull := columnType.DefaultValue()
if dvNotNull && !currentDefaultNotNull {
// default value -> null
alterColumn = true
} else if !dvNotNull && currentDefaultNotNull {
// null -> default value
alterColumn = true
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
// default value not equal
// not both null
if currentDefaultNotNull || dvNotNull {
alterColumn = true
}
}
}
// check comment
if comment, ok := columnType.Comment(); ok && comment != field.Comment {
// not primary key
if !field.PrimaryKey {
alterColumn = true
}
}
if alterColumn {
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
return err
}
}
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
return err
}
return nil
}
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
unique, ok := columnType.Unique()
if !ok || field.PrimaryKey {
return nil // skip primary key
}
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// We're currently only receiving boolean values on `Unique` tag,
// so the UniqueConstraint name is fixed
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
if unique && !field.Unique {
return m.DB.Migrator().DropConstraint(value, constraint)
}
if !unique && field.Unique {
return m.DB.Migrator().CreateConstraint(value, constraint)
}
return nil
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err != nil {
return err
}
defer func() {
err = rows.Close()
}()
var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, ColumnType{SQLColumnType: c})
}
return
})
return columnTypes, execErr
}
// CreateView create view from Query in gorm.ViewOption.
// Query in gorm.ViewOption is a [subquery]
//
// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20
// q := DB.Model(&User{}).Where("age > ?", 20)
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q})
//
// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION
// q := DB.Model(&User{})
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"})
//
// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
if option.Query == nil {
return gorm.ErrSubQueryRequired
}
sql := new(strings.Builder)
sql.WriteString("CREATE ")
if option.Replace {
sql.WriteString("OR REPLACE ")
}
sql.WriteString("VIEW ")
m.QuoteTo(sql, name)
sql.WriteString(" AS ")
m.DB.Statement.AddVar(sql, option.Query)
if option.CheckOption != "" {
sql.WriteString(" ")
sql.WriteString(option.CheckOption)
}
return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error
}
// DropView drop view
func (m Migrator) DropView(name string) error {
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
}
// GuessConstraintAndTable guess statement's constraint and it's table based on name
//
// Deprecated: use GuessConstraintInterfaceAndTable instead.
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
switch c := constraint.(type) {
case *schema.Constraint:
return c, nil, table
case *schema.CheckConstraint:
return nil, c, table
default:
return nil, nil, table
}
}
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
// nolint:cyclop
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
if stmt.Schema == nil {
return nil, stmt.Table
}
checkConstraints := stmt.Schema.ParseCheckConstraints()
if chk, ok := checkConstraints[name]; ok {
return &chk, stmt.Table
}
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
if uni, ok := uniqueConstraints[name]; ok {
return &uni, stmt.Table
}
getTable := func(rel *schema.Relationship) string {
switch rel.Type {
case schema.HasOne, schema.HasMany:
return rel.FieldSchema.Table
case schema.Many2Many:
return rel.JoinTable.Table
}
return stmt.Table
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
return constraint, getTable(rel)
}
}
if field := stmt.Schema.LookUpField(name); field != nil {
for k := range checkConstraints {
if checkConstraints[k].Field == field {
v := checkConstraints[k]
return &v, stmt.Table
}
}
for k := range uniqueConstraints {
if uniqueConstraints[k].Field == field {
v := uniqueConstraints[k]
return &v, stmt.Table
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
return constraint, getTable(rel)
}
}
}
return nil, stmt.Schema.Table
}
// CreateConstraint create constraint
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
vars := []interface{}{clause.Table{Name: table}}
if stmt.TableExpr != nil {
vars[0] = stmt.TableExpr
}
sql, values := constraint.Build()
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
}
return nil
})
}
// DropConstraint drop constraint
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
}
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
})
}
// HasConstraint check has constraint or not
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
}
return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
currentDatabase, table, name,
).Row().Scan(&count)
})
return count > 0
}
// BuildIndexOptions build index options
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
} else if opt.Length > 0 {
str += fmt.Sprintf("(%d)", opt.Length)
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
// BuildIndexOptionsInterface build index options interface
type BuildIndexOptionsInterface interface {
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
}
// CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
if idx.Comment != "" {
createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
if idx.Option != "" {
createIndexSQL += " " + idx.Option
}
return m.DB.Exec(createIndexSQL, values...).Error
}
return fmt.Errorf("failed to create index with name %s", name)
})
}
// DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
})
}
// HasIndex check has index `name` or not
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
return m.DB.Raw(
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
currentDatabase, stmt.Table, name,
).Row().Scan(&count)
})
return count > 0
}
// RenameIndex rename index from oldName to newName
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Exec(
"ALTER TABLE ? RENAME INDEX ? TO ?",
m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
).Error
})
}
// CurrentDatabase returns current database name
func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
return
}
// ReorderModels reorder models according to constraint dependencies
func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) {
type Dependency struct {
*gorm.Statement
Depends []*schema.Schema
}
var (
modelNames, orderedModelNames []string
orderedModelNamesMap = map[string]bool{}
parsedSchemas = map[*schema.Schema]bool{}
valuesMap = map[string]Dependency{}
insertIntoOrderedList func(name string)
parseDependence func(value interface{}, addToList bool)
)
parseDependence = func(value interface{}, addToList bool) {
dep := Dependency{
Statement: &gorm.Statement{DB: m.DB, Dest: value},
}
beDependedOn := map[*schema.Schema]bool{}
// support for special table name
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
}
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
return
}
parsedSchemas[dep.Statement.Schema] = true
if !m.DB.IgnoreRelationshipsWhenMigrating {
for _, rel := range dep.Schema.Relationships.Relations {
if rel.Field.IgnoreMigration {
continue
}
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
dep.Depends = append(dep.Depends, c.ReferenceSchema)
}
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
beDependedOn[rel.FieldSchema] = true
}
if rel.JoinTable != nil {
// append join value
defer func(rel *schema.Relationship, joinValue interface{}) {
if !beDependedOn[rel.FieldSchema] {
dep.Depends = append(dep.Depends, rel.FieldSchema)
} else {
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
parseDependence(fieldValue, autoAdd)
}
parseDependence(joinValue, autoAdd)
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
}
}
}
valuesMap[dep.Schema.Table] = dep
if addToList {
modelNames = append(modelNames, dep.Schema.Table)
}
}
insertIntoOrderedList = func(name string) {
if _, ok := orderedModelNamesMap[name]; ok {
return // avoid loop
}
orderedModelNamesMap[name] = true
if autoAdd {
dep := valuesMap[name]
for _, d := range dep.Depends {
if _, ok := valuesMap[d.Table]; ok {
insertIntoOrderedList(d.Table)
} else {
parseDependence(reflect.New(d.ModelType).Interface(), autoAdd)
insertIntoOrderedList(d.Table)
}
}
}
orderedModelNames = append(orderedModelNames, name)
}
for _, value := range values {
if v, ok := value.(string); ok {
results = append(results, v)
} else {
parseDependence(value, true)
}
}
for _, name := range modelNames {
insertIntoOrderedList(name)
}
for _, name := range orderedModelNames {
results = append(results, valuesMap[name].Statement.Dest)
}
return
}
// CurrentTable returns current statement's table expression
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
if stmt.TableExpr != nil {
return *stmt.TableExpr
}
return clause.Table{Name: stmt.Table}
}
// GetIndexes return Indexes []gorm.Index and execErr error
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
return nil, errors.New("not support")
}
// GetTypeAliases return database type aliases
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
return nil
}
// TableType return tableType gorm.TableType and execErr error
func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) {
return nil, errors.New("not support")
}

33
vendor/gorm.io/gorm/migrator/table_type.go generated vendored Normal file
View File

@@ -0,0 +1,33 @@
package migrator
import (
"database/sql"
)
// TableType table type implements TableType interface
type TableType struct {
SchemaValue string
NameValue string
TypeValue string
CommentValue sql.NullString
}
// Schema returns the schema of the table.
func (ct TableType) Schema() string {
return ct.SchemaValue
}
// Name returns the name of the table.
func (ct TableType) Name() string {
return ct.NameValue
}
// Type returns the type of the table.
func (ct TableType) Type() string {
return ct.TypeValue
}
// Comment returns the comment of current table.
func (ct TableType) Comment() (comment string, ok bool) {
return ct.CommentValue.String, ct.CommentValue.Valid
}

16
vendor/gorm.io/gorm/model.go generated vendored Normal file
View File

@@ -0,0 +1,16 @@
package gorm
import "time"
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
// It may be embedded into your model or you may build your own model without it
//
// type User struct {
// gorm.Model
// }
type Model struct {
ID uint `gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt DeletedAt `gorm:"index"`
}

242
vendor/gorm.io/gorm/prepare_stmt.go generated vendored Normal file
View File

@@ -0,0 +1,242 @@
package gorm
import (
"context"
"database/sql"
"reflect"
"sync"
)
type Stmt struct {
*sql.Stmt
Transaction bool
prepared chan struct{}
prepareErr error
}
type PreparedStmtDB struct {
Stmts map[string]*Stmt
PreparedSQL []string
Mux *sync.RWMutex
ConnPool
}
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
return &PreparedStmtDB{
ConnPool: connPool,
Stmts: make(map[string]*Stmt),
Mux: &sync.RWMutex{},
PreparedSQL: make([]string, 0, 100),
}
}
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}
return nil, ErrInvalidDB
}
func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
defer db.Mux.Unlock()
for _, query := range db.PreparedSQL {
if stmt, ok := db.Stmts[query]; ok {
delete(db.Stmts, query)
go stmt.Close()
}
}
}
func (sdb *PreparedStmtDB) Reset() {
sdb.Mux.Lock()
defer sdb.Mux.Unlock()
for _, stmt := range sdb.Stmts {
go stmt.Close()
}
sdb.PreparedSQL = make([]string, 0, 100)
sdb.Stmts = make(map[string]*Stmt)
}
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
db.Mux.RLock()
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.RUnlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
db.Mux.RUnlock()
db.Mux.Lock()
// double check
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
db.Mux.Unlock()
// wait for other goroutines prepared
<-stmt.prepared
if stmt.prepareErr != nil {
return Stmt{}, stmt.prepareErr
}
return *stmt, nil
}
// cache preparing stmt first
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
db.Stmts[query] = &cacheStmt
db.Mux.Unlock()
// prepare completed
defer close(cacheStmt.prepared)
// Reason why cannot lock conn.PrepareContext
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
stmt, err := conn.PrepareContext(ctx, query)
if err != nil {
cacheStmt.prepareErr = err
db.Mux.Lock()
delete(db.Stmts, query)
db.Mux.Unlock()
return Stmt{}, err
}
db.Mux.Lock()
cacheStmt.Stmt = stmt
db.PreparedSQL = append(db.PreparedSQL, query)
db.Mux.Unlock()
return cacheStmt, nil
}
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
if beginner, ok := db.ConnPool.(TxBeginner); ok {
tx, err := beginner.BeginTx(ctx, opt)
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
}
beginner, ok := db.ConnPool.(ConnPoolBeginner)
if !ok {
return nil, ErrInvalidTransaction
}
connPool, err := beginner.BeginTx(ctx, opt)
if err != nil {
return nil, err
}
if tx, ok := connPool.(Tx); ok {
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
}
return nil, ErrInvalidTransaction
}
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
result, err = stmt.ExecContext(ctx, args...)
if err != nil {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
}
}
return result, err
}
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
rows, err = stmt.QueryContext(ctx, args...)
if err != nil {
db.Mux.Lock()
defer db.Mux.Unlock()
go stmt.Close()
delete(db.Stmts, query)
}
}
return rows, err
}
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
if err == nil {
return stmt.QueryRowContext(ctx, args...)
}
return &sql.Row{}
}
type PreparedStmtTX struct {
Tx
PreparedStmtDB *PreparedStmtDB
}
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
return db.PreparedStmtDB.GetDBConn()
}
func (tx *PreparedStmtTX) Commit() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Commit()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) Rollback() error {
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
return tx.Tx.Rollback()
}
return ErrInvalidTransaction
}
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
}
}
return result, err
}
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
if err != nil {
tx.PreparedStmtDB.Mux.Lock()
defer tx.PreparedStmtDB.Mux.Unlock()
go stmt.Close()
delete(tx.PreparedStmtDB.Stmts, query)
}
}
return rows, err
}
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
if err == nil {
return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
}
return &sql.Row{}
}

342
vendor/gorm.io/gorm/scan.go generated vendored Normal file
View File

@@ -0,0 +1,342 @@
package gorm
import (
"database/sql"
"database/sql/driver"
"reflect"
"time"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// prepareValues prepare values slice
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
if db.Statement.Schema != nil {
for idx, name := range columns {
if field := db.Statement.Schema.LookUpField(name); field != nil {
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
continue
}
values[idx] = new(interface{})
}
} else if len(columnTypes) > 0 {
for idx, columnType := range columnTypes {
if columnType.ScanType() != nil {
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
} else {
values[idx] = new(interface{})
}
}
} else {
for idx := range columns {
values[idx] = new(interface{})
}
}
}
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
for idx, column := range columns {
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
mapValue[column] = reflectValue.Interface()
if valuer, ok := mapValue[column].(driver.Valuer); ok {
mapValue[column], _ = valuer.Value()
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
mapValue[column] = string(b)
}
} else {
mapValue[column] = nil
}
}
}
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
for idx, field := range fields {
if field != nil {
values[idx] = field.NewValuePool.Get()
} else if len(fields) == 1 {
if reflectValue.CanAddr() {
values[idx] = reflectValue.Addr().Interface()
} else {
values[idx] = reflectValue.Interface()
}
}
}
db.RowsAffected++
db.AddError(rows.Scan(values...))
joinedNestedSchemaMap := make(map[string]interface{})
for idx, field := range fields {
if field == nil {
continue
}
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
} else { // joinFields count is larger than 2 when using join
var isNilPtrValue bool
var relValue reflect.Value
// does not contain raw dbname
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
// current reflect value
currentReflectValue := reflectValue
fullRels := make([]string, 0, len(nestedJoinSchemas))
for _, joinSchema := range nestedJoinSchemas {
fullRels = append(fullRels, joinSchema.Name)
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
if relValue.Kind() == reflect.Ptr {
fullRelsName := utils.JoinNestedRelationNames(fullRels)
// same nested structure
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
isNilPtrValue = true
break
}
relValue.Set(reflect.New(relValue.Type().Elem()))
joinedNestedSchemaMap[fullRelsName] = nil
}
}
currentReflectValue = relValue
}
if !isNilPtrValue { // ignore if value is nil
f := joinFields[idx][len(joinFields[idx])-1]
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
}
}
// release data to pool
field.NewValuePool.Put(values[idx])
}
}
// ScanMode scan data mode
type ScanMode uint8
// scan modes
const (
ScanInitialized ScanMode = 1 << 0 // 1
ScanUpdate ScanMode = 1 << 1 // 2
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
)
// Scan scan rows into db statement
func Scan(rows Rows, db *DB, mode ScanMode) {
var (
columns, _ = rows.Columns()
values = make([]interface{}, len(columns))
initialized = mode&ScanInitialized != 0
update = mode&ScanUpdate != 0
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
)
db.RowsAffected = 0
switch dest := db.Statement.Dest.(type) {
case map[string]interface{}, *map[string]interface{}:
if initialized || rows.Next() {
columnTypes, _ := rows.ColumnTypes()
prepareValues(values, db, columnTypes, columns)
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue, ok := dest.(map[string]interface{})
if !ok {
if v, ok := dest.(*map[string]interface{}); ok {
if *v == nil {
*v = map[string]interface{}{}
}
mapValue = *v
}
}
scanIntoMap(mapValue, values, columns)
}
case *[]map[string]interface{}:
columnTypes, _ := rows.ColumnTypes()
for initialized || rows.Next() {
prepareValues(values, db, columnTypes, columns)
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(values...))
mapValue := map[string]interface{}{}
scanIntoMap(mapValue, values, columns)
*dest = append(*dest, mapValue)
}
case *int, *int8, *int16, *int32, *int64,
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
*float32, *float64,
*bool, *string, *time.Time,
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
*sql.NullBool, *sql.NullString, *sql.NullTime:
for initialized || rows.Next() {
initialized = false
db.RowsAffected++
db.AddError(rows.Scan(dest))
}
default:
var (
fields = make([]*schema.Field, len(columns))
joinFields [][]*schema.Field
sch = db.Statement.Schema
reflectValue = db.Statement.ReflectValue
)
if reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
reflectValueType := reflectValue.Type()
switch reflectValueType.Kind() {
case reflect.Array, reflect.Slice:
reflectValueType = reflectValueType.Elem()
}
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}
if sch != nil {
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
}
if len(columns) == 1 {
// Is Pluck
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
reflectValueType.Kind() != reflect.Struct || // is not struct
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
sch = nil
}
}
// Not Pluck
if sch != nil {
matchedFieldCount := make(map[string]int, len(columns))
for idx, column := range columns {
if field := sch.LookUpField(column); field != nil && field.Readable {
fields[idx] = field
if count, ok := matchedFieldCount[column]; ok {
// handle duplicate fields
for _, selectField := range sch.Fields {
if selectField.DBName == column && selectField.Readable {
if count == 0 {
matchedFieldCount[column]++
fields[idx] = selectField
break
}
count--
}
}
} else {
matchedFieldCount[column] = 1
}
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
subNameCount := len(names)
// nested relation fields
relFields := make([]*schema.Field, 0, subNameCount-1)
relFields = append(relFields, rel.Field)
for _, name := range names[1 : subNameCount-1] {
rel = rel.FieldSchema.Relationships.Relations[name]
relFields = append(relFields, rel.Field)
}
// lastest name is raw dbname
dbName := names[subNameCount-1]
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
fields[idx] = field
if len(joinFields) == 0 {
joinFields = make([][]*schema.Field, len(columns))
}
relFields = append(relFields, field)
joinFields[idx] = relFields
continue
}
}
values[idx] = &sql.RawBytes{}
} else {
values[idx] = &sql.RawBytes{}
}
}
}
}
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
var (
elem reflect.Value
isArrayKind = reflectValue.Kind() == reflect.Array
)
if !update || reflectValue.Len() == 0 {
update = false
// if the slice cap is externally initialized, the externally initialized slice is directly used here
if reflectValue.Cap() == 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
} else if !isArrayKind {
reflectValue.SetLen(0)
db.Statement.ReflectValue.Set(reflectValue)
}
}
for initialized || rows.Next() {
BEGIN:
initialized = false
if update {
if int(db.RowsAffected) >= reflectValue.Len() {
return
}
elem = reflectValue.Index(int(db.RowsAffected))
if onConflictDonothing {
for _, field := range fields {
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
db.RowsAffected++
goto BEGIN
}
}
}
} else {
elem = reflect.New(reflectValueType)
}
db.scanIntoStruct(rows, elem, values, fields, joinFields)
if !update {
if !isPtr {
elem = elem.Elem()
}
if isArrayKind {
if reflectValue.Len() >= int(db.RowsAffected) {
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
}
} else {
reflectValue = reflect.Append(reflectValue, elem)
}
}
}
if !update {
db.Statement.ReflectValue.Set(reflectValue)
}
case reflect.Struct, reflect.Ptr:
if initialized || rows.Next() {
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
}
default:
db.AddError(rows.Scan(dest))
}
}
if err := rows.Err(); err != nil && err != db.Error {
db.AddError(err)
}
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
db.AddError(ErrRecordNotFound)
}
}

66
vendor/gorm.io/gorm/schema/constraint.go generated vendored Normal file
View File

@@ -0,0 +1,66 @@
package schema
import (
"regexp"
"strings"
"gorm.io/gorm/clause"
)
// reg match english letters and midline
var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
type CheckConstraint struct {
Name string
Constraint string // length(phone) >= 10
*Field
}
func (chk *CheckConstraint) GetName() string { return chk.Name }
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
}
// ParseCheckConstraints parse schema check constraints
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
checks := map[string]CheckConstraint{}
for _, field := range schema.FieldsByDBName {
if chk := field.TagSettings["CHECK"]; chk != "" {
names := strings.Split(chk, ",")
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
} else {
if names[0] == "" {
chk = strings.Join(names[1:], ",")
}
name := schema.namer.CheckerName(schema.Table, field.DBName)
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
}
}
}
return checks
}
type UniqueConstraint struct {
Name string
Field *Field
}
func (uni *UniqueConstraint) GetName() string { return uni.Name }
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
}
// ParseUniqueConstraints parse schema unique constraints
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
uniques := make(map[string]UniqueConstraint)
for _, field := range schema.Fields {
if field.Unique {
name := schema.namer.UniqueName(schema.Table, field.DBName)
uniques[name] = UniqueConstraint{Name: name, Field: field}
}
}
return uniques
}

996
vendor/gorm.io/gorm/schema/field.go generated vendored Normal file
View File

@@ -0,0 +1,996 @@
package schema
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/jinzhu/now"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
)
// special types' reflect type
var (
TimeReflectType = reflect.TypeOf(time.Time{})
TimePtrReflectType = reflect.TypeOf(&time.Time{})
ByteReflectType = reflect.TypeOf(uint8(0))
)
type (
// DataType GORM data type
DataType string
// TimeType GORM time type
TimeType int64
)
// GORM time types
const (
UnixTime TimeType = 1
UnixSecond TimeType = 2
UnixMillisecond TimeType = 3
UnixNanosecond TimeType = 4
)
// GORM fields types
const (
Bool DataType = "bool"
Int DataType = "int"
Uint DataType = "uint"
Float DataType = "float"
String DataType = "string"
Time DataType = "time"
Bytes DataType = "bytes"
)
const DefaultAutoIncrementIncrement int64 = 1
// Field is the representation of model schema's field
type Field struct {
Name string
DBName string
BindNames []string
DataType DataType
GORMDataType DataType
PrimaryKey bool
AutoIncrement bool
AutoIncrementIncrement int64
Creatable bool
Updatable bool
Readable bool
AutoCreateTime TimeType
AutoUpdateTime TimeType
HasDefaultValue bool
DefaultValue string
DefaultValueInterface interface{}
NotNull bool
Unique bool
Comment string
Size int
Precision int
Scale int
IgnoreMigration bool
FieldType reflect.Type
IndirectFieldType reflect.Type
StructField reflect.StructField
Tag reflect.StructTag
TagSettings map[string]string
Schema *Schema
EmbeddedSchema *Schema
OwnerSchema *Schema
ReflectValueOf func(context.Context, reflect.Value) reflect.Value
ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool)
Set func(context.Context, reflect.Value, interface{}) error
Serializer SerializerInterface
NewValuePool FieldNewValuePool
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
// It causes field unnecessarily migration.
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
UniqueIndex string
}
func (field *Field) BindName() string {
return strings.Join(field.BindNames, ".")
}
// ParseField parses reflect.StructField to Field
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var (
err error
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
)
field := &Field{
Name: fieldStruct.Name,
DBName: tagSetting["COLUMN"],
BindNames: []string{fieldStruct.Name},
FieldType: fieldStruct.Type,
IndirectFieldType: fieldStruct.Type,
StructField: fieldStruct,
Tag: fieldStruct.Tag,
TagSettings: tagSetting,
Schema: schema,
Creatable: true,
Updatable: true,
Readable: true,
PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
Comment: tagSetting["COMMENT"],
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
}
for field.IndirectFieldType.Kind() == reflect.Ptr {
field.IndirectFieldType = field.IndirectFieldType.Elem()
}
fieldValue := reflect.New(field.IndirectFieldType)
// if field is valuer, used its value or first field as data type
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
if isValuer {
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil {
fieldValue = reflect.ValueOf(v)
}
// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
var getRealFieldValue func(reflect.Value)
getRealFieldValue = func(v reflect.Value) {
var (
rv = reflect.Indirect(v)
rvType = rv.Type()
)
if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
for i := 0; i < rvType.NumField(); i++ {
for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
if _, ok := field.TagSettings[key]; !ok {
field.TagSettings[key] = value
}
}
}
for i := 0; i < rvType.NumField(); i++ {
newFieldType := rvType.Field(i).Type
for newFieldType.Kind() == reflect.Ptr {
newFieldType = newFieldType.Elem()
}
fieldValue = reflect.New(newFieldType)
if rvType != reflect.Indirect(fieldValue).Type() {
getRealFieldValue(fieldValue)
}
if fieldValue.IsValid() {
return
}
}
}
}
getRealFieldValue(fieldValue)
}
}
if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer {
field.DataType = String
field.Serializer = v
} else {
serializerName := field.TagSettings["JSON"]
if serializerName == "" {
serializerName = field.TagSettings["SERIALIZER"]
}
if serializerName != "" {
if serializer, ok := GetSerializer(serializerName); ok {
// Set default data type to string for serializer
field.DataType = String
field.Serializer = serializer
} else {
schema.err = fmt.Errorf("invalid serializer type %v", serializerName)
}
}
}
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64)
}
if v, ok := field.TagSettings["DEFAULT"]; ok {
field.HasDefaultValue = true
field.DefaultValue = v
}
if num, ok := field.TagSettings["SIZE"]; ok {
if field.Size, err = strconv.Atoi(num); err != nil {
field.Size = -1
}
}
if p, ok := field.TagSettings["PRECISION"]; ok {
field.Precision, _ = strconv.Atoi(p)
}
if s, ok := field.TagSettings["SCALE"]; ok {
field.Scale, _ = strconv.Atoi(s)
}
// default value is function or null or blank (primary keys)
field.DefaultValue = strings.TrimSpace(field.DefaultValue)
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
switch reflect.Indirect(fieldValue).Kind() {
case reflect.Bool:
field.DataType = Bool
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil {
schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err)
}
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.DataType = Int
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err)
}
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.DataType = Uint
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err)
}
}
case reflect.Float32, reflect.Float64:
field.DataType = Float
if field.HasDefaultValue && !skipParseDefaultValue {
if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil {
schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err)
}
}
case reflect.String:
field.DataType = String
if field.HasDefaultValue && !skipParseDefaultValue {
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
field.DefaultValueInterface = field.DefaultValue
}
case reflect.Struct:
if _, ok := fieldValue.Interface().(*time.Time); ok {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
field.DataType = Time
} else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) {
field.DataType = Time
}
if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time {
if t, err := now.Parse(field.DefaultValue); err == nil {
field.DefaultValueInterface = t
}
}
case reflect.Array, reflect.Slice:
if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" {
field.DataType = Bytes
}
}
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
field.DataType = DataType(dataTyper.GormDataType())
}
if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time {
field.AutoCreateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoCreateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
field.AutoCreateTime = UnixMillisecond
} else {
field.AutoCreateTime = UnixSecond
}
}
if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
if field.DataType == Time {
field.AutoUpdateTime = UnixTime
} else if strings.ToUpper(v) == "NANO" {
field.AutoUpdateTime = UnixNanosecond
} else if strings.ToUpper(v) == "MILLI" {
field.AutoUpdateTime = UnixMillisecond
} else {
field.AutoUpdateTime = UnixSecond
}
}
if field.GORMDataType == "" {
field.GORMDataType = field.DataType
}
if val, ok := field.TagSettings["TYPE"]; ok {
switch DataType(strings.ToLower(val)) {
case Bool, Int, Uint, Float, String, Time, Bytes:
field.DataType = DataType(strings.ToLower(val))
default:
field.DataType = DataType(val)
}
}
if field.Size == 0 {
switch reflect.Indirect(fieldValue).Kind() {
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
field.Size = 64
case reflect.Int8, reflect.Uint8:
field.Size = 8
case reflect.Int16, reflect.Uint16:
field.Size = 16
case reflect.Int32, reflect.Uint32, reflect.Float32:
field.Size = 32
}
}
// setup permission
if val, ok := field.TagSettings["-"]; ok {
val = strings.ToLower(strings.TrimSpace(val))
switch val {
case "-":
field.Creatable = false
field.Updatable = false
field.Readable = false
field.DataType = ""
case "all":
field.Creatable = false
field.Updatable = false
field.Readable = false
field.DataType = ""
field.IgnoreMigration = true
case "migration":
field.IgnoreMigration = true
}
}
if v, ok := field.TagSettings["->"]; ok {
field.Creatable = false
field.Updatable = false
if strings.ToLower(v) == "false" {
field.Readable = false
} else {
field.Readable = true
}
}
if v, ok := field.TagSettings["<-"]; ok {
field.Creatable = true
field.Updatable = true
if v != "<-" {
if !strings.Contains(v, "create") {
field.Creatable = false
}
if !strings.Contains(v, "update") {
field.Updatable = false
}
}
}
// Normal anonymous field or having `EMBEDDED` tag
if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer &&
fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) {
kind := reflect.Indirect(fieldValue).Kind()
switch kind {
case reflect.Struct:
var err error
field.Creatable = false
field.Updatable = false
field.Readable = false
cacheStore := &sync.Map{}
cacheStore.Store(embeddedCacheKey, true)
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
schema.err = err
}
for _, ef := range field.EmbeddedSchema.Fields {
ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
// index is negative means is pointer
if field.FieldType.Kind() == reflect.Struct {
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
} else {
ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
}
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" {
ef.DBName = prefix + ef.DBName
}
if ef.PrimaryKey {
if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) {
ef.PrimaryKey = false
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
ef.AutoIncrement = false
}
if !ef.AutoIncrement && ef.DefaultValue == "" {
ef.HasDefaultValue = false
}
}
}
for k, v := range field.TagSettings {
ef.TagSettings[k] = v
}
}
case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128:
schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
}
}
return field
}
// create valuer, setter when parse struct
func (field *Field) setupValuerAndSetter() {
// Setup NewValuePool
field.setupNewValuePool()
// ValueOf returns field's value and if it is zero
fieldIndex := field.StructField.Index[0]
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
fieldValue := reflect.Indirect(value).Field(fieldIndex)
return fieldValue.Interface(), fieldValue.IsZero()
}
default:
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
v = reflect.Indirect(v)
for _, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
if !v.IsNil() {
v = v.Elem()
} else {
return nil, true
}
}
}
fv, zero := v.Interface(), v.IsZero()
return fv, zero
}
}
if field.Serializer != nil {
oldValuerOf := field.ValueOf
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
value, zero := oldValuerOf(ctx, v)
s, ok := value.(SerializerValuerInterface)
if !ok {
s = field.Serializer
}
return &serializer{
Field: field,
SerializeValuer: s,
Destination: v,
Context: ctx,
fieldValue: value,
}, zero
}
}
// ReflectValueOf returns field's reflect value
switch {
case len(field.StructField.Index) == 1 && fieldIndex > 0:
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
return reflect.Indirect(value).Field(fieldIndex)
}
default:
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
for idx, fieldIdx := range field.StructField.Index {
if fieldIdx >= 0 {
v = v.Field(fieldIdx)
} else {
v = v.Field(-fieldIdx - 1)
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
if idx < len(field.StructField.Index)-1 {
v = v.Elem()
}
}
}
return v
}
}
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
if v == nil {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else {
reflectV := reflect.ValueOf(v)
// Optimal value type acquisition for v
reflectValType := reflectV.Type()
if reflectValType.AssignableTo(field.FieldType) {
if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr {
reflectV = reflect.Indirect(reflectV)
}
field.ReflectValueOf(ctx, value).Set(reflectV)
return
} else if reflectValType.ConvertibleTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType))
return
} else if field.FieldType.Kind() == reflect.Ptr {
fieldValue := field.ReflectValueOf(ctx, value)
fieldType := field.FieldType.Elem()
if reflectValType.AssignableTo(fieldType) {
if !fieldValue.IsValid() {
fieldValue = reflect.New(fieldType)
} else if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldType))
}
fieldValue.Elem().Set(reflectV)
return
} else if reflectValType.ConvertibleTo(fieldType) {
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(fieldType))
}
fieldValue.Elem().Set(reflectV.Convert(fieldType))
return
}
}
if reflectV.Kind() == reflect.Ptr {
if reflectV.IsNil() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Type().Elem().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV.Elem())
return
} else {
err = setter(ctx, value, reflectV.Elem().Interface())
}
} else if valuer, ok := v.(driver.Valuer); ok {
if v, err = valuer.Value(); err == nil {
err = setter(ctx, value, v)
}
} else if _, ok := v.(clause.Expr); !ok {
return fmt.Errorf("failed to set value %#v to field %s", v, field.Name)
}
}
return
}
// Set
switch field.FieldType.Kind() {
case reflect.Bool:
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) {
case **bool:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetBool(**data)
}
case bool:
field.ReflectValueOf(ctx, value).SetBool(data)
case int64:
field.ReflectValueOf(ctx, value).SetBool(data > 0)
case string:
b, _ := strconv.ParseBool(data)
field.ReflectValueOf(ctx, value).SetBool(b)
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return nil
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case **int64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(**data)
}
case **int:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case **int8:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case **int16:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case **int32:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
}
case int64:
field.ReflectValueOf(ctx, value).SetInt(data)
case int:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int8:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int16:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case int32:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint8:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint16:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint32:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case uint64:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float32:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case float64:
field.ReflectValueOf(ctx, value).SetInt(int64(data))
case []byte:
return field.Set(ctx, value, string(data))
case string:
if i, err := strconv.ParseInt(data, 0, 64); err == nil {
field.ReflectValueOf(ctx, value).SetInt(i)
} else {
return err
}
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
case *time.Time:
if data != nil {
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
} else {
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
}
} else {
field.ReflectValueOf(ctx, value).SetInt(0)
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case **uint64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(**data)
}
case **uint:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case **uint8:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case **uint16:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case **uint32:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
}
case uint64:
field.ReflectValueOf(ctx, value).SetUint(data)
case uint:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint8:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint16:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case uint32:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int64:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int8:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int16:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case int32:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float32:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case float64:
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
case []byte:
return field.Set(ctx, value, string(data))
case time.Time:
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
} else {
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
}
case string:
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
field.ReflectValueOf(ctx, value).SetUint(i)
} else {
return err
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
case reflect.Float32, reflect.Float64:
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case **float64:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(**data)
}
case **float32:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetFloat(float64(**data))
}
case float64:
field.ReflectValueOf(ctx, value).SetFloat(data)
case float32:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int64:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int8:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int16:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case int32:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint8:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint16:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint32:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case uint64:
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
case []byte:
return field.Set(ctx, value, string(data))
case string:
if i, err := strconv.ParseFloat(data, 64); err == nil {
field.ReflectValueOf(ctx, value).SetFloat(i)
} else {
return err
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
case reflect.String:
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
switch data := v.(type) {
case **string:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).SetString(**data)
}
case string:
field.ReflectValueOf(ctx, value).SetString(data)
case []byte:
field.ReflectValueOf(ctx, value).SetString(string(data))
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.ReflectValueOf(ctx, value).SetString(utils.ToString(data))
case float64, float32:
field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return err
}
default:
fieldValue := reflect.New(field.FieldType)
switch fieldValue.Elem().Interface().(type) {
case time.Time:
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) {
case **time.Time:
if data != nil && *data != nil {
field.Set(ctx, value, *data)
}
case time.Time:
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case *time.Time:
if data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem())
} else {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{}))
}
case string:
if t, err := now.Parse(data); err == nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t))
} else {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return nil
}
case *time.Time:
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
switch data := v.(type) {
case **time.Time:
if data != nil && *data != nil {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
}
case time.Time:
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(v))
case *time.Time:
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
case string:
if t, err := now.Parse(data); err == nil {
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
if v == "" {
return nil
}
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
fieldValue.Elem().Set(reflect.ValueOf(t))
} else {
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
}
default:
return fallbackSetter(ctx, value, v, field.Set)
}
return nil
}
default:
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
// pointer scanner
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
return field.Set(ctx, value, reflectV.Elem().Interface())
} else {
fieldValue := field.ReflectValueOf(ctx, value)
if fieldValue.IsNil() {
fieldValue.Set(reflect.New(field.FieldType.Elem()))
}
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = fieldValue.Interface().(sql.Scanner).Scan(v)
}
return
}
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
// struct scanner
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
reflectV := reflect.ValueOf(v)
if !reflectV.IsValid() {
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
return
} else if reflectV.Type().AssignableTo(field.FieldType) {
field.ReflectValueOf(ctx, value).Set(reflectV)
} else if reflectV.Kind() == reflect.Ptr {
return field.Set(ctx, value, reflectV.Elem().Interface())
} else {
if valuer, ok := v.(driver.Valuer); ok {
v, _ = valuer.Value()
}
err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v)
}
return
}
} else {
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
return fallbackSetter(ctx, value, v, field.Set)
}
}
}
}
if field.Serializer != nil {
var (
oldFieldSetter = field.Set
sameElemType bool
sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type()
)
if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr {
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
}
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
serializerType := serializerValue.Type()
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
if s, ok := v.(*serializer); ok {
if s.fieldValue != nil {
err = oldFieldSetter(ctx, value, s.fieldValue)
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
if sameElemType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
} else if sameType {
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
}
si := reflect.New(serializerType)
si.Elem().Set(serializerValue)
s.Serializer = si.Interface().(SerializerInterface)
}
} else {
err = oldFieldSetter(ctx, value, v)
}
return
}
}
}
func (field *Field) setupNewValuePool() {
if field.Serializer != nil {
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
serializerType := serializerValue.Type()
field.NewValuePool = &sync.Pool{
New: func() interface{} {
si := reflect.New(serializerType)
si.Elem().Set(serializerValue)
return &serializer{
Field: field,
Serializer: si.Interface().(SerializerInterface),
}
},
}
}
if field.NewValuePool == nil {
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
}
}

166
vendor/gorm.io/gorm/schema/index.go generated vendored Normal file
View File

@@ -0,0 +1,166 @@
package schema
import (
"fmt"
"sort"
"strconv"
"strings"
)
type Index struct {
Name string
Class string // UNIQUE | FULLTEXT | SPATIAL
Type string // btree, hash, gist, spgist, gin, and brin
Where string
Comment string
Option string // WITH PARSER parser_name
Fields []IndexOption // Note: IndexOption's Field maybe the same
}
type IndexOption struct {
*Field
Expression string
Sort string // DESC, ASC
Collate string
Length int
priority int
}
// ParseIndexes parse schema indexes
func (schema *Schema) ParseIndexes() map[string]Index {
indexes := map[string]Index{}
for _, field := range schema.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
fieldIndexes, err := parseFieldIndexes(field)
if err != nil {
schema.err = err
break
}
for _, index := range fieldIndexes {
idx := indexes[index.Name]
idx.Name = index.Name
if idx.Class == "" {
idx.Class = index.Class
}
if idx.Type == "" {
idx.Type = index.Type
}
if idx.Where == "" {
idx.Where = index.Where
}
if idx.Comment == "" {
idx.Comment = index.Comment
}
if idx.Option == "" {
idx.Option = index.Option
}
idx.Fields = append(idx.Fields, index.Fields...)
sort.Slice(idx.Fields, func(i, j int) bool {
return idx.Fields[i].priority < idx.Fields[j].priority
})
indexes[index.Name] = idx
}
}
}
for _, index := range indexes {
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
index.Fields[0].Field.UniqueIndex = index.Name
}
}
return indexes
}
func (schema *Schema) LookIndex(name string) *Index {
if schema != nil {
indexes := schema.ParseIndexes()
for _, index := range indexes {
if index.Name == name {
return &index
}
for _, field := range index.Fields {
if field.Name == name {
return &index
}
}
}
}
return nil
}
func parseFieldIndexes(field *Field) (indexes []Index, err error) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUEINDEX" {
var (
name string
tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",")
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
settings = ParseTagSetting(tagSetting, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
)
if idx == -1 {
idx = len(tag)
}
if idx != -1 {
name = tag[0:idx]
}
if name == "" {
subName := field.Name
const key = "COMPOSITE"
if composite, found := settings[key]; found {
if len(composite) == 0 || composite == key {
err = fmt.Errorf(
"The composite tag of %s.%s cannot be empty",
field.Schema.Name,
field.Name)
return
}
subName = composite
}
name = field.Schema.namer.IndexName(
field.Schema.Table, subName)
}
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
settings["CLASS"] = "UNIQUE"
}
priority, err := strconv.Atoi(settings["PRIORITY"])
if err != nil {
priority = 10
}
indexes = append(indexes, Index{
Name: name,
Class: settings["CLASS"],
Type: settings["TYPE"],
Where: settings["WHERE"],
Comment: settings["COMMENT"],
Option: settings["OPTION"],
Fields: []IndexOption{{
Field: field,
Expression: settings["EXPRESSION"],
Sort: settings["SORT"],
Collate: settings["COLLATE"],
Length: length,
priority: priority,
}},
})
}
}
}
err = nil
return
}

42
vendor/gorm.io/gorm/schema/interfaces.go generated vendored Normal file
View File

@@ -0,0 +1,42 @@
package schema
import (
"gorm.io/gorm/clause"
)
// ConstraintInterface database constraint interface
type ConstraintInterface interface {
GetName() string
Build() (sql string, vars []interface{})
}
// GormDataTypeInterface gorm data type interface
type GormDataTypeInterface interface {
GormDataType() string
}
// FieldNewValuePool field new scan value pool
type FieldNewValuePool interface {
Get() interface{}
Put(interface{})
}
// CreateClausesInterface create clauses interface
type CreateClausesInterface interface {
CreateClauses(*Field) []clause.Interface
}
// QueryClausesInterface query clauses interface
type QueryClausesInterface interface {
QueryClauses(*Field) []clause.Interface
}
// UpdateClausesInterface update clauses interface
type UpdateClausesInterface interface {
UpdateClauses(*Field) []clause.Interface
}
// DeleteClausesInterface delete clauses interface
type DeleteClausesInterface interface {
DeleteClauses(*Field) []clause.Interface
}

194
vendor/gorm.io/gorm/schema/naming.go generated vendored Normal file
View File

@@ -0,0 +1,194 @@
package schema
import (
"crypto/sha1"
"encoding/hex"
"regexp"
"strings"
"unicode/utf8"
"github.com/jinzhu/inflection"
)
// Namer namer interface
type Namer interface {
TableName(table string) string
SchemaName(table string) string
ColumnName(table, column string) string
JoinTableName(joinTable string) string
RelationshipFKName(Relationship) string
CheckerName(table, column string) string
IndexName(table, column string) string
UniqueName(table, column string) string
}
// Replacer replacer interface like strings.Replacer
type Replacer interface {
Replace(name string) string
}
var _ Namer = (*NamingStrategy)(nil)
// NamingStrategy tables, columns naming strategy
type NamingStrategy struct {
TablePrefix string
SingularTable bool
NameReplacer Replacer
NoLowerCase bool
IdentifierMaxLength int
}
// TableName convert string to table name
func (ns NamingStrategy) TableName(str string) string {
if ns.SingularTable {
return ns.TablePrefix + ns.toDBName(str)
}
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
func (ns NamingStrategy) SchemaName(table string) string {
table = strings.TrimPrefix(table, ns.TablePrefix)
if ns.SingularTable {
return ns.toSchemaName(table)
}
return ns.toSchemaName(inflection.Singular(table))
}
// ColumnName convert string to column name
func (ns NamingStrategy) ColumnName(table, column string) string {
return ns.toDBName(column)
}
// JoinTableName convert string to join table name
func (ns NamingStrategy) JoinTableName(str string) string {
if !ns.NoLowerCase && strings.ToLower(str) == str {
return ns.TablePrefix + str
}
if ns.SingularTable {
return ns.TablePrefix + ns.toDBName(str)
}
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
}
// RelationshipFKName generate fk name for relation
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name))
}
// CheckerName generate checker name
func (ns NamingStrategy) CheckerName(table, column string) string {
return ns.formatName("chk", table, column)
}
// IndexName generate index name
func (ns NamingStrategy) IndexName(table, column string) string {
return ns.formatName("idx", table, ns.toDBName(column))
}
// UniqueName generate unique constraint name
func (ns NamingStrategy) UniqueName(table, column string) string {
return ns.formatName("uni", table, ns.toDBName(column))
}
func (ns NamingStrategy) formatName(prefix, table, name string) string {
formattedName := strings.ReplaceAll(strings.Join([]string{
prefix, table, name,
}, "_"), ".", "_")
if ns.IdentifierMaxLength == 0 {
ns.IdentifierMaxLength = 64
}
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
h := sha1.New()
h.Write([]byte(formattedName))
bs := h.Sum(nil)
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
}
return formattedName
}
var (
// https://github.com/golang/lint/blob/master/lint.go#L770
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
commonInitialismsReplacer *strings.Replacer
)
func init() {
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
for _, initialism := range commonInitialisms {
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
}
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
}
func (ns NamingStrategy) toDBName(name string) string {
if name == "" {
return ""
}
if ns.NameReplacer != nil {
tmpName := ns.NameReplacer.Replace(name)
if tmpName == "" {
return name
}
name = tmpName
}
if ns.NoLowerCase {
return name
}
var (
value = commonInitialismsReplacer.Replace(name)
buf strings.Builder
lastCase, nextCase, nextNumber bool // upper case == true
curCase = value[0] <= 'Z' && value[0] >= 'A'
)
for i, v := range value[:len(value)-1] {
nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
if curCase {
if lastCase && (nextCase || nextNumber) {
buf.WriteRune(v + 32)
} else {
if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
buf.WriteByte('_')
}
buf.WriteRune(v + 32)
}
} else {
buf.WriteRune(v)
}
lastCase = curCase
curCase = nextCase
}
if curCase {
if !lastCase && len(value) > 1 {
buf.WriteByte('_')
}
buf.WriteByte(value[len(value)-1] + 32)
} else {
buf.WriteByte(value[len(value)-1])
}
ret := buf.String()
return ret
}
func (ns NamingStrategy) toSchemaName(name string) string {
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
for _, initialism := range commonInitialisms {
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
}
return result
}

19
vendor/gorm.io/gorm/schema/pool.go generated vendored Normal file
View File

@@ -0,0 +1,19 @@
package schema
import (
"reflect"
"sync"
)
// sync pools
var (
normalPool sync.Map
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
New: func() interface{} {
return reflect.New(reflectType).Interface()
},
})
return v.(FieldNewValuePool)
}
)

764
vendor/gorm.io/gorm/schema/relationship.go generated vendored Normal file
View File

@@ -0,0 +1,764 @@
package schema
import (
"context"
"fmt"
"reflect"
"strings"
"github.com/jinzhu/inflection"
"gorm.io/gorm/clause"
)
// RelationshipType relationship type
type RelationshipType string
const (
HasOne RelationshipType = "has_one" // HasOneRel has one relationship
HasMany RelationshipType = "has_many" // HasManyRel has many relationship
BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship
Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship
has RelationshipType = "has"
)
type Relationships struct {
HasOne []*Relationship
BelongsTo []*Relationship
HasMany []*Relationship
Many2Many []*Relationship
Relations map[string]*Relationship
EmbeddedRelations map[string]*Relationships
}
type Relationship struct {
Name string
Type RelationshipType
Field *Field
Polymorphic *Polymorphic
References []*Reference
Schema *Schema
FieldSchema *Schema
JoinTable *Schema
foreignKeys, primaryKeys []string
}
type Polymorphic struct {
PolymorphicID *Field
PolymorphicType *Field
Value string
}
type Reference struct {
PrimaryKey *Field
PrimaryValue string
ForeignKey *Field
OwnPrimaryKey bool
}
func (schema *Schema) parseRelation(field *Field) *Relationship {
var (
err error
fieldValue = reflect.New(field.IndirectFieldType).Interface()
relation = &Relationship{
Name: field.Name,
Field: field,
Schema: schema,
foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
primaryKeys: toColumns(field.TagSettings["REFERENCES"]),
}
)
cacheStore := schema.cacheStore
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
schema.err = err
return nil
}
if hasPolymorphicRelation(field.TagSettings) {
schema.buildPolymorphicRelation(relation, field)
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
schema.buildMany2ManyRelation(relation, field, many2many)
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
schema.guessRelation(relation, field, guessBelongs)
} else {
switch field.IndirectFieldType.Kind() {
case reflect.Struct:
schema.guessRelation(relation, field, guessGuess)
case reflect.Slice:
schema.guessRelation(relation, field, guessHas)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
field.Name)
}
}
if relation.Type == has {
// don't add relations to embedded schema, which might be shared
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
}
switch field.IndirectFieldType.Kind() {
case reflect.Struct:
relation.Type = HasOne
case reflect.Slice:
relation.Type = HasMany
}
}
if schema.err == nil {
schema.setRelation(relation)
switch relation.Type {
case HasOne:
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
case HasMany:
schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation)
case BelongsTo:
schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation)
case Many2Many:
schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation)
}
}
return relation
}
// hasPolymorphicRelation check if has polymorphic relation
// 1. `POLYMORPHIC` tag
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
func hasPolymorphicRelation(tagSettings map[string]string) bool {
if _, ok := tagSettings["POLYMORPHIC"]; ok {
return true
}
_, hasType := tagSettings["POLYMORPHICTYPE"]
_, hasId := tagSettings["POLYMORPHICID"]
return hasType && hasId
}
func (schema *Schema) setRelation(relation *Relationship) {
// set non-embedded relation
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
if len(rel.Field.BindNames) > 1 {
schema.Relationships.Relations[relation.Name] = relation
}
} else {
schema.Relationships.Relations[relation.Name] = relation
}
// set embedded relation
if len(relation.Field.BindNames) <= 1 {
return
}
relationships := &schema.Relationships
for i, name := range relation.Field.BindNames {
if i < len(relation.Field.BindNames)-1 {
if relationships.EmbeddedRelations == nil {
relationships.EmbeddedRelations = map[string]*Relationships{}
}
if r := relationships.EmbeddedRelations[name]; r == nil {
relationships.EmbeddedRelations[name] = &Relationships{}
}
relationships = relationships.EmbeddedRelations[name]
} else {
if relationships.Relations == nil {
relationships.Relations = map[string]*Relationship{}
}
relationships.Relations[relation.Name] = relation
}
}
}
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
//
// type User struct {
// Toys []Toy `gorm:"polymorphic:Owner;"`
// }
// type Pet struct {
// Toy Toy `gorm:"polymorphic:Owner;"`
// }
// type Toy struct {
// OwnerID int
// OwnerType string
// }
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
polymorphic := field.TagSettings["POLYMORPHIC"]
relation.Polymorphic = &Polymorphic{
Value: schema.Table,
}
var (
typeName = polymorphic + "Type"
typeId = polymorphic + "ID"
)
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
typeName = strings.TrimSpace(value)
}
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
typeId = strings.TrimSpace(value)
}
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
relation.Polymorphic.Value = strings.TrimSpace(value)
}
if relation.Polymorphic.PolymorphicType == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
}
if relation.Polymorphic.PolymorphicID == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
}
if schema.err == nil {
relation.References = append(relation.References, &Reference{
PrimaryValue: relation.Polymorphic.Value,
ForeignKey: relation.Polymorphic.PolymorphicType,
})
primaryKeyField := schema.PrioritizedPrimaryField
if len(relation.foreignKeys) > 0 {
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
schema, field.Name)
}
}
if primaryKeyField == nil {
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
relation.FieldSchema, schema, field.Name)
return
}
// use same data type for foreign keys
if copyableDataType(primaryKeyField.DataType) {
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
}
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
if relation.Polymorphic.PolymorphicID.Size == 0 {
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
}
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryKeyField,
ForeignKey: relation.Polymorphic.PolymorphicID,
OwnPrimaryKey: true,
})
}
relation.Type = has
}
func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
relation.Type = Many2Many
var (
err error
joinTableFields []reflect.StructField
fieldsMap = map[string]*Field{}
ownFieldsMap = map[string]*Field{} // fix self join many2many
referFieldsMap = map[string]*Field{}
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
)
ownForeignFields := schema.PrimaryFields
refForeignFields := relation.FieldSchema.PrimaryFields
if len(relation.foreignKeys) > 0 {
ownForeignFields = []*Field{}
for _, foreignKey := range relation.foreignKeys {
if field := schema.LookUpField(foreignKey); field != nil {
ownForeignFields = append(ownForeignFields, field)
} else {
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
return
}
}
}
if len(relation.primaryKeys) > 0 {
refForeignFields = []*Field{}
for _, foreignKey := range relation.primaryKeys {
if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
refForeignFields = append(refForeignFields, field)
} else {
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
return
}
}
}
for idx, ownField := range ownForeignFields {
joinFieldName := strings.Title(schema.Name) + ownField.Name
if len(joinForeignKeys) > idx {
joinFieldName = strings.Title(joinForeignKeys[idx])
}
ownFieldsMap[joinFieldName] = ownField
fieldsMap[joinFieldName] = ownField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: ownField.StructField.PkgPath,
Type: ownField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
for idx, relField := range refForeignFields {
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
if _, ok := ownFieldsMap[joinFieldName]; ok {
if field.Name != relation.FieldSchema.Name {
joinFieldName = inflection.Singular(field.Name) + relField.Name
} else {
joinFieldName += "Reference"
}
}
if len(joinReferences) > idx {
joinFieldName = strings.Title(joinReferences[idx])
}
referFieldsMap[joinFieldName] = relField
if _, ok := fieldsMap[joinFieldName]; !ok {
fieldsMap[joinFieldName] = relField
joinTableFields = append(joinTableFields, reflect.StructField{
Name: joinFieldName,
PkgPath: relField.StructField.PkgPath,
Type: relField.StructField.Type,
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
"column", "autoincrement", "index", "unique", "uniqueindex"),
})
}
}
joinTableFields = append(joinTableFields, reflect.StructField{
Name: strings.Title(schema.Name) + field.Name,
Type: schema.ModelType,
Tag: `gorm:"-"`,
})
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
relName := relation.Schema.Name
relRefName := relation.FieldSchema.Name
if relName == relRefName {
relRefName = relation.Field.Name
}
if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok {
relation.JoinTable.Relationships.Relations[relName] = &Relationship{
Name: relName,
Type: BelongsTo,
Schema: relation.JoinTable,
FieldSchema: relation.Schema,
}
} else {
relation.JoinTable.Relationships.Relations[relName].References = []*Reference{}
}
if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok {
relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{
Name: relRefName,
Type: BelongsTo,
Schema: relation.JoinTable,
FieldSchema: relation.FieldSchema,
}
} else {
relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{}
}
// build references
for _, f := range relation.JoinTable.Fields {
if f.Creatable || f.Readable || f.Updatable {
// use same data type for foreign keys
if copyableDataType(fieldsMap[f.Name].DataType) {
f.DataType = fieldsMap[f.Name].DataType
}
f.GORMDataType = fieldsMap[f.Name].GORMDataType
if f.Size == 0 {
f.Size = fieldsMap[f.Name].Size
}
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
if of, ok := ownFieldsMap[f.Name]; ok {
joinRel := relation.JoinTable.Relationships.Relations[relName]
joinRel.Field = relation.Field
joinRel.References = append(joinRel.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: of,
ForeignKey: f,
OwnPrimaryKey: true,
})
}
if rf, ok := referFieldsMap[f.Name]; ok {
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
if joinRefRel.Field == nil {
joinRefRel.Field = relation.Field
}
joinRefRel.References = append(joinRefRel.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
relation.References = append(relation.References, &Reference{
PrimaryKey: rf,
ForeignKey: f,
})
}
}
}
}
type guessLevel int
const (
guessGuess guessLevel = iota
guessBelongs
guessEmbeddedBelongs
guessHas
guessEmbeddedHas
)
func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) {
var (
primaryFields, foreignFields []*Field
primarySchema, foreignSchema = schema, relation.FieldSchema
gl = cgl
)
if gl == guessGuess {
if field.Schema == relation.FieldSchema {
gl = guessBelongs
} else {
gl = guessHas
}
}
reguessOrErr := func() {
switch cgl {
case guessGuess:
schema.guessRelation(relation, field, guessBelongs)
case guessBelongs:
schema.guessRelation(relation, field, guessEmbeddedBelongs)
case guessEmbeddedBelongs:
schema.guessRelation(relation, field, guessHas)
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
// case guessEmbeddedHas:
default:
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
schema, field.Name)
}
}
switch gl {
case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
case guessHas:
case guessEmbeddedHas:
if field.OwnerSchema == nil {
reguessOrErr()
return
}
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
}
if len(relation.foreignKeys) > 0 {
for _, foreignKey := range relation.foreignKeys {
f := foreignSchema.LookUpField(foreignKey)
if f == nil {
reguessOrErr()
return
}
foreignFields = append(foreignFields, f)
}
} else {
primarySchemaName := primarySchema.Name
if primarySchemaName == "" {
primarySchemaName = relation.FieldSchema.Name
}
if len(relation.primaryKeys) > 0 {
for _, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil {
primaryFields = append(primaryFields, f)
}
}
} else {
primaryFields = primarySchema.PrimaryFields
}
primaryFieldLoop:
for _, primaryField := range primaryFields {
lookUpName := primarySchemaName + primaryField.Name
if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name
}
lookUpNames := []string{lookUpName}
if len(primaryFields) == 1 {
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
for _, name := range lookUpNames {
if f := foreignSchema.LookUpField(name); f != nil {
foreignFields = append(foreignFields, f)
primaryFields = append(primaryFields, primaryField)
continue primaryFieldLoop
}
}
}
}
switch {
case len(foreignFields) == 0:
reguessOrErr()
return
case len(relation.primaryKeys) > 0:
for idx, primaryKey := range relation.primaryKeys {
if f := primarySchema.LookUpField(primaryKey); f != nil {
if len(primaryFields) < idx+1 {
primaryFields = append(primaryFields, f)
} else if f != primaryFields[idx] {
reguessOrErr()
return
}
} else {
reguessOrErr()
return
}
}
case len(primaryFields) == 0:
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
} else {
reguessOrErr()
return
}
}
// build references
for idx, foreignField := range foreignFields {
// use same data type for foreign keys
if copyableDataType(primaryFields[idx].DataType) {
foreignField.DataType = primaryFields[idx].DataType
}
foreignField.GORMDataType = primaryFields[idx].GORMDataType
if foreignField.Size == 0 {
foreignField.Size = primaryFields[idx].Size
}
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx],
ForeignKey: foreignField,
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
})
}
if gl == guessHas || gl == guessEmbeddedHas {
relation.Type = has
} else {
relation.Type = BelongsTo
}
}
// Constraint is ForeignKey Constraint
type Constraint struct {
Name string
Field *Field
Schema *Schema
ForeignKeys []*Field
ReferenceSchema *Schema
References []*Field
OnDelete string
OnUpdate string
}
func (constraint *Constraint) GetName() string { return constraint.Name }
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
references := make([]interface{}, 0, len(constraint.References))
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (rel *Relationship) ParseConstraint() *Constraint {
str := rel.Field.TagSettings["CONSTRAINT"]
if str == "-" {
return nil
}
if rel.Type == BelongsTo {
for _, r := range rel.FieldSchema.Relationships.Relations {
if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) {
matched := true
for idx, ref := range r.References {
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
matched = false
}
}
if matched {
return nil
}
}
}
}
var (
name string
idx = strings.Index(str, ",")
settings = ParseTagSetting(str, ",")
)
// optimize match english letters and midline
// The following code is basically called in for.
// In order to avoid the performance problems caused by repeated compilation of regular expressions,
// it only needs to be done once outside, so optimization is done here.
if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) {
name = str[0:idx]
} else {
name = rel.Schema.namer.RelationshipFKName(*rel)
}
constraint := Constraint{
Name: name,
Field: rel.Field,
OnUpdate: settings["ONUPDATE"],
OnDelete: settings["ONDELETE"],
}
for _, ref := range rel.References {
if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) {
constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
constraint.References = append(constraint.References, ref.PrimaryKey)
if ref.OwnPrimaryKey {
constraint.Schema = ref.ForeignKey.Schema
constraint.ReferenceSchema = rel.Schema
} else {
constraint.Schema = rel.Schema
constraint.ReferenceSchema = ref.PrimaryKey.Schema
}
}
}
return &constraint
}
func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table
foreignFields := []*Field{}
relForeignKeys := []string{}
if rel.JoinTable != nil {
table = rel.JoinTable.Table
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
} else if ref.PrimaryValue != "" {
conds = append(conds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
} else {
conds = append(conds, clause.Eq{
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName},
})
}
}
} else {
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
foreignFields = append(foreignFields, ref.PrimaryKey)
} else if ref.PrimaryValue != "" {
conds = append(conds, clause.Eq{
Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName},
Value: ref.PrimaryValue,
})
} else {
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
foreignFields = append(foreignFields, ref.ForeignKey)
}
}
}
_, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
conds = append(conds, clause.IN{Column: column, Values: values})
return
}
func copyableDataType(str DataType) bool {
for _, s := range []string{"auto_increment", "primary key"} {
if strings.Contains(strings.ToLower(string(str)), s) {
return false
}
}
return true
}

423
vendor/gorm.io/gorm/schema/schema.go generated vendored Normal file
View File

@@ -0,0 +1,423 @@
package schema
import (
"context"
"errors"
"fmt"
"go/ast"
"reflect"
"strings"
"sync"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
)
type callbackType string
const (
callbackTypeBeforeCreate callbackType = "BeforeCreate"
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
callbackTypeAfterCreate callbackType = "AfterCreate"
callbackTypeAfterUpdate callbackType = "AfterUpdate"
callbackTypeBeforeSave callbackType = "BeforeSave"
callbackTypeAfterSave callbackType = "AfterSave"
callbackTypeBeforeDelete callbackType = "BeforeDelete"
callbackTypeAfterDelete callbackType = "AfterDelete"
callbackTypeAfterFind callbackType = "AfterFind"
)
// ErrUnsupportedDataType unsupported data type
var ErrUnsupportedDataType = errors.New("unsupported data type")
type Schema struct {
Name string
ModelType reflect.Type
Table string
PrioritizedPrimaryField *Field
DBNames []string
PrimaryFields []*Field
PrimaryFieldDBNames []string
Fields []*Field
FieldsByName map[string]*Field
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships
CreateClauses []clause.Interface
QueryClauses []clause.Interface
UpdateClauses []clause.Interface
DeleteClauses []clause.Interface
BeforeCreate, AfterCreate bool
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
err error
initialized chan struct{}
namer Namer
cacheStore *sync.Map
}
func (schema Schema) String() string {
if schema.ModelType.Name() == "" {
return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
}
return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
}
func (schema Schema) MakeSlice() reflect.Value {
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
results := reflect.New(slice.Type())
results.Elem().Set(slice)
return results
}
func (schema Schema) LookUpField(name string) *Field {
if field, ok := schema.FieldsByDBName[name]; ok {
return field
}
if field, ok := schema.FieldsByName[name]; ok {
return field
}
return nil
}
// LookUpFieldByBindName looks for the closest field in the embedded struct.
//
// type Struct struct {
// Embedded struct {
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
// }
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
// }
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
if len(bindNames) == 0 {
return nil
}
for i := len(bindNames) - 1; i >= 0; i-- {
find := strings.Join(bindNames[:i], ".") + "." + name
if field, ok := schema.FieldsByBindName[find]; ok {
return field
}
}
return nil
}
type Tabler interface {
TableName() string
}
type TablerWithNamer interface {
TableName(Namer) string
}
// Parse get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
}
// ParseWithSpecialTableName get data type from dialector with extra schema table
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
if dest == nil {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
value := reflect.ValueOf(dest)
if value.Kind() == reflect.Ptr && value.IsNil() {
value = reflect.New(value.Type().Elem())
}
modelType := reflect.Indirect(value).Type()
if modelType.Kind() == reflect.Interface {
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
}
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
// Cache the Schema for performance,
// Use the modelType or modelType + schemaTable (if it present) as cache key.
var schemaCacheKey interface{}
if specialTableName != "" {
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
} else {
schemaCacheKey = modelType
}
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
modelValue := reflect.New(modelType)
tableName := namer.TableName(modelType.Name())
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
}
schema := &Schema{
Name: modelType.Name(),
ModelType: modelType,
Table: tableName,
FieldsByName: map[string]*Field{},
FieldsByBindName: map[string]*Field{},
FieldsByDBName: map[string]*Field{},
Relationships: Relationships{Relations: map[string]*Relationship{}},
cacheStore: cacheStore,
namer: namer,
initialized: make(chan struct{}),
}
// When the schema initialization is completed, the channel will be closed
defer close(schema.initialized)
// Load exist schema cache, return if exists
if v, ok := cacheStore.Load(schemaCacheKey); ok {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
} else {
schema.Fields = append(schema.Fields, field)
}
}
}
for _, field := range schema.Fields {
if field.DBName == "" && field.DataType != "" {
field.DBName = namer.ColumnName(schema.Table, field.Name)
}
bindName := field.BindName()
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
schema.FieldsByDBName[field.DBName] = field
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[bindName] = field
if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields {
if f == v {
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
}
}
}
if field.PrimaryKey {
schema.PrimaryFields = append(schema.PrimaryFields, field)
}
}
}
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field
}
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByBindName[bindName] = field
}
field.setupValuerAndSetter()
}
prioritizedPrimaryField := schema.LookUpField("id")
if prioritizedPrimaryField == nil {
prioritizedPrimaryField = schema.LookUpField("ID")
}
if prioritizedPrimaryField != nil {
if prioritizedPrimaryField.PrimaryKey {
schema.PrioritizedPrimaryField = prioritizedPrimaryField
} else if len(schema.PrimaryFields) == 0 {
prioritizedPrimaryField.PrimaryKey = true
schema.PrioritizedPrimaryField = prioritizedPrimaryField
schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
}
}
if schema.PrioritizedPrimaryField == nil {
if len(schema.PrimaryFields) == 1 {
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
} else if len(schema.PrimaryFields) > 1 {
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
for _, field := range schema.PrimaryFields {
if field.AutoIncrement {
schema.PrioritizedPrimaryField = field
break
}
}
}
}
for _, field := range schema.PrimaryFields {
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
}
for _, field := range schema.Fields {
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
}
if field := schema.PrioritizedPrimaryField; field != nil {
switch field.GORMDataType {
case Int, Uint:
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
}
field.HasDefaultValue = true
field.AutoIncrement = true
}
}
}
callbackTypes := []callbackType{
callbackTypeBeforeCreate, callbackTypeAfterCreate,
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
callbackTypeBeforeSave, callbackTypeAfterSave,
callbackTypeBeforeDelete, callbackTypeAfterDelete,
callbackTypeAfterFind,
}
for _, cbName := range callbackTypes {
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
switch methodValue.Type().String() {
case "func(*gorm.DB) error": // TODO hack
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
default:
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
}
}
}
// Cache the schema
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
s := v.(*Schema)
// Wait for the initialization of other goroutines to complete
<-s.initialized
return s, s.err
}
defer func() {
if schema.err != nil {
logger.Default.Error(context.Background(), schema.err.Error())
cacheStore.Delete(modelType)
}
}()
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
for _, field := range schema.Fields {
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
} else {
schema.FieldsByName[field.Name] = field
schema.FieldsByBindName[field.BindName()] = field
}
}
fieldValue := reflect.New(field.IndirectFieldType)
fieldInterface := fieldValue.Interface()
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
}
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
}
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
}
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
}
}
}
return schema, schema.err
}
// This unrolling is needed to show to the compiler the exact set of methods
// that can be used on the modelType.
// Prior to go1.22 any use of MethodByName would cause the linker to
// abandon dead code elimination for the entire binary.
// As of go1.22 the compiler supports one special case of a string constant
// being passed to MethodByName. For enterprise customers or those building
// large binaries, this gives a significant reduction in binary size.
// https://github.com/golang/go/issues/62257
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
switch cbType {
case callbackTypeBeforeCreate:
return modelType.MethodByName(string(callbackTypeBeforeCreate))
case callbackTypeAfterCreate:
return modelType.MethodByName(string(callbackTypeAfterCreate))
case callbackTypeBeforeUpdate:
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
case callbackTypeAfterUpdate:
return modelType.MethodByName(string(callbackTypeAfterUpdate))
case callbackTypeBeforeSave:
return modelType.MethodByName(string(callbackTypeBeforeSave))
case callbackTypeAfterSave:
return modelType.MethodByName(string(callbackTypeAfterSave))
case callbackTypeBeforeDelete:
return modelType.MethodByName(string(callbackTypeBeforeDelete))
case callbackTypeAfterDelete:
return modelType.MethodByName(string(callbackTypeAfterDelete))
case callbackTypeAfterFind:
return modelType.MethodByName(string(callbackTypeAfterFind))
default:
return reflect.ValueOf(nil)
}
}
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}
if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil
}
return Parse(dest, cacheStore, namer)
}

170
vendor/gorm.io/gorm/schema/serializer.go generated vendored Normal file
View File

@@ -0,0 +1,170 @@
package schema
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/gob"
"encoding/json"
"fmt"
"reflect"
"strings"
"sync"
"time"
)
var serializerMap = sync.Map{}
// RegisterSerializer register serializer
func RegisterSerializer(name string, serializer SerializerInterface) {
serializerMap.Store(strings.ToLower(name), serializer)
}
// GetSerializer get serializer
func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
v, ok := serializerMap.Load(strings.ToLower(name))
if ok {
serializer, ok = v.(SerializerInterface)
}
return serializer, ok
}
func init() {
RegisterSerializer("json", JSONSerializer{})
RegisterSerializer("unixtime", UnixSecondSerializer{})
RegisterSerializer("gob", GobSerializer{})
}
// Serializer field value serializer
type serializer struct {
Field *Field
Serializer SerializerInterface
SerializeValuer SerializerValuerInterface
Destination reflect.Value
Context context.Context
value interface{}
fieldValue interface{}
}
// Scan implements sql.Scanner interface
func (s *serializer) Scan(value interface{}) error {
s.value = value
return nil
}
// Value implements driver.Valuer interface
func (s serializer) Value() (driver.Value, error) {
return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue)
}
// SerializerInterface serializer interface
type SerializerInterface interface {
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error
SerializerValuerInterface
}
// SerializerValuerInterface serializer valuer interface
type SerializerValuerInterface interface {
Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error)
}
// JSONSerializer json serializer
type JSONSerializer struct{}
// Scan implements serializer interface
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
}
if len(bytes) > 0 {
err = json.Unmarshal(bytes, fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
result, err := json.Marshal(fieldValue)
if string(result) == "null" {
if field.TagSettings["NOT NULL"] != "" {
return "", nil
}
return nil, err
}
return string(result), err
}
// UnixSecondSerializer json serializer
type UnixSecondSerializer struct{}
// Scan implements serializer interface
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
t := sql.NullTime{}
if err = t.Scan(dbValue); err == nil && t.Valid {
err = field.Set(ctx, dst, t.Time.Unix())
}
return
}
// Value implements serializer interface
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
rv := reflect.ValueOf(fieldValue)
switch v := fieldValue.(type) {
case int64, int, uint, uint64, int32, uint32, int16, uint16:
result = time.Unix(reflect.Indirect(rv).Int(), 0)
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
if rv.IsZero() {
return nil, nil
}
result = time.Unix(reflect.Indirect(rv).Int(), 0)
default:
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
}
return
}
// GobSerializer gob serializer
type GobSerializer struct{}
// Scan implements serializer interface
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
fieldValue := reflect.New(field.FieldType)
if dbValue != nil {
var bytesValue []byte
switch v := dbValue.(type) {
case []byte:
bytesValue = v
default:
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
}
if len(bytesValue) > 0 {
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
err = decoder.Decode(fieldValue.Interface())
}
}
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
return
}
// Value implements serializer interface
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
buf := new(bytes.Buffer)
err := gob.NewEncoder(buf).Encode(fieldValue)
return buf.Bytes(), err
}

213
vendor/gorm.io/gorm/schema/utils.go generated vendored Normal file
View File

@@ -0,0 +1,213 @@
package schema
import (
"context"
"fmt"
"reflect"
"regexp"
"strings"
"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
)
var embeddedCacheKey = "embedded_cache_store"
func ParseTagSetting(str string, sep string) map[string]string {
settings := map[string]string{}
names := strings.Split(str, sep)
for i := 0; i < len(names); i++ {
j := i
if len(names[j]) > 0 {
for {
if names[j][len(names[j])-1] == '\\' {
i++
names[j] = names[j][0:len(names[j])-1] + sep + names[i]
names[i] = ""
} else {
break
}
}
}
values := strings.Split(names[j], ":")
k := strings.TrimSpace(strings.ToUpper(values[0]))
if len(values) >= 2 {
settings[k] = strings.Join(values[1:], ":")
} else if k != "" {
settings[k] = k
}
}
return settings
}
func toColumns(val string) (results []string) {
if val != "" {
for _, v := range strings.Split(val, ",") {
results = append(results, strings.TrimSpace(v))
}
}
return
}
func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag {
for _, name := range names {
tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}"))
}
return tag
}
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
t := tag.Get("gorm")
if strings.Contains(t, value) {
return tag
}
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
}
// GetRelationsValues get relations's values from a reflect value
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
for _, rel := range rels {
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
appendToResults := func(value reflect.Value) {
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value))
switch result.Kind() {
case reflect.Struct:
reflectResults = reflect.Append(reflectResults, result.Addr())
case reflect.Slice, reflect.Array:
for i := 0; i < result.Len(); i++ {
if elem := result.Index(i); elem.Kind() == reflect.Ptr {
reflectResults = reflect.Append(reflectResults, elem)
} else {
reflectResults = reflect.Append(reflectResults, elem.Addr())
}
}
}
}
}
switch reflectValue.Kind() {
case reflect.Struct:
appendToResults(reflectValue)
case reflect.Slice:
for i := 0; i < reflectValue.Len(); i++ {
appendToResults(reflectValue.Index(i))
}
}
reflectValue = reflectResults
}
return
}
// GetIdentityFieldValuesMap get identity map from fields
func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
var (
results = [][]interface{}{}
dataResults = map[string][]reflect.Value{}
loaded = map[interface{}]bool{}
notZero, zero bool
)
if reflectValue.Kind() == reflect.Ptr ||
reflectValue.Kind() == reflect.Interface {
reflectValue = reflectValue.Elem()
}
switch reflectValue.Kind() {
case reflect.Struct:
results = [][]interface{}{make([]interface{}, len(fields))}
for idx, field := range fields {
results[0][idx], zero = field.ValueOf(ctx, reflectValue)
notZero = notZero || !zero
}
if !notZero {
return nil, nil
}
dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
elem := reflectValue.Index(i)
elemKey := elem.Interface()
if elem.Kind() != reflect.Ptr && elem.CanAddr() {
elemKey = elem.Addr().Interface()
}
if _, ok := loaded[elemKey]; ok {
continue
}
loaded[elemKey] = true
fieldValues := make([]interface{}, len(fields))
notZero = false
for idx, field := range fields {
fieldValues[idx], zero = field.ValueOf(ctx, elem)
notZero = notZero || !zero
}
if notZero {
dataKey := utils.ToStringKey(fieldValues...)
if _, ok := dataResults[dataKey]; !ok {
results = append(results, fieldValues)
dataResults[dataKey] = []reflect.Value{elem}
} else {
dataResults[dataKey] = append(dataResults[dataKey], elem)
}
}
}
}
return dataResults, results
}
// GetIdentityFieldValuesMapFromValues get identity map from fields
func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
resultsMap := map[string][]reflect.Value{}
results := [][]interface{}{}
for _, v := range values {
rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields)
for k, v := range rm {
resultsMap[k] = append(resultsMap[k], v...)
}
results = append(results, rs...)
}
return resultsMap, results
}
// ToQueryValues to query values
func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) {
queryValues := make([]interface{}, len(foreignValues))
if len(foreignKeys) == 1 {
for idx, r := range foreignValues {
queryValues[idx] = r[0]
}
return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
}
columns := make([]clause.Column, len(foreignKeys))
for idx, key := range foreignKeys {
columns[idx] = clause.Column{Table: table, Name: key}
}
for idx, r := range foreignValues {
queryValues[idx] = r
}
return columns, queryValues
}
type embeddedNamer struct {
Table string
Namer
}

170
vendor/gorm.io/gorm/soft_delete.go generated vendored Normal file
View File

@@ -0,0 +1,170 @@
package gorm
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"reflect"
"github.com/jinzhu/now"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)
type DeletedAt sql.NullTime
// Scan implements the Scanner interface.
func (n *DeletedAt) Scan(value interface{}) error {
return (*sql.NullTime)(n).Scan(value)
}
// Value implements the driver Valuer interface.
func (n DeletedAt) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
func (n DeletedAt) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.Time)
}
return json.Marshal(nil)
}
func (n *DeletedAt) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
n.Valid = false
return nil
}
err := json.Unmarshal(b, &n.Time)
if err == nil {
n.Valid = true
}
return err
}
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
func parseZeroValueTag(f *schema.Field) sql.NullString {
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
if _, err := now.Parse(v); err == nil {
return sql.NullString{String: v, Valid: true}
}
}
return sql.NullString{Valid: false}
}
type SoftDeleteQueryClause struct {
ZeroValue sql.NullString
Field *schema.Field
}
func (sd SoftDeleteQueryClause) Name() string {
return ""
}
func (sd SoftDeleteQueryClause) Build(clause.Builder) {
}
func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped {
if c, ok := stmt.Clauses["WHERE"]; ok {
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 {
for _, expr := range where.Exprs {
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
c.Expression = where
stmt.Clauses["WHERE"] = c
break
}
}
}
}
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
}})
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
}
}
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
type SoftDeleteUpdateClause struct {
ZeroValue sql.NullString
Field *schema.Field
}
func (sd SoftDeleteUpdateClause) Name() string {
return ""
}
func (sd SoftDeleteUpdateClause) Build(clause.Builder) {
}
func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
}
}
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
}
type SoftDeleteDeleteClause struct {
ZeroValue sql.NullString
Field *schema.Field
}
func (sd SoftDeleteDeleteClause) Name() string {
return ""
}
func (sd SoftDeleteDeleteClause) Build(clause.Builder) {
}
func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
}
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
curTime := stmt.DB.NowFunc()
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
stmt.SetColumn(sd.Field.DBName, curTime, true)
if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}
}
}
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
stmt.AddClauseIfNotExists(clause.Update{})
stmt.Build(stmt.DB.Callback().Update().Clauses...)
}
}

742
vendor/gorm.io/gorm/statement.go generated vendored Normal file
View File

@@ -0,0 +1,742 @@
package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// Statement statement
type Statement struct {
*DB
TableExpr *clause.Expr
Table string
Model interface{}
Unscoped bool
Dest interface{}
ReflectValue reflect.Value
Clauses map[string]clause.Clause
BuildClauses []string
Distinct bool
Selects []string // selected columns
Omits []string // omit columns
Joins []join
Preloads map[string][]interface{}
Settings sync.Map
ConnPool ConnPool
Schema *schema.Schema
Context context.Context
RaiseErrorOnNotFound bool
SkipHooks bool
SQL strings.Builder
Vars []interface{}
CurDestIndex int
attrs []interface{}
assigns []interface{}
scopes []func(*DB) *DB
}
type join struct {
Name string
Conds []interface{}
On *clause.Where
Selects []string
Omits []string
JoinType clause.JoinType
}
// StatementModifier statement modifier interface
type StatementModifier interface {
ModifyStatement(*Statement)
}
// WriteString write string
func (stmt *Statement) WriteString(str string) (int, error) {
return stmt.SQL.WriteString(str)
}
// WriteByte write byte
func (stmt *Statement) WriteByte(c byte) error {
return stmt.SQL.WriteByte(c)
}
// WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(value interface{}) {
stmt.QuoteTo(&stmt.SQL, value)
}
// QuoteTo write quoted value to writer
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
write := func(raw bool, str string) {
if raw {
writer.WriteString(str)
} else {
stmt.DB.Dialector.QuoteTo(writer, str)
}
}
switch v := field.(type) {
case clause.Table:
if v.Name == clause.CurrentTable {
if stmt.TableExpr != nil {
stmt.TableExpr.Build(stmt)
} else {
write(v.Raw, stmt.Table)
}
} else {
write(v.Raw, v.Name)
}
if v.Alias != "" {
writer.WriteByte(' ')
write(v.Raw, v.Alias)
}
case clause.Column:
if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.')
}
if v.Name == clause.PrimaryKey {
if stmt.Schema == nil {
stmt.DB.AddError(ErrModelValueRequired)
} else if stmt.Schema.PrioritizedPrimaryField != nil {
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
} else if len(stmt.Schema.DBNames) > 0 {
write(v.Raw, stmt.Schema.DBNames[0])
} else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
}
} else {
write(v.Raw, v.Name)
}
if v.Alias != "" {
writer.WriteString(" AS ")
write(v.Raw, v.Alias)
}
case []clause.Column:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteByte(',')
}
stmt.QuoteTo(writer, d)
}
writer.WriteByte(')')
case clause.Expr:
v.Build(stmt)
case string:
stmt.DB.Dialector.QuoteTo(writer, v)
case []string:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteByte(',')
}
stmt.DB.Dialector.QuoteTo(writer, d)
}
writer.WriteByte(')')
default:
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
}
}
// Quote returns quoted value
func (stmt *Statement) Quote(field interface{}) string {
var builder strings.Builder
stmt.QuoteTo(&builder, field)
return builder.String()
}
// AddVar add var
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
for idx, v := range vars {
if idx > 0 {
writer.WriteByte(',')
}
switch v := v.(type) {
case sql.NamedArg:
stmt.Vars = append(stmt.Vars, v.Value)
case clause.Column, clause.Table:
stmt.QuoteTo(writer, v)
case Valuer:
reflectValue := reflect.ValueOf(v)
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
stmt.AddVar(writer, nil)
} else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
}
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression:
v.Build(stmt)
case driver.Valuer:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
case []byte:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
case []interface{}:
if len(v) > 0 {
writer.WriteByte('(')
stmt.AddVar(writer, v...)
writer.WriteByte(')')
} else {
writer.WriteString("(NULL)")
}
case *DB:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if v.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = v.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
writer.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
default:
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
writer.WriteString("(NULL)")
} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
} else {
writer.WriteByte('(')
for i := 0; i < rv.Len(); i++ {
if i > 0 {
writer.WriteByte(',')
}
stmt.AddVar(writer, rv.Index(i).Interface())
}
writer.WriteByte(')')
}
default:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
}
}
}
}
// AddClause add clause
func (stmt *Statement) AddClause(v clause.Interface) {
if optimizer, ok := v.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
} else {
name := v.Name()
c := stmt.Clauses[name]
c.Name = name
v.MergeClause(&c)
stmt.Clauses[name] = c
}
}
// AddClauseIfNotExists add clause if not exists
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
stmt.AddClause(v)
}
}
// BuildCondition build condition
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
if s, ok := query.(string); ok {
// if it is a number, then treats it as primary key
if _, err := strconv.Atoi(s); err != nil {
if s == "" && len(args) == 0 {
return nil
}
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) > 0 && strings.Contains(s, "@") {
// looks like a named query
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
}
if strings.Contains(strings.TrimSpace(s), " ") {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
}
}
}
conds := make([]clause.Expression, 0, 4)
args = append([]interface{}{query}, args...)
for idx, arg := range args {
if arg == nil {
continue
}
if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value()
}
switch v := arg.(type) {
case clause.Expression:
conds = append(conds, v)
case *DB:
v.executeScopes()
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
where.Exprs[0] = clause.AndConditions(orConds)
}
}
conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
conds = append(conds, cs.Expression)
}
}
case map[interface{}]interface{}:
for i, j := range v {
conds = append(conds, clause.Eq{Column: i, Value: j})
}
case map[string]string:
keys := make([]string, 0, len(v))
for i := range v {
keys = append(keys, i)
}
sort.Strings(keys)
for _, key := range keys {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
}
case map[string]interface{}:
keys := make([]string, 0, len(v))
for i := range v {
keys = append(keys, i)
}
sort.Strings(keys)
for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else {
// optimize reflect value length
valueLen := reflectValue.Len()
values := make([]interface{}, valueLen)
for i := 0; i < valueLen; i++ {
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: key, Values: values})
}
default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
}
}
default:
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
for reflectValue.Kind() == reflect.Ptr {
reflectValue = reflectValue.Elem()
}
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
selectedColumns := map[string]bool{}
if idx == 0 {
for _, v := range args[1:] {
if vs, ok := v.(string); ok {
selectedColumns[vs] = true
}
}
}
restricted := len(selectedColumns) != 0
switch reflectValue.Kind() {
case reflect.Struct:
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
}
}
}
if restricted {
break
}
} else if !reflectValue.IsValid() {
stmt.AddError(ErrInvalidData)
} else if len(conds) == 0 {
if len(args) == 1 {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
// optimize reflect value length
valueLen := reflectValue.Len()
values := make([]interface{}, valueLen)
for i := 0; i < valueLen; i++ {
values[i] = reflectValue.Index(i).Interface()
}
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
return []clause.Expression{clause.And(conds...)}
}
return nil
}
}
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
}
}
}
if len(conds) > 0 {
return []clause.Expression{clause.And(conds...)}
}
return nil
}
// Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if firstClauseWritten {
stmt.WriteByte(' ')
}
firstClauseWritten = true
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b(c, stmt)
} else {
c.Build(stmt)
}
}
}
}
func (stmt *Statement) Parse(value interface{}) (err error) {
return stmt.ParseWithSpecialTableName(value, "")
}
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]
return
}
stmt.Table = stmt.Schema.Table
}
return err
}
func (stmt *Statement) clone() *Statement {
newStmt := &Statement{
TableExpr: stmt.TableExpr,
Table: stmt.Table,
Model: stmt.Model,
Unscoped: stmt.Unscoped,
Dest: stmt.Dest,
ReflectValue: stmt.ReflectValue,
Clauses: map[string]clause.Clause{},
Distinct: stmt.Distinct,
Selects: stmt.Selects,
Omits: stmt.Omits,
Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool,
Schema: stmt.Schema,
Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
SkipHooks: stmt.SkipHooks,
}
if stmt.SQL.Len() > 0 {
newStmt.SQL.WriteString(stmt.SQL.String())
newStmt.Vars = make([]interface{}, 0, len(stmt.Vars))
newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
}
for k, c := range stmt.Clauses {
newStmt.Clauses[k] = c
}
for k, p := range stmt.Preloads {
newStmt.Preloads[k] = p
}
if len(stmt.Joins) > 0 {
newStmt.Joins = make([]join, len(stmt.Joins))
copy(newStmt.Joins, stmt.Joins)
}
if len(stmt.scopes) > 0 {
newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
copy(newStmt.scopes, stmt.scopes)
}
stmt.Settings.Range(func(k, v interface{}) bool {
newStmt.Settings.Store(k, v)
return true
})
return newStmt
}
// SetColumn set column's value
//
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
v[name] = value
} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
for _, m := range v {
m[name] = value
}
} else if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
if stmt.ReflectValue != destValue {
if !destValue.CanAddr() {
destValueCanAddr := reflect.New(destValue.Type())
destValueCanAddr.Elem().Set(destValue)
stmt.Dest = destValueCanAddr.Interface()
destValue = destValueCanAddr.Elem()
}
switch destValue.Kind() {
case reflect.Struct:
stmt.AddError(field.Set(stmt.Context, destValue, value))
default:
stmt.AddError(ErrInvalidData)
}
}
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(fromCallbacks) > 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
}
} else {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
}
case reflect.Struct:
if !stmt.ReflectValue.CanAddr() {
stmt.AddError(ErrInvalidValue)
return
}
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
}
} else {
stmt.AddError(ErrInvalidField)
}
} else {
stmt.AddError(ErrInvalidField)
}
}
// Changed check model changed or not when updating
func (stmt *Statement) Changed(fields ...string) bool {
modelValue := stmt.ReflectValue
switch modelValue.Kind() {
case reflect.Slice, reflect.Array:
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
}
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if mv, mok := stmt.Dest.(map[string]interface{}); mok {
if fv, ok := mv[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := mv[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue)
}
} else {
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
changedValue, zero := field.ValueOf(stmt.Context, destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
}
return !zero && !utils.AssertEqual(changedValue, fieldValue)
}
}
return false
}
if len(fields) == 0 {
for _, field := range stmt.Schema.FieldsByDBName {
if changed(field) {
return true
}
}
} else {
for _, name := range fields {
if field := stmt.Schema.LookUpField(name); field != nil {
if changed(field) {
return true
}
}
}
}
return false
}
var matchName = func() func(tableColumn string) (table, column string) {
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
return func(tableColumn string) (table, column string) {
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
table = matches[1]
star := matches[2]
columnName := matches[3]
if star != "" {
return table, star
}
return table, columnName
}
return "", ""
}
}()
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
processColumn := func(column string, result bool) {
if stmt.Schema == nil {
results[column] = result
} else if column == "*" {
notRestricted = result
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = result
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = result
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
if col == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else {
results[col] = result
}
} else {
results[column] = result
}
}
// select columns
for _, column := range stmt.Selects {
processColumn(column, true)
}
// omit columns
for _, column := range stmt.Omits {
processColumn(column, false)
}
if stmt.Schema != nil {
for _, field := range stmt.Schema.FieldsByName {
name := field.DBName
if name == "" {
name = field.Name
}
if requireCreate && !field.Creatable {
results[name] = false
} else if requireUpdate && !field.Updatable {
results[name] = false
}
}
}
return results, !notRestricted && len(stmt.Selects) > 0
}

160
vendor/gorm.io/gorm/utils/utils.go generated vendored Normal file
View File

@@ -0,0 +1,160 @@
package utils
import (
"database/sql/driver"
"fmt"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"unicode"
)
var gormSourceDir string
func init() {
_, file, _, _ := runtime.Caller(0)
// compatible solution to get gorm source directory with various operating systems
gormSourceDir = sourceDir(file)
}
func sourceDir(file string) string {
dir := filepath.Dir(file)
dir = filepath.Dir(dir)
s := filepath.Dir(dir)
if filepath.Base(s) != "gorm.io" {
s = dir
}
return filepath.ToSlash(s) + "/"
}
// FileWithLineNum return the file name and line number of the current file
func FileWithLineNum() string {
// the second caller usually from gorm internal, so set i start from 2
for i := 2; i < 15; i++ {
_, file, line, ok := runtime.Caller(i)
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) &&
!strings.HasSuffix(file, ".gen.go") {
return file + ":" + strconv.FormatInt(int64(line), 10)
}
}
return ""
}
func IsValidDBNameChar(c rune) bool {
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
}
// CheckTruth check string true or not
func CheckTruth(vals ...string) bool {
for _, val := range vals {
if val != "" && !strings.EqualFold(val, "false") {
return true
}
}
return false
}
func ToStringKey(values ...interface{}) string {
results := make([]string, len(values))
for idx, value := range values {
if valuer, ok := value.(driver.Valuer); ok {
value, _ = valuer.Value()
}
switch v := value.(type) {
case string:
results[idx] = v
case []byte:
results[idx] = string(v)
case uint:
results[idx] = strconv.FormatUint(uint64(v), 10)
default:
results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface())
}
}
return strings.Join(results, "_")
}
func Contains(elems []string, elem string) bool {
for _, e := range elems {
if elem == e {
return true
}
}
return false
}
func AssertEqual(x, y interface{}) bool {
if reflect.DeepEqual(x, y) {
return true
}
if x == nil || y == nil {
return false
}
xval := reflect.ValueOf(x)
yval := reflect.ValueOf(y)
if xval.Kind() == reflect.Ptr && xval.IsNil() ||
yval.Kind() == reflect.Ptr && yval.IsNil() {
return false
}
if valuer, ok := x.(driver.Valuer); ok {
x, _ = valuer.Value()
}
if valuer, ok := y.(driver.Valuer); ok {
y, _ = valuer.Value()
}
return reflect.DeepEqual(x, y)
}
func ToString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case int:
return strconv.FormatInt(int64(v), 10)
case int8:
return strconv.FormatInt(int64(v), 10)
case int16:
return strconv.FormatInt(int64(v), 10)
case int32:
return strconv.FormatInt(int64(v), 10)
case int64:
return strconv.FormatInt(v, 10)
case uint:
return strconv.FormatUint(uint64(v), 10)
case uint8:
return strconv.FormatUint(uint64(v), 10)
case uint16:
return strconv.FormatUint(uint64(v), 10)
case uint32:
return strconv.FormatUint(uint64(v), 10)
case uint64:
return strconv.FormatUint(v, 10)
}
return ""
}
const nestedRelationSplit = "__"
// NestedRelationName nested relationships like `Manager__Company`
func NestedRelationName(prefix, name string) string {
return prefix + nestedRelationSplit + name
}
// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}`
func SplitNestedRelationName(name string) []string {
return strings.Split(name, nestedRelationSplit)
}
// JoinNestedRelationNames nested relationships like `Manager__Company`
func JoinNestedRelationNames(relationNames []string) string {
return strings.Join(relationNames, nestedRelationSplit)
}