Skip to content

Commit

Permalink
OpenXLA pin update (pytorch#6530)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Feb 14, 2024
1 parent 1d4a8da commit 27d1f70
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 120 deletions.
6 changes: 2 additions & 4 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,11 @@ http_archive(
"//openxla_patches:cache_urls.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_hanging.diff",
"//openxla_patches:quant_dequant_converter.diff",
"//openxla_patches:stablehlo_quant_seralization.diff",
],
strip_prefix = "xla-c08cfb0377e4e33a21bde65950f986a21c8a8199",
strip_prefix = "xla-b166243711f71b0a55daa1eda36b1dc745886784",
urls = [
"https://github.com/openxla/xla/archive/c08cfb0377e4e33a21bde65950f986a21c8a8199.tar.gz",
"https://github.com/openxla/xla/archive/b166243711f71b0a55daa1eda36b1dc745886784.tar.gz",
],
)

Expand Down
20 changes: 2 additions & 18 deletions openxla_patches/cache_urls.diff
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/xla/mlir_hlo/WORKSPACE b/xla/mlir_hlo/WORKSPACE
index cc9eeb64f..b290eb455 100644
index 7078b6cc7..1efd5e33b 100644
--- a/xla/mlir_hlo/WORKSPACE
+++ b/xla/mlir_hlo/WORKSPACE
@@ -35,7 +35,10 @@ http_archive(
Expand All @@ -13,20 +13,4 @@ index cc9eeb64f..b290eb455 100644
+ ],
)

load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps")
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps")
diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl
index a4574d75d..f9ce37094 100644
--- a/third_party/llvm/workspace.bzl
+++ b/third_party/llvm/workspace.bzl
@@ -13,7 +13,9 @@ def repo(name):
strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT),
urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
+ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT),
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
+ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT),
],
build_file = "//third_party/llvm:llvm.BUILD",
patch_file = [

load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
6 changes: 3 additions & 3 deletions openxla_patches/f16_abi_clang.diff
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ index 3f7af5197..ce4491c5d 100644
--- a/xla/service/cpu/runtime_fp16.h
+++ b/xla/service/cpu/runtime_fp16.h
@@ -18,12 +18,7 @@ limitations under the License.

#include <stdint.h>

-// _Float16 always gets us the correct ABI type, so use that if available.
-// AArch64 GCC defines __FLT16_MANT_DIG__ even when _Float16 is not available.
-#if defined(__FLT16_MANT_DIG__) && \
Expand All @@ -16,4 +16,4 @@ index 3f7af5197..ce4491c5d 100644
+#if defined(__x86_64__)
// Older versions of Clang don't have _Float16. Since both float and _Float16
// are passed in the same register we can use the wider type and careful casting
// to conform to x86_64 psABI. This only works with the assumption that we're
// to conform to x86_64 psABI. This only works with the assumption that we're
36 changes: 0 additions & 36 deletions openxla_patches/gpu_hanging.diff

This file was deleted.

27 changes: 14 additions & 13 deletions openxla_patches/quant_dequant_converter.diff
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
// stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize.
// The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO.
diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD
index f74973ae1..8e3f0e06b 100644
index cd6e427c3..9deecdcb2 100644
--- a/xla/translate/hlo_to_mhlo/BUILD
+++ b/xla/translate/hlo_to_mhlo/BUILD
@@ -67,6 +67,7 @@ cc_library(
@@ -68,6 +68,7 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:FuncDialect",
Expand All @@ -14,13 +14,13 @@ index f74973ae1..8e3f0e06b 100644
"@llvm-project//mlir:SparseTensorDialect",
"@tsl//tsl/platform:statusor",
diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
index 08d5f49c8..2f9ad1e0b 100644
index 5e5e32652..d246383bd 100644
--- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc
@@ -664,6 +664,70 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
@@ -669,6 +669,71 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
return importer.ImportInstructionWithLayout(instr, operands, builder, mode);
}

+Type getQuantizedType(mlir::DictionaryAttr& backend_config) {
+ std::vector<double> scales;
+ std::vector<int64_t> zero_points;
Expand Down Expand Up @@ -85,12 +85,13 @@ index 08d5f49c8..2f9ad1e0b 100644
+ }
+}
+
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
const HloInstruction* instruction,
const llvm::SmallVectorImpl<mlir::Value>& operands,
@@ -933,6 +997,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
+
StatusOr<mlir::Operation*> HloFunctionImporter::ImportCustomCallAsOp(
const HloInstruction* instruction, mlir::Location loc,
const Type result_type, mlir::ValueRange operands,
@@ -992,6 +1057,25 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstructionImpl(
"Couldn't parse backend config into a dictionary attribute");

attributes.push_back(builder_->getNamedAttr("backend_config", attr));
+ auto backend_config = attr.cast<mlir::DictionaryAttr>();
+ if (custom_call->custom_call_target() ==
Expand All @@ -115,13 +116,13 @@ index 08d5f49c8..2f9ad1e0b 100644
} else {
attributes.push_back(builder_->getNamedAttr(
diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
index 9f05992c8..03cf4840d 100644
index f9edd1272..23a747fb1 100644
--- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
+++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include <memory>
#include <vector>

+#include "mlir/Dialect/Quant/QuantOps.h"
+#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
Expand All @@ -133,5 +134,5 @@ index 9f05992c8..03cf4840d 100644
module.getContext()->loadDialect<mlir::mhlo::MhloDialect>();
+ module.getContext()->loadDialect<mlir::quant::QuantizationDialect>();
}

namespace {
45 changes: 0 additions & 45 deletions openxla_patches/stablehlo_quant_seralization.diff

This file was deleted.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_libtpu_version = '0.1.dev20240124'
_libtpu_version = '0.1.dev20240213'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'


Expand Down

0 comments on commit 27d1f70

Please sign in to comment.