补充测试用例

This commit is contained in:
2025-09-13 19:48:13 +08:00
parent ec2595a751
commit bd22e452d3
9 changed files with 353 additions and 117 deletions

View File

@@ -9,6 +9,7 @@ help:
@echo " run Run the application" @echo " run Run the application"
@echo " build Build the application" @echo " build Build the application"
@echo " clean Clean generated files" @echo " clean Clean generated files"
@echo " test Run all tests"
@echo " help Show this help message" @echo " help Show this help message"
# 运行应用 # 运行应用
@@ -24,4 +25,9 @@ build:
# 清理生成文件 # 清理生成文件
.PHONY: clean .PHONY: clean
clean: clean:
rm -f bin/pig-farm-controller rm -f bin/pig-farm-controller
# 运行所有测试
.PHONY: test
test:
go test --count=1 ./...

View File

@@ -1,6 +1,7 @@
package token package token
import ( import (
"fmt"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@@ -52,9 +53,16 @@ func (s *tokenService) ParseToken(tokenString string) (*Claims, error) {
return s.secret, nil return s.secret, nil
}) })
// 优先检查解析过程中是否发生错误
if err != nil {
return nil, err
}
// 只有当 token 对象有效时,才尝试获取 Claims 并验证
if claims, ok := token.Claims.(*Claims); ok && token.Valid { if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil return claims, nil
} }
return nil, err // 如果 token 无效(例如,过期但没有返回错误,或者 Claims 类型不匹配),则返回一个通用错误
return nil, fmt.Errorf("token is invalid")
} }

View File

@@ -1,6 +1,7 @@
package token_test package token_test
import ( import (
"errors"
"testing" "testing"
"time" "time"
@@ -45,36 +46,38 @@ func TestGenerateToken(t *testing.T) {
} }
func TestParseToken(t *testing.T) { func TestParseToken(t *testing.T) {
// 使用一个测试密钥初始化 TokenService // 使用两个不同的测试密钥
testSecret := []byte("test_secret_key") correctSecret := []byte("the_correct_secret")
service := token.NewTokenService(testSecret) wrongSecret := []byte("a_very_wrong_secret")
serviceWithCorrectKey := token.NewTokenService(correctSecret)
serviceWithWrongKey := token.NewTokenService(wrongSecret)
userID := uint(456) userID := uint(456)
// 生成一个有效的 token 用于解析测试 // 1. 生成一个有效的 token
validToken, err := service.GenerateToken(userID) validToken, err := serviceWithCorrectKey.GenerateToken(userID)
if err != nil { if err != nil {
t.Fatalf("为解析测试生成有效令牌失败: %v", err) t.Fatalf("为解析测试生成有效令牌失败: %v", err)
} }
// 测试用例 1: 有效 token // 测试用例 1: 使用正确的密钥成功解析
claims, err := service.ParseToken(validToken) claims, err := serviceWithCorrectKey.ParseToken(validToken)
if err != nil { if err != nil {
t.Errorf("解析有效令牌失败: %v", err) t.Errorf("使用正确密钥解析有效令牌失败: %v", err)
} }
if claims.UserID != userID { if claims.UserID != userID {
t.Errorf("解析有效令牌时期望用户ID %d, 实际为 %d", userID, claims.UserID) t.Errorf("解析有效令牌时期望用户ID %d, 实际为 %d", userID, claims.UserID)
} }
// 测试用例 2: 无效 token (例如, 错误的签名) // 测试用例 2: 无效 token (例如, 格式错误的字符串)
invalidToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoxMjMsImV4cCI6MTY3ODkwNTYwMCwiaXNzIjoicGlnLWZhcm0tY29udHJvbGxlciJ9.invalid_signature_here" invalidTokenString := "this.is.not.a.valid.jwt"
_, err = service.ParseToken(invalidToken) _, err = serviceWithCorrectKey.ParseToken(invalidTokenString)
if err == nil { if err == nil {
t.Error("解析无效令牌意外成功") t.Error("解析格式错误的令牌意外成功")
} }
// 测试用例 3: 过期 token (创建一个过期时间在过去的 token) // 测试用C:\Users\divano\Desktop\work\AA-Pig\pig-farm-controller\internal\infra\repository\plan_repository_test.go例 3: 过期 token
expiredClaims := token.Claims{ expiredClaims := token.Claims{
UserID: userID, UserID: userID,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
@@ -83,13 +86,22 @@ func TestParseToken(t *testing.T) {
}, },
} }
expiredTokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, expiredClaims) expiredTokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, expiredClaims)
expiredTokenString, err := expiredTokenClaims.SignedString(testSecret) expiredTokenString, err := expiredTokenClaims.SignedString(correctSecret)
if err != nil { if err != nil {
t.Fatalf("生成过期令牌失败: %v", err) t.Fatalf("生成过期令牌失败: %v", err)
} }
_, err = serviceWithCorrectKey.ParseToken(expiredTokenString)
_, err = service.ParseToken(expiredTokenString)
if err == nil { if err == nil {
t.Error("解析过期令牌意外成功") t.Error("解析过期令牌意外成功")
} }
// 新增测试用例 4: 使用错误的密钥解析
_, err = serviceWithWrongKey.ParseToken(validToken)
if err == nil {
t.Error("使用错误密钥解析令牌意外成功")
}
// 我们可以更精确地检查错误类型,以确保它是签名错误
if !errors.Is(err, jwt.ErrTokenSignatureInvalid) {
t.Errorf("期望得到签名无效错误 (ErrTokenSignatureInvalid),但得到了: %v", err)
}
} }

