Skip to content

Commit

Permalink
Merge pull request #403 from zenhack/handleCall-locking
Browse files Browse the repository at this point in the history
handleCall: cleanups & locking simplifications
  • Loading branch information
lthibault authored Dec 28, 2022
2 parents de17112 + 6da1fd2 commit 6d8c1b2
Showing 1 changed file with 64 additions and 57 deletions.
121 changes: 64 additions & 57 deletions rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,16 @@ func (c *Conn) handleBootstrap(ctx context.Context, id answerID) error {
return err
}

func idempotent(f func()) func() {
called := false
return func() {
if !called {
called = true
f()
}
}
}

func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capnp.ReleaseFunc) error {
rl := &releaseList{}
defer rl.Release()
Expand Down Expand Up @@ -694,18 +704,25 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
return nil
}

c.lk.Lock()
if c.lk.answers[id] != nil {
c.lk.Unlock()
releaseCall()
return rpcerr.Failedf("incoming call: answer ID %d reused", id)
}
var (
err error
p parsedCall
parseErr error
)
syncutil.With(&c.lk, func() {
if c.lk.answers[id] != nil {
rl.Add(releaseCall)
err = rpcerr.Failedf("incoming call: answer ID %d reused", id)
return
}

var p parsedCall
parseErr := c.parseCall(&p, call) // parseCall sets CapTable
parseErr = c.parseCall(&p, call) // parseCall sets CapTable
})
if err != nil {
return err
}

// Create return message.
c.lk.Unlock()
ret, send, retReleaser, err := c.newReturn()
if err != nil {
err = rpcerr.Annotate(err, "incoming call")
Expand All @@ -720,74 +737,74 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
ret.SetReleaseParamCaps(false)

// Find target and start call.
c.lk.Lock()
ans := &answer{
c: c,
id: id,
ret: ret,
sendMsg: send,
msgReleaser: retReleaser,
}
c.lk.Lock()
defer c.lk.Unlock()

c.lk.answers[id] = ans
if parseErr != nil {
parseErr = rpcerr.Annotate(parseErr, "incoming call")
ans.sendException(rl, parseErr)
c.lk.Unlock()
c.er.ReportError(parseErr)
releaseCall()
rl.Add(func() {
c.er.ReportError(parseErr)
releaseCall()
})
return nil
}
released := false
releaseArgs := func() {
if released {
return
}
released = true
releaseCall()

recv := capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: idempotent(releaseCall),
Returner: ans,
}

switch p.target.which {
case rpccp.MessageTarget_Which_importedCap:
ent := c.findExport(p.target.importedCap)
if ent == nil {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.msgReleaser = nil
c.lk.Unlock()
retReleaser.Decr()
releaseCall()
rl.Add(func() {
retReleaser.Decr()
releaseCall()
})
return rpcerr.Failedf("incoming call: unknown export ID %d", id)
}
c.tasks.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
c.lk.Unlock()
pcall := ent.client.RecvCall(callCtx, capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: releaseArgs,
Returner: ans,
rl.Add(func() {
pcall := ent.client.RecvCall(callCtx, recv)
// Place PipelineCaller into answer. Since the receive goroutine is
// the only one that uses answer.pcall, it's fine that there's a
// time gap for this being set.
ans.setPipelineCaller(p.method, pcall)
})
// Place PipelineCaller into answer. Since the receive goroutine is
// the only one that uses answer.pcall, it's fine that there's a
// time gap for this being set.
ans.setPipelineCaller(p.method, pcall)
return nil
case rpccp.MessageTarget_Which_promisedAnswer:
tgtAns := c.lk.answers[p.target.promisedAnswer]
if tgtAns == nil || tgtAns.flags.Contains(finishReceived) {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.msgReleaser = nil
c.lk.Unlock()
retReleaser.Decr()
releaseCall()
rl.Add(func() {
retReleaser.Decr()
releaseCall()
})
return rpcerr.Failedf("incoming call: use of unknown or finished answer ID %d for promised answer target", p.target.promisedAnswer)
}
if tgtAns.flags.Contains(resultsReady) {
if tgtAns.err != nil {
ans.sendException(rl, tgtAns.err)
c.lk.Unlock()
releaseCall()
rl.Add(releaseCall)
return nil
}
// tgtAns.results is guaranteed to stay alive because it hasn't
Expand All @@ -798,17 +815,15 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
if err != nil {
err = rpcerr.Failedf("incoming call: read results from target answer: %w", err)
ans.sendException(rl, err)
c.lk.Unlock()
releaseCall()
rl.Add(releaseCall)
c.er.ReportError(err)
return nil
}
sub, err := capnp.Transform(content, p.target.transform)
if err != nil {
// Not reporting, as this is the caller's fault.
ans.sendException(rl, err)
c.lk.Unlock()
releaseCall()
rl.Add(releaseCall)
return nil
}
iface := sub.Interface()
Expand All @@ -824,30 +839,22 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
c.tasks.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
c.lk.Unlock()
pcall := tgt.RecvCall(callCtx, capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: releaseArgs,
Returner: ans,
rl.Add(func() {
pcall := tgt.RecvCall(callCtx, recv)
ans.setPipelineCaller(p.method, pcall)
})
ans.setPipelineCaller(p.method, pcall)
} else {
// Results not ready, use pipeline caller.
tgtAns.pcalls.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
tgt := tgtAns.pcall
c.tasks.Add(1) // will be finished by answer.Return
c.lk.Unlock()
pcall := tgt.PipelineRecv(callCtx, p.target.transform, capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: releaseArgs,
Returner: ans,
rl.Add(func() {
pcall := tgt.PipelineRecv(callCtx, p.target.transform, recv)
tgtAns.pcalls.Done()
ans.setPipelineCaller(p.method, pcall)
})
tgtAns.pcalls.Done()
ans.setPipelineCaller(p.method, pcall)
}
return nil
default:
Expand Down

0 comments on commit 6d8c1b2

Please sign in to comment.