Skip to content

Commit

Permalink
Clean up some code after the interface inheritance work (dotnet#86347)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkoritzinsky authored May 16, 2023
1 parent 17bf9b8 commit 41274d0
Show file tree
Hide file tree
Showing 20 changed files with 418 additions and 568 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,23 @@

namespace Microsoft.Interop
{
public sealed partial class ComInterfaceGenerator
/// <summary>
/// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces).
/// </summary>
internal sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray<ComMethodContext> Methods)
{
// Change Calc all methods to return an ordered list of all the methods and the data in comInterfaceandMethodsContext
// Have a step that runs CalculateMethodStub on each of them.
// Call GroupMethodsByInterfaceForGeneration

/// <summary>
/// Represents an interface and all of the methods that need to be generated for it (methods declared on the interface and methods inherited from base interfaces).
/// COM methods that are declared on the attributed interface declaration.
/// </summary>
private sealed record ComInterfaceAndMethodsContext(ComInterfaceContext Interface, SequenceEqualImmutableArray<ComMethodContext> Methods)
{
// Change Calc all methods to return an ordered list of all the methods and the data in comInterfaceandMethodsContext
// Have a step that runs CalculateMethodStub on each of them.
// Call GroupMethodsByInterfaceForGeneration

/// <summary>
/// COM methods that are declared on the attributed interface declaration.
/// </summary>
public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);
public IEnumerable<ComMethodContext> DeclaredMethods => Methods.Where(m => !m.IsInheritedMethod);

/// <summary>
/// COM methods that are declared on an interface the interface inherits from.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
}
/// <summary>
/// COM methods that are declared on an interface the interface inherits from.
/// </summary>
public IEnumerable<ComMethodContext> ShadowingMethods => Methods.Where(m => m.IsInheritedMethod);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,48 @@

namespace Microsoft.Interop
{
public sealed partial class ComInterfaceGenerator
internal sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base)
{
private sealed record ComInterfaceContext(ComInterfaceInfo Info, ComInterfaceContext? Base)
/// <summary>
/// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext.
/// </summary>
public static ImmutableArray<ComInterfaceContext> GetContexts(ImmutableArray<ComInterfaceInfo> data, CancellationToken _)
{
/// <summary>
/// Takes a list of ComInterfaceInfo, and creates a list of ComInterfaceContext.
/// </summary>
public static ImmutableArray<ComInterfaceContext> GetContexts(ImmutableArray<ComInterfaceInfo> data, CancellationToken _)
Dictionary<string, ComInterfaceInfo> symbolToInterfaceInfoMap = new();
var accumulator = ImmutableArray.CreateBuilder<ComInterfaceContext>(data.Length);
foreach (var iface in data)
{
Dictionary<string, ComInterfaceInfo> symbolToInterfaceInfoMap = new();
var accumulator = ImmutableArray.CreateBuilder<ComInterfaceContext>(data.Length);
foreach (var iface in data)
symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface);
}
Dictionary<string, ComInterfaceContext> symbolToContextMap = new();

foreach (var iface in data)
{
accumulator.Add(AddContext(iface));
}
return accumulator.MoveToImmutable();

ComInterfaceContext AddContext(ComInterfaceInfo iface)
{
if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue))
{
symbolToInterfaceInfoMap.Add(iface.ThisInterfaceKey, iface);
return cachedValue;
}
Dictionary<string, ComInterfaceContext> symbolToContextMap = new();

foreach (var iface in data)
if (iface.BaseInterfaceKey is null)
{
accumulator.Add(AddContext(iface));
var baselessCtx = new ComInterfaceContext(iface, null);
symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx;
return baselessCtx;
}
return accumulator.MoveToImmutable();

ComInterfaceContext AddContext(ComInterfaceInfo iface)
if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext))
{
if (symbolToContextMap.TryGetValue(iface.ThisInterfaceKey, out var cachedValue))
{
return cachedValue;
}

if (iface.BaseInterfaceKey is null)
{
var baselessCtx = new ComInterfaceContext(iface, null);
symbolToContextMap[iface.ThisInterfaceKey] = baselessCtx;
return baselessCtx;
}

if (!symbolToContextMap.TryGetValue(iface.BaseInterfaceKey, out var baseContext))
{
baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]);
}
var ctx = new ComInterfaceContext(iface, baseContext);
symbolToContextMap[iface.ThisInterfaceKey] = ctx;
return ctx;
baseContext = AddContext(symbolToInterfaceInfoMap[iface.BaseInterfaceKey]);
}
var ctx = new ComInterfaceContext(iface, baseContext);
symbolToContextMap[iface.ThisInterfaceKey] = ctx;
return ctx;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Collections.Specialized;
using System.IO;
using System.Linq;
using System.Reflection;
Expand All @@ -19,14 +17,6 @@ namespace Microsoft.Interop
[Generator]
public sealed partial class ComInterfaceGenerator : IIncrementalGenerator
{
private sealed record class GeneratedStubCodeContext(
ManagedTypeInfo OriginalDefiningType,
ContainingSyntaxContext ContainingSyntaxContext,
SyntaxEquivalentNode<MethodDeclarationSyntax> Stub,
SequenceEqualImmutableArray<Diagnostic> Diagnostics) : GeneratedMethodContextBase(OriginalDefiningType, Diagnostics);

private sealed record SkippedStubContext(ManagedTypeInfo OriginalDefiningType) : GeneratedMethodContextBase(OriginalDefiningType, new(ImmutableArray<Diagnostic>.Empty));

public static class StepNames
{
public const string CalculateStubInformation = nameof(CalculateStubInformation);
Expand Down Expand Up @@ -103,11 +93,9 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
{
var ((data, symbolMap), env) = param;
return new ComMethodContext(
data.Method.OriginalDeclaringInterface,
data.TypeKeyOwner,
data.Method.MethodInfo,
data.Method.Index,
CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.TypeKeyOwner.Info.Type, ct));
data.Method,
data.OwningInterface,
CalculateStubInformation(data.Method.MethodInfo.Syntax, symbolMap[data.Method.MethodInfo], data.Method.Index, env, data.OwningInterface.Info.Type, ct));
}).WithTrackingName(StepNames.CalculateStubInformation);

