211 lines
4.7 KiB
Go
211 lines
4.7 KiB
Go
// Package websocket 提供WebSocket通信功能
|
|
// 实现中继设备与平台之间的实时通信
|
|
package websocket
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"git.huangwc.com/pig/pig-farm-controller/internal/logs"
|
|
"git.huangwc.com/pig/pig-farm-controller/internal/storage/repository"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// Server WebSocket服务器结构
|
|
type Server struct {
|
|
hub *Hub
|
|
logger *logs.Logger
|
|
deviceRepo repository.DeviceRepo
|
|
}
|
|
|
|
const (
|
|
// 允许写入的最长时间
|
|
writeWait = 10 * time.Second
|
|
|
|
// 允许读取的最长时间
|
|
pongWait = 60 * time.Second
|
|
|
|
// 发送ping消息的周期
|
|
pingPeriod = (pongWait * 9) / 10
|
|
|
|
// 发送队列的最大容量
|
|
maxMessageSize = 512
|
|
)
|
|
|
|
var (
|
|
newline = []byte{'\n'}
|
|
space = []byte{' '}
|
|
)
|
|
|
|
// Upgrader WebSocket升级器
|
|
var upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
// 允许所有来源的连接(在生产环境中应该更严格)
|
|
return true
|
|
},
|
|
}
|
|
|
|
// NewServer 创建新的WebSocket服务器实例
|
|
func NewServer(deviceRepo repository.DeviceRepo) *Server {
|
|
return &Server{
|
|
hub: NewHub(deviceRepo),
|
|
logger: logs.NewLogger(),
|
|
deviceRepo: deviceRepo,
|
|
}
|
|
}
|
|
|
|
// getDeviceDisplayName 获取设备显示名称
|
|
func (s *Server) getDeviceDisplayName(deviceID string) string {
|
|
if s.deviceRepo != nil {
|
|
if device, err := s.deviceRepo.FindByIDString(deviceID); err == nil && device != nil {
|
|
return fmt.Sprintf("%s(id:%s)", device.Name, deviceID)
|
|
}
|
|
}
|
|
return fmt.Sprintf("未知设备(id:%s)", deviceID)
|
|
}
|
|
|
|
// Start 启动WebSocket服务器
|
|
func (s *Server) Start() {
|
|
// 启动hub
|
|
go s.hub.Run()
|
|
}
|
|
|
|
// readPump 从WebSocket连接读取消息
|
|
func (c *Client) readPump() {
|
|
defer func() {
|
|
c.hub.unregister <- c
|
|
c.conn.Close()
|
|
}()
|
|
|
|
c.conn.SetReadLimit(maxMessageSize)
|
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
c.conn.SetPongHandler(func(string) error {
|
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
var msg Message
|
|
err := c.conn.ReadJSON(&msg)
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
c.logger.Error("[WebSocket] 读取错误: " + err.Error())
|
|
}
|
|
break
|
|
}
|
|
|
|
// 处理收到的消息
|
|
c.hub.broadcast <- msg
|
|
}
|
|
}
|
|
|
|
// writePump 向WebSocket连接写入消息
|
|
func (c *Client) writePump() {
|
|
ticker := time.NewTicker(pingPeriod)
|
|
defer func() {
|
|
ticker.Stop()
|
|
c.conn.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.send:
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
if !ok {
|
|
// hub关闭了send通道
|
|
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
|
|
w, err := c.conn.NextWriter(websocket.TextMessage)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// 将消息序列化为JSON
|
|
data, err := json.Marshal(message)
|
|
if err != nil {
|
|
c.logger.Error("[WebSocket] 消息序列化失败: " + err.Error())
|
|
continue
|
|
}
|
|
|
|
w.Write(data)
|
|
|
|
// 添加队列中的其他消息
|
|
n := len(c.send)
|
|
for i := 0; i < n; i++ {
|
|
msg := <-c.send
|
|
data, err := json.Marshal(msg)
|
|
if err != nil {
|
|
c.logger.Error("[WebSocket] 消息序列化失败: " + err.Error())
|
|
continue
|
|
}
|
|
w.Write(newline)
|
|
w.Write(data)
|
|
}
|
|
|
|
if err := w.Close(); err != nil {
|
|
return
|
|
}
|
|
case <-ticker.C:
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// HandleConnection 处理WebSocket连接请求
|
|
func (s *Server) HandleConnection(c *gin.Context) {
|
|
// 升级HTTP连接为WebSocket连接
|
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
s.logger.Error("[WebSocket] 连接升级失败: " + err.Error())
|
|
return
|
|
}
|
|
|
|
// 从查询参数获取设备ID
|
|
deviceID := c.Query("device_id")
|
|
if deviceID == "" {
|
|
s.logger.Warn("[WebSocket] 缺少设备ID参数")
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// 创建客户端
|
|
client := &Client{
|
|
hub: s.hub,
|
|
conn: conn,
|
|
send: make(chan Message, 256),
|
|
DeviceID: deviceID,
|
|
Request: c.Request,
|
|
logger: s.logger,
|
|
}
|
|
|
|
// 注册客户端
|
|
client.hub.register <- client
|
|
|
|
// 启动读写goroutine
|
|
go client.writePump()
|
|
go client.readPump()
|
|
|
|
deviceName := s.getDeviceDisplayName(deviceID)
|
|
s.logger.Info("[WebSocket] 设备 " + deviceName + " 连接成功")
|
|
}
|
|
|
|
// SendToDevice 向指定设备发送消息
|
|
func (s *Server) SendToDevice(deviceID string, msgType string, data interface{}) error {
|
|
return s.hub.SendToDevice(deviceID, msgType, data)
|
|
}
|
|
|
|
// GetConnectedDevices 获取已连接的设备列表
|
|
func (s *Server) GetConnectedDevices() []string {
|
|
return s.hub.GetConnectedDevices()
|
|
}
|