diff --git a/regression_test.go b/regression_test.go index 5f917f6..01f409d 100644 --- a/regression_test.go +++ b/regression_test.go @@ -5,6 +5,7 @@ package jrpc2_test import ( "context" "strings" + "sync" "testing" "time" @@ -192,3 +193,51 @@ func TestCheckBatchDuplicateID(t *testing.T) { t.Errorf("Server response: (-want, +got)\n%s", diff) } } + +// Verify that callbacks from notification handlers cannot deadlock on delivery +// of their own replies. Reported in #78, test case courtesy of @radeksimko. +func TestServer_NotificationCallbackDeadlock(t *testing.T) { + defer leaktest.Check(t)() + + var wg sync.WaitGroup + loc := server.NewLocal(handler.Map{ + "NotifyMe": handler.New(func(ctx context.Context) error { + defer wg.Done() + if _, err := jrpc2.ServerFromContext(ctx).Callback(ctx, "succeed", nil); err != nil { + t.Errorf("Callback failed: %v", err) + } + return nil + }), + }, &server.LocalOptions{ + Server: &jrpc2.ServerOptions{AllowPush: true}, + Client: &jrpc2.ClientOptions{ + OnCallback: func(ctx context.Context, req *jrpc2.Request) (interface{}, error) { + switch req.Method() { + case "succeed": + return true, nil + } + panic("broken test: you should not see this") + }, + }, + }) + defer loc.Close() + ctx := context.Background() + + // Call the notification method that posts a callback. + wg.Add(2) + if err := loc.Client.Notify(ctx, "NotifyMe", nil); err != nil { + t.Fatalf("Notify: unexpected error: %v", err) + } + if err := loc.Client.Notify(ctx, "NotifyMe", nil); err != nil { + t.Fatalf("Notify: unexpected error: %v", err) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + // all is well + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for callbacks to return") + } +} diff --git a/server.go b/server.go index 735557c..837d6f7 100644 --- a/server.go +++ b/server.go @@ -281,31 +281,21 @@ func (s *Server) checkAndAssign(next jmessages) tasks { var ids []string dup := make(map[string]*task) // :: id ⇒ first task in batch with id - // Phase 1: Filter out responses from push calls and check for duplicate - // request ID.s + // Phase 1: Check for errors and duplicate request IDs. for _, req := range next { - fid := fixID(req.ID) - id := string(fid) - if !req.isRequestOrNotification() && s.call[id] != nil { - // This is a result or error for a pending push-call. - // - // N.B. It is important to check for this before checking for - // duplicate request IDs, since the ID spaces could overlap. - rsp := s.call[id] - delete(s.call, id) - rsp.ch <- req - continue // don't send a reply for this - } else if req.err != nil { + if req.err != nil { // keep the existing error } else if !s.versionOK(req.V) { req.err = ErrInvalidVersion } + fid := fixID(req.ID) t := &task{ hreq: &Request{id: fid, method: req.M, params: req.P}, batch: req.batch, err: req.err, } + id := string(fid) if old := dup[id]; old != nil { // A previous task already used this ID, fail both. old.err = errDuplicateID.WithData(id) @@ -651,16 +641,51 @@ func (s *Server) read(ch receiver) { } else if len(in) == 0 { s.pushError(errEmptyBatch) } else { - s.log("Received request batch of size %d (qlen=%d)", len(in), s.inq.size()) - s.inq.push(in) - if s.inq.size() == 1 { // the queue was empty - s.signal() + // Filter out response messages. It's possible that the entire batch + // was responses, so re-check the length after doing this. + keep := s.filterBatch(in) + if len(keep) != 0 { + s.log("Received request batch of size %d (qlen=%d)", len(keep), s.inq.size()) + s.inq.push(keep) + if s.inq.size() == 1 { // the queue was empty + s.signal() + } } } s.mu.Unlock() } } +// filterBatch removes and handles any response messages from next, dispatching +// replies to pending callbacks as required. The remainder is returned. +// The caller must hold s.mu, and must re-check that the result is not empty. +func (s *Server) filterBatch(next jmessages) jmessages { + keep := make(jmessages, 0, len(next)) + for _, req := range next { + if req.isRequestOrNotification() { + keep = append(keep, req) + continue + } + + // If this is a response implicating the ID of a pending push-call, + // deliver the result to that call. Do this early to avoid deadlocking on + // the sequencing barrier (see #78). + // + // Note, however, if it does NOT correspond to a known push-call, keep it + // in the batch so it can be serviced as an error. + id := string(fixID(req.ID)) + if s.call[id] != nil { + rsp := s.call[id] + delete(s.call, id) + rsp.ch <- req + s.log("Received response for callback %q", id) + } else { + keep = append(keep, req) + } + } + return keep +} + // ServerInfo is the concrete type of responses from the rpc.serverInfo method. type ServerInfo struct { // The list of method names exported by this server.