Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Profiler] Allow user to flush L2 cache in time_evalutor function for profiling CUDA kernels #13726

Merged
merged 12 commits into from
Jan 10, 2023

Conversation

yzh119
Copy link
Member

@yzh119 yzh119 commented Jan 8, 2023

Motivation

Currently, our default profiler (time_evaluator) does not flush the L2 cache per execution, this might lead to incorrect time measurement because the input data last run might reside in L2 cache and reduce the data fetching time in the next run. Both Triton and nvbench consider this effect thus reporting more accurate measurements.

Solution

time_evalutor has an argument f_preproc where user can specify a pre-processing function per execution of the kernel being evaluated. Currently, TVM supports cache_flush_cpu_non_first_arg which flushes CPU cache. But similar functionality for GPU is missing.

This PR completely borrows the design of nvbench's l2flush struct and allow the user to specify "l2_cache_flush_cuda" as a preprocessing function which flushes NVIDIA GPU's L2 cache. l2_cache_flush_cuda is not a default value so existing program's behavior would not be influenced.

Note that this PR also changes the location where `f_preproc` being triggered: previously `f_preproc` is triggered per repeat but that doesn't sound correct to me because most users specify `repeat=1` and `f_preproc` need to be triggered once per run.

cc @masahi @tkonolige @junrushao @spectrometerHBH @tqchen

@tvm-bot
Copy link
Collaborator

tvm-bot commented Jan 8, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

src/runtime/profiling.cc Outdated Show resolved Hide resolved
Copy link
Contributor

@Icemist Icemist left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, except for a couple of minor non-functional issues.

src/runtime/cuda/l2_cache_flush.cc Outdated Show resolved Hide resolved
src/runtime/cuda/l2_cache_flush.cc Outdated Show resolved Hide resolved
Copy link
Contributor

@tkonolige tkonolige left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yzh119! This is a great addition to benchmarking. Maybe we should consider flushing the cache by default instead of requiring the user to specify a preprocessor. Maybe something like a cold-start mode.

There's one change I need to you make around moving this code to third party as it is almost directly from nvbench.

src/runtime/cuda/l2_cache_flush.cc Show resolved Hide resolved
@tkonolige tkonolige merged commit 92da138 into apache:main Jan 10, 2023
fzi-peccia pushed a commit to fzi-peccia/tvm that referenced this pull request Mar 27, 2023
…or profiling CUDA kernels (apache#13726)

Currently, our default profiler (time_evaluator) does not flush the L2 cache per execution, this might lead to incorrect time measurement because the input data last run might reside in L2 cache and reduce the data fetching time in the next run. Both Triton and nvbench consider this effect thus reporting more accurate measurements.

Solution: time_evalutor has an argument f_preproc where user can specify a pre-processing function per execution of the kernel being evaluated. Currently, TVM supports cache_flush_cpu_non_first_arg which flushes CPU cache. But similar functionality for GPU is missing.

This PR completely borrows the design of nvbench's l2flush struct and allow the user to specify "l2_cache_flush_cuda" as a preprocessing function which flushes NVIDIA GPU's L2 cache. l2_cache_flush_cuda is not a default value so existing program's behavior would not be influenced.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants