diff --git a/torch_xla/csrc/ops/dequant_tensor.cpp b/torch_xla/csrc/ops/dequant_tensor.cpp index bfafcb34e23..dd074ad2630 100644 --- a/torch_xla/csrc/ops/dequant_tensor.cpp +++ b/torch_xla/csrc/ops/dequant_tensor.cpp @@ -39,7 +39,7 @@ XlaOpVector DequantizeTensor::Lower(LoweringContext* loctx) const { xla::Shape output_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, input_shape.dimensions()); // TODO(lsy323): Lower to HLO directly once qdtype is added to HLO. - static const std::string opname = "stablehlo.uniform_dequantize"; + static const std::string opname = "mhlo.uniform_dequantize"; auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_, dtype_, xla::PrimitiveType::F32); diff --git a/torch_xla/csrc/ops/quant_tensor.cpp b/torch_xla/csrc/ops/quant_tensor.cpp index dcd1da53596..4845168be34 100644 --- a/torch_xla/csrc/ops/quant_tensor.cpp +++ b/torch_xla/csrc/ops/quant_tensor.cpp @@ -40,7 +40,7 @@ XlaOpVector QuantizeTensor::Lower(LoweringContext* loctx) const { GetTorchIntDtypeToHloDtype(dtype_), input_shape.dimensions()); // TODO(lsy323): Lower to HLO directly once qdtype is added to HLO. - static const std::string opname = "stablehlo.uniform_quantize"; + static const std::string opname = "mhlo.uniform_quantize"; auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_, dtype_, input_shape.element_type());