Skip to content

Commit

Permalink
Don't export Message.capTable. Clean up tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
lthibault committed Apr 2, 2023
1 parent 73e5ddf commit 1cba92d
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 103 deletions.
10 changes: 5 additions & 5 deletions capability.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (i Interface) Client() Client {
if msg == nil {
return Client{}
}
tab := msg.CapTable
tab := msg.capTable
if int64(i.cap) >= int64(len(tab)) {
return Client{}
}
Expand Down Expand Up @@ -400,10 +400,10 @@ func (c Client) SendCall(ctx context.Context, s Send) (*Answer, ReleaseFunc) {

// SendStreamCall is like SendCall except that:
//
// 1. It does not return an answer for the eventual result.
// 2. If the call returns an error, all future calls on this
// client will return the same error (without starting
// the method or calling PlaceArgs).
// 1. It does not return an answer for the eventual result.
// 2. If the call returns an error, all future calls on this
// client will return the same error (without starting
// the method or calling PlaceArgs).
func (c Client) SendStreamCall(ctx context.Context, s Send) error {
streamError := mutex.With1(&c.state, func(c *clientState) error {
err := c.stream.err
Expand Down
38 changes: 32 additions & 6 deletions message.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,24 @@ const (

const maxDepth = ^uint(0)

type CapTable []Client

func (ct CapTable) Len() int {
return len(ct)
}

func (ct CapTable) Contains(ifc Interface) bool {
return ifc.IsValid() && ifc.Capability() < CapabilityID(ct.Len())
}

func (ct CapTable) Get(ifc Interface) (c Client) {
if ct.Contains(ifc) {
c = ct[ifc.Capability()]
}

return
}

// A Message is a tree of Cap'n Proto objects, split into one or more
// segments of contiguous memory. The only required field is Arena.
// A Message is safe to read from multiple goroutines.
Expand All @@ -35,14 +53,14 @@ type Message struct {

Arena Arena

// CapTable is the indexed list of the clients referenced in the
// capTable is the indexed list of the clients referenced in the
// message. Capability pointers inside the message will use this table
// to map pointers to Clients. The table is usually populated by the
// RPC system.
//
// See https://capnproto.org/encoding.html#capabilities-interfaces for
// more details on the capability table.
CapTable []Client
capTable CapTable

// TraverseLimit limits how many total bytes of data are allowed to be
// traversed while reading. Traversal is counted when a Struct or
Expand Down Expand Up @@ -105,7 +123,7 @@ func (m *Message) Release() {
// the Message, releases all clients in the cap table, and
// releases the current Arena, so use with caution.
func (m *Message) Reset(arena Arena) (first *Segment, err error) {
for _, c := range m.CapTable {
for _, c := range m.capTable {
c.Release()
}

Expand All @@ -121,7 +139,7 @@ func (m *Message) Reset(arena Arena) (first *Segment, err error) {
Arena: arena,
TraverseLimit: m.TraverseLimit,
DepthLimit: m.DepthLimit,
CapTable: m.CapTable[:0],
capTable: m.capTable[:0],
segs: m.segs,
}

Expand Down Expand Up @@ -222,12 +240,20 @@ func (m *Message) SetRoot(p Ptr) error {
return nil
}

func (m *Message) CapTable() CapTable {
return m.capTable
}

func (m *Message) SetCapTable(ct []Client) {
m.capTable = ct
}

// AddCap appends a capability to the message's capability table and
// returns its ID. It "steals" c's reference: the Message will release
// the client when calling Reset.
func (m *Message) AddCap(c Client) CapabilityID {
n := CapabilityID(len(m.CapTable))
m.CapTable = append(m.CapTable, c)
n := CapabilityID(len(m.capTable))
m.capTable = append(m.capTable, c)
return n
}

Expand Down
67 changes: 22 additions & 45 deletions message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,64 +401,41 @@ func TestAddCap(t *testing.T) {

// Simple case: distinct non-nil clients.
id1 := msg.AddCap(client1.AddRef())
if id1 != 0 {
t.Errorf("first AddCap ID = %d; want 0", id1)
}
if len(msg.CapTable) != 1 {
t.Errorf("after first AddCap, len(msg.CapTable) = %d; want 1", len(msg.CapTable))
} else if !msg.CapTable[0].IsSame(client1) {
t.Errorf("msg.CapTable[0] = %v; want %v", msg.CapTable[0], client1)
}
assert.Equal(t, CapabilityID(0), id1, "first capability ID should be 0")
assert.Len(t, msg.capTable, 1, "should have exactly one capability in the capTable")
assert.True(t, msg.capTable[0].IsSame(client1), "client does not match entry in cap table")

id2 := msg.AddCap(client2.AddRef())
if id2 != 1 {
t.Errorf("second AddCap ID = %d; want 1", id2)
}
if len(msg.CapTable) != 2 {
t.Errorf("after second AddCap, len(msg.CapTable) = %d; want 2", len(msg.CapTable))
} else if !msg.CapTable[1].IsSame(client2) {
t.Errorf("msg.CapTable[1] = %v; want %v", msg.CapTable[1], client1)
}
assert.Equal(t, CapabilityID(1), id2, "second capability ID should be 1")
assert.Len(t, msg.capTable, 2, "should have exactly two capabilities in the capTable")
assert.True(t, msg.capTable[1].IsSame(client2), "client does not match entry in cap table")

// nil client
id3 := msg.AddCap(Client{})
if id3 != 2 {
t.Errorf("third AddCap ID = %d; want 2", id3)
}
if len(msg.CapTable) != 3 {
t.Errorf("after third AddCap, len(msg.CapTable) = %d; want 3", len(msg.CapTable))
} else if !msg.CapTable[2].IsSame(Client{}) {
t.Errorf("msg.CapTable[2] = %v; want <nil>", msg.CapTable[2])
}
assert.Equal(t, CapabilityID(2), id3, "third capability ID should be 2")
assert.Len(t, msg.capTable, 3, "should have exactly three capabilities in the capTable")
assert.True(t, msg.capTable[2].IsSame(Client{}), "client does not match entry in cap table")

// AddCap should not attempt to deduplicate.
id4 := msg.AddCap(client1.AddRef())
if id4 != 3 {
t.Errorf("fourth AddCap ID = %d; want 3", id4)
}
if len(msg.CapTable) != 4 {
t.Errorf("after fourth AddCap, len(msg.CapTable) = %d; want 4", len(msg.CapTable))
} else if !msg.CapTable[3].IsSame(client1) {
t.Errorf("msg.CapTable[3] = %v; want %v", msg.CapTable[3], client1)
}
assert.Equal(t, CapabilityID(3), id4, "fourth capability ID should be 3")
assert.Len(t, msg.capTable, 4, "should have exactly four capabilities in the capTable")
assert.True(t, msg.capTable[3].IsSame(client1), "client does not match entry in cap table")

// Verify that AddCap steals the reference: once client1 and client2
// and the message's capabilities released, hook1 and hook2 should be
// shut down. If they are not, then AddCap created a new reference.
client1.Release()
if hook1.shutdowns > 0 {
t.Error("hook1 shut down before releasing msg.CapTable")
}
assert.Zero(t, hook1.shutdowns, "hook1 shut down before releasing msg.capTable")
client2.Release()
if hook2.shutdowns > 0 {
t.Error("hook2 shut down before releasing msg.CapTable")
}
for _, c := range msg.CapTable {
assert.Zero(t, hook2.shutdowns, "hook2 shut down before releasing msg.capTable")

for _, c := range msg.capTable {
c.Release()
}
if hook1.shutdowns == 0 {
t.Error("hook1 not shut down after releasing msg.CapTable")
}
if hook2.shutdowns == 0 {
t.Error("hook2 not shut down after releasing msg.CapTable")
}

assert.NotZero(t, hook1.shutdowns, "hook1 not shut down after releasing msg.capTable")
assert.NotZero(t, hook2.shutdowns, "hook2 not shut down after releasing msg.capTable")
}

func TestFirstSegmentMessage_SingleSegment(t *testing.T) {
Expand Down
32 changes: 16 additions & 16 deletions pointer.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,21 +269,21 @@ func isZeroFilled(b []byte) bool {
//
// Equality is defined to be:
//
// - Two structs are equal iff all of their fields are equal. If one
// struct has more fields than the other, the extra fields must all be
// zero.
// - Two lists are equal iff they have the same length and their
// corresponding elements are equal. If one list is a list of
// primitives and the other is a list of structs, then the list of
// primitives is treated as if it was a list of structs with the
// element value as the sole field.
// - Two interfaces are equal iff they point to a capability created by
// the same call to NewClient or they are referring to the same
// capability table index in the same message. The latter is
// significant when the message's capability table has not been
// populated.
// - Two null pointers are equal.
// - All other combinations of things are not equal.
// - Two structs are equal iff all of their fields are equal. If one
// struct has more fields than the other, the extra fields must all be
// zero.
// - Two lists are equal iff they have the same length and their
// corresponding elements are equal. If one list is a list of
// primitives and the other is a list of structs, then the list of
// primitives is treated as if it was a list of structs with the
// element value as the sole field.
// - Two interfaces are equal iff they point to a capability created by
// the same call to NewClient or they are referring to the same
// capability table index in the same message. The latter is
// significant when the message's capability table has not been
// populated.
// - Two null pointers are equal.
// - All other combinations of things are not equal.
func Equal(p1, p2 Ptr) (bool, error) {
if !p1.IsValid() && !p2.IsValid() {
return true, nil
Expand Down Expand Up @@ -376,7 +376,7 @@ func Equal(p1, p2 Ptr) (bool, error) {
if i1.Capability() == i2.Capability() {
return true, nil
}
ntab := len(i1.Message().CapTable)
ntab := len(i1.Message().capTable)
if int64(i1.Capability()) >= int64(ntab) || int64(i2.Capability()) >= int64(ntab) {
return false, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pointer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestEqual(t *testing.T) {
plistB, _ := NewPointerList(seg, 1)
plistB.Set(0, structB.ToPtr())
ec := ErrorClient(errors.New("boo"))
msg.CapTable = []Client{
msg.capTable = []Client{
0: ec,
1: ec,
2: ErrorClient(errors.New("another boo")),
Expand Down
4 changes: 2 additions & 2 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ func (ans *ansReturner) AllocResults(sz capnp.ObjectSize) (capnp.Struct, error)
// setBootstrap sets the results to an interface pointer, stealing the
// reference.
func (ans *ansReturner) setBootstrap(c capnp.Client) error {
if ans.ret.HasResults() || len(ans.ret.Message().CapTable) > 0 {
if ans.ret.HasResults() || ans.ret.Message().CapTable().Len() > 0 {
panic("setBootstrap called after creating results")
}
// Add the capability to the table early to avoid leaks if setBootstrap fails.
ans.ret.Message().CapTable = []capnp.Client{c}
ans.ret.Message().SetCapTable([]capnp.Client{c})

var err error
ans.results, err = ans.ret.NewResults()
Expand Down
8 changes: 4 additions & 4 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,11 @@ func (c *lockedConn) fillPayloadCapTable(payload rpccp.Payload) (map[exportID]ui
if !payload.IsValid() {
return nil, nil
}
clients := payload.Message().CapTable
if len(clients) == 0 {
clients := payload.Message().CapTable()
if clients.Len() == 0 {
return nil, nil
}
list, err := payload.NewCapTable(int32(len(clients)))
list, err := payload.NewCapTable(int32(clients.Len()))
if err != nil {
return nil, rpcerr.WrapFailed("payload capability table", err)
}
Expand All @@ -284,7 +284,7 @@ func (c *lockedConn) fillPayloadCapTable(payload rpccp.Payload) (map[exportID]ui
continue
}
if refs == nil {
refs = make(map[exportID]uint32, len(clients)-i)
refs = make(map[exportID]uint32, clients.Len()-i)
}
refs[id]++
}
Expand Down
23 changes: 9 additions & 14 deletions rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -924,14 +924,12 @@ func (c *Conn) handleCall(ctx context.Context, in transport.IncomingMessage) err
}
iface := sub.Interface()
var tgt capnp.Client
switch {
case sub.IsValid() && !iface.IsValid():
if sub.IsValid() && !iface.IsValid() {
tgt = capnp.ErrorClient(rpcerr.Failed(ErrNotACapability))
case !iface.IsValid() || int64(iface.Capability()) >= int64(len(tgtAns.returner.results.Message().CapTable)):
tgt = capnp.Client{}
default:
tgt = tgtAns.returner.results.Message().CapTable[iface.Capability()]
} else {
tgt = tgtAns.returner.results.Message().CapTable().Get(iface)
}

c.tasks.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
Expand Down Expand Up @@ -1231,15 +1229,12 @@ func (c *lockedConn) parseReturn(rl *releaseList, ret rpccp.Return, called [][]c

var embargoCaps uintSet
var disembargoes []senderLoopback
mtab := ret.Message().CapTable
mtab := ret.Message().CapTable()
for _, xform := range called {
p2, _ := capnp.Transform(content, xform)
iface := p2.Interface()
if !iface.IsValid() {
continue
}
i := iface.Capability()
if int64(i) >= int64(len(mtab)) || !locals.has(uint(i)) || embargoCaps.has(uint(i)) {
if !mtab.Contains(iface) || !locals.has(uint(i)) || embargoCaps.has(uint(i)) {
continue
}
var id embargoID
Expand Down Expand Up @@ -1502,7 +1497,7 @@ func (c *lockedConn) recvPayload(rl *releaseList, payload rpccp.Payload) (_ capn
// and just return.
return capnp.Ptr{}, nil, nil
}
if payload.Message().CapTable != nil {
if payload.Message().CapTable() != nil {
// RecvMessage likely violated its invariant.
return capnp.Ptr{}, nil, rpcerr.WrapFailed("read payload", ErrCapTablePopulated)
}
Expand Down Expand Up @@ -1533,7 +1528,7 @@ func (c *lockedConn) recvPayload(rl *releaseList, payload rpccp.Payload) (_ capn
locals.add(uint(i))
}
}
payload.Message().CapTable = mtab
payload.Message().SetCapTable(mtab)
return p, locals, nil
}

Expand Down Expand Up @@ -1656,7 +1651,7 @@ func (c *Conn) handleDisembargo(ctx context.Context, in transport.IncomingMessag
}

iface := ptr.Interface()
if !iface.IsValid() || int64(iface.Capability()) >= int64(len(ans.returner.results.Message().CapTable)) {
if !ans.returner.results.Message().CapTable().Contains(iface) {
err = rpcerr.Failed(errors.New(
"incoming disembargo: sender loopback requested on a capability that is not an import",
))
Expand Down
6 changes: 3 additions & 3 deletions rpc/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error))
}
// simulate mutating CapTable
callMsg.Message().Message().AddCap(capnp.ErrorClient(errors.New("foo")))
callMsg.Message().Message().CapTable = nil
callMsg.Message().Message().SetCapTable(nil)
capPtr := capnp.NewInterface(params.Segment(), 0).ToPtr()
if err := params.SetContent(capPtr); err != nil {
t.Fatal("SetContent:", err)
Expand All @@ -102,7 +102,7 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error))
if err != nil {
t.Fatal("t2.RecvMessage:", err)
}
if r1.Message().Message().CapTable != nil {
if r1.Message().Message().CapTable() != nil {
t.Error("t2.RecvMessage(ctx).Message().CapTable is not nil")
}
if r1.Message().Which() != rpccp.Message_Which_bootstrap {
Expand All @@ -124,7 +124,7 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error))
if err != nil {
t.Fatal("t2.RecvMessage:", err)
}
if r2.Message().Message().CapTable != nil {
if r2.Message().Message().CapTable() != nil {
t.Error("t2.RecvMessage(ctx).Message().CapTable is not nil")
}
if r2.Message().Which() != rpccp.Message_Which_call {
Expand Down
Loading

0 comments on commit 1cba92d

Please sign in to comment.