Skip to content

Commit

Permalink
Only run GuaranteedUnmarshal statements if invoke was successful. (do…
Browse files Browse the repository at this point in the history
  • Loading branch information
jkoritzinsky authored Jul 1, 2021
1 parent 0fc1ca2 commit 9f1e3ea
Showing 1 changed file with 154 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,11 @@ internal sealed class StubCodeGenerator : StubCodeContext

private const string InvokeReturnIdentifier = "__invokeRetVal";
private const string LastErrorIdentifier = "__lastError";
private const string InvokeSucceededIdentifier = "__invokeSucceeded";

// Error code representing success. This maps to S_OK for Windows HRESULT semantics and 0 for POSIX errno semantics.
private const int SuccessErrorCode = 0;

private static readonly Stage[] Stages = new Stage[]
{
Stage.Setup,
Stage.Marshal,
Stage.Pin,
Stage.Invoke,
Stage.KeepAlive,
Stage.Unmarshal,
Stage.GuaranteedUnmarshal,
Stage.Cleanup
};

private readonly GeneratorDiagnostics diagnostics;
private readonly AnalyzerConfigOptions options;
private readonly IMethodSymbol stubMethod;
Expand Down Expand Up @@ -248,139 +237,32 @@ public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes)
}

var tryStatements = new List<StatementSyntax>();
var finallyStatements = new List<StatementSyntax>();
var guaranteedUnmarshalStatements = new List<StatementSyntax>();
var cleanupStatements = new List<StatementSyntax>();
var invoke = InvocationExpression(IdentifierName(dllImportName));
var fixedStatements = new List<FixedStatementSyntax>();
foreach (var stage in Stages)
{
var statements = GetStatements(stage);
int initialCount = statements.Count;
this.CurrentStage = stage;

if (!invokeReturnsVoid && (stage is Stage.Setup or Stage.Cleanup))
{
// Handle setup and unmarshalling for return
var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, this);
statements.AddRange(retStatements);
}

if (stage is Stage.Unmarshal or Stage.GuaranteedUnmarshal)
{
// For Unmarshal and GuaranteedUnmarshal stages, use the topologically sorted
// marshaller list to generate the marshalling statements

foreach (var marshaller in sortedMarshallers)
{
statements.AddRange(marshaller.Generator.Generate(marshaller.TypeInfo, this));
}
}
else
{
// Generate code for each parameter for the current stage
foreach (var marshaller in paramMarshallers)
{
if (stage == Stage.Invoke)
{
// Get arguments for invocation
ArgumentSyntax argSyntax = marshaller.Generator.AsArgument(marshaller.TypeInfo, this);
invoke = invoke.AddArgumentListArguments(argSyntax);
}
else
{
var generatedStatements = marshaller.Generator.Generate(marshaller.TypeInfo, this);
if (stage == Stage.Pin)
{
// Collect all the fixed statements. These will be used in the Invoke stage.
foreach (var statement in generatedStatements)
{
if (statement is not FixedStatementSyntax fixedStatement)
continue;

fixedStatements.Add(fixedStatement);
}
}
else
{
statements.AddRange(generatedStatements);
}
}
}
}

if (stage == Stage.Invoke)
{
StatementSyntax invokeStatement;

// Assign to return value if necessary
if (invokeReturnsVoid)
{
invokeStatement = ExpressionStatement(invoke);
}
else
{
invokeStatement = ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(this.GetIdentifiers(retMarshaller.TypeInfo).native),
invoke));
}

// Do not manually handle SetLastError when generating forwarders.
// We want the runtime to handle everything.
if (this.dllImportData.SetLastError && !options.GenerateForwarders())
{
// Marshal.SetLastSystemError(0);
var clearLastError = ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_Marshal),
IdentifierName("SetLastSystemError")),
ArgumentList(SingletonSeparatedList(
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(SuccessErrorCode)))))));

// <lastError> = Marshal.GetLastSystemError();
var getLastError = ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(LastErrorIdentifier),
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_Marshal),
IdentifierName("GetLastSystemError")))));

invokeStatement = Block(clearLastError, invokeStatement, getLastError);
}

// Nest invocation in fixed statements
if (fixedStatements.Any())
{
fixedStatements.Reverse();
invokeStatement = fixedStatements.First().WithStatement(invokeStatement);
foreach (var fixedStatement in fixedStatements.Skip(1))
{
invokeStatement = fixedStatement.WithStatement(Block(invokeStatement));
}
}
// Handle GuaranteedUnmarshal first since that stage producing statements affects multiple other stages.
GenerateStatementsForStage(Stage.GuaranteedUnmarshal, guaranteedUnmarshalStatements);
if (guaranteedUnmarshalStatements.Count > 0)
{
setupStatements.Add(MarshallerHelpers.DeclareWithDefault(PredefinedType(Token(SyntaxKind.BoolKeyword)), InvokeSucceededIdentifier));
}

statements.Add(invokeStatement);
}
GenerateStatementsForStage(Stage.Setup, setupStatements);
GenerateStatementsForStage(Stage.Marshal, tryStatements);
GenerateStatementsForInvoke(tryStatements, invoke);
GenerateStatementsForStage(Stage.KeepAlive, tryStatements);
GenerateStatementsForStage(Stage.Unmarshal, tryStatements);
GenerateStatementsForStage(Stage.Cleanup, cleanupStatements);

