Skip to content

Commit

Permalink
[LibraryImportGenerator] Allow span copy for char arrays instead of m…
Browse files Browse the repository at this point in the history
…anual copy loop (#69764)
  • Loading branch information
elinor-fung authored May 25, 2022
1 parent 7015fa2 commit 1601356
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,6 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
IMarshallingGenerator elementMarshaller = _elementMarshallingGenerator.Create(
elementInfo,
new LinearCollectionElementMarshallingCodeContext(StubCodeContext.Stage.Setup, string.Empty, string.Empty, context));
TypeSyntax elementType = elementMarshaller.AsNativeType(elementInfo);


ExpressionSyntax numElementsExpression = LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(0));
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In))
Expand All @@ -315,9 +313,9 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, collectionInfo.ElementCountInfo, context);
}

bool isBlittable = elementMarshaller is BlittableMarshaller;

if (isBlittable)
bool enableArrayPinning = elementMarshaller is BlittableMarshaller;
bool treatAsBlittable = enableArrayPinning || elementMarshaller is Utf16CharMarshaller;
if (treatAsBlittable)
{
marshallingStrategy = new LinearCollectionWithBlittableElementsMarshalling(marshallingStrategy, collectionInfo.ElementType.Syntax, numElementsExpression);
}
Expand All @@ -332,16 +330,17 @@ private IMarshallingGenerator CreateNativeCollectionMarshaller(
marshallingStrategy = DecorateWithTwoStageMarshallingStrategy(collectionInfo, marshallingStrategy);
}

TypeSyntax nativeElementType = elementMarshaller.AsNativeType(elementInfo);
marshallingStrategy = new SizeOfElementMarshalling(
marshallingStrategy,
SizeOfExpression(elementType));
SizeOfExpression(nativeElementType));

if (collectionInfo.UseDefaultMarshalling && info.ManagedType is SzArrayType)
{
return new ArrayMarshaller(
new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: true),
elementType,
isBlittable);
collectionInfo.ElementType.Syntax,
enableArrayPinning);
}

IMarshallingGenerator marshallingGenerator = new CustomNativeTypeMarshallingGenerator(marshallingStrategy, enableByValueContentsMarshalling: false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,19 @@ public IEnumerable<StatementSyntax> GenerateMarshalStatements(TypePositionInfo i
if (!info.IsByRef && info.ByValueContentsMarshalKind == ByValueContentsMarshalKind.Out)
{
// If the parameter is marshalled by-value [Out], then we don't marshal the contents of the collection.
// We do clear the span, so that if the invoke target doesn't fill it, we aren't left with undefined content.
// <nativeIdentifier>.GetNativeValuesDestination().Clear();
yield return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nativeIdentifier),
IdentifierName(ShapeMemberNames.LinearCollection.GetNativeValuesDestination)),
ArgumentList()),
IdentifierName("Clear"))));
yield break;
}

Expand Down Expand Up @@ -742,12 +755,76 @@ public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo
{
string nativeIdentifier = context.GetIdentifiers(info).native;
string numElementsIdentifier = context.GetAdditionalIdentifier(info, "numElements");
yield return LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.IntKeyword)),
SingletonSeparatedList(
VariableDeclarator(numElementsIdentifier).WithInitializer(EqualsValueClause(_numElementsExpression)))));
// MemoryMarshal.Cast<byte, <elementType>>(<nativeIdentifier>.GetNativeValuesSource(<numElements>)).CopyTo(<nativeIdentifier>.GetManagedValuesDestination(<numElements>));

ExpressionSyntax copySource;
ExpressionSyntax copyDestination;
if (!info.IsByRef && info.ByValueContentsMarshalKind.HasFlag(ByValueContentsMarshalKind.Out))
{
// <nativeIdentifier>.GetNativeValuesDestination()
copySource = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nativeIdentifier),
IdentifierName(ShapeMemberNames.LinearCollection.GetNativeValuesDestination)));

// MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(<nativeIdentifier>.GetManagedValuesSource()), <nativeIdentifier>.GetManagedValuesSource().Length)
copyDestination = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MemoryMarshal),
IdentifierName("CreateSpan")),
ArgumentList(
SeparatedList(new[]
{
Argument(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
ParseName(TypeNames.System_Runtime_InteropServices_MemoryMarshal),
IdentifierName("GetReference")),
ArgumentList(SingletonSeparatedList(
Argument(
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nativeIdentifier),
IdentifierName(ShapeMemberNames.LinearCollection.GetManagedValuesSource))))))))
.WithRefKindKeyword(
Token(SyntaxKind.RefKeyword)),
Argument(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
InvocationExpression(
MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nativeIdentifier),
IdentifierName(ShapeMemberNames.LinearCollection.GetManagedValuesSource))),
IdentifierName("Length")))
})));

}
else
{
yield return LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.IntKeyword)),
SingletonSeparatedList(
VariableDeclarator(numElementsIdentifier).WithInitializer(EqualsValueClause(_numElementsExpression)))));

// <nativeIdentifier>.GetNativeValuesSource(<numElements>)
copySource = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nativeIdentifier),
IdentifierName(ShapeMemberNames.LinearCollection.GetNativeValuesSource)),
ArgumentList(SingletonSeparatedList(Argument(IdentifierName(numElementsIdentifier)))));

// <nativeIdentifier>.GetManagedValuesDestination(<numElements>)
copyDestination = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nativeIdentifier),
IdentifierName(ShapeMemberNames.LinearCollection.GetManagedValuesDestination)),
ArgumentList(SingletonSeparatedList(Argument(IdentifierName(numElementsIdentifier)))));
}

// MemoryMarshal.Cast<byte, <elementType>>(<copySource>).CopyTo(<copyDestination>);
yield return ExpressionStatement(
InvocationExpression(
MemberAccessExpression(
Expand All @@ -767,22 +844,10 @@ public IEnumerable<StatementSyntax> GenerateUnmarshalStatements(TypePositionInfo
_elementType
})))))
.AddArgumentListArguments(
Argument(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nativeIdentifier),
IdentifierName(ShapeMemberNames.LinearCollection.GetNativeValuesSource)),
ArgumentList(SingletonSeparatedList(Argument(IdentifierName(numElementsIdentifier))))))),
Argument(copySource)),
IdentifierName("CopyTo")))
.AddArgumentListArguments(
Argument(
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName(nativeIdentifier),
IdentifierName(ShapeMemberNames.LinearCollection.GetManagedValuesDestination)),
ArgumentList(SingletonSeparatedList(Argument(IdentifierName(numElementsIdentifier))))))));
Argument(copyDestination)));

foreach (StatementSyntax statement in _innerMarshaller.GenerateUnmarshalStatements(info, context))
{
Expand Down

0 comments on commit 1601356

Please sign in to comment.