Skip to content

Commit

Permalink
Fix cases when preprocessor definitions are surrounding the method we…
Browse files Browse the repository at this point in the history
… are generating. (#947)
  • Loading branch information
jkoritzinsky authored Apr 12, 2021
1 parent 9aac138 commit 600ff5a
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 7 deletions.
68 changes: 68 additions & 0 deletions DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -929,5 +929,73 @@ class MySafeHandle : SafeHandle
protected override bool ReleaseHandle() => true;
}}";

public static string PreprocessorIfAroundFullFunctionDefinition(string define) =>
@$"
partial class Test
{{
#if {define}
[System.Runtime.InteropServices.GeneratedDllImport(""DoesNotExist"")]
public static partial int Method(
int p,
in int pIn,
out int pOut);
#endif
}}";

public static string PreprocessorIfAroundFullFunctionDefinitionWithFollowingFunction(string define) =>
@$"
using System.Runtime.InteropServices;
partial class Test
{{
#if {define}
[GeneratedDllImport(""DoesNotExist"")]
public static partial int Method(
int p,
in int pIn,
out int pOut);
#endif
public static int Method2(
SafeHandle p) => throw null;
}}";

public static string PreprocessorIfAfterAttributeAroundFunction(string define) =>
@$"
using System.Runtime.InteropServices;
partial class Test
{{
[GeneratedDllImport(""DoesNotExist"")]
#if {define}
public static partial int Method(
int p,
in int pIn,
out int pOut);
#else
public static partial int Method2(
int p,
in int pIn,
out int pOut);
#endif
}}";

public static string PreprocessorIfAfterAttributeAroundFunctionAdditionalFunctionAfter(string define) =>
@$"
using System.Runtime.InteropServices;
partial class Test
{{
[GeneratedDllImport(""DoesNotExist"")]
#if {define}
public static partial int Method(
int p,
in int pIn,
out int pOut);
#else
public static partial int Method2(
int p,
in int pIn,
out int pOut);
#endif
public static int Foo() => throw null;
}}";
}
}
26 changes: 26 additions & 0 deletions DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,32 @@ public async Task ValidateSnippets(string source)
var newCompDiags = newComp.GetDiagnostics();
Assert.Empty(newCompDiags);
}

public static IEnumerable<object[]> CodeSnippetsToCompileWithPreprocessorSymbols()
{
yield return new object[] { CodeSnippets.PreprocessorIfAroundFullFunctionDefinition("Foo"), new string[] { "Foo" } };
yield return new object[] { CodeSnippets.PreprocessorIfAroundFullFunctionDefinition("Foo"), Array.Empty<string>() };
yield return new object[] { CodeSnippets.PreprocessorIfAroundFullFunctionDefinitionWithFollowingFunction("Foo"), new string[] { "Foo" } };
yield return new object[] { CodeSnippets.PreprocessorIfAroundFullFunctionDefinitionWithFollowingFunction("Foo"), Array.Empty<string>() };
yield return new object[] { CodeSnippets.PreprocessorIfAfterAttributeAroundFunction("Foo"), new string[] { "Foo" } };
yield return new object[] { CodeSnippets.PreprocessorIfAfterAttributeAroundFunction("Foo"), Array.Empty<string>() };
yield return new object[] { CodeSnippets.PreprocessorIfAfterAttributeAroundFunctionAdditionalFunctionAfter("Foo"), new string[] { "Foo" } };
yield return new object[] { CodeSnippets.PreprocessorIfAfterAttributeAroundFunctionAdditionalFunctionAfter("Foo"), Array.Empty<string>() };
}

[Theory]
[MemberData(nameof(CodeSnippetsToCompileWithPreprocessorSymbols))]
public async Task ValidateSnippetsWithPreprocessorDefintions(string source, IEnumerable<string> preprocessorSymbols)
{
Compilation comp = await TestUtils.CreateCompilation(source, preprocessorSymbols: preprocessorSymbols);
TestUtils.AssertPreSourceGeneratorCompilation(comp);

var newComp = TestUtils.RunGenerators(comp, out var generatorDiags, new Microsoft.Interop.DllImportGenerator());
Assert.Empty(generatorDiags);

var newCompDiags = newComp.GetDiagnostics();
Assert.Empty(newCompDiags);
}

