Skip to content

Commit

Permalink
Backport #6677 to r2.3 (#6737)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
  • Loading branch information
lsy323 and Siyuan Liu authored Mar 13, 2024
1 parent a188f7e commit db30eb3
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .bazelversion
Original file line number Diff line number Diff line change
@@ -1 +1 @@
5.3.0
6.5.0
9 changes: 6 additions & 3 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,13 @@ function run_torch_xla_python_tests() {
PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=$num_devices test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=16 --num_epochs=1 --num_steps=25 --model=resnet18

# single-host-SPMD
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 16 --model=resnet50 --sharding=batch --num_epochs=1 --num_steps=25 --model=resnet18
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
XLA_USE_SPMD=1 PJRT_DEVICE=CUDA torchrun --nnodes=1 --node_rank=0 --nproc_per_node=1 test/spmd/test_train_spmd_imagenet.py --fake_data --batch_size 8 --model=resnet50 --sharding=batch --num_epochs=1 --num_steps=25 --model=resnet18

PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32
# TODO: Reduce BS due to GPU test OOM in CI after pin update to 03/05/2024 (#6677)
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1 --batch_size 32 --test_set_batch_size 32
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# Syncfree SGD optimizer tests
if [ -d ./torch_xla/amp/syncfree ]; then
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ http_archive(
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:quant_dequant_converter.diff",
],
strip_prefix = "xla-419a3d736bdfe6891bc40e9ab5c78b466c8e0dc6",
strip_prefix = "xla-18cbd2019898d3a7b563aeb73683f0c5a6ce14fd",
urls = [
"https://github.com/openxla/xla/archive/419a3d736bdfe6891bc40e9ab5c78b466c8e0dc6.tar.gz",
"https://github.com/openxla/xla/archive/18cbd2019898d3a7b563aeb73683f0c5a6ce14fd.tar.gz",
],
)

Expand Down
128 changes: 56 additions & 72 deletions openxla_patches/quant_dequant_converter.diff
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
// TODO(lsy323): This is a patch on the HLO->StableHLO converter, this allows the custom call to
// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize.
// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO.
diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD
index cd6e427c3..9deecdcb2 100644
index a75baa9dad..a09fca4898 100644
--- a/xla/translate/hlo_to_mhlo/BUILD
+++ b/xla/translate/hlo_to_mhlo/BUILD
@@ -68,6 +68,7 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
@@ -40,6 +40,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:FuncDialect",
+ "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SparseTensorDialect",
"@tsl//tsl/platform:statusor",
diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
index 5e5e32652..d246383bd 100644
--- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
@@ -669,6 +669,71 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
}

+Type getQuantizedType(mlir::DictionaryAttr& backend_config) {
+ "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
],
)
diff --git a/xla/translate/hlo_to_mhlo/custom_call_importer.cc b/xla/translate/hlo_to_mhlo/custom_call_importer.cc
index 9250747210..9cba4c992a 100644
--- a/xla/translate/hlo_to_mhlo/custom_call_importer.cc
+++ b/xla/translate/hlo_to_mhlo/custom_call_importer.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/match.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/AsmParser/AsmParser.h" // from @llvm-project
+#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
@@ -93,6 +94,66 @@ absl::StatusOr<mlir::Operation*> ImportRealDynamicSliceOp(

} // namespace

