diff --git a/mlir/include/mlir/Support/ThreadLocalCache.h b/mlir/include/mlir/Support/ThreadLocalCache.h index 1be94ca14bcfac..87cc52cc56ac4f 100644 --- a/mlir/include/mlir/Support/ThreadLocalCache.h +++ b/mlir/include/mlir/Support/ThreadLocalCache.h @@ -25,28 +25,80 @@ namespace mlir { /// cache has very large lock contention. template class ThreadLocalCache { + struct PerInstanceState; + + /// The "observer" is owned by a thread-local cache instance. It is + /// constructed the first time a `ThreadLocalCache` instance is accessed by a + /// thread, unless `perInstanceState` happens to get re-allocated to the same + /// address as a previous one. A `thread_local` instance of this class is + /// destructed when the thread in which it lives is destroyed. + /// + /// This class is called the "observer" because while values cached in + /// thread-local caches are owned by `PerInstanceState`, a reference is stored + /// via this class in the TLC. With a double pointer, it knows when the + /// referenced value has been destroyed. + struct Observer { + /// This is the double pointer, explicitly allocated because we need to keep + /// the address stable if the TLC map re-allocates. It is owned by the + /// observer and shared with the value owner. + std::shared_ptr ptr = std::make_shared(nullptr); + /// Because the `Owner` instance that lives inside `PerInstanceState` + /// contains a reference to the double pointer, and likewise this class + /// contains a reference to the value, we need to synchronize destruction of + /// the TLC and the `PerInstanceState` to avoid racing. This weak pointer is + /// acquired during TLC destruction if the `PerInstanceState` hasn't entered + /// its destructor yet, and prevents it from happening. + std::weak_ptr keepalive; + }; + + /// This struct owns the cache entries. It contains a reference back to the + /// reference inside the cache so that it can be written to null to indicate + /// that the cache entry is invalidated. It needs to do this because + /// `perInstanceState` could get re-allocated to the same pointer and we don't + /// remove entries from the TLC when it is deallocated. Thus, we have to reset + /// the TLC entries to a starting state in case the `ThreadLocalCache` lives + /// shorter than the threads. + struct Owner { + /// Save a pointer to the reference and write it to the newly created entry. + Owner(Observer &observer) + : value(std::make_unique()), ptrRef(observer.ptr) { + *observer.ptr = value.get(); + } + ~Owner() { + if (std::shared_ptr ptr = ptrRef.lock()) + *ptr = nullptr; + } + + Owner(Owner &&) = default; + Owner &operator=(Owner &&) = default; + + std::unique_ptr value; + std::weak_ptr ptrRef; + }; + // Keep a separate shared_ptr protected state that can be acquired atomically // instead of using shared_ptr's for each value. This avoids a problem // where the instance shared_ptr is locked() successfully, and then the // ThreadLocalCache gets destroyed before remove() can be called successfully. struct PerInstanceState { - /// Remove the given value entry. This is generally called when a thread - /// local cache is destructing. + /// Remove the given value entry. This is called when a thread local cache + /// is destructing but still contains references to values owned by the + /// `PerInstanceState`. Removal is required because it prevents writeback to + /// a pointer that was deallocated. void remove(ValueT *value) { // Erase the found value directly, because it is guaranteed to be in the // list. llvm::sys::SmartScopedLock threadInstanceLock(instanceMutex); - auto it = - llvm::find_if(instances, [&](std::unique_ptr &instance) { - return instance.get() == value; - }); + auto it = llvm::find_if(instances, [&](Owner &instance) { + return instance.value.get() == value; + }); assert(it != instances.end() && "expected value to exist in cache"); instances.erase(it); } /// Owning pointers to all of the values that have been constructed for this /// object in the static cache. - SmallVector, 1> instances; + SmallVector instances; /// A mutex used when a new thread instance has been added to the cache for /// this object. @@ -57,13 +109,14 @@ class ThreadLocalCache { /// instance of the non-static cache and a weak reference to an instance of /// ValueT. We use a weak reference here so that the object can be destroyed /// without needing to lock access to the cache itself. - struct CacheType - : public llvm::SmallDenseMap> { + struct CacheType : public llvm::SmallDenseMap { ~CacheType() { - // Remove the values of this cache that haven't already expired. - for (auto &it : *this) - if (std::shared_ptr value = it.second.lock()) - it.first->remove(value.get()); + // Remove the values of this cache that haven't already expired. This is + // required because if we don't remove them, they will contain a reference + // back to the data here that is being destroyed. + for (auto &[instance, observer] : *this) + if (std::shared_ptr state = observer.keepalive.lock()) + state->remove(*observer.ptr); } /// Clear out any unused entries within the map. This method is not @@ -71,7 +124,7 @@ class ThreadLocalCache { void clearExpiredEntries() { for (auto it = this->begin(), e = this->end(); it != e;) { auto curIt = it++; - if (curIt->second.expired()) + if (!*curIt->second.ptr) this->erase(curIt); } } @@ -88,22 +141,23 @@ class ThreadLocalCache { ValueT &get() { // Check for an already existing instance for this thread. CacheType &staticCache = getStaticCache(); - std::weak_ptr &threadInstance = staticCache[perInstanceState.get()]; - if (std::shared_ptr value = threadInstance.lock()) + Observer &threadInstance = staticCache[perInstanceState.get()]; + if (ValueT *value = *threadInstance.ptr) return *value; // Otherwise, create a new instance for this thread. - llvm::sys::SmartScopedLock threadInstanceLock( - perInstanceState->instanceMutex); - perInstanceState->instances.push_back(std::make_unique()); - ValueT *instance = perInstanceState->instances.back().get(); - threadInstance = std::shared_ptr(perInstanceState, instance); + { + llvm::sys::SmartScopedLock threadInstanceLock( + perInstanceState->instanceMutex); + perInstanceState->instances.emplace_back(threadInstance); + } + threadInstance.keepalive = perInstanceState; // Before returning the new instance, take the chance to clear out any used // entries in the static map. The cache is only cleared within the same // thread to remove the need to lock the cache itself. staticCache.clearExpiredEntries(); - return *instance; + return **threadInstance.ptr; } ValueT &operator*() { return get(); } ValueT *operator->() { return &get(); }