-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update CUDNN Frontend API to v0.9.1 #54949
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
#include <map> | ||
#include <mutex> | ||
#include <string> | ||
#include <thread> | ||
#include <vector> | ||
|
||
#include "paddle/phi/backends/dynload/cudnn_frontend.h" | ||
|
@@ -34,7 +35,13 @@ class CudnnFrontendPlanCache { | |
saturation_count_ = FLAGS_cudnn_cache_saturation_count; | ||
} | ||
|
||
int64_t Size() const { return map_.size(); } | ||
int64_t Size() const { | ||
int64_t total_size = 0; | ||
for (auto it = map_.begin(); it != map_.end(); it++) { | ||
total_size += (it->second).size(); | ||
} | ||
return total_size; | ||
} | ||
|
||
int64_t CacheHits() const { return cache_hits_; } | ||
|
||
|
@@ -58,11 +65,12 @@ class CudnnFrontendPlanCache { | |
cache_misses_ = 0; | ||
} | ||
|
||
bool FindPlan(const cudnn_frontend::OperationGraph& op_graph, | ||
bool use_addto = false) { | ||
bool FindPlan(const cudnn_frontend::feature_vector_t &feature, | ||
cudnnHandle_t handle) { | ||
bool ret = false; | ||
std::lock_guard<std::mutex> lock(*cache_mutex_); | ||
if (map_.count(MakeKey(op_graph, use_addto)) > 0) { | ||
auto &local_map = map_[hasher(std::this_thread::get_id())]; | ||
if (local_map.count(GetExtendedFeature(feature, handle)) > 0) { | ||
cache_hits_++; | ||
ret = true; | ||
} else { | ||
|
@@ -71,58 +79,98 @@ class CudnnFrontendPlanCache { | |
return ret; | ||
} | ||
|
||
cudnn_frontend::ManagedOpaqueDescriptor GetConfig( | ||
const cudnn_frontend::OperationGraph& op_graph, | ||
cudnnHandle_t handle, | ||
bool use_addto = false) { | ||
void GetPlan(const cudnn_frontend::feature_vector_t &feature, | ||
const cudnn_frontend::ExecutionPlan **plan, | ||
int64_t *workspace_size, | ||
cudnnHandle_t handle) { | ||
// Note(tizheng): CUDNNv8 execution plan is not thread-safe. | ||
// A shared plan being executed by different threads is | ||
// generally not safe (for now). | ||
std::lock_guard<std::mutex> lock(*cache_mutex_); | ||
auto engine_config = map_[MakeKey(op_graph, use_addto)]; | ||
return engine_config; | ||
auto &local_map = map_[hasher(std::this_thread::get_id())]; | ||
|
||
auto it = local_map.find(GetExtendedFeature(feature, handle)); | ||
if (it == local_map.end()) { | ||
PADDLE_THROW(phi::errors::InvalidArgument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以直接用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 下个PR修改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed at #55026 |
||
"[cudnn_frontend] Cached Plan Not Found.")); | ||
return; | ||
} | ||
*plan = &(it->second); | ||
*workspace_size = (*plan)->getWorkspaceSize(); | ||
VLOG(4) << "Cached execution plan found." << (*plan)->getTag() | ||
<< "; Require workspace: " << *workspace_size; | ||
} | ||
|
||
void InsertPlan(const cudnn_frontend::OperationGraph& op_graph, | ||
const cudnn_frontend::ExecutionPlan& plan, | ||
bool use_addto = false) { | ||
VLOG(4) << "[cudnn_frontend] cache: Insert graph tag: " | ||
<< op_graph.getTag(); | ||
void InsertPlan(const cudnn_frontend::feature_vector_t &feature, | ||
const cudnn_frontend::ExecutionPlan &plan, | ||
cudnnHandle_t handle) { | ||
VLOG(4) << "[cudnn_frontend] cache: Insert plan: " << plan.getTag(); | ||
std::lock_guard<std::mutex> lock(*cache_mutex_); | ||
map_.insert( | ||
std::make_pair(MakeKey(op_graph, use_addto), plan.GetEngineConfig())); | ||
auto &local_map = map_[hasher(std::this_thread::get_id())]; | ||
local_map.insert(std::make_pair(GetExtendedFeature(feature, handle), plan)); | ||
} | ||
|
||
bool IsStable(const cudnn_frontend::OperationGraph& op_graph, | ||
const std::string& tag, | ||
bool use_addto = false) { | ||
bool IsStable(const cudnn_frontend::feature_vector_t &feature, | ||
const std::string &tag, | ||
cudnnHandle_t handle) { | ||
if (saturation_count_ == 1) { | ||
return true; | ||
} | ||
std::lock_guard<std::mutex> lock(*cache_mutex_); | ||
if (map_.count(MakeKey(op_graph, use_addto))) { | ||
auto &local_map = map_[hasher(std::this_thread::get_id())]; | ||
auto &local_tracker = tracker_[hasher(std::this_thread::get_id())]; | ||
auto ext_feature = GetExtendedFeature(feature, handle); | ||
if (local_map.count(ext_feature)) { | ||
return false; | ||
} | ||
int cnt = tracker_[std::make_pair(MakeKey(op_graph, use_addto), tag)] += 1; | ||
VLOG(4) << "[cudnn_frontend] SaturationTracker: " << op_graph.getTag() | ||
<< " " << tag << " " << cnt; | ||
int cnt = local_tracker[std::make_pair(ext_feature, tag)] += 1; | ||
VLOG(4) << "[cudnn_frontend] SaturationTracker: " << tag << " " << cnt; | ||
return cnt >= saturation_count_; | ||
} | ||
|
||
bool FindPlan(const cudnn_frontend::OperationGraph &op_graph, | ||
cudnnHandle_t handle) { | ||
return FindPlan(op_graph.getFeatureVector(), handle); | ||
} | ||
|
||
void GetPlan(const cudnn_frontend::OperationGraph &op_graph, | ||
const cudnn_frontend::ExecutionPlan **plan, | ||
int64_t *workspace_size, | ||
cudnnHandle_t handle) { | ||
GetPlan(op_graph.getFeatureVector(), plan, workspace_size, handle); | ||
} | ||
|
||
void InsertPlan(const cudnn_frontend::OperationGraph &op_graph, | ||
const cudnn_frontend::ExecutionPlan &plan, | ||
cudnnHandle_t handle) { | ||
InsertPlan(op_graph.getFeatureVector(), plan, handle); | ||
} | ||
|
||
bool IsStable(const cudnn_frontend::OperationGraph &op_graph, | ||
const std::string &tag, | ||
cudnnHandle_t handle) { | ||
return IsStable(op_graph.getFeatureVector(), tag, handle); | ||
} | ||
|
||
private: | ||
static cudnn_frontend::feature_vector_t MakeKey( | ||
const cudnn_frontend::OperationGraph& op_graph, bool use_addto) { | ||
Tom-Zheng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto key = op_graph.getFeatureVector(); | ||
key.push_back(static_cast<uint64_t>(use_addto)); | ||
return key; | ||
cudnn_frontend::feature_vector_t GetExtendedFeature( | ||
cudnn_frontend::feature_vector_t feat, cudnnHandle_t handle) { | ||
int64_t val = 0; | ||
memcpy(&val, &handle, sizeof(int64_t)); | ||
feat.push_back(val); | ||
return feat; | ||
} | ||
using FeatureVectorToPlanMap = | ||
std::map<cudnn_frontend::feature_vector_t, cudnn_frontend::ExecutionPlan>; | ||
std::map<std::size_t, FeatureVectorToPlanMap> map_; | ||
std::hash<std::thread::id> hasher; | ||
|
||
std::map<cudnn_frontend::feature_vector_t, | ||
cudnn_frontend::ManagedOpaqueDescriptor> | ||
map_; | ||
std::shared_ptr<std::mutex> cache_mutex_; | ||
int saturation_count_; | ||
|
||
using SaturationTracker = | ||
std::map<std::pair<cudnn_frontend::feature_vector_t, std::string>, int>; | ||
SaturationTracker tracker_; | ||
std::map<std::size_t, SaturationTracker> tracker_; | ||
|
||
int64_t cache_hits_{0}; | ||
int64_t cache_misses_{0}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetPlan
->GetPlanAndWorkspaceSize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个PR修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed at #55026