Skip to content

Commit

Permalink
experimenting
Browse files Browse the repository at this point in the history
  • Loading branch information
hypdeb committed Sep 9, 2024
1 parent d297ef0 commit 6d31409
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 37 deletions.
10 changes: 5 additions & 5 deletions cuda/private/rules/cuda_library.bzl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
load("//cuda/private:cuda_helper.bzl", "cuda_helper")
load("//cuda/private:providers.bzl", "CudaInfo")
load("//cuda/private:toolchain.bzl", "find_cuda_toolchain", "use_cpp_toolchain", "use_cuda_toolchain")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain")
load("//cuda/private:actions/compile.bzl", "compile")
load("//cuda/private:actions/dlink.bzl", "device_link")
load("//cuda/private:cuda_helper.bzl", "cuda_helper")
load("//cuda/private:providers.bzl", "CudaInfo")
load("//cuda/private:rules/common.bzl", "ALLOW_CUDA_HDRS", "ALLOW_CUDA_SRCS")
load("//cuda/private:toolchain.bzl", "find_cuda_toolchain", "use_cuda_toolchain")

def _cuda_library_impl(ctx):
"""cuda_library is a rule that perform device link.
Expand Down Expand Up @@ -175,6 +175,6 @@ cuda_library = rule(
"_default_cuda_archs": attr.label(default = "//cuda:archs"),
},
fragments = ["cpp"],
toolchains = use_cpp_toolchain() + use_cuda_toolchain(),
toolchains = use_cpp_toolchain(mandatory = True) + use_cuda_toolchain(),
provides = [DefaultInfo, OutputGroupInfo, CcInfo, CudaInfo],
)
8 changes: 4 additions & 4 deletions cuda/private/rules/cuda_objects.bzl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain")
load("//cuda/private:actions/compile.bzl", "compile")
load("//cuda/private:cuda_helper.bzl", "cuda_helper")
load("//cuda/private:providers.bzl", "CudaInfo")
load("//cuda/private:toolchain.bzl", "find_cuda_toolchain", "use_cpp_toolchain", "use_cuda_toolchain")
load("//cuda/private:actions/compile.bzl", "compile")
load("//cuda/private:rules/common.bzl", "ALLOW_CUDA_HDRS", "ALLOW_CUDA_SRCS")
load("//cuda/private:toolchain.bzl", "find_cuda_toolchain", "use_cuda_toolchain")

def _cuda_objects_impl(ctx):
attr = ctx.attr
Expand Down Expand Up @@ -110,6 +110,6 @@ code and device link time optimization source files.""",
"_default_cuda_archs": attr.label(default = "//cuda:archs"),
},
fragments = ["cpp"],
toolchains = use_cpp_toolchain() + use_cuda_toolchain(),
toolchains = use_cpp_toolchain(mandatory = True) + use_cuda_toolchain(),
provides = [DefaultInfo, OutputGroupInfo, CcInfo, CudaInfo],
)
9 changes: 0 additions & 9 deletions cuda/private/toolchain.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ cuda_toolchain = rule(
CPP_TOOLCHAIN_TYPE = "@bazel_tools//tools/cpp:toolchain_type"
CUDA_TOOLCHAIN_TYPE = "//cuda:toolchain_type"

# buildifier: disable=unused-variable
def use_cpp_toolchain(mandatory = True):
"""Helper to depend on the C++ toolchain.
Notes:
Copied from [toolchain_utils.bzl](https://github.com/bazelbuild/bazel/blob/ac48e65f70/tools/cpp/toolchain_utils.bzl#L53-L72)
"""
return [CPP_TOOLCHAIN_TYPE]

def use_cuda_toolchain():
"""Helper to depend on the CUDA toolchain."""
return [CUDA_TOOLCHAIN_TYPE]
Expand Down
5 changes: 2 additions & 3 deletions cuda/private/toolchain_configs/clang.bzl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_tools//tools/build_defs/cc:action_names.bzl", CC_ACTION_NAMES = "ACTION_NAMES")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain")
load("//cuda/private:action_names.bzl", "ACTION_NAMES")
load("//cuda/private:artifact_categories.bzl", "ARTIFACT_CATEGORIES")
load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo")
load("//cuda/private:toolchain.bzl", "use_cpp_toolchain")
load(
"//cuda/private:toolchain_config_lib.bzl",
"action_config",
Expand Down Expand Up @@ -516,5 +515,5 @@ cuda_toolchain_config = rule(
"_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"), # legacy behaviour
},
provides = [CudaToolchainConfigInfo],
toolchains = use_cpp_toolchain(),
toolchains = use_cpp_toolchain(mandatory = True),
)
5 changes: 2 additions & 3 deletions cuda/private/toolchain_configs/nvcc.bzl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_tools//tools/build_defs/cc:action_names.bzl", CC_ACTION_NAMES = "ACTION_NAMES")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain")
load("//cuda/private:action_names.bzl", "ACTION_NAMES")
load("//cuda/private:artifact_categories.bzl", "ARTIFACT_CATEGORIES")
load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo")
load("//cuda/private:toolchain.bzl", "use_cpp_toolchain")
load(
"//cuda/private:toolchain_config_lib.bzl",
"action_config",
Expand Down Expand Up @@ -526,5 +525,5 @@ cuda_toolchain_config = rule(
"_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"), # legacy behaviour
},
provides = [CudaToolchainConfigInfo],
toolchains = use_cpp_toolchain(),
toolchains = use_cpp_toolchain(mandatory = True),
)
5 changes: 2 additions & 3 deletions cuda/private/toolchain_configs/nvcc_msvc.bzl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_tools//tools/build_defs/cc:action_names.bzl", CC_ACTION_NAMES = "ACTION_NAMES")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain")
load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain")
load("//cuda/private:action_names.bzl", "ACTION_NAMES")
load("//cuda/private:artifact_categories.bzl", "ARTIFACT_CATEGORIES")
load("//cuda/private:providers.bzl", "CudaToolchainConfigInfo", "CudaToolkitInfo")
load("//cuda/private:toolchain.bzl", "use_cpp_toolchain")
load(
"//cuda/private:toolchain_config_lib.bzl",
"action_config",
Expand Down Expand Up @@ -614,5 +613,5 @@ cuda_toolchain_config = rule(
"_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"), # legacy behaviour
},
provides = [CudaToolchainConfigInfo],
toolchains = use_cpp_toolchain(),
toolchains = use_cpp_toolchain(mandatory = True),
)
18 changes: 8 additions & 10 deletions docs/developer_docs.bzl
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
load("@rules_cuda//cuda/private:cuda_helper.bzl", _cuda_helper = "cuda_helper")
load("@rules_cuda//cuda/private:actions/compile.bzl", _compile = "compile")
load("@rules_cuda//cuda/private:actions/dlink.bzl", _device_link = "device_link")
load(
"@rules_cuda//cuda/private:toolchain.bzl",
_find_cuda_toolchain = "find_cuda_toolchain",
_find_cuda_toolkit = "find_cuda_toolkit",
_use_cpp_toolchain = "use_cpp_toolchain",
_use_cuda_toolchain = "use_cuda_toolchain",
)
load("@rules_cuda//cuda/private:toolchain_config_lib.bzl", _config_helper = "config_helper")
load("@rules_cuda//cuda/private:cuda_helper.bzl", _cuda_helper = "cuda_helper")
load(
"@rules_cuda//cuda/private:repositories.bzl",
_config_clang = "config_clang",
_config_cuda_toolkit_and_nvcc = "config_cuda_toolkit_and_nvcc",
_detect_clang = "detect_clang",
_detect_cuda_toolkit = "detect_cuda_toolkit",
)
load(
"@rules_cuda//cuda/private:toolchain.bzl",
_find_cuda_toolchain = "find_cuda_toolchain",
_find_cuda_toolkit = "find_cuda_toolkit",
_use_cuda_toolchain = "use_cuda_toolchain",
)
load("@rules_cuda//cuda/private:toolchain_config_lib.bzl", _config_helper = "config_helper")

# create a struct to group toolchain symbols semantically
toolchain = struct(
use_cpp_toolchain = _use_cpp_toolchain,
use_cuda_toolchain = _use_cuda_toolchain,
find_cuda_toolchain = _find_cuda_toolchain,
find_cuda_toolkit = _find_cuda_toolkit,
Expand Down

0 comments on commit 6d31409

Please sign in to comment.