Skip to content

Commit

Permalink
Merge pull request #379 from zenhack/Conn.lk
Browse files Browse the repository at this point in the history
Conn: wrap fields protected by mu in struct.
  • Loading branch information
lthibault authored Dec 13, 2022
2 parents 1339225 + 1790bbf commit d105cc4
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 192 deletions.
40 changes: 20 additions & 20 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ type answer struct {
// entry is a placeholder until the remote vat cancels the call.
ret rpccp.Return

// sendMsg sends the return message. The caller MUST NOT hold ans.c.mu.
// sendMsg sends the return message. The caller MUST NOT hold ans.c.lk.
sendMsg func()

// msgReleaser releases the return message when its refcount hits zero.
// The caller MUST NOT hold ans.c.mu.
// The caller MUST NOT hold ans.c.lk.
msgReleaser *rc.Releaser

// results is the memoized answer to ret.Results().
Expand Down Expand Up @@ -129,11 +129,11 @@ func (c *Conn) newReturn(ctx context.Context) (_ rpccp.Return, sendMsg func(), _
}

// setPipelineCaller sets ans.pcall to pcall if the answer has not
// already returned. The caller MUST NOT hold ans.c.mu.
// already returned. The caller MUST NOT hold ans.c.lk.
//
// This also sets ans.promise to a new promise, wrapping pcall.
func (ans *answer) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) {
syncutil.With(&ans.c.mu, func() {
syncutil.With(&ans.c.lk, func() {
if !ans.flags.Contains(resultsReady) {
ans.pcall = pcall
ans.promise = capnp.NewPromise(m, pcall)
Expand Down Expand Up @@ -181,12 +181,12 @@ func (ans *answer) setBootstrap(c capnp.Client) error {

// Return sends the return message.
//
// The caller MUST NOT hold ans.c.mu.
// The caller MUST NOT hold ans.c.lk.
func (ans *answer) Return(e error) {
ans.c.mu.Lock()
ans.c.lk.Lock()
if e != nil {
rl := ans.sendException(e)
ans.c.mu.Unlock()
ans.c.lk.Unlock()
rl.release()
ans.pcalls.Wait()
ans.c.tasks.Done() // added by handleCall
Expand All @@ -202,13 +202,13 @@ func (ans *answer) Return(e error) {
ans.c.er.ReportError(err)
}

ans.c.mu.Unlock()
ans.c.lk.Unlock()
rl.release()
ans.pcalls.Wait()
return
}
}
ans.c.mu.Unlock()
ans.c.lk.Unlock()
rl.release()
ans.pcalls.Wait()
ans.c.tasks.Done() // added by handleCall
Expand All @@ -219,7 +219,7 @@ func (ans *answer) Return(e error) {
// Finish with releaseResultCaps set to true, then sendReturn returns
// the number of references to be subtracted from each export.
//
// The caller MUST be holding onto ans.c.mu. sendReturn MUST NOT be
// The caller MUST be holding onto ans.c.lk. sendReturn MUST NOT be
// called if sendException was previously called.
func (ans *answer) sendReturn() (releaseList, error) {
ans.pcall = nil
Expand Down Expand Up @@ -252,13 +252,13 @@ func (ans *answer) sendReturn() (releaseList, error) {
}
ans.promise = nil
}
ans.c.mu.Unlock()
ans.c.lk.Unlock()
ans.sendMsg()
if fin {
ans.c.mu.Lock()
ans.c.lk.Lock()
return ans.destroy()
}
ans.c.mu.Lock()
ans.c.lk.Lock()
}
ans.flags |= returnSent
if !ans.flags.Contains(finishReceived) {
Expand All @@ -269,7 +269,7 @@ func (ans *answer) sendReturn() (releaseList, error) {

// sendException sends an exception on the answer's return message.
//
// The caller MUST be holding onto ans.c.mu. sendException MUST NOT
// The caller MUST be holding onto ans.c.lk. sendException MUST NOT
// be called if sendReturn was previously called.
func (ans *answer) sendException(ex error) releaseList {
ans.err = ex
Expand All @@ -286,7 +286,7 @@ func (ans *answer) sendException(ex error) releaseList {
default:
// Send exception.
fin := ans.flags.Contains(finishReceived)
ans.c.mu.Unlock()
ans.c.lk.Unlock()
if e, err := ans.ret.NewException(); err != nil {
ans.c.er.ReportError(fmt.Errorf("send exception: %w", err))
} else {
Expand All @@ -298,11 +298,11 @@ func (ans *answer) sendException(ex error) releaseList {
}
}
if fin {
ans.c.mu.Lock()
ans.c.lk.Lock()
rl, _ := ans.destroy()
return rl
}
ans.c.mu.Lock()
ans.c.lk.Lock()
}
ans.flags |= returnSent
if !ans.flags.Contains(finishReceived) {
Expand All @@ -316,12 +316,12 @@ func (ans *answer) sendException(ex error) releaseList {

// destroy removes the answer from the table and returns the clients to
// release. The answer must have sent a return and received a finish.
// The caller must be holding onto ans.c.mu.
// The caller must be holding onto ans.c.lk.
//
// shutdown has its own strategy for cleaning up an answer.
func (ans *answer) destroy() (releaseList, error) {
defer syncutil.Without(&ans.c.mu, ans.msgReleaser.Decr)
delete(ans.c.answers, ans.id)
defer syncutil.Without(&ans.c.lk, ans.msgReleaser.Decr)
delete(ans.c.lk.answers, ans.id)
if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 {
return nil, nil

Expand Down
38 changes: 19 additions & 19 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ func (c *Conn) clearExportID(m *capnp.Metadata) {
}

// findExport returns the export entry with the given ID or nil if
// couldn't be found.
// couldn't be found. The caller must be holding c.mu
func (c *Conn) findExport(id exportID) *expent {
if int64(id) >= int64(len(c.exports)) {
if int64(id) >= int64(len(c.lk.exports)) {
return nil
}
return c.exports[id] // might be nil
return c.lk.exports[id] // might be nil
}

// releaseExport decreases the number of wire references to an export
Expand All @@ -63,8 +63,8 @@ func (c *Conn) releaseExport(id exportID, count uint32) (capnp.Client, error) {
switch {
case count == ent.wireRefs:
client := ent.client
c.exports[id] = nil
c.exportID.remove(uint32(id))
c.lk.exports[id] = nil
c.lk.exportID.remove(uint32(id))
metadata := client.State().Metadata
syncutil.With(metadata, func() {
c.clearExportID(metadata)
Expand Down Expand Up @@ -116,7 +116,7 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
state := client.State()
bv := state.Brand.Value
if ic, ok := bv.(*importClient); ok && ic.c == c {
if ent := c.imports[ic.id]; ent != nil && ent.generation == ic.generation {
if ent := c.lk.imports[ic.id]; ent != nil && ent.generation == ic.generation {
d.SetReceiverHosted(uint32(ic.id))
return 0, false, nil
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
defer state.Metadata.Unlock()
id, ok := c.findExportID(state.Metadata)
if ok {
ent := c.exports[id]
ent := c.lk.exports[id]
ent.wireRefs++
d.SetSenderHosted(uint32(id))
return id, true, nil
Expand All @@ -160,11 +160,11 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
client: client.AddRef(),
wireRefs: 1,
}
id = exportID(c.exportID.next())
if int64(id) == int64(len(c.exports)) {
c.exports = append(c.exports, ee)
id = exportID(c.lk.exportID.next())
if int64(id) == int64(len(c.lk.exports)) {
c.lk.exports = append(c.lk.exports, ee)
} else {
c.exports[id] = ee
c.lk.exports[id] = ee
}
c.setExportID(state.Metadata, id)
d.SetSenderHosted(uint32(id))
Expand Down Expand Up @@ -217,28 +217,28 @@ type embargo struct {
//
// The caller must be holding onto c.mu.
func (c *Conn) embargo(client capnp.Client) (embargoID, capnp.Client) {
id := embargoID(c.embargoID.next())
id := embargoID(c.lk.embargoID.next())
e := &embargo{
c: client,
lifted: make(chan struct{}),
}
if int64(id) == int64(len(c.embargoes)) {
c.embargoes = append(c.embargoes, e)
if int64(id) == int64(len(c.lk.embargoes)) {
c.lk.embargoes = append(c.lk.embargoes, e)
} else {
c.embargoes[id] = e
c.lk.embargoes[id] = e
}
var c2 capnp.Client
c2, c.embargoes[id].p = capnp.NewPromisedClient(c.embargoes[id])
c2, c.lk.embargoes[id].p = capnp.NewPromisedClient(c.lk.embargoes[id])
return id, c2
}

// findEmbargo returns the embargo entry with the given ID or nil if
// couldn't be found.
// couldn't be found. Must be holding c.mu
func (c *Conn) findEmbargo(id embargoID) *embargo {
if int64(id) >= int64(len(c.embargoes)) {
if int64(id) >= int64(len(c.lk.embargoes)) {
return nil
}
return c.embargoes[id] // might be nil
return c.lk.embargoes[id] // might be nil
}

// lift disembargoes the client. It must be called only once.
Expand Down
32 changes: 16 additions & 16 deletions rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type impent struct {
//
// The caller must be holding onto c.mu.
func (c *Conn) addImport(id importID) capnp.Client {
if ent := c.imports[id]; ent != nil {
if ent := c.lk.imports[id]; ent != nil {
ent.wireRefs++
client, ok := ent.wc.AddRef()
if !ok {
Expand All @@ -69,7 +69,7 @@ func (c *Conn) addImport(id importID) capnp.Client {
c: c,
id: id,
})
c.imports[id] = &impent{
c.lk.imports[id] = &impent{
wc: client.WeakRef(),
wireRefs: 1,
}
Expand All @@ -84,33 +84,33 @@ type importClient struct {
}

func (ic *importClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer, capnp.ReleaseFunc) {
ic.c.mu.Lock()
defer ic.c.mu.Unlock()
ic.c.lk.Lock()
defer ic.c.lk.Unlock()

if !ic.c.startTask() {
return capnp.ErrorAnswer(s.Method, ExcClosed), func() {}
}
defer ic.c.tasks.Done()
ent := ic.c.imports[ic.id]
ent := ic.c.lk.imports[ic.id]
if ent == nil || ic.generation != ent.generation {
return capnp.ErrorAnswer(s.Method, rpcerr.Disconnectedf("send on closed import")), func() {}
}
q := ic.c.newQuestion(s.Method)

// Send call message.
syncutil.Without(&ic.c.mu, func() {
syncutil.Without(&ic.c.lk, func() {
ic.c.sendMessage(ctx, func(m rpccp.Message) error {
return ic.c.newImportCallMessage(m, ic.id, q.id, s)
}, func(err error) {
ic.c.mu.Lock()
defer ic.c.mu.Unlock()
ic.c.lk.Lock()
defer ic.c.lk.Unlock()

if err != nil {
ic.c.questions[q.id] = nil
syncutil.Without(&ic.c.mu, func() {
ic.c.lk.questions[q.id] = nil
syncutil.Without(&ic.c.lk, func() {
q.p.Reject(rpcerr.Failedf("send message: %w", err))
})
ic.c.questionID.remove(uint32(q.id))
ic.c.lk.questionID.remove(uint32(q.id))
return
}

Expand Down Expand Up @@ -164,7 +164,7 @@ func (c *Conn) newImportCallMessage(msg rpccp.Message, imp importID, qid questio
if err := s.PlaceArgs(args); err != nil {
return rpcerr.Failedf("place arguments: %w", err)
}
syncutil.With(&c.mu, func() {
syncutil.With(&c.lk, func() {
// TODO(soon): save param refs
_, err = c.fillPayloadCapTable(payload)
})
Expand Down Expand Up @@ -219,21 +219,21 @@ func (ic *importClient) Brand() capnp.Brand {
}

func (ic *importClient) Shutdown() {
ic.c.mu.Lock()
defer ic.c.mu.Unlock()
ic.c.lk.Lock()
defer ic.c.lk.Unlock()

if !ic.c.startTask() {
return
}
defer ic.c.tasks.Done()

ent := ic.c.imports[ic.id]
ent := ic.c.lk.imports[ic.id]
if ic.generation != ent.generation {
// A new reference was added concurrently with the Shutdown. See
// impent.generation documentation for an explanation.
return
}
delete(ic.c.imports, ic.id)
delete(ic.c.lk.imports, ic.id)
ic.c.sendMessage(ic.c.bgctx, func(msg rpccp.Message) error {
rel, err := msg.NewRelease()
if err == nil {
Expand Down
Loading

0 comments on commit d105cc4

Please sign in to comment.