Skip to content

Commit

Permalink
rebase quant patch
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu committed Mar 12, 2024
1 parent 4c80495 commit 689e2dc
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 73 deletions.
2 changes: 1 addition & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ http_archive(
"//openxla_patches:cache_urls.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
# "//openxla_patches:quant_dequant_converter.diff",
"//openxla_patches:quant_dequant_converter.diff",
],
strip_prefix = "xla-18cbd2019898d3a7b563aeb73683f0c5a6ce14fd",
urls = [
Expand Down
128 changes: 56 additions & 72 deletions openxla_patches/quant_dequant_converter.diff
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
// TODO(lsy323): This is a patch on the HLO->StableHLO converter, this allows the custom call to
// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize.
// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO.
diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD
index cd6e427c3..9deecdcb2 100644
index a75baa9dad..a09fca4898 100644
--- a/xla/translate/hlo_to_mhlo/BUILD
+++ b/xla/translate/hlo_to_mhlo/BUILD
@@ -68,6 +68,7 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
@@ -40,6 +40,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SparseTensorDialect",
"@tsl//tsl/platform:statusor",
diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
index 5e5e32652..d246383bd 100644
--- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
@@ -669,6 +669,71 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
}

+Type getQuantizedType(mlir::DictionaryAttr& backend_config) {
+ "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
],
)
diff --git a/xla/translate/hlo_to_mhlo/custom_call_importer.cc b/xla/translate/hlo_to_mhlo/custom_call_importer.cc
index 9250747210..9cba4c992a 100644
--- a/xla/translate/hlo_to_mhlo/custom_call_importer.cc
+++ b/xla/translate/hlo_to_mhlo/custom_call_importer.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/AsmParser/AsmParser.h" // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
@@ -93,6 +94,66 @@ absl::StatusOr<mlir::Operation*> ImportRealDynamicSliceOp(

} // namespace

+mlir::Type getQuantizedType(mlir::DictionaryAttr& backend_config) {
+ std::vector<double> scales;
+ std::vector<int64_t> zero_points;
+ int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0;
+ Type storage_type, expressed_type;
+ mlir::Type storage_type, expressed_type;
+
+ auto scales_attr = backend_config.get("scale");
+ if (scales_attr) {
Expand Down Expand Up @@ -61,15 +66,11 @@ index 5e5e32652..d246383bd 100644
+ auto storage_type_attr = backend_config.get("storage_type");
+ if (storage_type_attr) {
+ storage_type = storage_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto expressed_type_attr = backend_config.get("expressed_type");
+ if (expressed_type_attr) {
+ expressed_type = expressed_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto is_signed = storage_type.cast<mlir::IntegerType>().isSignless();
Expand All @@ -85,54 +86,37 @@ index 5e5e32652..d246383bd 100644
+ }
+}
+
absl::StatusOr<mlir::Operation*> ImportCustomCallAsOp(
const HloCustomCallInstruction* instruction, mlir::Location loc,
mlir::Type result_type, mlir::ValueRange operands,
@@ -112,6 +173,30 @@ absl::StatusOr<mlir::Operation*> ImportCustomCallAsOp(
return ImportRealDynamicSliceOp(backend_config_str, loc, result_type,
operands, builder);
}
+
StatusOr<mlir::Operation*> HloFunctionImporter::ImportCustomCallAsOp(
const HloInstruction* instruction, mlir::Location loc,
const Type result_type, mlir::ValueRange operands,
@@ -992,6 +1057,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
"Couldn't parse backend config into a dictionary attribute");

attributes.push_back(builder_->getNamedAttr("backend_config", attr));
+ auto backend_config = attr.cast<mlir::DictionaryAttr>();
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_quantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformQuantizeOp>(
+ loc,
+ mlir::RankedTensorType::get(
+ result_type.cast<RankedTensorType>().getShape(),
+ getQuantizedType(backend_config)),
+ operands)
+ .getOperation();
+ }
+ auto attr = mlir::parseAttribute(backend_config_str, builder->getContext())
+ .dyn_cast<mlir::DictionaryAttr>();
+ if (!attr) {
+ return Internal(
+ "Couldn't parse backend config into a dictionary attribute");
+ }
+ auto backend_config = attr.cast<mlir::DictionaryAttr>();
+ if (custom_call_target == "mhlo.uniform_quantize") {
+ return builder
+ ->create<mlir::mhlo::UniformQuantizeOp>(
+ loc,
+ mlir::RankedTensorType::get(
+ result_type.cast<mlir::RankedTensorType>().getShape(),
+ getQuantizedType(backend_config)),
+ operands)
+ .getOperation();
+ }
+
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_dequantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformDequantizeOp>(
+ loc, result_type, operands) .getOperation();
+ }
}
} else {
attributes.push_back(builder_->getNamedAttr(
diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
index f9edd1272..23a747fb1 100644
--- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>

+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
@@ -41,6 +43,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module,
module.getContext()->loadDialect<mlir::arith::ArithDialect>();
module.getContext()->loadDialect<mlir::func::FuncDialect>();
module.getContext()->loadDialect<mlir::mhlo::MhloDialect>();
+ module.getContext()->loadDialect<mlir::quant::QuantizationDialect>();
+ if (custom_call_target == "mhlo.uniform_dequantize") {
+ return builder
+ ->create<mlir::mhlo::UniformDequantizeOp>(loc, result_type, operands)
+ .getOperation();
+ }
return InvalidArgument("Unsupported MHLO op custom_call %s",
custom_call_target);
}

namespace {

0 comments on commit 689e2dc

Please sign in to comment.