go-kratos/internal/websocket/message.go

252 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package websocket
import (
"bytes"
"compress/gzip"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"sync"
"time"
"unicode/utf8"
)
// WSMessage 定义WebSocket消息的结构
type WSMessage struct {
Type string `json:"type"` // 消息类型broadcast/direct/system
Content string `json:"content"` // 消息内容
To string `json:"to"` // 接收者ID用于私信
From string `json:"from"` // 发送者ID
Time time.Time `json:"time"` // 消息发送时间
Data interface{} `json:"data"` // 额外数据
Ack bool `json:"ack,omitempty"` // 消息确认标志
}
// MessageHandler 处理WebSocket消息的处理器
type MessageHandler struct {
compressionEnabled bool // 是否启用压缩
encryptionEnabled bool // 是否启用加密
encryptionKey []byte // 加密密钥
mu sync.RWMutex // 读写锁
}
// NewMessageHandler 创建新的消息处理器
func NewMessageHandler(compressionEnabled bool, encryptionEnabled bool, encryptionKey []byte) *MessageHandler {
return &MessageHandler{
compressionEnabled: compressionEnabled,
encryptionEnabled: encryptionKey != nil,
encryptionKey: encryptionKey,
}
}
// ProcessOutgoingMessage 处理发送前的消息
func (h *MessageHandler) ProcessOutgoingMessage(msg *WSMessage) ([]byte, error) {
h.mu.Lock()
defer h.mu.Unlock()
// 确保消息内容是有效的 UTF-8
if !utf8.ValidString(msg.Content) {
msg.Content = string([]rune(msg.Content))
}
// 序列化消息
data, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("消息序列化失败: %v", err)
}
// 压缩消息(如果启用)
if h.compressionEnabled {
data, err = h.compress(data)
if err != nil {
return nil, fmt.Errorf("消息压缩失败: %v", err)
}
}
// 加密消息(如果启用)
if h.encryptionEnabled {
data, err = h.encrypt(data)
if err != nil {
return nil, fmt.Errorf("消息加密失败: %v", err)
}
}
return data, nil
}
// ProcessIncomingMessage 处理接收到的消息
func (h *MessageHandler) ProcessIncomingMessage(data []byte) (*WSMessage, error) {
h.mu.Lock()
defer h.mu.Unlock()
var err error
// 如果启用了加密,先解密消息
if h.encryptionEnabled {
data, err = h.decrypt(data)
if err != nil {
return nil, fmt.Errorf("消息解密失败: %v", err)
}
}
// 如果启用了压缩,解压消息
if h.compressionEnabled {
data, err = h.decompress(data)
if err != nil {
return nil, fmt.Errorf("消息解压缩失败: %v", err)
}
}
var msg WSMessage
if err := json.Unmarshal(data, &msg); err != nil {
return nil, fmt.Errorf("消息反序列化失败: %v", err)
}
// 验证消息
if err := h.validateMessage(&msg); err != nil {
return nil, err
}
return &msg, nil
}
// validateMessage 验证消息的有效性
func (h *MessageHandler) validateMessage(msg *WSMessage) error {
// 确保消息内容是有效的 UTF-8
if !utf8.ValidString(msg.Content) {
msg.Content = string([]rune(msg.Content))
}
// 验证消息类型
if msg.Type == "" {
return errors.New("消息类型不能为空")
}
// 验证消息内容
if msg.Content == "" {
return errors.New("消息内容不能为空")
}
// 验证私信接收者
if msg.Type == "direct" && msg.To == "" {
return errors.New("私信接收者不能为空")
}
return nil
}
// compress 压缩数据
func (h *MessageHandler) compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
zw := gzip.NewWriter(&buf)
if _, err := zw.Write(data); err != nil {
return nil, err
}
if err := zw.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decompress 解压数据
func (h *MessageHandler) decompress(data []byte) ([]byte, error) {
zr, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, err
}
defer zr.Close()
return io.ReadAll(zr)
}
// encrypt 加密数据
func (h *MessageHandler) encrypt(data []byte) ([]byte, error) {
block, err := aes.NewCipher(h.encryptionKey)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := gcm.Seal(nonce, nonce, data, nil)
return ciphertext, nil
}
// decrypt 解密数据
func (h *MessageHandler) decrypt(data []byte) ([]byte, error) {
block, err := aes.NewCipher(h.encryptionKey)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("密文太短")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
// MessageStats 跟踪消息处理统计信息
type MessageStats struct {
CompressedSize int64 // 压缩后的大小
UncompressedSize int64 // 未压缩的大小
EncryptedSize int64 // 加密后的大小
DecryptedSize int64 // 解密后的大小
ProcessingTime time.Duration // 处理时间
mu sync.RWMutex // 读写锁
}
// UpdateStats 更新消息统计信息
func (s *MessageStats) UpdateStats(compressed, encrypted bool, sizes ...int64) {
s.mu.Lock()
defer s.mu.Unlock()
if compressed {
s.CompressedSize += sizes[0]
s.UncompressedSize += sizes[1]
}
if encrypted {
s.EncryptedSize += sizes[2]
s.DecryptedSize += sizes[3]
}
s.ProcessingTime = time.Since(time.Now())
}
// GetStats 获取消息统计信息
func (s *MessageStats) GetStats() map[string]interface{} {
s.mu.RLock()
defer s.mu.RUnlock()
return map[string]interface{}{
"compressed_size": s.CompressedSize,
"uncompressed_size": s.UncompressedSize,
"encrypted_size": s.EncryptedSize,
"decrypted_size": s.DecryptedSize,
"compression_ratio": float64(s.CompressedSize) / float64(s.UncompressedSize),
"processing_time_ms": s.ProcessingTime.Milliseconds(),
}
}