Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove base type rooting for types in rooted assemblies #92864

Merged
merged 5 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ namespace ILCompiler
{
public class RootingHelpers
{
public static bool TryRootType(IRootingServiceProvider rootProvider, TypeDesc type, string reason)
public static bool TryRootType(IRootingServiceProvider rootProvider, TypeDesc type, bool rootBaseTypes, string reason)
{
try
{
RootType(rootProvider, type, reason);
RootType(rootProvider, type, rootBaseTypes, reason);
return true;
}
catch (TypeSystemException)
Expand All @@ -24,7 +24,7 @@ public static bool TryRootType(IRootingServiceProvider rootProvider, TypeDesc ty
}
}

public static void RootType(IRootingServiceProvider rootProvider, TypeDesc type, string reason)
public static void RootType(IRootingServiceProvider rootProvider, TypeDesc type, bool rootBaseTypes, string reason)
{
rootProvider.AddReflectionRoot(type, reason);

Expand All @@ -40,13 +40,13 @@ public static void RootType(IRootingServiceProvider rootProvider, TypeDesc type,
rootProvider.AddReflectionRoot(type, reason);
}

// Also root base types. This is so that we make methods on the base types callable.
// This helps in cases like "class Foo : Bar<int> { }" where we discover new
// generic instantiations.
TypeDesc baseType = type.BaseType;
if (baseType != null)
if (rootBaseTypes)
{
RootType(rootProvider, baseType.NormalizeInstantiation(), reason);
TypeDesc baseType = type.BaseType;
if (baseType != null)
{
RootType(rootProvider, baseType.NormalizeInstantiation(), rootBaseTypes, reason);
}
}

if (type.IsDefType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ protected override void GetMetadataDependenciesDueToReflectability(ref Dependenc
var rootProvider = new RootingServiceProvider(factory, dependencies.Add);
foreach (TypeDesc t in mdType.Module.GetAllTypes())
{
RootingHelpers.TryRootType(rootProvider, t, reason);
RootingHelpers.TryRootType(rootProvider, t, rootBaseTypes: false, reason);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ public static IEnumerable<object[]> InlineArrays ()
return TestNamesBySuiteName();
}

public static IEnumerable<object[]> LinkXml()
public static IEnumerable<object[]> Libraries()
{
return TestNamesBySuiteName();
}

public static IEnumerable<object[]> LinkXml()
{
return TestNamesBySuiteName();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ public void InlineArrays(string t)
Run(t);
}

[Theory]
[Theory]
[MemberData(nameof(TestDatabase.Libraries), MemberType = typeof(TestDatabase))]
public void Libraries(string t)
{
Run(t);
}

[Theory]
[MemberData (nameof (TestDatabase.LinkXml), MemberType = typeof (TestDatabase))]
public void LinkXml (string t)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ public void Verify ()
throw new NotImplementedException ($"Don't know how to check member of type {originalMember.GetType ()}");
}

// Filter out all members which are not from the main assembly
// The Kept attributes are "optional" for non-main assemblies
string mainModuleName = originalAssembly.Name.Name;
// Verify anything not in the main assembly
VerifyLinkingOfOtherAssemblies(this.originalAssembly);

// Filter out all members which are not from the main assembly
// The Kept attributes are "optional" for non-main assemblies
string mainModuleName = originalAssembly.Name.Name;
List<AssemblyQualifiedToken> externalMembers = linkedMembers.Where (m => GetModuleName (m.Value.Entity) != mainModuleName).Select (m => m.Key).ToList ();
foreach (var externalMember in externalMembers) {
linkedMembers.Remove (externalMember);
Expand All @@ -136,7 +139,7 @@ public void Verify ()
false,
"Linked output includes unexpected member:\n " +
string.Join ("\n ", linkedMembers.Values.Select (e => e.Entity.GetDisplayName ())));
}
}

private void PopulateLinkedMembers ()
{
Expand Down Expand Up @@ -276,12 +279,23 @@ static bool ShouldIncludeType (TypeDesc type)
static bool ShouldIncludeMethod (MethodDesc method) => ShouldIncludeType (method.OwningType) && ShouldIncludeEntityByDisplayName (method);
}

private static MetadataType? GetOwningType (TypeSystemEntity? entity)
{
return entity switch
{
MetadataType type => type.ContainingType as MetadataType,
MethodDesc method => method.OwningType as MetadataType,
PropertyPseudoDesc prop => prop.OwningType,
EventPseudoDesc e => e.OwningType,
_ => null
};
}

private static string? GetModuleName (TypeSystemEntity entity)
{
return entity switch {
MetadataType type => type.Module.ToString (),
MethodDesc { OwningType: MetadataType owningType } => owningType.Module.ToString (),
_ => null
_ => GetOwningType(entity)?.Module.ToString()
};
}

Expand Down Expand Up @@ -1310,38 +1324,38 @@ private static bool HasActiveKeptDerivedAttribute (ICustomAttributeProvider prov
return GetActiveKeptDerivedAttributes (provider).Any ();
}

private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original)
internal void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original)
{
var checks = BuildOtherAssemblyCheckTable (original);

// TODO
// For now disable the code below by simply removing all checks
checks.Clear ();

try {
foreach (var assemblyName in checks.Keys) {
var linkedAssembly = ResolveLinkedAssembly (assemblyName);
var linkedMembersInAssembly = ResolveLinkedMembersForAssembly (assemblyName);
var originalTargetAssembly = ResolveOriginalsAssembly(assemblyName);
foreach (var checkAttrInAssembly in checks[assemblyName]) {
var attributeTypeName = checkAttrInAssembly.AttributeType.Name;

switch (attributeTypeName) {
case nameof (KeptAllTypesAndMembersInAssemblyAttribute):
VerifyKeptAllTypesAndMembersInAssembly (linkedAssembly);
VerifyKeptAllTypesAndMembersInAssembly (assemblyName, linkedMembersInAssembly);
continue;
case nameof (KeptAttributeInAssemblyAttribute):
VerifyKeptAttributeInAssembly (checkAttrInAssembly, linkedAssembly);
// VerifyKeptAttributeInAssembly (checkAttrInAssembly, linkedAssembly);
continue;
case nameof (RemovedAttributeInAssembly):
VerifyRemovedAttributeInAssembly (checkAttrInAssembly, linkedAssembly);
// VerifyRemovedAttributeInAssembly (checkAttrInAssembly, linkedAssembly);
continue;
default:
break;
}

var expectedTypeName = checkAttrInAssembly.ConstructorArguments[1].Value.ToString ()!;
TypeDefinition? linkedType = linkedAssembly.MainModule.GetType (expectedTypeName);
var expectedTypeName = checkAttrInAssembly.ConstructorArguments[1].Value.ToString ()!;
var expectedType = originalTargetAssembly.MainModule.GetType(expectedTypeName);
linkedMembersInAssembly.TryGetValue(new AssemblyQualifiedToken(expectedType), out LinkedEntity? linkedTypeEntity);
MetadataType? linkedType = linkedTypeEntity?.Entity as MetadataType;

if (linkedType == null && linkedAssembly.MainModule.HasExportedTypes) {
#if false
if (linkedType == null && linkedAssembly.MainModule.HasExportedTypes) {
ExportedType? exportedType = linkedAssembly.MainModule.ExportedTypes
.FirstOrDefault (exported => exported.FullName == expectedTypeName);

Expand All @@ -1353,6 +1367,7 @@ private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original)

linkedType = exportedType?.Resolve ();
}
#endif

switch (attributeTypeName) {
case nameof (RemovedTypeInAssemblyAttribute):
Expand All @@ -1364,6 +1379,7 @@ private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original)
if (linkedType == null)
Assert.Fail ($"Type `{expectedTypeName}' should have been kept in assembly {assemblyName}");
break;
#if false
case nameof (RemovedInterfaceOnTypeInAssemblyAttribute):
if (linkedType == null)
Assert.Fail ($"Type `{expectedTypeName}' should have been kept in assembly {assemblyName}");
Expand Down Expand Up @@ -1416,11 +1432,15 @@ private void VerifyLinkingOfOtherAssemblies (AssemblyDefinition original)
Assert.Fail ($"Type `{expectedTypeName}` should have been kept in assembly {assemblyName}");
VerifyExpectedInstructionSequenceOnMemberInAssembly (checkAttrInAssembly, linkedType);
break;
default:
default:
UnhandledOtherAssemblyAssertion (expectedTypeName, checkAttrInAssembly, linkedType);
break;
}
}
#else
default:
break;
#endif
}
}
}
} catch (AssemblyResolutionException e) {
Assert.Fail ($"Failed to resolve linked assembly `{e.AssemblyReference.Name}`. It must not exist in the output.");
Expand Down Expand Up @@ -1712,54 +1732,62 @@ protected virtual bool TryVerifyKeptMemberInAssemblyAsMethod (string memberName,

private void VerifyKeptReferencesInAssembly (CustomAttribute inAssemblyAttribute)
{
var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!);
#if false
var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!);
var expectedReferenceNames = ((CustomAttributeArgument[]) inAssemblyAttribute.ConstructorArguments[1].Value).Select (attr => (string) attr.Value).ToList ();
for (int i = 0; i < expectedReferenceNames.Count; i++)
if (expectedReferenceNames[i].EndsWith (".dll"))
expectedReferenceNames[i] = expectedReferenceNames[i].Substring (0, expectedReferenceNames[i].LastIndexOf ("."));

Assert.Equal (assembly.MainModule.AssemblyReferences.Select (asm => asm.Name), expectedReferenceNames);
#endif
}