+mlir::Type getQuantizedType(mlir::DictionaryAttr& backend_config) {
+ std::vector<double> scales;
+ std::vector<int64_t> zero_points;
+ int64_t quantization_dimension = -1, storage_max = 0, storage_min = 0;
+ Type storage_type, expressed_type;
+ mlir::Type storage_type, expressed_type;
+
+ auto scales_attr = backend_config.get("scale");
+ if (scales_attr) {
Expand Down Expand Up @@ -61,15 +66,11 @@ index 5e5e32652..d246383bd 100644
+ auto storage_type_attr = backend_config.get("storage_type");
+ if (storage_type_attr) {
+ storage_type = storage_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto expressed_type_attr = backend_config.get("expressed_type");
+ if (expressed_type_attr) {
+ expressed_type = expressed_type_attr.cast<mlir::TypeAttr>().getValue();
+ //.cast<mlir::ShapedType>()
+ //.getElementType();
+ }
+
+ auto is_signed = storage_type.cast<mlir::IntegerType>().isSignless();
Expand All @@ -85,54 +86,37 @@ index 5e5e32652..d246383bd 100644
+ }
+}
+
absl::StatusOr<mlir::Operation*> ImportCustomCallAsOp(
const HloCustomCallInstruction* instruction, mlir::Location loc,
mlir::Type result_type, mlir::ValueRange operands,
@@ -112,6 +173,30 @@ absl::StatusOr<mlir::Operation*> ImportCustomCallAsOp(
return ImportRealDynamicSliceOp(backend_config_str, loc, result_type,
operands, builder);
}
+
StatusOr<mlir::Operation*> HloFunctionImporter::ImportCustomCallAsOp(
const HloInstruction* instruction, mlir::Location loc,
const Type result_type, mlir::ValueRange operands,
@@ -992,6 +1057,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
"Couldn't parse backend config into a dictionary attribute");

attributes.push_back(builder_->getNamedAttr("backend_config", attr));
+ auto backend_config = attr.cast<mlir::DictionaryAttr>();
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_quantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformQuantizeOp>(
+ loc,
+ mlir::RankedTensorType::get(
+ result_type.cast<RankedTensorType>().getShape(),
+ getQuantizedType(backend_config)),
+ operands)
+ .getOperation();
+ }
+ auto attr = mlir::parseAttribute(backend_config_str, builder->getContext())
+ .dyn_cast<mlir::DictionaryAttr>();
+ if (!attr) {
+ return Internal(
+ "Couldn't parse backend config into a dictionary attribute");
+ }
+ auto backend_config = attr.cast<mlir::DictionaryAttr>();
+ if (custom_call_target == "mhlo.uniform_quantize") {
+ return builder
+ ->create<mlir::mhlo::UniformQuantizeOp>(
+ loc,
+ mlir::RankedTensorType::get(
+ result_type.cast<mlir::RankedTensorType>().getShape(),
+ getQuantizedType(backend_config)),
+ operands)
+ .getOperation();
+ }
+
+ if (custom_call->custom_call_target() ==
+ "stablehlo.uniform_dequantize") {
+ return func_builder
+ ->create<mlir::mhlo::UniformDequantizeOp>(
+ loc, result_type, operands) .getOperation();
+ }
}
} else {
attributes.push_back(builder_->getNamedAttr(
diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
index f9edd1272..23a747fb1 100644
--- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>

+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
@@ -41,6 +43,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module,
module.getContext()->loadDialect<mlir::arith::ArithDialect>();
module.getContext()->loadDialect<mlir::func::FuncDialect>();
module.getContext()->loadDialect<mlir::mhlo::MhloDialect>();
+ module.getContext()->loadDialect<mlir::quant::QuantizationDialect>();
+ if (custom_call_target == "mhlo.uniform_dequantize") {
+ return builder
+ ->create<mlir::mhlo::UniformDequantizeOp>(loc, result_type, operands)
+ .getOperation();
+ }
return InvalidArgument("Unsupported MHLO op custom_call %s",
custom_call_target);
}

namespace {
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240213'
_date = '20240305'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.25.dev{_date}'
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/dequant_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ XlaOpVector DequantizeTensor::Lower(LoweringContext* loctx) const {
xla::Shape output_shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32,
input_shape.dimensions());
// TODO(lsy323): Lower to HLO directly once qdtype is added to HLO.
static const std::string opname = "stablehlo.uniform_dequantize";
static const std::string opname = "mhlo.uniform_dequantize";
auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_,
dtype_, xla::PrimitiveType::F32);

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/quant_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ XlaOpVector QuantizeTensor::Lower(LoweringContext* loctx) const {
GetTorchIntDtypeToHloDtype(dtype_), input_shape.dimensions());

// TODO(lsy323): Lower to HLO directly once qdtype is added to HLO.
static const std::string opname = "stablehlo.uniform_quantize";
static const std::string opname = "mhlo.uniform_quantize";
auto qparams = QuantParams(scale_, zero_point_, quant_min_, quant_max_, axis_,
dtype_, input_shape.element_type());

Expand Down
4 changes: 3 additions & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ std::unordered_map<int, int> build_index_map(
xla::Shape host_output_shape(xla::PjRtBuffer* buffer) {
xla::Shape shape = xla::ShapeUtil::MakeShape(
buffer->element_type(), buffer->logical_dimensions().value());
*shape.mutable_layout() = buffer->layout();
*shape.mutable_layout() =
dynamic_cast<const xla::PjRtXlaLayout*>(buffer->layout().get())
->xla_layout();

return xla::ShapeUtil::DeviceShapeToHostShape(shape);
}
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/runtime/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ namespace profiler {
namespace {

const PLUGIN_Profiler_Api* FindProfilerApi(const PJRT_Api* pjrt_api) {
const PJRT_Structure_Base* next =
reinterpret_cast<const PJRT_Structure_Base*>(pjrt_api->extension_start);
const PJRT_Extension_Base* next =
reinterpret_cast<const PJRT_Extension_Base*>(pjrt_api->extension_start);
while (next != nullptr &&
next->type != PJRT_Structure_Type::PJRT_Structure_Type_Profiler) {
next->type != PJRT_Extension_Type::PJRT_Extension_Type_Profiler) {
next = next->next;
}
if (next == nullptr) {
Expand Down

0 comments on commit db30eb3

Please sign in to comment.