diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 2352257e9a0..8e80b8d0316 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -708,9 +708,14 @@ at::Tensor XLANativeFunctions::as_strided_copy( &xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride, storage_offset); } - return bridge::AtenFromXlaTensor(tensor_methods::as_strided( - self_tensor, std::move(xsize), std::move(xstride), - XlaHelpers::I64Optional(storage_offset))); + // Sets the base tensor as tensor. + // Even though this function copies (without aliasing) tensor, it's still treated + // as a view function in the functionalization layer. + return bridge::AtenFromXlaTensor( + bridge::SetBaseTensor( + tensor_methods::as_strided( + self_tensor, std::move(xsize), std::move(xstride), XlaHelpers::I64Optional(storage_offset)), + tensor)); } at::Tensor XLANativeFunctions::as_strided_scatter( @@ -2796,9 +2801,6 @@ at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim, TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); int64_t start_val = start.has_value() ? start.value() : 0; int64_t end_val = end.has_value() ? end.value() : INT64_MAX; - // Sets the base tensor as self. - // Even though this function copies (without aliasing) self, it's still treated - // as a view function in the functionalization layer. return bridge::AtenFromXlaTensor( bridge::SetBaseTensor( tensor_methods::slice(