Skip to content

Commit

Permalink
[benchmarks] Run some models with smaller batch sizes. (pytorch#6542)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored and amithrm committed Mar 1, 2024
1 parent 1e772e3 commit 40bc884
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 50 deletions.
105 changes: 55 additions & 50 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch_xla.core.xla_model as xm
import types
import yaml
from util import move_to_device, set_cwd, get_torchbench_test_name
from util import move_to_device, set_cwd, get_torchbench_test_name, find_near_file
from benchmark_model import ModelLoader, BenchmarkModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,6 +112,7 @@
"hf_T5_generate",
}

# This list was extracted from PyTorch's repository: benchmarks/dynamo/torchbench.py
FORCE_AMP_FOR_FP16_BF16_MODELS = {
"DALLE2_pytorch",
"doctr_det_predictor",
Expand All @@ -122,33 +123,54 @@
"detectron2_fcos_r_50_fpn",
}

# This list was extracted from PyTorch's repository: benchmarks/dynamo/torchbench.py
FORCE_FP16_FOR_BF16_MODELS = {"vision_maskrcnn"}


@functools.lru_cache(maxsize=1)
def config_data():
"""Retrieve the skip data in the PyTorch YAML file.
Reads the YAML file in PyTorch's dynamo benchmarks directory, and transform
its lists of models into sets of models.
"""

benchmarks_dynamo_dir = find_near_file(
("pytorch/benchmarks/dynamo", "benchmarks/dynamo"))
assert benchmarks_dynamo_dir is not None, "PyTorch benchmarks folder not found."

skip_file = os.path.join(benchmarks_dynamo_dir, "torchbench.yaml")
with open(skip_file) as f:
data = yaml.safe_load(f)

def flatten(lst):
for item in lst:
if isinstance(item, list):
yield from flatten(item)
else:
yield item

def maybe_list_to_set(obj):
if isinstance(obj, dict):
return {k: maybe_list_to_set(v) for k, v in obj.items()}
if isinstance(obj, list):
return set(flatten(obj))
return obj

return maybe_list_to_set(data)


class TorchBenchModelLoader(ModelLoader):

def __init__(self, args):
super().__init__(args)
self.benchmark_model_class = TorchBenchModel
self.torchbench_dir = self.add_torchbench_dir()
self.config = self.get_config_data()

def _find_near_file(self, names):
"""Find a file near the current directory.
Looks for `names` in the current directory, up to its two direct parents.
"""
for dir in ("./", "../", "../../", "../../../"):
for name in names:
path = os.path.join(dir, name)
if exists(path):
return abspath(path)
return None

def add_torchbench_dir(self):
os.environ["KALDI_ROOT"] = "/tmp" # avoids some spam

torchbench_dir = self._find_near_file(
torchbench_dir = find_near_file(
("torchbenchmark", "torchbench", "benchmark"))
assert torchbench_dir is not None, "Torch Benchmark folder not found."

Expand All @@ -160,37 +182,6 @@ def add_torchbench_dir(self):

return torchbench_dir

def get_config_data(self):
"""Retrieve the skip data in the PyTorch YAML file.
Reads the YAML file in PyTorch's dynamo benchmarks directory, and transform
its lists of models into sets of models.
"""

benchmarks_dynamo_dir = self._find_near_file(
("pytorch/benchmarks/dynamo", "benchmarks/dynamo"))
assert benchmarks_dynamo_dir is not None, "PyTorch benchmarks folder not found."

skip_file = os.path.join(benchmarks_dynamo_dir, "torchbench.yaml")
with open(skip_file) as f:
data = yaml.safe_load(f)

def flatten(lst):
for item in lst:
if isinstance(item, list):
yield from flatten(item)
else:
yield item

def maybe_list_to_set(obj):
if isinstance(obj, dict):
return {k: maybe_list_to_set(v) for k, v in obj.items()}
if isinstance(obj, list):
return set(flatten(obj))
return obj

return maybe_list_to_set(data)

def list_model_configs(self):
model_configs = []

Expand All @@ -212,7 +203,7 @@ def list_model_configs(self):

@property
def skip(self):
return self.config["skip"]
return config_data()["skip"]

def is_compatible(self, dummy_benchmark_model, benchmark_experiment):
name = dummy_benchmark_model.model_name
Expand Down Expand Up @@ -308,12 +299,26 @@ def benchmark_cls(self):
logger.warning(f"Unable to import {module_src}.")
return None

@property
def batch_size(self):
return config_data()["batch_size"]

def load_benchmark(self):
cant_change_batch_size = (not getattr(self.benchmark_cls(),
"ALLOW_CUSTOMIZE_BSIZE", True))
cant_change_batch_size = (
not getattr(self.benchmark_cls(), "ALLOW_CUSTOMIZE_BSIZE", True) or
model_name in config_data()["dont_change_batch_size"])

if cant_change_batch_size:
self.benchmark_experiment.batch_size = None

if self.benchmark_experiment.batch_size is not None:
batch_size = self.benchmark_experiment.batch_size
elif self.is_training() and self.model_name in self.batch_size["training"]:
batch_size = self.batch_size["training"][self.model_name]
elif self.is_inference(
) and self.model_name in self.batch_size["inference"]:
batch_size = self.batch_size["inference"][self.model_name]

# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
# torch.backends.__allow_nonbracketed_mutation_flag = True

Expand All @@ -324,7 +329,7 @@ def load_benchmark(self):
return self.benchmark_cls()(
test=self.benchmark_experiment.test,
device=device,
batch_size=self.benchmark_experiment.batch_size,
batch_size=batch_size,
)

def update_process_env(self, process_env):
Expand Down
13 changes: 13 additions & 0 deletions benchmarks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,16 @@ def get_tpu_name():

def get_torchbench_test_name(test):
return {"train": "training", "eval": "inference"}[test]


def find_near_file(self, names):
"""Find a file near the current directory.
Looks for `names` in the current directory, up to its two direct parents.
"""
for dir in ("./", "../", "../../", "../../../"):
for name in names:
path = os.path.join(dir, name)
if exists(path):
return abspath(path)
return None

0 comments on commit 40bc884

Please sign in to comment.