go-kratos/app/websocket/internal/server/websocket.go

173 lines
3.9 KiB
Go

package server
import (
"context"
"log"
"net/http"
"sync"
"ky-go-kratos/app/websocket/internal/conf"
"ky-go-kratos/app/websocket/internal/service"
"github.com/gorilla/websocket"
)
type WebSocketServer struct {
conf *conf.Server
svc *service.WebSocketService
upgrader websocket.Upgrader
clients map[*websocket.Conn]bool
broadcast chan []byte
register chan *websocket.Conn
unregister chan *websocket.Conn
mu sync.RWMutex
server *http.Server
done chan struct{} // 添加 done 通道用于优雅关闭
}
func NewWebSocketServer(c *conf.Server, svc *service.WebSocketService) *WebSocketServer {
return &WebSocketServer{
conf: c,
svc: svc,
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有来源的连接
},
},
clients: make(map[*websocket.Conn]bool),
broadcast: make(chan []byte),
register: make(chan *websocket.Conn),
unregister: make(chan *websocket.Conn),
done: make(chan struct{}),
}
}
func (s *WebSocketServer) Start(ctx context.Context) error {
mux := http.NewServeMux()
mux.HandleFunc(s.conf.Websocket.Path, s.handleWebSocket)
s.server = &http.Server{
Addr: s.conf.Websocket.Addr,
Handler: mux,
}
go s.run()
log.Printf("Starting WebSocket server on %s%s", s.conf.Websocket.Addr, s.conf.Websocket.Path)
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Printf("Failed to start WebSocket server: %v", err)
}
}()
return nil
}
func (s *WebSocketServer) Stop(ctx context.Context) error {
// 关闭所有 WebSocket 连接
s.mu.Lock()
for client := range s.clients {
client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server is shutting down"))
client.Close()
delete(s.clients, client)
}
s.mu.Unlock()
// 关闭通道
close(s.done)
close(s.broadcast)
close(s.register)
close(s.unregister)
// 关闭 HTTP 服务器
if s.server != nil {
return s.server.Shutdown(ctx)
}
return nil
}
func (s *WebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Request) {
log.Printf("Received WebSocket connection request from %s", r.RemoteAddr)
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Failed to upgrade connection: %v", err)
return
}
log.Printf("WebSocket connection established with %s", r.RemoteAddr)
s.register <- conn
go s.readPump(conn)
}
func (s *WebSocketServer) readPump(conn *websocket.Conn) {
defer func() {
s.unregister <- conn
conn.Close()
}()
for {
select {
case <-s.done:
return
default:
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("error: %v", err)
}
return
}
log.Printf("Received message: %v", string(message))
// 处理接收到的消息
if err := s.svc.HandleMessage(context.Background(), message); err != nil {
log.Printf("Error handling message: %v", err)
continue
}
// 广播消息给所有客户端
select {
case s.broadcast <- message:
case <-s.done:
return
}
}
}
}
func (s *WebSocketServer) run() {
for {
select {
case <-s.done:
return
case client := <-s.register:
s.mu.Lock()
s.clients[client] = true
s.mu.Unlock()
log.Printf("New client connected. Total clients: %d", len(s.clients))
case client := <-s.unregister:
s.mu.Lock()
if _, ok := s.clients[client]; ok {
delete(s.clients, client)
client.Close()
}
s.mu.Unlock()
log.Printf("Client disconnected. Total clients: %d", len(s.clients))
case message := <-s.broadcast:
s.mu.Lock()
for client := range s.clients {
err := client.WriteMessage(websocket.TextMessage, message)
if err != nil {
log.Printf("Error broadcasting message: %v", err)
client.Close()
delete(s.clients, client)
}
}
s.mu.Unlock()
}
}
}