Skip to content

Commit

Permalink
Switch JAX to use the OpenXLA repository.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed Mar 13, 2023
1 parent 233911c commit 172a831
Show file tree
Hide file tree
Showing 39 changed files with 1,281 additions and 142 deletions.
18 changes: 11 additions & 7 deletions .bazelrc
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
############################################################################
# All default build options below.

# Required by OpenXLA
# https://github.com/openxla/xla/issues/1323
build --nocheck_visibility

# Sets the default Apple platform to macOS.
build --apple_platform_type=macos
build --macos_minimum_os=10.14
Expand Down Expand Up @@ -35,9 +39,9 @@ build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.

# Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled,
# these values are overridden.
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=false
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false
build --@xla//xla/python:enable_gpu=false
build --@xla//xla/python:enable_tpu=false
build --@xla//xla/python:enable_plugin_device=false

###########################################################################

Expand Down Expand Up @@ -65,12 +69,12 @@ build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_52,sm_60,sm_70,compute_80"
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --@local_config_cuda//:enable_cuda
build:cuda --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true
build:cuda --@xla//xla/python:enable_gpu=true
build:cuda --define=xla_python_enable_gpu=true

build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true
build:rocm --@xla//xla/python:enable_gpu=true
build:rocm --define=xla_python_enable_gpu=true
build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
Expand Down Expand Up @@ -113,10 +117,10 @@ build:macos --config=posix
# Suppress all warning messages.
build:short_logs --output_filter=DONT_MATCH_ANYTHING

build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true
build:tpu --@xla//xla/python:enable_tpu=true
build:tpu --define=with_tpu_support=true

build:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true
build:plugin_device --@xla//xla/python:enable_plugin_device=true

#########################################################################
# RBE config options below.
Expand Down
42 changes: 24 additions & 18 deletions WORKSPACE
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

# To update TensorFlow to a new revision,
# To update XLA to a new revision,
# a) update URL and strip_prefix to the new git commit hash
# b) get the sha256 hash of the commit by running:
# curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.
http_archive(
name = "org_tensorflow",
sha256 = "08fd0ab0b672510229ad2fff276a3634f205fc539fa16a5bdeeaaccd881ece27",
strip_prefix = "tensorflow-2aaeef25361311b21b9e81e992edff94bcb6bae3",
name = "xla",
sha256 = "9f39af4d81d2c8bd52b47f4ef37dfd6642c6950112e4d9d3d95ae19982c46eba",
strip_prefix = "xla-0f31407ee498e6dba242d03f8d382ebcfcc61790",
urls = [
"https://github.com/tensorflow/tensorflow/archive/2aaeef25361311b21b9e81e992edff94bcb6bae3.tar.gz",
"https://github.com/openxla/xla/archive/0f31407ee498e6dba242d03f8d382ebcfcc61790.tar.gz",
],
)

