package server import ( "context" "log" "net/http" "sync" "ky-go-kratos/app/websocket/internal/conf" "ky-go-kratos/app/websocket/internal/service" "github.com/gorilla/websocket" ) type WebSocketServer struct { conf *conf.Server svc *service.WebSocketService upgrader websocket.Upgrader clients map[*websocket.Conn]bool broadcast chan []byte register chan *websocket.Conn unregister chan *websocket.Conn mu sync.Mutex server *http.Server } func NewWebSocketServer(c *conf.Server, svc *service.WebSocketService) *WebSocketServer { return &WebSocketServer{ conf: c, svc: svc, upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // 允许所有来源的连接 }, }, clients: make(map[*websocket.Conn]bool), broadcast: make(chan []byte), register: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn), } } func (s *WebSocketServer) Start(ctx context.Context) error { mux := http.NewServeMux() mux.HandleFunc(s.conf.Websocket.Path, s.handleWebSocket) s.server = &http.Server{ Addr: s.conf.Websocket.Addr, Handler: mux, } go s.run() log.Printf("Starting WebSocket server on %s%s", s.conf.Websocket.Addr, s.conf.Websocket.Path) go func() { if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Printf("Failed to start WebSocket server: %v", err) } }() return nil } func (s *WebSocketServer) Stop(ctx context.Context) error { if s.server != nil { return s.server.Shutdown(ctx) } return nil } func (s *WebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { log.Printf("Received WebSocket connection request from %s", r.RemoteAddr) conn, err := s.upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("Failed to upgrade connection: %v", err) return } log.Printf("WebSocket connection established with %s", r.RemoteAddr) s.register <- conn go s.readPump(conn) } func (s *WebSocketServer) readPump(conn *websocket.Conn) { defer func() { s.unregister <- conn conn.Close() }() for { _, message, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("error: %v", err) } break } log.Printf("Received message: %v", string(message)) // 处理接收到的消息 if err := s.svc.HandleMessage(context.Background(), message); err != nil { log.Printf("Error handling message: %v", err) continue } // 广播消息给所有客户端 s.broadcast <- message } } func (s *WebSocketServer) run() { for { select { case client := <-s.register: s.mu.Lock() s.clients[client] = true s.mu.Unlock() log.Printf("New client connected. Total clients: %d", len(s.clients)) case client := <-s.unregister: s.mu.Lock() if _, ok := s.clients[client]; ok { delete(s.clients, client) client.Close() } s.mu.Unlock() log.Printf("Client disconnected. Total clients: %d", len(s.clients)) case message := <-s.broadcast: s.mu.Lock() for client := range s.clients { err := client.WriteMessage(websocket.TextMessage, message) if err != nil { log.Printf("Error broadcasting message: %v", err) client.Close() delete(s.clients, client) } } s.mu.Unlock() } } }