Skip to content

Commit

Permalink
Add Per-assembly Load Native Library callbacks
Browse files Browse the repository at this point in the history
This Change implements the Native Library resolution
Call-backs proposed in https://github.com/dotnet/corefx/issues/32015
  • Loading branch information
swaroop-sridhar authored and swaroop-sridhar committed Jan 11, 2019
1 parent 459b58a commit 39b81b9
Show file tree
Hide file tree
Showing 13 changed files with 334 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@

namespace System.Runtime.InteropServices
{

/// <summary>
/// A delegate used to resolve native libraries via callback.
/// </summary>
/// <param name="libraryName">The native library to resolve</param>
/// <param name="assembly">The assembly requesting the resolution</param>
/// <param name="DllImportSearchPath?">
/// The DllImportSearchPathsAttribute on the PInvoke, if any.
/// Otherwise, the DllImportSearchPathsAttribute on the assembly, if any.
/// Otherwise null.
/// </param>
/// <returns>The handle for the loaded native library on success, null on failure</returns>
public delegate IntPtr DllImportResolver(string libraryName,
Assembly assembly,
DllImportSearchPath? searchPath);

/// <summary>
/// APIs for managing Native Libraries
/// </summary>
Expand Down Expand Up @@ -58,7 +74,9 @@ public static bool TryLoad(string libraryPath, out IntPtr handle)
/// Otherwise, the flags specified by the DefaultDllImportSearchPaths attribute on the
/// calling assembly (if any) are used.
/// This LoadLibrary() method does not invoke the managed call-backs for native library resolution:
/// * The per-assembly registered callback
/// * AssemblyLoadContext.LoadUnmanagedDll()
/// * AssemblyLoadContext.ResolvingUnmanagedDllEvent
/// </summary>
/// <param name="libraryName">The name of the native library to be loaded</param>
/// <param name="assembly">The assembly loading the native library</param>
Expand Down Expand Up @@ -161,6 +179,74 @@ public static bool TryGetExport(IntPtr handle, string name, out IntPtr address)
return address != IntPtr.Zero;
}

/// <summary>
/// Map from assembly to native-library-resolution-callback.
/// Generally interop specific fields and properties are not added to assembly.
/// Therefore, this table uses weak assembly pointers to indirectly achieve
/// similar behavior.
/// </summary>
public static ConditionalWeakTable<Assembly, DllImportResolver> s_nativeDllResolveMap = null;

/// <summary>
/// Set a callback for resolving native library imports from an assembly.
/// This per-assembly callback is the first attempt to resolve native library loads
/// initiated by this assembly.
///
/// Only one resolver callback can be registered per assembly.
/// Trying to register a second callback fails with InvalidOperationException.
/// </summary>
/// <param name="assembly">The assembly for which the callback is registered</param>
/// <param name="resolver">The callback to register</param>
/// <exception cref="System.ArgumentNullException">If assembly or resolver is null</exception>
/// <exception cref="System.InvalidOperationException">If a callback is already set for this assembly</exception>
public static void SetDllImportResolver(Assembly assembly, DllImportResolver resolver)
{
if (assembly == null)
throw new ArgumentNullException(nameof(assembly));
if (resolver == null)
throw new ArgumentNullException(nameof(resolver));
if (!(assembly is RuntimeAssembly))
throw new ArgumentException(SR.Argument_MustBeRuntimeAssembly);

DllImportResolver existingResolver = null;

if (s_nativeDllResolveMap == null)
{
s_nativeDllResolveMap = new ConditionalWeakTable<Assembly, DllImportResolver>();
}
else if (s_nativeDllResolveMap.TryGetValue(assembly, out existingResolver))
{
if (existingResolver != resolver)
throw new InvalidOperationException();
return;
}

s_nativeDllResolveMap.Add(assembly, resolver);
}

/// <summary>
/// The helper function that calls the per-assembly native-library resolver
/// if one is registered for this assembly.
/// </summary>
/// <param name="libraryName">The native library to load</param>
/// <param name="assembly">The assembly trying load the native library</param>
/// <param name="hasDllImportSearchPathFlags">If the pInvoke has DefaultDllImportSearchPathAttribute</param>
/// <param name="dllImportSearchPathFlags">If hasdllImportSearchPathFlags is true, the flags in
/// DefaultDllImportSearchPathAttribute; meaningless otherwise </param>
/// <returns>The handle for the loaded library on success. Null on failure.</returns>
internal static IntPtr LoadLibraryCallbackStub(string libraryName, Assembly assembly,
bool hasDllImportSearchPathFlags, uint dllImportSearchPathFlags)
{
DllImportResolver resolver;

if (!s_nativeDllResolveMap.TryGetValue(assembly, out resolver))
{
return IntPtr.Zero;
}

return resolver(libraryName, assembly, hasDllImportSearchPathFlags ? (DllImportSearchPath?)dllImportSearchPathFlags : null);
}

