Skip to content

Commit

Permalink
Using YAML configuration.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Feb 23, 2024
1 parent 4873571 commit 10419ed
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 77 deletions.
119 changes: 42 additions & 77 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch_xla.core.xla_model as xm
import types
import yaml
from util import move_to_device, set_cwd, get_torchbench_test_name
from util import move_to_device, set_cwd, get_torchbench_test_name, find_near_file
from benchmark_model import ModelLoader, BenchmarkModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,33 +126,38 @@
# This list was extracted from PyTorch's repository: benchmarks/dynamo/torchbench.py
FORCE_FP16_FOR_BF16_MODELS = {"vision_maskrcnn"}

# Some models have large dataset that doesn't fit in memory. Lower the batch
# size to test the accuracy.
# This list was extracted from PyTorch's repository: benchmarks/dynamo/torchbench.py
USE_SMALL_BATCH_SIZE = {
"demucs": 4,
"dlrm": 1024,
"densenet121": 4,
"hf_Reformer": 4,
"hf_T5_base": 4,
"timm_efficientdet": 1,
"llama_v2_7b_16h": 1,
# reduced from 16 due to cudagraphs OOM in TorchInductor dashboard
"yolov3": 8,
}

# This list was extracted from PyTorch's repository: benchmarks/dynamo/torchbench.py
INFERENCE_SMALL_BATCH_SIZE = {
"timm_efficientdet": 32,
}
@functools.lru_cache(maxsize=1)
def config_data():
"""Retrieve the skip data in the PyTorch YAML file.
# This list was extracted from PyTorch's repository: benchmarks/dynamo/torchbench.py
DONT_CHANGE_BATCH_SIZE = {
"demucs",
"pytorch_struct",
"pyhpc_turbulent_kinetic_energy",
"vision_maskrcnn", # https://github.com/pytorch/benchmark/pull/1656
}
Reads the YAML file in PyTorch's dynamo benchmarks directory, and transform
its lists of models into sets of models.
"""

benchmarks_dynamo_dir = find_near_file(
("pytorch/benchmarks/dynamo", "benchmarks/dynamo"))
assert benchmarks_dynamo_dir is not None, "PyTorch benchmarks folder not found."

skip_file = os.path.join(benchmarks_dynamo_dir, "torchbench.yaml")
with open(skip_file) as f:
data = yaml.safe_load(f)

def flatten(lst):
for item in lst:
if isinstance(item, list):
yield from flatten(item)
else:
yield item

def maybe_list_to_set(obj):
if isinstance(obj, dict):
return {k: maybe_list_to_set(v) for k, v in obj.items()}
if isinstance(obj, list):
return set(flatten(obj))
return obj

return maybe_list_to_set(data)


class TorchBenchModelLoader(ModelLoader):
Expand All @@ -161,24 +166,11 @@ def __init__(self, args):
super().__init__(args)
self.benchmark_model_class = TorchBenchModel
self.torchbench_dir = self.add_torchbench_dir()
self.config = self.get_config_data()

def _find_near_file(self, names):
"""Find a file near the current directory.
Looks for `names` in the current directory, up to its two direct parents.
"""
for dir in ("./", "../", "../../", "../../../"):
for name in names:
path = os.path.join(dir, name)
if exists(path):
return abspath(path)
return None

def add_torchbench_dir(self):
os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam

torchbench_dir = self._find_near_file(
torchbench_dir = find_near_file(
("torchbenchmark", "torchbench", "benchmark"))
assert torchbench_dir is not None, "Torch Benchmark folder not found."

Expand All @@ -190,37 +182,6 @@ def add_torchbench_dir(self):

return torchbench_dir

def get_config_data(self):
"""Retrieve the skip data in the PyTorch YAML file.
Reads the YAML file in PyTorch's dynamo benchmarks directory, and transform
its lists of models into sets of models.
"""

benchmarks_dynamo_dir = self._find_near_file(
("pytorch/benchmarks/dynamo", "benchmarks/dynamo"))
assert benchmarks_dynamo_dir is not None, "PyTorch benchmarks folder not found."

skip_file = os.path.join(benchmarks_dynamo_dir, "torchbench.yaml")
with open(skip_file) as f:
data = yaml.safe_load(f)

def flatten(lst):
for item in lst:
if isinstance(item, list):
yield from flatten(item)
else:
yield item

def maybe_list_to_set(obj):
if isinstance(obj, dict):
return {k: maybe_list_to_set(v) for k, v in obj.items()}
if isinstance(obj, list):
return set(flatten(obj))
return obj

return maybe_list_to_set(data)

def list_model_configs(self):
model_configs = []

Expand All @@ -242,7 +203,7 @@ def list_model_configs(self):

@property
def skip(self):
return self.config["skip"]
return config_data()["skip"]

def is_compatible(self, dummy_benchmark_model, benchmark_experiment):
name = dummy_benchmark_model.model_name
Expand Down Expand Up @@ -338,20 +299,24 @@ def benchmark_cls(self):
logger.warning(f"Unable to import {module_src}.")
return None

@property
def batch_size(self):
return config_data()["batch_size"]

def load_benchmark(self):
cant_change_batch_size = (
not getattr(self.benchmark_cls(), "ALLOW_CUSTOMIZE_BSIZE", True) or
model_name in DONT_CHANGE_BATCH_SIZE)
model_name in config_data()["dont_change_batch_size"])

if cant_change_batch_size:
self.benchmark_experiment.batch_size = None

if self.benchmark_experiment.batch_size is not None:
batch_size = self.benchmark_experiment.batch_size
elif self.is_training() and self.model_name in USE_SMALL_BATCH_SIZE:
batch_size = USE_SMALL_BATCH_SIZE[self.model_name]
elif self.is_inference() and self.model_name in INFERENCE_SMALL_BATCH_SIZE:
batch_size = INFERENCE_SMALL_BATCH_SIZE[self.model_name]
elif self.is_training() and self.model_name in self.batch_size["training"]:
batch_size = self.batch_size["training"][self.model_name]
elif self.is_inference() and self.model_name in self.batch_size["inference"]:
batch_size = self.batch_size["inference"][self.model_name]

# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
# torch.backends.__allow_nonbracketed_mutation_flag = True
Expand Down
15 changes: 15 additions & 0 deletions benchmarks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,18 @@ def get_tpu_name():

def get_torchbench_test_name(test):
return {"train": "training", "eval": "inference"}[test]


def find_near_file(self, names):
"""Find a file near the current directory.
Looks for `names` in the current directory, up to its two direct parents.
"""
for dir in ("./", "../", "../../", "../../../"):
for name in names:
path = os.path.join(dir, name)
if exists(path):
return abspath(path)
return None


0 comments on commit 10419ed

Please sign in to comment.