Expand All @@ -19,26 +19,32 @@ http_archive(
# local checkout by either:
# a) overriding the TF repository on the build.py command line by passing a flag
# like:
# python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow
# python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
# or
# b) by commenting out the http_archive above and uncommenting the following:
# local_repository(
# name = "org_tensorflow",
# path = "/path/to/tensorflow",
# name = "xla",
# path = "/path/to/xla",
# )

load("//third_party/ducc:workspace.bzl", ducc = "repo")
ducc()

# Initialize TensorFlow's external dependencies.
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
tf_workspace3()
load("@xla//:workspace4.bzl", "xla_workspace4")
xla_workspace4()

load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2")
tf_workspace2()
load("@xla//:workspace3.bzl", "xla_workspace3")
xla_workspace3()

load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1")
tf_workspace1()
load("@xla//:workspace2.bzl", "xla_workspace2")
xla_workspace2()

load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0")
tf_workspace0()
load("@xla//:workspace1.bzl", "xla_workspace1")
xla_workspace1()

load("@xla//:workspace0.bzl", "xla_workspace0")
xla_workspace0()


load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
flatbuffers()
4 changes: 2 additions & 2 deletions build/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ py_binary(
"//jaxlib:README.md",
"//jaxlib:setup.py",
"//jaxlib:setup.cfg",
"@org_tensorflow//tensorflow/compiler/xla/python:xla_client",
"@xla//xla/python:xla_client",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + select({
":remote_tpu_enabled": ["@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client"],
":remote_tpu_enabled": ["@xla//xla/python/tpu_driver/client:py_tpu_client"],
"//conditions:default": [],
}) + if_cuda([
"//jaxlib/cuda:cuda_gpu_support",
Expand Down
18 changes: 9 additions & 9 deletions build/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,29 +103,29 @@ def patch_copy_xla_extension_stubs(dst_dir):
os.makedirs(xla_extension_dir)
for stub_name in _XLA_EXTENSION_STUBS:
stub_path = r.Rlocation(
"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/" + stub_name)
"xla/xla/python/xla_extension/" + stub_name)
stub_path = str(stub_path) # Make pytype accept os.path.exists(stub_path).
if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):
continue
with open(stub_path) as f:
src = f.read()
src = src.replace(
"from tensorflow.compiler.xla.python import xla_extension",
"from xla.python import xla_extension",
"from .. import xla_extension"
)
with open(os.path.join(xla_extension_dir, stub_name), "w") as f:
f.write(src)


def patch_copy_tpu_client_py(dst_dir):
with open(r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py")) as f:
with open(r.Rlocation("xla/xla/python/tpu_driver/client/tpu_client.py")) as f:
src = f.read()
src = src.replace("from tensorflow.compiler.xla.python import xla_extension as _xla",
src = src.replace("from xla.python import xla_extension as _xla",
"from . import xla_extension as _xla")
src = src.replace("from tensorflow.compiler.xla.python import xla_client",
src = src.replace("from xla.python import xla_client",
"from . import xla_client")
src = src.replace(
"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
"from xla.python.tpu_driver.client import tpu_client_extension as _tpu_client",
"from . import tpu_client_extension as _tpu_client")
with open(os.path.join(dst_dir, "tpu_client.py"), "w") as f:
f.write(src)
Expand All @@ -143,7 +143,7 @@ def verify_mac_libraries_dont_reference_chkstack():
return
nm = subprocess.run(
["nm", "-g",
r.Rlocation("org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so")
r.Rlocation("xla/xla/python/xla_extension.so")
],
capture_output=True, text=True,
check=False)
Expand Down Expand Up @@ -250,8 +250,8 @@ def prepare_wheel(sources_path):
copy_file("__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so", dst_dir=mlir_libs_dir)
patch_copy_xla_extension_stubs(jaxlib_dir)

if exists("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"):
copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so")
if exists("xla/xla/python/tpu_driver/client/tpu_client_extension.so"):
copy_to_jaxlib("xla/xla/python/tpu_driver/client/tpu_client_extension.so")
patch_copy_tpu_client_py(jaxlib_dir)


Expand Down
16 changes: 8 additions & 8 deletions docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,24 @@ specify the paths to CUDA and CUDNN, which you must have installed. Here
may need to use `python3` instead. By default, the wheel is written to the
`dist/` subdirectory of the current directory.

### Building jaxlib from source with a modified TensorFlow repository.
### Building jaxlib from source with a modified XLA repository.

JAX depends on XLA, whose source code is in the
[Tensorflow GitHub repository](https://github.com/tensorflow/tensorflow).
By default JAX uses a pinned copy of the TensorFlow repository, but we often
[XLA GitHub repository](https://github.com/openxla/xla).
By default JAX uses a pinned copy of the XLA repository, but we often
want to use a locally-modified copy of XLA when working on JAX. There are two
ways to do this:

* use Bazel's `override_repository` feature, which you can pass as a command
line flag to `build.py` as follows:

```
python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow
python build/build.py --bazel_options=--override_repository=xla=/path/to/xla
```
* modify the `WORKSPACE` file in the root of the JAX source tree to point to
a different TensorFlow tree.
a different XLA tree.

To contribute changes back to XLA, send PRs to the TensorFlow repository.
To contribute changes back to XLA, send PRs to the XLA repository.

The version of XLA pinned by JAX is regularly updated, but is updated in
particular before each `jaxlib` release.
Expand Down Expand Up @@ -141,7 +141,7 @@ sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
```

AMD's fork of the XLA (TensorFlow) repository may include fixes
AMD's fork of the XLA repository may include fixes
not present in the upstream repository. To use AMD's fork, you should clone
their repository:
```
Expand All @@ -152,7 +152,7 @@ To build jaxlib with ROCM support, you can run the following build command,
suitably adjusted for your paths and ROCM version.
```
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.3.0 \
--bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow-upstream
--bazel_options=--override_repository=xla=/path/to/xla-upstream
```

## Installing `jax`
Expand Down
4 changes: 2 additions & 2 deletions docs/jep/9419-jax-versioning.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ no released `jax` version uses that API.
`jaxlib` is split across two main repositories, namely the
[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib)
and in the
[XLA source tree, which lives inside the TensorFlow repository](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla).
[XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla).
The JAX-specific pieces inside XLA are primarily in the
[`xla/python` subdirectory](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla/python).

Expand Down Expand Up @@ -164,7 +164,7 @@ compatibility, we have additional versioning that is independent of the `jaxlib`
release version numbers.

We maintain an additional version number (`_version`) in
[`xla_client.py` in the XLA repository](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py).
[`xla_client.py` in the XLA repository](https://github.com/openxla/xla/blob/main/xla/python/xla_client.py).
The idea is that this version number, is defined in `xla/python`
together with the C++ parts of JAX, is also accessible to JAX Python as
`jax._src.lib.xla_extension_version`, and must
Expand Down
29 changes: 12 additions & 17 deletions examples/jax_cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_cc_binary",
)

licenses(["notice"])

tf_cc_binary(
cc_binary(
name = "main",
srcs = ["main.cc"],
tags = ["manual"],
deps = [
"@org_tensorflow//tensorflow/compiler/xla:literal",
"@org_tensorflow//tensorflow/compiler/xla:literal_util",
"@org_tensorflow//tensorflow/compiler/xla:shape_util",
"@org_tensorflow//tensorflow/compiler/xla:status",
"@org_tensorflow//tensorflow/compiler/xla:statusor",
"@org_tensorflow//tensorflow/compiler/xla/pjrt:pjrt_client",
"@org_tensorflow//tensorflow/compiler/xla/pjrt:tfrt_cpu_pjrt_client",
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
"@org_tensorflow//tensorflow/compiler/xla/tools:hlo_module_loader",
"@org_tensorflow//tensorflow/core/platform:logging",
"@org_tensorflow//tensorflow/core/platform:platform_port",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla:shape_util",
"@xla//xla:status",
"@xla//xla:statusor",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt:tfrt_cpu_pjrt_client",
"@xla//xla/service:hlo_proto_cc",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:platform_port",
],
)
20 changes: 10 additions & 10 deletions examples/jax_cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ limitations under the License.
#include <string>
#include <vector>

#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/hlo_module_loader.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/status.h"
#include "xla/statusor.h"
#include "xla/tools/hlo_module_loader.h"
#include "tsl/platform/init_main.h"
#include "tsl/platform/logging.h"

int main(int argc, char** argv) {
tensorflow::port::InitMain("", &argc, &argv);
tsl::port::InitMain("", &argc, &argv);

// Load HloModule from file.
std::string hlo_filename = "/tmp/fn_hlo.txt";
Expand Down
8 changes: 4 additions & 4 deletions jaxlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ symlink_files(

symlink_files(
name = "xla_client",
srcs = ["@org_tensorflow//tensorflow/compiler/xla/python:xla_client"],
srcs = ["@xla//xla/python:xla_client"],
dst = ".",
flatten = True,
)

symlink_files(
name = "xla_extension",
srcs = if_windows(
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.pyd"],
["@org_tensorflow//tensorflow/compiler/xla/python:xla_extension.so"],
["@xla//xla/python:xla_extension.pyd"],
["@xla//xla/python:xla_extension.so"],
),
dst = ".",
flatten = True,
Expand Down Expand Up @@ -140,7 +140,7 @@ pybind_extension(
srcs = ["cpu_feature_guard.c"],
module_name = "cpu_feature_guard",
deps = [
"@org_tensorflow//third_party/python_runtime:headers",
"@xla//third_party/python_runtime:headers",
],
)

Expand Down
6 changes: 3 additions & 3 deletions jaxlib/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ cc_library(
srcs = ["lapack_kernels.cc"],
hdrs = ["lapack_kernels.h"],
deps = [
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/base:dynamic_annotations",
],
)
Expand Down Expand Up @@ -74,7 +74,7 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":ducc_fft_flatbuffers_cc",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@xla//xla/service:custom_call_status",
"@ducc",
"@flatbuffers//:runtime_cc",
],
Expand Down Expand Up @@ -106,7 +106,7 @@ cc_library(
":ducc_fft_kernels",
":lapack_kernels",
":lapack_kernels_using_lapack",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_target_registry",
"@xla//xla/service:custom_call_target_registry",
],
alwayslink = 1,
)
2 changes: 1 addition & 1 deletion jaxlib/cpu/cpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.

#include "jaxlib/cpu/lapack_kernels.h"
#include "jaxlib/cpu/ducc_fft_kernels.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "xla/service/custom_call_target_registry.h"

namespace jax {
namespace {
Expand Down
Loading

0 comments on commit 172a831

Please sign in to comment.