go-kratos/internal/websocket/connection.go

425 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package websocket
import (
"context"
"encoding/json"
"fmt"
"log"
"sync"
"time"
"unicode/utf8"
"github.com/gorilla/websocket"
"golang.org/x/time/rate"
)
// MessageQueue represents a message queue with rate limiting
type MessageQueue struct {
queue chan *WSMessage
limiter *rate.Limiter
maxSize int
mu sync.RWMutex
processing bool
}
func NewMessageQueue(maxSize int, rateLimit float64) *MessageQueue {
return &MessageQueue{
queue: make(chan *WSMessage, maxSize),
limiter: rate.NewLimiter(rate.Limit(rateLimit), int(rateLimit)),
maxSize: maxSize,
}
}
func (mq *MessageQueue) Push(msg *WSMessage) bool {
select {
case mq.queue <- msg:
return true
default:
return false
}
}
func (mq *MessageQueue) Start(handler func(*WSMessage)) {
mq.mu.Lock()
if mq.processing {
mq.mu.Unlock()
return
}
mq.processing = true
mq.mu.Unlock()
go func() {
for msg := range mq.queue {
if err := mq.limiter.Wait(context.Background()); err != nil {
log.Printf("Rate limit error: %v", err)
continue
}
handler(msg)
}
}()
}
func (mq *MessageQueue) Stop() {
mq.mu.Lock()
defer mq.mu.Unlock()
if mq.processing {
close(mq.queue)
mq.processing = false
}
}
// ConnectionConfig holds configuration for a Connection
type ConnectionConfig struct {
ReadBufferSize int
WriteBufferSize int
PingInterval time.Duration
PongWait time.Duration
WriteWait time.Duration
MaxMessageSize int64
RateLimit float64
QueueSize int
CompressionEnabled bool
EncryptionEnabled bool
EncryptionKey []byte
ReconnectConfig *ReconnectConfig
}
// DefaultConnectionConfig returns default connection configuration
func DefaultConnectionConfig() *ConnectionConfig {
return &ConnectionConfig{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
PingInterval: 54 * time.Second,
PongWait: 60 * time.Second,
WriteWait: 10 * time.Second,
MaxMessageSize: 512 * 1024, // 512KB
RateLimit: 100, // 100 messages per second
QueueSize: 1000, // 1000 messages queue size
CompressionEnabled: true,
EncryptionEnabled: false,
ReconnectConfig: DefaultReconnectConfig(),
}
}
// Connection 表示一个WebSocket连接
type Connection struct {
UID string // 用户唯一标识
Room *Room // 所属房间
Conn *websocket.Conn // WebSocket连接
Send chan *WSMessage // 发送消息的通道
Config *ConnectionConfig // 连接配置
mu sync.Mutex // 互斥锁
closed bool // 连接是否已关闭
Stats *ConnectionStats // 连接统计信息
Queue *MessageQueue // 消息队列
MessageHandler *MessageHandler // 消息处理器
reconnect *ReconnectManager // 重连管理器
}
// ConnectionStats 记录连接的统计信息
type ConnectionStats struct {
MessagesSent int64 // 发送的消息数
MessagesReceived int64 // 接收的消息数
BytesSent int64 // 发送的字节数
BytesReceived int64 // 接收的字节数
StartTime time.Time // 连接开始时间
mu sync.RWMutex // 读写锁
}
// NewConnectionStats 创建新的连接统计对象
func NewConnectionStats() *ConnectionStats {
return &ConnectionStats{
StartTime: time.Now(),
}
}
// IncrementMessagesSent 增加发送消息计数
func (cs *ConnectionStats) IncrementMessagesSent(bytes int) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.MessagesSent++
cs.BytesSent += int64(bytes)
}
// IncrementMessagesReceived 增加接收消息计数
func (cs *ConnectionStats) IncrementMessagesReceived(bytes int) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.MessagesReceived++
cs.BytesReceived += int64(bytes)
}
// GetStats 获取统计信息
func (cs *ConnectionStats) GetStats() map[string]interface{} {
cs.mu.RLock()
defer cs.mu.RUnlock()
return map[string]interface{}{
"messages_sent": cs.MessagesSent,
"messages_received": cs.MessagesReceived,
"bytes_sent": cs.BytesSent,
"bytes_received": cs.BytesReceived,
"uptime": time.Since(cs.StartTime).String(),
}
}
// ReadPump 处理从WebSocket连接读取消息
func (c *Connection) ReadPump() {
defer func() {
log.Printf("客户端 %s 的ReadPump退出", c.UID)
c.Close()
c.Room.Unregister <- c
}()
// 设置连接参数
c.Conn.SetReadLimit(c.Config.MaxMessageSize)
c.Conn.SetReadDeadline(time.Now().Add(c.Config.PongWait))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(c.Config.PongWait))
return nil
})
log.Printf("启动客户端 %s 的ReadPump", c.UID)
for {
messageType, message, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("客户端 %s 读取错误: %v", c.UID, err)
}
break
}
log.Printf("收到来自客户端 %s 的消息类型 %d (长度: %d)", c.UID, messageType, len(message))
// 处理不同类型的消息
switch messageType {
case websocket.TextMessage:
// 记录原始消息内容
log.Printf("客户端 %s 的原始消息: %s", c.UID, string(message))
// 确保消息是有效的UTF-8
if !utf8.Valid(message) {
log.Printf("客户端 %s 的消息包含无效的UTF-8序列尝试修复", c.UID)
message = []byte(string([]rune(string(message))))
}
c.Stats.IncrementMessagesReceived(len(message))
// 处理接收到的消息
wsMsg, err := c.MessageHandler.ProcessIncomingMessage(message)
if err != nil {
log.Printf("处理客户端 %s 的消息时出错: %v", c.UID, err)
// 发送错误消息给客户端
errorMsg := &WSMessage{
Type: "error",
Content: fmt.Sprintf("处理消息时出错: %v", err),
}
select {
case c.Send <- errorMsg:
log.Printf("已发送错误消息给客户端 %s", c.UID)
default:
log.Printf("发送错误消息给客户端 %s 失败: 通道已满", c.UID)
}
continue
}
log.Printf("已处理消息: 类型=%s, 内容=%s", wsMsg.Type, wsMsg.Content)
// 发送确认消息
ack := &WSMessage{
Type: wsMsg.Type,
Ack: true,
From: c.UID,
To: wsMsg.To,
Content: wsMsg.Content,
Time: time.Now(),
}
select {
case c.Send <- ack:
log.Printf("已发送确认消息给客户端 %s", c.UID)
default:
log.Printf("发送确认消息给客户端 %s 失败: 通道已满", c.UID)
}
// 根据消息类型处理
switch wsMsg.Type {
case "broadcast":
log.Printf("广播来自客户端 %s 的消息: %s", c.UID, wsMsg.Content)
select {
case c.Room.Broadcast <- &BroadcastPayload{Message: wsMsg}:
log.Printf("已将消息加入广播队列: 来自 %s", c.UID)
default:
log.Printf("将消息加入广播队列失败: 来自 %s, 通道已满", c.UID)
}
case "direct":
log.Printf("发送来自客户端 %s 的私信给 %s: %s", c.UID, wsMsg.To, wsMsg.Content)
directMsg := &WSMessage{
Type: wsMsg.Type,
Content: wsMsg.Content,
From: c.UID,
To: wsMsg.To,
Time: time.Now(),
Data: wsMsg.Data,
}
select {
case c.Room.Direct <- &DirectPayload{To: wsMsg.To, Message: directMsg}:
log.Printf("已将私信加入队列: 从 %s 到 %s", c.UID, wsMsg.To)
default:
log.Printf("将私信加入队列失败: 从 %s 到 %s, 通道已满", c.UID, wsMsg.To)
}
case "stats":
stats := c.Stats.GetStats()
statsMsg := &WSMessage{
Type: "stats",
Content: string(must(json.Marshal(stats))),
}
select {
case c.Send <- statsMsg:
log.Printf("已发送统计信息给客户端 %s", c.UID)
default:
log.Printf("发送统计信息给客户端 %s 失败: 通道已满", c.UID)
}
case "reconnect":
c.handleReconnect()
default:
log.Printf("收到未知消息类型: 来自 %s, 类型=%s", c.UID, wsMsg.Type)
}
case websocket.BinaryMessage:
log.Printf("收到二进制消息: 来自 %s, 长度=%d", c.UID, len(message))
c.Stats.IncrementMessagesReceived(len(message))
case websocket.PingMessage:
log.Printf("收到ping: 来自 %s", c.UID)
if err := c.Conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(c.Config.WriteWait)); err != nil {
log.Printf("发送pong给客户端 %s 失败: %v", c.UID, err)
}
case websocket.PongMessage:
log.Printf("收到pong: 来自 %s", c.UID)
case websocket.CloseMessage:
log.Printf("收到关闭消息: 来自 %s", c.UID)
return
default:
log.Printf("收到未知消息类型 %d: 来自 %s", messageType, c.UID)
}
}
}
// WritePump 处理向WebSocket连接写入消息
func (c *Connection) WritePump() {
ticker := time.NewTicker(c.Config.PingInterval)
defer func() {
ticker.Stop()
c.Close()
}()
// 启动消息队列处理器
c.Queue.Start(func(msg *WSMessage) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.mu.Unlock()
// 处理发送的消息
data, err := c.MessageHandler.ProcessOutgoingMessage(msg)
if err != nil {
log.Printf("处理发送的消息时出错: %v", err)
return
}
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.Conn.SetWriteDeadline(time.Now().Add(c.Config.WriteWait))
if err := c.Conn.WriteMessage(websocket.TextMessage, data); err != nil {
log.Printf("写入错误: %v", err)
c.mu.Unlock()
c.Close()
return
}
c.mu.Unlock()
c.Stats.IncrementMessagesSent(len(data))
})
for {
select {
case message, ok := <-c.Send:
c.mu.Lock()
if !ok || c.closed {
c.mu.Unlock()
return
}
c.mu.Unlock()
if !c.Queue.Push(message) {
log.Printf("连接 %s 的消息队列已满", c.UID)
c.Send <- &WSMessage{
Type: "error",
Content: "Message queue full",
}
}
case <-ticker.C:
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.Conn.SetWriteDeadline(time.Now().Add(c.Config.WriteWait))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
c.mu.Unlock()
return
}
c.mu.Unlock()
}
}
}
func (c *Connection) handleReconnect() {
if c.reconnect != nil {
c.reconnect.Stop()
}
c.reconnect = NewReconnectManager(c.Config.ReconnectConfig)
c.reconnect.Start(func() error {
// Implement reconnection logic here
// This could involve re-establishing the WebSocket connection
// and rejoining the room
return nil
})
}
func (c *Connection) Close() {
c.mu.Lock()
defer c.mu.Unlock()
if !c.closed {
c.closed = true
if c.Queue != nil {
c.Queue.Stop()
}
if c.reconnect != nil {
c.reconnect.Stop()
}
c.Conn.Close()
close(c.Send)
}
}
// Helper function to handle JSON marshaling errors
func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}