Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LibraryImportGenerator] Use basic forwarder in down-level support if any parameters can't be marshalled #104416

Merged
merged 4 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,12 @@ private static (MemberDeclarationSyntax, ImmutableArray<DiagnosticInfo>) Generat
LibraryImportGeneratorHelpers.CreateGeneratorResolver(pinvokeStub.TargetFramework, pinvokeStub.Options, pinvokeStub.EnvironmentFlags),
new CodeEmitOptions(SkipInit: pinvokeStub.TargetFramework is (TargetFramework.Net, _)));

// For down-level support, if some parameters cannot be marshalled, consider the target framework as not supported
if (stubGenerator.HasForwardedTypes && (pinvokeStub.TargetFramework.TargetFramework != TargetFramework.Net || pinvokeStub.TargetFramework.Version.Major < 7))
elinor-fung marked this conversation as resolved.
Show resolved Hide resolved
{
supportsTargetFramework = false;
}

// Check if the generator should produce a forwarder stub - regular DllImport.
// This is done if the signature is blittable or the target framework is not supported.
if (stubGenerator.StubIsBasicForwarder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ internal sealed class PInvokeStubCodeGenerator
{
public bool StubIsBasicForwarder { get; }

public bool HasForwardedTypes { get; }

/// <summary>
/// Identifier for managed return value
/// </summary>
Expand Down Expand Up @@ -74,6 +76,12 @@ public PInvokeStubCodeGenerator(
// Check if generator is either blittable or just a forwarder.
noMarshallingNeeded &= generator is { Generator: BlittableMarshaller, TypeInfo.IsByRef: false }
or { Generator: Forwarder };

// Track if any generators are just forwarders - for types other than void, this indicates
// types that can't be marshalled by the source generated
elinor-fung marked this conversation as resolved.
Show resolved Hide resolved
// In .NET 7+ support, we would have emitted a diagnostic error about lack of support
// In down-level support, we do not error - tracking this allows us to switch to generating a basic forwarder stub
elinor-fung marked this conversation as resolved.
Show resolved Hide resolved
HasForwardedTypes |= generator is { Generator: Forwarder, TypeInfo.ManagedType: not SpecialTypeInfo { SpecialType: Microsoft.CodeAnalysis.SpecialType.System_Void } };
}

StubIsBasicForwarder = !setLastError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,80 +331,6 @@ await VerifySourceGeneratorAsync(
});
}

[Fact]
[OuterLoop("Uses the network for downlevel ref packs")]
public async Task InOutAttributes_Forwarded_To_ForwardedParameter()
{
// This code is invalid configuration from the source generator's perspective.
// We just use it as validation for forwarding the In and Out attributes.
string source = """
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
partial class C
{
[LibraryImportAttribute("DoesNotExist")]
[return: MarshalAs(UnmanagedType.Bool)]
public static partial bool Method1([In, Out] int {|SYSLIB1051:a|});
}
""" + CodeSnippets.LibraryImportAttributeDeclaration;

await VerifySourceGeneratorAsync(
source,
(targetMethod, newComp) =>
{
INamedTypeSymbol marshalAsAttribute = newComp.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute)!;
INamedTypeSymbol inAttribute = newComp.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_InAttribute)!;
INamedTypeSymbol outAttribute = newComp.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_OutAttribute)!;
Assert.Collection(targetMethod.Parameters,
param => Assert.Collection(param.GetAttributes(),
attr =>
{
Assert.Equal(inAttribute, attr.AttributeClass, SymbolEqualityComparer.Default);
Assert.Empty(attr.ConstructorArguments);
Assert.Empty(attr.NamedArguments);
},
attr =>
{
Assert.Equal(outAttribute, attr.AttributeClass, SymbolEqualityComparer.Default);
Assert.Empty(attr.ConstructorArguments);
Assert.Empty(attr.NamedArguments);
}));
},
TestTargetFramework.Standard);
}

