Skip to content

Commit

Permalink
change qdq custom call from stablhlo.qdq to mhlo.qdq to accommodate t…
Browse files Browse the repository at this point in the history
…he change in converter
  • Loading branch information
Siyuan Liu committed Mar 13, 2024
1 parent 402ebdf commit a51fc76
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/dequant_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ XlaOpVector DequantizeTensor::Lower(LoweringContext* loctx) const {
xla::Shape output_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32,
input_shape.dimensions());
// TODO(lsy323): Lower to HLO directly once qdtype is added to HLO.
static const std::string opname = "stablehlo.uniform_dequantize";
static const std::string opname = "mhlo.uniform_dequantize";
auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_,
dtype_, xla::PrimitiveType::F32);

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/quant_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ XlaOpVector QuantizeTensor::Lower(LoweringContext* loctx) const {
GetTorchIntDtypeToHloDtype(dtype_), input_shape.dimensions());

// TODO(lsy323): Lower to HLO directly once qdtype is added to HLO.
static const std::string opname = "stablehlo.uniform_quantize";
static const std::string opname = "mhlo.uniform_quantize";
auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_,
dtype_, input_shape.element_type());

Expand Down

0 comments on commit a51fc76

Please sign in to comment.