diff --git a/app/websocket/cmd/websocket/wire_gen.go b/app/websocket/cmd/websocket/wire_gen.go index 5cb77f1..82e1661 100644 --- a/app/websocket/cmd/websocket/wire_gen.go +++ b/app/websocket/cmd/websocket/wire_gen.go @@ -31,8 +31,7 @@ func wireApp(confServer *conf.Server, confData *conf.Data, logger log.Logger) (* webSocketRepo := data.NewWebSocketRepo(dataData, logger) webSocketUsecase := biz.NewWebSocketUsecase(webSocketRepo) webSocketService := service.NewWebSocketService(webSocketUsecase) - messageHandler := service.NewMessageHandler(webSocketService) - kafkaConsumerServer, cleanup2 := server.NewKafkaConsumerServer(confData, messageHandler) + kafkaConsumerServer, cleanup2 := server.NewKafkaConsumerServer(confData, webSocketService) webSocketServer := server.NewWebSocketServer(confServer, webSocketService) grpcServer := server.NewGRPCServer(confServer, logger) httpServer := server.NewHTTPServer(confServer, logger) diff --git a/app/websocket/configs/config.yaml b/app/websocket/configs/config.yaml index 4646622..20ac7f3 100644 --- a/app/websocket/configs/config.yaml +++ b/app/websocket/configs/config.yaml @@ -21,3 +21,4 @@ data: - 47.108.232.131:9092 topic: websocket-topic partition: 0 + group_id: websocket-topic diff --git a/app/websocket/internal/conf/conf.pb.go b/app/websocket/internal/conf/conf.pb.go index 1792ba8..e7bac09 100644 --- a/app/websocket/internal/conf/conf.pb.go +++ b/app/websocket/internal/conf/conf.pb.go @@ -355,6 +355,7 @@ type Data_Kafka struct { Brokers []string `protobuf:"bytes,1,rep,name=brokers,proto3" json:"brokers,omitempty"` Topic string `protobuf:"bytes,2,opt,name=topic,proto3" json:"topic,omitempty"` Partition int64 `protobuf:"varint,3,opt,name=partition,proto3" json:"partition,omitempty"` + GroupId string `protobuf:"bytes,4,opt,name=group_id,json=groupId,proto3" json:"group_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -410,6 +411,13 @@ func (x *Data_Kafka) GetPartition() int64 { return 0 } +func (x *Data_Kafka) GetGroupId() string { + if x != nil { + return x.GroupId + } + return "" +} + var File_conf_conf_proto protoreflect.FileDescriptor const file_conf_conf_proto_rawDesc = "" + @@ -433,13 +441,14 @@ const file_conf_conf_proto_rawDesc = "" + "\atimeout\x18\x03 \x01(\v2\x19.google.protobuf.DurationR\atimeout\x1a3\n" + "\tWebSocket\x12\x12\n" + "\x04addr\x18\x01 \x01(\tR\x04addr\x12\x12\n" + - "\x04path\x18\x02 \x01(\tR\x04path\"\x8b\x01\n" + + "\x04path\x18\x02 \x01(\tR\x04path\"\xa6\x01\n" + "\x04Data\x12,\n" + - "\x05kafka\x18\x01 \x01(\v2\x16.kratos.api.Data.KafkaR\x05kafka\x1aU\n" + + "\x05kafka\x18\x01 \x01(\v2\x16.kratos.api.Data.KafkaR\x05kafka\x1ap\n" + "\x05Kafka\x12\x18\n" + "\abrokers\x18\x01 \x03(\tR\abrokers\x12\x14\n" + "\x05topic\x18\x02 \x01(\tR\x05topic\x12\x1c\n" + - "\tpartition\x18\x03 \x01(\x03R\tpartitionB/Z-ky-go-kratos/app/websocket/internal/conf;confb\x06proto3" + "\tpartition\x18\x03 \x01(\x03R\tpartition\x12\x19\n" + + "\bgroup_id\x18\x04 \x01(\tR\agroupIdB/Z-ky-go-kratos/app/websocket/internal/conf;confb\x06proto3" var ( file_conf_conf_proto_rawDescOnce sync.Once diff --git a/app/websocket/internal/conf/conf.proto b/app/websocket/internal/conf/conf.proto index 82f6bed..413cba2 100644 --- a/app/websocket/internal/conf/conf.proto +++ b/app/websocket/internal/conf/conf.proto @@ -35,6 +35,7 @@ message Data { repeated string brokers = 1; string topic = 2; int64 partition = 3; + string group_id = 4; } Kafka kafka = 1; } diff --git a/app/websocket/internal/data/websocket.go b/app/websocket/internal/data/websocket.go index 886aad1..e755fb6 100644 --- a/app/websocket/internal/data/websocket.go +++ b/app/websocket/internal/data/websocket.go @@ -30,7 +30,6 @@ type Message struct { // SendMessage sends a message to Kafka. func (r *WebSocketRepo) SendMessage(ctx context.Context, message []byte) error { - r.log.Warnf(">>>>>>>> %v", string(message)) err := r.data.kafka.SendToKafkaMessage(ctx, "websocket", message) if err != nil { r.log.Errorf("failed to send message to kafka: %v", err) diff --git a/app/websocket/internal/server/kafka_consume.go b/app/websocket/internal/server/kafka_consume.go index eb8c6ae..c6295f5 100644 --- a/app/websocket/internal/server/kafka_consume.go +++ b/app/websocket/internal/server/kafka_consume.go @@ -17,30 +17,23 @@ type KafkaConsumerServer struct { wg sync.WaitGroup } -func NewKafkaConsumerServer(c *conf.Data, messageHandler *service.MessageHandler) (*KafkaConsumerServer, func()) { - consumer := kafka.NewKafkaReader(c.Kafka.Brokers, c.Kafka.Topic, int(c.Kafka.Partition)) +func NewKafkaConsumerServer(c *conf.Data, wsService *service.WebSocketService) (*KafkaConsumerServer, func()) { + consumer := kafka.NewKafkaReader(c.Kafka.Brokers, c.Kafka.Topic, int(c.Kafka.Partition), c.Kafka.GroupId) + messageHandler := service.NewMessageHandler(wsService, consumer.Reader) + server := &KafkaConsumerServer{ consumer: consumer, messageHandler: messageHandler, } - ctx := context.Background() - if err := server.Start(ctx); err != nil { - panic(err) - } - return server, func() { - if err := server.Stop(ctx); err != nil { + if err := server.Stop(context.Background()); err != nil { log.Errorf("failed to stop kafka consumer: %v", err) } } } func (s *KafkaConsumerServer) Start(ctx context.Context) error { - if err := s.messageHandler.Start(); err != nil { - return err - } - s.wg.Add(1) go s.consumeMessages(ctx) return nil @@ -48,9 +41,6 @@ func (s *KafkaConsumerServer) Start(ctx context.Context) error { func (s *KafkaConsumerServer) Stop(ctx context.Context) error { s.wg.Wait() - if err := s.messageHandler.Stop(); err != nil { - return err - } return s.consumer.Close() } @@ -63,7 +53,6 @@ func (s *KafkaConsumerServer) consumeMessages(ctx context.Context) { return default: message, err := s.consumer.Reader.ReadMessage(ctx) - log.Infof("message data: %v", message) if err != nil { if s.consumer.IsTransientNetworkError(err) { continue diff --git a/app/websocket/internal/server/websocket.go b/app/websocket/internal/server/websocket.go index e97a44d..032f85b 100644 --- a/app/websocket/internal/server/websocket.go +++ b/app/websocket/internal/server/websocket.go @@ -20,8 +20,9 @@ type WebSocketServer struct { broadcast chan []byte register chan *websocket.Conn unregister chan *websocket.Conn - mu sync.Mutex + mu sync.RWMutex server *http.Server + done chan struct{} // 添加 done 通道用于优雅关闭 } func NewWebSocketServer(c *conf.Server, svc *service.WebSocketService) *WebSocketServer { @@ -37,6 +38,7 @@ func NewWebSocketServer(c *conf.Server, svc *service.WebSocketService) *WebSocke broadcast: make(chan []byte), register: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn), + done: make(chan struct{}), } } @@ -63,6 +65,22 @@ func (s *WebSocketServer) Start(ctx context.Context) error { } func (s *WebSocketServer) Stop(ctx context.Context) error { + // 关闭所有 WebSocket 连接 + s.mu.Lock() + for client := range s.clients { + client.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "Server is shutting down")) + client.Close() + delete(s.clients, client) + } + s.mu.Unlock() + + // 关闭通道 + close(s.done) + close(s.broadcast) + close(s.register) + close(s.unregister) + + // 关闭 HTTP 服务器 if s.server != nil { return s.server.Shutdown(ctx) } @@ -90,30 +108,41 @@ func (s *WebSocketServer) readPump(conn *websocket.Conn) { }() for { - _, message, err := conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Printf("error: %v", err) + select { + case <-s.done: + return + default: + _, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("error: %v", err) + } + return + } + + log.Printf("Received message: %v", string(message)) + + // 处理接收到的消息 + if err := s.svc.HandleMessage(context.Background(), message); err != nil { + log.Printf("Error handling message: %v", err) + continue + } + + // 广播消息给所有客户端 + select { + case s.broadcast <- message: + case <-s.done: + return } - break } - - log.Printf("Received message: %v", string(message)) - - // 处理接收到的消息 - if err := s.svc.HandleMessage(context.Background(), message); err != nil { - log.Printf("Error handling message: %v", err) - continue - } - - // 广播消息给所有客户端 - s.broadcast <- message } } func (s *WebSocketServer) run() { for { select { + case <-s.done: + return case client := <-s.register: s.mu.Lock() s.clients[client] = true diff --git a/app/websocket/internal/service/message_handler.go b/app/websocket/internal/service/kafka_message_handler.go similarity index 75% rename from app/websocket/internal/service/message_handler.go rename to app/websocket/internal/service/kafka_message_handler.go index 4083d42..9192fd7 100644 --- a/app/websocket/internal/service/message_handler.go +++ b/app/websocket/internal/service/kafka_message_handler.go @@ -11,6 +11,7 @@ import ( // MessageHandler 处理 Kafka 消息 type MessageHandler struct { wsService *WebSocketService + reader *kafkago.Reader } // WebSocketMessage 定义 WebSocket 消息结构 @@ -20,20 +21,13 @@ type WebSocketMessage struct { Payload json.RawMessage `json:"payload,omitempty"` } -func NewMessageHandler(wsService *WebSocketService) *MessageHandler { +func NewMessageHandler(wsService *WebSocketService, reader *kafkago.Reader) *MessageHandler { return &MessageHandler{ wsService: wsService, + reader: reader, } } -func (h *MessageHandler) Start() error { - return nil -} - -func (h *MessageHandler) Stop() error { - return nil -} - // parseJSONMessage 解析并格式化 JSON 消息 func (h *MessageHandler) parseJSONMessage(message []byte) (string, error) { var jsonData interface{} @@ -65,7 +59,7 @@ func (h *MessageHandler) HandleMessage(message *kafkago.Message) error { log.Printf("Error parsing message: %v", err) return err } - fmt.Printf("Received message:\n%s\n", prettyJSON) + fmt.Printf("kafka parseJSONMessage:\n%s\n", prettyJSON) // 解析为 WebSocketMessage wsMsg, err := h.parseWebSocketMessage(message.Value) @@ -75,18 +69,27 @@ func (h *MessageHandler) HandleMessage(message *kafkago.Message) error { } // 根据消息类型处理不同的业务逻辑 + var processErr error switch wsMsg.Type { case "message": // 处理普通消息 - return h.wsService.Broadcast(message.Value) + processErr = h.wsService.Broadcast(message.Value) case "broadcast": // 广播消息给所有连接的客户端 - return h.wsService.Broadcast(wsMsg.Payload) + processErr = h.wsService.Broadcast(wsMsg.Payload) case "private": // 处理私聊消息 - return h.wsService.SendPrivateMessage(wsMsg.Payload) + processErr = h.wsService.SendPrivateMessage(wsMsg.Payload) default: log.Printf("Unknown message type: %s", wsMsg.Type) - return nil } + + if processErr != nil { + log.Printf("Error processing message: %v", processErr) + } else { + log.Printf("Message processed successfully - Topic: %s, Partition: %d, Offset: %d", + message.Topic, message.Partition, message.Offset) + } + + return processErr } diff --git a/pkg/kafka/consumer.go b/pkg/kafka/consumer.go index 85ac62b..4cc1165 100644 --- a/pkg/kafka/consumer.go +++ b/pkg/kafka/consumer.go @@ -13,7 +13,7 @@ type KafkaConsumer struct { Reader *kafka.Reader } -func NewKafkaReader(brokers []string, topic string, partition int) *KafkaConsumer { +func NewKafkaReader(brokers []string, topic string, partition int, groupID string) *KafkaConsumer { return &KafkaConsumer{ Reader: kafka.NewReader(kafka.ReaderConfig{ Brokers: brokers, @@ -23,6 +23,7 @@ func NewKafkaReader(brokers []string, topic string, partition int) *KafkaConsume MaxBytes: 10e6, MaxWait: 500 * time.Millisecond, CommitInterval: 5 * time.Second, + // GroupID: groupID, }), } }