From 9148dad9ce6636a938bd1e64380a8cd2f1064c40 Mon Sep 17 00:00:00 2001 From: Elinor Fung Date: Wed, 3 Nov 2021 16:15:15 -0700 Subject: [PATCH] Fix stub generation for char array marshalling --- .../Marshalling/CharMarshaller.cs | 2 +- .../DllImportGenerator.Tests/ArrayTests.cs | 22 +++++++++++++++ .../tests/TestAssets/NativeExports/Arrays.cs | 28 +++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CharMarshaller.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CharMarshaller.cs index 4852974aa8053..293b97959d63b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CharMarshaller.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/Marshalling/CharMarshaller.cs @@ -96,7 +96,7 @@ public IEnumerable Generate(TypePositionInfo info, StubCodeCont case StubCodeContext.Stage.Setup: break; case StubCodeContext.Stage.Marshal: - if (info.IsByRef && info.RefKind != RefKind.Out) + if ((info.IsByRef && info.RefKind != RefKind.Out) || !context.SingleFrameSpansNativeContext) { yield return ExpressionStatement( AssignmentExpression( diff --git a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.Tests/ArrayTests.cs b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.Tests/ArrayTests.cs index 82039c4fac635..b52876d425f25 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.Tests/ArrayTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/DllImportGenerator.Tests/ArrayTests.cs @@ -35,6 +35,12 @@ public partial class Arrays [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "create_range_array_out")] public static partial void CreateRange_Out(int start, int end, out int numValues, [MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 2)] out int[] res); + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_char_array", CharSet = CharSet.Unicode)] + public static partial int SumChars(char[] chars, int numElements); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "reverse_char_array", CharSet = CharSet.Unicode)] + public static partial void ReverseChars([MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)] ref char[] chars, int numElements); + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_string_lengths")] public static partial int SumStringLengths([MarshalAs(UnmanagedType.LPArray, ArraySubType = UnmanagedType.LPWStr)] string[] strArray); @@ -118,6 +124,22 @@ public void IntArrayRefParameter() Assert.Equal((IEnumerable)array, newArray); } + [Fact] + public void CharArrayMarshalledToNativeAsExpected() + { + char[] array = CharacterTests.CharacterMappings().Select(o => (char)o[0]).ToArray(); + Assert.Equal(array.Sum(c => c), NativeExportsNE.Arrays.SumChars(array, array.Length)); + } + + [Fact] + public void CharArrayRefParameter() + { + char[] array = CharacterTests.CharacterMappings().Select(o => (char)o[0]).ToArray(); + var newArray = array; + NativeExportsNE.Arrays.ReverseChars(ref newArray, array.Length); + Assert.Equal(array.Reverse(), newArray); + } + [Fact] public void ArraysReturnedFromNative() { diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/Arrays.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/Arrays.cs index 3e11f6c91298e..a22016275d69c 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/Arrays.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/Arrays.cs @@ -90,6 +90,34 @@ public static void DoubleValues([DNNE.C99Type("struct int_struct_wrapper*")] Int } } + [UnmanagedCallersOnly(EntryPoint = "sum_char_array")] + public static int SumChars(ushort* values, int numValues) + { + if (values == null) + { + return -1; + } + + int sum = 0; + for (int i = 0; i < numValues; i++) + { + sum += values[i]; + } + return sum; + } + + [UnmanagedCallersOnly(EntryPoint = "reverse_char_array")] + public static void ReverseChars(ushort** values, int numValues) + { + if (*values == null) + { + return; + } + + var span = new Span(*values, numValues); + span.Reverse(); + } + [UnmanagedCallersOnly(EntryPoint = "sum_string_lengths")] public static int SumStringLengths(ushort** strArray) {