From 80ca0fc0509df8b20aee5e4590220bd04c1e6853 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 23 Nov 2023 16:26:00 +0000 Subject: [PATCH] Collect CUDA/CPU profiling info into result sheets. This PR: 0. Adds CUDA/CPU collection capabilties to the script. 1. Modifies result_analyzer.py to analyze newly collected results. 2. Moves CUDA synchronize/XLA device synchronize into the profiler. 3. Fixes list typing for Python 3.8+. Tested with command: python3 xla/benchmarks/experiment_runner.py --dynamo=openxla --xla=PJRT --test=train --filter=basic_gnn_gcn$ --suite-name=torchbench --accelerator=cuda --progress-bar --output-dirname=/tmp/output --repeat=2 --print-subprocess --no-resume --profile-cuda-cpu-collect --profile-cuda python3 xla/benchmarks/result_analyzer.py --output-dir=/tmp/output --- benchmarks/experiment_runner.py | 46 ++++++++++++++++++++++++++++++--- benchmarks/result_analyzer.py | 13 ++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/benchmarks/experiment_runner.py b/benchmarks/experiment_runner.py index 660fc68a604..11f9f5c827c 100644 --- a/benchmarks/experiment_runner.py +++ b/benchmarks/experiment_runner.py @@ -13,6 +13,8 @@ import torch from tqdm import tqdm from torch.profiler import profile, record_function, ProfilerActivity +from torch.autograd import DeviceType +from typing import List try: from .benchmark_model import ModelLoader @@ -249,8 +251,37 @@ def dump_profile_info(self, prof, model_name): with open(os.path.join(file_path, "kernel_dump.txt"), "a") as f: f.write(kernel_dump) - def timed_run(self, benchmark_experiment, benchmark_model): + def collect_profile_to_metrics(self, prof, metrics): + assert prof is not None, 'Expecting profiler to be defined!' + if not self._args.profile_cuda_cpu_collect: + logger.warning('Profiling enabled, but collection of CPU/CUDA profiling info disabled.') + return + kernel_dump = prof.profiler.total_average() + total_cuda_time = 0 + total_cpu_time = kernel_dump.self_cpu_time_total + + # Avoid double counting CUDA time for inductor. Copied here, since the interface is not really exposed via any interface. + # The alternative is regex matching resulting string dump for CUDA kernel time. + # Source: https://github.com/pytorch/pytorch/blob/2f3beb715c608a060934c237de402faa40ea211f/torch/autograd/profiler_util.py#L1025-L1037 + for evt in prof.profiler.key_averages(): + if evt.device_type == DeviceType.CPU: + # in legacy profiler, kernel info is stored in cpu events + if evt.is_legacy: + if not use_device: + total_cuda_time += evt.self_cuda_time_total + elif evt.device_type == DeviceType.CUDA: + # in kineto profiler, there're events with the correct device type (e.g. CUDA) + total_cuda_time += evt.self_cuda_time_total + + total_cpu_time /= 1000000 + total_cuda_time /= 1000000 + metrics["total_cpu_time"] = total_cpu_time + metrics["total_cuda_time"] = total_cuda_time + metrics["per_iter_cpu_time"] = total_cpu_time / self._args.iterations_per_run + metrics["per_iter_cuda_time"] = total_cuda_time / self._args.iterations_per_run + + def timed_run(self, benchmark_experiment, benchmark_model): reset_rng_state(benchmark_experiment) inputs_list = self.prepare_inputs(benchmark_model.example_inputs, @@ -282,6 +313,7 @@ def loop(prof=None): if prof: prof.step() + self._synchronize(benchmark_experiment) return output if enable_prof: @@ -291,11 +323,11 @@ def loop(prof=None): else: output = loop() - self._synchronize(benchmark_experiment) t_end = time.perf_counter() if enable_prof: self.dump_profile_info(prof, benchmark_model.model_name) + self.collect_profile_to_metrics(prof, metrics) metrics["total_time"] = t_end - t_start metrics[ @@ -306,7 +338,7 @@ def loop(prof=None): return metrics, output -def append_filter_by_tier(filter_list: list[str], filter_by_tier: list[int]): +def append_filter_by_tier(filter_list: List[str], filter_by_tier: List[int]): _FILTER_BY_TIER = { 1: r"^(BERT_pytorch|cm3leon_generate|DALLE2_pytorch|dlrm|hf_GPT2|hf_GPT2_large|GPT_3|hf_T5|hf_T5_base|hf_T5_generate|hf_T5_large|llama_v2_7b_16h|stable_diffusion_xl)$", @@ -521,7 +553,13 @@ def parse_args(args=None): "--profile-cuda-dump", type=str, default="./output/", - help="Directory specifying where to dump profiling information (summary, and trace)" + help="Directory specifying where to dump profiling information (summary, and trace)", + ), + + parser.add_argument( + "--profile-cuda-cpu-collect", + action="store_true", + help="Whether to collect CPU/GPU profiling information in the resulting file.", ), parser.add_argument( diff --git a/benchmarks/result_analyzer.py b/benchmarks/result_analyzer.py index 15deb45e8b9..3b510129024 100644 --- a/benchmarks/result_analyzer.py +++ b/benchmarks/result_analyzer.py @@ -88,6 +88,19 @@ def get_calculated_metrics(self, d, dataline): d["xla_median_trace_per_iter_time"] = -1 d["xla_compile_time"] = -1 + if "total_cpu_time" in dataline["metrics"]: + total_cpu_time = np.asarray(dataline["metrics"]["total_cpu_time"], dtype="float") + d["median_total_cpu_time"] = np.median(total_cpu_time) + if "per_iter_cpu_time" in dataline["metrics"]: + per_iter_cpu_time = np.asarray(dataline["metrics"]["per_iter_cpu_time"], dtype="float") + d["median_per_iter_cpu_time"] = np.median(per_iter_cpu_time) + if "total_cuda_time" in dataline["metrics"]: + total_cuda_time = np.asarray(dataline["metrics"]["total_cuda_time"], dtype="float") + d["median_total_cuda_time"] = np.median(total_cuda_time) + if "per_iter_cuda_time" in dataline["metrics"]: + per_iter_cuda_time = np.asarray(dataline["metrics"]["per_iter_cuda_time"], dtype="float") + d["median_per_iter_cuda_time"] = np.median(per_iter_cuda_time) + if dataline["experiment"]["dynamo"]: d["dynamo_compile_time"] = np.max(total_time) - np.median(total_time) else: