From 198406ce74f222157126992280e8df21de2c3c5d Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Fri, 31 Mar 2023 20:32:31 -0400 Subject: [PATCH 1/4] Split answer struct into ansent and ansReturner. ...based on whether or not we need to access the field through c.lk. --- rpc/answer.go | 155 +++++++++++++++++++++++++++----------------------- rpc/rpc.go | 88 ++++++++++++++-------------- 2 files changed, 130 insertions(+), 113 deletions(-) diff --git a/rpc/answer.go b/rpc/answer.go index f1390877..c57ce0de 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -14,36 +14,8 @@ import ( // An answerID is an index into the answers table. type answerID uint32 -// answer is an entry in a Conn's answer table. -type answer struct { - // c and id must be set before any answer methods are called. - c *Conn - id answerID - - // cancel cancels the Context used in the received method call. - // May be nil. - cancel context.CancelFunc - - // ret is the outgoing Return struct. ret is valid iff there was no - // error creating the message. If ret is invalid, then this answer - // entry is a placeholder until the remote vat cancels the call. - ret rpccp.Return - - // sendMsg sends the return message. The caller MUST hold ans.c.lk; - // the argument should be the same as ans.c - sendMsg func(*lockedConn) - - // msgReleaser releases the return message when its refcount hits zero. - // The caller MUST NOT hold ans.c.lk. - msgReleaser *rc.Releaser - - // results is the memoized answer to ret.Results(). - // Set by AllocResults and setBootstrap, but contents can only be read - // if flags has resultsReady but not finishReceived set. - results rpccp.Payload - - // All fields below are protected by s.c.mu. - +// ansent is an entry in a Conn's answer table. +type ansent struct { // flags is a bitmask of events that have occurred in an answer's // lifetime. flags answerFlags @@ -68,6 +40,40 @@ type answer struct { // the Return message. Can only be read after resultsReady is set in // flags. err error + + // sendMsg sends the return message. The caller MUST hold ans.c.lk; + // the argument should be the same as ans.c + sendMsg func(*lockedConn) + + // Unlike other fields in this struct, it is ok to hand out pointers + // to this that can be used while not holding the connection lock. + returner ansReturner +} + +// ansReturner is the implementation of capnp.Returner that is used when +// handling an incoming call on a local capability. +type ansReturner struct { + // c and id must be set before any answer methods are called. + c *Conn + id answerID + + // cancel cancels the Context used in the received method call. + // May be nil. + cancel context.CancelFunc + + // ret is the outgoing Return struct. ret is valid iff there was no + // error creating the message. If ret is invalid, then this answer + // entry is a placeholder until the remote vat cancels the call. + ret rpccp.Return + + // msgReleaser releases the return message when its refcount hits zero. + // The caller MUST NOT hold ans.c.lk. + msgReleaser *rc.Releaser + + // results is the memoized answer to ret.Results(). + // Set by AllocResults and setBootstrap, but contents can only be read + // if flags has resultsReady but not finishReceived set. + results rpccp.Payload } type answerFlags uint8 @@ -85,11 +91,13 @@ func (flags answerFlags) Contains(flag answerFlags) bool { return flags&flag != 0 } -// errorAnswer returns a placeholder answer with an error result already set. -func errorAnswer(c *Conn, id answerID, err error) *answer { - return &answer{ - c: c, - id: id, +// errorAnswer returns a placeholder answer entry with an error result already set. +func errorAnswer(c *Conn, id answerID, err error) *ansent { + return &ansent{ + returner: ansReturner{ + c: c, + id: id, + }, err: err, flags: resultsReady | returnSent, } @@ -134,8 +142,8 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel // already returned. The caller MUST hold ans.c.lk. // // This also sets ans.promise to a new promise, wrapping pcall. -func (ans *answer) setPipelineCaller(c *lockedConn, m capnp.Method, pcall capnp.PipelineCaller) { - c.assertIs(ans.c) +func (ans *ansent) setPipelineCaller(c *lockedConn, m capnp.Method, pcall capnp.PipelineCaller) { + c.assertIs(ans.returner.c) if !ans.flags.Contains(resultsReady) { ans.pcall = pcall @@ -144,7 +152,7 @@ func (ans *answer) setPipelineCaller(c *lockedConn, m capnp.Method, pcall capnp. } // AllocResults allocates the results struct. -func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) { +func (ans *ansReturner) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) { var err error ans.results, err = ans.ret.NewResults() if err != nil { @@ -163,7 +171,7 @@ func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) { // setBootstrap sets the results to an interface pointer, stealing the // reference. -func (ans *answer) setBootstrap(c capnp.Client) error { +func (ans *ansReturner) setBootstrap(c capnp.Client) error { if ans.ret.HasResults() || len(ans.ret.Message().CapTable) > 0 { panic("setBootstrap called after creating results") } @@ -183,33 +191,38 @@ func (ans *answer) setBootstrap(c capnp.Client) error { } // PrepareReturn implements capnp.Returner.PrepareReturn -func (ans *answer) PrepareReturn(e error) { +func (ans *ansReturner) PrepareReturn(e error) { rl := &releaseList{} defer rl.Release() ans.c.withLocked(func(c *lockedConn) { + ent := c.lk.answers[ans.id] if e == nil { - ans.prepareSendReturn(c, rl) + ent.prepareSendReturn(c, rl) } else { - ans.prepareSendException(c, rl, e) + ent.prepareSendException(c, rl, e) } }) } // Return implements capnp.Returner.Return -func (ans *answer) Return() { +func (ans *ansReturner) Return() { rl := &releaseList{} defer rl.Release() - defer ans.pcalls.Wait() + pcallsWait := func() {} var err error ans.c.withLocked(func(c *lockedConn) { - if ans.err == nil { - err = ans.completeSendReturn(c, rl) + ent := c.lk.answers[ans.id] + pcallsWait = ent.pcalls.Wait + + if ent.err == nil { + err = ent.completeSendReturn(c, rl) } else { - ans.completeSendException(c, rl) + ent.completeSendException(c, rl) } }) + defer pcallsWait() ans.c.tasks.Done() // added by handleCall if err == nil { @@ -221,7 +234,7 @@ func (ans *answer) Return() { } } -func (ans *answer) ReleaseResults() { +func (ans *ansReturner) ReleaseResults() { if ans.results.IsValid() { ans.msgReleaser.Decr() } @@ -237,32 +250,32 @@ func (ans *answer) ReleaseResults() { // the lock, per usual). // // sendReturn MUST NOT be called if sendException was previously called. -func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error { +func (ans *ansent) sendReturn(c *lockedConn, rl *releaseList) error { ans.prepareSendReturn(c, rl) return ans.completeSendReturn(c, rl) } -func (ans *answer) prepareSendReturn(c *lockedConn, rl *releaseList) { - c.assertIs(ans.c) +func (ans *ansent) prepareSendReturn(c *lockedConn, rl *releaseList) { + c.assertIs(ans.returner.c) var err error - ans.exportRefs, err = c.fillPayloadCapTable(ans.results) + ans.exportRefs, err = c.fillPayloadCapTable(ans.returner.results) if err != nil { - ans.c.er.ReportError(rpcerr.Annotate(err, "send return")) + c.er.ReportError(rpcerr.Annotate(err, "send return")) } // Continue. Don't fail to send return if cap table isn't fully filled. select { - case <-ans.c.bgctx.Done(): + case <-c.bgctx.Done(): // We're not going to send the message after all, so don't forget to release it. - ans.msgReleaser.Decr() + ans.returner.msgReleaser.Decr() ans.sendMsg = nil default: } } -func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error { - c.assertIs(ans.c) +func (ans *ansent) completeSendReturn(c *lockedConn, rl *releaseList) error { + c.assertIs(ans.returner.c) ans.pcall = nil ans.flags |= resultsReady @@ -277,7 +290,7 @@ func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error { // the cancel variant of return. ans.promise.Reject(rpcerr.Failed(errors.New("received finish before return"))) } else { - ans.promise.Resolve(ans.results.Content()) + ans.promise.Resolve(ans.returner.results.Content()) } ans.promise = nil } @@ -295,31 +308,31 @@ func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error { // // The caller MUST be holding onto ans.c.lk. sendException MUST NOT // be called if sendReturn was previously called. -func (ans *answer) sendException(c *lockedConn, rl *releaseList, ex error) { +func (ans *ansent) sendException(c *lockedConn, rl *releaseList, ex error) { ans.prepareSendException(c, rl, ex) ans.completeSendException(c, rl) } -func (ans *answer) prepareSendException(c *lockedConn, rl *releaseList, ex error) { - c.assertIs(ans.c) +func (ans *ansent) prepareSendException(c *lockedConn, rl *releaseList, ex error) { + c.assertIs(ans.returner.c) ans.err = ex select { - case <-ans.c.bgctx.Done(): + case <-c.bgctx.Done(): default: // Send exception. - if e, err := ans.ret.NewException(); err != nil { - ans.c.er.ReportError(exc.WrapError("send exception", err)) + if e, err := ans.returner.ret.NewException(); err != nil { + c.er.ReportError(exc.WrapError("send exception", err)) ans.sendMsg = nil } else if err := e.MarshalError(ex); err != nil { - ans.c.er.ReportError(exc.WrapError("send exception", err)) + c.er.ReportError(exc.WrapError("send exception", err)) ans.sendMsg = nil } } } -func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) { - c.assertIs(ans.c) +func (ans *ansent) completeSendException(c *lockedConn, rl *releaseList) { + c.assertIs(ans.returner.c) ex := ans.err ans.pcall = nil @@ -345,11 +358,11 @@ func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) { // The caller must be holding onto ans.c.lk. // // shutdown has its own strategy for cleaning up an answer. -func (ans *answer) destroy(c *lockedConn, rl *releaseList) error { - c.assertIs(ans.c) +func (ans *ansent) destroy(c *lockedConn, rl *releaseList) error { + c.assertIs(ans.returner.c) - rl.Add(ans.msgReleaser.Decr) - delete(c.lk.answers, ans.id) + rl.Add(ans.returner.msgReleaser.Decr) + delete(c.lk.answers, ans.returner.id) if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 { return nil diff --git a/rpc/rpc.go b/rpc/rpc.go index 14247d16..2501502d 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -122,7 +122,7 @@ type Conn struct { // Tables questions []*question questionID idgen - answers map[answerID]*answer + answers map[answerID]*ansent exports []*expent exportID idgen imports map[importID]*impent @@ -234,7 +234,7 @@ func NewConn(t Transport, opts *Options) *Conn { c.sendRx = &sender.Rx c.lk.sendTx = &sender.Tx - c.lk.answers = make(map[answerID]*answer) + c.lk.answers = make(map[answerID]*ansent) c.lk.imports = make(map[importID]*impent) if opts != nil { @@ -442,8 +442,8 @@ func (c *Conn) shutdown(abortErr error) (err error) { // Called by 'shutdown'. Callers MUST hold c.lk. func (c *lockedConn) cancelTasks() { for _, a := range c.lk.answers { - if a != nil && a.cancel != nil { - a.cancel() + if a != nil && a.returner.cancel != nil { + a.returner.cancel() } } } @@ -507,10 +507,10 @@ func (c *lockedConn) liftEmbargoes(rl *releaseList, embargoes []*embargo) { } } -func (c *lockedConn) releaseAnswers(rl *releaseList, answers map[answerID]*answer) { +func (c *lockedConn) releaseAnswers(rl *releaseList, answers map[answerID]*ansent) { for _, a := range answers { - if a != nil && a.msgReleaser != nil { - rl.Add(a.msgReleaser.Decr) + if a != nil && a.returner.msgReleaser != nil { + rl.Add(a.returner.msgReleaser.Decr) } } } @@ -717,37 +717,39 @@ func (c *Conn) handleBootstrap(in transport.IncomingMessage) error { rl := &releaseList{} defer rl.Release() - ans := answer{ - c: c, - id: answerID(bootstrap.QuestionId()), + ans := ansent{ + returner: ansReturner{ + c: c, + id: answerID(bootstrap.QuestionId()), + }, } - ans.ret, ans.sendMsg, ans.msgReleaser, err = c.newReturn() + ans.returner.ret, ans.sendMsg, ans.returner.msgReleaser, err = c.newReturn() if err == nil { - ans.ret.SetAnswerId(uint32(ans.id)) - ans.ret.SetReleaseParamCaps(false) + ans.returner.ret.SetAnswerId(uint32(ans.returner.id)) + ans.returner.ret.SetReleaseParamCaps(false) } c.withLocked(func(c *lockedConn) { - if c.lk.answers[ans.id] != nil { - rl.Add(ans.msgReleaser.Decr) - err = rpcerr.Failed(errors.New("incoming bootstrap: answer ID " + str.Utod(ans.id) + " reused")) + if c.lk.answers[ans.returner.id] != nil { + rl.Add(ans.returner.msgReleaser.Decr) + err = rpcerr.Failed(errors.New("incoming bootstrap: answer ID " + str.Utod(ans.returner.id) + " reused")) return } if err != nil { err = rpcerr.Annotate(err, "incoming bootstrap") - c.lk.answers[ans.id] = errorAnswer((*Conn)(c), ans.id, err) + c.lk.answers[ans.returner.id] = errorAnswer((*Conn)(c), ans.returner.id, err) c.er.ReportError(err) return } - c.lk.answers[ans.id] = &ans + c.lk.answers[ans.returner.id] = &ans if !c.bootstrap.IsValid() { ans.sendException(c, rl, exc.New(exc.Failed, "", "vat does not expose a public/bootstrap interface")) return } - if err := ans.setBootstrap(c.bootstrap.AddRef()); err != nil { + if err := ans.returner.setBootstrap(c.bootstrap.AddRef()); err != nil { ans.sendException(c, rl, err) return } @@ -836,12 +838,14 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err ret.SetReleaseParamCaps(false) // Find target and start call. - ans := &answer{ - c: c, - id: id, - ret: ret, - sendMsg: send, - msgReleaser: retReleaser, + ans := &ansent{ + returner: ansReturner{ + c: c, + id: id, + ret: ret, + msgReleaser: retReleaser, + }, + sendMsg: send, } return withLockedConn1(c, func(c *lockedConn) error { c.lk.answers[id] = ans @@ -859,16 +863,16 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err Args: p.args, Method: p.method, ReleaseArgs: util.Idempotent(in.Release), - Returner: ans, + Returner: &ans.returner, } switch p.target.which { case rpccp.MessageTarget_Which_importedCap: ent := c.findExport(p.target.importedCap) if ent == nil { - ans.ret = rpccp.Return{} + ans.returner.ret = rpccp.Return{} ans.sendMsg = nil - ans.msgReleaser = nil + ans.returner.msgReleaser = nil rl.Add(func() { retReleaser.Decr() in.Release() @@ -877,7 +881,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err } c.tasks.Add(1) // will be finished by answer.Return var callCtx context.Context - callCtx, ans.cancel = context.WithCancel(c.bgctx) + callCtx, ans.returner.cancel = context.WithCancel(c.bgctx) pcall := newPromisedPipelineCaller() ans.setPipelineCaller(c, p.method, pcall) rl.Add(func() { @@ -887,9 +891,9 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err case rpccp.MessageTarget_Which_promisedAnswer: tgtAns := c.lk.answers[p.target.promisedAnswer] if tgtAns == nil || tgtAns.flags.Contains(finishReceived) { - ans.ret = rpccp.Return{} + ans.returner.ret = rpccp.Return{} ans.sendMsg = nil - ans.msgReleaser = nil + ans.returner.msgReleaser = nil rl.Add(func() { retReleaser.Decr() in.Release() @@ -910,7 +914,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err // received finish yet (it would have been deleted from the // answers table), and it can't receive a finish because this is // happening on the receive goroutine. - content, err := tgtAns.results.Content() + content, err := tgtAns.returner.results.Content() if err != nil { err = rpcerr.WrapFailed("incoming call: read results from target answer", err) ans.sendException(c, rl, err) @@ -930,14 +934,14 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err switch { case sub.IsValid() && !iface.IsValid(): tgt = capnp.ErrorClient(rpcerr.Failed(ErrNotACapability)) - case !iface.IsValid() || int64(iface.Capability()) >= int64(len(tgtAns.results.Message().CapTable)): + case !iface.IsValid() || int64(iface.Capability()) >= int64(len(tgtAns.returner.results.Message().CapTable)): tgt = capnp.Client{} default: - tgt = tgtAns.results.Message().CapTable[iface.Capability()] + tgt = tgtAns.returner.results.Message().CapTable[iface.Capability()] } c.tasks.Add(1) // will be finished by answer.Return var callCtx context.Context - callCtx, ans.cancel = context.WithCancel(c.bgctx) + callCtx, ans.returner.cancel = context.WithCancel(c.bgctx) pcall := newPromisedPipelineCaller() ans.setPipelineCaller(c, p.method, pcall) rl.Add(func() { @@ -947,7 +951,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err // Results not ready, use pipeline caller. tgtAns.pcalls.Add(1) // will be finished by answer.Return var callCtx context.Context - callCtx, ans.cancel = context.WithCancel(c.bgctx) + callCtx, ans.returner.cancel = context.WithCancel(c.bgctx) tgt := tgtAns.pcall c.tasks.Add(1) // will be finished by answer.Return pcall := newPromisedPipelineCaller() @@ -1319,8 +1323,8 @@ func (c *Conn) handleFinish(ctx context.Context, in transport.IncomingMessage) e if releaseResultCaps { ans.flags |= releaseResultCapsFlag } - if ans.cancel != nil { - ans.cancel() + if ans.returner.cancel != nil { + ans.returner.cancel() } if !ans.flags.Contains(returnSent) { return nil @@ -1417,7 +1421,7 @@ func (c *lockedConn) recvCap(d rpccp.CapDescriptor) (capnp.Client, error) { } // Helper for lockedConn.recvCap(); handles the receiverAnswer case. -func (c *lockedConn) recvCapReceiverAnswer(ans *answer, transform []capnp.PipelineOp) capnp.Client { +func (c *lockedConn) recvCapReceiverAnswer(ans *ansent, transform []capnp.PipelineOp) capnp.Client { if ans.promise != nil { // Still unresolved. future := ans.promise.Answer().Future() @@ -1431,7 +1435,7 @@ func (c *lockedConn) recvCapReceiverAnswer(ans *answer, transform []capnp.Pipeli return capnp.ErrorClient(ans.err) } - ptr, err := ans.results.Content() + ptr, err := ans.returner.results.Content() if err != nil { return capnp.ErrorClient(rpcerr.WrapFailed("except.Failed reading results", err)) } @@ -1641,7 +1645,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag } var content capnp.Ptr - if content, err = ans.results.Content(); err != nil { + if content, err = ans.returner.results.Content(); err != nil { err = rpcerr.Failed(errors.New( "incoming disembargo: read answer ID " + str.Utod(tgt.promisedAnswer) + ": " + err.Error(), @@ -1659,7 +1663,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag } iface := ptr.Interface() - if !iface.IsValid() || int64(iface.Capability()) >= int64(len(ans.results.Message().CapTable)) { + if !iface.IsValid() || int64(iface.Capability()) >= int64(len(ans.returner.results.Message().CapTable)) { err = rpcerr.Failed(errors.New( "incoming disembargo: sender loopback requested on a capability that is not an import", )) From 777d0b525a967270aaa2982ddd5b1cdafb7dcd62 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Fri, 31 Mar 2023 20:33:48 -0400 Subject: [PATCH 2/4] Move ansReturner.cancel into ansent. We never actually invoke it directly from ansReturner. --- rpc/answer.go | 8 ++++---- rpc/rpc.go | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/rpc/answer.go b/rpc/answer.go index c57ce0de..02a94350 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -45,6 +45,10 @@ type ansent struct { // the argument should be the same as ans.c sendMsg func(*lockedConn) + // cancel cancels the Context used in the received method call. + // May be nil. + cancel context.CancelFunc + // Unlike other fields in this struct, it is ok to hand out pointers // to this that can be used while not holding the connection lock. returner ansReturner @@ -57,10 +61,6 @@ type ansReturner struct { c *Conn id answerID - // cancel cancels the Context used in the received method call. - // May be nil. - cancel context.CancelFunc - // ret is the outgoing Return struct. ret is valid iff there was no // error creating the message. If ret is invalid, then this answer // entry is a placeholder until the remote vat cancels the call. diff --git a/rpc/rpc.go b/rpc/rpc.go index 2501502d..35a0c648 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -442,8 +442,8 @@ func (c *Conn) shutdown(abortErr error) (err error) { // Called by 'shutdown'. Callers MUST hold c.lk. func (c *lockedConn) cancelTasks() { for _, a := range c.lk.answers { - if a != nil && a.returner.cancel != nil { - a.returner.cancel() + if a != nil && a.cancel != nil { + a.cancel() } } } @@ -881,7 +881,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err } c.tasks.Add(1) // will be finished by answer.Return var callCtx context.Context - callCtx, ans.returner.cancel = context.WithCancel(c.bgctx) + callCtx, ans.cancel = context.WithCancel(c.bgctx) pcall := newPromisedPipelineCaller() ans.setPipelineCaller(c, p.method, pcall) rl.Add(func() { @@ -941,7 +941,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err } c.tasks.Add(1) // will be finished by answer.Return var callCtx context.Context - callCtx, ans.returner.cancel = context.WithCancel(c.bgctx) + callCtx, ans.cancel = context.WithCancel(c.bgctx) pcall := newPromisedPipelineCaller() ans.setPipelineCaller(c, p.method, pcall) rl.Add(func() { @@ -951,7 +951,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err // Results not ready, use pipeline caller. tgtAns.pcalls.Add(1) // will be finished by answer.Return var callCtx context.Context - callCtx, ans.returner.cancel = context.WithCancel(c.bgctx) + callCtx, ans.cancel = context.WithCancel(c.bgctx) tgt := tgtAns.pcall c.tasks.Add(1) // will be finished by answer.Return pcall := newPromisedPipelineCaller() @@ -1323,8 +1323,8 @@ func (c *Conn) handleFinish(ctx context.Context, in transport.IncomingMessage) e if releaseResultCaps { ans.flags |= releaseResultCapsFlag } - if ans.returner.cancel != nil { - ans.returner.cancel() + if ans.cancel != nil { + ans.cancel() } if !ans.flags.Contains(returnSent) { return nil From 9161f0e8d63c1b09060c2228037c6d5f6e04e786 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Fri, 31 Mar 2023 20:34:58 -0400 Subject: [PATCH 3/4] Fix invariant violation. The comments for msgReleaser suggest we can't hold the lock while doing this, so add it to the releaseList instead. I'd like to make this harder to mess up in the future. --- rpc/answer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rpc/answer.go b/rpc/answer.go index 02a94350..e5fc37a5 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -268,7 +268,7 @@ func (ans *ansent) prepareSendReturn(c *lockedConn, rl *releaseList) { select { case <-c.bgctx.Done(): // We're not going to send the message after all, so don't forget to release it. - ans.returner.msgReleaser.Decr() + rl.Add(ans.returner.msgReleaser.Decr) ans.sendMsg = nil default: } From b9a9bfb105ba8bce6b14718f516ef970593c8e41 Mon Sep 17 00:00:00 2001 From: Ian Denhardt Date: Fri, 31 Mar 2023 20:44:55 -0400 Subject: [PATCH 4/4] Get rid of lockedConn.assertIs(). This is obsolete, now that we have to go through c.lk to get to the entry. This patch also drops all of the parameters that were transitively passed to assertIs(). We do still sometimes want to actually *use* the connection though, so we add a method to ansent to fetch it, which localizes the cast. --- rpc/answer.go | 72 ++++++++++++++++++++++++--------------------------- rpc/rpc.go | 29 ++++++++------------- 2 files changed, 45 insertions(+), 56 deletions(-) diff --git a/rpc/answer.go b/rpc/answer.go index e5fc37a5..99a74b37 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -41,9 +41,8 @@ type ansent struct { // flags. err error - // sendMsg sends the return message. The caller MUST hold ans.c.lk; - // the argument should be the same as ans.c - sendMsg func(*lockedConn) + // sendMsg sends the return message. + sendMsg func() // cancel cancels the Context used in the received method call. // May be nil. @@ -54,6 +53,13 @@ type ansent struct { returner ansReturner } +// Returns the already-locked connection to which this entry belongs. +// Since ansents are only supposed to be accessed through c.lk, it is +// assumed that the caller already holds the lock. +func (ans *ansent) lockedConn() *lockedConn { + return (*lockedConn)(ans.returner.c) +} + // ansReturner is the implementation of capnp.Returner that is used when // handling an incoming call on a local capability. type ansReturner struct { @@ -106,7 +112,7 @@ func errorAnswer(c *Conn, id answerID, err error) *ansent { // newReturn creates a new Return message. The returned Releaser will release the message when // all references to it are dropped; the caller is responsible for one reference. This will not // happen before the message is sent, as the returned send function retains a reference. -func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Releaser, _ error) { +func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ error) { outMsg, err := c.transport.NewMessage() if err != nil { return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err) @@ -123,9 +129,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel // 'releaseMsg' is called. releaser := rc.NewReleaser(2, outMsg.Release) - unlockedConn := c - return ret, func(c *lockedConn) { - c.assertIs(unlockedConn) + return ret, func() { c.lk.sendTx.Send(asyncSend{ send: outMsg.Send, release: releaser.Decr, @@ -142,9 +146,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel // already returned. The caller MUST hold ans.c.lk. // // This also sets ans.promise to a new promise, wrapping pcall. -func (ans *ansent) setPipelineCaller(c *lockedConn, m capnp.Method, pcall capnp.PipelineCaller) { - c.assertIs(ans.returner.c) - +func (ans *ansent) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) { if !ans.flags.Contains(resultsReady) { ans.pcall = pcall ans.promise = capnp.NewPromise(m, pcall) @@ -198,9 +200,9 @@ func (ans *ansReturner) PrepareReturn(e error) { ans.c.withLocked(func(c *lockedConn) { ent := c.lk.answers[ans.id] if e == nil { - ent.prepareSendReturn(c, rl) + ent.prepareSendReturn(rl) } else { - ent.prepareSendException(c, rl, e) + ent.prepareSendException(rl, e) } }) } @@ -217,9 +219,9 @@ func (ans *ansReturner) Return() { pcallsWait = ent.pcalls.Wait if ent.err == nil { - err = ent.completeSendReturn(c, rl) + err = ent.completeSendReturn(rl) } else { - ent.completeSendException(c, rl) + ent.completeSendException(rl) } }) defer pcallsWait() @@ -250,15 +252,14 @@ func (ans *ansReturner) ReleaseResults() { // the lock, per usual). // // sendReturn MUST NOT be called if sendException was previously called. -func (ans *ansent) sendReturn(c *lockedConn, rl *releaseList) error { - ans.prepareSendReturn(c, rl) - return ans.completeSendReturn(c, rl) +func (ans *ansent) sendReturn(rl *releaseList) error { + ans.prepareSendReturn(rl) + return ans.completeSendReturn(rl) } -func (ans *ansent) prepareSendReturn(c *lockedConn, rl *releaseList) { - c.assertIs(ans.returner.c) - +func (ans *ansent) prepareSendReturn(rl *releaseList) { var err error + c := ans.lockedConn() ans.exportRefs, err = c.fillPayloadCapTable(ans.returner.results) if err != nil { c.er.ReportError(rpcerr.Annotate(err, "send return")) @@ -274,9 +275,7 @@ func (ans *ansent) prepareSendReturn(c *lockedConn, rl *releaseList) { } } -func (ans *ansent) completeSendReturn(c *lockedConn, rl *releaseList) error { - c.assertIs(ans.returner.c) - +func (ans *ansent) completeSendReturn(rl *releaseList) error { ans.pcall = nil ans.flags |= resultsReady @@ -294,12 +293,12 @@ func (ans *ansent) completeSendReturn(c *lockedConn, rl *releaseList) error { } ans.promise = nil } - ans.sendMsg(c) + ans.sendMsg() } ans.flags |= returnSent if fin { - return ans.destroy(c, rl) + return ans.destroy(rl) } return nil } @@ -308,15 +307,15 @@ func (ans *ansent) completeSendReturn(c *lockedConn, rl *releaseList) error { // // The caller MUST be holding onto ans.c.lk. sendException MUST NOT // be called if sendReturn was previously called. -func (ans *ansent) sendException(c *lockedConn, rl *releaseList, ex error) { - ans.prepareSendException(c, rl, ex) - ans.completeSendException(c, rl) +func (ans *ansent) sendException(rl *releaseList, ex error) { + ans.prepareSendException(rl, ex) + ans.completeSendException(rl) } -func (ans *ansent) prepareSendException(c *lockedConn, rl *releaseList, ex error) { - c.assertIs(ans.returner.c) +func (ans *ansent) prepareSendException(rl *releaseList, ex error) { ans.err = ex + c := ans.lockedConn() select { case <-c.bgctx.Done(): default: @@ -331,9 +330,7 @@ func (ans *ansent) prepareSendException(c *lockedConn, rl *releaseList, ex error } } -func (ans *ansent) completeSendException(c *lockedConn, rl *releaseList) { - c.assertIs(ans.returner.c) - +func (ans *ansent) completeSendException(rl *releaseList) { ex := ans.err ans.pcall = nil ans.flags |= resultsReady @@ -343,13 +340,13 @@ func (ans *ansent) completeSendException(c *lockedConn, rl *releaseList) { ans.promise = nil } if ans.sendMsg != nil { - ans.sendMsg(c) + ans.sendMsg() } ans.flags |= returnSent if ans.flags.Contains(finishReceived) { // destroy will never return an error because sendException does // create any exports. - _ = ans.destroy(c, rl) + _ = ans.destroy(rl) } } @@ -358,10 +355,9 @@ func (ans *ansent) completeSendException(c *lockedConn, rl *releaseList) { // The caller must be holding onto ans.c.lk. // // shutdown has its own strategy for cleaning up an answer. -func (ans *ansent) destroy(c *lockedConn, rl *releaseList) error { - c.assertIs(ans.returner.c) - +func (ans *ansent) destroy(rl *releaseList) error { rl.Add(ans.returner.msgReleaser.Decr) + c := ans.lockedConn() delete(c.lk.answers, ans.returner.id) if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 { return nil diff --git a/rpc/rpc.go b/rpc/rpc.go index 35a0c648..708317d4 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -172,13 +172,6 @@ func withLockedConn2[A, B any](c *Conn, f func(*lockedConn) (A, B)) (a A, b B) { return } -// assertIs panics if the receiver and the argument are not the same connection. -func (lc *lockedConn) assertIs(c *Conn) { - if (*Conn)(lc) != c { - panic("lockedConn.assertIs: different connections.") - } -} - // Options specifies optional parameters for creating a Conn. type Options struct { // BootstrapClient is the capability that will be returned to the @@ -746,14 +739,14 @@ func (c *Conn) handleBootstrap(in transport.IncomingMessage) error { c.lk.answers[ans.returner.id] = &ans if !c.bootstrap.IsValid() { - ans.sendException(c, rl, exc.New(exc.Failed, "", "vat does not expose a public/bootstrap interface")) + ans.sendException(rl, exc.New(exc.Failed, "", "vat does not expose a public/bootstrap interface")) return } if err := ans.returner.setBootstrap(c.bootstrap.AddRef()); err != nil { - ans.sendException(c, rl, err) + ans.sendException(rl, err) return } - err = ans.sendReturn(c, rl) + err = ans.sendReturn(rl) if err != nil { // Answer cannot possibly encounter a Finish, since we still // haven't returned to receive(). @@ -851,7 +844,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err c.lk.answers[id] = ans if parseErr != nil { parseErr = rpcerr.Annotate(parseErr, "incoming call") - ans.sendException(c, rl, parseErr) + ans.sendException(rl, parseErr) rl.Add(func() { c.er.ReportError(parseErr) in.Release() @@ -883,7 +876,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err var callCtx context.Context callCtx, ans.cancel = context.WithCancel(c.bgctx) pcall := newPromisedPipelineCaller() - ans.setPipelineCaller(c, p.method, pcall) + ans.setPipelineCaller(p.method, pcall) rl.Add(func() { pcall.resolve(ent.client.RecvCall(callCtx, recv)) }) @@ -906,7 +899,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err } if tgtAns.flags.Contains(resultsReady) { if tgtAns.err != nil { - ans.sendException(c, rl, tgtAns.err) + ans.sendException(rl, tgtAns.err) rl.Add(in.Release) return nil } @@ -917,7 +910,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err content, err := tgtAns.returner.results.Content() if err != nil { err = rpcerr.WrapFailed("incoming call: read results from target answer", err) - ans.sendException(c, rl, err) + ans.sendException(rl, err) rl.Add(in.Release) c.er.ReportError(err) return nil @@ -925,7 +918,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err sub, err := capnp.Transform(content, p.target.transform) if err != nil { // Not reporting, as this is the caller's fault. - ans.sendException(c, rl, err) + ans.sendException(rl, err) rl.Add(in.Release) return nil } @@ -943,7 +936,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err var callCtx context.Context callCtx, ans.cancel = context.WithCancel(c.bgctx) pcall := newPromisedPipelineCaller() - ans.setPipelineCaller(c, p.method, pcall) + ans.setPipelineCaller(p.method, pcall) rl.Add(func() { pcall.resolve(tgt.RecvCall(callCtx, recv)) }) @@ -955,7 +948,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err tgt := tgtAns.pcall c.tasks.Add(1) // will be finished by answer.Return pcall := newPromisedPipelineCaller() - ans.setPipelineCaller(c, p.method, pcall) + ans.setPipelineCaller(p.method, pcall) rl.Add(func() { pcall.resolve(tgt.PipelineRecv(callCtx, p.target.transform, recv)) tgtAns.pcalls.Done() @@ -1331,7 +1324,7 @@ func (c *Conn) handleFinish(ctx context.Context, in transport.IncomingMessage) e } // Return sent and finish received: time to destroy answer. - err := ans.destroy(c, rl) + err := ans.destroy(rl) if err != nil { return rpcerr.Annotate(err, "incoming finish: release result caps") }