From 810a7f9cd78d6611c0473940cb733231f3eabc06 Mon Sep 17 00:00:00 2001 From: Aaron Robinson Date: Fri, 20 May 2022 07:58:35 -0700 Subject: [PATCH] Add `BStrStringMarshaller` to source generator (#69213) * Add BStrStringMarshaller to source generator * Convert to use void* for BStr and Utf16 marshallers' native types. Co-authored-by: Jan Kotas --- .../System.Private.CoreLib.Shared.projitems | 1 + .../Marshalling/BStrStringMarshaller.cs | 125 ++++++++++++++++++ .../Marshalling/Utf16StringMarshaller.cs | 8 +- .../MarshallingAttributeInfo.cs | 39 +++--- .../TypeNames.cs | 1 + .../ref/System.Runtime.InteropServices.cs | 18 ++- .../StringTests.cs | 77 +++++++++++ .../Compiles.cs | 2 + .../tests/TestAssets/NativeExports/Strings.cs | 62 +++++++++ 9 files changed, 307 insertions(+), 26 deletions(-) create mode 100644 src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/BStrStringMarshaller.cs diff --git a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems index 478bd1cb3d008..a374c87fb383f 100644 --- a/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems +++ b/src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems @@ -866,6 +866,7 @@ + diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/BStrStringMarshaller.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/BStrStringMarshaller.cs new file mode 100644 index 0000000000000..ce5a83a69e1ab --- /dev/null +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/BStrStringMarshaller.cs @@ -0,0 +1,125 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Text; + +namespace System.Runtime.InteropServices.Marshalling +{ + /// + /// Marshaller for BSTR strings + /// + [CLSCompliant(false)] + [CustomTypeMarshaller(typeof(string), BufferSize = 0x100, + Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling | CustomTypeMarshallerFeatures.CallerAllocatedBuffer)] + public unsafe ref struct BStrStringMarshaller + { + private void* _ptrToFirstChar; + private bool _allocated; + + /// + /// Initializes a new instance of the . + /// + /// The string to marshal. + public BStrStringMarshaller(string? str) + : this(str, default) + { } + + /// + /// Initializes a new instance of the . + /// + /// The string to marshal. + /// Buffer that may be used for marshalling. + /// + /// The must not be movable - that is, it should not be + /// on the managed heap or it should be pinned. + /// + /// + public BStrStringMarshaller(string? str, Span buffer) + { + _allocated = false; + + if (str is null) + { + _ptrToFirstChar = null; + return; + } + + ushort* ptrToFirstChar; + int lengthInBytes = checked(sizeof(char) * str.Length); + + // A caller provided buffer must be at least (lengthInBytes + 6) bytes + // in order to be constructed manually. The 6 extra bytes are 4 for byte length and 2 for wide null. + int manualBstrNeeds = checked(lengthInBytes + 6); + if (manualBstrNeeds > buffer.Length) + { + // Use precise byte count when the provided stack-allocated buffer is not sufficient + ptrToFirstChar = (ushort*)Marshal.AllocBSTRByteLen((uint)lengthInBytes); + _allocated = true; + } + else + { + // Set length and update buffer target + byte* pBuffer = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(buffer)); + *((uint*)pBuffer) = (uint)lengthInBytes; + ptrToFirstChar = (ushort*)(pBuffer + sizeof(uint)); + } + + // Confirm the size is properly set for the allocated BSTR. + Debug.Assert(lengthInBytes == Marshal.SysStringByteLen((IntPtr)ptrToFirstChar)); + + // Copy characters from the managed string + str.CopyTo(new Span(ptrToFirstChar, str.Length)); + ptrToFirstChar[str.Length] = '\0'; // null-terminate + _ptrToFirstChar = ptrToFirstChar; + } + + /// + /// Returns the native value representing the string. + /// + /// + /// + /// + public void* ToNativeValue() => _ptrToFirstChar; + + /// + /// Sets the native value representing the string. + /// + /// The native value. + /// + /// + /// + public void FromNativeValue(void* value) + { + _ptrToFirstChar = value; + _allocated = true; + } + + /// + /// Returns the managed string. + /// + /// + /// + /// + public string? ToManaged() + { + if (_ptrToFirstChar is null) + return null; + + return Marshal.PtrToStringBSTR((IntPtr)_ptrToFirstChar); + } + + /// + /// Frees native resources. + /// + /// + /// + /// + public void FreeNative() + { + if (_allocated) + Marshal.FreeBSTR((IntPtr)_ptrToFirstChar); + } + } +} diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/Utf16StringMarshaller.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/Utf16StringMarshaller.cs index e207bec5f9fcf..04ada6d9d5132 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/Utf16StringMarshaller.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/Utf16StringMarshaller.cs @@ -13,7 +13,7 @@ namespace System.Runtime.InteropServices.Marshalling Features = CustomTypeMarshallerFeatures.UnmanagedResources | CustomTypeMarshallerFeatures.TwoStageMarshalling)] public unsafe ref struct Utf16StringMarshaller { - private ushort* _nativeValue; + private void* _nativeValue; /// /// Initializes a new instance of the . @@ -25,7 +25,7 @@ public unsafe ref struct Utf16StringMarshaller /// The string to marshal. public Utf16StringMarshaller(string? str) { - _nativeValue = (ushort*)Marshal.StringToCoTaskMemUni(str); + _nativeValue = (void*)Marshal.StringToCoTaskMemUni(str); } /// @@ -34,7 +34,7 @@ public Utf16StringMarshaller(string? str) /// /// /// - public ushort* ToNativeValue() => _nativeValue; + public void* ToNativeValue() => _nativeValue; /// /// Sets the native value representing the string. @@ -43,7 +43,7 @@ public Utf16StringMarshaller(string? str) /// /// /// - public void FromNativeValue(ushort* value) => _nativeValue = value; + public void FromNativeValue(void* value) => _nativeValue = value; /// /// Returns the managed string. diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs index 33270eb954550..f2a98d6b246fd 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshallingAttributeInfo.cs @@ -11,7 +11,6 @@ namespace Microsoft.Interop { - /// /// Type used to pass on default marshalling details. /// @@ -72,7 +71,6 @@ public enum CharEncoding Undefined, Utf8, Utf16, - Ansi, Custom } @@ -761,7 +759,13 @@ private bool TryCreateTypeBasedMarshallingInfo( } else { - marshallingInfo = CreateStringMarshallingInfo(type, _defaultInfo.CharEncoding); + marshallingInfo = _defaultInfo.CharEncoding switch + { + CharEncoding.Utf16 => CreateStringMarshallingInfo(type, TypeNames.Utf16StringMarshaller), + CharEncoding.Utf8 => CreateStringMarshallingInfo(type, TypeNames.Utf8StringMarshaller), + _ => throw new InvalidOperationException() + }; + return true; } @@ -842,30 +846,25 @@ private MarshallingInfo CreateStringMarshallingInfo( ITypeSymbol type, UnmanagedType unmanagedType) { - CharEncoding charEncoding = unmanagedType switch + string? marshallerName = unmanagedType switch { - UnmanagedType.LPStr => CharEncoding.Ansi, - UnmanagedType.LPTStr or UnmanagedType.LPWStr => CharEncoding.Utf16, - MarshalAsInfo.UnmanagedType_LPUTF8Str => CharEncoding.Utf8, - _ => CharEncoding.Undefined + UnmanagedType.BStr => TypeNames.BStrStringMarshaller, + UnmanagedType.LPStr => TypeNames.AnsiStringMarshaller, + UnmanagedType.LPTStr or UnmanagedType.LPWStr => TypeNames.Utf16StringMarshaller, + MarshalAsInfo.UnmanagedType_LPUTF8Str => TypeNames.Utf8StringMarshaller, + _ => null }; - if (charEncoding == CharEncoding.Undefined) + + if (marshallerName is null) return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding); - return CreateStringMarshallingInfo(type, charEncoding); + return CreateStringMarshallingInfo(type, marshallerName); } private MarshallingInfo CreateStringMarshallingInfo( ITypeSymbol type, - CharEncoding charEncoding) + string marshallerName) { - string? marshallerName = charEncoding switch - { - CharEncoding.Ansi => TypeNames.AnsiStringMarshaller, - CharEncoding.Utf16 => TypeNames.Utf16StringMarshaller, - CharEncoding.Utf8 => TypeNames.Utf8StringMarshaller, - _ => throw new InvalidOperationException() - }; INamedTypeSymbol? stringMarshaller = _compilation.GetTypeByMetadataName(marshallerName); if (stringMarshaller is null) return new MissingSupportMarshallingInfo(); @@ -876,9 +875,9 @@ private MarshallingInfo CreateStringMarshallingInfo( return CreateNativeMarshallingInfoForValue( type, stringMarshaller, - default, + null, customTypeMarshallerData.Value, - allowPinningManagedType: charEncoding == CharEncoding.Utf16, + allowPinningManagedType: marshallerName is TypeNames.Utf16StringMarshaller, useDefaultMarshalling: false); } diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs index 876dc649001f9..778c1cd2beb39 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs @@ -18,6 +18,7 @@ public static class TypeNames public const string CustomTypeMarshallerAttributeGenericPlaceholder = "System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute.GenericPlaceholder"; public const string AnsiStringMarshaller = "System.Runtime.InteropServices.Marshalling.AnsiStringMarshaller"; + public const string BStrStringMarshaller = "System.Runtime.InteropServices.Marshalling.BStrStringMarshaller"; public const string Utf16StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf16StringMarshaller"; public const string Utf8StringMarshaller = "System.Runtime.InteropServices.Marshalling.Utf8StringMarshaller"; diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs index ca618ba5c7b93..895df2ce26a16 100644 --- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs +++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs @@ -2103,6 +2103,20 @@ public void FromNativeValue(byte* value) { } public T[]? ToManaged() { throw null; } public void FreeNative() { } } + [System.CLSCompliant(false)] + [System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerAttribute(typeof(string), BufferSize = 0x100, + Features = System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.UnmanagedResources + | System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.CallerAllocatedBuffer + | System.Runtime.InteropServices.Marshalling.CustomTypeMarshallerFeatures.TwoStageMarshalling )] + public unsafe ref struct BStrStringMarshaller + { + public BStrStringMarshaller(string? str) { } + public BStrStringMarshaller(string? str, System.Span buffer) { } + public void* ToNativeValue() { throw null; } + public void FromNativeValue(void* value) { } + public string? ToManaged() { throw null; } + public void FreeNative() { } + } [System.AttributeUsageAttribute(System.AttributeTargets.Struct)] public sealed partial class CustomTypeMarshallerAttribute : System.Attribute { @@ -2197,8 +2211,8 @@ public void FreeNative() { } public unsafe ref struct Utf16StringMarshaller { public Utf16StringMarshaller(string? str) { } - public ushort* ToNativeValue() { throw null; } - public void FromNativeValue(ushort* value) { } + public void* ToNativeValue() { throw null; } + public void FromNativeValue(void* value) { } public string? ToManaged() { throw null; } public void FreeNative() { } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/StringTests.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/StringTests.cs index 9987c687f8455..b518f6ed0d9fa 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/StringTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.Tests/StringTests.cs @@ -23,6 +23,7 @@ private class EntryPoints private const string UShortSuffix = "_ushort"; private const string ByteSuffix = "_byte"; + private const string BStrSuffix = "_bstr"; public class Byte { @@ -41,6 +42,15 @@ public class UShort public const string ReverseInplace = EntryPoints.ReverseInplace + UShortSuffix; public const string ReverseReplace = EntryPoints.ReverseReplace + UShortSuffix; } + + public class BStr + { + public const string ReturnLength = EntryPoints.ReturnLength + BStrSuffix; + public const string ReverseReturn = EntryPoints.ReverseReturn + BStrSuffix; + public const string ReverseOut = EntryPoints.ReverseOut + BStrSuffix; + public const string ReverseInplace = EntryPoints.ReverseInplace + BStrSuffix; + public const string ReverseReplace = EntryPoints.ReverseReplace + BStrSuffix; + } } public partial class Utf16 @@ -185,6 +195,31 @@ public partial class LPStr public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.LPStr)] ref string s); } + public partial class BStr + { + [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength)] + public static partial int ReturnLength([MarshalAs(UnmanagedType.BStr)] string s); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReturnLength, StringMarshalling = StringMarshalling.Utf16)] + public static partial int ReturnLength_IgnoreStringMarshalling([MarshalAs(UnmanagedType.BStr)] string s); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReturn)] + [return: MarshalAs(UnmanagedType.BStr)] + public static partial string Reverse_Return([MarshalAs(UnmanagedType.BStr)] string s); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseOut)] + public static partial void Reverse_Out([MarshalAs(UnmanagedType.BStr)] string s, [MarshalAs(UnmanagedType.BStr)] out string ret); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)] + public static partial void Reverse_Ref([MarshalAs(UnmanagedType.BStr)] ref string s); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseInplace)] + public static partial void Reverse_In([MarshalAs(UnmanagedType.BStr)] in string s); + + [LibraryImport(NativeExportsNE_Binary, EntryPoint = EntryPoints.BStr.ReverseReplace)] + public static partial void Reverse_Replace_Ref([MarshalAs(UnmanagedType.BStr)] ref string s); + } + public partial class StringMarshallingCustomType { public partial class Utf16 @@ -418,6 +453,48 @@ public void AnsiStringByRef(string value) Assert.Equal(expected, refValue); } + [Theory] + [MemberData(nameof(UnicodeStrings))] + public void BStrStringMarshalledAsExpected(string value) + { + int expectedLen = value != null ? value.Length : -1; + + Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength(value)); + Assert.Equal(expectedLen, NativeExportsNE.BStr.ReturnLength_IgnoreStringMarshalling(value)); + } + + [Theory] + [MemberData(nameof(UnicodeStrings))] + public void BStrStringReturn(string value) + { + string expected = ReverseChars(value); + + Assert.Equal(expected, NativeExportsNE.BStr.Reverse_Return(value)); + + string ret; + NativeExportsNE.BStr.Reverse_Out(value, out ret); + Assert.Equal(expected, ret); + } + + [Theory] + [MemberData(nameof(UnicodeStrings))] + public void BStrStringByRef(string value) + { + string refValue = value; + string expected = ReverseChars(value); + + NativeExportsNE.BStr.Reverse_In(in refValue); + Assert.Equal(value, refValue); // Should not be updated when using 'in' + + refValue = value; + NativeExportsNE.BStr.Reverse_Ref(ref refValue); + Assert.Equal(expected, refValue); + + refValue = value; + NativeExportsNE.BStr.Reverse_Replace_Ref(ref refValue); + Assert.Equal(expected, refValue); + } + [Theory] [MemberData(nameof(UnicodeStrings))] public void StringMarshallingCustomType_MarshalledAsExpected(string value) diff --git a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs index e8ca396d89560..3448c83275caf 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs @@ -108,9 +108,11 @@ public static IEnumerable CodeSnippetsToCompile() yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers(UnmanagedType.LPTStr) }; yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers(UnmanagedType.LPUTF8Str) }; yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers(UnmanagedType.LPStr) }; + yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers(UnmanagedType.BStr) }; yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo(UnmanagedType.LPWStr) }; yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo(UnmanagedType.LPUTF8Str) }; yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo(UnmanagedType.LPStr) }; + yield return new[] { CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo(UnmanagedType.BStr) }; // [In, Out] attributes // By value non-blittable array diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/Strings.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/Strings.cs index acb72da759ce5..1610a49069eed 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/Strings.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/Strings.cs @@ -26,6 +26,15 @@ public static int ReturnLengthByte(byte* input) return GetLength(input); } + [UnmanagedCallersOnly(EntryPoint = "return_length_bstr")] + public static int ReturnLengthBStr(byte* input) + { + if (input == null) + return -1; + + return GetLengthBStr(input); + } + [UnmanagedCallersOnly(EntryPoint = "reverse_return_ushort")] public static ushort* ReverseReturnUShort(ushort* input) { @@ -38,6 +47,12 @@ public static int ReturnLengthByte(byte* input) return Reverse(input); } + [UnmanagedCallersOnly(EntryPoint = "reverse_return_bstr")] + public static byte* ReverseReturnBStr(byte* input) + { + return ReverseBStr(input); + } + [UnmanagedCallersOnly(EntryPoint = "reverse_out_ushort")] public static void ReverseReturnAsOutUShort(ushort* input, ushort** ret) { @@ -50,6 +65,12 @@ public static void ReverseReturnAsOutByte(byte* input, byte** ret) *ret = Reverse(input); } + [UnmanagedCallersOnly(EntryPoint = "reverse_out_bstr")] + public static void ReverseReturnAsOutBStr(byte* input, byte** ret) + { + *ret = ReverseBStr(input); + } + [UnmanagedCallersOnly(EntryPoint = "reverse_inplace_ref_ushort")] public static void ReverseInPlaceUShort(ushort** refInput) { @@ -69,6 +90,17 @@ public static void ReverseInPlaceByte(byte** refInput) span.Reverse(); } + [UnmanagedCallersOnly(EntryPoint = "reverse_inplace_ref_bstr")] + public static void ReverseInPlaceBStr(byte** refInput) + { + int len = GetLengthBStr(*refInput); + + // Testing of BSTRs is done under the assumption the + // test character input size is 16 bit. + var span = new Span(*refInput, len); + span.Reverse(); + } + [UnmanagedCallersOnly(EntryPoint = "reverse_replace_ref_ushort")] public static void ReverseReplaceRefUShort(ushort** s) { @@ -91,6 +123,17 @@ public static void ReverseReplaceRefByte(byte** s) *s = ret; } + [UnmanagedCallersOnly(EntryPoint = "reverse_replace_ref_bstr")] + public static void ReverseReplaceRefBStr(byte** s) + { + if (*s == null) + return; + + byte* ret = ReverseBStr(*s); + Marshal.FreeBSTR((IntPtr)(*s)); + *s = ret; + } + internal static ushort* Reverse(ushort *s) { if (s == null) @@ -121,6 +164,17 @@ public static void ReverseReplaceRefByte(byte** s) return ret; } + internal static byte* ReverseBStr(byte* s) + { + if (s == null) + return null; + + var arr = Marshal.PtrToStringBSTR((IntPtr)s).ToCharArray(); + Array.Reverse(arr); + var revStr = new string(arr); + return (byte*)Marshal.StringToBSTR(revStr); + } + private static int GetLength(ushort* input) { if (input == null) @@ -150,5 +204,13 @@ private static int GetLength(byte* input) return len; } + + private static int GetLengthBStr(byte* input) + { + if (input == null) + return 0; + + return Marshal.PtrToStringBSTR((IntPtr)input).Length; + } } }