Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:CPU] Add support for cross-process collectives using mpi. #7849

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/mpitrampoline/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"])
149 changes: 149 additions & 0 deletions third_party/mpitrampoline/gen.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
diff --git a/gen/gen_decl.py b/gen/gen_decl.py
index 1005b95..696b4e0 100755
--- a/gen/gen_decl.py
+++ b/gen/gen_decl.py
@@ -9,8 +9,8 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))

from mpi_constants import constants
from mpi_functions import functions
-from mpi_constants_fortran import constants_fortran
-from mpi_functions_fortran import functions_fortran
+# from mpi_constants_fortran import constants_fortran
+# from mpi_functions_fortran import functions_fortran

support_profiling = True
have_weak_symbols = False
@@ -24,7 +24,7 @@ def wrap(line):
lines.append(line)
return "\n".join(lines)

-with open("include/mpi_decl_constants_c.h", "w") as file:
+with open(sys.argv[1], "w") as file:
file.write("// Declare C MPI constants\n")
file.write("\n")
for (tp, nm) in constants:
@@ -32,7 +32,7 @@ with open("include/mpi_decl_constants_c.h", "w") as file:
'mpi_nm': nm}
file.write(Template("extern $mpi_tp MPITRAMPOLINE_CONST $mpi_nm;\n").substitute(subs))

-with open("include/mpi_decl_functions_c.h", "w") as file:
+with open(sys.argv[2], "w") as file:
file.write("// Declare C MPI functions\n")
file.write("\n")
for (tp, nm, args, flags) in functions:
@@ -90,7 +90,7 @@ with open("include/mpi_decl_functions_c.h", "w") as file:
file.write(Template("\n".join(tmpl)).substitute(subs))
file.write("\n")

-with open("include/mpi_decl_constants_fortran.h", "w") as file:
+if False:
file.write("! Declare Fortran MPI constants\n")
file.write("\n")
for (tp, nm) in constants_fortran:
@@ -104,7 +104,7 @@ with open("include/mpi_decl_constants_fortran.h", "w") as file:
file.write("\n".join(map(lambda line: wrap(Template(line).substitute(subs)), tmpl)))
file.write("\n")

-with open("include/mpi_decl_functions_fortran.h", "w") as file:
+if False:
file.write("! Declare Fortran MPI functions\n")
file.write("\n")
for (tp, nm, args) in functions_fortran:
diff --git a/gen/gen_defn.py b/gen/gen_defn.py
index bf31f35..318222e 100755
--- a/gen/gen_defn.py
+++ b/gen/gen_defn.py
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))

from mpi_constants import constants
from mpi_functions import functions
-from mpi_constants_fortran import constants_fortran
-from mpi_functions_fortran import functions_fortran
+# from mpi_constants_fortran import constants_fortran
+# from mpi_functions_fortran import functions_fortran

support_profiling = True
have_weak_symbols = False
replace_sentinels = False

-with open("src/mpi_defn_constants_c.h", "w") as file:
+with open(sys.argv[1], "w") as file:
file.write("// Define C MPI constants")
file.write("\n")
for (tp, nm) in constants:
@@ -24,7 +24,7 @@ with open("src/mpi_defn_constants_c.h", "w") as file:
'mpi_nm': nm}
file.write(Template("$mpi_tp $mpi_nm = ($mpi_tp)0xdeadbeef;\n").substitute(subs))

-with open("src/mpi_defn_functions_c.h", "w") as file:
+with open(sys.argv[2], "w") as file:
file.write("// Define C MPI functions\n")
file.write("\n")
for (tp, nm, args, flags) in functions:
@@ -89,7 +89,7 @@ with open("src/mpi_defn_functions_c.h", "w") as file:
file.write(Template("\n".join(tmpl)).substitute(subs))
file.write("\n")

-with open("src/mpi_defn_constants_fortran.h", "w") as file:
+if False:
file.write("// Define Fortran MPI constants\n")
file.write("\n")
for (tp, nm) in constants_fortran:
@@ -98,7 +98,7 @@ with open("src/mpi_defn_constants_fortran.h", "w") as file:
# Fortran common blocks with `-march=skylake-avx512` are aligned to 64 bytes
file.write(Template("$mpi_tp $abi_nm __attribute__((__aligned__(64))) = (int)0xdeadbeef;\n").substitute(subs))

