diff --git a/Makefile b/Makefile index 6dcde32..7d38b0b 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,7 @@ help: @echo " run Run the application" @echo " build Build the application" @echo " clean Clean generated files" + @echo " test Run all tests" @echo " help Show this help message" # 运行应用 @@ -24,4 +25,9 @@ build: # 清理生成文件 .PHONY: clean clean: - rm -f bin/pig-farm-controller \ No newline at end of file + rm -f bin/pig-farm-controller + +# 运行所有测试 +.PHONY: test +test: + go test --count=1 ./... \ No newline at end of file diff --git a/internal/app/service/token/token_service.go b/internal/app/service/token/token_service.go index b2bbb12..361e305 100644 --- a/internal/app/service/token/token_service.go +++ b/internal/app/service/token/token_service.go @@ -1,6 +1,7 @@ package token import ( + "fmt" "time" "github.com/golang-jwt/jwt/v5" @@ -52,9 +53,16 @@ func (s *tokenService) ParseToken(tokenString string) (*Claims, error) { return s.secret, nil }) + // 优先检查解析过程中是否发生错误 + if err != nil { + return nil, err + } + + // 只有当 token 对象有效时,才尝试获取 Claims 并验证 if claims, ok := token.Claims.(*Claims); ok && token.Valid { return claims, nil } - return nil, err + // 如果 token 无效(例如,过期但没有返回错误,或者 Claims 类型不匹配),则返回一个通用错误 + return nil, fmt.Errorf("token is invalid") } diff --git a/internal/app/service/token/token_service_test.go b/internal/app/service/token/token_service_test.go index 7f5198e..6c01426 100644 --- a/internal/app/service/token/token_service_test.go +++ b/internal/app/service/token/token_service_test.go @@ -1,6 +1,7 @@ package token_test import ( + "errors" "testing" "time" @@ -45,36 +46,38 @@ func TestGenerateToken(t *testing.T) { } func TestParseToken(t *testing.T) { - // 使用一个测试密钥初始化 TokenService - testSecret := []byte("test_secret_key") - service := token.NewTokenService(testSecret) + // 使用两个不同的测试密钥 + correctSecret := []byte("the_correct_secret") + wrongSecret := []byte("a_very_wrong_secret") + + serviceWithCorrectKey := token.NewTokenService(correctSecret) + serviceWithWrongKey := token.NewTokenService(wrongSecret) userID := uint(456) - // 生成一个有效的 token 用于解析测试 - validToken, err := service.GenerateToken(userID) + // 1. 生成一个有效的 token + validToken, err := serviceWithCorrectKey.GenerateToken(userID) if err != nil { t.Fatalf("为解析测试生成有效令牌失败: %v", err) } - // 测试用例 1: 有效 token - claims, err := service.ParseToken(validToken) + // 测试用例 1: 使用正确的密钥成功解析 + claims, err := serviceWithCorrectKey.ParseToken(validToken) if err != nil { - t.Errorf("解析有效令牌失败: %v", err) + t.Errorf("使用正确密钥解析有效令牌失败: %v", err) } - if claims.UserID != userID { t.Errorf("解析有效令牌时期望用户ID %d, 实际为 %d", userID, claims.UserID) } - // 测试用例 2: 无效 token (例如, 错误的签名) - invalidToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoxMjMsImV4cCI6MTY3ODkwNTYwMCwiaXNzIjoicGlnLWZhcm0tY29udHJvbGxlciJ9.invalid_signature_here" - _, err = service.ParseToken(invalidToken) + // 测试用例 2: 无效 token (例如, 格式错误的字符串) + invalidTokenString := "this.is.not.a.valid.jwt" + _, err = serviceWithCorrectKey.ParseToken(invalidTokenString) if err == nil { - t.Error("解析无效令牌意外成功") + t.Error("解析格式错误的令牌意外成功") } - // 测试用例 3: 过期 token (创建一个过期时间在过去的 token) + // 测试用C:\Users\divano\Desktop\work\AA-Pig\pig-farm-controller\internal\infra\repository\plan_repository_test.go例 3: 过期 token expiredClaims := token.Claims{ UserID: userID, RegisteredClaims: jwt.RegisteredClaims{ @@ -83,13 +86,22 @@ func TestParseToken(t *testing.T) { }, } expiredTokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, expiredClaims) - expiredTokenString, err := expiredTokenClaims.SignedString(testSecret) + expiredTokenString, err := expiredTokenClaims.SignedString(correctSecret) if err != nil { t.Fatalf("生成过期令牌失败: %v", err) } - - _, err = service.ParseToken(expiredTokenString) + _, err = serviceWithCorrectKey.ParseToken(expiredTokenString) if err == nil { t.Error("解析过期令牌意外成功") } + + // 新增测试用例 4: 使用错误的密钥解析 + _, err = serviceWithWrongKey.ParseToken(validToken) + if err == nil { + t.Error("使用错误密钥解析令牌意外成功") + } + // 我们可以更精确地检查错误类型,以确保它是签名错误 + if !errors.Is(err, jwt.ErrTokenSignatureInvalid) { + t.Errorf("期望得到签名无效错误 (ErrTokenSignatureInvalid),但得到了: %v", err) + } } diff --git a/internal/infra/config/config.go b/internal/infra/config/config.go index df19122..74679bf 100644 --- a/internal/infra/config/config.go +++ b/internal/infra/config/config.go @@ -127,17 +127,3 @@ func (c *Config) Load(path string) error { return nil } - -// GetDatabaseConnectionString 获取数据库连接字符串 -func (c *Config) GetDatabaseConnectionString() string { - // 构建PostgreSQL连接字符串 - return fmt.Sprintf( - "user=%s password=%s dbname=%s host=%s port=%d sslmode=%s", - c.Database.Username, - c.Database.Password, - c.Database.DBName, - c.Database.Host, - c.Database.Port, - c.Database.SSLMode, - ) -} diff --git a/internal/infra/logs/logs.go b/internal/infra/logs/logs.go index bf16752..63f4799 100644 --- a/internal/infra/logs/logs.go +++ b/internal/infra/logs/logs.go @@ -27,7 +27,7 @@ type Logger struct { // 这是实现依赖注入的关键,在应用启动时调用一次。 func NewLogger(cfg config.LogConfig) *Logger { // 1. 设置日志编码器 - encoder := getEncoder(cfg.Format) + encoder := GetEncoder(cfg.Format) // 2. 设置日志写入器 (支持文件和控制台) writeSyncer := getWriteSyncer(cfg) @@ -49,8 +49,8 @@ func NewLogger(cfg config.LogConfig) *Logger { return &Logger{zapLogger.Sugar()} } -// getEncoder 根据指定的格式返回一个 zapcore.Encoder。 -func getEncoder(format string) zapcore.Encoder { +// GetEncoder 根据指定的格式返回一个 zapcore.Encoder。 +func GetEncoder(format string) zapcore.Encoder { encoderConfig := zap.NewProductionEncoderConfig() encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder // 时间格式: 2006-01-02T15:04:05.000Z0700 encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder // 日志级别大写: INFO diff --git a/internal/infra/logs/logs_test.go b/internal/infra/logs/logs_test.go index abb4cff..b5c3f9c 100644 --- a/internal/infra/logs/logs_test.go +++ b/internal/infra/logs/logs_test.go @@ -3,6 +3,7 @@ package logs_test import ( "bytes" "context" + "encoding/json" "errors" "strings" "testing" @@ -18,12 +19,8 @@ import ( // captureOutput 是一个辅助函数,用于捕获 logger 的输出到内存缓冲区 func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) { var buf bytes.Buffer - // 使用一个简单的 Console Encoder 进行测试,方便断言字符串 - encoderConfig := zap.NewDevelopmentEncoderConfig() - encoderConfig.EncodeTime = nil // 忽略时间,避免测试结果不一致 - encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder - encoderConfig.EncodeCaller = nil // 忽略调用者信息 - encoder := zapcore.NewConsoleEncoder(encoderConfig) + + encoder := logs.GetEncoder(cfg.Format) writer := zapcore.AddSync(&buf) @@ -31,29 +28,52 @@ func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) { _ = level.UnmarshalText([]byte(cfg.Level)) core := zapcore.NewCore(encoder, writer, level) - // 在测试中我们直接操作 zap Logger,而不是通过封装的 NewLogger,以注入内存 writer - zapLogger := zap.New(core) + // 匹配 logs.go 中 NewLogger 的行为,添加调用者信息 + zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)) logger := &logs.Logger{SugaredLogger: zapLogger.Sugar()} return logger, &buf } - func TestNewLogger(t *testing.T) { - t.Run("构造函数不会 panic", func(t *testing.T) { - // 测试 Console 格式 - cfgConsole := config.LogConfig{Level: "info", Format: "console"} - assert.NotPanics(t, func() { logs.NewLogger(cfgConsole) }) + t.Run("日志级别应生效", func(t *testing.T) { + // 1. 创建一个级别为 WARN 的 logger + logger, buf := captureOutput(config.LogConfig{Level: "warn", Format: "console"}) - // 测试 JSON 格式 - cfgJSON := config.LogConfig{Level: "info", Format: "json"} - assert.NotPanics(t, func() { logs.NewLogger(cfgJSON) }) + // 2. 调用不同级别的日志方法 + logger.Info("这条 info 日志不应被打印") + logger.Warn("这条 warn 日志应该被打印") - // 测试文件日志启用 - // 不实际写入文件,只确保构造函数能正常运行 + // 3. 断言输出 + output := buf.String() + assert.NotContains(t, output, "这条 info 日志不应被打印") + assert.Contains(t, output, "这条 warn 日志应该被打印") + }) + + t.Run("JSON 格式应生效", func(t *testing.T) { + // 1. 创建一个格式为 JSON 的 logger + logger, buf := captureOutput(config.LogConfig{Level: "info", Format: "json"}) + + // 2. 打印一条日志 + logger.Info("测试json输出") + + // 3. 断言输出 + output := buf.String() + // 验证它是否是合法的 JSON,并且包含预期的键值对 + var logEntry map[string]interface{} + // 注意:由于日志库可能会在行尾添加换行符,我们先 trim space + err := json.Unmarshal([]byte(strings.TrimSpace(output)), &logEntry) + assert.NoError(t, err, "日志输出应为合法的JSON") + assert.Equal(t, "INFO", logEntry["level"]) + assert.Equal(t, "测试json输出", logEntry["msg"]) + }) + + t.Run("文件日志构造函数不应 panic", func(t *testing.T) { + // 这个测试保持原样,只验证构造函数在启用文件时不会崩溃 + // 注意:我们不在单元测试中实际写入文件 cfgFile := config.LogConfig{ Level: "info", EnableFile: true, - FilePath: "test.log", + FilePath: "test.log", // 在测试环境中,这个文件不会被真正创建 } assert.NotPanics(t, func() { logs.NewLogger(cfgFile) }) }) @@ -84,7 +104,7 @@ func TestGormLogger(t *testing.T) { return sql, rows } - t.Run("Slow Query", func(t *testing.T) { + t.Run("慢查询应记录为警告", func(t *testing.T) { buf.Reset() // 模拟一个耗时超过 200ms 的查询 begin := time.Now().Add(-300 * time.Millisecond) @@ -96,7 +116,7 @@ func TestGormLogger(t *testing.T) { assert.Contains(t, output, "SELECT * FROM users WHERE id = 1", "应包含 SQL 语句") }) - t.Run("Error Query", func(t *testing.T) { + t.Run("普通错误应记录为Error", func(t *testing.T) { buf.Reset() queryError := errors.New("syntax error") gormLogger.Trace(ctx, time.Now(), fc, queryError) @@ -106,8 +126,10 @@ func TestGormLogger(t *testing.T) { assert.Contains(t, output, "[GORM] error: syntax error") }) - t.Run("Record Not Found Error is Skipped", func(t *testing.T) { + t.Run("当SkipErrRecordNotFound为true时应跳过RecordNotFound错误", func(t *testing.T) { buf.Reset() + // 确保默认设置是 true + gormLogger.SkipErrRecordNotFound = true // 错误必须包含 "record not found" 字符串以匹配 logs.go 中的判断逻辑 queryError := errors.New("record not found") gormLogger.Trace(ctx, time.Now(), fc, queryError) @@ -115,7 +137,24 @@ func TestGormLogger(t *testing.T) { assert.Empty(t, buf.String(), "开启 SkipErrRecordNotFound 后,record not found 错误不应产生任何日志") }) - t.Run("Normal Query", func(t *testing.T) { + t.Run("当SkipErrRecordNotFound为false时应记录RecordNotFound错误", func(t *testing.T) { + buf.Reset() + // 手动将 SkipErrRecordNotFound 设置为 false + gormLogger.SkipErrRecordNotFound = false + + queryError := errors.New("record not found") + gormLogger.Trace(ctx, time.Now(), fc, queryError) + + // 恢复设置,避免影响其他测试 + gormLogger.SkipErrRecordNotFound = true + + output := buf.String() + assert.NotEmpty(t, output, "关闭 SkipErrRecordNotFound 后,record not found 错误应该产生日志") + assert.Contains(t, output, "ERROR") + assert.Contains(t, output, "[GORM] error: record not found") + }) + + t.Run("正常查询应记录为Debug", func(t *testing.T) { buf.Reset() // 模拟一个快速查询 gormLogger.Trace(ctx, time.Now(), fc, nil) diff --git a/internal/infra/models/device.go b/internal/infra/models/device.go index cbae8f3..00b2dbf 100644 --- a/internal/infra/models/device.go +++ b/internal/infra/models/device.go @@ -58,7 +58,7 @@ type Device struct { gorm.Model // Name 是设备的业务名称,应清晰可读,例如 "1号猪舍温度传感器" 或 "做料车间主控" - Name string `gorm:"not null" json:"name"` + Name string `gorm:"unique;not null" json:"name"` // Type 是设备的高级类别,用于区分区域主控和普通设备。建立索引以优化按类型查询。 Type DeviceType `gorm:"not null;index" json:"type"` diff --git a/internal/infra/models/user_test.go b/internal/infra/models/user_test.go index 877abb2..338959f 100644 --- a/internal/infra/models/user_test.go +++ b/internal/infra/models/user_test.go @@ -38,7 +38,37 @@ func TestUser_CheckPassword(t *testing.T) { assert.False(t, match, "空密码应该校验失败") }) } +func TestUser_BeforeCreate(t *testing.T) { + t.Run("密码应被成功哈希", func(t *testing.T) { + plainPassword := "securepassword123" + user := &models.User{ + Username: "testuser", + Password: plainPassword, + } -// 注意:BeforeSave 钩子是一个 GORM 框架的回调,它的正确性 -// 将在 repository 的集成测试中,通过实际创建一个用户来得到验证, -// 而不是在这里进行孤立的、脆弱的单元测试。 + // 模拟 GORM 钩子调用 + err := user.BeforeCreate(nil) // GORM 钩子通常接收 *gorm.DB,这里我们传入 nil,因为 BeforeCreate 不依赖 DB + assert.NoError(t, err, "BeforeCreate 不应返回错误") + + // 验证密码是否已被哈希(不再是明文) + assert.NotEqual(t, plainPassword, user.Password, "密码应已被哈希") + + // 验证哈希后的密码是否能被正确校验 + assert.True(t, user.CheckPassword(plainPassword), "哈希后的密码应能通过校验") + }) + + t.Run("空密码不应被哈希", func(t *testing.T) { + plainPassword := "" + user := &models.User{ + Username: "empty_pass_user", + Password: plainPassword, + } + + // 模拟 GORM 钩子调用 + err := user.BeforeCreate(nil) + assert.NoError(t, err, "BeforeCreate 不应返回错误") + + // 验证密码仍然是空字符串 + assert.Equal(t, plainPassword, user.Password, "空密码不应被哈希") + }) +} diff --git a/internal/infra/repository/device_repository_test.go b/internal/infra/repository/device_repository_test.go index 034dbe3..03e1b7e 100644 --- a/internal/infra/repository/device_repository_test.go +++ b/internal/infra/repository/device_repository_test.go @@ -11,97 +11,252 @@ import ( "gorm.io/gorm" ) -func TestGormDeviceRepository(t *testing.T) { +// createTestDevice 辅助函数,用于创建测试设备 +func createTestDevice(t *testing.T, db *gorm.DB, name string, deviceType models.DeviceType, parentID *uint) *models.Device { + device := &models.Device{ + Name: name, + Type: deviceType, + ParentID: parentID, + // 其他字段可以根据需要添加 + } + err := db.Create(device).Error + assert.NoError(t, err) + return device +} + +func TestRepoCreate(t *testing.T) { db := setupTestDB(t) repo := repository.NewGormDeviceRepository(db) - // --- 准备测试数据 --- loraProps, _ := json.Marshal(models.LoraProperties{LoraAddress: "0xABCD"}) - busProps, _ := json.Marshal(models.BusProperties{BusID: 1, BusAddress: 10}) - areaController := &models.Device{ - Name: "1号猪舍主控", - Type: models.DeviceTypeAreaController, - Location: "1号猪舍", - Properties: loraProps, - } - - t.Run("创建 - 成功创建区域主控", func(t *testing.T) { - err := repo.Create(areaController) + t.Run("成功创建区域主控", func(t *testing.T) { + device := &models.Device{ + Name: "主控A", + Type: models.DeviceTypeAreaController, + Location: "猪舍1", + Properties: loraProps, + } + err := repo.Create(device) assert.NoError(t, err) - assert.NotZero(t, areaController.ID, "创建后应获得一个非零ID") - assert.Nil(t, areaController.ParentID, "区域主控的 ParentID 应为 nil") + assert.NotZero(t, device.ID, "创建后应获得一个非零ID") + assert.Nil(t, device.ParentID, "区域主控的 ParentID 应为 nil") }) - var createdDevice *models.Device - t.Run("通过ID查找 - 成功找到已创建的设备", func(t *testing.T) { - var err error - createdDevice, err = repo.FindByID(areaController.ID) + t.Run("成功创建子设备", func(t *testing.T) { + parent := createTestDevice(t, db, "父设备", models.DeviceTypeAreaController, nil) + child := &models.Device{ + Name: "子设备A", + Type: models.DeviceTypeDevice, + ParentID: &parent.ID, + } + err := repo.Create(child) assert.NoError(t, err) - assert.NotNil(t, createdDevice) - assert.Equal(t, areaController.Name, createdDevice.Name) + assert.NotZero(t, child.ID) + assert.NotNil(t, child.ParentID) + assert.Equal(t, parent.ID, *child.ParentID) }) +} - t.Run("通过字符串ID查找 - 使用有效字符串ID找到设备", func(t *testing.T) { - foundDevice, err := repo.FindByIDString(strconv.FormatUint(uint64(areaController.ID), 10)) +func TestRepoFindByID(t *testing.T) { + db := setupTestDB(t) + repo := repository.NewGormDeviceRepository(db) + + device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil) + + t.Run("成功通过ID查找", func(t *testing.T) { + foundDevice, err := repo.FindByID(device.ID) assert.NoError(t, err) assert.NotNil(t, foundDevice) - assert.Equal(t, areaController.ID, foundDevice.ID) + assert.Equal(t, device.ID, foundDevice.ID) + assert.Equal(t, device.Name, foundDevice.Name) }) - t.Run("通过字符串ID查找 - 使用无效字符串ID", func(t *testing.T) { + t.Run("查找不存在的ID", func(t *testing.T) { + _, err := repo.FindByID(9999) // 不存在的ID + assert.Error(t, err) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + }) + + t.Run("数据库查询失败", func(t *testing.T) { + // 模拟数据库连接关闭,强制查询失败 + sqlDB, _ := db.DB() + sqlDB.Close() + + _, err := repo.FindByID(device.ID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "database is closed") + }) +} + +func TestRepoFindByIDString(t *testing.T) { + db := setupTestDB(t) + repo := repository.NewGormDeviceRepository(db) + + device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil) + + t.Run("成功通过字符串ID查找", func(t *testing.T) { + idStr := strconv.FormatUint(uint64(device.ID), 10) + foundDevice, err := repo.FindByIDString(idStr) + assert.NoError(t, err) + assert.NotNil(t, foundDevice) + assert.Equal(t, device.ID, foundDevice.ID) + }) + + t.Run("无效的字符串ID格式", func(t *testing.T) { _, err := repo.FindByIDString("invalid-id") - assert.Error(t, err, "使用无效ID字符串应返回错误") + assert.Error(t, err) + assert.Contains(t, err.Error(), "无效的设备ID格式") }) - // 创建一个子设备 - childDevice := &models.Device{ - Name: "1号猪舍温度传感器", - Type: models.DeviceTypeDevice, - SubType: models.SubTypeSensorTemp, - ParentID: &areaController.ID, - Location: "1号猪舍东侧", - Properties: busProps, - } + t.Run("查找不存在的字符串ID", func(t *testing.T) { + idStr := strconv.FormatUint(uint64(9999), 10) // 不存在的ID + _, err := repo.FindByIDString(idStr) + assert.Error(t, err) + assert.ErrorIs(t, err, gorm.ErrRecordNotFound) + }) - t.Run("创建 - 成功创建子设备", func(t *testing.T) { - err := repo.Create(childDevice) + t.Run("数据库查询失败", func(t *testing.T) { + sqlDB, _ := db.DB() + sqlDB.Close() + + idStr := strconv.FormatUint(uint64(device.ID), 10) + _, err := repo.FindByIDString(idStr) + assert.Error(t, err) + assert.Contains(t, err.Error(), "database is closed") + }) +} + +func TestRepoListAll(t *testing.T) { + db := setupTestDB(t) + repo := repository.NewGormDeviceRepository(db) + + t.Run("成功获取空列表", func(t *testing.T) { + devices, err := repo.ListAll() assert.NoError(t, err) - assert.NotZero(t, childDevice.ID) - assert.NotNil(t, childDevice.ParentID) - assert.Equal(t, areaController.ID, *childDevice.ParentID) + assert.Empty(t, devices) }) - t.Run("通过父ID列出 - 找到子设备", func(t *testing.T) { - children, err := repo.ListByParentID(&areaController.ID) + t.Run("成功获取包含设备的列表", func(t *testing.T) { + createTestDevice(t, db, "设备1", models.DeviceTypeAreaController, nil) + createTestDevice(t, db, "设备2", models.DeviceTypeDevice, nil) + + devices, err := repo.ListAll() assert.NoError(t, err) - assert.Len(t, children, 1, "应找到一个子设备") - assert.Equal(t, childDevice.ID, children[0].ID) + assert.Len(t, devices, 2) + assert.Equal(t, "设备1", devices[0].Name) + assert.Equal(t, "设备2", devices[1].Name) }) - t.Run("通过父ID列出 - 找到顶层设备", func(t *testing.T) { + t.Run("数据库查询失败", func(t *testing.T) { + sqlDB, _ := db.DB() + sqlDB.Close() + + _, err := repo.ListAll() + assert.Error(t, err) + assert.Contains(t, err.Error(), "database is closed") + }) +} + +func TestRepoListByParentID(t *testing.T) { + db := setupTestDB(t) + repo := repository.NewGormDeviceRepository(db) + + parent1 := createTestDevice(t, db, "父设备1", models.DeviceTypeAreaController, nil) + parent2 := createTestDevice(t, db, "父设备2", models.DeviceTypeAreaController, nil) + child1_1 := createTestDevice(t, db, "子设备1-1", models.DeviceTypeDevice, &parent1.ID) + child1_2 := createTestDevice(t, db, "子设备1-2", models.DeviceTypeDevice, &parent1.ID) + _ = createTestDevice(t, db, "子设备2-1", models.DeviceTypeDevice, &parent2.ID) + + t.Run("成功通过父ID查找子设备", func(t *testing.T) { + children, err := repo.ListByParentID(&parent1.ID) + assert.NoError(t, err) + assert.Len(t, children, 2) + assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[0].ID) + assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[1].ID) + }) + + t.Run("成功通过nil父ID查找顶层设备", func(t *testing.T) { parents, err := repo.ListByParentID(nil) assert.NoError(t, err) - assert.Len(t, parents, 1, "应找到一个顶层设备") - assert.Equal(t, areaController.ID, parents[0].ID) + assert.Len(t, parents, 2) + assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[0].ID) + assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[1].ID) }) - t.Run("更新 - 成功更新设备信息", func(t *testing.T) { - childDevice.Location = "1号猪舍西侧" - err := repo.Update(childDevice) + t.Run("查找不存在的父ID", func(t *testing.T) { + nonExistentParentID := uint(9999) + children, err := repo.ListByParentID(&nonExistentParentID) + assert.NoError(t, err) // GORM 在未找到时返回空列表而不是错误 + assert.Empty(t, children) + }) + + t.Run("数据库查询失败", func(t *testing.T) { + sqlDB, _ := db.DB() + sqlDB.Close() + + _, err := repo.ListByParentID(&parent1.ID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "database is closed") + }) +} + +func TestRepoUpdate(t *testing.T) { + db := setupTestDB(t) + repo := repository.NewGormDeviceRepository(db) + + device := createTestDevice(t, db, "原始设备", models.DeviceTypeAreaController, nil) + + t.Run("成功更新设备信息", func(t *testing.T) { + device.Name = "更新后的设备" + device.Location = "新地点" + err := repo.Update(device) assert.NoError(t, err) - updatedDevice, _ := repo.FindByID(childDevice.ID) - assert.Equal(t, "1号猪舍西侧", updatedDevice.Location) + updatedDevice, err := repo.FindByID(device.ID) + assert.NoError(t, err) + assert.Equal(t, "更新后的设备", updatedDevice.Name) + assert.Equal(t, "新地点", updatedDevice.Location) }) - t.Run("删除 - 成功删除设备", func(t *testing.T) { - err := repo.Delete(childDevice.ID) + t.Run("数据库更新失败", func(t *testing.T) { + sqlDB, _ := db.DB() + sqlDB.Close() + + device.Name = "更新失败的设备" + err := repo.Update(device) + assert.Error(t, err) + assert.Contains(t, err.Error(), "database is closed") + }) +} + +func TestRepoDelete(t *testing.T) { + db := setupTestDB(t) + repo := repository.NewGormDeviceRepository(db) + + device := createTestDevice(t, db, "待删除设备", models.DeviceTypeAreaController, nil) + + t.Run("成功删除设备", func(t *testing.T) { + err := repo.Delete(device.ID) assert.NoError(t, err) // 验证设备已被软删除 - _, err = repo.FindByID(childDevice.ID) + _, err = repo.FindByID(device.ID) assert.Error(t, err, "删除后应无法找到设备") assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 RecordNotFound") }) + + t.Run("删除不存在的设备", func(t *testing.T) { + err := repo.Delete(9999) // 不存在的ID + assert.NoError(t, err) // GORM 的 Delete 方法在删除不存在的记录时不会报错 + }) + + t.Run("数据库删除失败", func(t *testing.T) { + sqlDB, _ := db.DB() + sqlDB.Close() + + err := repo.Delete(device.ID) + assert.Error(t, err) + assert.Contains(t, err.Error(), "database is closed") + }) }