Skip to content

Commit

Permalink
fix output type of quantize operation
Browse files Browse the repository at this point in the history
  • Loading branch information
sdasgup3 committed Jan 10, 2024
1 parent 257f0f5 commit 9da267d
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 15 deletions.
6 changes: 4 additions & 2 deletions torch_xla/csrc/ops/dequant_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ torch::lazy::NodePtr DequantizeTensor::Clone(
XlaOpVector DequantizeTensor::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::Shape input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::Shape output_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32,
input_shape.dimensions());
// TODO(lsy323): Lower to HLO directly once qdtype is added to HLO.
static const std::string opname = "stablehlo.uniform_dequantize";
auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_,
dtype_, input_shape.element_type());
dtype_, xla::PrimitiveType::F32);

xla::XlaOp output = xla::CustomCall(
input.builder(), opname, {input}, input_shape,
input.builder(), opname, {input}, output_shape,
qparams.SerializeToAttrDictStr(),
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/ops/quant_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/quant_util.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
#include "torch_xla/csrc/shape_helper.h"

namespace torch_xla {
Expand Down Expand Up @@ -35,13 +36,16 @@ torch::lazy::NodePtr QuantizeTensor::Clone(torch::lazy::OpList operands) const {
XlaOpVector QuantizeTensor::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::Shape input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::Shape output_shape = xla::ShapeUtil::MakeShape(
GetTorchIntDtypeToHloDtype(dtype_), input_shape.dimensions());

// TODO(lsy323): Lower to HLO directly once qdtype is added to HLO.
static const std::string opname = "stablehlo.uniform_quantize";
auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_,
dtype_, input_shape.element_type());

xla::XlaOp output = xla::CustomCall(
input.builder(), opname, {input}, input_shape,
input.builder(), opname, {input}, output_shape,
qparams.SerializeToAttrDictStr(),
/*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
Expand Down
6 changes: 5 additions & 1 deletion torch_xla/csrc/quant_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <iostream>
#include <unordered_map>

#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"

namespace torch_xla {
Expand Down Expand Up @@ -44,7 +45,10 @@ std::string QuantParams::SerializeToAttrDictStr() const {
}
ss << "scale=" << SeralizeFloatVector<float>(scale, true) << ',';
ss << "zero_point=" << SeralizeFloatVector<int>(zero_point) << ',';
ss << "storage_type=" << GetTorchDtypeToStablehloDtypeMap().at(dtype) << ',';
ss << "storage_type=" << GetTorchDtypeToStablehloDtype(dtype) << ',';
if (!GetHloDtypeToStablehloDtypeMap().count(expressed_type))
XLA_ERROR() << "Unsupported dtype for conversion from Hlo to Stablehlo: "
<< dtype;
ss << "expressed_type=" << GetHloDtypeToStablehloDtypeMap().at(expressed_type)
<< ',';
ss << "storage_min=" << quant_min << ',';
Expand Down
27 changes: 18 additions & 9 deletions torch_xla/csrc/runtime/stablehlo_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,14 @@ void ConvertStableHloToHlo(mlir::ModuleOp* mlir_module,
<< getMlirModuleStr(*mlir_module);
}

const std::unordered_map<std::string, std::string>&
GetTorchDtypeToStablehloDtypeMap() {
static const std::unordered_map<std::string, std::string> m_{
{"torch.int8", "si8"},
{"torch.uint8", "ui8"},
{"torch.int16", "si16"},
};
return m_;
const std::string GetTorchDtypeToStablehloDtype(const std::string& dtype) {
if (dtype == "torch.int8") return "si8";
if (dtype == "torch.uint8") return "ui8";
if (dtype == "torch.int16") return "si16";
if (dtype == "torch.int32") return "si32";
if (dtype == "torch.int64") return "si64";
XLA_ERROR() << "Unsupported dtype for conversion to Stablehlo type: "
<< dtype;
}

const std::unordered_map<xla::PrimitiveType, std::string>&
Expand All @@ -187,9 +187,18 @@ GetHloDtypeToStablehloDtypeMap() {
{xla::PrimitiveType::U8, "ui8"}, {xla::PrimitiveType::U16, "ui16"},
{xla::PrimitiveType::U32, "ui32"}, {xla::PrimitiveType::U64, "ui64"},
{xla::PrimitiveType::F16, "f16"}, {xla::PrimitiveType::BF16, "bf16"},
{xla::PrimitiveType::F32, "f32"},
{xla::PrimitiveType::F32, "f32"}, {xla::PrimitiveType::F64, "f64"},
};
return m_;
}

xla::PrimitiveType GetTorchIntDtypeToHloDtype(const std::string& dtype) {
if (dtype == "torch.int8") return xla::PrimitiveType::S8;
if (dtype == "torch.uint8") return xla::PrimitiveType::U8;
if (dtype == "torch.int16") return xla::PrimitiveType::S16;
if (dtype == "torch.int32") return xla::PrimitiveType::S32;
if (dtype == "torch.int64") return xla::PrimitiveType::S64;
XLA_ERROR() << "Unsupported dtype for conversion to Hlo type: " << dtype;
}

} // namespace torch_xla
5 changes: 3 additions & 2 deletions torch_xla/csrc/runtime/stablehlo_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ void ConvertStableHloToHlo(mlir::ModuleOp* mlir_module,

std::string GetHloModuleStr(const xla::HloModuleProto* proto);

const std::unordered_map<std::string, std::string>&
GetTorchDtypeToStablehloDtypeMap();
const std::string GetTorchDtypeToStablehloDtype(const std::string& dtype);

const std::unordered_map<xla::PrimitiveType, std::string>&
GetHloDtypeToStablehloDtypeMap();

xla::PrimitiveType GetTorchIntDtypeToHloDtype(const std::string& dtype);

} // namespace torch_xla

#endif

0 comments on commit 9da267d

Please sign in to comment.