diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index c791cd8ca417..999a251f2c91 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -4,7 +4,6 @@ import importlib import logging import os -from os.path import abspath, exists import sys import torch import torch.amp @@ -306,18 +305,19 @@ def batch_size(self): def load_benchmark(self): cant_change_batch_size = ( not getattr(self.benchmark_cls(), "ALLOW_CUSTOMIZE_BSIZE", True) or - model_name in config_data()["dont_change_batch_size"]) + self.model_name in config_data()["dont_change_batch_size"]) if cant_change_batch_size: self.benchmark_experiment.batch_size = None - if self.benchmark_experiment.batch_size is not None: - batch_size = self.benchmark_experiment.batch_size - elif self.is_training() and self.model_name in self.batch_size["training"]: - batch_size = self.batch_size["training"][self.model_name] - elif self.is_inference( - ) and self.model_name in self.batch_size["inference"]: - batch_size = self.batch_size["inference"][self.model_name] + batch_size = self.benchmark_experiment.batch_size + + if batch_size is None: + if self.is_training() and self.model_name in self.batch_size["training"]: + batch_size = self.batch_size["training"][self.model_name] + elif self.is_inference( + ) and self.model_name in self.batch_size["inference"]: + batch_size = self.batch_size["inference"][self.model_name] # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" # torch.backends.__allow_nonbracketed_mutation_flag = True diff --git a/benchmarks/util.py b/benchmarks/util.py index 21f3736d91a2..ce56ceb4143e 100644 --- a/benchmarks/util.py +++ b/benchmarks/util.py @@ -3,7 +3,7 @@ import logging import numpy as np import os -from os.path import abspath +from os.path import abspath, exists import random import subprocess import torch @@ -158,7 +158,7 @@ def get_torchbench_test_name(test): return {"train": "training", "eval": "inference"}[test] -def find_near_file(self, names): +def find_near_file(names): """Find a file near the current directory. Looks for `names` in the current directory, up to its two direct parents.