Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use refcounting for results #374

Merged
merged 8 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions internal/rc/rc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package rc

import "sync/atomic"

type Releaser struct {
release func()
refcount int32
}

func NewReleaser(init int32, release func()) *Releaser {
return &Releaser{
release: release,
refcount: init,
}
}

func (rc *Releaser) Decr() {
newCount := atomic.AddInt32(&rc.refcount, -1)
if newCount == 0 {
rc.release()
rc.release = nil
} else if newCount < 0 {
panic("Decremented an already-zero refcount")
}
lthibault marked this conversation as resolved.
Show resolved Hide resolved
}

func (rc *Releaser) Incr() {
newCount := atomic.AddInt32(&rc.refcount, 1)
if newCount == 1 {
panic("Incremented an already-zero refcount")
}
}
62 changes: 25 additions & 37 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"

"capnproto.org/go/capnp/v3"
"capnproto.org/go/capnp/v3/exc"
"capnproto.org/go/capnp/v3/internal/rc"
"capnproto.org/go/capnp/v3/internal/syncutil"
rpccp "capnproto.org/go/capnp/v3/std/capnp/rpc"
)
Expand All @@ -33,9 +33,9 @@ type answer struct {
// sendMsg sends the return message. The caller MUST NOT hold ans.c.mu.
sendMsg func()

// releaseMsg releases the return message. The caller MUST NOT hold
// ans.c.mu.
releaseMsg capnp.ReleaseFunc
// msgReleaser releases the return message when its refcount hits zero.
// The caller MUST NOT hold ans.c.mu.
msgReleaser *rc.Releaser

// results is the memoized answer to ret.Results().
// Set by AllocResults and setBootstrap, but contents can only be read
Expand All @@ -48,12 +48,6 @@ type answer struct {
// lifetime.
flags answerFlags

// resultCapTable is the CapTable for results. It is not kept in the
// results message because CapTable cannot be used once results are
// sent. However, the capabilities need to be retained for promised
// answer targets.
resultCapTable []capnp.Client

// exportRefs is the number of references to exports placed in the
// results.
exportRefs map[exportID]uint32
Expand Down Expand Up @@ -101,8 +95,10 @@ func errorAnswer(c *Conn, id answerID, err error) *answer {
}
}

// newReturn creates a new Return message.
func (c *Conn) newReturn(ctx context.Context) (rpccp.Return, func(), capnp.ReleaseFunc, error) {
// newReturn creates a new Return message. The returned Releaser will release the message when
// all references to it are dropped; the caller is responsible for one reference. This will not
// happen before the message is sent, as the returned send function retains a reference.
func (c *Conn) newReturn(ctx context.Context) (_ rpccp.Return, sendMsg func(), _ *rc.Releaser, _ error) {
msg, send, releaseMsg, err := c.transport.NewMessage()
if err != nil {
return rpccp.Return{}, nil, nil, rpcerr.Failedf("create return: %w", err)
Expand All @@ -117,24 +113,19 @@ func (c *Conn) newReturn(ctx context.Context) (rpccp.Return, func(), capnp.Relea
// until the local vat is done with it. We therefore implement a simple
// ref-counting mechanism such that 'release' must be called twice before
// 'releaseMsg' is called.
ref := int32(2)
release := func() {
if atomic.AddInt32(&ref, -1) == 0 {
releaseMsg()
}
}
releaser := rc.NewReleaser(2, releaseMsg)

return ret, func() {
c.sender.Send(asyncSend{
send: send,
release: release,
release: releaser.Decr,
callback: func(err error) {
if err != nil {
c.er.ReportError(fmt.Errorf("send return: %w", err))
}
},
})
}, release, nil
}, releaser, nil
}

// setPipelineCaller sets ans.pcall to pcall if the answer has not
Expand Down Expand Up @@ -170,11 +161,11 @@ func (ans *answer) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error) {
// setBootstrap sets the results to an interface pointer, stealing the
// reference.
func (ans *answer) setBootstrap(c capnp.Client) error {
if ans.ret.HasResults() || len(ans.ret.Message().CapTable) > 0 || len(ans.resultCapTable) > 0 {
if ans.ret.HasResults() || len(ans.ret.Message().CapTable) > 0 {
panic("setBootstrap called after creating results")
}
// Add the capability to the table early to avoid leaks if setBootstrap fails.
ans.resultCapTable = []capnp.Client{c}
ans.ret.Message().CapTable = []capnp.Client{c}

var err error
ans.results, err = ans.ret.NewResults()
Expand All @@ -192,9 +183,6 @@ func (ans *answer) setBootstrap(c capnp.Client) error {
//
// The caller MUST NOT hold ans.c.mu.
func (ans *answer) Return(e error) {
if ans.results.IsValid() {
ans.resultCapTable = ans.results.Message().CapTable
}
ans.c.mu.Lock()
if e != nil {
rl := ans.sendException(e)
Expand Down Expand Up @@ -231,22 +219,25 @@ func (ans *answer) Return(e error) {
// Finish with releaseResultCaps set to true, then sendReturn returns
// the number of references to be subtracted from each export.
//
// The caller MUST be holding onto ans.c.mu. The result's capability table
// MUST have been extracted into ans.resultCapTable before calling sendReturn.
// sendReturn MUST NOT be called if sendException was previously called.
// The caller MUST be holding onto ans.c.mu. sendReturn MUST NOT be
// called if sendException was previously called.
func (ans *answer) sendReturn() (releaseList, error) {
ans.pcall = nil
ans.flags |= resultsReady

var err error
ans.exportRefs, err = ans.c.fillPayloadCapTable(ans.results, ans.resultCapTable)
ans.exportRefs, err = ans.c.fillPayloadCapTable(ans.results)
if err != nil {
// We're not going to send the message after all, so don't forget to release it.
ans.msgReleaser.Decr()
ans.c.er.ReportError(rpcerr.Annotate(err, "send return"))
}
// Continue. Don't fail to send return if cap table isn't fully filled.

select {
case <-ans.c.bgctx.Done():
// We're not going to send the message after all, so don't forget to release it.
ans.msgReleaser.Decr()
default:
fin := ans.flags.Contains(finishReceived)
if ans.promise != nil {
Expand Down Expand Up @@ -278,9 +269,8 @@ func (ans *answer) sendReturn() (releaseList, error) {

// sendException sends an exception on the answer's return message.
//
// The caller MUST be holding onto ans.c.mu. The result's capability table
// MUST have been extracted into ans.resultCapTable before calling sendException.
// sendException MUST NOT be called if sendReturn was previously called.
// The caller MUST be holding onto ans.c.mu. sendException MUST NOT
// be called if sendReturn was previously called.
func (ans *answer) sendException(ex error) releaseList {
ans.err = ex
ans.pcall = nil
Expand Down Expand Up @@ -330,13 +320,11 @@ func (ans *answer) sendException(ex error) releaseList {
//
// shutdown has its own strategy for cleaning up an answer.
func (ans *answer) destroy() (releaseList, error) {
defer syncutil.Without(&ans.c.mu, ans.releaseMsg)
defer syncutil.Without(&ans.c.mu, ans.msgReleaser.Decr)
delete(ans.c.answers, ans.id)
rl := releaseList(ans.resultCapTable)
if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 {
return rl, nil
return nil, nil

}
exportReleases, err := ans.c.releaseExportRefs(ans.exportRefs)
return append(rl, exportReleases...), err
return ans.c.releaseExportRefs(ans.exportRefs)
}
8 changes: 6 additions & 2 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,12 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
// reference counts that have been added to the exports table.
//
// The caller must be holding onto c.mu.
func (c *Conn) fillPayloadCapTable(payload rpccp.Payload, clients []capnp.Client) (map[exportID]uint32, error) {
if !payload.IsValid() || len(clients) == 0 {
func (c *Conn) fillPayloadCapTable(payload rpccp.Payload) (map[exportID]uint32, error) {
if !payload.IsValid() {
return nil, nil
}
clients := payload.Message().CapTable
if len(clients) == 0 {
return nil, nil
}
list, err := payload.NewCapTable(int32(len(clients)))
Expand Down
4 changes: 1 addition & 3 deletions rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,12 @@ func (c *Conn) newImportCallMessage(msg rpccp.Message, imp importID, qid questio
if s.PlaceArgs == nil {
return nil
}
m := args.Message()
if err := s.PlaceArgs(args); err != nil {
return rpcerr.Failedf("place arguments: %w", err)
}
clients := m.CapTable
syncutil.With(&c.mu, func() {
// TODO(soon): save param refs
_, err = c.fillPayloadCapTable(payload, clients)
_, err = c.fillPayloadCapTable(payload)
})
if err != nil {
return rpcerr.Annotatef(err, "build call message")
Expand Down
4 changes: 1 addition & 3 deletions rpc/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,12 @@ func (c *Conn) newPipelineCallMessage(msg rpccp.Message, tgt questionID, transfo
if s.PlaceArgs == nil {
return nil
}
m := args.Message()
if err := s.PlaceArgs(args); err != nil {
return rpcerr.Failedf("place arguments: %w", err)
}
clients := m.CapTable
syncutil.With(&c.mu, func() {
// TODO(soon): save param refs
_, err = c.fillPayloadCapTable(payload, clients)
_, err = c.fillPayloadCapTable(payload)
})

if err != nil {
Expand Down
46 changes: 18 additions & 28 deletions rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,8 @@ func (c *Conn) liftEmbargoes(embargoes []*embargo) {

func (c *Conn) releaseAnswers(answers map[answerID]*answer) {
for _, a := range answers {
if a != nil {
releaseList(a.resultCapTable).release()
if a.releaseMsg != nil {
a.releaseMsg()
}
if a != nil && a.msgReleaser != nil {
a.msgReleaser.Decr()
}
}
}
Expand Down Expand Up @@ -617,7 +614,7 @@ func (c *Conn) handleBootstrap(ctx context.Context, id answerID) error {
)

syncutil.Without(&c.mu, func() {
ans.ret, ans.sendMsg, ans.releaseMsg, err = c.newReturn(ctx)
ans.ret, ans.sendMsg, ans.msgReleaser, err = c.newReturn(ctx)
if err == nil {
ans.ret.SetAnswerId(uint32(id))
ans.ret.SetReleaseParamCaps(false)
Expand Down Expand Up @@ -692,7 +689,7 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn

// Create return message.
c.mu.Unlock()
ret, send, releaseRet, err := c.newReturn(ctx)
ret, send, retReleaser, err := c.newReturn(ctx)
if err != nil {
err = rpcerr.Annotate(err, "incoming call")
syncutil.With(&c.mu, func() {
Expand All @@ -708,11 +705,11 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
// Find target and start call.
c.mu.Lock()
ans := &answer{
c: c,
id: id,
ret: ret,
sendMsg: send,
releaseMsg: releaseRet,
c: c,
id: id,
ret: ret,
sendMsg: send,
msgReleaser: retReleaser,
}
c.answers[id] = ans
if parseErr != nil {
Expand All @@ -738,9 +735,9 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
if ent == nil {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.releaseMsg = nil
ans.msgReleaser = nil
c.mu.Unlock()
releaseRet()
retReleaser.Decr()
releaseCall()
return rpcerr.Failedf("incoming call: unknown export ID %d", id)
}
Expand All @@ -764,9 +761,9 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
if tgtAns == nil || tgtAns.flags.Contains(finishReceived) {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.releaseMsg = nil
ans.msgReleaser = nil
c.mu.Unlock()
releaseRet()
retReleaser.Decr()
releaseCall()
return rpcerr.Failedf("incoming call: use of unknown or finished answer ID %d for promised answer target", p.target.promisedAnswer)
}
Expand Down Expand Up @@ -806,10 +803,10 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
switch {
case sub.IsValid() && !iface.IsValid():
tgt = capnp.ErrorClient(rpcerr.Failed(ErrNotACapability))
case !iface.IsValid() || int64(iface.Capability()) >= int64(len(tgtAns.resultCapTable)):
case !iface.IsValid() || int64(iface.Capability()) >= int64(len(tgtAns.results.Message().CapTable)):
tgt = capnp.Client{}
default:
tgt = tgtAns.resultCapTable[iface.Capability()]
tgt = tgtAns.results.Message().CapTable[iface.Capability()]
}
c.tasks.Add(1) // will be finished by answer.Return
var callCtx context.Context
Expand Down Expand Up @@ -1226,14 +1223,7 @@ func (c *Conn) recvCapReceiverAnswer(ans *answer, transform []capnp.PipelineOp)
return capnp.ErrorClient(rpcerr.Failedf("Result is not a capability"))
}

// We can't just call Client(), because the CapTable has been cleared; instead,
// look it up in resultCapTable ourselves:
capId := int(iface.Capability())
if capId < 0 || capId >= len(ans.resultCapTable) {
return capnp.Client{}
}

return ans.resultCapTable[capId].AddRef()
return iface.Client().AddRef()
}

// Returns whether the client should be treated as local, for the purpose of
Expand Down Expand Up @@ -1392,12 +1382,12 @@ func (c *Conn) handleDisembargo(ctx context.Context, d rpccp.Disembargo, release
}

iface := ptr.Interface()
if !iface.IsValid() || int64(iface.Capability()) >= int64(len(ans.resultCapTable)) {
if !iface.IsValid() || int64(iface.Capability()) >= int64(len(ans.results.Message().CapTable)) {
err = rpcerr.Failedf("incoming disembargo: sender loopback requested on a capability that is not an import")
return
}

client := ans.resultCapTable[iface.Capability()] //.AddRef()
client := iface.Client()

var ok bool
syncutil.Without(&c.mu, func() {
Expand Down