Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward XLATensorImpl::is_contiguous_custom to TensorImpl. #8032

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def test_backward_optimization_barrier(self):

hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
self.assertIn(
'%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)',
'%opt-barrier.38 = (f32[1,64]{1,0}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{1,0}, f32[1]{0}, f32[2,64]{1,0}) %tuple.37)',
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed for fixing a CI failure in pytorch/pytorch#135237. @JackCaoG let me know if you think this should not be happening.

hlo)

def test_mark_shard_scalar(self):
Expand Down
62 changes: 61 additions & 1 deletion test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
DeviceSupport = collections.namedtuple('DeviceSupport', ['num_devices'])

XLA_DISABLE_FUNCTIONALIZATION = bool(
os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', False))
int(os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', '0')))


def _is_on_tpu():
Expand Down Expand Up @@ -2783,6 +2783,66 @@ def test_unsafe_buffer_pointer(self):
buf_ptr_3 = torch_xla._XLAC._unsafe_buffer_pointer(xla_tensor_3)
self.assertGreaterEqual(buf_ptr_3, 0)

def test_consistent_strides(self):
# Tests whether the `is_contiguous()` method is consisten with the tensor's stride.
# In other words, if `is_contiguous()` is true, the tensor's stride should reflect
# in a contiguous storage.

def stride_is_contiguous(tensor):
# Order the sizes and strides tuple list in ascending stride order, so that the
# first element corresponds to the smallest stride.
sizes_and_strides = list(
sorted(zip(tensor.shape, tensor.stride()), key=lambda t: t[1]))

# A contiguous tensor's smallest stride should be 1.
if sizes_and_strides[0][1] != 1:
return False

# Check whether the next larger stride `stride[i + 1]` is equal the current
# one `stride[i]` multiplied by the current size `size[i]`.
for i, (size, stride) in enumerate(sizes_and_strides[:-1]):
if stride[i + 1] != stride[i] * size[i]:
return False

return True

def assert_strides_consistent(tensor):
self.assertEquals(tensor.is_contiguous(), stride_is_contiguous(tensor))

# Obviously contiguous, since it was created with random.
a = torch.rand(10).to(xm.xla_device())
assert_strides_consistent(a)

# Not contiguous, since we are skipping every other element.
b = a[::2]
assert_strides_consistent(b)

# Still not contiguous, since 'b' is not contiguous.
c = b[1:]
assert_strides_consistent(c)

def test_contiguity_on_different_memory_format(self):
# Create contiguous strided tensor.
a = torch.rand(2, 3, 4, 5).to(xm.xla_device())
self.assertTrue(a.is_contiguous())
# When functionalization is disabled, we fallback to the old behavior, where
# `is_contiguous()` calls always returns True.
self.assertEquals(
a.is_contiguous(memory_format=torch.channels_last),
XLA_DISABLE_FUNCTIONALIZATION)

# Make `a` contiguous in torch.channels_last memory format.
#
# This should, in theory, be a no-op, since we can't really change the strides
# of XLA tensors. However, `contiguous` is a composite operation that checks the
# tensor's metadata. Therefore, it shall clone the tensor whenever its strides
# do not conform to the given memory format.
b = a.contiguous(memory_format=torch.channels_last)
# When functionalization is disabled, we fallback to the old behavior, where
# `is_contiguous()` calls always returns True.
self.assertEquals(b.is_contiguous(), XLA_DISABLE_FUNCTIONALIZATION)
self.assertTrue(b.is_contiguous(memory_format=torch.channels_last))


class TestDLPack(parameterized.TestCase):

Expand Down
23 changes: 20 additions & 3 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1227,11 +1227,28 @@ at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self,
}

at::Tensor XLANativeFunctions::clone(
const at::Tensor& self,
std::optional<at::MemoryFormat> /* memory_format */) {
const at::Tensor& self, std::optional<at::MemoryFormat> memory_format) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return bridge::AtenFromXlaTensor(

at::Tensor out = bridge::AtenFromXlaTensor(
tensor_methods::clone(bridge::GetXlaTensor(self)));

if (!runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
at::Tensor ref;
if (memory_format.has_value() &&
*memory_format != at::MemoryFormat::Preserve) {
// We need to run the meta function as reference, for setting the correct
// strides to the output tensor.
at::Tensor ref_self = self.to(at::kMeta);
ref = ref_self.clone(memory_format);
} else {
ref = self;
}
out.unsafeGetTensorImpl()->set_sizes_and_strides(ref.sym_sizes(),
ref.sym_strides());
}

return out;
}

at::Tensor XLANativeFunctions::constant_pad_nd(const at::Tensor& self,
Expand Down
9 changes: 8 additions & 1 deletion torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,16 @@ int64_t XLATensorImpl::numel_custom() const {
}

bool XLATensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
// If functionalization is disabled, the tensors' metadata aren't being
// updated w.r.t. the output of meta functions. Therefore, we fallback to the
// old behavior returning true, always.
if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just add a IsFunctionalizationDisabled(use a static bool to store the value of this env var) in https://github.com/pytorch/xla/blob/ea8c47a345a29c6a1b1bdf4ee38a9159b07a980f/torch_xla/csrc/xla_graph_executor.h or in tensor.h instead of keep getting the from the env var

return true;
}

// Storage is always contiguous, but the tensor metadata is_contiguous_ might
// be false due to the update in the functionalization layer..
return true;
return c10::TensorImpl::is_contiguous_custom(memory_format);
}

void XLATensorImpl::SetupSizeProperties() {
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/experimental/scan.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file are needed because it seems that PyTorch is running these tests with Python 3.8. Thus, the subscript operator is not allowed for types.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hopefully pytorch/pytorch#135278 can be merged and we don't have this issue anymore in the future...

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

"""

from typing import Callable, TypeVar
from typing import Callable, Tuple, TypeVar

import torch
from torch.utils._pytree import tree_map, tree_iter
Expand All @@ -15,10 +15,10 @@


def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
fn: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
) -> tuple[Carry, Y]:
) -> Tuple[Carry, Y]:
"""Apply a function over leading dimension of tensors while carrying along state.

This is similar to the JAX `jax.lax.scan` function found in [1].
Expand Down
Loading