Skip to content

Commit

Permalink
Merge pull request #495 from zenhack/ansReturner
Browse files Browse the repository at this point in the history
Decompose answer struct.
  • Loading branch information
zenhack authored Apr 1, 2023
2 parents 1a829fd + b9a9bfb commit d5231d3
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 141 deletions.
185 changes: 97 additions & 88 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,8 @@ import (
// An answerID is an index into the answers table.
type answerID uint32

// answer is an entry in a Conn's answer table.
type answer struct {
// c and id must be set before any answer methods are called.
c *Conn
id answerID

// cancel cancels the Context used in the received method call.
// May be nil.
cancel context.CancelFunc

// ret is the outgoing Return struct. ret is valid iff there was no
// error creating the message. If ret is invalid, then this answer
// entry is a placeholder until the remote vat cancels the call.
ret rpccp.Return

// sendMsg sends the return message. The caller MUST hold ans.c.lk;
// the argument should be the same as ans.c
sendMsg func(*lockedConn)

// msgReleaser releases the return message when its refcount hits zero.
// The caller MUST NOT hold ans.c.lk.
msgReleaser *rc.Releaser

// results is the memoized answer to ret.Results().
// Set by AllocResults and setBootstrap, but contents can only be read
// if flags has resultsReady but not finishReceived set.
results rpccp.Payload

// All fields below are protected by s.c.mu.

// ansent is an entry in a Conn's answer table.
type ansent struct {
// flags is a bitmask of events that have occurred in an answer's
// lifetime.
flags answerFlags
Expand All @@ -68,6 +40,46 @@ type answer struct {
// the Return message. Can only be read after resultsReady is set in
// flags.
err error

// sendMsg sends the return message.
sendMsg func()

// cancel cancels the Context used in the received method call.
// May be nil.
cancel context.CancelFunc

// Unlike other fields in this struct, it is ok to hand out pointers
// to this that can be used while not holding the connection lock.
returner ansReturner
}

// Returns the already-locked connection to which this entry belongs.
// Since ansents are only supposed to be accessed through c.lk, it is
// assumed that the caller already holds the lock.
func (ans *ansent) lockedConn() *lockedConn {
return (*lockedConn)(ans.returner.c)
}

// ansReturner is the implementation of capnp.Returner that is used when
// handling an incoming call on a local capability.
type ansReturner struct {
// c and id must be set before any answer methods are called.
c *Conn
id answerID

// ret is the outgoing Return struct. ret is valid iff there was no
// error creating the message. If ret is invalid, then this answer
// entry is a placeholder until the remote vat cancels the call.
ret rpccp.Return

// msgReleaser releases the return message when its refcount hits zero.
// The caller MUST NOT hold ans.c.lk.
msgReleaser *rc.Releaser

// results is the memoized answer to ret.Results().
// Set by AllocResults and setBootstrap, but contents can only be read
// if flags has resultsReady but not finishReceived set.
results rpccp.Payload
}

type answerFlags uint8
Expand All @@ -85,11 +97,13 @@ func (flags answerFlags) Contains(flag answerFlags) bool {
return flags&flag != 0
}

// errorAnswer returns a placeholder answer with an error result already set.
func errorAnswer(c *Conn, id answerID, err error) *answer {
return &answer{
c: c,
id: id,
// errorAnswer returns a placeholder answer entry with an error result already set.
func errorAnswer(c *Conn, id answerID, err error) *ansent {
return &ansent{
returner: ansReturner{
c: c,
id: id,
},
err: err,
flags: resultsReady | returnSent,
}
Expand All @@ -98,7 +112,7 @@ func errorAnswer(c *Conn, id answerID, err error) *answer {
// newReturn creates a new Return message. The returned Releaser will release the message when
// all references to it are dropped; the caller is responsible for one reference. This will not
// happen before the message is sent, as the returned send function retains a reference.
func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Releaser, _ error) {
func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ error) {
outMsg, err := c.transport.NewMessage()
if err != nil {
return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err)
Expand All @@ -115,9 +129,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel
// 'releaseMsg' is called.
releaser := rc.NewReleaser(2, outMsg.Release)

unlockedConn := c
return ret, func(c *lockedConn) {
c.assertIs(unlockedConn)
return ret, func() {
c.lk.sendTx.Send(asyncSend{
send: outMsg.Send,
release: releaser.Decr,
Expand All @@ -134,17 +146,15 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel
// already returned. The caller MUST hold ans.c.lk.
//
// This also sets ans.promise to a new promise, wrapping pcall.
func (ans *answer) setPipelineCaller(c *lockedConn, m capnp.Method, pcall capnp.PipelineCaller) {
c.assertIs(ans.c)

func (ans *ansent) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) {
if !ans.flags.Contains(resultsReady) {
ans.pcall = pcall
ans.promise = capnp.NewPromise(m, pcall)
}
}

// AllocResults allocates the results struct.
func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) {
func (ans *ansReturner) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) {
var err error
ans.results, err = ans.ret.NewResults()
if err != nil {
Expand All @@ -163,7 +173,7 @@ func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) {

// setBootstrap sets the results to an interface pointer, stealing the
// reference.
func (ans *answer) setBootstrap(c capnp.Client) error {
func (ans *ansReturner) setBootstrap(c capnp.Client) error {
if ans.ret.HasResults() || len(ans.ret.Message().CapTable) > 0 {
panic("setBootstrap called after creating results")
}
Expand All @@ -183,33 +193,38 @@ func (ans *answer) setBootstrap(c capnp.Client) error {
}

// PrepareReturn implements capnp.Returner.PrepareReturn
func (ans *answer) PrepareReturn(e error) {
func (ans *ansReturner) PrepareReturn(e error) {
rl := &releaseList{}
defer rl.Release()

ans.c.withLocked(func(c *lockedConn) {
ent := c.lk.answers[ans.id]
if e == nil {
ans.prepareSendReturn(c, rl)
ent.prepareSendReturn(rl)
} else {
ans.prepareSendException(c, rl, e)
ent.prepareSendException(rl, e)
}
})
}

// Return implements capnp.Returner.Return
func (ans *answer) Return() {
func (ans *ansReturner) Return() {
rl := &releaseList{}
defer rl.Release()
defer ans.pcalls.Wait()
pcallsWait := func() {}

var err error
ans.c.withLocked(func(c *lockedConn) {
if ans.err == nil {
err = ans.completeSendReturn(c, rl)
ent := c.lk.answers[ans.id]
pcallsWait = ent.pcalls.Wait

if ent.err == nil {
err = ent.completeSendReturn(rl)
} else {
ans.completeSendException(c, rl)
ent.completeSendException(rl)
}
})
defer pcallsWait()
ans.c.tasks.Done() // added by handleCall

if err == nil {
Expand All @@ -221,7 +236,7 @@ func (ans *answer) Return() {
}
}

func (ans *answer) ReleaseResults() {
func (ans *ansReturner) ReleaseResults() {
if ans.results.IsValid() {
ans.msgReleaser.Decr()
}
Expand All @@ -237,33 +252,30 @@ func (ans *answer) ReleaseResults() {
// the lock, per usual).
//
// sendReturn MUST NOT be called if sendException was previously called.
func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error {
ans.prepareSendReturn(c, rl)
return ans.completeSendReturn(c, rl)
func (ans *ansent) sendReturn(rl *releaseList) error {
ans.prepareSendReturn(rl)
return ans.completeSendReturn(rl)
}

func (ans *answer) prepareSendReturn(c *lockedConn, rl *releaseList) {
c.assertIs(ans.c)

func (ans *ansent) prepareSendReturn(rl *releaseList) {
var err error
ans.exportRefs, err = c.fillPayloadCapTable(ans.results)
c := ans.lockedConn()
ans.exportRefs, err = c.fillPayloadCapTable(ans.returner.results)
if err != nil {
ans.c.er.ReportError(rpcerr.Annotate(err, "send return"))
c.er.ReportError(rpcerr.Annotate(err, "send return"))
}
// Continue. Don't fail to send return if cap table isn't fully filled.

select {
case <-ans.c.bgctx.Done():
case <-c.bgctx.Done():
// We're not going to send the message after all, so don't forget to release it.
ans.msgReleaser.Decr()
rl.Add(ans.returner.msgReleaser.Decr)
ans.sendMsg = nil
default:
}
}

func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error {
c.assertIs(ans.c)

func (ans *ansent) completeSendReturn(rl *releaseList) error {
ans.pcall = nil
ans.flags |= resultsReady

Expand All @@ -277,16 +289,16 @@ func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error {
// the cancel variant of return.
ans.promise.Reject(rpcerr.Failed(errors.New("received finish before return")))
} else {
ans.promise.Resolve(ans.results.Content())
ans.promise.Resolve(ans.returner.results.Content())
}
ans.promise = nil
}
ans.sendMsg(c)
ans.sendMsg()
}

ans.flags |= returnSent
if fin {
return ans.destroy(c, rl)
return ans.destroy(rl)
}
return nil
}
Expand All @@ -295,32 +307,30 @@ func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error {
//
// The caller MUST be holding onto ans.c.lk. sendException MUST NOT
// be called if sendReturn was previously called.
func (ans *answer) sendException(c *lockedConn, rl *releaseList, ex error) {
ans.prepareSendException(c, rl, ex)
ans.completeSendException(c, rl)
func (ans *ansent) sendException(rl *releaseList, ex error) {
ans.prepareSendException(rl, ex)
ans.completeSendException(rl)
}

func (ans *answer) prepareSendException(c *lockedConn, rl *releaseList, ex error) {
c.assertIs(ans.c)
func (ans *ansent) prepareSendException(rl *releaseList, ex error) {
ans.err = ex

c := ans.lockedConn()
select {
case <-ans.c.bgctx.Done():
case <-c.bgctx.Done():
default:
// Send exception.
if e, err := ans.ret.NewException(); err != nil {
ans.c.er.ReportError(exc.WrapError("send exception", err))
if e, err := ans.returner.ret.NewException(); err != nil {
c.er.ReportError(exc.WrapError("send exception", err))
ans.sendMsg = nil
} else if err := e.MarshalError(ex); err != nil {
ans.c.er.ReportError(exc.WrapError("send exception", err))
c.er.ReportError(exc.WrapError("send exception", err))
ans.sendMsg = nil
}
}
}

func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) {
c.assertIs(ans.c)

func (ans *ansent) completeSendException(rl *releaseList) {
ex := ans.err
ans.pcall = nil
ans.flags |= resultsReady
Expand All @@ -330,13 +340,13 @@ func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) {
ans.promise = nil
}
if ans.sendMsg != nil {
ans.sendMsg(c)
ans.sendMsg()
}
ans.flags |= returnSent
if ans.flags.Contains(finishReceived) {
// destroy will never return an error because sendException does
// create any exports.
_ = ans.destroy(c, rl)
_ = ans.destroy(rl)
}
}

Expand All @@ -345,11 +355,10 @@ func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) {
// The caller must be holding onto ans.c.lk.
//
// shutdown has its own strategy for cleaning up an answer.
func (ans *answer) destroy(c *lockedConn, rl *releaseList) error {
c.assertIs(ans.c)

rl.Add(ans.msgReleaser.Decr)
delete(c.lk.answers, ans.id)
func (ans *ansent) destroy(rl *releaseList) error {
rl.Add(ans.returner.msgReleaser.Decr)
c := ans.lockedConn()
delete(c.lk.answers, ans.returner.id)
if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 {
return nil

Expand Down
Loading

0 comments on commit d5231d3

Please sign in to comment.