Skip to content

Commit

Permalink
Add manual dtype conversion for aten_reciprocal
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jan 22, 2024
1 parent 5b60a40 commit 34786e8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
4 changes: 4 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,10 @@ torch_xla::XlaOpVector NeTensor::Lower(LoweringContext* loctx) const {

torch_xla::XlaOpVector Reciprocal::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(BuildReciprocal(xla_input), loctx);
}

Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,11 @@ xla::Shape NeTensorOutputShape(const torch::lazy::Value& self,
}

xla::Shape ReciprocalOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
xla::Shape result_shape = GetXlaShape(input);
if (xla::primitive_util::IsIntegralType(result_shape.element_type())) {
result_shape.set_element_type(xla::PrimitiveType::F32);
}
return result_shape;
}

xla::Shape ReluOutputShape(const torch::lazy::Value& input) {
Expand Down

0 comments on commit 34786e8

Please sign in to comment.