private void VerifyKeptResourceInAssembly (CustomAttribute inAssemblyAttribute)
{
var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!);
#if false
var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!);
var resourceName = inAssemblyAttribute.ConstructorArguments[1].Value.ToString ();

Assert.Contains (resourceName, assembly.MainModule.Resources.Select (r => r.Name));
#endif
}

private void VerifyRemovedResourceInAssembly (CustomAttribute inAssemblyAttribute)
{
var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!);
#if false
var assembly = ResolveLinkedAssembly (inAssemblyAttribute.ConstructorArguments[0].Value.ToString ()!);
var resourceName = inAssemblyAttribute.ConstructorArguments[1].Value.ToString ();

Assert.DoesNotContain (resourceName, assembly.MainModule.Resources.Select (r => r.Name));
#endif
}

private void VerifyKeptAllTypesAndMembersInAssembly (AssemblyDefinition linked)
private void VerifyKeptAllTypesAndMembersInAssembly (string assemblyName, Dictionary<AssemblyQualifiedToken, LinkedEntity> linkedMembers)
{
var original = ResolveOriginalsAssembly (linked.MainModule.Assembly.Name.Name);
var original = ResolveOriginalsAssembly (assemblyName);

if (original == null)
Assert.Fail ($"Failed to resolve original assembly {linked.MainModule.Assembly.Name.Name}");
Assert.Fail ($"Failed to resolve original assembly {assemblyName}");

var originalTypes = original.AllDefinedTypes ().ToDictionary (t => t.FullName);
var linkedTypes = linked.AllDefinedTypes ().ToDictionary (t => t.FullName);
var originalTypes = original.AllDefinedTypes ().ToDictionary (t => new AssemblyQualifiedToken(t));
var linkedTypes = linkedMembers.Where(t => t.Value.Entity is TypeDesc).ToDictionary();

var missingInLinked = originalTypes.Keys.Except (linkedTypes.Keys);

Assert.True (missingInLinked.Any (), $"Expected all types to exist in the linked assembly, but one or more were missing");
Assert.False (missingInLinked.Any (), $"Expected all types to exist in the linked assembly {assemblyName}, but one or more were missing");

foreach (var originalKvp in originalTypes) {
var linkedType = linkedTypes[originalKvp.Key];
TypeDesc linkedTypeDesc = (TypeDesc)linkedType.Entity;

var originalMembers = originalKvp.Value.AllMembers ().Select (m => m.FullName);
var linkedMembers = linkedType.AllMembers ().Select (m => m.FullName);
// NativeAOT field trimming is very different (it basically doesn't trim fields, not in the same way trimmer does)
var originalMembers = originalKvp.Value.AllMembers ().Where(m => m is not FieldDefinition).Select (m => new AssemblyQualifiedToken(m));
var linkedMembersOnType = linkedMembers.Where(t => GetOwningType(t.Value.Entity) == linkedTypeDesc).Select(t => t.Key);

var missingMembersInLinked = originalMembers.Except (linkedMembers);
var missingMembersInLinked = originalMembers.Except (linkedMembersOnType);

Assert.True (missingMembersInLinked.Any (), $"Expected all members of `{originalKvp.Key}`to exist in the linked assembly, but one or more were missing");
Assert.False (missingMembersInLinked.Any (), $"Expected all members of `{linkedTypeDesc.GetDisplayName()}`to exist in the linked assembly, but one or more were missing");
}
}

