Skip to content

Commit

Permalink
fix(GODT-1586): Implement context cancellation for commands
Browse files Browse the repository at this point in the history
Check for context cancellation for long running commands such as list,
fetch, lsub and search.

The command reading process has also been updated so that it returns the
error in the channel. This way we can report the error back to the
client.
  • Loading branch information
LBeernaertProton committed Sep 1, 2022
1 parent bdbefef commit ccc6245
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 69 deletions.
35 changes: 16 additions & 19 deletions internal/session/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package session

import (
"context"
"errors"
"io"
"runtime/pprof"
"strconv"

Expand All @@ -14,39 +12,38 @@ import (
type command struct {
tag string
cmd *proto.Command
err error
}

func (s *Session) getCommandCh(ctx context.Context, del string) <-chan command {
func (s *Session) startCommandReader(ctx context.Context, del string) <-chan command {
cmdCh := make(chan command)

go func() {
labels := pprof.Labels("go", "CommandReader", "SessionID", strconv.Itoa(s.sessionID))
pprof.Do(ctx, labels, func(_ context.Context) {
defer close(cmdCh)

for {
tag, cmd, err := s.readCommand(del)
if err != nil {
if errors.Is(err, io.EOF) {
return
} else if err := response.Bad(tag).WithError(err).Send(s); err != nil {
return

if err == nil && cmd.GetStartTLS() != nil {
// TLS needs to be handled here in order to ensure that next command read is over the
// tls connection.
if e := s.handleStartTLS(tag, cmd.GetStartTLS()); e != nil {
cmd = nil
err = e
} else {
continue
}

continue
}

switch {
case cmd.GetStartTLS() != nil:
if err := s.handleStartTLS(tag, cmd.GetStartTLS()); err != nil {
if err := response.Bad(tag).WithError(err).Send(s); err != nil {
return
}
select {
case cmdCh <- command{tag: tag, cmd: cmd, err: err}:

continue
}
case <-ctx.Done():
return

default:
cmdCh <- command{tag: tag, cmd: cmd}
}
}
})
Expand Down
9 changes: 7 additions & 2 deletions internal/session/handle_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ func (s *Session) handleList(ctx context.Context, tag string, cmd *proto.List, c

return s.state.List(ctx, cmd.GetReference(), nameUTF8, false, func(matches map[string]state.Match) error {
for _, match := range matches {
ch <- response.List().
select {
case ch <- response.List().
WithName(match.Name).
WithDelimiter(match.Delimiter).
WithAttributes(match.Atts)
WithAttributes(match.Atts):

case <-ctx.Done():
return ctx.Err()
}
}

ch <- response.Ok(tag).WithMessage("LIST")
Expand Down
9 changes: 7 additions & 2 deletions internal/session/handle_lsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@ func (s *Session) handleLsub(ctx context.Context, tag string, cmd *proto.Lsub, c

return s.state.List(ctx, cmd.GetReference(), nameUTF8, true, func(matches map[string]state.Match) error {
for _, match := range matches {
ch <- response.Lsub().
select {
case ch <- response.Lsub().
WithName(match.Name).
WithDelimiter(match.Delimiter).
WithAttributes(match.Atts)
WithAttributes(match.Atts):

case <-ctx.Done():
return ctx.Err()
}
}

ch <- response.Ok(tag).WithMessage("LSUB")
Expand Down
7 changes: 6 additions & 1 deletion internal/session/handle_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ func (s *Session) handleSearch(ctx context.Context, tag string, cmd *proto.Searc
return err
}

ch <- response.Search(seq...)
select {
case ch <- response.Search(seq...):

case <-ctx.Done():
return ctx.Err()
}

var items []response.Item

Expand Down
36 changes: 28 additions & 8 deletions internal/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package session
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -114,15 +115,15 @@ func (s *Session) Serve(ctx context.Context) error {
return err
}

if err := s.serve(ctx, s.getCommandCh(ctx, s.backend.GetDelimiter())); err != nil {
if err := s.serve(ctx); err != nil {
logrus.WithError(err).Errorf("Failed to serve session %v", s.sessionID)
return err
}

return nil
}

func (s *Session) serve(ctx context.Context, cmdCh <-chan command) error {
func (s *Session) serve(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

Expand All @@ -134,15 +135,31 @@ func (s *Session) serve(ctx context.Context, cmdCh <-chan command) error {
cmd *proto.Command
)

cmdCh := s.startCommandReader(ctx, s.backend.GetDelimiter())

for {
select {
case res, ok := <-cmdCh:
if !ok {
logrus.Debugf("Failed to read from command channel")
return nil
}

tag, cmd = res.tag, res.cmd

if res.err != nil {
logrus.WithError(res.err).Debugf("Error during command parsing")

if errors.Is(res.err, io.EOF) {
logrus.Debugf("Connection to client lost")
return nil
} else if err := response.Bad(tag).WithError(res.err).Send(s); err != nil {
return err
}

continue
}

case <-s.state.Done():
return nil

Expand Down Expand Up @@ -178,12 +195,15 @@ func (s *Session) serve(ctx context.Context, cmdCh <-chan command) error {
responseCh := s.handleOther(withStartTime(ctx, time.Now()), tag, cmd, profiler)
for res := range responseCh {
if err := res.Send(s); err != nil {
// Consume all remaining channel response since the connection is no longer available.
// Failing to do so can cause a deadlock in the program as `s.handleOther` never finishes
// executing and can hold onto a number of locks indefinitely.
for range responseCh {
// ...
}
go func() {
// Consume all remaining channel response since the connection is no longer available.
// Failing to do so can cause a deadlock in the program as `s.handleOther` never finishes
// executing and can hold onto a number of locks indefinitely.
// Consumed on a separate go routine to not block the return.
for range responseCh {
// ...
}
}()

return fmt.Errorf("failed to send response to client: %w", err)
}
Expand Down
7 changes: 6 additions & 1 deletion internal/state/mailbox_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ func (m *Mailbox) Fetch(ctx context.Context, seq *proto.SequenceSet, attributes
return err
}

ch <- response.Fetch(seq).WithItems(items...)
select {
case ch <- response.Fetch(seq).WithItems(items...):

case <-ctx.Done():
return ctx.Err()
}
}

return nil
Expand Down
Loading

0 comments on commit ccc6245

Please sign in to comment.