diff --git a/rpc/rpc.go b/rpc/rpc.go index 29c2a677..f49fe56f 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -663,6 +663,16 @@ func (c *Conn) handleBootstrap(ctx context.Context, id answerID) error { return err } +func idempotent(f func()) func() { + called := false + return func() { + if !called { + called = true + f() + } + } +} + func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capnp.ReleaseFunc) error { rl := &releaseList{} defer rl.Release() @@ -694,18 +704,25 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn return nil } - c.lk.Lock() - if c.lk.answers[id] != nil { - c.lk.Unlock() - releaseCall() - return rpcerr.Failedf("incoming call: answer ID %d reused", id) - } + var ( + err error + p parsedCall + parseErr error + ) + syncutil.With(&c.lk, func() { + if c.lk.answers[id] != nil { + rl.Add(releaseCall) + err = rpcerr.Failedf("incoming call: answer ID %d reused", id) + return + } - var p parsedCall - parseErr := c.parseCall(&p, call) // parseCall sets CapTable + parseErr = c.parseCall(&p, call) // parseCall sets CapTable + }) + if err != nil { + return err + } // Create return message. - c.lk.Unlock() ret, send, retReleaser, err := c.newReturn() if err != nil { err = rpcerr.Annotate(err, "incoming call") @@ -720,7 +737,6 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn ret.SetReleaseParamCaps(false) // Find target and start call. - c.lk.Lock() ans := &answer{ c: c, id: id, @@ -728,23 +744,27 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn sendMsg: send, msgReleaser: retReleaser, } + c.lk.Lock() + defer c.lk.Unlock() + c.lk.answers[id] = ans if parseErr != nil { parseErr = rpcerr.Annotate(parseErr, "incoming call") ans.sendException(rl, parseErr) - c.lk.Unlock() - c.er.ReportError(parseErr) - releaseCall() + rl.Add(func() { + c.er.ReportError(parseErr) + releaseCall() + }) return nil } - released := false - releaseArgs := func() { - if released { - return - } - released = true - releaseCall() + + recv := capnp.Recv{ + Args: p.args, + Method: p.method, + ReleaseArgs: idempotent(releaseCall), + Returner: ans, } + switch p.target.which { case rpccp.MessageTarget_Which_importedCap: ent := c.findExport(p.target.importedCap) @@ -752,25 +772,22 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn ans.ret = rpccp.Return{} ans.sendMsg = nil ans.msgReleaser = nil - c.lk.Unlock() - retReleaser.Decr() - releaseCall() + rl.Add(func() { + retReleaser.Decr() + releaseCall() + }) return rpcerr.Failedf("incoming call: unknown export ID %d", id) } c.tasks.Add(1) // will be finished by answer.Return var callCtx context.Context callCtx, ans.cancel = context.WithCancel(c.bgctx) - c.lk.Unlock() - pcall := ent.client.RecvCall(callCtx, capnp.Recv{ - Args: p.args, - Method: p.method, - ReleaseArgs: releaseArgs, - Returner: ans, + rl.Add(func() { + pcall := ent.client.RecvCall(callCtx, recv) + // Place PipelineCaller into answer. Since the receive goroutine is + // the only one that uses answer.pcall, it's fine that there's a + // time gap for this being set. + ans.setPipelineCaller(p.method, pcall) }) - // Place PipelineCaller into answer. Since the receive goroutine is - // the only one that uses answer.pcall, it's fine that there's a - // time gap for this being set. - ans.setPipelineCaller(p.method, pcall) return nil case rpccp.MessageTarget_Which_promisedAnswer: tgtAns := c.lk.answers[p.target.promisedAnswer] @@ -778,16 +795,16 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn ans.ret = rpccp.Return{} ans.sendMsg = nil ans.msgReleaser = nil - c.lk.Unlock() - retReleaser.Decr() - releaseCall() + rl.Add(func() { + retReleaser.Decr() + releaseCall() + }) return rpcerr.Failedf("incoming call: use of unknown or finished answer ID %d for promised answer target", p.target.promisedAnswer) } if tgtAns.flags.Contains(resultsReady) { if tgtAns.err != nil { ans.sendException(rl, tgtAns.err) - c.lk.Unlock() - releaseCall() + rl.Add(releaseCall) return nil } // tgtAns.results is guaranteed to stay alive because it hasn't @@ -798,8 +815,7 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn if err != nil { err = rpcerr.Failedf("incoming call: read results from target answer: %w", err) ans.sendException(rl, err) - c.lk.Unlock() - releaseCall() + rl.Add(releaseCall) c.er.ReportError(err) return nil } @@ -807,8 +823,7 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn if err != nil { // Not reporting, as this is the caller's fault. ans.sendException(rl, err) - c.lk.Unlock() - releaseCall() + rl.Add(releaseCall) return nil } iface := sub.Interface() @@ -824,14 +839,10 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn c.tasks.Add(1) // will be finished by answer.Return var callCtx context.Context callCtx, ans.cancel = context.WithCancel(c.bgctx) - c.lk.Unlock() - pcall := tgt.RecvCall(callCtx, capnp.Recv{ - Args: p.args, - Method: p.method, - ReleaseArgs: releaseArgs, - Returner: ans, + rl.Add(func() { + pcall := tgt.RecvCall(callCtx, recv) + ans.setPipelineCaller(p.method, pcall) }) - ans.setPipelineCaller(p.method, pcall) } else { // Results not ready, use pipeline caller. tgtAns.pcalls.Add(1) // will be finished by answer.Return @@ -839,15 +850,11 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn callCtx, ans.cancel = context.WithCancel(c.bgctx) tgt := tgtAns.pcall c.tasks.Add(1) // will be finished by answer.Return - c.lk.Unlock() - pcall := tgt.PipelineRecv(callCtx, p.target.transform, capnp.Recv{ - Args: p.args, - Method: p.method, - ReleaseArgs: releaseArgs, - Returner: ans, + rl.Add(func() { + pcall := tgt.PipelineRecv(callCtx, p.target.transform, recv) + tgtAns.pcalls.Done() + ans.setPipelineCaller(p.method, pcall) }) - tgtAns.pcalls.Done() - ans.setPipelineCaller(p.method, pcall) } return nil default: