diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index 2098fa0baa2..10d250ead71 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -179,6 +179,15 @@ def set_up(self): self.benchmark_experiment.batch_size = benchmark.batch_size + # Move the initialized model to XLA device. + if self.benchmark_experiment.xla: + import torch.utils._pytree as pytree + device = self.benchmark_experiment.get_device() + self.module = self.module.to(device) + self.example_inputs = pytree.tree_map_only(torch.Tensor, + lambda t: t.to(device), + self.example_inputs) + # Torchbench has quite different setup for yolov3, so directly passing # the right example_inputs if self.model_name == "yolov3": @@ -207,16 +216,9 @@ def load_benchmark(self): # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" # torch.backends.__allow_nonbracketed_mutation_flag = True - if self.benchmark_experiment.accelerator == "cpu": - device = "cpu" - elif self.benchmark_experiment.accelerator == "cuda" and not self.benchmark_experiment.xla: - device = "cuda" - else: - device = str(self.benchmark_experiment.get_device()) - return benchmark_cls( test=self.benchmark_experiment.test, - device=device, + device=self.benchmark_experiment.accelerator, batch_size=self.benchmark_experiment.batch_size, )