1. 实现配置文件解析
2. 实现数据库连接
This commit is contained in:
7
vendor/gorm.io/gorm/.gitignore
generated
vendored
Normal file
7
vendor/gorm.io/gorm/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
TODO*
|
||||
documents
|
||||
coverage.txt
|
||||
_book
|
||||
.idea
|
||||
vendor
|
||||
.vscode
|
||||
20
vendor/gorm.io/gorm/.golangci.yml
generated
vendored
Normal file
20
vendor/gorm.io/gorm/.golangci.yml
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
linters:
|
||||
enable:
|
||||
- cyclop
|
||||
- exportloopref
|
||||
- gocritic
|
||||
- gosec
|
||||
- ineffassign
|
||||
- misspell
|
||||
- prealloc
|
||||
- unconvert
|
||||
- unparam
|
||||
- goimports
|
||||
- whitespace
|
||||
|
||||
linters-settings:
|
||||
whitespace:
|
||||
multi-func: true
|
||||
goimports:
|
||||
local-prefixes: gorm.io/gorm
|
||||
|
||||
21
vendor/gorm.io/gorm/LICENSE
generated
vendored
Normal file
21
vendor/gorm.io/gorm/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2013-NOW Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
44
vendor/gorm.io/gorm/README.md
generated
vendored
Normal file
44
vendor/gorm.io/gorm/README.md
generated
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
# GORM
|
||||
|
||||
The fantastic ORM library for Golang, aims to be developer friendly.
|
||||
|
||||
[](https://goreportcard.com/report/github.com/go-gorm/gorm)
|
||||
[](https://github.com/go-gorm/gorm/actions)
|
||||
[](https://opensource.org/licenses/MIT)
|
||||
[](https://pkg.go.dev/gorm.io/gorm?tab=doc)
|
||||
|
||||
## Overview
|
||||
|
||||
* Full-Featured ORM
|
||||
* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance)
|
||||
* Hooks (Before/After Create/Save/Update/Delete/Find)
|
||||
* Eager loading with `Preload`, `Joins`
|
||||
* Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point
|
||||
* Context, Prepared Statement Mode, DryRun Mode
|
||||
* Batch Insert, FindInBatches, Find To Map
|
||||
* SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr
|
||||
* Composite Primary Key
|
||||
* Auto Migrations
|
||||
* Logger
|
||||
* Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus…
|
||||
* Every feature comes with tests
|
||||
* Developer Friendly
|
||||
|
||||
## Getting Started
|
||||
|
||||
* GORM Guides [https://gorm.io](https://gorm.io)
|
||||
* Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html)
|
||||
|
||||
## Contributing
|
||||
|
||||
[You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html)
|
||||
|
||||
## Contributors
|
||||
|
||||
[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework!
|
||||
|
||||
## License
|
||||
|
||||
© Jinzhu, 2013~time.Now
|
||||
|
||||
Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE)
|
||||
579
vendor/gorm.io/gorm/association.go
generated
vendored
Normal file
579
vendor/gorm.io/gorm/association.go
generated
vendored
Normal file
@@ -0,0 +1,579 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Association Mode contains some helper methods to handle relationship things easily.
|
||||
type Association struct {
|
||||
DB *DB
|
||||
Relationship *schema.Relationship
|
||||
Unscope bool
|
||||
Error error
|
||||
}
|
||||
|
||||
func (db *DB) Association(column string) *Association {
|
||||
association := &Association{DB: db}
|
||||
table := db.Statement.Table
|
||||
|
||||
if err := db.Statement.Parse(db.Statement.Model); err == nil {
|
||||
db.Statement.Table = table
|
||||
association.Relationship = db.Statement.Schema.Relationships.Relations[column]
|
||||
|
||||
if association.Relationship == nil {
|
||||
association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column)
|
||||
}
|
||||
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
|
||||
}
|
||||
} else {
|
||||
association.Error = err
|
||||
}
|
||||
|
||||
return association
|
||||
}
|
||||
|
||||
func (association *Association) Unscoped() *Association {
|
||||
return &Association{
|
||||
DB: association.DB,
|
||||
Relationship: association.Relationship,
|
||||
Error: association.Error,
|
||||
Unscope: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (association *Association) Find(out interface{}, conds ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
association.Error = association.buildCondition().Find(out, conds...).Error
|
||||
}
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Append(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
switch association.Relationship.Type {
|
||||
case schema.HasOne, schema.BelongsTo:
|
||||
if len(values) > 0 {
|
||||
association.Error = association.Replace(values...)
|
||||
}
|
||||
default:
|
||||
association.saveAssociation( /*clear*/ false, values...)
|
||||
}
|
||||
}
|
||||
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Replace(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
reflectValue := association.DB.Statement.ReflectValue
|
||||
rel := association.Relationship
|
||||
|
||||
var oldBelongsToExpr clause.Expression
|
||||
// we have to record the old BelongsTo value
|
||||
if association.Unscope && rel.Type == schema.BelongsTo {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
oldBelongsToExpr = clause.IN{Column: column, Values: values}
|
||||
}
|
||||
}
|
||||
|
||||
// save associations
|
||||
if association.saveAssociation( /*clear*/ true, values...); association.Error != nil {
|
||||
return association.Error
|
||||
}
|
||||
|
||||
// set old associations's foreign key to null
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
if len(values) == 0 {
|
||||
updateMap := map[string]interface{}{}
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
case reflect.Struct:
|
||||
association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface())
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
updateMap[ref.ForeignKey.DBName] = nil
|
||||
}
|
||||
|
||||
association.Error = association.DB.UpdateColumns(updateMap).Error
|
||||
}
|
||||
if association.Unscope && oldBelongsToExpr != nil {
|
||||
association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
var (
|
||||
primaryFields []*schema.Field
|
||||
foreignKeys []string
|
||||
updateMap = map[string]interface{}{}
|
||||
relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel})
|
||||
modelValue = reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
|
||||
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
|
||||
tx.Not(clause.IN{Column: column, Values: values})
|
||||
}
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||
updateMap[ref.ForeignKey.DBName] = nil
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
|
||||
if association.Unscope {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error
|
||||
} else {
|
||||
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
|
||||
}
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryValue == "" {
|
||||
if ref.OwnPrimaryKey {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
|
||||
} else {
|
||||
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
||||
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
|
||||
}
|
||||
} else {
|
||||
tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
|
||||
tx.Where(clause.IN{Column: column, Values: values})
|
||||
} else {
|
||||
return ErrPrimaryKeyRequired
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
|
||||
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
|
||||
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
|
||||
}
|
||||
|
||||
association.Error = tx.Delete(modelValue).Error
|
||||
}
|
||||
}
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Delete(values ...interface{}) error {
|
||||
if association.Error == nil {
|
||||
var (
|
||||
reflectValue = association.DB.Statement.ReflectValue
|
||||
rel = association.Relationship
|
||||
primaryFields []*schema.Field
|
||||
foreignKeys []string
|
||||
updateAttrs = map[string]interface{}{}
|
||||
conds []clause.Expression
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryValue == "" {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
foreignKeys = append(foreignKeys, ref.ForeignKey.DBName)
|
||||
updateAttrs[ref.ForeignKey.DBName] = nil
|
||||
} else {
|
||||
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
switch rel.Type {
|
||||
case schema.BelongsTo:
|
||||
associationDB := association.DB.Session(&Session{})
|
||||
tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface())
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 {
|
||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||
} else {
|
||||
return ErrPrimaryKeyRequired
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields)
|
||||
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
if association.Unscope {
|
||||
var foreignFields []*schema.Field
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 {
|
||||
column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs)
|
||||
association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error
|
||||
}
|
||||
}
|
||||
case schema.HasOne, schema.HasMany:
|
||||
model := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := association.DB.Model(model)
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 {
|
||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||
} else {
|
||||
return ErrPrimaryKeyRequired
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
|
||||
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
if association.Unscope {
|
||||
association.Error = tx.Clauses(conds...).Delete(model).Error
|
||||
} else {
|
||||
association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
primaryFields, relPrimaryFields []*schema.Field
|
||||
joinPrimaryKeys, joinRelPrimaryKeys []string
|
||||
joinValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryValue == "" {
|
||||
if ref.OwnPrimaryKey {
|
||||
primaryFields = append(primaryFields, ref.PrimaryKey)
|
||||
joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName)
|
||||
} else {
|
||||
relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey)
|
||||
joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName)
|
||||
}
|
||||
} else {
|
||||
conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
}
|
||||
}
|
||||
|
||||
_, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields)
|
||||
if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 {
|
||||
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})
|
||||
} else {
|
||||
return ErrPrimaryKeyRequired
|
||||
}
|
||||
|
||||
_, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields)
|
||||
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
|
||||
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})
|
||||
|
||||
association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error
|
||||
}
|
||||
|
||||
if association.Error == nil {
|
||||
// clean up deleted values's foreign key
|
||||
relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields)
|
||||
|
||||
cleanUpDeletedRelations := func(data reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero {
|
||||
fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data))
|
||||
primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields))
|
||||
|
||||
switch fieldValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
validFieldValues := reflect.Zero(rel.Field.IndirectFieldType)
|
||||
for i := 0; i < fieldValue.Len(); i++ {
|
||||
for idx, field := range rel.FieldSchema.PrimaryFields {
|
||||
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i))
|
||||
}
|
||||
|
||||
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok {
|
||||
validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface())
|
||||
case reflect.Struct:
|
||||
for idx, field := range rel.FieldSchema.PrimaryFields {
|
||||
primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue)
|
||||
}
|
||||
|
||||
if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok {
|
||||
if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if rel.JoinTable == nil {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey || ref.PrimaryValue != "" {
|
||||
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
} else {
|
||||
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i)))
|
||||
}
|
||||
case reflect.Struct:
|
||||
cleanUpDeletedRelations(reflectValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return association.Error
|
||||
}
|
||||
|
||||
func (association *Association) Clear() error {
|
||||
return association.Replace()
|
||||
}
|
||||
|
||||
func (association *Association) Count() (count int64) {
|
||||
if association.Error == nil {
|
||||
association.Error = association.buildCondition().Count(&count).Error
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type assignBack struct {
|
||||
Source reflect.Value
|
||||
Index int
|
||||
Dest reflect.Value
|
||||
}
|
||||
|
||||
func (association *Association) saveAssociation(clear bool, values ...interface{}) {
|
||||
var (
|
||||
reflectValue = association.DB.Statement.ReflectValue
|
||||
assignBacks []assignBack // assign association values back to arguments after save
|
||||
)
|
||||
|
||||
appendToRelations := func(source, rv reflect.Value, clear bool) {
|
||||
switch association.Relationship.Type {
|
||||
case schema.HasOne, schema.BelongsTo:
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() > 0 {
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface())
|
||||
|
||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)})
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface())
|
||||
|
||||
if association.Relationship.Field.FieldType.Kind() == reflect.Struct {
|
||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv})
|
||||
}
|
||||
}
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
elemType := association.Relationship.Field.IndirectFieldType.Elem()
|
||||
oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source))
|
||||
var fieldValue reflect.Value
|
||||
if clear {
|
||||
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap())
|
||||
} else {
|
||||
fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap())
|
||||
reflect.Copy(fieldValue, oldFieldValue)
|
||||
}
|
||||
|
||||
appendToFieldValues := func(ev reflect.Value) {
|
||||
if ev.Type().AssignableTo(elemType) {
|
||||
fieldValue = reflect.Append(fieldValue, ev)
|
||||
} else if ev.Type().Elem().AssignableTo(elemType) {
|
||||
fieldValue = reflect.Append(fieldValue, ev.Elem())
|
||||
} else {
|
||||
association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name)
|
||||
}
|
||||
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()})
|
||||
}
|
||||
}
|
||||
|
||||
switch rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
appendToFieldValues(reflect.Indirect(rv.Index(i)).Addr())
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToFieldValues(rv.Addr())
|
||||
}
|
||||
|
||||
if association.Error == nil {
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
selectedSaveColumns := []string{association.Relationship.Name}
|
||||
omitColumns := []string{}
|
||||
selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false)
|
||||
for name, ok := range selectColumns {
|
||||
columnName := ""
|
||||
if strings.HasPrefix(name, association.Relationship.Name) {
|
||||
if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" {
|
||||
columnName = name
|
||||
}
|
||||
} else if strings.HasPrefix(name, clause.Associations) {
|
||||
columnName = name
|
||||
}
|
||||
|
||||
if columnName != "" {
|
||||
if ok {
|
||||
selectedSaveColumns = append(selectedSaveColumns, columnName)
|
||||
} else {
|
||||
omitColumns = append(omitColumns, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name)
|
||||
}
|
||||
}
|
||||
|
||||
associationDB := association.DB.Session(&Session{}).Model(nil)
|
||||
if !association.DB.FullSaveAssociations {
|
||||
associationDB.Select(selectedSaveColumns)
|
||||
}
|
||||
if len(omitColumns) > 0 {
|
||||
associationDB.Omit(omitColumns...)
|
||||
}
|
||||
associationDB = associationDB.Session(&Session{})
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if len(values) != reflectValue.Len() {
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil {
|
||||
association.Error = err
|
||||
break
|
||||
}
|
||||
|
||||
if association.Relationship.JoinTable == nil {
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
||||
if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil {
|
||||
association.Error = err
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
association.Error = ErrInvalidValueOfLength
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)
|
||||
|
||||
// TODO support save slice data, sql with case?
|
||||
association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error
|
||||
}
|
||||
case reflect.Struct:
|
||||
// clear old data
|
||||
if clear && len(values) == 0 {
|
||||
association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface())
|
||||
|
||||
if association.Relationship.JoinTable == nil && association.Error == nil {
|
||||
for _, ref := range association.Relationship.References {
|
||||
if !ref.OwnPrimaryKey && ref.PrimaryValue == "" {
|
||||
association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, value := range values {
|
||||
rv := reflect.Indirect(reflect.ValueOf(value))
|
||||
appendToRelations(reflectValue, rv, clear && idx == 0)
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error
|
||||
}
|
||||
}
|
||||
|
||||
for _, assignBack := range assignBacks {
|
||||
fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source))
|
||||
if assignBack.Index > 0 {
|
||||
reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1))
|
||||
} else {
|
||||
reflect.Indirect(assignBack.Dest).Set(fieldValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (association *Association) buildCondition() *DB {
|
||||
var (
|
||||
queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue)
|
||||
modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface()
|
||||
tx = association.DB.Model(modelValue)
|
||||
)
|
||||
|
||||
if association.Relationship.JoinTable != nil {
|
||||
if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 {
|
||||
joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}}
|
||||
for _, queryClause := range association.Relationship.JoinTable.QueryClauses {
|
||||
joinStmt.AddClause(queryClause)
|
||||
}
|
||||
joinStmt.Build("WHERE")
|
||||
if len(joinStmt.SQL.String()) > 0 {
|
||||
tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars})
|
||||
}
|
||||
}
|
||||
|
||||
tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{
|
||||
Table: clause.Table{Name: association.Relationship.JoinTable.Table},
|
||||
ON: clause.Where{Exprs: queryConds},
|
||||
}}})
|
||||
} else {
|
||||
tx.Clauses(clause.Where{Exprs: queryConds})
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
341
vendor/gorm.io/gorm/callbacks.go
generated
vendored
Normal file
341
vendor/gorm.io/gorm/callbacks.go
generated
vendored
Normal file
@@ -0,0 +1,341 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func initializeCallbacks(db *DB) *callbacks {
|
||||
return &callbacks{
|
||||
processors: map[string]*processor{
|
||||
"create": {db: db},
|
||||
"query": {db: db},
|
||||
"update": {db: db},
|
||||
"delete": {db: db},
|
||||
"row": {db: db},
|
||||
"raw": {db: db},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// callbacks gorm callbacks manager
|
||||
type callbacks struct {
|
||||
processors map[string]*processor
|
||||
}
|
||||
|
||||
type processor struct {
|
||||
db *DB
|
||||
Clauses []string
|
||||
fns []func(*DB)
|
||||
callbacks []*callback
|
||||
}
|
||||
|
||||
type callback struct {
|
||||
name string
|
||||
before string
|
||||
after string
|
||||
remove bool
|
||||
replace bool
|
||||
match func(*DB) bool
|
||||
handler func(*DB)
|
||||
processor *processor
|
||||
}
|
||||
|
||||
func (cs *callbacks) Create() *processor {
|
||||
return cs.processors["create"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Query() *processor {
|
||||
return cs.processors["query"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Update() *processor {
|
||||
return cs.processors["update"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Delete() *processor {
|
||||
return cs.processors["delete"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Row() *processor {
|
||||
return cs.processors["row"]
|
||||
}
|
||||
|
||||
func (cs *callbacks) Raw() *processor {
|
||||
return cs.processors["raw"]
|
||||
}
|
||||
|
||||
func (p *processor) Execute(db *DB) *DB {
|
||||
// call scopes
|
||||
for len(db.Statement.scopes) > 0 {
|
||||
db = db.executeScopes()
|
||||
}
|
||||
|
||||
var (
|
||||
curTime = time.Now()
|
||||
stmt = db.Statement
|
||||
resetBuildClauses bool
|
||||
)
|
||||
|
||||
if len(stmt.BuildClauses) == 0 {
|
||||
stmt.BuildClauses = p.Clauses
|
||||
resetBuildClauses = true
|
||||
}
|
||||
|
||||
if optimizer, ok := db.Statement.Dest.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(stmt)
|
||||
}
|
||||
|
||||
// assign model values
|
||||
if stmt.Model == nil {
|
||||
stmt.Model = stmt.Dest
|
||||
} else if stmt.Dest == nil {
|
||||
stmt.Dest = stmt.Model
|
||||
}
|
||||
|
||||
// parse model values
|
||||
if stmt.Model != nil {
|
||||
if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) {
|
||||
if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil {
|
||||
db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err))
|
||||
} else {
|
||||
db.AddError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assign stmt.ReflectValue
|
||||
if stmt.Dest != nil {
|
||||
stmt.ReflectValue = reflect.ValueOf(stmt.Dest)
|
||||
for stmt.ReflectValue.Kind() == reflect.Ptr {
|
||||
if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() {
|
||||
stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem()))
|
||||
}
|
||||
|
||||
stmt.ReflectValue = stmt.ReflectValue.Elem()
|
||||
}
|
||||
if !stmt.ReflectValue.IsValid() {
|
||||
db.AddError(ErrInvalidValue)
|
||||
}
|
||||
}
|
||||
|
||||
for _, f := range p.fns {
|
||||
f(db)
|
||||
}
|
||||
|
||||
if stmt.SQL.Len() > 0 {
|
||||
db.Logger.Trace(stmt.Context, curTime, func() (string, int64) {
|
||||
sql, vars := stmt.SQL.String(), stmt.Vars
|
||||
if filter, ok := db.Logger.(ParamsFilter); ok {
|
||||
sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...)
|
||||
}
|
||||
return db.Dialector.Explain(sql, vars...), db.RowsAffected
|
||||
}, db.Error)
|
||||
}
|
||||
|
||||
if !stmt.DB.DryRun {
|
||||
stmt.SQL.Reset()
|
||||
stmt.Vars = nil
|
||||
}
|
||||
|
||||
if resetBuildClauses {
|
||||
stmt.BuildClauses = nil
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (p *processor) Get(name string) func(*DB) {
|
||||
for i := len(p.callbacks) - 1; i >= 0; i-- {
|
||||
if v := p.callbacks[i]; v.name == name && !v.remove {
|
||||
return v.handler
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *processor) Before(name string) *callback {
|
||||
return &callback{before: name, processor: p}
|
||||
}
|
||||
|
||||
func (p *processor) After(name string) *callback {
|
||||
return &callback{after: name, processor: p}
|
||||
}
|
||||
|
||||
func (p *processor) Match(fc func(*DB) bool) *callback {
|
||||
return &callback{match: fc, processor: p}
|
||||
}
|
||||
|
||||
func (p *processor) Register(name string, fn func(*DB)) error {
|
||||
return (&callback{processor: p}).Register(name, fn)
|
||||
}
|
||||
|
||||
func (p *processor) Remove(name string) error {
|
||||
return (&callback{processor: p}).Remove(name)
|
||||
}
|
||||
|
||||
func (p *processor) Replace(name string, fn func(*DB)) error {
|
||||
return (&callback{processor: p}).Replace(name, fn)
|
||||
}
|
||||
|
||||
func (p *processor) compile() (err error) {
|
||||
var callbacks []*callback
|
||||
for _, callback := range p.callbacks {
|
||||
if callback.match == nil || callback.match(p.db) {
|
||||
callbacks = append(callbacks, callback)
|
||||
}
|
||||
}
|
||||
p.callbacks = callbacks
|
||||
|
||||
if p.fns, err = sortCallbacks(p.callbacks); err != nil {
|
||||
p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *callback) Before(name string) *callback {
|
||||
c.before = name
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *callback) After(name string) *callback {
|
||||
c.after = name
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *callback) Register(name string, fn func(*DB)) error {
|
||||
c.name = name
|
||||
c.handler = fn
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
return c.processor.compile()
|
||||
}
|
||||
|
||||
func (c *callback) Remove(name string) error {
|
||||
c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.remove = true
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
return c.processor.compile()
|
||||
}
|
||||
|
||||
func (c *callback) Replace(name string, fn func(*DB)) error {
|
||||
c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum())
|
||||
c.name = name
|
||||
c.handler = fn
|
||||
c.replace = true
|
||||
c.processor.callbacks = append(c.processor.callbacks, c)
|
||||
return c.processor.compile()
|
||||
}
|
||||
|
||||
// getRIndex get right index from string slice
|
||||
func getRIndex(strs []string, str string) int {
|
||||
for i := len(strs) - 1; i >= 0; i-- {
|
||||
if strs[i] == str {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func sortCallbacks(cs []*callback) (fns []func(*DB), err error) {
|
||||
var (
|
||||
names, sorted []string
|
||||
sortCallback func(*callback) error
|
||||
)
|
||||
sort.SliceStable(cs, func(i, j int) bool {
|
||||
if cs[j].before == "*" && cs[i].before != "*" {
|
||||
return true
|
||||
}
|
||||
if cs[j].after == "*" && cs[i].after != "*" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
|
||||
for _, c := range cs {
|
||||
// show warning message the callback name already exists
|
||||
if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove {
|
||||
c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum())
|
||||
}
|
||||
names = append(names, c.name)
|
||||
}
|
||||
|
||||
sortCallback = func(c *callback) error {
|
||||
if c.before != "" { // if defined before callback
|
||||
if c.before == "*" && len(sorted) > 0 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
sorted = append([]string{c.name}, sorted...)
|
||||
}
|
||||
} else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
// if before callback already sorted, append current callback just after it
|
||||
sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...)
|
||||
} else if curIdx > sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.before); idx != -1 {
|
||||
// if before callback exists
|
||||
cs[idx].after = c.name
|
||||
}
|
||||
}
|
||||
|
||||
if c.after != "" { // if defined after callback
|
||||
if c.after == "*" && len(sorted) > 0 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
sorted = append(sorted, c.name)
|
||||
}
|
||||
} else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 {
|
||||
if curIdx := getRIndex(sorted, c.name); curIdx == -1 {
|
||||
// if after callback sorted, append current callback to last
|
||||
sorted = append(sorted, c.name)
|
||||
} else if curIdx < sortedIdx {
|
||||
return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after)
|
||||
}
|
||||
} else if idx := getRIndex(names, c.after); idx != -1 {
|
||||
// if after callback exists but haven't sorted
|
||||
// set after callback's before callback to current callback
|
||||
after := cs[idx]
|
||||
|
||||
if after.before == "" {
|
||||
after.before = c.name
|
||||
}
|
||||
|
||||
if err := sortCallback(after); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := sortCallback(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if current callback haven't been sorted, append it to last
|
||||
if getRIndex(sorted, c.name) == -1 {
|
||||
sorted = append(sorted, c.name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, c := range cs {
|
||||
if err = sortCallback(c); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range sorted {
|
||||
if idx := getRIndex(names, name); !cs[idx].remove {
|
||||
fns = append(fns, cs[idx].handler)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
453
vendor/gorm.io/gorm/callbacks/associations.go
generated
vendored
Normal file
453
vendor/gorm.io/gorm/callbacks/associations.go
generated
vendored
Normal file
@@ -0,0 +1,453 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
|
||||
|
||||
// Save Belongs To associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
setupReferences := func(obj reflect.Value, elem reflect.Value) {
|
||||
for _, ref := range rel.References {
|
||||
if !ref.OwnPrimaryKey {
|
||||
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
dest[ref.ForeignKey.DBName] = pv
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
dest[rel.Name] = elem.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
rValLen = db.Statement.ReflectValue.Len()
|
||||
objs = make([]reflect.Value, 0, rValLen)
|
||||
fieldType = rel.Field.FieldType
|
||||
isPtr = fieldType.Kind() == reflect.Ptr
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
|
||||
if !isPtr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
objs = append(objs, obj)
|
||||
elems = reflect.Append(elems, rv)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
|
||||
setupReferences(db.Statement.ReflectValue, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
|
||||
|
||||
// Save Has One associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasOne {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
fieldType = rel.Field.FieldType
|
||||
isPtr = fieldType.Kind() == reflect.Ptr
|
||||
)
|
||||
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
|
||||
rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
|
||||
}
|
||||
}
|
||||
|
||||
elems = reflect.Append(elems, rv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||
f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if f.Kind() != reflect.Ptr {
|
||||
f = f.Addr()
|
||||
}
|
||||
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
|
||||
}
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save Has Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.HasMany {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
identityMap := map[string]bool{}
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
|
||||
}
|
||||
}
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
if isPtr {
|
||||
elems = reflect.Append(elems, elem)
|
||||
} else {
|
||||
elems = reflect.Append(elems, elem.Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
assignmentColumns := make([]string, 0, len(rel.References))
|
||||
for _, ref := range rel.References {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
|
||||
// Save Many2Many associations
|
||||
for _, rel := range db.Statement.Schema.Relationships.Many2Many {
|
||||
if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldType := rel.Field.IndirectFieldType.Elem()
|
||||
isPtr := fieldType.Kind() == reflect.Ptr
|
||||
if !isPtr {
|
||||
fieldType = reflect.PtrTo(fieldType)
|
||||
}
|
||||
elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
|
||||
joins := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.JoinTable.ModelType)), 0, 10)
|
||||
objs := []reflect.Value{}
|
||||
|
||||
appendToJoins := func(obj reflect.Value, elem reflect.Value) {
|
||||
joinValue := reflect.New(rel.JoinTable.ModelType)
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
|
||||
} else if ref.PrimaryValue != "" {
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
|
||||
} else {
|
||||
fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
|
||||
db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
|
||||
}
|
||||
}
|
||||
joins = reflect.Append(joins, joinValue)
|
||||
}
|
||||
|
||||
identityMap := map[string]bool{}
|
||||
appendToElems := func(v reflect.Value) {
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
|
||||
f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
|
||||
for i := 0; i < f.Len(); i++ {
|
||||
elem := f.Index(i)
|
||||
if !isPtr {
|
||||
elem = elem.Addr()
|
||||
}
|
||||
objs = append(objs, v)
|
||||
elems = reflect.Append(elems, elem)
|
||||
|
||||
relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
|
||||
for _, pf := range rel.FieldSchema.PrimaryFields {
|
||||
if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
|
||||
relPrimaryValues = append(relPrimaryValues, pfv)
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := utils.ToStringKey(relPrimaryValues...)
|
||||
if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
|
||||
if cacheKey != "" { // has primary fields
|
||||
identityMap[cacheKey] = true
|
||||
}
|
||||
|
||||
distinctElems = reflect.Append(distinctElems, elem)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
obj := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(obj).Kind() == reflect.Struct {
|
||||
appendToElems(obj)
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
appendToElems(db.Statement.ReflectValue)
|
||||
}
|
||||
|
||||
// optimize elems of reflect value length
|
||||
if elemLen := elems.Len(); elemLen > 0 {
|
||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||
saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
|
||||
}
|
||||
|
||||
for i := 0; i < elemLen; i++ {
|
||||
appendToJoins(objs[i], elems.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
if joins.Len() > 0 {
|
||||
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
DisableNestedTransaction: true,
|
||||
}).Create(joins.Interface()).Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
|
||||
if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
|
||||
onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
|
||||
for _, dbName := range s.PrimaryFieldDBNames {
|
||||
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
|
||||
}
|
||||
|
||||
onConflict.UpdateAll = stmt.DB.FullSaveAssociations
|
||||
if !onConflict.UpdateAll {
|
||||
onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
|
||||
}
|
||||
} else {
|
||||
onConflict.DoNothing = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||
// stop save association loop
|
||||
if checkAssociationsSaved(db, rValues) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
selects, omits []string
|
||||
onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
|
||||
refName = rel.Name + "."
|
||||
values = rValues.Interface()
|
||||
)
|
||||
|
||||
for name, ok := range selectColumns {
|
||||
columnName := ""
|
||||
if strings.HasPrefix(name, refName) {
|
||||
columnName = strings.TrimPrefix(name, refName)
|
||||
}
|
||||
|
||||
if columnName != "" {
|
||||
if ok {
|
||||
selects = append(selects, columnName)
|
||||
} else {
|
||||
omits = append(omits, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
|
||||
FullSaveAssociations: db.FullSaveAssociations,
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
DisableNestedTransaction: true,
|
||||
})
|
||||
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if tx.Statement.FullSaveAssociations {
|
||||
tx = tx.Set("gorm:update_track_time", true)
|
||||
}
|
||||
|
||||
if len(selects) > 0 {
|
||||
tx = tx.Select(selects)
|
||||
} else if restricted && len(omits) == 0 {
|
||||
tx = tx.Omit(clause.Associations)
|
||||
}
|
||||
|
||||
if len(omits) > 0 {
|
||||
tx = tx.Omit(omits...)
|
||||
}
|
||||
|
||||
return db.AddError(tx.Create(values).Error)
|
||||
}
|
||||
|
||||
// check association values has been saved
|
||||
// if values kind is Struct, check it has been saved
|
||||
// if values kind is Slice/Array, check all items have been saved
|
||||
var visitMapStoreKey = "gorm:saved_association_map"
|
||||
|
||||
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
|
||||
if visit, ok := db.Get(visitMapStoreKey); ok {
|
||||
if v, ok := visit.(*visitMap); ok {
|
||||
if loadOrStoreVisitMap(v, values) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vistMap := make(visitMap)
|
||||
loadOrStoreVisitMap(&vistMap, values)
|
||||
db.Set(visitMapStoreKey, &vistMap)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
83
vendor/gorm.io/gorm/callbacks/callbacks.go
generated
vendored
Normal file
83
vendor/gorm.io/gorm/callbacks/callbacks.go
generated
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"}
|
||||
queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"}
|
||||
updateClauses = []string{"UPDATE", "SET", "WHERE"}
|
||||
deleteClauses = []string{"DELETE", "FROM", "WHERE"}
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
LastInsertIDReversed bool
|
||||
CreateClauses []string
|
||||
QueryClauses []string
|
||||
UpdateClauses []string
|
||||
DeleteClauses []string
|
||||
}
|
||||
|
||||
func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
|
||||
enableTransaction := func(db *gorm.DB) bool {
|
||||
return !db.SkipDefaultTransaction
|
||||
}
|
||||
|
||||
if len(config.CreateClauses) == 0 {
|
||||
config.CreateClauses = createClauses
|
||||
}
|
||||
if len(config.QueryClauses) == 0 {
|
||||
config.QueryClauses = queryClauses
|
||||
}
|
||||
if len(config.DeleteClauses) == 0 {
|
||||
config.DeleteClauses = deleteClauses
|
||||
}
|
||||
if len(config.UpdateClauses) == 0 {
|
||||
config.UpdateClauses = updateClauses
|
||||
}
|
||||
|
||||
createCallback := db.Callback().Create()
|
||||
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
createCallback.Register("gorm:before_create", BeforeCreate)
|
||||
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true))
|
||||
createCallback.Register("gorm:create", Create(config))
|
||||
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true))
|
||||
createCallback.Register("gorm:after_create", AfterCreate)
|
||||
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
createCallback.Clauses = config.CreateClauses
|
||||
|
||||
queryCallback := db.Callback().Query()
|
||||
queryCallback.Register("gorm:query", Query)
|
||||
queryCallback.Register("gorm:preload", Preload)
|
||||
queryCallback.Register("gorm:after_query", AfterQuery)
|
||||
queryCallback.Clauses = config.QueryClauses
|
||||
|
||||
deleteCallback := db.Callback().Delete()
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
deleteCallback.Register("gorm:before_delete", BeforeDelete)
|
||||
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
|
||||
deleteCallback.Register("gorm:delete", Delete(config))
|
||||
deleteCallback.Register("gorm:after_delete", AfterDelete)
|
||||
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
deleteCallback.Clauses = config.DeleteClauses
|
||||
|
||||
updateCallback := db.Callback().Update()
|
||||
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
|
||||
updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue)
|
||||
updateCallback.Register("gorm:before_update", BeforeUpdate)
|
||||
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false))
|
||||
updateCallback.Register("gorm:update", Update(config))
|
||||
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false))
|
||||
updateCallback.Register("gorm:after_update", AfterUpdate)
|
||||
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
|
||||
updateCallback.Clauses = config.UpdateClauses
|
||||
|
||||
rowCallback := db.Callback().Row()
|
||||
rowCallback.Register("gorm:row", RowQuery)
|
||||
rowCallback.Clauses = config.QueryClauses
|
||||
|
||||
rawCallback := db.Callback().Raw()
|
||||
rawCallback.Register("gorm:raw", RawExec)
|
||||
rawCallback.Clauses = config.QueryClauses
|
||||
}
|
||||
32
vendor/gorm.io/gorm/callbacks/callmethod.go
generated
vendored
Normal file
32
vendor/gorm.io/gorm/callbacks/callmethod.go
generated
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
|
||||
tx := db.Session(&gorm.Session{NewDB: true})
|
||||
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
db.Statement.CurDestIndex = 0
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() {
|
||||
fc(value.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
return
|
||||
}
|
||||
db.Statement.CurDestIndex++
|
||||
}
|
||||
case reflect.Struct:
|
||||
if db.Statement.ReflectValue.CanAddr() {
|
||||
fc(db.Statement.ReflectValue.Addr().Interface(), tx)
|
||||
} else {
|
||||
db.AddError(gorm.ErrInvalidValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
385
vendor/gorm.io/gorm/callbacks/create.go
generated
vendored
Normal file
385
vendor/gorm.io/gorm/callbacks/create.go
generated
vendored
Normal file
@@ -0,0 +1,385 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// BeforeCreate before create hooks
|
||||
func BeforeCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeCreate {
|
||||
if i, ok := value.(BeforeCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeCreate(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Create create hook
|
||||
func Create(config *Config) func(db *gorm.DB) {
|
||||
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
if !db.Statement.Unscoped {
|
||||
for _, c := range db.Statement.Schema.CreateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
|
||||
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
|
||||
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
|
||||
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
|
||||
fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
db.Statement.AddClause(clause.Returning{Columns: fromColumns})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Insert{})
|
||||
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
isDryRun := !db.DryRun && db.Error == nil
|
||||
if !isDryRun {
|
||||
return
|
||||
}
|
||||
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
if ok {
|
||||
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
|
||||
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
|
||||
mode |= gorm.ScanOnConflictDoNothing
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := db.Statement.ConnPool.QueryContext(
|
||||
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
|
||||
)
|
||||
if db.AddError(err) == nil {
|
||||
defer func() {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, mode)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
result, err := db.Statement.ConnPool.ExecContext(
|
||||
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
|
||||
)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
if db.RowsAffected == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
pkField *schema.Field
|
||||
pkFieldName = "@id"
|
||||
)
|
||||
if db.Statement.Schema != nil {
|
||||
if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
|
||||
return
|
||||
}
|
||||
pkField = db.Statement.Schema.PrioritizedPrimaryField
|
||||
pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
|
||||
}
|
||||
|
||||
insertID, err := result.LastInsertId()
|
||||
insertOk := err == nil && insertID > 0
|
||||
if !insertOk {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
|
||||
// append @id column with value for auto-increment primary key
|
||||
// the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
|
||||
switch values := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values[pkFieldName] = insertID
|
||||
case *map[string]interface{}:
|
||||
(*values)[pkFieldName] = insertID
|
||||
case []map[string]interface{}, *[]map[string]interface{}:
|
||||
mapValues, ok := values.([]map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := values.(*[]map[string]interface{}); ok {
|
||||
if *v != nil {
|
||||
mapValues = *v
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, mapValue := range mapValues {
|
||||
if mapValue != nil {
|
||||
mapValue[pkFieldName] = insertID
|
||||
}
|
||||
insertID += schema.DefaultAutoIncrementIncrement
|
||||
}
|
||||
default:
|
||||
if pkField == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if config.LastInsertIDReversed {
|
||||
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, rv)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID -= pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
|
||||
rv := db.Statement.ReflectValue.Index(i)
|
||||
if reflect.Indirect(rv).Kind() != reflect.Struct {
|
||||
break
|
||||
}
|
||||
|
||||
if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
|
||||
insertID += pkField.AutoIncrementIncrement
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
_, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
if isZero {
|
||||
db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AfterCreate after create hooks
|
||||
func AfterCreate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterCreate {
|
||||
if i, ok := value.(AfterCreateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterCreate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToCreateValues convert to create values
|
||||
func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
|
||||
curTime := stmt.DB.NowFunc()
|
||||
|
||||
switch value := stmt.Dest.(type) {
|
||||
case map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, value)
|
||||
case *map[string]interface{}:
|
||||
values = ConvertMapToValuesForCreate(stmt, *value)
|
||||
case []map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, value)
|
||||
case *[]map[string]interface{}:
|
||||
values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
|
||||
default:
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
_, updateTrackTime = stmt.Get("gorm:update_track_time")
|
||||
isZero bool
|
||||
)
|
||||
stmt.Settings.Delete("gorm:update_track_time")
|
||||
|
||||
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
|
||||
|
||||
for _, db := range stmt.Schema.DBNames {
|
||||
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: db})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
rValLen := stmt.ReflectValue.Len()
|
||||
if rValLen == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
stmt.SQL.Grow(rValLen * 18)
|
||||
stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
|
||||
values.Values = make([][]interface{}, rValLen)
|
||||
|
||||
defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
|
||||
for i := 0; i < rValLen; i++ {
|
||||
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
|
||||
if !rv.IsValid() {
|
||||
stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
|
||||
return
|
||||
}
|
||||
|
||||
values.Values[i] = make([]interface{}, len(values.Columns))
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[i][idx] = field.DefaultValueInterface
|
||||
stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||
stmt.AddError(field.Set(stmt.Context, rv, curTime))
|
||||
values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
|
||||
if len(defaultValueFieldsHavingValue[field]) == 0 {
|
||||
defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
|
||||
}
|
||||
defaultValueFieldsHavingValue[field][i] = rvOfvalue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for field, vs := range defaultValueFieldsHavingValue {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
for idx := range values.Values {
|
||||
if vs[idx] == nil {
|
||||
values.Values[idx] = append(values.Values[idx], stmt.Dialector.DefaultValueOf(field))
|
||||
} else {
|
||||
values.Values[idx] = append(values.Values[idx], vs[idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
|
||||
for idx, column := range values.Columns {
|
||||
field := stmt.Schema.FieldsByDBName[column.Name]
|
||||
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
|
||||
if field.DefaultValueInterface != nil {
|
||||
values.Values[0][idx] = field.DefaultValueInterface
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
|
||||
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||
}
|
||||
} else if field.AutoUpdateTime > 0 && updateTrackTime {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
|
||||
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
|
||||
values.Values[0] = append(values.Values[0], rvOfvalue)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
|
||||
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
|
||||
if stmt.Schema != nil && len(values.Columns) >= 1 {
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
|
||||
|
||||
columns := make([]string, 0, len(values.Columns)-1)
|
||||
for _, column := range values.Columns {
|
||||
if field := stmt.Schema.LookUpField(column.Name); field != nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
|
||||
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
|
||||
if field.AutoUpdateTime > 0 {
|
||||
assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
|
||||
switch field.AutoUpdateTime {
|
||||
case schema.UnixNanosecond:
|
||||
assignment.Value = curTime.UnixNano()
|
||||
case schema.UnixMillisecond:
|
||||
assignment.Value = curTime.UnixNano() / 1e6
|
||||
case schema.UnixSecond:
|
||||
assignment.Value = curTime.Unix()
|
||||
}
|
||||
|
||||
onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
|
||||
} else {
|
||||
columns = append(columns, column.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
|
||||
if len(onConflict.DoUpdates) == 0 {
|
||||
onConflict.DoNothing = true
|
||||
}
|
||||
|
||||
// use primary fields as default OnConflict columns
|
||||
if len(onConflict.Columns) == 0 {
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
|
||||
}
|
||||
}
|
||||
stmt.AddClause(onConflict)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
185
vendor/gorm.io/gorm/callbacks/delete.go
generated
vendored
Normal file
185
vendor/gorm.io/gorm/callbacks/delete.go
generated
vendored
Normal file
@@ -0,0 +1,185 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func BeforeDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(BeforeDeleteInterface); ok {
|
||||
db.AddError(i.BeforeDelete(tx))
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteBeforeAssociations(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
|
||||
if !restricted {
|
||||
return
|
||||
}
|
||||
|
||||
for column, v := range selectColumns {
|
||||
if !v {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, ok := db.Statement.Schema.Relationships.Relations[column]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch rel.Type {
|
||||
case schema.HasOne, schema.HasMany:
|
||||
queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue)
|
||||
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
|
||||
withoutConditions := false
|
||||
if db.Statement.Unscoped {
|
||||
tx = tx.Unscoped()
|
||||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
selects := make([]string, 0, len(db.Statement.Selects))
|
||||
for _, s := range db.Statement.Selects {
|
||||
if s == clause.Associations {
|
||||
selects = append(selects, s)
|
||||
} else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) {
|
||||
selects = append(selects, strings.TrimPrefix(s, columnPrefix))
|
||||
}
|
||||
}
|
||||
|
||||
if len(selects) > 0 {
|
||||
tx = tx.Select(selects)
|
||||
}
|
||||
}
|
||||
|
||||
for _, cond := range queryConds {
|
||||
if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
|
||||
withoutConditions = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
case schema.Many2Many:
|
||||
var (
|
||||
queryConds = make([]clause.Expression, 0, len(rel.References))
|
||||
foreignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
relForeignKeys = make([]string, 0, len(rel.References))
|
||||
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
|
||||
table = rel.JoinTable.Table
|
||||
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
queryConds = append(queryConds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields)
|
||||
column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
|
||||
queryConds = append(queryConds, clause.IN{Column: column, Values: values})
|
||||
|
||||
if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func Delete(config *Config) func(db *gorm.DB) {
|
||||
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.DeleteClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(100)
|
||||
db.Statement.AddClauseIfNotExists(clause.Delete{})
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
ok, mode := hasReturning(db, supportReturning)
|
||||
if !ok {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
gorm.Scan(rows, db, mode)
|
||||
db.AddError(rows.Close())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func AfterDelete(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterDeleteInterface); ok {
|
||||
db.AddError(i.AfterDelete(tx))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
152
vendor/gorm.io/gorm/callbacks/helper.go
generated
vendored
Normal file
152
vendor/gorm.io/gorm/callbacks/helper.go
generated
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ConvertMapToValuesForCreate convert map to values
|
||||
func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
|
||||
values.Columns = make([]clause.Column, 0, len(mapValue))
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
|
||||
|
||||
keys := make([]string, 0, len(mapValue))
|
||||
for k := range mapValue {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
value := mapValue[k]
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
values.Columns = append(values.Columns, clause.Column{Name: k})
|
||||
if len(values.Values) == 0 {
|
||||
values.Values = [][]interface{}{{}}
|
||||
}
|
||||
|
||||
values.Values[0] = append(values.Values[0], value)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ConvertSliceOfMapToValuesForCreate convert slice of map to values
|
||||
func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
|
||||
columns := make([]string, 0, len(mapValues))
|
||||
|
||||
// when the length of mapValues is zero,return directly here
|
||||
// no need to call stmt.SelectAndOmitColumns method
|
||||
if len(mapValues) == 0 {
|
||||
stmt.AddError(gorm.ErrEmptySlice)
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
result = make(map[string][]interface{}, len(mapValues))
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
|
||||
)
|
||||
|
||||
for idx, mapValue := range mapValues {
|
||||
for k, v := range mapValue {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
k = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := result[k]; !ok {
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
result[k] = make([]interface{}, len(mapValues))
|
||||
columns = append(columns, k)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
result[k][idx] = v
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(columns)
|
||||
values.Values = make([][]interface{}, len(mapValues))
|
||||
values.Columns = make([]clause.Column, len(columns))
|
||||
for idx, column := range columns {
|
||||
values.Columns[idx] = clause.Column{Name: column}
|
||||
|
||||
for i, v := range result[column] {
|
||||
if len(values.Values[i]) == 0 {
|
||||
values.Values[i] = make([]interface{}, len(columns))
|
||||
}
|
||||
|
||||
values.Values[i][idx] = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
|
||||
if supportReturning {
|
||||
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
|
||||
returning, _ := c.Expression.(clause.Returning)
|
||||
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
|
||||
return true, 0
|
||||
}
|
||||
return true, gorm.ScanUpdate
|
||||
}
|
||||
}
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func checkMissingWhereConditions(db *gorm.DB) {
|
||||
if !db.AllowGlobalUpdate && db.Error == nil {
|
||||
where, withCondition := db.Statement.Clauses["WHERE"]
|
||||
if withCondition {
|
||||
if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
|
||||
whereClause, _ := where.Expression.(clause.Where)
|
||||
withCondition = len(whereClause.Exprs) > 1
|
||||
}
|
||||
}
|
||||
if !withCondition {
|
||||
db.AddError(gorm.ErrMissingWhereClause)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type visitMap = map[reflect.Value]bool
|
||||
|
||||
// Check if circular values, return true if loaded
|
||||
func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
loaded = true
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
|
||||
loaded = false
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Interface:
|
||||
if v.CanAddr() {
|
||||
p := v.Addr()
|
||||
if _, ok := (*visitMap)[p]; ok {
|
||||
return true
|
||||
}
|
||||
(*visitMap)[p] = true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
39
vendor/gorm.io/gorm/callbacks/interfaces.go
generated
vendored
Normal file
39
vendor/gorm.io/gorm/callbacks/interfaces.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package callbacks
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type BeforeCreateInterface interface {
|
||||
BeforeCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterCreateInterface interface {
|
||||
AfterCreate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeUpdateInterface interface {
|
||||
BeforeUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterUpdateInterface interface {
|
||||
AfterUpdate(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeSaveInterface interface {
|
||||
BeforeSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterSaveInterface interface {
|
||||
AfterSave(*gorm.DB) error
|
||||
}
|
||||
|
||||
type BeforeDeleteInterface interface {
|
||||
BeforeDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterDeleteInterface interface {
|
||||
AfterDelete(*gorm.DB) error
|
||||
}
|
||||
|
||||
type AfterFindInterface interface {
|
||||
AfterFind(*gorm.DB) error
|
||||
}
|
||||
320
vendor/gorm.io/gorm/callbacks/preload.go
generated
vendored
Normal file
320
vendor/gorm.io/gorm/callbacks/preload.go
generated
vendored
Normal file
@@ -0,0 +1,320 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// parsePreloadMap extracts nested preloads. e.g.
|
||||
//
|
||||
// // schema has a "k0" relation and a "k7.k8" embedded relation
|
||||
// parsePreloadMap(schema, map[string][]interface{}{
|
||||
// clause.Associations: {"arg1"},
|
||||
// "k1": {"arg2"},
|
||||
// "k2.k3": {"arg3"},
|
||||
// "k4.k5.k6": {"arg4"},
|
||||
// })
|
||||
// // preloadMap is
|
||||
// map[string]map[string][]interface{}{
|
||||
// "k0": {},
|
||||
// "k7": {
|
||||
// "k8": {},
|
||||
// },
|
||||
// "k1": {},
|
||||
// "k2": {
|
||||
// "k3": {"arg3"},
|
||||
// },
|
||||
// "k4": {
|
||||
// "k5.k6": {"arg4"},
|
||||
// },
|
||||
// }
|
||||
func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
|
||||
preloadMap := map[string]map[string][]interface{}{}
|
||||
setPreloadMap := func(name, value string, args []interface{}) {
|
||||
if _, ok := preloadMap[name]; !ok {
|
||||
preloadMap[name] = map[string][]interface{}{}
|
||||
}
|
||||
if value != "" {
|
||||
preloadMap[name][value] = args
|
||||
}
|
||||
}
|
||||
|
||||
for name, args := range preloads {
|
||||
preloadFields := strings.Split(name, ".")
|
||||
value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
|
||||
if preloadFields[0] == clause.Associations {
|
||||
for _, relation := range s.Relationships.Relations {
|
||||
if relation.Schema == s {
|
||||
setPreloadMap(relation.Name, value, args)
|
||||
}
|
||||
}
|
||||
|
||||
for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
|
||||
for _, value := range embeddedValues(embeddedRelations) {
|
||||
setPreloadMap(embedded, value, args)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
setPreloadMap(preloadFields[0], value, args)
|
||||
}
|
||||
}
|
||||
return preloadMap
|
||||
}
|
||||
|
||||
func embeddedValues(embeddedRelations *schema.Relationships) []string {
|
||||
if embeddedRelations == nil {
|
||||
return nil
|
||||
}
|
||||
names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
|
||||
for _, relation := range embeddedRelations.Relations {
|
||||
// skip first struct name
|
||||
names = append(names, strings.Join(relation.Field.BindNames[1:], "."))
|
||||
}
|
||||
for _, relations := range embeddedRelations.EmbeddedRelations {
|
||||
names = append(names, embeddedValues(relations)...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
|
||||
// If the current relationship is embedded or joined, current query will be ignored.
|
||||
//
|
||||
//nolint:cyclop
|
||||
func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
|
||||
preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
|
||||
|
||||
// avoid random traversal of the map
|
||||
preloadNames := make([]string, 0, len(preloadMap))
|
||||
for key := range preloadMap {
|
||||
preloadNames = append(preloadNames, key)
|
||||
}
|
||||
sort.Strings(preloadNames)
|
||||
|
||||
isJoined := func(name string) (joined bool, nestedJoins []string) {
|
||||
for _, join := range joins {
|
||||
if _, ok := relationships.Relations[join]; ok && name == join {
|
||||
joined = true
|
||||
continue
|
||||
}
|
||||
joinNames := strings.SplitN(join, ".", 2)
|
||||
if len(joinNames) == 2 {
|
||||
if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
|
||||
joined = true
|
||||
nestedJoins = append(nestedJoins, joinNames[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
return joined, nestedJoins
|
||||
}
|
||||
|
||||
for _, name := range preloadNames {
|
||||
if relations := relationships.EmbeddedRelations[name]; relations != nil {
|
||||
if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if rel := relationships.Relations[name]; rel != nil {
|
||||
if joined, nestedJoins := isJoined(name); joined {
|
||||
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
|
||||
tx := preloadDB(db, reflectValue, reflectValue.Interface())
|
||||
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
|
||||
tx.Statement.ReflectValue = db.Statement.ReflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
|
||||
tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
|
||||
db.Statement.Settings.Range(func(k, v interface{}) bool {
|
||||
tx.Statement.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
if err := tx.Statement.Parse(dest); err != nil {
|
||||
tx.AddError(err)
|
||||
return tx
|
||||
}
|
||||
tx.Statement.ReflectValue = reflectValue
|
||||
tx.Statement.Unscoped = db.Statement.Unscoped
|
||||
return tx
|
||||
}
|
||||
|
||||
func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
|
||||
var (
|
||||
reflectValue = tx.Statement.ReflectValue
|
||||
relForeignKeys []string
|
||||
relForeignFields []*schema.Field
|
||||
foreignFields []*schema.Field
|
||||
foreignValues [][]interface{}
|
||||
identityMap = map[string][]reflect.Value{}
|
||||
inlineConds []interface{}
|
||||
)
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
var (
|
||||
joinForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
|
||||
joinForeignKeys = make([]string, 0, len(rel.References))
|
||||
)
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
|
||||
joinForeignFields = append(joinForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
}
|
||||
}
|
||||
|
||||
joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||
if len(joinForeignValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
joinResults := rel.JoinTable.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
|
||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// convert join identity map to relation identity map
|
||||
fieldValues := make([]interface{}, len(joinForeignFields))
|
||||
joinFieldValues := make([]interface{}, len(joinRelForeignFields))
|
||||
for i := 0; i < joinResults.Len(); i++ {
|
||||
joinIndexValue := joinResults.Index(i)
|
||||
for idx, field := range joinForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||
}
|
||||
|
||||
for idx, field := range joinRelForeignFields {
|
||||
joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
|
||||
}
|
||||
|
||||
if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
|
||||
joinKey := utils.ToStringKey(joinFieldValues...)
|
||||
identityMap[joinKey] = append(identityMap[joinKey], results...)
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
|
||||
} else {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.ForeignKey)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
|
||||
} else {
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
relForeignFields = append(relForeignFields, ref.PrimaryKey)
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
|
||||
identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
|
||||
if len(foreignValues) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// nested preload
|
||||
for p, pvs := range preloads {
|
||||
tx = tx.Preload(p, pvs...)
|
||||
}
|
||||
|
||||
reflectResults := rel.FieldSchema.MakeSlice().Elem()
|
||||
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
|
||||
|
||||
if len(values) != 0 {
|
||||
for _, cond := range conds {
|
||||
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
|
||||
tx = fc(tx)
|
||||
} else {
|
||||
inlineConds = append(inlineConds, cond)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fieldValues := make([]interface{}, len(relForeignFields))
|
||||
|
||||
// clean up old values before preloading
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
|
||||
default:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
switch rel.Type {
|
||||
case schema.HasMany, schema.Many2Many:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
|
||||
default:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < reflectResults.Len(); i++ {
|
||||
elem := reflectResults.Index(i)
|
||||
for idx, field := range relForeignFields {
|
||||
fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
|
||||
}
|
||||
|
||||
datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
|
||||
}
|
||||
|
||||
for _, data := range datas {
|
||||
reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
|
||||
if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
|
||||
reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
|
||||
}
|
||||
|
||||
reflectFieldValue = reflect.Indirect(reflectFieldValue)
|
||||
switch reflectFieldValue.Kind() {
|
||||
case reflect.Struct:
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
|
||||
case reflect.Slice, reflect.Array:
|
||||
if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
|
||||
} else {
|
||||
tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Error
|
||||
}
|
||||
299
vendor/gorm.io/gorm/callbacks/query.go
generated
vendored
Normal file
299
vendor/gorm.io/gorm/callbacks/query.go
generated
vendored
Normal file
@@ -0,0 +1,299 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func Query(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
db.AddError(rows.Close())
|
||||
}()
|
||||
gorm.Scan(rows, db, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BuildQuerySQL(db *gorm.DB) {
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.QueryClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(100)
|
||||
clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
|
||||
|
||||
if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
|
||||
var conds []clause.Expression
|
||||
for _, primaryField := range db.Statement.Schema.PrimaryFields {
|
||||
if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
|
||||
}
|
||||
}
|
||||
|
||||
if len(conds) > 0 {
|
||||
db.Statement.AddClause(clause.Where{Exprs: conds})
|
||||
}
|
||||
}
|
||||
|
||||
if len(db.Statement.Selects) > 0 {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
|
||||
for idx, name := range db.Statement.Selects {
|
||||
if db.Statement.Schema == nil {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
} else if f := db.Statement.Schema.LookUpField(name); f != nil {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
|
||||
} else {
|
||||
clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
|
||||
}
|
||||
}
|
||||
} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
|
||||
selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
|
||||
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
|
||||
for _, dbName := range db.Statement.Schema.DBNames {
|
||||
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
|
||||
}
|
||||
}
|
||||
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
|
||||
queryFields := db.QueryFields
|
||||
if !queryFields {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
|
||||
case reflect.Slice:
|
||||
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
|
||||
}
|
||||
}
|
||||
|
||||
if queryFields {
|
||||
stmt := gorm.Statement{DB: db}
|
||||
// smaller struct
|
||||
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
|
||||
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
|
||||
|
||||
for idx, dbName := range stmt.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inline joins
|
||||
fromClause := clause.From{}
|
||||
if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
|
||||
fromClause = v
|
||||
}
|
||||
|
||||
if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 {
|
||||
if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil {
|
||||
clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
|
||||
for idx, dbName := range db.Statement.Schema.DBNames {
|
||||
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
|
||||
}
|
||||
}
|
||||
|
||||
specifiedRelationsName := make(map[string]interface{})
|
||||
for _, join := range db.Statement.Joins {
|
||||
if db.Statement.Schema != nil {
|
||||
var isRelations bool // is relations or raw sql
|
||||
var relations []*schema.Relationship
|
||||
relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]
|
||||
if ok {
|
||||
isRelations = true
|
||||
relations = append(relations, relation)
|
||||
} else {
|
||||
// handle nested join like "Manager.Company"
|
||||
nestedJoinNames := strings.Split(join.Name, ".")
|
||||
if len(nestedJoinNames) > 1 {
|
||||
isNestedJoin := true
|
||||
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
|
||||
currentRelations := db.Statement.Schema.Relationships.Relations
|
||||
for _, relname := range nestedJoinNames {
|
||||
// incomplete match, only treated as raw sql
|
||||
if relation, ok = currentRelations[relname]; ok {
|
||||
gussNestedRelations = append(gussNestedRelations, relation)
|
||||
currentRelations = relation.FieldSchema.Relationships.Relations
|
||||
} else {
|
||||
isNestedJoin = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if isNestedJoin {
|
||||
isRelations = true
|
||||
relations = gussNestedRelations
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if isRelations {
|
||||
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
|
||||
tableAliasName := relation.Name
|
||||
if parentTableName != clause.CurrentTable {
|
||||
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
|
||||
}
|
||||
|
||||
columnStmt := gorm.Statement{
|
||||
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
|
||||
Selects: join.Selects, Omits: join.Omits,
|
||||
}
|
||||
|
||||
selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false)
|
||||
for _, s := range relation.FieldSchema.DBNames {
|
||||
if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) {
|
||||
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
|
||||
Table: tableAliasName,
|
||||
Name: s,
|
||||
Alias: utils.NestedRelationName(tableAliasName, s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, len(relation.References))
|
||||
for idx, ref := range relation.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
}
|
||||
} else {
|
||||
if ref.PrimaryValue == "" {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
|
||||
}
|
||||
} else {
|
||||
exprs[idx] = clause.Eq{
|
||||
Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}}
|
||||
for _, c := range relation.FieldSchema.QueryClauses {
|
||||
onStmt.AddClause(c)
|
||||
}
|
||||
|
||||
if join.On != nil {
|
||||
onStmt.AddClause(join.On)
|
||||
}
|
||||
|
||||
if cs, ok := onStmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
where.Build(&onStmt)
|
||||
|
||||
if onSQL := onStmt.SQL.String(); onSQL != "" {
|
||||
vars := onStmt.Vars
|
||||
for idx, v := range vars {
|
||||
bindvar := strings.Builder{}
|
||||
onStmt.Vars = vars[0 : idx+1]
|
||||
db.Dialector.BindVarTo(&bindvar, &onStmt, v)
|
||||
onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clause.Join{
|
||||
Type: joinType,
|
||||
Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
|
||||
ON: clause.Where{Exprs: exprs},
|
||||
}
|
||||
}
|
||||
|
||||
parentTableName := clause.CurrentTable
|
||||
for _, rel := range relations {
|
||||
// joins table alias like "Manager, Company, Manager__Company"
|
||||
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
|
||||
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
|
||||
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
|
||||
specifiedRelationsName[nestedAlias] = nil
|
||||
}
|
||||
|
||||
if parentTableName != clause.CurrentTable {
|
||||
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
|
||||
} else {
|
||||
parentTableName = rel.Name
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
fromClause.Joins = append(fromClause.Joins, clause.Join{
|
||||
Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.AddClause(fromClause)
|
||||
} else {
|
||||
db.Statement.AddClauseIfNotExists(clause.From{})
|
||||
}
|
||||
|
||||
db.Statement.AddClauseIfNotExists(clauseSelect)
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
}
|
||||
|
||||
func Preload(db *gorm.DB) {
|
||||
if db.Error == nil && len(db.Statement.Preloads) > 0 {
|
||||
if db.Statement.Schema == nil {
|
||||
db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired))
|
||||
return
|
||||
}
|
||||
|
||||
joins := make([]string, 0, len(db.Statement.Joins))
|
||||
for _, join := range db.Statement.Joins {
|
||||
joins = append(joins, join.Name)
|
||||
}
|
||||
|
||||
tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest)
|
||||
if tx.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations]))
|
||||
}
|
||||
}
|
||||
|
||||
func AfterQuery(db *gorm.DB) {
|
||||
// clear the joins after query because preload need it
|
||||
db.Statement.Joins = nil
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
|
||||
if i, ok := value.(AfterFindInterface); ok {
|
||||
db.AddError(i.AfterFind(tx))
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
}
|
||||
17
vendor/gorm.io/gorm/callbacks/raw.go
generated
vendored
Normal file
17
vendor/gorm.io/gorm/callbacks/raw.go
generated
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RawExec(db *gorm.DB) {
|
||||
if db.Error == nil && !db.DryRun {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
if err != nil {
|
||||
db.AddError(err)
|
||||
return
|
||||
}
|
||||
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
}
|
||||
23
vendor/gorm.io/gorm/callbacks/row.go
generated
vendored
Normal file
23
vendor/gorm.io/gorm/callbacks/row.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RowQuery(db *gorm.DB) {
|
||||
if db.Error == nil {
|
||||
BuildQuerySQL(db)
|
||||
if db.DryRun || db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if isRows, ok := db.Get("rows"); ok && isRows.(bool) {
|
||||
db.Statement.Settings.Delete("rows")
|
||||
db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
} else {
|
||||
db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
}
|
||||
|
||||
db.RowsAffected = -1
|
||||
}
|
||||
}
|
||||
32
vendor/gorm.io/gorm/callbacks/transaction.go
generated
vendored
Normal file
32
vendor/gorm.io/gorm/callbacks/transaction.go
generated
vendored
Normal file
@@ -0,0 +1,32 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func BeginTransaction(db *gorm.DB) {
|
||||
if !db.Config.SkipDefaultTransaction && db.Error == nil {
|
||||
if tx := db.Begin(); tx.Error == nil {
|
||||
db.Statement.ConnPool = tx.Statement.ConnPool
|
||||
db.InstanceSet("gorm:started_transaction", true)
|
||||
} else if tx.Error == gorm.ErrInvalidTransaction {
|
||||
tx.Error = nil
|
||||
} else {
|
||||
db.Error = tx.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CommitOrRollbackTransaction(db *gorm.DB) {
|
||||
if !db.Config.SkipDefaultTransaction {
|
||||
if _, ok := db.InstanceGet("gorm:started_transaction"); ok {
|
||||
if db.Error != nil {
|
||||
db.Rollback()
|
||||
} else {
|
||||
db.Commit()
|
||||
}
|
||||
|
||||
db.Statement.ConnPool = db.ConnPool
|
||||
}
|
||||
}
|
||||
}
|
||||
304
vendor/gorm.io/gorm/callbacks/update.go
generated
vendored
Normal file
304
vendor/gorm.io/gorm/callbacks/update.go
generated
vendored
Normal file
@@ -0,0 +1,304 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
func SetupUpdateReflectValue(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil {
|
||||
if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
|
||||
db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
|
||||
for db.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
|
||||
}
|
||||
|
||||
if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
|
||||
for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
|
||||
if _, ok := dest[rel.Name]; ok {
|
||||
db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeUpdate before update hooks
|
||||
func BeforeUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.BeforeSave {
|
||||
if i, ok := value.(BeforeSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.BeforeUpdate {
|
||||
if i, ok := value.(BeforeUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.BeforeUpdate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Update update hook
|
||||
func Update(config *Config) func(db *gorm.DB) {
|
||||
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
|
||||
|
||||
return func(db *gorm.DB) {
|
||||
if db.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil {
|
||||
for _, c := range db.Statement.Schema.UpdateClauses {
|
||||
db.Statement.AddClause(c)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.SQL.Len() == 0 {
|
||||
db.Statement.SQL.Grow(180)
|
||||
db.Statement.AddClauseIfNotExists(clause.Update{})
|
||||
if _, ok := db.Statement.Clauses["SET"]; !ok {
|
||||
if set := ConvertToAssignments(db.Statement); len(set) != 0 {
|
||||
defer delete(db.Statement.Clauses, "SET")
|
||||
db.Statement.AddClause(set)
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
db.Statement.Build(db.Statement.BuildClauses...)
|
||||
}
|
||||
|
||||
checkMissingWhereConditions(db)
|
||||
|
||||
if !db.DryRun && db.Error == nil {
|
||||
if ok, mode := hasReturning(db, supportReturning); ok {
|
||||
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
|
||||
dest := db.Statement.Dest
|
||||
db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
|
||||
gorm.Scan(rows, db, mode)
|
||||
db.Statement.Dest = dest
|
||||
db.AddError(rows.Close())
|
||||
}
|
||||
} else {
|
||||
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
if db.AddError(err) == nil {
|
||||
db.RowsAffected, _ = result.RowsAffected()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AfterUpdate after update hooks
|
||||
func AfterUpdate(db *gorm.DB) {
|
||||
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
|
||||
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
|
||||
if db.Statement.Schema.AfterUpdate {
|
||||
if i, ok := value.(AfterUpdateInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterUpdate(tx))
|
||||
}
|
||||
}
|
||||
|
||||
if db.Statement.Schema.AfterSave {
|
||||
if i, ok := value.(AfterSaveInterface); ok {
|
||||
called = true
|
||||
db.AddError(i.AfterSave(tx))
|
||||
}
|
||||
}
|
||||
|
||||
return called
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertToAssignments convert to update assignments
|
||||
func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
|
||||
var (
|
||||
selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
|
||||
assignValue func(field *schema.Field, value interface{})
|
||||
)
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
if stmt.ReflectValue.CanAddr() {
|
||||
field.Set(stmt.Context, stmt.ReflectValue, value)
|
||||
}
|
||||
}
|
||||
default:
|
||||
assignValue = func(field *schema.Field, value interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
updatingValue := reflect.ValueOf(stmt.Dest)
|
||||
for updatingValue.Kind() == reflect.Ptr {
|
||||
updatingValue = updatingValue.Elem()
|
||||
}
|
||||
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if size := stmt.ReflectValue.Len(); size > 0 {
|
||||
var isZero bool
|
||||
for i := 0; i < size; i++ {
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
_, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
|
||||
if !isZero {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isZero {
|
||||
_, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch value := updatingValue.Interface().(type) {
|
||||
case map[string]interface{}:
|
||||
set = make([]clause.Assignment, 0, len(value))
|
||||
|
||||
keys := make([]string, 0, len(value))
|
||||
for k := range value {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, k := range keys {
|
||||
kv := value[k]
|
||||
if _, ok := kv.(*gorm.DB); ok {
|
||||
kv = []interface{}{kv}
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(k); field != nil {
|
||||
if field.DBName != "" {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
} else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
|
||||
assignValue(field, value[k])
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
|
||||
}
|
||||
}
|
||||
|
||||
if !stmt.SkipHooks && stmt.Schema != nil {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.LookUpField(dbName)
|
||||
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
|
||||
now := stmt.DB.NowFunc()
|
||||
assignValue(field, now)
|
||||
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano() / 1e6})
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
|
||||
} else {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
updatingSchema := stmt.Schema
|
||||
var isDiffSchema bool
|
||||
if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
// different schema
|
||||
updatingStmt := &gorm.Statement{DB: stmt.DB}
|
||||
if err := updatingStmt.Parse(stmt.Dest); err == nil {
|
||||
updatingSchema = updatingStmt.Schema
|
||||
isDiffSchema = true
|
||||
}
|
||||
}
|
||||
|
||||
switch updatingValue.Kind() {
|
||||
case reflect.Struct:
|
||||
set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
if field := updatingSchema.LookUpField(dbName); field != nil {
|
||||
if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
|
||||
value, isZero := field.ValueOf(stmt.Context, updatingValue)
|
||||
if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
|
||||
if field.AutoUpdateTime == schema.UnixNanosecond {
|
||||
value = stmt.DB.NowFunc().UnixNano()
|
||||
} else if field.AutoUpdateTime == schema.UnixMillisecond {
|
||||
value = stmt.DB.NowFunc().UnixNano() / 1e6
|
||||
} else if field.AutoUpdateTime == schema.UnixSecond {
|
||||
value = stmt.DB.NowFunc().Unix()
|
||||
} else {
|
||||
value = stmt.DB.NowFunc()
|
||||
}
|
||||
isZero = false
|
||||
}
|
||||
|
||||
if (ok || !isZero) && field.Updatable {
|
||||
set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
|
||||
assignField := field
|
||||
if isDiffSchema {
|
||||
if originField := stmt.Schema.LookUpField(dbName); originField != nil {
|
||||
assignField = originField
|
||||
}
|
||||
}
|
||||
assignValue(assignField, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
stmt.AddError(gorm.ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
448
vendor/gorm.io/gorm/chainable_api.go
generated
vendored
Normal file
448
vendor/gorm.io/gorm/chainable_api.go
generated
vendored
Normal file
@@ -0,0 +1,448 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Model specify the model you would like to run db operations
|
||||
//
|
||||
// // update all users's name to `hello`
|
||||
// db.Model(&User{}).Update("name", "hello")
|
||||
// // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello`
|
||||
// db.Model(&user).Update("name", "hello")
|
||||
func (db *DB) Model(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Model = value
|
||||
return
|
||||
}
|
||||
|
||||
// Clauses Add clauses
|
||||
//
|
||||
// This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more
|
||||
// advanced techniques like specifying lock strength and optimizer hints. See the
|
||||
// [docs] for more depth.
|
||||
//
|
||||
// // add a simple limit clause
|
||||
// db.Clauses(clause.Limit{Limit: 1}).Find(&User{})
|
||||
// // tell the optimizer to use the `idx_user_name` index
|
||||
// db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{})
|
||||
// // specify the lock strength to UPDATE
|
||||
// db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/sql_builder.html#Clauses
|
||||
func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
var whereConds []interface{}
|
||||
|
||||
for _, cond := range conds {
|
||||
if c, ok := cond.(clause.Interface); ok {
|
||||
tx.Statement.AddClause(c)
|
||||
} else if optimizer, ok := cond.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(tx.Statement)
|
||||
} else {
|
||||
whereConds = append(whereConds, cond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(whereConds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`)
|
||||
|
||||
// Table specify the table you would like to run db operations
|
||||
//
|
||||
// // Get a user
|
||||
// db.Table("users").Take(&result)
|
||||
func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
|
||||
if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 {
|
||||
if results[1] != "" {
|
||||
tx.Statement.Table = results[1]
|
||||
} else {
|
||||
tx.Statement.Table = results[2]
|
||||
}
|
||||
}
|
||||
} else if tables := strings.Split(name, "."); len(tables) == 2 {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
tx.Statement.Table = tables[1]
|
||||
} else if name != "" {
|
||||
tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
|
||||
tx.Statement.Table = name
|
||||
} else {
|
||||
tx.Statement.TableExpr = nil
|
||||
tx.Statement.Table = ""
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Distinct specify distinct fields that you want querying
|
||||
//
|
||||
// // Select distinct names of users
|
||||
// db.Distinct("name").Find(&results)
|
||||
// // Select distinct name/age pairs from users
|
||||
// db.Distinct("name", "age").Find(&results)
|
||||
func (db *DB) Distinct(args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Distinct = true
|
||||
if len(args) > 0 {
|
||||
tx = tx.Select(args[0], args[1:]...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Select specify fields that you want when querying, creating, updating
|
||||
//
|
||||
// Use Select when you only want a subset of the fields. By default, GORM will select all fields.
|
||||
// Select accepts both string arguments and arrays.
|
||||
//
|
||||
// // Select name and age of user using multiple arguments
|
||||
// db.Select("name", "age").Find(&users)
|
||||
// // Select name and age of user using an array
|
||||
// db.Select([]string{"name", "age"}).Find(&users)
|
||||
func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
switch v := query.(type) {
|
||||
case []string:
|
||||
tx.Statement.Selects = v
|
||||
|
||||
for _, arg := range args {
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg)
|
||||
case []string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
|
||||
default:
|
||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
|
||||
clause.Expression = nil
|
||||
tx.Statement.Clauses["SELECT"] = clause
|
||||
}
|
||||
case string:
|
||||
if strings.Count(v, "?") >= len(args) && len(args) > 0 {
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.Expr{SQL: v, Vars: args},
|
||||
})
|
||||
} else if strings.Count(v, "@") > 0 && len(args) > 0 {
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.NamedExpr{SQL: v, Vars: args},
|
||||
})
|
||||
} else {
|
||||
tx.Statement.Selects = []string{v}
|
||||
|
||||
for _, arg := range args {
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg)
|
||||
case []string:
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, arg...)
|
||||
default:
|
||||
tx.Statement.AddClause(clause.Select{
|
||||
Distinct: db.Statement.Distinct,
|
||||
Expression: clause.Expr{SQL: v, Vars: args},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if clause, ok := tx.Statement.Clauses["SELECT"]; ok {
|
||||
clause.Expression = nil
|
||||
tx.Statement.Clauses["SELECT"] = clause
|
||||
}
|
||||
}
|
||||
default:
|
||||
tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Omit specify fields that you want to ignore when creating, updating and querying
|
||||
func (db *DB) Omit(columns ...string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
|
||||
tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
|
||||
} else {
|
||||
tx.Statement.Omits = columns
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Where add conditions
|
||||
//
|
||||
// See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND.
|
||||
//
|
||||
// // Find the first user with name jinzhu
|
||||
// db.Where("name = ?", "jinzhu").First(&user)
|
||||
// // Find the first user with name jinzhu and age 20
|
||||
// db.Where(&User{Name: "jinzhu", Age: 20}).First(&user)
|
||||
// // Find the first user with name jinzhu and age not equal to 20
|
||||
// db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user)
|
||||
//
|
||||
// [docs]: https://gorm.io/docs/query.html#Conditions
|
||||
func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: conds})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Not add NOT conditions
|
||||
//
|
||||
// Not works similarly to where, and has the same syntax.
|
||||
//
|
||||
// // Find the first user with name not equal to jinzhu
|
||||
// db.Not("name = ?", "jinzhu").First(&user)
|
||||
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Or add OR conditions
|
||||
//
|
||||
// Or is used to chain together queries with an OR.
|
||||
//
|
||||
// // Find the first user with name equal to jinzhu or john
|
||||
// db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user)
|
||||
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Joins specify Joins conditions
|
||||
//
|
||||
// db.Joins("Account").Find(&user)
|
||||
// db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
|
||||
// db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
|
||||
func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.LeftJoin, query, args...)
|
||||
}
|
||||
|
||||
// InnerJoins specify inner joins conditions
|
||||
// db.InnerJoins("Account").Find(&user)
|
||||
func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) {
|
||||
return joins(db, clause.InnerJoin, query, args...)
|
||||
}
|
||||
|
||||
func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
if len(args) == 1 {
|
||||
if db, ok := args[0].(*DB); ok {
|
||||
j := join{
|
||||
Name: query, Conds: args, Selects: db.Statement.Selects,
|
||||
Omits: db.Statement.Omits, JoinType: joinType,
|
||||
}
|
||||
if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
|
||||
j.On = &where
|
||||
}
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, j)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType})
|
||||
return
|
||||
}
|
||||
|
||||
// Group specify the group method on the find
|
||||
//
|
||||
// // Select the sum age of users with given names
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results)
|
||||
func (db *DB) Group(name string) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Having specify HAVING conditions for GROUP BY
|
||||
//
|
||||
// // Select the sum age of users with name jinzhu
|
||||
// db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result)
|
||||
func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.GroupBy{
|
||||
Having: tx.Statement.BuildCondition(query, args...),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Order specify order when retrieving records from database
|
||||
//
|
||||
// db.Order("name DESC")
|
||||
// db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
|
||||
func (db *DB) Order(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
|
||||
switch v := value.(type) {
|
||||
case clause.OrderByColumn:
|
||||
tx.Statement.AddClause(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{v},
|
||||
})
|
||||
case string:
|
||||
if v != "" {
|
||||
tx.Statement.AddClause(clause.OrderBy{
|
||||
Columns: []clause.OrderByColumn{{
|
||||
Column: clause.Column{Name: v, Raw: true},
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Limit specify the number of records to be retrieved
|
||||
//
|
||||
// Limit conditions can be cancelled by using `Limit(-1)`.
|
||||
//
|
||||
// // retrieve 3 users
|
||||
// db.Limit(3).Find(&users)
|
||||
// // retrieve 3 users into users1, and all users into users2
|
||||
// db.Limit(3).Find(&users1).Limit(-1).Find(&users2)
|
||||
func (db *DB) Limit(limit int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Limit: &limit})
|
||||
return
|
||||
}
|
||||
|
||||
// Offset specify the number of records to skip before starting to return the records
|
||||
//
|
||||
// Offset conditions can be cancelled by using `Offset(-1)`.
|
||||
//
|
||||
// // select the third user
|
||||
// db.Offset(2).First(&user)
|
||||
// // select the first user by cancelling an earlier chained offset
|
||||
// db.Offset(5).Offset(-1).First(&user)
|
||||
func (db *DB) Offset(offset int) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.AddClause(clause.Limit{Offset: offset})
|
||||
return
|
||||
}
|
||||
|
||||
// Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
|
||||
//
|
||||
// func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
|
||||
// return db.Where("amount > ?", 1000)
|
||||
// }
|
||||
//
|
||||
// func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
|
||||
// return func (db *gorm.DB) *gorm.DB {
|
||||
// return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
|
||||
func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
|
||||
return tx
|
||||
}
|
||||
|
||||
func (db *DB) executeScopes() (tx *DB) {
|
||||
scopes := db.Statement.scopes
|
||||
db.Statement.scopes = nil
|
||||
for _, scope := range scopes {
|
||||
db = scope(db)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Preload preload associations with given conditions
|
||||
//
|
||||
// // get all users, and preload all non-cancelled orders
|
||||
// db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
|
||||
func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Preloads == nil {
|
||||
tx.Statement.Preloads = map[string][]interface{}{}
|
||||
}
|
||||
tx.Statement.Preloads[query] = args
|
||||
return
|
||||
}
|
||||
|
||||
// Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Attrs only adds attributes if the record is not found.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign an email if the record is not found, otherwise ignore provided email
|
||||
// db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.attrs = attrs
|
||||
return
|
||||
}
|
||||
|
||||
// Assign provide attributes used in [FirstOrCreate] or [FirstOrInit]
|
||||
//
|
||||
// Assign adds attributes even if the record is found. If using FirstOrCreate, this means that
|
||||
// records will be updated even if they are found.
|
||||
//
|
||||
// // assign an email regardless of if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
//
|
||||
// [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate
|
||||
// [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit
|
||||
func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.assigns = attrs
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Unscoped() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Unscoped = true
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
}
|
||||
return
|
||||
}
|
||||
89
vendor/gorm.io/gorm/clause/clause.go
generated
vendored
Normal file
89
vendor/gorm.io/gorm/clause/clause.go
generated
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
package clause
|
||||
|
||||
// Interface clause interface
|
||||
type Interface interface {
|
||||
Name() string
|
||||
Build(Builder)
|
||||
MergeClause(*Clause)
|
||||
}
|
||||
|
||||
// ClauseBuilder clause builder, allows to customize how to build clause
|
||||
type ClauseBuilder func(Clause, Builder)
|
||||
|
||||
type Writer interface {
|
||||
WriteByte(byte) error
|
||||
WriteString(string) (int, error)
|
||||
}
|
||||
|
||||
// Builder builder interface
|
||||
type Builder interface {
|
||||
Writer
|
||||
WriteQuoted(field interface{})
|
||||
AddVar(Writer, ...interface{})
|
||||
AddError(error) error
|
||||
}
|
||||
|
||||
// Clause
|
||||
type Clause struct {
|
||||
Name string // WHERE
|
||||
BeforeExpression Expression
|
||||
AfterNameExpression Expression
|
||||
AfterExpression Expression
|
||||
Expression Expression
|
||||
Builder ClauseBuilder
|
||||
}
|
||||
|
||||
// Build build clause
|
||||
func (c Clause) Build(builder Builder) {
|
||||
if c.Builder != nil {
|
||||
c.Builder(c, builder)
|
||||
} else if c.Expression != nil {
|
||||
if c.BeforeExpression != nil {
|
||||
c.BeforeExpression.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if c.Name != "" {
|
||||
builder.WriteString(c.Name)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if c.AfterNameExpression != nil {
|
||||
c.AfterNameExpression.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
c.Expression.Build(builder)
|
||||
|
||||
if c.AfterExpression != nil {
|
||||
builder.WriteByte(' ')
|
||||
c.AfterExpression.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
PrimaryKey string = "~~~py~~~" // primary key
|
||||
CurrentTable string = "~~~ct~~~" // current table
|
||||
Associations string = "~~~as~~~" // associations
|
||||
)
|
||||
|
||||
var (
|
||||
currentTable = Table{Name: CurrentTable}
|
||||
PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey}
|
||||
)
|
||||
|
||||
// Column quote with name
|
||||
type Column struct {
|
||||
Table string
|
||||
Name string
|
||||
Alias string
|
||||
Raw bool
|
||||
}
|
||||
|
||||
// Table quote with name
|
||||
type Table struct {
|
||||
Name string
|
||||
Alias string
|
||||
Raw bool
|
||||
}
|
||||
23
vendor/gorm.io/gorm/clause/delete.go
generated
vendored
Normal file
23
vendor/gorm.io/gorm/clause/delete.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package clause
|
||||
|
||||
type Delete struct {
|
||||
Modifier string
|
||||
}
|
||||
|
||||
func (d Delete) Name() string {
|
||||
return "DELETE"
|
||||
}
|
||||
|
||||
func (d Delete) Build(builder Builder) {
|
||||
builder.WriteString("DELETE")
|
||||
|
||||
if d.Modifier != "" {
|
||||
builder.WriteByte(' ')
|
||||
builder.WriteString(d.Modifier)
|
||||
}
|
||||
}
|
||||
|
||||
func (d Delete) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
clause.Expression = d
|
||||
}
|
||||
385
vendor/gorm.io/gorm/clause/expression.go
generated
vendored
Normal file
385
vendor/gorm.io/gorm/clause/expression.go
generated
vendored
Normal file
@@ -0,0 +1,385 @@
|
||||
package clause
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// Expression expression interface
|
||||
type Expression interface {
|
||||
Build(builder Builder)
|
||||
}
|
||||
|
||||
// NegationExpressionBuilder negation expression builder
|
||||
type NegationExpressionBuilder interface {
|
||||
NegationBuild(builder Builder)
|
||||
}
|
||||
|
||||
// Expr raw expression
|
||||
type Expr struct {
|
||||
SQL string
|
||||
Vars []interface{}
|
||||
WithoutParentheses bool
|
||||
}
|
||||
|
||||
// Build build raw expression
|
||||
func (expr Expr) Build(builder Builder) {
|
||||
var (
|
||||
afterParenthesis bool
|
||||
idx int
|
||||
)
|
||||
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '?' && len(expr.Vars) > idx {
|
||||
if afterParenthesis || expr.WithoutParentheses {
|
||||
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
} else {
|
||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() == 0 {
|
||||
builder.AddVar(builder, nil)
|
||||
} else {
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
}
|
||||
default:
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
|
||||
idx++
|
||||
} else {
|
||||
if v == '(' {
|
||||
afterParenthesis = true
|
||||
} else {
|
||||
afterParenthesis = false
|
||||
}
|
||||
builder.WriteByte(v)
|
||||
}
|
||||
}
|
||||
|
||||
if idx < len(expr.Vars) {
|
||||
for _, v := range expr.Vars[idx:] {
|
||||
builder.AddVar(builder, sql.NamedArg{Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NamedExpr raw expression for named expr
|
||||
type NamedExpr struct {
|
||||
SQL string
|
||||
Vars []interface{}
|
||||
}
|
||||
|
||||
// Build build raw expression
|
||||
func (expr NamedExpr) Build(builder Builder) {
|
||||
var (
|
||||
idx int
|
||||
inName bool
|
||||
afterParenthesis bool
|
||||
namedMap = make(map[string]interface{}, len(expr.Vars))
|
||||
)
|
||||
|
||||
for _, v := range expr.Vars {
|
||||
switch value := v.(type) {
|
||||
case sql.NamedArg:
|
||||
namedMap[value.Name] = value.Value
|
||||
case map[string]interface{}:
|
||||
for k, v := range value {
|
||||
namedMap[k] = v
|
||||
}
|
||||
default:
|
||||
var appendFieldsToMap func(reflect.Value)
|
||||
appendFieldsToMap = func(reflectValue reflect.Value) {
|
||||
reflectValue = reflect.Indirect(reflectValue)
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
modelType := reflectValue.Type()
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
|
||||
|
||||
if fieldStruct.Anonymous {
|
||||
appendFieldsToMap(reflectValue.Field(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
appendFieldsToMap(reflect.ValueOf(value))
|
||||
}
|
||||
}
|
||||
|
||||
name := make([]byte, 0, 10)
|
||||
|
||||
for _, v := range []byte(expr.SQL) {
|
||||
if v == '@' && !inName {
|
||||
inName = true
|
||||
name = name[:0]
|
||||
} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' {
|
||||
if inName {
|
||||
if nv, ok := namedMap[string(name)]; ok {
|
||||
builder.AddVar(builder, nv)
|
||||
} else {
|
||||
builder.WriteByte('@')
|
||||
builder.WriteString(string(name))
|
||||
}
|
||||
inName = false
|
||||
}
|
||||
|
||||
afterParenthesis = false
|
||||
builder.WriteByte(v)
|
||||
} else if v == '?' && len(expr.Vars) > idx {
|
||||
if afterParenthesis {
|
||||
if _, ok := expr.Vars[idx].(driver.Valuer); ok {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
} else {
|
||||
switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() == 0 {
|
||||
builder.AddVar(builder, nil)
|
||||
} else {
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
}
|
||||
default:
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
builder.AddVar(builder, expr.Vars[idx])
|
||||
}
|
||||
|
||||
idx++
|
||||
} else if inName {
|
||||
name = append(name, v)
|
||||
} else {
|
||||
if v == '(' {
|
||||
afterParenthesis = true
|
||||
} else {
|
||||
afterParenthesis = false
|
||||
}
|
||||
builder.WriteByte(v)
|
||||
}
|
||||
}
|
||||
|
||||
if inName {
|
||||
if nv, ok := namedMap[string(name)]; ok {
|
||||
builder.AddVar(builder, nv)
|
||||
} else {
|
||||
builder.WriteByte('@')
|
||||
builder.WriteString(string(name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IN Whether a value is within a set of values
|
||||
type IN struct {
|
||||
Column interface{}
|
||||
Values []interface{}
|
||||
}
|
||||
|
||||
func (in IN) Build(builder Builder) {
|
||||
builder.WriteQuoted(in.Column)
|
||||
|
||||
switch len(in.Values) {
|
||||
case 0:
|
||||
builder.WriteString(" IN (NULL)")
|
||||
case 1:
|
||||
if _, ok := in.Values[0].([]interface{}); !ok {
|
||||
builder.WriteString(" = ")
|
||||
builder.AddVar(builder, in.Values[0])
|
||||
break
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
builder.WriteString(" IN (")
|
||||
builder.AddVar(builder, in.Values...)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
func (in IN) NegationBuild(builder Builder) {
|
||||
builder.WriteQuoted(in.Column)
|
||||
switch len(in.Values) {
|
||||
case 0:
|
||||
builder.WriteString(" IS NOT NULL")
|
||||
case 1:
|
||||
if _, ok := in.Values[0].([]interface{}); !ok {
|
||||
builder.WriteString(" <> ")
|
||||
builder.AddVar(builder, in.Values[0])
|
||||
break
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
builder.WriteString(" NOT IN (")
|
||||
builder.AddVar(builder, in.Values...)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
// Eq equal to for where
|
||||
type Eq struct {
|
||||
Column interface{}
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (eq Eq) Build(builder Builder) {
|
||||
builder.WriteQuoted(eq.Column)
|
||||
|
||||
switch eq.Value.(type) {
|
||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||
rv := reflect.ValueOf(eq.Value)
|
||||
if rv.Len() == 0 {
|
||||
builder.WriteString(" IN (NULL)")
|
||||
} else {
|
||||
builder.WriteString(" IN (")
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
default:
|
||||
if eqNil(eq.Value) {
|
||||
builder.WriteString(" IS NULL")
|
||||
} else {
|
||||
builder.WriteString(" = ")
|
||||
builder.AddVar(builder, eq.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (eq Eq) NegationBuild(builder Builder) {
|
||||
Neq(eq).Build(builder)
|
||||
}
|
||||
|
||||
// Neq not equal to for where
|
||||
type Neq Eq
|
||||
|
||||
func (neq Neq) Build(builder Builder) {
|
||||
builder.WriteQuoted(neq.Column)
|
||||
|
||||
switch neq.Value.(type) {
|
||||
case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
|
||||
builder.WriteString(" NOT IN (")
|
||||
rv := reflect.ValueOf(neq.Value)
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.AddVar(builder, rv.Index(i).Interface())
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
default:
|
||||
if eqNil(neq.Value) {
|
||||
builder.WriteString(" IS NOT NULL")
|
||||
} else {
|
||||
builder.WriteString(" <> ")
|
||||
builder.AddVar(builder, neq.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (neq Neq) NegationBuild(builder Builder) {
|
||||
Eq(neq).Build(builder)
|
||||
}
|
||||
|
||||
// Gt greater than for where
|
||||
type Gt Eq
|
||||
|
||||
func (gt Gt) Build(builder Builder) {
|
||||
builder.WriteQuoted(gt.Column)
|
||||
builder.WriteString(" > ")
|
||||
builder.AddVar(builder, gt.Value)
|
||||
}
|
||||
|
||||
func (gt Gt) NegationBuild(builder Builder) {
|
||||
Lte(gt).Build(builder)
|
||||
}
|
||||
|
||||
// Gte greater than or equal to for where
|
||||
type Gte Eq
|
||||
|
||||
func (gte Gte) Build(builder Builder) {
|
||||
builder.WriteQuoted(gte.Column)
|
||||
builder.WriteString(" >= ")
|
||||
builder.AddVar(builder, gte.Value)
|
||||
}
|
||||
|
||||
func (gte Gte) NegationBuild(builder Builder) {
|
||||
Lt(gte).Build(builder)
|
||||
}
|
||||
|
||||
// Lt less than for where
|
||||
type Lt Eq
|
||||
|
||||
func (lt Lt) Build(builder Builder) {
|
||||
builder.WriteQuoted(lt.Column)
|
||||
builder.WriteString(" < ")
|
||||
builder.AddVar(builder, lt.Value)
|
||||
}
|
||||
|
||||
func (lt Lt) NegationBuild(builder Builder) {
|
||||
Gte(lt).Build(builder)
|
||||
}
|
||||
|
||||
// Lte less than or equal to for where
|
||||
type Lte Eq
|
||||
|
||||
func (lte Lte) Build(builder Builder) {
|
||||
builder.WriteQuoted(lte.Column)
|
||||
builder.WriteString(" <= ")
|
||||
builder.AddVar(builder, lte.Value)
|
||||
}
|
||||
|
||||
func (lte Lte) NegationBuild(builder Builder) {
|
||||
Gt(lte).Build(builder)
|
||||
}
|
||||
|
||||
// Like whether string matches regular expression
|
||||
type Like Eq
|
||||
|
||||
func (like Like) Build(builder Builder) {
|
||||
builder.WriteQuoted(like.Column)
|
||||
builder.WriteString(" LIKE ")
|
||||
builder.AddVar(builder, like.Value)
|
||||
}
|
||||
|
||||
func (like Like) NegationBuild(builder Builder) {
|
||||
builder.WriteQuoted(like.Column)
|
||||
builder.WriteString(" NOT LIKE ")
|
||||
builder.AddVar(builder, like.Value)
|
||||
}
|
||||
|
||||
func eqNil(value interface{}) bool {
|
||||
if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) {
|
||||
value, _ = valuer.Value()
|
||||
}
|
||||
|
||||
return value == nil || eqNilReflect(value)
|
||||
}
|
||||
|
||||
func eqNilReflect(value interface{}) bool {
|
||||
reflectValue := reflect.ValueOf(value)
|
||||
return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
|
||||
}
|
||||
37
vendor/gorm.io/gorm/clause/from.go
generated
vendored
Normal file
37
vendor/gorm.io/gorm/clause/from.go
generated
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
package clause
|
||||
|
||||
// From from clause
|
||||
type From struct {
|
||||
Tables []Table
|
||||
Joins []Join
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (from From) Name() string {
|
||||
return "FROM"
|
||||
}
|
||||
|
||||
// Build build from clause
|
||||
func (from From) Build(builder Builder) {
|
||||
if len(from.Tables) > 0 {
|
||||
for idx, table := range from.Tables {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(table)
|
||||
}
|
||||
} else {
|
||||
builder.WriteQuoted(currentTable)
|
||||
}
|
||||
|
||||
for _, join := range from.Joins {
|
||||
builder.WriteByte(' ')
|
||||
join.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge from clause
|
||||
func (from From) MergeClause(clause *Clause) {
|
||||
clause.Expression = from
|
||||
}
|
||||
48
vendor/gorm.io/gorm/clause/group_by.go
generated
vendored
Normal file
48
vendor/gorm.io/gorm/clause/group_by.go
generated
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
package clause
|
||||
|
||||
// GroupBy group by clause
|
||||
type GroupBy struct {
|
||||
Columns []Column
|
||||
Having []Expression
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (groupBy GroupBy) Name() string {
|
||||
return "GROUP BY"
|
||||
}
|
||||
|
||||
// Build build group by clause
|
||||
func (groupBy GroupBy) Build(builder Builder) {
|
||||
for idx, column := range groupBy.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
|
||||
if len(groupBy.Having) > 0 {
|
||||
builder.WriteString(" HAVING ")
|
||||
Where{Exprs: groupBy.Having}.Build(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge group by clause
|
||||
func (groupBy GroupBy) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(GroupBy); ok {
|
||||
copiedColumns := make([]Column, len(v.Columns))
|
||||
copy(copiedColumns, v.Columns)
|
||||
groupBy.Columns = append(copiedColumns, groupBy.Columns...)
|
||||
|
||||
copiedHaving := make([]Expression, len(v.Having))
|
||||
copy(copiedHaving, v.Having)
|
||||
groupBy.Having = append(copiedHaving, groupBy.Having...)
|
||||
}
|
||||
clause.Expression = groupBy
|
||||
|
||||
if len(groupBy.Columns) == 0 {
|
||||
clause.Name = ""
|
||||
} else {
|
||||
clause.Name = groupBy.Name()
|
||||
}
|
||||
}
|
||||
39
vendor/gorm.io/gorm/clause/insert.go
generated
vendored
Normal file
39
vendor/gorm.io/gorm/clause/insert.go
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
package clause
|
||||
|
||||
type Insert struct {
|
||||
Table Table
|
||||
Modifier string
|
||||
}
|
||||
|
||||
// Name insert clause name
|
||||
func (insert Insert) Name() string {
|
||||
return "INSERT"
|
||||
}
|
||||
|
||||
// Build build insert clause
|
||||
func (insert Insert) Build(builder Builder) {
|
||||
if insert.Modifier != "" {
|
||||
builder.WriteString(insert.Modifier)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
builder.WriteString("INTO ")
|
||||
if insert.Table.Name == "" {
|
||||
builder.WriteQuoted(currentTable)
|
||||
} else {
|
||||
builder.WriteQuoted(insert.Table)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge insert clause
|
||||
func (insert Insert) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Insert); ok {
|
||||
if insert.Modifier == "" {
|
||||
insert.Modifier = v.Modifier
|
||||
}
|
||||
if insert.Table.Name == "" {
|
||||
insert.Table = v.Table
|
||||
}
|
||||
}
|
||||
clause.Expression = insert
|
||||
}
|
||||
47
vendor/gorm.io/gorm/clause/joins.go
generated
vendored
Normal file
47
vendor/gorm.io/gorm/clause/joins.go
generated
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
package clause
|
||||
|
||||
type JoinType string
|
||||
|
||||
const (
|
||||
CrossJoin JoinType = "CROSS"
|
||||
InnerJoin JoinType = "INNER"
|
||||
LeftJoin JoinType = "LEFT"
|
||||
RightJoin JoinType = "RIGHT"
|
||||
)
|
||||
|
||||
// Join clause for from
|
||||
type Join struct {
|
||||
Type JoinType
|
||||
Table Table
|
||||
ON Where
|
||||
Using []string
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
func (join Join) Build(builder Builder) {
|
||||
if join.Expression != nil {
|
||||
join.Expression.Build(builder)
|
||||
} else {
|
||||
if join.Type != "" {
|
||||
builder.WriteString(string(join.Type))
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
builder.WriteString("JOIN ")
|
||||
builder.WriteQuoted(join.Table)
|
||||
|
||||
if len(join.ON.Exprs) > 0 {
|
||||
builder.WriteString(" ON ")
|
||||
join.ON.Build(builder)
|
||||
} else if len(join.Using) > 0 {
|
||||
builder.WriteString(" USING (")
|
||||
for idx, c := range join.Using {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(c)
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
46
vendor/gorm.io/gorm/clause/limit.go
generated
vendored
Normal file
46
vendor/gorm.io/gorm/clause/limit.go
generated
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
package clause
|
||||
|
||||
// Limit limit clause
|
||||
type Limit struct {
|
||||
Limit *int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (limit Limit) Name() string {
|
||||
return "LIMIT"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (limit Limit) Build(builder Builder) {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.AddVar(builder, *limit.Limit)
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
builder.WriteString("OFFSET ")
|
||||
builder.AddVar(builder, limit.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (limit Limit) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
|
||||
if v, ok := clause.Expression.(Limit); ok {
|
||||
if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil {
|
||||
limit.Limit = v.Limit
|
||||
}
|
||||
|
||||
if limit.Offset == 0 && v.Offset > 0 {
|
||||
limit.Offset = v.Offset
|
||||
} else if limit.Offset < 0 {
|
||||
limit.Offset = 0
|
||||
}
|
||||
}
|
||||
|
||||
clause.Expression = limit
|
||||
}
|
||||
38
vendor/gorm.io/gorm/clause/locking.go
generated
vendored
Normal file
38
vendor/gorm.io/gorm/clause/locking.go
generated
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
package clause
|
||||
|
||||
const (
|
||||
LockingStrengthUpdate = "UPDATE"
|
||||
LockingStrengthShare = "SHARE"
|
||||
LockingOptionsSkipLocked = "SKIP LOCKED"
|
||||
LockingOptionsNoWait = "NOWAIT"
|
||||
)
|
||||
|
||||
type Locking struct {
|
||||
Strength string
|
||||
Table Table
|
||||
Options string
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (locking Locking) Name() string {
|
||||
return "FOR"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (locking Locking) Build(builder Builder) {
|
||||
builder.WriteString(locking.Strength)
|
||||
if locking.Table.Name != "" {
|
||||
builder.WriteString(" OF ")
|
||||
builder.WriteQuoted(locking.Table)
|
||||
}
|
||||
|
||||
if locking.Options != "" {
|
||||
builder.WriteByte(' ')
|
||||
builder.WriteString(locking.Options)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (locking Locking) MergeClause(clause *Clause) {
|
||||
clause.Expression = locking
|
||||
}
|
||||
59
vendor/gorm.io/gorm/clause/on_conflict.go
generated
vendored
Normal file
59
vendor/gorm.io/gorm/clause/on_conflict.go
generated
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
package clause
|
||||
|
||||
type OnConflict struct {
|
||||
Columns []Column
|
||||
Where Where
|
||||
TargetWhere Where
|
||||
OnConstraint string
|
||||
DoNothing bool
|
||||
DoUpdates Set
|
||||
UpdateAll bool
|
||||
}
|
||||
|
||||
func (OnConflict) Name() string {
|
||||
return "ON CONFLICT"
|
||||
}
|
||||
|
||||
// Build build onConflict clause
|
||||
func (onConflict OnConflict) Build(builder Builder) {
|
||||
if onConflict.OnConstraint != "" {
|
||||
builder.WriteString("ON CONSTRAINT ")
|
||||
builder.WriteString(onConflict.OnConstraint)
|
||||
builder.WriteByte(' ')
|
||||
} else {
|
||||
if len(onConflict.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range onConflict.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteString(`) `)
|
||||
}
|
||||
|
||||
if len(onConflict.TargetWhere.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.TargetWhere.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
if onConflict.DoNothing {
|
||||
builder.WriteString("DO NOTHING")
|
||||
} else {
|
||||
builder.WriteString("DO UPDATE SET ")
|
||||
onConflict.DoUpdates.Build(builder)
|
||||
}
|
||||
|
||||
if len(onConflict.Where.Exprs) > 0 {
|
||||
builder.WriteString(" WHERE ")
|
||||
onConflict.Where.Build(builder)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge onConflict clauses
|
||||
func (onConflict OnConflict) MergeClause(clause *Clause) {
|
||||
clause.Expression = onConflict
|
||||
}
|
||||
54
vendor/gorm.io/gorm/clause/order_by.go
generated
vendored
Normal file
54
vendor/gorm.io/gorm/clause/order_by.go
generated
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
package clause
|
||||
|
||||
type OrderByColumn struct {
|
||||
Column Column
|
||||
Desc bool
|
||||
Reorder bool
|
||||
}
|
||||
|
||||
type OrderBy struct {
|
||||
Columns []OrderByColumn
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (orderBy OrderBy) Name() string {
|
||||
return "ORDER BY"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (orderBy OrderBy) Build(builder Builder) {
|
||||
if orderBy.Expression != nil {
|
||||
orderBy.Expression.Build(builder)
|
||||
} else {
|
||||
for idx, column := range orderBy.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(column.Column)
|
||||
if column.Desc {
|
||||
builder.WriteString(" DESC")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (orderBy OrderBy) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(OrderBy); ok {
|
||||
for i := len(orderBy.Columns) - 1; i >= 0; i-- {
|
||||
if orderBy.Columns[i].Reorder {
|
||||
orderBy.Columns = orderBy.Columns[i:]
|
||||
clause.Expression = orderBy
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
copiedColumns := make([]OrderByColumn, len(v.Columns))
|
||||
copy(copiedColumns, v.Columns)
|
||||
orderBy.Columns = append(copiedColumns, orderBy.Columns...)
|
||||
}
|
||||
|
||||
clause.Expression = orderBy
|
||||
}
|
||||
34
vendor/gorm.io/gorm/clause/returning.go
generated
vendored
Normal file
34
vendor/gorm.io/gorm/clause/returning.go
generated
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
package clause
|
||||
|
||||
type Returning struct {
|
||||
Columns []Column
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (returning Returning) Name() string {
|
||||
return "RETURNING"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (returning Returning) Build(builder Builder) {
|
||||
if len(returning.Columns) > 0 {
|
||||
for idx, column := range returning.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
} else {
|
||||
builder.WriteByte('*')
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge order by clauses
|
||||
func (returning Returning) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Returning); ok {
|
||||
returning.Columns = append(v.Columns, returning.Columns...)
|
||||
}
|
||||
|
||||
clause.Expression = returning
|
||||
}
|
||||
59
vendor/gorm.io/gorm/clause/select.go
generated
vendored
Normal file
59
vendor/gorm.io/gorm/clause/select.go
generated
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
package clause
|
||||
|
||||
// Select select attrs when querying, updating, creating
|
||||
type Select struct {
|
||||
Distinct bool
|
||||
Columns []Column
|
||||
Expression Expression
|
||||
}
|
||||
|
||||
func (s Select) Name() string {
|
||||
return "SELECT"
|
||||
}
|
||||
|
||||
func (s Select) Build(builder Builder) {
|
||||
if len(s.Columns) > 0 {
|
||||
if s.Distinct {
|
||||
builder.WriteString("DISTINCT ")
|
||||
}
|
||||
|
||||
for idx, column := range s.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
} else {
|
||||
builder.WriteByte('*')
|
||||
}
|
||||
}
|
||||
|
||||
func (s Select) MergeClause(clause *Clause) {
|
||||
if s.Expression != nil {
|
||||
if s.Distinct {
|
||||
if expr, ok := s.Expression.(Expr); ok {
|
||||
expr.SQL = "DISTINCT " + expr.SQL
|
||||
clause.Expression = expr
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
clause.Expression = s.Expression
|
||||
} else {
|
||||
clause.Expression = s
|
||||
}
|
||||
}
|
||||
|
||||
// CommaExpression represents a group of expressions separated by commas.
|
||||
type CommaExpression struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (comma CommaExpression) Build(builder Builder) {
|
||||
for idx, expr := range comma.Exprs {
|
||||
if idx > 0 {
|
||||
_, _ = builder.WriteString(", ")
|
||||
}
|
||||
expr.Build(builder)
|
||||
}
|
||||
}
|
||||
60
vendor/gorm.io/gorm/clause/set.go
generated
vendored
Normal file
60
vendor/gorm.io/gorm/clause/set.go
generated
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
package clause
|
||||
|
||||
import "sort"
|
||||
|
||||
type Set []Assignment
|
||||
|
||||
type Assignment struct {
|
||||
Column Column
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (set Set) Name() string {
|
||||
return "SET"
|
||||
}
|
||||
|
||||
func (set Set) Build(builder Builder) {
|
||||
if len(set) > 0 {
|
||||
for idx, assignment := range set {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(assignment.Column)
|
||||
builder.WriteByte('=')
|
||||
builder.AddVar(builder, assignment.Value)
|
||||
}
|
||||
} else {
|
||||
builder.WriteQuoted(Column{Name: PrimaryKey})
|
||||
builder.WriteByte('=')
|
||||
builder.WriteQuoted(Column{Name: PrimaryKey})
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge assignments clauses
|
||||
func (set Set) MergeClause(clause *Clause) {
|
||||
copiedAssignments := make([]Assignment, len(set))
|
||||
copy(copiedAssignments, set)
|
||||
clause.Expression = Set(copiedAssignments)
|
||||
}
|
||||
|
||||
func Assignments(values map[string]interface{}) Set {
|
||||
keys := make([]string, 0, len(values))
|
||||
for key := range values {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
assignments := make([]Assignment, len(keys))
|
||||
for idx, key := range keys {
|
||||
assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]}
|
||||
}
|
||||
return assignments
|
||||
}
|
||||
|
||||
func AssignmentColumns(values []string) Set {
|
||||
assignments := make([]Assignment, len(values))
|
||||
for idx, value := range values {
|
||||
assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}}
|
||||
}
|
||||
return assignments
|
||||
}
|
||||
38
vendor/gorm.io/gorm/clause/update.go
generated
vendored
Normal file
38
vendor/gorm.io/gorm/clause/update.go
generated
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
package clause
|
||||
|
||||
type Update struct {
|
||||
Modifier string
|
||||
Table Table
|
||||
}
|
||||
|
||||
// Name update clause name
|
||||
func (update Update) Name() string {
|
||||
return "UPDATE"
|
||||
}
|
||||
|
||||
// Build build update clause
|
||||
func (update Update) Build(builder Builder) {
|
||||
if update.Modifier != "" {
|
||||
builder.WriteString(update.Modifier)
|
||||
builder.WriteByte(' ')
|
||||
}
|
||||
|
||||
if update.Table.Name == "" {
|
||||
builder.WriteQuoted(currentTable)
|
||||
} else {
|
||||
builder.WriteQuoted(update.Table)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge update clause
|
||||
func (update Update) MergeClause(clause *Clause) {
|
||||
if v, ok := clause.Expression.(Update); ok {
|
||||
if update.Modifier == "" {
|
||||
update.Modifier = v.Modifier
|
||||
}
|
||||
if update.Table.Name == "" {
|
||||
update.Table = v.Table
|
||||
}
|
||||
}
|
||||
clause.Expression = update
|
||||
}
|
||||
45
vendor/gorm.io/gorm/clause/values.go
generated
vendored
Normal file
45
vendor/gorm.io/gorm/clause/values.go
generated
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
package clause
|
||||
|
||||
type Values struct {
|
||||
Columns []Column
|
||||
Values [][]interface{}
|
||||
}
|
||||
|
||||
// Name from clause name
|
||||
func (Values) Name() string {
|
||||
return "VALUES"
|
||||
}
|
||||
|
||||
// Build build from clause
|
||||
func (values Values) Build(builder Builder) {
|
||||
if len(values.Columns) > 0 {
|
||||
builder.WriteByte('(')
|
||||
for idx, column := range values.Columns {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
builder.WriteQuoted(column)
|
||||
}
|
||||
builder.WriteByte(')')
|
||||
|
||||
builder.WriteString(" VALUES ")
|
||||
|
||||
for idx, value := range values.Values {
|
||||
if idx > 0 {
|
||||
builder.WriteByte(',')
|
||||
}
|
||||
|
||||
builder.WriteByte('(')
|
||||
builder.AddVar(builder, value...)
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
} else {
|
||||
builder.WriteString("DEFAULT VALUES")
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge values clauses
|
||||
func (values Values) MergeClause(clause *Clause) {
|
||||
clause.Name = ""
|
||||
clause.Expression = values
|
||||
}
|
||||
201
vendor/gorm.io/gorm/clause/where.go
generated
vendored
Normal file
201
vendor/gorm.io/gorm/clause/where.go
generated
vendored
Normal file
@@ -0,0 +1,201 @@
|
||||
package clause
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
AndWithSpace = " AND "
|
||||
OrWithSpace = " OR "
|
||||
)
|
||||
|
||||
// Where where clause
|
||||
type Where struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
// Name where clause name
|
||||
func (where Where) Name() string {
|
||||
return "WHERE"
|
||||
}
|
||||
|
||||
// Build build where clause
|
||||
func (where Where) Build(builder Builder) {
|
||||
if len(where.Exprs) == 1 {
|
||||
if andCondition, ok := where.Exprs[0].(AndConditions); ok {
|
||||
where.Exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
|
||||
// Switch position if the first query expression is a single Or condition
|
||||
for idx, expr := range where.Exprs {
|
||||
if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 {
|
||||
if idx != 0 {
|
||||
where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0]
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
buildExprs(where.Exprs, builder, AndWithSpace)
|
||||
}
|
||||
|
||||
func buildExprs(exprs []Expression, builder Builder, joinCond string) {
|
||||
wrapInParentheses := false
|
||||
|
||||
for idx, expr := range exprs {
|
||||
if idx > 0 {
|
||||
if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 {
|
||||
builder.WriteString(OrWithSpace)
|
||||
} else {
|
||||
builder.WriteString(joinCond)
|
||||
}
|
||||
}
|
||||
|
||||
if len(exprs) > 1 {
|
||||
switch v := expr.(type) {
|
||||
case OrConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||
}
|
||||
}
|
||||
case AndConditions:
|
||||
if len(v.Exprs) == 1 {
|
||||
if e, ok := v.Exprs[0].(Expr); ok {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||
}
|
||||
}
|
||||
case Expr:
|
||||
sql := strings.ToUpper(v.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||
case NamedExpr:
|
||||
sql := strings.ToUpper(v.SQL)
|
||||
wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
expr.Build(builder)
|
||||
builder.WriteByte(')')
|
||||
wrapInParentheses = false
|
||||
} else {
|
||||
expr.Build(builder)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MergeClause merge where clauses
|
||||
func (where Where) MergeClause(clause *Clause) {
|
||||
if w, ok := clause.Expression.(Where); ok {
|
||||
exprs := make([]Expression, len(w.Exprs)+len(where.Exprs))
|
||||
copy(exprs, w.Exprs)
|
||||
copy(exprs[len(w.Exprs):], where.Exprs)
|
||||
where.Exprs = exprs
|
||||
}
|
||||
|
||||
clause.Expression = where
|
||||
}
|
||||
|
||||
func And(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(exprs) == 1 {
|
||||
if _, ok := exprs[0].(OrConditions); !ok {
|
||||
return exprs[0]
|
||||
}
|
||||
}
|
||||
|
||||
return AndConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
type AndConditions struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (and AndConditions) Build(builder Builder) {
|
||||
if len(and.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
buildExprs(and.Exprs, builder, AndWithSpace)
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
buildExprs(and.Exprs, builder, AndWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
func Or(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return OrConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
type OrConditions struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (or OrConditions) Build(builder Builder) {
|
||||
if len(or.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
buildExprs(or.Exprs, builder, OrWithSpace)
|
||||
builder.WriteByte(')')
|
||||
} else {
|
||||
buildExprs(or.Exprs, builder, OrWithSpace)
|
||||
}
|
||||
}
|
||||
|
||||
func Not(exprs ...Expression) Expression {
|
||||
if len(exprs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(exprs) == 1 {
|
||||
if andCondition, ok := exprs[0].(AndConditions); ok {
|
||||
exprs = andCondition.Exprs
|
||||
}
|
||||
}
|
||||
return NotConditions{Exprs: exprs}
|
||||
}
|
||||
|
||||
type NotConditions struct {
|
||||
Exprs []Expression
|
||||
}
|
||||
|
||||
func (not NotConditions) Build(builder Builder) {
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
|
||||
for idx, c := range not.Exprs {
|
||||
if idx > 0 {
|
||||
builder.WriteString(AndWithSpace)
|
||||
}
|
||||
|
||||
if negationBuilder, ok := c.(NegationExpressionBuilder); ok {
|
||||
negationBuilder.NegationBuild(builder)
|
||||
} else {
|
||||
builder.WriteString("NOT ")
|
||||
e, wrapInParentheses := c.(Expr)
|
||||
if wrapInParentheses {
|
||||
sql := strings.ToUpper(e.SQL)
|
||||
if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses {
|
||||
builder.WriteByte('(')
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
|
||||
if wrapInParentheses {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(not.Exprs) > 1 {
|
||||
builder.WriteByte(')')
|
||||
}
|
||||
}
|
||||
3
vendor/gorm.io/gorm/clause/with.go
generated
vendored
Normal file
3
vendor/gorm.io/gorm/clause/with.go
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
package clause
|
||||
|
||||
type With struct{}
|
||||
52
vendor/gorm.io/gorm/errors.go
generated
vendored
Normal file
52
vendor/gorm.io/gorm/errors.go
generated
vendored
Normal file
@@ -0,0 +1,52 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrRecordNotFound record not found error
|
||||
ErrRecordNotFound = logger.ErrRecordNotFound
|
||||
// ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback`
|
||||
ErrInvalidTransaction = errors.New("invalid transaction")
|
||||
// ErrNotImplemented not implemented
|
||||
ErrNotImplemented = errors.New("not implemented")
|
||||
// ErrMissingWhereClause missing where clause
|
||||
ErrMissingWhereClause = errors.New("WHERE conditions required")
|
||||
// ErrUnsupportedRelation unsupported relations
|
||||
ErrUnsupportedRelation = errors.New("unsupported relations")
|
||||
// ErrPrimaryKeyRequired primary keys required
|
||||
ErrPrimaryKeyRequired = errors.New("primary key required")
|
||||
// ErrModelValueRequired model value required
|
||||
ErrModelValueRequired = errors.New("model value required")
|
||||
// ErrModelAccessibleFieldsRequired model accessible fields required
|
||||
ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required")
|
||||
// ErrSubQueryRequired sub query required
|
||||
ErrSubQueryRequired = errors.New("sub query required")
|
||||
// ErrInvalidData unsupported data
|
||||
ErrInvalidData = errors.New("unsupported data")
|
||||
// ErrUnsupportedDriver unsupported driver
|
||||
ErrUnsupportedDriver = errors.New("unsupported driver")
|
||||
// ErrRegistered registered
|
||||
ErrRegistered = errors.New("registered")
|
||||
// ErrInvalidField invalid field
|
||||
ErrInvalidField = errors.New("invalid field")
|
||||
// ErrEmptySlice empty slice found
|
||||
ErrEmptySlice = errors.New("empty slice found")
|
||||
// ErrDryRunModeUnsupported dry run mode unsupported
|
||||
ErrDryRunModeUnsupported = errors.New("dry run mode unsupported")
|
||||
// ErrInvalidDB invalid db
|
||||
ErrInvalidDB = errors.New("invalid db")
|
||||
// ErrInvalidValue invalid value
|
||||
ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice")
|
||||
// ErrInvalidValueOfLength invalid values do not match length
|
||||
ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match")
|
||||
// ErrPreloadNotAllowed preload is not allowed when count is used
|
||||
ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used")
|
||||
// ErrDuplicatedKey occurs when there is a unique key constraint violation
|
||||
ErrDuplicatedKey = errors.New("duplicated key not allowed")
|
||||
// ErrForeignKeyViolated occurs when there is a foreign key constraint violation
|
||||
ErrForeignKeyViolated = errors.New("violates foreign key constraint")
|
||||
)
|
||||
770
vendor/gorm.io/gorm/finisher_api.go
generated
vendored
Normal file
770
vendor/gorm.io/gorm/finisher_api.go
generated
vendored
Normal file
@@ -0,0 +1,770 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Create inserts value, returning the inserted data's primary key in value's id
|
||||
func (db *DB) Create(value interface{}) (tx *DB) {
|
||||
if db.CreateBatchSize > 0 {
|
||||
return db.CreateInBatches(value, db.CreateBatchSize)
|
||||
}
|
||||
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
return tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
|
||||
// CreateInBatches inserts value in batches of batchSize
|
||||
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var rowsAffected int64
|
||||
tx = db.getInstance()
|
||||
|
||||
// the reflection length judgment of the optimized value
|
||||
reflectLen := reflectValue.Len()
|
||||
|
||||
callFc := func(tx *DB) error {
|
||||
for i := 0; i < reflectLen; i += batchSize {
|
||||
ends := i + batchSize
|
||||
if ends > reflectLen {
|
||||
ends = reflectLen
|
||||
}
|
||||
|
||||
subtx := tx.getInstance()
|
||||
subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
|
||||
subtx.callbacks.Create().Execute(subtx)
|
||||
if subtx.Error != nil {
|
||||
return subtx.Error
|
||||
}
|
||||
rowsAffected += subtx.RowsAffected
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if tx.SkipDefaultTransaction || reflectLen <= batchSize {
|
||||
tx.AddError(callFc(tx.Session(&Session{})))
|
||||
} else {
|
||||
tx.AddError(tx.Transaction(callFc))
|
||||
}
|
||||
|
||||
tx.RowsAffected = rowsAffected
|
||||
default:
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
tx = tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
|
||||
func (db *DB) Save(value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = value
|
||||
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflect.Indirect(reflectValue)
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
|
||||
tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
|
||||
}
|
||||
tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
|
||||
case reflect.Struct:
|
||||
if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
|
||||
for _, pf := range tx.Statement.Schema.PrimaryFields {
|
||||
if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
|
||||
return tx.callbacks.Create().Execute(tx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fallthrough
|
||||
default:
|
||||
selectedUpdate := len(tx.Statement.Selects) != 0
|
||||
// when updating, use all fields including those zero-value fields
|
||||
if !selectedUpdate {
|
||||
tx.Statement.Selects = append(tx.Statement.Selects, "*")
|
||||
}
|
||||
|
||||
updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
|
||||
|
||||
if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
|
||||
return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
|
||||
}
|
||||
|
||||
return updateTx
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// First finds the first record ordered by primary key, matching given conditions conds
|
||||
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Take finds the first record returned by the database in no specified order, matching given conditions conds
|
||||
func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1)
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Last finds the last record ordered by primary key, matching given conditions conds
|
||||
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
Desc: true,
|
||||
})
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.RaiseErrorOnNotFound = true
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// Find finds all records matching given conditions conds
|
||||
func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
// FindInBatches finds all records in batches of batchSize
|
||||
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
|
||||
var (
|
||||
tx = db.Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
}).Session(&Session{})
|
||||
queryDB = tx
|
||||
rowsAffected int64
|
||||
batch int
|
||||
)
|
||||
|
||||
// user specified offset or limit
|
||||
var totalSize int
|
||||
if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
if limit.Limit != nil {
|
||||
totalSize = *limit.Limit
|
||||
}
|
||||
|
||||
if totalSize > 0 && batchSize > totalSize {
|
||||
batchSize = totalSize
|
||||
}
|
||||
|
||||
// reset to offset to 0 in next batch
|
||||
tx = tx.Offset(-1).Session(&Session{})
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
result := queryDB.Limit(batchSize).Find(dest)
|
||||
rowsAffected += result.RowsAffected
|
||||
batch++
|
||||
|
||||
if result.Error == nil && result.RowsAffected != 0 {
|
||||
fcTx := result.Session(&Session{NewDB: true})
|
||||
fcTx.RowsAffected = result.RowsAffected
|
||||
tx.AddError(fc(fcTx, batch))
|
||||
} else if result.Error != nil {
|
||||
tx.AddError(result.Error)
|
||||
}
|
||||
|
||||
if tx.Error != nil || int(result.RowsAffected) < batchSize {
|
||||
break
|
||||
}
|
||||
|
||||
if totalSize > 0 {
|
||||
if totalSize <= int(rowsAffected) {
|
||||
break
|
||||
}
|
||||
if totalSize/batchSize == batch {
|
||||
batchSize = totalSize % batchSize
|
||||
}
|
||||
}
|
||||
|
||||
// Optimize for-break
|
||||
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
|
||||
if result.Statement.Schema.PrioritizedPrimaryField == nil {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
}
|
||||
|
||||
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
|
||||
if zero {
|
||||
tx.AddError(ErrPrimaryKeyRequired)
|
||||
break
|
||||
}
|
||||
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
|
||||
}
|
||||
|
||||
tx.RowsAffected = rowsAffected
|
||||
return tx
|
||||
}
|
||||
|
||||
func (db *DB) assignInterfacesToValue(values ...interface{}) {
|
||||
for _, value := range values {
|
||||
switch v := value.(type) {
|
||||
case []clause.Expression:
|
||||
for _, expr := range v {
|
||||
if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
if field := db.Statement.Schema.LookUpField(column); field != nil {
|
||||
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
case clause.Column:
|
||||
if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
|
||||
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
|
||||
}
|
||||
}
|
||||
} else if andCond, ok := expr.(clause.AndConditions); ok {
|
||||
db.assignInterfacesToValue(andCond.Exprs)
|
||||
}
|
||||
}
|
||||
case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
|
||||
if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
|
||||
db.assignInterfacesToValue(exprs)
|
||||
}
|
||||
default:
|
||||
if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(value))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, f := range s.Fields {
|
||||
if f.Readable {
|
||||
if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
|
||||
if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
|
||||
db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if len(values) > 0 {
|
||||
if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
|
||||
db.assignInterfacesToValue(exprs)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// FirstOrInit never modifies the database. It is often used with Assign and Attrs.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
queryTx := db.Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
|
||||
if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
|
||||
if c, ok := tx.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
tx.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.attrs) > 0 {
|
||||
tx.assignInterfacesToValue(tx.Statement.attrs...)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(tx.Statement.assigns) > 0 {
|
||||
tx.assignInterfacesToValue(tx.Statement.assigns...)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
|
||||
// Each conds must be a struct or map.
|
||||
//
|
||||
// Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
|
||||
//
|
||||
// // assign an email if the record is not found
|
||||
// result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "non_existing", Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
//
|
||||
// // assign email regardless of if record is found
|
||||
// result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
|
||||
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
|
||||
// // result.RowsAffected -> 1
|
||||
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
|
||||
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
|
||||
})
|
||||
|
||||
result := queryTx.Find(dest, conds...)
|
||||
if result.Error != nil {
|
||||
tx.Error = result.Error
|
||||
return tx
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
if c, ok := result.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok {
|
||||
result.assignInterfacesToValue(where.Exprs)
|
||||
}
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.attrs) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.attrs...)
|
||||
}
|
||||
|
||||
// initialize with attrs, conds
|
||||
if len(db.Statement.assigns) > 0 {
|
||||
result.assignInterfacesToValue(db.Statement.assigns...)
|
||||
}
|
||||
|
||||
return tx.Create(dest)
|
||||
} else if len(db.Statement.assigns) > 0 {
|
||||
exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
|
||||
assigns := map[string]interface{}{}
|
||||
for i := 0; i < len(exprs); i++ {
|
||||
expr := exprs[i]
|
||||
|
||||
if eq, ok := expr.(clause.AndConditions); ok {
|
||||
exprs = append(exprs, eq.Exprs...)
|
||||
} else if eq, ok := expr.(clause.Eq); ok {
|
||||
switch column := eq.Column.(type) {
|
||||
case string:
|
||||
assigns[column] = eq.Value
|
||||
case clause.Column:
|
||||
assigns[column.Name] = eq.Value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Model(dest).Updates(assigns)
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (db *DB) Update(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
// Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
|
||||
func (db *DB) Updates(values interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = values
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = map[string]interface{}{column: value}
|
||||
tx.Statement.SkipHooks = true
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.Dest = values
|
||||
tx.Statement.SkipHooks = true
|
||||
return tx.callbacks.Update().Execute(tx)
|
||||
}
|
||||
|
||||
// Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
|
||||
// value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
|
||||
// time if null.
|
||||
func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if len(conds) > 0 {
|
||||
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
|
||||
tx.Statement.AddClause(clause.Where{Exprs: exprs})
|
||||
}
|
||||
}
|
||||
tx.Statement.Dest = value
|
||||
return tx.callbacks.Delete().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) Count(count *int64) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Model == nil {
|
||||
tx.Statement.Model = tx.Statement.Dest
|
||||
defer func() {
|
||||
tx.Statement.Model = nil
|
||||
}()
|
||||
}
|
||||
|
||||
if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
|
||||
defer func() {
|
||||
tx.Statement.Clauses["SELECT"] = selectClause
|
||||
}()
|
||||
} else {
|
||||
defer delete(tx.Statement.Clauses, "SELECT")
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) == 0 {
|
||||
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
|
||||
} else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
|
||||
expr := clause.Expr{SQL: "count(*)"}
|
||||
|
||||
if len(tx.Statement.Selects) == 1 {
|
||||
dbName := tx.Statement.Selects[0]
|
||||
fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
|
||||
if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
|
||||
dbName = f.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if tx.Statement.Distinct {
|
||||
expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
} else if dbName != "*" {
|
||||
expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.AddClause(clause.Select{Expression: expr})
|
||||
}
|
||||
|
||||
if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
|
||||
if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
|
||||
delete(tx.Statement.Clauses, "ORDER BY")
|
||||
defer func() {
|
||||
tx.Statement.Clauses["ORDER BY"] = orderByClause
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
tx.Statement.Dest = count
|
||||
tx = tx.callbacks.Query().Execute(tx)
|
||||
|
||||
if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
|
||||
*count = tx.RowsAffected
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (db *DB) Row() *sql.Row {
|
||||
tx := db.getInstance().Set("rows", false)
|
||||
tx = tx.callbacks.Row().Execute(tx)
|
||||
row, ok := tx.Statement.Dest.(*sql.Row)
|
||||
if !ok && tx.DryRun {
|
||||
db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
func (db *DB) Rows() (*sql.Rows, error) {
|
||||
tx := db.getInstance().Set("rows", true)
|
||||
tx = tx.callbacks.Row().Execute(tx)
|
||||
rows, ok := tx.Statement.Dest.(*sql.Rows)
|
||||
if !ok && tx.DryRun && tx.Error == nil {
|
||||
tx.Error = ErrDryRunModeUnsupported
|
||||
}
|
||||
return rows, tx.Error
|
||||
}
|
||||
|
||||
// Scan scans selected value to the struct dest
|
||||
func (db *DB) Scan(dest interface{}) (tx *DB) {
|
||||
config := *db.Config
|
||||
currentLogger, newLogger := config.Logger, logger.Recorder.New()
|
||||
config.Logger = newLogger
|
||||
|
||||
tx = db.getInstance()
|
||||
tx.Config = &config
|
||||
|
||||
if rows, err := tx.Rows(); err == nil {
|
||||
if rows.Next() {
|
||||
tx.ScanRows(rows, dest)
|
||||
} else {
|
||||
tx.RowsAffected = 0
|
||||
tx.AddError(rows.Err())
|
||||
}
|
||||
tx.AddError(rows.Close())
|
||||
}
|
||||
|
||||
currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
|
||||
return newLogger.SQL, tx.RowsAffected
|
||||
}, tx.Error)
|
||||
tx.Logger = currentLogger
|
||||
return
|
||||
}
|
||||
|
||||
// Pluck queries a single column from a model, returning in the slice dest. E.g.:
|
||||
//
|
||||
// var ages []int64
|
||||
// db.Model(&users).Pluck("age", &ages)
|
||||
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
if tx.Statement.Model != nil {
|
||||
if tx.Statement.Parse(tx.Statement.Model) == nil {
|
||||
if f := tx.Statement.Schema.LookUpField(column); f != nil {
|
||||
column = f.DBName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(tx.Statement.Selects) != 1 {
|
||||
fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
|
||||
tx.Statement.AddClauseIfNotExists(clause.Select{
|
||||
Distinct: tx.Statement.Distinct,
|
||||
Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
|
||||
})
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
return tx.callbacks.Query().Execute(tx)
|
||||
}
|
||||
|
||||
func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
|
||||
tx := db.getInstance()
|
||||
if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
|
||||
tx.AddError(err)
|
||||
}
|
||||
tx.Statement.Dest = dest
|
||||
tx.Statement.ReflectValue = reflect.ValueOf(dest)
|
||||
for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
|
||||
elem := tx.Statement.ReflectValue.Elem()
|
||||
if !elem.IsValid() {
|
||||
elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
|
||||
tx.Statement.ReflectValue.Set(elem)
|
||||
}
|
||||
tx.Statement.ReflectValue = elem
|
||||
}
|
||||
Scan(rows, tx, ScanInitialized)
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
// Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
|
||||
// returned to the connection pool.
|
||||
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
|
||||
if db.Error != nil {
|
||||
return db.Error
|
||||
}
|
||||
|
||||
tx := db.getInstance()
|
||||
sqlDB, err := tx.DB()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := sqlDB.Conn(tx.Statement.Context)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
tx.Statement.ConnPool = conn
|
||||
return fc(tx)
|
||||
}
|
||||
|
||||
// Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
|
||||
// arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
|
||||
// they are rolled back.
|
||||
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
|
||||
panicked := true
|
||||
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
// nested transaction
|
||||
if !db.DisableNestedTransaction {
|
||||
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
db.RollbackTo(fmt.Sprintf("sp%p", fc))
|
||||
}
|
||||
}()
|
||||
}
|
||||
err = fc(db.Session(&Session{NewDB: db.clone == 1}))
|
||||
} else {
|
||||
tx := db.Begin(opts...)
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if panicked || err != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = fc(tx); err == nil {
|
||||
panicked = false
|
||||
return tx.Commit().Error
|
||||
}
|
||||
}
|
||||
|
||||
panicked = false
|
||||
return
|
||||
}
|
||||
|
||||
// Begin begins a transaction with any transaction options opts
|
||||
func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
|
||||
var (
|
||||
// clone statement
|
||||
tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
|
||||
opt *sql.TxOptions
|
||||
err error
|
||||
)
|
||||
|
||||
if len(opts) > 0 {
|
||||
opt = opts[0]
|
||||
}
|
||||
|
||||
switch beginner := tx.Statement.ConnPool.(type) {
|
||||
case TxBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
case ConnPoolBeginner:
|
||||
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
|
||||
default:
|
||||
err = ErrInvalidTransaction
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
tx.AddError(err)
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// Commit commits the changes in a transaction
|
||||
func (db *DB) Commit() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
|
||||
db.AddError(committer.Commit())
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Rollback rollbacks the changes in a transaction
|
||||
func (db *DB) Rollback() *DB {
|
||||
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
|
||||
if !reflect.ValueOf(committer).IsNil() {
|
||||
db.AddError(committer.Rollback())
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrInvalidTransaction)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) SavePoint(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.SavePoint(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) RollbackTo(name string) *DB {
|
||||
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
|
||||
// close prepared statement, because RollbackTo not support prepared statement.
|
||||
// e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
|
||||
var (
|
||||
preparedStmtTx *PreparedStmtTX
|
||||
isPreparedStmtTx bool
|
||||
)
|
||||
// close prepared statement, because SavePoint not support prepared statement.
|
||||
if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx.Tx
|
||||
}
|
||||
db.AddError(savePointer.RollbackTo(db, name))
|
||||
// restore prepared statement
|
||||
if isPreparedStmtTx {
|
||||
db.Statement.ConnPool = preparedStmtTx
|
||||
}
|
||||
} else {
|
||||
db.AddError(ErrUnsupportedDriver)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// Exec executes raw sql
|
||||
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
tx.Statement.SQL = strings.Builder{}
|
||||
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
|
||||
}
|
||||
|
||||
return tx.callbacks.Raw().Execute(tx)
|
||||
}
|
||||
506
vendor/gorm.io/gorm/gorm.go
generated
vendored
Normal file
506
vendor/gorm.io/gorm/gorm.go
generated
vendored
Normal file
@@ -0,0 +1,506 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// for Config.cacheStore store PreparedStmtDB key
|
||||
const preparedStmtDBKey = "preparedStmt"
|
||||
|
||||
// Config GORM config
|
||||
type Config struct {
|
||||
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
|
||||
// You can disable it by setting `SkipDefaultTransaction` to true
|
||||
SkipDefaultTransaction bool
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
NamingStrategy schema.Namer
|
||||
// FullSaveAssociations full save associations
|
||||
FullSaveAssociations bool
|
||||
// Logger
|
||||
Logger logger.Interface
|
||||
// NowFunc the function to be used when creating a new timestamp
|
||||
NowFunc func() time.Time
|
||||
// DryRun generate sql without execute
|
||||
DryRun bool
|
||||
// PrepareStmt executes the given query in cached statement
|
||||
PrepareStmt bool
|
||||
// DisableAutomaticPing
|
||||
DisableAutomaticPing bool
|
||||
// DisableForeignKeyConstraintWhenMigrating
|
||||
DisableForeignKeyConstraintWhenMigrating bool
|
||||
// IgnoreRelationshipsWhenMigrating
|
||||
IgnoreRelationshipsWhenMigrating bool
|
||||
// DisableNestedTransaction disable nested transaction
|
||||
DisableNestedTransaction bool
|
||||
// AllowGlobalUpdate allow global update
|
||||
AllowGlobalUpdate bool
|
||||
// QueryFields executes the SQL query with all fields of the table
|
||||
QueryFields bool
|
||||
// CreateBatchSize default create batch size
|
||||
CreateBatchSize int
|
||||
// TranslateError enabling error translation
|
||||
TranslateError bool
|
||||
|
||||
// ClauseBuilders clause builder
|
||||
ClauseBuilders map[string]clause.ClauseBuilder
|
||||
// ConnPool db conn pool
|
||||
ConnPool ConnPool
|
||||
// Dialector database dialector
|
||||
Dialector
|
||||
// Plugins registered plugins
|
||||
Plugins map[string]Plugin
|
||||
|
||||
callbacks *callbacks
|
||||
cacheStore *sync.Map
|
||||
}
|
||||
|
||||
// Apply update config to new config
|
||||
func (c *Config) Apply(config *Config) error {
|
||||
if config != c {
|
||||
*config = *c
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterInitialize initialize plugins after db connected
|
||||
func (c *Config) AfterInitialize(db *DB) error {
|
||||
if db != nil {
|
||||
for _, plugin := range c.Plugins {
|
||||
if err := plugin.Initialize(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Option gorm option interface
|
||||
type Option interface {
|
||||
Apply(*Config) error
|
||||
AfterInitialize(*DB) error
|
||||
}
|
||||
|
||||
// DB GORM DB definition
|
||||
type DB struct {
|
||||
*Config
|
||||
Error error
|
||||
RowsAffected int64
|
||||
Statement *Statement
|
||||
clone int
|
||||
}
|
||||
|
||||
// Session session config when create session with Session() method
|
||||
type Session struct {
|
||||
DryRun bool
|
||||
PrepareStmt bool
|
||||
NewDB bool
|
||||
Initialized bool
|
||||
SkipHooks bool
|
||||
SkipDefaultTransaction bool
|
||||
DisableNestedTransaction bool
|
||||
AllowGlobalUpdate bool
|
||||
FullSaveAssociations bool
|
||||
QueryFields bool
|
||||
Context context.Context
|
||||
Logger logger.Interface
|
||||
NowFunc func() time.Time
|
||||
CreateBatchSize int
|
||||
}
|
||||
|
||||
// Open initialize db session based on dialector
|
||||
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
|
||||
config := &Config{}
|
||||
|
||||
sort.Slice(opts, func(i, j int) bool {
|
||||
_, isConfig := opts[i].(*Config)
|
||||
_, isConfig2 := opts[j].(*Config)
|
||||
return isConfig && !isConfig2
|
||||
})
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
if applyErr := opt.Apply(config); applyErr != nil {
|
||||
return nil, applyErr
|
||||
}
|
||||
defer func(opt Option) {
|
||||
if errr := opt.AfterInitialize(db); errr != nil {
|
||||
err = errr
|
||||
}
|
||||
}(opt)
|
||||
}
|
||||
}
|
||||
|
||||
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
|
||||
if err = d.Apply(config); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if config.NamingStrategy == nil {
|
||||
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = logger.Default
|
||||
}
|
||||
|
||||
if config.NowFunc == nil {
|
||||
config.NowFunc = func() time.Time { return time.Now().Local() }
|
||||
}
|
||||
|
||||
if dialector != nil {
|
||||
config.Dialector = dialector
|
||||
}
|
||||
|
||||
if config.Plugins == nil {
|
||||
config.Plugins = map[string]Plugin{}
|
||||
}
|
||||
|
||||
if config.cacheStore == nil {
|
||||
config.cacheStore = &sync.Map{}
|
||||
}
|
||||
|
||||
db = &DB{Config: config, clone: 1}
|
||||
|
||||
db.callbacks = initializeCallbacks(db)
|
||||
|
||||
if config.ClauseBuilders == nil {
|
||||
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
|
||||
}
|
||||
|
||||
if config.Dialector != nil {
|
||||
err = config.Dialector.Initialize(db)
|
||||
|
||||
if err != nil {
|
||||
if db, _ := db.DB(); db != nil {
|
||||
_ = db.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
preparedStmt := NewPreparedStmtDB(db.ConnPool)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
db.ConnPool = preparedStmt
|
||||
}
|
||||
|
||||
db.Statement = &Statement{
|
||||
DB: db,
|
||||
ConnPool: db.ConnPool,
|
||||
Context: context.Background(),
|
||||
Clauses: map[string]clause.Clause{},
|
||||
}
|
||||
|
||||
if err == nil && !config.DisableAutomaticPing {
|
||||
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
|
||||
err = pinger.Ping()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Session create new db session
|
||||
func (db *DB) Session(config *Session) *DB {
|
||||
var (
|
||||
txConfig = *db.Config
|
||||
tx = &DB{
|
||||
Config: &txConfig,
|
||||
Statement: db.Statement,
|
||||
Error: db.Error,
|
||||
clone: 1,
|
||||
}
|
||||
)
|
||||
if config.CreateBatchSize > 0 {
|
||||
tx.Config.CreateBatchSize = config.CreateBatchSize
|
||||
}
|
||||
|
||||
if config.SkipDefaultTransaction {
|
||||
tx.Config.SkipDefaultTransaction = true
|
||||
}
|
||||
|
||||
if config.AllowGlobalUpdate {
|
||||
txConfig.AllowGlobalUpdate = true
|
||||
}
|
||||
|
||||
if config.FullSaveAssociations {
|
||||
txConfig.FullSaveAssociations = true
|
||||
}
|
||||
|
||||
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
|
||||
tx.Statement = tx.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
}
|
||||
|
||||
if config.Context != nil {
|
||||
tx.Statement.Context = config.Context
|
||||
}
|
||||
|
||||
if config.PrepareStmt {
|
||||
var preparedStmt *PreparedStmtDB
|
||||
|
||||
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
|
||||
preparedStmt = v.(*PreparedStmtDB)
|
||||
} else {
|
||||
preparedStmt = NewPreparedStmtDB(db.ConnPool)
|
||||
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
|
||||
}
|
||||
|
||||
switch t := tx.Statement.ConnPool.(type) {
|
||||
case Tx:
|
||||
tx.Statement.ConnPool = &PreparedStmtTX{
|
||||
Tx: t,
|
||||
PreparedStmtDB: preparedStmt,
|
||||
}
|
||||
default:
|
||||
tx.Statement.ConnPool = &PreparedStmtDB{
|
||||
ConnPool: db.Config.ConnPool,
|
||||
Mux: preparedStmt.Mux,
|
||||
Stmts: preparedStmt.Stmts,
|
||||
}
|
||||
}
|
||||
txConfig.ConnPool = tx.Statement.ConnPool
|
||||
txConfig.PrepareStmt = true
|
||||
}
|
||||
|
||||
if config.SkipHooks {
|
||||
tx.Statement.SkipHooks = true
|
||||
}
|
||||
|
||||
if config.DisableNestedTransaction {
|
||||
txConfig.DisableNestedTransaction = true
|
||||
}
|
||||
|
||||
if !config.NewDB {
|
||||
tx.clone = 2
|
||||
}
|
||||
|
||||
if config.DryRun {
|
||||
tx.Config.DryRun = true
|
||||
}
|
||||
|
||||
if config.QueryFields {
|
||||
tx.Config.QueryFields = true
|
||||
}
|
||||
|
||||
if config.Logger != nil {
|
||||
tx.Config.Logger = config.Logger
|
||||
}
|
||||
|
||||
if config.NowFunc != nil {
|
||||
tx.Config.NowFunc = config.NowFunc
|
||||
}
|
||||
|
||||
if config.Initialized {
|
||||
tx = tx.getInstance()
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
// WithContext change current instance db's context to ctx
|
||||
func (db *DB) WithContext(ctx context.Context) *DB {
|
||||
return db.Session(&Session{Context: ctx})
|
||||
}
|
||||
|
||||
// Debug start debug mode
|
||||
func (db *DB) Debug() (tx *DB) {
|
||||
tx = db.getInstance()
|
||||
return tx.Session(&Session{
|
||||
Logger: db.Logger.LogMode(logger.Info),
|
||||
})
|
||||
}
|
||||
|
||||
// Set store value with key into current db instance's context
|
||||
func (db *DB) Set(key string, value interface{}) *DB {
|
||||
tx := db.getInstance()
|
||||
tx.Statement.Settings.Store(key, value)
|
||||
return tx
|
||||
}
|
||||
|
||||
// Get get value with key from current db instance's context
|
||||
func (db *DB) Get(key string) (interface{}, bool) {
|
||||
return db.Statement.Settings.Load(key)
|
||||
}
|
||||
|
||||
// InstanceSet store value with key into current db instance's context
|
||||
func (db *DB) InstanceSet(key string, value interface{}) *DB {
|
||||
tx := db.getInstance()
|
||||
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
|
||||
return tx
|
||||
}
|
||||
|
||||
// InstanceGet get value with key from current db instance's context
|
||||
func (db *DB) InstanceGet(key string) (interface{}, bool) {
|
||||
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
|
||||
}
|
||||
|
||||
// Callback returns callback manager
|
||||
func (db *DB) Callback() *callbacks {
|
||||
return db.callbacks
|
||||
}
|
||||
|
||||
// AddError add error to db
|
||||
func (db *DB) AddError(err error) error {
|
||||
if err != nil {
|
||||
if db.Config.TranslateError {
|
||||
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
|
||||
err = errTranslator.Translate(err)
|
||||
}
|
||||
}
|
||||
|
||||
if db.Error == nil {
|
||||
db.Error = err
|
||||
} else {
|
||||
db.Error = fmt.Errorf("%v; %w", db.Error, err)
|
||||
}
|
||||
}
|
||||
return db.Error
|
||||
}
|
||||
|
||||
// DB returns `*sql.DB`
|
||||
func (db *DB) DB() (*sql.DB, error) {
|
||||
connPool := db.ConnPool
|
||||
if db.Statement != nil && db.Statement.ConnPool != nil {
|
||||
connPool = db.Statement.ConnPool
|
||||
}
|
||||
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
|
||||
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
|
||||
return sqldb, err
|
||||
}
|
||||
}
|
||||
|
||||
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
func (db *DB) getInstance() *DB {
|
||||
if db.clone > 0 {
|
||||
tx := &DB{Config: db.Config, Error: db.Error}
|
||||
|
||||
if db.clone == 1 {
|
||||
// clone with new statement
|
||||
tx.Statement = &Statement{
|
||||
DB: tx,
|
||||
ConnPool: db.Statement.ConnPool,
|
||||
Context: db.Statement.Context,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Vars: make([]interface{}, 0, 8),
|
||||
SkipHooks: db.Statement.SkipHooks,
|
||||
}
|
||||
} else {
|
||||
// with clone statement
|
||||
tx.Statement = db.Statement.clone()
|
||||
tx.Statement.DB = tx
|
||||
}
|
||||
|
||||
return tx
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// Expr returns clause.Expr, which can be used to pass SQL expression as params
|
||||
func Expr(expr string, args ...interface{}) clause.Expr {
|
||||
return clause.Expr{SQL: expr, Vars: args}
|
||||
}
|
||||
|
||||
// SetupJoinTable setup join table schema
|
||||
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
|
||||
var (
|
||||
tx = db.getInstance()
|
||||
stmt = tx.Statement
|
||||
modelSchema, joinSchema *schema.Schema
|
||||
)
|
||||
|
||||
err := stmt.Parse(model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modelSchema = stmt.Schema
|
||||
|
||||
err = stmt.Parse(joinTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
joinSchema = stmt.Schema
|
||||
|
||||
relation, ok := modelSchema.Relationships.Relations[field]
|
||||
isRelation := ok && relation.JoinTable != nil
|
||||
if !isRelation {
|
||||
return fmt.Errorf("failed to find relation: %s", field)
|
||||
}
|
||||
|
||||
for _, ref := range relation.References {
|
||||
f := joinSchema.LookUpField(ref.ForeignKey.DBName)
|
||||
if f == nil {
|
||||
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
f.DataType = ref.ForeignKey.DataType
|
||||
f.GORMDataType = ref.ForeignKey.GORMDataType
|
||||
if f.Size == 0 {
|
||||
f.Size = ref.ForeignKey.Size
|
||||
}
|
||||
ref.ForeignKey = f
|
||||
}
|
||||
|
||||
for name, rel := range relation.JoinTable.Relationships.Relations {
|
||||
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
|
||||
rel.Schema = joinSchema
|
||||
joinSchema.Relationships.Relations[name] = rel
|
||||
}
|
||||
}
|
||||
relation.JoinTable = joinSchema
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use use plugin
|
||||
func (db *DB) Use(plugin Plugin) error {
|
||||
name := plugin.Name()
|
||||
if _, ok := db.Plugins[name]; ok {
|
||||
return ErrRegistered
|
||||
}
|
||||
if err := plugin.Initialize(db); err != nil {
|
||||
return err
|
||||
}
|
||||
db.Plugins[name] = plugin
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToSQL for generate SQL string.
|
||||
//
|
||||
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
|
||||
// .Limit(10).Offset(5)
|
||||
// .Order("name ASC")
|
||||
// .First(&User{})
|
||||
// })
|
||||
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
|
||||
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
|
||||
stmt := tx.Statement
|
||||
|
||||
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
|
||||
}
|
||||
92
vendor/gorm.io/gorm/interfaces.go
generated
vendored
Normal file
92
vendor/gorm.io/gorm/interfaces.go
generated
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Dialector GORM database dialector
|
||||
type Dialector interface {
|
||||
Name() string
|
||||
Initialize(*DB) error
|
||||
Migrator(db *DB) Migrator
|
||||
DataTypeOf(*schema.Field) string
|
||||
DefaultValueOf(*schema.Field) clause.Expression
|
||||
BindVarTo(writer clause.Writer, stmt *Statement, v interface{})
|
||||
QuoteTo(clause.Writer, string)
|
||||
Explain(sql string, vars ...interface{}) string
|
||||
}
|
||||
|
||||
// Plugin GORM plugin interface
|
||||
type Plugin interface {
|
||||
Name() string
|
||||
Initialize(*DB) error
|
||||
}
|
||||
|
||||
type ParamsFilter interface {
|
||||
ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{})
|
||||
}
|
||||
|
||||
// ConnPool db conns pool interface
|
||||
type ConnPool interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// SavePointerDialectorInterface save pointer interface
|
||||
type SavePointerDialectorInterface interface {
|
||||
SavePoint(tx *DB, name string) error
|
||||
RollbackTo(tx *DB, name string) error
|
||||
}
|
||||
|
||||
// TxBeginner tx beginner
|
||||
type TxBeginner interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// ConnPoolBeginner conn pool beginner
|
||||
type ConnPoolBeginner interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error)
|
||||
}
|
||||
|
||||
// TxCommitter tx committer
|
||||
type TxCommitter interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// Tx sql.Tx interface
|
||||
type Tx interface {
|
||||
ConnPool
|
||||
TxCommitter
|
||||
StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt
|
||||
}
|
||||
|
||||
// Valuer gorm valuer interface
|
||||
type Valuer interface {
|
||||
GormValue(context.Context, *DB) clause.Expr
|
||||
}
|
||||
|
||||
// GetDBConnector SQL db connector
|
||||
type GetDBConnector interface {
|
||||
GetDBConn() (*sql.DB, error)
|
||||
}
|
||||
|
||||
// Rows rows interface
|
||||
type Rows interface {
|
||||
Columns() ([]string, error)
|
||||
ColumnTypes() ([]*sql.ColumnType, error)
|
||||
Next() bool
|
||||
Scan(dest ...interface{}) error
|
||||
Err() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type ErrorTranslator interface {
|
||||
Translate(err error) error
|
||||
}
|
||||
213
vendor/gorm.io/gorm/logger/logger.go
generated
vendored
Normal file
213
vendor/gorm.io/gorm/logger/logger.go
generated
vendored
Normal file
@@ -0,0 +1,213 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// ErrRecordNotFound record not found error
|
||||
var ErrRecordNotFound = errors.New("record not found")
|
||||
|
||||
// Colors
|
||||
const (
|
||||
Reset = "\033[0m"
|
||||
Red = "\033[31m"
|
||||
Green = "\033[32m"
|
||||
Yellow = "\033[33m"
|
||||
Blue = "\033[34m"
|
||||
Magenta = "\033[35m"
|
||||
Cyan = "\033[36m"
|
||||
White = "\033[37m"
|
||||
BlueBold = "\033[34;1m"
|
||||
MagentaBold = "\033[35;1m"
|
||||
RedBold = "\033[31;1m"
|
||||
YellowBold = "\033[33;1m"
|
||||
)
|
||||
|
||||
// LogLevel log level
|
||||
type LogLevel int
|
||||
|
||||
const (
|
||||
// Silent silent log level
|
||||
Silent LogLevel = iota + 1
|
||||
// Error error log level
|
||||
Error
|
||||
// Warn warn log level
|
||||
Warn
|
||||
// Info info log level
|
||||
Info
|
||||
)
|
||||
|
||||
// Writer log writer interface
|
||||
type Writer interface {
|
||||
Printf(string, ...interface{})
|
||||
}
|
||||
|
||||
// Config logger config
|
||||
type Config struct {
|
||||
SlowThreshold time.Duration
|
||||
Colorful bool
|
||||
IgnoreRecordNotFoundError bool
|
||||
ParameterizedQueries bool
|
||||
LogLevel LogLevel
|
||||
}
|
||||
|
||||
// Interface logger interface
|
||||
type Interface interface {
|
||||
LogMode(LogLevel) Interface
|
||||
Info(context.Context, string, ...interface{})
|
||||
Warn(context.Context, string, ...interface{})
|
||||
Error(context.Context, string, ...interface{})
|
||||
Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error)
|
||||
}
|
||||
|
||||
var (
|
||||
// Discard logger will print any log to io.Discard
|
||||
Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{})
|
||||
// Default Default logger
|
||||
Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
LogLevel: Warn,
|
||||
IgnoreRecordNotFoundError: false,
|
||||
Colorful: true,
|
||||
})
|
||||
// Recorder logger records running SQL into a recorder instance
|
||||
Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()}
|
||||
)
|
||||
|
||||
// New initialize logger
|
||||
func New(writer Writer, config Config) Interface {
|
||||
var (
|
||||
infoStr = "%s\n[info] "
|
||||
warnStr = "%s\n[warn] "
|
||||
errStr = "%s\n[error] "
|
||||
traceStr = "%s\n[%.3fms] [rows:%v] %s"
|
||||
traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s"
|
||||
traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
|
||||
)
|
||||
|
||||
if config.Colorful {
|
||||
infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
|
||||
warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
|
||||
errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
|
||||
traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
|
||||
traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
|
||||
traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
|
||||
}
|
||||
|
||||
return &logger{
|
||||
Writer: writer,
|
||||
Config: config,
|
||||
infoStr: infoStr,
|
||||
warnStr: warnStr,
|
||||
errStr: errStr,
|
||||
traceStr: traceStr,
|
||||
traceWarnStr: traceWarnStr,
|
||||
traceErrStr: traceErrStr,
|
||||
}
|
||||
}
|
||||
|
||||
type logger struct {
|
||||
Writer
|
||||
Config
|
||||
infoStr, warnStr, errStr string
|
||||
traceStr, traceErrStr, traceWarnStr string
|
||||
}
|
||||
|
||||
// LogMode log mode
|
||||
func (l *logger) LogMode(level LogLevel) Interface {
|
||||
newlogger := *l
|
||||
newlogger.LogLevel = level
|
||||
return &newlogger
|
||||
}
|
||||
|
||||
// Info print info
|
||||
func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Info {
|
||||
l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Warn print warn messages
|
||||
func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Warn {
|
||||
l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Error print error messages
|
||||
func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= Error {
|
||||
l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Trace print sql message
|
||||
//
|
||||
//nolint:cyclop
|
||||
func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
if l.LogLevel <= Silent {
|
||||
return
|
||||
}
|
||||
|
||||
elapsed := time.Since(begin)
|
||||
switch {
|
||||
case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError):
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||
} else {
|
||||
l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
}
|
||||
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn:
|
||||
sql, rows := fc()
|
||||
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
|
||||
if rows == -1 {
|
||||
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||
} else {
|
||||
l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
}
|
||||
case l.LogLevel == Info:
|
||||
sql, rows := fc()
|
||||
if rows == -1 {
|
||||
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
|
||||
} else {
|
||||
l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ParamsFilter filter params
|
||||
func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) {
|
||||
if l.Config.ParameterizedQueries {
|
||||
return sql, nil
|
||||
}
|
||||
return sql, params
|
||||
}
|
||||
|
||||
type traceRecorder struct {
|
||||
Interface
|
||||
BeginAt time.Time
|
||||
SQL string
|
||||
RowsAffected int64
|
||||
Err error
|
||||
}
|
||||
|
||||
// New trace recorder
|
||||
func (l *traceRecorder) New() *traceRecorder {
|
||||
return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()}
|
||||
}
|
||||
|
||||
// Trace implement logger interface
|
||||
func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||||
l.BeginAt = begin
|
||||
l.SQL, l.RowsAffected = fc()
|
||||
l.Err = err
|
||||
}
|
||||
162
vendor/gorm.io/gorm/logger/sql.go
generated
vendored
Normal file
162
vendor/gorm.io/gorm/logger/sql.go
generated
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
tmFmtWithMS = "2006-01-02 15:04:05.999"
|
||||
tmFmtZero = "0000-00-00 00:00:00"
|
||||
nullStr = "NULL"
|
||||
)
|
||||
|
||||
func isPrintable(s string) bool {
|
||||
for _, r := range s {
|
||||
if !unicode.IsPrint(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// A list of Go types that should be converted to SQL primitives
|
||||
var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
|
||||
|
||||
// RegEx matches only numeric values
|
||||
var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
|
||||
|
||||
// ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
|
||||
func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
|
||||
var (
|
||||
convertParams func(interface{}, int)
|
||||
vars = make([]string, len(avars))
|
||||
)
|
||||
|
||||
convertParams = func(v interface{}, idx int) {
|
||||
switch v := v.(type) {
|
||||
case bool:
|
||||
vars[idx] = strconv.FormatBool(v)
|
||||
case time.Time:
|
||||
if v.IsZero() {
|
||||
vars[idx] = escaper + tmFmtZero + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
|
||||
}
|
||||
case *time.Time:
|
||||
if v != nil {
|
||||
if v.IsZero() {
|
||||
vars[idx] = escaper + tmFmtZero + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
|
||||
}
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
case driver.Valuer:
|
||||
reflectValue := reflect.ValueOf(v)
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
r, _ := v.Value()
|
||||
convertParams(r, idx)
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
case fmt.Stringer:
|
||||
reflectValue := reflect.ValueOf(v)
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
|
||||
case reflect.Bool:
|
||||
vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
|
||||
case reflect.String:
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
default:
|
||||
if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = nullStr
|
||||
}
|
||||
}
|
||||
case []byte:
|
||||
if s := string(v); isPrintable(s) {
|
||||
vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
|
||||
} else {
|
||||
vars[idx] = escaper + "<binary>" + escaper
|
||||
}
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
vars[idx] = utils.ToString(v)
|
||||
case float32:
|
||||
vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
|
||||
case float64:
|
||||
vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case string:
|
||||
vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
|
||||
default:
|
||||
rv := reflect.ValueOf(v)
|
||||
if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
vars[idx] = nullStr
|
||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
convertParams(v, idx)
|
||||
} else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
|
||||
convertParams(reflect.Indirect(rv).Interface(), idx)
|
||||
} else {
|
||||
for _, t := range convertibleTypes {
|
||||
if rv.Type().ConvertibleTo(t) {
|
||||
convertParams(rv.Convert(t).Interface(), idx)
|
||||
return
|
||||
}
|
||||
}
|
||||
vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, v := range avars {
|
||||
convertParams(v, idx)
|
||||
}
|
||||
|
||||
if numericPlaceholder == nil {
|
||||
var idx int
|
||||
var newSQL strings.Builder
|
||||
|
||||
for _, v := range []byte(sql) {
|
||||
if v == '?' {
|
||||
if len(vars) > idx {
|
||||
newSQL.WriteString(vars[idx])
|
||||
idx++
|
||||
continue
|
||||
}
|
||||
}
|
||||
newSQL.WriteByte(v)
|
||||
}
|
||||
|
||||
sql = newSQL.String()
|
||||
} else {
|
||||
sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
|
||||
|
||||
sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
|
||||
num := v[1 : len(v)-1]
|
||||
n, _ := strconv.Atoi(num)
|
||||
|
||||
// position var start from 1 ($1, $2)
|
||||
n -= 1
|
||||
if n >= 0 && n <= len(vars)-1 {
|
||||
return vars[n]
|
||||
}
|
||||
return v
|
||||
})
|
||||
}
|
||||
|
||||
return sql
|
||||
}
|
||||
111
vendor/gorm.io/gorm/migrator.go
generated
vendored
Normal file
111
vendor/gorm.io/gorm/migrator.go
generated
vendored
Normal file
@@ -0,0 +1,111 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Migrator returns migrator
|
||||
func (db *DB) Migrator() Migrator {
|
||||
tx := db.getInstance()
|
||||
|
||||
// apply scopes to migrator
|
||||
for len(tx.Statement.scopes) > 0 {
|
||||
tx = tx.executeScopes()
|
||||
}
|
||||
|
||||
return tx.Dialector.Migrator(tx.Session(&Session{}))
|
||||
}
|
||||
|
||||
// AutoMigrate run auto migration for given models
|
||||
func (db *DB) AutoMigrate(dst ...interface{}) error {
|
||||
return db.Migrator().AutoMigrate(dst...)
|
||||
}
|
||||
|
||||
// ViewOption view option
|
||||
type ViewOption struct {
|
||||
Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE`
|
||||
CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION`
|
||||
Query *DB // required subquery.
|
||||
}
|
||||
|
||||
// ColumnType column type interface
|
||||
type ColumnType interface {
|
||||
Name() string
|
||||
DatabaseTypeName() string // varchar
|
||||
ColumnType() (columnType string, ok bool) // varchar(64)
|
||||
PrimaryKey() (isPrimaryKey bool, ok bool)
|
||||
AutoIncrement() (isAutoIncrement bool, ok bool)
|
||||
Length() (length int64, ok bool)
|
||||
DecimalSize() (precision int64, scale int64, ok bool)
|
||||
Nullable() (nullable bool, ok bool)
|
||||
Unique() (unique bool, ok bool)
|
||||
ScanType() reflect.Type
|
||||
Comment() (value string, ok bool)
|
||||
DefaultValue() (value string, ok bool)
|
||||
}
|
||||
|
||||
type Index interface {
|
||||
Table() string
|
||||
Name() string
|
||||
Columns() []string
|
||||
PrimaryKey() (isPrimaryKey bool, ok bool)
|
||||
Unique() (unique bool, ok bool)
|
||||
Option() string
|
||||
}
|
||||
|
||||
// TableType table type interface
|
||||
type TableType interface {
|
||||
Schema() string
|
||||
Name() string
|
||||
Type() string
|
||||
Comment() (comment string, ok bool)
|
||||
}
|
||||
|
||||
// Migrator migrator interface
|
||||
type Migrator interface {
|
||||
// AutoMigrate
|
||||
AutoMigrate(dst ...interface{}) error
|
||||
|
||||
// Database
|
||||
CurrentDatabase() string
|
||||
FullDataTypeOf(*schema.Field) clause.Expr
|
||||
GetTypeAliases(databaseTypeName string) []string
|
||||
|
||||
// Tables
|
||||
CreateTable(dst ...interface{}) error
|
||||
DropTable(dst ...interface{}) error
|
||||
HasTable(dst interface{}) bool
|
||||
RenameTable(oldName, newName interface{}) error
|
||||
GetTables() (tableList []string, err error)
|
||||
TableType(dst interface{}) (TableType, error)
|
||||
|
||||
// Columns
|
||||
AddColumn(dst interface{}, field string) error
|
||||
DropColumn(dst interface{}, field string) error
|
||||
AlterColumn(dst interface{}, field string) error
|
||||
MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
// MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn.
|
||||
MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error
|
||||
HasColumn(dst interface{}, field string) bool
|
||||
RenameColumn(dst interface{}, oldName, field string) error
|
||||
ColumnTypes(dst interface{}) ([]ColumnType, error)
|
||||
|
||||
// Views
|
||||
CreateView(name string, option ViewOption) error
|
||||
DropView(name string) error
|
||||
|
||||
// Constraints
|
||||
CreateConstraint(dst interface{}, name string) error
|
||||
DropConstraint(dst interface{}, name string) error
|
||||
HasConstraint(dst interface{}, name string) bool
|
||||
|
||||
// Indexes
|
||||
CreateIndex(dst interface{}, name string) error
|
||||
DropIndex(dst interface{}, name string) error
|
||||
HasIndex(dst interface{}, name string) bool
|
||||
RenameIndex(dst interface{}, oldName, newName string) error
|
||||
GetIndexes(dst interface{}) ([]Index, error)
|
||||
}
|
||||
107
vendor/gorm.io/gorm/migrator/column_type.go
generated
vendored
Normal file
107
vendor/gorm.io/gorm/migrator/column_type.go
generated
vendored
Normal file
@@ -0,0 +1,107 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// ColumnType column type implements ColumnType interface
|
||||
type ColumnType struct {
|
||||
SQLColumnType *sql.ColumnType
|
||||
NameValue sql.NullString
|
||||
DataTypeValue sql.NullString
|
||||
ColumnTypeValue sql.NullString
|
||||
PrimaryKeyValue sql.NullBool
|
||||
UniqueValue sql.NullBool
|
||||
AutoIncrementValue sql.NullBool
|
||||
LengthValue sql.NullInt64
|
||||
DecimalSizeValue sql.NullInt64
|
||||
ScaleValue sql.NullInt64
|
||||
NullableValue sql.NullBool
|
||||
ScanTypeValue reflect.Type
|
||||
CommentValue sql.NullString
|
||||
DefaultValueValue sql.NullString
|
||||
}
|
||||
|
||||
// Name returns the name or alias of the column.
|
||||
func (ct ColumnType) Name() string {
|
||||
if ct.NameValue.Valid {
|
||||
return ct.NameValue.String
|
||||
}
|
||||
return ct.SQLColumnType.Name()
|
||||
}
|
||||
|
||||
// DatabaseTypeName returns the database system name of the column type. If an empty
|
||||
// string is returned, then the driver type name is not supported.
|
||||
// Consult your driver documentation for a list of driver data types. Length specifiers
|
||||
// are not included.
|
||||
// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
|
||||
// "INT", and "BIGINT".
|
||||
func (ct ColumnType) DatabaseTypeName() string {
|
||||
if ct.DataTypeValue.Valid {
|
||||
return ct.DataTypeValue.String
|
||||
}
|
||||
return ct.SQLColumnType.DatabaseTypeName()
|
||||
}
|
||||
|
||||
// ColumnType returns the database type of the column. like `varchar(16)`
|
||||
func (ct ColumnType) ColumnType() (columnType string, ok bool) {
|
||||
return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
|
||||
}
|
||||
|
||||
// PrimaryKey returns the column is primary key or not.
|
||||
func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
|
||||
return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid
|
||||
}
|
||||
|
||||
// AutoIncrement returns the column is auto increment or not.
|
||||
func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
|
||||
return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
|
||||
}
|
||||
|
||||
// Length returns the column type length for variable length column types
|
||||
func (ct ColumnType) Length() (length int64, ok bool) {
|
||||
if ct.LengthValue.Valid {
|
||||
return ct.LengthValue.Int64, true
|
||||
}
|
||||
return ct.SQLColumnType.Length()
|
||||
}
|
||||
|
||||
// DecimalSize returns the scale and precision of a decimal type.
|
||||
func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
|
||||
if ct.DecimalSizeValue.Valid {
|
||||
return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
|
||||
}
|
||||
return ct.SQLColumnType.DecimalSize()
|
||||
}
|
||||
|
||||
// Nullable reports whether the column may be null.
|
||||
func (ct ColumnType) Nullable() (nullable bool, ok bool) {
|
||||
if ct.NullableValue.Valid {
|
||||
return ct.NullableValue.Bool, true
|
||||
}
|
||||
return ct.SQLColumnType.Nullable()
|
||||
}
|
||||
|
||||
// Unique reports whether the column may be unique.
|
||||
func (ct ColumnType) Unique() (unique bool, ok bool) {
|
||||
return ct.UniqueValue.Bool, ct.UniqueValue.Valid
|
||||
}
|
||||
|
||||
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
|
||||
func (ct ColumnType) ScanType() reflect.Type {
|
||||
if ct.ScanTypeValue != nil {
|
||||
return ct.ScanTypeValue
|
||||
}
|
||||
return ct.SQLColumnType.ScanType()
|
||||
}
|
||||
|
||||
// Comment returns the comment of current column.
|
||||
func (ct ColumnType) Comment() (value string, ok bool) {
|
||||
return ct.CommentValue.String, ct.CommentValue.Valid
|
||||
}
|
||||
|
||||
// DefaultValue returns the default value of current column.
|
||||
func (ct ColumnType) DefaultValue() (value string, ok bool) {
|
||||
return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
|
||||
}
|
||||
43
vendor/gorm.io/gorm/migrator/index.go
generated
vendored
Normal file
43
vendor/gorm.io/gorm/migrator/index.go
generated
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
package migrator
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// Index implements gorm.Index interface
|
||||
type Index struct {
|
||||
TableName string
|
||||
NameValue string
|
||||
ColumnList []string
|
||||
PrimaryKeyValue sql.NullBool
|
||||
UniqueValue sql.NullBool
|
||||
OptionValue string
|
||||
}
|
||||
|
||||
// Table return the table name of the index.
|
||||
func (idx Index) Table() string {
|
||||
return idx.TableName
|
||||
}
|
||||
|
||||
// Name return the name of the index.
|
||||
func (idx Index) Name() string {
|
||||
return idx.NameValue
|
||||
}
|
||||
|
||||
// Columns return the columns of the index
|
||||
func (idx Index) Columns() []string {
|
||||
return idx.ColumnList
|
||||
}
|
||||
|
||||
// PrimaryKey returns the index is primary key or not.
|
||||
func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) {
|
||||
return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid
|
||||
}
|
||||
|
||||
// Unique returns whether the index is unique or not.
|
||||
func (idx Index) Unique() (unique bool, ok bool) {
|
||||
return idx.UniqueValue.Bool, idx.UniqueValue.Valid
|
||||
}
|
||||
|
||||
// Option return the optional attribute of the index
|
||||
func (idx Index) Option() string {
|
||||
return idx.OptionValue
|
||||
}
|
||||
989
vendor/gorm.io/gorm/migrator/migrator.go
generated
vendored
Normal file
989
vendor/gorm.io/gorm/migrator/migrator.go
generated
vendored
Normal file
@@ -0,0 +1,989 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*),
|
||||
// with a possible trailing non-digit character (\D?).
|
||||
|
||||
// For example, values that can pass this regular expression are:
|
||||
// - "123"
|
||||
// - "abc456"
|
||||
// -"%$#@789"
|
||||
var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`)
|
||||
|
||||
// TODO:? Create const vars for raw sql queries ?
|
||||
|
||||
var _ gorm.Migrator = (*Migrator)(nil)
|
||||
|
||||
// Migrator m struct
|
||||
type Migrator struct {
|
||||
Config
|
||||
}
|
||||
|
||||
// Config schema config
|
||||
type Config struct {
|
||||
CreateIndexAfterCreateTable bool
|
||||
DB *gorm.DB
|
||||
gorm.Dialector
|
||||
}
|
||||
|
||||
type printSQLLogger struct {
|
||||
logger.Interface
|
||||
}
|
||||
|
||||
func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
sql, _ := fc()
|
||||
fmt.Println(sql + ";")
|
||||
l.Interface.Trace(ctx, begin, fc, err)
|
||||
}
|
||||
|
||||
// GormDataTypeInterface gorm data type interface
|
||||
type GormDataTypeInterface interface {
|
||||
GormDBDataType(*gorm.DB, *schema.Field) string
|
||||
}
|
||||
|
||||
// RunWithValue run migration with statement value
|
||||
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
|
||||
stmt := &gorm.Statement{DB: m.DB}
|
||||
if m.DB.Statement != nil {
|
||||
stmt.Table = m.DB.Statement.Table
|
||||
stmt.TableExpr = m.DB.Statement.TableExpr
|
||||
}
|
||||
|
||||
if table, ok := value.(string); ok {
|
||||
stmt.Table = table
|
||||
} else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fc(stmt)
|
||||
}
|
||||
|
||||
// DataTypeOf return field's db data type
|
||||
func (m Migrator) DataTypeOf(field *schema.Field) string {
|
||||
fieldValue := reflect.New(field.IndirectFieldType)
|
||||
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
|
||||
if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" {
|
||||
return dataType
|
||||
}
|
||||
}
|
||||
|
||||
return m.Dialector.DataTypeOf(field)
|
||||
}
|
||||
|
||||
// FullDataTypeOf returns field's db full data type
|
||||
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
|
||||
expr.SQL = m.DataTypeOf(field)
|
||||
|
||||
if field.NotNull {
|
||||
expr.SQL += " NOT NULL"
|
||||
}
|
||||
|
||||
if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
|
||||
if field.DefaultValueInterface != nil {
|
||||
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
|
||||
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
|
||||
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
|
||||
} else if field.DefaultValue != "(-)" {
|
||||
expr.SQL += " DEFAULT " + field.DefaultValue
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) {
|
||||
queryTx = m.DB.Session(&gorm.Session{})
|
||||
execTx = queryTx
|
||||
if m.DB.DryRun {
|
||||
queryTx.DryRun = false
|
||||
execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}})
|
||||
}
|
||||
return queryTx, execTx
|
||||
}
|
||||
|
||||
// AutoMigrate auto migrate values
|
||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, true) {
|
||||
queryTx, execTx := m.GetQueryAndExecTx()
|
||||
if !queryTx.Migrator().HasTable(value) {
|
||||
if err := execTx.Migrator().CreateTable(value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
columnTypes, err := queryTx.Migrator().ColumnTypes(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
parseIndexes = stmt.Schema.ParseIndexes()
|
||||
parseCheckConstraints = stmt.Schema.ParseCheckConstraints()
|
||||
)
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
var foundColumn gorm.ColumnType
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
if columnType.Name() == dbName {
|
||||
foundColumn = columnType
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundColumn == nil {
|
||||
// not found, add column
|
||||
if err = execTx.Migrator().AddColumn(value, dbName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// found, smartly migrate
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if constraint := rel.ParseConstraint(); constraint != nil &&
|
||||
constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chk := range parseCheckConstraints {
|
||||
if !queryTx.Migrator().HasConstraint(value, chk.Name) {
|
||||
if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range parseIndexes {
|
||||
if !queryTx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTables returns tables
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
|
||||
Scan(&tableList).Error
|
||||
return
|
||||
}
|
||||
|
||||
// CreateTable create table in database for values
|
||||
func (m Migrator) CreateTable(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, false) {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
var (
|
||||
createTableSQL = "CREATE TABLE ? ("
|
||||
values = []interface{}{m.CurrentTable(stmt)}
|
||||
hasPrimaryKeyInDataType bool
|
||||
)
|
||||
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
if !field.IgnoreMigration {
|
||||
createTableSQL += "? ?"
|
||||
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY")
|
||||
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
|
||||
createTableSQL += ","
|
||||
}
|
||||
}
|
||||
|
||||
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
||||
createTableSQL += "PRIMARY KEY ?,"
|
||||
primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields))
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
values = append(values, primaryKeys)
|
||||
}
|
||||
|
||||
for _, idx := range stmt.Schema.ParseIndexes() {
|
||||
if m.CreateIndexAfterCreateTable {
|
||||
defer func(value interface{}, name string) {
|
||||
if err == nil {
|
||||
err = tx.Migrator().CreateIndex(value, name)
|
||||
}
|
||||
}(value, idx.Name)
|
||||
} else {
|
||||
if idx.Class != "" {
|
||||
createTableSQL += idx.Class + " "
|
||||
}
|
||||
createTableSQL += "INDEX ? ?"
|
||||
|
||||
if idx.Comment != "" {
|
||||
createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
|
||||
}
|
||||
|
||||
if idx.Option != "" {
|
||||
createTableSQL += " " + idx.Option
|
||||
}
|
||||
|
||||
createTableSQL += ","
|
||||
values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
|
||||
}
|
||||
}
|
||||
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
sql, vars := constraint.Build()
|
||||
createTableSQL += sql + ","
|
||||
values = append(values, vars...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, uni := range stmt.Schema.ParseUniqueConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? UNIQUE (?),"
|
||||
values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)})
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
||||
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
||||
}
|
||||
|
||||
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
|
||||
|
||||
createTableSQL += ")"
|
||||
|
||||
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
|
||||
createTableSQL += fmt.Sprint(tableOption)
|
||||
}
|
||||
|
||||
err = tx.Exec(createTableSQL, values...).Error
|
||||
return err
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropTable drop table for values
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
values = m.ReorderModels(values, false)
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasTable returns table exists or not for value, value could be a struct or string
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int64
|
||||
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count)
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// RenameTable rename table from oldName to newName
|
||||
func (m Migrator) RenameTable(oldName, newName interface{}) error {
|
||||
var oldTable, newTable interface{}
|
||||
if v, ok := oldName.(string); ok {
|
||||
oldTable = clause.Table{Name: v}
|
||||
} else {
|
||||
stmt := &gorm.Statement{DB: m.DB}
|
||||
if err := stmt.Parse(oldName); err == nil {
|
||||
oldTable = m.CurrentTable(stmt)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := newName.(string); ok {
|
||||
newTable = clause.Table{Name: v}
|
||||
} else {
|
||||
stmt := &gorm.Statement{DB: m.DB}
|
||||
if err := stmt.Parse(newName); err == nil {
|
||||
newTable = m.CurrentTable(stmt)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
|
||||
}
|
||||
|
||||
// AddColumn create `name` column for value
|
||||
func (m Migrator) AddColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// avoid using the same name field
|
||||
f := stmt.Schema.LookUpField(name)
|
||||
if f == nil {
|
||||
return fmt.Errorf("failed to look up field with name: %s", name)
|
||||
}
|
||||
|
||||
if !f.IgnoreMigration {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ADD ? ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f),
|
||||
).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DropColumn drop value's `name` column
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
// AlterColumn alter value's `field` column' type based on schema definition
|
||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
fileType := m.FullDataTypeOf(field)
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? ALTER COLUMN ? TYPE ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
|
||||
).Error
|
||||
|
||||
}
|
||||
return fmt.Errorf("failed to look up field with name: %s", field)
|
||||
})
|
||||
}
|
||||
|
||||
// HasColumn check has column `field` for value or not
|
||||
func (m Migrator) HasColumn(value interface{}, field string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
name := field
|
||||
if field := stmt.Schema.LookUpField(field); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
|
||||
currentDatabase, stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// RenameColumn rename value's field name from oldName to newName
|
||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if field := stmt.Schema.LookUpField(oldName); field != nil {
|
||||
oldName = field.DBName
|
||||
}
|
||||
|
||||
if field := stmt.Schema.LookUpField(newName); field != nil {
|
||||
newName = field.DBName
|
||||
}
|
||||
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? RENAME COLUMN ? TO ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
// MigrateColumn migrate column
|
||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||
if field.IgnoreMigration {
|
||||
return nil
|
||||
}
|
||||
|
||||
// found, smart migrate
|
||||
fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL))
|
||||
realDataType := strings.ToLower(columnType.DatabaseTypeName())
|
||||
|
||||
var (
|
||||
alterColumn bool
|
||||
isSameType = fullDataType == realDataType
|
||||
)
|
||||
|
||||
if !field.PrimaryKey {
|
||||
// check type
|
||||
if !strings.HasPrefix(fullDataType, realDataType) {
|
||||
// check type aliases
|
||||
aliases := m.DB.Migrator().GetTypeAliases(realDataType)
|
||||
for _, alias := range aliases {
|
||||
if strings.HasPrefix(fullDataType, alias) {
|
||||
isSameType = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isSameType {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !isSameType {
|
||||
// check size
|
||||
if length, ok := columnType.Length(); length != int64(field.Size) {
|
||||
if length > 0 && field.Size > 0 {
|
||||
alterColumn = true
|
||||
} else {
|
||||
// has size in data type and not equal
|
||||
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
||||
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
||||
if !field.PrimaryKey &&
|
||||
(len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check precision
|
||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check nullable
|
||||
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
||||
// not primary key & database is nullable
|
||||
if !field.PrimaryKey && nullable {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
// check default value
|
||||
if !field.PrimaryKey {
|
||||
currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL"))
|
||||
dv, dvNotNull := columnType.DefaultValue()
|
||||
if dvNotNull && !currentDefaultNotNull {
|
||||
// default value -> null
|
||||
alterColumn = true
|
||||
} else if !dvNotNull && currentDefaultNotNull {
|
||||
// null -> default value
|
||||
alterColumn = true
|
||||
} else if (field.GORMDataType != schema.Time && dv != field.DefaultValue) ||
|
||||
(field.GORMDataType == schema.Time && !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()"))) {
|
||||
// default value not equal
|
||||
// not both null
|
||||
if currentDefaultNotNull || dvNotNull {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check comment
|
||||
if comment, ok := columnType.Comment(); ok && comment != field.Comment {
|
||||
// not primary key
|
||||
if !field.PrimaryKey {
|
||||
alterColumn = true
|
||||
}
|
||||
}
|
||||
|
||||
if alterColumn {
|
||||
if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
|
||||
unique, ok := columnType.Unique()
|
||||
if !ok || field.PrimaryKey {
|
||||
return nil // skip primary key
|
||||
}
|
||||
// By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex.
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
// We're currently only receiving boolean values on `Unique` tag,
|
||||
// so the UniqueConstraint name is fixed
|
||||
constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
|
||||
if unique && !field.Unique {
|
||||
return m.DB.Migrator().DropConstraint(value, constraint)
|
||||
}
|
||||
if !unique && field.Unique {
|
||||
return m.DB.Migrator().CreateConstraint(value, constraint)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = rows.Close()
|
||||
}()
|
||||
|
||||
var rawColumnTypes []*sql.ColumnType
|
||||
rawColumnTypes, err = rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, c := range rawColumnTypes {
|
||||
columnTypes = append(columnTypes, ColumnType{SQLColumnType: c})
|
||||
}
|
||||
|
||||
return
|
||||
})
|
||||
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
// CreateView create view from Query in gorm.ViewOption.
|
||||
// Query in gorm.ViewOption is a [subquery]
|
||||
//
|
||||
// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20
|
||||
// q := DB.Model(&User{}).Where("age > ?", 20)
|
||||
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q})
|
||||
//
|
||||
// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION
|
||||
// q := DB.Model(&User{})
|
||||
// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"})
|
||||
//
|
||||
// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery
|
||||
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
||||
if option.Query == nil {
|
||||
return gorm.ErrSubQueryRequired
|
||||
}
|
||||
|
||||
sql := new(strings.Builder)
|
||||
sql.WriteString("CREATE ")
|
||||
if option.Replace {
|
||||
sql.WriteString("OR REPLACE ")
|
||||
}
|
||||
sql.WriteString("VIEW ")
|
||||
m.QuoteTo(sql, name)
|
||||
sql.WriteString(" AS ")
|
||||
|
||||
m.DB.Statement.AddVar(sql, option.Query)
|
||||
|
||||
if option.CheckOption != "" {
|
||||
sql.WriteString(" ")
|
||||
sql.WriteString(option.CheckOption)
|
||||
}
|
||||
return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error
|
||||
}
|
||||
|
||||
// DropView drop view
|
||||
func (m Migrator) DropView(name string) error {
|
||||
return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error
|
||||
}
|
||||
|
||||
// GuessConstraintAndTable guess statement's constraint and it's table based on name
|
||||
//
|
||||
// Deprecated: use GuessConstraintInterfaceAndTable instead.
|
||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
switch c := constraint.(type) {
|
||||
case *schema.Constraint:
|
||||
return c, nil, table
|
||||
case *schema.CheckConstraint:
|
||||
return nil, c, table
|
||||
default:
|
||||
return nil, nil, table
|
||||
}
|
||||
}
|
||||
|
||||
// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name
|
||||
// nolint:cyclop
|
||||
func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) {
|
||||
if stmt.Schema == nil {
|
||||
return nil, stmt.Table
|
||||
}
|
||||
|
||||
checkConstraints := stmt.Schema.ParseCheckConstraints()
|
||||
if chk, ok := checkConstraints[name]; ok {
|
||||
return &chk, stmt.Table
|
||||
}
|
||||
|
||||
uniqueConstraints := stmt.Schema.ParseUniqueConstraints()
|
||||
if uni, ok := uniqueConstraints[name]; ok {
|
||||
return &uni, stmt.Table
|
||||
}
|
||||
|
||||
getTable := func(rel *schema.Relationship) string {
|
||||
switch rel.Type {
|
||||
case schema.HasOne, schema.HasMany:
|
||||
return rel.FieldSchema.Table
|
||||
case schema.Many2Many:
|
||||
return rel.JoinTable.Table
|
||||
}
|
||||
return stmt.Table
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name {
|
||||
return constraint, getTable(rel)
|
||||
}
|
||||
}
|
||||
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
for k := range checkConstraints {
|
||||
if checkConstraints[k].Field == field {
|
||||
v := checkConstraints[k]
|
||||
return &v, stmt.Table
|
||||
}
|
||||
}
|
||||
|
||||
for k := range uniqueConstraints {
|
||||
if uniqueConstraints[k].Field == field {
|
||||
v := uniqueConstraints[k]
|
||||
return &v, stmt.Table
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field {
|
||||
return constraint, getTable(rel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, stmt.Schema.Table
|
||||
}
|
||||
|
||||
// CreateConstraint create constraint
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
vars := []interface{}{clause.Table{Name: table}}
|
||||
if stmt.TableExpr != nil {
|
||||
vars[0] = stmt.TableExpr
|
||||
}
|
||||
sql, values := constraint.Build()
|
||||
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DropConstraint drop constraint
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.GetName()
|
||||
}
|
||||
return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error
|
||||
})
|
||||
}
|
||||
|
||||
// HasConstraint check has constraint or not
|
||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.GetName()
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?",
|
||||
currentDatabase, table, name,
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// BuildIndexOptions build index options
|
||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
str = opt.Expression
|
||||
} else if opt.Length > 0 {
|
||||
str += fmt.Sprintf("(%d)", opt.Length)
|
||||
}
|
||||
|
||||
if opt.Collate != "" {
|
||||
str += " COLLATE " + opt.Collate
|
||||
}
|
||||
|
||||
if opt.Sort != "" {
|
||||
str += " " + opt.Sort
|
||||
}
|
||||
results = append(results, clause.Expr{SQL: str})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// BuildIndexOptionsInterface build index options interface
|
||||
type BuildIndexOptionsInterface interface {
|
||||
BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
|
||||
}
|
||||
|
||||
// CreateIndex create index `name`
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ? ON ??"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
|
||||
if idx.Comment != "" {
|
||||
createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
|
||||
}
|
||||
|
||||
if idx.Option != "" {
|
||||
createIndexSQL += " " + idx.Option
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to create index with name %s", name)
|
||||
})
|
||||
}
|
||||
|
||||
// DropIndex drop index `name`
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
|
||||
})
|
||||
}
|
||||
|
||||
// HasIndex check has index `name` or not
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
currentDatabase := m.DB.Migrator().CurrentDatabase()
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
|
||||
return m.DB.Raw(
|
||||
"SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?",
|
||||
currentDatabase, stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// RenameIndex rename index from oldName to newName
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Exec(
|
||||
"ALTER TABLE ? RENAME INDEX ? TO ?",
|
||||
m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName},
|
||||
).Error
|
||||
})
|
||||
}
|
||||
|
||||
// CurrentDatabase returns current database name
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
|
||||
return
|
||||
}
|
||||
|
||||
// ReorderModels reorder models according to constraint dependencies
|
||||
func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) {
|
||||
type Dependency struct {
|
||||
*gorm.Statement
|
||||
Depends []*schema.Schema
|
||||
}
|
||||
|
||||
var (
|
||||
modelNames, orderedModelNames []string
|
||||
orderedModelNamesMap = map[string]bool{}
|
||||
parsedSchemas = map[*schema.Schema]bool{}
|
||||
valuesMap = map[string]Dependency{}
|
||||
insertIntoOrderedList func(name string)
|
||||
parseDependence func(value interface{}, addToList bool)
|
||||
)
|
||||
|
||||
parseDependence = func(value interface{}, addToList bool) {
|
||||
dep := Dependency{
|
||||
Statement: &gorm.Statement{DB: m.DB, Dest: value},
|
||||
}
|
||||
beDependedOn := map[*schema.Schema]bool{}
|
||||
// support for special table name
|
||||
if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil {
|
||||
m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err)
|
||||
}
|
||||
if _, ok := parsedSchemas[dep.Statement.Schema]; ok {
|
||||
return
|
||||
}
|
||||
parsedSchemas[dep.Statement.Schema] = true
|
||||
|
||||
if !m.DB.IgnoreRelationshipsWhenMigrating {
|
||||
for _, rel := range dep.Schema.Relationships.Relations {
|
||||
if rel.Field.IgnoreMigration {
|
||||
continue
|
||||
}
|
||||
if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema {
|
||||
dep.Depends = append(dep.Depends, c.ReferenceSchema)
|
||||
}
|
||||
|
||||
if rel.Type == schema.HasOne || rel.Type == schema.HasMany {
|
||||
beDependedOn[rel.FieldSchema] = true
|
||||
}
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
// append join value
|
||||
defer func(rel *schema.Relationship, joinValue interface{}) {
|
||||
if !beDependedOn[rel.FieldSchema] {
|
||||
dep.Depends = append(dep.Depends, rel.FieldSchema)
|
||||
} else {
|
||||
fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface()
|
||||
parseDependence(fieldValue, autoAdd)
|
||||
}
|
||||
parseDependence(joinValue, autoAdd)
|
||||
}(rel, reflect.New(rel.JoinTable.ModelType).Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
valuesMap[dep.Schema.Table] = dep
|
||||
|
||||
if addToList {
|
||||
modelNames = append(modelNames, dep.Schema.Table)
|
||||
}
|
||||
}
|
||||
|
||||
insertIntoOrderedList = func(name string) {
|
||||
if _, ok := orderedModelNamesMap[name]; ok {
|
||||
return // avoid loop
|
||||
}
|
||||
orderedModelNamesMap[name] = true
|
||||
|
||||
if autoAdd {
|
||||
dep := valuesMap[name]
|
||||
for _, d := range dep.Depends {
|
||||
if _, ok := valuesMap[d.Table]; ok {
|
||||
insertIntoOrderedList(d.Table)
|
||||
} else {
|
||||
parseDependence(reflect.New(d.ModelType).Interface(), autoAdd)
|
||||
insertIntoOrderedList(d.Table)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
orderedModelNames = append(orderedModelNames, name)
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
if v, ok := value.(string); ok {
|
||||
results = append(results, v)
|
||||
} else {
|
||||
parseDependence(value, true)
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range modelNames {
|
||||
insertIntoOrderedList(name)
|
||||
}
|
||||
|
||||
for _, name := range orderedModelNames {
|
||||
results = append(results, valuesMap[name].Statement.Dest)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CurrentTable returns current statement's table expression
|
||||
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
|
||||
if stmt.TableExpr != nil {
|
||||
return *stmt.TableExpr
|
||||
}
|
||||
return clause.Table{Name: stmt.Table}
|
||||
}
|
||||
|
||||
// GetIndexes return Indexes []gorm.Index and execErr error
|
||||
func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) {
|
||||
return nil, errors.New("not support")
|
||||
}
|
||||
|
||||
// GetTypeAliases return database type aliases
|
||||
func (m Migrator) GetTypeAliases(databaseTypeName string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableType return tableType gorm.TableType and execErr error
|
||||
func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) {
|
||||
return nil, errors.New("not support")
|
||||
}
|
||||
33
vendor/gorm.io/gorm/migrator/table_type.go
generated
vendored
Normal file
33
vendor/gorm.io/gorm/migrator/table_type.go
generated
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
package migrator
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// TableType table type implements TableType interface
|
||||
type TableType struct {
|
||||
SchemaValue string
|
||||
NameValue string
|
||||
TypeValue string
|
||||
CommentValue sql.NullString
|
||||
}
|
||||
|
||||
// Schema returns the schema of the table.
|
||||
func (ct TableType) Schema() string {
|
||||
return ct.SchemaValue
|
||||
}
|
||||
|
||||
// Name returns the name of the table.
|
||||
func (ct TableType) Name() string {
|
||||
return ct.NameValue
|
||||
}
|
||||
|
||||
// Type returns the type of the table.
|
||||
func (ct TableType) Type() string {
|
||||
return ct.TypeValue
|
||||
}
|
||||
|
||||
// Comment returns the comment of current table.
|
||||
func (ct TableType) Comment() (comment string, ok bool) {
|
||||
return ct.CommentValue.String, ct.CommentValue.Valid
|
||||
}
|
||||
16
vendor/gorm.io/gorm/model.go
generated
vendored
Normal file
16
vendor/gorm.io/gorm/model.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
package gorm
|
||||
|
||||
import "time"
|
||||
|
||||
// Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt
|
||||
// It may be embedded into your model or you may build your own model without it
|
||||
//
|
||||
// type User struct {
|
||||
// gorm.Model
|
||||
// }
|
||||
type Model struct {
|
||||
ID uint `gorm:"primarykey"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt DeletedAt `gorm:"index"`
|
||||
}
|
||||
242
vendor/gorm.io/gorm/prepare_stmt.go
generated
vendored
Normal file
242
vendor/gorm.io/gorm/prepare_stmt.go
generated
vendored
Normal file
@@ -0,0 +1,242 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Stmt struct {
|
||||
*sql.Stmt
|
||||
Transaction bool
|
||||
prepared chan struct{}
|
||||
prepareErr error
|
||||
}
|
||||
|
||||
type PreparedStmtDB struct {
|
||||
Stmts map[string]*Stmt
|
||||
PreparedSQL []string
|
||||
Mux *sync.RWMutex
|
||||
ConnPool
|
||||
}
|
||||
|
||||
func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB {
|
||||
return &PreparedStmtDB{
|
||||
ConnPool: connPool,
|
||||
Stmts: make(map[string]*Stmt),
|
||||
Mux: &sync.RWMutex{},
|
||||
PreparedSQL: make([]string, 0, 100),
|
||||
}
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) {
|
||||
if sqldb, ok := db.ConnPool.(*sql.DB); ok {
|
||||
return sqldb, nil
|
||||
}
|
||||
|
||||
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
|
||||
return dbConnector.GetDBConn()
|
||||
}
|
||||
|
||||
return nil, ErrInvalidDB
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) Close() {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
for _, query := range db.PreparedSQL {
|
||||
if stmt, ok := db.Stmts[query]; ok {
|
||||
delete(db.Stmts, query)
|
||||
go stmt.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (sdb *PreparedStmtDB) Reset() {
|
||||
sdb.Mux.Lock()
|
||||
defer sdb.Mux.Unlock()
|
||||
|
||||
for _, stmt := range sdb.Stmts {
|
||||
go stmt.Close()
|
||||
}
|
||||
sdb.PreparedSQL = make([]string, 0, 100)
|
||||
sdb.Stmts = make(map[string]*Stmt)
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) {
|
||||
db.Mux.RLock()
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.RUnlock()
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
db.Mux.RUnlock()
|
||||
|
||||
db.Mux.Lock()
|
||||
// double check
|
||||
if stmt, ok := db.Stmts[query]; ok && (!stmt.Transaction || isTransaction) {
|
||||
db.Mux.Unlock()
|
||||
// wait for other goroutines prepared
|
||||
<-stmt.prepared
|
||||
if stmt.prepareErr != nil {
|
||||
return Stmt{}, stmt.prepareErr
|
||||
}
|
||||
|
||||
return *stmt, nil
|
||||
}
|
||||
|
||||
// cache preparing stmt first
|
||||
cacheStmt := Stmt{Transaction: isTransaction, prepared: make(chan struct{})}
|
||||
db.Stmts[query] = &cacheStmt
|
||||
db.Mux.Unlock()
|
||||
|
||||
// prepare completed
|
||||
defer close(cacheStmt.prepared)
|
||||
|
||||
// Reason why cannot lock conn.PrepareContext
|
||||
// suppose the maxopen is 1, g1 is creating record and g2 is querying record.
|
||||
// 1. g1 begin tx, g1 is requeue because of waiting for the system call, now `db.ConnPool` db.numOpen == 1.
|
||||
// 2. g2 select lock `conn.PrepareContext(ctx, query)`, now db.numOpen == db.maxOpen , wait for release.
|
||||
// 3. g1 tx exec insert, wait for unlock `conn.PrepareContext(ctx, query)` to finish tx and release.
|
||||
stmt, err := conn.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
cacheStmt.prepareErr = err
|
||||
db.Mux.Lock()
|
||||
delete(db.Stmts, query)
|
||||
db.Mux.Unlock()
|
||||
return Stmt{}, err
|
||||
}
|
||||
|
||||
db.Mux.Lock()
|
||||
cacheStmt.Stmt = stmt
|
||||
db.PreparedSQL = append(db.PreparedSQL, query)
|
||||
db.Mux.Unlock()
|
||||
|
||||
return cacheStmt, nil
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) {
|
||||
if beginner, ok := db.ConnPool.(TxBeginner); ok {
|
||||
tx, err := beginner.BeginTx(ctx, opt)
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err
|
||||
}
|
||||
|
||||
beginner, ok := db.ConnPool.(ConnPoolBeginner)
|
||||
if !ok {
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
connPool, err := beginner.BeginTx(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tx, ok := connPool.(Tx); ok {
|
||||
return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil
|
||||
}
|
||||
return nil, ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
result, err = stmt.ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
go stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
rows, err = stmt.QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
db.Mux.Lock()
|
||||
defer db.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(db.Stmts, query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
stmt, err := db.prepare(ctx, db.ConnPool, false, query)
|
||||
if err == nil {
|
||||
return stmt.QueryRowContext(ctx, args...)
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
|
||||
type PreparedStmtTX struct {
|
||||
Tx
|
||||
PreparedStmtDB *PreparedStmtDB
|
||||
}
|
||||
|
||||
func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) {
|
||||
return db.PreparedStmtDB.GetDBConn()
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Commit() error {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) Rollback() error {
|
||||
if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() {
|
||||
return tx.Tx.Rollback()
|
||||
}
|
||||
return ErrInvalidTransaction
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
}
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
tx.PreparedStmtDB.Mux.Lock()
|
||||
defer tx.PreparedStmtDB.Mux.Unlock()
|
||||
|
||||
go stmt.Close()
|
||||
delete(tx.PreparedStmtDB.Stmts, query)
|
||||
}
|
||||
}
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query)
|
||||
if err == nil {
|
||||
return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...)
|
||||
}
|
||||
return &sql.Row{}
|
||||
}
|
||||
342
vendor/gorm.io/gorm/scan.go
generated
vendored
Normal file
342
vendor/gorm.io/gorm/scan.go
generated
vendored
Normal file
@@ -0,0 +1,342 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// prepareValues prepare values slice
|
||||
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
|
||||
if db.Statement.Schema != nil {
|
||||
for idx, name := range columns {
|
||||
if field := db.Statement.Schema.LookUpField(name); field != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(field.FieldType)).Interface()
|
||||
continue
|
||||
}
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
} else if len(columnTypes) > 0 {
|
||||
for idx, columnType := range columnTypes {
|
||||
if columnType.ScanType() != nil {
|
||||
values[idx] = reflect.New(reflect.PtrTo(columnType.ScanType())).Interface()
|
||||
} else {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for idx := range columns {
|
||||
values[idx] = new(interface{})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
|
||||
for idx, column := range columns {
|
||||
if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
|
||||
mapValue[column] = reflectValue.Interface()
|
||||
if valuer, ok := mapValue[column].(driver.Valuer); ok {
|
||||
mapValue[column], _ = valuer.Value()
|
||||
} else if b, ok := mapValue[column].(sql.RawBytes); ok {
|
||||
mapValue[column] = string(b)
|
||||
}
|
||||
} else {
|
||||
mapValue[column] = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
|
||||
for idx, field := range fields {
|
||||
if field != nil {
|
||||
values[idx] = field.NewValuePool.Get()
|
||||
} else if len(fields) == 1 {
|
||||
if reflectValue.CanAddr() {
|
||||
values[idx] = reflectValue.Addr().Interface()
|
||||
} else {
|
||||
values[idx] = reflectValue.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
joinedNestedSchemaMap := make(map[string]interface{})
|
||||
for idx, field := range fields {
|
||||
if field == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
|
||||
db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
|
||||
} else { // joinFields count is larger than 2 when using join
|
||||
var isNilPtrValue bool
|
||||
var relValue reflect.Value
|
||||
// does not contain raw dbname
|
||||
nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
|
||||
// current reflect value
|
||||
currentReflectValue := reflectValue
|
||||
fullRels := make([]string, 0, len(nestedJoinSchemas))
|
||||
for _, joinSchema := range nestedJoinSchemas {
|
||||
fullRels = append(fullRels, joinSchema.Name)
|
||||
relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
|
||||
if relValue.Kind() == reflect.Ptr {
|
||||
fullRelsName := utils.JoinNestedRelationNames(fullRels)
|
||||
// same nested structure
|
||||
if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
|
||||
if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
isNilPtrValue = true
|
||||
break
|
||||
}
|
||||
|
||||
relValue.Set(reflect.New(relValue.Type().Elem()))
|
||||
joinedNestedSchemaMap[fullRelsName] = nil
|
||||
}
|
||||
}
|
||||
currentReflectValue = relValue
|
||||
}
|
||||
|
||||
if !isNilPtrValue { // ignore if value is nil
|
||||
f := joinFields[idx][len(joinFields[idx])-1]
|
||||
db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
|
||||
}
|
||||
}
|
||||
|
||||
// release data to pool
|
||||
field.NewValuePool.Put(values[idx])
|
||||
}
|
||||
}
|
||||
|
||||
// ScanMode scan data mode
|
||||
type ScanMode uint8
|
||||
|
||||
// scan modes
|
||||
const (
|
||||
ScanInitialized ScanMode = 1 << 0 // 1
|
||||
ScanUpdate ScanMode = 1 << 1 // 2
|
||||
ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
|
||||
)
|
||||
|
||||
// Scan scan rows into db statement
|
||||
func Scan(rows Rows, db *DB, mode ScanMode) {
|
||||
var (
|
||||
columns, _ = rows.Columns()
|
||||
values = make([]interface{}, len(columns))
|
||||
initialized = mode&ScanInitialized != 0
|
||||
update = mode&ScanUpdate != 0
|
||||
onConflictDonothing = mode&ScanOnConflictDoNothing != 0
|
||||
)
|
||||
|
||||
db.RowsAffected = 0
|
||||
|
||||
switch dest := db.Statement.Dest.(type) {
|
||||
case map[string]interface{}, *map[string]interface{}:
|
||||
if initialized || rows.Next() {
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue, ok := dest.(map[string]interface{})
|
||||
if !ok {
|
||||
if v, ok := dest.(*map[string]interface{}); ok {
|
||||
if *v == nil {
|
||||
*v = map[string]interface{}{}
|
||||
}
|
||||
mapValue = *v
|
||||
}
|
||||
}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
}
|
||||
case *[]map[string]interface{}:
|
||||
columnTypes, _ := rows.ColumnTypes()
|
||||
for initialized || rows.Next() {
|
||||
prepareValues(values, db, columnTypes, columns)
|
||||
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(values...))
|
||||
|
||||
mapValue := map[string]interface{}{}
|
||||
scanIntoMap(mapValue, values, columns)
|
||||
*dest = append(*dest, mapValue)
|
||||
}
|
||||
case *int, *int8, *int16, *int32, *int64,
|
||||
*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
|
||||
*float32, *float64,
|
||||
*bool, *string, *time.Time,
|
||||
*sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
|
||||
*sql.NullBool, *sql.NullString, *sql.NullTime:
|
||||
for initialized || rows.Next() {
|
||||
initialized = false
|
||||
db.RowsAffected++
|
||||
db.AddError(rows.Scan(dest))
|
||||
}
|
||||
default:
|
||||
var (
|
||||
fields = make([]*schema.Field, len(columns))
|
||||
joinFields [][]*schema.Field
|
||||
sch = db.Statement.Schema
|
||||
reflectValue = db.Statement.ReflectValue
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
reflectValueType := reflectValue.Type()
|
||||
switch reflectValueType.Kind() {
|
||||
case reflect.Array, reflect.Slice:
|
||||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
isPtr := reflectValueType.Kind() == reflect.Ptr
|
||||
if isPtr {
|
||||
reflectValueType = reflectValueType.Elem()
|
||||
}
|
||||
|
||||
if sch != nil {
|
||||
if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
|
||||
sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
|
||||
}
|
||||
|
||||
if len(columns) == 1 {
|
||||
// Is Pluck
|
||||
if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
|
||||
reflectValueType.Kind() != reflect.Struct || // is not struct
|
||||
sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
|
||||
sch = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Not Pluck
|
||||
if sch != nil {
|
||||
matchedFieldCount := make(map[string]int, len(columns))
|
||||
for idx, column := range columns {
|
||||
if field := sch.LookUpField(column); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
if count, ok := matchedFieldCount[column]; ok {
|
||||
// handle duplicate fields
|
||||
for _, selectField := range sch.Fields {
|
||||
if selectField.DBName == column && selectField.Readable {
|
||||
if count == 0 {
|
||||
matchedFieldCount[column]++
|
||||
fields[idx] = selectField
|
||||
break
|
||||
}
|
||||
count--
|
||||
}
|
||||
}
|
||||
} else {
|
||||
matchedFieldCount[column] = 1
|
||||
}
|
||||
} else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
|
||||
if rel, ok := sch.Relationships.Relations[names[0]]; ok {
|
||||
subNameCount := len(names)
|
||||
// nested relation fields
|
||||
relFields := make([]*schema.Field, 0, subNameCount-1)
|
||||
relFields = append(relFields, rel.Field)
|
||||
for _, name := range names[1 : subNameCount-1] {
|
||||
rel = rel.FieldSchema.Relationships.Relations[name]
|
||||
relFields = append(relFields, rel.Field)
|
||||
}
|
||||
// lastest name is raw dbname
|
||||
dbName := names[subNameCount-1]
|
||||
if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
|
||||
fields[idx] = field
|
||||
|
||||
if len(joinFields) == 0 {
|
||||
joinFields = make([][]*schema.Field, len(columns))
|
||||
}
|
||||
relFields = append(relFields, field)
|
||||
joinFields[idx] = relFields
|
||||
continue
|
||||
}
|
||||
}
|
||||
values[idx] = &sql.RawBytes{}
|
||||
} else {
|
||||
values[idx] = &sql.RawBytes{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
var (
|
||||
elem reflect.Value
|
||||
isArrayKind = reflectValue.Kind() == reflect.Array
|
||||
)
|
||||
|
||||
if !update || reflectValue.Len() == 0 {
|
||||
update = false
|
||||
// if the slice cap is externally initialized, the externally initialized slice is directly used here
|
||||
if reflectValue.Cap() == 0 {
|
||||
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
|
||||
} else if !isArrayKind {
|
||||
reflectValue.SetLen(0)
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
}
|
||||
|
||||
for initialized || rows.Next() {
|
||||
BEGIN:
|
||||
initialized = false
|
||||
|
||||
if update {
|
||||
if int(db.RowsAffected) >= reflectValue.Len() {
|
||||
return
|
||||
}
|
||||
elem = reflectValue.Index(int(db.RowsAffected))
|
||||
if onConflictDonothing {
|
||||
for _, field := range fields {
|
||||
if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
|
||||
db.RowsAffected++
|
||||
goto BEGIN
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
elem = reflect.New(reflectValueType)
|
||||
}
|
||||
|
||||
db.scanIntoStruct(rows, elem, values, fields, joinFields)
|
||||
|
||||
if !update {
|
||||
if !isPtr {
|
||||
elem = elem.Elem()
|
||||
}
|
||||
if isArrayKind {
|
||||
if reflectValue.Len() >= int(db.RowsAffected) {
|
||||
reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
|
||||
}
|
||||
} else {
|
||||
reflectValue = reflect.Append(reflectValue, elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !update {
|
||||
db.Statement.ReflectValue.Set(reflectValue)
|
||||
}
|
||||
case reflect.Struct, reflect.Ptr:
|
||||
if initialized || rows.Next() {
|
||||
db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
|
||||
}
|
||||
default:
|
||||
db.AddError(rows.Scan(dest))
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil && err != db.Error {
|
||||
db.AddError(err)
|
||||
}
|
||||
|
||||
if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
|
||||
db.AddError(ErrRecordNotFound)
|
||||
}
|
||||
}
|
||||
66
vendor/gorm.io/gorm/schema/constraint.go
generated
vendored
Normal file
66
vendor/gorm.io/gorm/schema/constraint.go
generated
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// reg match english letters and midline
|
||||
var regEnLetterAndMidline = regexp.MustCompile("^[A-Za-z-_]+$")
|
||||
|
||||
type CheckConstraint struct {
|
||||
Name string
|
||||
Constraint string // length(phone) >= 10
|
||||
*Field
|
||||
}
|
||||
|
||||
func (chk *CheckConstraint) GetName() string { return chk.Name }
|
||||
|
||||
func (chk *CheckConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
}
|
||||
|
||||
// ParseCheckConstraints parse schema check constraints
|
||||
func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint {
|
||||
checks := map[string]CheckConstraint{}
|
||||
for _, field := range schema.FieldsByDBName {
|
||||
if chk := field.TagSettings["CHECK"]; chk != "" {
|
||||
names := strings.Split(chk, ",")
|
||||
if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) {
|
||||
checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field}
|
||||
} else {
|
||||
if names[0] == "" {
|
||||
chk = strings.Join(names[1:], ",")
|
||||
}
|
||||
name := schema.namer.CheckerName(schema.Table, field.DBName)
|
||||
checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field}
|
||||
}
|
||||
}
|
||||
}
|
||||
return checks
|
||||
}
|
||||
|
||||
type UniqueConstraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
}
|
||||
|
||||
func (uni *UniqueConstraint) GetName() string { return uni.Name }
|
||||
|
||||
func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) {
|
||||
return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}}
|
||||
}
|
||||
|
||||
// ParseUniqueConstraints parse schema unique constraints
|
||||
func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint {
|
||||
uniques := make(map[string]UniqueConstraint)
|
||||
for _, field := range schema.Fields {
|
||||
if field.Unique {
|
||||
name := schema.namer.UniqueName(schema.Table, field.DBName)
|
||||
uniques[name] = UniqueConstraint{Name: name, Field: field}
|
||||
}
|
||||
}
|
||||
return uniques
|
||||
}
|
||||
996
vendor/gorm.io/gorm/schema/field.go
generated
vendored
Normal file
996
vendor/gorm.io/gorm/schema/field.go
generated
vendored
Normal file
@@ -0,0 +1,996 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// special types' reflect type
|
||||
var (
|
||||
TimeReflectType = reflect.TypeOf(time.Time{})
|
||||
TimePtrReflectType = reflect.TypeOf(&time.Time{})
|
||||
ByteReflectType = reflect.TypeOf(uint8(0))
|
||||
)
|
||||
|
||||
type (
|
||||
// DataType GORM data type
|
||||
DataType string
|
||||
// TimeType GORM time type
|
||||
TimeType int64
|
||||
)
|
||||
|
||||
// GORM time types
|
||||
const (
|
||||
UnixTime TimeType = 1
|
||||
UnixSecond TimeType = 2
|
||||
UnixMillisecond TimeType = 3
|
||||
UnixNanosecond TimeType = 4
|
||||
)
|
||||
|
||||
// GORM fields types
|
||||
const (
|
||||
Bool DataType = "bool"
|
||||
Int DataType = "int"
|
||||
Uint DataType = "uint"
|
||||
Float DataType = "float"
|
||||
String DataType = "string"
|
||||
Time DataType = "time"
|
||||
Bytes DataType = "bytes"
|
||||
)
|
||||
|
||||
const DefaultAutoIncrementIncrement int64 = 1
|
||||
|
||||
// Field is the representation of model schema's field
|
||||
type Field struct {
|
||||
Name string
|
||||
DBName string
|
||||
BindNames []string
|
||||
DataType DataType
|
||||
GORMDataType DataType
|
||||
PrimaryKey bool
|
||||
AutoIncrement bool
|
||||
AutoIncrementIncrement int64
|
||||
Creatable bool
|
||||
Updatable bool
|
||||
Readable bool
|
||||
AutoCreateTime TimeType
|
||||
AutoUpdateTime TimeType
|
||||
HasDefaultValue bool
|
||||
DefaultValue string
|
||||
DefaultValueInterface interface{}
|
||||
NotNull bool
|
||||
Unique bool
|
||||
Comment string
|
||||
Size int
|
||||
Precision int
|
||||
Scale int
|
||||
IgnoreMigration bool
|
||||
FieldType reflect.Type
|
||||
IndirectFieldType reflect.Type
|
||||
StructField reflect.StructField
|
||||
Tag reflect.StructTag
|
||||
TagSettings map[string]string
|
||||
Schema *Schema
|
||||
EmbeddedSchema *Schema
|
||||
OwnerSchema *Schema
|
||||
ReflectValueOf func(context.Context, reflect.Value) reflect.Value
|
||||
ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool)
|
||||
Set func(context.Context, reflect.Value, interface{}) error
|
||||
Serializer SerializerInterface
|
||||
NewValuePool FieldNewValuePool
|
||||
|
||||
// In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable.
|
||||
// When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique.
|
||||
// It causes field unnecessarily migration.
|
||||
// Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique.
|
||||
UniqueIndex string
|
||||
}
|
||||
|
||||
func (field *Field) BindName() string {
|
||||
return strings.Join(field.BindNames, ".")
|
||||
}
|
||||
|
||||
// ParseField parses reflect.StructField to Field
|
||||
func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
|
||||
var (
|
||||
err error
|
||||
tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";")
|
||||
)
|
||||
|
||||
field := &Field{
|
||||
Name: fieldStruct.Name,
|
||||
DBName: tagSetting["COLUMN"],
|
||||
BindNames: []string{fieldStruct.Name},
|
||||
FieldType: fieldStruct.Type,
|
||||
IndirectFieldType: fieldStruct.Type,
|
||||
StructField: fieldStruct,
|
||||
Tag: fieldStruct.Tag,
|
||||
TagSettings: tagSetting,
|
||||
Schema: schema,
|
||||
Creatable: true,
|
||||
Updatable: true,
|
||||
Readable: true,
|
||||
PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]),
|
||||
AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
|
||||
HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]),
|
||||
NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]),
|
||||
Unique: utils.CheckTruth(tagSetting["UNIQUE"]),
|
||||
Comment: tagSetting["COMMENT"],
|
||||
AutoIncrementIncrement: DefaultAutoIncrementIncrement,
|
||||
}
|
||||
|
||||
for field.IndirectFieldType.Kind() == reflect.Ptr {
|
||||
field.IndirectFieldType = field.IndirectFieldType.Elem()
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(field.IndirectFieldType)
|
||||
// if field is valuer, used its value or first field as data type
|
||||
valuer, isValuer := fieldValue.Interface().(driver.Valuer)
|
||||
if isValuer {
|
||||
if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok {
|
||||
if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil {
|
||||
fieldValue = reflect.ValueOf(v)
|
||||
}
|
||||
|
||||
// Use the field struct's first field type as data type, e.g: use `string` for sql.NullString
|
||||
var getRealFieldValue func(reflect.Value)
|
||||
getRealFieldValue = func(v reflect.Value) {
|
||||
var (
|
||||
rv = reflect.Indirect(v)
|
||||
rvType = rv.Type()
|
||||
)
|
||||
|
||||
if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) {
|
||||
for i := 0; i < rvType.NumField(); i++ {
|
||||
for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") {
|
||||
if _, ok := field.TagSettings[key]; !ok {
|
||||
field.TagSettings[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < rvType.NumField(); i++ {
|
||||
newFieldType := rvType.Field(i).Type
|
||||
for newFieldType.Kind() == reflect.Ptr {
|
||||
newFieldType = newFieldType.Elem()
|
||||
}
|
||||
|
||||
fieldValue = reflect.New(newFieldType)
|
||||
if rvType != reflect.Indirect(fieldValue).Type() {
|
||||
getRealFieldValue(fieldValue)
|
||||
}
|
||||
|
||||
if fieldValue.IsValid() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getRealFieldValue(fieldValue)
|
||||
}
|
||||
}
|
||||
|
||||
if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer {
|
||||
field.DataType = String
|
||||
field.Serializer = v
|
||||
} else {
|
||||
serializerName := field.TagSettings["JSON"]
|
||||
if serializerName == "" {
|
||||
serializerName = field.TagSettings["SERIALIZER"]
|
||||
}
|
||||
if serializerName != "" {
|
||||
if serializer, ok := GetSerializer(serializerName); ok {
|
||||
// Set default data type to string for serializer
|
||||
field.DataType = String
|
||||
field.Serializer = serializer
|
||||
} else {
|
||||
schema.err = fmt.Errorf("invalid serializer type %v", serializerName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
|
||||
field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64)
|
||||
}
|
||||
|
||||
if v, ok := field.TagSettings["DEFAULT"]; ok {
|
||||
field.HasDefaultValue = true
|
||||
field.DefaultValue = v
|
||||
}
|
||||
|
||||
if num, ok := field.TagSettings["SIZE"]; ok {
|
||||
if field.Size, err = strconv.Atoi(num); err != nil {
|
||||
field.Size = -1
|
||||
}
|
||||
}
|
||||
|
||||
if p, ok := field.TagSettings["PRECISION"]; ok {
|
||||
field.Precision, _ = strconv.Atoi(p)
|
||||
}
|
||||
|
||||
if s, ok := field.TagSettings["SCALE"]; ok {
|
||||
field.Scale, _ = strconv.Atoi(s)
|
||||
}
|
||||
|
||||
// default value is function or null or blank (primary keys)
|
||||
field.DefaultValue = strings.TrimSpace(field.DefaultValue)
|
||||
skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") &&
|
||||
strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == ""
|
||||
switch reflect.Indirect(fieldValue).Kind() {
|
||||
case reflect.Bool:
|
||||
field.DataType = Bool
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.DataType = Int
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.DataType = Uint
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
field.DataType = Float
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil {
|
||||
schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err)
|
||||
}
|
||||
}
|
||||
case reflect.String:
|
||||
field.DataType = String
|
||||
if field.HasDefaultValue && !skipParseDefaultValue {
|
||||
field.DefaultValue = strings.Trim(field.DefaultValue, "'")
|
||||
field.DefaultValue = strings.Trim(field.DefaultValue, `"`)
|
||||
field.DefaultValueInterface = field.DefaultValue
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, ok := fieldValue.Interface().(*time.Time); ok {
|
||||
field.DataType = Time
|
||||
} else if fieldValue.Type().ConvertibleTo(TimeReflectType) {
|
||||
field.DataType = Time
|
||||
} else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) {
|
||||
field.DataType = Time
|
||||
}
|
||||
if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time {
|
||||
if t, err := now.Parse(field.DefaultValue); err == nil {
|
||||
field.DefaultValueInterface = t
|
||||
}
|
||||
}
|
||||
case reflect.Array, reflect.Slice:
|
||||
if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" {
|
||||
field.DataType = Bytes
|
||||
}
|
||||
}
|
||||
|
||||
if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
|
||||
field.DataType = DataType(dataTyper.GormDataType())
|
||||
}
|
||||
|
||||
if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||
if field.DataType == Time {
|
||||
field.AutoCreateTime = UnixTime
|
||||
} else if strings.ToUpper(v) == "NANO" {
|
||||
field.AutoCreateTime = UnixNanosecond
|
||||
} else if strings.ToUpper(v) == "MILLI" {
|
||||
field.AutoCreateTime = UnixMillisecond
|
||||
} else {
|
||||
field.AutoCreateTime = UnixSecond
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) {
|
||||
if field.DataType == Time {
|
||||
field.AutoUpdateTime = UnixTime
|
||||
} else if strings.ToUpper(v) == "NANO" {
|
||||
field.AutoUpdateTime = UnixNanosecond
|
||||
} else if strings.ToUpper(v) == "MILLI" {
|
||||
field.AutoUpdateTime = UnixMillisecond
|
||||
} else {
|
||||
field.AutoUpdateTime = UnixSecond
|
||||
}
|
||||
}
|
||||
|
||||
if field.GORMDataType == "" {
|
||||
field.GORMDataType = field.DataType
|
||||
}
|
||||
|
||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||
switch DataType(strings.ToLower(val)) {
|
||||
case Bool, Int, Uint, Float, String, Time, Bytes:
|
||||
field.DataType = DataType(strings.ToLower(val))
|
||||
default:
|
||||
field.DataType = DataType(val)
|
||||
}
|
||||
}
|
||||
|
||||
if field.Size == 0 {
|
||||
switch reflect.Indirect(fieldValue).Kind() {
|
||||
case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64:
|
||||
field.Size = 64
|
||||
case reflect.Int8, reflect.Uint8:
|
||||
field.Size = 8
|
||||
case reflect.Int16, reflect.Uint16:
|
||||
field.Size = 16
|
||||
case reflect.Int32, reflect.Uint32, reflect.Float32:
|
||||
field.Size = 32
|
||||
}
|
||||
}
|
||||
|
||||
// setup permission
|
||||
if val, ok := field.TagSettings["-"]; ok {
|
||||
val = strings.ToLower(strings.TrimSpace(val))
|
||||
switch val {
|
||||
case "-":
|
||||
field.Creatable = false
|
||||
field.Updatable = false
|
||||
field.Readable = false
|
||||
field.DataType = ""
|
||||
case "all":
|
||||
field.Creatable = false
|
||||
field.Updatable = false
|
||||
field.Readable = false
|
||||
field.DataType = ""
|
||||
field.IgnoreMigration = true
|
||||
case "migration":
|
||||
field.IgnoreMigration = true
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := field.TagSettings["->"]; ok {
|
||||
field.Creatable = false
|
||||
field.Updatable = false
|
||||
if strings.ToLower(v) == "false" {
|
||||
field.Readable = false
|
||||
} else {
|
||||
field.Readable = true
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := field.TagSettings["<-"]; ok {
|
||||
field.Creatable = true
|
||||
field.Updatable = true
|
||||
|
||||
if v != "<-" {
|
||||
if !strings.Contains(v, "create") {
|
||||
field.Creatable = false
|
||||
}
|
||||
|
||||
if !strings.Contains(v, "update") {
|
||||
field.Updatable = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normal anonymous field or having `EMBEDDED` tag
|
||||
if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer &&
|
||||
fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) {
|
||||
kind := reflect.Indirect(fieldValue).Kind()
|
||||
switch kind {
|
||||
case reflect.Struct:
|
||||
var err error
|
||||
field.Creatable = false
|
||||
field.Updatable = false
|
||||
field.Readable = false
|
||||
|
||||
cacheStore := &sync.Map{}
|
||||
cacheStore.Store(embeddedCacheKey, true)
|
||||
if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
|
||||
for _, ef := range field.EmbeddedSchema.Fields {
|
||||
ef.Schema = schema
|
||||
ef.OwnerSchema = field.EmbeddedSchema
|
||||
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
|
||||
// index is negative means is pointer
|
||||
if field.FieldType.Kind() == reflect.Struct {
|
||||
ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...)
|
||||
} else {
|
||||
ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...)
|
||||
}
|
||||
|
||||
if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" {
|
||||
ef.DBName = prefix + ef.DBName
|
||||
}
|
||||
|
||||
if ef.PrimaryKey {
|
||||
if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) {
|
||||
ef.PrimaryKey = false
|
||||
|
||||
if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) {
|
||||
ef.AutoIncrement = false
|
||||
}
|
||||
|
||||
if !ef.AutoIncrement && ef.DefaultValue == "" {
|
||||
ef.HasDefaultValue = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range field.TagSettings {
|
||||
ef.TagSettings[k] = v
|
||||
}
|
||||
}
|
||||
case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface,
|
||||
reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128:
|
||||
schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType)
|
||||
}
|
||||
}
|
||||
|
||||
return field
|
||||
}
|
||||
|
||||
// create valuer, setter when parse struct
|
||||
func (field *Field) setupValuerAndSetter() {
|
||||
// Setup NewValuePool
|
||||
field.setupNewValuePool()
|
||||
|
||||
// ValueOf returns field's value and if it is zero
|
||||
fieldIndex := field.StructField.Index[0]
|
||||
switch {
|
||||
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||
field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
|
||||
fieldValue := reflect.Indirect(value).Field(fieldIndex)
|
||||
return fieldValue.Interface(), fieldValue.IsZero()
|
||||
}
|
||||
default:
|
||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||
v = reflect.Indirect(v)
|
||||
for _, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
} else {
|
||||
v = v.Field(-fieldIdx - 1)
|
||||
|
||||
if !v.IsNil() {
|
||||
v = v.Elem()
|
||||
} else {
|
||||
return nil, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fv, zero := v.Interface(), v.IsZero()
|
||||
return fv, zero
|
||||
}
|
||||
}
|
||||
|
||||
if field.Serializer != nil {
|
||||
oldValuerOf := field.ValueOf
|
||||
field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
|
||||
value, zero := oldValuerOf(ctx, v)
|
||||
|
||||
s, ok := value.(SerializerValuerInterface)
|
||||
if !ok {
|
||||
s = field.Serializer
|
||||
}
|
||||
|
||||
return &serializer{
|
||||
Field: field,
|
||||
SerializeValuer: s,
|
||||
Destination: v,
|
||||
Context: ctx,
|
||||
fieldValue: value,
|
||||
}, zero
|
||||
}
|
||||
}
|
||||
|
||||
// ReflectValueOf returns field's reflect value
|
||||
switch {
|
||||
case len(field.StructField.Index) == 1 && fieldIndex > 0:
|
||||
field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
|
||||
return reflect.Indirect(value).Field(fieldIndex)
|
||||
}
|
||||
default:
|
||||
field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
for idx, fieldIdx := range field.StructField.Index {
|
||||
if fieldIdx >= 0 {
|
||||
v = v.Field(fieldIdx)
|
||||
} else {
|
||||
v = v.Field(-fieldIdx - 1)
|
||||
|
||||
if v.IsNil() {
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
|
||||
if idx < len(field.StructField.Index)-1 {
|
||||
v = v.Elem()
|
||||
}
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
|
||||
if v == nil {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
// Optimal value type acquisition for v
|
||||
reflectValType := reflectV.Type()
|
||||
|
||||
if reflectValType.AssignableTo(field.FieldType) {
|
||||
if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr {
|
||||
reflectV = reflect.Indirect(reflectV)
|
||||
}
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
return
|
||||
} else if reflectValType.ConvertibleTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType))
|
||||
return
|
||||
} else if field.FieldType.Kind() == reflect.Ptr {
|
||||
fieldValue := field.ReflectValueOf(ctx, value)
|
||||
fieldType := field.FieldType.Elem()
|
||||
|
||||
if reflectValType.AssignableTo(fieldType) {
|
||||
if !fieldValue.IsValid() {
|
||||
fieldValue = reflect.New(fieldType)
|
||||
} else if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(fieldType))
|
||||
}
|
||||
fieldValue.Elem().Set(reflectV)
|
||||
return
|
||||
} else if reflectValType.ConvertibleTo(fieldType) {
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(fieldType))
|
||||
}
|
||||
|
||||
fieldValue.Elem().Set(reflectV.Convert(fieldType))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if reflectV.Kind() == reflect.Ptr {
|
||||
if reflectV.IsNil() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Type().Elem().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV.Elem())
|
||||
return
|
||||
} else {
|
||||
err = setter(ctx, value, reflectV.Elem().Interface())
|
||||
}
|
||||
} else if valuer, ok := v.(driver.Valuer); ok {
|
||||
if v, err = valuer.Value(); err == nil {
|
||||
err = setter(ctx, value, v)
|
||||
}
|
||||
} else if _, ok := v.(clause.Expr); !ok {
|
||||
return fmt.Errorf("failed to set value %#v to field %s", v, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Set
|
||||
switch field.FieldType.Kind() {
|
||||
case reflect.Bool:
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||
switch data := v.(type) {
|
||||
case **bool:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetBool(**data)
|
||||
}
|
||||
case bool:
|
||||
field.ReflectValueOf(ctx, value).SetBool(data)
|
||||
case int64:
|
||||
field.ReflectValueOf(ctx, value).SetBool(data > 0)
|
||||
case string:
|
||||
b, _ := strconv.ParseBool(data)
|
||||
field.ReflectValueOf(ctx, value).SetBool(b)
|
||||
default:
|
||||
return fallbackSetter(ctx, value, v, field.Set)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
switch data := v.(type) {
|
||||
case **int64:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(**data)
|
||||
}
|
||||
case **int:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case **int8:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case **int16:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case **int32:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(**data))
|
||||
}
|
||||
case int64:
|
||||
field.ReflectValueOf(ctx, value).SetInt(data)
|
||||
case int:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case int8:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case int16:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case int32:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case uint:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case uint8:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case uint16:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case uint32:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case uint64:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case float32:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case float64:
|
||||
field.ReflectValueOf(ctx, value).SetInt(int64(data))
|
||||
case []byte:
|
||||
return field.Set(ctx, value, string(data))
|
||||
case string:
|
||||
if i, err := strconv.ParseInt(data, 0, 64); err == nil {
|
||||
field.ReflectValueOf(ctx, value).SetInt(i)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
case time.Time:
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
case *time.Time:
|
||||
if data != nil {
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano())
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.UnixNano() / 1e6)
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(data.Unix())
|
||||
}
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetInt(0)
|
||||
}
|
||||
default:
|
||||
return fallbackSetter(ctx, value, v, field.Set)
|
||||
}
|
||||
return err
|
||||
}
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
switch data := v.(type) {
|
||||
case **uint64:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(**data)
|
||||
}
|
||||
case **uint:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case **uint8:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case **uint16:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case **uint32:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(**data))
|
||||
}
|
||||
case uint64:
|
||||
field.ReflectValueOf(ctx, value).SetUint(data)
|
||||
case uint:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case uint8:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case uint16:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case uint32:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case int64:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case int:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case int8:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case int16:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case int32:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case float32:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case float64:
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data))
|
||||
case []byte:
|
||||
return field.Set(ctx, value, string(data))
|
||||
case time.Time:
|
||||
if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano()))
|
||||
} else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano() / 1e6))
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix()))
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.ParseUint(data, 0, 64); err == nil {
|
||||
field.ReflectValueOf(ctx, value).SetUint(i)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fallbackSetter(ctx, value, v, field.Set)
|
||||
}
|
||||
return err
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
switch data := v.(type) {
|
||||
case **float64:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetFloat(**data)
|
||||
}
|
||||
case **float32:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(**data))
|
||||
}
|
||||
case float64:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(data)
|
||||
case float32:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case int64:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case int:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case int8:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case int16:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case int32:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case uint:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case uint8:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case uint16:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case uint32:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case uint64:
|
||||
field.ReflectValueOf(ctx, value).SetFloat(float64(data))
|
||||
case []byte:
|
||||
return field.Set(ctx, value, string(data))
|
||||
case string:
|
||||
if i, err := strconv.ParseFloat(data, 64); err == nil {
|
||||
field.ReflectValueOf(ctx, value).SetFloat(i)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fallbackSetter(ctx, value, v, field.Set)
|
||||
}
|
||||
return err
|
||||
}
|
||||
case reflect.String:
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
switch data := v.(type) {
|
||||
case **string:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).SetString(**data)
|
||||
}
|
||||
case string:
|
||||
field.ReflectValueOf(ctx, value).SetString(data)
|
||||
case []byte:
|
||||
field.ReflectValueOf(ctx, value).SetString(string(data))
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
field.ReflectValueOf(ctx, value).SetString(utils.ToString(data))
|
||||
case float64, float32:
|
||||
field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data))
|
||||
default:
|
||||
return fallbackSetter(ctx, value, v, field.Set)
|
||||
}
|
||||
return err
|
||||
}
|
||||
default:
|
||||
fieldValue := reflect.New(field.FieldType)
|
||||
switch fieldValue.Elem().Interface().(type) {
|
||||
case time.Time:
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||
switch data := v.(type) {
|
||||
case **time.Time:
|
||||
if data != nil && *data != nil {
|
||||
field.Set(ctx, value, *data)
|
||||
}
|
||||
case time.Time:
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
|
||||
case *time.Time:
|
||||
if data != nil {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem())
|
||||
} else {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{}))
|
||||
}
|
||||
case string:
|
||||
if t, err := now.Parse(data); err == nil {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t))
|
||||
} else {
|
||||
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
|
||||
}
|
||||
default:
|
||||
return fallbackSetter(ctx, value, v, field.Set)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case *time.Time:
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error {
|
||||
switch data := v.(type) {
|
||||
case **time.Time:
|
||||
if data != nil && *data != nil {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data))
|
||||
}
|
||||
case time.Time:
|
||||
fieldValue := field.ReflectValueOf(ctx, value)
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||
}
|
||||
fieldValue.Elem().Set(reflect.ValueOf(v))
|
||||
case *time.Time:
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v))
|
||||
case string:
|
||||
if t, err := now.Parse(data); err == nil {
|
||||
fieldValue := field.ReflectValueOf(ctx, value)
|
||||
if fieldValue.IsNil() {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||
}
|
||||
fieldValue.Elem().Set(reflect.ValueOf(t))
|
||||
} else {
|
||||
return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err)
|
||||
}
|
||||
default:
|
||||
return fallbackSetter(ctx, value, v, field.Set)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok {
|
||||
// pointer scanner
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
} else {
|
||||
fieldValue := field.ReflectValueOf(ctx, value)
|
||||
if fieldValue.IsNil() {
|
||||
fieldValue.Set(reflect.New(field.FieldType.Elem()))
|
||||
}
|
||||
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
err = fieldValue.Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else if _, ok := fieldValue.Interface().(sql.Scanner); ok {
|
||||
// struct scanner
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
reflectV := reflect.ValueOf(v)
|
||||
if !reflectV.IsValid() {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem())
|
||||
} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() {
|
||||
return
|
||||
} else if reflectV.Type().AssignableTo(field.FieldType) {
|
||||
field.ReflectValueOf(ctx, value).Set(reflectV)
|
||||
} else if reflectV.Kind() == reflect.Ptr {
|
||||
return field.Set(ctx, value, reflectV.Elem().Interface())
|
||||
} else {
|
||||
if valuer, ok := v.(driver.Valuer); ok {
|
||||
v, _ = valuer.Value()
|
||||
}
|
||||
|
||||
err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else {
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
return fallbackSetter(ctx, value, v, field.Set)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if field.Serializer != nil {
|
||||
var (
|
||||
oldFieldSetter = field.Set
|
||||
sameElemType bool
|
||||
sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type()
|
||||
)
|
||||
|
||||
if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr {
|
||||
sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem()
|
||||
}
|
||||
|
||||
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||
serializerType := serializerValue.Type()
|
||||
field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) {
|
||||
if s, ok := v.(*serializer); ok {
|
||||
if s.fieldValue != nil {
|
||||
err = oldFieldSetter(ctx, value, s.fieldValue)
|
||||
} else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil {
|
||||
if sameElemType {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem())
|
||||
} else if sameType {
|
||||
field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer))
|
||||
}
|
||||
si := reflect.New(serializerType)
|
||||
si.Elem().Set(serializerValue)
|
||||
s.Serializer = si.Interface().(SerializerInterface)
|
||||
}
|
||||
} else {
|
||||
err = oldFieldSetter(ctx, value, v)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (field *Field) setupNewValuePool() {
|
||||
if field.Serializer != nil {
|
||||
serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer))
|
||||
serializerType := serializerValue.Type()
|
||||
field.NewValuePool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
si := reflect.New(serializerType)
|
||||
si.Elem().Set(serializerValue)
|
||||
return &serializer{
|
||||
Field: field,
|
||||
Serializer: si.Interface().(SerializerInterface),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if field.NewValuePool == nil {
|
||||
field.NewValuePool = poolInitializer(reflect.PtrTo(field.IndirectFieldType))
|
||||
}
|
||||
}
|
||||
166
vendor/gorm.io/gorm/schema/index.go
generated
vendored
Normal file
166
vendor/gorm.io/gorm/schema/index.go
generated
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Index struct {
|
||||
Name string
|
||||
Class string // UNIQUE | FULLTEXT | SPATIAL
|
||||
Type string // btree, hash, gist, spgist, gin, and brin
|
||||
Where string
|
||||
Comment string
|
||||
Option string // WITH PARSER parser_name
|
||||
Fields []IndexOption // Note: IndexOption's Field maybe the same
|
||||
}
|
||||
|
||||
type IndexOption struct {
|
||||
*Field
|
||||
Expression string
|
||||
Sort string // DESC, ASC
|
||||
Collate string
|
||||
Length int
|
||||
priority int
|
||||
}
|
||||
|
||||
// ParseIndexes parse schema indexes
|
||||
func (schema *Schema) ParseIndexes() map[string]Index {
|
||||
indexes := map[string]Index{}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
|
||||
fieldIndexes, err := parseFieldIndexes(field)
|
||||
if err != nil {
|
||||
schema.err = err
|
||||
break
|
||||
}
|
||||
for _, index := range fieldIndexes {
|
||||
idx := indexes[index.Name]
|
||||
idx.Name = index.Name
|
||||
if idx.Class == "" {
|
||||
idx.Class = index.Class
|
||||
}
|
||||
if idx.Type == "" {
|
||||
idx.Type = index.Type
|
||||
}
|
||||
if idx.Where == "" {
|
||||
idx.Where = index.Where
|
||||
}
|
||||
if idx.Comment == "" {
|
||||
idx.Comment = index.Comment
|
||||
}
|
||||
if idx.Option == "" {
|
||||
idx.Option = index.Option
|
||||
}
|
||||
|
||||
idx.Fields = append(idx.Fields, index.Fields...)
|
||||
sort.Slice(idx.Fields, func(i, j int) bool {
|
||||
return idx.Fields[i].priority < idx.Fields[j].priority
|
||||
})
|
||||
|
||||
indexes[index.Name] = idx
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, index := range indexes {
|
||||
if index.Class == "UNIQUE" && len(index.Fields) == 1 {
|
||||
index.Fields[0].Field.UniqueIndex = index.Name
|
||||
}
|
||||
}
|
||||
return indexes
|
||||
}
|
||||
|
||||
func (schema *Schema) LookIndex(name string) *Index {
|
||||
if schema != nil {
|
||||
indexes := schema.ParseIndexes()
|
||||
for _, index := range indexes {
|
||||
if index.Name == name {
|
||||
return &index
|
||||
}
|
||||
|
||||
for _, field := range index.Fields {
|
||||
if field.Name == name {
|
||||
return &index
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseFieldIndexes(field *Field) (indexes []Index, err error) {
|
||||
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
|
||||
if value != "" {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
||||
if k == "INDEX" || k == "UNIQUEINDEX" {
|
||||
var (
|
||||
name string
|
||||
tag = strings.Join(v[1:], ":")
|
||||
idx = strings.Index(tag, ",")
|
||||
tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",")
|
||||
settings = ParseTagSetting(tagSetting, ",")
|
||||
length, _ = strconv.Atoi(settings["LENGTH"])
|
||||
)
|
||||
|
||||
if idx == -1 {
|
||||
idx = len(tag)
|
||||
}
|
||||
|
||||
if idx != -1 {
|
||||
name = tag[0:idx]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
subName := field.Name
|
||||
const key = "COMPOSITE"
|
||||
if composite, found := settings[key]; found {
|
||||
if len(composite) == 0 || composite == key {
|
||||
err = fmt.Errorf(
|
||||
"The composite tag of %s.%s cannot be empty",
|
||||
field.Schema.Name,
|
||||
field.Name)
|
||||
return
|
||||
}
|
||||
subName = composite
|
||||
}
|
||||
name = field.Schema.namer.IndexName(
|
||||
field.Schema.Table, subName)
|
||||
}
|
||||
|
||||
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
|
||||
settings["CLASS"] = "UNIQUE"
|
||||
}
|
||||
|
||||
priority, err := strconv.Atoi(settings["PRIORITY"])
|
||||
if err != nil {
|
||||
priority = 10
|
||||
}
|
||||
|
||||
indexes = append(indexes, Index{
|
||||
Name: name,
|
||||
Class: settings["CLASS"],
|
||||
Type: settings["TYPE"],
|
||||
Where: settings["WHERE"],
|
||||
Comment: settings["COMMENT"],
|
||||
Option: settings["OPTION"],
|
||||
Fields: []IndexOption{{
|
||||
Field: field,
|
||||
Expression: settings["EXPRESSION"],
|
||||
Sort: settings["SORT"],
|
||||
Collate: settings["COLLATE"],
|
||||
Length: length,
|
||||
priority: priority,
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
42
vendor/gorm.io/gorm/schema/interfaces.go
generated
vendored
Normal file
42
vendor/gorm.io/gorm/schema/interfaces.go
generated
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// ConstraintInterface database constraint interface
|
||||
type ConstraintInterface interface {
|
||||
GetName() string
|
||||
Build() (sql string, vars []interface{})
|
||||
}
|
||||
|
||||
// GormDataTypeInterface gorm data type interface
|
||||
type GormDataTypeInterface interface {
|
||||
GormDataType() string
|
||||
}
|
||||
|
||||
// FieldNewValuePool field new scan value pool
|
||||
type FieldNewValuePool interface {
|
||||
Get() interface{}
|
||||
Put(interface{})
|
||||
}
|
||||
|
||||
// CreateClausesInterface create clauses interface
|
||||
type CreateClausesInterface interface {
|
||||
CreateClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
||||
// QueryClausesInterface query clauses interface
|
||||
type QueryClausesInterface interface {
|
||||
QueryClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
||||
// UpdateClausesInterface update clauses interface
|
||||
type UpdateClausesInterface interface {
|
||||
UpdateClauses(*Field) []clause.Interface
|
||||
}
|
||||
|
||||
// DeleteClausesInterface delete clauses interface
|
||||
type DeleteClausesInterface interface {
|
||||
DeleteClauses(*Field) []clause.Interface
|
||||
}
|
||||
194
vendor/gorm.io/gorm/schema/naming.go
generated
vendored
Normal file
194
vendor/gorm.io/gorm/schema/naming.go
generated
vendored
Normal file
@@ -0,0 +1,194 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
)
|
||||
|
||||
// Namer namer interface
|
||||
type Namer interface {
|
||||
TableName(table string) string
|
||||
SchemaName(table string) string
|
||||
ColumnName(table, column string) string
|
||||
JoinTableName(joinTable string) string
|
||||
RelationshipFKName(Relationship) string
|
||||
CheckerName(table, column string) string
|
||||
IndexName(table, column string) string
|
||||
UniqueName(table, column string) string
|
||||
}
|
||||
|
||||
// Replacer replacer interface like strings.Replacer
|
||||
type Replacer interface {
|
||||
Replace(name string) string
|
||||
}
|
||||
|
||||
var _ Namer = (*NamingStrategy)(nil)
|
||||
|
||||
// NamingStrategy tables, columns naming strategy
|
||||
type NamingStrategy struct {
|
||||
TablePrefix string
|
||||
SingularTable bool
|
||||
NameReplacer Replacer
|
||||
NoLowerCase bool
|
||||
IdentifierMaxLength int
|
||||
}
|
||||
|
||||
// TableName convert string to table name
|
||||
func (ns NamingStrategy) TableName(str string) string {
|
||||
if ns.SingularTable {
|
||||
return ns.TablePrefix + ns.toDBName(str)
|
||||
}
|
||||
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
|
||||
}
|
||||
|
||||
// SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName
|
||||
func (ns NamingStrategy) SchemaName(table string) string {
|
||||
table = strings.TrimPrefix(table, ns.TablePrefix)
|
||||
|
||||
if ns.SingularTable {
|
||||
return ns.toSchemaName(table)
|
||||
}
|
||||
return ns.toSchemaName(inflection.Singular(table))
|
||||
}
|
||||
|
||||
// ColumnName convert string to column name
|
||||
func (ns NamingStrategy) ColumnName(table, column string) string {
|
||||
return ns.toDBName(column)
|
||||
}
|
||||
|
||||
// JoinTableName convert string to join table name
|
||||
func (ns NamingStrategy) JoinTableName(str string) string {
|
||||
if !ns.NoLowerCase && strings.ToLower(str) == str {
|
||||
return ns.TablePrefix + str
|
||||
}
|
||||
|
||||
if ns.SingularTable {
|
||||
return ns.TablePrefix + ns.toDBName(str)
|
||||
}
|
||||
return ns.TablePrefix + inflection.Plural(ns.toDBName(str))
|
||||
}
|
||||
|
||||
// RelationshipFKName generate fk name for relation
|
||||
func (ns NamingStrategy) RelationshipFKName(rel Relationship) string {
|
||||
return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name))
|
||||
}
|
||||
|
||||
// CheckerName generate checker name
|
||||
func (ns NamingStrategy) CheckerName(table, column string) string {
|
||||
return ns.formatName("chk", table, column)
|
||||
}
|
||||
|
||||
// IndexName generate index name
|
||||
func (ns NamingStrategy) IndexName(table, column string) string {
|
||||
return ns.formatName("idx", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
// UniqueName generate unique constraint name
|
||||
func (ns NamingStrategy) UniqueName(table, column string) string {
|
||||
return ns.formatName("uni", table, ns.toDBName(column))
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) formatName(prefix, table, name string) string {
|
||||
formattedName := strings.ReplaceAll(strings.Join([]string{
|
||||
prefix, table, name,
|
||||
}, "_"), ".", "_")
|
||||
|
||||
if ns.IdentifierMaxLength == 0 {
|
||||
ns.IdentifierMaxLength = 64
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength {
|
||||
h := sha1.New()
|
||||
h.Write([]byte(formattedName))
|
||||
bs := h.Sum(nil)
|
||||
|
||||
formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8]
|
||||
}
|
||||
return formattedName
|
||||
}
|
||||
|
||||
var (
|
||||
// https://github.com/golang/lint/blob/master/lint.go#L770
|
||||
commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"}
|
||||
commonInitialismsReplacer *strings.Replacer
|
||||
)
|
||||
|
||||
func init() {
|
||||
commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms))
|
||||
for _, initialism := range commonInitialisms {
|
||||
commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, strings.Title(strings.ToLower(initialism)))
|
||||
}
|
||||
commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...)
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) toDBName(name string) string {
|
||||
if name == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if ns.NameReplacer != nil {
|
||||
tmpName := ns.NameReplacer.Replace(name)
|
||||
|
||||
if tmpName == "" {
|
||||
return name
|
||||
}
|
||||
|
||||
name = tmpName
|
||||
}
|
||||
|
||||
if ns.NoLowerCase {
|
||||
return name
|
||||
}
|
||||
|
||||
var (
|
||||
value = commonInitialismsReplacer.Replace(name)
|
||||
buf strings.Builder
|
||||
lastCase, nextCase, nextNumber bool // upper case == true
|
||||
curCase = value[0] <= 'Z' && value[0] >= 'A'
|
||||
)
|
||||
|
||||
for i, v := range value[:len(value)-1] {
|
||||
nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A'
|
||||
nextNumber = value[i+1] >= '0' && value[i+1] <= '9'
|
||||
|
||||
if curCase {
|
||||
if lastCase && (nextCase || nextNumber) {
|
||||
buf.WriteRune(v + 32)
|
||||
} else {
|
||||
if i > 0 && value[i-1] != '_' && value[i+1] != '_' {
|
||||
buf.WriteByte('_')
|
||||
}
|
||||
buf.WriteRune(v + 32)
|
||||
}
|
||||
} else {
|
||||
buf.WriteRune(v)
|
||||
}
|
||||
|
||||
lastCase = curCase
|
||||
curCase = nextCase
|
||||
}
|
||||
|
||||
if curCase {
|
||||
if !lastCase && len(value) > 1 {
|
||||
buf.WriteByte('_')
|
||||
}
|
||||
buf.WriteByte(value[len(value)-1] + 32)
|
||||
} else {
|
||||
buf.WriteByte(value[len(value)-1])
|
||||
}
|
||||
ret := buf.String()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (ns NamingStrategy) toSchemaName(name string) string {
|
||||
result := strings.ReplaceAll(strings.Title(strings.ReplaceAll(name, "_", " ")), " ", "")
|
||||
for _, initialism := range commonInitialisms {
|
||||
result = regexp.MustCompile(strings.Title(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1")
|
||||
}
|
||||
return result
|
||||
}
|
||||
19
vendor/gorm.io/gorm/schema/pool.go
generated
vendored
Normal file
19
vendor/gorm.io/gorm/schema/pool.go
generated
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// sync pools
|
||||
var (
|
||||
normalPool sync.Map
|
||||
poolInitializer = func(reflectType reflect.Type) FieldNewValuePool {
|
||||
v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{
|
||||
New: func() interface{} {
|
||||
return reflect.New(reflectType).Interface()
|
||||
},
|
||||
})
|
||||
return v.(FieldNewValuePool)
|
||||
}
|
||||
)
|
||||
764
vendor/gorm.io/gorm/schema/relationship.go
generated
vendored
Normal file
764
vendor/gorm.io/gorm/schema/relationship.go
generated
vendored
Normal file
@@ -0,0 +1,764 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/jinzhu/inflection"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
||||
// RelationshipType relationship type
|
||||
type RelationshipType string
|
||||
|
||||
const (
|
||||
HasOne RelationshipType = "has_one" // HasOneRel has one relationship
|
||||
HasMany RelationshipType = "has_many" // HasManyRel has many relationship
|
||||
BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship
|
||||
Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship
|
||||
has RelationshipType = "has"
|
||||
)
|
||||
|
||||
type Relationships struct {
|
||||
HasOne []*Relationship
|
||||
BelongsTo []*Relationship
|
||||
HasMany []*Relationship
|
||||
Many2Many []*Relationship
|
||||
Relations map[string]*Relationship
|
||||
|
||||
EmbeddedRelations map[string]*Relationships
|
||||
}
|
||||
|
||||
type Relationship struct {
|
||||
Name string
|
||||
Type RelationshipType
|
||||
Field *Field
|
||||
Polymorphic *Polymorphic
|
||||
References []*Reference
|
||||
Schema *Schema
|
||||
FieldSchema *Schema
|
||||
JoinTable *Schema
|
||||
foreignKeys, primaryKeys []string
|
||||
}
|
||||
|
||||
type Polymorphic struct {
|
||||
PolymorphicID *Field
|
||||
PolymorphicType *Field
|
||||
Value string
|
||||
}
|
||||
|
||||
type Reference struct {
|
||||
PrimaryKey *Field
|
||||
PrimaryValue string
|
||||
ForeignKey *Field
|
||||
OwnPrimaryKey bool
|
||||
}
|
||||
|
||||
func (schema *Schema) parseRelation(field *Field) *Relationship {
|
||||
var (
|
||||
err error
|
||||
fieldValue = reflect.New(field.IndirectFieldType).Interface()
|
||||
relation = &Relationship{
|
||||
Name: field.Name,
|
||||
Field: field,
|
||||
Schema: schema,
|
||||
foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]),
|
||||
primaryKeys: toColumns(field.TagSettings["REFERENCES"]),
|
||||
}
|
||||
)
|
||||
|
||||
cacheStore := schema.cacheStore
|
||||
|
||||
if relation.FieldSchema, err = getOrParse(fieldValue, cacheStore, schema.namer); err != nil {
|
||||
schema.err = err
|
||||
return nil
|
||||
}
|
||||
|
||||
if hasPolymorphicRelation(field.TagSettings) {
|
||||
schema.buildPolymorphicRelation(relation, field)
|
||||
} else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" {
|
||||
schema.buildMany2ManyRelation(relation, field, many2many)
|
||||
} else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" {
|
||||
schema.guessRelation(relation, field, guessBelongs)
|
||||
} else {
|
||||
switch field.IndirectFieldType.Kind() {
|
||||
case reflect.Struct:
|
||||
schema.guessRelation(relation, field, guessGuess)
|
||||
case reflect.Slice:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
default:
|
||||
schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema,
|
||||
field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if relation.Type == has {
|
||||
// don't add relations to embedded schema, which might be shared
|
||||
if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil {
|
||||
relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation
|
||||
}
|
||||
|
||||
switch field.IndirectFieldType.Kind() {
|
||||
case reflect.Struct:
|
||||
relation.Type = HasOne
|
||||
case reflect.Slice:
|
||||
relation.Type = HasMany
|
||||
}
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
schema.setRelation(relation)
|
||||
switch relation.Type {
|
||||
case HasOne:
|
||||
schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation)
|
||||
case HasMany:
|
||||
schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation)
|
||||
case BelongsTo:
|
||||
schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation)
|
||||
case Many2Many:
|
||||
schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation)
|
||||
}
|
||||
}
|
||||
|
||||
return relation
|
||||
}
|
||||
|
||||
// hasPolymorphicRelation check if has polymorphic relation
|
||||
// 1. `POLYMORPHIC` tag
|
||||
// 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag
|
||||
func hasPolymorphicRelation(tagSettings map[string]string) bool {
|
||||
if _, ok := tagSettings["POLYMORPHIC"]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
_, hasType := tagSettings["POLYMORPHICTYPE"]
|
||||
_, hasId := tagSettings["POLYMORPHICID"]
|
||||
|
||||
return hasType && hasId
|
||||
}
|
||||
|
||||
func (schema *Schema) setRelation(relation *Relationship) {
|
||||
// set non-embedded relation
|
||||
if rel := schema.Relationships.Relations[relation.Name]; rel != nil {
|
||||
if len(rel.Field.BindNames) > 1 {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
} else {
|
||||
schema.Relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
|
||||
// set embedded relation
|
||||
if len(relation.Field.BindNames) <= 1 {
|
||||
return
|
||||
}
|
||||
relationships := &schema.Relationships
|
||||
for i, name := range relation.Field.BindNames {
|
||||
if i < len(relation.Field.BindNames)-1 {
|
||||
if relationships.EmbeddedRelations == nil {
|
||||
relationships.EmbeddedRelations = map[string]*Relationships{}
|
||||
}
|
||||
if r := relationships.EmbeddedRelations[name]; r == nil {
|
||||
relationships.EmbeddedRelations[name] = &Relationships{}
|
||||
}
|
||||
relationships = relationships.EmbeddedRelations[name]
|
||||
} else {
|
||||
if relationships.Relations == nil {
|
||||
relationships.Relations = map[string]*Relationship{}
|
||||
}
|
||||
relationships.Relations[relation.Name] = relation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner`
|
||||
//
|
||||
// type User struct {
|
||||
// Toys []Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Pet struct {
|
||||
// Toy Toy `gorm:"polymorphic:Owner;"`
|
||||
// }
|
||||
// type Toy struct {
|
||||
// OwnerID int
|
||||
// OwnerType string
|
||||
// }
|
||||
func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) {
|
||||
polymorphic := field.TagSettings["POLYMORPHIC"]
|
||||
|
||||
relation.Polymorphic = &Polymorphic{
|
||||
Value: schema.Table,
|
||||
}
|
||||
|
||||
var (
|
||||
typeName = polymorphic + "Type"
|
||||
typeId = polymorphic + "ID"
|
||||
)
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok {
|
||||
typeName = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICID"]; ok {
|
||||
typeId = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName]
|
||||
relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId]
|
||||
|
||||
if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok {
|
||||
relation.Polymorphic.Value = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicType == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"Type")
|
||||
}
|
||||
|
||||
if relation.Polymorphic.PolymorphicID == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s",
|
||||
relation.FieldSchema, schema, field.Name, polymorphic+"ID")
|
||||
}
|
||||
|
||||
if schema.err == nil {
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryValue: relation.Polymorphic.Value,
|
||||
ForeignKey: relation.Polymorphic.PolymorphicType,
|
||||
})
|
||||
|
||||
primaryKeyField := schema.PrioritizedPrimaryField
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 {
|
||||
schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys,
|
||||
schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if primaryKeyField == nil {
|
||||
schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field",
|
||||
relation.FieldSchema, schema, field.Name)
|
||||
return
|
||||
}
|
||||
|
||||
// use same data type for foreign keys
|
||||
if copyableDataType(primaryKeyField.DataType) {
|
||||
relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType
|
||||
}
|
||||
relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType
|
||||
if relation.Polymorphic.PolymorphicID.Size == 0 {
|
||||
relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size
|
||||
}
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: primaryKeyField,
|
||||
ForeignKey: relation.Polymorphic.PolymorphicID,
|
||||
OwnPrimaryKey: true,
|
||||
})
|
||||
}
|
||||
|
||||
relation.Type = has
|
||||
}
|
||||
|
||||
func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) {
|
||||
relation.Type = Many2Many
|
||||
|
||||
var (
|
||||
err error
|
||||
joinTableFields []reflect.StructField
|
||||
fieldsMap = map[string]*Field{}
|
||||
ownFieldsMap = map[string]*Field{} // fix self join many2many
|
||||
referFieldsMap = map[string]*Field{}
|
||||
joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"])
|
||||
joinReferences = toColumns(field.TagSettings["JOINREFERENCES"])
|
||||
)
|
||||
|
||||
ownForeignFields := schema.PrimaryFields
|
||||
refForeignFields := relation.FieldSchema.PrimaryFields
|
||||
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
ownForeignFields = []*Field{}
|
||||
for _, foreignKey := range relation.foreignKeys {
|
||||
if field := schema.LookUpField(foreignKey); field != nil {
|
||||
ownForeignFields = append(ownForeignFields, field)
|
||||
} else {
|
||||
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(relation.primaryKeys) > 0 {
|
||||
refForeignFields = []*Field{}
|
||||
for _, foreignKey := range relation.primaryKeys {
|
||||
if field := relation.FieldSchema.LookUpField(foreignKey); field != nil {
|
||||
refForeignFields = append(refForeignFields, field)
|
||||
} else {
|
||||
schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for idx, ownField := range ownForeignFields {
|
||||
joinFieldName := strings.Title(schema.Name) + ownField.Name
|
||||
if len(joinForeignKeys) > idx {
|
||||
joinFieldName = strings.Title(joinForeignKeys[idx])
|
||||
}
|
||||
|
||||
ownFieldsMap[joinFieldName] = ownField
|
||||
fieldsMap[joinFieldName] = ownField
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: joinFieldName,
|
||||
PkgPath: ownField.StructField.PkgPath,
|
||||
Type: ownField.StructField.Type,
|
||||
Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"),
|
||||
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
}
|
||||
|
||||
for idx, relField := range refForeignFields {
|
||||
joinFieldName := strings.Title(relation.FieldSchema.Name) + relField.Name
|
||||
|
||||
if _, ok := ownFieldsMap[joinFieldName]; ok {
|
||||
if field.Name != relation.FieldSchema.Name {
|
||||
joinFieldName = inflection.Singular(field.Name) + relField.Name
|
||||
} else {
|
||||
joinFieldName += "Reference"
|
||||
}
|
||||
}
|
||||
|
||||
if len(joinReferences) > idx {
|
||||
joinFieldName = strings.Title(joinReferences[idx])
|
||||
}
|
||||
|
||||
referFieldsMap[joinFieldName] = relField
|
||||
|
||||
if _, ok := fieldsMap[joinFieldName]; !ok {
|
||||
fieldsMap[joinFieldName] = relField
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: joinFieldName,
|
||||
PkgPath: relField.StructField.PkgPath,
|
||||
Type: relField.StructField.Type,
|
||||
Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"),
|
||||
"column", "autoincrement", "index", "unique", "uniqueindex"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
joinTableFields = append(joinTableFields, reflect.StructField{
|
||||
Name: strings.Title(schema.Name) + field.Name,
|
||||
Type: schema.ModelType,
|
||||
Tag: `gorm:"-"`,
|
||||
})
|
||||
|
||||
if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore,
|
||||
schema.namer); err != nil {
|
||||
schema.err = err
|
||||
}
|
||||
relation.JoinTable.Name = many2many
|
||||
relation.JoinTable.Table = schema.namer.JoinTableName(many2many)
|
||||
relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields))
|
||||
|
||||
relName := relation.Schema.Name
|
||||
relRefName := relation.FieldSchema.Name
|
||||
if relName == relRefName {
|
||||
relRefName = relation.Field.Name
|
||||
}
|
||||
|
||||
if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok {
|
||||
relation.JoinTable.Relationships.Relations[relName] = &Relationship{
|
||||
Name: relName,
|
||||
Type: BelongsTo,
|
||||
Schema: relation.JoinTable,
|
||||
FieldSchema: relation.Schema,
|
||||
}
|
||||
} else {
|
||||
relation.JoinTable.Relationships.Relations[relName].References = []*Reference{}
|
||||
}
|
||||
|
||||
if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok {
|
||||
relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{
|
||||
Name: relRefName,
|
||||
Type: BelongsTo,
|
||||
Schema: relation.JoinTable,
|
||||
FieldSchema: relation.FieldSchema,
|
||||
}
|
||||
} else {
|
||||
relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{}
|
||||
}
|
||||
|
||||
// build references
|
||||
for _, f := range relation.JoinTable.Fields {
|
||||
if f.Creatable || f.Readable || f.Updatable {
|
||||
// use same data type for foreign keys
|
||||
if copyableDataType(fieldsMap[f.Name].DataType) {
|
||||
f.DataType = fieldsMap[f.Name].DataType
|
||||
}
|
||||
f.GORMDataType = fieldsMap[f.Name].GORMDataType
|
||||
if f.Size == 0 {
|
||||
f.Size = fieldsMap[f.Name].Size
|
||||
}
|
||||
relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f)
|
||||
|
||||
if of, ok := ownFieldsMap[f.Name]; ok {
|
||||
joinRel := relation.JoinTable.Relationships.Relations[relName]
|
||||
joinRel.Field = relation.Field
|
||||
joinRel.References = append(joinRel.References, &Reference{
|
||||
PrimaryKey: of,
|
||||
ForeignKey: f,
|
||||
})
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: of,
|
||||
ForeignKey: f,
|
||||
OwnPrimaryKey: true,
|
||||
})
|
||||
}
|
||||
|
||||
if rf, ok := referFieldsMap[f.Name]; ok {
|
||||
joinRefRel := relation.JoinTable.Relationships.Relations[relRefName]
|
||||
if joinRefRel.Field == nil {
|
||||
joinRefRel.Field = relation.Field
|
||||
}
|
||||
joinRefRel.References = append(joinRefRel.References, &Reference{
|
||||
PrimaryKey: rf,
|
||||
ForeignKey: f,
|
||||
})
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: rf,
|
||||
ForeignKey: f,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type guessLevel int
|
||||
|
||||
const (
|
||||
guessGuess guessLevel = iota
|
||||
guessBelongs
|
||||
guessEmbeddedBelongs
|
||||
guessHas
|
||||
guessEmbeddedHas
|
||||
)
|
||||
|
||||
func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) {
|
||||
var (
|
||||
primaryFields, foreignFields []*Field
|
||||
primarySchema, foreignSchema = schema, relation.FieldSchema
|
||||
gl = cgl
|
||||
)
|
||||
|
||||
if gl == guessGuess {
|
||||
if field.Schema == relation.FieldSchema {
|
||||
gl = guessBelongs
|
||||
} else {
|
||||
gl = guessHas
|
||||
}
|
||||
}
|
||||
|
||||
reguessOrErr := func() {
|
||||
switch cgl {
|
||||
case guessGuess:
|
||||
schema.guessRelation(relation, field, guessBelongs)
|
||||
case guessBelongs:
|
||||
schema.guessRelation(relation, field, guessEmbeddedBelongs)
|
||||
case guessEmbeddedBelongs:
|
||||
schema.guessRelation(relation, field, guessHas)
|
||||
case guessHas:
|
||||
schema.guessRelation(relation, field, guessEmbeddedHas)
|
||||
// case guessEmbeddedHas:
|
||||
default:
|
||||
schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface",
|
||||
schema, field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
switch gl {
|
||||
case guessBelongs:
|
||||
primarySchema, foreignSchema = relation.FieldSchema, schema
|
||||
case guessEmbeddedBelongs:
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
|
||||
case guessHas:
|
||||
case guessEmbeddedHas:
|
||||
if field.OwnerSchema == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
|
||||
}
|
||||
|
||||
if len(relation.foreignKeys) > 0 {
|
||||
for _, foreignKey := range relation.foreignKeys {
|
||||
f := foreignSchema.LookUpField(foreignKey)
|
||||
if f == nil {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
foreignFields = append(foreignFields, f)
|
||||
}
|
||||
} else {
|
||||
primarySchemaName := primarySchema.Name
|
||||
if primarySchemaName == "" {
|
||||
primarySchemaName = relation.FieldSchema.Name
|
||||
}
|
||||
|
||||
if len(relation.primaryKeys) > 0 {
|
||||
for _, primaryKey := range relation.primaryKeys {
|
||||
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
||||
primaryFields = append(primaryFields, f)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
primaryFields = primarySchema.PrimaryFields
|
||||
}
|
||||
|
||||
primaryFieldLoop:
|
||||
for _, primaryField := range primaryFields {
|
||||
lookUpName := primarySchemaName + primaryField.Name
|
||||
if gl == guessBelongs {
|
||||
lookUpName = field.Name + primaryField.Name
|
||||
}
|
||||
|
||||
lookUpNames := []string{lookUpName}
|
||||
if len(primaryFields) == 1 {
|
||||
lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID",
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table,
|
||||
strings.TrimSuffix(lookUpName, primaryField.Name)+"ID"))
|
||||
}
|
||||
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
continue primaryFieldLoop
|
||||
}
|
||||
}
|
||||
for _, name := range lookUpNames {
|
||||
if f := foreignSchema.LookUpField(name); f != nil {
|
||||
foreignFields = append(foreignFields, f)
|
||||
primaryFields = append(primaryFields, primaryField)
|
||||
continue primaryFieldLoop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(foreignFields) == 0:
|
||||
reguessOrErr()
|
||||
return
|
||||
case len(relation.primaryKeys) > 0:
|
||||
for idx, primaryKey := range relation.primaryKeys {
|
||||
if f := primarySchema.LookUpField(primaryKey); f != nil {
|
||||
if len(primaryFields) < idx+1 {
|
||||
primaryFields = append(primaryFields, f)
|
||||
} else if f != primaryFields[idx] {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
}
|
||||
case len(primaryFields) == 0:
|
||||
if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil {
|
||||
primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField)
|
||||
} else if len(primarySchema.PrimaryFields) == len(foreignFields) {
|
||||
primaryFields = append(primaryFields, primarySchema.PrimaryFields...)
|
||||
} else {
|
||||
reguessOrErr()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// build references
|
||||
for idx, foreignField := range foreignFields {
|
||||
// use same data type for foreign keys
|
||||
if copyableDataType(primaryFields[idx].DataType) {
|
||||
foreignField.DataType = primaryFields[idx].DataType
|
||||
}
|
||||
foreignField.GORMDataType = primaryFields[idx].GORMDataType
|
||||
if foreignField.Size == 0 {
|
||||
foreignField.Size = primaryFields[idx].Size
|
||||
}
|
||||
|
||||
relation.References = append(relation.References, &Reference{
|
||||
PrimaryKey: primaryFields[idx],
|
||||
ForeignKey: foreignField,
|
||||
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
|
||||
})
|
||||
}
|
||||
|
||||
if gl == guessHas || gl == guessEmbeddedHas {
|
||||
relation.Type = has
|
||||
} else {
|
||||
relation.Type = BelongsTo
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint is ForeignKey Constraint
|
||||
type Constraint struct {
|
||||
Name string
|
||||
Field *Field
|
||||
Schema *Schema
|
||||
ForeignKeys []*Field
|
||||
ReferenceSchema *Schema
|
||||
References []*Field
|
||||
OnDelete string
|
||||
OnUpdate string
|
||||
}
|
||||
|
||||
func (constraint *Constraint) GetName() string { return constraint.Name }
|
||||
|
||||
func (constraint *Constraint) Build() (sql string, vars []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys))
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
references := make([]interface{}, 0, len(constraint.References))
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
func (rel *Relationship) ParseConstraint() *Constraint {
|
||||
str := rel.Field.TagSettings["CONSTRAINT"]
|
||||
if str == "-" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rel.Type == BelongsTo {
|
||||
for _, r := range rel.FieldSchema.Relationships.Relations {
|
||||
if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) {
|
||||
matched := true
|
||||
for idx, ref := range r.References {
|
||||
if !(rel.References[idx].PrimaryKey == ref.PrimaryKey && rel.References[idx].ForeignKey == ref.ForeignKey &&
|
||||
rel.References[idx].PrimaryValue == ref.PrimaryValue) {
|
||||
matched = false
|
||||
}
|
||||
}
|
||||
|
||||
if matched {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
name string
|
||||
idx = strings.Index(str, ",")
|
||||
settings = ParseTagSetting(str, ",")
|
||||
)
|
||||
|
||||
// optimize match english letters and midline
|
||||
// The following code is basically called in for.
|
||||
// In order to avoid the performance problems caused by repeated compilation of regular expressions,
|
||||
// it only needs to be done once outside, so optimization is done here.
|
||||
if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) {
|
||||
name = str[0:idx]
|
||||
} else {
|
||||
name = rel.Schema.namer.RelationshipFKName(*rel)
|
||||
}
|
||||
|
||||
constraint := Constraint{
|
||||
Name: name,
|
||||
Field: rel.Field,
|
||||
OnUpdate: settings["ONUPDATE"],
|
||||
OnDelete: settings["ONDELETE"],
|
||||
}
|
||||
|
||||
for _, ref := range rel.References {
|
||||
if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) {
|
||||
constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey)
|
||||
constraint.References = append(constraint.References, ref.PrimaryKey)
|
||||
|
||||
if ref.OwnPrimaryKey {
|
||||
constraint.Schema = ref.ForeignKey.Schema
|
||||
constraint.ReferenceSchema = rel.Schema
|
||||
} else {
|
||||
constraint.Schema = rel.Schema
|
||||
constraint.ReferenceSchema = ref.PrimaryKey.Schema
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &constraint
|
||||
}
|
||||
|
||||
func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) {
|
||||
table := rel.FieldSchema.Table
|
||||
foreignFields := []*Field{}
|
||||
relForeignKeys := []string{}
|
||||
|
||||
if rel.JoinTable != nil {
|
||||
table = rel.JoinTable.Table
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
} else {
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName},
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, ref := range rel.References {
|
||||
if ref.OwnPrimaryKey {
|
||||
relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
|
||||
foreignFields = append(foreignFields, ref.PrimaryKey)
|
||||
} else if ref.PrimaryValue != "" {
|
||||
conds = append(conds, clause.Eq{
|
||||
Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName},
|
||||
Value: ref.PrimaryValue,
|
||||
})
|
||||
} else {
|
||||
relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
|
||||
foreignFields = append(foreignFields, ref.ForeignKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields)
|
||||
column, values := ToQueryValues(table, relForeignKeys, foreignValues)
|
||||
|
||||
conds = append(conds, clause.IN{Column: column, Values: values})
|
||||
return
|
||||
}
|
||||
|
||||
func copyableDataType(str DataType) bool {
|
||||
for _, s := range []string{"auto_increment", "primary key"} {
|
||||
if strings.Contains(strings.ToLower(string(str)), s) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
423
vendor/gorm.io/gorm/schema/schema.go
generated
vendored
Normal file
423
vendor/gorm.io/gorm/schema/schema.go
generated
vendored
Normal file
@@ -0,0 +1,423 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type callbackType string
|
||||
|
||||
const (
|
||||
callbackTypeBeforeCreate callbackType = "BeforeCreate"
|
||||
callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
|
||||
callbackTypeAfterCreate callbackType = "AfterCreate"
|
||||
callbackTypeAfterUpdate callbackType = "AfterUpdate"
|
||||
callbackTypeBeforeSave callbackType = "BeforeSave"
|
||||
callbackTypeAfterSave callbackType = "AfterSave"
|
||||
callbackTypeBeforeDelete callbackType = "BeforeDelete"
|
||||
callbackTypeAfterDelete callbackType = "AfterDelete"
|
||||
callbackTypeAfterFind callbackType = "AfterFind"
|
||||
)
|
||||
|
||||
// ErrUnsupportedDataType unsupported data type
|
||||
var ErrUnsupportedDataType = errors.New("unsupported data type")
|
||||
|
||||
type Schema struct {
|
||||
Name string
|
||||
ModelType reflect.Type
|
||||
Table string
|
||||
PrioritizedPrimaryField *Field
|
||||
DBNames []string
|
||||
PrimaryFields []*Field
|
||||
PrimaryFieldDBNames []string
|
||||
Fields []*Field
|
||||
FieldsByName map[string]*Field
|
||||
FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
|
||||
FieldsByDBName map[string]*Field
|
||||
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
|
||||
Relationships Relationships
|
||||
CreateClauses []clause.Interface
|
||||
QueryClauses []clause.Interface
|
||||
UpdateClauses []clause.Interface
|
||||
DeleteClauses []clause.Interface
|
||||
BeforeCreate, AfterCreate bool
|
||||
BeforeUpdate, AfterUpdate bool
|
||||
BeforeDelete, AfterDelete bool
|
||||
BeforeSave, AfterSave bool
|
||||
AfterFind bool
|
||||
err error
|
||||
initialized chan struct{}
|
||||
namer Namer
|
||||
cacheStore *sync.Map
|
||||
}
|
||||
|
||||
func (schema Schema) String() string {
|
||||
if schema.ModelType.Name() == "" {
|
||||
return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
|
||||
}
|
||||
|
||||
func (schema Schema) MakeSlice() reflect.Value {
|
||||
slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
|
||||
results := reflect.New(slice.Type())
|
||||
results.Elem().Set(slice)
|
||||
return results
|
||||
}
|
||||
|
||||
func (schema Schema) LookUpField(name string) *Field {
|
||||
if field, ok := schema.FieldsByDBName[name]; ok {
|
||||
return field
|
||||
}
|
||||
if field, ok := schema.FieldsByName[name]; ok {
|
||||
return field
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LookUpFieldByBindName looks for the closest field in the embedded struct.
|
||||
//
|
||||
// type Struct struct {
|
||||
// Embedded struct {
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
|
||||
// }
|
||||
// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
|
||||
// }
|
||||
func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
|
||||
if len(bindNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := len(bindNames) - 1; i >= 0; i-- {
|
||||
find := strings.Join(bindNames[:i], ".") + "." + name
|
||||
if field, ok := schema.FieldsByBindName[find]; ok {
|
||||
return field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Tabler interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
type TablerWithNamer interface {
|
||||
TableName(Namer) string
|
||||
}
|
||||
|
||||
// Parse get data type from dialector
|
||||
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
return ParseWithSpecialTableName(dest, cacheStore, namer, "")
|
||||
}
|
||||
|
||||
// ParseWithSpecialTableName get data type from dialector with extra schema table
|
||||
func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
|
||||
if dest == nil {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
|
||||
value := reflect.ValueOf(dest)
|
||||
if value.Kind() == reflect.Ptr && value.IsNil() {
|
||||
value = reflect.New(value.Type().Elem())
|
||||
}
|
||||
modelType := reflect.Indirect(value).Type()
|
||||
|
||||
if modelType.Kind() == reflect.Interface {
|
||||
modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
|
||||
}
|
||||
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
if modelType.PkgPath() == "" {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
}
|
||||
|
||||
// Cache the Schema for performance,
|
||||
// Use the modelType or modelType + schemaTable (if it present) as cache key.
|
||||
var schemaCacheKey interface{}
|
||||
if specialTableName != "" {
|
||||
schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
|
||||
} else {
|
||||
schemaCacheKey = modelType
|
||||
}
|
||||
|
||||
// Load exist schema cache, return if exists
|
||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||
s := v.(*Schema)
|
||||
// Wait for the initialization of other goroutines to complete
|
||||
<-s.initialized
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
modelValue := reflect.New(modelType)
|
||||
tableName := namer.TableName(modelType.Name())
|
||||
if tabler, ok := modelValue.Interface().(Tabler); ok {
|
||||
tableName = tabler.TableName()
|
||||
}
|
||||
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
|
||||
tableName = tabler.TableName(namer)
|
||||
}
|
||||
if en, ok := namer.(embeddedNamer); ok {
|
||||
tableName = en.Table
|
||||
}
|
||||
if specialTableName != "" && specialTableName != tableName {
|
||||
tableName = specialTableName
|
||||
}
|
||||
|
||||
schema := &Schema{
|
||||
Name: modelType.Name(),
|
||||
ModelType: modelType,
|
||||
Table: tableName,
|
||||
FieldsByName: map[string]*Field{},
|
||||
FieldsByBindName: map[string]*Field{},
|
||||
FieldsByDBName: map[string]*Field{},
|
||||
Relationships: Relationships{Relations: map[string]*Relationship{}},
|
||||
cacheStore: cacheStore,
|
||||
namer: namer,
|
||||
initialized: make(chan struct{}),
|
||||
}
|
||||
// When the schema initialization is completed, the channel will be closed
|
||||
defer close(schema.initialized)
|
||||
|
||||
// Load exist schema cache, return if exists
|
||||
if v, ok := cacheStore.Load(schemaCacheKey); ok {
|
||||
s := v.(*Schema)
|
||||
// Wait for the initialization of other goroutines to complete
|
||||
<-s.initialized
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
|
||||
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
|
||||
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
|
||||
} else {
|
||||
schema.Fields = append(schema.Fields, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.DBName == "" && field.DataType != "" {
|
||||
field.DBName = namer.ColumnName(schema.Table, field.Name)
|
||||
}
|
||||
|
||||
bindName := field.BindName()
|
||||
if field.DBName != "" {
|
||||
// nonexistence or shortest path or first appear prioritized if has permission
|
||||
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
|
||||
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
|
||||
schema.DBNames = append(schema.DBNames, field.DBName)
|
||||
}
|
||||
schema.FieldsByDBName[field.DBName] = field
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
|
||||
if v != nil && v.PrimaryKey {
|
||||
for idx, f := range schema.PrimaryFields {
|
||||
if f == v {
|
||||
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if field.PrimaryKey {
|
||||
schema.PrimaryFields = append(schema.PrimaryFields, field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
}
|
||||
if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
|
||||
schema.FieldsByBindName[bindName] = field
|
||||
}
|
||||
|
||||
field.setupValuerAndSetter()
|
||||
}
|
||||
|
||||
prioritizedPrimaryField := schema.LookUpField("id")
|
||||
if prioritizedPrimaryField == nil {
|
||||
prioritizedPrimaryField = schema.LookUpField("ID")
|
||||
}
|
||||
|
||||
if prioritizedPrimaryField != nil {
|
||||
if prioritizedPrimaryField.PrimaryKey {
|
||||
schema.PrioritizedPrimaryField = prioritizedPrimaryField
|
||||
} else if len(schema.PrimaryFields) == 0 {
|
||||
prioritizedPrimaryField.PrimaryKey = true
|
||||
schema.PrioritizedPrimaryField = prioritizedPrimaryField
|
||||
schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
|
||||
}
|
||||
}
|
||||
|
||||
if schema.PrioritizedPrimaryField == nil {
|
||||
if len(schema.PrimaryFields) == 1 {
|
||||
schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
|
||||
} else if len(schema.PrimaryFields) > 1 {
|
||||
// If there are multiple primary keys, the AUTOINCREMENT field is prioritized
|
||||
for _, field := range schema.PrimaryFields {
|
||||
if field.AutoIncrement {
|
||||
schema.PrioritizedPrimaryField = field
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, field := range schema.PrimaryFields {
|
||||
schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
|
||||
}
|
||||
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
|
||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||
}
|
||||
}
|
||||
|
||||
if field := schema.PrioritizedPrimaryField; field != nil {
|
||||
switch field.GORMDataType {
|
||||
case Int, Uint:
|
||||
if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
|
||||
if !field.HasDefaultValue || field.DefaultValueInterface != nil {
|
||||
schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
|
||||
}
|
||||
|
||||
field.HasDefaultValue = true
|
||||
field.AutoIncrement = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
callbackTypes := []callbackType{
|
||||
callbackTypeBeforeCreate, callbackTypeAfterCreate,
|
||||
callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
|
||||
callbackTypeBeforeSave, callbackTypeAfterSave,
|
||||
callbackTypeBeforeDelete, callbackTypeAfterDelete,
|
||||
callbackTypeAfterFind,
|
||||
}
|
||||
for _, cbName := range callbackTypes {
|
||||
if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
|
||||
switch methodValue.Type().String() {
|
||||
case "func(*gorm.DB) error": // TODO hack
|
||||
reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
|
||||
default:
|
||||
logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the schema
|
||||
if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
|
||||
s := v.(*Schema)
|
||||
// Wait for the initialization of other goroutines to complete
|
||||
<-s.initialized
|
||||
return s, s.err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if schema.err != nil {
|
||||
logger.Default.Error(context.Background(), schema.err.Error())
|
||||
cacheStore.Delete(modelType)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
|
||||
for _, field := range schema.Fields {
|
||||
if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
|
||||
if schema.parseRelation(field); schema.err != nil {
|
||||
return schema, schema.err
|
||||
} else {
|
||||
schema.FieldsByName[field.Name] = field
|
||||
schema.FieldsByBindName[field.BindName()] = field
|
||||
}
|
||||
}
|
||||
|
||||
fieldValue := reflect.New(field.IndirectFieldType)
|
||||
fieldInterface := fieldValue.Interface()
|
||||
if fc, ok := fieldInterface.(CreateClausesInterface); ok {
|
||||
field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldInterface.(QueryClausesInterface); ok {
|
||||
field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
|
||||
field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
|
||||
}
|
||||
|
||||
if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
|
||||
field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return schema, schema.err
|
||||
}
|
||||
|
||||
// This unrolling is needed to show to the compiler the exact set of methods
|
||||
// that can be used on the modelType.
|
||||
// Prior to go1.22 any use of MethodByName would cause the linker to
|
||||
// abandon dead code elimination for the entire binary.
|
||||
// As of go1.22 the compiler supports one special case of a string constant
|
||||
// being passed to MethodByName. For enterprise customers or those building
|
||||
// large binaries, this gives a significant reduction in binary size.
|
||||
// https://github.com/golang/go/issues/62257
|
||||
func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
|
||||
switch cbType {
|
||||
case callbackTypeBeforeCreate:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeCreate))
|
||||
case callbackTypeAfterCreate:
|
||||
return modelType.MethodByName(string(callbackTypeAfterCreate))
|
||||
case callbackTypeBeforeUpdate:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeUpdate))
|
||||
case callbackTypeAfterUpdate:
|
||||
return modelType.MethodByName(string(callbackTypeAfterUpdate))
|
||||
case callbackTypeBeforeSave:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeSave))
|
||||
case callbackTypeAfterSave:
|
||||
return modelType.MethodByName(string(callbackTypeAfterSave))
|
||||
case callbackTypeBeforeDelete:
|
||||
return modelType.MethodByName(string(callbackTypeBeforeDelete))
|
||||
case callbackTypeAfterDelete:
|
||||
return modelType.MethodByName(string(callbackTypeAfterDelete))
|
||||
case callbackTypeAfterFind:
|
||||
return modelType.MethodByName(string(callbackTypeAfterFind))
|
||||
default:
|
||||
return reflect.ValueOf(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
|
||||
modelType := reflect.ValueOf(dest).Type()
|
||||
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
if modelType.PkgPath() == "" {
|
||||
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
|
||||
}
|
||||
|
||||
if v, ok := cacheStore.Load(modelType); ok {
|
||||
return v.(*Schema), nil
|
||||
}
|
||||
|
||||
return Parse(dest, cacheStore, namer)
|
||||
}
|
||||
170
vendor/gorm.io/gorm/schema/serializer.go
generated
vendored
Normal file
170
vendor/gorm.io/gorm/schema/serializer.go
generated
vendored
Normal file
@@ -0,0 +1,170 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var serializerMap = sync.Map{}
|
||||
|
||||
// RegisterSerializer register serializer
|
||||
func RegisterSerializer(name string, serializer SerializerInterface) {
|
||||
serializerMap.Store(strings.ToLower(name), serializer)
|
||||
}
|
||||
|
||||
// GetSerializer get serializer
|
||||
func GetSerializer(name string) (serializer SerializerInterface, ok bool) {
|
||||
v, ok := serializerMap.Load(strings.ToLower(name))
|
||||
if ok {
|
||||
serializer, ok = v.(SerializerInterface)
|
||||
}
|
||||
return serializer, ok
|
||||
}
|
||||
|
||||
func init() {
|
||||
RegisterSerializer("json", JSONSerializer{})
|
||||
RegisterSerializer("unixtime", UnixSecondSerializer{})
|
||||
RegisterSerializer("gob", GobSerializer{})
|
||||
}
|
||||
|
||||
// Serializer field value serializer
|
||||
type serializer struct {
|
||||
Field *Field
|
||||
Serializer SerializerInterface
|
||||
SerializeValuer SerializerValuerInterface
|
||||
Destination reflect.Value
|
||||
Context context.Context
|
||||
value interface{}
|
||||
fieldValue interface{}
|
||||
}
|
||||
|
||||
// Scan implements sql.Scanner interface
|
||||
func (s *serializer) Scan(value interface{}) error {
|
||||
s.value = value
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements driver.Valuer interface
|
||||
func (s serializer) Value() (driver.Value, error) {
|
||||
return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue)
|
||||
}
|
||||
|
||||
// SerializerInterface serializer interface
|
||||
type SerializerInterface interface {
|
||||
Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error
|
||||
SerializerValuerInterface
|
||||
}
|
||||
|
||||
// SerializerValuerInterface serializer valuer interface
|
||||
type SerializerValuerInterface interface {
|
||||
Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error)
|
||||
}
|
||||
|
||||
// JSONSerializer json serializer
|
||||
type JSONSerializer struct{}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
fieldValue := reflect.New(field.FieldType)
|
||||
|
||||
if dbValue != nil {
|
||||
var bytes []byte
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal JSONB value: %#v", dbValue)
|
||||
}
|
||||
|
||||
if len(bytes) > 0 {
|
||||
err = json.Unmarshal(bytes, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||
return
|
||||
}
|
||||
|
||||
// Value implements serializer interface
|
||||
func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
result, err := json.Marshal(fieldValue)
|
||||
if string(result) == "null" {
|
||||
if field.TagSettings["NOT NULL"] != "" {
|
||||
return "", nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return string(result), err
|
||||
}
|
||||
|
||||
// UnixSecondSerializer json serializer
|
||||
type UnixSecondSerializer struct{}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
t := sql.NullTime{}
|
||||
if err = t.Scan(dbValue); err == nil && t.Valid {
|
||||
err = field.Set(ctx, dst, t.Time.Unix())
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Value implements serializer interface
|
||||
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
|
||||
rv := reflect.ValueOf(fieldValue)
|
||||
switch v := fieldValue.(type) {
|
||||
case int64, int, uint, uint64, int32, uint32, int16, uint16:
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
|
||||
if rv.IsZero() {
|
||||
return nil, nil
|
||||
}
|
||||
result = time.Unix(reflect.Indirect(rv).Int(), 0)
|
||||
default:
|
||||
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GobSerializer gob serializer
|
||||
type GobSerializer struct{}
|
||||
|
||||
// Scan implements serializer interface
|
||||
func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) {
|
||||
fieldValue := reflect.New(field.FieldType)
|
||||
|
||||
if dbValue != nil {
|
||||
var bytesValue []byte
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytesValue = v
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue)
|
||||
}
|
||||
if len(bytesValue) > 0 {
|
||||
decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue))
|
||||
err = decoder.Decode(fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem())
|
||||
return
|
||||
}
|
||||
|
||||
// Value implements serializer interface
|
||||
func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
err := gob.NewEncoder(buf).Encode(fieldValue)
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
213
vendor/gorm.io/gorm/schema/utils.go
generated
vendored
Normal file
213
vendor/gorm.io/gorm/schema/utils.go
generated
vendored
Normal file
@@ -0,0 +1,213 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
var embeddedCacheKey = "embedded_cache_store"
|
||||
|
||||
func ParseTagSetting(str string, sep string) map[string]string {
|
||||
settings := map[string]string{}
|
||||
names := strings.Split(str, sep)
|
||||
|
||||
for i := 0; i < len(names); i++ {
|
||||
j := i
|
||||
if len(names[j]) > 0 {
|
||||
for {
|
||||
if names[j][len(names[j])-1] == '\\' {
|
||||
i++
|
||||
names[j] = names[j][0:len(names[j])-1] + sep + names[i]
|
||||
names[i] = ""
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
values := strings.Split(names[j], ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(values[0]))
|
||||
|
||||
if len(values) >= 2 {
|
||||
settings[k] = strings.Join(values[1:], ":")
|
||||
} else if k != "" {
|
||||
settings[k] = k
|
||||
}
|
||||
}
|
||||
|
||||
return settings
|
||||
}
|
||||
|
||||
func toColumns(val string) (results []string) {
|
||||
if val != "" {
|
||||
for _, v := range strings.Split(val, ",") {
|
||||
results = append(results, strings.TrimSpace(v))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag {
|
||||
for _, name := range names {
|
||||
tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}"))
|
||||
}
|
||||
return tag
|
||||
}
|
||||
|
||||
func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag {
|
||||
t := tag.Get("gorm")
|
||||
if strings.Contains(t, value) {
|
||||
return tag
|
||||
}
|
||||
return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t))
|
||||
}
|
||||
|
||||
// GetRelationsValues get relations's values from a reflect value
|
||||
func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) {
|
||||
for _, rel := range rels {
|
||||
reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(rel.FieldSchema.ModelType)), 0, 1)
|
||||
|
||||
appendToResults := func(value reflect.Value) {
|
||||
if _, isZero := rel.Field.ValueOf(ctx, value); !isZero {
|
||||
result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value))
|
||||
switch result.Kind() {
|
||||
case reflect.Struct:
|
||||
reflectResults = reflect.Append(reflectResults, result.Addr())
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < result.Len(); i++ {
|
||||
if elem := result.Index(i); elem.Kind() == reflect.Ptr {
|
||||
reflectResults = reflect.Append(reflectResults, elem)
|
||||
} else {
|
||||
reflectResults = reflect.Append(reflectResults, elem.Addr())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
appendToResults(reflectValue)
|
||||
case reflect.Slice:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
appendToResults(reflectValue.Index(i))
|
||||
}
|
||||
}
|
||||
|
||||
reflectValue = reflectResults
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// GetIdentityFieldValuesMap get identity map from fields
|
||||
func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
||||
var (
|
||||
results = [][]interface{}{}
|
||||
dataResults = map[string][]reflect.Value{}
|
||||
loaded = map[interface{}]bool{}
|
||||
notZero, zero bool
|
||||
)
|
||||
|
||||
if reflectValue.Kind() == reflect.Ptr ||
|
||||
reflectValue.Kind() == reflect.Interface {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
results = [][]interface{}{make([]interface{}, len(fields))}
|
||||
|
||||
for idx, field := range fields {
|
||||
results[0][idx], zero = field.ValueOf(ctx, reflectValue)
|
||||
notZero = notZero || !zero
|
||||
}
|
||||
|
||||
if !notZero {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
elem := reflectValue.Index(i)
|
||||
elemKey := elem.Interface()
|
||||
if elem.Kind() != reflect.Ptr && elem.CanAddr() {
|
||||
elemKey = elem.Addr().Interface()
|
||||
}
|
||||
|
||||
if _, ok := loaded[elemKey]; ok {
|
||||
continue
|
||||
}
|
||||
loaded[elemKey] = true
|
||||
|
||||
fieldValues := make([]interface{}, len(fields))
|
||||
notZero = false
|
||||
for idx, field := range fields {
|
||||
fieldValues[idx], zero = field.ValueOf(ctx, elem)
|
||||
notZero = notZero || !zero
|
||||
}
|
||||
|
||||
if notZero {
|
||||
dataKey := utils.ToStringKey(fieldValues...)
|
||||
if _, ok := dataResults[dataKey]; !ok {
|
||||
results = append(results, fieldValues)
|
||||
dataResults[dataKey] = []reflect.Value{elem}
|
||||
} else {
|
||||
dataResults[dataKey] = append(dataResults[dataKey], elem)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return dataResults, results
|
||||
}
|
||||
|
||||
// GetIdentityFieldValuesMapFromValues get identity map from fields
|
||||
func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) {
|
||||
resultsMap := map[string][]reflect.Value{}
|
||||
results := [][]interface{}{}
|
||||
|
||||
for _, v := range values {
|
||||
rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields)
|
||||
for k, v := range rm {
|
||||
resultsMap[k] = append(resultsMap[k], v...)
|
||||
}
|
||||
results = append(results, rs...)
|
||||
}
|
||||
return resultsMap, results
|
||||
}
|
||||
|
||||
// ToQueryValues to query values
|
||||
func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) {
|
||||
queryValues := make([]interface{}, len(foreignValues))
|
||||
if len(foreignKeys) == 1 {
|
||||
for idx, r := range foreignValues {
|
||||
queryValues[idx] = r[0]
|
||||
}
|
||||
|
||||
return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
|
||||
}
|
||||
|
||||
columns := make([]clause.Column, len(foreignKeys))
|
||||
for idx, key := range foreignKeys {
|
||||
columns[idx] = clause.Column{Table: table, Name: key}
|
||||
}
|
||||
|
||||
for idx, r := range foreignValues {
|
||||
queryValues[idx] = r
|
||||
}
|
||||
|
||||
return columns, queryValues
|
||||
}
|
||||
|
||||
type embeddedNamer struct {
|
||||
Table string
|
||||
Namer
|
||||
}
|
||||
170
vendor/gorm.io/gorm/soft_delete.go
generated
vendored
Normal file
170
vendor/gorm.io/gorm/soft_delete.go
generated
vendored
Normal file
@@ -0,0 +1,170 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
|
||||
"github.com/jinzhu/now"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type DeletedAt sql.NullTime
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *DeletedAt) Scan(value interface{}) error {
|
||||
return (*sql.NullTime)(n).Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n DeletedAt) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Time, nil
|
||||
}
|
||||
|
||||
func (n DeletedAt) MarshalJSON() ([]byte, error) {
|
||||
if n.Valid {
|
||||
return json.Marshal(n.Time)
|
||||
}
|
||||
return json.Marshal(nil)
|
||||
}
|
||||
|
||||
func (n *DeletedAt) UnmarshalJSON(b []byte) error {
|
||||
if string(b) == "null" {
|
||||
n.Valid = false
|
||||
return nil
|
||||
}
|
||||
err := json.Unmarshal(b, &n.Time)
|
||||
if err == nil {
|
||||
n.Valid = true
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
func parseZeroValueTag(f *schema.Field) sql.NullString {
|
||||
if v, ok := f.TagSettings["ZEROVALUE"]; ok {
|
||||
if _, err := now.Parse(v); err == nil {
|
||||
return sql.NullString{String: v, Valid: true}
|
||||
}
|
||||
}
|
||||
return sql.NullString{Valid: false}
|
||||
}
|
||||
|
||||
type SoftDeleteQueryClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) {
|
||||
if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped {
|
||||
if c, ok := stmt.Clauses["WHERE"]; ok {
|
||||
if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 {
|
||||
for _, expr := range where.Exprs {
|
||||
if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 {
|
||||
where.Exprs = []clause.Expression{clause.And(where.Exprs...)}
|
||||
c.Expression = where
|
||||
stmt.Clauses["WHERE"] = c
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{
|
||||
clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue},
|
||||
}})
|
||||
stmt.Clauses["soft_delete_enabled"] = clause.Clause{}
|
||||
}
|
||||
}
|
||||
|
||||
func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
type SoftDeleteUpdateClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface {
|
||||
return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}}
|
||||
}
|
||||
|
||||
type SoftDeleteDeleteClause struct {
|
||||
ZeroValue sql.NullString
|
||||
Field *schema.Field
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) Name() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) Build(clause.Builder) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) {
|
||||
}
|
||||
|
||||
func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) {
|
||||
if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped {
|
||||
curTime := stmt.DB.NowFunc()
|
||||
stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}})
|
||||
stmt.SetColumn(sd.Field.DBName, curTime, true)
|
||||
|
||||
if stmt.Schema != nil {
|
||||
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
|
||||
column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
|
||||
if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil {
|
||||
_, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
|
||||
column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)
|
||||
|
||||
if len(values) > 0 {
|
||||
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SoftDeleteQueryClause(sd).ModifyStatement(stmt)
|
||||
stmt.AddClauseIfNotExists(clause.Update{})
|
||||
stmt.Build(stmt.DB.Callback().Update().Clauses...)
|
||||
}
|
||||
}
|
||||
742
vendor/gorm.io/gorm/statement.go
generated
vendored
Normal file
742
vendor/gorm.io/gorm/statement.go
generated
vendored
Normal file
@@ -0,0 +1,742 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
// Statement statement
|
||||
type Statement struct {
|
||||
*DB
|
||||
TableExpr *clause.Expr
|
||||
Table string
|
||||
Model interface{}
|
||||
Unscoped bool
|
||||
Dest interface{}
|
||||
ReflectValue reflect.Value
|
||||
Clauses map[string]clause.Clause
|
||||
BuildClauses []string
|
||||
Distinct bool
|
||||
Selects []string // selected columns
|
||||
Omits []string // omit columns
|
||||
Joins []join
|
||||
Preloads map[string][]interface{}
|
||||
Settings sync.Map
|
||||
ConnPool ConnPool
|
||||
Schema *schema.Schema
|
||||
Context context.Context
|
||||
RaiseErrorOnNotFound bool
|
||||
SkipHooks bool
|
||||
SQL strings.Builder
|
||||
Vars []interface{}
|
||||
CurDestIndex int
|
||||
attrs []interface{}
|
||||
assigns []interface{}
|
||||
scopes []func(*DB) *DB
|
||||
}
|
||||
|
||||
type join struct {
|
||||
Name string
|
||||
Conds []interface{}
|
||||
On *clause.Where
|
||||
Selects []string
|
||||
Omits []string
|
||||
JoinType clause.JoinType
|
||||
}
|
||||
|
||||
// StatementModifier statement modifier interface
|
||||
type StatementModifier interface {
|
||||
ModifyStatement(*Statement)
|
||||
}
|
||||
|
||||
// WriteString write string
|
||||
func (stmt *Statement) WriteString(str string) (int, error) {
|
||||
return stmt.SQL.WriteString(str)
|
||||
}
|
||||
|
||||
// WriteByte write byte
|
||||
func (stmt *Statement) WriteByte(c byte) error {
|
||||
return stmt.SQL.WriteByte(c)
|
||||
}
|
||||
|
||||
// WriteQuoted write quoted value
|
||||
func (stmt *Statement) WriteQuoted(value interface{}) {
|
||||
stmt.QuoteTo(&stmt.SQL, value)
|
||||
}
|
||||
|
||||
// QuoteTo write quoted value to writer
|
||||
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
|
||||
write := func(raw bool, str string) {
|
||||
if raw {
|
||||
writer.WriteString(str)
|
||||
} else {
|
||||
stmt.DB.Dialector.QuoteTo(writer, str)
|
||||
}
|
||||
}
|
||||
|
||||
switch v := field.(type) {
|
||||
case clause.Table:
|
||||
if v.Name == clause.CurrentTable {
|
||||
if stmt.TableExpr != nil {
|
||||
stmt.TableExpr.Build(stmt)
|
||||
} else {
|
||||
write(v.Raw, stmt.Table)
|
||||
}
|
||||
} else {
|
||||
write(v.Raw, v.Name)
|
||||
}
|
||||
|
||||
if v.Alias != "" {
|
||||
writer.WriteByte(' ')
|
||||
write(v.Raw, v.Alias)
|
||||
}
|
||||
case clause.Column:
|
||||
if v.Table != "" {
|
||||
if v.Table == clause.CurrentTable {
|
||||
write(v.Raw, stmt.Table)
|
||||
} else {
|
||||
write(v.Raw, v.Table)
|
||||
}
|
||||
writer.WriteByte('.')
|
||||
}
|
||||
|
||||
if v.Name == clause.PrimaryKey {
|
||||
if stmt.Schema == nil {
|
||||
stmt.DB.AddError(ErrModelValueRequired)
|
||||
} else if stmt.Schema.PrioritizedPrimaryField != nil {
|
||||
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
|
||||
} else if len(stmt.Schema.DBNames) > 0 {
|
||||
write(v.Raw, stmt.Schema.DBNames[0])
|
||||
} else {
|
||||
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
|
||||
}
|
||||
} else {
|
||||
write(v.Raw, v.Name)
|
||||
}
|
||||
|
||||
if v.Alias != "" {
|
||||
writer.WriteString(" AS ")
|
||||
write(v.Raw, v.Alias)
|
||||
}
|
||||
case []clause.Column:
|
||||
writer.WriteByte('(')
|
||||
for idx, d := range v {
|
||||
if idx > 0 {
|
||||
writer.WriteByte(',')
|
||||
}
|
||||
stmt.QuoteTo(writer, d)
|
||||
}
|
||||
writer.WriteByte(')')
|
||||
case clause.Expr:
|
||||
v.Build(stmt)
|
||||
case string:
|
||||
stmt.DB.Dialector.QuoteTo(writer, v)
|
||||
case []string:
|
||||
writer.WriteByte('(')
|
||||
for idx, d := range v {
|
||||
if idx > 0 {
|
||||
writer.WriteByte(',')
|
||||
}
|
||||
stmt.DB.Dialector.QuoteTo(writer, d)
|
||||
}
|
||||
writer.WriteByte(')')
|
||||
default:
|
||||
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
|
||||
}
|
||||
}
|
||||
|
||||
// Quote returns quoted value
|
||||
func (stmt *Statement) Quote(field interface{}) string {
|
||||
var builder strings.Builder
|
||||
stmt.QuoteTo(&builder, field)
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// AddVar add var
|
||||
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
|
||||
for idx, v := range vars {
|
||||
if idx > 0 {
|
||||
writer.WriteByte(',')
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case sql.NamedArg:
|
||||
stmt.Vars = append(stmt.Vars, v.Value)
|
||||
case clause.Column, clause.Table:
|
||||
stmt.QuoteTo(writer, v)
|
||||
case Valuer:
|
||||
reflectValue := reflect.ValueOf(v)
|
||||
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
|
||||
stmt.AddVar(writer, nil)
|
||||
} else {
|
||||
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
|
||||
}
|
||||
case clause.Interface:
|
||||
c := clause.Clause{Name: v.Name()}
|
||||
v.MergeClause(&c)
|
||||
c.Build(stmt)
|
||||
case clause.Expression:
|
||||
v.Build(stmt)
|
||||
case driver.Valuer:
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
case []byte:
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
case []interface{}:
|
||||
if len(v) > 0 {
|
||||
writer.WriteByte('(')
|
||||
stmt.AddVar(writer, v...)
|
||||
writer.WriteByte(')')
|
||||
} else {
|
||||
writer.WriteString("(NULL)")
|
||||
}
|
||||
case *DB:
|
||||
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
|
||||
if v.Statement.SQL.Len() > 0 {
|
||||
var (
|
||||
vars = subdb.Statement.Vars
|
||||
sql = v.Statement.SQL.String()
|
||||
)
|
||||
|
||||
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
|
||||
for _, vv := range vars {
|
||||
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
|
||||
bindvar := strings.Builder{}
|
||||
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
|
||||
sql = strings.Replace(sql, bindvar.String(), "?", 1)
|
||||
}
|
||||
|
||||
subdb.Statement.SQL.Reset()
|
||||
subdb.Statement.Vars = stmt.Vars
|
||||
if strings.Contains(sql, "@") {
|
||||
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
} else {
|
||||
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
|
||||
}
|
||||
} else {
|
||||
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
|
||||
subdb.callbacks.Query().Execute(subdb)
|
||||
}
|
||||
|
||||
writer.WriteString(subdb.Statement.SQL.String())
|
||||
stmt.Vars = subdb.Statement.Vars
|
||||
default:
|
||||
switch rv := reflect.ValueOf(v); rv.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if rv.Len() == 0 {
|
||||
writer.WriteString("(NULL)")
|
||||
} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
} else {
|
||||
writer.WriteByte('(')
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
if i > 0 {
|
||||
writer.WriteByte(',')
|
||||
}
|
||||
stmt.AddVar(writer, rv.Index(i).Interface())
|
||||
}
|
||||
writer.WriteByte(')')
|
||||
}
|
||||
default:
|
||||
stmt.Vars = append(stmt.Vars, v)
|
||||
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddClause add clause
|
||||
func (stmt *Statement) AddClause(v clause.Interface) {
|
||||
if optimizer, ok := v.(StatementModifier); ok {
|
||||
optimizer.ModifyStatement(stmt)
|
||||
} else {
|
||||
name := v.Name()
|
||||
c := stmt.Clauses[name]
|
||||
c.Name = name
|
||||
v.MergeClause(&c)
|
||||
stmt.Clauses[name] = c
|
||||
}
|
||||
}
|
||||
|
||||
// AddClauseIfNotExists add clause if not exists
|
||||
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
|
||||
if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
|
||||
stmt.AddClause(v)
|
||||
}
|
||||
}
|
||||
|
||||
// BuildCondition build condition
|
||||
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
|
||||
if s, ok := query.(string); ok {
|
||||
// if it is a number, then treats it as primary key
|
||||
if _, err := strconv.Atoi(s); err != nil {
|
||||
if s == "" && len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
|
||||
// looks like a where condition
|
||||
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
||||
}
|
||||
|
||||
if len(args) > 0 && strings.Contains(s, "@") {
|
||||
// looks like a named query
|
||||
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
|
||||
}
|
||||
|
||||
if strings.Contains(strings.TrimSpace(s), " ") {
|
||||
// looks like a where condition
|
||||
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
|
||||
}
|
||||
|
||||
if len(args) == 1 {
|
||||
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
conds := make([]clause.Expression, 0, 4)
|
||||
args = append([]interface{}{query}, args...)
|
||||
for idx, arg := range args {
|
||||
if arg == nil {
|
||||
continue
|
||||
}
|
||||
if valuer, ok := arg.(driver.Valuer); ok {
|
||||
arg, _ = valuer.Value()
|
||||
}
|
||||
|
||||
switch v := arg.(type) {
|
||||
case clause.Expression:
|
||||
conds = append(conds, v)
|
||||
case *DB:
|
||||
v.executeScopes()
|
||||
|
||||
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
|
||||
if where, ok := cs.Expression.(clause.Where); ok {
|
||||
if len(where.Exprs) == 1 {
|
||||
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
|
||||
where.Exprs[0] = clause.AndConditions(orConds)
|
||||
}
|
||||
}
|
||||
conds = append(conds, clause.And(where.Exprs...))
|
||||
} else if cs.Expression != nil {
|
||||
conds = append(conds, cs.Expression)
|
||||
}
|
||||
}
|
||||
case map[interface{}]interface{}:
|
||||
for i, j := range v {
|
||||
conds = append(conds, clause.Eq{Column: i, Value: j})
|
||||
}
|
||||
case map[string]string:
|
||||
keys := make([]string, 0, len(v))
|
||||
for i := range v {
|
||||
keys = append(keys, i)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
}
|
||||
case map[string]interface{}:
|
||||
keys := make([]string, 0, len(v))
|
||||
for i := range v {
|
||||
keys = append(keys, i)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for _, key := range keys {
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if _, ok := v[key].(driver.Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
} else if _, ok := v[key].(Valuer); ok {
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
} else {
|
||||
// optimize reflect value length
|
||||
valueLen := reflectValue.Len()
|
||||
values := make([]interface{}, valueLen)
|
||||
for i := 0; i < valueLen; i++ {
|
||||
values[i] = reflectValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: key, Values: values})
|
||||
}
|
||||
default:
|
||||
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
|
||||
}
|
||||
}
|
||||
default:
|
||||
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
|
||||
for reflectValue.Kind() == reflect.Ptr {
|
||||
reflectValue = reflectValue.Elem()
|
||||
}
|
||||
|
||||
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
|
||||
selectedColumns := map[string]bool{}
|
||||
if idx == 0 {
|
||||
for _, v := range args[1:] {
|
||||
if vs, ok := v.(string); ok {
|
||||
selectedColumns[vs] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
restricted := len(selectedColumns) != 0
|
||||
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Struct:
|
||||
for _, field := range s.Fields {
|
||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < reflectValue.Len(); i++ {
|
||||
for _, field := range s.Fields {
|
||||
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
|
||||
if selected || (!restricted && field.Readable) {
|
||||
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
|
||||
if field.DBName != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
|
||||
} else if field.DataType != "" {
|
||||
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if restricted {
|
||||
break
|
||||
}
|
||||
} else if !reflectValue.IsValid() {
|
||||
stmt.AddError(ErrInvalidData)
|
||||
} else if len(conds) == 0 {
|
||||
if len(args) == 1 {
|
||||
switch reflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
// optimize reflect value length
|
||||
valueLen := reflectValue.Len()
|
||||
values := make([]interface{}, valueLen)
|
||||
for i := 0; i < valueLen; i++ {
|
||||
values[i] = reflectValue.Index(i).Interface()
|
||||
}
|
||||
|
||||
if len(values) > 0 {
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(conds) > 0 {
|
||||
return []clause.Expression{clause.And(conds...)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build build sql with clauses names
|
||||
func (stmt *Statement) Build(clauses ...string) {
|
||||
var firstClauseWritten bool
|
||||
|
||||
for _, name := range clauses {
|
||||
if c, ok := stmt.Clauses[name]; ok {
|
||||
if firstClauseWritten {
|
||||
stmt.WriteByte(' ')
|
||||
}
|
||||
|
||||
firstClauseWritten = true
|
||||
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
|
||||
b(c, stmt)
|
||||
} else {
|
||||
c.Build(stmt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (stmt *Statement) Parse(value interface{}) (err error) {
|
||||
return stmt.ParseWithSpecialTableName(value, "")
|
||||
}
|
||||
|
||||
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
|
||||
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
|
||||
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
|
||||
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
|
||||
stmt.Table = tables[1]
|
||||
return
|
||||
}
|
||||
|
||||
stmt.Table = stmt.Schema.Table
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (stmt *Statement) clone() *Statement {
|
||||
newStmt := &Statement{
|
||||
TableExpr: stmt.TableExpr,
|
||||
Table: stmt.Table,
|
||||
Model: stmt.Model,
|
||||
Unscoped: stmt.Unscoped,
|
||||
Dest: stmt.Dest,
|
||||
ReflectValue: stmt.ReflectValue,
|
||||
Clauses: map[string]clause.Clause{},
|
||||
Distinct: stmt.Distinct,
|
||||
Selects: stmt.Selects,
|
||||
Omits: stmt.Omits,
|
||||
Preloads: map[string][]interface{}{},
|
||||
ConnPool: stmt.ConnPool,
|
||||
Schema: stmt.Schema,
|
||||
Context: stmt.Context,
|
||||
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
|
||||
SkipHooks: stmt.SkipHooks,
|
||||
}
|
||||
|
||||
if stmt.SQL.Len() > 0 {
|
||||
newStmt.SQL.WriteString(stmt.SQL.String())
|
||||
newStmt.Vars = make([]interface{}, 0, len(stmt.Vars))
|
||||
newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
|
||||
}
|
||||
|
||||
for k, c := range stmt.Clauses {
|
||||
newStmt.Clauses[k] = c
|
||||
}
|
||||
|
||||
for k, p := range stmt.Preloads {
|
||||
newStmt.Preloads[k] = p
|
||||
}
|
||||
|
||||
if len(stmt.Joins) > 0 {
|
||||
newStmt.Joins = make([]join, len(stmt.Joins))
|
||||
copy(newStmt.Joins, stmt.Joins)
|
||||
}
|
||||
|
||||
if len(stmt.scopes) > 0 {
|
||||
newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
|
||||
copy(newStmt.scopes, stmt.scopes)
|
||||
}
|
||||
|
||||
stmt.Settings.Range(func(k, v interface{}) bool {
|
||||
newStmt.Settings.Store(k, v)
|
||||
return true
|
||||
})
|
||||
|
||||
return newStmt
|
||||
}
|
||||
|
||||
// SetColumn set column's value
|
||||
//
|
||||
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
|
||||
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
|
||||
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
|
||||
if v, ok := stmt.Dest.(map[string]interface{}); ok {
|
||||
v[name] = value
|
||||
} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
|
||||
for _, m := range v {
|
||||
m[name] = value
|
||||
}
|
||||
} else if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
destValue := reflect.ValueOf(stmt.Dest)
|
||||
for destValue.Kind() == reflect.Ptr {
|
||||
destValue = destValue.Elem()
|
||||
}
|
||||
|
||||
if stmt.ReflectValue != destValue {
|
||||
if !destValue.CanAddr() {
|
||||
destValueCanAddr := reflect.New(destValue.Type())
|
||||
destValueCanAddr.Elem().Set(destValue)
|
||||
stmt.Dest = destValueCanAddr.Interface()
|
||||
destValue = destValueCanAddr.Elem()
|
||||
}
|
||||
|
||||
switch destValue.Kind() {
|
||||
case reflect.Struct:
|
||||
stmt.AddError(field.Set(stmt.Context, destValue, value))
|
||||
default:
|
||||
stmt.AddError(ErrInvalidData)
|
||||
}
|
||||
}
|
||||
|
||||
switch stmt.ReflectValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if len(fromCallbacks) > 0 {
|
||||
for i := 0; i < stmt.ReflectValue.Len(); i++ {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
|
||||
}
|
||||
} else {
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !stmt.ReflectValue.CanAddr() {
|
||||
stmt.AddError(ErrInvalidValue)
|
||||
return
|
||||
}
|
||||
|
||||
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
|
||||
}
|
||||
} else {
|
||||
stmt.AddError(ErrInvalidField)
|
||||
}
|
||||
} else {
|
||||
stmt.AddError(ErrInvalidField)
|
||||
}
|
||||
}
|
||||
|
||||
// Changed check model changed or not when updating
|
||||
func (stmt *Statement) Changed(fields ...string) bool {
|
||||
modelValue := stmt.ReflectValue
|
||||
switch modelValue.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
|
||||
}
|
||||
|
||||
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
|
||||
changed := func(field *schema.Field) bool {
|
||||
fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
|
||||
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
|
||||
if mv, mok := stmt.Dest.(map[string]interface{}); mok {
|
||||
if fv, ok := mv[field.Name]; ok {
|
||||
return !utils.AssertEqual(fv, fieldValue)
|
||||
} else if fv, ok := mv[field.DBName]; ok {
|
||||
return !utils.AssertEqual(fv, fieldValue)
|
||||
}
|
||||
} else {
|
||||
destValue := reflect.ValueOf(stmt.Dest)
|
||||
for destValue.Kind() == reflect.Ptr {
|
||||
destValue = destValue.Elem()
|
||||
}
|
||||
|
||||
changedValue, zero := field.ValueOf(stmt.Context, destValue)
|
||||
if v {
|
||||
return !utils.AssertEqual(changedValue, fieldValue)
|
||||
}
|
||||
return !zero && !utils.AssertEqual(changedValue, fieldValue)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if len(fields) == 0 {
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
if changed(field) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, name := range fields {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
if changed(field) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var matchName = func() func(tableColumn string) (table, column string) {
|
||||
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
|
||||
return func(tableColumn string) (table, column string) {
|
||||
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
|
||||
table = matches[1]
|
||||
star := matches[2]
|
||||
columnName := matches[3]
|
||||
if star != "" {
|
||||
return table, star
|
||||
}
|
||||
return table, columnName
|
||||
}
|
||||
return "", ""
|
||||
}
|
||||
}()
|
||||
|
||||
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
|
||||
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
|
||||
results := map[string]bool{}
|
||||
notRestricted := false
|
||||
|
||||
processColumn := func(column string, result bool) {
|
||||
if stmt.Schema == nil {
|
||||
results[column] = result
|
||||
} else if column == "*" {
|
||||
notRestricted = result
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
}
|
||||
} else if column == clause.Associations {
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
results[rel.Name] = result
|
||||
}
|
||||
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
|
||||
results[field.DBName] = result
|
||||
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
|
||||
if col == "*" {
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
results[dbName] = result
|
||||
}
|
||||
} else {
|
||||
results[col] = result
|
||||
}
|
||||
} else {
|
||||
results[column] = result
|
||||
}
|
||||
}
|
||||
|
||||
// select columns
|
||||
for _, column := range stmt.Selects {
|
||||
processColumn(column, true)
|
||||
}
|
||||
|
||||
// omit columns
|
||||
for _, column := range stmt.Omits {
|
||||
processColumn(column, false)
|
||||
}
|
||||
|
||||
if stmt.Schema != nil {
|
||||
for _, field := range stmt.Schema.FieldsByName {
|
||||
name := field.DBName
|
||||
if name == "" {
|
||||
name = field.Name
|
||||
}
|
||||
|
||||
if requireCreate && !field.Creatable {
|
||||
results[name] = false
|
||||
} else if requireUpdate && !field.Updatable {
|
||||
results[name] = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results, !notRestricted && len(stmt.Selects) > 0
|
||||
}
|
||||
160
vendor/gorm.io/gorm/utils/utils.go
generated
vendored
Normal file
160
vendor/gorm.io/gorm/utils/utils.go
generated
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
var gormSourceDir string
|
||||
|
||||
func init() {
|
||||
_, file, _, _ := runtime.Caller(0)
|
||||
// compatible solution to get gorm source directory with various operating systems
|
||||
gormSourceDir = sourceDir(file)
|
||||
}
|
||||
|
||||
func sourceDir(file string) string {
|
||||
dir := filepath.Dir(file)
|
||||
dir = filepath.Dir(dir)
|
||||
|
||||
s := filepath.Dir(dir)
|
||||
if filepath.Base(s) != "gorm.io" {
|
||||
s = dir
|
||||
}
|
||||
return filepath.ToSlash(s) + "/"
|
||||
}
|
||||
|
||||
// FileWithLineNum return the file name and line number of the current file
|
||||
func FileWithLineNum() string {
|
||||
// the second caller usually from gorm internal, so set i start from 2
|
||||
for i := 2; i < 15; i++ {
|
||||
_, file, line, ok := runtime.Caller(i)
|
||||
if ok && (!strings.HasPrefix(file, gormSourceDir) || strings.HasSuffix(file, "_test.go")) &&
|
||||
!strings.HasSuffix(file, ".gen.go") {
|
||||
return file + ":" + strconv.FormatInt(int64(line), 10)
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func IsValidDBNameChar(c rune) bool {
|
||||
return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@'
|
||||
}
|
||||
|
||||
// CheckTruth check string true or not
|
||||
func CheckTruth(vals ...string) bool {
|
||||
for _, val := range vals {
|
||||
if val != "" && !strings.EqualFold(val, "false") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ToStringKey(values ...interface{}) string {
|
||||
results := make([]string, len(values))
|
||||
|
||||
for idx, value := range values {
|
||||
if valuer, ok := value.(driver.Valuer); ok {
|
||||
value, _ = valuer.Value()
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
results[idx] = v
|
||||
case []byte:
|
||||
results[idx] = string(v)
|
||||
case uint:
|
||||
results[idx] = strconv.FormatUint(uint64(v), 10)
|
||||
default:
|
||||
results[idx] = fmt.Sprint(reflect.Indirect(reflect.ValueOf(v)).Interface())
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(results, "_")
|
||||
}
|
||||
|
||||
func Contains(elems []string, elem string) bool {
|
||||
for _, e := range elems {
|
||||
if elem == e {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func AssertEqual(x, y interface{}) bool {
|
||||
if reflect.DeepEqual(x, y) {
|
||||
return true
|
||||
}
|
||||
if x == nil || y == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
xval := reflect.ValueOf(x)
|
||||
yval := reflect.ValueOf(y)
|
||||
if xval.Kind() == reflect.Ptr && xval.IsNil() ||
|
||||
yval.Kind() == reflect.Ptr && yval.IsNil() {
|
||||
return false
|
||||
}
|
||||
|
||||
if valuer, ok := x.(driver.Valuer); ok {
|
||||
x, _ = valuer.Value()
|
||||
}
|
||||
if valuer, ok := y.(driver.Valuer); ok {
|
||||
y, _ = valuer.Value()
|
||||
}
|
||||
return reflect.DeepEqual(x, y)
|
||||
}
|
||||
|
||||
func ToString(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case int:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(v), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10)
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(v), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(v, 10)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const nestedRelationSplit = "__"
|
||||
|
||||
// NestedRelationName nested relationships like `Manager__Company`
|
||||
func NestedRelationName(prefix, name string) string {
|
||||
return prefix + nestedRelationSplit + name
|
||||
}
|
||||
|
||||
// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}`
|
||||
func SplitNestedRelationName(name string) []string {
|
||||
return strings.Split(name, nestedRelationSplit)
|
||||
}
|
||||
|
||||
// JoinNestedRelationNames nested relationships like `Manager__Company`
|
||||
func JoinNestedRelationNames(relationNames []string) string {
|
||||
return strings.Join(relationNames, nestedRelationSplit)
|
||||
}
|
||||
Reference in New Issue
Block a user