149 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			149 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Package middleware 提供HTTP中间件功能
 | |
| // 包含鉴权、日志、恢复等中间件实现
 | |
| package middleware
 | |
| 
 | |
| import (
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"git.huangwc.com/pig/pig-farm-controller/internal/logs"
 | |
| 	"git.huangwc.com/pig/pig-farm-controller/internal/storage/repository"
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 	"github.com/golang-jwt/jwt/v5"
 | |
| 	"gorm.io/gorm"
 | |
| )
 | |
| 
 | |
| // AuthMiddleware 鉴权中间件结构
 | |
| type AuthMiddleware struct {
 | |
| 	userRepo repository.UserRepo
 | |
| 	logger   *logs.Logger
 | |
| }
 | |
| 
 | |
| // AuthUser 用于在上下文中存储的用户信息
 | |
| type AuthUser struct {
 | |
| 	ID       uint   `json:"id"`
 | |
| 	Username string `json:"username"`
 | |
| }
 | |
| 
 | |
| // JWTClaims 自定义JWT声明
 | |
| type JWTClaims struct {
 | |
| 	UserID   uint   `json:"user_id"`
 | |
| 	Username string `json:"username"`
 | |
| 	jwt.RegisteredClaims
 | |
| }
 | |
| 
 | |
| // NewAuthMiddleware 创建鉴权中间件实例
 | |
| func NewAuthMiddleware(userRepo repository.UserRepo) *AuthMiddleware {
 | |
| 	return &AuthMiddleware{
 | |
| 		userRepo: userRepo,
 | |
| 		logger:   logs.NewLogger(),
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // getJWTSecret 获取JWT密钥
 | |
| func (m *AuthMiddleware) getJWTSecret() []byte {
 | |
| 	// 在实际项目中,应该从配置文件或环境变量中读取
 | |
| 	secret := os.Getenv("JWT_SECRET")
 | |
| 	if secret == "" {
 | |
| 		secret = "pig-farm-controller-secret-key" // 默认密钥
 | |
| 	}
 | |
| 	return []byte(secret)
 | |
| }
 | |
| 
 | |
| // GenerateToken 为用户生成JWT token
 | |
| func (m *AuthMiddleware) GenerateToken(userID uint, username string) (string, error) {
 | |
| 	claims := JWTClaims{
 | |
| 		UserID:   userID,
 | |
| 		Username: username,
 | |
| 		RegisteredClaims: jwt.RegisteredClaims{
 | |
| 			ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), // 24小时过期
 | |
| 			IssuedAt:  jwt.NewNumericDate(time.Now()),
 | |
| 			NotBefore: jwt.NewNumericDate(time.Now()),
 | |
| 			Issuer:    "pig-farm-controller",
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
 | |
| 	return token.SignedString(m.getJWTSecret())
 | |
| }
 | |
| 
 | |
| // Handle 鉴权中间件处理函数
 | |
| func (m *AuthMiddleware) Handle() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		// 从请求头中获取认证信息
 | |
| 		authHeader := c.GetHeader("Authorization")
 | |
| 		if authHeader == "" {
 | |
| 			c.JSON(http.StatusUnauthorized, gin.H{"error": "缺少认证信息"})
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// 检查Bearer token格式
 | |
| 		if !strings.HasPrefix(authHeader, "Bearer ") {
 | |
| 			c.JSON(http.StatusUnauthorized, gin.H{"error": "认证信息格式错误"})
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// 解析token
 | |
| 		tokenString := strings.TrimPrefix(authHeader, "Bearer ")
 | |
| 
 | |
| 		// 验证token并获取用户信息
 | |
| 		user, err := m.getUserFromJWT(tokenString)
 | |
| 		if err != nil {
 | |
| 			if err == gorm.ErrRecordNotFound {
 | |
| 				c.JSON(http.StatusUnauthorized, gin.H{"error": "用户不存在"})
 | |
| 			} else {
 | |
| 				m.logger.Error("Token验证失败: " + err.Error())
 | |
| 				c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的认证令牌"})
 | |
| 			}
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// 将用户信息保存到上下文中,供后续处理函数使用
 | |
| 		c.Set("user", user)
 | |
| 
 | |
| 		// 继续处理请求
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // getUserFromJWT 从JWT token中获取用户信息
 | |
| func (m *AuthMiddleware) getUserFromJWT(tokenString string) (*AuthUser, error) {
 | |
| 	// 解析token
 | |
| 	token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
 | |
| 		return m.getJWTSecret(), nil
 | |
| 	})
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// 验证token
 | |
| 	if !token.Valid {
 | |
| 		return nil, gorm.ErrRecordNotFound
 | |
| 	}
 | |
| 
 | |
| 	// 获取声明
 | |
| 	claims, ok := token.Claims.(*JWTClaims)
 | |
| 	if !ok {
 | |
| 		return nil, gorm.ErrRecordNotFound
 | |
| 	}
 | |
| 
 | |
| 	// 根据用户ID查找用户
 | |
| 	userModel, err := m.userRepo.FindByID(claims.UserID)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	user := &AuthUser{
 | |
| 		ID:       userModel.ID,
 | |
| 		Username: userModel.Username,
 | |
| 	}
 | |
| 
 | |
| 	return user, nil
 | |
| }
 |