diff --git a/Directory.Packages.props b/Directory.Packages.props
index d3c669aa..ecb877a3 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -28,7 +28,7 @@
-
+
diff --git a/src/Microsoft.Windows.CsWin32/Generator.Com.cs b/src/Microsoft.Windows.CsWin32/Generator.Com.cs
index 9584f577..122b4ef4 100644
--- a/src/Microsoft.Windows.CsWin32/Generator.Com.cs
+++ b/src/Microsoft.Windows.CsWin32/Generator.Com.cs
@@ -99,8 +99,9 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
var vtblMembers = new List();
TypeSyntaxSettings typeSettings = this.comSignatureTypeSettings;
IdentifierNameSyntax pThisLocal = IdentifierName("pThis");
- ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null;
+ ParameterSyntax? ccwThisParameter = this.canUseUnmanagedCallersOnlyAttribute && !this.options.AllowMarshaling && originalIfaceName != "IUnknown" && originalIfaceName != "IDispatch" && !this.IsNonCOMInterface(typeDef) ? Parameter(pThisLocal.Identifier).WithType(PointerType(ifaceName).WithTrailingTrivia(Space)) : null;
List ccwMethodsToSkip = new();
+ List ccwEntrypointMethods = new();
IdentifierNameSyntax vtblParamName = IdentifierName("vtable");
BlockSyntax populateVTableBody = Block();
IdentifierNameSyntax objectLocal = IdentifierName("__object");
@@ -119,7 +120,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
// We do *not* emit CCW methods for IUnknown, because those are provided by ComWrappers.
if (ccwThisParameter is not null &&
- (qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IUnknown") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IDispatch")))
+ (qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IUnknown") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IDispatch") || qualifiedBaseType.Reader.StringComparer.Equals(baseType.Name, "IInspectable")))
{
ccwMethodsToSkip.AddRange(methodsThisType);
}
@@ -132,6 +133,8 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
allMethods.Select(qh => qh.Reader.GetMethodDefinition(qh.MethodHandle)),
originalIfaceName,
allowNonConsecutiveAccessors: true);
+ ISet? ifaceDeclaredProperties = ccwThisParameter is not null ? this.GetDeclarableProperties(allMethods.Select(qh => qh.Reader.GetMethodDefinition(qh.MethodHandle)), originalIfaceName, allowNonConsecutiveAccessors: false) : null;
+
foreach (QualifiedMethodDefinitionHandle methodDefHandle in allMethods)
{
methodCounter++;
@@ -147,6 +150,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
ParameterListSyntax parameterList = methodDefinition.Generator.CreateParameterList(methodDefinition.Method, signature, typeSettings);
ParameterListSyntax parameterListPreserveSig = parameterList; // preserve a copy that has no mutations.
+ bool requiresMarshaling = parameterList.Parameters.Any(p => p.AttributeLists.SelectMany(al => al.Attributes).Any(a => a.Name is IdentifierNameSyntax { Identifier.ValueText: "MarshalAs" }) || p.Modifiers.Any(SyntaxKind.RefKeyword) || p.Modifiers.Any(SyntaxKind.OutKeyword) || p.Modifiers.Any(SyntaxKind.InKeyword));
FunctionPointerParameterListSyntax funcPtrParameters = FunctionPointerParameterList()
.AddParameters(FunctionPointerParameter(PointerType(ifaceName)))
.AddParameters(parameterList.Parameters.Select(p => FunctionPointerParameter(p.Type!).WithModifiers(p.Modifiers)).ToArray())
@@ -174,7 +178,7 @@ private TypeDeclarationSyntax DeclareInterfaceAsStruct(TypeDefinitionHandle type
.AddArguments(Argument(pThisLocal))
.AddArguments(parameterList.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText)).WithRefKindKeyword(p.Modifiers.Count > 0 ? p.Modifiers[0] : default)).ToArray())));
- MemberDeclarationSyntax propertyOrMethod;
+ MemberDeclarationSyntax? propertyOrMethod;
MethodDeclarationSyntax? methodDeclaration = null;
// We can declare this method as a property accessor if it represents a property.
@@ -212,18 +216,6 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
resultLocalDeclaration,
vtblInvocationStatement,
returnStatement)).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)));
-
- if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
- {
- //// *inputArg = @object.Property;
- StatementSyntax propertyGet = ExpressionStatement(AssignmentExpression(
- SyntaxKind.SimpleAssignmentExpression,
- PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)),
- MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName)));
- this.TryGenerateConstantOrThrow("S_OK");
- AddCcwThunk(propertyGet, returnSOK);
- }
-
break;
case SyntaxKind.SetAccessorDeclaration:
// vtblInvoke(pThis, value).ThrowOnFailure();
@@ -233,18 +225,6 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
VariableDeclaration(PointerType(ifaceName)).AddVariables(
VariableDeclarator(pThisLocal.Identifier).WithInitializer(EqualsValueClause(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, ThisExpression())))),
vtblInvocationStatement).WithFixedKeyword(TokenWithSpace(SyntaxKind.FixedKeyword)));
-
- if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
- {
- //// @object.Property = inputArg;
- StatementSyntax propertySet = ExpressionStatement(AssignmentExpression(
- SyntaxKind.SimpleAssignmentExpression,
- MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName),
- IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)));
- this.TryGenerateConstantOrThrow("S_OK");
- AddCcwThunk(propertySet, returnSOK);
- }
-
break;
default:
throw new NotSupportedException("Unsupported accessor kind: " + accessorKind);
@@ -258,7 +238,7 @@ StatementSyntax ThrowOnHRFailure(ExpressionSyntax hrExpression) => ExpressionSta
// Add the accessor to the existing property declaration.
PropertyDeclarationSyntax priorDeclaration = (PropertyDeclarationSyntax)members[priorPropertyDeclarationIndex];
members[priorPropertyDeclarationIndex] = priorDeclaration.WithAccessorList(priorDeclaration.AccessorList!.AddAccessors(accessor));
- continue;
+ propertyOrMethod = null;
}
else
{
@@ -359,8 +339,38 @@ StatementSyntax InvokeVtblAndThrow() => ExpressionStatement(InvocationExpression
propertyOrMethod = methodDeclaration;
members.AddRange(methodDefinition.Generator.DeclareFriendlyOverloads(methodDefinition.Method, methodDeclaration, IdentifierName(ifaceName.Identifier.ValueText), FriendlyOverloadOf.StructMethod, helperMethodsInStruct));
+ }
- if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
+ if (ccwThisParameter is not null && !ccwMethodsToSkip.Contains(methodDefHandle))
+ {
+ if (this.TryGetPropertyAccessorInfo(methodDefinition.Method, originalIfaceName, out propertyName, out accessorKind, out propertyType) &&
+ ifaceDeclaredProperties!.Contains(propertyName.Identifier.ValueText))
+ {
+ switch (accessorKind)
+ {
+ case SyntaxKind.GetAccessorDeclaration:
+ //// *inputArg = @object.Property;
+ StatementSyntax propertyGet = ExpressionStatement(AssignmentExpression(
+ SyntaxKind.SimpleAssignmentExpression,
+ PrefixUnaryExpression(SyntaxKind.PointerIndirectionExpression, IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)),
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName)));
+ this.TryGenerateConstantOrThrow("S_OK");
+ AddCcwThunk(propertyGet, returnSOK);
+ break;
+ case SyntaxKind.SetAccessorDeclaration:
+ //// @object.Property = inputArg;
+ StatementSyntax propertySet = ExpressionStatement(AssignmentExpression(
+ SyntaxKind.SimpleAssignmentExpression,
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, objectLocal, propertyName),
+ IdentifierName(parameterListPreserveSig.Parameters.Last().Identifier.ValueText)));
+ this.TryGenerateConstantOrThrow("S_OK");
+ AddCcwThunk(propertySet, returnSOK);
+ break;
+ default:
+ throw new NotSupportedException("Unsupported accessor kind: " + accessorKind);
+ }
+ }
+ else
{
// Prepare the args for the thunk call. The Interface we thunk into *always* uses PreserveSig, which is super convenient for us.
ArgumentListSyntax args = ArgumentList().AddArguments(parameterListPreserveSig.Parameters.Select(p => Argument(IdentifierName(p.Identifier.ValueText))).ToArray());
@@ -385,6 +395,20 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
return;
}
+ if (requiresMarshaling)
+ {
+ // Oops. This method requires marshaling, which isn't supported in a native-callable function.
+ // Abandon all efforts to add CCW support to this interface.
+ ccwThisParameter = null;
+ foreach (MethodDeclarationSyntax ccwEntrypointMethod in ccwEntrypointMethods)
+ {
+ members.Remove(ccwEntrypointMethod);
+ }
+
+ ccwEntrypointMethods.Clear();
+ return;
+ }
+
this.RequestComHelpers(context);
bool hrReturnType = returnTypePreserveSig is QualifiedNameSyntax { Right.Identifier.ValueText: "HRESULT" };
@@ -434,6 +458,7 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
ccwBody,
semicolonToken: default);
members.Add(ccwMethod);
+ ccwEntrypointMethods.Add(ccwMethod);
populateVTableBody = populateVTableBody.AddStatements(
ExpressionStatement(AssignmentExpression(
@@ -442,9 +467,12 @@ void AddCcwThunk(params StatementSyntax[] thunkInvokeAndReturn)
PrefixUnaryExpression(SyntaxKind.AddressOfExpression, SafeIdentifierName(methodName)))));
}
- // Add documentation if we can find it.
- propertyOrMethod = this.AddApiDocumentation($"{ifaceName}.{methodName}", propertyOrMethod);
- members.Add(propertyOrMethod);
+ if (propertyOrMethod is not null)
+ {
+ // Add documentation if we can find it.
+ propertyOrMethod = this.AddApiDocumentation($"{ifaceName}.{methodName}", propertyOrMethod);
+ members.Add(propertyOrMethod);
+ }
}
// We expose the vtbl struct to support CCWs.
diff --git a/src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs b/src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs
index 0366ffff..418e5c8f 100644
--- a/src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs
+++ b/src/Microsoft.Windows.CsWin32/SimpleSyntaxFactory.cs
@@ -5,27 +5,91 @@ namespace Microsoft.Windows.CsWin32;
internal static class SimpleSyntaxFactory
{
+ ///
+ /// C# keywords that must be escaped or changed when they appear as identifiers from metadata.
+ ///
+ ///
+ /// This list comes from this documentation.
+ ///
internal static readonly HashSet CSharpKeywords = new HashSet(StringComparer.Ordinal)
{
+ "abstract",
"as",
"base",
+ "bool",
+ "break",
+ "byte",
+ "case",
+ "catch",
+ "char",
"checked",
+ "class",
+ "const",
+ "continue",
"decimal",
+ "default",
+ "delegate",
+ "do",
+ "double",
+ "else",
+ "enum",
"event",
+ "explicit",
+ "extern",
+ "false",
+ "finally",
+ "fixed",
+ "float",
+ "for",
+ "foreach",
+ "goto",
+ "if",
+ "implicit",
"in",
- "is",
+ "int",
+ "interface",
"internal",
+ "is",
"lock",
+ "long",
+ "namespace",
+ "new",
+ "null",
"object",
+ "operator",
"out",
"override",
"params",
"private",
"protected",
"public",
+ "readonly",
"ref",
+ "return",
+ "sbyte",
+ "sealed",
+ "short",
+ "sizeof",
+ "stackalloc",
+ "static",
"string",
+ "struct",
+ "switch",
+ "this",
+ "throw",
+ "true",
+ "try",
+ "typeof",
+ "uint",
+ "ulong",
+ "unchecked",
+ "unsafe",
+ "ushort",
+ "using",
"virtual",
+ "void",
+ "volatile",
+ "while",
};
internal static readonly XmlTextSyntax DocCommentStart = XmlText(" ").WithLeadingTrivia(DocumentationCommentExterior("///"));
diff --git a/test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs b/test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs
index caeba087..5934a311 100644
--- a/test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs
+++ b/test/GenerationSandbox.Unmarshalled.Tests/GeneratedForm.cs
@@ -28,6 +28,14 @@ private static unsafe void COMStructsPreserveSig()
o.MachineName = bstr;
}
+#if NET5_0_OR_GREATER
+ private static unsafe void IStream_GetsCCW()
+ {
+ IStream.Vtbl vtbl;
+ IStream.PopulateVTable(&vtbl);
+ }
+#endif
+
private static unsafe void IUnknownGetsVtbl()
{
// WinForms needs the v-table to be declared for these base interfaces.
diff --git a/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt b/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt
index 0e0b9e49..c59c92ce 100644
--- a/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt
+++ b/test/GenerationSandbox.Unmarshalled.Tests/NativeMethods.txt
@@ -1,2 +1,3 @@
-IPersistFile
-IEventSubscription
+IEventSubscription
+IPersistFile
+IStream
diff --git a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
index dcd21371..e767b4fc 100644
--- a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
+++ b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
@@ -240,6 +240,20 @@ public void MethodWithHRParameter()
this.AssertNoDiagnostics();
}
+ [Theory]
+ [InlineData("IVssCreateWriterMetadata")] // A non-COM compliant interface (since it doesn't derive from IUnknown).
+ [InlineData("IProtectionPolicyManagerInterop3")] // An IInspectable-derived interface.
+ [InlineData("ICompositionCapabilitiesInteropFactory")] // An interface with managed types.
+ [InlineData("IPicture")] // An interface with properties that cannot be represented as properties.
+ public void InterestingComInterfaces(string api)
+ {
+ this.compilation = this.starterCompilations["net6.0"];
+ this.generator = this.CreateGenerator(DefaultTestGeneratorOptions with { AllowMarshaling = false });
+ Assert.True(this.generator.TryGenerate(api, CancellationToken.None));
+ this.CollectGeneratedCode(this.generator);
+ this.AssertNoDiagnostics();
+ }
+
[Fact]
public void ComOutPtrTypedAsOutObject()
{
diff --git a/test/Microsoft.Windows.CsWin32.Tests/FullGenerationTests.cs b/test/Microsoft.Windows.CsWin32.Tests/FullGenerationTests.cs
index 7c7efa77..0bddf65f 100644
--- a/test/Microsoft.Windows.CsWin32.Tests/FullGenerationTests.cs
+++ b/test/Microsoft.Windows.CsWin32.Tests/FullGenerationTests.cs
@@ -20,36 +20,47 @@ public FullGenerationTests(ITestOutputHelper logger)
[Trait("TestCategory", "FailsInCloudTest")] // these take ~4GB of memory to run.
[Theory, PairwiseData]
- public void Everything(MarshalingOptions marshaling, bool useIntPtrForComOutPtr, [CombinatorialMemberData(nameof(AnyCpuArchitectures))] Platform platform)
+ public void Everything(
+ MarshalingOptions marshaling,
+ bool useIntPtrForComOutPtr,
+ [CombinatorialMemberData(nameof(AnyCpuArchitectures))] Platform platform,
+ [CombinatorialMemberData(nameof(TFMDataNoNetFx35))] string tfm)
{
- this.TestHelper(marshaling, useIntPtrForComOutPtr, platform, generator => generator.GenerateAll(CancellationToken.None));
+ this.TestHelper(marshaling, useIntPtrForComOutPtr, platform, tfm, generator => generator.GenerateAll(CancellationToken.None));
}
[Theory, PairwiseData]
- public void InteropTypes(MarshalingOptions marshaling, bool useIntPtrForComOutPtr)
+ public void InteropTypes(
+ MarshalingOptions marshaling,
+ bool useIntPtrForComOutPtr,
+ [CombinatorialMemberData(nameof(TFMDataNoNetFx35))] string tfm)
{
- this.TestHelper(marshaling, useIntPtrForComOutPtr, Platform.X64, generator => generator.GenerateAllInteropTypes(CancellationToken.None));
+ this.TestHelper(marshaling, useIntPtrForComOutPtr, Platform.X64, tfm, generator => generator.GenerateAllInteropTypes(CancellationToken.None));
}
[Fact]
public void Constants()
{
- this.TestHelper(marshaling: MarshalingOptions.FullMarshaling, useIntPtrForComOutPtr: false, Platform.X64, generator => generator.GenerateAllConstants(CancellationToken.None));
+ this.TestHelper(marshaling: MarshalingOptions.FullMarshaling, useIntPtrForComOutPtr: false, Platform.X64, DefaultTFM, generator => generator.GenerateAllConstants(CancellationToken.None));
}
[Theory, PairwiseData]
- public void ExternMethods(MarshalingOptions marshaling, bool useIntPtrForComOutPtr, [CombinatorialMemberData(nameof(SpecificCpuArchitectures))] Platform platform)
+ public void ExternMethods(
+ MarshalingOptions marshaling,
+ bool useIntPtrForComOutPtr,
+ [CombinatorialMemberData(nameof(SpecificCpuArchitectures))] Platform platform,
+ [CombinatorialMemberData(nameof(TFMDataNoNetFx35))] string tfm)
{
- this.TestHelper(marshaling, useIntPtrForComOutPtr, platform, generator => generator.GenerateAllExternMethods(CancellationToken.None));
+ this.TestHelper(marshaling, useIntPtrForComOutPtr, platform, tfm, generator => generator.GenerateAllExternMethods(CancellationToken.None));
}
[Fact]
public void Macros()
{
- this.TestHelper(marshaling: MarshalingOptions.FullMarshaling, useIntPtrForComOutPtr: false, Platform.X64, generator => generator.GenerateAllMacros(CancellationToken.None));
+ this.TestHelper(marshaling: MarshalingOptions.FullMarshaling, useIntPtrForComOutPtr: false, Platform.X64, DefaultTFM, generator => generator.GenerateAllMacros(CancellationToken.None));
}
- private void TestHelper(MarshalingOptions marshaling, bool useIntPtrForComOutPtr, Platform platform, Action generationCommands)
+ private void TestHelper(MarshalingOptions marshaling, bool useIntPtrForComOutPtr, Platform platform, string targetFramework, Action generationCommands)
{
var generatorOptions = new GeneratorOptions
{
@@ -57,6 +68,7 @@ private void TestHelper(MarshalingOptions marshaling, bool useIntPtrForComOutPtr
UseSafeHandles = marshaling == MarshalingOptions.FullMarshaling,
ComInterop = new() { UseIntPtrForComOutPointers = useIntPtrForComOutPtr },
};
+ this.compilation = this.starterCompilations[targetFramework];
this.compilation = this.compilation.WithOptions(this.compilation.Options.WithPlatform(platform));
long? lastHeapSize = null;
diff --git a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTestBase.cs b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTestBase.cs
index 23f3cbcb..9218c3e3 100644
--- a/test/Microsoft.Windows.CsWin32.Tests/GeneratorTestBase.cs
+++ b/test/Microsoft.Windows.CsWin32.Tests/GeneratorTestBase.cs
@@ -3,6 +3,7 @@
public abstract class GeneratorTestBase : IDisposable, IAsyncLifetime
{
+ protected const string DefaultTFM = "netstandard2.0";
protected static readonly GeneratorOptions DefaultTestGeneratorOptions = new GeneratorOptions { EmitSingleFile = true };
protected static readonly string FileSeparator = new string('=', 140);
protected static readonly string MetadataPath = Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location!)!, "Windows.Win32.winmd");
@@ -44,12 +45,14 @@ public enum MarshalingOptions
new object[] { "net6.0" },
};
- public static IEnumerable