Skip to content

Commit

Permalink
Add use_cuda_nvcc flag to build.py to enable compilation of CUDA code…
Browse files Browse the repository at this point in the history
… with clang. If --use_cuda_nvcc NVCC compiler will be used to build CUDA code (default case), if --nouse_cuda_nvcc then Clang will be used to build CUDA code.

Refactor .bazelrc configs to match the new flag and to cleanup all previous confusing names

PiperOrigin-RevId: 676660938
  • Loading branch information
Google-ML-Automation committed Sep 20, 2024
1 parent 1db47fd commit 38ad98f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
25 changes: 12 additions & 13 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,6 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# Requires MSVC and LLVM to be installed
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
build:win_clang --compiler=clang-cl

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
Expand All @@ -89,6 +81,8 @@ build:win_clang --compiler=clang-cl
# acceptable, because the workaround is "remove the nvidia-..." pip packages.
# The list of CUDA pip packages that JAX depends on are present in setup.py.
build:cuda --linkopt=-Wl,--disable-new-dtags
# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true

build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
Expand All @@ -102,10 +96,10 @@ build:cuda_clang --copt=-Wno-gnu-offsetof-extensions
build:cuda_clang --copt=-Qunused-arguments

# Build with nvcc for CUDA and clang for host
build:nvcc_clang --config=cuda
build:nvcc_clang --config=cuda_clang
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc
build:cuda_nvcc --config=cuda
build:cuda_nvcc --config=cuda_clang
build:cuda_nvcc --action_env=TF_NVCC_CLANG="1"
build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc

build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
Expand All @@ -114,6 +108,11 @@ build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1

build:nonccl --define=no_nccl_support=true

# Requires MSVC and LLVM to be installed
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
build:win_clang --compiler=clang-cl

# Windows has a relatively short command line limit, which JAX has begun to hit.
# See https://docs.bazel.build/versions/main/windows.html
build:windows --features=compiler_param_file
Expand Down Expand Up @@ -223,7 +222,7 @@ build:rbe_linux_cuda_base --config=cuda
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1

build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda12.3_nvcc_base --config=nvcc_clang
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_nvcc
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain"
Expand Down
19 changes: 15 additions & 4 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def write_bazelrc(*, remote_build,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, target_cpu_features,
wheel_cpu, enable_mkl_dnn, use_clang, clang_path,
clang_major_version, enable_cuda, enable_nccl, enable_rocm,
python_version):
clang_major_version, enable_cuda, use_cuda_nvcc,
enable_nccl, enable_rocm, python_version):

with open("../.jax_configure.bazelrc", "w") as f:
if not remote_build:
Expand Down Expand Up @@ -286,8 +286,9 @@ def write_bazelrc(*, remote_build,
if not enable_nccl:
f.write("build --config=nonccl\n")
if use_clang:
f.write("build --config=nvcc_clang\n")
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
if use_cuda_nvcc:
f.write("build --config=cuda_nvcc\n")
if cuda_version:
f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n"
.format(cuda_version=cuda_version))
Expand Down Expand Up @@ -413,7 +414,16 @@ def main():
add_boolean_argument(
parser,
"enable_cuda",
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN."
)
add_boolean_argument(
parser,
"use_cuda_nvcc",
default=True,
help_str=(
"Should we build CUDA using NVCC as the compiler? The default value is true."
),
)
add_boolean_argument(
parser,
"build_gpu_plugin",
Expand Down Expand Up @@ -618,6 +628,7 @@ def main():
clang_path=clang_path,
clang_major_version=clang_major_version,
enable_cuda=args.enable_cuda,
use_cuda_nvcc=args.use_cuda_nvcc,
enable_nccl=args.enable_nccl,
enable_rocm=args.enable_rocm,
python_version=python_version,
Expand Down

0 comments on commit 38ad98f

Please sign in to comment.