public static IEnumerable<object[]> CodeSnippetsToCompileWithForwarder()
{
Expand Down
4 changes: 2 additions & 2 deletions DllImportGenerator/DllImportGenerator.UnitTests/TestUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ public static void AssertPreSourceGeneratorCompilation(Compilation comp)
/// <param name="outputKind">Output type</param>
/// <param name="allowUnsafe">Whether or not use of the unsafe keyword should be allowed</param>
/// <returns>The resulting compilation</returns>
public static async Task<Compilation> CreateCompilation(string source, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true)
public static async Task<Compilation> CreateCompilation(string source, OutputKind outputKind = OutputKind.DynamicallyLinkedLibrary, bool allowUnsafe = true, IEnumerable<string>? preprocessorSymbols = null)
{
var (mdRefs, ancillary) = GetReferenceAssemblies();

return CSharpCompilation.Create("compilation",
new[] { CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview)) },
new[] { CSharpSyntaxTree.ParseText(source, new CSharpParseOptions(LanguageVersion.Preview, preprocessorSymbols: preprocessorSymbols)) },
(await mdRefs.ResolveAsync(LanguageNames.CSharp, CancellationToken.None)).Add(ancillary),
new CSharpCompilationOptions(outputKind, allowUnsafe: allowUnsafe));
}
Expand Down
29 changes: 24 additions & 5 deletions DllImportGenerator/DllImportGenerator/DllImportGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,25 @@ public void Initialize(GeneratorInitializationContext context)
context.RegisterForSyntaxNotifications(() => new SyntaxReceiver());
}

private SyntaxTokenList StripTriviaFromModifiers(SyntaxTokenList tokenList)
{
SyntaxToken[] strippedTokens = new SyntaxToken[tokenList.Count];
for (int i = 0; i < tokenList.Count; i++)
{
strippedTokens[i] = tokenList[i].WithoutTrivia();
}
return new SyntaxTokenList(strippedTokens);
}

private TypeDeclarationSyntax CreateTypeDeclarationWithoutTrivia(TypeDeclarationSyntax typeDeclaration)
{
return TypeDeclaration(
typeDeclaration.Kind(),
typeDeclaration.Identifier)
.WithTypeParameterList(typeDeclaration.TypeParameterList)
.WithModifiers(typeDeclaration.Modifiers);
}

private void PrintGeneratedSource(
StringBuilder builder,
MethodDeclarationSyntax userDeclaredMethod,
Expand All @@ -137,27 +156,27 @@ private void PrintGeneratedSource(
// Create stub function
var stubMethod = MethodDeclaration(stub.StubReturnType, userDeclaredMethod.Identifier)
.AddAttributeLists(stub.AdditionalAttributes)
.WithModifiers(userDeclaredMethod.Modifiers)
.WithModifiers(StripTriviaFromModifiers(userDeclaredMethod.Modifiers))
.WithParameterList(ParameterList(SeparatedList(stub.StubParameters)))
.WithBody(stub.StubCode);

// Create the DllImport declaration.
var dllImport = stub.DllImportDeclaration.AddAttributeLists(
AttributeList(
SingletonSeparatedList<AttributeSyntax>(dllImportAttr)));
SingletonSeparatedList(dllImportAttr)));

// Stub should have at least one containing type
Debug.Assert(stub.StubContainingTypes.Any());

// Add stub function and DllImport declaration to the first (innermost) containing
MemberDeclarationSyntax containingType = stub.StubContainingTypes.First()
MemberDeclarationSyntax containingType = CreateTypeDeclarationWithoutTrivia(stub.StubContainingTypes.First())
.AddMembers(stubMethod, dllImport);

// Add type to the remaining containing types (skipping the first which was handled above)
foreach (var typeDecl in stub.StubContainingTypes.Skip(1))
{
containingType = typeDecl.WithMembers(
SingletonList<MemberDeclarationSyntax>(containingType));
containingType = CreateTypeDeclarationWithoutTrivia(typeDecl)
.WithMembers(SingletonList(containingType));
}

MemberDeclarationSyntax toPrint = containingType;
Expand Down

0 comments on commit 600ff5a

Please sign in to comment.