Skip to content

Commit

Permalink
Merge pull request #1017 from microsoft/fix600
Browse files Browse the repository at this point in the history
Implement `SetLastError` ourselves on .NET when marshaling is not allowed
  • Loading branch information
AArnott authored Aug 11, 2023
2 parents 5951d13 + 2c9e6a7 commit 093ccb6
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 6 deletions.
52 changes: 48 additions & 4 deletions src/Microsoft.Windows.CsWin32/Generator.Extern.cs
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,13 @@ private void DeclareExternMethod(MethodDefinitionHandle methodDefinitionHandle)
}
}

bool setLastError = (import.Attributes & MethodImportAttributes.SetLastError) == MethodImportAttributes.SetLastError;
bool setLastErrorViaMarshaling = setLastError && (this.Options.AllowMarshaling || !this.canUseSetLastPInvokeError);
bool setLastErrorManually = setLastError && !setLastErrorViaMarshaling;

AttributeListSyntax CreateDllImportAttributeList() => AttributeList()
.WithCloseBracketToken(TokenWithLineFeed(SyntaxKind.CloseBracketToken))
.AddAttributes(DllImport(import, moduleName, entrypoint, requiresUnicodeCharSet ? CharSet.Unicode : CharSet.Ansi));
.AddAttributes(DllImport(import, moduleName, entrypoint, setLastErrorViaMarshaling, requiresUnicodeCharSet ? CharSet.Unicode : CharSet.Ansi));

MethodDeclarationSyntax externDeclaration = MethodDeclaration(
List<AttributeListSyntax>().Add(CreateDllImportAttributeList()),
Expand Down Expand Up @@ -246,7 +250,7 @@ AttributeListSyntax CreateDllImportAttributeList() => AttributeList()
}

MethodDeclarationSyntax exposedMethod;
if (returnTypeEnumName is null && parameterEnumType is null)
if (returnTypeEnumName is null && parameterEnumType is null && !setLastErrorManually)
{
// No need for wrapping the extern method, so just expose it directly.
exposedMethod = externDeclaration.WithModifiers(externDeclaration.Modifiers.Insert(0, TokenWithSpace(this.Visibility)));
Expand Down Expand Up @@ -288,17 +292,57 @@ static SyntaxToken RefInOutKeyword(ParameterSyntax p) =>
invocation = CastExpression(returnTypeEnumName, invocation);
}

StatementSyntax forwardingStatement = returnType.Type is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.VoidKeyword } ? ExpressionStatement(invocation) : ReturnStatement(invocation);
BlockSyntax body = Block();

IdentifierNameSyntax? retValLocalName = returnType.Type is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.VoidKeyword } ? null : IdentifierName("__retVal");

if (setLastErrorManually)
{
// Marshal.SetLastSystemError(0);
body = body.AddStatements(ExpressionStatement(InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(Marshal)), IdentifierName("SetLastSystemError")),
ArgumentList().AddArguments(Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0)))))));
}

if (retValLocalName is not null)
{
// var __retVal = LocalExternFunction(...);
body = body.AddStatements(
LocalDeclarationStatement(VariableDeclaration(returnTypeEnumName ?? returnType.Type).AddVariables(
VariableDeclarator(retValLocalName.Identifier).WithInitializer(EqualsValueClause(invocation)))));
}
else
{
// LocalExternFunction(...);
body = body.AddStatements(ExpressionStatement(invocation));
}

if (setLastErrorManually)
{
// Marshal.SetLastPInvokeError(Marshal.GetLastSystemError())
body = body.AddStatements(ExpressionStatement(InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(Marshal)), IdentifierName("SetLastPInvokeError")),
ArgumentList().AddArguments(Argument(InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, IdentifierName(nameof(Marshal)), IdentifierName("GetLastSystemError"))))))));
}

if (retValLocalName is not null)
{
// return __retVal;
body = body.AddStatements(ReturnStatement(retValLocalName));
}

LocalFunctionStatementSyntax externFunction = LocalFunctionStatement(externDeclaration.ReturnType, localExternFunctionName.Identifier)
.AddAttributeLists(CreateDllImportAttributeList().WithOpenBracketToken(Token(SyntaxKind.OpenBracketToken).WithLeadingTrivia(LineFeed)))
.WithModifiers(externDeclaration.Modifiers)
.WithParameterList(externDeclaration.ParameterList)
.WithSemicolonToken(SemicolonWithLineFeed);
body = body.AddStatements(externFunction);

