Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make as_strided_copy materialize a new tensor with index. #6624

Merged
merged 7 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 77 additions & 6 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ def _is_on_eager_debug_mode():
'skip on eager debug mode')


def _skipIfFunctionalization(value=True, reason=""):
verb = "is" if value else "is not"
reason = f" Reason: {reason}" if reason else ""
return unittest.skipIf(
XLA_DISABLE_FUNCTIONALIZATION is value,
f'Works only when functionalization {verb} disabled.{reason}.')


def skipIfFunctionalizationEnabled(reason):
return _skipIfFunctionalization(value=False, reason=reason)


def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(value=True, reason=reason)


def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)

Expand Down Expand Up @@ -978,8 +994,8 @@ def func(a, b):

# TODO - upstream behavior has changed and results in expected DestroyXlaTensor
# counter as of 11/13/2023. Re-enable after reviewing the change.
@unittest.skipIf(True or XLA_DISABLE_FUNCTIONALIZATION,
'Metrics differ when functionalization is disabled.')
# @skipIfFunctionalizationDisabled("metrics differ")
@unittest.skip
def test_set(self):
met.clear_all()

Expand All @@ -997,8 +1013,7 @@ def test_set(self):
# shouldn't crash
self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10)))

@unittest.skipIf(XLA_DISABLE_FUNCTIONALIZATION,
'Metrics differ when functionalization is disabled.')
@skipIfFunctionalizationDisabled("metrics differ")
def test_replace_xla_tensor(self):
met.clear_all()

Expand Down Expand Up @@ -1341,8 +1356,7 @@ def test_fn(t, c):
), dtype=torch.int64)
self.runAtenTest([token_type_ids, cat_ids], test_fn)

@unittest.skipIf(not XLA_DISABLE_FUNCTIONALIZATION,
'When functionalization is enabled, views do not exist.')
@skipIfFunctionalizationEnabled("views do not exist")
def test_save_view_alias_check(self):

class Nested(object):
Expand Down Expand Up @@ -1498,6 +1512,63 @@ def test_fn(r):

