Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ILLink] Mark recursive interface implementations in MarkStep #99922

Merged
merged 8 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions src/tools/illink/src/linker/Linker.Steps/MarkStep.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using System.Reflection.Runtime.TypeParsing;
using System.Runtime.CompilerServices;
using System.Text.RegularExpressions;
using ILLink.Shared;
using ILLink.Shared.TrimAnalysis;
Expand Down Expand Up @@ -2450,26 +2450,27 @@ void MarkNamedProperty (TypeDefinition type, string property_name, in Dependency

void MarkInterfaceImplementations (TypeDefinition type)
{
if (!type.HasInterfaces)
var ifaces = Annotations.GetRecursiveInterfaces (type);
if (ifaces is null)
return;

foreach (var iface in type.Interfaces) {
foreach (var (ifaceType, impls) in ifaces) {
// Only mark interface implementations of interface types that have been marked.
// This enables stripping of interfaces that are never used
if (ShouldMarkInterfaceImplementation (type, iface))
MarkInterfaceImplementation (iface, new MessageOrigin (type));
if (ShouldMarkInterfaceImplementationList (type, impls, ifaceType))
MarkInterfaceImplementationList (impls, new MessageOrigin (type));
}
}

protected virtual bool ShouldMarkInterfaceImplementation (TypeDefinition type, InterfaceImplementation iface)

protected virtual bool ShouldMarkInterfaceImplementationList (TypeDefinition type, List<InterfaceImplementation> ifaces, TypeReference ifaceType)
{
if (Annotations.IsMarked (iface))
if (ifaces.All (Annotations.IsMarked))
return false;

if (!Context.IsOptimizationEnabled (CodeOptimizations.UnusedInterfaces, type))
return true;

if (Context.Resolve (iface.InterfaceType) is not TypeDefinition resolvedInterfaceType)
if (Context.Resolve (ifaceType) is not TypeDefinition resolvedInterfaceType)
return false;

if (Annotations.IsMarked (resolvedInterfaceType))
Expand Down Expand Up @@ -3764,8 +3765,7 @@ protected virtual void MarkInstruction (Instruction instruction, MethodDefinitio
ScopeStack.UpdateCurrentScopeInstructionOffset (instruction.Offset);
if (markForReflectionAccess) {
MarkMethodVisibleToReflection (methodReference, new DependencyInfo (dependencyKind, method), ScopeStack.CurrentScope.Origin);
}
else {
} else {
MarkMethod (methodReference, new DependencyInfo (dependencyKind, method), ScopeStack.CurrentScope.Origin);
}
break;
Expand Down Expand Up @@ -3825,6 +3825,12 @@ protected virtual void MarkInstruction (Instruction instruction, MethodDefinitio
}
}

void MarkInterfaceImplementationList (List<InterfaceImplementation> ifaces, MessageOrigin? origin = null, DependencyInfo? reason = null)
jtschuster marked this conversation as resolved.
Show resolved Hide resolved
{
foreach (var iface in ifaces) {
MarkInterfaceImplementation (iface, origin, reason);
}
}
jtschuster marked this conversation as resolved.
Show resolved Hide resolved

protected internal virtual void MarkInterfaceImplementation (InterfaceImplementation iface, MessageOrigin? origin = null, DependencyInfo? reason = null)
{
Expand Down
6 changes: 6 additions & 0 deletions src/tools/illink/src/linker/Linker/Annotations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection.Metadata.Ecma335;
using ILLink.Shared.TrimAnalysis;
using Mono.Cecil;
using Mono.Cecil.Cil;
Expand Down Expand Up @@ -717,5 +718,10 @@ public void EnqueueVirtualMethod (MethodDefinition method)
if (FlowAnnotations.RequiresVirtualMethodDataFlowAnalysis (method) || HasLinkerAttribute<RequiresUnreferencedCodeAttribute> (method))
VirtualMethodsWithAnnotationsToValidate.Add (method);
}

internal List<(TypeReference, List<InterfaceImplementation>)>? GetRecursiveInterfaces (TypeDefinition type)
{
return TypeMapInfo.GetRecursiveInterfaces (type);
}
}
}
171 changes: 171 additions & 0 deletions src/tools/illink/src/linker/Linker/MethodReferenceComparer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using Mono.Cecil;

namespace Mono.Linker
{
// Copied from https://github.com/jbevain/cecil/blob/master/Mono.Cecil/MethodReferenceComparer.cs
internal sealed class MethodReferenceComparer : EqualityComparer<MethodReference>
{
// Initialized lazily for each thread
[ThreadStatic]
static List<MethodReference>? xComparisonStack;

[ThreadStatic]
static List<MethodReference>? yComparisonStack;

public readonly ITryResolveMetadata _resolver;

public MethodReferenceComparer(ITryResolveMetadata resolver)
{
_resolver = resolver;
}

public override bool Equals (MethodReference? x, MethodReference? y)
{
return AreEqual (x, y, _resolver);
}

public override int GetHashCode (MethodReference obj)
{
return GetHashCodeFor (obj);
}

public static bool AreEqual (MethodReference? x, MethodReference? y, ITryResolveMetadata resolver)
{
if (ReferenceEquals (x, y))
return true;

if (x is null ^ y is null)
return false;

Debug.Assert (x is not null);
Debug.Assert (y is not null);

if (x.HasThis != y.HasThis)
return false;

#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
if (x.HasParameters != y.HasParameters)
return false;
#pragma warning restore RS0030

if (x.HasGenericParameters != y.HasGenericParameters)
return false;

#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
if (x.Parameters.Count != y.Parameters.Count)
return false;
#pragma warning restore RS0030

if (x.Name != y.Name)
return false;

if (!TypeReferenceEqualityComparer.AreEqual (x.DeclaringType, y.DeclaringType, resolver))
return false;

var xGeneric = x as GenericInstanceMethod;
var yGeneric = y as GenericInstanceMethod;
if (xGeneric != null || yGeneric != null) {
if (xGeneric == null || yGeneric == null)
return false;

if (xGeneric.GenericArguments.Count != yGeneric.GenericArguments.Count)
return false;

for (int i = 0; i < xGeneric.GenericArguments.Count; i++)
if (!TypeReferenceEqualityComparer.AreEqual (xGeneric.GenericArguments[i], yGeneric.GenericArguments[i], resolver))
return false;
}

var xResolved = resolver.TryResolve (x);
var yResolved = resolver.TryResolve (y);

if (xResolved != yResolved)
return false;

if (xResolved == null) {
// We couldn't resolve either method. In order for them to be equal, their parameter types _must_ match. But wait, there's a twist!
// There exists a situation where we might get into a recursive state: parameter type comparison might lead to comparing the same
// methods again if the parameter types are generic parameters whose owners are these methods. We guard against these by using a
// thread static list of all our comparisons carried out in the stack so far, and if we're in progress of comparing them already,
// we'll just say that they match.

xComparisonStack ??= new List<MethodReference> ();

yComparisonStack ??= new List<MethodReference> ();

for (int i = 0; i < xComparisonStack.Count; i++) {
if (xComparisonStack[i] == x && yComparisonStack[i] == y)
return true;
}

xComparisonStack.Add (x);

try {
yComparisonStack.Add (y);

try {
#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
for (int i = 0; i < x.Parameters.Count; i++) {
if (!TypeReferenceEqualityComparer.AreEqual (x.Parameters[i].ParameterType, y.Parameters[i].ParameterType, resolver))
return false;
}
#pragma warning restore RS0030
} finally {
yComparisonStack.RemoveAt (yComparisonStack.Count - 1);
}
} finally {
xComparisonStack.RemoveAt (xComparisonStack.Count - 1);
}
}

return true;
}

public static bool AreSignaturesEqual (MethodReference x, MethodReference y, ITryResolveMetadata resolver, TypeComparisonMode comparisonMode = TypeComparisonMode.Exact)
{
if (x.HasThis != y.HasThis)
return false;

#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
if (x.Parameters.Count != y.Parameters.Count)
return false;
#pragma warning restore RS0030

if (x.GenericParameters.Count != y.GenericParameters.Count)
return false;

#pragma warning disable RS0030 // MethodReference.HasParameters is banned - this code is copied from Cecil
for (var i = 0; i < x.Parameters.Count; i++)
if (!TypeReferenceEqualityComparer.AreEqual (x.Parameters[i].ParameterType, y.Parameters[i].ParameterType, resolver, comparisonMode))
return false;
#pragma warning restore RS0030

if (!TypeReferenceEqualityComparer.AreEqual (x.ReturnType, y.ReturnType, resolver, comparisonMode))
return false;

return true;
}

public static int GetHashCodeFor (MethodReference obj)
{
// a very good prime number
const int hashCodeMultiplier = 486187739;

var genericInstanceMethod = obj as GenericInstanceMethod;
if (genericInstanceMethod != null) {
var hashCode = GetHashCodeFor (genericInstanceMethod.ElementMethod);
for (var i = 0; i < genericInstanceMethod.GenericArguments.Count; i++)
hashCode = hashCode * hashCodeMultiplier + TypeReferenceEqualityComparer.GetHashCodeFor (genericInstanceMethod.GenericArguments[i]);
return hashCode;
}

return TypeReferenceEqualityComparer.GetHashCodeFor (obj.DeclaringType) * hashCodeMultiplier + obj.Name.GetHashCode ();
}
}
}
17 changes: 17 additions & 0 deletions src/tools/illink/src/linker/Linker/TypeComparisonMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) .NET Foundation and contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Mono.Linker
{
// Copied from https://github.com/jbevain/cecil/blob/master/Mono.Cecil/TypeComparisonMode.cs
internal enum TypeComparisonMode
{
Exact,
SignatureOnly,

/// <summary>
/// Types can be in different assemblies, as long as the module, assembly, and type names match they will be considered equal
/// </summary>
SignatureOnlyLoose
}
}
46 changes: 46 additions & 0 deletions src/tools/illink/src/linker/Linker/TypeMapInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ public void AddDefaultInterfaceImplementation (MethodDefinition @base, Interface
default_interface_implementations.AddToList (@base, new OverrideInformation (@base, defaultImplementationMethod, interfaceImplementor));
}

