diff --git a/rpc/answer.go b/rpc/answer.go index f1390877..99a74b37 100644 --- a/rpc/answer.go +++ b/rpc/answer.go @@ -14,36 +14,8 @@ import ( // An answerID is an index into the answers table. type answerID uint32 -// answer is an entry in a Conn's answer table. -type answer struct { - // c and id must be set before any answer methods are called. - c *Conn - id answerID - - // cancel cancels the Context used in the received method call. - // May be nil. - cancel context.CancelFunc - - // ret is the outgoing Return struct. ret is valid iff there was no - // error creating the message. If ret is invalid, then this answer - // entry is a placeholder until the remote vat cancels the call. - ret rpccp.Return - - // sendMsg sends the return message. The caller MUST hold ans.c.lk; - // the argument should be the same as ans.c - sendMsg func(*lockedConn) - - // msgReleaser releases the return message when its refcount hits zero. - // The caller MUST NOT hold ans.c.lk. - msgReleaser *rc.Releaser - - // results is the memoized answer to ret.Results(). - // Set by AllocResults and setBootstrap, but contents can only be read - // if flags has resultsReady but not finishReceived set. - results rpccp.Payload - - // All fields below are protected by s.c.mu. - +// ansent is an entry in a Conn's answer table. +type ansent struct { // flags is a bitmask of events that have occurred in an answer's // lifetime. flags answerFlags @@ -68,6 +40,46 @@ type answer struct { // the Return message. Can only be read after resultsReady is set in // flags. err error + + // sendMsg sends the return message. + sendMsg func() + + // cancel cancels the Context used in the received method call. + // May be nil. + cancel context.CancelFunc + + // Unlike other fields in this struct, it is ok to hand out pointers + // to this that can be used while not holding the connection lock. + returner ansReturner +} + +// Returns the already-locked connection to which this entry belongs. +// Since ansents are only supposed to be accessed through c.lk, it is +// assumed that the caller already holds the lock. +func (ans *ansent) lockedConn() *lockedConn { + return (*lockedConn)(ans.returner.c) +} + +// ansReturner is the implementation of capnp.Returner that is used when +// handling an incoming call on a local capability. +type ansReturner struct { + // c and id must be set before any answer methods are called. + c *Conn + id answerID + + // ret is the outgoing Return struct. ret is valid iff there was no + // error creating the message. If ret is invalid, then this answer + // entry is a placeholder until the remote vat cancels the call. + ret rpccp.Return + + // msgReleaser releases the return message when its refcount hits zero. + // The caller MUST NOT hold ans.c.lk. + msgReleaser *rc.Releaser + + // results is the memoized answer to ret.Results(). + // Set by AllocResults and setBootstrap, but contents can only be read + // if flags has resultsReady but not finishReceived set. + results rpccp.Payload } type answerFlags uint8 @@ -85,11 +97,13 @@ func (flags answerFlags) Contains(flag answerFlags) bool { return flags&flag != 0 } -// errorAnswer returns a placeholder answer with an error result already set. -func errorAnswer(c *Conn, id answerID, err error) *answer { - return &answer{ - c: c, - id: id, +// errorAnswer returns a placeholder answer entry with an error result already set. +func errorAnswer(c *Conn, id answerID, err error) *ansent { + return &ansent{ + returner: ansReturner{ + c: c, + id: id, + }, err: err, flags: resultsReady | returnSent, } @@ -98,7 +112,7 @@ func errorAnswer(c *Conn, id answerID, err error) *answer { // newReturn creates a new Return message. The returned Releaser will release the message when // all references to it are dropped; the caller is responsible for one reference. This will not // happen before the message is sent, as the returned send function retains a reference. -func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Releaser, _ error) { +func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ error) { outMsg, err := c.transport.NewMessage() if err != nil { return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err) @@ -115,9 +129,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel // 'releaseMsg' is called. releaser := rc.NewReleaser(2, outMsg.Release) - unlockedConn := c - return ret, func(c *lockedConn) { - c.assertIs(unlockedConn) + return ret, func() { c.lk.sendTx.Send(asyncSend{ send: outMsg.Send, release: releaser.Decr, @@ -134,9 +146,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel // already returned. The caller MUST hold ans.c.lk. // // This also sets ans.promise to a new promise, wrapping pcall. -func (ans *answer) setPipelineCaller(c *lockedConn, m capnp.Method, pcall capnp.PipelineCaller) { - c.assertIs(ans.c) - +func (ans *ansent) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) { if !ans.flags.Contains(resultsReady) { ans.pcall = pcall ans.promise = capnp.NewPromise(m, pcall) @@ -144,7 +154,7 @@ func (ans *answer) setPipelineCaller(c *lockedConn, m capnp.Method, pcall capnp. } // AllocResults allocates the results struct. -func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) { +func (ans *ansReturner) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) { var err error ans.results, err = ans.ret.NewResults() if err != nil { @@ -163,7 +173,7 @@ func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) { // setBootstrap sets the results to an interface pointer, stealing the // reference. -func (ans *answer) setBootstrap(c capnp.Client) error { +func (ans *ansReturner) setBootstrap(c capnp.Client) error { if ans.ret.HasResults() || len(ans.ret.Message().CapTable) > 0 { panic("setBootstrap called after creating results") } @@ -183,33 +193,38 @@ func (ans *answer) setBootstrap(c capnp.Client) error { } // PrepareReturn implements capnp.Returner.PrepareReturn -func (ans *answer) PrepareReturn(e error) { +func (ans *ansReturner) PrepareReturn(e error) { rl := &releaseList{} defer rl.Release() ans.c.withLocked(func(c *lockedConn) { + ent := c.lk.answers[ans.id] if e == nil { - ans.prepareSendReturn(c, rl) + ent.prepareSendReturn(rl) } else { - ans.prepareSendException(c, rl, e) + ent.prepareSendException(rl, e) } }) } // Return implements capnp.Returner.Return -func (ans *answer) Return() { +func (ans *ansReturner) Return() { rl := &releaseList{} defer rl.Release() - defer ans.pcalls.Wait() + pcallsWait := func() {} var err error ans.c.withLocked(func(c *lockedConn) { - if ans.err == nil { - err = ans.completeSendReturn(c, rl) + ent := c.lk.answers[ans.id] + pcallsWait = ent.pcalls.Wait + + if ent.err == nil { + err = ent.completeSendReturn(rl) } else { - ans.completeSendException(c, rl) + ent.completeSendException(rl) } }) + defer pcallsWait() ans.c.tasks.Done() // added by handleCall if err == nil { @@ -221,7 +236,7 @@ func (ans *answer) Return() { } } -func (ans *answer) ReleaseResults() { +func (ans *ansReturner) ReleaseResults() { if ans.results.IsValid() { ans.msgReleaser.Decr() } @@ -237,33 +252,30 @@ func (ans *answer) ReleaseResults() { // the lock, per usual). // // sendReturn MUST NOT be called if sendException was previously called. -func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error { - ans.prepareSendReturn(c, rl) - return ans.completeSendReturn(c, rl) +func (ans *ansent) sendReturn(rl *releaseList) error { + ans.prepareSendReturn(rl) + return ans.completeSendReturn(rl) } -func (ans *answer) prepareSendReturn(c *lockedConn, rl *releaseList) { - c.assertIs(ans.c) - +func (ans *ansent) prepareSendReturn(rl *releaseList) { var err error - ans.exportRefs, err = c.fillPayloadCapTable(ans.results) + c := ans.lockedConn() + ans.exportRefs, err = c.fillPayloadCapTable(ans.returner.results) if err != nil { - ans.c.er.ReportError(rpcerr.Annotate(err, "send return")) + c.er.ReportError(rpcerr.Annotate(err, "send return")) } // Continue. Don't fail to send return if cap table isn't fully filled. select { - case <-ans.c.bgctx.Done(): + case <-c.bgctx.Done(): // We're not going to send the message after all, so don't forget to release it. - ans.msgReleaser.Decr() + rl.Add(ans.returner.msgReleaser.Decr) ans.sendMsg = nil default: } } -func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error { - c.assertIs(ans.c) - +func (ans *ansent) completeSendReturn(rl *releaseList) error { ans.pcall = nil ans.flags |= resultsReady @@ -277,16 +289,16 @@ func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error { // the cancel variant of return. ans.promise.Reject(rpcerr.Failed(errors.New("received finish before return"))) } else { - ans.promise.Resolve(ans.results.Content()) + ans.promise.Resolve(ans.returner.results.Content()) } ans.promise = nil } - ans.sendMsg(c) + ans.sendMsg() } ans.flags |= returnSent if fin { - return ans.destroy(c, rl) + return ans.destroy(rl) } return nil } @@ -295,32 +307,30 @@ func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error { // // The caller MUST be holding onto ans.c.lk. sendException MUST NOT // be called if sendReturn was previously called. -func (ans *answer) sendException(c *lockedConn, rl *releaseList, ex error) { - ans.prepareSendException(c, rl, ex) - ans.completeSendException(c, rl) +func (ans *ansent) sendException(rl *releaseList, ex error) { + ans.prepareSendException(rl, ex) + ans.completeSendException(rl) } -func (ans *answer) prepareSendException(c *lockedConn, rl *releaseList, ex error) { - c.assertIs(ans.c) +func (ans *ansent) prepareSendException(rl *releaseList, ex error) { ans.err = ex + c := ans.lockedConn() select { - case <-ans.c.bgctx.Done(): + case <-c.bgctx.Done(): default: // Send exception. - if e, err := ans.ret.NewException(); err != nil { - ans.c.er.ReportError(exc.WrapError("send exception", err)) + if e, err := ans.returner.ret.NewException(); err != nil { + c.er.ReportError(exc.WrapError("send exception", err)) ans.sendMsg = nil } else if err := e.MarshalError(ex); err != nil { - ans.c.er.ReportError(exc.WrapError("send exception", err)) + c.er.ReportError(exc.WrapError("send exception", err)) ans.sendMsg = nil } } } -func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) { - c.assertIs(ans.c) - +func (ans *ansent) completeSendException(rl *releaseList) { ex := ans.err ans.pcall = nil ans.flags |= resultsReady @@ -330,13 +340,13 @@ func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) { ans.promise = nil } if ans.sendMsg != nil { - ans.sendMsg(c) + ans.sendMsg() } ans.flags |= returnSent if ans.flags.Contains(finishReceived) { // destroy will never return an error because sendException does // create any exports. - _ = ans.destroy(c, rl) + _ = ans.destroy(rl) } } @@ -345,11 +355,10 @@ func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) { // The caller must be holding onto ans.c.lk. // // shutdown has its own strategy for cleaning up an answer. -func (ans *answer) destroy(c *lockedConn, rl *releaseList) error { - c.assertIs(ans.c) - - rl.Add(ans.msgReleaser.Decr) - delete(c.lk.answers, ans.id) +func (ans *ansent) destroy(rl *releaseList) error { + rl.Add(ans.returner.msgReleaser.Decr) + c := ans.lockedConn() + delete(c.lk.answers, ans.returner.id) if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 { return nil diff --git a/rpc/rpc.go b/rpc/rpc.go index 14247d16..708317d4 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -122,7 +122,7 @@ type Conn struct { // Tables questions []*question questionID idgen - answers map[answerID]*answer + answers map[answerID]*ansent exports []*expent exportID idgen imports map[importID]*impent @@ -172,13 +172,6 @@ func withLockedConn2[A, B any](c *Conn, f func(*lockedConn) (A, B)) (a A, b B) { return } -// assertIs panics if the receiver and the argument are not the same connection. -func (lc *lockedConn) assertIs(c *Conn) { - if (*Conn)(lc) != c { - panic("lockedConn.assertIs: different connections.") - } -} - // Options specifies optional parameters for creating a Conn. type Options struct { // BootstrapClient is the capability that will be returned to the @@ -234,7 +227,7 @@ func NewConn(t Transport, opts *Options) *Conn { c.sendRx = &sender.Rx c.lk.sendTx = &sender.Tx - c.lk.answers = make(map[answerID]*answer) + c.lk.answers = make(map[answerID]*ansent) c.lk.imports = make(map[importID]*impent) if opts != nil { @@ -507,10 +500,10 @@ func (c *lockedConn) liftEmbargoes(rl *releaseList, embargoes []*embargo) { } } -func (c *lockedConn) releaseAnswers(rl *releaseList, answers map[answerID]*answer) { +func (c *lockedConn) releaseAnswers(rl *releaseList, answers map[answerID]*ansent) { for _, a := range answers { - if a != nil && a.msgReleaser != nil { - rl.Add(a.msgReleaser.Decr) + if a != nil && a.returner.msgReleaser != nil { + rl.Add(a.returner.msgReleaser.Decr) } } } @@ -717,41 +710,43 @@ func (c *Conn) handleBootstrap(in transport.IncomingMessage) error { rl := &releaseList{} defer rl.Release() - ans := answer{ - c: c, - id: answerID(bootstrap.QuestionId()), + ans := ansent{ + returner: ansReturner{ + c: c, + id: answerID(bootstrap.QuestionId()), + }, } - ans.ret, ans.sendMsg, ans.msgReleaser, err = c.newReturn() + ans.returner.ret, ans.sendMsg, ans.returner.msgReleaser, err = c.newReturn() if err == nil { - ans.ret.SetAnswerId(uint32(ans.id)) - ans.ret.SetReleaseParamCaps(false) + ans.returner.ret.SetAnswerId(uint32(ans.returner.id)) + ans.returner.ret.SetReleaseParamCaps(false) } c.withLocked(func(c *lockedConn) { - if c.lk.answers[ans.id] != nil { - rl.Add(ans.msgReleaser.Decr) - err = rpcerr.Failed(errors.New("incoming bootstrap: answer ID " + str.Utod(ans.id) + " reused")) + if c.lk.answers[ans.returner.id] != nil { + rl.Add(ans.returner.msgReleaser.Decr) + err = rpcerr.Failed(errors.New("incoming bootstrap: answer ID " + str.Utod(ans.returner.id) + " reused")) return } if err != nil { err = rpcerr.Annotate(err, "incoming bootstrap") - c.lk.answers[ans.id] = errorAnswer((*Conn)(c), ans.id, err) + c.lk.answers[ans.returner.id] = errorAnswer((*Conn)(c), ans.returner.id, err) c.er.ReportError(err) return } - c.lk.answers[ans.id] = &ans + c.lk.answers[ans.returner.id] = &ans if !c.bootstrap.IsValid() { - ans.sendException(c, rl, exc.New(exc.Failed, "", "vat does not expose a public/bootstrap interface")) + ans.sendException(rl, exc.New(exc.Failed, "", "vat does not expose a public/bootstrap interface")) return } - if err := ans.setBootstrap(c.bootstrap.AddRef()); err != nil { - ans.sendException(c, rl, err) + if err := ans.returner.setBootstrap(c.bootstrap.AddRef()); err != nil { + ans.sendException(rl, err) return } - err = ans.sendReturn(c, rl) + err = ans.sendReturn(rl) if err != nil { // Answer cannot possibly encounter a Finish, since we still // haven't returned to receive(). @@ -836,18 +831,20 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err ret.SetReleaseParamCaps(false) // Find target and start call. - ans := &answer{ - c: c, - id: id, - ret: ret, - sendMsg: send, - msgReleaser: retReleaser, + ans := &ansent{ + returner: ansReturner{ + c: c, + id: id, + ret: ret, + msgReleaser: retReleaser, + }, + sendMsg: send, } return withLockedConn1(c, func(c *lockedConn) error { c.lk.answers[id] = ans if parseErr != nil { parseErr = rpcerr.Annotate(parseErr, "incoming call") - ans.sendException(c, rl, parseErr) + ans.sendException(rl, parseErr) rl.Add(func() { c.er.ReportError(parseErr) in.Release() @@ -859,16 +856,16 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err Args: p.args, Method: p.method, ReleaseArgs: util.Idempotent(in.Release), - Returner: ans, + Returner: &ans.returner, } switch p.target.which { case rpccp.MessageTarget_Which_importedCap: ent := c.findExport(p.target.importedCap) if ent == nil { - ans.ret = rpccp.Return{} + ans.returner.ret = rpccp.Return{} ans.sendMsg = nil - ans.msgReleaser = nil + ans.returner.msgReleaser = nil rl.Add(func() { retReleaser.Decr() in.Release() @@ -879,7 +876,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err var callCtx context.Context callCtx, ans.cancel = context.WithCancel(c.bgctx) pcall := newPromisedPipelineCaller() - ans.setPipelineCaller(c, p.method, pcall) + ans.setPipelineCaller(p.method, pcall) rl.Add(func() { pcall.resolve(ent.client.RecvCall(callCtx, recv)) }) @@ -887,9 +884,9 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err case rpccp.MessageTarget_Which_promisedAnswer: tgtAns := c.lk.answers[p.target.promisedAnswer] if tgtAns == nil || tgtAns.flags.Contains(finishReceived) { - ans.ret = rpccp.Return{} + ans.returner.ret = rpccp.Return{} ans.sendMsg = nil - ans.msgReleaser = nil + ans.returner.msgReleaser = nil rl.Add(func() { retReleaser.Decr() in.Release() @@ -902,7 +899,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err } if tgtAns.flags.Contains(resultsReady) { if tgtAns.err != nil { - ans.sendException(c, rl, tgtAns.err) + ans.sendException(rl, tgtAns.err) rl.Add(in.Release) return nil } @@ -910,10 +907,10 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err // received finish yet (it would have been deleted from the // answers table), and it can't receive a finish because this is // happening on the receive goroutine. - content, err := tgtAns.results.Content() + content, err := tgtAns.returner.results.Content() if err != nil { err = rpcerr.WrapFailed("incoming call: read results from target answer", err) - ans.sendException(c, rl, err) + ans.sendException(rl, err) rl.Add(in.Release) c.er.ReportError(err) return nil @@ -921,7 +918,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err sub, err := capnp.Transform(content, p.target.transform) if err != nil { // Not reporting, as this is the caller's fault. - ans.sendException(c, rl, err) + ans.sendException(rl, err) rl.Add(in.Release) return nil } @@ -930,16 +927,16 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err switch { case sub.IsValid() && !iface.IsValid(): tgt = capnp.ErrorClient(rpcerr.Failed(ErrNotACapability)) - case !iface.IsValid() || int64(iface.Capability()) >= int64(len(tgtAns.results.Message().CapTable)): + case !iface.IsValid() || int64(iface.Capability()) >= int64(len(tgtAns.returner.results.Message().CapTable)): tgt = capnp.Client{} default: - tgt = tgtAns.results.Message().CapTable[iface.Capability()] + tgt = tgtAns.returner.results.Message().CapTable[iface.Capability()] } c.tasks.Add(1) // will be finished by answer.Return var callCtx context.Context callCtx, ans.cancel = context.WithCancel(c.bgctx) pcall := newPromisedPipelineCaller() - ans.setPipelineCaller(c, p.method, pcall) + ans.setPipelineCaller(p.method, pcall) rl.Add(func() { pcall.resolve(tgt.RecvCall(callCtx, recv)) }) @@ -951,7 +948,7 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err tgt := tgtAns.pcall c.tasks.Add(1) // will be finished by answer.Return pcall := newPromisedPipelineCaller() - ans.setPipelineCaller(c, p.method, pcall) + ans.setPipelineCaller(p.method, pcall) rl.Add(func() { pcall.resolve(tgt.PipelineRecv(callCtx, p.target.transform, recv)) tgtAns.pcalls.Done() @@ -1327,7 +1324,7 @@ func (c *Conn) handleFinish(ctx context.Context, in transport.IncomingMessage) e } // Return sent and finish received: time to destroy answer. - err := ans.destroy(c, rl) + err := ans.destroy(rl) if err != nil { return rpcerr.Annotate(err, "incoming finish: release result caps") } @@ -1417,7 +1414,7 @@ func (c *lockedConn) recvCap(d rpccp.CapDescriptor) (capnp.Client, error) { } // Helper for lockedConn.recvCap(); handles the receiverAnswer case. -func (c *lockedConn) recvCapReceiverAnswer(ans *answer, transform []capnp.PipelineOp) capnp.Client { +func (c *lockedConn) recvCapReceiverAnswer(ans *ansent, transform []capnp.PipelineOp) capnp.Client { if ans.promise != nil { // Still unresolved. future := ans.promise.Answer().Future() @@ -1431,7 +1428,7 @@ func (c *lockedConn) recvCapReceiverAnswer(ans *answer, transform []capnp.Pipeli return capnp.ErrorClient(ans.err) } - ptr, err := ans.results.Content() + ptr, err := ans.returner.results.Content() if err != nil { return capnp.ErrorClient(rpcerr.WrapFailed("except.Failed reading results", err)) } @@ -1641,7 +1638,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag } var content capnp.Ptr - if content, err = ans.results.Content(); err != nil { + if content, err = ans.returner.results.Content(); err != nil { err = rpcerr.Failed(errors.New( "incoming disembargo: read answer ID " + str.Utod(tgt.promisedAnswer) + ": " + err.Error(), @@ -1659,7 +1656,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag } iface := ptr.Interface() - if !iface.IsValid() || int64(iface.Capability()) >= int64(len(ans.results.Message().CapTable)) { + if !iface.IsValid() || int64(iface.Capability()) >= int64(len(ans.returner.results.Message().CapTable)) { err = rpcerr.Failed(errors.New( "incoming disembargo: sender loopback requested on a capability that is not an import", ))