Skip to content

Commit

Permalink
Resolve SizeParamIndex to a TypePositionInfo during MarshallingInfo p…
Browse files Browse the repository at this point in the history
…arsing (#1293)
  • Loading branch information
jkoritzinsky authored Jul 7, 2021
1 parent d4b5fd0 commit eea0bae
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 139 deletions.
11 changes: 11 additions & 0 deletions DllImportGenerator/DllImportGenerator.UnitTests/CodeSnippets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,17 @@ public static partial void Method(
[MarshalUsing(CountElementName=""arr"")] ref int[] arr2
);
}
";
public static string MutuallyRecursiveSizeParamIndexOnParameter => @"
using System.Runtime.InteropServices;
partial class Test
{
[GeneratedDllImport(""DoesNotExist"")]
public static partial void Method(
[MarshalAs(UnmanagedType.LPArray, SizeParamIndex=1)] ref int[] arr,
[MarshalAs(UnmanagedType.LPArray, SizeParamIndex=0)] ref int[] arr2
);
}
";

public static string CollectionsOfCollectionsStress => @"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ public static IEnumerable<object[]> CodeSnippetsToCompile()
yield return new object[] { CodeSnippets.RecursiveCountElementNameOnReturnValue, 2, 0 };
yield return new object[] { CodeSnippets.RecursiveCountElementNameOnParameter, 2, 0 };
yield return new object[] { CodeSnippets.MutuallyRecursiveCountElementNameOnParameter, 4, 0 };
yield return new object[] { CodeSnippets.MutuallyRecursiveSizeParamIndexOnParameter, 4, 0 };
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,6 @@ public override string GetAdditionalIdentifier(TypePositionInfo info, string nam
return $"{nativeSpanIdentifier}__{IndexerIdentifier}__{name}";
}

public override TypePositionInfo? GetTypePositionInfoForManagedIndex(int index)
{
// We don't have parameters to look at when we're in the middle of marshalling an array.
return null;
}

