重构 #4
							
								
								
									
										8
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								Makefile
									
									
									
									
									
								
							| @@ -9,6 +9,7 @@ help: | |||||||
| 	@echo "  run          Run the application" | 	@echo "  run          Run the application" | ||||||
| 	@echo "  build        Build the application" | 	@echo "  build        Build the application" | ||||||
| 	@echo "  clean        Clean generated files" | 	@echo "  clean        Clean generated files" | ||||||
|  | 	@echo "  test         Run all tests" | ||||||
| 	@echo "  help         Show this help message" | 	@echo "  help         Show this help message" | ||||||
|  |  | ||||||
| # 运行应用 | # 运行应用 | ||||||
| @@ -24,4 +25,9 @@ build: | |||||||
| # 清理生成文件 | # 清理生成文件 | ||||||
| .PHONY: clean | .PHONY: clean | ||||||
| clean: | clean: | ||||||
| 	rm -f bin/pig-farm-controller | 	rm -f bin/pig-farm-controller | ||||||
|  |  | ||||||
|  | # 运行所有测试 | ||||||
|  | .PHONY: test | ||||||
|  | test: | ||||||
|  | 	go test --count=1 ./... | ||||||
| @@ -1,6 +1,7 @@ | |||||||
| package token | package token | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/golang-jwt/jwt/v5" | 	"github.com/golang-jwt/jwt/v5" | ||||||
| @@ -52,9 +53,16 @@ func (s *tokenService) ParseToken(tokenString string) (*Claims, error) { | |||||||
| 		return s.secret, nil | 		return s.secret, nil | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
|  | 	// 优先检查解析过程中是否发生错误 | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// 只有当 token 对象有效时,才尝试获取 Claims 并验证 | ||||||
| 	if claims, ok := token.Claims.(*Claims); ok && token.Valid { | 	if claims, ok := token.Claims.(*Claims); ok && token.Valid { | ||||||
| 		return claims, nil | 		return claims, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return nil, err | 	// 如果 token 无效(例如,过期但没有返回错误,或者 Claims 类型不匹配),则返回一个通用错误 | ||||||
|  | 	return nil, fmt.Errorf("token is invalid") | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package token_test | package token_test | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| @@ -45,36 +46,38 @@ func TestGenerateToken(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestParseToken(t *testing.T) { | func TestParseToken(t *testing.T) { | ||||||
| 	// 使用一个测试密钥初始化 TokenService | 	// 使用两个不同的测试密钥 | ||||||
| 	testSecret := []byte("test_secret_key") | 	correctSecret := []byte("the_correct_secret") | ||||||
| 	service := token.NewTokenService(testSecret) | 	wrongSecret := []byte("a_very_wrong_secret") | ||||||
|  |  | ||||||
|  | 	serviceWithCorrectKey := token.NewTokenService(correctSecret) | ||||||
|  | 	serviceWithWrongKey := token.NewTokenService(wrongSecret) | ||||||
|  |  | ||||||
| 	userID := uint(456) | 	userID := uint(456) | ||||||
|  |  | ||||||
| 	// 生成一个有效的 token 用于解析测试 | 	// 1. 生成一个有效的 token | ||||||
| 	validToken, err := service.GenerateToken(userID) | 	validToken, err := serviceWithCorrectKey.GenerateToken(userID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("为解析测试生成有效令牌失败: %v", err) | 		t.Fatalf("为解析测试生成有效令牌失败: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// 测试用例 1: 有效 token | 	// 测试用例 1: 使用正确的密钥成功解析 | ||||||
| 	claims, err := service.ParseToken(validToken) | 	claims, err := serviceWithCorrectKey.ParseToken(validToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Errorf("解析有效令牌失败: %v", err) | 		t.Errorf("使用正确密钥解析有效令牌失败: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if claims.UserID != userID { | 	if claims.UserID != userID { | ||||||
| 		t.Errorf("解析有效令牌时期望用户ID %d, 实际为 %d", userID, claims.UserID) | 		t.Errorf("解析有效令牌时期望用户ID %d, 实际为 %d", userID, claims.UserID) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// 测试用例 2: 无效 token (例如, 错误的签名) | 	// 测试用例 2: 无效 token (例如, 格式错误的字符串) | ||||||
| 	invalidToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoxMjMsImV4cCI6MTY3ODkwNTYwMCwiaXNzIjoicGlnLWZhcm0tY29udHJvbGxlciJ9.invalid_signature_here" | 	invalidTokenString := "this.is.not.a.valid.jwt" | ||||||
| 	_, err = service.ParseToken(invalidToken) | 	_, err = serviceWithCorrectKey.ParseToken(invalidTokenString) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		t.Error("解析无效令牌意外成功") | 		t.Error("解析格式错误的令牌意外成功") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// 测试用例 3: 过期 token (创建一个过期时间在过去的 token) | 	// 测试用C:\Users\divano\Desktop\work\AA-Pig\pig-farm-controller\internal\infra\repository\plan_repository_test.go例 3: 过期 token | ||||||
| 	expiredClaims := token.Claims{ | 	expiredClaims := token.Claims{ | ||||||
| 		UserID: userID, | 		UserID: userID, | ||||||
| 		RegisteredClaims: jwt.RegisteredClaims{ | 		RegisteredClaims: jwt.RegisteredClaims{ | ||||||
| @@ -83,13 +86,22 @@ func TestParseToken(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	expiredTokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, expiredClaims) | 	expiredTokenClaims := jwt.NewWithClaims(jwt.SigningMethodHS256, expiredClaims) | ||||||
| 	expiredTokenString, err := expiredTokenClaims.SignedString(testSecret) | 	expiredTokenString, err := expiredTokenClaims.SignedString(correctSecret) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatalf("生成过期令牌失败: %v", err) | 		t.Fatalf("生成过期令牌失败: %v", err) | ||||||
| 	} | 	} | ||||||
|  | 	_, err = serviceWithCorrectKey.ParseToken(expiredTokenString) | ||||||
| 	_, err = service.ParseToken(expiredTokenString) |  | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		t.Error("解析过期令牌意外成功") | 		t.Error("解析过期令牌意外成功") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// 新增测试用例 4: 使用错误的密钥解析 | ||||||
|  | 	_, err = serviceWithWrongKey.ParseToken(validToken) | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Error("使用错误密钥解析令牌意外成功") | ||||||
|  | 	} | ||||||
|  | 	// 我们可以更精确地检查错误类型,以确保它是签名错误 | ||||||
|  | 	if !errors.Is(err, jwt.ErrTokenSignatureInvalid) { | ||||||
|  | 		t.Errorf("期望得到签名无效错误 (ErrTokenSignatureInvalid),但得到了: %v", err) | ||||||
|  | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -127,17 +127,3 @@ func (c *Config) Load(path string) error { | |||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // GetDatabaseConnectionString 获取数据库连接字符串 |  | ||||||
| func (c *Config) GetDatabaseConnectionString() string { |  | ||||||
| 	// 构建PostgreSQL连接字符串 |  | ||||||
| 	return fmt.Sprintf( |  | ||||||
| 		"user=%s password=%s dbname=%s host=%s port=%d sslmode=%s", |  | ||||||
| 		c.Database.Username, |  | ||||||
| 		c.Database.Password, |  | ||||||
| 		c.Database.DBName, |  | ||||||
| 		c.Database.Host, |  | ||||||
| 		c.Database.Port, |  | ||||||
| 		c.Database.SSLMode, |  | ||||||
| 	) |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -27,7 +27,7 @@ type Logger struct { | |||||||
| // 这是实现依赖注入的关键,在应用启动时调用一次。 | // 这是实现依赖注入的关键,在应用启动时调用一次。 | ||||||
| func NewLogger(cfg config.LogConfig) *Logger { | func NewLogger(cfg config.LogConfig) *Logger { | ||||||
| 	// 1. 设置日志编码器 | 	// 1. 设置日志编码器 | ||||||
| 	encoder := getEncoder(cfg.Format) | 	encoder := GetEncoder(cfg.Format) | ||||||
|  |  | ||||||
| 	// 2. 设置日志写入器 (支持文件和控制台) | 	// 2. 设置日志写入器 (支持文件和控制台) | ||||||
| 	writeSyncer := getWriteSyncer(cfg) | 	writeSyncer := getWriteSyncer(cfg) | ||||||
| @@ -49,8 +49,8 @@ func NewLogger(cfg config.LogConfig) *Logger { | |||||||
| 	return &Logger{zapLogger.Sugar()} | 	return &Logger{zapLogger.Sugar()} | ||||||
| } | } | ||||||
|  |  | ||||||
| // getEncoder 根据指定的格式返回一个 zapcore.Encoder。 | // GetEncoder 根据指定的格式返回一个 zapcore.Encoder。 | ||||||
| func getEncoder(format string) zapcore.Encoder { | func GetEncoder(format string) zapcore.Encoder { | ||||||
| 	encoderConfig := zap.NewProductionEncoderConfig() | 	encoderConfig := zap.NewProductionEncoderConfig() | ||||||
| 	encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder   // 时间格式: 2006-01-02T15:04:05.000Z0700 | 	encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder   // 时间格式: 2006-01-02T15:04:05.000Z0700 | ||||||
| 	encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder // 日志级别大写: INFO | 	encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder // 日志级别大写: INFO | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ package logs_test | |||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
| @@ -18,12 +19,8 @@ import ( | |||||||
| // captureOutput 是一个辅助函数,用于捕获 logger 的输出到内存缓冲区 | // captureOutput 是一个辅助函数,用于捕获 logger 的输出到内存缓冲区 | ||||||
| func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) { | func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) { | ||||||
| 	var buf bytes.Buffer | 	var buf bytes.Buffer | ||||||
| 	// 使用一个简单的 Console Encoder 进行测试,方便断言字符串 |  | ||||||
| 	encoderConfig := zap.NewDevelopmentEncoderConfig() | 	encoder := logs.GetEncoder(cfg.Format) | ||||||
| 	encoderConfig.EncodeTime = nil // 忽略时间,避免测试结果不一致 |  | ||||||
| 	encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder |  | ||||||
| 	encoderConfig.EncodeCaller = nil // 忽略调用者信息 |  | ||||||
| 	encoder := zapcore.NewConsoleEncoder(encoderConfig) |  | ||||||
|  |  | ||||||
| 	writer := zapcore.AddSync(&buf) | 	writer := zapcore.AddSync(&buf) | ||||||
|  |  | ||||||
| @@ -31,29 +28,52 @@ func captureOutput(cfg config.LogConfig) (*logs.Logger, *bytes.Buffer) { | |||||||
| 	_ = level.UnmarshalText([]byte(cfg.Level)) | 	_ = level.UnmarshalText([]byte(cfg.Level)) | ||||||
|  |  | ||||||
| 	core := zapcore.NewCore(encoder, writer, level) | 	core := zapcore.NewCore(encoder, writer, level) | ||||||
| 	// 在测试中我们直接操作 zap Logger,而不是通过封装的 NewLogger,以注入内存 writer | 	// 匹配 logs.go 中 NewLogger 的行为,添加调用者信息 | ||||||
| 	zapLogger := zap.New(core) | 	zapLogger := zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)) | ||||||
|  |  | ||||||
| 	logger := &logs.Logger{SugaredLogger: zapLogger.Sugar()} | 	logger := &logs.Logger{SugaredLogger: zapLogger.Sugar()} | ||||||
| 	return logger, &buf | 	return logger, &buf | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestNewLogger(t *testing.T) { | func TestNewLogger(t *testing.T) { | ||||||
| 	t.Run("构造函数不会 panic", func(t *testing.T) { | 	t.Run("日志级别应生效", func(t *testing.T) { | ||||||
| 		// 测试 Console 格式 | 		// 1. 创建一个级别为 WARN 的 logger | ||||||
| 		cfgConsole := config.LogConfig{Level: "info", Format: "console"} | 		logger, buf := captureOutput(config.LogConfig{Level: "warn", Format: "console"}) | ||||||
| 		assert.NotPanics(t, func() { logs.NewLogger(cfgConsole) }) |  | ||||||
|  |  | ||||||
| 		// 测试 JSON 格式 | 		// 2. 调用不同级别的日志方法 | ||||||
| 		cfgJSON := config.LogConfig{Level: "info", Format: "json"} | 		logger.Info("这条 info 日志不应被打印") | ||||||
| 		assert.NotPanics(t, func() { logs.NewLogger(cfgJSON) }) | 		logger.Warn("这条 warn 日志应该被打印") | ||||||
|  |  | ||||||
| 		// 测试文件日志启用 | 		// 3. 断言输出 | ||||||
| 		// 不实际写入文件,只确保构造函数能正常运行 | 		output := buf.String() | ||||||
|  | 		assert.NotContains(t, output, "这条 info 日志不应被打印") | ||||||
|  | 		assert.Contains(t, output, "这条 warn 日志应该被打印") | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("JSON 格式应生效", func(t *testing.T) { | ||||||
|  | 		// 1. 创建一个格式为 JSON 的 logger | ||||||
|  | 		logger, buf := captureOutput(config.LogConfig{Level: "info", Format: "json"}) | ||||||
|  |  | ||||||
|  | 		// 2. 打印一条日志 | ||||||
|  | 		logger.Info("测试json输出") | ||||||
|  |  | ||||||
|  | 		// 3. 断言输出 | ||||||
|  | 		output := buf.String() | ||||||
|  | 		// 验证它是否是合法的 JSON,并且包含预期的键值对 | ||||||
|  | 		var logEntry map[string]interface{} | ||||||
|  | 		// 注意:由于日志库可能会在行尾添加换行符,我们先 trim space | ||||||
|  | 		err := json.Unmarshal([]byte(strings.TrimSpace(output)), &logEntry) | ||||||
|  | 		assert.NoError(t, err, "日志输出应为合法的JSON") | ||||||
|  | 		assert.Equal(t, "INFO", logEntry["level"]) | ||||||
|  | 		assert.Equal(t, "测试json输出", logEntry["msg"]) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("文件日志构造函数不应 panic", func(t *testing.T) { | ||||||
|  | 		// 这个测试保持原样,只验证构造函数在启用文件时不会崩溃 | ||||||
|  | 		// 注意:我们不在单元测试中实际写入文件 | ||||||
| 		cfgFile := config.LogConfig{ | 		cfgFile := config.LogConfig{ | ||||||
| 			Level:      "info", | 			Level:      "info", | ||||||
| 			EnableFile: true, | 			EnableFile: true, | ||||||
| 			FilePath:   "test.log", | 			FilePath:   "test.log", // 在测试环境中,这个文件不会被真正创建 | ||||||
| 		} | 		} | ||||||
| 		assert.NotPanics(t, func() { logs.NewLogger(cfgFile) }) | 		assert.NotPanics(t, func() { logs.NewLogger(cfgFile) }) | ||||||
| 	}) | 	}) | ||||||
| @@ -84,7 +104,7 @@ func TestGormLogger(t *testing.T) { | |||||||
| 		return sql, rows | 		return sql, rows | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	t.Run("Slow Query", func(t *testing.T) { | 	t.Run("慢查询应记录为警告", func(t *testing.T) { | ||||||
| 		buf.Reset() | 		buf.Reset() | ||||||
| 		// 模拟一个耗时超过 200ms 的查询 | 		// 模拟一个耗时超过 200ms 的查询 | ||||||
| 		begin := time.Now().Add(-300 * time.Millisecond) | 		begin := time.Now().Add(-300 * time.Millisecond) | ||||||
| @@ -96,7 +116,7 @@ func TestGormLogger(t *testing.T) { | |||||||
| 		assert.Contains(t, output, "SELECT * FROM users WHERE id = 1", "应包含 SQL 语句") | 		assert.Contains(t, output, "SELECT * FROM users WHERE id = 1", "应包含 SQL 语句") | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("Error Query", func(t *testing.T) { | 	t.Run("普通错误应记录为Error", func(t *testing.T) { | ||||||
| 		buf.Reset() | 		buf.Reset() | ||||||
| 		queryError := errors.New("syntax error") | 		queryError := errors.New("syntax error") | ||||||
| 		gormLogger.Trace(ctx, time.Now(), fc, queryError) | 		gormLogger.Trace(ctx, time.Now(), fc, queryError) | ||||||
| @@ -106,8 +126,10 @@ func TestGormLogger(t *testing.T) { | |||||||
| 		assert.Contains(t, output, "[GORM] error: syntax error") | 		assert.Contains(t, output, "[GORM] error: syntax error") | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("Record Not Found Error is Skipped", func(t *testing.T) { | 	t.Run("当SkipErrRecordNotFound为true时应跳过RecordNotFound错误", func(t *testing.T) { | ||||||
| 		buf.Reset() | 		buf.Reset() | ||||||
|  | 		// 确保默认设置是 true | ||||||
|  | 		gormLogger.SkipErrRecordNotFound = true | ||||||
| 		// 错误必须包含 "record not found" 字符串以匹配 logs.go 中的判断逻辑 | 		// 错误必须包含 "record not found" 字符串以匹配 logs.go 中的判断逻辑 | ||||||
| 		queryError := errors.New("record not found") | 		queryError := errors.New("record not found") | ||||||
| 		gormLogger.Trace(ctx, time.Now(), fc, queryError) | 		gormLogger.Trace(ctx, time.Now(), fc, queryError) | ||||||
| @@ -115,7 +137,24 @@ func TestGormLogger(t *testing.T) { | |||||||
| 		assert.Empty(t, buf.String(), "开启 SkipErrRecordNotFound 后,record not found 错误不应产生任何日志") | 		assert.Empty(t, buf.String(), "开启 SkipErrRecordNotFound 后,record not found 错误不应产生任何日志") | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("Normal Query", func(t *testing.T) { | 	t.Run("当SkipErrRecordNotFound为false时应记录RecordNotFound错误", func(t *testing.T) { | ||||||
|  | 		buf.Reset() | ||||||
|  | 		// 手动将 SkipErrRecordNotFound 设置为 false | ||||||
|  | 		gormLogger.SkipErrRecordNotFound = false | ||||||
|  |  | ||||||
|  | 		queryError := errors.New("record not found") | ||||||
|  | 		gormLogger.Trace(ctx, time.Now(), fc, queryError) | ||||||
|  |  | ||||||
|  | 		// 恢复设置,避免影响其他测试 | ||||||
|  | 		gormLogger.SkipErrRecordNotFound = true | ||||||
|  |  | ||||||
|  | 		output := buf.String() | ||||||
|  | 		assert.NotEmpty(t, output, "关闭 SkipErrRecordNotFound 后,record not found 错误应该产生日志") | ||||||
|  | 		assert.Contains(t, output, "ERROR") | ||||||
|  | 		assert.Contains(t, output, "[GORM] error: record not found") | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("正常查询应记录为Debug", func(t *testing.T) { | ||||||
| 		buf.Reset() | 		buf.Reset() | ||||||
| 		// 模拟一个快速查询 | 		// 模拟一个快速查询 | ||||||
| 		gormLogger.Trace(ctx, time.Now(), fc, nil) | 		gormLogger.Trace(ctx, time.Now(), fc, nil) | ||||||
|   | |||||||
| @@ -58,7 +58,7 @@ type Device struct { | |||||||
| 	gorm.Model | 	gorm.Model | ||||||
|  |  | ||||||
| 	// Name 是设备的业务名称,应清晰可读,例如 "1号猪舍温度传感器" 或 "做料车间主控" | 	// Name 是设备的业务名称,应清晰可读,例如 "1号猪舍温度传感器" 或 "做料车间主控" | ||||||
| 	Name string `gorm:"not null" json:"name"` | 	Name string `gorm:"unique;not null" json:"name"` | ||||||
|  |  | ||||||
| 	// Type 是设备的高级类别,用于区分区域主控和普通设备。建立索引以优化按类型查询。 | 	// Type 是设备的高级类别,用于区分区域主控和普通设备。建立索引以优化按类型查询。 | ||||||
| 	Type DeviceType `gorm:"not null;index" json:"type"` | 	Type DeviceType `gorm:"not null;index" json:"type"` | ||||||
|   | |||||||
| @@ -38,7 +38,37 @@ func TestUser_CheckPassword(t *testing.T) { | |||||||
| 		assert.False(t, match, "空密码应该校验失败") | 		assert.False(t, match, "空密码应该校验失败") | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  | func TestUser_BeforeCreate(t *testing.T) { | ||||||
|  | 	t.Run("密码应被成功哈希", func(t *testing.T) { | ||||||
|  | 		plainPassword := "securepassword123" | ||||||
|  | 		user := &models.User{ | ||||||
|  | 			Username: "testuser", | ||||||
|  | 			Password: plainPassword, | ||||||
|  | 		} | ||||||
|  |  | ||||||
| // 注意:BeforeSave 钩子是一个 GORM 框架的回调,它的正确性 | 		// 模拟 GORM 钩子调用 | ||||||
| // 将在 repository 的集成测试中,通过实际创建一个用户来得到验证, | 		err := user.BeforeCreate(nil) // GORM 钩子通常接收 *gorm.DB,这里我们传入 nil,因为 BeforeCreate 不依赖 DB | ||||||
| // 而不是在这里进行孤立的、脆弱的单元测试。 | 		assert.NoError(t, err, "BeforeCreate 不应返回错误") | ||||||
|  |  | ||||||
|  | 		// 验证密码是否已被哈希(不再是明文) | ||||||
|  | 		assert.NotEqual(t, plainPassword, user.Password, "密码应已被哈希") | ||||||
|  |  | ||||||
|  | 		// 验证哈希后的密码是否能被正确校验 | ||||||
|  | 		assert.True(t, user.CheckPassword(plainPassword), "哈希后的密码应能通过校验") | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("空密码不应被哈希", func(t *testing.T) { | ||||||
|  | 		plainPassword := "" | ||||||
|  | 		user := &models.User{ | ||||||
|  | 			Username: "empty_pass_user", | ||||||
|  | 			Password: plainPassword, | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// 模拟 GORM 钩子调用 | ||||||
|  | 		err := user.BeforeCreate(nil) | ||||||
|  | 		assert.NoError(t, err, "BeforeCreate 不应返回错误") | ||||||
|  |  | ||||||
|  | 		// 验证密码仍然是空字符串 | ||||||
|  | 		assert.Equal(t, plainPassword, user.Password, "空密码不应被哈希") | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|   | |||||||
| @@ -11,97 +11,252 @@ import ( | |||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestGormDeviceRepository(t *testing.T) { | // createTestDevice 辅助函数,用于创建测试设备 | ||||||
|  | func createTestDevice(t *testing.T, db *gorm.DB, name string, deviceType models.DeviceType, parentID *uint) *models.Device { | ||||||
|  | 	device := &models.Device{ | ||||||
|  | 		Name:     name, | ||||||
|  | 		Type:     deviceType, | ||||||
|  | 		ParentID: parentID, | ||||||
|  | 		// 其他字段可以根据需要添加 | ||||||
|  | 	} | ||||||
|  | 	err := db.Create(device).Error | ||||||
|  | 	assert.NoError(t, err) | ||||||
|  | 	return device | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRepoCreate(t *testing.T) { | ||||||
| 	db := setupTestDB(t) | 	db := setupTestDB(t) | ||||||
| 	repo := repository.NewGormDeviceRepository(db) | 	repo := repository.NewGormDeviceRepository(db) | ||||||
|  |  | ||||||
| 	// --- 准备测试数据 --- |  | ||||||
| 	loraProps, _ := json.Marshal(models.LoraProperties{LoraAddress: "0xABCD"}) | 	loraProps, _ := json.Marshal(models.LoraProperties{LoraAddress: "0xABCD"}) | ||||||
| 	busProps, _ := json.Marshal(models.BusProperties{BusID: 1, BusAddress: 10}) |  | ||||||
|  |  | ||||||
| 	areaController := &models.Device{ | 	t.Run("成功创建区域主控", func(t *testing.T) { | ||||||
| 		Name:       "1号猪舍主控", | 		device := &models.Device{ | ||||||
| 		Type:       models.DeviceTypeAreaController, | 			Name:       "主控A", | ||||||
| 		Location:   "1号猪舍", | 			Type:       models.DeviceTypeAreaController, | ||||||
| 		Properties: loraProps, | 			Location:   "猪舍1", | ||||||
| 	} | 			Properties: loraProps, | ||||||
|  | 		} | ||||||
| 	t.Run("创建 - 成功创建区域主控", func(t *testing.T) { | 		err := repo.Create(device) | ||||||
| 		err := repo.Create(areaController) |  | ||||||
| 		assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 		assert.NotZero(t, areaController.ID, "创建后应获得一个非零ID") | 		assert.NotZero(t, device.ID, "创建后应获得一个非零ID") | ||||||
| 		assert.Nil(t, areaController.ParentID, "区域主控的 ParentID 应为 nil") | 		assert.Nil(t, device.ParentID, "区域主控的 ParentID 应为 nil") | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	var createdDevice *models.Device | 	t.Run("成功创建子设备", func(t *testing.T) { | ||||||
| 	t.Run("通过ID查找 - 成功找到已创建的设备", func(t *testing.T) { | 		parent := createTestDevice(t, db, "父设备", models.DeviceTypeAreaController, nil) | ||||||
| 		var err error | 		child := &models.Device{ | ||||||
| 		createdDevice, err = repo.FindByID(areaController.ID) | 			Name:     "子设备A", | ||||||
|  | 			Type:     models.DeviceTypeDevice, | ||||||
|  | 			ParentID: &parent.ID, | ||||||
|  | 		} | ||||||
|  | 		err := repo.Create(child) | ||||||
| 		assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 		assert.NotNil(t, createdDevice) | 		assert.NotZero(t, child.ID) | ||||||
| 		assert.Equal(t, areaController.Name, createdDevice.Name) | 		assert.NotNil(t, child.ParentID) | ||||||
|  | 		assert.Equal(t, parent.ID, *child.ParentID) | ||||||
| 	}) | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
| 	t.Run("通过字符串ID查找 - 使用有效字符串ID找到设备", func(t *testing.T) { | func TestRepoFindByID(t *testing.T) { | ||||||
| 		foundDevice, err := repo.FindByIDString(strconv.FormatUint(uint64(areaController.ID), 10)) | 	db := setupTestDB(t) | ||||||
|  | 	repo := repository.NewGormDeviceRepository(db) | ||||||
|  |  | ||||||
|  | 	device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil) | ||||||
|  |  | ||||||
|  | 	t.Run("成功通过ID查找", func(t *testing.T) { | ||||||
|  | 		foundDevice, err := repo.FindByID(device.ID) | ||||||
| 		assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 		assert.NotNil(t, foundDevice) | 		assert.NotNil(t, foundDevice) | ||||||
| 		assert.Equal(t, areaController.ID, foundDevice.ID) | 		assert.Equal(t, device.ID, foundDevice.ID) | ||||||
|  | 		assert.Equal(t, device.Name, foundDevice.Name) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("通过字符串ID查找 - 使用无效字符串ID", func(t *testing.T) { | 	t.Run("查找不存在的ID", func(t *testing.T) { | ||||||
|  | 		_, err := repo.FindByID(9999) // 不存在的ID | ||||||
|  | 		assert.Error(t, err) | ||||||
|  | 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("数据库查询失败", func(t *testing.T) { | ||||||
|  | 		// 模拟数据库连接关闭,强制查询失败 | ||||||
|  | 		sqlDB, _ := db.DB() | ||||||
|  | 		sqlDB.Close() | ||||||
|  |  | ||||||
|  | 		_, err := repo.FindByID(device.ID) | ||||||
|  | 		assert.Error(t, err) | ||||||
|  | 		assert.Contains(t, err.Error(), "database is closed") | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRepoFindByIDString(t *testing.T) { | ||||||
|  | 	db := setupTestDB(t) | ||||||
|  | 	repo := repository.NewGormDeviceRepository(db) | ||||||
|  |  | ||||||
|  | 	device := createTestDevice(t, db, "测试设备", models.DeviceTypeAreaController, nil) | ||||||
|  |  | ||||||
|  | 	t.Run("成功通过字符串ID查找", func(t *testing.T) { | ||||||
|  | 		idStr := strconv.FormatUint(uint64(device.ID), 10) | ||||||
|  | 		foundDevice, err := repo.FindByIDString(idStr) | ||||||
|  | 		assert.NoError(t, err) | ||||||
|  | 		assert.NotNil(t, foundDevice) | ||||||
|  | 		assert.Equal(t, device.ID, foundDevice.ID) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("无效的字符串ID格式", func(t *testing.T) { | ||||||
| 		_, err := repo.FindByIDString("invalid-id") | 		_, err := repo.FindByIDString("invalid-id") | ||||||
| 		assert.Error(t, err, "使用无效ID字符串应返回错误") | 		assert.Error(t, err) | ||||||
|  | 		assert.Contains(t, err.Error(), "无效的设备ID格式") | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	// 创建一个子设备 | 	t.Run("查找不存在的字符串ID", func(t *testing.T) { | ||||||
| 	childDevice := &models.Device{ | 		idStr := strconv.FormatUint(uint64(9999), 10) // 不存在的ID | ||||||
| 		Name:       "1号猪舍温度传感器", | 		_, err := repo.FindByIDString(idStr) | ||||||
| 		Type:       models.DeviceTypeDevice, | 		assert.Error(t, err) | ||||||
| 		SubType:    models.SubTypeSensorTemp, | 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound) | ||||||
| 		ParentID:   &areaController.ID, | 	}) | ||||||
| 		Location:   "1号猪舍东侧", |  | ||||||
| 		Properties: busProps, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	t.Run("创建 - 成功创建子设备", func(t *testing.T) { | 	t.Run("数据库查询失败", func(t *testing.T) { | ||||||
| 		err := repo.Create(childDevice) | 		sqlDB, _ := db.DB() | ||||||
|  | 		sqlDB.Close() | ||||||
|  |  | ||||||
|  | 		idStr := strconv.FormatUint(uint64(device.ID), 10) | ||||||
|  | 		_, err := repo.FindByIDString(idStr) | ||||||
|  | 		assert.Error(t, err) | ||||||
|  | 		assert.Contains(t, err.Error(), "database is closed") | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRepoListAll(t *testing.T) { | ||||||
|  | 	db := setupTestDB(t) | ||||||
|  | 	repo := repository.NewGormDeviceRepository(db) | ||||||
|  |  | ||||||
|  | 	t.Run("成功获取空列表", func(t *testing.T) { | ||||||
|  | 		devices, err := repo.ListAll() | ||||||
| 		assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 		assert.NotZero(t, childDevice.ID) | 		assert.Empty(t, devices) | ||||||
| 		assert.NotNil(t, childDevice.ParentID) |  | ||||||
| 		assert.Equal(t, areaController.ID, *childDevice.ParentID) |  | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("通过父ID列出 - 找到子设备", func(t *testing.T) { | 	t.Run("成功获取包含设备的列表", func(t *testing.T) { | ||||||
| 		children, err := repo.ListByParentID(&areaController.ID) | 		createTestDevice(t, db, "设备1", models.DeviceTypeAreaController, nil) | ||||||
|  | 		createTestDevice(t, db, "设备2", models.DeviceTypeDevice, nil) | ||||||
|  |  | ||||||
|  | 		devices, err := repo.ListAll() | ||||||
| 		assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 		assert.Len(t, children, 1, "应找到一个子设备") | 		assert.Len(t, devices, 2) | ||||||
| 		assert.Equal(t, childDevice.ID, children[0].ID) | 		assert.Equal(t, "设备1", devices[0].Name) | ||||||
|  | 		assert.Equal(t, "设备2", devices[1].Name) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("通过父ID列出 - 找到顶层设备", func(t *testing.T) { | 	t.Run("数据库查询失败", func(t *testing.T) { | ||||||
|  | 		sqlDB, _ := db.DB() | ||||||
|  | 		sqlDB.Close() | ||||||
|  |  | ||||||
|  | 		_, err := repo.ListAll() | ||||||
|  | 		assert.Error(t, err) | ||||||
|  | 		assert.Contains(t, err.Error(), "database is closed") | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRepoListByParentID(t *testing.T) { | ||||||
|  | 	db := setupTestDB(t) | ||||||
|  | 	repo := repository.NewGormDeviceRepository(db) | ||||||
|  |  | ||||||
|  | 	parent1 := createTestDevice(t, db, "父设备1", models.DeviceTypeAreaController, nil) | ||||||
|  | 	parent2 := createTestDevice(t, db, "父设备2", models.DeviceTypeAreaController, nil) | ||||||
|  | 	child1_1 := createTestDevice(t, db, "子设备1-1", models.DeviceTypeDevice, &parent1.ID) | ||||||
|  | 	child1_2 := createTestDevice(t, db, "子设备1-2", models.DeviceTypeDevice, &parent1.ID) | ||||||
|  | 	_ = createTestDevice(t, db, "子设备2-1", models.DeviceTypeDevice, &parent2.ID) | ||||||
|  |  | ||||||
|  | 	t.Run("成功通过父ID查找子设备", func(t *testing.T) { | ||||||
|  | 		children, err := repo.ListByParentID(&parent1.ID) | ||||||
|  | 		assert.NoError(t, err) | ||||||
|  | 		assert.Len(t, children, 2) | ||||||
|  | 		assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[0].ID) | ||||||
|  | 		assert.Contains(t, []uint{child1_1.ID, child1_2.ID}, children[1].ID) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("成功通过nil父ID查找顶层设备", func(t *testing.T) { | ||||||
| 		parents, err := repo.ListByParentID(nil) | 		parents, err := repo.ListByParentID(nil) | ||||||
| 		assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
| 		assert.Len(t, parents, 1, "应找到一个顶层设备") | 		assert.Len(t, parents, 2) | ||||||
| 		assert.Equal(t, areaController.ID, parents[0].ID) | 		assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[0].ID) | ||||||
|  | 		assert.Contains(t, []uint{parent1.ID, parent2.ID}, parents[1].ID) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("更新 - 成功更新设备信息", func(t *testing.T) { | 	t.Run("查找不存在的父ID", func(t *testing.T) { | ||||||
| 		childDevice.Location = "1号猪舍西侧" | 		nonExistentParentID := uint(9999) | ||||||
| 		err := repo.Update(childDevice) | 		children, err := repo.ListByParentID(&nonExistentParentID) | ||||||
|  | 		assert.NoError(t, err) // GORM 在未找到时返回空列表而不是错误 | ||||||
|  | 		assert.Empty(t, children) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("数据库查询失败", func(t *testing.T) { | ||||||
|  | 		sqlDB, _ := db.DB() | ||||||
|  | 		sqlDB.Close() | ||||||
|  |  | ||||||
|  | 		_, err := repo.ListByParentID(&parent1.ID) | ||||||
|  | 		assert.Error(t, err) | ||||||
|  | 		assert.Contains(t, err.Error(), "database is closed") | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRepoUpdate(t *testing.T) { | ||||||
|  | 	db := setupTestDB(t) | ||||||
|  | 	repo := repository.NewGormDeviceRepository(db) | ||||||
|  |  | ||||||
|  | 	device := createTestDevice(t, db, "原始设备", models.DeviceTypeAreaController, nil) | ||||||
|  |  | ||||||
|  | 	t.Run("成功更新设备信息", func(t *testing.T) { | ||||||
|  | 		device.Name = "更新后的设备" | ||||||
|  | 		device.Location = "新地点" | ||||||
|  | 		err := repo.Update(device) | ||||||
| 		assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
|  |  | ||||||
| 		updatedDevice, _ := repo.FindByID(childDevice.ID) | 		updatedDevice, err := repo.FindByID(device.ID) | ||||||
| 		assert.Equal(t, "1号猪舍西侧", updatedDevice.Location) | 		assert.NoError(t, err) | ||||||
|  | 		assert.Equal(t, "更新后的设备", updatedDevice.Name) | ||||||
|  | 		assert.Equal(t, "新地点", updatedDevice.Location) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	t.Run("删除 - 成功删除设备", func(t *testing.T) { | 	t.Run("数据库更新失败", func(t *testing.T) { | ||||||
| 		err := repo.Delete(childDevice.ID) | 		sqlDB, _ := db.DB() | ||||||
|  | 		sqlDB.Close() | ||||||
|  |  | ||||||
|  | 		device.Name = "更新失败的设备" | ||||||
|  | 		err := repo.Update(device) | ||||||
|  | 		assert.Error(t, err) | ||||||
|  | 		assert.Contains(t, err.Error(), "database is closed") | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRepoDelete(t *testing.T) { | ||||||
|  | 	db := setupTestDB(t) | ||||||
|  | 	repo := repository.NewGormDeviceRepository(db) | ||||||
|  |  | ||||||
|  | 	device := createTestDevice(t, db, "待删除设备", models.DeviceTypeAreaController, nil) | ||||||
|  |  | ||||||
|  | 	t.Run("成功删除设备", func(t *testing.T) { | ||||||
|  | 		err := repo.Delete(device.ID) | ||||||
| 		assert.NoError(t, err) | 		assert.NoError(t, err) | ||||||
|  |  | ||||||
| 		// 验证设备已被软删除 | 		// 验证设备已被软删除 | ||||||
| 		_, err = repo.FindByID(childDevice.ID) | 		_, err = repo.FindByID(device.ID) | ||||||
| 		assert.Error(t, err, "删除后应无法找到设备") | 		assert.Error(t, err, "删除后应无法找到设备") | ||||||
| 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 RecordNotFound") | 		assert.ErrorIs(t, err, gorm.ErrRecordNotFound, "错误类型应为 RecordNotFound") | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("删除不存在的设备", func(t *testing.T) { | ||||||
|  | 		err := repo.Delete(9999) // 不存在的ID | ||||||
|  | 		assert.NoError(t, err)   // GORM 的 Delete 方法在删除不存在的记录时不会报错 | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	t.Run("数据库删除失败", func(t *testing.T) { | ||||||
|  | 		sqlDB, _ := db.DB() | ||||||
|  | 		sqlDB.Close() | ||||||
|  |  | ||||||
|  | 		err := repo.Delete(device.ID) | ||||||
|  | 		assert.Error(t, err) | ||||||
|  | 		assert.Contains(t, err.Error(), "database is closed") | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user