Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CUDNN Frontend API to v0.9.1 #54949

Merged
merged 5 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions cmake/external/cudnn-frontend.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,24 @@ endif()

if((NOT DEFINED CUDNN_FRONTEND_NAME) OR (NOT DEFINED CUDNN_FRONTEND_URL))
set(CUDNN_FRONTEND_VER
"1.23.2"
"v0.9.1"
CACHE STRING "" FORCE)
set(CUDNN_FRONTEND_NAME
"cudnn-frontend"
CACHE STRING "" FORCE)
set(CUDNN_FRONTEND_URL
"https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v0.7.1.tar.gz"
"https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/${CUDNN_FRONTEND_VER}.tar.gz"
CACHE STRING "" FORCE)
set(CUDNN_FRONTEND_CACHE_FILENAME "v0.7.1.tar.gz")
endif()
set(CUDNN_FRONTEND_URL_MD5 "d8f911df571f8b0d40226efa9c0150c8")
set(CUDNN_FRONTEND_CACHE_FILENAME "${CUDNN_FRONTEND_VER}.tar.gz")
set(CUDNN_FRONTEND_URL_MD5 "da7cbad1305427f687dd4fd737178f80")

message(
STATUS
"CUDNN_FRONTEND_NAME: ${CUDNN_FRONTEND_NAME}, CUDNN_FRONTEND_URL: ${CUDNN_FRONTEND_URL}"
)
set(DIRENT_DOWNLOAD_DIR "${PADDLE_SOURCE_DIR}/third_party/cudnn-frontend")
# Version: v0.7.1
set(CUDNN_FRONTEND_DOWNLOAD_DIR
"${PADDLE_SOURCE_DIR}/third_party/cudnn-frontend")
set(CUDNN_FRONTEND_PREFIX_DIR ${THIRD_PARTY_PATH}/cudnn-frontend)
set(CUDNN_FRONTEND_SOURCE_DIR
${THIRD_PARTY_PATH}/cudnn-frontend/src/extern_cudnn_frontend/include)
Expand All @@ -55,7 +55,7 @@ include_directories(${CUDNN_FRONTEND_INCLUDE_DIR})

message(
STATUS
"Adding cudnn-frontend. Version: ${CUDNN_FRONTEND_VER}. Directory: ${DIRENT_DOWNLOAD_DIR}"
"Adding cudnn-frontend. Version: ${CUDNN_FRONTEND_VER}. Directory: ${CUDNN_FRONTEND_DOWNLOAD_DIR}"
)

