diff --git a/btf/core.go b/btf/core.go index b743ffe3b..a3d311a06 100644 --- a/btf/core.go +++ b/btf/core.go @@ -374,7 +374,7 @@ func coreCalculateFixup(relo *CORERelocation, target Type, targetID TypeID, bo b return zero, fmt.Errorf("unexpected accessor %v", relo.accessor) } - err := coreAreTypesCompatible(local, target) + err := CheckTypeCompatibility(local, target) if errors.Is(err, errIncompatibleTypes) { return poison() } @@ -901,7 +901,11 @@ func coreFindEnumValue(local Type, localAcc coreAccessor, target Type) (localVal // // Only layout compatibility is checked, ignoring names of the root type. func CheckTypeCompatibility(localType Type, targetType Type) error { - return coreAreTypesCompatible(localType, targetType) + return coreAreTypesCompatible(localType, targetType, nil) +} + +type pair struct { + A, B Type } /* The comment below is from bpf_core_types_are_compat in libbpf.c: @@ -927,59 +931,60 @@ func CheckTypeCompatibility(localType Type, targetType Type) error { * * Returns errIncompatibleTypes if types are not compatible. */ -func coreAreTypesCompatible(localType Type, targetType Type) error { +func coreAreTypesCompatible(localType Type, targetType Type, visited map[pair]struct{}) error { + localType = UnderlyingType(localType) + targetType = UnderlyingType(targetType) - var ( - localTs, targetTs typeDeque - l, t = &localType, &targetType - depth = 0 - ) + if reflect.TypeOf(localType) != reflect.TypeOf(targetType) { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } - for ; l != nil && t != nil; l, t = localTs.Shift(), targetTs.Shift() { - if depth >= maxResolveDepth { - return errors.New("types are nested too deep") - } + if _, ok := visited[pair{localType, targetType}]; ok { + return nil + } + if visited == nil { + visited = make(map[pair]struct{}) + } + visited[pair{localType, targetType}] = struct{}{} - localType = UnderlyingType(*l) - targetType = UnderlyingType(*t) + switch lv := localType.(type) { + case *Void, *Struct, *Union, *Enum, *Fwd, *Int: + return nil - if reflect.TypeOf(localType) != reflect.TypeOf(targetType) { - return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) - } + case *Pointer: + tv := targetType.(*Pointer) + return coreAreTypesCompatible(lv.Target, tv.Target, visited) - switch lv := (localType).(type) { - case *Void, *Struct, *Union, *Enum, *Fwd, *Int: - // Nothing to do here + case *Array: + tv := targetType.(*Array) + if err := coreAreTypesCompatible(lv.Index, tv.Index, visited); err != nil { + return err + } - case *Pointer, *Array: - depth++ - walkType(localType, localTs.Push) - walkType(targetType, targetTs.Push) + return coreAreTypesCompatible(lv.Type, tv.Type, visited) - case *FuncProto: - tv := targetType.(*FuncProto) - if len(lv.Params) != len(tv.Params) { - return fmt.Errorf("function param mismatch: %w", errIncompatibleTypes) - } + case *FuncProto: + tv := targetType.(*FuncProto) + if err := coreAreTypesCompatible(lv.Return, tv.Return, visited); err != nil { + return err + } - depth++ - walkType(localType, localTs.Push) - walkType(targetType, targetTs.Push) + if len(lv.Params) != len(tv.Params) { + return fmt.Errorf("function param mismatch: %w", errIncompatibleTypes) + } - default: - return fmt.Errorf("unsupported type %T", localType) + for i, localParam := range lv.Params { + targetParam := tv.Params[i] + if err := coreAreTypesCompatible(localParam.Type, targetParam.Type, visited); err != nil { + return err + } } - } - if l != nil { - return fmt.Errorf("dangling local type %T", *l) - } + return nil - if t != nil { - return fmt.Errorf("dangling target type %T", *t) + default: + return fmt.Errorf("unsupported type %T", localType) } - - return nil } /* coreAreMembersCompatible checks two types for field-based relocation compatibility. diff --git a/btf/core_test.go b/btf/core_test.go index 5ceda0301..f56d3d481 100644 --- a/btf/core_test.go +++ b/btf/core_test.go @@ -17,42 +17,6 @@ import ( ) func TestCheckTypeCompatibility(t *testing.T) { - tests := []struct { - a, b Type - compatible bool - }{ - {&FuncProto{Return: &Typedef{Type: &Int{}}}, &FuncProto{Return: &Int{}}, true}, - {&FuncProto{Return: &Typedef{Type: &Int{}}}, &FuncProto{Return: &Void{}}, false}, - } - for _, test := range tests { - err := CheckTypeCompatibility(test.a, test.b) - if test.compatible { - if err != nil { - t.Errorf("Expected types to be compatible: %s\na = %#v\nb = %#v", err, test.a, test.b) - continue - } - } else { - if !errors.Is(err, errIncompatibleTypes) { - t.Errorf("Expected types to be incompatible: %s\na = %#v\nb = %#v", err, test.a, test.b) - continue - } - } - - err = CheckTypeCompatibility(test.b, test.a) - if test.compatible { - if err != nil { - t.Errorf("Expected reversed types to be compatible: %s\na = %#v\nb = %#v", err, test.a, test.b) - } - } else { - if !errors.Is(err, errIncompatibleTypes) { - t.Errorf("Expected reversed types to be incompatible: %s\na = %#v\nb = %#v", err, test.a, test.b) - } - } - } - -} - -func TestCOREAreTypesCompatible(t *testing.T) { tests := []struct { a, b Type compatible bool @@ -84,10 +48,12 @@ func TestCOREAreTypesCompatible(t *testing.T) { &FuncProto{Return: &Void{}, Params: []FuncParam{{Type: &Void{}}}}, false, }, + {&FuncProto{Return: &Typedef{Type: &Int{}}}, &FuncProto{Return: &Int{}}, true}, + {&FuncProto{Return: &Typedef{Type: &Int{}}}, &FuncProto{Return: &Void{}}, false}, } for _, test := range tests { - err := coreAreTypesCompatible(test.a, test.b) + err := CheckTypeCompatibility(test.a, test.b) if test.compatible { if err != nil { t.Errorf("Expected types to be compatible: %s\na = %#v\nb = %#v", err, test.a, test.b) @@ -100,7 +66,7 @@ func TestCOREAreTypesCompatible(t *testing.T) { } } - err = coreAreTypesCompatible(test.b, test.a) + err = CheckTypeCompatibility(test.b, test.a) if test.compatible { if err != nil { t.Errorf("Expected reversed types to be compatible: %s\na = %#v\nb = %#v", err, test.a, test.b) @@ -113,7 +79,7 @@ func TestCOREAreTypesCompatible(t *testing.T) { } for _, invalid := range []Type{&Var{}, &Datasec{}} { - err := coreAreTypesCompatible(invalid, invalid) + err := CheckTypeCompatibility(invalid, invalid) if errors.Is(err, errIncompatibleTypes) { t.Errorf("Expected an error for %T, not errIncompatibleTypes", invalid) } else if err == nil {