Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handleCall: cleanups & locking simplifications #403

Merged
merged 5 commits into from
Dec 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 64 additions & 57 deletions rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,16 @@ func (c *Conn) handleBootstrap(ctx context.Context, id answerID) error {
return err
}

func idempotent(f func()) func() {
called := false
return func() {
if !called {
called = true
f()
}
}
}

func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capnp.ReleaseFunc) error {
rl := &releaseList{}
defer rl.Release()
Expand Down Expand Up @@ -694,18 +704,25 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
return nil
}

c.lk.Lock()
if c.lk.answers[id] != nil {
c.lk.Unlock()
releaseCall()
return rpcerr.Failedf("incoming call: answer ID %d reused", id)
}
var (
err error
p parsedCall
parseErr error
)
syncutil.With(&c.lk, func() {
if c.lk.answers[id] != nil {
rl.Add(releaseCall)
err = rpcerr.Failedf("incoming call: answer ID %d reused", id)
return
}

var p parsedCall
parseErr := c.parseCall(&p, call) // parseCall sets CapTable
parseErr = c.parseCall(&p, call) // parseCall sets CapTable
})
if err != nil {
return err
}

// Create return message.
c.lk.Unlock()
ret, send, retReleaser, err := c.newReturn()
if err != nil {
err = rpcerr.Annotate(err, "incoming call")
Expand All @@ -720,74 +737,74 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
ret.SetReleaseParamCaps(false)

// Find target and start call.
c.lk.Lock()
ans := &answer{
c: c,
id: id,
ret: ret,
sendMsg: send,
msgReleaser: retReleaser,
}
c.lk.Lock()
defer c.lk.Unlock()

c.lk.answers[id] = ans
if parseErr != nil {
parseErr = rpcerr.Annotate(parseErr, "incoming call")
ans.sendException(rl, parseErr)
c.lk.Unlock()
c.er.ReportError(parseErr)
releaseCall()
rl.Add(func() {
c.er.ReportError(parseErr)
releaseCall()
})
return nil
}
released := false
releaseArgs := func() {
if released {
return
}
released = true
releaseCall()

recv := capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: idempotent(releaseCall),
Returner: ans,
}

switch p.target.which {
case rpccp.MessageTarget_Which_importedCap:
ent := c.findExport(p.target.importedCap)
if ent == nil {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.msgReleaser = nil
c.lk.Unlock()
retReleaser.Decr()
releaseCall()
rl.Add(func() {
retReleaser.Decr()
releaseCall()
})
return rpcerr.Failedf("incoming call: unknown export ID %d", id)
}
c.tasks.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
c.lk.Unlock()
pcall := ent.client.RecvCall(callCtx, capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: releaseArgs,
Returner: ans,
rl.Add(func() {
pcall := ent.client.RecvCall(callCtx, recv)
// Place PipelineCaller into answer. Since the receive goroutine is
// the only one that uses answer.pcall, it's fine that there's a
// time gap for this being set.
ans.setPipelineCaller(p.method, pcall)
})
// Place PipelineCaller into answer. Since the receive goroutine is
// the only one that uses answer.pcall, it's fine that there's a
// time gap for this being set.
ans.setPipelineCaller(p.method, pcall)
return nil
case rpccp.MessageTarget_Which_promisedAnswer:
tgtAns := c.lk.answers[p.target.promisedAnswer]
if tgtAns == nil || tgtAns.flags.Contains(finishReceived) {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.msgReleaser = nil
c.lk.Unlock()
retReleaser.Decr()
releaseCall()
rl.Add(func() {
retReleaser.Decr()
releaseCall()
})
return rpcerr.Failedf("incoming call: use of unknown or finished answer ID %d for promised answer target", p.target.promisedAnswer)
}
if tgtAns.flags.Contains(resultsReady) {
if tgtAns.err != nil {
ans.sendException(rl, tgtAns.err)
c.lk.Unlock()
releaseCall()
rl.Add(releaseCall)
return nil
}
// tgtAns.results is guaranteed to stay alive because it hasn't
Expand All @@ -798,17 +815,15 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
if err != nil {
err = rpcerr.Failedf("incoming call: read results from target answer: %w", err)
ans.sendException(rl, err)
c.lk.Unlock()
releaseCall()
rl.Add(releaseCall)
c.er.ReportError(err)
return nil
}
sub, err := capnp.Transform(content, p.target.transform)
if err != nil {
// Not reporting, as this is the caller's fault.
ans.sendException(rl, err)
c.lk.Unlock()
releaseCall()
rl.Add(releaseCall)
return nil
}
iface := sub.Interface()
Expand All @@ -824,30 +839,22 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
c.tasks.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
c.lk.Unlock()
pcall := tgt.RecvCall(callCtx, capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: releaseArgs,
Returner: ans,
rl.Add(func() {
pcall := tgt.RecvCall(callCtx, recv)
ans.setPipelineCaller(p.method, pcall)
})
ans.setPipelineCaller(p.method, pcall)
} else {
// Results not ready, use pipeline caller.
tgtAns.pcalls.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
tgt := tgtAns.pcall
c.tasks.Add(1) // will be finished by answer.Return
c.lk.Unlock()
pcall := tgt.PipelineRecv(callCtx, p.target.transform, capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: releaseArgs,
Returner: ans,
rl.Add(func() {
pcall := tgt.PipelineRecv(callCtx, p.target.transform, recv)
tgtAns.pcalls.Done()
ans.setPipelineCaller(p.method, pcall)
})
tgtAns.pcalls.Done()
ans.setPipelineCaller(p.method, pcall)
}
return nil
default:
Expand Down