提取修改猪群数量逻辑

This commit is contained in:
2025-10-06 15:08:32 +08:00
parent 59b6977367
commit 91e18c432c
8 changed files with 66 additions and 54 deletions

View File

@@ -26,7 +26,7 @@ type MockDeviceRepository struct {
mock.Mock mock.Mock
} }
// Create 模拟 DeviceRepository 的 Create 方法 // CreateTx 模拟 DeviceRepository 的 CreateTx 方法
func (m *MockDeviceRepository) Create(device *models.Device) error { func (m *MockDeviceRepository) Create(device *models.Device) error {
args := m.Called(device) args := m.Called(device)
return args.Error(0) return args.Error(0)
@@ -169,7 +169,7 @@ func TestCreateDevice(t *testing.T) {
Properties: controller.Properties(`{"lora_address":"0x1234"}`), Properties: controller.Properties(`{"lora_address":"0x1234"}`),
}, },
mockRepoSetup: func(m *MockDeviceRepository) { mockRepoSetup: func(m *MockDeviceRepository) {
m.On("Create", mock.MatchedBy(func(dev *models.Device) bool { m.On("CreateTx", mock.MatchedBy(func(dev *models.Device) bool {
// 检查 Name 字段 // 检查 Name 字段
nameMatch := dev.Name == "主控A" nameMatch := dev.Name == "主控A"
// 检查 Type 字段 // 检查 Type 字段
@@ -215,7 +215,7 @@ func TestCreateDevice(t *testing.T) {
Properties: controller.Properties(`{"bus_id":1,"bus_address":10}`), Properties: controller.Properties(`{"bus_id":1,"bus_address":10}`),
}, },
mockRepoSetup: func(m *MockDeviceRepository) { mockRepoSetup: func(m *MockDeviceRepository) {
m.On("Create", mock.Anything).Return(nil).Run(func(args mock.Arguments) { m.On("CreateTx", mock.Anything).Return(nil).Run(func(args mock.Arguments) {
arg := args.Get(0).(*models.Device) arg := args.Get(0).(*models.Device)
arg.ID = 2 arg.ID = 2
arg.CreatedAt = time.Now() arg.CreatedAt = time.Now()
@@ -259,7 +259,7 @@ func TestCreateDevice(t *testing.T) {
Type: models.DeviceTypeDevice, Type: models.DeviceTypeDevice,
}, },
mockRepoSetup: func(m *MockDeviceRepository) { mockRepoSetup: func(m *MockDeviceRepository) {
m.On("Create", mock.Anything).Return(errors.New("db error")).Once() m.On("CreateTx", mock.Anything).Return(errors.New("db error")).Once()
}, },
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedCode: controller.CodeInternalError, expectedCode: controller.CodeInternalError,
@@ -276,9 +276,9 @@ func TestCreateDevice(t *testing.T) {
Properties: controller.Properties(`{invalid json}`), Properties: controller.Properties(`{invalid json}`),
}, },
mockRepoSetup: func(m *MockDeviceRepository) { mockRepoSetup: func(m *MockDeviceRepository) {
// 期望 Create 方法被调用,并返回一个模拟的数据库错误 // 期望 CreateTx 方法被调用,并返回一个模拟的数据库错误
// 这个错误模拟的是数据库层因为 Properties 字段的 JSON 格式无效而拒绝保存 // 这个错误模拟的是数据库层因为 Properties 字段的 JSON 格式无效而拒绝保存
m.On("Create", mock.Anything).Return(errors.New("database error: invalid json format")).Run(func(args mock.Arguments) { m.On("CreateTx", mock.Anything).Return(errors.New("database error: invalid json format")).Run(func(args mock.Arguments) {
dev := args.Get(0).(*models.Device) dev := args.Get(0).(*models.Device)
assert.Equal(t, "无效JSON设备", dev.Name) assert.Equal(t, "无效JSON设备", dev.Name)
assert.Equal(t, models.DeviceTypeDevice, dev.Type) assert.Equal(t, models.DeviceTypeDevice, dev.Type)

View File

@@ -25,7 +25,7 @@ type MockUserRepository struct {
mock.Mock mock.Mock
} }
// Create 模拟 UserRepository 的 Create 方法 // CreateTx 模拟 UserRepository 的 CreateTx 方法
func (m *MockUserRepository) Create(user *models.User) error { func (m *MockUserRepository) Create(user *models.User) error {
args := m.Called(user) args := m.Called(user)
return args.Error(0) return args.Error(0)
@@ -90,8 +90,8 @@ func TestCreateUser(t *testing.T) {
Password: "password123", Password: "password123",
}, },
mockRepoSetup: func(m *MockUserRepository) { mockRepoSetup: func(m *MockUserRepository) {
// 模拟 Create 成功 // 模拟 CreateTx 成功
m.On("Create", mock.AnythingOfType("*models.User")).Return(nil).Run(func(args mock.Arguments) { m.On("CreateTx", mock.AnythingOfType("*models.User")).Return(nil).Run(func(args mock.Arguments) {
// 模拟数据库自动填充 ID // 模拟数据库自动填充 ID
userArg := args.Get(0).(*models.User) userArg := args.Get(0).(*models.User)
userArg.ID = 1 // 设置一个非零的 ID userArg.ID = 1 // 设置一个非零的 ID
@@ -114,7 +114,7 @@ func TestCreateUser(t *testing.T) {
Password: "123", // 密码少于6位 Password: "123", // 密码少于6位
}, },
mockRepoSetup: func(m *MockUserRepository) { mockRepoSetup: func(m *MockUserRepository) {
// 不会调用 Create 或 FindByUsername // 不会调用 CreateTx 或 FindByUsername
}, },
expectedResponse: map[string]interface{}{ expectedResponse: map[string]interface{}{
"code": float64(controller.CodeBadRequest), "code": float64(controller.CodeBadRequest),
@@ -128,7 +128,7 @@ func TestCreateUser(t *testing.T) {
Password: "password123", Password: "password123",
}, },
mockRepoSetup: func(m *MockUserRepository) { mockRepoSetup: func(m *MockUserRepository) {
// 不会调用 Create 或 FindByUsername // 不会调用 CreateTx 或 FindByUsername
}, },
expectedResponse: map[string]interface{}{ expectedResponse: map[string]interface{}{
"code": float64(controller.CodeBadRequest), "code": float64(controller.CodeBadRequest),
@@ -143,8 +143,8 @@ func TestCreateUser(t *testing.T) {
Password: "password123", Password: "password123",
}, },
mockRepoSetup: func(m *MockUserRepository) { mockRepoSetup: func(m *MockUserRepository) {
// 模拟 Create 失败,因为用户名已存在 // 模拟 CreateTx 失败,因为用户名已存在
m.On("Create", mock.AnythingOfType("*models.User")).Return(errors.New("duplicate entry")).Once() m.On("CreateTx", mock.AnythingOfType("*models.User")).Return(errors.New("duplicate entry")).Once()
// 模拟 FindByUsername 找到用户,确认是用户名重复 // 模拟 FindByUsername 找到用户,确认是用户名重复
m.On("FindByUsername", "existinguser").Return(&models.User{Username: "existinguser"}, nil).Once() m.On("FindByUsername", "existinguser").Return(&models.User{Username: "existinguser"}, nil).Once()
}, },
@@ -161,8 +161,8 @@ func TestCreateUser(t *testing.T) {
Password: "password123", Password: "password123",
}, },
mockRepoSetup: func(m *MockUserRepository) { mockRepoSetup: func(m *MockUserRepository) {
// 模拟 Create 失败,通用数据库错误 // 模拟 CreateTx 失败,通用数据库错误
m.On("Create", mock.AnythingOfType("*models.User")).Return(errors.New("database error")).Once() m.On("CreateTx", mock.AnythingOfType("*models.User")).Return(errors.New("database error")).Once()
// 模拟 FindByUsername 找不到用户,确认不是用户名重复 // 模拟 FindByUsername 找不到用户,确认不是用户名重复
m.On("FindByUsername", "db_error_user").Return(nil, gorm.ErrRecordNotFound).Once() m.On("FindByUsername", "db_error_user").Return(nil, gorm.ErrRecordNotFound).Once()
}, },

View File

@@ -24,6 +24,8 @@ var (
ErrPenNotFound = errors.New("指定的猪栏不存在") ErrPenNotFound = errors.New("指定的猪栏不存在")
// ErrPenNotAssociatedWithBatch 表示猪栏未与该批次关联 // ErrPenNotAssociatedWithBatch 表示猪栏未与该批次关联
ErrPenNotAssociatedWithBatch = errors.New("猪栏未与该批次关联") ErrPenNotAssociatedWithBatch = errors.New("猪栏未与该批次关联")
// ErrInvalidOperation 非法操作
ErrInvalidOperation = errors.New("非法操作")
) )
// --- 领域服务接口 --- // --- 领域服务接口 ---
@@ -51,4 +53,6 @@ type PigBatchService interface {
SellPigs(batchID uint, quantity int, unitPrice float64, tatalPrice float64, traderName string, tradeDate time.Time, remarks string, operatorID uint) error SellPigs(batchID uint, quantity int, unitPrice float64, tatalPrice float64, traderName string, tradeDate time.Time, remarks string, operatorID uint) error
// BuyPigs 处理买猪的业务逻辑。 // BuyPigs 处理买猪的业务逻辑。
BuyPigs(batchID uint, quantity int, unitPrice float64, tatalPrice float64, traderName string, tradeDate time.Time, remarks string, operatorID uint) error BuyPigs(batchID uint, quantity int, unitPrice float64, tatalPrice float64, traderName string, tradeDate time.Time, remarks string, operatorID uint) error
UpdatePigBatchQuantity(operatorID uint, batchID uint, changeType models.LogChangeType, changeAmount int, changeReason string, happenedAt time.Time) error
} }

View File

@@ -67,7 +67,7 @@ func (s *pigBatchService) CreatePigBatch(operatorID uint, batch *models.PigBatch
} }
// 3. 记录批次日志 // 3. 记录批次日志
if err := s.pigBatchLogRepo.Create(tx, initialLog); err != nil { if err := s.pigBatchLogRepo.CreateTx(tx, initialLog); err != nil {
return fmt.Errorf("记录初始批次日志失败: %w", err) return fmt.Errorf("记录初始批次日志失败: %w", err)
} }
@@ -274,3 +274,33 @@ func (s *pigBatchService) getCurrentPigQuantityTx(tx *gorm.DB, batchID uint) (in
// 3. 如果找到最后一条日志,则当前数量为该日志的 AfterCount // 3. 如果找到最后一条日志,则当前数量为该日志的 AfterCount
return lastLog.AfterCount, nil return lastLog.AfterCount, nil
} }
func (s *pigBatchService) UpdatePigBatchQuantity(operatorID uint, batchID uint, changeType models.LogChangeType, changeAmount int, changeReason string, happenedAt time.Time) error {
return s.uow.ExecuteInTransaction(func(tx *gorm.DB) error {
return s.updatePigBatchQuantityTx(tx, operatorID, batchID, changeType, changeAmount, changeReason, happenedAt)
})
}
func (s *pigBatchService) updatePigBatchQuantityTx(tx *gorm.DB, operatorID uint, batchID uint, changeType models.LogChangeType, changeAmount int, changeReason string, happenedAt time.Time) error {
lastLog, err := s.pigBatchLogRepo.GetLastLogByBatchIDTx(tx, batchID)
if err != nil {
return err
}
// 检查数量不应该减到小于零
if changeAmount < 0 {
if lastLog.AfterCount+changeAmount < 0 {
return ErrInvalidOperation
}
}
pigBatchLog := &models.PigBatchLog{
PigBatchID: batchID,
ChangeType: changeType,
ChangeCount: changeAmount,
Reason: changeReason,
BeforeCount: lastLog.AfterCount,
AfterCount: lastLog.AfterCount + changeAmount,
OperatorID: operatorID,
HappenedAt: happenedAt,
}
return s.pigBatchLogRepo.CreateTx(tx, pigBatchLog)
}

View File

@@ -51,18 +51,10 @@ func (s *pigBatchService) SellPigs(batchID uint, quantity int, unitPrice float64
} }
// 4. 记录批次日志 // 4. 记录批次日志
log := &models.PigBatchLog{ if err := s.updatePigBatchQuantityTx(tx, operatorID, batchID, models.ChangeTypeSale, -quantity,
PigBatchID: batchID, fmt.Sprintf("猪批次 %d 销售 %d 头猪给 %s", batchID, quantity, traderName),
HappenedAt: time.Now(), tradeDate); err != nil {
ChangeType: models.ChangeTypeSale, return fmt.Errorf("更新猪批次数量失败: %w", err)
ChangeCount: -quantity,
Reason: fmt.Sprintf("猪批次 %d 销售 %d 头猪给 %s", batchID, quantity, traderName),
BeforeCount: currentQuantity,
AfterCount: currentQuantity - quantity,
OperatorID: operatorID,
}
if err := s.pigBatchLogRepo.Create(tx, log); err != nil {
return fmt.Errorf("记录销售批次日志失败: %w", err)
} }
return nil return nil
@@ -85,12 +77,6 @@ func (s *pigBatchService) BuyPigs(batchID uint, quantity int, unitPrice float64,
return fmt.Errorf("获取猪批次 %d 信息失败: %w", batchID, err) return fmt.Errorf("获取猪批次 %d 信息失败: %w", batchID, err)
} }
// 2. 获取当前猪批次数量
currentQuantity, err := s.getCurrentPigQuantityTx(tx, batchID)
if err != nil {
return fmt.Errorf("获取猪批次 %d 当前数量失败: %w", batchID, err)
}
// 3. 记录采购交易 // 3. 记录采购交易
purchase := &models.PigPurchase{ purchase := &models.PigPurchase{
PigBatchID: batchID, PigBatchID: batchID,
@@ -107,18 +93,10 @@ func (s *pigBatchService) BuyPigs(batchID uint, quantity int, unitPrice float64,
} }
// 4. 记录批次日志 // 4. 记录批次日志
log := &models.PigBatchLog{ if err := s.updatePigBatchQuantityTx(tx, operatorID, batchID, models.ChangeTypeBuy, quantity,
PigBatchID: batchID, fmt.Sprintf("猪批次 %d 采购 %d 头猪从 %s", batchID, quantity, traderName),
HappenedAt: time.Now(), tradeDate); err != nil {
ChangeType: models.ChangeTypeBuy, return fmt.Errorf("更新猪批次数量失败: %w", err)
ChangeCount: quantity,
Reason: fmt.Sprintf("猪批次 %d 采购 %d 头猪从 %s", batchID, quantity, traderName),
BeforeCount: currentQuantity,
AfterCount: currentQuantity + quantity,
OperatorID: operatorID,
}
if err := s.pigBatchLogRepo.Create(tx, log); err != nil {
return fmt.Errorf("记录采购批次日志失败: %w", err)
} }
return nil return nil

View File

@@ -96,7 +96,7 @@ func (r *gormExecutionLogRepository) CreateTaskExecutionLogsInBatch(logs []*mode
if len(logs) == 0 { if len(logs) == 0 {
return nil return nil
} }
// GORM 的 Create 传入一个切片指针会执行批量插入。 // GORM 的 CreateTx 传入一个切片指针会执行批量插入。
return r.db.Create(&logs).Error return r.db.Create(&logs).Error
} }

View File

@@ -9,8 +9,8 @@ import (
// PigBatchLogRepository 定义了与猪批次日志相关的数据库操作接口。 // PigBatchLogRepository 定义了与猪批次日志相关的数据库操作接口。
type PigBatchLogRepository interface { type PigBatchLogRepository interface {
// Create 在指定的事务中创建一条新的猪批次日志。 // CreateTx 在指定的事务中创建一条新的猪批次日志。
Create(tx *gorm.DB, log *models.PigBatchLog) error CreateTx(tx *gorm.DB, log *models.PigBatchLog) error
// GetLogsByBatchIDAndDateRangeTx 在指定的事务中,获取指定批次在特定时间范围内的所有日志记录。 // GetLogsByBatchIDAndDateRangeTx 在指定的事务中,获取指定批次在特定时间范围内的所有日志记录。
GetLogsByBatchIDAndDateRangeTx(tx *gorm.DB, batchID uint, startDate, endDate time.Time) ([]*models.PigBatchLog, error) GetLogsByBatchIDAndDateRangeTx(tx *gorm.DB, batchID uint, startDate, endDate time.Time) ([]*models.PigBatchLog, error)
@@ -30,7 +30,7 @@ func NewGormPigBatchLogRepository(db *gorm.DB) PigBatchLogRepository {
} }
// Create 实现了创建猪批次日志的逻辑。 // Create 实现了创建猪批次日志的逻辑。
func (r *gormPigBatchLogRepository) Create(tx *gorm.DB, log *models.PigBatchLog) error { func (r *gormPigBatchLogRepository) CreateTx(tx *gorm.DB, log *models.PigBatchLog) error {
return tx.Create(log).Error return tx.Create(log).Error
} }

View File

@@ -870,7 +870,7 @@ func TestPlanRepository_Create(t *testing.T) {
type testCase struct { type testCase struct {
name string name string
setupDB func(db *gorm.DB) // 准备数据库的初始状态 setupDB func(db *gorm.DB) // 准备数据库的初始状态
inputPlan *models.Plan // 传入 Create 方法的计划对象 inputPlan *models.Plan // 传入 CreateTx 方法的计划对象
expectedError error // 期望的错误类型 expectedError error // 期望的错误类型
verifyDB func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) // 验证数据库状态 verifyDB func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) // 验证数据库状态
} }
@@ -1040,7 +1040,7 @@ func TestPlanRepository_Create(t *testing.T) {
{Name: "Task 2", ExecutionOrder: 1}, // 重复的顺序 {Name: "Task 2", ExecutionOrder: 1}, // 重复的顺序
}, },
}, },
expectedError: fmt.Errorf("任务执行顺序重复: %d", 1), // 假设 Create 方法会返回此错误 expectedError: fmt.Errorf("任务执行顺序重复: %d", 1), // 假设 CreateTx 方法会返回此错误
verifyDB: func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) { verifyDB: func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) {
var count int64 var count int64
db.Model(&models.Plan{}).Where("name = ?", "重复任务顺序计划").Count(&count) db.Model(&models.Plan{}).Where("name = ?", "重复任务顺序计划").Count(&count)
@@ -1061,7 +1061,7 @@ func TestPlanRepository_Create(t *testing.T) {
{ChildPlanID: 2, ExecutionOrder: 1}, // 重复的顺序 {ChildPlanID: 2, ExecutionOrder: 1}, // 重复的顺序
}, },
}, },
expectedError: fmt.Errorf("子计划执行顺序重复: %d", 1), // 假设 Create 方法会返回此错误 expectedError: fmt.Errorf("子计划执行顺序重复: %d", 1), // 假设 CreateTx 方法会返回此错误
verifyDB: func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) { verifyDB: func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) {
var count int64 var count int64
db.Model(&models.Plan{}).Where("name = ?", "重复子计划顺序计划").Count(&count) db.Model(&models.Plan{}).Where("name = ?", "重复子计划顺序计划").Count(&count)
@@ -1078,7 +1078,7 @@ func TestPlanRepository_Create(t *testing.T) {
// 准备数据库状态 // 准备数据库状态
tc.setupDB(db) tc.setupDB(db)
// 执行 Create 操作 // 执行 CreateTx 操作
err := repo.CreatePlan(tc.inputPlan) err := repo.CreatePlan(tc.inputPlan)
// 断言错误 // 断言错误