Skip to content

Commit

Permalink
Implement support for collections of collections marshalling (#1226)
Browse files Browse the repository at this point in the history
* Refactor indexer naming to be nested-collection safe.

* Implement initial collections-of-collections/jagged array support.

Use a topological sort of the parameter+return marshallers to unmarshal "CountElementName"-referenced parameters/return before unmarshalling the elements that have a dependency on them through "CountElementName".

* Fixes for edge cases around HResult/Exception handling.

* Flip the edgeMap indices so we can use Array.IndexOf (which is optimized) to search for edges.

* Comments and optimizations.

* Add citation for algorithm,

* Hoist elementIndex out of the loop.

* Encapsulate the edgeMap in a private struct type and add some simple abstractions to enable more perf optimizations.

* Make Topological sort more flexible for element ids

* Use native index if managed index is unset to allow handling multiple native-only parameters as distinct nodes in the graph model.

* Fix using the return value for passing collection size. Validate the types of the CountElementName'd elements, even in nested scenarios.

* Change how we initialize the numRows array.

* Fix cycle breaking in count info.

* Update nested indexer creation for num elements expressions to handle non-collection sub contexts.

* Add stress test for collections of collections that uses 0-11 nested arrays.

* PR feedback.
  • Loading branch information
jkoritzinsky authored Jun 12, 2021
1 parent 4835f73 commit aa2c08c
Show file tree
Hide file tree
Showing 14 changed files with 528 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ public partial class Arrays
[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "and_all_members")]
[return:MarshalAs(UnmanagedType.U1)]
public static partial bool AndAllMembers(BoolStruct[] pArray, int length);

[GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "transpose_matrix")]
[return: MarshalUsing(CountElementName = "numColumns")]
[return: MarshalUsing(CountElementName = "numRows", ElementIndirectionLevel = 1)]
public static partial int[][] TransposeMatrix(int[][] matrix, int[] numRows, int numColumns);
}
}

Expand Down Expand Up @@ -273,6 +278,36 @@ public void ArrayWithSimpleNonBlittableTypeMarshalling(bool result)
Assert.Equal(result, NativeExportsNE.Arrays.AndAllMembers(boolValues, boolValues.Length));
}

[Fact]
public void ArraysOfArrays()
{
var random = new Random(42);
int numRows = random.Next(1, 5);
int numColumns = random.Next(1, 5);
int[][] matrix = new int[numRows][];
for (int i = 0; i < numRows; i++)
{
matrix[i] = new int[numColumns];
for (int j = 0; j < numColumns; j++)
{
matrix[i][j] = random.Next();
}
}

int[] numRowsArray = new int[numColumns];
numRowsArray.AsSpan().Fill(numRows);

int[][] transposed = NativeExportsNE.Arrays.TransposeMatrix(matrix, numRowsArray, numColumns);

for (int i = 0; i < numRows; i++)
{
for (int j = 0; j < numColumns; j++)
{
Assert.Equal(matrix[i][j], transposed[j][i]);
}
}
}

private static string ReverseChars(string value)
{
if (value == null)
Expand Down
77 changes: 77 additions & 0 deletions DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,83 @@ public static partial void Method(
[MarshalUsing(CountElementName=""arr"")] ref int[] arr2
);
}
";

