Skip to content

Commit

Permalink
Use refcounting for results
Browse files Browse the repository at this point in the history
There is no functional change here yet; this just factors the recounting
we were already doing into a package, and exposes the ability to
refcount to other parts of the code.

This is a stepping stone towards addressing #349; the idea is that the
method receiver should get a reference to this too to keep it alive for
the right amount of time.
  • Loading branch information
zenhack committed Dec 7, 2022
1 parent eac21b6 commit a4c11d8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 28 deletions.
31 changes: 31 additions & 0 deletions internal/rc/rc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package rc

import "sync/atomic"

type Releaser struct {
release func()
refcount atomic.Int32
}

func NewReleaser(init int32, release func()) *Releaser {
ret := &Releaser{
release: release,
}
ret.refcount.Store(init)
return ret
}

func (rc *Releaser) Decr() {
newCount := rc.refcount.Add(-1)
if newCount == 0 {
rc.release()
rc.release = nil
}
}

func (rc *Releaser) Incr() {
newCount := rc.refcount.Add(1)
if newCount == 1 {
panic("Incremented an already-zero refcount")
}
}
27 changes: 12 additions & 15 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"

"capnproto.org/go/capnp/v3"
"capnproto.org/go/capnp/v3/exc"
"capnproto.org/go/capnp/v3/internal/rc"
"capnproto.org/go/capnp/v3/internal/syncutil"
rpccp "capnproto.org/go/capnp/v3/std/capnp/rpc"
)
Expand All @@ -33,9 +33,9 @@ type answer struct {
// sendMsg sends the return message. The caller MUST NOT hold ans.c.mu.
sendMsg func()

// releaseMsg releases the return message. The caller MUST NOT hold
// ans.c.mu.
releaseMsg capnp.ReleaseFunc
// msgReleaser releases the return message when its refcount hits zero.
// The caller MUST NOT hold ans.c.mu.
msgReleaser *rc.Releaser

// results is the memoized answer to ret.Results().
// Set by AllocResults and setBootstrap, but contents can only be read
Expand Down Expand Up @@ -95,8 +95,10 @@ func errorAnswer(c *Conn, id answerID, err error) *answer {
}
}

// newReturn creates a new Return message.
func (c *Conn) newReturn(ctx context.Context) (rpccp.Return, func(), capnp.ReleaseFunc, error) {
// newReturn creates a new Return message. The returned Releaser will release the message when
// all references to it are dropped; the caller is responsible for one reference. This will not
// happen before the message is sent, as the returned send function retains a reference.
func (c *Conn) newReturn(ctx context.Context) (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ error) {
msg, send, releaseMsg, err := c.transport.NewMessage()
if err != nil {
return rpccp.Return{}, nil, nil, rpcerr.Failedf("create return: %w", err)
Expand All @@ -111,24 +113,19 @@ func (c *Conn) newReturn(ctx context.Context) (rpccp.Return, func(), capnp.Relea
// until the local vat is done with it. We therefore implement a simple
// ref-counting mechanism such that 'release' must be called twice before
// 'releaseMsg' is called.
ref := int32(2)
release := func() {
if atomic.AddInt32(&ref, -1) == 0 {
releaseMsg()
}
}
releaser := rc.NewReleaser(2, releaseMsg)

return ret, func() {
c.sender.Send(asyncSend{
send: send,
release: release,
release: releaser.Decr,
callback: func(err error) {
if err != nil {
c.er.ReportError(fmt.Errorf("send return: %w", err))
}
},
})
}, release, nil
}, releaser, nil
}

// setPipelineCaller sets ans.pcall to pcall if the answer has not
Expand Down Expand Up @@ -319,7 +316,7 @@ func (ans *answer) sendException(ex error) releaseList {
//
// shutdown has its own strategy for cleaning up an answer.
func (ans *answer) destroy() (releaseList, error) {
defer syncutil.Without(&ans.c.mu, ans.releaseMsg)
defer syncutil.Without(&ans.c.mu, ans.msgReleaser.Decr)
delete(ans.c.answers, ans.id)
if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 {
return nil, nil
Expand Down
26 changes: 13 additions & 13 deletions rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ func (c *Conn) liftEmbargoes(embargoes []*embargo) {

func (c *Conn) releaseAnswers(answers map[answerID]*answer) {
for _, a := range answers {
if a != nil && a.releaseMsg != nil {
a.releaseMsg()
if a != nil && a.msgReleaser != nil {
a.msgReleaser.Decr()
}
}
}
Expand Down Expand Up @@ -614,7 +614,7 @@ func (c *Conn) handleBootstrap(ctx context.Context, id answerID) error {
)

syncutil.Without(&c.mu, func() {
ans.ret, ans.sendMsg, ans.releaseMsg, err = c.newReturn(ctx)
ans.ret, ans.sendMsg, ans.msgReleaser, err = c.newReturn(ctx)
if err == nil {
ans.ret.SetAnswerId(uint32(id))
ans.ret.SetReleaseParamCaps(false)
Expand Down Expand Up @@ -689,7 +689,7 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn

// Create return message.
c.mu.Unlock()
ret, send, releaseRet, err := c.newReturn(ctx)
ret, send, retReleaser, err := c.newReturn(ctx)
if err != nil {
err = rpcerr.Annotate(err, "incoming call")
syncutil.With(&c.mu, func() {
Expand All @@ -705,11 +705,11 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
// Find target and start call.
c.mu.Lock()
ans := &answer{
c: c,
id: id,
ret: ret,
sendMsg: send,
releaseMsg: releaseRet,
c: c,
id: id,
ret: ret,
sendMsg: send,
msgReleaser: retReleaser,
}
c.answers[id] = ans
if parseErr != nil {
Expand All @@ -735,9 +735,9 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
if ent == nil {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.releaseMsg = nil
ans.msgReleaser = nil
c.mu.Unlock()
releaseRet()
retReleaser.Decr()
releaseCall()
return rpcerr.Failedf("incoming call: unknown export ID %d", id)
}
Expand All @@ -761,9 +761,9 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
if tgtAns == nil || tgtAns.flags.Contains(finishReceived) {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.releaseMsg = nil
ans.msgReleaser = nil
c.mu.Unlock()
releaseRet()
retReleaser.Decr()
releaseCall()
return rpcerr.Failedf("incoming call: use of unknown or finished answer ID %d for promised answer target", p.target.promisedAnswer)
}
Expand Down

0 comments on commit a4c11d8

Please sign in to comment.