Skip to content

Commit

Permalink
Allow custom attribute filtering with an open generic type. (#68158)
Browse files Browse the repository at this point in the history
* Allow custom attribute filtering with an open generic type.
  • Loading branch information
madelson authored Apr 22, 2022
1 parent fdd881c commit c6a51bd
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,7 @@ private static bool FilterCustomAttributeRecord(
attributeType = (decoratedModule.ResolveType(scope.GetParentToken(caCtorToken), null, null) as RuntimeType)!;

// Test attribute type against user provided attribute type filter
if (!attributeFilterType.IsAssignableFrom(attributeType))
if (!MatchesTypeFilter(attributeType, attributeFilterType))
return false;

// Ensure if attribute type must be inheritable that it is inheritable
Expand Down Expand Up @@ -1374,6 +1374,23 @@ private static bool FilterCustomAttributeRecord(
GC.KeepAlive(ctorWithParameters);
return result;
}

private static bool MatchesTypeFilter(RuntimeType attributeType, RuntimeType attributeFilterType)
{
if (attributeFilterType.IsGenericTypeDefinition)
{
for (RuntimeType? type = attributeType; type != null; type = (RuntimeType?)type.BaseType)
{
if (type.IsConstructedGenericType && type.GetGenericTypeDefinition() == attributeFilterType)
{
return true;
}
}
return false;
}

return attributeFilterType.IsAssignableFrom(attributeType);
}
#endregion

#region Private Static Methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ public IEnumerable<CustomAttributeData> GetMatchingCustomAttributes(E element, T
return true;
};
}
else if (optionalAttributeTypeFilter.IsGenericTypeDefinition)
{
passesFilter =
delegate (Type actualType)
{
if (actualType.IsConstructedGenericType && actualType.GetGenericTypeDefinition() == optionalAttributeTypeFilter)
{
return true;
}

if (!typeFilterKnownToBeSealed)
{
for (Type? type = actualType.BaseType; type != null; type = type.BaseType)
{
if (type.IsConstructedGenericType && type.GetGenericTypeDefinition() == optionalAttributeTypeFilter)
{
return true;
}
}
}

return false;
};
}
else
{
passesFilter =
Expand Down
26 changes: 25 additions & 1 deletion src/libraries/System.Runtime/tests/System/Attributes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,27 @@ public static void GetCustomAttributesWorksWithOpenAndClosedGenericTypesForEvent
GenericAttributesTestHelper<DateTime?>(t => Attribute.GetCustomAttributes(@event, t));
}