/// External functions that implement the NativeLibrary interface

[DllImport(JitHelpers.QCall, CharSet = CharSet.Unicode)]
Expand Down
1 change: 1 addition & 0 deletions src/vm/callhelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ enum DispatchCallSimpleFlags
#define STRINGREF_TO_ARGHOLDER(x) (LPVOID)STRINGREFToObject(x)
#define PTR_TO_ARGHOLDER(x) (LPVOID)x
#define DWORD_TO_ARGHOLDER(x) (LPVOID)(SIZE_T)x
#define BOOL_TO_ARGHOLDER(x) DWORD_TO_ARGHOLDER(!!(x))

#define INIT_VARIABLES(count) \
DWORD __numArgs = count; \
Expand Down
119 changes: 91 additions & 28 deletions src/vm/dllimport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6144,7 +6144,7 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryFromPath(LPCWSTR libraryPath, BOOL thr

// static
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryByName(LPCWSTR libraryName, Assembly *callingAssembly,
BOOL hasDllImportSearchFlag, DWORD dllImportSearchFlag,
BOOL hasDllImportSearchFlags, DWORD dllImportSearchFlags,
BOOL throwOnError)
{
CONTRACTL
Expand All @@ -6157,15 +6157,15 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryByName(LPCWSTR libraryName, Assembly *

LoadLibErrorTracker errorTracker;

// First checks if a default DllImportSearchPathFlag was passed in, if so, use that value.
// First checks if a default dllImportSearchPathFlags was passed in, if so, use that value.
// Otherwise checks if the assembly has the DefaultDllImportSearchPathsAttribute attribute. If so, use that value.
BOOL searchAssemblyDirectory = TRUE;
DWORD dllImportSearchPathFlag = 0;
DWORD dllImportSearchPathFlags = 0;

if (hasDllImportSearchFlag)
if (hasDllImportSearchFlags)
{
dllImportSearchPathFlag = dllImportSearchFlag & ~DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
searchAssemblyDirectory = dllImportSearchFlag & DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
dllImportSearchPathFlags = dllImportSearchFlags & ~DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
searchAssemblyDirectory = dllImportSearchFlags & DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;

}
else
Expand All @@ -6174,13 +6174,13 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryByName(LPCWSTR libraryName, Assembly *

if (pModule->HasDefaultDllImportSearchPathsAttribute())
{
dllImportSearchPathFlag = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
dllImportSearchPathFlags = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
searchAssemblyDirectory = pModule->DllImportSearchAssemblyDirectory();
}
}

NATIVE_LIBRARY_HANDLE hmod =
LoadLibraryModuleBySearch(callingAssembly, searchAssemblyDirectory, dllImportSearchPathFlag, &errorTracker, libraryName);
LoadLibraryModuleBySearch(callingAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, &errorTracker, libraryName);

if (throwOnError && (hmod == nullptr))
{
Expand All @@ -6199,11 +6199,11 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(NDirectMethodDesc * pMD
// First checks if the method has DefaultDllImportSearchPathsAttribute. If so, use that value.
// Otherwise checks if the assembly has the attribute. If so, use that value.
BOOL searchAssemblyDirectory = TRUE;
DWORD dllImportSearchPathFlag = 0;
DWORD dllImportSearchPathFlags = 0;

if (pMD->HasDefaultDllImportSearchPathsAttribute())
{
dllImportSearchPathFlag = pMD->DefaultDllImportSearchPathsAttributeCachedValue();
dllImportSearchPathFlags = pMD->DefaultDllImportSearchPathsAttributeCachedValue();
searchAssemblyDirectory = pMD->DllImportSearchAssemblyDirectory();
}
else
Expand All @@ -6212,13 +6212,13 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(NDirectMethodDesc * pMD

if (pModule->HasDefaultDllImportSearchPathsAttribute())
{
dllImportSearchPathFlag = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
dllImportSearchPathFlags = pModule->DefaultDllImportSearchPathsAttributeCachedValue();
searchAssemblyDirectory = pModule->DllImportSearchAssemblyDirectory();
}
}

Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly();
return LoadLibraryModuleBySearch(pAssembly, searchAssemblyDirectory, dllImportSearchPathFlag, pErrorTracker, wszLibName);
return LoadLibraryModuleBySearch(pAssembly, searchAssemblyDirectory, dllImportSearchPathFlags, pErrorTracker, wszLibName);
}

// static
Expand Down Expand Up @@ -6267,23 +6267,34 @@ INT_PTR NDirect::GetNativeLibraryExport(NATIVE_LIBRARY_HANDLE handle, LPCWSTR sy
return address;
}

#ifndef PLATFORM_UNIX
BOOL IsWindowsAPISet(PCWSTR wszLibName)
{
STANDARD_VM_CONTRACT;

// This is replicating quick check from the OS implementation of api sets.
return SString::_wcsnicmp(wszLibName, W("api-"), 4) == 0 ||
SString::_wcsnicmp(wszLibName, W("ext-"), 4) == 0;
}
#endif // !PLATFORM_UNIX

// static
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaHost(NDirectMethodDesc * pMD, AppDomain* pDomain, PCWSTR wszLibName)
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaHost(NDirectMethodDesc * pMD, PCWSTR wszLibName)
{
STANDARD_VM_CONTRACT;
//Dynamic Pinvoke Support:
//Check if we need to provide the host a chance to provide the unmanaged dll

#ifndef PLATFORM_UNIX
// Prevent Overriding of Windows API sets.
// This is replicating quick check from the OS implementation of api sets.
if (SString::_wcsnicmp(wszLibName, W("api-"), 4) == 0 || SString::_wcsnicmp(wszLibName, W("ext-"), 4) == 0)
if (IsWindowsAPISet(wszLibName))
{
// Prevent Overriding of Windows API sets.
return NULL;
}
#endif
#endif // !PLATFORM_UNIX

LPVOID hmod = NULL;
AppDomain* pDomain = GetAppDomain();
CLRPrivBinderCoreCLR *pTPABinder = pDomain->GetTPABinderContext();
Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly();

Expand Down Expand Up @@ -6349,6 +6360,49 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaHost(NDirectMethodDesc * pMD,
return (NATIVE_LIBRARY_HANDLE)hmod;
}

NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleViaCallback(NDirectMethodDesc * pMD, LPCWSTR wszLibName)
{
STANDARD_VM_CONTRACT;

NATIVE_LIBRARY_HANDLE handle = NULL;

DWORD dllImportSearchPathFlags = 0;
BOOL hasDllImportSearchPathFlags = pMD->HasDefaultDllImportSearchPathsAttribute();
if (hasDllImportSearchPathFlags)
{
dllImportSearchPathFlags = pMD->DefaultDllImportSearchPathsAttributeCachedValue();
if (pMD->DllImportSearchAssemblyDirectory())
dllImportSearchPathFlags |= DLLIMPORTSEARCHPATH_ASSEMBLYDIRECTORY;
}

Assembly* pAssembly = pMD->GetMethodTable()->GetAssembly();

GCX_COOP();

struct {
STRINGREF libNameRef;
OBJECTREF assemblyRef;
} gc = { NULL, NULL };

GCPROTECT_BEGIN(gc);

gc.libNameRef = StringObject::NewString(wszLibName);
gc.assemblyRef = pAssembly->GetExposedObject();

PREPARE_NONVIRTUAL_CALLSITE(METHOD__MARSHAL__LOADLIBRARYCALLBACKSTUB);
DECLARE_ARGHOLDER_ARRAY(args, 4);
args[ARGNUM_0] = STRINGREF_TO_ARGHOLDER(gc.libNameRef);
args[ARGNUM_1] = OBJECTREF_TO_ARGHOLDER(gc.assemblyRef);
args[ARGNUM_2] = BOOL_TO_ARGHOLDER(hasDllImportSearchPathFlags);
args[ARGNUM_3] = DWORD_TO_ARGHOLDER(dllImportSearchPathFlags);

// Make the call
CALL_MANAGED_METHOD(handle, NATIVE_LIBRARY_HANDLE, args);
GCPROTECT_END();

return handle;
}

// Try to load the module alongside the assembly where the PInvoke was declared.
NATIVE_LIBRARY_HANDLE NDirect::LoadFromPInvokeAssemblyDirectory(Assembly *pAssembly, LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker)
{
Expand All @@ -6372,11 +6426,12 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadFromPInvokeAssemblyDirectory(Assembly *pAssem
}

// Try to load the module from the native DLL search directories
NATIVE_LIBRARY_HANDLE NDirect::LoadFromNativeDllSearchDirectories(AppDomain* pDomain, LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker)
NATIVE_LIBRARY_HANDLE NDirect::LoadFromNativeDllSearchDirectories(LPCWSTR libName, DWORD flags, LoadLibErrorTracker *pErrorTracker)
{
STANDARD_VM_CONTRACT;

NATIVE_LIBRARY_HANDLE hmod = NULL;
AppDomain* pDomain = GetAppDomain();

if (pDomain->HasNativeDllSearchDirectories())
{
Expand Down Expand Up @@ -6498,7 +6553,7 @@ static void DetermineLibNameVariations(const WCHAR** libNameVariations, int* num
// Search for the library and variants of its name in probing directories.
//static
NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssembly,
BOOL searchAssemblyDirectory, DWORD dllImportSearchPathFlag,
BOOL searchAssemblyDirectory, DWORD dllImportSearchPathFlags,
LoadLibErrorTracker * pErrorTracker, LPCWSTR wszLibName)
{
STANDARD_VM_CONTRACT;
Expand All @@ -6508,7 +6563,7 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
#if defined(FEATURE_CORESYSTEM) && !defined(PLATFORM_UNIX)
// Try to go straight to System32 for Windows API sets. This is replicating quick check from
// the OS implementation of api sets.
if (SString::_wcsnicmp(wszLibName, W("api-"), 4) == 0 || SString::_wcsnicmp(wszLibName, W("ext-"), 4) == 0)
if (IsWindowsAPISet(wszLibName))
{
hmod = LocalLoadLibraryHelper(wszLibName, LOAD_LIBRARY_SEARCH_SYSTEM32, pErrorTracker);
if (hmod != NULL)
Expand Down Expand Up @@ -6536,7 +6591,7 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
currLibNameVariation.Printf(prefixSuffixCombinations[i], PLATFORM_SHARED_LIB_PREFIX_W, wszLibName, PLATFORM_SHARED_LIB_SUFFIX_W);

// NATIVE_DLL_SEARCH_DIRECTORIES set by host is considered well known path
hmod = LoadFromNativeDllSearchDirectories(pDomain, currLibNameVariation, loadWithAlteredPathFlags, pErrorTracker);
hmod = LoadFromNativeDllSearchDirectories(currLibNameVariation, loadWithAlteredPathFlags, pErrorTracker);
if (hmod != NULL)
{
return hmod;
Expand All @@ -6545,11 +6600,11 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
if (!libNameIsRelativePath)
{
DWORD flags = loadWithAlteredPathFlags;
if ((dllImportSearchPathFlag & LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) != 0)
if ((dllImportSearchPathFlags & LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) != 0)
{
// LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR is the only flag affecting absolute path. Don't OR the flags
// unconditionally as all absolute path P/Invokes could then lose LOAD_WITH_ALTERED_SEARCH_PATH.
flags |= dllImportSearchPathFlag;
flags |= dllImportSearchPathFlags;
}

hmod = LocalLoadLibraryHelper(currLibNameVariation, flags, pErrorTracker);
Expand All @@ -6560,14 +6615,14 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
}
else if ((callingAssembly != nullptr) && searchAssemblyDirectory)
{
hmod = LoadFromPInvokeAssemblyDirectory(callingAssembly, currLibNameVariation, loadWithAlteredPathFlags | dllImportSearchPathFlag, pErrorTracker);
hmod = LoadFromPInvokeAssemblyDirectory(callingAssembly, currLibNameVariation, loadWithAlteredPathFlags | dllImportSearchPathFlags, pErrorTracker);
if (hmod != NULL)
{
return hmod;
}
}

hmod = LocalLoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlag, pErrorTracker);
hmod = LocalLoadLibraryHelper(currLibNameVariation, dllImportSearchPathFlags, pErrorTracker);
if (hmod != NULL)
{
return hmod;
Expand Down Expand Up @@ -6597,7 +6652,7 @@ NATIVE_LIBRARY_HANDLE NDirect::LoadLibraryModuleBySearch(Assembly *callingAssemb
Assembly *pAssembly = spec.LoadAssembly(FILE_LOADED);
Module *pModule = pAssembly->FindModuleByName(szLibName);

hmod = LocalLoadLibraryHelper(pModule->GetPath(), loadWithAlteredPathFlags | dllImportSearchPathFlag, pErrorTracker);
hmod = LocalLoadLibraryHelper(pModule->GetPath(), loadWithAlteredPathFlags | dllImportSearchPathFlags, pErrorTracker);
}
}

Expand All @@ -6618,19 +6673,27 @@ HINSTANCE NDirect::LoadLibraryModule(NDirectMethodDesc * pMD, LoadLibErrorTracke
if ( !name || !*name )
return NULL;

ModuleHandleHolder hmod;

PREFIX_ASSUME( name != NULL );
MAKE_WIDEPTR_FROMUTF8( wszLibName, name );

ModuleHandleHolder hmod = LoadLibraryModuleViaCallback(pMD, wszLibName);
if (hmod != NULL)
{
#ifdef FEATURE_PAL
hmod = PAL_RegisterLibraryDirect(hmod, wszLibName);
#endif // FEATURE_PAL
return hmod.Extract();
}

AppDomain* pDomain = GetAppDomain();

// AssemblyLoadContext is not supported in AppX mode and thus,
// we should not perform PInvoke resolution via it when operating in
// AppX mode.
if (!AppX::IsAppXProcess())
{
hmod = LoadLibraryModuleViaHost(pMD, pDomain, wszLibName);
hmod = LoadLibraryModuleViaHost(pMD, wszLibName);
if (hmod != NULL)
{
#ifdef FEATURE_PAL
Expand Down
Loading

0 comments on commit 39b81b9

Please sign in to comment.