Skip to content

Commit

Permalink
client: check for client disconnect before sending opcodes (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
andreykaipov authored Mar 4, 2024
1 parent ce2d164 commit 1267b7f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
15 changes: 11 additions & 4 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type Client struct {
// Ya like logs?
Log Logger

Disconnected chan bool

mutex sync.Mutex
}

Expand Down Expand Up @@ -75,10 +77,15 @@ func (c *Client) SendRequest(requestBody Params, responseBody Response) error {
c.mutex.Lock()
defer c.mutex.Unlock()

c.Opcodes <- &opcodes.Request{
Type: name,
ID: id,
Data: requestBody,
select {
case <-c.Disconnected:
return fmt.Errorf("request %s: client already disconnected", name)
default:
c.Opcodes <- &opcodes.Request{
Type: name,
ID: id,
Data: requestBody,
}
}

var response *opcodes.RequestResponse
Expand Down
13 changes: 6 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ type Client struct {
dialer *websocket.Dialer
requestHeader http.Header
eventSubscriptions int
disconnected chan bool
profiler *profile.Profile
once sync.Once
}
Expand Down Expand Up @@ -123,15 +122,15 @@ func (c *Client) Disconnect() error {
func (c *Client) markDisconnected() {
c.once.Do(func() {
select {
case c.disconnected <- true:
case c.client.Disconnected <- true:
default:
}

c.client.Log.Printf("[TRACE] Closing internal channels")
close(c.IncomingEvents)
close(c.client.Opcodes)
close(c.client.IncomingResponses)
close(c.disconnected)
close(c.client.Disconnected)
})
}

Expand All @@ -145,8 +144,8 @@ func New(host string, opts ...Option) (*Client, error) {
dialer: websocket.DefaultDialer,
requestHeader: http.Header{"User-Agent": []string{"goobs/" + LibraryVersion}},
eventSubscriptions: subscriptions.All,
disconnected: make(chan bool),
client: &api.Client{
Disconnected: make(chan bool),
IncomingResponses: make(chan *opcodes.RequestResponse),
Opcodes: make(chan opcodes.Opcode),
ResponseTimeout: 10000,
Expand Down Expand Up @@ -248,7 +247,7 @@ func (c *Client) handleRawServerMessages(auth chan<- error) {
c.markDisconnected()
default:
select {
case <-c.disconnected:
case <-c.client.Disconnected:
default:
c.client.Log.Printf("[ERROR] Unhandled error: %s", t)
}
Expand All @@ -260,7 +259,7 @@ func (c *Client) handleRawServerMessages(auth chan<- error) {
c.client.Log.Printf("[TRACE] Raw server message: %s", raw)

select {
case <-c.disconnected:
case <-c.client.Disconnected:
// This might happen if the server sends messages to us
// after we've already disconnected, e.g.:
//
Expand Down Expand Up @@ -362,7 +361,7 @@ func (c *Client) handleOpcodes(auth chan<- error) {
// to use it, they'll have the latest events available to them.
func (c *Client) writeEvent(event any) {
select {
case <-c.disconnected:
case <-c.client.Disconnected:
case c.IncomingEvents <- event:
default:
if len(c.IncomingEvents) == cap(c.IncomingEvents) {
Expand Down

0 comments on commit 1267b7f

Please sign in to comment.