Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Evict program cache on PI_ERROR_OUT_OF_RESOURCES #11987

Merged
merged 8 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ std::optional<sycl::detail::pi::PiProgram> context_impl::getProgramForDevImgs(
const device &Device, const std::set<std::uintptr_t> &ImgIdentifiers,
const std::string &ObjectTypeName) {

KernelProgramCache::ProgramWithBuildStateT *BuildRes = nullptr;
KernelProgramCache::ProgramBuildResultPtr BuildRes = nullptr;
{
auto LockedCache = MKernelProgramCache.acquireCachedPrograms();
auto &KeyMap = LockedCache.get().KeyMap;
Expand All @@ -471,12 +471,18 @@ std::optional<sycl::detail::pi::PiProgram> context_impl::getProgramForDevImgs(
assert(KeyMappingsIt != KeyMap.end());
auto CachedProgIt = Cache.find(KeyMappingsIt->second);
assert(CachedProgIt != Cache.end());
BuildRes = &CachedProgIt->second;
BuildRes = CachedProgIt->second;
}
}
if (!BuildRes)
return std::nullopt;
return *MKernelProgramCache.waitUntilBuilt<compile_program_error>(BuildRes);
using BuildState = KernelProgramCache::BuildState;
BuildState NewState = BuildRes->waitUntilTransition();
if (NewState == BuildState::BS_Failed)
throw compile_program_error(BuildRes->Error.Msg, BuildRes->Error.Code);

assert(NewState == BuildState::BS_Done);
return BuildRes->Val;
}

