diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index 2aca04a552c8..2098fa0baa21 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -6,7 +6,6 @@ import sys import torch import torch.nn as nn -import torch.utils._pytree as pytree from torch._dynamo.testing import collect_results, reduce_to_scalar_loss from torch._dynamo.utils import clone_inputs import types @@ -158,6 +157,7 @@ def is_compatible(self, dummy_benchmark_model, benchmark_experiment): break if matched: return False + return True @@ -207,23 +207,12 @@ def load_benchmark(self): # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" # torch.backends.__allow_nonbracketed_mutation_flag = True - benchmark = benchmark_cls( - test=self.benchmark_experiment.test, - device=self.benchmark_experiment.accelerator, - batch_size=self.benchmark_experiment.batch_size, - ) - - self.module, self.example_inputs = benchmark.get_module() - - # Move the initialized model to XLA device. - device = self.benchmark_experiment.get_device() - if self.benchmark_experiment.xla: - self.module = self.module.to(device) - self.example_inputs = pytree.tree_map_only(torch.Tensor, - lambda t: t.to(device), - self.example_inputs) - - self.benchmark_experiment.batch_size = benchmark.batch_size + if self.benchmark_experiment.accelerator == "cpu": + device = "cpu" + elif self.benchmark_experiment.accelerator == "cuda" and not self.benchmark_experiment.xla: + device = "cuda" + else: + device = str(self.benchmark_experiment.get_device()) return benchmark_cls( test=self.benchmark_experiment.test,