Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include versioned so files in cuda_runtime #114

Merged
merged 1 commit into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cuda/defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Core rules for building CUDA projects.
"""

load("//cuda/private:providers.bzl", _CudaArchsInfo = "CudaArchsInfo", _cuda_archs = "cuda_archs")
load("//cuda/private:os_helpers.bzl", _if_linux = "if_linux", _if_windows = "if_windows")
load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows")
load("//cuda/private:rules/cuda_objects.bzl", _cuda_objects = "cuda_objects")
load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library")
load("//cuda/private:rules/cuda_toolkit.bzl", _cuda_toolkit = "cuda_toolkit")
Expand Down Expand Up @@ -33,3 +33,5 @@ cuda_library = _cuda_library

if_linux = _if_linux
if_windows = _if_windows

cc_import_versioned_sos = _cc_import_versioned_sos
26 changes: 26 additions & 0 deletions cuda/private/os_helpers.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
load("@bazel_skylib//lib:paths.bzl", "paths")

def if_linux(if_true, if_false = []):
return select({
"@platforms//os:linux": if_true,
Expand All @@ -9,3 +11,27 @@ def if_windows(if_true, if_false = []):
"@platforms//os:windows": if_true,
"//conditions:default": if_false,
})

def cc_import_versioned_sos(name, shared_library):
"""Creates a cc_library that depends on all versioned .so files with the given prefix.

If <shared_library> is path/to/foo.so, and it is a symlink to foo.so.<version>,
this should be used instead of cc_import.
The versioned files are typically needed at runtime, but not at build time.

Args:
name: Name of the cc_library.
shared_library: Prefix of the versioned .so files.
"""
so_paths = native.glob([shared_library + "*"])

[native.cc_import(
name = paths.basename(p),
shared_library = p,
target_compatible_with = ["@platforms//os:linux"],
) for p in so_paths]

native.cc_library(
name = name,
deps = [":%s" % paths.basename(p) for p in so_paths],
)
49 changes: 18 additions & 31 deletions cuda/runtime/BUILD.local_cuda
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
load("@rules_cuda//cuda:defs.bzl", "if_linux", "if_windows")
load(":defs.bzl", "if_local_cuda")
load("@rules_cuda//cuda:defs.bzl", "cc_import_versioned_sos", "if_linux", "if_windows")

package(
default_visibility = ["//visibility:public"],
Expand All @@ -8,9 +7,9 @@ package(
filegroup(
name = "compiler_deps",
srcs = [
"cuda/version.txt",
":_cuda_header_files",
] + glob([
"cuda/version.txt",
"cuda/bin/**",
"cuda/lib64/**",
"cuda/nvvm/**",
Expand Down Expand Up @@ -39,10 +38,9 @@ cc_library(
]),
)

cc_import(
cc_import_versioned_sos(
name = "cudart_so",
shared_library = "cuda/lib64/libcudart.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_library(
Expand Down Expand Up @@ -136,16 +134,14 @@ cc_library(
]),
)

cc_import(
cc_import_versioned_sos(
name = "cublas_so",
shared_library = "cuda/lib64/libcublas.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
cc_import_versioned_sos(
name = "cublasLt_so",
shared_library = "cuda/lib64/libcublasLt.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
Expand Down Expand Up @@ -193,10 +189,9 @@ cc_library(
)

# CUPTI
cc_import(
cc_import_versioned_sos(
name = "cupti_so",
shared_library = "cuda/lib64/libcupti.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
Expand Down Expand Up @@ -301,10 +296,9 @@ cc_library(
)

# curand
cc_import(
cc_import_versioned_sos(
name = "curand_so",
shared_library = "cuda/lib64/libcurand.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
Expand Down Expand Up @@ -351,19 +345,17 @@ cc_library(
"cuda/include",
],
visibility = ["//visibility:public"],
deps = [] +
if_linux([
":nvptxcompiler_so"
deps = [] + if_linux([
":nvptxcompiler_so",
]) + if_windows([
":nvptxcompiler_lib"
])
":nvptxcompiler_lib",
]),
)

# cufft
cc_import(
cc_import_versioned_sos(
name = "cufft_so",
shared_library = "cuda/lib64/libcufft.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
Expand All @@ -373,10 +365,9 @@ cc_import(
target_compatible_with = ["@platforms//os:windows"],
)

cc_import(
cc_import_versioned_sos(
name = "cufftw_so",
shared_library = "cuda/lib64/libcufftw.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
Expand All @@ -392,18 +383,17 @@ cc_library(
":cuda_headers",
] + if_linux([
":cufft_so",
":cufftw_so"
":cufftw_so",
]) + if_windows([
":cufft_lib",
":cufftw_lib",
]),
)

# cusolver
cc_import(
cc_import_versioned_sos(
name = "cusolver_so",
shared_library = "cuda/lib64/libcusolver.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
Expand All @@ -425,10 +415,9 @@ cc_library(
)

# cusparse
cc_import(
cc_import_versioned_sos(
name = "cusparse_so",
shared_library = "cuda/lib64/libcusparse.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
Expand All @@ -450,10 +439,9 @@ cc_library(
)

# nvtx
cc_import(
cc_import_versioned_sos(
name = "nvtx_so",
shared_library = "cuda/lib64/libnvToolsExt.so",
target_compatible_with = ["@platforms//os:linux"],
)

cc_import(
Expand Down Expand Up @@ -505,10 +493,9 @@ _NPP_LIBS = {
}

[
cc_import(
cc_import_versioned_sos(
name = name + "_so",
shared_library = "cuda/lib64/lib{}.so".format(name),
target_compatible_with = ["@platforms//os:linux"],
)
for name in _NPP_LIBS.keys()
]
Expand Down