public static string CollectionsOfCollectionsStress => @"
using System.Runtime.InteropServices;
partial class Test
{
[GeneratedDllImport(""DoesNotExist"")]
public static partial void Method(
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]
[MarshalUsing(CountElementName=""arr2"", ElementIndirectionLevel = 2)]
[MarshalUsing(CountElementName=""arr3"", ElementIndirectionLevel = 3)]
[MarshalUsing(CountElementName=""arr4"", ElementIndirectionLevel = 4)]
[MarshalUsing(CountElementName=""arr5"", ElementIndirectionLevel = 5)]
[MarshalUsing(CountElementName=""arr6"", ElementIndirectionLevel = 6)]
[MarshalUsing(CountElementName=""arr7"", ElementIndirectionLevel = 7)]
[MarshalUsing(CountElementName=""arr8"", ElementIndirectionLevel = 8)]
[MarshalUsing(CountElementName=""arr9"", ElementIndirectionLevel = 9)]
[MarshalUsing(CountElementName=""arr10"", ElementIndirectionLevel = 10)] ref int[][][][][][][][][][][] arr11,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]
[MarshalUsing(CountElementName=""arr2"", ElementIndirectionLevel = 2)]
[MarshalUsing(CountElementName=""arr3"", ElementIndirectionLevel = 3)]
[MarshalUsing(CountElementName=""arr4"", ElementIndirectionLevel = 4)]
[MarshalUsing(CountElementName=""arr5"", ElementIndirectionLevel = 5)]
[MarshalUsing(CountElementName=""arr6"", ElementIndirectionLevel = 6)]
[MarshalUsing(CountElementName=""arr7"", ElementIndirectionLevel = 7)]
[MarshalUsing(CountElementName=""arr8"", ElementIndirectionLevel = 8)]
[MarshalUsing(CountElementName=""arr9"", ElementIndirectionLevel = 9)]ref int[][][][][][][][][][] arr10,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]
[MarshalUsing(CountElementName=""arr2"", ElementIndirectionLevel = 2)]
[MarshalUsing(CountElementName=""arr3"", ElementIndirectionLevel = 3)]
[MarshalUsing(CountElementName=""arr4"", ElementIndirectionLevel = 4)]
[MarshalUsing(CountElementName=""arr5"", ElementIndirectionLevel = 5)]
[MarshalUsing(CountElementName=""arr6"", ElementIndirectionLevel = 6)]
[MarshalUsing(CountElementName=""arr7"", ElementIndirectionLevel = 7)]
[MarshalUsing(CountElementName=""arr8"", ElementIndirectionLevel = 8)]ref int[][][][][][][][][] arr9,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]
[MarshalUsing(CountElementName=""arr2"", ElementIndirectionLevel = 2)]
[MarshalUsing(CountElementName=""arr3"", ElementIndirectionLevel = 3)]
[MarshalUsing(CountElementName=""arr4"", ElementIndirectionLevel = 4)]
[MarshalUsing(CountElementName=""arr5"", ElementIndirectionLevel = 5)]
[MarshalUsing(CountElementName=""arr6"", ElementIndirectionLevel = 6)]
[MarshalUsing(CountElementName=""arr7"", ElementIndirectionLevel = 7)]ref int[][][][][][][][][] arr8,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]
[MarshalUsing(CountElementName=""arr2"", ElementIndirectionLevel = 2)]
[MarshalUsing(CountElementName=""arr3"", ElementIndirectionLevel = 3)]
[MarshalUsing(CountElementName=""arr4"", ElementIndirectionLevel = 4)]
[MarshalUsing(CountElementName=""arr5"", ElementIndirectionLevel = 5)]
[MarshalUsing(CountElementName=""arr6"", ElementIndirectionLevel = 6)]ref int[][][][][][][] arr7,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]
[MarshalUsing(CountElementName=""arr2"", ElementIndirectionLevel = 2)]
[MarshalUsing(CountElementName=""arr3"", ElementIndirectionLevel = 3)]
[MarshalUsing(CountElementName=""arr4"", ElementIndirectionLevel = 4)]
[MarshalUsing(CountElementName=""arr5"", ElementIndirectionLevel = 5)]ref int[][][][][][] arr6,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]
[MarshalUsing(CountElementName=""arr2"", ElementIndirectionLevel = 2)]
[MarshalUsing(CountElementName=""arr3"", ElementIndirectionLevel = 3)]
[MarshalUsing(CountElementName=""arr4"", ElementIndirectionLevel = 4)]ref int[][][][][] arr5,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]
[MarshalUsing(CountElementName=""arr2"", ElementIndirectionLevel = 2)]
[MarshalUsing(CountElementName=""arr3"", ElementIndirectionLevel = 3)]ref int[][][][] arr4,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]
[MarshalUsing(CountElementName=""arr2"", ElementIndirectionLevel = 2)]ref int[][][] arr3,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]
[MarshalUsing(CountElementName=""arr1"", ElementIndirectionLevel = 1)]ref int[][] arr2,
[MarshalUsing(CountElementName=""arr0"", ElementIndirectionLevel = 0)]ref int[] arr1,
ref int arr0
);
}
";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()
yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerParametersAndModifiers<UIntPtr>() };
yield return new[] { CodeSnippets.CustomCollectionCustomMarshallerReturnValueLength<int>() };
yield return new[] { CodeSnippets.GenericCollectionWithCustomElementMarshalling };
yield return new[] { CodeSnippets.CollectionsOfCollectionsStress };
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,12 @@ public void AnalyzeReturnType(SymbolAnalysisContext context)

