Skip to content

Commit

Permalink
Don't create a COM weak reference if the object is an aggregated COMW…
Browse files Browse the repository at this point in the history
…rappers RCW. (#61267)

* Don't create a COM weak reference if the object is an aggregated COMWrappers RCW.

* Add test for weak reference + aggregation with native weak reference impl.

* Apply suggestions from code review

Co-authored-by: Aaron Robinson <arobins@microsoft.com>

Co-authored-by: Aaron Robinson <arobins@microsoft.com>
  • Loading branch information
jkoritzinsky and AaronRobinsonMSFT authored Nov 6, 2021
1 parent dee4c0c commit aefb0fc
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/coreclr/vm/interoplibinterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ComWrappersNative
static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe);

public: // Unwrapping support
static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId);
static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated);
static bool HasManagedObjectComWrapper(_In_ OBJECTREF object, _Out_ bool* isActive);

public: // GC interaction
Expand Down
12 changes: 10 additions & 2 deletions src/coreclr/vm/interoplibinterface_comwrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ namespace
// The EOC is "detached" and no longer used to map between identity and a managed object.
// This will only be set if the EOC was inserted into the cache.
Flags_Detached = 8,

// This EOC is an aggregated instance
Flags_Aggregated = 16
};
DWORD Flags;

Expand Down Expand Up @@ -900,7 +903,11 @@ namespace
: ExternalObjectContext::Flags_None) |
(uniqueInstance
? ExternalObjectContext::Flags_None
: ExternalObjectContext::Flags_InCache);
: ExternalObjectContext::Flags_InCache) |
((flags & CreateObjectFlags::CreateObjectFlags_Aggregated) != 0
? ExternalObjectContext::Flags_Aggregated
: ExternalObjectContext::Flags_None);

ExternalObjectContext::Construct(
resultHolder.GetContext(),
identity,
Expand Down Expand Up @@ -1774,7 +1781,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(
objRef);
}

IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId)
IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated)
{
CONTRACTL
{
Expand Down Expand Up @@ -1807,6 +1814,7 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE
{
ExternalObjectContext* context = reinterpret_cast<ExternalObjectContext*>(contextMaybe);
*wrapperId = context->WrapperId;
*isAggregated = context->IsSet(ExternalObjectContext::Flags_Aggregated);

IUnknown* identity = reinterpret_cast<IUnknown*>(context->Identity);
GCX_PREEMP();
Expand Down
16 changes: 12 additions & 4 deletions src/coreclr/vm/weakreferencenative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct WeakHandleSpinLockHolder
//
// In order to qualify to be used with a HNDTYPE_WEAK_NATIVE_COM, the incoming object must:
// * be an RCW
// * not be an aggregated RCW
// * respond to a QI for IWeakReferenceSource
// * succeed when asked for an IWeakReference*
//
Expand Down Expand Up @@ -149,7 +150,14 @@ NativeComWeakHandleInfo* GetComWeakReferenceInfo(OBJECTREF* pObject)
#endif
{
#ifdef FEATURE_COMWRAPPERS
pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId));
bool isAggregated = false;
pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId, &isAggregated));
if (isAggregated)
{
// If the RCW is an aggregated RCW, then the managed object cannot be recreated from the IUnknown as the outer IUnknown wraps the managed object.
// In this case, don't create a weak reference backed by a COM weak reference.
pWeakReferenceSource = nullptr;
}
#endif
}

