Skip to content

Commit

Permalink
fix output type of quantize operation
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Jan 10, 2024
1 parent 257f0f5 commit 8ad8b29
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
6 changes: 4 additions & 2 deletions torch_xla/csrc/ops/dequant_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ torch::lazy::NodePtr DequantizeTensor::Clone(
XlaOpVector DequantizeTensor::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::Shape input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::Shape output_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32,
input_shape.dimensions());
// TODO(lsy323): Lower to HLO directly once qdtype is added to HLO.
static const std::string opname = "stablehlo.uniform_dequantize";
auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_,
dtype_, input_shape.element_type());
dtype_, xla::PrimitiveType::F32);

xla::XlaOp output = xla::CustomCall(
input.builder(), opname, {input}, input_shape,
input.builder(), opname, {input}, output_shape,
qparams.SerializeToAttrDictStr(),
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/ops/quant_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/quant_util.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/shape_helper.h"

namespace torch_xla {
Expand Down Expand Up @@ -35,13 +36,16 @@ torch::lazy::NodePtr QuantizeTensor::Clone(torch::lazy::OpList operands) const {
XlaOpVector QuantizeTensor::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::Shape input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::Shape output_shape = xla::ShapeUtil::MakeShape(
GetTorchIntDtypeToHloDtype(dtype_), input_shape.dimensions());

// TODO(lsy323): Lower to HLO directly once qdtype is added to HLO.
static const std::string opname = "stablehlo.uniform_quantize";
auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_,
dtype_, input_shape.element_type());

xla::XlaOp output = xla::CustomCall(
input.builder(), opname, {input}, input_shape,
input.builder(), opname, {input}, output_shape,
qparams.SerializeToAttrDictStr(),
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/runtime/stablehlo_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,13 @@ GetHloDtypeToStablehloDtypeMap() {
return m_;
}

xla::PrimitiveType GetTorchIntDtypeToHloDtype(const std::string& dtype) {
if (dtype == "torch.int8") return xla::PrimitiveType::S8;
if (dtype == "torch.uint8") return xla::PrimitiveType::U8;
if (dtype == "torch.int16") return xla::PrimitiveType::S16;
if (dtype == "torch.int32") return xla::PrimitiveType::S32;
if (dtype == "torch.int64") return xla::PrimitiveType::S64;
XLA_ERROR() << "Unsupported dtype for conversion to Hlo type: " << dtype;
}

} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/stablehlo_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ GetTorchDtypeToStablehloDtypeMap();
const std::unordered_map<xla::PrimitiveType, std::string>&
GetHloDtypeToStablehloDtypeMap();

xla::PrimitiveType GetTorchIntDtypeToHloDtype(const std::string& dtype);

} // namespace torch_xla

#endif

0 comments on commit 8ad8b29

Please sign in to comment.