215 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			215 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Package websocket 提供WebSocket通信功能
 | |
| // 实现中继设备与平台之间的实时通信
 | |
| package websocket
 | |
| 
 | |
| import (
 | |
| 	"encoding/json"
 | |
| 	"fmt"
 | |
| 	"net/http"
 | |
| 	"time"
 | |
| 
 | |
| 	"git.huangwc.com/pig/pig-farm-controller/internal/logs"
 | |
| 	"git.huangwc.com/pig/pig-farm-controller/internal/storage/repository"
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 	"github.com/gorilla/websocket"
 | |
| )
 | |
| 
 | |
| // Server WebSocket服务器结构
 | |
| type Server struct {
 | |
| 	hub        *Hub
 | |
| 	logger     *logs.Logger
 | |
| 	deviceRepo repository.DeviceRepo
 | |
| }
 | |
| 
 | |
| const (
 | |
| 	// 允许写入的最长时间
 | |
| 	writeWait = 10 * time.Second
 | |
| 
 | |
| 	// 允许读取的最长时间
 | |
| 	pongWait = 60 * time.Second
 | |
| 
 | |
| 	// 发送ping消息的周期
 | |
| 	pingPeriod = (pongWait * 9) / 10
 | |
| 
 | |
| 	// 发送队列的最大容量
 | |
| 	maxMessageSize = 512
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	newline = []byte{'\n'}
 | |
| 	space   = []byte{' '}
 | |
| )
 | |
| 
 | |
| // Upgrader WebSocket升级器
 | |
| var upgrader = websocket.Upgrader{
 | |
| 	ReadBufferSize:  1024,
 | |
| 	WriteBufferSize: 1024,
 | |
| 	CheckOrigin: func(r *http.Request) bool {
 | |
| 		// 允许所有来源的连接(在生产环境中应该更严格)
 | |
| 		return true
 | |
| 	},
 | |
| }
 | |
| 
 | |
| // NewServer 创建新的WebSocket服务器实例
 | |
