package service import ( "context" "encoding/json" "net/http" "time" ws "ky-go-kratos/app/websocket/internal/websocket" "ky-go-kratos/app/websocket/internal/biz" "github.com/go-kratos/kratos/v2/log" "github.com/gorilla/websocket" ) // WebSocketConfig 定义WebSocket连接的配置参数 type WebSocketConfig struct { ReadBufferSize int // 读取缓冲区大小 WriteBufferSize int // 写入缓冲区大小 PingInterval time.Duration // ping消息发送间隔 PongWait time.Duration // 等待pong响应的超时时间 WriteWait time.Duration // 写入操作的超时时间 QueueSize int // 消息队列大小 RateLimit float64 // 消息速率限制 CompressionEnabled bool // 是否启用压缩 EncryptionEnabled bool // 是否启用加密 EncryptionKey []byte // 加密密钥 } // DefaultConfig 返回默认的WebSocket配置 func DefaultConfig() *WebSocketConfig { return &WebSocketConfig{ ReadBufferSize: 1024, WriteBufferSize: 1024, PingInterval: 54 * time.Second, PongWait: 60 * time.Second, WriteWait: 10 * time.Second, QueueSize: 100, RateLimit: 10.0, CompressionEnabled: false, EncryptionEnabled: false, EncryptionKey: make([]byte, 32), } } var ( // hub 管理所有活动的房间 hub = ws.NewHub(ws.DefaultHubConfig()) // upgrader 用于将HTTP连接升级为WebSocket连接 upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true // 允许所有来源的连接,生产环境中应该限制 }, EnableCompression: false, // 禁用压缩以提高性能 } ) func init() { // 初始化默认hub实例 ws.SetHub(hub) } // WebSocketService is a WebSocket service. type WebSocketService struct { uc biz.SocketUsecase log *log.Helper } // NewWebSocketService new a WebSocket service. func NewWebSocketService(uc biz.SocketUsecase, logger log.Logger) *WebSocketService { return &WebSocketService{ uc: uc, log: log.NewHelper(logger), } } // HandleConnection handles a new WebSocket connection. func (s *WebSocketService) HandleConnection(conn *websocket.Conn, userID string, roomID string) { // 更新用户在线状态 err := s.uc.UpdateUserStatus(context.Background(), userID, roomID, true) if err != nil { s.log.Errorf("Failed to update user status: %v", err) return } // 定期清理不活跃用户 go s.cleanupInactiveUsers() // 处理消息 for { messageType, message, err := conn.ReadMessage() if err != nil { s.log.Errorf("Failed to read message: %v", err) break } if messageType == websocket.TextMessage { var msg Message if err := json.Unmarshal(message, &msg); err != nil { s.log.Errorf("Failed to unmarshal message: %v", err) continue } // 处理消息 if err := s.uc.SendMessage(context.Background(), msg.To, []string{userID}, []string{string(msg.Content)}); err != nil { s.log.Errorf("Failed to send message: %v", err) } } } // 用户断开连接时更新状态 err = s.uc.UpdateUserStatus(context.Background(), userID, roomID, false) if err != nil { s.log.Errorf("Failed to update user status: %v", err) } } // GetOnlineUsers returns a list of online users in a room. func (s *WebSocketService) GetOnlineUsers(roomID string) ([]string, error) { return s.uc.GetOnlineUsers(context.Background(), roomID) } // IsUserOnline checks if a user is online. func (s *WebSocketService) IsUserOnline(userID string) (bool, error) { return s.uc.IsOnline(context.Background(), userID) } // cleanupInactiveUsers periodically cleans up inactive users. func (s *WebSocketService) cleanupInactiveUsers() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { if err := s.uc.CleanupInactiveUsers(context.Background(), 10*time.Minute); err != nil { s.log.Errorf("Failed to cleanup inactive users: %v", err) } } } // Message represents a WebSocket message. type Message struct { To []string `json:"to"` Content []byte `json:"content"` } // WebSocketHandler 处理WebSocket连接请求 func WebSocketHandler(w http.ResponseWriter, r *http.Request) { // 记录连接请求信息 logger := log.NewHelper(log.DefaultLogger) logger.Infof("收到来自 %s 的WebSocket连接请求", r.RemoteAddr) logger.Infof("请求头: %v", r.Header) // // 设置CORS响应头 // w.Header().Set("Access-Control-Allow-Origin", "*") // w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") // w.Header().Set("Access-Control-Allow-Headers", "Content-Type") // 处理OPTIONS预检请求 if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) return } // 获取房间名和用户ID roomName := r.URL.Query().Get("room") uid := r.URL.Query().Get("uid") // 验证必要参数 if roomName == "" || uid == "" { logger.Errorf("无效的连接请求: 缺少房间名或用户ID") http.Error(w, "缺少必要参数: room和uid", http.StatusBadRequest) return } logger.Infof("尝试升级连接: room=%s, uid=%s", roomName, uid) // 升级HTTP连接为WebSocket连接 conn, err := upgrader.Upgrade(w, r, nil) if err != nil { logger.Errorf("连接升级失败: %v", err) return } logger.Infof("连接升级成功: room=%s, uid=%s", roomName, uid) // 设置连接参数 conn.SetReadLimit(512 * 1024) // 设置最大消息大小为512KB conn.SetReadDeadline(time.Now().Add(60 * time.Second)) conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) // 设置ping/pong处理器 conn.SetPingHandler(func(appData string) error { logger.Infof("收到来自客户端 %s 的ping", uid) return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(10*time.Second)) }) conn.SetPongHandler(func(string) error { logger.Infof("收到来自客户端 %s 的pong", uid) conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) // 设置连接关闭处理器 conn.SetCloseHandler(func(code int, text string) error { logger.Infof("客户端 %s 断开连接: code=%d, text=%s", uid, code, text) return nil }) // 获取或创建房间 room := hub.GetOrCreateRoom(roomName) if room == nil { logger.Errorf("获取或创建房间失败: %s", roomName) conn.Close() return } // 创建客户端连接对象 client := &ws.Connection{ UID: uid, Conn: conn, Room: room, Send: make(chan *ws.WSMessage, 256), Config: &ws.ConnectionConfig{ ReadBufferSize: 1024, WriteBufferSize: 1024, PingInterval: 54 * time.Second, PongWait: 60 * time.Second, WriteWait: 10 * time.Second, MaxMessageSize: 512 * 1024, RateLimit: 100, QueueSize: 1000, CompressionEnabled: false, EncryptionEnabled: false, ReconnectConfig: ws.DefaultReconnectConfig(), }, Stats: ws.NewConnectionStats(), Queue: ws.NewMessageQueue(1000, 100), MessageHandler: ws.NewMessageHandler(false, false, nil), } // 注册客户端到房间 room.Register <- client // 启动读写协程 go client.WritePump() go client.ReadPump() // 发送欢迎消息 welcomeMsg := &ws.WSMessage{ Type: "system", Content: "欢迎加入聊天室!", Time: time.Now(), } client.Send <- welcomeMsg logger.Infof("客户端 %s 成功连接到房间 %s", uid, roomName) // 等待连接关闭 select { case <-r.Context().Done(): logger.Infof("请求上下文已关闭") case <-time.After(24 * time.Hour): // 设置一个超时时间,防止连接永远不关闭 logger.Infof("连接超时") } logger.Infof("客户端 %s 的连接已关闭", uid) } // Shutdown gracefully closes all WebSocket connections func Shutdown() { hub.Shutdown() }