diff --git a/client.go b/client.go index 2eeabb1..083d316 100644 --- a/client.go +++ b/client.go @@ -32,6 +32,8 @@ type Client struct { Categories conn *websocket.Conn + connWLocker sync.Mutex + connRLocker sync.Mutex scheme string host string password string @@ -120,7 +122,7 @@ func (c *Client) Disconnect() error { c.client.Log.Printf("[DEBUG] Sending disconnect message") c.markDisconnected() - if err := c.conn.WriteMessage( + if err := c.writeMessage( websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Bye"), ); err != nil { @@ -131,6 +133,15 @@ func (c *Client) Disconnect() error { return nil } +func (c *Client) writeMessage(messageType int, data []byte) error { + c.connWLocker.Lock() + defer c.connWLocker.Unlock() + return c.conn.WriteMessage( + messageType, + data, + ) +} + func (c *Client) markDisconnected() { c.once.Do(func() { select { @@ -222,13 +233,19 @@ func (c *Client) connect() (err error) { } } +func (c *Client) readJSON(v any) error { + c.connRLocker.Lock() + defer c.connRLocker.Unlock() + return c.conn.ReadJSON(v) +} + // translates raw server messages into opcodes func (c *Client) handleRawServerMessages(auth chan<- error) { defer c.client.Log.Printf("[TRACE] Finished handling raw server messages") for { raw := json.RawMessage{} - if err := c.conn.ReadJSON(&raw); err != nil { + if err := c.readJSON(&raw); err != nil { switch t := err.(type) { case *json.UnmarshalTypeError: c.client.Log.Printf("[ERROR] Reading from connection: %s: %s", t, raw) @@ -323,7 +340,7 @@ func (c *Client) handleOpcodes(auth chan<- error) { c.client.Log.Printf("[INFO] Identify;") msg := opcodes.Wrap(val).Bytes() - if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil { + if err := c.writeMessage(websocket.TextMessage, msg); err != nil { auth <- fmt.Errorf("sending Identify to server `%s`: %w", msg, err) } @@ -354,7 +371,7 @@ func (c *Client) handleOpcodes(auth chan<- error) { c.client.Log.Printf("[TRACE] Got %s Request with ID %s", val.Type, val.ID) msg := opcodes.Wrap(val).Bytes() - if err := c.conn.WriteMessage(websocket.TextMessage, msg); err != nil { + if err := c.writeMessage(websocket.TextMessage, msg); err != nil { c.client.Log.Printf("[ERROR] Sending Request to server `%s`: %s", msg, err) }