Dictionary<TypeDefinition, List<(TypeReference, List<InterfaceImplementation>)>> interfaces = new ();
protected virtual void MapType (TypeDefinition type)
{
MapVirtualMethods (type);
MapInterfaceMethodsInTypeHierarchy (type);
interfaces[type] = GetRecursiveInterfaceImplementations (type);

if (!type.HasNestedTypes)
return;
Expand All @@ -128,6 +130,50 @@ protected virtual void MapType (TypeDefinition type)
MapType (nested);
}

internal List<(TypeReference, List<InterfaceImplementation>)>? GetRecursiveInterfaces (TypeDefinition type)
{
if (interfaces.TryGetValue (type, out var value))
return value;
return null;
}

List<(TypeReference, List<InterfaceImplementation>)> GetRecursiveInterfaceImplementations (TypeDefinition type)
{
List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain = new ();

AddRecursiveInterfaces (type, [], firstImplementationChain, context);
Debug.Assert (firstImplementationChain.All (kvp => context.Resolve (kvp.Item1) == context.Resolve (kvp.Item2.Last ().InterfaceType)));

return firstImplementationChain;

static void AddRecursiveInterfaces (TypeReference typeRef, IEnumerable<InterfaceImplementation> pathToType, List<(TypeReference, List<InterfaceImplementation>)> firstImplementationChain, LinkContext Context)
{
var type = Context.TryResolve (typeRef);
if (type is null)
return;
// Get all explicit interfaces of this type
foreach (var iface in type.Interfaces) {
var interfaceType = iface.InterfaceType.TryInflateFrom (typeRef, Context);
if (interfaceType is null) {
continue;
}
if (!firstImplementationChain.Any (i => TypeReferenceEqualityComparer.AreEqual (i.Item1, interfaceType, Context))) {
firstImplementationChain.Add ((interfaceType, pathToType.Append (iface).ToList ()));
}
}

// Recursive interfaces after all direct interfaces to preserve Inherit/Implement tree order
foreach (var iface in type.Interfaces) {
// If we can't resolve the interface type we can't find recursive interfaces
var ifaceDirectlyOnType = iface.InterfaceType.TryInflateFrom (typeRef, Context);
if (ifaceDirectlyOnType is null) {
continue;
}
AddRecursiveInterfaces (ifaceDirectlyOnType, pathToType.Append (iface), firstImplementationChain, Context);
}
}
}

void MapInterfaceMethodsInTypeHierarchy (TypeDefinition type)
{
if (!type.HasInterfaces)
Expand Down
Loading
Loading