From 139d891ed9ddb4f47c03e4f482cfe9b64f063e90 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Mon, 8 Apr 2024 14:21:44 -0300 Subject: [PATCH] Add dynamo `torch.Tensor.new` test. (#6661) --- test/dynamo/test_dynamo.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 27daaca4161..f0aee024f75 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -632,6 +632,31 @@ def test_all_cpu_tensor(self): self.assertIn('MarkStep', met.counter_names()) +class DynamoOperationsTests(test_utils.XlaTestCase): + + def test_new_with_sizes(self): + + # The addition operation is needed here, since the error only occurs when FakeTensorMode + # checks the device of the arguments of some operation. If there's no operation using the + # result of Tensor.new, this comparison never occurs. + def foo(x): + return x.new(*x.size()) + x + + optfoo = torch.compile(backend="openxla")(foo) + + t = torch.arange(9) + Xt = t.to(xm.xla_device()) + + expected = foo(t) + actual = optfoo(Xt).cpu() + + # Here, we don't expect the actual data to be the same. Reason being that Tensor.new + # returns uninitialized data. + self.assertEqual(expected.shape, actual.shape) + self.assertEqual(expected.dtype, actual.dtype) + self.assertEqual(expected.device, actual.device) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1)