425 lines
11 KiB
Go
425 lines
11 KiB
Go
package websocket
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log"
|
||
"sync"
|
||
"time"
|
||
"unicode/utf8"
|
||
|
||
"github.com/gorilla/websocket"
|
||
"golang.org/x/time/rate"
|
||
)
|
||
|
||
// MessageQueue represents a message queue with rate limiting
|
||
type MessageQueue struct {
|
||
queue chan *WSMessage
|
||
limiter *rate.Limiter
|
||
maxSize int
|
||
mu sync.RWMutex
|
||
processing bool
|
||
}
|
||
|
||
func NewMessageQueue(maxSize int, rateLimit float64) *MessageQueue {
|
||
return &MessageQueue{
|
||
queue: make(chan *WSMessage, maxSize),
|
||
limiter: rate.NewLimiter(rate.Limit(rateLimit), int(rateLimit)),
|
||
maxSize: maxSize,
|
||
}
|
||
}
|
||
|
||
func (mq *MessageQueue) Push(msg *WSMessage) bool {
|
||
select {
|
||
case mq.queue <- msg:
|
||
return true
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func (mq *MessageQueue) Start(handler func(*WSMessage)) {
|
||
mq.mu.Lock()
|
||
if mq.processing {
|
||
mq.mu.Unlock()
|
||
return
|
||
}
|
||
mq.processing = true
|
||
mq.mu.Unlock()
|
||
|
||
go func() {
|
||
for msg := range mq.queue {
|
||
if err := mq.limiter.Wait(context.Background()); err != nil {
|
||
log.Printf("Rate limit error: %v", err)
|
||
continue
|
||
}
|
||
handler(msg)
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (mq *MessageQueue) Stop() {
|
||
mq.mu.Lock()
|
||
defer mq.mu.Unlock()
|
||
|
||
if mq.processing {
|
||
close(mq.queue)
|
||
mq.processing = false
|
||
}
|
||
}
|
||
|
||
// ConnectionConfig holds configuration for a Connection
|
||
type ConnectionConfig struct {
|
||
ReadBufferSize int
|
||
WriteBufferSize int
|
||
PingInterval time.Duration
|
||
PongWait time.Duration
|
||
WriteWait time.Duration
|
||
MaxMessageSize int64
|
||
RateLimit float64
|
||
QueueSize int
|
||
CompressionEnabled bool
|
||
EncryptionEnabled bool
|
||
EncryptionKey []byte
|
||
ReconnectConfig *ReconnectConfig
|
||
}
|
||
|
||
// DefaultConnectionConfig returns default connection configuration
|
||
func DefaultConnectionConfig() *ConnectionConfig {
|
||
return &ConnectionConfig{
|
||
ReadBufferSize: 1024,
|
||
WriteBufferSize: 1024,
|
||
PingInterval: 54 * time.Second,
|
||
PongWait: 60 * time.Second,
|
||
WriteWait: 10 * time.Second,
|
||
MaxMessageSize: 512 * 1024, // 512KB
|
||
RateLimit: 100, // 100 messages per second
|
||
QueueSize: 1000, // 1000 messages queue size
|
||
CompressionEnabled: true,
|
||
EncryptionEnabled: false,
|
||
ReconnectConfig: DefaultReconnectConfig(),
|
||
}
|
||
}
|
||
|
||
// Connection 表示一个WebSocket连接
|
||
type Connection struct {
|
||
UID string // 用户唯一标识
|
||
Room *Room // 所属房间
|
||
Conn *websocket.Conn // WebSocket连接
|
||
Send chan *WSMessage // 发送消息的通道
|
||
Config *ConnectionConfig // 连接配置
|
||
mu sync.Mutex // 互斥锁
|
||
closed bool // 连接是否已关闭
|
||
Stats *ConnectionStats // 连接统计信息
|
||
Queue *MessageQueue // 消息队列
|
||
MessageHandler *MessageHandler // 消息处理器
|
||
reconnect *ReconnectManager // 重连管理器
|
||
}
|
||
|
||
// ConnectionStats 记录连接的统计信息
|
||
type ConnectionStats struct {
|
||
MessagesSent int64 // 发送的消息数
|
||
MessagesReceived int64 // 接收的消息数
|
||
BytesSent int64 // 发送的字节数
|
||
BytesReceived int64 // 接收的字节数
|
||
StartTime time.Time // 连接开始时间
|
||
mu sync.RWMutex // 读写锁
|
||
}
|
||
|
||
// NewConnectionStats 创建新的连接统计对象
|
||
func NewConnectionStats() *ConnectionStats {
|
||
return &ConnectionStats{
|
||
StartTime: time.Now(),
|
||
}
|
||
}
|
||
|
||
// IncrementMessagesSent 增加发送消息计数
|
||
func (cs *ConnectionStats) IncrementMessagesSent(bytes int) {
|
||
cs.mu.Lock()
|
||
defer cs.mu.Unlock()
|
||
cs.MessagesSent++
|
||
cs.BytesSent += int64(bytes)
|
||
}
|
||
|
||
// IncrementMessagesReceived 增加接收消息计数
|
||
func (cs *ConnectionStats) IncrementMessagesReceived(bytes int) {
|
||
cs.mu.Lock()
|
||
defer cs.mu.Unlock()
|
||
cs.MessagesReceived++
|
||
cs.BytesReceived += int64(bytes)
|
||
}
|
||
|
||
// GetStats 获取统计信息
|
||
func (cs *ConnectionStats) GetStats() map[string]interface{} {
|
||
cs.mu.RLock()
|
||
defer cs.mu.RUnlock()
|
||
return map[string]interface{}{
|
||
"messages_sent": cs.MessagesSent,
|
||
"messages_received": cs.MessagesReceived,
|
||
"bytes_sent": cs.BytesSent,
|
||
"bytes_received": cs.BytesReceived,
|
||
"uptime": time.Since(cs.StartTime).String(),
|
||
}
|
||
}
|
||
|
||
// ReadPump 处理从WebSocket连接读取消息
|
||
func (c *Connection) ReadPump() {
|
||
defer func() {
|
||
log.Printf("客户端 %s 的ReadPump退出", c.UID)
|
||
c.Close()
|
||
c.Room.Unregister <- c
|
||
}()
|
||
|
||
// 设置连接参数
|
||
c.Conn.SetReadLimit(c.Config.MaxMessageSize)
|
||
c.Conn.SetReadDeadline(time.Now().Add(c.Config.PongWait))
|
||
c.Conn.SetPongHandler(func(string) error {
|
||
c.Conn.SetReadDeadline(time.Now().Add(c.Config.PongWait))
|
||
return nil
|
||
})
|
||
|
||
log.Printf("启动客户端 %s 的ReadPump", c.UID)
|
||
|
||
for {
|
||
messageType, message, err := c.Conn.ReadMessage()
|
||
if err != nil {
|
||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||
log.Printf("客户端 %s 读取错误: %v", c.UID, err)
|
||
}
|
||
break
|
||
}
|
||
|
||
log.Printf("收到来自客户端 %s 的消息类型 %d (长度: %d)", c.UID, messageType, len(message))
|
||
|
||
// 处理不同类型的消息
|
||
switch messageType {
|
||
case websocket.TextMessage:
|
||
// 记录原始消息内容
|
||
log.Printf("客户端 %s 的原始消息: %s", c.UID, string(message))
|
||
|
||
// 确保消息是有效的UTF-8
|
||
if !utf8.Valid(message) {
|
||
log.Printf("客户端 %s 的消息包含无效的UTF-8序列,尝试修复", c.UID)
|
||
message = []byte(string([]rune(string(message))))
|
||
}
|
||
|
||
c.Stats.IncrementMessagesReceived(len(message))
|
||
|
||
// 处理接收到的消息
|
||
wsMsg, err := c.MessageHandler.ProcessIncomingMessage(message)
|
||
if err != nil {
|
||
log.Printf("处理客户端 %s 的消息时出错: %v", c.UID, err)
|
||
// 发送错误消息给客户端
|
||
errorMsg := &WSMessage{
|
||
Type: "error",
|
||
Content: fmt.Sprintf("处理消息时出错: %v", err),
|
||
}
|
||
select {
|
||
case c.Send <- errorMsg:
|
||
log.Printf("已发送错误消息给客户端 %s", c.UID)
|
||
default:
|
||
log.Printf("发送错误消息给客户端 %s 失败: 通道已满", c.UID)
|
||
}
|
||
continue
|
||
}
|
||
|
||
log.Printf("已处理消息: 类型=%s, 内容=%s", wsMsg.Type, wsMsg.Content)
|
||
|
||
// 发送确认消息
|
||
ack := &WSMessage{
|
||
Type: wsMsg.Type,
|
||
Ack: true,
|
||
From: c.UID,
|
||
To: wsMsg.To,
|
||
Content: wsMsg.Content,
|
||
Time: time.Now(),
|
||
}
|
||
select {
|
||
case c.Send <- ack:
|
||
log.Printf("已发送确认消息给客户端 %s", c.UID)
|
||
default:
|
||
log.Printf("发送确认消息给客户端 %s 失败: 通道已满", c.UID)
|
||
}
|
||
|
||
// 根据消息类型处理
|
||
switch wsMsg.Type {
|
||
case "broadcast":
|
||
log.Printf("广播来自客户端 %s 的消息: %s", c.UID, wsMsg.Content)
|
||
select {
|
||
case c.Room.Broadcast <- &BroadcastPayload{Message: wsMsg}:
|
||
log.Printf("已将消息加入广播队列: 来自 %s", c.UID)
|
||
default:
|
||
log.Printf("将消息加入广播队列失败: 来自 %s, 通道已满", c.UID)
|
||
}
|
||
case "direct":
|
||
log.Printf("发送来自客户端 %s 的私信给 %s: %s", c.UID, wsMsg.To, wsMsg.Content)
|
||
directMsg := &WSMessage{
|
||
Type: wsMsg.Type,
|
||
Content: wsMsg.Content,
|
||
From: c.UID,
|
||
To: wsMsg.To,
|
||
Time: time.Now(),
|
||
Data: wsMsg.Data,
|
||
}
|
||
select {
|
||
case c.Room.Direct <- &DirectPayload{To: wsMsg.To, Message: directMsg}:
|
||
log.Printf("已将私信加入队列: 从 %s 到 %s", c.UID, wsMsg.To)
|
||
default:
|
||
log.Printf("将私信加入队列失败: 从 %s 到 %s, 通道已满", c.UID, wsMsg.To)
|
||
}
|
||
case "stats":
|
||
stats := c.Stats.GetStats()
|
||
statsMsg := &WSMessage{
|
||
Type: "stats",
|
||
Content: string(must(json.Marshal(stats))),
|
||
}
|
||
select {
|
||
case c.Send <- statsMsg:
|
||
log.Printf("已发送统计信息给客户端 %s", c.UID)
|
||
default:
|
||
log.Printf("发送统计信息给客户端 %s 失败: 通道已满", c.UID)
|
||
}
|
||
case "reconnect":
|
||
c.handleReconnect()
|
||
default:
|
||
log.Printf("收到未知消息类型: 来自 %s, 类型=%s", c.UID, wsMsg.Type)
|
||
}
|
||
|
||
case websocket.BinaryMessage:
|
||
log.Printf("收到二进制消息: 来自 %s, 长度=%d", c.UID, len(message))
|
||
c.Stats.IncrementMessagesReceived(len(message))
|
||
|
||
case websocket.PingMessage:
|
||
log.Printf("收到ping: 来自 %s", c.UID)
|
||
if err := c.Conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(c.Config.WriteWait)); err != nil {
|
||
log.Printf("发送pong给客户端 %s 失败: %v", c.UID, err)
|
||
}
|
||
|
||
case websocket.PongMessage:
|
||
log.Printf("收到pong: 来自 %s", c.UID)
|
||
|
||
case websocket.CloseMessage:
|
||
log.Printf("收到关闭消息: 来自 %s", c.UID)
|
||
return
|
||
|
||
default:
|
||
log.Printf("收到未知消息类型 %d: 来自 %s", messageType, c.UID)
|
||
}
|
||
}
|
||
}
|
||
|
||
// WritePump 处理向WebSocket连接写入消息
|
||
func (c *Connection) WritePump() {
|
||
ticker := time.NewTicker(c.Config.PingInterval)
|
||
defer func() {
|
||
ticker.Stop()
|
||
c.Close()
|
||
}()
|
||
|
||
// 启动消息队列处理器
|
||
c.Queue.Start(func(msg *WSMessage) {
|
||
c.mu.Lock()
|
||
if c.closed {
|
||
c.mu.Unlock()
|
||
return
|
||
}
|
||
c.mu.Unlock()
|
||
|
||
// 处理发送的消息
|
||
data, err := c.MessageHandler.ProcessOutgoingMessage(msg)
|
||
if err != nil {
|
||
log.Printf("处理发送的消息时出错: %v", err)
|
||
return
|
||
}
|
||
|
||
c.mu.Lock()
|
||
if c.closed {
|
||
c.mu.Unlock()
|
||
return
|
||
}
|
||
c.Conn.SetWriteDeadline(time.Now().Add(c.Config.WriteWait))
|
||
if err := c.Conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
||
log.Printf("写入错误: %v", err)
|
||
c.mu.Unlock()
|
||
c.Close()
|
||
return
|
||
}
|
||
c.mu.Unlock()
|
||
|
||
c.Stats.IncrementMessagesSent(len(data))
|
||
})
|
||
|
||
for {
|
||
select {
|
||
case message, ok := <-c.Send:
|
||
c.mu.Lock()
|
||
if !ok || c.closed {
|
||
c.mu.Unlock()
|
||
return
|
||
}
|
||
c.mu.Unlock()
|
||
|
||
if !c.Queue.Push(message) {
|
||
log.Printf("连接 %s 的消息队列已满", c.UID)
|
||
c.Send <- &WSMessage{
|
||
Type: "error",
|
||
Content: "Message queue full",
|
||
}
|
||
}
|
||
|
||
case <-ticker.C:
|
||
c.mu.Lock()
|
||
if c.closed {
|
||
c.mu.Unlock()
|
||
return
|
||
}
|
||
c.Conn.SetWriteDeadline(time.Now().Add(c.Config.WriteWait))
|
||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||
c.mu.Unlock()
|
||
return
|
||
}
|
||
c.mu.Unlock()
|
||
}
|
||
}
|
||
}
|
||
|
||
func (c *Connection) handleReconnect() {
|
||
if c.reconnect != nil {
|
||
c.reconnect.Stop()
|
||
}
|
||
|
||
c.reconnect = NewReconnectManager(c.Config.ReconnectConfig)
|
||
c.reconnect.Start(func() error {
|
||
// Implement reconnection logic here
|
||
// This could involve re-establishing the WebSocket connection
|
||
// and rejoining the room
|
||
return nil
|
||
})
|
||
}
|
||
|
||
func (c *Connection) Close() {
|
||
c.mu.Lock()
|
||
defer c.mu.Unlock()
|
||
|
||
if !c.closed {
|
||
c.closed = true
|
||
if c.Queue != nil {
|
||
c.Queue.Stop()
|
||
}
|
||
if c.reconnect != nil {
|
||
c.reconnect.Stop()
|
||
}
|
||
c.Conn.Close()
|
||
close(c.Send)
|
||
}
|
||
}
|
||
|
||
// Helper function to handle JSON marshaling errors
|
||
func must[T any](v T, err error) T {
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
return v
|
||
}
|