Skip to content

Commit

Permalink
Set base tensor of as_strided.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed Nov 22, 2023
1 parent 4cfe21a commit e66a809
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,9 +708,14 @@ at::Tensor XLANativeFunctions::as_strided_copy(
&xla_cpu_fallback, ATEN_OP(as_strided)>::call(self, size, stride,
storage_offset);
}
return bridge::AtenFromXlaTensor(tensor_methods::as_strided(
self_tensor, std::move(xsize), std::move(xstride),
XlaHelpers::I64Optional(storage_offset)));
// Sets the base tensor as tensor.
// Even though this function copies (without aliasing) tensor, it's still treated
// as a view function in the functionalization layer.
return bridge::AtenFromXlaTensor(
bridge::SetBaseTensor(
tensor_methods::as_strided(
self_tensor, std::move(xsize), std::move(xstride), XlaHelpers::I64Optional(storage_offset)),
tensor));
}

at::Tensor XLANativeFunctions::as_strided_scatter(
Expand Down Expand Up @@ -2796,9 +2801,6 @@ at::Tensor XLANativeFunctions::slice_copy(const at::Tensor& self, int64_t dim,
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
int64_t start_val = start.has_value() ? start.value() : 0;
int64_t end_val = end.has_value() ? end.value() : INT64_MAX;
// Sets the base tensor as self.
// Even though this function copies (without aliasing) self, it's still treated
// as a view function in the functionalization layer.
return bridge::AtenFromXlaTensor(
bridge::SetBaseTensor(
tensor_methods::slice(
Expand Down

0 comments on commit e66a809

Please sign in to comment.