重构 #4
							
								
								
									
										6
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								Makefile
									
									
									
									
									
								
							| @@ -9,6 +9,7 @@ help: | ||||
| 	@echo "  run          Run the application" | ||||
| 	@echo "  build        Build the application" | ||||
| 	@echo "  clean        Clean generated files" | ||||
| 	@echo "  test         Run all tests" | ||||
| 	@echo "  help         Show this help message" | ||||
|  | ||||
| # 运行应用 | ||||
| @@ -25,3 +26,8 @@ build: | ||||
| .PHONY: clean | ||||
| clean: | ||||
| 	rm -f bin/pig-farm-controller | ||||
|  | ||||
| # 运行所有测试 | ||||
| .PHONY: test | ||||
| test: | ||||
| 	go test --count=1 ./... | ||||
| @@ -1,6 +1,7 @@ | ||||
| package token | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/golang-jwt/jwt/v5" | ||||
| @@ -52,9 +53,16 @@ func (s *tokenService) ParseToken(tokenString string) (*Claims, error) { | ||||
| 		return s.secret, nil | ||||
| 	}) | ||||
|  | ||||
| 	// 优先检查解析过程中是否发生错误 | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// 只有当 token 对象有效时,才尝试获取 Claims 并验证 | ||||
| 	if claims, ok := token.Claims.(*Claims); ok && token.Valid { | ||||
| 		return claims, nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, err | ||||
| 	// 如果 token 无效(例如,过期但没有返回错误,或者 Claims 类型不匹配),则返回一个通用错误 | ||||
| 	return nil, fmt.Errorf("token is invalid") | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package token_test | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| @@ -45,36 +46,38 @@ func TestGenerateToken(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestParseToken(t *testing.T) { | ||||
| 	// 使用一个测试密钥初始化 TokenService | ||||
| 	testSecret := []byte("test_secret_key") | ||||
| 	service := token.NewTokenService(testSecret) | ||||
| 	// 使用两个不同的测试密钥 | ||||
| 	correctSecret := []byte("the_correct_secret") | ||||
| 	wrongSecret := []byte("a_very_wrong_secret") | ||||
|  | ||||
| 	serviceWithCorrectKey := token.NewTokenService(correctSecret) | ||||
| 	serviceWithWrongKey := token.NewTokenService(wrongSecret) | ||||
|  | ||||
| 	userID := uint(456) | ||||
|  | ||||
| 	// 生成一个有效的 token 用于解析测试 | ||||
| 	validToken, err := service.GenerateToken(userID) | ||||
| 	// 1. 生成一个有效的 token | ||||
| 	validToken, err := serviceWithCorrectKey.GenerateToken(userID) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("为解析测试生成有效令牌失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 测试用例 1: 有效 token | ||||
| 	claims, err := service.ParseToken(validToken) | ||||
| 	// 测试用例 1: 使用正确的密钥成功解析 | ||||
| 	claims, err := serviceWithCorrectKey.ParseToken(validToken) | ||||
| 	if err != nil { | ||||
| 		t.Errorf("解析有效令牌失败: %v", err) | ||||
| 		t.Errorf("使用正确密钥解析有效令牌失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if claims.UserID != userID { | ||||
| 		t.Errorf("解析有效令牌时期望用户ID %d, 实际为 %d", userID, claims.UserID) | ||||
| 	} | ||||
|  | ||||
| 	// 测试用例 2: 无效 token (例如, 错误的签名) | ||||
| 	invalidToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoxMjMsImV4cCI6MTY3ODkwNTYwMCwiaXNzIjoicGlnLWZhcm0tY29udHJvbGxlciJ9.invalid_signature_here" | ||||
| 	_, err = service.ParseToken(invalidToken) | ||||
| 	// 测试用例 2: 无效 token (例如, 格式错误的字符串) | ||||
| 	invalidTokenString := "this.is.not.a.valid.jwt" | ||||
| 	_, err = serviceWithCorrectKey.ParseToken(invalidTokenString) | ||||
| 	if err == nil { | ||||
| 		t.Error("解析无效令牌意外成功") | ||||
| 		t.Error("解析格式错误的令牌意外成功") | ||||
| 	} | ||||
|  | ||||
| 	// 测试用例 3: 过期 token (创建一个过期时间在过去的 token) | ||||
| 	// 测试用C:\Users\divano\Desktop\work\AA-Pig\pig-farm-controller\internal\infra\repository\plan_repository_test.go例 3: 过期 token | ||||
| 	expiredClaims := token.Claims{ | ||||
| 		UserID: userID, | ||||
| 		RegisteredClaims: jwt.RegisteredClaims{ | ||||
| @@ -83,13 +86,22 @@ func TestParseToken(t *testing.T) { | ||||
| 		}, | ||||
| 	} | ||||
| 	expiredTokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, expiredClaims) | ||||
| 	expiredTokenString, err := expiredTokenClaims.SignedString(testSecret) | ||||
| 	expiredTokenString, err := expiredTokenClaims.SignedString(correctSecret) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("生成过期令牌失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	_, err = service.ParseToken(expiredTokenString) | ||||
| 	_, err = serviceWithCorrectKey.ParseToken(expiredTokenString) | ||||
| 	if err == nil { | ||||
| 		t.Error("解析过期令牌意外成功") | ||||
| 	} | ||||
|  | ||||
| 	// 新增测试用例 4: 使用错误的密钥解析 | ||||
| 	_, err = serviceWithWrongKey.ParseToken(validToken) | ||||
| 	if err == nil { | ||||
| 		t.Error("使用错误密钥解析令牌意外成功") | ||||
| 	} | ||||
| 	// 我们可以更精确地检查错误类型,以确保它是签名错误 | ||||
| 	if !errors.Is(err, jwt.ErrTokenSignatureInvalid) { | ||||
| 		t.Errorf("期望得到签名无效错误 (ErrTokenSignatureInvalid),但得到了: %v", err) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -127,17 +127,3 @@ func (c *Config) Load(path string) error { | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // GetDatabaseConnectionString 获取数据库连接字符串 | ||||
| func (c *Config) GetDatabaseConnectionString() string { | ||||
| 	// 构建PostgreSQL连接字符串 | ||||
| 	return fmt.Sprintf( | ||||
| 		"user=%s password=%s dbname=%s host=%s port=%d sslmode=%s", | ||||
| 		c.Database.Username, | ||||
| 		c.Database.Password, | ||||
| 		c.Database.DBName, | ||||
| 		c.Database.Host, | ||||
| 		c.Database.Port, | ||||
| 		c.Database.SSLMode, | ||||
| 	) | ||||
| } | ||||
|   | ||||
| @@ -27,7 +27,7 @@ type Logger struct { | ||||
| // 这是实现依赖注入的关键,在应用启动时调用一次。 | ||||
| func NewLogger(cfg config.LogConfig) *Logger { | ||||
| 	// 1. 设置日志编码器 | ||||
| 	encoder := getEncoder(cfg.Format) | ||||
| 	encoder := GetEncoder(cfg.Format) | ||||
|  | ||||
| 	// 2. 设置日志写入器 (支持文件和控制台) | ||||
| 	writeSyncer := getWriteSyncer(cfg) | ||||
| @@ -49,8 +49,8 @@ func NewLogger(cfg config.LogConfig) *Logger { | ||||
| 	return &Logger{zapLogger.Sugar()} | ||||
| } | ||||
|  | ||||
| // getEncoder 根据指定的格式返回一个 zapcore.Encoder。 | ||||
| func getEncoder(format string) zapcore.Encoder { | ||||
| // GetEncoder 根据指定的格式返回一个 zapcore.Encoder。 | ||||
| func GetEncoder(format string) zapcore.Encoder { | ||||
| 	encoderConfig := zap.NewProductionEncoderConfig() | ||||
| 	encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder   // 时间格式: 2006-01-02T15:04:05.000Z0700 | ||||
| 	encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder // 日志级别大写: INFO | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package logs_test | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| @@ -18,12 +19,8 @@ import ( | ||||
| // captureOutput 是一个辅助函数,用于捕获 logger 的输出到内存缓冲区 | ||||
| func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) { | ||||
| 	var buf bytes.Buffer | ||||
| 	// 使用一个简单的 Console Encoder 进行测试,方便断言字符串 | ||||
| 	encoderConfig := zap.NewDevelopmentEncoderConfig() | ||||
| 	encoderConfig.EncodeTime = nil // 忽略时间,避免测试结果不一致 | ||||
| 	encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder | ||||
| 	encoderConfig.EncodeCaller = nil // 忽略调用者信息 | ||||
| 	encoder := zapcore.NewConsoleEncoder(encoderConfig) | ||||
|  | ||||
| 	encoder := logs.GetEncoder(cfg.Format) | ||||
|  | ||||
| 	writer := zapcore.AddSync(&buf) | ||||
|  | ||||
| @@ -31,29 +28,52 @@ func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) { | ||||
| 	_ = level.UnmarshalText([]byte(cfg.Level)) | ||||
|  | ||||
| 	core := zapcore.NewCore(encoder, writer, level) | ||||
| 	// 在测试中我们直接操作 zap Logger,而不是通过封装的 NewLogger,以注入内存 writer | ||||
| 	zapLogger := zap.New(core) | ||||
| 	// 匹配 logs.go 中 NewLogger 的行为,添加调用者信息 | ||||
| 	zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)) | ||||
|  | ||||
| 	logger := &logs.Logger{SugaredLogger: zapLogger.Sugar()} | ||||
| 	return logger, &buf | ||||
| } | ||||
|  | ||||
| func TestNewLogger(t *testing.T) { | ||||
| 	t.Run("构造函数不会 panic", func(t *testing.T) { | ||||
| 		// 测试 Console 格式 | ||||
| 		cfgConsole := config.LogConfig{Level: "info", Format: "console"} | ||||
| 		assert.NotPanics(t, func() { logs.NewLogger(cfgConsole) }) | ||||
| 	t.Run("日志级别应生效", func(t *testing.T) { | ||||
| 		// 1. 创建一个级别为 WARN 的 logger | ||||
| 		logger, buf := captureOutput(config.LogConfig{Level: "warn", Format: "console"}) | ||||
|  | ||||
| 		// 测试 JSON 格式 | ||||
| 		cfgJSON := config.LogConfig{Level: "info", Format: "json"} | ||||
| 		assert.NotPanics(t, func() { logs.NewLogger(cfgJSON) }) | ||||
| 		// 2. 调用不同级别的日志方法 | ||||
| 		logger.Info("这条 info 日志不应被打印") | ||||
| 		logger.Warn("这条 warn 日志应该被打印") | ||||
|  | ||||
| 		// 测试文件日志启用 | ||||
| 		// 不实际写入文件,只确保构造函数能正常运行 | ||||
| 		// 3. 断言输出 | ||||
| 		output := buf.String() | ||||
| 		assert.NotContains(t, output, "这条 info 日志不应被打印") | ||||
| 		assert.Contains(t, output, "这条 warn 日志应该被打印") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("JSON 格式应生效", func(t *testing.T) { | ||||
| 		// 1. 创建一个格式为 JSON 的 logger | ||||
| 		logger, buf := captureOutput(config.LogConfig{Level: "info", Format: "json"}) | ||||
|  | ||||
| 		// 2. 打印一条日志 | ||||
| 		logger.Info("测试json输出") | ||||
|  | ||||
| 		// 3. 断言输出 | ||||
| 		output := buf.String() | ||||
| 		// 验证它是否是合法的 JSON,并且包含预期的键值对 | ||||
| 		var logEntry map[string]interface{} | ||||
| 		// 注意:由于日志库可能会在行尾添加换行符,我们先 trim space | ||||
| 		err := json.Unmarshal([]byte(strings.TrimSpace(output)), &logEntry) | ||||
| 		assert.NoError(t, err, "日志输出应为合法的JSON") | ||||
| 		assert.Equal(t, "INFO", logEntry["level"]) | ||||
| 		assert.Equal(t, "测试json输出", logEntry["msg"]) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("文件日志构造函数不应 panic", func(t *testing.T) { | ||||
| 		// 这个测试保持原样,只验证构造函数在启用文件时不会崩溃 | ||||
| 		// 注意:我们不在单元测试中实际写入文件 | ||||
| 		cfgFile := config.LogConfig{ | ||||
| 			Level:      "info", | ||||
| 			EnableFile: true, | ||||
| 			FilePath:   "test.log", | ||||
| 			FilePath:   "test.log", // 在测试环境中,这个文件不会被真正创建 | ||||
| 		} | ||||
| 		assert.NotPanics(t, func() { logs.NewLogger(cfgFile) }) | ||||
| 	}) | ||||
| @@ -84,7 +104,7 @@ func TestGormLogger(t *testing.T) { | ||||
| 		return sql, rows | ||||
| 	} | ||||
|  | ||||
| 	t.Run("Slow Query", func(t *testing.T) { | ||||
| 	t.Run("慢查询应记录为警告", func(t *testing.T) { | ||||
| 		buf.Reset() | ||||
| 		// 模拟一个耗时超过 200ms 的查询 | ||||
| 		begin := time.Now().Add(-300 * time.Millisecond) | ||||
| @@ -96,7 +116,7 @@ func TestGormLogger(t *testing.T) { | ||||
| 		assert.Contains(t, output, "SELECT * FROM users WHERE id = 1", "应包含 SQL 语句") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("Error Query", func(t *testing.T) { | ||||
| 	t.Run("普通错误应记录为Error", func(t *testing.T) { | ||||
| 		buf.Reset() | ||||
| 		queryError := errors.New("syntax error") | ||||
| 		gormLogger.Trace(ctx, time.Now(), fc, queryError) | ||||
| @@ -106,8 +126,10 @@ func TestGormLogger(t *testing.T) { | ||||
| 		assert.Contains(t, output, "[GORM] error: syntax error") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("Record Not Found Error is Skipped", func(t *testing.T) { | ||||
| 	t.Run("当SkipErrRecordNotFound为true时应跳过RecordNotFound错误", func(t *testing.T) { | ||||
| 		buf.Reset() | ||||
| 		// 确保默认设置是 true | ||||
| 		gormLogger.SkipErrRecordNotFound = true | ||||
| 		// 错误必须包含 "record not found" 字符串以匹配 logs.go 中的判断逻辑 | ||||
| 		queryError := errors.New("record not found") | ||||
| 		gormLogger.Trace(ctx, time.Now(), fc, queryError) | ||||
| @@ -115,7 +137,24 @@ func TestGormLogger(t *testing.T) { | ||||
| 		assert.Empty(t, buf.String(), "开启 SkipErrRecordNotFound 后,record not found 错误不应产生任何日志") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("Normal Query", func(t *testing.T) { | ||||
| 	t.Run("当SkipErrRecordNotFound为false时应记录RecordNotFound错误", func(t *testing.T) { | ||||
| 		buf.Reset() | ||||
| 		// 手动将 SkipErrRecordNotFound 设置为 false | ||||
| 		gormLogger.SkipErrRecordNotFound = false | ||||
|  | ||||
| 		queryError := errors.New("record not found") | ||||
| 		gormLogger.Trace(ctx, time.Now(), fc, queryError) | ||||
|  | ||||
| 		// 恢复设置,避免影响其他测试 | ||||
| 		gormLogger.SkipErrRecordNotFound = true | ||||
|  | ||||
| 		output := buf.String() | ||||
| 		assert.NotEmpty(t, output, "关闭 SkipErrRecordNotFound 后,record not found 错误应该产生日志") | ||||
| 		assert.Contains(t, output, "ERROR") | ||||
| 		assert.Contains(t, output, "[GORM] error: record not found") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("正常查询应记录为Debug", func(t *testing.T) { | ||||
| 		buf.Reset() | ||||
| 		// 模拟一个快速查询 | ||||
| 		gormLogger.Trace(ctx, time.Now(), fc, nil) | ||||
|   | ||||
| @@ -58,7 +58,7 @@ type Device struct { | ||||
| 	gorm.Model | ||||
|  | ||||
| 	// Name 是设备的业务名称,应清晰可读,例如 "1号猪舍温度传感器" 或 "做料车间主控" | ||||
| 	Name string `gorm:"not null" json:"name"` | ||||
| 	Name string `gorm:"unique;not null" json:"name"` | ||||
|  | ||||
| 	// Type 是设备的高级类别,用于区分区域主控和普通设备。建立索引以优化按类型查询。 | ||||
| 	Type DeviceType `gorm:"not null;index" json:"type"` | ||||
|   | ||||
| @@ -38,7 +38,37 @@ func TestUser_CheckPassword(t *testing.T) { | ||||
| 		assert.False(t, match, "空密码应该校验失败") | ||||
| 	}) | ||||
| } | ||||
| func TestUser_BeforeCreate(t *testing.T) { | ||||
| 	t.Run("密码应被成功哈希", func(t *testing.T) { | ||||
| 		plainPassword := "securepassword123" | ||||
| 		user := &models.User{ | ||||
| 			Username: "testuser", | ||||
| 			Password: plainPassword, | ||||
| 		} | ||||
|  | ||||
| // 注意:BeforeSave 钩子是一个 GORM 框架的回调,它的正确性 | ||||
| // 将在 repository 的集成测试中,通过实际创建一个用户来得到验证, | ||||
| // 而不是在这里进行孤立的、脆弱的单元测试。 | ||||
| 		// 模拟 GORM 钩子调用 | ||||
| 		err := user.BeforeCreate(nil) // GORM 钩子通常接收 *gorm.DB,这里我们传入 nil,因为 BeforeCreate 不依赖 DB | ||||
| 		assert.NoError(t, err, "BeforeCreate 不应返回错误") | ||||
|  | ||||
| 		// 验证密码是否已被哈希(不再是明文) | ||||
| 		assert.NotEqual(t, plainPassword, user.Password, "密码应已被哈希") | ||||
|  | ||||
| 		// 验证哈希后的密码是否能被正确校验 | ||||
| 		assert.True(t, user.CheckPassword(plainPassword), "哈希后的密码应能通过校验") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("空密码不应被哈希", func(t *testing.T) { | ||||
| 		plainPassword := "" | ||||
| 		user := &models.User{ | ||||
| 			Username: "empty_pass_user", | ||||
| 			Password: plainPassword, | ||||
| 		} | ||||
|  | ||||
| 		// 模拟 GORM 钩子调用 | ||||
| 		err := user.BeforeCreate(nil) | ||||
| 		assert.NoError(t, err, "BeforeCreate 不应返回错误") | ||||
|  | ||||
| 		// 验证密码仍然是空字符串 | ||||
| 		assert.Equal(t, plainPassword, user.Password, "空密码不应被哈希") | ||||
| 	}) | ||||
| } | ||||
|   | ||||
| @@ -11,97 +11,252 @@ import ( | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| func TestGormDeviceRepository(t *testing.T) { | ||||
| // createTestDevice 辅助函数,用于创建测试设备 | ||||
| func createTestDevice(t *testing.T, db *gorm.DB, name string, deviceType models.DeviceType, parentID *uint) *models.Device { | ||||
| 	device := &models.Device{ | ||||
| 		Name:     name, | ||||
| 		Type:     deviceType, | ||||
| 		ParentID: parentID, | ||||
| 		// 其他字段可以根据需要添加 | ||||
| 	} | ||||
| 	err := db.Create(device).Error | ||||
| 	assert.NoError(t, err) | ||||
| 	return device | ||||
| } | ||||
|  | ||||
| func TestRepoCreate(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 	repo := repository.NewGormDeviceRepository(db) | ||||
|  | ||||
| 	// --- 准备测试数据 --- | ||||
| 	loraProps, _ := json.Marshal(models.LoraProperties{LoraAddress: "0xABCD"}) | ||||
| 	busProps, _ := json.Marshal(models.BusProperties{BusID: 1, BusAddress: 10}) | ||||
|  | ||||
| 	areaController := &models.Device{ | ||||
| 		Name:       "1号猪舍主控", | ||||
| 	t.Run("成功创建区域主控", func(t *testing.T) { | ||||
| 		device := &models.Device{ | ||||
| 			Name:       "主控A", | ||||
| 			Type:       models.DeviceTypeAreaController, | ||||
| 		Location:   "1号猪舍", | ||||
| 			Location:   "猪舍1", | ||||
| 			Properties: loraProps, | ||||
| 		} | ||||
|  | ||||
| 	t.Run("创建 - 成功创建区域主控", func(t *testing.T) { | ||||
| 		err := repo.Create(areaController) | ||||
| 		err := repo.Create(device) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.NotZero(t, areaController.ID, "创建后应获得一个非零ID") | ||||
| 		assert.Nil(t, areaController.ParentID, "区域主控的 ParentID 应为 nil") | ||||
| 		assert.NotZero(t, device.ID, "创建后应获得一个非零ID") | ||||
| 		assert.Nil(t, device.ParentID, "区域主控的 ParentID 应为 nil") | ||||
| 	}) | ||||
|  | ||||
| 	var createdDevice *models.Device | ||||
| 	t.Run("通过ID查找 - 成功找到已创建的设备", func(t *testing.T) { | ||||
| 		var err error | ||||
| 		createdDevice, err = repo.FindByID(areaController.ID) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.NotNil(t, createdDevice) | ||||
| 		assert.Equal(t, areaController.Name, createdDevice.Name) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("通过字符串ID查找 - 使用有效字符串ID找到设备", func(t *testing.T) { | ||||
| 		foundDevice, err := repo.FindByIDString(strconv.FormatUint(uint64(areaController.ID), 10)) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.NotNil(t, foundDevice) | ||||
| 		assert.Equal(t, areaController.ID, foundDevice.ID) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("通过字符串ID查找 - 使用无效字符串ID", func(t *testing.T) { | ||||
| 		_, err := repo.FindByIDString("invalid-id") | ||||
| 		assert.Error(t, err, "使用无效ID字符串应返回错误") | ||||
| 	}) | ||||
|  | ||||
| 	// 创建一个子设备 | ||||
| 	childDevice := &models.Device{ | ||||
| 		Name:       "1号猪舍温度传感器", | ||||
| 	t.Run("成功创建子设备", func(t *testing.T) { | ||||
| 		parent := createTestDevice(t, db, "父设备", models.DeviceTypeAreaController, nil) | ||||
| 		child := &models.Device{ | ||||
| 			Name:     "子设备A", | ||||
| 			Type:     models.DeviceTypeDevice, | ||||
| 		SubType:    models.SubTypeSensorTemp, | ||||
| 		ParentID:   &areaController.ID, | ||||
| 		Location:   "1号猪舍东侧", | ||||
| 		Properties: busProps, | ||||
| 			ParentID: &parent.ID, | ||||
| 		} | ||||
| 		err := repo.Create(child) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.NotZero(t, child.ID) | ||||
| 		assert.NotNil(t, child.ParentID) | ||||
| 		assert.Equal(t, parent.ID, *child.ParentID) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| 	t.Run("创建 - 成功创建子设备", func(t *testing.T) { | ||||
| 		err := repo.Create(childDevice) | ||||
| func TestRepoFindByID(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 	repo := repository.NewGormDeviceRepository(db) | ||||
|  | ||||
| 	device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil) | ||||
|  | ||||
| 	t.Run("成功通过ID查找", func(t *testing.T) { | ||||
| 		foundDevice, err := repo.FindByID(device.ID) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.NotZero(t, childDevice.ID) | ||||
| 		assert.NotNil(t, childDevice.ParentID) | ||||
| 		assert.Equal(t, areaController.ID, *childDevice.ParentID) | ||||
| 		assert.NotNil(t, foundDevice) | ||||
| 		assert.Equal(t, device.ID, foundDevice.ID) | ||||
| 		assert.Equal(t, device.Name, foundDevice.Name) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("通过父ID列出 - 找到子设备", func(t *testing.T) { | ||||
| 		children, err := repo.ListByParentID(&areaController.ID) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Len(t, children, 1, "应找到一个子设备") | ||||
| 		assert.Equal(t, childDevice.ID, children[0].ID) | ||||
| 	t.Run("查找不存在的ID", func(t *testing.T) { | ||||
| 		_, err := repo.FindByID(9999) // 不存在的ID | ||||
| 		assert.Error(t, err) | ||||
| 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("通过父ID列出 - 找到顶层设备", func(t *testing.T) { | ||||
| 	t.Run("数据库查询失败", func(t *testing.T) { | ||||
| 		// 模拟数据库连接关闭,强制查询失败 | ||||
| 		sqlDB, _ := db.DB() | ||||
| 		sqlDB.Close() | ||||
|  | ||||
| 		_, err := repo.FindByID(device.ID) | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "database is closed") | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestRepoFindByIDString(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 	repo := repository.NewGormDeviceRepository(db) | ||||
|  | ||||
| 	device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil) | ||||
|  | ||||
| 	t.Run("成功通过字符串ID查找", func(t *testing.T) { | ||||
| 		idStr := strconv.FormatUint(uint64(device.ID), 10) | ||||
| 		foundDevice, err := repo.FindByIDString(idStr) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.NotNil(t, foundDevice) | ||||
| 		assert.Equal(t, device.ID, foundDevice.ID) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("无效的字符串ID格式", func(t *testing.T) { | ||||
| 		_, err := repo.FindByIDString("invalid-id") | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "无效的设备ID格式") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("查找不存在的字符串ID", func(t *testing.T) { | ||||
| 		idStr := strconv.FormatUint(uint64(9999), 10) // 不存在的ID | ||||
| 		_, err := repo.FindByIDString(idStr) | ||||
| 		assert.Error(t, err) | ||||
| 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("数据库查询失败", func(t *testing.T) { | ||||
| 		sqlDB, _ := db.DB() | ||||
| 		sqlDB.Close() | ||||
|  | ||||
| 		idStr := strconv.FormatUint(uint64(device.ID), 10) | ||||
| 		_, err := repo.FindByIDString(idStr) | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "database is closed") | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestRepoListAll(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 	repo := repository.NewGormDeviceRepository(db) | ||||
|  | ||||
| 	t.Run("成功获取空列表", func(t *testing.T) { | ||||
| 		devices, err := repo.ListAll() | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Empty(t, devices) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("成功获取包含设备的列表", func(t *testing.T) { | ||||
| 		createTestDevice(t, db, "设备1", models.DeviceTypeAreaController, nil) | ||||
| 		createTestDevice(t, db, "设备2", models.DeviceTypeDevice, nil) | ||||
|  | ||||
| 		devices, err := repo.ListAll() | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Len(t, devices, 2) | ||||
| 		assert.Equal(t, "设备1", devices[0].Name) | ||||
| 		assert.Equal(t, "设备2", devices[1].Name) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("数据库查询失败", func(t *testing.T) { | ||||
| 		sqlDB, _ := db.DB() | ||||
| 		sqlDB.Close() | ||||
|  | ||||
| 		_, err := repo.ListAll() | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "database is closed") | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestRepoListByParentID(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 	repo := repository.NewGormDeviceRepository(db) | ||||
|  | ||||
| 	parent1 := createTestDevice(t, db, "父设备1", models.DeviceTypeAreaController, nil) | ||||
| 	parent2 := createTestDevice(t, db, "父设备2", models.DeviceTypeAreaController, nil) | ||||
| 	child1_1 := createTestDevice(t, db, "子设备1-1", models.DeviceTypeDevice, &parent1.ID) | ||||
| 	child1_2 := createTestDevice(t, db, "子设备1-2", models.DeviceTypeDevice, &parent1.ID) | ||||
| 	_ = createTestDevice(t, db, "子设备2-1", models.DeviceTypeDevice, &parent2.ID) | ||||
|  | ||||
| 	t.Run("成功通过父ID查找子设备", func(t *testing.T) { | ||||
| 		children, err := repo.ListByParentID(&parent1.ID) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Len(t, children, 2) | ||||
| 		assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[0].ID) | ||||
| 		assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[1].ID) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("成功通过nil父ID查找顶层设备", func(t *testing.T) { | ||||
| 		parents, err := repo.ListByParentID(nil) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Len(t, parents, 1, "应找到一个顶层设备") | ||||
| 		assert.Equal(t, areaController.ID, parents[0].ID) | ||||
| 		assert.Len(t, parents, 2) | ||||
| 		assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[0].ID) | ||||
| 		assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[1].ID) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("更新 - 成功更新设备信息", func(t *testing.T) { | ||||
| 		childDevice.Location = "1号猪舍西侧" | ||||
| 		err := repo.Update(childDevice) | ||||
| 	t.Run("查找不存在的父ID", func(t *testing.T) { | ||||
| 		nonExistentParentID := uint(9999) | ||||
| 		children, err := repo.ListByParentID(&nonExistentParentID) | ||||
| 		assert.NoError(t, err) // GORM 在未找到时返回空列表而不是错误 | ||||
| 		assert.Empty(t, children) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("数据库查询失败", func(t *testing.T) { | ||||
| 		sqlDB, _ := db.DB() | ||||
| 		sqlDB.Close() | ||||
|  | ||||
| 		_, err := repo.ListByParentID(&parent1.ID) | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "database is closed") | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestRepoUpdate(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 	repo := repository.NewGormDeviceRepository(db) | ||||
|  | ||||
| 	device := createTestDevice(t, db, "原始设备", models.DeviceTypeAreaController, nil) | ||||
|  | ||||
| 	t.Run("成功更新设备信息", func(t *testing.T) { | ||||
| 		device.Name = "更新后的设备" | ||||
| 		device.Location = "新地点" | ||||
| 		err := repo.Update(device) | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		updatedDevice, _ := repo.FindByID(childDevice.ID) | ||||
| 		assert.Equal(t, "1号猪舍西侧", updatedDevice.Location) | ||||
| 		updatedDevice, err := repo.FindByID(device.ID) | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Equal(t, "更新后的设备", updatedDevice.Name) | ||||
| 		assert.Equal(t, "新地点", updatedDevice.Location) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("删除 - 成功删除设备", func(t *testing.T) { | ||||
| 		err := repo.Delete(childDevice.ID) | ||||
| 	t.Run("数据库更新失败", func(t *testing.T) { | ||||
| 		sqlDB, _ := db.DB() | ||||
| 		sqlDB.Close() | ||||
|  | ||||
| 		device.Name = "更新失败的设备" | ||||
| 		err := repo.Update(device) | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "database is closed") | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestRepoDelete(t *testing.T) { | ||||
| 	db := setupTestDB(t) | ||||
| 	repo := repository.NewGormDeviceRepository(db) | ||||
|  | ||||
| 	device := createTestDevice(t, db, "待删除设备", models.DeviceTypeAreaController, nil) | ||||
|  | ||||
| 	t.Run("成功删除设备", func(t *testing.T) { | ||||
| 		err := repo.Delete(device.ID) | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 验证设备已被软删除 | ||||
| 		_, err = repo.FindByID(childDevice.ID) | ||||
| 		_, err = repo.FindByID(device.ID) | ||||
| 		assert.Error(t, err, "删除后应无法找到设备") | ||||
| 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 RecordNotFound") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("删除不存在的设备", func(t *testing.T) { | ||||
| 		err := repo.Delete(9999) // 不存在的ID | ||||
| 		assert.NoError(t, err)   // GORM 的 Delete 方法在删除不存在的记录时不会报错 | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("数据库删除失败", func(t *testing.T) { | ||||
| 		sqlDB, _ := db.DB() | ||||
| 		sqlDB.Close() | ||||
|  | ||||
| 		err := repo.Delete(device.ID) | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "database is closed") | ||||
| 	}) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user