From a8a4065c8664a4eae292f63ac7fc2f287e1573d1 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 24 Jan 2024 17:07:59 -0300 Subject: [PATCH] Cache `benchmark_cls` and use it for checking precision. --- benchmarks/torchbench_model.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index 03254fa4051..49ea232af5a 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -1,3 +1,4 @@ +import functools import gc import importlib import logging @@ -244,16 +245,18 @@ def set_up(self): del benchmark self._cleanup() - def load_benchmark(self): + @functools.lru_cache(maxsize=1) + def benchmark_cls(self): try: module = importlib.import_module( f"torchbenchmark.models.{self.model_name}") except ModuleNotFoundError: module = importlib.import_module( f"torchbenchmark.models.fb.{self.model_name}") - benchmark_cls = getattr(module, "Model", None) + return getattr(module, "Model", None) - cant_change_batch_size = (not getattr(benchmark_cls, + def load_benchmark(self): + cant_change_batch_size = (not getattr(self.benchmark_cls(), "ALLOW_CUSTOMIZE_BSIZE", True)) if cant_change_batch_size: self.benchmark_experiment.batch_size = None @@ -264,7 +267,8 @@ def load_benchmark(self): # torchbench uses `xla` as device instead of `tpu` if device := self.benchmark_experiment.accelerator == 'tpu': device = str(self.benchmark_experiment.get_device()) - return benchmark_cls( + + return self.benchmark_cls()( test=self.benchmark_experiment.test, device=device, batch_size=self.benchmark_experiment.batch_size, @@ -285,20 +289,20 @@ def default_precision_flag(self): """ test = self.benchmark_experiment.test try: - benchmark = self.load_benchmark() + benchmark_cls = self.benchmark_cls() except Exception: - logger.exception("Cannot load benchmark model") + logger.exception("Cannot import benchmark model") return None - if test == "eval" and hasattr(benchmark, 'DEFAULT_EVAL_CUDA_PRECISION'): - precision = benchmark.DEFAULT_EVAL_CUDA_PRECISION - elif test == "train" and hasattr(benchmark, 'DEFAULT_TRAIN_CUDA_PRECISION'): - precision = benchmark.DEFAULT_TRAIN_CUDA_PRECISION + if test == "eval" and hasattr(benchmark_cls, 'DEFAULT_EVAL_CUDA_PRECISION'): + precision = benchmark_cls.DEFAULT_EVAL_CUDA_PRECISION + elif test == "train" and hasattr(benchmark_cls, + 'DEFAULT_TRAIN_CUDA_PRECISION'): + precision = benchmark_cls.DEFAULT_TRAIN_CUDA_PRECISION else: precision = None logger.warning("No default precision set. No patching needed.") - del benchmark self._cleanup() precision_flag = None