issue_29 #32

Merged
huang merged 65 commits from issue_29 into main 2025-10-07 13:33:25 +08:00
8 changed files with 66 additions and 54 deletions
Showing only changes of commit 91e18c432c - Show all commits

View File

@@ -26,7 +26,7 @@ type MockDeviceRepository struct {
mock.Mock
}
// Create 模拟 DeviceRepository 的 Create 方法
// CreateTx 模拟 DeviceRepository 的 CreateTx 方法
func (m *MockDeviceRepository) Create(device *models.Device) error {
args := m.Called(device)
return args.Error(0)
@@ -169,7 +169,7 @@ func TestCreateDevice(t *testing.T) {
Properties: controller.Properties(`{"lora_address":"0x1234"}`),
},
mockRepoSetup: func(m *MockDeviceRepository) {
m.On("Create", mock.MatchedBy(func(dev *models.Device) bool {
m.On("CreateTx", mock.MatchedBy(func(dev *models.Device) bool {
// 检查 Name 字段
nameMatch := dev.Name == "主控A"
// 检查 Type 字段
@@ -215,7 +215,7 @@ func TestCreateDevice(t *testing.T) {
Properties: controller.Properties(`{"bus_id":1,"bus_address":10}`),
},
mockRepoSetup: func(m *MockDeviceRepository) {
m.On("Create", mock.Anything).Return(nil).Run(func(args mock.Arguments) {
m.On("CreateTx", mock.Anything).Return(nil).Run(func(args mock.Arguments) {
arg := args.Get(0).(*models.Device)
arg.ID = 2
arg.CreatedAt = time.Now()
@@ -259,7 +259,7 @@ func TestCreateDevice(t *testing.T) {
Type: models.DeviceTypeDevice,
},
mockRepoSetup: func(m *MockDeviceRepository) {
m.On("Create", mock.Anything).Return(errors.New("db error")).Once()
m.On("CreateTx", mock.Anything).Return(errors.New("db error")).Once()
},
expectedStatus: http.StatusOK,
expectedCode: controller.CodeInternalError,
@@ -276,9 +276,9 @@ func TestCreateDevice(t *testing.T) {
Properties: controller.Properties(`{invalid json}`),
},
mockRepoSetup: func(m *MockDeviceRepository) {
// 期望 Create 方法被调用,并返回一个模拟的数据库错误
// 期望 CreateTx 方法被调用,并返回一个模拟的数据库错误
// 这个错误模拟的是数据库层因为 Properties 字段的 JSON 格式无效而拒绝保存
m.On("Create", mock.Anything).Return(errors.New("database error: invalid json format")).Run(func(args mock.Arguments) {
m.On("CreateTx", mock.Anything).Return(errors.New("database error: invalid json format")).Run(func(args mock.Arguments) {
dev := args.Get(0).(*models.Device)
assert.Equal(t, "无效JSON设备", dev.Name)
assert.Equal(t, models.DeviceTypeDevice, dev.Type)

View File

@@ -25,7 +25,7 @@ type MockUserRepository struct {
mock.Mock
}
// Create 模拟 UserRepository 的 Create 方法
// CreateTx 模拟 UserRepository 的 CreateTx 方法
func (m *MockUserRepository) Create(user *models.User) error {
args := m.Called(user)
return args.Error(0)
@@ -90,8 +90,8 @@ func TestCreateUser(t *testing.T) {
Password: "password123",
},
mockRepoSetup: func(m *MockUserRepository) {
// 模拟 Create 成功
m.On("Create", mock.AnythingOfType("*models.User")).Return(nil).Run(func(args mock.Arguments) {
// 模拟 CreateTx 成功
m.On("CreateTx", mock.AnythingOfType("*models.User")).Return(nil).Run(func(args mock.Arguments) {
// 模拟数据库自动填充 ID
userArg := args.Get(0).(*models.User)
userArg.ID = 1 // 设置一个非零的 ID
@@ -114,7 +114,7 @@ func TestCreateUser(t *testing.T) {
Password: "123", // 密码少于6位
},
mockRepoSetup: func(m *MockUserRepository) {
// 不会调用 Create 或 FindByUsername
// 不会调用 CreateTx 或 FindByUsername
},
expectedResponse: map[string]interface{}{
"code": float64(controller.CodeBadRequest),
@@ -128,7 +128,7 @@ func TestCreateUser(t *testing.T) {
Password: "password123",
},
mockRepoSetup: func(m *MockUserRepository) {
// 不会调用 Create 或 FindByUsername
// 不会调用 CreateTx 或 FindByUsername
},
expectedResponse: map[string]interface{}{
"code": float64(controller.CodeBadRequest),
@@ -143,8 +143,8 @@ func TestCreateUser(t *testing.T) {
Password: "password123",
},
mockRepoSetup: func(m *MockUserRepository) {
// 模拟 Create 失败,因为用户名已存在
m.On("Create", mock.AnythingOfType("*models.User")).Return(errors.New("duplicate entry")).Once()
// 模拟 CreateTx 失败,因为用户名已存在
m.On("CreateTx", mock.AnythingOfType("*models.User")).Return(errors.New("duplicate entry")).Once()
// 模拟 FindByUsername 找到用户,确认是用户名重复
m.On("FindByUsername", "existinguser").Return(&models.User{Username: "existinguser"}, nil).Once()
},
@@ -161,8 +161,8 @@ func TestCreateUser(t *testing.T) {
Password: "password123",
},
mockRepoSetup: func(m *MockUserRepository) {
// 模拟 Create 失败,通用数据库错误
m.On("Create", mock.AnythingOfType("*models.User")).Return(errors.New("database error")).Once()
// 模拟 CreateTx 失败,通用数据库错误
m.On("CreateTx", mock.AnythingOfType("*models.User")).Return(errors.New("database error")).Once()
// 模拟 FindByUsername 找不到用户,确认不是用户名重复
m.On("FindByUsername", "db_error_user").Return(nil, gorm.ErrRecordNotFound).Once()
},

View File

@@ -24,6 +24,8 @@ var (
ErrPenNotFound = errors.New("指定的猪栏不存在")
// ErrPenNotAssociatedWithBatch 表示猪栏未与该批次关联
ErrPenNotAssociatedWithBatch = errors.New("猪栏未与该批次关联")
// ErrInvalidOperation 非法操作
ErrInvalidOperation = errors.New("非法操作")
)
// --- 领域服务接口 ---
@@ -51,4 +53,6 @@ type PigBatchService interface {
SellPigs(batchID uint, quantity int, unitPrice float64, tatalPrice float64, traderName string, tradeDate time.Time, remarks string, operatorID uint) error
// BuyPigs 处理买猪的业务逻辑。
BuyPigs(batchID uint, quantity int, unitPrice float64, tatalPrice float64, traderName string, tradeDate time.Time, remarks string, operatorID uint) error
UpdatePigBatchQuantity(operatorID uint, batchID uint, changeType models.LogChangeType, changeAmount int, changeReason string, happenedAt time.Time) error
}

View File

@@ -67,7 +67,7 @@ func (s *pigBatchService) CreatePigBatch(operatorID uint, batch *models.PigBatch
}
// 3. 记录批次日志
if err := s.pigBatchLogRepo.Create(tx, initialLog); err != nil {
if err := s.pigBatchLogRepo.CreateTx(tx, initialLog); err != nil {
return fmt.Errorf("记录初始批次日志失败: %w", err)
}
@@ -274,3 +274,33 @@ func (s *pigBatchService) getCurrentPigQuantityTx(tx *gorm.DB, batchID uint) (in
// 3. 如果找到最后一条日志,则当前数量为该日志的 AfterCount
return lastLog.AfterCount, nil
}
func (s *pigBatchService) UpdatePigBatchQuantity(operatorID uint, batchID uint, changeType models.LogChangeType, changeAmount int, changeReason string, happenedAt time.Time) error {
return s.uow.ExecuteInTransaction(func(tx *gorm.DB) error {
return s.updatePigBatchQuantityTx(tx, operatorID, batchID, changeType, changeAmount, changeReason, happenedAt)
})
}
func (s *pigBatchService) updatePigBatchQuantityTx(tx *gorm.DB, operatorID uint, batchID uint, changeType models.LogChangeType, changeAmount int, changeReason string, happenedAt time.Time) error {
lastLog, err := s.pigBatchLogRepo.GetLastLogByBatchIDTx(tx, batchID)
if err != nil {
return err
}
// 检查数量不应该减到小于零
if changeAmount < 0 {
if lastLog.AfterCount+changeAmount < 0 {
return ErrInvalidOperation
}
}
pigBatchLog := &models.PigBatchLog{
PigBatchID: batchID,
ChangeType: changeType,
ChangeCount: changeAmount,
Reason: changeReason,
BeforeCount: lastLog.AfterCount,
AfterCount: lastLog.AfterCount + changeAmount,
OperatorID: operatorID,
HappenedAt: happenedAt,
}
return s.pigBatchLogRepo.CreateTx(tx, pigBatchLog)
}

View File

@@ -51,18 +51,10 @@ func (s *pigBatchService) SellPigs(batchID uint, quantity int, unitPrice float64
}
// 4. 记录批次日志
log := &models.PigBatchLog{
PigBatchID: batchID,
HappenedAt: time.Now(),
ChangeType: models.ChangeTypeSale,
ChangeCount: -quantity,
Reason: fmt.Sprintf("猪批次 %d 销售 %d 头猪给 %s", batchID, quantity, traderName),
BeforeCount: currentQuantity,
AfterCount: currentQuantity - quantity,
OperatorID: operatorID,
}
if err := s.pigBatchLogRepo.Create(tx, log); err != nil {
return fmt.Errorf("记录销售批次日志失败: %w", err)
if err := s.updatePigBatchQuantityTx(tx, operatorID, batchID, models.ChangeTypeSale, -quantity,
fmt.Sprintf("猪批次 %d 销售 %d 头猪给 %s", batchID, quantity, traderName),
tradeDate); err != nil {
return fmt.Errorf("更新猪批次数量失败: %w", err)
}
return nil
@@ -85,12 +77,6 @@ func (s *pigBatchService) BuyPigs(batchID uint, quantity int, unitPrice float64,
return fmt.Errorf("获取猪批次 %d 信息失败: %w", batchID, err)
}
// 2. 获取当前猪批次数量
currentQuantity, err := s.getCurrentPigQuantityTx(tx, batchID)
if err != nil {
return fmt.Errorf("获取猪批次 %d 当前数量失败: %w", batchID, err)
}
// 3. 记录采购交易
purchase := &models.PigPurchase{
PigBatchID: batchID,
@@ -107,18 +93,10 @@ func (s *pigBatchService) BuyPigs(batchID uint, quantity int, unitPrice float64,
}
// 4. 记录批次日志
log := &models.PigBatchLog{
PigBatchID: batchID,
HappenedAt: time.Now(),
ChangeType: models.ChangeTypeBuy,
ChangeCount: quantity,
Reason: fmt.Sprintf("猪批次 %d 采购 %d 头猪从 %s", batchID, quantity, traderName),
BeforeCount: currentQuantity,
AfterCount: currentQuantity + quantity,
OperatorID: operatorID,
}
if err := s.pigBatchLogRepo.Create(tx, log); err != nil {
return fmt.Errorf("记录采购批次日志失败: %w", err)
if err := s.updatePigBatchQuantityTx(tx, operatorID, batchID, models.ChangeTypeBuy, quantity,
fmt.Sprintf("猪批次 %d 采购 %d 头猪从 %s", batchID, quantity, traderName),
tradeDate); err != nil {
return fmt.Errorf("更新猪批次数量失败: %w", err)
}
return nil

View File

@@ -96,7 +96,7 @@ func (r *gormExecutionLogRepository) CreateTaskExecutionLogsInBatch(logs []*mode
if len(logs) == 0 {
return nil
}
// GORM 的 Create 传入一个切片指针会执行批量插入。
// GORM 的 CreateTx 传入一个切片指针会执行批量插入。
return r.db.Create(&logs).Error
}

View File

@@ -9,8 +9,8 @@ import (
// PigBatchLogRepository 定义了与猪批次日志相关的数据库操作接口。
type PigBatchLogRepository interface {
// Create 在指定的事务中创建一条新的猪批次日志。
Create(tx *gorm.DB, log *models.PigBatchLog) error
// CreateTx 在指定的事务中创建一条新的猪批次日志。
CreateTx(tx *gorm.DB, log *models.PigBatchLog) error
// GetLogsByBatchIDAndDateRangeTx 在指定的事务中,获取指定批次在特定时间范围内的所有日志记录。
GetLogsByBatchIDAndDateRangeTx(tx *gorm.DB, batchID uint, startDate, endDate time.Time) ([]*models.PigBatchLog, error)
@@ -30,7 +30,7 @@ func NewGormPigBatchLogRepository(db *gorm.DB) PigBatchLogRepository {
}
// Create 实现了创建猪批次日志的逻辑。
func (r *gormPigBatchLogRepository) Create(tx *gorm.DB, log *models.PigBatchLog) error {
func (r *gormPigBatchLogRepository) CreateTx(tx *gorm.DB, log *models.PigBatchLog) error {
return tx.Create(log).Error
}

View File

@@ -870,7 +870,7 @@ func TestPlanRepository_Create(t *testing.T) {
type testCase struct {
name string
setupDB func(db *gorm.DB) // 准备数据库的初始状态
inputPlan *models.Plan // 传入 Create 方法的计划对象
inputPlan *models.Plan // 传入 CreateTx 方法的计划对象
expectedError error // 期望的错误类型
verifyDB func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) // 验证数据库状态
}
@@ -1040,7 +1040,7 @@ func TestPlanRepository_Create(t *testing.T) {
{Name: "Task 2", ExecutionOrder: 1}, // 重复的顺序
},
},
expectedError: fmt.Errorf("任务执行顺序重复: %d", 1), // 假设 Create 方法会返回此错误
expectedError: fmt.Errorf("任务执行顺序重复: %d", 1), // 假设 CreateTx 方法会返回此错误
verifyDB: func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) {
var count int64
db.Model(&models.Plan{}).Where("name = ?", "重复任务顺序计划").Count(&count)
@@ -1061,7 +1061,7 @@ func TestPlanRepository_Create(t *testing.T) {
{ChildPlanID: 2, ExecutionOrder: 1}, // 重复的顺序
},
},
expectedError: fmt.Errorf("子计划执行顺序重复: %d", 1), // 假设 Create 方法会返回此错误
expectedError: fmt.Errorf("子计划执行顺序重复: %d", 1), // 假设 CreateTx 方法会返回此错误
verifyDB: func(t *testing.T, db *gorm.DB, createdPlan *models.Plan) {
var count int64
db.Model(&models.Plan{}).Where("name = ?", "重复子计划顺序计划").Count(&count)
@@ -1078,7 +1078,7 @@ func TestPlanRepository_Create(t *testing.T) {
// 准备数据库状态
tc.setupDB(db)
// 执行 Create 操作
// 执行 CreateTx 操作
err := repo.CreatePlan(tc.inputPlan)
// 断言错误