diff --git a/internal/app/controller/user/user_controller_test.go b/internal/app/controller/user/user_controller_test.go index 7738d95..e553af4 100644 --- a/internal/app/controller/user/user_controller_test.go +++ b/internal/app/controller/user/user_controller_test.go @@ -10,6 +10,7 @@ import ( "testing" "git.huangwc.com/pig/pig-farm-controller/internal/app/controller/user" + "git.huangwc.com/pig/pig-farm-controller/internal/app/service/token" "git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" "git.huangwc.com/pig/pig-farm-controller/internal/infra/models" "github.com/gin-gonic/gin" @@ -32,6 +33,7 @@ func (m *MockUserRepository) Create(user *models.User) error { } // FindByUsername 模拟 UserRepository 的 FindByUsername 方法 +// 返回类型改回 *models.User func (m *MockUserRepository) FindByUsername(username string) (*models.User, error) { args := m.Called(username) if args.Get(0) == nil { @@ -49,6 +51,26 @@ func (m *MockUserRepository) FindByID(id uint) (*models.User, error) { return args.Get(0).(*models.User), args.Error(1) } +// MockTokenService 是 token.TokenService 接口的模拟实现 +type MockTokenService struct { + mock.Mock +} + +// GenerateToken 模拟 TokenService 的 GenerateToken 方法 +func (m *MockTokenService) GenerateToken(userID uint) (string, error) { + args := m.Called(userID) + return args.String(0), args.Error(1) +} + +// ParseToken 模拟 TokenService 的 ParseToken 方法 +func (m *MockTokenService) ParseToken(tokenString string) (*token.Claims, error) { + args := m.Called(tokenString) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*token.Claims), args.Error(1) +} + // TestCreateUser 测试 CreateUser 方法 func TestCreateUser(t *testing.T) { gin.SetMode(gin.TestMode) // 设置 Gin 为测试模式 @@ -205,3 +227,218 @@ func TestCreateUser(t *testing.T) { }) } } + +// TestLogin 测试 Login 方法 +func TestLogin(t *testing.T) { + // 设置release模式阻止废话日志 + gin.SetMode(gin.ReleaseMode) + + // 创建一个不输出日志的真实 logs.Logger 实例 + discardSyncer := zapcore.AddSync(io.Discard) + encoderConfig := zap.NewProductionEncoderConfig() + encoder := zapcore.NewConsoleEncoder(encoderConfig) + core := zapcore.NewCore(encoder, discardSyncer, zap.DebugLevel) // 设置为 DebugLevel 以确保所有日志都被处理(并丢弃) + zapLogger := zap.New(core) + sugaredLogger := zapLogger.Sugar() + silentLogger := &logs.Logger{SugaredLogger: sugaredLogger} + + tests := []struct { + name string + requestBody user.LoginRequest + mockRepoSetup func(*MockUserRepository) + mockTokenServiceSetup func(*MockTokenService) + expectedResponse map[string]interface{} + }{ + { + name: "成功登录", + requestBody: user.LoginRequest{ + Username: "loginuser", + Password: "correctpassword", + }, + mockRepoSetup: func(m *MockUserRepository) { + mockUser := &models.User{ + Model: gorm.Model{ID: 1}, + Username: "loginuser", + Password: "correctpassword", // 明文密码,BeforeCreate 会哈希它 + } + // 调用 BeforeCreate 钩子来哈希密码 + _ = mockUser.BeforeCreate(nil) + m.On("FindByUsername", "loginuser").Return(mockUser, nil).Once() + }, + mockTokenServiceSetup: func(m *MockTokenService) { + m.On("GenerateToken", uint(1)).Return("mocked_token", nil).Once() + }, + expectedResponse: map[string]interface{}{ + "code": float64(http.StatusOK), + "message": "登录成功", + "data": map[string]interface{}{ + "username": "loginuser", + "id": float64(1), + "token": "mocked_token", + }, + }, + }, + { + name: "请求参数绑定失败_缺少用户名", + requestBody: user.LoginRequest{ + Username: "", // 缺少用户名 + Password: "password", + }, + mockRepoSetup: func(m *MockUserRepository) {}, + mockTokenServiceSetup: func(m *MockTokenService) {}, + expectedResponse: map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "message": "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'required' tag", + "data": nil, + }, + }, + { + name: "请求参数绑定失败_缺少密码", + requestBody: user.LoginRequest{ + Username: "testuser", + Password: "", // 缺少密码 + }, + mockRepoSetup: func(m *MockUserRepository) {}, + mockTokenServiceSetup: func(m *MockTokenService) {}, + expectedResponse: map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "message": "Key: 'LoginRequest.Password' Error:Field validation for 'Password' failed on the 'required' tag", + "data": nil, + }, + }, + { + name: "用户不存在", + requestBody: user.LoginRequest{ + Username: "nonexistent", + Password: "anypassword", + }, + mockRepoSetup: func(m *MockUserRepository) { + m.On("FindByUsername", "nonexistent").Return(nil, gorm.ErrRecordNotFound).Once() + }, + mockTokenServiceSetup: func(m *MockTokenService) {}, + expectedResponse: map[string]interface{}{ + "code": float64(http.StatusUnauthorized), + "message": "用户名或密码不正确", + "data": nil, + }, + }, + { + name: "查询用户失败_通用数据库错误", + requestBody: user.LoginRequest{ + Username: "dberroruser", + Password: "password", + }, + mockRepoSetup: func(m *MockUserRepository) { + m.On("FindByUsername", "dberroruser").Return(nil, errors.New("database connection error")).Once() + }, + mockTokenServiceSetup: func(m *MockTokenService) {}, expectedResponse: map[string]interface{}{ + "code": float64(http.StatusInternalServerError), + "message": "登录失败", + "data": nil, + }, + }, + { + name: "密码不正确", + requestBody: user.LoginRequest{ + Username: "loginuser", + Password: "wrongpassword", + }, + mockRepoSetup: func(m *MockUserRepository) { + mockUser := &models.User{ + Model: gorm.Model{ID: 1}, + Username: "loginuser", + Password: "correctpassword", // 明文密码,BeforeCreate 会哈希它 + } + // 调用 BeforeCreate 钩子来哈希密码 + _ = mockUser.BeforeCreate(nil) + m.On("FindByUsername", "loginuser").Return(mockUser, nil).Once() + }, + mockTokenServiceSetup: func(m *MockTokenService) {}, + expectedResponse: map[string]interface{}{ + "code": float64(http.StatusUnauthorized), + "message": "用户名或密码不正确", + "data": nil, + }, + }, + { + name: "生成Token失败", + requestBody: user.LoginRequest{ + Username: "loginuser", + Password: "correctpassword", + }, + mockRepoSetup: func(m *MockUserRepository) { + mockUser := &models.User{ + Model: gorm.Model{ID: 1}, + Username: "loginuser", + Password: "correctpassword", // 明文密码,BeforeCreate 会哈希它 + } + // 调用 BeforeCreate 钩子来哈希密码 + _ = mockUser.BeforeCreate(nil) + m.On("FindByUsername", "loginuser").Return(mockUser, nil).Once() + }, + mockTokenServiceSetup: func(m *MockTokenService) { + m.On("GenerateToken", uint(1)).Return("", errors.New("jwt error")).Once() + }, + expectedResponse: map[string]interface{}{ + "code": float64(http.StatusInternalServerError), + "message": "登录失败,无法生成认证信息", + "data": nil, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 初始化 Gin 上下文和记录器 + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Request = httptest.NewRequest(http.MethodPost, "/login", nil) // URL 路径不重要,因为我们不测试路由 + + // 设置请求体 + jsonBody, _ := json.Marshal(tt.requestBody) + ctx.Request.Body = io.NopCloser(bytes.NewBuffer(jsonBody)) + ctx.Request.Header.Set("Content-Type", "application/json") + + // 创建 Mock + mockRepo := new(MockUserRepository) + mockTokenService := new(MockTokenService) + + // 设置 Mock 行为 + tt.mockRepoSetup(mockRepo) + tt.mockTokenServiceSetup(mockTokenService) + + // 创建控制器实例 + controller := user.NewController(mockRepo, silentLogger, mockTokenService) + + // 调用被测试的方法 + controller.Login(ctx) + + // 解析响应体 + var responseBody map[string]interface{} + err := json.Unmarshal(w.Body.Bytes(), &responseBody) + assert.NoError(t, err) + + // 断言响应体中的 code 字段 + assert.Equal(t, tt.expectedResponse["code"], responseBody["code"]) + + // 断言响应内容 (除了 code 字段) + if tt.expectedResponse["code"] == float64(http.StatusOK) { + assert.NotNil(t, responseBody["data"].(map[string]interface{})["id"]) + assert.NotNil(t, responseBody["data"].(map[string]interface{})["token"]) + // 移除 ID 和 Token 字段以便进行通用断言 + delete(responseBody["data"].(map[string]interface{}), "id") + delete(tt.expectedResponse["data"].(map[string]interface{}), "id") + delete(responseBody["data"].(map[string]interface{}), "token") + delete(tt.expectedResponse["data"].(map[string]interface{}), "token") + } + // 移除 code 字段以便进行通用断言 + delete(responseBody, "code") + delete(tt.expectedResponse, "code") + assert.Equal(t, tt.expectedResponse, responseBody) + + // 验证 Mock 期望是否都已满足 + mockRepo.AssertExpectations(t) + mockTokenService.AssertExpectations(t) + }) + } +}