package server import ( "context" "log" "net/http" "sync" "ky-go-kratos/app/websocket/internal/conf" "ky-go-kratos/app/websocket/internal/service" "github.com/gorilla/websocket" ) type WebSocketServer struct { conf *conf.Server svc *service.WebSocketService upgrader websocket.Upgrader clients map[*websocket.Conn]bool broadcast chan []byte register chan *websocket.Conn unregister chan *websocket.Conn mu sync.RWMutex server *http.Server done chan struct{} // 添加 done 通道用于优雅关闭 } // NewWebSocketServer 创建 WebSocket 服务器 func NewWebSocketServer(c *conf.Server, svc *service.WebSocketService) *WebSocketServer { return &WebSocketServer{ conf: c, svc: svc, upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许所有来源的连接 }, }, clients: make(map[*websocket.Conn]bool), broadcast: make(chan []byte), register: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn), done: make(chan struct{}), } } // Start 启动 WebSocket 服务器 func (s *WebSocketServer) Start(ctx context.Context) error { mux := http.NewServeMux() mux.HandleFunc(s.conf.Websocket.Path, s.handleWebSocket) s.server = &http.Server{ Addr: s.conf.Websocket.Addr, Handler: mux, } go s.run() log.Printf("Starting WebSocket server on %s%s", s.conf.Websocket.Addr, s.conf.Websocket.Path) go func() { if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Printf("Failed to start WebSocket server: %v", err) } }() return nil } // Stop 停止 WebSocket 服务器 func (s *WebSocketServer) Stop(ctx context.Context) error { // 关闭所有 WebSocket 连接 s.mu.Lock() for client := range s.clients { client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server is shutting down")) client.Close() delete(s.clients, client) } s.mu.Unlock() // 关闭通道 close(s.done) close(s.broadcast) close(s.register) close(s.unregister) // 关闭 HTTP 服务器 if s.server != nil { return s.server.Shutdown(ctx) } return nil } // handleWebSocket 处理 WebSocket 连接请求 func (s *WebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { log.Printf("Received WebSocket connection request from %s", r.RemoteAddr) conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("Failed to upgrade connection: %v", err) return } log.Printf("WebSocket connection established with %s", r.RemoteAddr) s.register <- conn go s.readPump(conn) } // readPump 读取 WebSocket 消息 func (s *WebSocketServer) readPump(conn *websocket.Conn) { defer func() { s.unregister <- conn conn.Close() }() for { select { case <-s.done: return default: _, message, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("error: %v", err) } return } log.Printf("Received message: %v", string(message)) // 处理接收到的消息 if err := s.svc.HandleMessage(context.Background(), message); err != nil { log.Printf("Error handling message: %v", err) continue } // 广播消息给所有客户端 select { case s.broadcast <- message: case <-s.done: return } } } } // run 运行 WebSocket 服务器 func (s *WebSocketServer) run() { for { select { case <-s.done: // 如果 done 通道被关闭,则退出循环 return case client := <-s.register: // 如果 register 通道有新的连接,则将连接添加到 clients 中 s.mu.Lock() s.clients[client] = true s.mu.Unlock() log.Printf("New client connected. Total clients: %d", len(s.clients)) case client := <-s.unregister: // 如果 unregister 通道有连接关闭,则将连接从 clients 中删除 s.mu.Lock() if _, ok := s.clients[client]; ok { delete(s.clients, client) client.Close() } s.mu.Unlock() log.Printf("Client disconnected. Total clients: %d", len(s.clients)) case message := <-s.broadcast: // 如果 broadcast 通道有消息,则广播消息给所有客户端 s.mu.Lock() for client := range s.clients { err := client.WriteMessage(websocket.TextMessage, message) if err != nil { log.Printf("Error broadcasting message: %v", err) client.Close() delete(s.clients, client) } } s.mu.Unlock() } } }