[Fact]
[OuterLoop("Uses the network for downlevel ref packs")]
public async Task MarshalAsAttribute_Forwarded_To_ForwardedParameter()
{
string source = """
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
partial class C
{
[LibraryImportAttribute("DoesNotExist")]
[return: MarshalAs(UnmanagedType.Bool)]
public static partial bool Method1([MarshalAs(UnmanagedType.I2)] int a);
}
""" + CodeSnippets.LibraryImportAttributeDeclaration;

await VerifySourceGeneratorAsync(
source,
(targetMethod, newComp) =>
{
INamedTypeSymbol marshalAsAttribute = newComp.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_MarshalAsAttribute)!;
Assert.Collection(targetMethod.Parameters,
param => Assert.Collection(param.GetAttributes(),
attr =>
{
Assert.Equal(marshalAsAttribute, attr.AttributeClass, SymbolEqualityComparer.Default);
Assert.Equal(UnmanagedType.I2, (UnmanagedType)attr.ConstructorArguments[0].Value!);
Assert.Empty(attr.NamedArguments);
}));
},
TestTargetFramework.Standard);
}

private static Task VerifySourceGeneratorAsync(string source, Action<IMethodSymbol, Compilation> targetPInvokeAssertion, TestTargetFramework targetFramework = TestTargetFramework.Net)
{
var test = new GeneratedTargetPInvokeTest(targetPInvokeAssertion, targetFramework)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -813,13 +813,16 @@ partial class Test
}
""";

public static string BasicReturnAndParameterByValue(string returnType, string parameterType, string preDeclaration = "") => $$"""
/// <summary>
/// Declaration with a non-blittable parameter that is always supported for marshalling
/// </summary>
public static string BasicReturnAndParameterWithAlwaysSupportedParameter(string returnType, string parameterType, string preDeclaration = "") => $$"""
using System.Runtime.InteropServices;
{{preDeclaration}}
partial class Test
{
[LibraryImport("DoesNotExist")]
public static partial {{returnType}} Method({{parameterType}} p);
public static partial {{returnType}} Method({{parameterType}} p, out int i);
}
""";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,46 @@ public static IEnumerable<object[]> CodeSnippetsToValidateFallbackForwarder()
yield return new object[] { ID(), code, TestTargetFramework.Framework, true };
}

// Confirm that if support is missing for any type (like arrays), we fall back to a forwarder even if other types are supported.
// Confirm that if support is missing for a type with an ITypeBasedMarshallingInfoProvider (like arrays and SafeHandles), we fall back to a forwarder even if other types are supported.
{
string code = CodeSnippets.BasicReturnAndParameterByValue("System.Runtime.InteropServices.SafeHandle", "int[]", CodeSnippets.LibraryImportAttributeDeclaration);
string code = CodeSnippets.BasicReturnAndParameterWithAlwaysSupportedParameter("void", "System.Runtime.InteropServices.SafeHandle", CodeSnippets.LibraryImportAttributeDeclaration);
yield return new object[] { ID(), code, TestTargetFramework.Net6, true };
yield return new object[] { ID(), code, TestTargetFramework.Core, true };
yield return new object[] { ID(), code, TestTargetFramework.Standard, true };
yield return new object[] { ID(), code, TestTargetFramework.Framework, true };
}
{
string code = CodeSnippets.BasicReturnAndParameterWithAlwaysSupportedParameter("System.Runtime.InteropServices.SafeHandle", "int", CodeSnippets.LibraryImportAttributeDeclaration);
yield return new object[] { ID(), code, TestTargetFramework.Net6, true };
yield return new object[] { ID(), code, TestTargetFramework.Core, true };
yield return new object[] { ID(), code, TestTargetFramework.Standard, true };
yield return new object[] { ID(), code, TestTargetFramework.Framework, true };
}
{
string code = CodeSnippets.BasicReturnAndParameterWithAlwaysSupportedParameter("void", "int[]", CodeSnippets.LibraryImportAttributeDeclaration);
yield return new object[] { ID(), code, TestTargetFramework.Net6, true };
yield return new object[] { ID(), code, TestTargetFramework.Core, true };
yield return new object[] { ID(), code, TestTargetFramework.Standard, true };
yield return new object[] { ID(), code, TestTargetFramework.Framework, true };
}
{
string code = CodeSnippets.BasicReturnAndParameterWithAlwaysSupportedParameter("int", "int[]", CodeSnippets.LibraryImportAttributeDeclaration);
yield return new object[] { ID(), code, TestTargetFramework.Net6, true };
yield return new object[] { ID(), code, TestTargetFramework.Core, true };
yield return new object[] { ID(), code, TestTargetFramework.Standard, true };
yield return new object[] { ID(), code, TestTargetFramework.Framework, true };
}

