diff --git a/eng/Versions.props b/eng/Versions.props
index c97134bff4f7c..feab8f380f2fc 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -206,7 +206,7 @@
2.45.0
- 1.1.2-beta1.22403.2
+ 1.1.2-beta1.23205.1
7.0.0-preview-20221010.1
diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs
index 4ca18b10a5e5f..77b40b5dc3d2a 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs
@@ -550,7 +550,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
{
if (baseInterface is not null)
{
- return Diagnostic.Create(GeneratorDiagnostics.MultipleComInterfaceBaseTypesAttribute, syntax.Identifier.GetLocation(), type.ToDisplayString());
+ return Diagnostic.Create(GeneratorDiagnostics.MultipleComInterfaceBaseTypes, syntax.Identifier.GetLocation(), type.ToDisplayString());
}
baseInterface = implemented;
}
diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs
index 5d563b290942f..9df719d19d51f 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/GeneratorDiagnostics.cs
@@ -187,7 +187,7 @@ public class Ids
isEnabledByDefault: true,
description: GetResourceString(nameof(SR.InvalidGeneratedComInterfaceAttributeUsageDescription)));
- public static readonly DiagnosticDescriptor MultipleComInterfaceBaseTypesAttribute =
+ public static readonly DiagnosticDescriptor MultipleComInterfaceBaseTypes =
new DiagnosticDescriptor(
Ids.MultipleComInterfaceBaseTypes,
GetResourceString(nameof(SR.MultipleComInterfaceBaseTypesTitle)),
diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs
index 3792231bfadd7..d6c6564cb2afe 100644
--- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs
+++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ManagedToNativeVTableMethodGenerator.cs
@@ -238,7 +238,7 @@ private ParenthesizedExpressionSyntax CreateFunctionPointerExpression(
{
List functionPointerParameters = new();
var (paramList, retType, _) = _marshallers.GenerateTargetMethodSignatureData(_context);
- functionPointerParameters.AddRange(paramList.Parameters.Select(p => FunctionPointerParameter(p.Type)));
+ functionPointerParameters.AddRange(paramList.Parameters.Select(p => FunctionPointerParameter(attributeLists: default, p.Modifiers, p.Type)));
functionPointerParameters.Add(FunctionPointerParameter(retType));
// ((delegate* unmanaged<...>))
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CallingConventionForwarding.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CallingConventionForwarding.cs
index a5e44816cb562..627d5f860eaf5 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CallingConventionForwarding.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CallingConventionForwarding.cs
@@ -2,16 +2,16 @@
// The .NET Foundation licenses this file to you under the MIT license.
using System;
-using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata;
-using System.Text;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Operations;
-using Microsoft.Interop.UnitTests;
+using Microsoft.CodeAnalysis.Testing;
using Xunit;
+using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier;
+
namespace ComInterfaceGenerator.Unit.Tests
{
public class CallingConventionForwarding
@@ -31,16 +31,12 @@ partial interface INativeAPI : IUnmanagedInterfaceType
void Method();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- // Allow the Native nested type name to be missing in the pre-source-generator compilation
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
-
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.VtableIndexStubGenerator());
- var signature = await FindFunctionPointerInvocationSignature(newComp, "INativeAPI", "Method");
-
- Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
- Assert.Empty(signature.UnmanagedCallingConventionTypes);
+ await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (compilation, signature) =>
+ {
+ Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
+ Assert.Empty(signature.UnmanagedCallingConventionTypes);
+ });
}
[Fact]
@@ -59,16 +55,12 @@ partial interface INativeAPI : IUnmanagedInterfaceType
void Method();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- // Allow the Native nested type name to be missing in the pre-source-generator compilation
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.VtableIndexStubGenerator());
-
- var signature = await FindFunctionPointerInvocationSignature(newComp, "INativeAPI", "Method");
-
- Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
- Assert.Equal(newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvSuppressGCTransition"), Assert.Single(signature.UnmanagedCallingConventionTypes), SymbolEqualityComparer.Default);
+ await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
+ {
+ Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
+ Assert.Equal(newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvSuppressGCTransition"), Assert.Single(signature.UnmanagedCallingConventionTypes), SymbolEqualityComparer.Default);
+ });
}
[Fact]
@@ -87,22 +79,19 @@ partial interface INativeAPI : IUnmanagedInterfaceType
void Method();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- // Allow the Native nested type name to be missing in the pre-source-generator compilation
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.VtableIndexStubGenerator());
-
- var signature = await FindFunctionPointerInvocationSignature(newComp, "INativeAPI", "Method");
-
- Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
- Assert.Empty(signature.UnmanagedCallingConventionTypes);
+ await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
+ {
+ Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
+ Assert.Empty(signature.UnmanagedCallingConventionTypes);
+ });
}
[Fact]
public async Task SimpleUnmanagedCallConvAttributeForwarded()
{
string source = $$"""
+ using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
@@ -115,22 +104,19 @@ partial interface INativeAPI : IUnmanagedInterfaceType
void Method();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- // Allow the Native nested type name to be missing in the pre-source-generator compilation
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
-
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.VtableIndexStubGenerator());
- var signature = await FindFunctionPointerInvocationSignature(newComp, "INativeAPI", "Method");
-
- Assert.Equal(SignatureCallingConvention.CDecl, signature.CallingConvention);
- Assert.Empty(signature.UnmanagedCallingConventionTypes);
+ await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
+ {
+ Assert.Equal(SignatureCallingConvention.CDecl, signature.CallingConvention);
+ Assert.Empty(signature.UnmanagedCallingConventionTypes);
+ });
}
[Fact]
public async Task ComplexUnmanagedCallConvAttributeForwarded()
{
string source = $$"""
+ using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
@@ -143,28 +129,25 @@ partial interface INativeAPI : IUnmanagedInterfaceType
void Method();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- // Allow the Native nested type name to be missing in the pre-source-generator compilation
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
-
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.VtableIndexStubGenerator());
- var signature = await FindFunctionPointerInvocationSignature(newComp, "INativeAPI", "Method");
-
- Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
- Assert.Equal(new[]
+ await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
{
- newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvCdecl"),
- newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvMemberFunction"),
- },
- signature.UnmanagedCallingConventionTypes,
- SymbolEqualityComparer.Default);
+ Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
+ Assert.Equal(new[]
+ {
+ newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvCdecl"),
+ newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvMemberFunction"),
+ },
+ signature.UnmanagedCallingConventionTypes,
+ SymbolEqualityComparer.Default);
+ });
}
[Fact]
public async Task ComplexUnmanagedCallConvAttributeWithSuppressGCTransitionForwarded()
{
string source = $$"""
+ using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
@@ -178,41 +161,67 @@ partial interface INativeAPI : IUnmanagedInterfaceType
void Method();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- // Allow the Native nested type name to be missing in the pre-source-generator compilation
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.VtableIndexStubGenerator());
-
- var signature = await FindFunctionPointerInvocationSignature(newComp, "INativeAPI", "Method");
+ await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
+ {
+ Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
+ Assert.Equal(new[]
+ {
+ newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvSuppressGCTransition"),
+ newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvCdecl"),
+ newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvMemberFunction"),
+ },
+ signature.UnmanagedCallingConventionTypes,
+ SymbolEqualityComparer.Default);
+ });
+ }
- Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
- Assert.Equal(new[]
+ private static async Task VerifySourceGeneratorAsync(string source, string interfaceName, string methodName, Action signatureValidator)
+ {
+ CallingConventionForwardingTest test = new(interfaceName, methodName, signatureValidator)
{
- newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvSuppressGCTransition"),
- newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvCdecl"),
- newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvMemberFunction"),
- },
- signature.UnmanagedCallingConventionTypes,
- SymbolEqualityComparer.Default);
+ TestCode = source,
+ TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck
+ };
+
+ await test.RunAsync();
}
- private static async Task FindFunctionPointerInvocationSignature(Compilation compilation, string userDefinedInterfaceName, string methodName)
+ class CallingConventionForwardingTest : VerifyCS.Test
{
- INamedTypeSymbol? userDefinedInterface = compilation.Assembly.GetTypeByMetadataName(userDefinedInterfaceName);
- Assert.NotNull(userDefinedInterface);
+ private readonly Action _signatureValidator;
+ private readonly string _interfaceName;
+ private readonly string _methodName;
+
+ public CallingConventionForwardingTest(string interfaceName, string methodName, Action signatureValidator)
+ : base(referenceAncillaryInterop: true)
+ {
+ _signatureValidator = signatureValidator;
+ _interfaceName = interfaceName;
+ _methodName = methodName;
+ }
+
+ protected override void VerifyFinalCompilation(Compilation compilation)
+ {
+ _signatureValidator(compilation, FindFunctionPointerInvocationSignature(compilation));
+ }
+ private IMethodSymbol FindFunctionPointerInvocationSignature(Compilation compilation)
+ {
+ INamedTypeSymbol? userDefinedInterface = compilation.Assembly.GetTypeByMetadataName(_interfaceName);
+ Assert.NotNull(userDefinedInterface);
- INamedTypeSymbol generatedInterfaceImplementation = Assert.Single(userDefinedInterface.GetTypeMembers("Native"));
+ INamedTypeSymbol generatedInterfaceImplementation = Assert.Single(userDefinedInterface.GetTypeMembers("Native"));
- IMethodSymbol methodImplementation = Assert.Single(generatedInterfaceImplementation.GetMembers($"global::{userDefinedInterfaceName}.{methodName}").OfType());
+ IMethodSymbol methodImplementation = Assert.Single(generatedInterfaceImplementation.GetMembers($"global::{_interfaceName}.{_methodName}").OfType());
- SyntaxNode emittedImplementationSyntax = await methodImplementation.DeclaringSyntaxReferences[0].GetSyntaxAsync();
+ SyntaxNode emittedImplementationSyntax = methodImplementation.DeclaringSyntaxReferences[0].GetSyntax();
- SemanticModel model = compilation.GetSemanticModel(emittedImplementationSyntax.SyntaxTree);
+ SemanticModel model = compilation.GetSemanticModel(emittedImplementationSyntax.SyntaxTree);
- IOperation body = model.GetOperation(emittedImplementationSyntax)!;
+ IOperation body = model.GetOperation(emittedImplementationSyntax)!;
- return Assert.Single(body.Descendants().OfType()).GetFunctionPointerSignature();
+ return Assert.Single(body.Descendants().OfType()).GetFunctionPointerSignature();
+ }
}
}
}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs
index 459fd323c6319..9a1ab29d00219 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs
@@ -126,7 +126,7 @@ public string BasicParametersAndModifiers(string typeName, string methodModifier
partial interface INativeAPI
{
{{VirtualMethodIndex(0)}}
- {{methodModifiers}} {{typeName}} Method({{typeName}} value, in {{typeName}} inValue, ref {{typeName}} refValue, out {{typeName}} outValue);
+ {{methodModifiers}} {{typeName}} {|#0:Method|}({{typeName}} {|#1:value|}, in {{typeName}} {|#2:inValue|}, ref {{typeName}} {|#3:refValue|}, out {{typeName}} {|#4:outValue|});
}
{{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
""";
@@ -277,7 +277,7 @@ partial interface IOtherComInterface
void MethodA();
}
{{GeneratedComInterface}}
- partial interface IComInterface2 : IComInterface, IOtherComInterface
+ partial interface {|#0:IComInterface2|} : IComInterface, IOtherComInterface
{
void Method2();
}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs
index 2bb8d5e187ea3..cc3dc85bbe5e4 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs
@@ -5,9 +5,11 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
-using Microsoft.Interop.UnitTests;
+using Microsoft.CodeAnalysis.Testing;
using Xunit;
+using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier;
+
namespace ComInterfaceGenerator.Unit.Tests
{
public class ComClassGeneratorOutputShape
@@ -27,15 +29,8 @@ partial interface INativeAPI
[GeneratedComClass]
partial class C : INativeAPI {}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
-
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComClassGenerator());
- TestUtils.AssertPostSourceGeneratorCompilation(newComp);
- // We'll create one syntax tree for the new interface.
- Assert.Equal(comp.SyntaxTrees.Count() + 1, newComp.SyntaxTrees.Count());
- VerifyShape(newComp, "C");
+ await VerifySourceGeneratorAsync(source, "C");
}
[Fact]
@@ -69,36 +64,56 @@ partial class E : C
{
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComClassGenerator());
- TestUtils.AssertPostSourceGeneratorCompilation(newComp);
- // We'll create one syntax tree per user-defined interface.
- Assert.Equal(comp.SyntaxTrees.Count() + 3, newComp.SyntaxTrees.Count());
+ await VerifySourceGeneratorAsync(source, "C", "D", "E");
+ }
+
+ private static async Task VerifySourceGeneratorAsync(string source, params string[] typeNames)
+ {
+ GeneratedShapeTest test = new(typeNames)
+ {
+ TestCode = source,
+ TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck
+ };
- VerifyShape(newComp, "C");
- VerifyShape(newComp, "D");
- VerifyShape(newComp, "E");
+ await test.RunAsync();
}
- private static void VerifyShape(Compilation comp, string userDefinedClassMetadataName)
+ class GeneratedShapeTest : VerifyCS.Test
{
- INamedTypeSymbol? userDefinedClass = comp.Assembly.GetTypeByMetadataName(userDefinedClassMetadataName);
- Assert.NotNull(userDefinedClass);
+ private readonly string[] _typeNames;
- INamedTypeSymbol? comExposedClassAttribute = comp.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute`1");
+ public GeneratedShapeTest(params string[] typeNames)
+ :base(referenceAncillaryInterop: false)
+ {
+ _typeNames = typeNames;
+ }
- Assert.NotNull(comExposedClassAttribute);
+ protected override void VerifyFinalCompilation(Compilation compilation)
+ {
+ // Generate one source file per attributed interface.
+ Assert.Equal(TestState.Sources.Count + _typeNames.Length, compilation.SyntaxTrees.Count());
+ Assert.All(_typeNames, name => VerifyShape(compilation, name));
+ }
- AttributeData iUnknownDerivedAttribute = Assert.Single(
- userDefinedClass.GetAttributes(),
- attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, comExposedClassAttribute));
+ private static void VerifyShape(Compilation comp, string userDefinedClassMetadataName)
+ {
+ INamedTypeSymbol? userDefinedClass = comp.Assembly.GetTypeByMetadataName(userDefinedClassMetadataName);
+ Assert.NotNull(userDefinedClass);
- Assert.Collection(Assert.IsAssignableFrom(iUnknownDerivedAttribute.AttributeClass).TypeArguments,
- infoType =>
- {
- Assert.True(Assert.IsAssignableFrom(infoType).IsFileLocal);
- });
+ INamedTypeSymbol? comExposedClassAttribute = comp.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute`1");
+
+ Assert.NotNull(comExposedClassAttribute);
+
+ AttributeData iUnknownDerivedAttribute = Assert.Single(
+ userDefinedClass.GetAttributes(),
+ attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, comExposedClassAttribute));
+
+ Assert.Collection(Assert.IsAssignableFrom(iUnknownDerivedAttribute.AttributeClass).TypeArguments,
+ infoType =>
+ {
+ Assert.True(Assert.IsAssignableFrom(infoType).IsFileLocal);
+ });
+ }
}
}
}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGenerator.Unit.Tests.csproj b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGenerator.Unit.Tests.csproj
index 64e7dd10a070d..6766f8e39fb05 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGenerator.Unit.Tests.csproj
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGenerator.Unit.Tests.csproj
@@ -1,4 +1,4 @@
-
+
$(NetCoreAppCurrent)
@@ -27,6 +27,8 @@
Link="Verifiers\CSharpAnalyzerVerifier.cs"/>
+
@@ -35,6 +37,7 @@
+
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGeneratorOutputShape.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGeneratorOutputShape.cs
index 4b1eafc9c8240..d9063c07a7860 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGeneratorOutputShape.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComInterfaceGeneratorOutputShape.cs
@@ -8,12 +8,16 @@
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
+using Microsoft;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
+using Microsoft.CodeAnalysis.Testing;
using Microsoft.Interop.UnitTests;
using Xunit;
+using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier;
+
namespace ComInterfaceGenerator.Unit.Tests
{
public class ComInterfaceGeneratorOutputShape
@@ -32,15 +36,8 @@ partial interface INativeAPI
void Method2();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
-
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComInterfaceGenerator());
- TestUtils.AssertPostSourceGeneratorCompilation(newComp);
- // We'll create one syntax tree for the new interface.
- Assert.Equal(comp.SyntaxTrees.Count() + 1, newComp.SyntaxTrees.Count());
- VerifyShape(newComp, "INativeAPI");
+ await VerifySourceGeneratorAsync(source, "INativeAPI");
}
[Fact]
@@ -63,16 +60,8 @@ partial interface J
void Method2();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
-
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComInterfaceGenerator());
- TestUtils.AssertPostSourceGeneratorCompilation(newComp);
- // We'll create one syntax tree per user-defined interface.
- Assert.Equal(comp.SyntaxTrees.Count() + 2, newComp.SyntaxTrees.Count());
- VerifyShape(newComp, "I");
- VerifyShape(newComp, "J");
+ await VerifySourceGeneratorAsync(source, "I", "J");
}
[Fact]
@@ -99,17 +88,8 @@ partial interface J
void Method2();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComInterfaceGenerator());
- TestUtils.AssertPostSourceGeneratorCompilation(newComp);
- // We'll create one syntax tree per user-defined interface.
- Assert.Equal(comp.SyntaxTrees.Count() + 3, newComp.SyntaxTrees.Count());
-
- VerifyShape(newComp, "I");
- VerifyShape(newComp, "Empty");
- VerifyShape(newComp, "J");
+ await VerifySourceGeneratorAsync(source, "I", "Empty", "J");
}
[Fact]
@@ -132,47 +112,67 @@ partial interface J : I
void MethodB();
}
""";
- Compilation comp = await TestUtils.CreateCompilation(source);
- TestUtils.AssertPreSourceGeneratorCompilation(comp);
-
- var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComInterfaceGenerator());
- TestUtils.AssertPostSourceGeneratorCompilation(newComp);
- // We'll create one syntax tree per user-defined interface.
- Assert.Equal(comp.SyntaxTrees.Count() + 2, newComp.SyntaxTrees.Count());
- VerifyShape(newComp, "I");
- VerifyShape(newComp, "J");
+ await VerifySourceGeneratorAsync(source, "I", "J");
}
- private static void VerifyShape(Compilation comp, string userDefinedInterfaceMetadataName)
+ private static async Task VerifySourceGeneratorAsync(string source, params string[] typeNames)
{
- INamedTypeSymbol? userDefinedInterface = comp.Assembly.GetTypeByMetadataName(userDefinedInterfaceMetadataName);
- Assert.NotNull(userDefinedInterface);
-
- INamedTypeSymbol? iUnknownDerivedAttributeType = comp.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute`2");
+ GeneratedShapeTest test = new(typeNames)
+ {
+ TestCode = source,
+ TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck
+ };
- Assert.NotNull(iUnknownDerivedAttributeType);
-
- AttributeData iUnknownDerivedAttribute = Assert.Single(
- userDefinedInterface.GetAttributes(),
- attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, iUnknownDerivedAttributeType));
-
- Assert.Collection(Assert.IsAssignableFrom(iUnknownDerivedAttribute.AttributeClass).TypeArguments,
- infoType =>
- {
- Assert.True(Assert.IsAssignableFrom(infoType).IsFileLocal);
- },
- implementationType =>
- {
- Assert.True(Assert.IsAssignableFrom(implementationType).IsFileLocal);
- Assert.Contains(userDefinedInterface, implementationType.Interfaces, SymbolEqualityComparer.Default);
- Assert.Contains(implementationType.GetAttributes(), attr => attr.AttributeClass?.ToDisplayString() == typeof(DynamicInterfaceCastableImplementationAttribute).FullName);
- Assert.All(userDefinedInterface.GetMembers().OfType().Where(method => method.IsAbstract && !method.IsStatic),
- method =>
- {
- Assert.NotNull(implementationType.FindImplementationForInterfaceMember(method));
- });
- });
+ await test.RunAsync();
+ }
+ class GeneratedShapeTest : VerifyCS.Test
+ {
+ private readonly string[] _typeNames;
+
+ public GeneratedShapeTest(params string[] typeNames)
+ : base(referenceAncillaryInterop: false)
+ {
+ _typeNames = typeNames;
+ }
+
+ protected override void VerifyFinalCompilation(Compilation compilation)
+ {
+ // Generate one source file per attributed interface.
+ Assert.Equal(TestState.Sources.Count + _typeNames.Length, compilation.SyntaxTrees.Count());
+ Assert.All(_typeNames, name => VerifyShape(compilation, name));
+ }
+
+ private static void VerifyShape(Compilation comp, string userDefinedInterfaceMetadataName)
+ {
+ INamedTypeSymbol? userDefinedInterface = comp.Assembly.GetTypeByMetadataName(userDefinedInterfaceMetadataName);
+ Assert.NotNull(userDefinedInterface);
+
+ INamedTypeSymbol? iUnknownDerivedAttributeType = comp.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute`2");
+
+ Assert.NotNull(iUnknownDerivedAttributeType);
+
+ AttributeData iUnknownDerivedAttribute = Assert.Single(
+ userDefinedInterface.GetAttributes(),
+ attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, iUnknownDerivedAttributeType));
+
+ Assert.Collection(Assert.IsAssignableFrom(iUnknownDerivedAttribute.AttributeClass).TypeArguments,
+ infoType =>
+ {
+ Assert.True(Assert.IsAssignableFrom(infoType).IsFileLocal);
+ },
+ implementationType =>
+ {
+ Assert.True(Assert.IsAssignableFrom(implementationType).IsFileLocal);
+ Assert.Contains(userDefinedInterface, implementationType.Interfaces, SymbolEqualityComparer.Default);
+ Assert.Contains(implementationType.GetAttributes(), attr => attr.AttributeClass?.ToDisplayString() == typeof(DynamicInterfaceCastableImplementationAttribute).FullName);
+ Assert.All(userDefinedInterface.GetMembers().OfType().Where(method => method.IsAbstract && !method.IsStatic),
+ method =>
+ {
+ Assert.NotNull(implementationType.FindImplementationForInterfaceMember(method));
+ });
+ });
+ }
}
}
}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs
index d69ce9e2c7941..e4db5b789ec8a 100644
--- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs
+++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CompileFails.cs
@@ -11,9 +11,16 @@
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.Testing;
using Microsoft.Interop.UnitTests;
using Xunit;
+using System.Diagnostics;
+
+
+using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier;
+using Microsoft.Interop;
+
namespace ComInterfaceGenerator.Unit.Tests
{
public class CompileFails
@@ -27,31 +34,106 @@ public static IEnumerable