private static string CalculateIndexerIdentifierBasedOnParentContext(StubCodeContext? parentContext)
{
int i = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,20 @@ private bool TryRehydrateMarshalAsAttribute(TypePositionInfo info, out Attribute

if (collectionMarshalling.ElementCountInfo is SizeAndParamIndexInfo countInfo)
{
if (countInfo.ConstSize != SizeAndParamIndexInfo.UnspecifiedData)
if (countInfo.ConstSize != SizeAndParamIndexInfo.UnspecifiedConstSize)
{
marshalAsArguments.Add(
AttributeArgument(NameEquals("SizeConst"), null,
LiteralExpression(SyntaxKind.NumericLiteralExpression,
Literal(countInfo.ConstSize)))
);
}
if (countInfo.ParamIndex != SizeAndParamIndexInfo.UnspecifiedData)
if (countInfo.ParamAtIndex is { ManagedIndex: int paramIndex })
{
marshalAsArguments.Add(
AttributeArgument(NameEquals("SizeParamIndex"), null,
LiteralExpression(SyntaxKind.NumericLiteralExpression,
Literal(countInfo.ParamIndex)))
Literal(paramIndex)))
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,6 @@ public CustomNativeTypeWithValuePropertyStubContext(StubCodeContext parentContex

public override bool AdditionalTemporaryStateLivesAcrossStages => ParentContext!.AdditionalTemporaryStateLivesAcrossStages;

public override TypePositionInfo? GetTypePositionInfoForManagedIndex(int index)
{
return ParentContext!.GetTypePositionInfoForManagedIndex(index);
}

public override (string managed, string native) GetIdentifiers(TypePositionInfo info)
{
return (ParentContext!.GetIdentifiers(info).managed, MarshallerHelpers.GetMarshallerIdentifier(info, ParentContext));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,14 @@ private static IMarshallingGenerator CreateStringMarshaller(TypePositionInfo inf
throw new MarshallingNotSupportedException(info, context);
}

private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(TypePositionInfo info, CountInfo count, StubCodeContext context, AnalyzerConfigOptions options)
private static ExpressionSyntax GetNumElementsExpressionFromMarshallingInfo(TypePositionInfo info, CountInfo count, StubCodeContext context)
{
return count switch
{
SizeAndParamIndexInfo(int size, SizeAndParamIndexInfo.UnspecifiedData) => GetConstSizeExpression(size),
SizeAndParamIndexInfo(int size, SizeAndParamIndexInfo.UnspecifiedParam) => GetConstSizeExpression(size),
ConstSizeCountInfo(int size) => GetConstSizeExpression(size),
SizeAndParamIndexInfo(SizeAndParamIndexInfo.UnspecifiedData, int paramIndex) => CheckedExpression(SyntaxKind.CheckedExpression, GetExpressionForParam(context.GetTypePositionInfoForManagedIndex(paramIndex))),
SizeAndParamIndexInfo(int size, int paramIndex) => CheckedExpression(SyntaxKind.CheckedExpression, BinaryExpression(SyntaxKind.AddExpression, GetConstSizeExpression(size), GetExpressionForParam(context.GetTypePositionInfoForManagedIndex(paramIndex)))),
SizeAndParamIndexInfo(SizeAndParamIndexInfo.UnspecifiedConstSize, TypePositionInfo param) => CheckedExpression(SyntaxKind.CheckedExpression, GetExpressionForParam(param)),
SizeAndParamIndexInfo(int size, TypePositionInfo param) => CheckedExpression(SyntaxKind.CheckedExpression, BinaryExpression(SyntaxKind.AddExpression, GetConstSizeExpression(size), GetExpressionForParam(param))),
CountElementCountInfo(TypePositionInfo elementInfo) => CheckedExpression(SyntaxKind.CheckedExpression, GetExpressionForParam(elementInfo)),
_ => throw new MarshallingNotSupportedException(info, context)
{
Expand All @@ -396,53 +396,43 @@ static LiteralExpressionSyntax GetConstSizeExpression(int size)
return LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(size));
}

ExpressionSyntax GetExpressionForParam(TypePositionInfo? paramInfo)
ExpressionSyntax GetExpressionForParam(TypePositionInfo paramInfo)
{
if (paramInfo is null)
{
throw new MarshallingNotSupportedException(info, context)
{
NotSupportedDetails = Resources.ArraySizeParamIndexOutOfRange
};
}
else
{
ExpressionSyntax numElementsExpression = GetIndexedNumElementsExpression(
context,
paramInfo,
out int numIndirectionLevels);
ExpressionSyntax numElementsExpression = GetIndexedNumElementsExpression(
context,
paramInfo,
out int numIndirectionLevels);

ITypeSymbol type = paramInfo.ManagedType;
MarshallingInfo marshallingInfo = paramInfo.MarshallingAttributeInfo;
ITypeSymbol type = paramInfo.ManagedType;
MarshallingInfo marshallingInfo = paramInfo.MarshallingAttributeInfo;

for (int i = 0; i < numIndirectionLevels; i++)
for (int i = 0; i < numIndirectionLevels; i++)
{
if (marshallingInfo is NativeContiguousCollectionMarshallingInfo collectionInfo)
{
if (marshallingInfo is NativeContiguousCollectionMarshallingInfo collectionInfo)
{
type = collectionInfo.ElementType;
marshallingInfo = collectionInfo.ElementMarshallingInfo;
}
else
{
throw new MarshallingNotSupportedException(info, context)
{
NotSupportedDetails = Resources.CollectionSizeParamTypeMustBeIntegral
};
}
type = collectionInfo.ElementType;
marshallingInfo = collectionInfo.ElementMarshallingInfo;
}

if (!type.IsIntegralType())
else
{
throw new MarshallingNotSupportedException(info, context)
{
NotSupportedDetails = Resources.CollectionSizeParamTypeMustBeIntegral
};
}
}

return CastExpression(
PredefinedType(Token(SyntaxKind.IntKeyword)),
ParenthesizedExpression(numElementsExpression));
if (!type.IsIntegralType())
{
throw new MarshallingNotSupportedException(info, context)
{
NotSupportedDetails = Resources.CollectionSizeParamTypeMustBeIntegral
};
}

return CastExpression(
PredefinedType(Token(SyntaxKind.IntKeyword)),
ParenthesizedExpression(numElementsExpression));
}

static ExpressionSyntax GetIndexedNumElementsExpression(StubCodeContext context, TypePositionInfo numElementsInfo, out int numIndirectionLevels)
Expand Down Expand Up @@ -607,7 +597,7 @@ private static IMarshallingGenerator CreateNativeCollectionMarshaller(
if (info.IsManagedReturnPosition || (info.IsByRef && info.RefKind != RefKind.In))
{
// In this case, we need a numElementsExpression supplied from metadata, so we'll calculate it here.
numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, collectionInfo.ElementCountInfo, context, options);
numElementsExpression = GetNumElementsExpressionFromMarshallingInfo(info, collectionInfo.ElementCountInfo, context);
}

marshallingStrategy = new NumElementsExpressionMarshalling(
Expand Down
Loading

0 comments on commit eea0bae

Please sign in to comment.