252 lines
6.1 KiB
Go
252 lines
6.1 KiB
Go
package websocket
|
||
|
||
import (
|
||
"bytes"
|
||
"compress/gzip"
|
||
"crypto/aes"
|
||
"crypto/cipher"
|
||
"crypto/rand"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"sync"
|
||
"time"
|
||
"unicode/utf8"
|
||
)
|
||
|
||
// WSMessage 定义WebSocket消息的结构
|
||
type WSMessage struct {
|
||
Type string `json:"type"` // 消息类型:broadcast/direct/system
|
||
Content string `json:"content"` // 消息内容
|
||
To string `json:"to"` // 接收者ID(用于私信)
|
||
From string `json:"from"` // 发送者ID
|
||
Time time.Time `json:"time"` // 消息发送时间
|
||
Data interface{} `json:"data"` // 额外数据
|
||
Ack bool `json:"ack,omitempty"` // 消息确认标志
|
||
}
|
||
|
||
// MessageHandler 处理WebSocket消息的处理器
|
||
type MessageHandler struct {
|
||
compressionEnabled bool // 是否启用压缩
|
||
encryptionEnabled bool // 是否启用加密
|
||
encryptionKey []byte // 加密密钥
|
||
mu sync.RWMutex // 读写锁
|
||
}
|
||
|
||
// NewMessageHandler 创建新的消息处理器
|
||
func NewMessageHandler(compressionEnabled bool, encryptionEnabled bool, encryptionKey []byte) *MessageHandler {
|
||
return &MessageHandler{
|
||
compressionEnabled: compressionEnabled,
|
||
encryptionEnabled: encryptionKey != nil,
|
||
encryptionKey: encryptionKey,
|
||
}
|
||
}
|
||
|
||
// ProcessOutgoingMessage 处理发送前的消息
|
||
func (h *MessageHandler) ProcessOutgoingMessage(msg *WSMessage) ([]byte, error) {
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
// 确保消息内容是有效的 UTF-8
|
||
if !utf8.ValidString(msg.Content) {
|
||
msg.Content = string([]rune(msg.Content))
|
||
}
|
||
|
||
// 序列化消息
|
||
data, err := json.Marshal(msg)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("消息序列化失败: %v", err)
|
||
}
|
||
|
||
// 压缩消息(如果启用)
|
||
if h.compressionEnabled {
|
||
data, err = h.compress(data)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("消息压缩失败: %v", err)
|
||
}
|
||
}
|
||
|
||
// 加密消息(如果启用)
|
||
if h.encryptionEnabled {
|
||
data, err = h.encrypt(data)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("消息加密失败: %v", err)
|
||
}
|
||
}
|
||
|
||
return data, nil
|
||
}
|
||
|
||
// ProcessIncomingMessage 处理接收到的消息
|
||
func (h *MessageHandler) ProcessIncomingMessage(data []byte) (*WSMessage, error) {
|
||
h.mu.Lock()
|
||
defer h.mu.Unlock()
|
||
|
||
var err error
|
||
|
||
// 如果启用了加密,先解密消息
|
||
if h.encryptionEnabled {
|
||
data, err = h.decrypt(data)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("消息解密失败: %v", err)
|
||
}
|
||
}
|
||
|
||
// 如果启用了压缩,解压消息
|
||
if h.compressionEnabled {
|
||
data, err = h.decompress(data)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("消息解压缩失败: %v", err)
|
||
}
|
||
}
|
||
|
||
var msg WSMessage
|
||
if err := json.Unmarshal(data, &msg); err != nil {
|
||
return nil, fmt.Errorf("消息反序列化失败: %v", err)
|
||
}
|
||
|
||
// 验证消息
|
||
if err := h.validateMessage(&msg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &msg, nil
|
||
}
|
||
|
||
// validateMessage 验证消息的有效性
|
||
func (h *MessageHandler) validateMessage(msg *WSMessage) error {
|
||
// 确保消息内容是有效的 UTF-8
|
||
if !utf8.ValidString(msg.Content) {
|
||
msg.Content = string([]rune(msg.Content))
|
||
}
|
||
|
||
// 验证消息类型
|
||
if msg.Type == "" {
|
||
return errors.New("消息类型不能为空")
|
||
}
|
||
|
||
// 验证消息内容
|
||
if msg.Content == "" {
|
||
return errors.New("消息内容不能为空")
|
||
}
|
||
|
||
// 验证私信接收者
|
||
if msg.Type == "direct" && msg.To == "" {
|
||
return errors.New("私信接收者不能为空")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// compress 压缩数据
|
||
func (h *MessageHandler) compress(data []byte) ([]byte, error) {
|
||
var buf bytes.Buffer
|
||
zw := gzip.NewWriter(&buf)
|
||
if _, err := zw.Write(data); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := zw.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
return buf.Bytes(), nil
|
||
}
|
||
|
||
// decompress 解压数据
|
||
func (h *MessageHandler) decompress(data []byte) ([]byte, error) {
|
||
zr, err := gzip.NewReader(bytes.NewReader(data))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer zr.Close()
|
||
return io.ReadAll(zr)
|
||
}
|
||
|
||
// encrypt 加密数据
|
||
func (h *MessageHandler) encrypt(data []byte) ([]byte, error) {
|
||
block, err := aes.NewCipher(h.encryptionKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
gcm, err := cipher.NewGCM(block)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
nonce := make([]byte, gcm.NonceSize())
|
||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
ciphertext := gcm.Seal(nonce, nonce, data, nil)
|
||
return ciphertext, nil
|
||
}
|
||
|
||
// decrypt 解密数据
|
||
func (h *MessageHandler) decrypt(data []byte) ([]byte, error) {
|
||
block, err := aes.NewCipher(h.encryptionKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
gcm, err := cipher.NewGCM(block)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
nonceSize := gcm.NonceSize()
|
||
if len(data) < nonceSize {
|
||
return nil, errors.New("密文太短")
|
||
}
|
||
|
||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return plaintext, nil
|
||
}
|
||
|
||
// MessageStats 跟踪消息处理统计信息
|
||
type MessageStats struct {
|
||
CompressedSize int64 // 压缩后的大小
|
||
UncompressedSize int64 // 未压缩的大小
|
||
EncryptedSize int64 // 加密后的大小
|
||
DecryptedSize int64 // 解密后的大小
|
||
ProcessingTime time.Duration // 处理时间
|
||
mu sync.RWMutex // 读写锁
|
||
}
|
||
|
||
// UpdateStats 更新消息统计信息
|
||
func (s *MessageStats) UpdateStats(compressed, encrypted bool, sizes ...int64) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
if compressed {
|
||
s.CompressedSize += sizes[0]
|
||
s.UncompressedSize += sizes[1]
|
||
}
|
||
if encrypted {
|
||
s.EncryptedSize += sizes[2]
|
||
s.DecryptedSize += sizes[3]
|
||
}
|
||
s.ProcessingTime = time.Since(time.Now())
|
||
}
|
||
|
||
// GetStats 获取消息统计信息
|
||
func (s *MessageStats) GetStats() map[string]interface{} {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
return map[string]interface{}{
|
||
"compressed_size": s.CompressedSize,
|
||
"uncompressed_size": s.UncompressedSize,
|
||
"encrypted_size": s.EncryptedSize,
|
||
"decrypted_size": s.DecryptedSize,
|
||
"compression_ratio": float64(s.CompressedSize) / float64(s.UncompressedSize),
|
||
"processing_time_ms": s.ProcessingTime.Milliseconds(),
|
||
}
|
||
}
|