Expand Down Expand Up @@ -448,7 +456,7 @@ FCIMPL3(void, WeakReferenceNative::Create, WeakReferenceObject * pThisUNSAFE, Ob
_ASSERTE(gc.pThis->GetMethodTable()->CanCastToClass(pWeakReferenceMT));

// Create the handle.
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
NativeComWeakHandleInfo *comWeakHandleInfo = nullptr;
if (gc.pTarget != NULL)
{
Expand Down Expand Up @@ -690,7 +698,7 @@ FCIMPL1(Object *, WeakReferenceNative::GetTarget, WeakReferenceObject * pThisUNS

OBJECTREF pTarget = GetWeakReferenceTarget(pThis);

#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
// If we found an object, or we're not a native COM weak reference, then we're done. Othewrise
// we can try to create a new RCW to the underlying native COM object if it's still alive.
if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle))
Expand Down Expand Up @@ -718,7 +726,7 @@ FCIMPL1(Object *, WeakReferenceOfTNative::GetTarget, WeakReferenceObject * pThis
OBJECTREF pTarget = GetWeakReferenceTarget(pThis);


#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
// If we found an object, or we're not a native COM weak reference, then we're done. Othewrise
// we can try to create a new RCW to the underlying native COM object if it's still alive.
if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,122 @@ namespace
return UnknownImpl::DoRelease();
}
};

struct WeakReferenceSource : public IWeakReferenceSource, public IInspectable
{
private:
IUnknown* _outerUnknown;
ComSmartPtr<WeakReference> _weakReference;
public:
WeakReferenceSource(IUnknown* outerUnknown)
:_outerUnknown(outerUnknown),
_weakReference(new WeakReference(this, 1))
{
}

STDMETHOD(GetWeakReference)(IWeakReference** ppWeakReference)
{
_weakReference->AddRef();
*ppWeakReference = _weakReference;
return S_OK;
}

STDMETHOD(QueryInterface)(
/* [in] */ REFIID riid,
/* [iid_is][out] */ void ** ppvObject)
{
if (riid == __uuidof(IWeakReferenceSource))
{
*ppvObject = static_cast<IWeakReferenceSource*>(this);
_weakReference->AddStrongRef();
return S_OK;
}
return _outerUnknown->QueryInterface(riid, ppvObject);
}
STDMETHOD_(ULONG, AddRef)(void)
{
return _weakReference->AddStrongRef();
}
STDMETHOD_(ULONG, Release)(void)
{
return _weakReference->ReleaseStrongRef();
}

STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
{
return E_NOTIMPL;
}

STDMETHOD(GetIids)(
ULONG *iidCount,
IID **iids)
{
return E_NOTIMPL;
}

STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
{
*trustLevel = FullTrust;
return S_OK;
}
};

struct AggregatedWeakReferenceSource : IInspectable
{
private:
IUnknown* _outerUnknown;
ComSmartPtr<WeakReferenceSource> _weakReference;
public:
AggregatedWeakReferenceSource(IUnknown* outerUnknown)
:_outerUnknown(outerUnknown),
_weakReference(new WeakReferenceSource(outerUnknown))
{
}

STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
{
return E_NOTIMPL;
}

STDMETHOD(GetIids)(
ULONG *iidCount,
IID **iids)
{
return E_NOTIMPL;
}

STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
{
*trustLevel = FullTrust;
return S_OK;
}

STDMETHOD(QueryInterface)(
/* [in] */ REFIID riid,
/* [iid_is][out] */ void ** ppvObject)
{
if (riid == __uuidof(IWeakReferenceSource))
{
return _weakReference->QueryInterface(riid, ppvObject);
}
return _outerUnknown->QueryInterface(riid, ppvObject);
}
STDMETHOD_(ULONG, AddRef)(void)
{
return _outerUnknown->AddRef();
}
STDMETHOD_(ULONG, Release)(void)
{
return _outerUnknown->Release();
}
};
}
extern "C" DLL_EXPORT WeakReferencableObject* STDMETHODCALLTYPE CreateWeakReferencableObject()
{
return new WeakReferencableObject();
}

extern "C" DLL_EXPORT AggregatedWeakReferenceSource* STDMETHODCALLTYPE CreateAggregatedWeakReferenceObject(IUnknown* pOuter)
{
return new AggregatedWeakReferenceSource(pOuter);
}
Loading

0 comments on commit aefb0fc

Please sign in to comment.