Skip to content

Commit

Permalink
Add support for default implementation of static virtuals with method…
Browse files Browse the repository at this point in the history
… constraints (#89061)

- The major problem was the logic which incorrectly would instantiate the methods when it wasn't necessary
- As the number of flags to the implementation functions has grown very large, this change also includes logic converting them all to a single flags variable when passing them around

Fixes #73658
Fixes #78865
  • Loading branch information
davidwrighton authored Jul 19, 2023
1 parent ef4860a commit d1adf81
Show file tree
Hide file tree
Showing 13 changed files with 488 additions and 41 deletions.
52 changes: 52 additions & 0 deletions src/coreclr/inc/enum_class_flags.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

#ifndef ENUM_CLASS_FLAGS_OPERATORS
#define ENUM_CLASS_FLAGS_OPERATORS

template <typename T>
inline auto operator& (T left, T right) -> decltype(T::support_use_as_flags)
{
return static_cast<T>(static_cast<int>(left) & static_cast<int>(right));
}

template <typename T>
inline auto operator| (T left, T right) -> decltype(T::support_use_as_flags)
{
return static_cast<T>(static_cast<int>(left) | static_cast<int>(right));
}

template <typename T>
inline auto operator^ (T left, T right) -> decltype(T::support_use_as_flags)
{
return static_cast<T>(static_cast<int>(left) ^ static_cast<int>(right));
}

template <typename T>
inline auto operator~ (T value) -> decltype(T::support_use_as_flags)
{
return static_cast<T>(~static_cast<int>(value));
}

template <typename T>
inline auto operator |= (T& left, T right) -> const decltype(T::support_use_as_flags)&
{
left = left | right;
return left;
}

template <typename T>
inline auto operator &= (T& left, T right) -> const decltype(T::support_use_as_flags)&
{
left = left & right;
return left;
}

template <typename T>
inline auto operator ^= (T& left, T right) -> const decltype(T::support_use_as_flags)&
{
left = left ^ right;
return left;
}

#endif /* ENUM_CLASS_FLAGS_OPERATORS */
6 changes: 3 additions & 3 deletions src/coreclr/vm/genericdict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1144,9 +1144,9 @@ Dictionary::PopulateEntry(
pResolvedMD = constraintType.GetMethodTable()->ResolveVirtualStaticMethod(
ownerType.GetMethodTable(),
pMethod,
/* allowNullResult */ TRUE,
/* verifyImplemented */ FALSE,
/* allowVariantMatches */ TRUE,
ResolveVirtualStaticMethodFlags::AllowNullResult |
ResolveVirtualStaticMethodFlags::AllowVariantMatches |
ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc,
&uniqueResolution);

// If we couldn't get an exact result, fall back to using a stub to make the exact function call
Expand Down
94 changes: 68 additions & 26 deletions src/coreclr/vm/methodtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6210,22 +6210,26 @@ MethodTable::FindDispatchImpl(

// Try exact match first
MethodDesc *pDefaultMethod = NULL;

FindDefaultInterfaceImplementationFlags flags = FindDefaultInterfaceImplementationFlags::InstantiateFoundMethodDesc;
if (throwOnConflict)
flags = flags | FindDefaultInterfaceImplementationFlags::ThrowOnConflict;

BOOL foundDefaultInterfaceImplementation = FindDefaultInterfaceImplementation(
pIfcMD, // the interface method being resolved
pIfcMT, // the interface being resolved
&pDefaultMethod,
FALSE, // allowVariance
throwOnConflict);
flags);

// If there's no exact match, try a variant match
if (!foundDefaultInterfaceImplementation && pIfcMT->HasVariance())
{
flags = flags | FindDefaultInterfaceImplementationFlags::AllowVariance;
foundDefaultInterfaceImplementation = FindDefaultInterfaceImplementation(
pIfcMD, // the interface method being resolved
pIfcMT, // the interface being resolved
&pDefaultMethod,
TRUE, // allowVariance
throwOnConflict);
flags);
}

if (foundDefaultInterfaceImplementation)
Expand Down Expand Up @@ -6324,10 +6328,13 @@ namespace
MethodTable *pMT,
MethodDesc *interfaceMD,
MethodTable *interfaceMT,
BOOL allowVariance,
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags,
MethodDesc **candidateMD,
ClassLoadLevel level)
{
bool allowVariance = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::AllowVariance) != FindDefaultInterfaceImplementationFlags::None;
bool instantiateMethodInstantiation = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::InstantiateFoundMethodDesc) != FindDefaultInterfaceImplementationFlags::None;

*candidateMD = NULL;