function(download_cudnn_frontend)
Expand Down Expand Up @@ -99,9 +99,7 @@ ExternalProject_Add(
DOWNLOAD_DIR ${CUDNN_FRONTEND_DOWNLOAD_DIR}
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
PATCH_COMMAND
patch -d ${CUDNN_FRONTEND_SOURCE_DIR} -p2 <
${PADDLE_SOURCE_DIR}/patches/cudnn-frontend/0001-patch-for-paddle.patch
PATCH_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
Expand Down
114 changes: 81 additions & 33 deletions paddle/phi/kernels/autotune/cache_cudnn_frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <map>
#include <mutex>
#include <string>
#include <thread>
#include <vector>

#include "paddle/phi/backends/dynload/cudnn_frontend.h"
Expand All @@ -34,7 +35,13 @@ class CudnnFrontendPlanCache {
saturation_count_ = FLAGS_cudnn_cache_saturation_count;
}

int64_t Size() const { return map_.size(); }
int64_t Size() const {
int64_t total_size = 0;
for (auto it = map_.begin(); it != map_.end(); it++) {
total_size += (it->second).size();
}
return total_size;
}

int64_t CacheHits() const { return cache_hits_; }

Expand All @@ -58,11 +65,12 @@ class CudnnFrontendPlanCache {
cache_misses_ = 0;
}

bool FindPlan(const cudnn_frontend::OperationGraph& op_graph,
bool use_addto = false) {
bool FindPlan(const cudnn_frontend::feature_vector_t &feature,
cudnnHandle_t handle) {
bool ret = false;
std::lock_guard<std::mutex> lock(*cache_mutex_);
if (map_.count(MakeKey(op_graph, use_addto)) > 0) {
auto &local_map = map_[hasher(std::this_thread::get_id())];
if (local_map.count(GetExtendedFeature(feature, handle)) > 0) {
cache_hits_++;
ret = true;
} else {
Expand All @@ -71,58 +79,98 @@ class CudnnFrontendPlanCache {
return ret;
}

cudnn_frontend::ManagedOpaqueDescriptor GetConfig(
const cudnn_frontend::OperationGraph& op_graph,
cudnnHandle_t handle,
bool use_addto = false) {
void GetPlan(const cudnn_frontend::feature_vector_t &feature,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetPlan -> GetPlanAndWorkspaceSize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed at #55026

const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
// Note(tizheng): CUDNNv8 execution plan is not thread-safe.
// A shared plan being executed by different threads is
// generally not safe (for now).
std::lock_guard<std::mutex> lock(*cache_mutex_);
auto engine_config = map_[MakeKey(op_graph, use_addto)];
return engine_config;
auto &local_map = map_[hasher(std::this_thread::get_id())];

auto it = local_map.find(GetExtendedFeature(feature, handle));
if (it == local_map.end()) {
PADDLE_THROW(phi::errors::InvalidArgument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以直接用PADDLE_ENFORCE_NE(it, local_map.end(), ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed at #55026

"[cudnn_frontend] Cached Plan Not Found."));
return;
}
*plan = &(it->second);
*workspace_size = (*plan)->getWorkspaceSize();
VLOG(4) << "Cached execution plan found." << (*plan)->getTag()
<< "; Require workspace: " << *workspace_size;
}

void InsertPlan(const cudnn_frontend::OperationGraph& op_graph,
const cudnn_frontend::ExecutionPlan& plan,
bool use_addto = false) {
VLOG(4) << "[cudnn_frontend] cache: Insert graph tag: "
<< op_graph.getTag();
void InsertPlan(const cudnn_frontend::feature_vector_t &feature,
const cudnn_frontend::ExecutionPlan &plan,
cudnnHandle_t handle) {
VLOG(4) << "[cudnn_frontend] cache: Insert plan: " << plan.getTag();
std::lock_guard<std::mutex> lock(*cache_mutex_);
map_.insert(
std::make_pair(MakeKey(op_graph, use_addto), plan.GetEngineConfig()));
auto &local_map = map_[hasher(std::this_thread::get_id())];
local_map.insert(std::make_pair(GetExtendedFeature(feature, handle), plan));
}

bool IsStable(const cudnn_frontend::OperationGraph& op_graph,
const std::string& tag,
bool use_addto = false) {
bool IsStable(const cudnn_frontend::feature_vector_t &feature,
const std::string &tag,
cudnnHandle_t handle) {
if (saturation_count_ == 1) {
return true;
}
std::lock_guard<std::mutex> lock(*cache_mutex_);
if (map_.count(MakeKey(op_graph, use_addto))) {
auto &local_map = map_[hasher(std::this_thread::get_id())];
auto &local_tracker = tracker_[hasher(std::this_thread::get_id())];
auto ext_feature = GetExtendedFeature(feature, handle);
if (local_map.count(ext_feature)) {
return false;
}
int cnt = tracker_[std::make_pair(MakeKey(op_graph, use_addto), tag)] += 1;
VLOG(4) << "[cudnn_frontend] SaturationTracker: " << op_graph.getTag()
<< " " << tag << " " << cnt;
int cnt = local_tracker[std::make_pair(ext_feature, tag)] += 1;
VLOG(4) << "[cudnn_frontend] SaturationTracker: " << tag << " " << cnt;
return cnt >= saturation_count_;
}

bool FindPlan(const cudnn_frontend::OperationGraph &op_graph,
cudnnHandle_t handle) {
return FindPlan(op_graph.getFeatureVector(), handle);
}

void GetPlan(const cudnn_frontend::OperationGraph &op_graph,
const cudnn_frontend::ExecutionPlan **plan,
int64_t *workspace_size,
cudnnHandle_t handle) {
GetPlan(op_graph.getFeatureVector(), plan, workspace_size, handle);
}

void InsertPlan(const cudnn_frontend::OperationGraph &op_graph,
const cudnn_frontend::ExecutionPlan &plan,
cudnnHandle_t handle) {
InsertPlan(op_graph.getFeatureVector(), plan, handle);
}

bool IsStable(const cudnn_frontend::OperationGraph &op_graph,
const std::string &tag,
cudnnHandle_t handle) {
return IsStable(op_graph.getFeatureVector(), tag, handle);
}

private:
static cudnn_frontend::feature_vector_t MakeKey(
const cudnn_frontend::OperationGraph& op_graph, bool use_addto) {
Tom-Zheng marked this conversation as resolved.
Show resolved Hide resolved
auto key = op_graph.getFeatureVector();
key.push_back(static_cast<uint64_t>(use_addto));
return key;
cudnn_frontend::feature_vector_t GetExtendedFeature(
cudnn_frontend::feature_vector_t feat, cudnnHandle_t handle) {
int64_t val = 0;
memcpy(&val, &handle, sizeof(int64_t));
feat.push_back(val);
return feat;
}
using FeatureVectorToPlanMap =
std::map<cudnn_frontend::feature_vector_t, cudnn_frontend::ExecutionPlan>;
std::map<std::size_t, FeatureVectorToPlanMap> map_;
std::hash<std::thread::id> hasher;

std::map<cudnn_frontend::feature_vector_t,
cudnn_frontend::ManagedOpaqueDescriptor>
map_;
std::shared_ptr<std::mutex> cache_mutex_;
int saturation_count_;

using SaturationTracker =
std::map<std::pair<cudnn_frontend::feature_vector_t, std::string>, int>;
SaturationTracker tracker_;
std::map<std::size_t, SaturationTracker> tracker_;

int64_t cache_hits_{0};
int64_t cache_misses_{0};
Expand Down
Loading