diff --git a/WORKSPACE b/WORKSPACE index dc4c1eaab47..bf39d2c4113 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -49,13 +49,11 @@ http_archive( "//openxla_patches:cache_urls.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:gpu_hanging.diff", "//openxla_patches:quant_dequant_converter.diff", - "//openxla_patches:stablehlo_quant_seralization.diff", ], - strip_prefix = "xla-c08cfb0377e4e33a21bde65950f986a21c8a8199", + strip_prefix = "xla-b166243711f71b0a55daa1eda36b1dc745886784", urls = [ - "https://github.com/openxla/xla/archive/c08cfb0377e4e33a21bde65950f986a21c8a8199.tar.gz", + "https://github.com/openxla/xla/archive/b166243711f71b0a55daa1eda36b1dc745886784.tar.gz", ], ) diff --git a/openxla_patches/cache_urls.diff b/openxla_patches/cache_urls.diff index 72cd103b92c..7197645055a 100644 --- a/openxla_patches/cache_urls.diff +++ b/openxla_patches/cache_urls.diff @@ -1,5 +1,5 @@ diff --git a/xla/mlir_hlo/WORKSPACE b/xla/mlir_hlo/WORKSPACE -index cc9eeb64f..b290eb455 100644 +index 7078b6cc7..1efd5e33b 100644 --- a/xla/mlir_hlo/WORKSPACE +++ b/xla/mlir_hlo/WORKSPACE @@ -35,7 +35,10 @@ http_archive( @@ -13,20 +13,4 @@ index cc9eeb64f..b290eb455 100644 + ], ) - load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps") - load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure", "llvm_disable_optional_support_deps") -diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index a4574d75d..f9ce37094 100644 ---- a/third_party/llvm/workspace.bzl -+++ b/third_party/llvm/workspace.bzl -@@ -13,7 +13,9 @@ def repo(name): - strip_prefix = "llvm-project-{commit}".format(commit = LLVM_COMMIT), - urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), -+ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT), - "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), -+ "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT), - ], - build_file = "//third_party/llvm:llvm.BUILD", - patch_file = [ - + load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") diff --git a/openxla_patches/f16_abi_clang.diff b/openxla_patches/f16_abi_clang.diff index 24cc8e5b74d..07565fde30a 100644 --- a/openxla_patches/f16_abi_clang.diff +++ b/openxla_patches/f16_abi_clang.diff @@ -4,9 +4,9 @@ index 3f7af5197..ce4491c5d 100644 --- a/xla/service/cpu/runtime_fp16.h +++ b/xla/service/cpu/runtime_fp16.h @@ -18,12 +18,7 @@ limitations under the License. - + #include - + -// _Float16 always gets us the correct ABI type, so use that if available. -// AArch64 GCC defines __FLT16_MANT_DIG__ even when _Float16 is not available. -#if defined(__FLT16_MANT_DIG__) && \ @@ -16,4 +16,4 @@ index 3f7af5197..ce4491c5d 100644 +#if defined(__x86_64__) // Older versions of Clang don't have _Float16. Since both float and _Float16 // are passed in the same register we can use the wider type and careful casting - // to conform to x86_64 psABI. This only works with the assumption that we're \ No newline at end of file + // to conform to x86_64 psABI. This only works with the assumption that we're diff --git a/openxla_patches/gpu_hanging.diff b/openxla_patches/gpu_hanging.diff deleted file mode 100644 index f64d18084b4..00000000000 --- a/openxla_patches/gpu_hanging.diff +++ /dev/null @@ -1,36 +0,0 @@ -// This patch is for https://github.com/openxla/xla/commit/ec0177de1748b4ebb0ecbd6f26043fdb1eb47d24. -// It can be removed in the next openXLA pin update after 01/26/2024. -diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc -index 0f1818be2..c181f3025 100644 ---- a/xla/service/gpu/gpu_executable.cc -+++ b/xla/service/gpu/gpu_executable.cc -@@ -382,9 +382,13 @@ absl::Status ExecuteThunks(const std::string& module_name, - } - } - -- // Maybe join a round of rendezvous after thunk initialization. -- TF_RETURN_IF_ERROR( -- MaybeRendezvousAfterInitialization(run_options, thunks_initialized)); -+ // Maybe join a round of rendezvous after thunk initialization. We do this -+ // only in presence of collective cliques which means that we have collective -+ // operations in the XLA operations that tend to cause deadlocks. -+ if (!collective_cliques.empty()) { -+ TF_RETURN_IF_ERROR( -+ MaybeRendezvousAfterInitialization(run_options, thunks_initialized)); -+ } - - // Prepare parameters for thunks execution. - Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create( -diff --git a/xla/service/gpu/thunk.h b/xla/service/gpu/thunk.h -index 51a566b8f..94bab421f 100644 ---- a/xla/service/gpu/thunk.h -+++ b/xla/service/gpu/thunk.h -@@ -175,6 +175,8 @@ class Thunk { - absl::StatusOr GetComm(const NcclCliqueKey& clique_key, - int32_t rank) const; - -+ bool empty() const { return cliques_map_.empty(); } -+ - private: - CliquesMap cliques_map_; - }; diff --git a/openxla_patches/quant_dequant_converter.diff b/openxla_patches/quant_dequant_converter.diff index d35e36a5e22..cd3d0805741 100644 --- a/openxla_patches/quant_dequant_converter.diff +++ b/openxla_patches/quant_dequant_converter.diff @@ -2,10 +2,10 @@ // stablehlo.uniform_quantize/dequantize to be converted to stablehlo.uniform_quantize/dequantize. // The patch can be removed after quantize/dequantize, quantized dtype support is added to HLO. diff --git a/xla/translate/hlo_to_mhlo/BUILD b/xla/translate/hlo_to_mhlo/BUILD -index f74973ae1..8e3f0e06b 100644 +index cd6e427c3..9deecdcb2 100644 --- a/xla/translate/hlo_to_mhlo/BUILD +++ b/xla/translate/hlo_to_mhlo/BUILD -@@ -67,6 +67,7 @@ cc_library( +@@ -68,6 +68,7 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:FuncDialect", @@ -14,13 +14,13 @@ index f74973ae1..8e3f0e06b 100644 "@llvm-project//mlir:SparseTensorDialect", "@tsl//tsl/platform:statusor", diff --git a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc -index 08d5f49c8..2f9ad1e0b 100644 +index 5e5e32652..d246383bd 100644 --- a/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_function_importer.cc -@@ -664,6 +664,70 @@ StatusOr HloFunctionImporter::ImportInstruction( +@@ -669,6 +669,71 @@ StatusOr HloFunctionImporter::ImportInstruction( return importer.ImportInstructionWithLayout(instr, operands, builder, mode); } - + +Type getQuantizedType(mlir::DictionaryAttr& backend_config) { + std::vector scales; + std::vector zero_points; @@ -85,12 +85,13 @@ index 08d5f49c8..2f9ad1e0b 100644 + } +} + - StatusOr HloFunctionImporter::ImportInstructionImpl( - const HloInstruction* instruction, - const llvm::SmallVectorImpl& operands, -@@ -933,6 +997,25 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( ++ + StatusOr HloFunctionImporter::ImportCustomCallAsOp( + const HloInstruction* instruction, mlir::Location loc, + const Type result_type, mlir::ValueRange operands, +@@ -992,6 +1057,25 @@ StatusOr HloFunctionImporter::ImportInstructionImpl( "Couldn't parse backend config into a dictionary attribute"); - + attributes.push_back(builder_->getNamedAttr("backend_config", attr)); + auto backend_config = attr.cast(); + if (custom_call->custom_call_target() == @@ -115,13 +116,13 @@ index 08d5f49c8..2f9ad1e0b 100644 } else { attributes.push_back(builder_->getNamedAttr( diff --git a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc -index 9f05992c8..03cf4840d 100644 +index f9edd1272..23a747fb1 100644 --- a/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include - + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/IR/Attributes.h" // from @llvm-project @@ -133,5 +134,5 @@ index 9f05992c8..03cf4840d 100644 module.getContext()->loadDialect(); + module.getContext()->loadDialect(); } - + namespace { diff --git a/openxla_patches/stablehlo_quant_seralization.diff b/openxla_patches/stablehlo_quant_seralization.diff deleted file mode 100644 index fc4328dcfa7..00000000000 --- a/openxla_patches/stablehlo_quant_seralization.diff +++ /dev/null @@ -1,45 +0,0 @@ -// TODO(lsy323): This patch is needed to serialize stablehlo.uniform_quantize/dequantize in bytecode format -// This patch can be removed after https://github.com/openxla/stablehlo/issues/1812 is fixed. -diff --git a/third_party/stablehlo/stablehlo_quant_seralization.patch b/third_party/stablehlo/stablehlo_quant_seralization.patch -new file mode 100644 -index 000000000..24e23b67d ---- /dev/null -+++ b/third_party/stablehlo/stablehlo_quant_seralization.patch -@@ -0,0 +1,26 @@ -+diff --git a/stablehlo/api/PortableApi.cpp b/stablehlo/api/PortableApi.cpp -+index 07c856db..cd169cae 100644 -+--- a/stablehlo/api/PortableApi.cpp -++++ b/stablehlo/api/PortableApi.cpp -+@@ -15,10 +15,13 @@ limitations under the License. -+ -+ #include "stablehlo/api/PortableApi.h" -+ -++#include -+ #include -+ -+ #include "mlir/Bytecode/BytecodeWriter.h" -+ #include "mlir/Dialect/Func/IR/FuncOps.h" -++#include "mlir/Dialect/Quant/QuantOps.h" -++#include "mlir/Dialect/Quant/QuantTypes.h" -+ #include "mlir/IR/MLIRContext.h" -+ #include "mlir/Parser/Parser.h" -+ #include "stablehlo/dialect/Serialization.h" -+@@ -33,6 +36,7 @@ void loadSerializationDialects(MLIRContext* context) { -+ context->loadDialect(); -+ context->loadDialect(); -+ context->loadDialect(); -++ context->loadDialect(); -+ } -+ } // namespace -+ -diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl -index 9f4494aac..64fa072bb 100644 ---- a/third_party/stablehlo/workspace.bzl -+++ b/third_party/stablehlo/workspace.bzl -@@ -15,5 +15,6 @@ def repo(): - urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)), - patch_file = [ - "//third_party/stablehlo:temporary.patch", # Autogenerated, don't remove. -+ "//third_party/stablehlo:stablehlo_quant_seralization.patch", # Load quant dialect. - ], - ) diff --git a/setup.py b/setup.py index 94e1d048914..3b67330a1b2 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_libtpu_version = '0.1.dev20240124' +_libtpu_version = '0.1.dev20240213' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'