补充测试用例
This commit is contained in:
8
Makefile
8
Makefile
@@ -9,6 +9,7 @@ help:
|
||||
@echo " run Run the application"
|
||||
@echo " build Build the application"
|
||||
@echo " clean Clean generated files"
|
||||
@echo " test Run all tests"
|
||||
@echo " help Show this help message"
|
||||
|
||||
# 运行应用
|
||||
@@ -24,4 +25,9 @@ build:
|
||||
# 清理生成文件
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm -f bin/pig-farm-controller
|
||||
rm -f bin/pig-farm-controller
|
||||
|
||||
# 运行所有测试
|
||||
.PHONY: test
|
||||
test:
|
||||
go test --count=1 ./...
|
||||
@@ -1,6 +1,7 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -52,9 +53,16 @@ func (s *tokenService) ParseToken(tokenString string) (*Claims, error) {
|
||||
return s.secret, nil
|
||||
})
|
||||
|
||||
// 优先检查解析过程中是否发生错误
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 只有当 token 对象有效时,才尝试获取 Claims 并验证
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
// 如果 token 无效(例如,过期但没有返回错误,或者 Claims 类型不匹配),则返回一个通用错误
|
||||
return nil, fmt.Errorf("token is invalid")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package token_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -45,36 +46,38 @@ func TestGenerateToken(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseToken(t *testing.T) {
|
||||
// 使用一个测试密钥初始化 TokenService
|
||||
testSecret := []byte("test_secret_key")
|
||||
service := token.NewTokenService(testSecret)
|
||||
// 使用两个不同的测试密钥
|
||||
correctSecret := []byte("the_correct_secret")
|
||||
wrongSecret := []byte("a_very_wrong_secret")
|
||||
|
||||
serviceWithCorrectKey := token.NewTokenService(correctSecret)
|
||||
serviceWithWrongKey := token.NewTokenService(wrongSecret)
|
||||
|
||||
userID := uint(456)
|
||||
|
||||
// 生成一个有效的 token 用于解析测试
|
||||
validToken, err := service.GenerateToken(userID)
|
||||
// 1. 生成一个有效的 token
|
||||
validToken, err := serviceWithCorrectKey.GenerateToken(userID)
|
||||
if err != nil {
|
||||
t.Fatalf("为解析测试生成有效令牌失败: %v", err)
|
||||
}
|
||||
|
||||
// 测试用例 1: 有效 token
|
||||
claims, err := service.ParseToken(validToken)
|
||||
// 测试用例 1: 使用正确的密钥成功解析
|
||||
claims, err := serviceWithCorrectKey.ParseToken(validToken)
|
||||
if err != nil {
|
||||
t.Errorf("解析有效令牌失败: %v", err)
|
||||
t.Errorf("使用正确密钥解析有效令牌失败: %v", err)
|
||||
}
|
||||
|
||||
if claims.UserID != userID {
|
||||
t.Errorf("解析有效令牌时期望用户ID %d, 实际为 %d", userID, claims.UserID)
|
||||
}
|
||||
|
||||
// 测试用例 2: 无效 token (例如, 错误的签名)
|
||||
invalidToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoxMjMsImV4cCI6MTY3ODkwNTYwMCwiaXNzIjoicGlnLWZhcm0tY29udHJvbGxlciJ9.invalid_signature_here"
|
||||
_, err = service.ParseToken(invalidToken)
|
||||
// 测试用例 2: 无效 token (例如, 格式错误的字符串)
|
||||
invalidTokenString := "this.is.not.a.valid.jwt"
|
||||
_, err = serviceWithCorrectKey.ParseToken(invalidTokenString)
|
||||
if err == nil {
|
||||
t.Error("解析无效令牌意外成功")
|
||||
t.Error("解析格式错误的令牌意外成功")
|
||||
}
|
||||
|
||||
// 测试用例 3: 过期 token (创建一个过期时间在过去的 token)
|
||||
// 测试用C:\Users\divano\Desktop\work\AA-Pig\pig-farm-controller\internal\infra\repository\plan_repository_test.go例 3: 过期 token
|
||||
expiredClaims := token.Claims{
|
||||
UserID: userID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
@@ -83,13 +86,22 @@ func TestParseToken(t *testing.T) {
|
||||
},
|
||||
}
|
||||
expiredTokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, expiredClaims)
|
||||
expiredTokenString, err := expiredTokenClaims.SignedString(testSecret)
|
||||
expiredTokenString, err := expiredTokenClaims.SignedString(correctSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("生成过期令牌失败: %v", err)
|
||||
}
|
||||
|
||||
_, err = service.ParseToken(expiredTokenString)
|
||||
_, err = serviceWithCorrectKey.ParseToken(expiredTokenString)
|
||||
if err == nil {
|
||||
t.Error("解析过期令牌意外成功")
|
||||
}
|
||||
|
||||
// 新增测试用例 4: 使用错误的密钥解析
|
||||
_, err = serviceWithWrongKey.ParseToken(validToken)
|
||||
if err == nil {
|
||||
t.Error("使用错误密钥解析令牌意外成功")
|
||||
}
|
||||
// 我们可以更精确地检查错误类型,以确保它是签名错误
|
||||
if !errors.Is(err, jwt.ErrTokenSignatureInvalid) {
|
||||
t.Errorf("期望得到签名无效错误 (ErrTokenSignatureInvalid),但得到了: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,17 +127,3 @@ func (c *Config) Load(path string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDatabaseConnectionString 获取数据库连接字符串
|
||||
func (c *Config) GetDatabaseConnectionString() string {
|
||||
// 构建PostgreSQL连接字符串
|
||||
return fmt.Sprintf(
|
||||
"user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
|
||||
c.Database.Username,
|
||||
c.Database.Password,
|
||||
c.Database.DBName,
|
||||
c.Database.Host,
|
||||
c.Database.Port,
|
||||
c.Database.SSLMode,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ type Logger struct {
|
||||
// 这是实现依赖注入的关键,在应用启动时调用一次。
|
||||
func NewLogger(cfg config.LogConfig) *Logger {
|
||||
// 1. 设置日志编码器
|
||||
encoder := getEncoder(cfg.Format)
|
||||
encoder := GetEncoder(cfg.Format)
|
||||
|
||||
// 2. 设置日志写入器 (支持文件和控制台)
|
||||
writeSyncer := getWriteSyncer(cfg)
|
||||
@@ -49,8 +49,8 @@ func NewLogger(cfg config.LogConfig) *Logger {
|
||||
return &Logger{zapLogger.Sugar()}
|
||||
}
|
||||
|
||||
// getEncoder 根据指定的格式返回一个 zapcore.Encoder。
|
||||
func getEncoder(format string) zapcore.Encoder {
|
||||
// GetEncoder 根据指定的格式返回一个 zapcore.Encoder。
|
||||
func GetEncoder(format string) zapcore.Encoder {
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder // 时间格式: 2006-01-02T15:04:05.000Z0700
|
||||
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder // 日志级别大写: INFO
|
||||
|
||||
@@ -3,6 +3,7 @@ package logs_test
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -18,12 +19,8 @@ import (
|
||||
// captureOutput 是一个辅助函数,用于捕获 logger 的输出到内存缓冲区
|
||||
func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) {
|
||||
var buf bytes.Buffer
|
||||
// 使用一个简单的 Console Encoder 进行测试,方便断言字符串
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoderConfig.EncodeTime = nil // 忽略时间,避免测试结果不一致
|
||||
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
|
||||
encoderConfig.EncodeCaller = nil // 忽略调用者信息
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
|
||||
encoder := logs.GetEncoder(cfg.Format)
|
||||
|
||||
writer := zapcore.AddSync(&buf)
|
||||
|
||||
@@ -31,29 +28,52 @@ func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) {
|
||||
_ = level.UnmarshalText([]byte(cfg.Level))
|
||||
|
||||
core := zapcore.NewCore(encoder, writer, level)
|
||||
// 在测试中我们直接操作 zap Logger,而不是通过封装的 NewLogger,以注入内存 writer
|
||||
zapLogger := zap.New(core)
|
||||
// 匹配 logs.go 中 NewLogger 的行为,添加调用者信息
|
||||
zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1))
|
||||
|
||||
logger := &logs.Logger{SugaredLogger: zapLogger.Sugar()}
|
||||
return logger, &buf
|
||||
}
|
||||
|
||||
func TestNewLogger(t *testing.T) {
|
||||
t.Run("构造函数不会 panic", func(t *testing.T) {
|
||||
// 测试 Console 格式
|
||||
cfgConsole := config.LogConfig{Level: "info", Format: "console"}
|
||||
assert.NotPanics(t, func() { logs.NewLogger(cfgConsole) })
|
||||
t.Run("日志级别应生效", func(t *testing.T) {
|
||||
// 1. 创建一个级别为 WARN 的 logger
|
||||
logger, buf := captureOutput(config.LogConfig{Level: "warn", Format: "console"})
|
||||
|
||||
// 测试 JSON 格式
|
||||
cfgJSON := config.LogConfig{Level: "info", Format: "json"}
|
||||
assert.NotPanics(t, func() { logs.NewLogger(cfgJSON) })
|
||||
// 2. 调用不同级别的日志方法
|
||||
logger.Info("这条 info 日志不应被打印")
|
||||
logger.Warn("这条 warn 日志应该被打印")
|
||||
|
||||
// 测试文件日志启用
|
||||
// 不实际写入文件,只确保构造函数能正常运行
|
||||
// 3. 断言输出
|
||||
output := buf.String()
|
||||
assert.NotContains(t, output, "这条 info 日志不应被打印")
|
||||
assert.Contains(t, output, "这条 warn 日志应该被打印")
|
||||
})
|
||||
|
||||
t.Run("JSON 格式应生效", func(t *testing.T) {
|
||||
// 1. 创建一个格式为 JSON 的 logger
|
||||
logger, buf := captureOutput(config.LogConfig{Level: "info", Format: "json"})
|
||||
|
||||
// 2. 打印一条日志
|
||||
logger.Info("测试json输出")
|
||||
|
||||
// 3. 断言输出
|
||||
output := buf.String()
|
||||
// 验证它是否是合法的 JSON,并且包含预期的键值对
|
||||
var logEntry map[string]interface{}
|
||||
// 注意:由于日志库可能会在行尾添加换行符,我们先 trim space
|
||||
err := json.Unmarshal([]byte(strings.TrimSpace(output)), &logEntry)
|
||||
assert.NoError(t, err, "日志输出应为合法的JSON")
|
||||
assert.Equal(t, "INFO", logEntry["level"])
|
||||
assert.Equal(t, "测试json输出", logEntry["msg"])
|
||||
})
|
||||
|
||||
t.Run("文件日志构造函数不应 panic", func(t *testing.T) {
|
||||
// 这个测试保持原样,只验证构造函数在启用文件时不会崩溃
|
||||
// 注意:我们不在单元测试中实际写入文件
|
||||
cfgFile := config.LogConfig{
|
||||
Level: "info",
|
||||
EnableFile: true,
|
||||
FilePath: "test.log",
|
||||
FilePath: "test.log", // 在测试环境中,这个文件不会被真正创建
|
||||
}
|
||||
assert.NotPanics(t, func() { logs.NewLogger(cfgFile) })
|
||||
})
|
||||
@@ -84,7 +104,7 @@ func TestGormLogger(t *testing.T) {
|
||||
return sql, rows
|
||||
}
|
||||
|
||||
t.Run("Slow Query", func(t *testing.T) {
|
||||
t.Run("慢查询应记录为警告", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
// 模拟一个耗时超过 200ms 的查询
|
||||
begin := time.Now().Add(-300 * time.Millisecond)
|
||||
@@ -96,7 +116,7 @@ func TestGormLogger(t *testing.T) {
|
||||
assert.Contains(t, output, "SELECT * FROM users WHERE id = 1", "应包含 SQL 语句")
|
||||
})
|
||||
|
||||
t.Run("Error Query", func(t *testing.T) {
|
||||
t.Run("普通错误应记录为Error", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
queryError := errors.New("syntax error")
|
||||
gormLogger.Trace(ctx, time.Now(), fc, queryError)
|
||||
@@ -106,8 +126,10 @@ func TestGormLogger(t *testing.T) {
|
||||
assert.Contains(t, output, "[GORM] error: syntax error")
|
||||
})
|
||||
|
||||
t.Run("Record Not Found Error is Skipped", func(t *testing.T) {
|
||||
t.Run("当SkipErrRecordNotFound为true时应跳过RecordNotFound错误", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
// 确保默认设置是 true
|
||||
gormLogger.SkipErrRecordNotFound = true
|
||||
// 错误必须包含 "record not found" 字符串以匹配 logs.go 中的判断逻辑
|
||||
queryError := errors.New("record not found")
|
||||
gormLogger.Trace(ctx, time.Now(), fc, queryError)
|
||||
@@ -115,7 +137,24 @@ func TestGormLogger(t *testing.T) {
|
||||
assert.Empty(t, buf.String(), "开启 SkipErrRecordNotFound 后,record not found 错误不应产生任何日志")
|
||||
})
|
||||
|
||||
t.Run("Normal Query", func(t *testing.T) {
|
||||
t.Run("当SkipErrRecordNotFound为false时应记录RecordNotFound错误", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
// 手动将 SkipErrRecordNotFound 设置为 false
|
||||
gormLogger.SkipErrRecordNotFound = false
|
||||
|
||||
queryError := errors.New("record not found")
|
||||
gormLogger.Trace(ctx, time.Now(), fc, queryError)
|
||||
|
||||
// 恢复设置,避免影响其他测试
|
||||
gormLogger.SkipErrRecordNotFound = true
|
||||
|
||||
output := buf.String()
|
||||
assert.NotEmpty(t, output, "关闭 SkipErrRecordNotFound 后,record not found 错误应该产生日志")
|
||||
assert.Contains(t, output, "ERROR")
|
||||
assert.Contains(t, output, "[GORM] error: record not found")
|
||||
})
|
||||
|
||||
t.Run("正常查询应记录为Debug", func(t *testing.T) {
|
||||
buf.Reset()
|
||||
// 模拟一个快速查询
|
||||
gormLogger.Trace(ctx, time.Now(), fc, nil)
|
||||
|
||||
@@ -58,7 +58,7 @@ type Device struct {
|
||||
gorm.Model
|
||||
|
||||
// Name 是设备的业务名称,应清晰可读,例如 "1号猪舍温度传感器" 或 "做料车间主控"
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Name string `gorm:"unique;not null" json:"name"`
|
||||
|
||||
// Type 是设备的高级类别,用于区分区域主控和普通设备。建立索引以优化按类型查询。
|
||||
Type DeviceType `gorm:"not null;index" json:"type"`
|
||||
|
||||
@@ -38,7 +38,37 @@ func TestUser_CheckPassword(t *testing.T) {
|
||||
assert.False(t, match, "空密码应该校验失败")
|
||||
})
|
||||
}
|
||||
func TestUser_BeforeCreate(t *testing.T) {
|
||||
t.Run("密码应被成功哈希", func(t *testing.T) {
|
||||
plainPassword := "securepassword123"
|
||||
user := &models.User{
|
||||
Username: "testuser",
|
||||
Password: plainPassword,
|
||||
}
|
||||
|
||||
// 注意:BeforeSave 钩子是一个 GORM 框架的回调,它的正确性
|
||||
// 将在 repository 的集成测试中,通过实际创建一个用户来得到验证,
|
||||
// 而不是在这里进行孤立的、脆弱的单元测试。
|
||||
// 模拟 GORM 钩子调用
|
||||
err := user.BeforeCreate(nil) // GORM 钩子通常接收 *gorm.DB,这里我们传入 nil,因为 BeforeCreate 不依赖 DB
|
||||
assert.NoError(t, err, "BeforeCreate 不应返回错误")
|
||||
|
||||
// 验证密码是否已被哈希(不再是明文)
|
||||
assert.NotEqual(t, plainPassword, user.Password, "密码应已被哈希")
|
||||
|
||||
// 验证哈希后的密码是否能被正确校验
|
||||
assert.True(t, user.CheckPassword(plainPassword), "哈希后的密码应能通过校验")
|
||||
})
|
||||
|
||||
t.Run("空密码不应被哈希", func(t *testing.T) {
|
||||
plainPassword := ""
|
||||
user := &models.User{
|
||||
Username: "empty_pass_user",
|
||||
Password: plainPassword,
|
||||
}
|
||||
|
||||
// 模拟 GORM 钩子调用
|
||||
err := user.BeforeCreate(nil)
|
||||
assert.NoError(t, err, "BeforeCreate 不应返回错误")
|
||||
|
||||
// 验证密码仍然是空字符串
|
||||
assert.Equal(t, plainPassword, user.Password, "空密码不应被哈希")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,97 +11,252 @@ import (
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func TestGormDeviceRepository(t *testing.T) {
|
||||
// createTestDevice 辅助函数,用于创建测试设备
|
||||
func createTestDevice(t *testing.T, db *gorm.DB, name string, deviceType models.DeviceType, parentID *uint) *models.Device {
|
||||
device := &models.Device{
|
||||
Name: name,
|
||||
Type: deviceType,
|
||||
ParentID: parentID,
|
||||
// 其他字段可以根据需要添加
|
||||
}
|
||||
err := db.Create(device).Error
|
||||
assert.NoError(t, err)
|
||||
return device
|
||||
}
|
||||
|
||||
func TestRepoCreate(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := repository.NewGormDeviceRepository(db)
|
||||
|
||||
// --- 准备测试数据 ---
|
||||
loraProps, _ := json.Marshal(models.LoraProperties{LoraAddress: "0xABCD"})
|
||||
busProps, _ := json.Marshal(models.BusProperties{BusID: 1, BusAddress: 10})
|
||||
|
||||
areaController := &models.Device{
|
||||
Name: "1号猪舍主控",
|
||||
Type: models.DeviceTypeAreaController,
|
||||
Location: "1号猪舍",
|
||||
Properties: loraProps,
|
||||
}
|
||||
|
||||
t.Run("创建 - 成功创建区域主控", func(t *testing.T) {
|
||||
err := repo.Create(areaController)
|
||||
t.Run("成功创建区域主控", func(t *testing.T) {
|
||||
device := &models.Device{
|
||||
Name: "主控A",
|
||||
Type: models.DeviceTypeAreaController,
|
||||
Location: "猪舍1",
|
||||
Properties: loraProps,
|
||||
}
|
||||
err := repo.Create(device)
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, areaController.ID, "创建后应获得一个非零ID")
|
||||
assert.Nil(t, areaController.ParentID, "区域主控的 ParentID 应为 nil")
|
||||
assert.NotZero(t, device.ID, "创建后应获得一个非零ID")
|
||||
assert.Nil(t, device.ParentID, "区域主控的 ParentID 应为 nil")
|
||||
})
|
||||
|
||||
var createdDevice *models.Device
|
||||
t.Run("通过ID查找 - 成功找到已创建的设备", func(t *testing.T) {
|
||||
var err error
|
||||
createdDevice, err = repo.FindByID(areaController.ID)
|
||||
t.Run("成功创建子设备", func(t *testing.T) {
|
||||
parent := createTestDevice(t, db, "父设备", models.DeviceTypeAreaController, nil)
|
||||
child := &models.Device{
|
||||
Name: "子设备A",
|
||||
Type: models.DeviceTypeDevice,
|
||||
ParentID: &parent.ID,
|
||||
}
|
||||
err := repo.Create(child)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, createdDevice)
|
||||
assert.Equal(t, areaController.Name, createdDevice.Name)
|
||||
assert.NotZero(t, child.ID)
|
||||
assert.NotNil(t, child.ParentID)
|
||||
assert.Equal(t, parent.ID, *child.ParentID)
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("通过字符串ID查找 - 使用有效字符串ID找到设备", func(t *testing.T) {
|
||||
foundDevice, err := repo.FindByIDString(strconv.FormatUint(uint64(areaController.ID), 10))
|
||||
func TestRepoFindByID(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := repository.NewGormDeviceRepository(db)
|
||||
|
||||
device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil)
|
||||
|
||||
t.Run("成功通过ID查找", func(t *testing.T) {
|
||||
foundDevice, err := repo.FindByID(device.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, foundDevice)
|
||||
assert.Equal(t, areaController.ID, foundDevice.ID)
|
||||
assert.Equal(t, device.ID, foundDevice.ID)
|
||||
assert.Equal(t, device.Name, foundDevice.Name)
|
||||
})
|
||||
|
||||
t.Run("通过字符串ID查找 - 使用无效字符串ID", func(t *testing.T) {
|
||||
t.Run("查找不存在的ID", func(t *testing.T) {
|
||||
_, err := repo.FindByID(9999) // 不存在的ID
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
|
||||
})
|
||||
|
||||
t.Run("数据库查询失败", func(t *testing.T) {
|
||||
// 模拟数据库连接关闭,强制查询失败
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
_, err := repo.FindByID(device.ID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "database is closed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepoFindByIDString(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := repository.NewGormDeviceRepository(db)
|
||||
|
||||
device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil)
|
||||
|
||||
t.Run("成功通过字符串ID查找", func(t *testing.T) {
|
||||
idStr := strconv.FormatUint(uint64(device.ID), 10)
|
||||
foundDevice, err := repo.FindByIDString(idStr)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, foundDevice)
|
||||
assert.Equal(t, device.ID, foundDevice.ID)
|
||||
})
|
||||
|
||||
t.Run("无效的字符串ID格式", func(t *testing.T) {
|
||||
_, err := repo.FindByIDString("invalid-id")
|
||||
assert.Error(t, err, "使用无效ID字符串应返回错误")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "无效的设备ID格式")
|
||||
})
|
||||
|
||||
// 创建一个子设备
|
||||
childDevice := &models.Device{
|
||||
Name: "1号猪舍温度传感器",
|
||||
Type: models.DeviceTypeDevice,
|
||||
SubType: models.SubTypeSensorTemp,
|
||||
ParentID: &areaController.ID,
|
||||
Location: "1号猪舍东侧",
|
||||
Properties: busProps,
|
||||
}
|
||||
t.Run("查找不存在的字符串ID", func(t *testing.T) {
|
||||
idStr := strconv.FormatUint(uint64(9999), 10) // 不存在的ID
|
||||
_, err := repo.FindByIDString(idStr)
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, gorm.ErrRecordNotFound)
|
||||
})
|
||||
|
||||
t.Run("创建 - 成功创建子设备", func(t *testing.T) {
|
||||
err := repo.Create(childDevice)
|
||||
t.Run("数据库查询失败", func(t *testing.T) {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
idStr := strconv.FormatUint(uint64(device.ID), 10)
|
||||
_, err := repo.FindByIDString(idStr)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "database is closed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepoListAll(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := repository.NewGormDeviceRepository(db)
|
||||
|
||||
t.Run("成功获取空列表", func(t *testing.T) {
|
||||
devices, err := repo.ListAll()
|
||||
assert.NoError(t, err)
|
||||
assert.NotZero(t, childDevice.ID)
|
||||
assert.NotNil(t, childDevice.ParentID)
|
||||
assert.Equal(t, areaController.ID, *childDevice.ParentID)
|
||||
assert.Empty(t, devices)
|
||||
})
|
||||
|
||||
t.Run("通过父ID列出 - 找到子设备", func(t *testing.T) {
|
||||
children, err := repo.ListByParentID(&areaController.ID)
|
||||
t.Run("成功获取包含设备的列表", func(t *testing.T) {
|
||||
createTestDevice(t, db, "设备1", models.DeviceTypeAreaController, nil)
|
||||
createTestDevice(t, db, "设备2", models.DeviceTypeDevice, nil)
|
||||
|
||||
devices, err := repo.ListAll()
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, children, 1, "应找到一个子设备")
|
||||
assert.Equal(t, childDevice.ID, children[0].ID)
|
||||
assert.Len(t, devices, 2)
|
||||
assert.Equal(t, "设备1", devices[0].Name)
|
||||
assert.Equal(t, "设备2", devices[1].Name)
|
||||
})
|
||||
|
||||
t.Run("通过父ID列出 - 找到顶层设备", func(t *testing.T) {
|
||||
t.Run("数据库查询失败", func(t *testing.T) {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
_, err := repo.ListAll()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "database is closed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepoListByParentID(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := repository.NewGormDeviceRepository(db)
|
||||
|
||||
parent1 := createTestDevice(t, db, "父设备1", models.DeviceTypeAreaController, nil)
|
||||
parent2 := createTestDevice(t, db, "父设备2", models.DeviceTypeAreaController, nil)
|
||||
child1_1 := createTestDevice(t, db, "子设备1-1", models.DeviceTypeDevice, &parent1.ID)
|
||||
child1_2 := createTestDevice(t, db, "子设备1-2", models.DeviceTypeDevice, &parent1.ID)
|
||||
_ = createTestDevice(t, db, "子设备2-1", models.DeviceTypeDevice, &parent2.ID)
|
||||
|
||||
t.Run("成功通过父ID查找子设备", func(t *testing.T) {
|
||||
children, err := repo.ListByParentID(&parent1.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, children, 2)
|
||||
assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[0].ID)
|
||||
assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[1].ID)
|
||||
})
|
||||
|
||||
t.Run("成功通过nil父ID查找顶层设备", func(t *testing.T) {
|
||||
parents, err := repo.ListByParentID(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, parents, 1, "应找到一个顶层设备")
|
||||
assert.Equal(t, areaController.ID, parents[0].ID)
|
||||
assert.Len(t, parents, 2)
|
||||
assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[0].ID)
|
||||
assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[1].ID)
|
||||
})
|
||||
|
||||
t.Run("更新 - 成功更新设备信息", func(t *testing.T) {
|
||||
childDevice.Location = "1号猪舍西侧"
|
||||
err := repo.Update(childDevice)
|
||||
t.Run("查找不存在的父ID", func(t *testing.T) {
|
||||
nonExistentParentID := uint(9999)
|
||||
children, err := repo.ListByParentID(&nonExistentParentID)
|
||||
assert.NoError(t, err) // GORM 在未找到时返回空列表而不是错误
|
||||
assert.Empty(t, children)
|
||||
})
|
||||
|
||||
t.Run("数据库查询失败", func(t *testing.T) {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
_, err := repo.ListByParentID(&parent1.ID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "database is closed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepoUpdate(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := repository.NewGormDeviceRepository(db)
|
||||
|
||||
device := createTestDevice(t, db, "原始设备", models.DeviceTypeAreaController, nil)
|
||||
|
||||
t.Run("成功更新设备信息", func(t *testing.T) {
|
||||
device.Name = "更新后的设备"
|
||||
device.Location = "新地点"
|
||||
err := repo.Update(device)
|
||||
assert.NoError(t, err)
|
||||
|
||||
updatedDevice, _ := repo.FindByID(childDevice.ID)
|
||||
assert.Equal(t, "1号猪舍西侧", updatedDevice.Location)
|
||||
updatedDevice, err := repo.FindByID(device.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "更新后的设备", updatedDevice.Name)
|
||||
assert.Equal(t, "新地点", updatedDevice.Location)
|
||||
})
|
||||
|
||||
t.Run("删除 - 成功删除设备", func(t *testing.T) {
|
||||
err := repo.Delete(childDevice.ID)
|
||||
t.Run("数据库更新失败", func(t *testing.T) {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
device.Name = "更新失败的设备"
|
||||
err := repo.Update(device)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "database is closed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRepoDelete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := repository.NewGormDeviceRepository(db)
|
||||
|
||||
device := createTestDevice(t, db, "待删除设备", models.DeviceTypeAreaController, nil)
|
||||
|
||||
t.Run("成功删除设备", func(t *testing.T) {
|
||||
err := repo.Delete(device.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证设备已被软删除
|
||||
_, err = repo.FindByID(childDevice.ID)
|
||||
_, err = repo.FindByID(device.ID)
|
||||
assert.Error(t, err, "删除后应无法找到设备")
|
||||
assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 RecordNotFound")
|
||||
})
|
||||
|
||||
t.Run("删除不存在的设备", func(t *testing.T) {
|
||||
err := repo.Delete(9999) // 不存在的ID
|
||||
assert.NoError(t, err) // GORM 的 Delete 方法在删除不存在的记录时不会报错
|
||||
})
|
||||
|
||||
t.Run("数据库删除失败", func(t *testing.T) {
|
||||
sqlDB, _ := db.DB()
|
||||
sqlDB.Close()
|
||||
|
||||
err := repo.Delete(device.ID)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "database is closed")
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user