Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use benchmark_cls for checking precision. #6375

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import gc
import importlib
import logging
Expand Down Expand Up @@ -244,16 +245,18 @@ def set_up(self):
del benchmark
self._cleanup()

def load_benchmark(self):
@functools.lru_cache(maxsize=1)
zpcore marked this conversation as resolved.
Show resolved Hide resolved
def benchmark_cls(self):
try:
module = importlib.import_module(
f"torchbenchmark.models.{self.model_name}")
except ModuleNotFoundError:
module = importlib.import_module(
f"torchbenchmark.models.fb.{self.model_name}")
benchmark_cls = getattr(module, "Model", None)
return getattr(module, "Model", None)

cant_change_batch_size = (not getattr(benchmark_cls,
def load_benchmark(self):
cant_change_batch_size = (not getattr(self.benchmark_cls(),
"ALLOW_CUSTOMIZE_BSIZE", True))
if cant_change_batch_size:
self.benchmark_experiment.batch_size = None
Expand All @@ -264,7 +267,8 @@ def load_benchmark(self):
# torchbench uses `xla` as device instead of `tpu`
if device := self.benchmark_experiment.accelerator == 'tpu':
device = str(self.benchmark_experiment.get_device())
return benchmark_cls(

return self.benchmark_cls()(
test=self.benchmark_experiment.test,
device=device,
batch_size=self.benchmark_experiment.batch_size,
Expand All @@ -285,20 +289,20 @@ def default_precision_flag(self):
"""
test = self.benchmark_experiment.test
try:
benchmark = self.load_benchmark()
benchmark_cls = self.benchmark_cls()
except Exception:
logger.exception("Cannot load benchmark model")
logger.exception("Cannot import benchmark model")
return None

if test == "eval" and hasattr(benchmark, 'DEFAULT_EVAL_CUDA_PRECISION'):
precision = benchmark.DEFAULT_EVAL_CUDA_PRECISION
elif test == "train" and hasattr(benchmark, 'DEFAULT_TRAIN_CUDA_PRECISION'):
precision = benchmark.DEFAULT_TRAIN_CUDA_PRECISION
if test == "eval" and hasattr(benchmark_cls, 'DEFAULT_EVAL_CUDA_PRECISION'):
precision = benchmark_cls.DEFAULT_EVAL_CUDA_PRECISION
elif test == "train" and hasattr(benchmark_cls,
'DEFAULT_TRAIN_CUDA_PRECISION'):
precision = benchmark_cls.DEFAULT_TRAIN_CUDA_PRECISION
else:
precision = None
logger.warning("No default precision set. No patching needed.")

del benchmark
self._cleanup()

precision_flag = None
Expand Down