Skip to content

Commit

Permalink
Re-land: Make as_strided_copy materialize a new tensor with index. (
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Mar 19, 2024
1 parent 6ac3223 commit 27a7dd3
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 24 deletions.
83 changes: 77 additions & 6 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ def _is_on_eager_debug_mode():
'skip on eager debug mode')


def _skipIfFunctionalization(value=True, reason=""):
verb = "is" if value else "is not"
reason = f" Reason: {reason}" if reason else ""
return unittest.skipIf(
XLA_DISABLE_FUNCTIONALIZATION is value,
f'Works only when functionalization {verb} disabled.{reason}.')


def skipIfFunctionalizationEnabled(reason):
return _skipIfFunctionalization(value=False, reason=reason)


def skipIfFunctionalizationDisabled(reason):
return _skipIfFunctionalization(value=True, reason=reason)


def _gen_tensor(*args, **kwargs):
return torch.randn(*args, **kwargs)

Expand Down Expand Up @@ -977,8 +993,8 @@ def func(a, b):

# TODO - upstream behavior has changed and results in expected DestroyXlaTensor
# counter as of 11/13/2023. Re-enable after reviewing the change.
@unittest.skipIf(True or XLA_DISABLE_FUNCTIONALIZATION,
'Metrics differ when functionalization is disabled.')
# @skipIfFunctionalizationDisabled("metrics differ")
@unittest.skip
def test_set(self):
met.clear_all()

Expand All @@ -996,8 +1012,7 @@ def test_set(self):
# shouldn't crash
self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10)))

@unittest.skipIf(XLA_DISABLE_FUNCTIONALIZATION,
'Metrics differ when functionalization is disabled.')
@skipIfFunctionalizationDisabled("metrics differ")
def test_replace_xla_tensor(self):
met.clear_all()

Expand Down Expand Up @@ -1340,8 +1355,7 @@ def test_fn(t, c):
), dtype=torch.int64)
self.runAtenTest([token_type_ids, cat_ids], test_fn)

@unittest.skipIf(not XLA_DISABLE_FUNCTIONALIZATION,
'When functionalization is enabled, views do not exist.')
@skipIfFunctionalizationEnabled("views do not exist")
def test_save_view_alias_check(self):

class Nested(object):
Expand Down Expand Up @@ -1497,6 +1511,63 @@ def test_fn(r):