self.runAtenTest([torch.arange(144, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_gap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (8, 1))

self.runAtenTest([torch.arange(28, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_gap_no_unit_stride(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (8, 2))

self.runAtenTest([torch.arange(31, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_overlap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (2, 1))

self.runAtenTest([torch.arange(10, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_overlap_and_gap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (4, 2))

self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_overlap_zero_stride(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (0, 1))

self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_gap_no_unit_stride(self):

def test_fn(r):
x = r.view(8, 4)
return torch.as_strided(r, (4, 4), (6, 2))

self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_empty_args(self):

def test_fn(r):
return torch.as_strided(r, tuple(), tuple())

self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn)

def test_basic_bfloat16(self):

def test_fn(s):
Expand Down
7 changes: 7 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __new__(cls, name, variant_test_name=""):
{
AllowedOpInfoEntry('abs'),
AllowedOpInfoEntry('add'),
AllowedOpInfoEntry('as_strided'),
AllowedOpInfoEntry('mul'),
AllowedOpInfoEntry('sub'),
AllowedOpInfoEntry('addmm'),
Expand Down Expand Up @@ -349,6 +350,12 @@ def __new__(cls, name, variant_test_name=""):
# AllowedOpInfoEntry('var_mean'),
# AllowedOpInfoEntry('pow'), # for int64 don't work, likely rounding issue
# AllowedOpInfoEntry('__rpow__'),

# In theory, this should work.
# However, the problem is the way we prepare the reference (CPU) inputs:
# we clone them. If they were a view, they are not anymore.
#
# AllowedOpInfoEntry('as_strided', 'partial_views'),
}))


Expand Down
100 changes: 82 additions & 18 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -697,24 +697,88 @@ at::Tensor XLANativeFunctions::as_strided_copy(
// This function actually operates on the tensor's storage. Since XLA does not
// expose the actual storage, we use the originally allocated tensor.
const at::Tensor& base = bridge::GetXlaTensor(self)->Base();
const at::Tensor& tensor = base.defined() ? base : self;
XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor);
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride,
storage_offset.value_or(0))) {
return at::native::call_fallback_fn<
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
storage_offset);
}
// Sets the base tensor as tensor.
// Even though this function copies (without aliasing) tensor, it's still
// treated as a view function in the functionalization layer.
return bridge::AtenFromXlaTensor(bridge::SetBaseTensor(
tensor_methods::as_strided(self_tensor, std::move(xsize),
std::move(xstride),
XlaHelpers::I64Optional(storage_offset)),
tensor));
at::Tensor tensor = base.defined() ? base : self;

auto dim = size.size();
auto itemsize = tensor.dtype().itemsize();
int64_t storage_size =
at::detail::computeStorageNbytes(size, stride, itemsize);

XLA_CHECK(tensor.numel() * itemsize >= storage_size)
<< "as_strided: storage not big enough for size " << size << ": "
<< storage_size << " (needed) vs " << tensor.numel() << " (actual).";

if (dim == 0 && tensor.numel() > 0) {
// If there's no specified dimension, return the first element of the
// storage. This behavior is consistent with eager.
return select_copy(view_copy_symint(tensor, {tensor.numel()}), 0, 0);
}

if (storage_size == 0) {
// Return an empty tensor, if no storage is actually needed.
return empty_symint(c10::fromIntArrayRefSlow(size), tensor.scalar_type(),
/* layout= */ c10::nullopt, tensor.device(),
/* pin_memory= */ c10::nullopt,
/* memory_format= */ c10::nullopt);
}

// At this point, the following is true:
lezcano marked this conversation as resolved.
Show resolved Hide resolved
XLA_CHECK(storage_size > 0);
XLA_CHECK(tensor.numel() > 0);
XLA_CHECK(dim > 0);

// Index tensor for gathering the needed elements into contiguous data.
//
// PyTorch/XLA, by default, assumes dense and contiguous data. However, when
// specifying strides, that might not be the case.
//
// Therefore, we gather the elements selected by following the size, stride,
// and storage offset, materializing it into contiguous elements.
//
// In order to accomplish that, we create an index tensor. Specifically, we
// create an n-dimensional tensor (n is the number of dimensions of the
// output) of indices. Each element represent the at which position of the
// flattened tensor the desired element is in.

// Example: arange(13).as_strided((2, 2, 2), (3, 4, 5))
//
// Start with a 1-element n-dimensional tensor, initialized with 0:
//
// [[[0]]]
//
std::vector<int64_t> view_shape(dim, 1);
auto index_tensor =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is computed by cpu eager in the following code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Given size, stride, and offset argument spec, we compute ahead of time the correct indices for materializing the tensor. No need for computing at runtime.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

at::tensor({storage_offset.value_or(self.storage_offset())},
at::TensorOptions().dtype(at::kLong))
.view(view_shape);

// Then, add to the index_tensor the offset value introduced for each possible
// index of that corresponding dimension.
lezcano marked this conversation as resolved.
Show resolved Hide resolved
//
// - Iteration i=0:
// [[[0]]] + [[[0 * 3]], [[1 * 3]]]
// = [[[0 * 3]], [[1 * 3]]]
// = [[[0]], [[3]]]
//
// - Iteration i=1:
// [[[0]], [[3]]] + [[[0 * 4], [1 * 4]]]
// = [[[0 + 0 * 4], [0 + 1 * 4]], [[3 + 0 * 4], [3 + 1 * 4]]]
// = [[[0], [4]], [[3], [7]]]
//
// - Iteration i=2:
// [[[0], [4]], [[3], [7]]] + [[[0 * 5, 1 * 5]]]
// =[[[0 + 0 * 5, 0 + 1 * 5], [4 + 0 * 5, 4 + 1 * 5]],
// [[3 + 0 * 5, 3 + 1 * 5], [7 + 0 * 5, 7 + 1 * 5]]]
// =[[[0, 5], [4, 9]], [[3, 8], [7, 12]]]
for (int i = 0; i < dim; i++) {
auto vshape = view_shape;
vshape[i] = size[i];
index_tensor =
index_tensor.add((at::arange(size[i]) * stride[i]).view(vshape));
}

// Finally, index the tensor with the computed indices.
return take(tensor, index_tensor.to(tensor.device()));
lezcano marked this conversation as resolved.
Show resolved Hide resolved
}

at::Tensor XLANativeFunctions::as_strided_scatter(
Expand Down
Loading