exposedMethod = MethodDeclaration(returnTypeEnumName ?? returnType.Type, externDeclaration.Identifier)
.AddModifiers(TokenWithSpace(this.Visibility), TokenWithSpace(SyntaxKind.StaticKeyword))
.WithParameterList(exposedParameterList)
.AddBodyStatements(forwardingStatement, externFunction);
.WithBody(body);
if (requiresUnsafe)
{
exposedMethod = exposedMethod.AddModifiers(TokenWithSpace(SyntaxKind.UnsafeKeyword));
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.Windows.CsWin32/Generator.Features.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public partial class Generator
private readonly bool canUseUnsafeAsRef;
private readonly bool canUseUnsafeNullRef;
private readonly bool canUseUnmanagedCallersOnlyAttribute;
private readonly bool canUseSetLastPInvokeError;
private readonly bool unscopedRefAttributePredefined;
private readonly INamedTypeSymbol? runtimeFeatureClass;
private readonly bool generateSupportedOSPlatformAttributes;
Expand Down
1 change: 1 addition & 0 deletions src/Microsoft.Windows.CsWin32/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public Generator(string metadataLibraryPath, Docs? docs, GeneratorOptions option
this.canUseUnsafeAsRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("AsRef").Any() is true;
this.canUseUnsafeNullRef = this.compilation?.GetTypeByMetadataName(typeof(Unsafe).FullName)?.GetMembers("NullRef").Any() is true;
this.canUseUnmanagedCallersOnlyAttribute = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute") is not null;
this.canUseSetLastPInvokeError = this.compilation?.GetTypeByMetadataName("System.Runtime.InteropServices.Marshal")?.GetMembers("GetLastSystemError").IsEmpty is false;
this.unscopedRefAttributePredefined = this.FindTypeSymbolIfAlreadyAvailable("System.Diagnostics.CodeAnalysis.UnscopedRefAttribute") is not null;
this.runtimeFeatureClass = (INamedTypeSymbol?)this.FindTypeSymbolIfAlreadyAvailable("System.Runtime.CompilerServices.RuntimeFeature");
this.comIIDInterfacePredefined = this.FindTypeSymbolIfAlreadyAvailable($"{this.Namespace}.{IComIIDGuidInterfaceName}") is not null;
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ internal static AttributeSyntax InterfaceType(ComInterfaceType interfaceType)
IdentifierName(Enum.GetName(typeof(ComInterfaceType), interfaceType)!))));
}

internal static AttributeSyntax DllImport(MethodImport import, string moduleName, string? entrypoint, CharSet charSet = CharSet.Ansi)
internal static AttributeSyntax DllImport(MethodImport import, string moduleName, string? entrypoint, bool setLastError, CharSet charSet = CharSet.Ansi)
{
List<AttributeArgumentSyntax> args = new();
AttributeSyntax? dllImportAttribute = Attribute(IdentifierName("DllImport"));
Expand All @@ -216,7 +216,7 @@ internal static AttributeSyntax DllImport(MethodImport import, string moduleName
.WithNameEquals(NameEquals(nameof(DllImportAttribute.EntryPoint))));
}

if ((import.Attributes & MethodImportAttributes.SetLastError) == MethodImportAttributes.SetLastError)
if (setLastError)
{
args.Add(AttributeArgument(LiteralExpression(SyntaxKind.TrueLiteralExpression))
.WithNameEquals(NameEquals(nameof(DllImportAttribute.SetLastError))));
Expand Down
20 changes: 20 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/ExternMethodTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,24 @@ public void WdkMethod_NtCreateFile()
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();
}

[Theory, CombinatorialData]
public void SetLastError_ByMarshaling(
bool allowMarshaling,
[CombinatorialMemberData(nameof(TFMDataNoNetFx35))] string tfm)
{
this.compilation = this.starterCompilations[tfm];
this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = allowMarshaling });
Assert.True(this.generator.TryGenerate("GetVersionEx", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics();

bool expectMarshalingAttribute = allowMarshaling || tfm is "net472" or "netstandard2.0";
MethodDeclarationSyntax originalMethod = this.FindGeneratedMethod("GetVersionEx").Single(m => m.ParameterList.Parameters[0].Type is PointerTypeSyntax);
AttributeSyntax? attribute = FindDllImportAttribute(originalMethod.AttributeLists) ?? FindDllImportAttribute(originalMethod.Body?.Statements.OfType<LocalFunctionStatementSyntax>().SingleOrDefault()?.AttributeLists ?? default);
Assert.NotNull(attribute);
Assert.Equal(expectMarshalingAttribute, attribute.ArgumentList!.Arguments.Any(a => a.NameEquals?.Name.Identifier.ValueText == "SetLastError"));

static AttributeSyntax? FindDllImportAttribute(SyntaxList<AttributeListSyntax> attributeLists) => attributeLists.SelectMany(al => al.Attributes).FirstOrDefault(a => a.Name.ToString() == "DllImport");
}
}

0 comments on commit 093ccb6

Please sign in to comment.