Skip to content

Commit

Permalink
Mark most specific static DIM for types marked RelevantToVariantCasti…
Browse files Browse the repository at this point in the history
…ng (#97487)

Previously, we weren't handling static DIMs to ensure that DIMs that provided an implementation of an interface method for an inheriting type would be kept.

This method gets rid of _interfaceOverrides and uses _virtual_methods with TypeMapInfo to find all interface method / implementation pairs.

This PR adds static method handling to the ProcessDefaultImplementation method where it previously only handled instance interface methods. It assumes all static interface methods will be needed if the type implementing the interface IsRelevantToVariantCasting.

The DIM cache also is updated to include the method that provides the implementation for a type.
  • Loading branch information
jtschuster authored Feb 10, 2024
1 parent 96da5a0 commit 6ac1edf
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 51 deletions.
82 changes: 42 additions & 40 deletions src/tools/illink/src/linker/Linker.Steps/MarkStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ protected LinkContext Context {
readonly List<AttributeProviderPair> _ivt_attributes;
protected Queue<(AttributeProviderPair, DependencyInfo, MarkScopeStack.Scope)> _lateMarkedAttributes;
protected List<(TypeDefinition, MarkScopeStack.Scope)> _typesWithInterfaces;
protected HashSet<(OverrideInformation, MarkScopeStack.Scope)> _interfaceOverrides;
protected HashSet<AssemblyDefinition> _dynamicInterfaceCastableImplementationTypesDiscovered;
protected List<TypeDefinition> _dynamicInterfaceCastableImplementationTypes;
protected List<(MethodBody, MarkScopeStack.Scope)> _unreachableBodies;
Expand Down Expand Up @@ -226,7 +225,6 @@ public MarkStep ()
_ivt_attributes = new List<AttributeProviderPair> ();
_lateMarkedAttributes = new Queue<(AttributeProviderPair, DependencyInfo, MarkScopeStack.Scope)> ();
_typesWithInterfaces = new List<(TypeDefinition, MarkScopeStack.Scope)> ();
_interfaceOverrides = new HashSet<(OverrideInformation, MarkScopeStack.Scope)> ();
_dynamicInterfaceCastableImplementationTypesDiscovered = new HashSet<AssemblyDefinition> ();
_dynamicInterfaceCastableImplementationTypes = new List<TypeDefinition> ();
_unreachableBodies = new List<(MethodBody, MarkScopeStack.Scope)> ();
Expand Down Expand Up @@ -573,9 +571,10 @@ protected virtual void EnqueueMethod (MethodDefinition method, in DependencyInfo

void ProcessVirtualMethods ()
{
foreach ((MethodDefinition method, MarkScopeStack.Scope scope) in _virtual_methods) {
using (ScopeStack.PushScope (scope))
foreach ((var method, var scope) in _virtual_methods) {
using (ScopeStack.PushScope (scope)) {
ProcessVirtualMethod (method);
}
}
}

Expand Down Expand Up @@ -603,26 +602,19 @@ void ProcessMarkedTypesWithInterfaces ()
!unusedInterfacesOptimizationEnabled) {
MarkInterfaceImplementations (type);
}
// OverrideInformation for interfaces in PreservedScope aren't added yet
// Interfaces in PreservedScope should have their methods added to _virtual_methods so that they are properly processed
foreach (var method in type.Methods) {
var baseOverrideInformations = Annotations.GetBaseMethods (method);
if (baseOverrideInformations is null)
var baseMethods = Annotations.GetBaseMethods (method);
if (baseMethods is null)
continue;
foreach (var ov in baseOverrideInformations) {
if (ov.Base.DeclaringType is not null && ov.Base.DeclaringType.IsInterface && IgnoreScope (ov.Base.DeclaringType.Scope))
_interfaceOverrides.Add ((ov, ScopeStack.CurrentScope));
foreach (var ov in baseMethods) {
if (ov.Base.DeclaringType is not null && ov.Base.DeclaringType.IsInterface && IgnoreScope (ov.Base.DeclaringType.Scope)) {
_virtual_methods.Add ((ov.Base, ScopeStack.CurrentScope));
}
}
}
}
}

var interfaceOverrides = _interfaceOverrides.ToArray ();
foreach ((var overrideInformation, var scope) in interfaceOverrides) {
using (ScopeStack.PushScope (scope)) {
if (IsInterfaceImplementationMethodNeededByTypeDueToInterface (overrideInformation))
MarkMethod (overrideInformation.Override, new DependencyInfo (DependencyKind.Override, overrideInformation.Base), scope.Origin);
}
}
}

void DiscoverDynamicCastableImplementationInterfaces ()
Expand Down Expand Up @@ -705,10 +697,23 @@ void ProcessVirtualMethod (MethodDefinition method)
{
Annotations.EnqueueVirtualMethod (method);

var defaultImplementations = Annotations.GetDefaultInterfaceImplementations (method);
if (defaultImplementations != null) {
foreach (var defaultImplementationInfo in defaultImplementations) {
ProcessDefaultImplementation (defaultImplementationInfo.InstanceType, defaultImplementationInfo.ProvidingInterface);
if (method.DeclaringType.IsInterface) {
var defaultImplementations = Annotations.GetDefaultInterfaceImplementations (method);
if (defaultImplementations is not null) {
foreach (var dimInfo in defaultImplementations) {
ProcessDefaultImplementation (dimInfo.ImplementingType, dimInfo.InterfaceImpl, dimInfo.DefaultInterfaceMethod);

var ov = new OverrideInformation (method, dimInfo.DefaultInterfaceMethod, Context);
if (IsInterfaceImplementationMethodNeededByTypeDueToInterface (ov, dimInfo.ImplementingType))
MarkMethod (ov.Override, new DependencyInfo (DependencyKind.Override, ov.Base), ScopeStack.CurrentScope.Origin);
}
}
var overridingMethods = Annotations.GetOverrides (method);
if (overridingMethods is not null) {
foreach (var ov in overridingMethods) {
if (IsInterfaceImplementationMethodNeededByTypeDueToInterface (ov, ov.Override.DeclaringType))
MarkMethod (ov.Override, new DependencyInfo (DependencyKind.Override, ov.Base), ScopeStack.CurrentScope.Origin);
}
}
}
}
Expand All @@ -724,10 +729,8 @@ bool ShouldMarkOverrideForBase (OverrideInformation overrideInformation)
Debug.Assert (Annotations.IsMarked (overrideInformation.Base) || IgnoreScope (overrideInformation.Base.DeclaringType.Scope));
if (!Annotations.IsMarked (overrideInformation.Override.DeclaringType))
return false;
if (overrideInformation.IsOverrideOfInterfaceMember) {
_interfaceOverrides.Add ((overrideInformation, ScopeStack.CurrentScope));
if (overrideInformation.IsOverrideOfInterfaceMember)
return false;
}

if (!Context.IsOptimizationEnabled (CodeOptimizations.OverrideRemoval, overrideInformation.Override))
return true;
Expand Down Expand Up @@ -816,9 +819,10 @@ bool RequiresInterfaceRecursively (TypeDefinition typeToExamine, TypeDefinition
return false;
}

void ProcessDefaultImplementation (TypeDefinition typeWithDefaultImplementedInterfaceMethod, InterfaceImplementation implementation)
void ProcessDefaultImplementation (TypeDefinition typeWithDefaultImplementedInterfaceMethod, InterfaceImplementation implementation, MethodDefinition implementationMethod)
{
if (!Annotations.IsInstantiated (typeWithDefaultImplementedInterfaceMethod))
if ((!implementationMethod.IsStatic && !Annotations.IsInstantiated (typeWithDefaultImplementedInterfaceMethod))
|| implementationMethod.IsStatic && !Annotations.IsRelevantToVariantCasting (typeWithDefaultImplementedInterfaceMethod))
return;

MarkInterfaceImplementation (implementation);
Expand Down Expand Up @@ -2275,9 +2279,9 @@ void MarkTypeWithDebuggerDisplayAttribute (TypeDefinition type, CustomAttribute
// Record a logical dependency on the attribute so that we can blame it for the kept members below.
Tracer.AddDirectDependency (attribute, new DependencyInfo (DependencyKind.CustomAttribute, type), marked: false);

MarkTypeWithDebuggerDisplayAttributeValue(type, attribute, (string) attribute.ConstructorArguments[0].Value);
MarkTypeWithDebuggerDisplayAttributeValue (type, attribute, (string) attribute.ConstructorArguments[0].Value);
if (attribute.HasProperties) {
foreach (var property in attribute.Properties) {
foreach (var property in attribute.Properties) {
if (property.Name is "Name" or "Type") {
MarkTypeWithDebuggerDisplayAttributeValue (type, attribute, (string) property.Argument.Value);
}
Expand Down Expand Up @@ -2545,19 +2549,17 @@ bool IsMethodNeededByTypeDueToPreservedScope (MethodDefinition method)
/// <summary>
/// Returns true if the override method is required due to the interface that the base method is declared on. See doc at <see href="docs/methods-kept-by-interface.md"/> for explanation of logic.
/// </summary>
bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformation overrideInformation)
bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformation overrideInformation, TypeDefinition typeThatImplsInterface)
{
var @base = overrideInformation.Base;
var method = overrideInformation.Override;
Debug.Assert (@base.DeclaringType.IsInterface);
if (@base is null || method is null || @base.DeclaringType is null)
return false;

if (Annotations.IsMarked (method))
return false;

if (!@base.DeclaringType.IsInterface)
return false;

// If the interface implementation is not marked, do not mark the implementation method
// A type that doesn't implement the interface isn't required to have methods that implement the interface.
InterfaceImplementation? iface = overrideInformation.MatchingInterfaceImplementation;
Expand All @@ -2578,12 +2580,12 @@ bool IsInterfaceImplementationMethodNeededByTypeDueToInterface (OverrideInformat
// If the method is static and the implementing type is relevant to variant casting, mark the implementation method.
// A static method may only be called through a constrained call if the type is relevant to variant casting.
if (@base.IsStatic)
return Annotations.IsRelevantToVariantCasting (method.DeclaringType)
return Annotations.IsRelevantToVariantCasting (typeThatImplsInterface)
|| IgnoreScope (@base.DeclaringType.Scope);

// If the implementing type is marked as instantiated, mark the implementation method.
// If the type is not instantiated, do not mark the implementation method
return Annotations.IsInstantiated (method.DeclaringType);
return Annotations.IsInstantiated (typeThatImplsInterface);
}

static bool IsSpecialSerializationConstructor (MethodDefinition method)
Expand Down Expand Up @@ -3231,7 +3233,7 @@ protected virtual void ProcessMethod (MethodDefinition method, in DependencyInfo
} else if (method.TryGetProperty (out PropertyDefinition? property))
MarkProperty (property, new DependencyInfo (PropagateDependencyKindToAccessors (reason.Kind, DependencyKind.PropertyOfPropertyMethod), method));
else if (method.TryGetEvent (out EventDefinition? @event)) {
MarkEvent (@event, new DependencyInfo (PropagateDependencyKindToAccessors(reason.Kind, DependencyKind.EventOfEventMethod), method));
MarkEvent (@event, new DependencyInfo (PropagateDependencyKindToAccessors (reason.Kind, DependencyKind.EventOfEventMethod), method));
}

if (method.HasMetadataParameters ()) {
Expand Down Expand Up @@ -3315,7 +3317,7 @@ protected virtual void DoAdditionalMethodProcessing (MethodDefinition method)
{
}

static DependencyKind PropagateDependencyKindToAccessors(DependencyKind parentDependencyKind, DependencyKind kind)
static DependencyKind PropagateDependencyKindToAccessors (DependencyKind parentDependencyKind, DependencyKind kind)
{
switch (parentDependencyKind) {
// If the member is marked due to descriptor or similar, propagate the original reason to suppress some warnings correctly
Expand All @@ -3335,11 +3337,11 @@ void MarkImplicitlyUsedFields (TypeDefinition type)
return;

// keep fields for types with explicit layout, for enums and for InlineArray types
if (!type.IsAutoLayout || type.IsEnum || TypeIsInlineArrayType(type))
if (!type.IsAutoLayout || type.IsEnum || TypeIsInlineArrayType (type))
MarkFields (type, includeStatic: type.IsEnum, reason: new DependencyInfo (DependencyKind.MemberOfType, type));
}

static bool TypeIsInlineArrayType(TypeDefinition type)
static bool TypeIsInlineArrayType (TypeDefinition type)
{
if (!type.IsValueType)
return false;
Expand Down Expand Up @@ -3584,7 +3586,7 @@ protected internal virtual void MarkEvent (EventDefinition evt, in DependencyInf

MarkCustomAttributes (evt, new DependencyInfo (DependencyKind.CustomAttribute, evt));

DependencyKind dependencyKind = PropagateDependencyKindToAccessors(reason.Kind, DependencyKind.EventMethod);
DependencyKind dependencyKind = PropagateDependencyKindToAccessors (reason.Kind, DependencyKind.EventMethod);
MarkMethodIfNotNull (evt.AddMethod, new DependencyInfo (dependencyKind, evt), ScopeStack.CurrentScope.Origin);
MarkMethodIfNotNull (evt.InvokeMethod, new DependencyInfo (dependencyKind, evt), ScopeStack.CurrentScope.Origin);
MarkMethodIfNotNull (evt.RemoveMethod, new DependencyInfo (dependencyKind, evt), ScopeStack.CurrentScope.Origin);
Expand Down
14 changes: 11 additions & 3 deletions src/tools/illink/src/linker/Linker/Annotations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -447,22 +447,30 @@ public bool IsPublic (IMetadataTokenProvider provider)
}

/// <summary>
/// Returns a list of all known methods that override <paramref name="method"/>. The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet
/// Returns a list of all known methods that override <paramref name="method"/>.
/// The list may be incomplete if other overrides exist in assemblies that haven't been processed by TypeMapInfo yet
/// </summary>
public IEnumerable<OverrideInformation>? GetOverrides (MethodDefinition method)
{
return TypeMapInfo.GetOverrides (method);
}

public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface)>? GetDefaultInterfaceImplementations (MethodDefinition method)
/// <summary>
/// Returns a list of all default interface methods that implement <paramref name="method"/> for a type.
/// ImplementingType is the type that implements the interface,
/// InterfaceImpl is the <see cref="InterfaceImplementation" /> for the interface <paramref name="method" /> is declared on, and
/// DefaultInterfaceMethod is the method that implements <paramref name="method"/>.
/// </summary>
/// <param name="method">The interface method to find default implementations for</param>
public IEnumerable<(TypeDefinition ImplementingType, InterfaceImplementation InterfaceImpl, MethodDefinition DefaultInterfaceMethod)>? GetDefaultInterfaceImplementations (MethodDefinition method)
{
return TypeMapInfo.GetDefaultInterfaceImplementations (method);
}

/// <summary>
/// Returns all base methods that <paramref name="method"/> overrides.
/// This includes methods on <paramref name="method"/>'s declaring type's base type (but not methods higher up in the type hierarchy),
/// methods on an interface that <paramref name="method"/>'s delcaring type implements,
/// methods on an interface that <paramref name="method"/>'s declaring type implements,
/// and methods an interface implemented by a derived type of <paramref name="method"/>'s declaring type if the derived type uses <paramref name="method"/> as the implementing method.
/// The list may be incomplete if there are derived types in assemblies that havent been processed yet that use <paramref name="method"/> to implement an interface.
/// </summary>
Expand Down
28 changes: 20 additions & 8 deletions src/tools/illink/src/linker/Linker/TypeMapInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
//

using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Mono.Cecil;

Expand All @@ -42,7 +43,7 @@ public class TypeMapInfo
readonly LinkContext context;
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> base_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<OverrideInformation>> override_methods = new Dictionary<MethodDefinition, List<OverrideInformation>> ();
protected readonly Dictionary<MethodDefinition, List<(TypeDefinition InstanceType, InterfaceImplementation ImplementationProvider)>> default_interface_implementations = new Dictionary<MethodDefinition, List<(TypeDefinition, InterfaceImplementation)>> ();
protected readonly Dictionary<MethodDefinition, List<(TypeDefinition InstanceType, InterfaceImplementation ImplementationProvider, MethodDefinition DefaultImplementationMethod)>> default_interface_implementations = new Dictionary<MethodDefinition, List<(TypeDefinition, InterfaceImplementation, MethodDefinition)>> ();

public TypeMapInfo (LinkContext context)
{
Expand Down Expand Up @@ -84,9 +85,16 @@ public void EnsureProcessed (AssemblyDefinition assembly)
return bases;
}

public IEnumerable<(TypeDefinition InstanceType, InterfaceImplementation ProvidingInterface)>? GetDefaultInterfaceImplementations (MethodDefinition method)
/// <summary>
/// Returns a list of all default interface methods that implement <paramref name="method"/> for a type.
/// ImplementingType is the type that implements the interface,
/// InterfaceImpl is the <see cref="InterfaceImplementation" /> for the interface <paramref name="method" /> is declared on, and
/// DefaultInterfaceMethod is the method that implements <paramref name="method"/>.
/// </summary>
/// <param name="method">The interface method to find default implementations for</param>
public IEnumerable<(TypeDefinition ImplementingType, InterfaceImplementation InterfaceImpl, MethodDefinition DefaultImplementationMethod)>? GetDefaultInterfaceImplementations (MethodDefinition baseMethod)
{
default_interface_implementations.TryGetValue (method, out var ret);
default_interface_implementations.TryGetValue (baseMethod, out var ret);
return ret;
}

Expand All @@ -110,14 +118,15 @@ public void AddOverride (MethodDefinition @base, MethodDefinition @override, Int
methods.Add (new OverrideInformation (@base, @override, context, matchingInterfaceImplementation));
}

public void AddDefaultInterfaceImplementation (MethodDefinition @base, TypeDefinition implementingType, InterfaceImplementation matchingInterfaceImplementation)
public void AddDefaultInterfaceImplementation (MethodDefinition @base, TypeDefinition implementingType, (InterfaceImplementation, MethodDefinition) matchingInterfaceImplementation)
{
Debug.Assert(@base.DeclaringType.IsInterface);
if (!default_interface_implementations.TryGetValue (@base, out var implementations)) {
implementations = new List<(TypeDefinition, InterfaceImplementation)> ();
implementations = new List<(TypeDefinition, InterfaceImplementation, MethodDefinition)> ();
default_interface_implementations.Add (@base, implementations);
}

implementations.Add ((implementingType, matchingInterfaceImplementation));
implementations.Add ((implementingType, matchingInterfaceImplementation.Item1, matchingInterfaceImplementation.Item2));
}

protected virtual void MapType (TypeDefinition type)
Expand Down Expand Up @@ -278,6 +287,7 @@ void FindAndAddDefaultInterfaceImplementations (TypeDefinition type, MethodDefin
{
// Go over all interfaces, trying to find a method that is an explicit MethodImpl of the
// interface method in question.

foreach (var interfaceImpl in type.Interfaces) {
var potentialImplInterface = context.TryResolve (interfaceImpl.InterfaceType);
if (potentialImplInterface == null)
Expand All @@ -288,7 +298,9 @@ void FindAndAddDefaultInterfaceImplementations (TypeDefinition type, MethodDefin
foreach (var potentialImplMethod in potentialImplInterface.Methods) {
if (potentialImplMethod == interfaceMethod &&
!potentialImplMethod.IsAbstract) {
AddDefaultInterfaceImplementation (interfaceMethod, type, interfaceImpl);
AddDefaultInterfaceImplementation (interfaceMethod, type, (interfaceImpl, potentialImplMethod));
foundImpl = true;
break;
}

if (!potentialImplMethod.HasOverrides)
Expand All @@ -297,7 +309,7 @@ void FindAndAddDefaultInterfaceImplementations (TypeDefinition type, MethodDefin
// This method is an override of something. Let's see if it's the method we are looking for.
foreach (var @override in potentialImplMethod.Overrides) {
if (context.TryResolve (@override) == interfaceMethod) {
AddDefaultInterfaceImplementation (interfaceMethod, type, interfaceImpl);
AddDefaultInterfaceImplementation (interfaceMethod, type, (interfaceImpl, @potentialImplMethod));
foundImpl = true;
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,30 @@ public Task InterfaceWithAttributeOnImplementation ()
return RunTest (allowMissingWarnings: true);
}

[Fact]
public Task MostSpecificDefaultImplementationKeptInstance ()
{
return RunTest (allowMissingWarnings: true);
}

[Fact]
public Task MostSpecificDefaultImplementationKeptStatic ()
{
return RunTest (allowMissingWarnings: true);
}

[Fact]
public Task SimpleDefaultInterfaceMethod ()
{
return RunTest (allowMissingWarnings: true);
}

[Fact]
public Task StaticDefaultInterfaceMethodOnStruct ()
{
return RunTest (allowMissingWarnings: true);
}

[Fact]
public Task UnusedDefaultInterfaceImplementation ()
{
Expand Down
Loading

0 comments on commit 6ac1edf

Please sign in to comment.