From 9f1e3ea2b1b254ac150b11d6e5a81e0f4c38f997 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Thu, 1 Jul 2021 12:56:30 -0700 Subject: [PATCH] Only run GuaranteedUnmarshal statements if invoke was successful. (dotnet/runtimelab#1288) Commit migrated from https://github.com/dotnet/runtimelab/commit/de39034ce26e5ab47c0967748bc839bd40e58266 --- .../DllImportGenerator/StubCodeGenerator.cs | 300 +++++++++--------- 1 file changed, 154 insertions(+), 146 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs index 31d13feef318d..b21f4348afa97 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/DllImportGenerator/StubCodeGenerator.cs @@ -30,22 +30,11 @@ internal sealed class StubCodeGenerator : StubCodeContext private const string InvokeReturnIdentifier = "__invokeRetVal"; private const string LastErrorIdentifier = "__lastError"; + private const string InvokeSucceededIdentifier = "__invokeSucceeded"; // Error code representing success. This maps to S_OK for Windows HRESULT semantics and 0 for POSIX errno semantics. private const int SuccessErrorCode = 0; - private static readonly Stage[] Stages = new Stage[] - { - Stage.Setup, - Stage.Marshal, - Stage.Pin, - Stage.Invoke, - Stage.KeepAlive, - Stage.Unmarshal, - Stage.GuaranteedUnmarshal, - Stage.Cleanup - }; - private readonly GeneratorDiagnostics diagnostics; private readonly AnalyzerConfigOptions options; private readonly IMethodSymbol stubMethod; @@ -248,139 +237,32 @@ public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes) } var tryStatements = new List(); - var finallyStatements = new List(); + var guaranteedUnmarshalStatements = new List(); + var cleanupStatements = new List(); var invoke = InvocationExpression(IdentifierName(dllImportName)); - var fixedStatements = new List(); - foreach (var stage in Stages) - { - var statements = GetStatements(stage); - int initialCount = statements.Count; - this.CurrentStage = stage; - - if (!invokeReturnsVoid && (stage is Stage.Setup or Stage.Cleanup)) - { - // Handle setup and unmarshalling for return - var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, this); - statements.AddRange(retStatements); - } - - if (stage is Stage.Unmarshal or Stage.GuaranteedUnmarshal) - { - // For Unmarshal and GuaranteedUnmarshal stages, use the topologically sorted - // marshaller list to generate the marshalling statements - foreach (var marshaller in sortedMarshallers) - { - statements.AddRange(marshaller.Generator.Generate(marshaller.TypeInfo, this)); - } - } - else - { - // Generate code for each parameter for the current stage - foreach (var marshaller in paramMarshallers) - { - if (stage == Stage.Invoke) - { - // Get arguments for invocation - ArgumentSyntax argSyntax = marshaller.Generator.AsArgument(marshaller.TypeInfo, this); - invoke = invoke.AddArgumentListArguments(argSyntax); - } - else - { - var generatedStatements = marshaller.Generator.Generate(marshaller.TypeInfo, this); - if (stage == Stage.Pin) - { - // Collect all the fixed statements. These will be used in the Invoke stage. - foreach (var statement in generatedStatements) - { - if (statement is not FixedStatementSyntax fixedStatement) - continue; - - fixedStatements.Add(fixedStatement); - } - } - else - { - statements.AddRange(generatedStatements); - } - } - } - } - - if (stage == Stage.Invoke) - { - StatementSyntax invokeStatement; - - // Assign to return value if necessary - if (invokeReturnsVoid) - { - invokeStatement = ExpressionStatement(invoke); - } - else - { - invokeStatement = ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(this.GetIdentifiers(retMarshaller.TypeInfo).native), - invoke)); - } - - // Do not manually handle SetLastError when generating forwarders. - // We want the runtime to handle everything. - if (this.dllImportData.SetLastError && !options.GenerateForwarders()) - { - // Marshal.SetLastSystemError(0); - var clearLastError = ExpressionStatement( - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.System_Runtime_InteropServices_Marshal), - IdentifierName("SetLastSystemError")), - ArgumentList(SingletonSeparatedList( - Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(SuccessErrorCode))))))); - - // = Marshal.GetLastSystemError(); - var getLastError = ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - IdentifierName(LastErrorIdentifier), - InvocationExpression( - MemberAccessExpression( - SyntaxKind.SimpleMemberAccessExpression, - ParseName(TypeNames.System_Runtime_InteropServices_Marshal), - IdentifierName("GetLastSystemError"))))); - - invokeStatement = Block(clearLastError, invokeStatement, getLastError); - } - - // Nest invocation in fixed statements - if (fixedStatements.Any()) - { - fixedStatements.Reverse(); - invokeStatement = fixedStatements.First().WithStatement(invokeStatement); - foreach (var fixedStatement in fixedStatements.Skip(1)) - { - invokeStatement = fixedStatement.WithStatement(Block(invokeStatement)); - } - } + // Handle GuaranteedUnmarshal first since that stage producing statements affects multiple other stages. + GenerateStatementsForStage(Stage.GuaranteedUnmarshal, guaranteedUnmarshalStatements); + if (guaranteedUnmarshalStatements.Count > 0) + { + setupStatements.Add(MarshallerHelpers.DeclareWithDefault(PredefinedType(Token(SyntaxKind.BoolKeyword)), InvokeSucceededIdentifier)); + } - statements.Add(invokeStatement); - } + GenerateStatementsForStage(Stage.Setup, setupStatements); + GenerateStatementsForStage(Stage.Marshal, tryStatements); + GenerateStatementsForInvoke(tryStatements, invoke); + GenerateStatementsForStage(Stage.KeepAlive, tryStatements); + GenerateStatementsForStage(Stage.Unmarshal, tryStatements); + GenerateStatementsForStage(Stage.Cleanup, cleanupStatements); - if (statements.Count > initialCount) - { - // Comment separating each stage - var newLeadingTrivia = TriviaList( - Comment($"//"), - Comment($"// {stage}"), - Comment($"//")); - var firstStatementInStage = statements[initialCount]; - newLeadingTrivia = newLeadingTrivia.AddRange(firstStatementInStage.GetLeadingTrivia()); - statements[initialCount] = firstStatementInStage.WithLeadingTrivia(newLeadingTrivia); - } + List allStatements = setupStatements; + List finallyStatements = new List(); + if (guaranteedUnmarshalStatements.Count > 0) + { + finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(guaranteedUnmarshalStatements))); } - List allStatements = setupStatements; + finallyStatements.AddRange(cleanupStatements); if (finallyStatements.Count > 0) { // Add try-finally block if there are any statements in the finally block @@ -445,15 +327,141 @@ public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes) return codeBlock.AddStatements(dllImport); - List GetStatements(Stage stage) + void GenerateStatementsForStage(Stage stage, List statementsToUpdate) { - return stage switch + int initialCount = statementsToUpdate.Count; + this.CurrentStage = stage; + + if (!invokeReturnsVoid && (stage is Stage.Setup or Stage.Cleanup)) { - Stage.Setup => setupStatements, - Stage.Marshal or Stage.Pin or Stage.Invoke or Stage.KeepAlive or Stage.Unmarshal => tryStatements, - Stage.GuaranteedUnmarshal or Stage.Cleanup => finallyStatements, - _ => throw new ArgumentOutOfRangeException(nameof(stage)) - }; + // Handle setup and unmarshalling for return + var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, this); + statementsToUpdate.AddRange(retStatements); + } + + if (stage is Stage.Unmarshal or Stage.GuaranteedUnmarshal) + { + // For Unmarshal and GuaranteedUnmarshal stages, use the topologically sorted + // marshaller list to generate the marshalling statements + + foreach (var marshaller in sortedMarshallers) + { + statementsToUpdate.AddRange(marshaller.Generator.Generate(marshaller.TypeInfo, this)); + } + } + else + { + // Generate code for each parameter for the current stage in declaration order. + foreach (var marshaller in paramMarshallers) + { + var generatedStatements = marshaller.Generator.Generate(marshaller.TypeInfo, this); + statementsToUpdate.AddRange(generatedStatements); + } + } + + if (statementsToUpdate.Count > initialCount) + { + // Comment separating each stage + var newLeadingTrivia = TriviaList( + Comment($"//"), + Comment($"// {stage}"), + Comment($"//")); + var firstStatementInStage = statementsToUpdate[initialCount]; + newLeadingTrivia = newLeadingTrivia.AddRange(firstStatementInStage.GetLeadingTrivia()); + statementsToUpdate[initialCount] = firstStatementInStage.WithLeadingTrivia(newLeadingTrivia); + } + } + + void GenerateStatementsForInvoke(List statementsToUpdate, InvocationExpressionSyntax invoke) + { + var fixedStatements = new List(); + this.CurrentStage = Stage.Pin; + // Generate code for each parameter for the current stage + foreach (var marshaller in paramMarshallers) + { + var generatedStatements = marshaller.Generator.Generate(marshaller.TypeInfo, this); + // Collect all the fixed statements. These will be used in the Invoke stage. + foreach (var statement in generatedStatements) + { + if (statement is not FixedStatementSyntax fixedStatement) + continue; + + fixedStatements.Add(fixedStatement); + } + } + + this.CurrentStage = Stage.Invoke; + // Generate code for each parameter for the current stage + foreach (var marshaller in paramMarshallers) + { + // Get arguments for invocation + ArgumentSyntax argSyntax = marshaller.Generator.AsArgument(marshaller.TypeInfo, this); + invoke = invoke.AddArgumentListArguments(argSyntax); + } + + StatementSyntax invokeStatement; + + // Assign to return value if necessary + if (invokeReturnsVoid) + { + invokeStatement = ExpressionStatement(invoke); + } + else + { + invokeStatement = ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(this.GetIdentifiers(retMarshaller.TypeInfo).native), + invoke)); + } + + // Do not manually handle SetLastError when generating forwarders. + // We want the runtime to handle everything. + if (this.dllImportData.SetLastError && !options.GenerateForwarders()) + { + // Marshal.SetLastSystemError(0); + var clearLastError = ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseName(TypeNames.System_Runtime_InteropServices_Marshal), + IdentifierName("SetLastSystemError")), + ArgumentList(SingletonSeparatedList( + Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(SuccessErrorCode))))))); + + // = Marshal.GetLastSystemError(); + var getLastError = ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(LastErrorIdentifier), + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseName(TypeNames.System_Runtime_InteropServices_Marshal), + IdentifierName("GetLastSystemError"))))); + + invokeStatement = Block(clearLastError, invokeStatement, getLastError); + } + + // Nest invocation in fixed statements + if (fixedStatements.Any()) + { + fixedStatements.Reverse(); + invokeStatement = fixedStatements.First().WithStatement(invokeStatement); + foreach (var fixedStatement in fixedStatements.Skip(1)) + { + invokeStatement = fixedStatement.WithStatement(Block(invokeStatement)); + } + } + + statementsToUpdate.Add(invokeStatement); + // = true; + if (guaranteedUnmarshalStatements.Count > 0) + { + statementsToUpdate.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression, + IdentifierName(InvokeSucceededIdentifier), + LiteralExpression(SyntaxKind.TrueLiteralExpression)))); + } } }