From 83f3fd6e86cb676607c5adc3e94ec2678ea65068 Mon Sep 17 00:00:00 2001 From: Clemens Giuliani Date: Fri, 29 Mar 2024 11:49:47 -0700 Subject: [PATCH] PR #7849: [XLA:CPU] Add support for cross-process collectives using mpi. Imported from GitHub PR https://github.com/openxla/xla/pull/7849 Mpi collectives as proposed in https://github.com/google/jax/issues/11182?notification_referrer_id=NT_kwDOAG8zGbIzODQ5MDcxMzM0OjcyODc1Nzc#issuecomment-1851591135. I only implemented the inter-process communication and this does not yet support more than 1 threads per process. Adding support for multiple threads/devices per process in the future seems quite a bit more involved if one wanted to do it properly. For MPI I am building and linking against https://github.com/eschnett/MPItrampoline, which dlopens the (wrapped) mpi library at runtime. To wrap and load the desired mpi library one needs compile https://github.com/eschnett/MPIwrapper and set `MPITRAMPOLINE_LIB=/path/to/libmpiwrapper.so`. @hawkinsp Copybara import of the project: -- b74bbb909d902bd30523f943a7c15f2c754cf98a by Clemens Giuliani : add mpi collectives -- 23508eb46848464f6711dd8f3f91830ea1adb16d by Clemens Giuliani : add explicit Init and Finalize methods and export them to python -- bbe5840b8eb56a306a66ed03d701fd8976e01491 by Clemens Giuliani : add comment -- 38d156282ecc89509f4b21d80db1a37cb290437a by Clemens Giuliani : fix windows build -- 201f7238f166197ede5cf5d4d70e117a91eddcd7 by Clemens Giuliani : fmt -- 2784869df650c1c123c346401db2f67cb153b03e by Clemens Giuliani : bump xla_extension_version Merging this change closes #7849 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/7849 from inailuig:mpi_collectives 2784869df650c1c123c346401db2f67cb153b03e PiperOrigin-RevId: 620302264 --- third_party/mpitrampoline/BUILD | 1 + third_party/mpitrampoline/gen.patch | 149 +++++++++ third_party/mpitrampoline/mpitrampoline.BUILD | 135 +++++++++ third_party/mpitrampoline/workspace.bzl | 18 ++ workspace2.bzl | 2 + xla/pjrt/cpu/BUILD | 32 ++ xla/pjrt/cpu/mpi_collectives.cc | 283 ++++++++++++++++++ xla/pjrt/cpu/mpi_collectives.h | 102 +++++++ xla/python/BUILD | 6 + xla/python/xla.cc | 22 ++ xla/python/xla_client.py | 2 +- 11 files changed, 751 insertions(+), 1 deletion(-) create mode 100644 third_party/mpitrampoline/BUILD create mode 100644 third_party/mpitrampoline/gen.patch create mode 100644 third_party/mpitrampoline/mpitrampoline.BUILD create mode 100644 third_party/mpitrampoline/workspace.bzl create mode 100644 xla/pjrt/cpu/mpi_collectives.cc create mode 100644 xla/pjrt/cpu/mpi_collectives.h diff --git a/third_party/mpitrampoline/BUILD b/third_party/mpitrampoline/BUILD new file mode 100644 index 0000000000000..3c413807167ae --- /dev/null +++ b/third_party/mpitrampoline/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/mpitrampoline/gen.patch b/third_party/mpitrampoline/gen.patch new file mode 100644 index 0000000000000..35124db0abb1e --- /dev/null +++ b/third_party/mpitrampoline/gen.patch @@ -0,0 +1,149 @@ +diff --git a/gen/gen_decl.py b/gen/gen_decl.py +index 1005b95..696b4e0 100755 +--- a/gen/gen_decl.py ++++ b/gen/gen_decl.py +@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False +@@ -24,7 +24,7 @@ def wrap(line): + lines.append(line) + return "\n".join(lines) + +-with open("include/mpi_decl_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Declare C MPI constants\n") + file.write("\n") + for (tp, nm) in constants: +@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file: + 'mpi_nm': nm} + file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs)) + +-with open("include/mpi_decl_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Declare C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file: + file.write(Template("\n".join(tmpl)).substitute(subs)) + file.write("\n") + +-with open("include/mpi_decl_constants_fortran.h", "w") as file: ++if False: + file.write("! Declare Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file: + file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl))) + file.write("\n") + +-with open("include/mpi_decl_functions_fortran.h", "w") as file: ++if False: + file.write("! Declare Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: +diff --git a/gen/gen_defn.py b/gen/gen_defn.py +index bf31f35..318222e 100755 +--- a/gen/gen_defn.py ++++ b/gen/gen_defn.py +@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False + replace_sentinels = False + +-with open("src/mpi_defn_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Define C MPI constants") + file.write("\n") + for (tp, nm) in constants: +@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file: + 'mpi_nm': nm} + file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs)) + +-with open("src/mpi_defn_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Define C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file: + file.write(Template("\n".join(tmpl)).substitute(subs)) + file.write("\n") + +-with open("src/mpi_defn_constants_fortran.h", "w") as file: ++if False: + file.write("// Define Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file: + # Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes + file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs)) + +-with open("src/mpi_defn_functions_fortran.h", "w") as file: ++if False: + file.write("// Define Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: +diff --git a/gen/gen_init.py b/gen/gen_init.py +index 4939261..0e52822 100755 +--- a/gen/gen_init.py ++++ b/gen/gen_init.py +@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi")) + + from mpi_constants import constants + from mpi_functions import functions +-from mpi_constants_fortran import constants_fortran +-from mpi_functions_fortran import functions_fortran ++# from mpi_constants_fortran import constants_fortran ++# from mpi_functions_fortran import functions_fortran + + support_profiling = True + have_weak_symbols = False + replace_sentinels = False + +-with open("src/mpi_init_constants_c.h", "w") as file: ++with open(sys.argv[1], "w") as file: + file.write("// Initialize C MPI constants") + file.write("\n") + for (tp, nm) in constants: +@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file: + 'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)} + file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_functions_c.h", "w") as file: ++with open(sys.argv[2], "w") as file: + file.write("// Initialize C MPI functions\n") + file.write("\n") + for (tp, nm, args, flags) in functions: +@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file: + subs['anm{0}'.format(i)] = anm + file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_constants_fortran.h", "w") as file: ++if False: + file.write("// Initialize Fortran MPI constants\n") + file.write("\n") + for (tp, nm) in constants_fortran: +@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file: + 'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"} + file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs)) + +-with open("src/mpi_init_functions_fortran.h", "w") as file: ++if False: + file.write("// Initialize Fortran MPI functions\n") + file.write("\n") + for (tp, nm, args) in functions_fortran: diff --git a/third_party/mpitrampoline/mpitrampoline.BUILD b/third_party/mpitrampoline/mpitrampoline.BUILD new file mode 100644 index 0000000000000..20c5514b164e7 --- /dev/null +++ b/third_party/mpitrampoline/mpitrampoline.BUILD @@ -0,0 +1,135 @@ +# Description: +# A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI + +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") +load("@xla//xla:strict.default.bzl", "py_strict_binary") + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE.md"]) + +genrule( + name = "mpi_version", + srcs = [ + "CMakeLists.txt", + "include/mpi_version.h.in", + ], + outs = ["include/mpi_version.h"], + cmd = """ + PROJECT_VERSION=`cat $(location CMakeLists.txt) \ + | grep "MPItrampoline VERSION" | awk '{print $$NF}'` + PROJECT_VERSION_MAJOR=`echo $$PROJECT_VERSION | cut -d. -f1` + PROJECT_VERSION_MINOR=`echo $$PROJECT_VERSION | cut -d. -f2` + PROJECT_VERSION_PATCH=`echo $$PROJECT_VERSION | cut -d. -f3` + sed -e "s/@PROJECT_VERSION@/$${PROJECT_VERSION}/" \ + -e "s/@PROJECT_VERSION_MAJOR@/$${PROJECT_VERSION_MAJOR}/" \ + -e "s/@PROJECT_VERSION_MINOR@/$${PROJECT_VERSION_MINOR}/" \ + -e "s/@PROJECT_VERSION_PATCH@/$${PROJECT_VERSION_PATCH}/" \ + $(location include/mpi_version.h.in) > $(location include/mpi_version.h) + """, +) + +expand_template( + name = "mpi_defaults", + out = "src/mpi_defaults.h", + substitutions = { + "@MPITRAMPOLINE_DEFAULT_DELAY_INIT@": "", + "@MPITRAMPOLINE_DEFAULT_DLOPEN_BINDING@": "", + "@MPITRAMPOLINE_DEFAULT_DLOPEN_MODE@": "", + "@MPITRAMPOLINE_DEFAULT_LIB@": "", + "@MPITRAMPOLINE_DEFAULT_PRELOAD@": "", + "@MPITRAMPOLINE_DEFAULT_VERBOSE@": "", + }, + template = "src/mpi_defaults.h.in", +) + +py_strict_binary( + name = "gen_decl", + srcs = [ + "gen/gen_decl.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "decl", + outs = [ + "include/mpi_decl_constants_c.h", + "include/mpi_decl_functions_c.h", + ], + cmd = "$(location :gen_decl) $(location include/mpi_decl_constants_c.h) \ + $(location include/mpi_decl_functions_c.h)", + tools = [":gen_decl"], +) + +py_strict_binary( + name = "gen_defn", + srcs = [ + "gen/gen_defn.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "defn", + outs = [ + "include/mpi_defn_constants_c.h", + "include/mpi_defn_functions_c.h", + ], + cmd = "$(location :gen_defn) $(location include/mpi_defn_constants_c.h) \ + $(location include/mpi_defn_functions_c.h)", + tools = [":gen_defn"], +) + +py_strict_binary( + name = "gen_init", + srcs = [ + "gen/gen_init.py", + "mpiabi/mpi_constants.py", + "mpiabi/mpi_functions.py", + ], +) + +genrule( + name = "init", + outs = [ + "include/mpi_init_constants_c.h", + "include/mpi_init_functions_c.h", + ], + cmd = "$(location :gen_init) $(location include/mpi_init_constants_c.h) \ + $(location include/mpi_init_functions_c.h)", + tools = [":gen_init"], +) + +cc_library( + name = "mpitrampoline", + srcs = [ + "src/mpi.c", + ], + hdrs = [ + "include/mpi.h", + "include/mpi_decl_constants_c.h", + "include/mpi_decl_functions_c.h", + "include/mpi_defn_constants_c.h", + "include/mpi_defn_functions_c.h", + "include/mpi_init_constants_c.h", + "include/mpi_init_functions_c.h", + "include/mpi_version.h", + "mpiabi/mpiabi.h", + "src/mpi_defaults.h", + ], + copts = [ + "-fexceptions", + ], + includes = [ + "include", + "mpiabi", + "src", + ], +) diff --git a/third_party/mpitrampoline/workspace.bzl b/third_party/mpitrampoline/workspace.bzl new file mode 100644 index 0000000000000..4748931ae6e36 --- /dev/null +++ b/third_party/mpitrampoline/workspace.bzl @@ -0,0 +1,18 @@ +"""Provides the repository macro to import mpitrampoline.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports mpitrampoline.""" + + MPITRAMPOLINE_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed" + MPITRAMPOLINE_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84" + + tf_http_archive( + name = "mpitrampoline", + sha256 = MPITRAMPOLINE_SHA256, + strip_prefix = "MPItrampoline-{commit}".format(commit = MPITRAMPOLINE_COMMIT), + urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = MPITRAMPOLINE_COMMIT)), + patch_file = ["//third_party/mpitrampoline:gen.patch"], + build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD", + ) diff --git a/workspace2.bzl b/workspace2.bzl index 53f9cec96eac1..f5e7ac03ae409 100644 --- a/workspace2.bzl +++ b/workspace2.bzl @@ -10,6 +10,7 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # Import third party repository rules. See go/tfbr-thirdparty. load("//third_party/dlpack:workspace.bzl", dlpack = "repo") load("//third_party/gloo:workspace.bzl", gloo = "repo") +load("//third_party/mpitrampoline:workspace.bzl", mpitrampoline = "repo") load("//third_party/nanobind:workspace.bzl", nanobind = "repo") load("//third_party/robin_map:workspace.bzl", robin_map = "repo") load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") @@ -19,6 +20,7 @@ def _initialize_third_party(): """ Load third party repositories. See above load() statements. """ dlpack() gloo() + mpitrampoline() nanobind() robin_map() stablehlo() diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index c5d3f2e367951..024727c474109 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -1,3 +1,4 @@ +load("@tsl//tsl:tsl.bzl", "if_oss") load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") @@ -286,3 +287,34 @@ cc_library( "@tsl//tsl/platform:logging", ], ) + +cc_library( + name = "mpi_collectives", + srcs = if_oss(["mpi_collectives.cc"]), + hdrs = if_oss(["mpi_collectives.h"]), + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = if_oss([ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@mpitrampoline", + ]), +) diff --git a/xla/pjrt/cpu/mpi_collectives.cc b/xla/pjrt/cpu/mpi_collectives.cc new file mode 100644 index 0000000000000..d2c93fd75450f --- /dev/null +++ b/xla/pjrt/cpu/mpi_collectives.cc @@ -0,0 +1,283 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/pjrt/cpu/mpi_collectives.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mpi.h" // NOLINT +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/status_macros.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" + +namespace xla::cpu { + +absl::StatusOr PrimitiveTypeToMpiType( + PrimitiveType element_type) { + switch (element_type) { + case S8: + return MPI_INT8_T; + case U8: + case PRED: + return MPI_UINT8_T; + case S16: + return MPI_INT16_T; + case U16: + return MPI_UINT16_T; + case S32: + return MPI_INT32_T; + case U32: + return MPI_UINT32_T; + case S64: + return MPI_INT64_T; + case U64: + return MPI_UINT64_T; + case F32: + return MPI_FLOAT; + case F64: + return MPI_DOUBLE; + case C64: + return MPI_C_COMPLEX; + case C128: + return MPI_C_DOUBLE_COMPLEX; + default: + // For implementing the reduction of unsupported types + // see e.g. https://stackoverflow.com/a/29643391 + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported primitive type for reduction: ", + primitive_util::LowercasePrimitiveTypeName(element_type))); + } +} + +bool MpiTypeIsComplex(MPI_Datatype type) { + return type == MPI_C_COMPLEX || type == MPI_C_DOUBLE_COMPLEX; +} + +absl::StatusOr ReductionKindToMpiOp(ReductionKind reduction_kind, + MPI_Datatype type) { + switch (reduction_kind) { + case ReductionKind::SUM: + return MPI_SUM; + case ReductionKind::PRODUCT: + return MPI_PROD; + case ReductionKind::MIN: + if (!MpiTypeIsComplex(type)) { + return MPI_MIN; + } else { + return absl::InvalidArgumentError( + "MIN reduction not supported for complex types"); + } + case ReductionKind::MAX: + if (!MpiTypeIsComplex(type)) { + return MPI_MAX; + } else { + return absl::InvalidArgumentError( + "MAX reduction not supported for complex types"); + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Unknown reduction kind: ", reduction_kind)); + } +} + +static absl::Status MpiErrorToAbslStatus(int error) { + if (error != MPI_SUCCESS) { + char error_str[MPI_MAX_ERROR_STRING]; + int len; + MPI_Error_string(error, error_str, &len); + return absl::UnknownError(absl::StrCat("MPI error: ", error_str)); + } + return absl::OkStatus(); +} + +MpiCollectivesCommunicator::MpiCollectivesCommunicator(int color, int key) { + MPI_Comm_split(MPI_COMM_WORLD, color, key, &comm_); + MPI_Comm_rank(comm_, &mpi_rank_); + MPI_Comm_size(comm_, &mpi_size_); +} + +MpiCollectivesCommunicator::~MpiCollectivesCommunicator() { + MPI_Comm_free(&comm_); +}; + +absl::Status MpiCollectivesCommunicator::AllReduce( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(element_type)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus(MPI_Allreduce(input_buffer, output_buffer, + num_elements, type, op, comm_)); +} + +absl::Status MpiCollectivesCommunicator::CollectivePermute( + const RendezvousKey& key, size_t num_bytes, std::optional source_rank, + absl::Span target_ranks, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + int tag = 0; // TODO come up with better tags. + + const int rank = mpi_rank_; + + std::vector requests; + + if (source_rank) { + if (*source_rank == rank) { + std::memcpy(output_buffer, input_buffer, num_bytes); + } else { + VLOG(1) << "recv at " << rank << " from " << *source_rank; + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Irecv(output_buffer, num_bytes, MPI_BYTE, *source_rank, tag, + comm_, &requests.back()))); + } + } else { + std::memset(output_buffer, 0, num_bytes); + } + + for (int target : target_ranks) { + if (target != rank) { + VLOG(1) << "send from " << rank << " to " << target; + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Isend(input_buffer, num_bytes, MPI_BYTE, target, tag, comm_, + &requests.back()))); + } + } + + for (auto& request : requests) { + TF_RETURN_IF_ERROR( + MpiErrorToAbslStatus(MPI_Wait(&request, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCollectivesCommunicator::AllToAll( + const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, absl::Duration timeout) { + // We can't use MPI_Alltoall directly because it assumes that the inputs and + // outputs are contiguous. Therefore here we implement it using MPI_Sendrecv. + + int tag = 0; // TODO use better tags. + const int rank = mpi_rank_; + const int size = mpi_size_; + TF_RET_CHECK(size == input_buffers.size()); + TF_RET_CHECK(size == output_buffers.size()); + + std::memcpy(output_buffers[rank], input_buffers[rank], chunk_bytes); + + for (int i = 1; i < size; i++) { + int send_rank = (rank + i) % size; + int recv_rank = (rank + size - i) % size; + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Sendrecv(input_buffers[send_rank], chunk_bytes, MPI_BYTE, send_rank, + tag, output_buffers[recv_rank], chunk_bytes, MPI_BYTE, + recv_rank, tag, comm_, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCollectivesCommunicator::AllGather(const RendezvousKey& key, + size_t chunk_bytes, + const void* input_buffer, + void* output_buffer, + absl::Duration timeout) { + return MpiErrorToAbslStatus(MPI_Allgather(input_buffer, chunk_bytes, MPI_BYTE, + output_buffer, chunk_bytes, + MPI_BYTE, comm_)); +} + +absl::Status MpiCollectivesCommunicator::ReduceScatter( + const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, const void* input_buffer, + void* output_buffer, absl::Duration timeout) { + const int size = mpi_size_; + std::vector recvcounts(size, chunk_elems); + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(element_type)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus(MPI_Reduce_scatter( + input_buffer, output_buffer, recvcounts.data(), type, op, comm_)); +} + +void MpiCollectives::Init() { + int provided; + MPI_Init_thread(NULL, NULL, MPI_THREAD_FUNNELED, &provided); + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank_); + MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size_); + VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_; +} + +void MpiCollectives::Finalize() { + contexts_.clear(); + MPI_Finalize(); +} + +absl::StatusOr> +MpiCollectives::GetCommunicator(absl::Span global_devices, + int rank) { + int flag; + MPI_Is_thread_main(&flag); + if (!flag) { + return absl::UnknownError( + absl::StrCat("MPI: Communicator requested from a thread that is not " + "the one MPI was initialized from. Multiple " + "threads/devices per process are not yet supported.")); + } + + auto& context = contexts_[std::make_tuple( + std::vector(global_devices.begin(), global_devices.end()), + rank)]; + if (context) { + return context; + } + + int color; + int key = 0; + if (global_devices.size() > 0) { + color = static_cast(global_devices.at(0).value()); + key = rank; + } else { + color = MPI_UNDEFINED; + } + context = std::make_shared(color, key); + return context; +} + +} // namespace xla::cpu diff --git a/xla/pjrt/cpu/mpi_collectives.h b/xla/pjrt/cpu/mpi_collectives.h new file mode 100644 index 0000000000000..fdf6ec81b6dc6 --- /dev/null +++ b/xla/pjrt/cpu/mpi_collectives.h @@ -0,0 +1,102 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PJRT_CPU_MPI_COLLECTIVES_H_ +#define XLA_PJRT_CPU_MPI_COLLECTIVES_H_ + +#include +#include +#include +#include +#include + +#include "mpi.h" // NOLINT +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/cpu/collectives_interface.h" +#include "xla/service/global_device_id.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class MpiCollectivesCommunicator : public CollectivesCommunicator { + public: + explicit MpiCollectivesCommunicator(int color, int key); + ~MpiCollectivesCommunicator() override; + + absl::Status AllReduce(const RendezvousKey& key, ReductionKind reduction_kind, + PrimitiveType element_type, size_t num_elements, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status CollectivePermute(const RendezvousKey& key, size_t num_bytes, + std::optional source_rank, + absl::Span target_ranks, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status AllToAll(const RendezvousKey& key, size_t chunk_bytes, + absl::Span input_buffers, + absl::Span output_buffers, + absl::Duration timeout) override; + absl::Status AllGather(const RendezvousKey& key, size_t chunk_bytes, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + absl::Status ReduceScatter(const RendezvousKey& key, + ReductionKind reduction_kind, + PrimitiveType element_type, size_t chunk_elems, + const void* input_buffer, void* output_buffer, + absl::Duration timeout) override; + + private: + MPI_Comm comm_; + int mpi_rank_; + int mpi_size_; +}; + +class MpiCollectives : public CollectivesInterface { + public: + /* + The user has to explicitly call Init() and Finalize() before and + after use. + For example, using the Python client, this can be achieved with: + + collectives = xla_client._xla.make_mpi_collectives() + collectives.Init() + atexit.register(collectives.Finalize) + */ + void Init(); + void Finalize(); + + absl::StatusOr> GetCommunicator( + absl::Span global_devices, int rank) override; + + private: + absl::Status ExchangeGlobalDeviceIds( + absl::Span global_devices, int rank); + + int mpi_world_rank_; + int mpi_world_size_; + absl::flat_hash_map, int>, + std::shared_ptr> + contexts_; +}; + +} // namespace xla::cpu + +#endif // XLA_PJRT_CPU_MPI_COLLECTIVES_H_ diff --git a/xla/python/BUILD b/xla/python/BUILD index 69ba00bac6fc5..9734ab7538b46 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -1261,6 +1261,12 @@ tsl_pybind_extension( "//xla/pjrt/cpu:gloo_kv_store", "@gloo//:transport_tcp", ], + }) + select({ + # mpitrampoline does not build on windows + "@tsl//tsl:windows": [], + "//conditions:default": [ + "//xla/pjrt/cpu:mpi_collectives", + ], }) + select({ ":gpu_enabled": [ "//xla/pjrt/gpu:se_gpu_pjrt_client", diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 70d9a39388dd1..a45776a38fa30 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -67,6 +67,11 @@ limitations under the License. #include "xla/pjrt/cpu/gloo_collectives.h" #include "xla/pjrt/cpu/gloo_kv_store.h" #endif // __linux__ + +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) +#include "xla/pjrt/cpu/mpi_collectives.h" +#endif // !_WIN32 && !PLATFORM_GOOGLE + #include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/exceptions.h" @@ -270,6 +275,23 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, nb::arg("interface").none() = std::nullopt); +#if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) + nb::class_ mpi_collectives(m_nb, "MpiCollectives", + cpu_collectives); + mpi_collectives.def("Init", &cpu::MpiCollectives::Init); + mpi_collectives.def("Finalize", &cpu::MpiCollectives::Finalize); + m_nb.def("make_mpi_collectives", + []() -> std::shared_ptr { + return std::make_shared(); + }); +#else // !_WIN32 && !PLATFORM_GOOGLE + m_nb.def("make_mpi_collectives", + []() -> std::shared_ptr { + throw xla::XlaRuntimeError( + "make_mpi_collectives is not implemented for Windows"); + }); +#endif // !_WIN32 && !PLATFORM_GOOGLE + m_nb.def( "get_tfrt_cpu_client", [](bool asynchronous, diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index ef9e4e5291ab9..64eb6cd7d4e1d 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -48,7 +48,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 250 +_version = 251 # Version number for MLIR:Python components. mlir_api_version = 55