package user_test import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "git.huangwc.com/pig/pig-farm-controller/internal/app/controller" "git.huangwc.com/pig/pig-farm-controller/internal/app/controller/user" "git.huangwc.com/pig/pig-farm-controller/internal/app/service/token" "git.huangwc.com/pig/pig-farm-controller/internal/infra/logs" "git.huangwc.com/pig/pig-farm-controller/internal/infra/models" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "gorm.io/gorm" ) // MockUserRepository 是 UserRepository 接口的模拟实现 type MockUserRepository struct { mock.Mock } // Create 模拟 UserRepository 的 Create 方法 func (m *MockUserRepository) Create(user *models.User) error { args := m.Called(user) return args.Error(0) } // FindByUsername 模拟 UserRepository 的 FindByUsername 方法 // 返回类型改回 *models.User func (m *MockUserRepository) FindByUsername(username string) (*models.User, error) { args := m.Called(username) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*models.User), args.Error(1) } // FindByID 模拟 UserRepository 的 FindByID 方法 func (m *MockUserRepository) FindByID(id uint) (*models.User, error) { args := m.Called(id) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*models.User), args.Error(1) } // MockTokenService 是 token.TokenService 接口的模拟实现 type MockTokenService struct { mock.Mock } // GenerateToken 模拟 TokenService 的 GenerateToken 方法 func (m *MockTokenService) GenerateToken(userID uint) (string, error) { args := m.Called(userID) return args.String(0), args.Error(1) } // ParseToken 模拟 TokenService 的 ParseToken 方法 func (m *MockTokenService) ParseToken(tokenString string) (*token.Claims, error) { args := m.Called(tokenString) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(*token.Claims), args.Error(1) } // TestCreateUser 测试 CreateUser 方法 func TestCreateUser(t *testing.T) { gin.SetMode(gin.TestMode) // 设置 Gin 为测试模式 // 创建一个不输出日志的真实 logs.Logger 实例 silentLogger := logs.NewSilentLogger() tests := []struct { name string requestBody user.CreateUserRequest mockRepoSetup func(*MockUserRepository) expectedResponse map[string]interface{} }{ { name: "成功创建用户", requestBody: user.CreateUserRequest{ Username: "testuser", Password: "password123", }, mockRepoSetup: func(m *MockUserRepository) { // 模拟 Create 成功 m.On("Create", mock.AnythingOfType("*models.User")).Return(nil).Run(func(args mock.Arguments) { // 模拟数据库自动填充 ID userArg := args.Get(0).(*models.User) userArg.ID = 1 // 设置一个非零的 ID }).Once() // 在成功创建用户的路径下,FindByUsername 不会被调用,因此这里不需要设置其期望 }, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeCreated), // 修改这里:使用自定义状态码 "message": "用户创建成功", "data": map[string]interface{}{ "username": "testuser", // "id": mock.Anything, // 移除这里的 id,在断言时单独检查 }, }, }, { name: "请求参数绑定失败_密码过短", requestBody: user.CreateUserRequest{ Username: "testuser2", Password: "123", // 密码少于6位 }, mockRepoSetup: func(m *MockUserRepository) { // 不会调用 Create 或 FindByUsername }, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeBadRequest), "message": "Key: 'CreateUserRequest.Password' Error:Field validation for 'Password' failed on the 'min' tag", "data": nil, }, }, { name: "请求参数绑定失败_缺少用户名", requestBody: user.CreateUserRequest{ Password: "password123", }, mockRepoSetup: func(m *MockUserRepository) { // 不会调用 Create 或 FindByUsername }, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeBadRequest), "message": "Key: 'CreateUserRequest.Username' Error:Field validation for 'Username' failed on the 'required' tag", "data": nil, }, }, { name: "用户名已存在", requestBody: user.CreateUserRequest{ Username: "existinguser", Password: "password123", }, mockRepoSetup: func(m *MockUserRepository) { // 模拟 Create 失败,因为用户名已存在 m.On("Create", mock.AnythingOfType("*models.User")).Return(errors.New("duplicate entry")).Once() // 模拟 FindByUsername 找到用户,确认是用户名重复 m.On("FindByUsername", "existinguser").Return(&models.User{Username: "existinguser"}, nil).Once() }, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeConflict), "message": "用户名已存在", "data": nil, }, }, { name: "创建用户失败_通用数据库错误", requestBody: user.CreateUserRequest{ Username: "db_error_user", Password: "password123", }, mockRepoSetup: func(m *MockUserRepository) { // 模拟 Create 失败,通用数据库错误 m.On("Create", mock.AnythingOfType("*models.User")).Return(errors.New("database error")).Once() // 模拟 FindByUsername 找不到用户,确认不是用户名重复 m.On("FindByUsername", "db_error_user").Return(nil, gorm.ErrRecordNotFound).Once() }, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeInternalError), "message": "创建用户失败", "data": nil, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 初始化 Gin 上下文和记录器 w := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(w) ctx.Request = httptest.NewRequest(http.MethodPost, "/users", nil) // URL 路径不重要,因为我们不测试路由 // 设置请求体 jsonBody, _ := json.Marshal(tt.requestBody) ctx.Request.Body = io.NopCloser(bytes.NewBuffer(jsonBody)) ctx.Request.Header.Set("Content-Type", "application/json") // 创建 Mock UserRepository mockRepo := new(MockUserRepository) // 设置 Mock UserRepository 行为 tt.mockRepoSetup(mockRepo) // 创建控制器实例,使用静默日志器 userController := user.NewController(mockRepo, silentLogger, nil) // tokenService 在 CreateUser 中未使用,设为 nil // 调用被测试的方法 userController.CreateUser(ctx) // 解析响应体 var responseBody map[string]interface{} err := json.Unmarshal(w.Body.Bytes(), &responseBody) assert.NoError(t, err) // 断言响应体中的 code 字段 assert.Equal(t, tt.expectedResponse["code"], responseBody["code"]) // 断言响应内容 (除了 code 字段) if tt.expectedResponse["code"] == float64(controller.CodeCreated) { // 确保 data 字段存在且是 map[string]interface{} 类型 data, ok := responseBody["data"].(map[string]interface{}) assert.True(t, ok, "响应体中的 data 字段应为 map[string]interface{}") // 确保 id 字段存在且不为零 id, idOk := data["id"].(float64) assert.True(t, idOk, "响应体中的 data.id 字段应为 float64 类型") assert.NotEqual(t, float64(0), id, "响应体中的 data.id 不应为零") // 移除 ID 字段以便进行通用断言 delete(responseBody["data"].(map[string]interface{}), "id") // 移除 expectedResponse 中的 id 字段,因为我们已经单独验证了 if expectedData, ok := tt.expectedResponse["data"].(map[string]interface{}); ok { delete(expectedData, "id") } } // 移除 code 字段以便进行通用断言 delete(responseBody, "code") delete(tt.expectedResponse, "code") assert.Equal(t, tt.expectedResponse, responseBody) // 验证 Mock 期望是否都已满足 mockRepo.AssertExpectations(t) }) } } // TestLogin 测试 Login 方法 func TestLogin(t *testing.T) { // 设置release模式阻止废话日志 gin.SetMode(gin.ReleaseMode) // 创建一个不输出日志的真实 logs.Logger 实例 silentLogger := logs.NewSilentLogger() tests := []struct { name string requestBody user.LoginRequest mockRepoSetup func(*MockUserRepository) mockTokenServiceSetup func(*MockTokenService) expectedResponse map[string]interface{} }{ { name: "成功登录", requestBody: user.LoginRequest{ Username: "loginuser", Password: "correctpassword", }, mockRepoSetup: func(m *MockUserRepository) { mockUser := &models.User{ Model: gorm.Model{ID: 1}, Username: "loginuser", Password: "correctpassword", // 明文密码,BeforeCreate 会哈希它 } // 调用 BeforeCreate 钩子来哈希密码 _ = mockUser.BeforeCreate(nil) m.On("FindByUsername", "loginuser").Return(mockUser, nil).Once() }, mockTokenServiceSetup: func(m *MockTokenService) { m.On("GenerateToken", uint(1)).Return("mocked_token", nil).Once() }, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeSuccess), "message": "登录成功", "data": map[string]interface{}{ "username": "loginuser", "id": float64(1), "token": "mocked_token", }, }, }, { name: "请求参数绑定失败_缺少用户名", requestBody: user.LoginRequest{ Username: "", // 缺少用户名 Password: "password", }, mockRepoSetup: func(m *MockUserRepository) {}, mockTokenServiceSetup: func(m *MockTokenService) {}, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeBadRequest), "message": "Key: 'LoginRequest.Username' Error:Field validation for 'Username' failed on the 'required' tag", "data": nil, }, }, { name: "请求参数绑定失败_缺少密码", requestBody: user.LoginRequest{ Username: "testuser", Password: "", // 缺少密码 }, mockRepoSetup: func(m *MockUserRepository) {}, mockTokenServiceSetup: func(m *MockTokenService) {}, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeBadRequest), "message": "Key: 'LoginRequest.Password' Error:Field validation for 'Password' failed on the 'required' tag", "data": nil, }, }, { name: "用户不存在", requestBody: user.LoginRequest{ Username: "nonexistent", Password: "anypassword", }, mockRepoSetup: func(m *MockUserRepository) { m.On("FindByUsername", "nonexistent").Return(nil, gorm.ErrRecordNotFound).Once() }, mockTokenServiceSetup: func(m *MockTokenService) {}, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeUnauthorized), "message": "用户名或密码不正确", "data": nil, }, }, { name: "查询用户失败_通用数据库错误", requestBody: user.LoginRequest{ Username: "dberroruser", Password: "password", }, mockRepoSetup: func(m *MockUserRepository) { m.On("FindByUsername", "dberroruser").Return(nil, errors.New("database connection error")).Once() }, mockTokenServiceSetup: func(m *MockTokenService) {}, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeInternalError), "message": "登录失败", "data": nil, }, }, { name: "密码不正确", requestBody: user.LoginRequest{ Username: "loginuser", Password: "wrongpassword", }, mockRepoSetup: func(m *MockUserRepository) { mockUser := &models.User{ Model: gorm.Model{ID: 1}, Username: "loginuser", Password: "correctpassword", // 明文密码,BeforeCreate 会哈希它 } // 调用 BeforeCreate 钩子来哈希密码 _ = mockUser.BeforeCreate(nil) m.On("FindByUsername", "loginuser").Return(mockUser, nil).Once() }, mockTokenServiceSetup: func(m *MockTokenService) {}, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeUnauthorized), "message": "用户名或密码不正确", "data": nil, }, }, { name: "生成Token失败", requestBody: user.LoginRequest{ Username: "loginuser", Password: "correctpassword", }, mockRepoSetup: func(m *MockUserRepository) { mockUser := &models.User{ Model: gorm.Model{ID: 1}, Username: "loginuser", Password: "correctpassword", // 明文密码,BeforeCreate 会哈希它 } // 调用 BeforeCreate 钩子来哈希密码 _ = mockUser.BeforeCreate(nil) m.On("FindByUsername", "loginuser").Return(mockUser, nil).Once() }, mockTokenServiceSetup: func(m *MockTokenService) { m.On("GenerateToken", uint(1)).Return("", errors.New("jwt error")).Once() }, expectedResponse: map[string]interface{}{ "code": float64(controller.CodeInternalError), "message": "登录失败,无法生成认证信息", "data": nil, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 初始化 Gin 上下文和记录器 w := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(w) ctx.Request = httptest.NewRequest(http.MethodPost, "/login", nil) // URL 路径不重要,因为我们不测试路由 // 设置请求体 jsonBody, _ := json.Marshal(tt.requestBody) ctx.Request.Body = io.NopCloser(bytes.NewBuffer(jsonBody)) ctx.Request.Header.Set("Content-Type", "application/json") // 创建 Mock mockRepo := new(MockUserRepository) mockTokenService := new(MockTokenService) // 设置 Mock 行为 tt.mockRepoSetup(mockRepo) tt.mockTokenServiceSetup(mockTokenService) // 创建控制器实例 userController := user.NewController(mockRepo, silentLogger, mockTokenService) // 调用被测试的方法 userController.Login(ctx) // 解析响应体 var responseBody map[string]interface{} err := json.Unmarshal(w.Body.Bytes(), &responseBody) assert.NoError(t, err) // 断言响应体中的 code 字段 assert.Equal(t, tt.expectedResponse["code"], responseBody["code"]) // 断言响应内容 (除了 code 字段) if tt.expectedResponse["code"] == float64(controller.CodeSuccess) { // 确保 data 字段存在且是 map[string]interface{} 类型 data, ok := responseBody["data"].(map[string]interface{}) assert.True(t, ok, "响应体中的 data 字段应为 map[string]interface{}") // 验证 id 和 token 存在 assert.NotNil(t, data["id"]) assert.NotNil(t, data["token"]) // 移除 ID 和 Token 字段以便进行通用断言 delete(responseBody["data"].(map[string]interface{}), "id") delete(tt.expectedResponse["data"].(map[string]interface{}), "id") delete(responseBody["data"].(map[string]interface{}), "token") delete(tt.expectedResponse["data"].(map[string]interface{}), "token") } // 移除 code 字段以便进行通用断言 delete(responseBody, "code") delete(tt.expectedResponse, "code") assert.Equal(t, tt.expectedResponse, responseBody) // 验证 Mock 期望是否都已满足 mockRepo.AssertExpectations(t) mockTokenService.AssertExpectations(t) }) } }