Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conn: wrap fields protected by mu in struct. #379

Merged
merged 2 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ type answer struct {
// entry is a placeholder until the remote vat cancels the call.
ret rpccp.Return

// sendMsg sends the return message. The caller MUST NOT hold ans.c.mu.
// sendMsg sends the return message. The caller MUST NOT hold ans.c.lk.
sendMsg func()

// msgReleaser releases the return message when its refcount hits zero.
// The caller MUST NOT hold ans.c.mu.
// The caller MUST NOT hold ans.c.lk.
msgReleaser *rc.Releaser

// results is the memoized answer to ret.Results().
Expand Down Expand Up @@ -129,11 +129,11 @@ func (c *Conn) newReturn(ctx context.Context) (_ rpccp.Return, sendMsg func(), _
}

// setPipelineCaller sets ans.pcall to pcall if the answer has not
// already returned. The caller MUST NOT hold ans.c.mu.
// already returned. The caller MUST NOT hold ans.c.lk.
//
// This also sets ans.promise to a new promise, wrapping pcall.
func (ans *answer) setPipelineCaller(m capnp.Method, pcall capnp.PipelineCaller) {
syncutil.With(&ans.c.mu, func() {
syncutil.With(&ans.c.lk, func() {
if !ans.flags.Contains(resultsReady) {
ans.pcall = pcall
ans.promise = capnp.NewPromise(m, pcall)
Expand Down Expand Up @@ -181,12 +181,12 @@ func (ans *answer) setBootstrap(c capnp.Client) error {

// Return sends the return message.
//
// The caller MUST NOT hold ans.c.mu.
// The caller MUST NOT hold ans.c.lk.
func (ans *answer) Return(e error) {
ans.c.mu.Lock()
ans.c.lk.Lock()
if e != nil {
rl := ans.sendException(e)
ans.c.mu.Unlock()
ans.c.lk.Unlock()
rl.release()
ans.pcalls.Wait()
ans.c.tasks.Done() // added by handleCall
Expand All @@ -202,13 +202,13 @@ func (ans *answer) Return(e error) {
ans.c.er.ReportError(err)
}

ans.c.mu.Unlock()
ans.c.lk.Unlock()
rl.release()
ans.pcalls.Wait()
return
}
}
ans.c.mu.Unlock()
ans.c.lk.Unlock()
rl.release()
ans.pcalls.Wait()
ans.c.tasks.Done() // added by handleCall
Expand All @@ -219,7 +219,7 @@ func (ans *answer) Return(e error) {
// Finish with releaseResultCaps set to true, then sendReturn returns
// the number of references to be subtracted from each export.
//
// The caller MUST be holding onto ans.c.mu. sendReturn MUST NOT be
// The caller MUST be holding onto ans.c.lk. sendReturn MUST NOT be
// called if sendException was previously called.
func (ans *answer) sendReturn() (releaseList, error) {
ans.pcall = nil
Expand Down Expand Up @@ -252,13 +252,13 @@ func (ans *answer) sendReturn() (releaseList, error) {
}
ans.promise = nil
}
ans.c.mu.Unlock()
ans.c.lk.Unlock()
ans.sendMsg()
if fin {
ans.c.mu.Lock()
ans.c.lk.Lock()
return ans.destroy()
}
ans.c.mu.Lock()
ans.c.lk.Lock()
}
ans.flags |= returnSent
if !ans.flags.Contains(finishReceived) {
Expand All @@ -269,7 +269,7 @@ func (ans *answer) sendReturn() (releaseList, error) {

// sendException sends an exception on the answer's return message.
//
// The caller MUST be holding onto ans.c.mu. sendException MUST NOT
// The caller MUST be holding onto ans.c.lk. sendException MUST NOT
// be called if sendReturn was previously called.
func (ans *answer) sendException(ex error) releaseList {
ans.err = ex
Expand All @@ -286,7 +286,7 @@ func (ans *answer) sendException(ex error) releaseList {
default:
// Send exception.
fin := ans.flags.Contains(finishReceived)
ans.c.mu.Unlock()
ans.c.lk.Unlock()
if e, err := ans.ret.NewException(); err != nil {
ans.c.er.ReportError(fmt.Errorf("send exception: %w", err))
} else {
Expand All @@ -298,11 +298,11 @@ func (ans *answer) sendException(ex error) releaseList {
}
}
if fin {
ans.c.mu.Lock()
ans.c.lk.Lock()
rl, _ := ans.destroy()
return rl
}
ans.c.mu.Lock()
ans.c.lk.Lock()
}
ans.flags |= returnSent
if !ans.flags.Contains(finishReceived) {
Expand All @@ -316,12 +316,12 @@ func (ans *answer) sendException(ex error) releaseList {

// destroy removes the answer from the table and returns the clients to
// release. The answer must have sent a return and received a finish.
// The caller must be holding onto ans.c.mu.
// The caller must be holding onto ans.c.lk.
//
// shutdown has its own strategy for cleaning up an answer.
func (ans *answer) destroy() (releaseList, error) {
defer syncutil.Without(&ans.c.mu, ans.msgReleaser.Decr)
delete(ans.c.answers, ans.id)
defer syncutil.Without(&ans.c.lk, ans.msgReleaser.Decr)
delete(ans.c.lk.answers, ans.id)
if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 {
return nil, nil

Expand Down
38 changes: 19 additions & 19 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ func (c *Conn) clearExportID(m *capnp.Metadata) {
}

// findExport returns the export entry with the given ID or nil if
// couldn't be found.
// couldn't be found. The caller must be holding c.mu
func (c *Conn) findExport(id exportID) *expent {
if int64(id) >= int64(len(c.exports)) {
if int64(id) >= int64(len(c.lk.exports)) {
return nil
}
return c.exports[id] // might be nil
return c.lk.exports[id] // might be nil
}

// releaseExport decreases the number of wire references to an export
Expand All @@ -63,8 +63,8 @@ func (c *Conn) releaseExport(id exportID, count uint32) (capnp.Client, error) {
switch {
case count == ent.wireRefs:
client := ent.client
c.exports[id] = nil
c.exportID.remove(uint32(id))
c.lk.exports[id] = nil
c.lk.exportID.remove(uint32(id))
metadata := client.State().Metadata
syncutil.With(metadata, func() {
c.clearExportID(metadata)
Expand Down Expand Up @@ -116,7 +116,7 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
state := client.State()
bv := state.Brand.Value
if ic, ok := bv.(*importClient); ok && ic.c == c {
if ent := c.imports[ic.id]; ent != nil && ent.generation == ic.generation {
if ent := c.lk.imports[ic.id]; ent != nil && ent.generation == ic.generation {
d.SetReceiverHosted(uint32(ic.id))
return 0, false, nil
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
defer state.Metadata.Unlock()
id, ok := c.findExportID(state.Metadata)
if ok {
ent := c.exports[id]
ent := c.lk.exports[id]
ent.wireRefs++
d.SetSenderHosted(uint32(id))
return id, true, nil
Expand All @@ -160,11 +160,11 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
client: client.AddRef(),
wireRefs: 1,
}
id = exportID(c.exportID.next())
if int64(id) == int64(len(c.exports)) {
c.exports = append(c.exports, ee)
id = exportID(c.lk.exportID.next())
if int64(id) == int64(len(c.lk.exports)) {
c.lk.exports = append(c.lk.exports, ee)
} else {
c.exports[id] = ee
c.lk.exports[id] = ee
}
c.setExportID(state.Metadata, id)
d.SetSenderHosted(uint32(id))
Expand Down Expand Up @@ -217,28 +217,28 @@ type embargo struct {
//
// The caller must be holding onto c.mu.
func (c *Conn) embargo(client capnp.Client) (embargoID, capnp.Client) {
id := embargoID(c.embargoID.next())
id := embargoID(c.lk.embargoID.next())
e := &embargo{
c: client,
lifted: make(chan struct{}),
}
if int64(id) == int64(len(c.embargoes)) {
c.embargoes = append(c.embargoes, e)
if int64(id) == int64(len(c.lk.embargoes)) {
c.lk.embargoes = append(c.lk.embargoes, e)
} else {
c.embargoes[id] = e
c.lk.embargoes[id] = e
}
var c2 capnp.Client
c2, c.embargoes[id].p = capnp.NewPromisedClient(c.embargoes[id])
c2, c.lk.embargoes[id].p = capnp.NewPromisedClient(c.lk.embargoes[id])
return id, c2
}

// findEmbargo returns the embargo entry with the given ID or nil if
// couldn't be found.
// couldn't be found. Must be holding c.mu
func (c *Conn) findEmbargo(id embargoID) *embargo {
if int64(id) >= int64(len(c.embargoes)) {
if int64(id) >= int64(len(c.lk.embargoes)) {
return nil
}
return c.embargoes[id] // might be nil
return c.lk.embargoes[id] // might be nil
}

// lift disembargoes the client. It must be called only once.
Expand Down
32 changes: 16 additions & 16 deletions rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type impent struct {
//
// The caller must be holding onto c.mu.
func (c *Conn) addImport(id importID) capnp.Client {
if ent := c.imports[id]; ent != nil {
if ent := c.lk.imports[id]; ent != nil {
ent.wireRefs++
client, ok := ent.wc.AddRef()
if !ok {
Expand All @@ -69,7 +69,7 @@ func (c *Conn) addImport(id importID) capnp.Client {
c: c,
id: id,
})
c.imports[id] = &impent{
c.lk.imports[id] = &impent{
wc: client.WeakRef(),
wireRefs: 1,
}
Expand All @@ -84,33 +84,33 @@ type importClient struct {
}

func (ic *importClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer, capnp.ReleaseFunc) {
ic.c.mu.Lock()
defer ic.c.mu.Unlock()
ic.c.lk.Lock()
defer ic.c.lk.Unlock()

if !ic.c.startTask() {
return capnp.ErrorAnswer(s.Method, ExcClosed), func() {}
}
defer ic.c.tasks.Done()
ent := ic.c.imports[ic.id]
ent := ic.c.lk.imports[ic.id]
if ent == nil || ic.generation != ent.generation {
return capnp.ErrorAnswer(s.Method, rpcerr.Disconnectedf("send on closed import")), func() {}
}
q := ic.c.newQuestion(s.Method)

// Send call message.
syncutil.Without(&ic.c.mu, func() {
syncutil.Without(&ic.c.lk, func() {
ic.c.sendMessage(ctx, func(m rpccp.Message) error {
return ic.c.newImportCallMessage(m, ic.id, q.id, s)
}, func(err error) {
ic.c.mu.Lock()
defer ic.c.mu.Unlock()
ic.c.lk.Lock()
defer ic.c.lk.Unlock()

if err != nil {
ic.c.questions[q.id] = nil
syncutil.Without(&ic.c.mu, func() {
ic.c.lk.questions[q.id] = nil
syncutil.Without(&ic.c.lk, func() {
q.p.Reject(rpcerr.Failedf("send message: %w", err))
})
ic.c.questionID.remove(uint32(q.id))
ic.c.lk.questionID.remove(uint32(q.id))
return
}

Expand Down Expand Up @@ -164,7 +164,7 @@ func (c *Conn) newImportCallMessage(msg rpccp.Message, imp importID, qid questio
if err := s.PlaceArgs(args); err != nil {
return rpcerr.Failedf("place arguments: %w", err)
}
syncutil.With(&c.mu, func() {
syncutil.With(&c.lk, func() {
// TODO(soon): save param refs
_, err = c.fillPayloadCapTable(payload)
})
Expand Down Expand Up @@ -219,21 +219,21 @@ func (ic *importClient) Brand() capnp.Brand {
}

func (ic *importClient) Shutdown() {
ic.c.mu.Lock()
defer ic.c.mu.Unlock()
ic.c.lk.Lock()
defer ic.c.lk.Unlock()

if !ic.c.startTask() {
return
}
defer ic.c.tasks.Done()

ent := ic.c.imports[ic.id]
ent := ic.c.lk.imports[ic.id]
if ic.generation != ent.generation {
// A new reference was added concurrently with the Shutdown. See
// impent.generation documentation for an explanation.
return
}
delete(ic.c.imports, ic.id)
delete(ic.c.lk.imports, ic.id)
ic.c.sendMessage(ic.c.bgctx, func(msg rpccp.Message) error {
rel, err := msg.NewRelease()
if err == nil {
Expand Down
Loading