144 lines
3.3 KiB
Go
144 lines
3.3 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"ky-go-kratos/app/websocket/internal/conf"
|
|
"ky-go-kratos/app/websocket/internal/service"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type WebSocketServer struct {
|
|
conf *conf.Server
|
|
svc *service.WebSocketService
|
|
upgrader websocket.Upgrader
|
|
clients map[*websocket.Conn]bool
|
|
broadcast chan []byte
|
|
register chan *websocket.Conn
|
|
unregister chan *websocket.Conn
|
|
mu sync.Mutex
|
|
server *http.Server
|
|
}
|
|
|
|
func NewWebSocketServer(c *conf.Server, svc *service.WebSocketService) *WebSocketServer {
|
|
return &WebSocketServer{
|
|
conf: c,
|
|
svc: svc,
|
|
upgrader: websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true // 允许所有来源的连接
|
|
},
|
|
},
|
|
clients: make(map[*websocket.Conn]bool),
|
|
broadcast: make(chan []byte),
|
|
register: make(chan *websocket.Conn),
|
|
unregister: make(chan *websocket.Conn),
|
|
}
|
|
}
|
|
|
|
func (s *WebSocketServer) Start(ctx context.Context) error {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(s.conf.Websocket.Path, s.handleWebSocket)
|
|
|
|
s.server = &http.Server{
|
|
Addr: s.conf.Websocket.Addr,
|
|
Handler: mux,
|
|
}
|
|
|
|
go s.run()
|
|
|
|
log.Printf("Starting WebSocket server on %s%s", s.conf.Websocket.Addr, s.conf.Websocket.Path)
|
|
|
|
go func() {
|
|
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Printf("Failed to start WebSocket server: %v", err)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *WebSocketServer) Stop(ctx context.Context) error {
|
|
if s.server != nil {
|
|
return s.server.Shutdown(ctx)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *WebSocketServer) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
log.Printf("Received WebSocket connection request from %s", r.RemoteAddr)
|
|
|
|
conn, err := s.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Printf("Failed to upgrade connection: %v", err)
|
|
return
|
|
}
|
|
|
|
log.Printf("WebSocket connection established with %s", r.RemoteAddr)
|
|
s.register <- conn
|
|
go s.readPump(conn)
|
|
}
|
|
|
|
func (s *WebSocketServer) readPump(conn *websocket.Conn) {
|
|
defer func() {
|
|
s.unregister <- conn
|
|
conn.Close()
|
|
}()
|
|
|
|
for {
|
|
_, message, err := conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Printf("error: %v", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
log.Printf("Received message: %v", string(message))
|
|
|
|
// 处理接收到的消息
|
|
if err := s.svc.HandleMessage(context.Background(), message); err != nil {
|
|
log.Printf("Error handling message: %v", err)
|
|
continue
|
|
}
|
|
|
|
// 广播消息给所有客户端
|
|
s.broadcast <- message
|
|
}
|
|
}
|
|
|
|
func (s *WebSocketServer) run() {
|
|
for {
|
|
select {
|
|
case client := <-s.register:
|
|
s.mu.Lock()
|
|
s.clients[client] = true
|
|
s.mu.Unlock()
|
|
log.Printf("New client connected. Total clients: %d", len(s.clients))
|
|
case client := <-s.unregister:
|
|
s.mu.Lock()
|
|
if _, ok := s.clients[client]; ok {
|
|
delete(s.clients, client)
|
|
client.Close()
|
|
}
|
|
s.mu.Unlock()
|
|
log.Printf("Client disconnected. Total clients: %d", len(s.clients))
|
|
case message := <-s.broadcast:
|
|
s.mu.Lock()
|
|
for client := range s.clients {
|
|
err := client.WriteMessage(websocket.TextMessage, message)
|
|
if err != nil {
|
|
log.Printf("Error broadcasting message: %v", err)
|
|
client.Close()
|
|
delete(s.clients, client)
|
|
}
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
}
|
|
}
|