diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 27daaca4161..34a37c710cc 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -632,6 +632,24 @@ def test_all_cpu_tensor(self): self.assertIn('MarkStep', met.counter_names()) +class DynamoOperationsTests(test_utils.XlaTestCase): + + def test_new_with_sizes(self): + + def foo(x): + return x.new(*x.sizes()) + x + + optfoo = torch.compile(backend="openxla")(foo) + + t = torch.arange(10) + Xt = t.to(xm.xla_device()) + + expected = foo(t) + actual = optfoo(Xt) + + self.assertEqual(expected, actual.cpu()) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1)