From 91e18c432cf3511a2f9638b361f03d9a1a57b988 Mon Sep 17 00:00:00 2001 From: huang <1724659546@qq.com> Date: Mon, 6 Oct 2025 15:08:32 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E5=8F=96=E4=BF=AE=E6=94=B9=E7=8C=AA?= =?UTF-8?q?=E7=BE=A4=E6=95=B0=E9=87=8F=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../device/device_controller_test.go | 12 +++--- .../controller/user/user_controller_test.go | 18 ++++----- internal/domain/pig/pig_batch.go | 4 ++ internal/domain/pig/pig_batch_service.go | 32 +++++++++++++++- .../domain/pig/pig_batch_service_pig_trade.go | 38 ++++--------------- .../repository/execution_log_repository.go | 2 +- .../repository/pig_batch_log_repository.go | 6 +-- .../infra/repository/plan_repository_test.go | 8 ++-- 8 files changed, 66 insertions(+), 54 deletions(-) diff --git a/internal/app/controller/device/device_controller_test.go b/internal/app/controller/device/device_controller_test.go index e943533..7bb1239 100644 --- a/internal/app/controller/device/device_controller_test.go +++ b/internal/app/controller/device/device_controller_test.go @@ -26,7 +26,7 @@ type MockDeviceRepository struct { mock.Mock } -// Create 模拟 DeviceRepository 的 Create 方法 +// CreateTx 模拟 DeviceRepository 的 CreateTx 方法 func (m *MockDeviceRepository) Create(device *models.Device) error { args := m.Called(device) return args.Error(0) @@ -169,7 +169,7 @@ func TestCreateDevice(t *testing.T) { Properties: controller.Properties(`{"lora_address":"0x1234"}`), }, mockRepoSetup: func(m *MockDeviceRepository) { - m.On("Create", mock.MatchedBy(func(dev *models.Device) bool { + m.On("CreateTx", mock.MatchedBy(func(dev *models.Device) bool { // 检查 Name 字段 nameMatch := dev.Name == "主控A" // 检查 Type 字段 @@ -215,7 +215,7 @@ func TestCreateDevice(t *testing.T) { Properties: controller.Properties(`{"bus_id":1,"bus_address":10}`), }, mockRepoSetup: func(m *MockDeviceRepository) { - m.On("Create", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + m.On("CreateTx", mock.Anything).Return(nil).Run(func(args mock.Arguments) { arg := args.Get(0).(*models.Device) arg.ID = 2 arg.CreatedAt = time.Now() @@ -259,7 +259,7 @@ func TestCreateDevice(t *testing.T) { Type: models.DeviceTypeDevice, }, mockRepoSetup: func(m *MockDeviceRepository) { - m.On("Create", mock.Anything).Return(errors.New("db error")).Once() + m.On("CreateTx", mock.Anything).Return(errors.New("db error")).Once() }, expectedStatus: http.StatusOK, expectedCode: controller.CodeInternalError, @@ -276,9 +276,9 @@ func TestCreateDevice(t *testing.T) { Properties: controller.Properties(`{invalid json}`), }, mockRepoSetup: func(m *MockDeviceRepository) { - // 期望 Create 方法被调用,并返回一个模拟的数据库错误 + // 期望 CreateTx 方法被调用,并返回一个模拟的数据库错误 // 这个错误模拟的是数据库层因为 Properties 字段的 JSON 格式无效而拒绝保存 - m.On("Create", mock.Anything).Return(errors.New("database error: invalid json format")).Run(func(args mock.Arguments) { + m.On("CreateTx", mock.Anything).Return(errors.New("database error: invalid json format")).Run(func(args mock.Arguments) { dev := args.Get(0).(*models.Device) assert.Equal(t, "无效JSON设备", dev.Name) assert.Equal(t, models.DeviceTypeDevice, dev.Type) diff --git a/internal/app/controller/user/user_controller_test.go b/internal/app/controller/user/user_controller_test.go index 2cd004c..c447f11 100644 --- a/internal/app/controller/user/user_controller_test.go +++ b/internal/app/controller/user/user_controller_test.go @@ -25,7 +25,7 @@ type MockUserRepository struct { mock.Mock } -// Create 模拟 UserRepository 的 Create 方法 +// CreateTx 模拟 UserRepository 的 CreateTx 方法 func (m *MockUserRepository) Create(user *models.User) error { args := m.Called(user) return args.Error(0) @@ -90,8 +90,8 @@ func TestCreateUser(t *testing.T) { Password: "password123", }, mockRepoSetup: func(m *MockUserRepository) { - // 模拟 Create 成功 - m.On("Create", mock.AnythingOfType("*models.User")).Return(nil).Run(func(args mock.Arguments) { + // 模拟 CreateTx 成功 + m.On("CreateTx", mock.AnythingOfType("*models.User")).Return(nil).Run(func(args mock.Arguments) { // 模拟数据库自动填充 ID userArg := args.Get(0).(*models.User) userArg.ID = 1 // 设置一个非零的 ID @@ -114,7 +114,7 @@ func TestCreateUser(t *testing.T) { Password: "123", // 密码少于6位 }, mockRepoSetup: func(m *MockUserRepository) { - // 不会调用 Create 或 FindByUsername + // 不会调用 CreateTx 或 FindByUsername }, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeBadRequest), @@ -128,7 +128,7 @@ func TestCreateUser(t *testing.T) { Password: "password123", }, mockRepoSetup: func(m *MockUserRepository) { - // 不会调用 Create 或 FindByUsername + // 不会调用 CreateTx 或 FindByUsername }, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeBadRequest), @@ -143,8 +143,8 @@ func TestCreateUser(t *testing.T) { Password: "password123", }, mockRepoSetup: func(m *MockUserRepository) { - // 模拟 Create 失败,因为用户名已存在 - m.On("Create", mock.AnythingOfType("*models.User")).Return(errors.New("duplicate entry")).Once() + // 模拟 CreateTx 失败,因为用户名已存在 + m.On("CreateTx", mock.AnythingOfType("*models.User")).Return(errors.New("duplicate entry")).Once() // 模拟 FindByUsername 找到用户,确认是用户名重复 m.On("FindByUsername", "existinguser").Return(&models.User{Username: "existinguser"}, nil).Once() }, @@ -161,8 +161,8 @@ func TestCreateUser(t *testing.T) { Password: "password123", }, mockRepoSetup: func(m *MockUserRepository) { - // 模拟 Create 失败,通用数据库错误 - m.On("Create", mock.AnythingOfType("*models.User")).Return(errors.New("database error")).Once() + // 模拟 CreateTx 失败,通用数据库错误 + m.On("CreateTx", mock.AnythingOfType("*models.User")).Return(errors.New("database error")).Once() // 模拟 FindByUsername 找不到用户,确认不是用户名重复 m.On("FindByUsername", "db_error_user").Return(nil, gorm.ErrRecordNotFound).Once() }, diff --git a/internal/domain/pig/pig_batch.go b/internal/domain/pig/pig_batch.go index 63a2d93..2d56781 100644 --- a/internal/domain/pig/pig_batch.go +++ b/internal/domain/pig/pig_batch.go @@ -24,6 +24,8 @@ var ( ErrPenNotFound = errors.New("指定的猪栏不存在") // ErrPenNotAssociatedWithBatch 表示猪栏未与该批次关联 ErrPenNotAssociatedWithBatch = errors.New("猪栏未与该批次关联") + // ErrInvalidOperation 非法操作 + ErrInvalidOperation = errors.New("非法操作") ) // --- 领域服务接口 --- @@ -51,4 +53,6 @@ type PigBatchService interface { SellPigs(batchID uint, quantity int, unitPrice float64, tatalPrice float64, traderName string, tradeDate time.Time, remarks string, operatorID uint) error // BuyPigs 处理买猪的业务逻辑。 BuyPigs(batchID uint, quantity int, unitPrice float64, tatalPrice float64, traderName string, tradeDate time.Time, remarks string, operatorID uint) error + + UpdatePigBatchQuantity(operatorID uint, batchID uint, changeType models.LogChangeType, changeAmount int, changeReason string, happenedAt time.Time) error } diff --git a/internal/domain/pig/pig_batch_service.go b/internal/domain/pig/pig_batch_service.go index 110c0cb..65b0e07 100644 --- a/internal/domain/pig/pig_batch_service.go +++ b/internal/domain/pig/pig_batch_service.go @@ -67,7 +67,7 @@ func (s *pigBatchService) CreatePigBatch(operatorID uint, batch *models.PigBatch } // 3. 记录批次日志 - if err := s.pigBatchLogRepo.Create(tx, initialLog); err != nil { + if err := s.pigBatchLogRepo.CreateTx(tx, initialLog); err != nil { return fmt.Errorf("记录初始批次日志失败: %w", err) } @@ -274,3 +274,33 @@ func (s *pigBatchService) getCurrentPigQuantityTx(tx *gorm.DB, batchID uint) (in // 3. 如果找到最后一条日志,则当前数量为该日志的 AfterCount return lastLog.AfterCount, nil } + +func (s *pigBatchService) UpdatePigBatchQuantity(operatorID uint, batchID uint, changeType models.LogChangeType, changeAmount int, changeReason string, happenedAt time.Time) error { + return s.uow.ExecuteInTransaction(func(tx *gorm.DB) error { + return s.updatePigBatchQuantityTx(tx, operatorID, batchID, changeType, changeAmount, changeReason, happenedAt) + }) +} + +func (s *pigBatchService) updatePigBatchQuantityTx(tx *gorm.DB, operatorID uint, batchID uint, changeType models.LogChangeType, changeAmount int, changeReason string, happenedAt time.Time) error { + lastLog, err := s.pigBatchLogRepo.GetLastLogByBatchIDTx(tx, batchID) + if err != nil { + return err + } + // 检查数量不应该减到小于零 + if changeAmount < 0 { + if lastLog.AfterCount+changeAmount < 0 { + return ErrInvalidOperation + } + } + pigBatchLog := &models.PigBatchLog{ + PigBatchID: batchID, + ChangeType: changeType, + ChangeCount: changeAmount, + Reason: changeReason, + BeforeCount: lastLog.AfterCount, + AfterCount: lastLog.AfterCount + changeAmount, + OperatorID: operatorID, + HappenedAt: happenedAt, + } + return s.pigBatchLogRepo.CreateTx(tx, pigBatchLog) +} diff --git a/internal/domain/pig/pig_batch_service_pig_trade.go b/internal/domain/pig/pig_batch_service_pig_trade.go index 48603e1..de327a0 100644 --- a/internal/domain/pig/pig_batch_service_pig_trade.go +++ b/internal/domain/pig/pig_batch_service_pig_trade.go @@ -51,18 +51,10 @@ func (s *pigBatchService) SellPigs(batchID uint, quantity int, unitPrice float64 } // 4. 记录批次日志 - log := &models.PigBatchLog{ - PigBatchID: batchID, - HappenedAt: time.Now(), - ChangeType: models.ChangeTypeSale, - ChangeCount: -quantity, - Reason: fmt.Sprintf("猪批次 %d 销售 %d 头猪给 %s", batchID, quantity, traderName), - BeforeCount: currentQuantity, - AfterCount: currentQuantity - quantity, - OperatorID: operatorID, - } - if err := s.pigBatchLogRepo.Create(tx, log); err != nil { - return fmt.Errorf("记录销售批次日志失败: %w", err) + if err := s.updatePigBatchQuantityTx(tx, operatorID, batchID, models.ChangeTypeSale, -quantity, + fmt.Sprintf("猪批次 %d 销售 %d 头猪给 %s", batchID, quantity, traderName), + tradeDate); err != nil { + return fmt.Errorf("更新猪批次数量失败: %w", err) } return nil @@ -85,12 +77,6 @@ func (s *pigBatchService) BuyPigs(batchID uint, quantity int, unitPrice float64, return fmt.Errorf("获取猪批次 %d 信息失败: %w", batchID, err) } - // 2. 获取当前猪批次数量 - currentQuantity, err := s.getCurrentPigQuantityTx(tx, batchID) - if err != nil { - return fmt.Errorf("获取猪批次 %d 当前数量失败: %w", batchID, err) - } - // 3. 记录采购交易 purchase := &models.PigPurchase{ PigBatchID: batchID, @@ -107,18 +93,10 @@ func (s *pigBatchService) BuyPigs(batchID uint, quantity int, unitPrice float64, } // 4. 记录批次日志 - log := &models.PigBatchLog{ - PigBatchID: batchID, - HappenedAt: time.Now(), - ChangeType: models.ChangeTypeBuy, - ChangeCount: quantity, - Reason: fmt.Sprintf("猪批次 %d 采购 %d 头猪从 %s", batchID, quantity, traderName), - BeforeCount: currentQuantity, - AfterCount: currentQuantity + quantity, - OperatorID: operatorID, - } - if err := s.pigBatchLogRepo.Create(tx, log); err != nil { - return fmt.Errorf("记录采购批次日志失败: %w", err) + if err := s.updatePigBatchQuantityTx(tx, operatorID, batchID, models.ChangeTypeBuy, quantity, + fmt.Sprintf("猪批次 %d 采购 %d 头猪从 %s", batchID, quantity, traderName), + tradeDate); err != nil { + return fmt.Errorf("更新猪批次数量失败: %w", err) } return nil diff --git a/internal/infra/repository/execution_log_repository.go b/internal/infra/repository/execution_log_repository.go index d804247..33323e6 100644 --- a/internal/infra/repository/execution_log_repository.go +++ b/internal/infra/repository/execution_log_repository.go @@ -96,7 +96,7 @@ func (r *gormExecutionLogRepository) CreateTaskExecutionLogsInBatch(logs []*mode if len(logs) == 0 { return nil } - // GORM 的 Create 传入一个切片指针会执行批量插入。 + // GORM 的 CreateTx 传入一个切片指针会执行批量插入。 return r.db.Create(&logs).Error } diff --git a/internal/infra/repository/pig_batch_log_repository.go b/internal/infra/repository/pig_batch_log_repository.go index 283731f..34b4b59 100644 --- a/internal/infra/repository/pig_batch_log_repository.go +++ b/internal/infra/repository/pig_batch_log_repository.go @@ -9,8 +9,8 @@ import ( // PigBatchLogRepository 定义了与猪批次日志相关的数据库操作接口。 type PigBatchLogRepository interface { - // Create 在指定的事务中创建一条新的猪批次日志。 - Create(tx *gorm.DB, log *models.PigBatchLog) error + // CreateTx 在指定的事务中创建一条新的猪批次日志。 + CreateTx(tx *gorm.DB, log *models.PigBatchLog) error // GetLogsByBatchIDAndDateRangeTx 在指定的事务中,获取指定批次在特定时间范围内的所有日志记录。 GetLogsByBatchIDAndDateRangeTx(tx *gorm.DB, batchID uint, startDate, endDate time.Time) ([]*models.PigBatchLog, error) @@ -30,7 +30,7 @@ func NewGormPigBatchLogRepository(db *gorm.DB) PigBatchLogRepository { } // Create 实现了创建猪批次日志的逻辑。 -func (r *gormPigBatchLogRepository) Create(tx *gorm.DB, log *models.PigBatchLog) error { +func (r *gormPigBatchLogRepository) CreateTx(tx *gorm.DB, log *models.PigBatchLog) error { return tx.Create(log).Error } diff --git a/internal/infra/repository/plan_repository_test.go b/internal/infra/repository/plan_repository_test.go index a7b5389..ab511be 100644 --- a/internal/infra/repository/plan_repository_test.go +++ b/internal/infra/repository/plan_repository_test.go @@ -870,7 +870,7 @@ func TestPlanRepository_Create(t *testing.T) { type testCase struct { name string setupDB func(db *gorm.DB) // 准备数据库的初始状态 - inputPlan *models.Plan // 传入 Create 方法的计划对象 + inputPlan *models.Plan // 传入 CreateTx 方法的计划对象 expectedError error // 期望的错误类型 verifyDB func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) // 验证数据库状态 } @@ -1040,7 +1040,7 @@ func TestPlanRepository_Create(t *testing.T) { {Name: "Task 2", ExecutionOrder: 1}, // 重复的顺序 }, }, - expectedError: fmt.Errorf("任务执行顺序重复: %d", 1), // 假设 Create 方法会返回此错误 + expectedError: fmt.Errorf("任务执行顺序重复: %d", 1), // 假设 CreateTx 方法会返回此错误 verifyDB: func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) { var count int64 db.Model(&models.Plan{}).Where("name = ?", "重复任务顺序计划").Count(&count) @@ -1061,7 +1061,7 @@ func TestPlanRepository_Create(t *testing.T) { {ChildPlanID: 2, ExecutionOrder: 1}, // 重复的顺序 }, }, - expectedError: fmt.Errorf("子计划执行顺序重复: %d", 1), // 假设 Create 方法会返回此错误 + expectedError: fmt.Errorf("子计划执行顺序重复: %d", 1), // 假设 CreateTx 方法会返回此错误 verifyDB: func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) { var count int64 db.Model(&models.Plan{}).Where("name = ?", "重复子计划顺序计划").Count(&count) @@ -1078,7 +1078,7 @@ func TestPlanRepository_Create(t *testing.T) { // 准备数据库状态 tc.setupDB(db) - // 执行 Create 操作 + // 执行 CreateTx 操作 err := repo.CreatePlan(tc.inputPlan) // 断言错误