From 27a7dd3530e661ef196a0b241cdd065473f571e9 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Tue, 19 Mar 2024 10:35:02 -0300 Subject: [PATCH] Re-land: Make `as_strided_copy` materialize a new tensor with `index`. (#6697) --- test/test_operations.py | 83 ++++++++++++++++++++-- test/test_ops.py | 7 ++ torch_xla/csrc/aten_xla_type.cpp | 118 ++++++++++++++++++++++++++----- 3 files changed, 184 insertions(+), 24 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 8958f519a4b..ce0c94417e0 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -66,6 +66,22 @@ def _is_on_eager_debug_mode(): 'skip on eager debug mode') +def _skipIfFunctionalization(value=True, reason=""): + verb = "is" if value else "is not" + reason = f" Reason: {reason}" if reason else "" + return unittest.skipIf( + XLA_DISABLE_FUNCTIONALIZATION is value, + f'Works only when functionalization {verb} disabled.{reason}.') + + +def skipIfFunctionalizationEnabled(reason): + return _skipIfFunctionalization(value=False, reason=reason) + + +def skipIfFunctionalizationDisabled(reason): + return _skipIfFunctionalization(value=True, reason=reason) + + def _gen_tensor(*args, **kwargs): return torch.randn(*args, **kwargs) @@ -977,8 +993,8 @@ def func(a, b): # TODO - upstream behavior has changed and results in expected DestroyXlaTensor # counter as of 11/13/2023. Re-enable after reviewing the change. - @unittest.skipIf(True or XLA_DISABLE_FUNCTIONALIZATION, - 'Metrics differ when functionalization is disabled.') + # @skipIfFunctionalizationDisabled("metrics differ") + @unittest.skip def test_set(self): met.clear_all() @@ -996,8 +1012,7 @@ def test_set(self): # shouldn't crash self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10))) - @unittest.skipIf(XLA_DISABLE_FUNCTIONALIZATION, - 'Metrics differ when functionalization is disabled.') + @skipIfFunctionalizationDisabled("metrics differ") def test_replace_xla_tensor(self): met.clear_all() @@ -1340,8 +1355,7 @@ def test_fn(t, c): ), dtype=torch.int64) self.runAtenTest([token_type_ids, cat_ids], test_fn) - @unittest.skipIf(not XLA_DISABLE_FUNCTIONALIZATION, - 'When functionalization is enabled, views do not exist.') + @skipIfFunctionalizationEnabled("views do not exist") def test_save_view_alias_check(self): class Nested(object): @@ -1497,6 +1511,63 @@ def test_fn(r): self.runAtenTest([torch.arange(144, dtype=torch.int32)], test_fn) + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") + def test_as_strided_with_gap(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (8, 1)) + + self.runAtenTest([torch.arange(28, dtype=torch.int32)], test_fn) + + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") + def test_as_strided_with_gap_no_unit_stride(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (8, 2)) + + self.runAtenTest([torch.arange(31, dtype=torch.int32)], test_fn) + + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") + def test_as_strided_with_overlap(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (2, 1)) + + self.runAtenTest([torch.arange(10, dtype=torch.int32)], test_fn) + + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") + def test_as_strided_with_overlap_and_gap(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (4, 2)) + + self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn) + + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") + def test_as_strided_with_overlap_zero_stride(self): + + def test_fn(r): + return torch.as_strided(r, (4, 4), (0, 1)) + + self.runAtenTest([torch.arange(19, dtype=torch.int32)], test_fn) + + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") + def test_as_strided_with_gap_no_unit_stride(self): + + def test_fn(r): + x = r.view(8, 4) + return torch.as_strided(r, (4, 4), (6, 2)) + + self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn) + + @skipIfFunctionalizationDisabled("arbitrary as_strided unsupported") + def test_as_strided_with_empty_args(self): + + def test_fn(r): + return torch.as_strided(r, tuple(), tuple()) + + self.runAtenTest([torch.arange(32, dtype=torch.int32)], test_fn) + def test_basic_bfloat16(self): def test_fn(s): diff --git a/test/test_ops.py b/test/test_ops.py index a3db0a91cd1..12b874593bd 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -29,6 +29,7 @@ def __new__(cls, name, variant_test_name=""): { AllowedOpInfoEntry('abs'), AllowedOpInfoEntry('add'), + AllowedOpInfoEntry('as_strided'), AllowedOpInfoEntry('mul'), AllowedOpInfoEntry('sub'), AllowedOpInfoEntry('addmm'), @@ -349,6 +350,12 @@ def __new__(cls, name, variant_test_name=""): # AllowedOpInfoEntry('var_mean'), # AllowedOpInfoEntry('pow'), # for int64 don't work, likely rounding issue # AllowedOpInfoEntry('__rpow__'), + + # In theory, this should work. + # However, the problem is the way we prepare the reference (CPU) inputs: + # we clone them. If they were a view, they are not anymore. + # + # AllowedOpInfoEntry('as_strided', 'partial_views'), })) diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 855747ed8fe..ec9c7a54f12 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -706,24 +707,105 @@ at::Tensor XLANativeFunctions::as_strided_copy( // This function actually operates on the tensor's storage. Since XLA does not // expose the actual storage, we use the originally allocated tensor. const at::Tensor& base = bridge::GetXlaTensor(self)->Base(); - const at::Tensor& tensor = base.defined() ? base : self; - XLATensorPtr self_tensor = bridge::GetXlaTensor(tensor); - auto xsize = XlaHelpers::I64List(size); - auto xstride = XlaHelpers::I64List(stride); - if (!AsStrided::StrideIsSupported(self_tensor->shape(), xsize, xstride, - storage_offset.value_or(0))) { - return at::native::call_fallback_fn< - &xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride, - storage_offset); - } - // Sets the base tensor as tensor. - // Even though this function copies (without aliasing) tensor, it's still - // treated as a view function in the functionalization layer. - return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( - tensor_methods::as_strided(self_tensor, std::move(xsize), - std::move(xstride), - XlaHelpers::I64Optional(storage_offset)), - tensor)); + at::Tensor tensor = base.defined() ? base : self; + + // Fast path: PyTorch/XLA implementation for as_strided works only with + // non-overlapping and dense tensors. + if (c10::_compute_non_overlapping_and_dense(size, stride)) { + // Sets the base tensor as tensor. + // Even though this function copies (without aliasing) tensor, it's still + // treated as a view function in the functionalization layer. + return bridge::AtenFromXlaTensor(bridge::SetBaseTensor( + tensor_methods::as_strided(bridge::GetXlaTensor(tensor), + XlaHelpers::I64List(size), + XlaHelpers::I64List(stride), + XlaHelpers::I64Optional(storage_offset)), + tensor)); + } + + // Slow path: decompose as_strided into indexing (we use take, though) + // operations. We pre-compute the index on CPU, so as to avoid runtime + // overhead. + auto dim = size.size(); + auto itemsize = tensor.dtype().itemsize(); + int64_t storage_size = + at::detail::computeStorageNbytes(size, stride, itemsize); + + XLA_CHECK(tensor.numel() * itemsize >= storage_size) + << "as_strided: storage not big enough for size " << size << ": " + << storage_size << " (needed) vs " << tensor.numel() << " (actual)."; + + if (dim == 0 && tensor.numel() > 0) { + // If there's no specified dimension, return the first element of the + // storage. This behavior is consistent with eager. + return select_copy(view_copy_symint(tensor, {tensor.numel()}), 0, 0); + } + + if (storage_size == 0) { + // Return an empty tensor, if no storage is actually needed. + return empty_symint(c10::fromIntArrayRefSlow(size), tensor.scalar_type(), + /* layout= */ c10::nullopt, tensor.device(), + /* pin_memory= */ c10::nullopt, + /* memory_format= */ c10::nullopt); + } + + // At this point, the following is true: + XLA_CHECK(storage_size > 0); + XLA_CHECK(tensor.numel() > 0); + XLA_CHECK(dim > 0); + + // Index tensor for gathering the needed elements into contiguous data. + // + // PyTorch/XLA, by default, assumes dense and contiguous data. However, when + // specifying strides, that might not be the case. + // + // Therefore, we gather the elements selected by following the size, stride, + // and storage offset, materializing it into contiguous elements. + // + // In order to accomplish that, we create an index tensor. Specifically, we + // create an n-dimensional tensor (n is the number of dimensions of the + // output) of indices. Each element represent the at which position of the + // flattened tensor the desired element is in. + + // Example: arange(13).as_strided((2, 2, 2), (3, 4, 5)) + // + // Start with a 1-element n-dimensional tensor, initialized with 0: + // + // [[[0]]] + // + std::vector view_shape(dim, 1); + auto index_tensor = + at::tensor({storage_offset.value_or(self.storage_offset())}, + at::TensorOptions().dtype(at::kLong)) + .view(view_shape); + + // Then, add to the index_tensor the offset value introduced for each possible + // index of that corresponding dimension. + // + // - Iteration i=0: + // [[[0]]] + [[[0 * 3]], [[1 * 3]]] + // = [[[0 * 3]], [[1 * 3]]] + // = [[[0]], [[3]]] + // + // - Iteration i=1: + // [[[0]], [[3]]] + [[[0 * 4], [1 * 4]]] + // = [[[0 + 0 * 4], [0 + 1 * 4]], [[3 + 0 * 4], [3 + 1 * 4]]] + // = [[[0], [4]], [[3], [7]]] + // + // - Iteration i=2: + // [[[0], [4]], [[3], [7]]] + [[[0 * 5, 1 * 5]]] + // =[[[0 + 0 * 5, 0 + 1 * 5], [4 + 0 * 5, 4 + 1 * 5]], + // [[3 + 0 * 5, 3 + 1 * 5], [7 + 0 * 5, 7 + 1 * 5]]] + // =[[[0, 5], [4, 9]], [[3, 8], [7, 12]]] + for (int i = 0; i < dim; i++) { + auto vshape = view_shape; + vshape[i] = size[i]; + index_tensor = + index_tensor.add((at::arange(size[i]) * stride[i]).view(vshape)); + } + + // Finally, index the tensor with the computed indices. + return take(tensor, index_tensor.to(tensor.device())); } at::Tensor XLANativeFunctions::as_strided_scatter(