Skip to content

Commit

Permalink
Collect CUDA/CPU profiling info into result sheets. (pytorch#5921)
Browse files Browse the repository at this point in the history
* Collect CUDA/CPU profiling info into result sheets.

This PR:
0. Adds CUDA/CPU collection capabilties to the script.
1. Modifies result_analyzer.py to analyze newly collected results.
2. Moves CUDA synchronize/XLA device synchronize into the profiler.
3. Fixes list typing for Python 3.8+.

Tested with command:
python3 xla/benchmarks/experiment_runner.py --dynamo=openxla --xla=PJRT --test=train --filter=basic_gnn_gcn$ --suite-name=torchbench --accelerator=cuda --progress-bar --output-dirname=/tmp/output --repeat=2 --print-subprocess --no-resume --profile-cuda-cpu-collect --profile-cuda
python3 xla/benchmarks/result_analyzer.py --output-dir=/tmp/output

* Lint, and add _s suffix to metrics

---------

Co-authored-by: root <root@olechwierowicz9.zrh.corp.google.com>
  • Loading branch information
2 people authored and ManfeiBai committed Dec 1, 2023
1 parent 63b23e2 commit c36f0fd
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 5 deletions.
51 changes: 46 additions & 5 deletions benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch
from tqdm import tqdm
from torch.profiler import profile, record_function, ProfilerActivity
from torch.autograd import DeviceType
from typing import List

try:
from .benchmark_model import ModelLoader
Expand Down Expand Up @@ -249,8 +251,41 @@ def dump_profile_info(self, prof, model_name):
with open(os.path.join(file_path, "kernel_dump.txt"), "a") as f:
f.write(kernel_dump)

def timed_run(self, benchmark_experiment, benchmark_model):
def collect_profile_to_metrics(self, prof, metrics):
assert prof is not None, 'Expecting profiler to be defined!'
if not self._args.profile_cuda_cpu_collect:
logger.warning(
'Profiling enabled, but collection of CPU/CUDA profiling info disabled.'
)
return

kernel_dump = prof.profiler.total_average()
total_cuda_time = 0
total_cpu_time = kernel_dump.self_cpu_time_total

# Avoid double counting CUDA time for inductor. Copied here, since the interface is not really exposed via any interface.
# The alternative is regex matching resulting string dump for CUDA kernel time.
# Source: https://github.com/pytorch/pytorch/blob/2f3beb715c608a060934c237de402faa40ea211f/torch/autograd/profiler_util.py#L1025-L1037
for evt in prof.profiler.key_averages():
if evt.device_type == DeviceType.CPU:
# in legacy profiler, kernel info is stored in cpu events
if evt.is_legacy:
if not use_device:
total_cuda_time += evt.self_cuda_time_total
elif evt.device_type == DeviceType.CUDA:
# in kineto profiler, there're events with the correct device type (e.g. CUDA)
total_cuda_time += evt.self_cuda_time_total

total_cpu_time /= 1000000
total_cuda_time /= 1000000
metrics["total_cpu_time_s"] = total_cpu_time
metrics["total_cuda_time_s"] = total_cuda_time
metrics[
"per_iter_cpu_time_s"] = total_cpu_time / self._args.iterations_per_run
metrics[
"per_iter_cuda_time_s"] = total_cuda_time / self._args.iterations_per_run

def timed_run(self, benchmark_experiment, benchmark_model):
reset_rng_state(benchmark_experiment)

inputs_list = self.prepare_inputs(benchmark_model.example_inputs,
Expand Down Expand Up @@ -282,6 +317,7 @@ def loop(prof=None):

if prof:
prof.step()
self._synchronize(benchmark_experiment)
return output

if enable_prof:
Expand All @@ -291,11 +327,10 @@ def loop(prof=None):
else:
output = loop()

self._synchronize(benchmark_experiment)

t_end = time.perf_counter()
if enable_prof:
self.dump_profile_info(prof, benchmark_model.model_name)
self.collect_profile_to_metrics(prof, metrics)

metrics["total_time"] = t_end - t_start
metrics[
Expand All @@ -306,7 +341,7 @@ def loop(prof=None):
return metrics, output


def append_filter_by_tier(filter_list: list[str], filter_by_tier: list[int]):
def append_filter_by_tier(filter_list: List[str], filter_by_tier: List[int]):
_FILTER_BY_TIER = {
1:
r"^(BERT_pytorch|cm3leon_generate|DALLE2_pytorch|dlrm|hf_GPT2|hf_GPT2_large|GPT_3|hf_T5|hf_T5_base|hf_T5_generate|hf_T5_large|llama_v2_7b_16h|stable_diffusion_xl)$",
Expand Down Expand Up @@ -521,7 +556,13 @@ def parse_args(args=None):
"--profile-cuda-dump",
type=str,
default="./output/",
help="Directory specifying where to dump profiling information (summary, and trace)"
help="Directory specifying where to dump profiling information (summary, and trace)",
),

parser.add_argument(
"--profile-cuda-cpu-collect",
action="store_true",
help="Whether to collect CPU/GPU profiling information in the resulting file.",
),

parser.add_argument(
Expand Down
17 changes: 17 additions & 0 deletions benchmarks/result_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ def get_calculated_metrics(self, d, dataline):
d["xla_median_trace_per_iter_time"] = -1
d["xla_compile_time"] = -1

if "total_cpu_time_s" in dataline["metrics"]:
total_cpu_time = np.asarray(
dataline["metrics"]["total_cpu_time_s"], dtype="float")
d["median_total_cpu_time_s"] = np.median(total_cpu_time)
if "per_iter_cpu_time_s" in dataline["metrics"]:
per_iter_cpu_time = np.asarray(
dataline["metrics"]["per_iter_cpu_time_s"], dtype="float")
d["median_per_iter_cpu_time_s"] = np.median(per_iter_cpu_time)
if "total_cuda_time_s" in dataline["metrics"]:
total_cuda_time = np.asarray(
dataline["metrics"]["total_cuda_time_s"], dtype="float")
d["median_total_cuda_time_s"] = np.median(total_cuda_time)
if "per_iter_cuda_time_s" in dataline["metrics"]:
per_iter_cuda_time = np.asarray(
dataline["metrics"]["per_iter_cuda_time_s"], dtype="float")
d["median_per_iter_cuda_time_s"] = np.median(per_iter_cuda_time)

if dataline["experiment"]["dynamo"]:
d["dynamo_compile_time"] = np.max(total_time) - np.median(total_time)
else:
Expand Down

0 comments on commit c36f0fd

Please sign in to comment.