var interfaceAndMethodsContexts = comMethodContexts
Expand All @@ -117,7 +105,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

// Generate the code for the managed-to-unmanaged stubs and the diagnostics from code-generation.
context.RegisterDiagnostics(interfaceAndMethodsContexts
.SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetManagedToUnmanagedStub().Diagnostics)));
.SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.ManagedToUnmanagedStub.Diagnostics)));
var managedToNativeInterfaceImplementations = interfaceAndMethodsContexts
.Select(GenerateImplementationInterface)
.WithTrackingName(StepNames.GenerateManagedToNativeInterfaceImplementation)
Expand All @@ -126,7 +114,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

// Generate the code for the unmanaged-to-managed stubs and the diagnostics from code-generation.
context.RegisterDiagnostics(interfaceAndMethodsContexts
.SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.GetNativeToManagedStub().Diagnostics)));
.SelectMany((data, ct) => data.DeclaredMethods.SelectMany(m => m.UnmanagedToManagedStub.Diagnostics)));
var nativeToManagedVtableMethods = interfaceAndMethodsContexts
.Select(GenerateImplementationVTableMethods)
.WithTrackingName(StepNames.GenerateNativeToManagedVTableMethods)
Expand All @@ -145,11 +133,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Select((data, ct) =>
{
var context = data.Interface.Info;
var methods = data.ShadowingMethods.Select(m => (MemberDeclarationSyntax)m.GenerateShadow());
var methods = data.ShadowingMethods.Select(m => m.Shadow);
var typeDecl = TypeDeclaration(context.ContainingSyntax.TypeKind, context.ContainingSyntax.Identifier)
.WithModifiers(context.ContainingSyntax.Modifiers)
.WithTypeParameterList(context.ContainingSyntax.TypeParameters)
.WithMembers(List(methods));
.WithMembers(List<MemberDeclarationSyntax>(methods));
return data.Interface.Info.TypeDefinitionContext.WrapMemberInContainingSyntaxWithUnsafeModifier(typeDecl);
})
.WithTrackingName(StepNames.GenerateShadowingMethods)
Expand Down Expand Up @@ -211,33 +199,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
});
}

private static string GenerateMarkerInterfaceSource(ComInterfaceInfo iface) => $$"""
file unsafe class InterfaceInformation : global::System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType
{
public static global::System.Guid Iid => new(new global::System.ReadOnlySpan<byte>(new byte[] { {{string.Join(",", iface.InterfaceId.ToByteArray())}} }));
private static void** m_vtable;
public static void** ManagedVirtualMethodTable
{
get
{
if (m_vtable == null)
{
nint* vtable = (nint*)global::System.Runtime.CompilerServices.RuntimeHelpers.AllocateTypeAssociatedMemory(typeof({{iface.Type.FullTypeName}}), sizeof(nint) * 3);
global::System.Runtime.InteropServices.ComWrappers.GetIUnknownImpl(out vtable[0], out vtable[1], out vtable[2]);
m_vtable = (void**)vtable;
}
return m_vtable;
}
}
}

[global::System.Runtime.InteropServices.DynamicInterfaceCastableImplementation]
file interface InterfaceImplementation : {{iface.Type.FullTypeName}}
{}
""";

private static readonly AttributeSyntax s_iUnknownDerivedAttributeTemplate =
Attribute(
GenericName(TypeNames.IUnknownDerivedAttribute)
Expand All @@ -252,8 +213,7 @@ private static MemberDeclarationSyntax GenerateIUnknownDerivedAttributeApplicati
.WithTypeParameterList(context.ContainingSyntax.TypeParameters)
.AddAttributeLists(AttributeList(SingletonSeparatedList(s_iUnknownDerivedAttributeTemplate))));

// Todo: extract info needed from the IMethodSymbol into MethodInfo and only pass a MethodInfo here
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo typeKeyOwner, CancellationToken ct)
private static IncrementalMethodStubGenerationContext CalculateStubInformation(MethodDeclarationSyntax syntax, IMethodSymbol symbol, int index, StubEnvironment environment, ManagedTypeInfo owningInterface, CancellationToken ct)
{
ct.ThrowIfCancellationRequested();
INamedTypeSymbol? lcidConversionAttrType = environment.Compilation.GetTypeByMetadataName(TypeNames.LCIDConversionAttribute);
Expand Down Expand Up @@ -366,7 +326,7 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
new ComExceptionMarshalling(),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.ManagedToUnmanaged),
ComInterfaceGeneratorHelpers.CreateGeneratorFactory(environment, MarshalDirection.UnmanagedToManaged),
typeKeyOwner,
owningInterface,
declaringType,
generatorDiagnostics.Diagnostics.ToSequenceEqualImmutableArray(),
ComInterfaceDispatchMarshallingInfo.Instance);
Expand Down Expand Up @@ -413,31 +373,32 @@ private static ImmutableArray<ComInterfaceAndMethodsContext> GroupComContextsFor
private static InterfaceDeclarationSyntax GenerateImplementationInterface(ComInterfaceAndMethodsContext interfaceGroup, CancellationToken _)
{
var definingType = interfaceGroup.Interface.Info.Type;
var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.GetManagedToUnmanagedStub()))
var shadowImplementations = interfaceGroup.ShadowingMethods.Select(m => (Method: m, ManagedToUnmanagedStub: m.ManagedToUnmanagedStub))
.Where(p => p.ManagedToUnmanagedStub is GeneratedStubCodeContext)
.Select(ctx => ((GeneratedStubCodeContext)ctx.ManagedToUnmanagedStub).Stub.Node
.WithExplicitInterfaceSpecifier(
ExplicitInterfaceSpecifier(ParseName(definingType.FullTypeName))));
var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.GenerateUnreachableExceptionStub());
var inheritedStubs = interfaceGroup.ShadowingMethods.Select(m => m.UnreachableExceptionStub);
return ImplementationInterfaceTemplate
.AddBaseListTypes(SimpleBaseType(definingType.Syntax))
.WithMembers(
List<MemberDeclarationSyntax>(
interfaceGroup.DeclaredMethods
.Select(m => m.GetManagedToUnmanagedStub())
.Select(m => m.ManagedToUnmanagedStub)
.OfType<GeneratedStubCodeContext>()
.Select(ctx => ctx.Stub.Node)
.Concat(shadowImplementations)
.Concat(inheritedStubs)))
.AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(ParseName(TypeNames.System_Runtime_InteropServices_DynamicInterfaceCastableImplementationAttribute)))));
}

private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(ComInterfaceAndMethodsContext comInterfaceAndMethods, CancellationToken _)
{
return ImplementationInterfaceTemplate
.WithMembers(
List<MemberDeclarationSyntax>(
comInterfaceAndMethods.DeclaredMethods
.Select(m => m.GetNativeToManagedStub())
.Select(m => m.UnmanagedToManagedStub)
.OfType<GeneratedStubCodeContext>()
.Select(context => context.Stub.Node)));
}
Expand All @@ -448,6 +409,7 @@ private static InterfaceDeclarationSyntax GenerateImplementationVTableMethods(Co

private static readonly MethodDeclarationSyntax CreateManagedVirtualFunctionTableMethodTemplate = MethodDeclaration(VoidStarStarSyntax, CreateManagedVirtualFunctionTableMethodName)
.AddModifiers(Token(SyntaxKind.InternalKeyword), Token(SyntaxKind.StaticKeyword));

private static InterfaceDeclarationSyntax GenerateImplementationVTable(ComInterfaceAndMethodsContext interfaceMethods, CancellationToken _)
{
const string vtableLocalName = "vtable";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;

namespace Microsoft.Interop
Expand Down
Loading

0 comments on commit 41274d0

Please sign in to comment.