From 10bd7d88b0111198f954e4639c6adfd0ae4c131c Mon Sep 17 00:00:00 2001 From: Emilio Cota Date: Wed, 20 Dec 2023 19:32:49 -0500 Subject: [PATCH] Revert "Re-land: Fix model initialization. (#6182)" (#6220) This reverts commit 38e364432c77e27f3a2e8f1deaa43faa91e5d812 i.e. #6182. It is causing regressions for Inductor on inference -- see https://github.com/pytorch/xla/pull/6182#discussion_r1430373949 --- benchmarks/torchbench_model.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/benchmarks/torchbench_model.py b/benchmarks/torchbench_model.py index 2aca04a552c..2098fa0baa2 100644 --- a/benchmarks/torchbench_model.py +++ b/benchmarks/torchbench_model.py @@ -6,7 +6,6 @@ import sys import torch import torch.nn as nn -import torch.utils._pytree as pytree from torch._dynamo.testing import collect_results, reduce_to_scalar_loss from torch._dynamo.utils import clone_inputs import types @@ -158,6 +157,7 @@ def is_compatible(self, dummy_benchmark_model, benchmark_experiment): break if matched: return False + return True @@ -207,23 +207,12 @@ def load_benchmark(self): # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" # torch.backends.__allow_nonbracketed_mutation_flag = True - benchmark = benchmark_cls( - test=self.benchmark_experiment.test, - device=self.benchmark_experiment.accelerator, - batch_size=self.benchmark_experiment.batch_size, - ) - - self.module, self.example_inputs = benchmark.get_module() - - # Move the initialized model to XLA device. - device = self.benchmark_experiment.get_device() - if self.benchmark_experiment.xla: - self.module = self.module.to(device) - self.example_inputs = pytree.tree_map_only(torch.Tensor, - lambda t: t.to(device), - self.example_inputs) - - self.benchmark_experiment.batch_size = benchmark.batch_size + if self.benchmark_experiment.accelerator == "cpu": + device = "cpu" + elif self.benchmark_experiment.accelerator == "cuda" and not self.benchmark_experiment.xla: + device = "cuda" + else: + device = str(self.benchmark_experiment.get_device()) return benchmark_cls( test=self.benchmark_experiment.test,