From 1267b7f47597143743e1d2e847d6be7bc580ccc8 Mon Sep 17 00:00:00 2001 From: Andrey Kaipov <9457739+andreykaipov@users.noreply.github.com> Date: Sun, 3 Mar 2024 23:14:59 -0500 Subject: [PATCH] client: check for client disconnect before sending opcodes (#142) --- api/client.go | 15 +++++++++++---- client.go | 13 ++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/api/client.go b/api/client.go index 0f7d23b6..134c1228 100644 --- a/api/client.go +++ b/api/client.go @@ -42,6 +42,8 @@ type Client struct { // Ya like logs? Log Logger + Disconnected chan bool + mutex sync.Mutex } @@ -75,10 +77,15 @@ func (c *Client) SendRequest(requestBody Params, responseBody Response) error { c.mutex.Lock() defer c.mutex.Unlock() - c.Opcodes <- &opcodes.Request{ - Type: name, - ID: id, - Data: requestBody, + select { + case <-c.Disconnected: + return fmt.Errorf("request %s: client already disconnected", name) + default: + c.Opcodes <- &opcodes.Request{ + Type: name, + ID: id, + Data: requestBody, + } } var response *opcodes.RequestResponse diff --git a/client.go b/client.go index 015e1970..a49647a6 100644 --- a/client.go +++ b/client.go @@ -37,7 +37,6 @@ type Client struct { dialer *websocket.Dialer requestHeader http.Header eventSubscriptions int - disconnected chan bool profiler *profile.Profile once sync.Once } @@ -123,7 +122,7 @@ func (c *Client) Disconnect() error { func (c *Client) markDisconnected() { c.once.Do(func() { select { - case c.disconnected <- true: + case c.client.Disconnected <- true: default: } @@ -131,7 +130,7 @@ func (c *Client) markDisconnected() { close(c.IncomingEvents) close(c.client.Opcodes) close(c.client.IncomingResponses) - close(c.disconnected) + close(c.client.Disconnected) }) } @@ -145,8 +144,8 @@ func New(host string, opts ...Option) (*Client, error) { dialer: websocket.DefaultDialer, requestHeader: http.Header{"User-Agent": []string{"goobs/" + LibraryVersion}}, eventSubscriptions: subscriptions.All, - disconnected: make(chan bool), client: &api.Client{ + Disconnected: make(chan bool), IncomingResponses: make(chan *opcodes.RequestResponse), Opcodes: make(chan opcodes.Opcode), ResponseTimeout: 10000, @@ -248,7 +247,7 @@ func (c *Client) handleRawServerMessages(auth chan<- error) { c.markDisconnected() default: select { - case <-c.disconnected: + case <-c.client.Disconnected: default: c.client.Log.Printf("[ERROR] Unhandled error: %s", t) } @@ -260,7 +259,7 @@ func (c *Client) handleRawServerMessages(auth chan<- error) { c.client.Log.Printf("[TRACE] Raw server message: %s", raw) select { - case <-c.disconnected: + case <-c.client.Disconnected: // This might happen if the server sends messages to us // after we've already disconnected, e.g.: // @@ -362,7 +361,7 @@ func (c *Client) handleOpcodes(auth chan<- error) { // to use it, they'll have the latest events available to them. func (c *Client) writeEvent(event any) { select { - case <-c.disconnected: + case <-c.client.Disconnected: case c.IncomingEvents <- event: default: if len(c.IncomingEvents) == cap(c.IncomingEvents) {