From 20692cb04256ca24b1dd2c7d00ffba0bb0cb69ac Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 16 Feb 2024 10:56:54 -0300 Subject: [PATCH] [benchmarks] Default to `bfloat16` (inference) and AMP (training) precision. (#6518) --- benchmarks/benchmark_model.py | 11 ++- benchmarks/torchbench_model.py | 128 +++++++++++---------------------- benchmarks/util.py | 9 +++ 3 files changed, 61 insertions(+), 87 deletions(-) diff --git a/benchmarks/benchmark_model.py b/benchmarks/benchmark_model.py index 1a62d986c36..59a430d3982 100644 --- a/benchmarks/benchmark_model.py +++ b/benchmarks/benchmark_model.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from torch._dynamo.testing import collect_results -from util import move_to_device +from util import cast_to_dtype, move_to_device logger = logging.getLogger(__name__) @@ -104,8 +104,17 @@ def _prepare_for_train(self): # optimizer to use. So only initialize it when there is none existing. self.optimizer = self.optimizer_class(self.module.parameters(), lr=0.01) + def conversion_dtype(self): + return None + def prepare_for_experiment(self, dynamo_compilation_opts): self.device = self.benchmark_experiment.get_device() + self.dtype = self.conversion_dtype() + + if self.dtype is not None: + self.module = self.module.to(self.dtype) + self.example_inputs = cast_to_dtype(self.example_inputs, self.dtype) + self.module = self.module.to(self.device) self.example_inputs = move_to_device(self.example_inputs, self.device) diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index 7d1cd35be86..d3f644b4b73 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -144,6 +144,18 @@ "hf_T5_generate", } +FORCE_AMP_FOR_FP16_BF16_MODELS = { + "DALLE2_pytorch", + "doctr_det_predictor", + "doctr_reco_predictor", + "Super_SloMo", + "tts_angular", + "pyhpc_turbulent_kinetic_energy", + "detectron2_fcos_r_50_fpn", +} + +FORCE_FP16_FOR_BF16_MODELS = {"vision_maskrcnn"} + class TorchBenchModelLoader(ModelLoader): @@ -295,10 +307,6 @@ def set_up(self): self.example_inputs = move_to_device(self.example_inputs, "cpu") self._cleanup() - device = self.benchmark_experiment.get_device() - self.module = self.module.to(device) - self.example_inputs = move_to_device(self.example_inputs, device) - # Torchbench has quite different setup for yolov3, so directly passing # the right example_inputs if self.model_name == "yolov3": @@ -334,88 +342,13 @@ def load_benchmark(self): if (device := self.benchmark_experiment.accelerator) == 'tpu': device = str(self.benchmark_experiment.get_device()) - kwargs = { - "test": self.benchmark_experiment.test, - "device": device, - "batch_size": self.benchmark_experiment.batch_size, - } - - # Force FP32 when precision is either FP16 or BF16 (only for XLA:CUDA). - # If the model is, e.g. FP16 and XLA_USE_FP16, XLA will unexpectedly up-cast - # return values to FP32. - # Issue: https://github.com/pytorch/xla/issues/6348 - if self.benchmark_experiment.accelerator == "cuda" and self.benchmark_experiment.xla: - if self.is_cuda_precision_fp16() or self.is_cuda_precision_bf16(): - # PyTorch/benchmark will use these 'extra_args' for converting the model. - kwargs["extra_args"] = ["--precision", "fp32"] - - return self.benchmark_cls()(**kwargs) - - def get_cuda_precision(self): - test = self.benchmark_experiment.test.upper() - attr = f"DEFAULT_{test}_CUDA_PRECISION" - return getattr(self.benchmark_cls(), attr, None) - - def is_cuda_precision_fp16(self): - return self.get_cuda_precision() == "fp16" - - def is_cuda_precision_fp32(self): - return self.get_cuda_precision() == "fp32" - - def is_cuda_precision_bf16(self): - return self.get_cuda_precision() == "bf16" - - def is_cuda_precision_amp(self): - return self.get_cuda_precision() == "amp" - - @property - def default_precision_flag(self): - """ - Get the default precision config to XLA, if present. - - Whenever a model has a default precision for cuda set - we need to set proper environment flags so XLA catches - the requird precision. - - This function is a workaround. Proper solution requires - changes to the PT/XLA bridge so that the input shape - is properly inferred after issuing converts to `torch.nn.Module`. - """ - # At this moment, this method checks the precision flags only if both - # of the items below are true: - # - # 1. Device is CUDA: only check for 'DEFAULT_CUDA__PRECISION' - # - # 2. Dynamo backend is not inductor: PyTorch/benchmark scripts already - # take care of converting the model to the right precision. - # - if (self.benchmark_experiment.accelerator != "cuda" or - self.benchmark_experiment.dynamo == "inductor"): - return None - - if self.get_cuda_precision() is None: - return None - - if self.is_cuda_precision_fp16(): - return 'XLA_USE_FP16' - - if self.is_cuda_precision_bf16(): - return 'XLA_USE_BF16' - - if self.is_cuda_precision_amp(): - return None - - if self.is_cuda_precision_fp32(): - logger.warning("Sticking with the default fp32 precision.") - return None - - raise ValueError(f"Unknown precision: {precision}") + return self.benchmark_cls()( + test=self.benchmark_experiment.test, + device=device, + batch_size=self.benchmark_experiment.batch_size, + ) def update_process_env(self, process_env): - precision_flag = self.default_precision_flag - if precision_flag is not None: - process_env[precision_flag] = '1' - if self.model_name in NEED_LARGER_CACHE: process_env["XLA_COMPILATION_CACHE_SIZE"] = "2048" @@ -425,9 +358,32 @@ def pick_grad(self): return torch.enable_grad() return super().pick_grad() + def is_inference(self): + return self.benchmark_experiment.test == "eval" + + def is_training(self): + return self.benchmark_experiment.test == "train" + + def use_amp(self): + return self.is_training( + ) or self.model_name in FORCE_AMP_FOR_FP16_BF16_MODELS + + def use_fp16(self): + return self.is_inference() and self.model_name in FORCE_FP16_FOR_BF16_MODELS + + def conversion_dtype(self): + if self.is_training() or self.use_amp(): + return super().conversion_dtype() + + # From here, we are running inference without AMP, for sure. + # Do we have to use float16, instead of bfloat16? + if self.use_fp16(): + return torch.float16 + + return torch.bfloat16 + def _get_autocast_with_kwargs(self): - if (self.benchmark_experiment.accelerator == "cuda" and - self.is_cuda_precision_amp()): + if self.use_amp(): kwargs = {"dtype": torch.bfloat16} if self.benchmark_experiment.xla: # Should call device specific autocast implementations. diff --git a/benchmarks/util.py b/benchmarks/util.py index 6379e67d1f3..540839f2ac2 100644 --- a/benchmarks/util.py +++ b/benchmarks/util.py @@ -80,6 +80,15 @@ def move_to_device(item, device): return pytree.tree_map_only(torch.Tensor, lambda t: t.to(device), item) +def cast_to_dtype(item, dtype): + return pytree.tree_map_only( + torch.Tensor, + lambda t: t.to(dtype) + if isinstance(t, torch.Tensor) and t.is_floating_point() else t, + item, + ) + + def randomize_input(inputs): if isinstance(inputs, torch.Tensor): if inputs.dtype in (torch.float32, torch.float64):