package websocket import ( "bytes" "compress/gzip" "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/json" "errors" "fmt" "io" "sync" "time" "unicode/utf8" ) // WSMessage 定义WebSocket消息的结构 type WSMessage struct { Type string `json:"type"` // 消息类型:broadcast/direct/system Content string `json:"content"` // 消息内容 To string `json:"to"` // 接收者ID(用于私信) From string `json:"from"` // 发送者ID Time time.Time `json:"time"` // 消息发送时间 Data interface{} `json:"data"` // 额外数据 Ack bool `json:"ack,omitempty"` // 消息确认标志 } // MessageHandler 处理WebSocket消息的处理器 type MessageHandler struct { compressionEnabled bool // 是否启用压缩 encryptionEnabled bool // 是否启用加密 encryptionKey []byte // 加密密钥 mu sync.RWMutex // 读写锁 } // NewMessageHandler 创建新的消息处理器 func NewMessageHandler(compressionEnabled bool, encryptionEnabled bool, encryptionKey []byte) *MessageHandler { return &MessageHandler{ compressionEnabled: compressionEnabled, encryptionEnabled: encryptionKey != nil, encryptionKey: encryptionKey, } } // ProcessOutgoingMessage 处理发送前的消息 func (h *MessageHandler) ProcessOutgoingMessage(msg *WSMessage) ([]byte, error) { h.mu.Lock() defer h.mu.Unlock() // 确保消息内容是有效的 UTF-8 if !utf8.ValidString(msg.Content) { msg.Content = string([]rune(msg.Content)) } // 序列化消息 data, err := json.Marshal(msg) if err != nil { return nil, fmt.Errorf("消息序列化失败: %v", err) } // 压缩消息(如果启用) if h.compressionEnabled { data, err = h.compress(data) if err != nil { return nil, fmt.Errorf("消息压缩失败: %v", err) } } // 加密消息(如果启用) if h.encryptionEnabled { data, err = h.encrypt(data) if err != nil { return nil, fmt.Errorf("消息加密失败: %v", err) } } return data, nil } // ProcessIncomingMessage 处理接收到的消息 func (h *MessageHandler) ProcessIncomingMessage(data []byte) (*WSMessage, error) { h.mu.Lock() defer h.mu.Unlock() var err error // 如果启用了加密,先解密消息 if h.encryptionEnabled { data, err = h.decrypt(data) if err != nil { return nil, fmt.Errorf("消息解密失败: %v", err) } } // 如果启用了压缩,解压消息 if h.compressionEnabled { data, err = h.decompress(data) if err != nil { return nil, fmt.Errorf("消息解压缩失败: %v", err) } } var msg WSMessage if err := json.Unmarshal(data, &msg); err != nil { return nil, fmt.Errorf("消息反序列化失败: %v", err) } // 验证消息 if err := h.validateMessage(&msg); err != nil { return nil, err } return &msg, nil } // validateMessage 验证消息的有效性 func (h *MessageHandler) validateMessage(msg *WSMessage) error { // 确保消息内容是有效的 UTF-8 if !utf8.ValidString(msg.Content) { msg.Content = string([]rune(msg.Content)) } // 验证消息类型 if msg.Type == "" { return errors.New("消息类型不能为空") } // 验证消息内容 if msg.Content == "" { return errors.New("消息内容不能为空") } // 验证私信接收者 if msg.Type == "direct" && msg.To == "" { return errors.New("私信接收者不能为空") } return nil } // compress 压缩数据 func (h *MessageHandler) compress(data []byte) ([]byte, error) { var buf bytes.Buffer zw := gzip.NewWriter(&buf) if _, err := zw.Write(data); err != nil { return nil, err } if err := zw.Close(); err != nil { return nil, err } return buf.Bytes(), nil } // decompress 解压数据 func (h *MessageHandler) decompress(data []byte) ([]byte, error) { zr, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return nil, err } defer zr.Close() return io.ReadAll(zr) } // encrypt 加密数据 func (h *MessageHandler) encrypt(data []byte) ([]byte, error) { block, err := aes.NewCipher(h.encryptionKey) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, err } ciphertext := gcm.Seal(nonce, nonce, data, nil) return ciphertext, nil } // decrypt 解密数据 func (h *MessageHandler) decrypt(data []byte) ([]byte, error) { block, err := aes.NewCipher(h.encryptionKey) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonceSize := gcm.NonceSize() if len(data) < nonceSize { return nil, errors.New("密文太短") } nonce, ciphertext := data[:nonceSize], data[nonceSize:] plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, err } return plaintext, nil } // MessageStats 跟踪消息处理统计信息 type MessageStats struct { CompressedSize int64 // 压缩后的大小 UncompressedSize int64 // 未压缩的大小 EncryptedSize int64 // 加密后的大小 DecryptedSize int64 // 解密后的大小 ProcessingTime time.Duration // 处理时间 mu sync.RWMutex // 读写锁 } // UpdateStats 更新消息统计信息 func (s *MessageStats) UpdateStats(compressed, encrypted bool, sizes ...int64) { s.mu.Lock() defer s.mu.Unlock() if compressed { s.CompressedSize += sizes[0] s.UncompressedSize += sizes[1] } if encrypted { s.EncryptedSize += sizes[2] s.DecryptedSize += sizes[3] } s.ProcessingTime = time.Since(time.Now()) } // GetStats 获取消息统计信息 func (s *MessageStats) GetStats() map[string]interface{} { s.mu.RLock() defer s.mu.RUnlock() return map[string]interface{}{ "compressed_size": s.CompressedSize, "uncompressed_size": s.UncompressedSize, "encrypted_size": s.EncryptedSize, "decrypted_size": s.DecryptedSize, "compression_ratio": float64(s.CompressedSize) / float64(s.UncompressedSize), "processing_time_ms": s.ProcessingTime.Milliseconds(), } }