Expand Down Expand Up @@ -1795,6 +1823,11 @@ private static Dictionary<string, List<CustomAttribute>> BuildOtherAssemblyCheck
foreach (var typeWithRemoveInAssembly in original.AllDefinedTypes ()) {
foreach (var attr in typeWithRemoveInAssembly.CustomAttributes.Where (IsTypeInOtherAssemblyAssertion)) {
var assemblyName = (string) attr.ConstructorArguments[0].Value;

Tool? toolTarget = (Tool?)(int?)attr.GetPropertyValue("Tool");
if (toolTarget is not null && !toolTarget.Value.HasFlag(Tool.NativeAot))
continue;

if (!checks.TryGetValue (assemblyName, out List<CustomAttribute>? checksForAssembly))
checks[assemblyName] = checksForAssembly = new List<CustomAttribute> ();

Expand All @@ -1805,14 +1838,13 @@ private static Dictionary<string, List<CustomAttribute>> BuildOtherAssemblyCheck
return checks;
}

protected AssemblyDefinition ResolveLinkedAssembly (string assemblyName)
private Dictionary<AssemblyQualifiedToken, LinkedEntity> ResolveLinkedMembersForAssembly (string assemblyName)
{
//var cleanAssemblyName = assemblyName;
//if (assemblyName.EndsWith (".exe") || assemblyName.EndsWith (".dll"))
//cleanAssemblyName = System.IO.Path.GetFileNameWithoutExtension (assemblyName);
//return _linkedResolver.Resolve (new AssemblyNameReference (cleanAssemblyName, null), _linkedReaderParameters);
// TODO - adapt to Native AOT
return ResolveOriginalsAssembly (assemblyName);
var cleanAssemblyName = assemblyName;
if (assemblyName.EndsWith(".exe") || assemblyName.EndsWith(".dll"))
cleanAssemblyName = System.IO.Path.GetFileNameWithoutExtension(assemblyName);

return this.linkedMembers.Where(e => GetModuleName(e.Value.Entity) == cleanAssemblyName).ToDictionary();
}

protected AssemblyDefinition ResolveOriginalsAssembly (string assemblyName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
using Mono.Cecil;
using Mono.Cecil.Cil;
using Mono.Linker.Tests.Cases.Expectations.Assertions;
using Mono.Linker.Tests.Cases.Expectations.Metadata;
using Mono.Linker.Tests.Extensions;
using Xunit;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public ILScanResults Trim (ILCompilerOptions options, TrimmingCustomizations? cu
new ManifestResourceBlockingPolicy (logger, options.FeatureSwitches, new Dictionary<ModuleDesc, IReadOnlySet<string>>()),
logFile: null,
new NoStackTraceEmissionPolicy (),
new NoDynamicInvokeThunkGenerationPolicy (),
new DefaultDynamicInvokeThunkGenerationPolicy (),
new FlowAnnotations (logger, ilProvider, compilerGeneratedState),
UsageBasedMetadataGenerationOptions.ReflectionILScanning,
options: default,
Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/tools/aot/ILCompiler/RdXmlRootProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ private void ProcessAssemblyDirective(IRootingServiceProvider rootProvider, XEle

foreach (TypeDesc type in ((EcmaModule)assembly).GetAllTypes())
{
RootingHelpers.TryRootType(rootProvider, type, "RD.XML root");
RootingHelpers.TryRootType(rootProvider, type, rootBaseTypes: true, "RD.XML root");
}
}

Expand Down Expand Up @@ -103,7 +103,7 @@ private static void ProcessTypeDirective(IRootingServiceProvider rootProvider, M
if (dynamicDegreeAttribute.Value != "Required All")
throw new NotSupportedException($"\"{dynamicDegreeAttribute.Value}\" is not a supported value for the \"Dynamic\" attribute of the \"Type\" Runtime Directive. Supported values are \"Required All\".");

RootingHelpers.RootType(rootProvider, type, "RD.XML root");
RootingHelpers.RootType(rootProvider, type, rootBaseTypes: true, "RD.XML root");
}

var marshalStructureDegreeAttribute = typeElement.Attribute("MarshalStructure");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ void trimAndEnsureCapacity()

private static int GetUnderlyingBufferCapacity<TPriority, TElement>(PriorityQueue<TPriority, TElement> queue)
{
FieldInfo nodesField = queue.GetType().GetField("_nodes", BindingFlags.NonPublic | BindingFlags.Instance);
FieldInfo nodesField = typeof(PriorityQueue<TPriority, TElement>).GetField("_nodes", BindingFlags.NonPublic | BindingFlags.Instance);
Assert.NotNull(nodesField);
var nodes = ((TElement Element, TPriority Priority)[])nodesField.GetValue(queue);
return nodes.Length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public void WriteString_NotUtf8(int stringLengthInChars)

private static bool IsUsingFastUtf8(BinaryWriter writer)
{
return (bool)writer.GetType().GetField("_useFastUtf8", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(writer);
return (bool)typeof(BinaryWriter).GetField("_useFastUtf8", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(writer);
}

private static string GenerateLargeUnicodeString(int charCount)
Expand Down
4 changes: 2 additions & 2 deletions src/libraries/System.Security.Cryptography/tests/DSATests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ protected override void Dispose(bool disposing)
public override void ImportParameters(DSAParameters parameters) => _dsa.ImportParameters(parameters);
public override bool VerifySignature(byte[] rgbHash, byte[] rgbSignature) => _dsa.VerifySignature(rgbHash, rgbSignature);
protected override byte[] HashData(Stream data, HashAlgorithmName hashAlgorithm) =>
(byte[])_dsa.GetType().GetMethod(
(byte[])typeof(DSA).GetMethod(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are multiple implementations of DSA/ECDSA so this is technically getting a different method, but since it's virtual it probably doesn't matter. Just wanted to point out for double checking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - I was aware of that, but as you stated, the method is declared as virtual on the DSA class, so unless some derived class does new this will work (and new would be really weird for a protected virtual method). If this was in product I would be much more worried, but for tests I think this is perfectly OK.

nameof(HashData),
BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance,
null,
new Type[] { typeof(Stream), typeof(HashAlgorithmName) },
null)
.Invoke(_dsa, new object[] { data, hashAlgorithm });
protected override byte[] HashData(byte[] data, int offset, int count, HashAlgorithmName hashAlgorithm) =>
(byte[])_dsa.GetType().GetMethod(
(byte[])typeof(DSA).GetMethod(
nameof(HashData),
BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance,
null,
Expand Down
Loading
Loading