private void AnalyzeNativeMarshalerType(SymbolAnalysisContext context, ITypeSymbol type, AttributeData nativeMarshalerAttributeData, bool validateManagedGetPinnableReference, bool validateAllScenarioSupport)
{
if (nativeMarshalerAttributeData.ConstructorArguments.Length == 0)
{
// This is a MarshalUsing with just count information.
return;
}

if (nativeMarshalerAttributeData.ConstructorArguments[0].IsNull)
{
context.ReportDiagnostic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand All @@ -12,14 +12,14 @@ namespace Microsoft.Interop
{
internal sealed class ContiguousCollectionElementMarshallingCodeContext : StubCodeContext
{
private readonly string indexerIdentifier;
private readonly string nativeSpanIdentifier;
private readonly StubCodeContext parentContext;

public override bool SingleFrameSpansNativeContext => false;

public override bool AdditionalTemporaryStateLivesAcrossStages => false;

public string IndexerIdentifier { get; }

/// <summary>
/// Create a <see cref="StubCodeContext"/> for marshalling elements of an collection.
/// </summary>
Expand All @@ -29,14 +29,13 @@ internal sealed class ContiguousCollectionElementMarshallingCodeContext : StubCo
/// <param name="parentContext">The parent context.</param>
public ContiguousCollectionElementMarshallingCodeContext(
Stage currentStage,
string indexerIdentifier,
string nativeSpanIdentifier,
StubCodeContext parentContext)
{
CurrentStage = currentStage;
this.indexerIdentifier = indexerIdentifier;
IndexerIdentifier = CalculateIndexerIdentifierBasedOnParentContext(parentContext);
this.nativeSpanIdentifier = nativeSpanIdentifier;
this.parentContext = parentContext;
ParentContext = parentContext;
}

/// <summary>
Expand All @@ -46,22 +45,39 @@ public ContiguousCollectionElementMarshallingCodeContext(
/// <returns>Managed and native identifiers</returns>
public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
{
var (_, native) = parentContext.GetIdentifiers(info);
var (_, native) = ParentContext!.GetIdentifiers(info);
return (
$"{native}.ManagedValues[{indexerIdentifier}]",
$"{nativeSpanIdentifier}[{indexerIdentifier}]"
$"{native}.ManagedValues[{IndexerIdentifier}]",
$"{nativeSpanIdentifier}[{IndexerIdentifier}]"
);
}

public override string GetAdditionalIdentifier(TypePositionInfo info, string name)
{
return $"{nativeSpanIdentifier}__{indexerIdentifier}__{name}";
return $"{nativeSpanIdentifier}__{IndexerIdentifier}__{name}";
}

public override TypePositionInfo? GetTypePositionInfoForManagedIndex(int index)
{
// We don't have parameters to look at when we're in the middle of marshalling an array.
return null;
}

private static string CalculateIndexerIdentifierBasedOnParentContext(StubCodeContext? parentContext)
{
int i = 0;
while (parentContext is StubCodeContext context)
{
if (context is ContiguousCollectionElementMarshallingCodeContext)
{
i++;
}
parentContext = context.ParentContext;
}

// Follow a progression of indexers of the following form:
// __i0, __i1, __i2, __i3, etc/
return $"__i{i}";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,26 +127,24 @@ public IEnumerable<StatementSyntax> GeneratePinStatements(TypePositionInfo info,
/// </summary>
internal class CustomNativeTypeWithValuePropertyStubContext : StubCodeContext
{
private readonly StubCodeContext parentContext;

public CustomNativeTypeWithValuePropertyStubContext(StubCodeContext parentContext)
{
this.parentContext = parentContext;
ParentContext = parentContext;
CurrentStage = parentContext.CurrentStage;
}

public override bool SingleFrameSpansNativeContext => parentContext.SingleFrameSpansNativeContext;
public override bool SingleFrameSpansNativeContext => ParentContext!.SingleFrameSpansNativeContext;

public override bool AdditionalTemporaryStateLivesAcrossStages => parentContext.AdditionalTemporaryStateLivesAcrossStages;
public override bool AdditionalTemporaryStateLivesAcrossStages => ParentContext!.AdditionalTemporaryStateLivesAcrossStages;

public override TypePositionInfo? GetTypePositionInfoForManagedIndex(int index)
{
return parentContext.GetTypePositionInfoForManagedIndex(index);
return ParentContext!.GetTypePositionInfoForManagedIndex(index);
}

public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
{
return (parentContext.GetIdentifiers(info).managed, MarshallerHelpers.GetMarshallerIdentifier(info, parentContext));
return (ParentContext!.GetIdentifiers(info).managed, MarshallerHelpers.GetMarshallerIdentifier(info, ParentContext));
}
}

Expand Down Expand Up @@ -859,8 +857,6 @@ public bool UsesNativeIdentifier(TypePositionInfo info, StubCodeContext context)
/// </summary>
internal sealed class ContiguousNonBlittableElementCollectionMarshalling : ICustomNativeTypeMarshallingStrategy
{
private const string IndexerIdentifier = "__i";

private readonly ICustomNativeTypeMarshallingStrategy innerMarshaller;
private readonly IMarshallingGenerator elementMarshaller;
private readonly TypePositionInfo elementInfo;
Expand All @@ -874,15 +870,10 @@ public ContiguousNonBlittableElementCollectionMarshalling(ICustomNativeTypeMarsh
this.elementInfo = elementInfo;
}

private string GetNativeSpanIdentifier(TypePositionInfo info, StubCodeContext context)
{
return context.GetAdditionalIdentifier(info, "nativeSpan");
}

private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(TypePositionInfo info, StubCodeContext context)
{
string nativeIdentifier = context.GetIdentifiers(info).native;
string nativeSpanIdentifier = GetNativeSpanIdentifier(info, context);
string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context);
return LocalDeclarationStatement(VariableDeclaration(
GenericName(
Identifier(TypeNames.System_Span),
Expand Down Expand Up @@ -915,15 +906,13 @@ private LocalDeclarationStatementSyntax GenerateNativeSpanDeclaration(TypePositi
private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo info, StubCodeContext context, bool useManagedSpanForLength)
{
string nativeIdentifier = context.GetIdentifiers(info).native;
string nativeSpanIdentifier = GetNativeSpanIdentifier(info, context);
string nativeSpanIdentifier = MarshallerHelpers.GetNativeSpanIdentifier(info, context);
var elementSetupSubContext = new ContiguousCollectionElementMarshallingCodeContext(
StubCodeContext.Stage.Setup,
IndexerIdentifier,
nativeSpanIdentifier,
context);
var elementSubContext = new ContiguousCollectionElementMarshallingCodeContext(
context.CurrentStage,
IndexerIdentifier,
nativeSpanIdentifier,
context);

Expand Down Expand Up @@ -956,7 +945,7 @@ private StatementSyntax GenerateContentsMarshallingStatement(TypePositionInfo in
// Iterate through the elements of the native collection to unmarshal them
return Block(
GenerateNativeSpanDeclaration(info, context),
MarshallerHelpers.GetForLoop(collectionIdentifierForLength, IndexerIdentifier)
MarshallerHelpers.GetForLoop(collectionIdentifierForLength, elementSubContext.IndexerIdentifier)
.WithStatement(marshallingStatement));
}
return EmptyStatement();
Expand Down
Loading

0 comments on commit aa2c08c

Please sign in to comment.