| func NewServer(deviceRepo repository.DeviceRepo) *Server {
 | |
| 	return &Server{
 | |
| 		hub:        NewHub(deviceRepo),
 | |
| 		logger:     logs.NewLogger(),
 | |
| 		deviceRepo: deviceRepo,
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // getDeviceDisplayName 获取设备显示名称
 | |
| func (s *Server) getDeviceDisplayName(deviceID string) string {
 | |
| 	if s.deviceRepo != nil {
 | |
| 		if device, err := s.deviceRepo.FindByIDString(deviceID); err == nil && device != nil {
 | |
| 			return fmt.Sprintf("%s(id:%s)", device.Name, deviceID)
 | |
| 		}
 | |
| 	}
 | |
| 	return fmt.Sprintf("未知设备(id:%s)", deviceID)
 | |
| }
 | |
| 
 | |
| // Start 启动WebSocket服务器
 | |
| func (s *Server) Start() {
 | |
| 	// 启动hub
 | |
| 	go s.hub.Run()
 | |
| }
 | |
| 
 | |
| func (s *Server) Stop() {
 | |
| 	s.hub.Close()
 | |
| }
 | |
| 
 | |
| // readPump 从WebSocket连接读取消息
 | |
| func (c *Client) readPump() {
 | |
| 	defer func() {
 | |
| 		c.hub.unregister <- c
 | |
| 		c.conn.Close()
 | |
| 	}()
 | |
| 
 | |
| 	c.conn.SetReadLimit(maxMessageSize)
 | |
| 	c.conn.SetReadDeadline(time.Now().Add(pongWait))
 | |
| 	c.conn.SetPongHandler(func(string) error {
 | |
| 		c.conn.SetReadDeadline(time.Now().Add(pongWait))
 | |
| 		return nil
 | |
| 	})
 | |
| 
 | |
| 	for {
 | |
| 		var msg Message
 | |
| 		err := c.conn.ReadJSON(&msg)
 | |
| 		if err != nil {
 | |
| 			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
 | |
| 				c.logger.Error("[WebSocket] 读取错误: " + err.Error())
 | |
| 			}
 | |
| 			break
 | |
| 		}
 | |
| 
 | |
| 		// 处理收到的消息
 | |
| 		c.hub.broadcast <- msg
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // writePump 向WebSocket连接写入消息
 | |
| func (c *Client) writePump() {
 | |
| 	ticker := time.NewTicker(pingPeriod)
 | |
| 	defer func() {
 | |
| 		ticker.Stop()
 | |
| 		c.conn.Close()
 | |
| 	}()
 | |
| 
 | |
| 	for {
 | |
| 		select {
 | |
| 		case message, ok := <-c.send:
 | |
| 			c.conn.SetWriteDeadline(time.Now().Add(writeWait))
 | |
| 			if !ok {
 | |
| 				// hub关闭了send通道
 | |
| 				c.conn.WriteMessage(websocket.CloseMessage, []byte{})
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 			w, err := c.conn.NextWriter(websocket.TextMessage)
 | |
| 			if err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 			// 将消息序列化为JSON
 | |
| 			data, err := json.Marshal(message)
 | |
| 			if err != nil {
 | |
| 				c.logger.Error("[WebSocket] 消息序列化失败: " + err.Error())
 | |
| 				continue
 | |
| 			}
 | |
| 
 | |
| 			w.Write(data)
 | |
| 
 | |
| 			// 添加队列中的其他消息
 | |
| 			n := len(c.send)
 | |
| 			for i := 0; i < n; i++ {
 | |
| 				msg := <-c.send
 | |
| 				data, err := json.Marshal(msg)
 | |
| 				if err != nil {
 | |
| 					c.logger.Error("[WebSocket] 消息序列化失败: " + err.Error())
 | |
| 					continue
 | |
| 				}
 | |
| 				w.Write(newline)
 | |
| 				w.Write(data)
 | |
| 			}
 | |
| 
 | |
| 			if err := w.Close(); err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 		case <-ticker.C:
 | |
| 			c.conn.SetWriteDeadline(time.Now().Add(writeWait))
 | |
| 			if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
 | |
| 				return
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // HandleConnection 处理WebSocket连接请求
 | |
| func (s *Server) HandleConnection(c *gin.Context) {
 | |
| 	// 升级HTTP连接为WebSocket连接
 | |
| 	conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
 | |
| 	if err != nil {
 | |
| 		s.logger.Error("[WebSocket] 连接升级失败: " + err.Error())
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// 从查询参数获取设备ID
 | |
| 	deviceID := c.Query("device_id")
 | |
| 	if deviceID == "" {
 | |
| 		s.logger.Warn("[WebSocket] 缺少设备ID参数")
 | |
| 		conn.Close()
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	// 创建客户端
 | |
| 	client := &Client{
 | |
| 		hub:      s.hub,
 | |
| 		conn:     conn,
 | |
| 		send:     make(chan Message, 256),
 | |
| 		DeviceID: deviceID,
 | |
| 		Request:  c.Request,
 | |
| 		logger:   s.logger,
 | |
| 	}
 | |
| 
 | |
| 	// 注册客户端
 | |
| 	client.hub.register <- client
 | |
| 
 | |
| 	// 启动读写goroutine
 | |
| 	go client.writePump()
 | |
| 	go client.readPump()
 | |
| 
 | |
| 	deviceName := s.getDeviceDisplayName(deviceID)
 | |
| 	s.logger.Info("[WebSocket] 设备 " + deviceName + " 连接成功")
 | |
| }
 | |
| 
 | |
| // SendToDevice 向指定设备发送消息
 | |
| func (s *Server) SendToDevice(deviceID string, msgType string, data interface{}) error {
 | |
| 	return s.hub.SendToDevice(deviceID, msgType, data)
 | |
| }
 | |
| 
 | |
| // GetConnectedDevices 获取已连接的设备列表
 | |
| func (s *Server) GetConnectedDevices() []string {
 | |
| 	return s.hub.GetConnectedDevices()
 | |
| }
 |