[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/56887", TestRuntimes.Mono)]
public static void GetCustomAttributesOnOpenGenericTypeRetrievesDerivedAttributes()
{
Attribute[] attributes = Attribute.GetCustomAttributes(typeof(HasGenericAttribute), typeof(GenericAttribute<>));
Assert.Equal(3, attributes.Length);
Assert.Equal(1, attributes.Count(a => a.GetType() == typeof(DerivesFromGenericAttribute)));
Assert.Equal(1, attributes.Count(a => a.GetType() == typeof(GenericAttribute<bool>)));
Assert.Equal(1, attributes.Count(a => a.GetType() == typeof(GenericAttribute<string>)));

attributes = Attribute.GetCustomAttributes(typeof(HasGenericAttribute), typeof(GenericAttribute<bool>));
Assert.Equal(2, attributes.Length);
Assert.Equal(1, attributes.Count(a => a.GetType() == typeof(DerivesFromGenericAttribute)));
Assert.Equal(1, attributes.Count(a => a.GetType() == typeof(GenericAttribute<bool>)));
}

private static void GenericAttributesTestHelper<TGenericParameter>(Func<Type, Attribute[]> getCustomAttributes)
{
Attribute[] openGenericAttributes = getCustomAttributes(typeof(GenericAttribute<>));
Assert.Empty(openGenericAttributes);
Assert.True(openGenericAttributes.Length >= 1);
Assert.Equal(1, openGenericAttributes.OfType<GenericAttribute<TGenericParameter>>().Count());

Attribute[] closedGenericAttributes = getCustomAttributes(typeof(GenericAttribute<TGenericParameter>));
Assert.Equal(1, closedGenericAttributes.Length);
Expand Down Expand Up @@ -911,11 +928,18 @@ public class NameableAttribute : Attribute, INameable
[Nameable]
public class ExampleWithAttribute { }

[AttributeUsage(AttributeTargets.All, AllowMultiple = true)]
public class GenericAttribute<T> : Attribute
{
}

public class DerivesFromGenericAttribute : GenericAttribute<bool>
{
}

[DerivesFromGeneric]
[GenericAttribute<string>]
[GenericAttribute<bool>]
public class HasGenericAttribute
{
[GenericAttribute<TimeSpan>]
Expand Down
14 changes: 7 additions & 7 deletions src/tests/reflection/GenericAttribute/GenericAttributeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ static int Main(string[] args)
Assert(((ICustomAttributeProvider)assembly).IsDefined(typeof(SingleAttribute<int>), true));
Assert(CustomAttributeExtensions.IsDefined(assembly, typeof(SingleAttribute<bool>)));
Assert(((ICustomAttributeProvider)assembly).IsDefined(typeof(SingleAttribute<bool>), true));
Assert(!CustomAttributeExtensions.GetCustomAttributes(assembly, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
Assert(!CustomAttributeExtensions.GetCustomAttributes(assembly, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
Assert(CustomAttributeExtensions.GetCustomAttributes(assembly, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
Assert(CustomAttributeExtensions.GetCustomAttributes(assembly, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
*/

// Uncomment when https://github.com/dotnet/runtime/issues/66168 is resolved
// Module module = programTypeInfo.Module;
// AssertAny(CustomAttributeExtensions.GetCustomAttributes(module), a => a is SingleAttribute<long>);
// Assert(CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<long>)).GetEnumerator().MoveNext());
// Assert(CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<long>)).GetEnumerator().MoveNext());
// Assert(!CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
// Assert(!CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
// Assert(CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
// Assert(CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());

TypeInfo programTypeInfo = typeof(Class).GetTypeInfo();
Assert(CustomAttributeExtensions.GetCustomAttribute<SingleAttribute<int>>(programTypeInfo) != null);
Expand Down Expand Up @@ -161,9 +161,9 @@ static int Main(string[] args)
AssertAny(b10, a => (a as MultiAttribute<Type>)?.Value == typeof(Class));
AssertAny(b10, a => (a as MultiAttribute<Type>)?.Value == typeof(Class.Derive));

Assert(!CustomAttributeExtensions.GetCustomAttributes(programTypeInfo, typeof(MultiAttribute<>), false).GetEnumerator().MoveNext());
Assert(!CustomAttributeExtensions.GetCustomAttributes(programTypeInfo, typeof(MultiAttribute<>), true).GetEnumerator().MoveNext());
Assert(!((ICustomAttributeProvider)programTypeInfo).GetCustomAttributes(typeof(MultiAttribute<>), true).GetEnumerator().MoveNext());
Assert(CustomAttributeExtensions.GetCustomAttributes(programTypeInfo, typeof(MultiAttribute<>), false).GetEnumerator().MoveNext());
Assert(CustomAttributeExtensions.GetCustomAttributes(programTypeInfo, typeof(MultiAttribute<>), true).GetEnumerator().MoveNext());
Assert(((ICustomAttributeProvider)programTypeInfo).GetCustomAttributes(typeof(MultiAttribute<>), true).GetEnumerator().MoveNext());

// Test coverage for CustomAttributeData api surface
var a1_data = CustomAttributeData.GetCustomAttributes(programTypeInfo);
Expand Down

0 comments on commit c6a51bd

Please sign in to comment.