-with open("src/mpi_defn_functions_fortran.h", "w") as file:
+if False:
file.write("// Define Fortran MPI functions\n")
file.write("\n")
for (tp, nm, args) in functions_fortran:
diff --git a/gen/gen_init.py b/gen/gen_init.py
index 4939261..0e52822 100755
--- a/gen/gen_init.py
+++ b/gen/gen_init.py
@@ -9,14 +9,14 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "mpiabi"))

from mpi_constants import constants
from mpi_functions import functions
-from mpi_constants_fortran import constants_fortran
-from mpi_functions_fortran import functions_fortran
+# from mpi_constants_fortran import constants_fortran
+# from mpi_functions_fortran import functions_fortran

support_profiling = True
have_weak_symbols = False
replace_sentinels = False

-with open("src/mpi_init_constants_c.h", "w") as file:
+with open(sys.argv[1], "w") as file:
file.write("// Initialize C MPI constants")
file.write("\n")
for (tp, nm) in constants:
@@ -25,7 +25,7 @@ with open("src/mpi_init_constants_c.h", "w") as file:
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm)}
file.write(Template("$mpi_nm = *($mpi_tp const *)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))

-with open("src/mpi_init_functions_c.h", "w") as file:
+with open(sys.argv[2], "w") as file:
file.write("// Initialize C MPI functions\n")
file.write("\n")
for (tp, nm, args, flags) in functions:
@@ -39,7 +39,7 @@ with open("src/mpi_init_functions_c.h", "w") as file:
subs['anm{0}'.format(i)] = anm
file.write(Template("$abi_nm = get_symbol(handle, \"$abi_nm\");\n").substitute(subs))

-with open("src/mpi_init_constants_fortran.h", "w") as file:
+if False:
file.write("// Initialize Fortran MPI constants\n")
file.write("\n")
for (tp, nm) in constants_fortran:
@@ -47,7 +47,7 @@ with open("src/mpi_init_constants_fortran.h", "w") as file:
'abi_nm': re.sub(r"MPI(X?)_", r"MPI\1ABI_", nm).lower() + "_"}
file.write(Template("$abi_nm = *($abi_tp const*)get_symbol(handle, \"$abi_nm\");\n").substitute(subs))

-with open("src/mpi_init_functions_fortran.h", "w") as file:
+if False:
file.write("// Initialize Fortran MPI functions\n")
file.write("\n")
for (tp, nm, args) in functions_fortran:
131 changes: 131 additions & 0 deletions third_party/mpitrampoline/mpitrampoline.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Description:
# A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI

load("@bazel_skylib//rules:expand_template.bzl", "expand_template")

package(
default_visibility = ["//visibility:public"],
)

licenses(["notice"])

exports_files(["LICENSE.md"])

genrule(
name = "mpi_version",
srcs = [
"CMakeLists.txt",
"include/mpi_version.h.in",
],
outs = ["include/mpi_version.h"],
cmd = """
PROJECT_VERSION=`cat $(location CMakeLists.txt) \
| grep "MPItrampoline VERSION" | awk '{print $$NF}'`
PROJECT_VERSION_MAJOR=`echo $$PROJECT_VERSION | cut -d. -f1`
PROJECT_VERSION_MINOR=`echo $$PROJECT_VERSION | cut -d. -f2`
PROJECT_VERSION_PATCH=`echo $$PROJECT_VERSION | cut -d. -f3`
sed -e "s/@PROJECT_VERSION@/$${PROJECT_VERSION}/" \
-e "s/@PROJECT_VERSION_MAJOR@/$${PROJECT_VERSION_MAJOR}/" \
-e "s/@PROJECT_VERSION_MINOR@/$${PROJECT_VERSION_MINOR}/" \
-e "s/@PROJECT_VERSION_PATCH@/$${PROJECT_VERSION_PATCH}/" \
$(location include/mpi_version.h.in) > $(location include/mpi_version.h)
""",
)

expand_template(
name = "mpi_defaults",
out = "src/mpi_defaults.h",
substitutions = {
"@MPITRAMPOLINE_DEFAULT_DELAY_INIT@": "",
"@MPITRAMPOLINE_DEFAULT_DLOPEN_BINDING@": "",
"@MPITRAMPOLINE_DEFAULT_DLOPEN_MODE@": "",
"@MPITRAMPOLINE_DEFAULT_LIB@": "",
"@MPITRAMPOLINE_DEFAULT_PRELOAD@": "",
"@MPITRAMPOLINE_DEFAULT_VERBOSE@": "",
},
template = "src/mpi_defaults.h.in",
)

py_binary(
name='gen_decl',
srcs = [
"gen/gen_decl.py",
"mpiabi/mpi_constants.py",
"mpiabi/mpi_functions.py",
],
)

genrule(
name = "decl",
tools = [':gen_decl'],
outs = [
"include/mpi_decl_constants_c.h",
"include/mpi_decl_functions_c.h",
],
cmd = "$(location :gen_decl) $(location include/mpi_decl_constants_c.h) \
$(location include/mpi_decl_functions_c.h)",
)

py_binary(
name='gen_defn',
srcs = [
"gen/gen_defn.py",
"mpiabi/mpi_constants.py",
"mpiabi/mpi_functions.py",
],
)

genrule(
name = "defn",
tools = [':gen_defn'],
outs = [
"include/mpi_defn_constants_c.h",
"include/mpi_defn_functions_c.h",
],
cmd = "$(location :gen_defn) $(location include/mpi_defn_constants_c.h) \
$(location include/mpi_defn_functions_c.h)",
)

py_binary(
name='gen_init',
srcs = [
"gen/gen_init.py",
"mpiabi/mpi_constants.py",
"mpiabi/mpi_functions.py",
],
)

genrule(
name = "init",
tools = [':gen_init'],
outs = [
"include/mpi_init_constants_c.h",
"include/mpi_init_functions_c.h",
],
cmd = "$(location :gen_init) $(location include/mpi_init_constants_c.h) \
$(location include/mpi_init_functions_c.h)",
)

cc_library(
name = "mpitrampoline",
srcs = [
"src/mpi.c",
],
copts = [
"-fexceptions",
],
includes = ["include", "mpiabi", "src"],
hdrs = [
"include/mpi.h",
"include/mpi_version.h",
"include/mpi_decl_constants_c.h",
"include/mpi_decl_functions_c.h",
"include/mpi_defn_constants_c.h",
"include/mpi_defn_functions_c.h",
"include/mpi_init_constants_c.h",
"include/mpi_init_functions_c.h",
"mpiabi/mpiabi.h",
"src/mpi_defaults.h",

]
)
18 changes: 18 additions & 0 deletions third_party/mpitrampoline/workspace.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Provides the repository macro to import mpitrampoline."""

load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")

def repo():
"""Imports mpitrampoline."""

mpitrampoline_COMMIT = "25efb0f7a4cd00ed82bafb8b1a6285fc50d297ed"
mpitrampoline_SHA256 = "5a36656205c472bdb639bffebb0f014523b32dda0c2cbedd9ce7abfc9e879e84"

tf_http_archive(
name = "mpitrampoline",
sha256 = mpitrampoline_SHA256,
strip_prefix = "MPItrampoline-{commit}".format(commit = mpitrampoline_COMMIT),
urls = tf_mirror_urls("https://github.com/eschnett/mpitrampoline/archive/{commit}.tar.gz".format(commit = mpitrampoline_COMMIT)),
patch_file = ["//third_party/mpitrampoline:gen.patch"],
build_file = "//third_party/mpitrampoline:mpitrampoline.BUILD",
)
2 changes: 2 additions & 0 deletions workspace2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
# Import third party repository rules. See go/tfbr-thirdparty.
load("//third_party/dlpack:workspace.bzl", dlpack = "repo")
load("//third_party/gloo:workspace.bzl", gloo = "repo")
load("//third_party/mpitrampoline:workspace.bzl", mpitrampoline = "repo")
load("//third_party/nanobind:workspace.bzl", nanobind = "repo")
load("//third_party/robin_map:workspace.bzl", robin_map = "repo")
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
Expand All @@ -19,6 +20,7 @@ def _initialize_third_party():
""" Load third party repositories. See above load() statements. """
dlpack()
gloo()
mpitrampoline()
nanobind()
robin_map()
stablehlo()
Expand Down
30 changes: 30 additions & 0 deletions xla/pjrt/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,33 @@ cc_library(
"@tsl//tsl/platform:logging",
],
)
cc_library(
name = "mpi_collectives",
srcs = ["mpi_collectives.cc"],
hdrs = ["mpi_collectives.h"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
deps = [
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla/service:collective_ops_utils",
"//xla/service:global_device_id",
"//xla/service/cpu:collectives_interface",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
"@com_google_absl//absl/types:span",
"@mpitrampoline",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
],
)
Loading
Loading