diff --git a/README.md b/README.md index f47249f..0a3b430 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,10 @@ your CUDA version (like `cuda118`). When building from source it's enough to spe Note that all the precompiled binaries assume glibc 2.31 or newer. +##### Notes for ROCm: + +The ROCm precompiled build is currently broken due to an issue in our TensorFlow version with ROCm 5.4. You can still compile for ROCm by changing `TENSORFLOW_GIT_REV` per the instructions [here](https://github.com/elixir-nx/xla/issues/29) and running `XLA_BUILD=true mix compile`. + #### `XLA_BUILD` Defaults to `false`. If `true` the binary is built locally, which may be intended diff --git a/extension/BUILD b/extension/BUILD index 1019dba..6040923 100644 --- a/extension/BUILD +++ b/extension/BUILD @@ -1,5 +1,7 @@ load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_cuda_or_rocm",) load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_with_tpu_support",) +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda",) +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm",) load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependencies",) load("@org_tensorflow//tensorflow:tensorflow.bzl", "transitive_hdrs",) load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_tar",) @@ -49,7 +51,12 @@ cc_binary( ]) + if_cuda_or_rocm([ "@org_tensorflow//tensorflow/compiler/xla/service:gpu_plugin", + ]) + + if_cuda([ "@org_tensorflow//tensorflow/compiler/xla/stream_executor:cuda_platform" + ]) + + if_rocm([ + "@org_tensorflow//tensorflow/compiler/xla/stream_executor:rocm_platform" ]), linkopts = ["-shared"], linkshared = 1,