Skip to content

Commit

Permalink
Use PyTorch's dynamo benchmark skip-list. (#6416)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Jan 30, 2024
1 parent a83be44 commit 492fe27
Showing 1 changed file with 82 additions and 57 deletions.
139 changes: 82 additions & 57 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs
import types
import yaml
from util import move_to_device, set_cwd
from benchmark_model import ModelLoader, BenchmarkModel

Expand Down Expand Up @@ -67,59 +68,51 @@
"cm3leon_generate": [
{
"test": "train",
},
}, # Model's DEFAULT_TRAIN_BSIZE is not implemented
{
"test": "eval",
"xla": "PJRT",
},
], # no install.py
}, # TIMEOUT
],
"hf_T5_generate": [
{
"test": "train",
},
}, # Model's DEFAULT_TRAIN_BSIZE is not implemented
{
"test": "eval",
"xla": "PJRT",
},
], # no install.py
}, # TIMEOUT
],
"doctr_det_predictor": [{
"test": "train"
},], # not implemented
},], # Model's DEFAULT_TRAIN_BSIZE is not implemented
"doctr_reco_predictor": [{
"test": "train"
},], # not implemented
},], # Model's DEFAULT_TRAIN_BSIZE is not implemented
"detectron2_fcos_r_50_fpn": [{
"test": "train"
},], # not implemented
# https://github.com/pytorch/torchdynamo/issues/145
"fambench_xlmr": [{}],
"llama": [{
"test": "train"
},], # not implemented
},], # FCOS train is not supported by upstream detectron2
"mobilenet_v2_quantized_qat": [
{
"test": "eval",
"accelerator": "cuda"
}, # not implemented
}, # The eval test only supports CPU
{
"test": "eval",
"accelerator": "tpu"
},
], # not implemented
}, # The eval test only supports CPU
],
# self.load_benchmark() exits the main process. See issue #6207.
"pytorch_CycleGAN_and_pix2pix": [{}],
"pyhpc_equation_of_state": [{
"test": "train"
},], # not implemented
},], # Model's DEFAULT_TRAIN_BSIZE is not implemented
"pyhpc_isoneutral_mixing": [{
"test": "train"
},], # not implemented
},], # Model's DEFAULT_TRAIN_BSIZE is not implemented
"pyhpc_turbulent_kinetic_energy": [{
"test": "train"
},], # not implemented
"pytorch_struct": [{
"test": "eval"
},], # not implemented
},], # Model's DEFAULT_TRAIN_BSIZE is not implemented
"pytorch_unet": [
{
# self.load_benchmark() exits the main process. See issue #6207.
Expand All @@ -130,20 +123,12 @@
{
"test": "eval",
"accelerator": "cuda"
}, # not implemented
}, # The eval test only supports CPU
{
"test": "eval",
"accelerator": "tpu"
},
], # not implemented
"tacotron2": [
{
# self.load_benchmark() exits the main process. See issue #6207.
"xla": "PJRT",
},
}, # The eval test only supports CPU
],
# https://github.com/pytorch/pytorch/issues/99438
"vision_maskrcnn": [{}],
}

# This strict deny list denies tests that hold for too long and timeoout.
Expand Down Expand Up @@ -172,32 +157,59 @@ def __init__(self, args):
super().__init__(args)
self.benchmark_model_class = TorchBenchModel
self.torchbench_dir = self.add_torchbench_dir()
self.skip = self.get_skip_data()

def _find_near_file(self, names):
"""Find a file near the current directory.
Looks for `names` in the current directory, up to its two direct parents.
"""
for dir in ("./", "../", "../../"):
for name in names:
path = os.path.join(dir, name)
if exists(path):
return abspath(path)
return None

def add_torchbench_dir(self):
os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam
for torchbench_dir in (
"./torchbenchmark",
"./torchbench",
"./benchmark",
"../torchbenchmark",
"../torchbench",
"../benchmark",
"../../torchbenchmark",
"../../torchbench",
"../../benchmark",
):
if exists(torchbench_dir):
break

if exists(torchbench_dir):
torchbench_dir = abspath(torchbench_dir)
torchbench_dir = self._find_near_file(
("torchbenchmark", "torchbench", "benchmark"))
assert torchbench_dir is not None, "Torch Benchmark folder not found."

if torchbench_dir is not None:
if torchbench_dir not in sys.path:
sys.path.append(torchbench_dir)
else:
raise Exception("Torch Benchmark folder not found.")

return torchbench_dir

def get_skip_data(self):
"""Retrieve the skip data in the PyTorch YAML file.
Reads the YAML file in PyTorch's dynamo benchmarks directory, and transform
its lists of models into sets of models.
"""

benchmarks_dir = self._find_near_file(("benchmarks",))
assert benchmarks_dir is not None, "PyTorch benchmarks folder not found."

skip_file = os.path.join(benchmarks_dir, "dynamo",
"torchbench_skip_models.yaml")
with open(skip_file) as f:
data = yaml.safe_load(f)

def maybe_list_to_set(obj):
if isinstance(obj, dict):
return {k: maybe_list_to_set(v) for k, v in obj.items()}
if isinstance(obj, list):
return set(obj)
return obj

return maybe_list_to_set(data)

def list_model_configs(self):
model_configs = []

Expand All @@ -221,16 +233,29 @@ def is_compatible(self,
dummy_benchmark_model,
benchmark_experiment,
use_strict_deny=False):
name = dummy_benchmark_model.model_name
deny_list = STRICT_DENY_LIST if use_strict_deny else DENY_LIST
if dummy_benchmark_model.model_name in deny_list:
for deny_experiment_config in deny_list[dummy_benchmark_model.model_name]:
matched = True
for k, v in deny_experiment_config.items():
if getattr(benchmark_experiment, k) != v:
matched = False
break
if matched:
return False

if name in self.skip["skip"]:
return False

if name in self.skip["test"].get(benchmark_experiment.test, {}):
return False

if name in self.skip["device"].get(benchmark_experiment.accelerator, {}):
return False

if name in self.skip["multiprocess"]:
# No support for multiprocess, yet. So, skip all benchmarks that
# only work with it.
return False

def is_attr_eq(k, v):
return getattr(benchmark_experiment, k) == v

for deny_experiment_config in deny_list.get(name, []):
if all(is_attr_eq(k, v) for k, v in deny_experiment_config.items()):
return False

return True

Expand Down

0 comments on commit 492fe27

Please sign in to comment.