// Package websocket 提供WebSocket通信功能 // 实现中继设备与平台之间的实时通信 package websocket import ( "encoding/json" "fmt" "net/http" "time" "git.huangwc.com/pig/pig-farm-controller/internal/logs" "git.huangwc.com/pig/pig-farm-controller/internal/storage/repository" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) // Server WebSocket服务器结构 type Server struct { hub *Hub logger *logs.Logger deviceRepo repository.DeviceRepo } const ( // 允许写入的最长时间 writeWait = 10 * time.Second // 允许读取的最长时间 pongWait = 60 * time.Second // 发送ping消息的周期 pingPeriod = (pongWait * 9) / 10 // 发送队列的最大容量 maxMessageSize = 512 ) var ( newline = []byte{'\n'} space = []byte{' '} ) // Upgrader WebSocket升级器 var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { // 允许所有来源的连接(在生产环境中应该更严格) return true }, } // NewServer 创建新的WebSocket服务器实例 func NewServer(deviceRepo repository.DeviceRepo) *Server { return &Server{ hub: NewHub(deviceRepo), logger: logs.NewLogger(), deviceRepo: deviceRepo, } } // getDeviceDisplayName 获取设备显示名称 func (s *Server) getDeviceDisplayName(deviceID string) string { if s.deviceRepo != nil { if device, err := s.deviceRepo.FindByIDString(deviceID); err == nil && device != nil { return fmt.Sprintf("%s(id:%s)", device.Name, deviceID) } } return fmt.Sprintf("未知设备(id:%s)", deviceID) } // Start 启动WebSocket服务器 func (s *Server) Start() { // 启动hub go s.hub.Run() } func (s *Server) Stop() { s.hub.Close() } // readPump 从WebSocket连接读取消息 func (c *Client) readPump() { defer func() { c.hub.unregister <- c c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { var msg Message err := c.conn.ReadJSON(&msg) if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { c.logger.Error("[WebSocket] 读取错误: " + err.Error()) } break } // 处理收到的消息 c.hub.broadcast <- msg } } // writePump 向WebSocket连接写入消息 func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.conn.Close() }() for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { // hub关闭了send通道 c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } w, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { return } // 将消息序列化为JSON data, err := json.Marshal(message) if err != nil { c.logger.Error("[WebSocket] 消息序列化失败: " + err.Error()) continue } w.Write(data) // 添加队列中的其他消息 n := len(c.send) for i := 0; i < n; i++ { msg := <-c.send data, err := json.Marshal(msg) if err != nil { c.logger.Error("[WebSocket] 消息序列化失败: " + err.Error()) continue } w.Write(newline) w.Write(data) } if err := w.Close(); err != nil { return } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // HandleConnection 处理WebSocket连接请求 func (s *Server) HandleConnection(c *gin.Context) { // 升级HTTP连接为WebSocket连接 conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { s.logger.Error("[WebSocket] 连接升级失败: " + err.Error()) return } // 从查询参数获取设备ID deviceID := c.Query("device_id") if deviceID == "" { s.logger.Warn("[WebSocket] 缺少设备ID参数") conn.Close() return } // 创建客户端 client := &Client{ hub: s.hub, conn: conn, send: make(chan Message, 256), DeviceID: deviceID, Request: c.Request, logger: s.logger, } // 注册客户端 client.hub.register <- client // 启动读写goroutine go client.writePump() go client.readPump() deviceName := s.getDeviceDisplayName(deviceID) s.logger.Info("[WebSocket] 设备 " + deviceName + " 连接成功") } // SendToDevice 向指定设备发送消息 func (s *Server) SendToDevice(deviceID string, msgType string, data interface{}) error { return s.hub.SendToDevice(deviceID, msgType, data) } // GetConnectedDevices 获取已连接的设备列表 func (s *Server) GetConnectedDevices() []string { return s.hub.GetConnectedDevices() }