diff --git a/internal/rest/chat/chat.go b/internal/rest/chat/chat.go index e780738..35debf8 100644 --- a/internal/rest/chat/chat.go +++ b/internal/rest/chat/chat.go @@ -14,6 +14,7 @@ import ( "socio/usecase/csrf" "strconv" "strings" + "sync" "time" "github.com/gorilla/mux" @@ -36,6 +37,7 @@ const ( type ChatServer struct { Service ChatService + wsConns *sync.Map } type ChatService interface { @@ -63,9 +65,25 @@ var upgrader = &websocket.Upgrader{ func NewChatServer(service ChatService) (chatServer *ChatServer) { return &ChatServer{ Service: service, + wsConns: &sync.Map{}, } } +func (c *ChatServer) getWSConns(userID uint) (conns []*websocket.Conn, ok bool) { + untypedConns, ok := c.wsConns.Load(userID) + if !ok { + conns = nil + return + } + + conns, ok = untypedConns.([]*websocket.Conn) + if !ok { + return + } + + return +} + // HandleGetDialogs godoc // // @Summary get user dialogs @@ -246,7 +264,18 @@ func (c *ChatServer) ServeWS(w http.ResponseWriter, r *http.Request) { return } - go c.listenWrite(r.Context(), conn, client) + conns, ok := c.wsConns.Load(userID) + if !ok { + conns = make([]*websocket.Conn, 0, 1) + conns = append(conns.([]*websocket.Conn), conn) + c.wsConns.Store(userID, conns) + + go c.listenWrite(r.Context(), client) + } else { + conns = append(conns.([]*websocket.Conn), conn) + c.wsConns.Store(userID, conns) + } + go c.listenRead(r.Context(), conn, client) } @@ -308,17 +337,25 @@ func (c *ChatServer) listenRead(ctx context.Context, conn *websocket.Conn, clien } } -func (c *ChatServer) listenWrite(ctx context.Context, conn *websocket.Conn, client *chat.Client) { +func (c *ChatServer) listenWrite(ctx context.Context, client *chat.Client) { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() - err := conn.Close() - if err != nil { + + conns, ok := c.getWSConns(client.UserID) + if !ok { return } - err = c.Service.Unregister(client.UserID) + for _, conn := range conns { + err := conn.Close() + if err != nil { + return + } + } + + err := c.Service.Unregister(client.UserID) if err != nil { return } @@ -327,64 +364,82 @@ func (c *ChatServer) listenWrite(ctx context.Context, conn *websocket.Conn, clie for { select { case message, ok := <-client.Send: - err := conn.SetWriteDeadline(time.Now().Add(writeWait)) + messages := make([][]byte, 0, len(client.Send)+1) + + messageData, err := easyjson.Marshal(message) if err != nil { return } - if !ok { - err := conn.WriteMessage(websocket.CloseMessage, []byte{}) + messages = append(messages, messageData) + + n := len(client.Send) + for i := 0; i < n; i++ { + messageData, err = easyjson.Marshal(<-client.Send) if err != nil { return } - return - } - - w, err := conn.NextWriter(websocket.TextMessage) - if err != nil { - return - } - - messageData, err := easyjson.Marshal(message) - if err != nil { - return + messages = append(messages, messageData) } - _, err = w.Write(messageData) - if err != nil { + conns, ok := c.getWSConns(client.UserID) + if !ok { return } - n := len(client.Send) - for i := 0; i < n; i++ { - messageData, err = easyjson.Marshal(<-client.Send) + for _, conn := range conns { + err := conn.SetWriteDeadline(time.Now().Add(writeWait)) if err != nil { return } - _, err := w.Write([]byte{newline}) - if err != nil { + + if !ok { + err := conn.WriteMessage(websocket.CloseMessage, []byte{}) + if err != nil { + return + } + return } - _, err = w.Write(messageData) + w, err := conn.NextWriter(websocket.TextMessage) if err != nil { return } - } - if err := w.Close(); err != nil { - return + for _, message := range messages { + _, err := w.Write([]byte{newline}) + if err != nil { + return + } + + _, err = w.Write(message) + if err != nil { + return + } + } + + if err := w.Close(); err != nil { + return + } } case <-ticker.C: - err := conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err != nil { + conns, ok := c.getWSConns(client.UserID) + if !ok { return } - if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { - return + for _, conn := range conns { + err := conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err != nil { + return + } + + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } } } }