diff --git a/include/dxc/Support/WinAdapter.h b/include/dxc/Support/WinAdapter.h index 311ed522af..ac96f63b0c 100644 --- a/include/dxc/Support/WinAdapter.h +++ b/include/dxc/Support/WinAdapter.h @@ -45,10 +45,6 @@ #define CoTaskMemAlloc malloc #define CoTaskMemFree free -#define SysFreeString free -#define SysAllocStringLen(ptr, size) \ - (wchar_t *)realloc(ptr, (size + 1) * sizeof(wchar_t)) - #define ARRAYSIZE(array) (sizeof(array) / sizeof(array[0])) #define _countof(a) (sizeof(a) / sizeof(*(a))) @@ -916,6 +912,12 @@ class CHeapPtr : public CHeapPtrBase { #define CComHeapPtr CHeapPtr +//===--------------------------- BSTR Allocation --------------------------===// + +void SysFreeString(BSTR bstrString); +// Allocate string with length prefix +BSTR SysAllocStringLen(const OLECHAR *strIn, UINT ui); + //===--------------------- UTF-8 Related Types ----------------------------===// // Code Page diff --git a/lib/DxcSupport/WinAdapter.cpp b/lib/DxcSupport/WinAdapter.cpp index da2a17899c..dcc6303cc6 100644 --- a/lib/DxcSupport/WinAdapter.cpp +++ b/lib/DxcSupport/WinAdapter.cpp @@ -68,6 +68,36 @@ void *CAllocator::Reallocate(void *p, size_t nBytes) throw() { void *CAllocator::Allocate(size_t nBytes) throw() { return malloc(nBytes); } void CAllocator::Free(void *p) throw() { free(p); } +//===--------------------------- BSTR Allocation --------------------------===// + +void SysFreeString(BSTR bstrString) { + if (bstrString) + free((void *)((uintptr_t)bstrString - sizeof(uint32_t))); +} + +// Allocate string with length prefix +// https://docs.microsoft.com/en-us/previous-versions/windows/desktop/automat/bstr +BSTR SysAllocStringLen(const OLECHAR *strIn, UINT ui) { + uint32_t *blobOut = + (uint32_t *)malloc(sizeof(uint32_t) + (ui + 1) * sizeof(OLECHAR)); + + if (!blobOut) + return nullptr; + + // Size in bytes without trailing NULL character + blobOut[0] = ui * sizeof(OLECHAR); + + BSTR strOut = (BSTR)&blobOut[1]; + + if (strIn) + memcpy(strOut, strIn, blobOut[0]); + + // Write trailing NULL character: + strOut[ui] = 0; + + return strOut; +} + //===---------------------- Char converstion ------------------------------===// const char *CPToLocale(uint32_t CodePage) {