MethodDesc *candidateMaybe = NULL;
Expand Down Expand Up @@ -6418,11 +6425,20 @@ namespace
else
{
// Static virtual methods don't record MethodImpl slots so they need special handling
ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags = ResolveVirtualStaticMethodFlags::None;
if (allowVariance)
{
resolveVirtualStaticMethodFlags |= ResolveVirtualStaticMethodFlags::AllowVariantMatches;
}
if (instantiateMethodInstantiation)
{
resolveVirtualStaticMethodFlags |= ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc;
}

candidateMaybe = pMT->TryResolveVirtualStaticMethodOnThisType(
interfaceMT,
interfaceMD,
/* verifyImplemented */ FALSE,
/* allowVariance */ allowVariance,
resolveVirtualStaticMethodFlags,
/* level */ level);
}
}
Expand Down Expand Up @@ -6461,8 +6477,7 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
MethodDesc *pInterfaceMD,
MethodTable *pInterfaceMT,
MethodDesc **ppDefaultMethod,
BOOL allowVariance,
BOOL throwOnConflict,
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags,
ClassLoadLevel level
)
{
Expand All @@ -6478,12 +6493,13 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
} CONTRACT_END;

#ifdef FEATURE_DEFAULT_INTERFACES
bool allowVariance = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::AllowVariance) != FindDefaultInterfaceImplementationFlags::None;
CQuickArray<MatchCandidate> candidates;
unsigned candidatesCount = 0;

// Check the current method table itself
MethodDesc *candidateMaybe = NULL;
if (IsInterface() && TryGetCandidateImplementation(this, pInterfaceMD, pInterfaceMT, allowVariance, &candidateMaybe, level))
if (IsInterface() && TryGetCandidateImplementation(this, pInterfaceMD, pInterfaceMT, findDefaultImplementationFlags, &candidateMaybe, level))
{
_ASSERTE(candidateMaybe != NULL);

Expand Down Expand Up @@ -6523,7 +6539,7 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
MethodTable *pCurMT = it.GetInterface(pMT, level);

MethodDesc *pCurMD = NULL;
if (TryGetCandidateImplementation(pCurMT, pInterfaceMD, pInterfaceMT, allowVariance, &pCurMD, level))
if (TryGetCandidateImplementation(pCurMT, pInterfaceMD, pInterfaceMT, findDefaultImplementationFlags, &pCurMD, level))
{
//
// Found a match. But is it a more specific match (we want most specific interfaces)
Expand Down Expand Up @@ -6619,6 +6635,8 @@ BOOL MethodTable::FindDefaultInterfaceImplementation(
}
else if (pBestCandidateMT != candidates[i].pMT)
{
bool throwOnConflict = (findDefaultImplementationFlags & FindDefaultInterfaceImplementationFlags::ThrowOnConflict) != FindDefaultInterfaceImplementationFlags::None;

if (throwOnConflict)
ThrowExceptionForConflictingOverride(this, pInterfaceMT, pInterfaceMD);

Expand Down Expand Up @@ -8875,12 +8893,15 @@ MethodDesc *
MethodTable::ResolveVirtualStaticMethod(
MethodTable* pInterfaceType,
MethodDesc* pInterfaceMD,
BOOL allowNullResult,
BOOL verifyImplemented,
BOOL allowVariantMatches,
ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags,
BOOL* uniqueResolution,
ClassLoadLevel level)
{
bool verifyImplemented = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::VerifyImplemented) != ResolveVirtualStaticMethodFlags::None;
bool allowVariantMatches = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::AllowVariantMatches) != ResolveVirtualStaticMethodFlags::None;
bool instantiateMethodParameters = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc) != ResolveVirtualStaticMethodFlags::None;
bool allowNullResult = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::AllowNullResult) != ResolveVirtualStaticMethodFlags::None;

