From 38e364432c77e27f3a2e8f1deaa43faa91e5d812 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 15 Dec 2023 15:50:05 -0300 Subject: [PATCH] Re-land: Fix model initialization. (#6182) --- benchmarks/torchbench_model.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index fcde37f4528..2f7046f2687 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -6,6 +6,7 @@ import sys import torch import torch.nn as nn +import torch.utils._pytree as pytree from torch._dynamo.testing import collect_results, reduce_to_scalar_loss from torch._dynamo.utils import clone_inputs import types @@ -143,7 +144,6 @@ def is_compatible(self, dummy_benchmark_model, benchmark_experiment): break if matched: return False - return True @@ -193,12 +193,23 @@ def load_benchmark(self): # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" # torch.backends.__allow_nonbracketed_mutation_flag = True - if self.benchmark_experiment.accelerator == "cpu": - device = "cpu" - elif self.benchmark_experiment.accelerator == "cuda" and not self.benchmark_experiment.xla: - device = "cuda" - else: - device = str(self.benchmark_experiment.get_device()) + benchmark = benchmark_cls( + test=self.benchmark_experiment.test, + device=self.benchmark_experiment.accelerator, + batch_size=self.benchmark_experiment.batch_size, + ) + + self.module, self.example_inputs = benchmark.get_module() + + # Move the initialized model to XLA device. + if self.benchmark_experiment.xla: + device = self.benchmark_experiment.get_device() + self.module = self.module.to(device) + self.example_inputs = pytree.tree_map_only(torch.Tensor, + lambda t: t.to(device), + self.example_inputs) + + self.benchmark_experiment.batch_size = benchmark.batch_size return benchmark_cls( test=self.benchmark_experiment.test,