diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index 633a22e0591..0de15d86cc4 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -15,7 +15,7 @@ import torch_xla.core.xla_model as xm import types import yaml -from util import move_to_device, set_cwd, get_torchbench_test_name +from util import move_to_device, set_cwd, get_torchbench_test_name, find_near_file from benchmark_model import ModelLoader, BenchmarkModel logger = logging.getLogger(__name__) @@ -112,6 +112,7 @@ "hf_T5_generate", } +# This list was extracted from PyTorch's repository: benchmarks/dynamo/torchbench.py FORCE_AMP_FOR_FP16_BF16_MODELS = { "DALLE2_pytorch", "doctr_det_predictor", @@ -122,33 +123,54 @@ "detectron2_fcos_r_50_fpn", } +# This list was extracted from PyTorch's repository: benchmarks/dynamo/torchbench.py FORCE_FP16_FOR_BF16_MODELS = {"vision_maskrcnn"} +@functools.lru_cache(maxsize=1) +def config_data(): + """Retrieve the skip data in the PyTorch YAML file. + + Reads the YAML file in PyTorch's dynamo benchmarks directory, and transform + its lists of models into sets of models. + """ + + benchmarks_dynamo_dir = find_near_file( + ("pytorch/benchmarks/dynamo", "benchmarks/dynamo")) + assert benchmarks_dynamo_dir is not None, "PyTorch benchmarks folder not found." + + skip_file = os.path.join(benchmarks_dynamo_dir, "torchbench.yaml") + with open(skip_file) as f: + data = yaml.safe_load(f) + + def flatten(lst): + for item in lst: + if isinstance(item, list): + yield from flatten(item) + else: + yield item + + def maybe_list_to_set(obj): + if isinstance(obj, dict): + return {k: maybe_list_to_set(v) for k, v in obj.items()} + if isinstance(obj, list): + return set(flatten(obj)) + return obj + + return maybe_list_to_set(data) + + class TorchBenchModelLoader(ModelLoader): def __init__(self, args): super().__init__(args) self.benchmark_model_class = TorchBenchModel self.torchbench_dir = self.add_torchbench_dir() - self.config = self.get_config_data() - - def _find_near_file(self, names): - """Find a file near the current directory. - - Looks for `names` in the current directory, up to its two direct parents. - """ - for dir in ("./", "../", "../../", "../../../"): - for name in names: - path = os.path.join(dir, name) - if exists(path): - return abspath(path) - return None def add_torchbench_dir(self): os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam - torchbench_dir = self._find_near_file( + torchbench_dir = find_near_file( ("torchbenchmark", "torchbench", "benchmark")) assert torchbench_dir is not None, "Torch Benchmark folder not found." @@ -160,37 +182,6 @@ def add_torchbench_dir(self): return torchbench_dir - def get_config_data(self): - """Retrieve the skip data in the PyTorch YAML file. - - Reads the YAML file in PyTorch's dynamo benchmarks directory, and transform - its lists of models into sets of models. - """ - - benchmarks_dynamo_dir = self._find_near_file( - ("pytorch/benchmarks/dynamo", "benchmarks/dynamo")) - assert benchmarks_dynamo_dir is not None, "PyTorch benchmarks folder not found." - - skip_file = os.path.join(benchmarks_dynamo_dir, "torchbench.yaml") - with open(skip_file) as f: - data = yaml.safe_load(f) - - def flatten(lst): - for item in lst: - if isinstance(item, list): - yield from flatten(item) - else: - yield item - - def maybe_list_to_set(obj): - if isinstance(obj, dict): - return {k: maybe_list_to_set(v) for k, v in obj.items()} - if isinstance(obj, list): - return set(flatten(obj)) - return obj - - return maybe_list_to_set(data) - def list_model_configs(self): model_configs = [] @@ -212,7 +203,7 @@ def list_model_configs(self): @property def skip(self): - return self.config["skip"] + return config_data()["skip"] def is_compatible(self, dummy_benchmark_model, benchmark_experiment): name = dummy_benchmark_model.model_name @@ -308,12 +299,26 @@ def benchmark_cls(self): logger.warning(f"Unable to import {module_src}.") return None + @property + def batch_size(self): + return config_data()["batch_size"] + def load_benchmark(self): - cant_change_batch_size = (not getattr(self.benchmark_cls(), - "ALLOW_CUSTOMIZE_BSIZE", True)) + cant_change_batch_size = ( + not getattr(self.benchmark_cls(), "ALLOW_CUSTOMIZE_BSIZE", True) or + model_name in config_data()["dont_change_batch_size"]) + if cant_change_batch_size: self.benchmark_experiment.batch_size = None + if self.benchmark_experiment.batch_size is not None: + batch_size = self.benchmark_experiment.batch_size + elif self.is_training() and self.model_name in self.batch_size["training"]: + batch_size = self.batch_size["training"][self.model_name] + elif self.is_inference( + ) and self.model_name in self.batch_size["inference"]: + batch_size = self.batch_size["inference"][self.model_name] + # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" # torch.backends.__allow_nonbracketed_mutation_flag = True @@ -324,7 +329,7 @@ def load_benchmark(self): return self.benchmark_cls()( test=self.benchmark_experiment.test, device=device, - batch_size=self.benchmark_experiment.batch_size, + batch_size=batch_size, ) def update_process_env(self, process_env): diff --git a/benchmarks/util.py b/benchmarks/util.py index 3c2da358b35..21f3736d91a 100644 --- a/benchmarks/util.py +++ b/benchmarks/util.py @@ -156,3 +156,16 @@ def get_tpu_name(): def get_torchbench_test_name(test): return {"train": "training", "eval": "inference"}[test] + + +def find_near_file(self, names): + """Find a file near the current directory. + + Looks for `names` in the current directory, up to its two direct parents. + """ + for dir in ("./", "../", "../../", "../../../"): + for name in names: + path = os.path.join(dir, name) + if exists(path): + return abspath(path) + return None