if (statements.Count > initialCount)
{
// Comment separating each stage
var newLeadingTrivia = TriviaList(
Comment($"//"),
Comment($"// {stage}"),
Comment($"//"));
var firstStatementInStage = statements[initialCount];
newLeadingTrivia = newLeadingTrivia.AddRange(firstStatementInStage.GetLeadingTrivia());
statements[initialCount] = firstStatementInStage.WithLeadingTrivia(newLeadingTrivia);
}
List<StatementSyntax> allStatements = setupStatements;
List<StatementSyntax> finallyStatements = new List<StatementSyntax>();
if (guaranteedUnmarshalStatements.Count > 0)
{
finallyStatements.Add(IfStatement(IdentifierName(InvokeSucceededIdentifier), Block(guaranteedUnmarshalStatements)));
}

List<StatementSyntax> allStatements = setupStatements;
finallyStatements.AddRange(cleanupStatements);
if (finallyStatements.Count > 0)
{
// Add try-finally block if there are any statements in the finally block
Expand Down Expand Up @@ -445,15 +327,141 @@ public BlockSyntax GenerateSyntax(AttributeListSyntax? forwardedAttributes)

return codeBlock.AddStatements(dllImport);

List<StatementSyntax> GetStatements(Stage stage)
void GenerateStatementsForStage(Stage stage, List<StatementSyntax> statementsToUpdate)
{
return stage switch
int initialCount = statementsToUpdate.Count;
this.CurrentStage = stage;

if (!invokeReturnsVoid && (stage is Stage.Setup or Stage.Cleanup))
{
Stage.Setup => setupStatements,
Stage.Marshal or Stage.Pin or Stage.Invoke or Stage.KeepAlive or Stage.Unmarshal => tryStatements,
Stage.GuaranteedUnmarshal or Stage.Cleanup => finallyStatements,
_ => throw new ArgumentOutOfRangeException(nameof(stage))
};
// Handle setup and unmarshalling for return
var retStatements = retMarshaller.Generator.Generate(retMarshaller.TypeInfo, this);
statementsToUpdate.AddRange(retStatements);
}

if (stage is Stage.Unmarshal or Stage.GuaranteedUnmarshal)
{
// For Unmarshal and GuaranteedUnmarshal stages, use the topologically sorted
// marshaller list to generate the marshalling statements

foreach (var marshaller in sortedMarshallers)
{
statementsToUpdate.AddRange(marshaller.Generator.Generate(marshaller.TypeInfo, this));
}
}
else
{
// Generate code for each parameter for the current stage in declaration order.
foreach (var marshaller in paramMarshallers)
{
var generatedStatements = marshaller.Generator.Generate(marshaller.TypeInfo, this);
statementsToUpdate.AddRange(generatedStatements);
}
}

if (statementsToUpdate.Count > initialCount)
{
// Comment separating each stage
var newLeadingTrivia = TriviaList(
Comment($"//"),
Comment($"// {stage}"),
Comment($"//"));
var firstStatementInStage = statementsToUpdate[initialCount];
newLeadingTrivia = newLeadingTrivia.AddRange(firstStatementInStage.GetLeadingTrivia());
statementsToUpdate[initialCount] = firstStatementInStage.WithLeadingTrivia(newLeadingTrivia);
}
}

void GenerateStatementsForInvoke(List<StatementSyntax> statementsToUpdate, InvocationExpressionSyntax invoke)
{
var fixedStatements = new List<FixedStatementSyntax>();
this.CurrentStage = Stage.Pin;
// Generate code for each parameter for the current stage
foreach (var marshaller in paramMarshallers)
{
var generatedStatements = marshaller.Generator.Generate(marshaller.TypeInfo, this);
// Collect all the fixed statements. These will be used in the Invoke stage.
foreach (var statement in generatedStatements)
{
if (statement is not FixedStatementSyntax fixedStatement)
continue;

fixedStatements.Add(fixedStatement);
}
}

this.CurrentStage = Stage.Invoke;
// Generate code for each parameter for the current stage
foreach (var marshaller in paramMarshallers)
{
// Get arguments for invocation
ArgumentSyntax argSyntax = marshaller.Generator.AsArgument(marshaller.TypeInfo, this);
invoke = invoke.AddArgumentListArguments(argSyntax);
}

StatementSyntax invokeStatement;

// Assign to return value if necessary
if (invokeReturnsVoid)
{
invokeStatement = ExpressionStatement(invoke);
}
else
{
invokeStatement = ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(this.GetIdentifiers(retMarshaller.TypeInfo).native),
invoke));
}

// Do not manually handle SetLastError when generating forwarders.
// We want the runtime to handle everything.
if (this.dllImportData.SetLastError && !options.GenerateForwarders())
{
// Marshal.SetLastSystemError(0);
var clearLastError = ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_Marshal),
IdentifierName("SetLastSystemError")),
ArgumentList(SingletonSeparatedList(
Argument(LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(SuccessErrorCode)))))));

// <lastError> = Marshal.GetLastSystemError();
var getLastError = ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName(LastErrorIdentifier),
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_Marshal),
IdentifierName("GetLastSystemError")))));

invokeStatement = Block(clearLastError, invokeStatement, getLastError);
}

// Nest invocation in fixed statements
if (fixedStatements.Any())
{
fixedStatements.Reverse();
invokeStatement = fixedStatements.First().WithStatement(invokeStatement);
foreach (var fixedStatement in fixedStatements.Skip(1))
{
invokeStatement = fixedStatement.WithStatement(Block(invokeStatement));
}
}

statementsToUpdate.Add(invokeStatement);
// <invokeSucceeded> = true;
if (guaranteedUnmarshalStatements.Count > 0)
{
statementsToUpdate.Add(ExpressionStatement(AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
IdentifierName(InvokeSucceededIdentifier),
LiteralExpression(SyntaxKind.TrueLiteralExpression))));
}
}
}

Expand Down

0 comments on commit 9f1e3ea

Please sign in to comment.