Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conn: wrap fields protected by mu in struct. #379

Merged
merged 2 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (ans *answer) sendException(ex error) releaseList {
// shutdown has its own strategy for cleaning up an answer.
func (ans *answer) destroy() (releaseList, error) {
defer syncutil.Without(&ans.c.mu, ans.msgReleaser.Decr)
delete(ans.c.answers, ans.id)
delete(ans.c.lk.answers, ans.id)
if !ans.flags.Contains(releaseResultCapsFlag) || len(ans.exportRefs) == 0 {
return nil, nil

Expand Down
38 changes: 19 additions & 19 deletions rpc/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ func (c *Conn) clearExportID(m *capnp.Metadata) {
}

// findExport returns the export entry with the given ID or nil if
// couldn't be found.
// couldn't be found. The caller must be holding c.mu
func (c *Conn) findExport(id exportID) *expent {
if int64(id) >= int64(len(c.exports)) {
if int64(id) >= int64(len(c.lk.exports)) {
return nil
}
return c.exports[id] // might be nil
return c.lk.exports[id] // might be nil
}

// releaseExport decreases the number of wire references to an export
Expand All @@ -63,8 +63,8 @@ func (c *Conn) releaseExport(id exportID, count uint32) (capnp.Client, error) {
switch {
case count == ent.wireRefs:
client := ent.client
c.exports[id] = nil
c.exportID.remove(uint32(id))
c.lk.exports[id] = nil
c.lk.exportID.remove(uint32(id))
metadata := client.State().Metadata
syncutil.With(metadata, func() {
c.clearExportID(metadata)
Expand Down Expand Up @@ -116,7 +116,7 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
state := client.State()
bv := state.Brand.Value
if ic, ok := bv.(*importClient); ok && ic.c == c {
if ent := c.imports[ic.id]; ent != nil && ent.generation == ic.generation {
if ent := c.lk.imports[ic.id]; ent != nil && ent.generation == ic.generation {
d.SetReceiverHosted(uint32(ic.id))
return 0, false, nil
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
defer state.Metadata.Unlock()
id, ok := c.findExportID(state.Metadata)
if ok {
ent := c.exports[id]
ent := c.lk.exports[id]
ent.wireRefs++
d.SetSenderHosted(uint32(id))
return id, true, nil
Expand All @@ -160,11 +160,11 @@ func (c *Conn) sendCap(d rpccp.CapDescriptor, client capnp.Client) (_ exportID,
client: client.AddRef(),
wireRefs: 1,
}
id = exportID(c.exportID.next())
if int64(id) == int64(len(c.exports)) {
c.exports = append(c.exports, ee)
id = exportID(c.lk.exportID.next())
if int64(id) == int64(len(c.lk.exports)) {
c.lk.exports = append(c.lk.exports, ee)
} else {
c.exports[id] = ee
c.lk.exports[id] = ee
}
c.setExportID(state.Metadata, id)
d.SetSenderHosted(uint32(id))
Expand Down Expand Up @@ -217,28 +217,28 @@ type embargo struct {
//
// The caller must be holding onto c.mu.
func (c *Conn) embargo(client capnp.Client) (embargoID, capnp.Client) {
id := embargoID(c.embargoID.next())
id := embargoID(c.lk.embargoID.next())
e := &embargo{
c: client,
lifted: make(chan struct{}),
}
if int64(id) == int64(len(c.embargoes)) {
c.embargoes = append(c.embargoes, e)
if int64(id) == int64(len(c.lk.embargoes)) {
c.lk.embargoes = append(c.lk.embargoes, e)
} else {
c.embargoes[id] = e
c.lk.embargoes[id] = e
}
var c2 capnp.Client
c2, c.embargoes[id].p = capnp.NewPromisedClient(c.embargoes[id])
c2, c.lk.embargoes[id].p = capnp.NewPromisedClient(c.lk.embargoes[id])
return id, c2
}

// findEmbargo returns the embargo entry with the given ID or nil if
// couldn't be found.
// couldn't be found. Must be holding c.mu
func (c *Conn) findEmbargo(id embargoID) *embargo {
if int64(id) >= int64(len(c.embargoes)) {
if int64(id) >= int64(len(c.lk.embargoes)) {
return nil
}
return c.embargoes[id] // might be nil
return c.lk.embargoes[id] // might be nil
}

// lift disembargoes the client. It must be called only once.
Expand Down
14 changes: 7 additions & 7 deletions rpc/import.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type impent struct {
//
// The caller must be holding onto c.mu.
func (c *Conn) addImport(id importID) capnp.Client {
if ent := c.imports[id]; ent != nil {
if ent := c.lk.imports[id]; ent != nil {
ent.wireRefs++
client, ok := ent.wc.AddRef()
if !ok {
Expand All @@ -69,7 +69,7 @@ func (c *Conn) addImport(id importID) capnp.Client {
c: c,
id: id,
})
c.imports[id] = &impent{
c.lk.imports[id] = &impent{
wc: client.WeakRef(),
wireRefs: 1,
}
Expand All @@ -91,7 +91,7 @@ func (ic *importClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer,
return capnp.ErrorAnswer(s.Method, ExcClosed), func() {}
}
defer ic.c.tasks.Done()
ent := ic.c.imports[ic.id]
ent := ic.c.lk.imports[ic.id]
if ent == nil || ic.generation != ent.generation {
return capnp.ErrorAnswer(s.Method, rpcerr.Disconnectedf("send on closed import")), func() {}
}
Expand All @@ -106,11 +106,11 @@ func (ic *importClient) Send(ctx context.Context, s capnp.Send) (*capnp.Answer,
defer ic.c.mu.Unlock()

if err != nil {
ic.c.questions[q.id] = nil
ic.c.lk.questions[q.id] = nil
syncutil.Without(&ic.c.mu, func() {
q.p.Reject(rpcerr.Failedf("send message: %w", err))
})
ic.c.questionID.remove(uint32(q.id))
ic.c.lk.questionID.remove(uint32(q.id))
return
}

Expand Down Expand Up @@ -227,13 +227,13 @@ func (ic *importClient) Shutdown() {
}
defer ic.c.tasks.Done()

ent := ic.c.imports[ic.id]
ent := ic.c.lk.imports[ic.id]
if ic.generation != ent.generation {
// A new reference was added concurrently with the Shutdown. See
// impent.generation documentation for an explanation.
return
}
delete(ic.c.imports, ic.id)
delete(ic.c.lk.imports, ic.id)
ic.c.sendMessage(ic.c.bgctx, func(msg rpccp.Message) error {
rel, err := msg.NewRelease()
if err == nil {
Expand Down
12 changes: 6 additions & 6 deletions rpc/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ const (
func (c *Conn) newQuestion(method capnp.Method) *question {
q := &question{
c: c,
id: questionID(c.questionID.next()),
id: questionID(c.lk.questionID.next()),
release: func() {},
finishMsgSend: make(chan struct{}),
}
q.p = capnp.NewPromise(method, q) // TODO(someday): customize error message for bootstrap
c.setAnswerQuestion(q.p.Answer(), q)
if int(q.id) == len(c.questions) {
c.questions = append(c.questions, q)
if int(q.id) == len(c.lk.questions) {
c.lk.questions = append(c.lk.questions, q)
} else {
c.questions[q.id] = q
c.lk.questions[q.id] = q
}
return q
}
Expand Down Expand Up @@ -158,11 +158,11 @@ func (q *question) PipelineSend(ctx context.Context, transform []capnp.PipelineO
}, func(err error) {
if err != nil {
syncutil.With(&q.c.mu, func() {
q.c.questions[q2.id] = nil
q.c.lk.questions[q2.id] = nil
})
q2.p.Reject(rpcerr.Failedf("send message: %w", err))
syncutil.With(&q.c.mu, func() {
q.c.questionID.remove(uint32(q2.id))
q.c.lk.questionID.remove(uint32(q2.id))
})
return
}
Expand Down
Loading