self.runAtenTest([torch.arange(144, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_gap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (8, 1))

self.runAtenTest([torch.arange(28, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_gap_no_unit_stride(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (8, 2))

self.runAtenTest([torch.arange(31, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_overlap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (2, 1))

self.runAtenTest([torch.arange(10, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_overlap_and_gap(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (4, 2))

self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_overlap_zero_stride(self):

def test_fn(r):
return torch.as_strided(r, (4, 4), (0, 1))

self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_gap_no_unit_stride(self):

def test_fn(r):
x = r.view(8, 4)
return torch.as_strided(r, (4, 4), (6, 2))

self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn)

@skipIfFunctionalizationDisabled("arbitrary as_strided unsupported")
def test_as_strided_with_empty_args(self):

def test_fn(r):
return torch.as_strided(r, tuple(), tuple())

self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn)

def test_basic_bfloat16(self):

def test_fn(s):
Expand Down
7 changes: 7 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __new__(cls, name, variant_test_name=""):
{
AllowedOpInfoEntry('abs'),
AllowedOpInfoEntry('add'),
AllowedOpInfoEntry('as_strided'),
AllowedOpInfoEntry('mul'),
AllowedOpInfoEntry('sub'),
AllowedOpInfoEntry('addmm'),
Expand Down Expand Up @@ -349,6 +350,12 @@ def __new__(cls, name, variant_test_name=""):
# AllowedOpInfoEntry('var_mean'),
# AllowedOpInfoEntry('pow'), # for int64 don't work, likely rounding issue
# AllowedOpInfoEntry('__rpow__'),

# In theory, this should work.
# However, the problem is the way we prepare the reference (CPU) inputs:
# we clone them. If they were a view, they are not anymore.
#
# AllowedOpInfoEntry('as_strided', 'partial_views'),
}))


Expand Down
118 changes: 100 additions & 18 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/native/CPUFallback.h>
#include <ATen/native/TypeProperties.h>
#include <ATen/ops/expand_copy.h>
#include <c10/core/Contiguity.h>
#include <torch/csrc/lazy/core/shape_inference.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/core/util.h>
Expand Down Expand Up @@ -706,24 +707,105 @@ at::Tensor XLANativeFunctions::as_strided_copy(
// This function actually operates on the tensor's storage. Since XLA does not
// expose the actual storage, we use the originally allocated tensor.
const at::Tensor& base = bridge::GetXlaTensor(self)->Base();
const at::Tensor& tensor = base.defined() ? base : self;
XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor);
auto xsize = XlaHelpers::I64List(size);
auto xstride = XlaHelpers::I64List(stride);
if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride,
storage_offset.value_or(0))) {
return at::native::call_fallback_fn<
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
storage_offset);
}
// Sets the base tensor as tensor.
// Even though this function copies (without aliasing) tensor, it's still
// treated as a view function in the functionalization layer.
return bridge::AtenFromXlaTensor(bridge::SetBaseTensor(
tensor_methods::as_strided(self_tensor, std::move(xsize),
std::move(xstride),
XlaHelpers::I64Optional(storage_offset)),
tensor));
at::Tensor tensor = base.defined() ? base : self;

// Fast path: PyTorch/XLA implementation for as_strided works only with
// non-overlapping and dense tensors.
if (c10::_compute_non_overlapping_and_dense(size, stride)) {
// Sets the base tensor as tensor.
// Even though this function copies (without aliasing) tensor, it's still
// treated as a view function in the functionalization layer.
return bridge::AtenFromXlaTensor(bridge::SetBaseTensor(
tensor_methods::as_strided(bridge::GetXlaTensor(tensor),
XlaHelpers::I64List(size),
XlaHelpers::I64List(stride),
XlaHelpers::I64Optional(storage_offset)),
tensor));
}

// Slow path: decompose as_strided into indexing (we use take, though)
// operations. We pre-compute the index on CPU, so as to avoid runtime
// overhead.
auto dim = size.size();
auto itemsize = tensor.dtype().itemsize();
int64_t storage_size =
at::detail::computeStorageNbytes(size, stride, itemsize);

XLA_CHECK(tensor.numel() * itemsize >= storage_size)
<< "as_strided: storage not big enough for size " << size << ": "
<< storage_size << " (needed) vs " << tensor.numel() << " (actual).";

if (dim == 0 && tensor.numel() > 0) {
// If there's no specified dimension, return the first element of the
// storage. This behavior is consistent with eager.
return select_copy(view_copy_symint(tensor, {tensor.numel()}), 0, 0);
}

if (storage_size == 0) {
// Return an empty tensor, if no storage is actually needed.
return empty_symint(c10::fromIntArrayRefSlow(size), tensor.scalar_type(),
/* layout= */ c10::nullopt, tensor.device(),
/* pin_memory= */ c10::nullopt,
/* memory_format= */ c10::nullopt);
}

// At this point, the following is true:
XLA_CHECK(storage_size > 0);
XLA_CHECK(tensor.numel() > 0);
XLA_CHECK(dim > 0);

// Index tensor for gathering the needed elements into contiguous data.
//
// PyTorch/XLA, by default, assumes dense and contiguous data. However, when
// specifying strides, that might not be the case.
//
// Therefore, we gather the elements selected by following the size, stride,
// and storage offset, materializing it into contiguous elements.
//
// In order to accomplish that, we create an index tensor. Specifically, we
// create an n-dimensional tensor (n is the number of dimensions of the
// output) of indices. Each element represent the at which position of the
// flattened tensor the desired element is in.

// Example: arange(13).as_strided((2, 2, 2), (3, 4, 5))
//
// Start with a 1-element n-dimensional tensor, initialized with 0:
//
// [[[0]]]
//
std::vector<int64_t> view_shape(dim, 1);
auto index_tensor =
at::tensor({storage_offset.value_or(self.storage_offset())},
at::TensorOptions().dtype(at::kLong))
.view(view_shape);

// Then, add to the index_tensor the offset value introduced for each possible
// index of that corresponding dimension.
//
// - Iteration i=0:
// [[[0]]] + [[[0 * 3]], [[1 * 3]]]
// = [[[0 * 3]], [[1 * 3]]]
// = [[[0]], [[3]]]
//
// - Iteration i=1:
// [[[0]], [[3]]] + [[[0 * 4], [1 * 4]]]
// = [[[0 + 0 * 4], [0 + 1 * 4]], [[3 + 0 * 4], [3 + 1 * 4]]]
// = [[[0], [4]], [[3], [7]]]
//
// - Iteration i=2:
// [[[0], [4]], [[3], [7]]] + [[[0 * 5, 1 * 5]]]
// =[[[0 + 0 * 5, 0 + 1 * 5], [4 + 0 * 5, 4 + 1 * 5]],
// [[3 + 0 * 5, 3 + 1 * 5], [7 + 0 * 5, 7 + 1 * 5]]]
// =[[[0, 5], [4, 9]], [[3, 8], [7, 12]]]
for (int i = 0; i < dim; i++) {
auto vshape = view_shape;
vshape[i] = size[i];
index_tensor =
index_tensor.add((at::arange(size[i]) * stride[i]).view(vshape));
}

// Finally, index the tensor with the computed indices.
return take(tensor, index_tensor.to(tensor.device()));
}

at::Tensor XLANativeFunctions::as_strided_scatter(
Expand Down

0 comments on commit 27a7dd3

Please sign in to comment.