Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --use_cuda_nvcc flag to enable or disable compilation of CUDA code using NVCC. #23787

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 27 additions & 25 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ build:native_arch_posix --host_copt=-march=native

build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1

build:clang --action_env=CC="/usr/lib/llvm-18/bin/clang"
# Disable clang extention that rejects type definitions within offsetof.
# This was added in clang-16 by https://reviews.llvm.org/D133574.
# Can be removed once upb is updated, since a type definition is used within
# offset of in the current version of ubp.
# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
build:clang --copt=-Wno-gnu-offsetof-extensions
# Disable clang extention that rejects unknown arguments.
build:clang --copt=-Qunused-arguments

build:cuda --repo_env TF_NEED_CUDA=1
build:cuda --repo_env TF_NCCL_USE_STUB=1
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
Expand All @@ -68,14 +78,6 @@ build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true
# Default hermetic CUDA and CUDNN versions.
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# Requires MSVC and LLVM to be installed
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
build:win_clang --compiler=clang-cl

# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
Expand All @@ -89,23 +91,18 @@ build:win_clang --compiler=clang-cl
# acceptable, because the workaround is "remove the nvidia-..." pip packages.
# The list of CUDA pip packages that JAX depends on are present in setup.py.
build:cuda --linkopt=-Wl,--disable-new-dtags
build:cuda --@local_config_cuda//:cuda_compiler=clang
build:cuda --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"

build:cuda_clang --@local_config_cuda//:cuda_compiler=clang
build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
# Disable clang extention that rejects type definitions within offsetof.
# This was added in clang-16 by https://reviews.llvm.org/D133574.
# Can be removed once upb is updated, since a type definition is used within
# offset of in the current version of ubp.
# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183.
build:cuda_clang --copt=-Wno-gnu-offsetof-extensions
# Disable clang extention that rejects unknown arguments.
build:cuda_clang --copt=-Qunused-arguments
# This flag is needed to include CUDA libraries for bazel tests.
test:cuda --@local_config_cuda//cuda:include_cuda_libs=true

# Build with nvcc for CUDA and clang for host
build:nvcc_clang --config=cuda
build:nvcc_clang --config=cuda_clang
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc
# Build with NVCC for CUDA
build:cuda_nvcc --config=cuda
build:cuda_nvcc --config=clang
build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc
build:cuda_nvcc --action_env=TF_NVCC_CLANG="1"
build:cuda_nvcc --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"

build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
Expand All @@ -114,6 +111,11 @@ build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1

build:nonccl --define=no_nccl_support=true

# Requires MSVC and LLVM to be installed
build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl
build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl
build:win_clang --compiler=clang-cl

# Windows has a relatively short command line limit, which JAX has begun to hit.
# See https://docs.bazel.build/versions/main/windows.html
build:windows --features=compiler_param_file
Expand Down Expand Up @@ -200,7 +202,7 @@ build:rbe_linux --host_linkopt=-lm
# Use the GPU toolchain until the CPU one is ready.
# https://github.com/bazelbuild/bazel/issues/13623
build:rbe_cpu_linux_base --config=rbe_linux
build:rbe_cpu_linux_base --config=cuda_clang
build:rbe_cpu_linux_base --config=clang
build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain"
build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64"
Expand All @@ -223,7 +225,7 @@ build:rbe_linux_cuda_base --config=cuda
build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1

build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base
build:rbe_linux_cuda12.3_nvcc_base --config=nvcc_clang
build:rbe_linux_cuda12.3_nvcc_base --config=cuda_nvcc
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain"
Expand Down
36 changes: 24 additions & 12 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def get_clang_path_or_exit():
return str(pathlib.Path(which_clang_output).resolve())
else:
print(
"--use_clang set, but --clang_path is unset and clang cannot be found"
"--clang_path is unset and clang cannot be found"
" on the PATH. Please pass --clang_path directly."
)
sys.exit(-1)
Expand All @@ -241,8 +241,9 @@ def write_bazelrc(*, remote_build,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, target_cpu_features,
wheel_cpu, enable_mkl_dnn, use_clang, clang_path,
clang_major_version, enable_cuda, enable_nccl, enable_rocm,
python_version):
clang_major_version, python_version,
enable_cuda, enable_nccl, enable_rocm,
use_cuda_nvcc):

with open("../.jax_configure.bazelrc", "w") as f:
if not remote_build:
Expand Down Expand Up @@ -283,11 +284,11 @@ def write_bazelrc(*, remote_build,
f.write("build --config=mkl_open_source_only\n")
if enable_cuda:
f.write("build --config=cuda\n")
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if use_clang:
f.write("build --config=nvcc_clang\n")
f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n")
if use_cuda_nvcc:
f.write("build --config=cuda_nvcc\n")
if cuda_version:
f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n"
.format(cuda_version=cuda_version))
Expand Down Expand Up @@ -392,15 +393,14 @@ def main():
"use_clang",
default = "true",
help_str=(
"Should we build using clang as the host compiler? Requires "
"clang to be findable via the PATH, or a path to be given via "
"--clang_path."
"DEPRECATED: This flag is redundant because clang is "
"always used as default compiler."
),
)
parser.add_argument(
"--clang_path",
help=(
"Path to clang binary to use if --use_clang is set. The default is "
"Path to clang binary to use. The default is "
"to find clang via the PATH."
),
)
Expand All @@ -413,7 +413,18 @@ def main():
add_boolean_argument(
parser,
"enable_cuda",
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.")
help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN."
)
add_boolean_argument(
parser,
"use_cuda_nvcc",
default=True,
help_str=(
"Should we build CUDA code using NVCC compiler driver? The default value "
"is true. If --nouse_cuda_nvcc flag is used then CUDA code is built "
"by clang compiler."
),
)
add_boolean_argument(
parser,
"build_gpu_plugin",
Expand Down Expand Up @@ -617,10 +628,11 @@ def main():
use_clang=args.use_clang,
clang_path=clang_path,
clang_major_version=clang_major_version,
python_version=python_version,
enable_cuda=args.enable_cuda,
enable_nccl=args.enable_nccl,
enable_rocm=args.enable_rocm,
python_version=python_version,
use_cuda_nvcc=args.use_cuda_nvcc,
)

if args.requirements_update or args.requirements_nightly_update:
Expand Down
3 changes: 2 additions & 1 deletion docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ There are two ways to build `jaxlib` with CUDA support: (1) use
support, or (2) use
`python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12`
to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and
jax-cuda-pjrt).
jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and
clang, but it can be restricted to clang via the `--nouse_cuda_nvcc` flag.

See `python build/build.py --help` for configuration options. Here
`python` should be the name of your Python 3 interpreter; on some systems, you
Expand Down
Loading