View File

@@ -127,17 +127,3 @@ func (c *Config) Load(path string) error {
return nil return nil
} }
// GetDatabaseConnectionString 获取数据库连接字符串
func (c *Config) GetDatabaseConnectionString() string {
// 构建PostgreSQL连接字符串
return fmt.Sprintf(
"user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
c.Database.Username,
c.Database.Password,
c.Database.DBName,
c.Database.Host,
c.Database.Port,
c.Database.SSLMode,
)
}

View File

@@ -27,7 +27,7 @@ type Logger struct {
// 这是实现依赖注入的关键,在应用启动时调用一次。 // 这是实现依赖注入的关键,在应用启动时调用一次。
func NewLogger(cfg config.LogConfig) *Logger { func NewLogger(cfg config.LogConfig) *Logger {
// 1. 设置日志编码器 // 1. 设置日志编码器
encoder := getEncoder(cfg.Format) encoder := GetEncoder(cfg.Format)
// 2. 设置日志写入器 (支持文件和控制台) // 2. 设置日志写入器 (支持文件和控制台)
writeSyncer := getWriteSyncer(cfg) writeSyncer := getWriteSyncer(cfg)
@@ -49,8 +49,8 @@ func NewLogger(cfg config.LogConfig) *Logger {
return &Logger{zapLogger.Sugar()} return &Logger{zapLogger.Sugar()}
} }
// getEncoder 根据指定的格式返回一个 zapcore.Encoder。 // GetEncoder 根据指定的格式返回一个 zapcore.Encoder。
func getEncoder(format string) zapcore.Encoder { func GetEncoder(format string) zapcore.Encoder {
encoderConfig := zap.NewProductionEncoderConfig() encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder // 时间格式: 2006-01-02T15:04:05.000Z0700 encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder // 时间格式: 2006-01-02T15:04:05.000Z0700
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder // 日志级别大写: INFO encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder // 日志级别大写: INFO

View File

@@ -3,6 +3,7 @@ package logs_test
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"strings" "strings"
"testing" "testing"
@@ -18,12 +19,8 @@ import (
// captureOutput 是一个辅助函数,用于捕获 logger 的输出到内存缓冲区 // captureOutput 是一个辅助函数,用于捕获 logger 的输出到内存缓冲区
func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) { func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) {
var buf bytes.Buffer var buf bytes.Buffer
// 使用一个简单的 Console Encoder 进行测试,方便断言字符串
encoderConfig := zap.NewDevelopmentEncoderConfig() encoder := logs.GetEncoder(cfg.Format)
encoderConfig.EncodeTime = nil // 忽略时间,避免测试结果不一致
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
encoderConfig.EncodeCaller = nil // 忽略调用者信息
encoder := zapcore.NewConsoleEncoder(encoderConfig)
writer := zapcore.AddSync(&buf) writer := zapcore.AddSync(&buf)
@@ -31,29 +28,52 @@ func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) {
_ = level.UnmarshalText([]byte(cfg.Level)) _ = level.UnmarshalText([]byte(cfg.Level))
core := zapcore.NewCore(encoder, writer, level) core := zapcore.NewCore(encoder, writer, level)
// 在测试中我们直接操作 zap Logger而不是通过封装的 NewLogger以注入内存 writer // 匹配 logs.go 中 NewLogger 的行为,添加调用者信息
zapLogger := zap.New(core) zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1))
logger := &logs.Logger{SugaredLogger: zapLogger.Sugar()} logger := &logs.Logger{SugaredLogger: zapLogger.Sugar()}
return logger, &buf return logger, &buf
} }
func TestNewLogger(t *testing.T) { func TestNewLogger(t *testing.T) {
t.Run("构造函数不会 panic", func(t *testing.T) { t.Run("日志级别应生效", func(t *testing.T) {
// 测试 Console 格式 // 1. 创建一个级别为 WARN 的 logger
cfgConsole := config.LogConfig{Level: "info", Format: "console"} logger, buf := captureOutput(config.LogConfig{Level: "warn", Format: "console"})
assert.NotPanics(t, func() { logs.NewLogger(cfgConsole) })
// 测试 JSON 格式 // 2. 调用不同级别的日志方法
cfgJSON := config.LogConfig{Level: "info", Format: "json"} logger.Info("这条 info 日志不应被打印")
assert.NotPanics(t, func() { logs.NewLogger(cfgJSON) }) logger.Warn("这条 warn 日志应该被打印")
// 测试文件日志启用 // 3. 断言输出
// 不实际写入文件,只确保构造函数能正常运行 output := buf.String()
assert.NotContains(t, output, "这条 info 日志不应被打印")
assert.Contains(t, output, "这条 warn 日志应该被打印")
})
t.Run("JSON 格式应生效", func(t *testing.T) {
// 1. 创建一个格式为 JSON 的 logger
logger, buf := captureOutput(config.LogConfig{Level: "info", Format: "json"})
// 2. 打印一条日志
logger.Info("测试json输出")
// 3. 断言输出
output := buf.String()
// 验证它是否是合法的 JSON并且包含预期的键值对
var logEntry map[string]interface{}
// 注意:由于日志库可能会在行尾添加换行符,我们先 trim space
err := json.Unmarshal([]byte(strings.TrimSpace(output)), &logEntry)
assert.NoError(t, err, "日志输出应为合法的JSON")
assert.Equal(t, "INFO", logEntry["level"])
assert.Equal(t, "测试json输出", logEntry["msg"])
})
t.Run("文件日志构造函数不应 panic", func(t *testing.T) {
// 这个测试保持原样,只验证构造函数在启用文件时不会崩溃
// 注意:我们不在单元测试中实际写入文件
cfgFile := config.LogConfig{ cfgFile := config.LogConfig{
Level: "info", Level: "info",
EnableFile: true, EnableFile: true,
FilePath: "test.log", FilePath: "test.log", // 在测试环境中,这个文件不会被真正创建
} }
assert.NotPanics(t, func() { logs.NewLogger(cfgFile) }) assert.NotPanics(t, func() { logs.NewLogger(cfgFile) })
}) })
@@ -84,7 +104,7 @@ func TestGormLogger(t *testing.T) {
return sql, rows return sql, rows
} }
t.Run("Slow Query", func(t *testing.T) { t.Run("慢查询应记录为警告", func(t *testing.T) {
buf.Reset() buf.Reset()
// 模拟一个耗时超过 200ms 的查询 // 模拟一个耗时超过 200ms 的查询
begin := time.Now().Add(-300 * time.Millisecond) begin := time.Now().Add(-300 * time.Millisecond)
@@ -96,7 +116,7 @@ func TestGormLogger(t *testing.T) {
assert.Contains(t, output, "SELECT * FROM users WHERE id = 1", "应包含 SQL 语句") assert.Contains(t, output, "SELECT * FROM users WHERE id = 1", "应包含 SQL 语句")
}) })
t.Run("Error Query", func(t *testing.T) { t.Run("普通错误应记录为Error", func(t *testing.T) {
buf.Reset() buf.Reset()
queryError := errors.New("syntax error") queryError := errors.New("syntax error")
gormLogger.Trace(ctx, time.Now(), fc, queryError) gormLogger.Trace(ctx, time.Now(), fc, queryError)
@@ -106,8 +126,10 @@ func TestGormLogger(t *testing.T) {
assert.Contains(t, output, "[GORM] error: syntax error") assert.Contains(t, output, "[GORM] error: syntax error")
}) })
t.Run("Record Not Found Error is Skipped", func(t *testing.T) { t.Run("当SkipErrRecordNotFound为true时应跳过RecordNotFound错误", func(t *testing.T) {
buf.Reset() buf.Reset()
// 确保默认设置是 true
gormLogger.SkipErrRecordNotFound = true
// 错误必须包含 "record not found" 字符串以匹配 logs.go 中的判断逻辑 // 错误必须包含 "record not found" 字符串以匹配 logs.go 中的判断逻辑
queryError := errors.New("record not found") queryError := errors.New("record not found")
gormLogger.Trace(ctx, time.Now(), fc, queryError) gormLogger.Trace(ctx, time.Now(), fc, queryError)
@@ -115,7 +137,24 @@ func TestGormLogger(t *testing.T) {
assert.Empty(t, buf.String(), "开启 SkipErrRecordNotFound 后record not found 错误不应产生任何日志") assert.Empty(t, buf.String(), "开启 SkipErrRecordNotFound 后record not found 错误不应产生任何日志")
}) })
t.Run("Normal Query", func(t *testing.T) { t.Run("当SkipErrRecordNotFound为false时应记录RecordNotFound错误", func(t *testing.T) {
buf.Reset()
// 手动将 SkipErrRecordNotFound 设置为 false
gormLogger.SkipErrRecordNotFound = false
queryError := errors.New("record not found")
gormLogger.Trace(ctx, time.Now(), fc, queryError)
// 恢复设置,避免影响其他测试
gormLogger.SkipErrRecordNotFound = true
output := buf.String()
assert.NotEmpty(t, output, "关闭 SkipErrRecordNotFound 后record not found 错误应该产生日志")
assert.Contains(t, output, "ERROR")
assert.Contains(t, output, "[GORM] error: record not found")
})
t.Run("正常查询应记录为Debug", func(t *testing.T) {
buf.Reset() buf.Reset()
// 模拟一个快速查询 // 模拟一个快速查询
gormLogger.Trace(ctx, time.Now(), fc, nil) gormLogger.Trace(ctx, time.Now(), fc, nil)

View File

@@ -58,7 +58,7 @@ type Device struct {
gorm.Model gorm.Model
// Name 是设备的业务名称,应清晰可读,例如 "1号猪舍温度传感器" 或 "做料车间主控" // Name 是设备的业务名称,应清晰可读,例如 "1号猪舍温度传感器" 或 "做料车间主控"
Name string `gorm:"not null" json:"name"` Name string `gorm:"unique;not null" json:"name"`
// Type 是设备的高级类别,用于区分区域主控和普通设备。建立索引以优化按类型查询。 // Type 是设备的高级类别,用于区分区域主控和普通设备。建立索引以优化按类型查询。
Type DeviceType `gorm:"not null;index" json:"type"` Type DeviceType `gorm:"not null;index" json:"type"`

View File

@@ -38,7 +38,37 @@ func TestUser_CheckPassword(t *testing.T) {
assert.False(t, match, "空密码应该校验失败") assert.False(t, match, "空密码应该校验失败")
}) })
} }
func TestUser_BeforeCreate(t *testing.T) {
t.Run("密码应被成功哈希", func(t *testing.T) {
plainPassword := "securepassword123"
user := &models.User{
Username: "testuser",
Password: plainPassword,
}
// 注意BeforeSave 钩子是一个 GORM 框架的回调,它的正确性 // 模拟 GORM 钩子调用
// 将在 repository 的集成测试中,通过实际创建一个用户来得到验证, err := user.BeforeCreate(nil) // GORM 钩子通常接收 *gorm.DB这里我们传入 nil因为 BeforeCreate 不依赖 DB
// 而不是在这里进行孤立的、脆弱的单元测试。 assert.NoError(t, err, "BeforeCreate 不应返回错误")
// 验证密码是否已被哈希(不再是明文)
assert.NotEqual(t, plainPassword, user.Password, "密码应已被哈希")
// 验证哈希后的密码是否能被正确校验
assert.True(t, user.CheckPassword(plainPassword), "哈希后的密码应能通过校验")
})
t.Run("空密码不应被哈希", func(t *testing.T) {
plainPassword := ""
user := &models.User{
Username: "empty_pass_user",
Password: plainPassword,
}
// 模拟 GORM 钩子调用
err := user.BeforeCreate(nil)
assert.NoError(t, err, "BeforeCreate 不应返回错误")
// 验证密码仍然是空字符串
assert.Equal(t, plainPassword, user.Password, "空密码不应被哈希")
})
}

View File

@@ -11,97 +11,252 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
func TestGormDeviceRepository(t *testing.T) { // createTestDevice 辅助函数,用于创建测试设备
func createTestDevice(t *testing.T, db *gorm.DB, name string, deviceType models.DeviceType, parentID *uint) *models.Device {
device := &models.Device{
Name: name,
Type: deviceType,
ParentID: parentID,
// 其他字段可以根据需要添加
}
err := db.Create(device).Error
assert.NoError(t, err)
return device
}
func TestRepoCreate(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
repo := repository.NewGormDeviceRepository(db) repo := repository.NewGormDeviceRepository(db)
// --- 准备测试数据 ---
loraProps, _ := json.Marshal(models.LoraProperties{LoraAddress: "0xABCD"}) loraProps, _ := json.Marshal(models.LoraProperties{LoraAddress: "0xABCD"})
busProps, _ := json.Marshal(models.BusProperties{BusID: 1, BusAddress: 10})
areaController := &models.Device{ t.Run("成功创建区域主控", func(t *testing.T) {
Name: "1号猪舍主控", device := &models.Device{
Type: models.DeviceTypeAreaController, Name: "主控A",
Location: "1号猪舍", Type: models.DeviceTypeAreaController,
Properties: loraProps, Location: "猪舍1",
} Properties: loraProps,
}
t.Run("创建 - 成功创建区域主控", func(t *testing.T) { err := repo.Create(device)
err := repo.Create(areaController)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotZero(t, areaController.ID, "创建后应获得一个非零ID") assert.NotZero(t, device.ID, "创建后应获得一个非零ID")
assert.Nil(t, areaController.ParentID, "区域主控的 ParentID 应为 nil") assert.Nil(t, device.ParentID, "区域主控的 ParentID 应为 nil")
}) })
var createdDevice *models.Device t.Run("成功创建子设备", func(t *testing.T) {
t.Run("通过ID查找 - 成功找到已创建的设备", func(t *testing.T) { parent := createTestDevice(t, db, "父设备", models.DeviceTypeAreaController, nil)
var err error child := &models.Device{
createdDevice, err = repo.FindByID(areaController.ID) Name: "子设备A",
Type: models.DeviceTypeDevice,
ParentID: &parent.ID,
}
err := repo.Create(child)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdDevice) assert.NotZero(t, child.ID)
assert.Equal(t, areaController.Name, createdDevice.Name) assert.NotNil(t, child.ParentID)
assert.Equal(t, parent.ID, *child.ParentID)
}) })
}
t.Run("通过字符串ID查找 - 使用有效字符串ID找到设备", func(t *testing.T) { func TestRepoFindByID(t *testing.T) {
foundDevice, err := repo.FindByIDString(strconv.FormatUint(uint64(areaController.ID), 10)) db := setupTestDB(t)
repo := repository.NewGormDeviceRepository(db)
device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil)
t.Run("成功通过ID查找", func(t *testing.T) {
foundDevice, err := repo.FindByID(device.ID)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, foundDevice) assert.NotNil(t, foundDevice)
assert.Equal(t, areaController.ID, foundDevice.ID) assert.Equal(t, device.ID, foundDevice.ID)
assert.Equal(t, device.Name, foundDevice.Name)
}) })
t.Run("通过字符串ID查找 - 使用无效字符串ID", func(t *testing.T) { t.Run("查找不存在的ID", func(t *testing.T) {
_, err := repo.FindByID(9999) // 不存在的ID
assert.Error(t, err)
assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
})
t.Run("数据库查询失败", func(t *testing.T) {
// 模拟数据库连接关闭,强制查询失败
sqlDB, _ := db.DB()
sqlDB.Close()
_, err := repo.FindByID(device.ID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "database is closed")
})
}
func TestRepoFindByIDString(t *testing.T) {
db := setupTestDB(t)
repo := repository.NewGormDeviceRepository(db)
device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil)
t.Run("成功通过字符串ID查找", func(t *testing.T) {
idStr := strconv.FormatUint(uint64(device.ID), 10)
foundDevice, err := repo.FindByIDString(idStr)
assert.NoError(t, err)
assert.NotNil(t, foundDevice)
assert.Equal(t, device.ID, foundDevice.ID)
})
t.Run("无效的字符串ID格式", func(t *testing.T) {
_, err := repo.FindByIDString("invalid-id") _, err := repo.FindByIDString("invalid-id")
assert.Error(t, err, "使用无效ID字符串应返回错误") assert.Error(t, err)
assert.Contains(t, err.Error(), "无效的设备ID格式")
}) })
// 创建一个子设备 t.Run("查找不存在的字符串ID", func(t *testing.T) {
childDevice := &models.Device{ idStr := strconv.FormatUint(uint64(9999), 10) // 不存在的ID
Name: "1号猪舍温度传感器", _, err := repo.FindByIDString(idStr)
Type: models.DeviceTypeDevice, assert.Error(t, err)
SubType: models.SubTypeSensorTemp, assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
ParentID: &areaController.ID, })
Location: "1号猪舍东侧",
Properties: busProps,
}
t.Run("创建 - 成功创建子设备", func(t *testing.T) { t.Run("数据库查询失败", func(t *testing.T) {
err := repo.Create(childDevice) sqlDB, _ := db.DB()
sqlDB.Close()
idStr := strconv.FormatUint(uint64(device.ID), 10)
_, err := repo.FindByIDString(idStr)
assert.Error(t, err)
assert.Contains(t, err.Error(), "database is closed")
})
}
func TestRepoListAll(t *testing.T) {
db := setupTestDB(t)
repo := repository.NewGormDeviceRepository(db)
t.Run("成功获取空列表", func(t *testing.T) {
devices, err := repo.ListAll()
assert.NoError(t, err) assert.NoError(t, err)
assert.NotZero(t, childDevice.ID) assert.Empty(t, devices)
assert.NotNil(t, childDevice.ParentID)
assert.Equal(t, areaController.ID, *childDevice.ParentID)
}) })
t.Run("通过父ID列出 - 找到子设备", func(t *testing.T) { t.Run("成功获取包含设备的列表", func(t *testing.T) {
children, err := repo.ListByParentID(&areaController.ID) createTestDevice(t, db, "设备1", models.DeviceTypeAreaController, nil)
createTestDevice(t, db, "设备2", models.DeviceTypeDevice, nil)
devices, err := repo.ListAll()
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, children, 1, "应找到一个子设备") assert.Len(t, devices, 2)
assert.Equal(t, childDevice.ID, children[0].ID) assert.Equal(t, "设备1", devices[0].Name)
assert.Equal(t, "设备2", devices[1].Name)
}) })
t.Run("通过父ID列出 - 找到顶层设备", func(t *testing.T) { t.Run("数据库查询失败", func(t *testing.T) {
sqlDB, _ := db.DB()
sqlDB.Close()
_, err := repo.ListAll()
assert.Error(t, err)
assert.Contains(t, err.Error(), "database is closed")
})
}
func TestRepoListByParentID(t *testing.T) {
db := setupTestDB(t)
repo := repository.NewGormDeviceRepository(db)
parent1 := createTestDevice(t, db, "父设备1", models.DeviceTypeAreaController, nil)
parent2 := createTestDevice(t, db, "父设备2", models.DeviceTypeAreaController, nil)
child1_1 := createTestDevice(t, db, "子设备1-1", models.DeviceTypeDevice, &parent1.ID)
child1_2 := createTestDevice(t, db, "子设备1-2", models.DeviceTypeDevice, &parent1.ID)
_ = createTestDevice(t, db, "子设备2-1", models.DeviceTypeDevice, &parent2.ID)
t.Run("成功通过父ID查找子设备", func(t *testing.T) {
children, err := repo.ListByParentID(&parent1.ID)
assert.NoError(t, err)
assert.Len(t, children, 2)
assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[0].ID)
assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[1].ID)
})
t.Run("成功通过nil父ID查找顶层设备", func(t *testing.T) {
parents, err := repo.ListByParentID(nil) parents, err := repo.ListByParentID(nil)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, parents, 1, "应找到一个顶层设备") assert.Len(t, parents, 2)
assert.Equal(t, areaController.ID, parents[0].ID) assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[0].ID)
assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[1].ID)
}) })
t.Run("更新 - 成功更新设备信息", func(t *testing.T) { t.Run("查找不存在的父ID", func(t *testing.T) {
childDevice.Location = "1号猪舍西侧" nonExistentParentID := uint(9999)
err := repo.Update(childDevice) children, err := repo.ListByParentID(&nonExistentParentID)
assert.NoError(t, err) // GORM 在未找到时返回空列表而不是错误
assert.Empty(t, children)
})
t.Run("数据库查询失败", func(t *testing.T) {
sqlDB, _ := db.DB()
sqlDB.Close()
_, err := repo.ListByParentID(&parent1.ID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "database is closed")
})
}
func TestRepoUpdate(t *testing.T) {
db := setupTestDB(t)
repo := repository.NewGormDeviceRepository(db)
device := createTestDevice(t, db, "原始设备", models.DeviceTypeAreaController, nil)
t.Run("成功更新设备信息", func(t *testing.T) {
device.Name = "更新后的设备"
device.Location = "新地点"
err := repo.Update(device)
assert.NoError(t, err) assert.NoError(t, err)
updatedDevice, _ := repo.FindByID(childDevice.ID) updatedDevice, err := repo.FindByID(device.ID)
assert.Equal(t, "1号猪舍西侧", updatedDevice.Location) assert.NoError(t, err)
assert.Equal(t, "更新后的设备", updatedDevice.Name)
assert.Equal(t, "新地点", updatedDevice.Location)
}) })
t.Run("删除 - 成功删除设备", func(t *testing.T) { t.Run("数据库更新失败", func(t *testing.T) {
err := repo.Delete(childDevice.ID) sqlDB, _ := db.DB()
sqlDB.Close()
device.Name = "更新失败的设备"
err := repo.Update(device)
assert.Error(t, err)
assert.Contains(t, err.Error(), "database is closed")
})
}
func TestRepoDelete(t *testing.T) {
db := setupTestDB(t)
repo := repository.NewGormDeviceRepository(db)
device := createTestDevice(t, db, "待删除设备", models.DeviceTypeAreaController, nil)
t.Run("成功删除设备", func(t *testing.T) {
err := repo.Delete(device.ID)
assert.NoError(t, err) assert.NoError(t, err)
// 验证设备已被软删除 // 验证设备已被软删除
_, err = repo.FindByID(childDevice.ID) _, err = repo.FindByID(device.ID)
assert.Error(t, err, "删除后应无法找到设备") assert.Error(t, err, "删除后应无法找到设备")
assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 RecordNotFound") assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 RecordNotFound")
}) })
t.Run("删除不存在的设备", func(t *testing.T) {
err := repo.Delete(9999) // 不存在的ID
assert.NoError(t, err) // GORM 的 Delete 方法在删除不存在的记录时不会报错
})
t.Run("数据库删除失败", func(t *testing.T) {
sqlDB, _ := db.DB()
sqlDB.Close()
err := repo.Delete(device.ID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "database is closed")
})
} }