diff --git a/btf/btf.go b/btf/btf.go index 6a3e53d67..1268a4135 100644 --- a/btf/btf.go +++ b/btf/btf.go @@ -66,8 +66,9 @@ func (s *immutableTypes) typeByID(id TypeID) (Type, bool) { // mutableTypes is a set of types which may be changed. type mutableTypes struct { imm immutableTypes + mu *sync.RWMutex // protects copies below copies map[Type]Type // map[orig]copy - copiedTypeIDs map[Type]TypeID //map[copy]origID + copiedTypeIDs map[Type]TypeID // map[copy]origID } // add a type to the set of mutable types. @@ -75,6 +76,20 @@ type mutableTypes struct { // Copies type and all of its children once. Repeated calls with the same type // do not copy again. func (mt *mutableTypes) add(typ Type, typeIDs map[Type]TypeID) Type { + mt.mu.RLock() + cpy, ok := mt.copies[typ] + mt.mu.RUnlock() + + if ok { + // Fast path: the type has been copied before. + return cpy + } + + // modifyGraphPreorder copies the type graph node by node, so we can't drop + // the lock in between. + mt.mu.Lock() + defer mt.mu.Unlock() + return modifyGraphPreorder(typ, func(t Type) (Type, bool) { cpy, ok := mt.copies[t] if ok { @@ -98,6 +113,7 @@ func (mt *mutableTypes) add(typ Type, typeIDs map[Type]TypeID) Type { func (mt *mutableTypes) copy() mutableTypes { mtCopy := mutableTypes{ mt.imm, + &sync.RWMutex{}, make(map[Type]Type, len(mt.copies)), make(map[Type]TypeID, len(mt.copiedTypeIDs)), } @@ -122,6 +138,9 @@ func (mt *mutableTypes) typeID(typ Type) (TypeID, error) { return 0, nil } + mt.mu.RLock() + defer mt.mu.RUnlock() + id, ok := mt.copiedTypeIDs[typ] if !ok { return 0, fmt.Errorf("no ID for type %s: %w", typ, ErrNotFound) @@ -343,6 +362,7 @@ func loadRawSpec(btf io.ReaderAt, bo binary.ByteOrder, base *Spec) (*Spec, error typesByName, bo, }, + &sync.RWMutex{}, make(map[Type]Type), make(map[Type]TypeID), }, diff --git a/btf/btf_test.go b/btf/btf_test.go index 846439c3a..b5d256f1b 100644 --- a/btf/btf_test.go +++ b/btf/btf_test.go @@ -7,7 +7,9 @@ import ( "fmt" "io" "os" + "runtime" "sync" + "sync/atomic" "testing" "github.com/go-quicktest/qt" @@ -525,6 +527,36 @@ func TestFixupDatasecLayout(t *testing.T) { qt.Assert(t, qt.Equals(ds.Vars[5].Offset, 32)) } +func TestSpecConcurrentAccess(t *testing.T) { + spec := vmlinuxTestdataSpec(t) + + n := runtime.GOMAXPROCS(0) + if n < 3 { + t.Error("GOMAXPROCS is too low:", n) + } + + var cond atomic.Bool + var wg sync.WaitGroup + for i := 0; i < n-1; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + for !cond.Load() { + // Spin to increase the chances of a race. + } + + _, _ = spec.AnyTypeByName("gov_update_cpu_data") + }() + + // Try to get the Goroutines scheduled and spinning. + runtime.Gosched() + } + + cond.Store(true) + wg.Wait() +} + func BenchmarkSpecCopy(b *testing.B) { spec := vmlinuxTestdataSpec(b) b.ResetTimer()