std::optional<sycl::detail::pi::PiProgram>
Expand Down
28 changes: 2 additions & 26 deletions sycl/source/detail/kernel_program_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,8 @@
namespace sycl {
inline namespace _V1 {
namespace detail {
KernelProgramCache::~KernelProgramCache() {
for (auto &ProgIt : MCachedPrograms.Cache) {
ProgramWithBuildStateT &ProgWithState = ProgIt.second;
sycl::detail::pi::PiProgram *ToBeDeleted = ProgWithState.Ptr.load();

if (!ToBeDeleted)
continue;

auto KernIt = MKernelsPerProgramCache.find(*ToBeDeleted);

if (KernIt != MKernelsPerProgramCache.end()) {
for (auto &p : KernIt->second) {
BuildResult<KernelArgMaskPairT> &KernelWithState = p.second;
KernelArgMaskPairT *KernelArgMaskPair = KernelWithState.Ptr.load();

if (KernelArgMaskPair) {
const PluginPtr &Plugin = MParentContext->getPlugin();
Plugin->call<PiApiKind::piKernelRelease>(KernelArgMaskPair->first);
}
}
MKernelsPerProgramCache.erase(KernIt);
}

const PluginPtr &Plugin = MParentContext->getPlugin();
Plugin->call<PiApiKind::piProgramRelease>(*ToBeDeleted);
}
const PluginPtr &KernelProgramCache::getPlugin() {
return MParentContext->getPlugin();
}
} // namespace detail
} // namespace _V1
Expand Down
195 changes: 144 additions & 51 deletions sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,16 @@ class KernelProgramCache {
};

/// Denotes the state of a build.
enum BuildState { BS_InProgress, BS_Done, BS_Failed };
enum class BuildState { BS_Initial, BS_InProgress, BS_Done, BS_Failed };

/// Denotes pointer to some entity with its general state and build error.
/// The pointer is not null if and only if the entity is usable.
/// State of the entity is provided by the user of cache instance.
/// Currently there is only a single user - ProgramManager class.
template <typename T> struct BuildResult {
std::atomic<T *> Ptr;
T Val;
std::atomic<BuildState> State;
BuildError Error;
std::atomic<BuildState> State{BuildState::BS_Initial};
BuildError Error{"", 0};

/// Condition variable to signal that build result is ready.
/// A per-object (i.e. kernel or program) condition variable is employed
Expand All @@ -69,10 +68,38 @@ class KernelProgramCache {
/// A mutex to be employed along with MBuildCV.
std::mutex MBuildResultMutex;

BuildResult(T *P, BuildState S) : Ptr{P}, State{S}, Error{"", 0} {}
BuildState
waitUntilTransition(BuildState From = BuildState::BS_InProgress) {
BuildState To;
std::unique_lock Lock(MBuildResultMutex);
MBuildCV.wait(Lock, [&] {
To = State;
return State != From;
});
return To;
}

void updateAndNotify(BuildState DesiredState) {
{
std::lock_guard<std::mutex> Lock(MBuildResultMutex);
State.store(DesiredState);
}
MBuildCV.notify_all();
}
};

struct ProgramBuildResult : public BuildResult<sycl::detail::pi::PiProgram> {
PluginPtr Plugin;
ProgramBuildResult(const PluginPtr &Plugin) : Plugin(Plugin) {
Val = nullptr;
}
~ProgramBuildResult() {
if (Val)
Plugin->call<PiApiKind::piProgramRelease>(Val);
}
};
using ProgramBuildResultPtr = std::shared_ptr<ProgramBuildResult>;

using ProgramWithBuildStateT = BuildResult<sycl::detail::pi::PiProgram>;
/* Drop LinkOptions and CompileOptions from CacheKey since they are only used
* when debugging environment variables are set and we can just ignore them
* since all kernels will have their build options overridden with the same
Expand All @@ -83,7 +110,7 @@ class KernelProgramCache {
std::pair<std::uintptr_t, sycl::detail::pi::PiDevice>;

struct ProgramCache {
::boost::unordered_map<ProgramCacheKeyT, ProgramWithBuildStateT> Cache;
::boost::unordered_map<ProgramCacheKeyT, ProgramBuildResultPtr> Cache;
::boost::unordered_multimap<CommonProgramKeyT, ProgramCacheKeyT> KeyMap;

size_t size() const noexcept { return Cache.size(); }
Expand All @@ -93,8 +120,20 @@ class KernelProgramCache {

using KernelArgMaskPairT =
std::pair<sycl::detail::pi::PiKernel, const KernelArgMask *>;
struct KernelBuildResult : public BuildResult<KernelArgMaskPairT> {
PluginPtr Plugin;
KernelBuildResult(const PluginPtr &Plugin) : Plugin(Plugin) {
Val.first = nullptr;
}
~KernelBuildResult() {
if (Val.first)
Plugin->call<PiApiKind::piKernelRelease>(Val.first);
}
};
using KernelBuildResultPtr = std::shared_ptr<KernelBuildResult>;

using KernelByNameT =
::boost::unordered_map<std::string, BuildResult<KernelArgMaskPairT>>;
::boost::unordered_map<std::string, KernelBuildResultPtr>;
using KernelCacheT =
::boost::unordered_map<sycl::detail::pi::PiProgram, KernelByNameT>;

Expand All @@ -112,7 +151,7 @@ class KernelProgramCache {
using KernelFastCacheT =
::boost::unordered_flat_map<KernelFastCacheKeyT, KernelFastCacheValT>;

~KernelProgramCache();
~KernelProgramCache() = default;

void setContextPtr(const ContextPtr &AContext) { MParentContext = AContext; }

Expand All @@ -124,61 +163,30 @@ class KernelProgramCache {
return {MKernelsPerProgramCache, MKernelsPerProgramCacheMutex};
}

std::pair<ProgramWithBuildStateT *, bool>
std::pair<ProgramBuildResultPtr, bool>
getOrInsertProgram(const ProgramCacheKeyT &CacheKey) {
auto LockedCache = acquireCachedPrograms();
auto &ProgCache = LockedCache.get();
auto Inserted = ProgCache.Cache.emplace(
std::piecewise_construct, std::forward_as_tuple(CacheKey),
std::forward_as_tuple(nullptr, BS_InProgress));
if (Inserted.second) {
auto [It, DidInsert] = ProgCache.Cache.try_emplace(CacheKey, nullptr);
if (DidInsert) {
It->second = std::make_shared<ProgramBuildResult>(getPlugin());
// Save reference between the common key and the full key.
CommonProgramKeyT CommonKey =
std::make_pair(CacheKey.first.second, CacheKey.second);
ProgCache.KeyMap.emplace(std::piecewise_construct,
std::forward_as_tuple(CommonKey),
std::forward_as_tuple(CacheKey));
ProgCache.KeyMap.emplace(CommonKey, CacheKey);
}
return std::make_pair(&Inserted.first->second, Inserted.second);
return std::make_pair(It->second, DidInsert);
}

std::pair<BuildResult<KernelArgMaskPairT> *, bool>
std::pair<KernelBuildResultPtr, bool>
getOrInsertKernel(sycl::detail::pi::PiProgram Program,
const std::string &KernelName) {
auto LockedCache = acquireKernelsPerProgramCache();
auto &Cache = LockedCache.get()[Program];
auto Inserted = Cache.emplace(
std::piecewise_construct, std::forward_as_tuple(KernelName),
std::forward_as_tuple(nullptr, BS_InProgress));
return std::make_pair(&Inserted.first->second, Inserted.second);
}

template <typename T, class Predicate>
void waitUntilBuilt(BuildResult<T> &BR, Predicate Pred) const {
std::unique_lock<std::mutex> Lock(BR.MBuildResultMutex);

BR.MBuildCV.wait(Lock, Pred);
}

template <typename ExceptionT, typename RetT>
RetT *waitUntilBuilt(BuildResult<RetT> *BuildResult) {
// Any thread which will find nullptr in cache will wait until the pointer
// is not null anymore.
waitUntilBuilt(*BuildResult, [BuildResult]() {
int State = BuildResult->State.load();
return State == BuildState::BS_Done || State == BuildState::BS_Failed;
});

if (BuildResult->Error.isFilledIn()) {
const BuildError &Error = BuildResult->Error;
throw ExceptionT(Error.Msg, Error.Code);
}

return BuildResult->Ptr.load();
}

template <typename T> void notifyAllBuild(BuildResult<T> &BR) const {
BR.MBuildCV.notify_all();
auto [It, DidInsert] = Cache.try_emplace(KernelName, nullptr);
if (DidInsert)
It->second = std::make_shared<KernelBuildResult>(getPlugin());
return std::make_pair(It->second, DidInsert);
}

template <typename KeyT>
Expand All @@ -203,11 +211,94 @@ class KernelProgramCache {
///
/// This member function should only be used in unit tests.
void reset() {
std::lock_guard<std::mutex> L1(MProgramCacheMutex);
std::lock_guard<std::mutex> L2(MKernelsPerProgramCacheMutex);
std::lock_guard<std::mutex> L3(MKernelFastCacheMutex);
MCachedPrograms = ProgramCache{};
MKernelsPerProgramCache = KernelCacheT{};
MKernelFastCache = KernelFastCacheT{};
}

/// Try to fetch entity (kernel or program) from cache. If there is no such
/// entity try to build it. Throw any exception build process may throw.
/// This method eliminates unwanted builds by employing atomic variable with
/// build state and waiting until the entity is built in another thread.
/// If the building thread has failed the awaiting thread will fail either.
/// Exception thrown by build procedure are rethrown.
///
/// \tparam RetT type of entity to get
/// \tparam ExceptionT type of exception to throw on awaiting thread if the
/// building thread fails build step.
/// \tparam KeyT key (in cache) to fetch built entity with
/// \tparam AcquireFT type of function which will acquire the locked version
/// of
/// the cache. Accept reference to KernelProgramCache.
/// \tparam GetCacheFT type of function which will fetch proper cache from
/// locked version. Accepts reference to locked version of cache.
/// \tparam BuildFT type of function which will build the entity if it is not
/// in
/// cache. Accepts nothing. Return pointer to built entity.
///
/// \return a pointer to cached build result, return value must not be
/// nullptr.
template <typename ExceptionT, typename GetCachedBuildFT, typename BuildFT>
auto getOrBuild(GetCachedBuildFT &&GetCachedBuild, BuildFT &&Build) {
using BuildState = KernelProgramCache::BuildState;
constexpr size_t MaxAttempts = 2;
for (size_t AttemptCounter = 0;; ++AttemptCounter) {
auto Res = GetCachedBuild();
auto &BuildResult = Res.first;
BuildState Expected = BuildState::BS_Initial;
BuildState Desired = BuildState::BS_InProgress;
if (!BuildResult->State.compare_exchange_strong(Expected, Desired)) {
// no insertion took place, thus some other thread has already inserted
// smth in the cache
BuildState NewState = BuildResult->waitUntilTransition();

// Build succeeded.
if (NewState == BuildState::BS_Done)
return BuildResult;

// Build failed, or this is the last attempt.
if (NewState == BuildState::BS_Failed ||
AttemptCounter + 1 == MaxAttempts) {
if (BuildResult->Error.isFilledIn())
throw ExceptionT(BuildResult->Error.Msg, BuildResult->Error.Code);
else
throw exception();
}

// NewState == BuildState::BS_Initial
// Build state was set back to the initial state,
// which means to go back to the beginning of the
// loop and try again.
continue;
}

// only the building thread will run this
try {
BuildResult->Val = Build();

BuildResult->updateAndNotify(BuildState::BS_Done);
return BuildResult;
} catch (const exception &Ex) {
BuildResult->Error.Msg = Ex.what();
BuildResult->Error.Code = Ex.get_cl_code();
if (BuildResult->Error.Code == PI_ERROR_OUT_OF_RESOURCES) {
reset();
BuildResult->updateAndNotify(BuildState::BS_Initial);
continue;
}

BuildResult->updateAndNotify(BuildState::BS_Failed);
std::rethrow_exception(std::current_exception());
} catch (...) {
BuildResult->updateAndNotify(BuildState::BS_Initial);
std::rethrow_exception(std::current_exception());
}
}
}

private:
std::mutex MProgramCacheMutex;
std::mutex MKernelsPerProgramCacheMutex;
Expand All @@ -219,6 +310,8 @@ class KernelProgramCache {
std::mutex MKernelFastCacheMutex;
KernelFastCacheT MKernelFastCache;
friend class ::MockKernelProgramCache;

const PluginPtr &getPlugin();
};
} // namespace detail
} // namespace _V1
Expand Down
Loading