package websocket import ( "context" "encoding/json" "fmt" "log" "sync" "time" "unicode/utf8" "github.com/gorilla/websocket" "golang.org/x/time/rate" ) // MessageQueue represents a message queue with rate limiting type MessageQueue struct { queue chan *WSMessage limiter *rate.Limiter maxSize int mu sync.RWMutex processing bool } func NewMessageQueue(maxSize int, rateLimit float64) *MessageQueue { return &MessageQueue{ queue: make(chan *WSMessage, maxSize), limiter: rate.NewLimiter(rate.Limit(rateLimit), int(rateLimit)), maxSize: maxSize, } } func (mq *MessageQueue) Push(msg *WSMessage) bool { select { case mq.queue <- msg: return true default: return false } } func (mq *MessageQueue) Start(handler func(*WSMessage)) { mq.mu.Lock() if mq.processing { mq.mu.Unlock() return } mq.processing = true mq.mu.Unlock() go func() { for msg := range mq.queue { if err := mq.limiter.Wait(context.Background()); err != nil { log.Printf("Rate limit error: %v", err) continue } handler(msg) } }() } func (mq *MessageQueue) Stop() { mq.mu.Lock() defer mq.mu.Unlock() if mq.processing { close(mq.queue) mq.processing = false } } // ConnectionConfig holds configuration for a Connection type ConnectionConfig struct { ReadBufferSize int WriteBufferSize int PingInterval time.Duration PongWait time.Duration WriteWait time.Duration MaxMessageSize int64 RateLimit float64 QueueSize int CompressionEnabled bool EncryptionEnabled bool EncryptionKey []byte ReconnectConfig *ReconnectConfig } // DefaultConnectionConfig returns default connection configuration func DefaultConnectionConfig() *ConnectionConfig { return &ConnectionConfig{ ReadBufferSize: 1024, WriteBufferSize: 1024, PingInterval: 54 * time.Second, PongWait: 60 * time.Second, WriteWait: 10 * time.Second, MaxMessageSize: 512 * 1024, // 512KB RateLimit: 100, // 100 messages per second QueueSize: 1000, // 1000 messages queue size CompressionEnabled: true, EncryptionEnabled: false, ReconnectConfig: DefaultReconnectConfig(), } } // Connection 表示一个WebSocket连接 type Connection struct { UID string // 用户唯一标识 Room *Room // 所属房间 Conn *websocket.Conn // WebSocket连接 Send chan *WSMessage // 发送消息的通道 Config *ConnectionConfig // 连接配置 mu sync.Mutex // 互斥锁 closed bool // 连接是否已关闭 Stats *ConnectionStats // 连接统计信息 Queue *MessageQueue // 消息队列 MessageHandler *MessageHandler // 消息处理器 reconnect *ReconnectManager // 重连管理器 } // ConnectionStats 记录连接的统计信息 type ConnectionStats struct { MessagesSent int64 // 发送的消息数 MessagesReceived int64 // 接收的消息数 BytesSent int64 // 发送的字节数 BytesReceived int64 // 接收的字节数 StartTime time.Time // 连接开始时间 mu sync.RWMutex // 读写锁 } // NewConnectionStats 创建新的连接统计对象 func NewConnectionStats() *ConnectionStats { return &ConnectionStats{ StartTime: time.Now(), } } // IncrementMessagesSent 增加发送消息计数 func (cs *ConnectionStats) IncrementMessagesSent(bytes int) { cs.mu.Lock() defer cs.mu.Unlock() cs.MessagesSent++ cs.BytesSent += int64(bytes) } // IncrementMessagesReceived 增加接收消息计数 func (cs *ConnectionStats) IncrementMessagesReceived(bytes int) { cs.mu.Lock() defer cs.mu.Unlock() cs.MessagesReceived++ cs.BytesReceived += int64(bytes) } // GetStats 获取统计信息 func (cs *ConnectionStats) GetStats() map[string]interface{} { cs.mu.RLock() defer cs.mu.RUnlock() return map[string]interface{}{ "messages_sent": cs.MessagesSent, "messages_received": cs.MessagesReceived, "bytes_sent": cs.BytesSent, "bytes_received": cs.BytesReceived, "uptime": time.Since(cs.StartTime).String(), } } // ReadPump 处理从WebSocket连接读取消息 func (c *Connection) ReadPump() { defer func() { log.Printf("客户端 %s 的ReadPump退出", c.UID) c.Close() c.Room.Unregister <- c }() // 设置连接参数 c.Conn.SetReadLimit(c.Config.MaxMessageSize) c.Conn.SetReadDeadline(time.Now().Add(c.Config.PongWait)) c.Conn.SetPongHandler(func(string) error { c.Conn.SetReadDeadline(time.Now().Add(c.Config.PongWait)) return nil }) log.Printf("启动客户端 %s 的ReadPump", c.UID) for { messageType, message, err := c.Conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("客户端 %s 读取错误: %v", c.UID, err) } break } log.Printf("收到来自客户端 %s 的消息类型 %d (长度: %d)", c.UID, messageType, len(message)) // 处理不同类型的消息 switch messageType { case websocket.TextMessage: // 记录原始消息内容 log.Printf("客户端 %s 的原始消息: %s", c.UID, string(message)) // 确保消息是有效的UTF-8 if !utf8.Valid(message) { log.Printf("客户端 %s 的消息包含无效的UTF-8序列,尝试修复", c.UID) message = []byte(string([]rune(string(message)))) } c.Stats.IncrementMessagesReceived(len(message)) // 处理接收到的消息 wsMsg, err := c.MessageHandler.ProcessIncomingMessage(message) if err != nil { log.Printf("处理客户端 %s 的消息时出错: %v", c.UID, err) // 发送错误消息给客户端 errorMsg := &WSMessage{ Type: "error", Content: fmt.Sprintf("处理消息时出错: %v", err), } select { case c.Send <- errorMsg: log.Printf("已发送错误消息给客户端 %s", c.UID) default: log.Printf("发送错误消息给客户端 %s 失败: 通道已满", c.UID) } continue } log.Printf("已处理消息: 类型=%s, 内容=%s", wsMsg.Type, wsMsg.Content) // 发送确认消息 ack := &WSMessage{ Type: wsMsg.Type, Ack: true, From: c.UID, To: wsMsg.To, Content: wsMsg.Content, Time: time.Now(), } select { case c.Send <- ack: log.Printf("已发送确认消息给客户端 %s", c.UID) default: log.Printf("发送确认消息给客户端 %s 失败: 通道已满", c.UID) } // 根据消息类型处理 switch wsMsg.Type { case "broadcast": log.Printf("广播来自客户端 %s 的消息: %s", c.UID, wsMsg.Content) select { case c.Room.Broadcast <- &BroadcastPayload{Message: wsMsg}: log.Printf("已将消息加入广播队列: 来自 %s", c.UID) default: log.Printf("将消息加入广播队列失败: 来自 %s, 通道已满", c.UID) } case "direct": log.Printf("发送来自客户端 %s 的私信给 %s: %s", c.UID, wsMsg.To, wsMsg.Content) directMsg := &WSMessage{ Type: wsMsg.Type, Content: wsMsg.Content, From: c.UID, To: wsMsg.To, Time: time.Now(), Data: wsMsg.Data, } select { case c.Room.Direct <- &DirectPayload{To: wsMsg.To, Message: directMsg}: log.Printf("已将私信加入队列: 从 %s 到 %s", c.UID, wsMsg.To) default: log.Printf("将私信加入队列失败: 从 %s 到 %s, 通道已满", c.UID, wsMsg.To) } case "stats": stats := c.Stats.GetStats() statsMsg := &WSMessage{ Type: "stats", Content: string(must(json.Marshal(stats))), } select { case c.Send <- statsMsg: log.Printf("已发送统计信息给客户端 %s", c.UID) default: log.Printf("发送统计信息给客户端 %s 失败: 通道已满", c.UID) } case "reconnect": c.handleReconnect() default: log.Printf("收到未知消息类型: 来自 %s, 类型=%s", c.UID, wsMsg.Type) } case websocket.BinaryMessage: log.Printf("收到二进制消息: 来自 %s, 长度=%d", c.UID, len(message)) c.Stats.IncrementMessagesReceived(len(message)) case websocket.PingMessage: log.Printf("收到ping: 来自 %s", c.UID) if err := c.Conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(c.Config.WriteWait)); err != nil { log.Printf("发送pong给客户端 %s 失败: %v", c.UID, err) } case websocket.PongMessage: log.Printf("收到pong: 来自 %s", c.UID) case websocket.CloseMessage: log.Printf("收到关闭消息: 来自 %s", c.UID) return default: log.Printf("收到未知消息类型 %d: 来自 %s", messageType, c.UID) } } } // WritePump 处理向WebSocket连接写入消息 func (c *Connection) WritePump() { ticker := time.NewTicker(c.Config.PingInterval) defer func() { ticker.Stop() c.Close() }() // 启动消息队列处理器 c.Queue.Start(func(msg *WSMessage) { c.mu.Lock() if c.closed { c.mu.Unlock() return } c.mu.Unlock() // 处理发送的消息 data, err := c.MessageHandler.ProcessOutgoingMessage(msg) if err != nil { log.Printf("处理发送的消息时出错: %v", err) return } c.mu.Lock() if c.closed { c.mu.Unlock() return } c.Conn.SetWriteDeadline(time.Now().Add(c.Config.WriteWait)) if err := c.Conn.WriteMessage(websocket.TextMessage, data); err != nil { log.Printf("写入错误: %v", err) c.mu.Unlock() c.Close() return } c.mu.Unlock() c.Stats.IncrementMessagesSent(len(data)) }) for { select { case message, ok := <-c.Send: c.mu.Lock() if !ok || c.closed { c.mu.Unlock() return } c.mu.Unlock() if !c.Queue.Push(message) { log.Printf("连接 %s 的消息队列已满", c.UID) c.Send <- &WSMessage{ Type: "error", Content: "Message queue full", } } case <-ticker.C: c.mu.Lock() if c.closed { c.mu.Unlock() return } c.Conn.SetWriteDeadline(time.Now().Add(c.Config.WriteWait)) if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { c.mu.Unlock() return } c.mu.Unlock() } } } func (c *Connection) handleReconnect() { if c.reconnect != nil { c.reconnect.Stop() } c.reconnect = NewReconnectManager(c.Config.ReconnectConfig) c.reconnect.Start(func() error { // Implement reconnection logic here // This could involve re-establishing the WebSocket connection // and rejoining the room return nil }) } func (c *Connection) Close() { c.mu.Lock() defer c.mu.Unlock() if !c.closed { c.closed = true if c.Queue != nil { c.Queue.Stop() } if c.reconnect != nil { c.reconnect.Stop() } c.Conn.Close() close(c.Send) } } // Helper function to handle JSON marshaling errors func must[T any](v T, err error) T { if err != nil { panic(err) } return v }