diff --git a/cuda/defs.bzl b/cuda/defs.bzl index ba1aa4ab..70459824 100644 --- a/cuda/defs.bzl +++ b/cuda/defs.bzl @@ -3,7 +3,7 @@ Core rules for building CUDA projects. """ load("//cuda/private:providers.bzl", _CudaArchsInfo = "CudaArchsInfo", _cuda_archs = "cuda_archs") -load("//cuda/private:os_helpers.bzl", _if_linux = "if_linux", _if_windows = "if_windows") +load("//cuda/private:os_helpers.bzl", _cc_import_versioned_sos = "cc_import_versioned_sos", _if_linux = "if_linux", _if_windows = "if_windows") load("//cuda/private:rules/cuda_objects.bzl", _cuda_objects = "cuda_objects") load("//cuda/private:rules/cuda_library.bzl", _cuda_library = "cuda_library") load("//cuda/private:rules/cuda_toolkit.bzl", _cuda_toolkit = "cuda_toolkit") @@ -33,3 +33,5 @@ cuda_library = _cuda_library if_linux = _if_linux if_windows = _if_windows + +cc_import_versioned_sos = _cc_import_versioned_sos diff --git a/cuda/private/os_helpers.bzl b/cuda/private/os_helpers.bzl index f9068a3e..50bcf73e 100644 --- a/cuda/private/os_helpers.bzl +++ b/cuda/private/os_helpers.bzl @@ -1,3 +1,5 @@ +load("@bazel_skylib//lib:paths.bzl", "paths") + def if_linux(if_true, if_false = []): return select({ "@platforms//os:linux": if_true, @@ -9,3 +11,27 @@ def if_windows(if_true, if_false = []): "@platforms//os:windows": if_true, "//conditions:default": if_false, }) + +def cc_import_versioned_sos(name, shared_library): + """Creates a cc_library that depends on all versioned .so files with the given prefix. + + If is path/to/foo.so, and it is a symlink to foo.so., + this should be used instead of cc_import. + The versioned files are typically needed at runtime, but not at build time. + + Args: + name: Name of the cc_library. + shared_library: Prefix of the versioned .so files. + """ + so_paths = native.glob([shared_library + "*"]) + + [native.cc_import( + name = paths.basename(p), + shared_library = p, + target_compatible_with = ["@platforms//os:linux"], + ) for p in so_paths] + + native.cc_library( + name = name, + deps = [":%s" % paths.basename(p) for p in so_paths], + ) diff --git a/cuda/runtime/BUILD.local_cuda b/cuda/runtime/BUILD.local_cuda index a775673f..6f5cae4b 100644 --- a/cuda/runtime/BUILD.local_cuda +++ b/cuda/runtime/BUILD.local_cuda @@ -1,5 +1,4 @@ -load("@rules_cuda//cuda:defs.bzl", "if_linux", "if_windows") -load(":defs.bzl", "if_local_cuda") +load("@rules_cuda//cuda:defs.bzl", "cc_import_versioned_sos", "if_linux", "if_windows") package( default_visibility = ["//visibility:public"], @@ -8,9 +7,9 @@ package( filegroup( name = "compiler_deps", srcs = [ + "cuda/version.txt", ":_cuda_header_files", ] + glob([ - "cuda/version.txt", "cuda/bin/**", "cuda/lib64/**", "cuda/nvvm/**", @@ -39,10 +38,9 @@ cc_library( ]), ) -cc_import( +cc_import_versioned_sos( name = "cudart_so", shared_library = "cuda/lib64/libcudart.so", - target_compatible_with = ["@platforms//os:linux"], ) cc_library( @@ -136,16 +134,14 @@ cc_library( ]), ) -cc_import( +cc_import_versioned_sos( name = "cublas_so", shared_library = "cuda/lib64/libcublas.so", - target_compatible_with = ["@platforms//os:linux"], ) -cc_import( +cc_import_versioned_sos( name = "cublasLt_so", shared_library = "cuda/lib64/libcublasLt.so", - target_compatible_with = ["@platforms//os:linux"], ) cc_import( @@ -193,10 +189,9 @@ cc_library( ) # CUPTI -cc_import( +cc_import_versioned_sos( name = "cupti_so", shared_library = "cuda/lib64/libcupti.so", - target_compatible_with = ["@platforms//os:linux"], ) cc_import( @@ -301,10 +296,9 @@ cc_library( ) # curand -cc_import( +cc_import_versioned_sos( name = "curand_so", shared_library = "cuda/lib64/libcurand.so", - target_compatible_with = ["@platforms//os:linux"], ) cc_import( @@ -351,19 +345,17 @@ cc_library( "cuda/include", ], visibility = ["//visibility:public"], - deps = [] + - if_linux([ - ":nvptxcompiler_so" + deps = [] + if_linux([ + ":nvptxcompiler_so", ]) + if_windows([ - ":nvptxcompiler_lib" - ]) + ":nvptxcompiler_lib", + ]), ) # cufft -cc_import( +cc_import_versioned_sos( name = "cufft_so", shared_library = "cuda/lib64/libcufft.so", - target_compatible_with = ["@platforms//os:linux"], ) cc_import( @@ -373,10 +365,9 @@ cc_import( target_compatible_with = ["@platforms//os:windows"], ) -cc_import( +cc_import_versioned_sos( name = "cufftw_so", shared_library = "cuda/lib64/libcufftw.so", - target_compatible_with = ["@platforms//os:linux"], ) cc_import( @@ -392,7 +383,7 @@ cc_library( ":cuda_headers", ] + if_linux([ ":cufft_so", - ":cufftw_so" + ":cufftw_so", ]) + if_windows([ ":cufft_lib", ":cufftw_lib", @@ -400,10 +391,9 @@ cc_library( ) # cusolver -cc_import( +cc_import_versioned_sos( name = "cusolver_so", shared_library = "cuda/lib64/libcusolver.so", - target_compatible_with = ["@platforms//os:linux"], ) cc_import( @@ -425,10 +415,9 @@ cc_library( ) # cusparse -cc_import( +cc_import_versioned_sos( name = "cusparse_so", shared_library = "cuda/lib64/libcusparse.so", - target_compatible_with = ["@platforms//os:linux"], ) cc_import( @@ -450,10 +439,9 @@ cc_library( ) # nvtx -cc_import( +cc_import_versioned_sos( name = "nvtx_so", shared_library = "cuda/lib64/libnvToolsExt.so", - target_compatible_with = ["@platforms//os:linux"], ) cc_import( @@ -505,10 +493,9 @@ _NPP_LIBS = { } [ - cc_import( + cc_import_versioned_sos( name = name + "_so", shared_library = "cuda/lib64/lib{}.so".format(name), - target_compatible_with = ["@platforms//os:linux"], ) for name in _NPP_LIBS.keys() ]