if (uniqueResolution != nullptr)
{
*uniqueResolution = TRUE;
Expand Down Expand Up @@ -8912,7 +8933,7 @@ MethodTable::ResolveVirtualStaticMethod(
// Search for match on a per-level in the type hierarchy
for (MethodTable* pMT = this; pMT != nullptr; pMT = pMT->GetParentMethodTable())
{
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
MethodDesc* pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pInterfaceType, pInterfaceMD, resolveVirtualStaticMethodFlags & ~ResolveVirtualStaticMethodFlags::AllowVariantMatches, level);
if (pMD != nullptr)
{
return pMD;
Expand Down Expand Up @@ -8956,7 +8977,7 @@ MethodTable::ResolveVirtualStaticMethod(
{
// Variant or equivalent matching interface found
// Attempt to resolve on variance matched interface
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, verifyImplemented, /*allowVariance*/ FALSE, level);
pMD = pMT->TryResolveVirtualStaticMethodOnThisType(pItfInMap, pInterfaceMD, resolveVirtualStaticMethodFlags & ~ResolveVirtualStaticMethodFlags::AllowVariantMatches, level);
if (pMD != nullptr)
{
return pMD;
Expand All @@ -8970,12 +8991,25 @@ MethodTable::ResolveVirtualStaticMethod(
BOOL allowVariantMatchInDefaultImplementationLookup = FALSE;
do
{
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags = FindDefaultInterfaceImplementationFlags::None;
if (allowVariantMatchInDefaultImplementationLookup)
{
findDefaultImplementationFlags |= FindDefaultInterfaceImplementationFlags::AllowVariance;
}
if (uniqueResolution == nullptr)
{
findDefaultImplementationFlags |= FindDefaultInterfaceImplementationFlags::ThrowOnConflict;
}
if (instantiateMethodParameters)
{
findDefaultImplementationFlags |= FindDefaultInterfaceImplementationFlags::InstantiateFoundMethodDesc;
}

BOOL haveUniqueDefaultImplementation = FindDefaultInterfaceImplementation(
pInterfaceMD,
pInterfaceType,
&pMDDefaultImpl,
/* allowVariance */ allowVariantMatchInDefaultImplementationLookup,
/* throwOnConflict */ uniqueResolution == nullptr,
findDefaultImplementationFlags,
level);
if (haveUniqueDefaultImplementation || (pMDDefaultImpl != nullptr && (verifyImplemented || uniqueResolution != nullptr)))
{
Expand Down Expand Up @@ -9018,8 +9052,12 @@ MethodTable::ResolveVirtualStaticMethod(
// Try to locate the appropriate MethodImpl matching a given interface static virtual method.
// Returns nullptr on failure.
MethodDesc*
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level)
MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags, ClassLoadLevel level)
{
bool instantiateMethodParameters = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc) != ResolveVirtualStaticMethodFlags::None;
bool allowVariance = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::AllowVariantMatches) != ResolveVirtualStaticMethodFlags::None;
bool verifyImplemented = (resolveVirtualStaticMethodFlags & ResolveVirtualStaticMethodFlags::VerifyImplemented) != ResolveVirtualStaticMethodFlags::None;

HRESULT hr = S_OK;
IMDInternalImport* pMDInternalImport = GetMDImport();
HENUMInternalMethodImplHolder hEnumMethodImpl(pMDInternalImport);
Expand Down Expand Up @@ -9148,7 +9186,7 @@ MethodTable::TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType
COMPlusThrow(kTypeLoadException, E_FAIL);
}

if (!verifyImplemented)
if (!verifyImplemented && instantiateMethodParameters)
{
pMethodImpl = pMethodImpl->FindOrCreateAssociatedMethodDesc(
pMethodImpl,
Expand Down Expand Up @@ -9202,9 +9240,7 @@ MethodTable::VerifyThatAllVirtualStaticMethodsAreImplemented()
!ResolveVirtualStaticMethod(
pInterfaceMT,
pMD,
/* allowNullResult */ TRUE,
/* verifyImplemented */ TRUE,
/* allowVariantMatches */ FALSE,
ResolveVirtualStaticMethodFlags::AllowNullResult | ResolveVirtualStaticMethodFlags::VerifyImplemented,
/* uniqueResolution */ &uniqueResolution,
/* level */ CLASS_LOAD_EXACTPARENTS)))
{
Expand Down Expand Up @@ -9240,12 +9276,18 @@ MethodTable::TryResolveConstraintMethodApprox(
_ASSERTE(!thInterfaceType.IsTypeDesc());
_ASSERTE(thInterfaceType.IsInterface());
BOOL uniqueResolution;

ResolveVirtualStaticMethodFlags flags = ResolveVirtualStaticMethodFlags::AllowVariantMatches
| ResolveVirtualStaticMethodFlags::InstantiateResultOverFinalMethodDesc;
if (pfForceUseRuntimeLookup != NULL)
{
flags |= ResolveVirtualStaticMethodFlags::AllowNullResult;
}

MethodDesc *result = ResolveVirtualStaticMethod(
thInterfaceType.GetMethodTable(),
pInterfaceMD,
/* allowNullResult */pfForceUseRuntimeLookup != NULL,
/* verifyImplemented */ FALSE,
/* allowVariantMatches */ TRUE,
flags,
&uniqueResolution);
if (result == NULL || !uniqueResolution)
{
Expand Down
33 changes: 26 additions & 7 deletions src/coreclr/vm/methodtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "contractimpl.h"
#include "generics.h"
#include "gcinfotypes.h"
#include "enum_class_flags.h"

/*
* Forward Declarations
Expand Down Expand Up @@ -63,6 +64,28 @@ class ClassFactoryBase;
class ArgDestination;
enum class WellKnownAttribute : DWORD;

enum class ResolveVirtualStaticMethodFlags
{
None = 0,
AllowNullResult = 1,
VerifyImplemented = 2,
AllowVariantMatches = 4,
InstantiateResultOverFinalMethodDesc = 8,

support_use_as_flags // Enable the template functions in enum_class_flags.h
};


enum class FindDefaultInterfaceImplementationFlags
{
None,
AllowVariance = 1,
ThrowOnConflict = 2,
InstantiateFoundMethodDesc = 4,

support_use_as_flags // Enable the template functions in enum_class_flags.h
};

//============================================================================
// This is the in-memory structure of a class and it will evolve.
//============================================================================
Expand Down Expand Up @@ -2084,7 +2107,6 @@ class MethodTable
MethodDesc *GetMethodDescForComInterfaceMethod(MethodDesc *pItfMD, bool fNullOk);
#endif // FEATURE_COMINTEROP


// Resolve virtual static interface method pInterfaceMD on this type.
//
// Specify allowNullResult to return NULL instead of throwing if the there is no implementation
Expand All @@ -2096,9 +2118,7 @@ class MethodTable
MethodDesc *ResolveVirtualStaticMethod(
MethodTable* pInterfaceType,
MethodDesc* pInterfaceMD,
BOOL allowNullResult,
BOOL verifyImplemented = FALSE,
BOOL allowVariantMatches = TRUE,
ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags,
BOOL *uniqueResolution = NULL,
ClassLoadLevel level = CLASS_LOADED);

Expand Down Expand Up @@ -2178,8 +2198,7 @@ class MethodTable
MethodDesc *pInterfaceMD,
MethodTable *pObjectMT,
MethodDesc **ppDefaultMethod,
BOOL allowVariance,
BOOL throwOnConflict,
FindDefaultInterfaceImplementationFlags findDefaultImplementationFlags,
ClassLoadLevel level = CLASS_LOADED);
#endif // DACCESS_COMPILE

Expand Down Expand Up @@ -2219,7 +2238,7 @@ class MethodTable

// Try to resolve a given static virtual method override on this type. Return nullptr
// when not found.
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, BOOL verifyImplemented, BOOL allowVariance, ClassLoadLevel level);
MethodDesc *TryResolveVirtualStaticMethodOnThisType(MethodTable* pInterfaceType, MethodDesc* pInterfaceMD, ResolveVirtualStaticMethodFlags resolveVirtualStaticMethodFlags, ClassLoadLevel level);

public:
static MethodDesc *MapMethodDeclToMethodImpl(MethodDesc *pMDDecl);
Expand Down
5 changes: 2 additions & 3 deletions src/coreclr/vm/runtimehandles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1102,9 +1102,8 @@ extern "C" MethodDesc* QCALLTYPE RuntimeTypeHandle_GetInterfaceMethodImplementat
pResult = typeHandle.GetMethodTable()->ResolveVirtualStaticMethod(
thOwnerOfMD.GetMethodTable(),
pMD,
/* allowNullResult */ TRUE,
/* verifyImplemented*/ FALSE,
/* allowVariantMatches */ TRUE);
ResolveVirtualStaticMethodFlags::AllowNullResult |
ResolveVirtualStaticMethodFlags::AllowVariantMatches);
}
else
{
Expand Down
5 changes: 3 additions & 2 deletions src/coreclr/vm/typedesc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1611,8 +1611,9 @@ BOOL TypeVarTypeDesc::SatisfiesConstraints(SigTypeContext *pTypeContextOfConstra
if (pMD->IsVirtual() &&
pMD->IsStatic() &&
(pMD->IsAbstract() && !thElem.AsMethodTable()->ResolveVirtualStaticMethod(
pInterfaceMT, pMD, /* allowNullResult */ TRUE, /* verifyImplemented */ TRUE,
/*allowVariantMatches*/ TRUE, /*uniqueResolution*/ NULL, CLASS_DEPENDENCIES_LOADED)))
pInterfaceMT, pMD,
ResolveVirtualStaticMethodFlags::AllowNullResult | ResolveVirtualStaticMethodFlags::VerifyImplemented | ResolveVirtualStaticMethodFlags::AllowVariantMatches,
/*uniqueResolution*/ NULL, CLASS_DEPENDENCIES_LOADED)))
{
virtualStaticResolutionCheckFailed = true;
break;
Expand Down
Loading

0 comments on commit d1adf81

Please sign in to comment.