// Confirm that if support is missing for a type without an ITypeBasedMarshallingInfoProvider (like StringBuilder), we fall back to a forwarder even if other types are supported.
{
string code = CodeSnippets.BasicReturnAndParameterWithAlwaysSupportedParameter("void", "System.Text.StringBuilder", CodeSnippets.LibraryImportAttributeDeclaration);
yield return new object[] { ID(), code, TestTargetFramework.Net6, true };
yield return new object[] { ID(), code, TestTargetFramework.Core, true };
yield return new object[] { ID(), code, TestTargetFramework.Standard, true };
yield return new object[] { ID(), code, TestTargetFramework.Framework, true };
}
{
string code = CodeSnippets.BasicReturnAndParameterWithAlwaysSupportedParameter("int", "System.Text.StringBuilder", CodeSnippets.LibraryImportAttributeDeclaration);
yield return new object[] { ID(), code, TestTargetFramework.Net6, true };
yield return new object[] { ID(), code, TestTargetFramework.Core, true };
yield return new object[] { ID(), code, TestTargetFramework.Standard, true };
Expand Down Expand Up @@ -724,7 +761,6 @@ public class Basic { }
await test.RunAsync();
}


[OuterLoop("Uses the network for downlevel ref packs")]
[InlineData(TestTargetFramework.Standard)]
[InlineData(TestTargetFramework.Framework)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ partial class Test
public static partial void {|#0:Method1|}(string s);

[LibraryImport("DoesNotExist", StringMarshalling = StringMarshalling.Custom, StringMarshallingCustomType = typeof(Native))]
public static partial void Method2(string {|#1:s|});
public static partial void {|#2:Method2|}(string {|#1:s|});

struct Native
{
Expand All @@ -266,7 +266,13 @@ public Native(string s) { }
.WithArguments($"{nameof(TypeNames.LibraryImportAttribute)}{Type.Delimiter}{nameof(StringMarshalling)}={nameof(StringMarshalling)}{Type.Delimiter}{nameof(StringMarshalling.Utf8)}"),
VerifyCS.Diagnostic(GeneratorDiagnostics.ParameterTypeNotSupportedWithDetails)
.WithLocation(1)
.WithArguments("Marshalling string or char without explicit marshalling information is not supported. Specify 'LibraryImportAttribute.StringMarshalling', 'LibraryImportAttribute.StringMarshallingCustomType', 'MarshalUsingAttribute' or 'MarshalAsAttribute'.", "s")
.WithArguments("Marshalling string or char without explicit marshalling information is not supported. Specify 'LibraryImportAttribute.StringMarshalling', 'LibraryImportAttribute.StringMarshallingCustomType', 'MarshalUsingAttribute' or 'MarshalAsAttribute'.", "s"),
VerifyCS.Diagnostic(GeneratorDiagnostics.CannotForwardToDllImport)
.WithLocation(2)
.WithArguments($"{nameof(TypeNames.LibraryImportAttribute)}{Type.Delimiter}{nameof(StringMarshalling)}={nameof(StringMarshalling)}{Type.Delimiter}{nameof(StringMarshalling.Custom)}"),
VerifyCS.Diagnostic(GeneratorDiagnostics.CannotForwardToDllImport)
.WithLocation(2)
.WithArguments($"{nameof(TypeNames.LibraryImportAttribute)}{Type.Delimiter}{nameof(LibraryImportAttribute.StringMarshallingCustomType)}")
};

var test = new VerifyCS.Test(TestTargetFramework.Standard)
Expand All @@ -289,10 +295,10 @@ partial class Test
{
[{|#0:LibraryImport("DoesNotExist", StringMarshalling = StringMarshalling.Custom)|}]
public static partial void Method1(out int i);

[{|#1:LibraryImport("DoesNotExist", StringMarshalling = StringMarshalling.Utf8, StringMarshallingCustomType = typeof(Native))|}]
public static partial void Method2(out int i);

struct Native
{
public Native(string s) { }
Expand Down
Loading