Skip to content

Commit

Permalink
Don't pin managed collection objects when the contents are non-blitta…
Browse files Browse the repository at this point in the history
…ble. (#69696)

Co-authored-by: Aaron Robinson <arobins@microsoft.com>
  • Loading branch information
jkoritzinsky and AaronRobinsonMSFT authored May 27, 2022
1 parent f05fa01 commit 554aa54
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
}

bool enableArrayPinning = elementMarshaller is BlittableMarshaller;
bool treatAsBlittable = enableArrayPinning || elementMarshaller is Utf16CharMarshaller;
if (treatAsBlittable)
bool treatElementAsBlittable = enableArrayPinning || elementMarshaller is Utf16CharMarshaller;
if (treatElementAsBlittable)
{
marshallingStrategy = new LinearCollectionWithBlittableElementsMarshalling(marshallingStrategy, collectionInfo.ElementType.Syntax, numElementsExpression);
}
Expand Down Expand Up @@ -345,7 +345,8 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(

IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);

if (collectionInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.ManagedType))
// Elements in the collection must be blittable to use the pinnable marshaller.
if (collectionInfo.PinningFeatures.HasFlag(CustomTypeMarshallerPinning.ManagedType) && treatElementAsBlittable)
{
return new PinnableManagedValueMarshaller(marshallingGenerator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ public partial class Collections
[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")]
public static partial int SumStringLengths([MarshalUsing(typeof(ListMarshaller<string>)), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] List<string> strArray);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")]
public static partial int SumStringLengths([MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] WrappedList<string> strArray);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "reverse_strings_replace")]
public static partial void ReverseStrings_Ref([MarshalUsing(typeof(ListMarshaller<string>), CountElementName = "numElements"), MarshalUsing(typeof(Utf16StringMarshaller), ElementIndirectionDepth = 1)] ref List<string> strArray, out int numElements);

Expand All @@ -57,7 +60,7 @@ public static partial void ReverseStrings_Out(
public static partial List<byte> GetLongBytes(long l);

[LibraryImport(NativeExportsNE_Binary, EntryPoint = "and_all_members")]
[return:MarshalAs(UnmanagedType.U1)]
[return: MarshalAs(UnmanagedType.U1)]
public static partial bool AndAllMembers([MarshalUsing(typeof(ListMarshaller<BoolStruct>))] List<BoolStruct> pArray, int length);
}
}
Expand Down Expand Up @@ -143,6 +146,13 @@ public void ByValueNullCollectionWithNonBlittableElements()
Assert.Equal(0, NativeExportsNE.Collections.SumStringLengths(null));
}

[Fact]
public void ByValueCollectionWithNonBlittableElements_WithDefaultMarshalling()
{
var strings = new WrappedList<string>(GetStringList());
Assert.Equal(strings.Wrapped.Sum(str => str?.Length ?? 0), NativeExportsNE.Collections.SumStringLengths(strings));
}

[Fact]
public void ByRefCollectionWithNonBlittableElements()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,4 +227,57 @@ public void FreeNative()
Marshal.FreeCoTaskMem(allocatedMemory);
}
}

[NativeMarshalling(typeof(WrappedListMarshaller<>))]
public struct WrappedList<T>
{
public WrappedList(List<T> list)
{
Wrapped = list;
}

public List<T> Wrapped { get; }

public ref T GetPinnableReference() => ref CollectionsMarshal.AsSpan(Wrapped).GetPinnableReference();
}

[CustomTypeMarshaller(typeof(WrappedList<>), CustomTypeMarshallerKind.LinearCollection, Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer, BufferSize = 0x200)]
public unsafe ref struct WrappedListMarshaller<T>
{
private ListMarshaller<T> _marshaller;

public WrappedListMarshaller(int sizeOfNativeElement)
: this()
{
this._marshaller = new ListMarshaller<T>(sizeOfNativeElement);
}

public WrappedListMarshaller(WrappedList<T> managed, int sizeOfNativeElement)
: this(managed, Span<byte>.Empty, sizeOfNativeElement)
{
}

public WrappedListMarshaller(WrappedList<T> managed, Span<byte> stackSpace, int sizeOfNativeElement)
{
this._marshaller = new ListMarshaller<T>(managed.Wrapped, stackSpace, sizeOfNativeElement);
}

public ReadOnlySpan<T> GetManagedValuesSource() => _marshaller.GetManagedValuesSource();

public Span<T> GetManagedValuesDestination(int length) => _marshaller.GetManagedValuesDestination(length);

public Span<byte> GetNativeValuesDestination() => _marshaller.GetNativeValuesDestination();

public ReadOnlySpan<byte> GetNativeValuesSource(int length) => _marshaller.GetNativeValuesSource(length);

public ref byte GetPinnableReference() => ref _marshaller.GetPinnableReference();

public byte* ToNativeValue() => _marshaller.ToNativeValue();

public void FromNativeValue(byte* value) => _marshaller.FromNativeValue(value);

public WrappedList<T> ToManaged() => new(_marshaller.ToManaged());

public void FreeNative() => _marshaller.FreeNative();
}
}

0 comments on commit 554aa54

Please sign in to comment.