Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decompose answer struct. #495

Merged
merged 4 commits into from
Apr 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 97 additions & 88 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,8 @@ import (
// An answerID is an index into the answers table.
type answerID uint32

// answer is an entry in a Conn's answer table.
type answer struct {
// c and id must be set before any answer methods are called.
c *Conn
id answerID

// cancel cancels the Context used in the received method call.
// May be nil.
cancel context.CancelFunc

// ret is the outgoing Return struct. ret is valid iff there was no
// error creating the message. If ret is invalid, then this answer
// entry is a placeholder until the remote vat cancels the call.
ret rpccp.Return

// sendMsg sends the return message. The caller MUST hold ans.c.lk;
// the argument should be the same as ans.c
sendMsg func(*lockedConn)

// msgReleaser releases the return message when its refcount hits zero.
// The caller MUST NOT hold ans.c.lk.
msgReleaser *rc.Releaser

// results is the memoized answer to ret.Results().
// Set by AllocResults and setBootstrap, but contents can only be read
// if flags has resultsReady but not finishReceived set.
results rpccp.Payload

// All fields below are protected by s.c.mu.

// ansent is an entry in a Conn's answer table.
type ansent struct {
// flags is a bitmask of events that have occurred in an answer's
// lifetime.
flags answerFlags
Expand All @@ -68,6 +40,46 @@ type answer struct {
// the Return message. Can only be read after resultsReady is set in
// flags.
err error

// sendMsg sends the return message.
sendMsg func()

// cancel cancels the Context used in the received method call.
// May be nil.
cancel context.CancelFunc

// Unlike other fields in this struct, it is ok to hand out pointers
// to this that can be used while not holding the connection lock.
returner ansReturner
}

// Returns the already-locked connection to which this entry belongs.
// Since ansents are only supposed to be accessed through c.lk, it is
// assumed that the caller already holds the lock.
func (ans *ansent) lockedConn() *lockedConn {
return (*lockedConn)(ans.returner.c)
}

// ansReturner is the implementation of capnp.Returner that is used when
// handling an incoming call on a local capability.
type ansReturner struct {
// c and id must be set before any answer methods are called.
c *Conn
id answerID

// ret is the outgoing Return struct. ret is valid iff there was no
// error creating the message. If ret is invalid, then this answer
// entry is a placeholder until the remote vat cancels the call.
ret rpccp.Return

// msgReleaser releases the return message when its refcount hits zero.
// The caller MUST NOT hold ans.c.lk.
msgReleaser *rc.Releaser

// results is the memoized answer to ret.Results().
// Set by AllocResults and setBootstrap, but contents can only be read
// if flags has resultsReady but not finishReceived set.
results rpccp.Payload
}

type answerFlags uint8
Expand All @@ -85,11 +97,13 @@ func (flags answerFlags) Contains(flag answerFlags) bool {
return flags&flag != 0
}

// errorAnswer returns a placeholder answer with an error result already set.
func errorAnswer(c *Conn, id answerID, err error) *answer {
return &answer{
c: c,
id: id,
// errorAnswer returns a placeholder answer entry with an error result already set.
func errorAnswer(c *Conn, id answerID, err error) *ansent {
return &ansent{
returner: ansReturner{
c: c,
id: id,
},
err: err,
flags: resultsReady | returnSent,
}
Expand All @@ -98,7 +112,7 @@ func errorAnswer(c *Conn, id answerID, err error) *answer {
// newReturn creates a new Return message. The returned Releaser will release the message when
// all references to it are dropped; the caller is responsible for one reference. This will not
// happen before the message is sent, as the returned send function retains a reference.
func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Releaser, _ error) {
func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ error) {
outMsg, err := c.transport.NewMessage()
if err != nil {
return rpccp.Return{}, nil, nil, rpcerr.WrapFailed("create return", err)
Expand All @@ -115,9 +129,7 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel
// 'releaseMsg' is called.
releaser := rc.NewReleaser(2, outMsg.Release)

unlockedConn := c
return ret, func(c *lockedConn) {
c.assertIs(unlockedConn)
return ret, func() {
c.lk.sendTx.Send(asyncSend{
send: outMsg.Send,
release: releaser.Decr,
Expand All @@ -134,17 +146,15 @@ func (c *Conn) newReturn() (_ rpccp.Return, sendMsg func(*lockedConn), _ *rc.Rel
// already returned. The caller MUST hold ans.c.lk.
//
// This also sets ans.promise to a new promise, wrapping pcall.
func (ans *answer) setPipelineCaller(c *lockedConn, m capnp.Method, pcall capnp.PipelineCaller) {
c.assertIs(ans.c)

func (ans *ansent) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) {
if !ans.flags.Contains(resultsReady) {
ans.pcall = pcall
ans.promise = capnp.NewPromise(m, pcall)
}
}

// AllocResults allocates the results struct.
func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) {
func (ans *ansReturner) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) {
var err error
ans.results, err = ans.ret.NewResults()
if err != nil {
Expand All @@ -163,7 +173,7 @@ func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) {

// setBootstrap sets the results to an interface pointer, stealing the
// reference.
func (ans *answer) setBootstrap(c capnp.Client) error {
func (ans *ansReturner) setBootstrap(c capnp.Client) error {
if ans.ret.HasResults() || len(ans.ret.Message().CapTable) > 0 {
panic("setBootstrap called after creating results")
}
Expand All @@ -183,33 +193,38 @@ func (ans *answer) setBootstrap(c capnp.Client) error {
}

// PrepareReturn implements capnp.Returner.PrepareReturn
func (ans *answer) PrepareReturn(e error) {
func (ans *ansReturner) PrepareReturn(e error) {
rl := &releaseList{}
defer rl.Release()

ans.c.withLocked(func(c *lockedConn) {
ent := c.lk.answers[ans.id]
if e == nil {
ans.prepareSendReturn(c, rl)
ent.prepareSendReturn(rl)
} else {
ans.prepareSendException(c, rl, e)
ent.prepareSendException(rl, e)
}
})
}

// Return implements capnp.Returner.Return
func (ans *answer) Return() {
func (ans *ansReturner) Return() {
rl := &releaseList{}
defer rl.Release()
defer ans.pcalls.Wait()
pcallsWait := func() {}

var err error
ans.c.withLocked(func(c *lockedConn) {
if ans.err == nil {
err = ans.completeSendReturn(c, rl)
ent := c.lk.answers[ans.id]
pcallsWait = ent.pcalls.Wait

if ent.err == nil {
err = ent.completeSendReturn(rl)
} else {
ans.completeSendException(c, rl)
ent.completeSendException(rl)
}
})
defer pcallsWait()
ans.c.tasks.Done() // added by handleCall

if err == nil {
Expand All @@ -221,7 +236,7 @@ func (ans *answer) Return() {
}
}

func (ans *answer) ReleaseResults() {
func (ans *ansReturner) ReleaseResults() {
if ans.results.IsValid() {
ans.msgReleaser.Decr()
}
Expand All @@ -237,33 +252,30 @@ func (ans *answer) ReleaseResults() {
// the lock, per usual).
//
// sendReturn MUST NOT be called if sendException was previously called.
func (ans *answer) sendReturn(c *lockedConn, rl *releaseList) error {
ans.prepareSendReturn(c, rl)
return ans.completeSendReturn(c, rl)
func (ans *ansent) sendReturn(rl *releaseList) error {
ans.prepareSendReturn(rl)
return ans.completeSendReturn(rl)
}

func (ans *answer) prepareSendReturn(c *lockedConn, rl *releaseList) {
c.assertIs(ans.c)

func (ans *ansent) prepareSendReturn(rl *releaseList) {
var err error
ans.exportRefs, err = c.fillPayloadCapTable(ans.results)
c := ans.lockedConn()
ans.exportRefs, err = c.fillPayloadCapTable(ans.returner.results)
if err != nil {
ans.c.er.ReportError(rpcerr.Annotate(err, "send return"))
c.er.ReportError(rpcerr.Annotate(err, "send return"))
}
// Continue. Don't fail to send return if cap table isn't fully filled.

select {
case <-ans.c.bgctx.Done():
case <-c.bgctx.Done():
// We're not going to send the message after all, so don't forget to release it.
ans.msgReleaser.Decr()
rl.Add(ans.returner.msgReleaser.Decr)
ans.sendMsg = nil
default:
}
}

func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error {
c.assertIs(ans.c)

func (ans *ansent) completeSendReturn(rl *releaseList) error {
ans.pcall = nil
ans.flags |= resultsReady

Expand All @@ -277,16 +289,16 @@ func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error {
// the cancel variant of return.
ans.promise.Reject(rpcerr.Failed(errors.New("received finish before return")))
} else {
ans.promise.Resolve(ans.results.Content())
ans.promise.Resolve(ans.returner.results.Content())
}
ans.promise = nil
}
ans.sendMsg(c)
ans.sendMsg()
}

ans.flags |= returnSent
if fin {
return ans.destroy(c, rl)
return ans.destroy(rl)
}
return nil
}
Expand All @@ -295,32 +307,30 @@ func (ans *answer) completeSendReturn(c *lockedConn, rl *releaseList) error {
//
// The caller MUST be holding onto ans.c.lk. sendException MUST NOT
// be called if sendReturn was previously called.
func (ans *answer) sendException(c *lockedConn, rl *releaseList, ex error) {
ans.prepareSendException(c, rl, ex)
ans.completeSendException(c, rl)
func (ans *ansent) sendException(rl *releaseList, ex error) {
ans.prepareSendException(rl, ex)
ans.completeSendException(rl)
}

func (ans *answer) prepareSendException(c *lockedConn, rl *releaseList, ex error) {
c.assertIs(ans.c)
func (ans *ansent) prepareSendException(rl *releaseList, ex error) {
ans.err = ex

c := ans.lockedConn()
select {
case <-ans.c.bgctx.Done():
case <-c.bgctx.Done():
default:
// Send exception.
if e, err := ans.ret.NewException(); err != nil {
ans.c.er.ReportError(exc.WrapError("send exception", err))
if e, err := ans.returner.ret.NewException(); err != nil {
c.er.ReportError(exc.WrapError("send exception", err))
ans.sendMsg = nil
} else if err := e.MarshalError(ex); err != nil {
ans.c.er.ReportError(exc.WrapError("send exception", err))
c.er.ReportError(exc.WrapError("send exception", err))
ans.sendMsg = nil
}
}
}

func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) {
c.assertIs(ans.c)

func (ans *ansent) completeSendException(rl *releaseList) {
ex := ans.err
ans.pcall = nil
ans.flags |= resultsReady
Expand All @@ -330,13 +340,13 @@ func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) {
ans.promise = nil
}
if ans.sendMsg != nil {
ans.sendMsg(c)
ans.sendMsg()
}
ans.flags |= returnSent
if ans.flags.Contains(finishReceived) {
// destroy will never return an error because sendException does
// create any exports.
_ = ans.destroy(c, rl)
_ = ans.destroy(rl)
}
}

Expand All @@ -345,11 +355,10 @@ func (ans *answer) completeSendException(c *lockedConn, rl *releaseList) {
// The caller must be holding onto ans.c.lk.
//
// shutdown has its own strategy for cleaning up an answer.
func (ans *answer) destroy(c *lockedConn, rl *releaseList) error {
c.assertIs(ans.c)

rl.Add(ans.msgReleaser.Decr)
delete(c.lk.answers, ans.id)
func (ans *ansent) destroy(rl *releaseList) error {
rl.Add(ans.returner.msgReleaser.Decr)
c := ans.lockedConn()
delete(c.lk.answers, ans.returner.id)
if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 {
return nil

Expand Down
Loading