From 857867d279aea385c2dcf2cbd8b5090ac8e73f16 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 17 Sep 2024 12:30:28 -0300 Subject: [PATCH 01/12] Add test. --- test/test_operations.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/test_operations.py b/test/test_operations.py index 1af928e6a47..6eac3d73774 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2783,6 +2783,28 @@ def test_unsafe_buffer_pointer(self): buf_ptr_3 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_3) self.assertGreaterEqual(buf_ptr_3, 0) + def test_consistent_strides(self): + def stride_is_contiguous(tensor): + sizes_and_strides = list(sorted(zip(tensor.shape, tensor.stride()), key=lambda t: t[1])) + if sizes_and_strides[0][1] != 1: + return False + for i, (size, stride) in enumerate(sizes_and_strides[:-1]): + if stride[i + 1] != stride[i] * size[i]: + return False + return True + + def assert_consistent(tensor): + self.assertEquals(tensor.is_contiguous(), stride_is_contiguous(tensor)) + + a = torch.rand(10).to(xm.xla_device()) + assert_consistent(a) + + b = a[::2] + assert_consistent(b) + + c = b[1:] + assert_consistent(c) + class TestDLPack(parameterized.TestCase): From 4c897e771acf36f4854044b1d2375c47a17ae548 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 17 Sep 2024 12:30:39 -0300 Subject: [PATCH 02/12] Forward `is_contiguous_custom` call to `TensorImpl`. --- torch_xla/csrc/tensor_impl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 4e69127ff81..52fc5cdc81f 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -175,7 +175,7 @@ int64_t XLATensorImpl::numel_custom() const { bool XLATensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { // Storage is always contiguous, but the tensor metadata is_contiguous_ might // be false due to the update in the functionalization layer.. - return true; + return c10::TensorImpl::is_contiguous_custom(memory_format); } void XLATensorImpl::SetupSizeProperties() { From dfc32cfc28c2626107ae9f43b9989ee82e9f171b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 17 Sep 2024 14:46:35 -0300 Subject: [PATCH 03/12] Add comments. --- test/test_operations.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 6eac3d73774..6f461005eba 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2784,26 +2784,42 @@ def test_unsafe_buffer_pointer(self): self.assertGreaterEqual(buf_ptr_3, 0) def test_consistent_strides(self): + # Tests whether the `is_contiguous()` method is consisten with the tensor's stride. + # In other words, if `is_contiguous()` is true, the tensor's stride should reflect + # in a contiguous storage. + def stride_is_contiguous(tensor): + # Order the sizes and strides tuple list in ascending stride order, so that the + # first element corresponds to the smallest stride. sizes_and_strides = list(sorted(zip(tensor.shape, tensor.stride()), key=lambda t: t[1])) + + # A contiguous tensor's smallest stride should be 1. if sizes_and_strides[0][1] != 1: return False + + # Check whether the next larger stride `stride[i + 1]` is equal the current + # one `stride[i]` multiplied by the current size `size[i]`. for i, (size, stride) in enumerate(sizes_and_strides[:-1]): if stride[i + 1] != stride[i] * size[i]: return False + return True - def assert_consistent(tensor): + def assert_strides_consistent(tensor, value): + self.assertEquals(tensor.is_contiguous(), value) self.assertEquals(tensor.is_contiguous(), stride_is_contiguous(tensor)) + # Obviously contiguous, since it was created with random. a = torch.rand(10).to(xm.xla_device()) - assert_consistent(a) + assert_strides_consistent(a, True) + # Not contiguous, since we are skipping every other element. b = a[::2] - assert_consistent(b) + assert_strides_consistent(b, False) + # Still not contiguous, since 'b' is not contiguous. c = b[1:] - assert_consistent(c) + assert_strides_consistent(c, False) class TestDLPack(parameterized.TestCase): From ca025430bf0a7aa49563893a56ac84c354d8db20 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 17 Sep 2024 15:05:37 -0300 Subject: [PATCH 04/12] Fix lint issues. --- test/test_operations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_operations.py b/test/test_operations.py index 6f461005eba..75cbd5cb434 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2791,7 +2791,8 @@ def test_consistent_strides(self): def stride_is_contiguous(tensor): # Order the sizes and strides tuple list in ascending stride order, so that the # first element corresponds to the smallest stride. - sizes_and_strides = list(sorted(zip(tensor.shape, tensor.stride()), key=lambda t: t[1])) + sizes_and_strides = list( + sorted(zip(tensor.shape, tensor.stride()), key=lambda t: t[1])) # A contiguous tensor's smallest stride should be 1. if sizes_and_strides[0][1] != 1: From 408ae7458c6a28eaef7fe625ea062476f2f512bd Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 17 Sep 2024 16:21:19 -0300 Subject: [PATCH 05/12] Fix test. --- test/test_operations.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 75cbd5cb434..1933a321587 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2806,21 +2806,20 @@ def stride_is_contiguous(tensor): return True - def assert_strides_consistent(tensor, value): - self.assertEquals(tensor.is_contiguous(), value) + def assert_strides_consistent(tensor): self.assertEquals(tensor.is_contiguous(), stride_is_contiguous(tensor)) # Obviously contiguous, since it was created with random. a = torch.rand(10).to(xm.xla_device()) - assert_strides_consistent(a, True) + assert_strides_consistent(a) # Not contiguous, since we are skipping every other element. b = a[::2] - assert_strides_consistent(b, False) + assert_strides_consistent(b) # Still not contiguous, since 'b' is not contiguous. c = b[1:] - assert_strides_consistent(c, False) + assert_strides_consistent(c) class TestDLPack(parameterized.TestCase): From 8c324e0a6ea8c8852a0fbae96ab2a925e1a8627e Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 18 Sep 2024 11:47:23 -0300 Subject: [PATCH 06/12] Run meta function on `clone`. --- torch_xla/csrc/aten_xla_type.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index d355d6c378f..585625f20ca 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1227,11 +1227,26 @@ at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self, } at::Tensor XLANativeFunctions::clone( - const at::Tensor& self, - std::optional /* memory_format */) { + const at::Tensor& self, std::optional memory_format) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( + + at::Tensor out = bridge::AtenFromXlaTensor( tensor_methods::clone(bridge::GetXlaTensor(self))); + + at::Tensor ref; + if (memory_format.has_value() && + *memory_format != at::MemoryFormat::Preserve) { + // We need to run the meta function as reference, for setting the correct + // strides to the output tensor. + at::Tensor ref_self = self.to(at::kMeta); + ref = ref_self.clone(memory_format); + } else { + ref = self; + } + out.unsafeGetTensorImpl()->set_sizes_and_strides(ref.sym_sizes(), + ref.sym_strides()); + + return out; } at::Tensor XLANativeFunctions::constant_pad_nd(const at::Tensor& self, From 5d914a4afff395c1080baf91d127c51e8011f8cf Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 18 Sep 2024 17:40:26 -0300 Subject: [PATCH 07/12] Add test for contiguity on different memory formats. --- test/test_operations.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/test/test_operations.py b/test/test_operations.py index 1933a321587..83021b38c99 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -58,7 +58,7 @@ DeviceSupport = collections.namedtuple('DeviceSupport', ['num_devices']) XLA_DISABLE_FUNCTIONALIZATION = bool( - os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', False)) + int(os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', '0'))) def _is_on_tpu(): @@ -2821,6 +2821,26 @@ def assert_strides_consistent(tensor): c = b[1:] assert_strides_consistent(c) + def test_contiguity_on_different_memory_format(self): + # Create contiguous strided tensor. + a = torch.rand(2, 3, 4, 5).to(xm.xla_device()) + self.assertTrue(a.is_contiguous()) + # When functionalization is disabled, we fallback to the old behavior, where + # `is_contiguous()` calls always returns True. + self.assertEquals(a.is_contiguous(memory_format=torch.channels_last), XLA_DISABLE_FUNCTIONALIZATION) + + # Make `a` contiguous in torch.channels_last memory format. + # + # This should, in theory, be a no-op, since we can't really change the strides + # of XLA tensors. However, `contiguous` is a composite operation that checks the + # tensor's metadata. Therefore, it shall clone the tensor whenever its strides + # do not conform to the given memory format. + b = a.contiguous(memory_format=torch.channels_last) + # When functionalization is disabled, we fallback to the old behavior, where + # `is_contiguous()` calls always returns True. + self.assertEquals(b.is_contiguous(), XLA_DISABLE_FUNCTIONALIZATION) + self.assertTrue(b.is_contiguous(memory_format=torch.channels_last)) + class TestDLPack(parameterized.TestCase): From 780d2c6bc759709a2e4f610f51f97387c7860de7 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 18 Sep 2024 17:40:58 -0300 Subject: [PATCH 08/12] Fallback to old `is_contiguous()` behavior when functionalization is disabled. --- torch_xla/csrc/aten_xla_type.cpp | 24 +++++++++++++----------- torch_xla/csrc/tensor_impl.cpp | 7 +++++++ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 585625f20ca..2ea7e4e6a87 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1233,18 +1233,20 @@ at::Tensor XLANativeFunctions::clone( at::Tensor out = bridge::AtenFromXlaTensor( tensor_methods::clone(bridge::GetXlaTensor(self))); - at::Tensor ref; - if (memory_format.has_value() && - *memory_format != at::MemoryFormat::Preserve) { - // We need to run the meta function as reference, for setting the correct - // strides to the output tensor. - at::Tensor ref_self = self.to(at::kMeta); - ref = ref_self.clone(memory_format); - } else { - ref = self; + if (!runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) { + at::Tensor ref; + if (memory_format.has_value() && + *memory_format != at::MemoryFormat::Preserve) { + // We need to run the meta function as reference, for setting the correct + // strides to the output tensor. + at::Tensor ref_self = self.to(at::kMeta); + ref = ref_self.clone(memory_format); + } else { + ref = self; + } + out.unsafeGetTensorImpl()->set_sizes_and_strides(ref.sym_sizes(), + ref.sym_strides()); } - out.unsafeGetTensorImpl()->set_sizes_and_strides(ref.sym_sizes(), - ref.sym_strides()); return out; } diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 52fc5cdc81f..fb6c4508a76 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -173,6 +173,13 @@ int64_t XLATensorImpl::numel_custom() const { } bool XLATensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { + // If functionalization is disabled, the tensors' metadata aren't being updated + // w.r.t. the output of meta functions. Therefore, we fallback to the old behavior + // returning true, always. + if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) { + return true; + } + // Storage is always contiguous, but the tensor metadata is_contiguous_ might // be false due to the update in the functionalization layer.. return c10::TensorImpl::is_contiguous_custom(memory_format); From b045e2ced97ef3317f50b12d2b81c2978a3a97ac Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 18 Sep 2024 17:56:59 -0300 Subject: [PATCH 09/12] Fix lint issues. --- test/test_operations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_operations.py b/test/test_operations.py index 83021b38c99..4a1338dc30e 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2827,7 +2827,9 @@ def test_contiguity_on_different_memory_format(self): self.assertTrue(a.is_contiguous()) # When functionalization is disabled, we fallback to the old behavior, where # `is_contiguous()` calls always returns True. - self.assertEquals(a.is_contiguous(memory_format=torch.channels_last), XLA_DISABLE_FUNCTIONALIZATION) + self.assertEquals( + a.is_contiguous(memory_format=torch.channels_last), + XLA_DISABLE_FUNCTIONALIZATION) # Make `a` contiguous in torch.channels_last memory format. # From 8239b9b7c0d544662a61bc2f2f4aa68eccd06a1d Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 18 Sep 2024 18:00:15 -0300 Subject: [PATCH 10/12] Fix lint issues. --- torch_xla/csrc/tensor_impl.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index fb6c4508a76..22539532663 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -173,9 +173,9 @@ int64_t XLATensorImpl::numel_custom() const { } bool XLATensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { - // If functionalization is disabled, the tensors' metadata aren't being updated - // w.r.t. the output of meta functions. Therefore, we fallback to the old behavior - // returning true, always. + // If functionalization is disabled, the tensors' metadata aren't being + // updated w.r.t. the output of meta functions. Therefore, we fallback to the + // old behavior returning true, always. if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) { return true; } From fc49779066387c30de8cd23aefc61a75343593a0 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 20 Sep 2024 17:11:17 -0300 Subject: [PATCH 11/12] Fix type hints for Python 3.8. --- torch_xla/experimental/scan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/scan.py b/torch_xla/experimental/scan.py index 9008e03dbd9..d6eb4d1a65b 100644 --- a/torch_xla/experimental/scan.py +++ b/torch_xla/experimental/scan.py @@ -4,7 +4,7 @@ """ -from typing import Callable, TypeVar +from typing import Callable, Tuple, TypeVar import torch from torch.utils._pytree import tree_map, tree_iter @@ -15,10 +15,10 @@ def scan( - fn: Callable[[Carry, X], tuple[Carry, Y]], + fn: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry, xs: X, -) -> tuple[Carry, Y]: +) -> Tuple[Carry, Y]: """Apply a function over leading dimension of tensors while carrying along state. This is similar to the JAX `jax.lax.scan` function found in [1]. From d42b1c4e59d07ba5e77759d34ce55668347d462b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 20 Sep 2024 19:00:01 -0300 Subject: [PATCH 12/12] Fix test. --- test/spmd/test_xla_sharding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 1473fd5f995..18a64629f25 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -1068,7 +1068,7 @@ def test_backward_optimization_barrier(self): hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad]) self.assertIn( - '%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)', + '%opt-barrier.38 = (f32[1,64]{1,0}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{1,0}, f32[1]{0}, f32[2,64]{1,0}) %tuple.37)', hlo) def test_mark_shard_scalar(self):