Skip to content

Commit

Permalink
Reverts b4e36ea
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 679573999
  • Loading branch information
beckerhe authored and Google-ML-Automation committed Sep 27, 2024
1 parent 9b1056c commit 6c3194e
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 116 deletions.
1 change: 0 additions & 1 deletion xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,6 @@ bzl_library(
deps = [
"//xla/tsl:tsl_bzl",
"@bazel_skylib//lib:paths",
"@tsl//tsl/platform/default:cuda_build_defs_bzl",
],
)

Expand Down
7 changes: 1 addition & 6 deletions xla/lit.bzl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Helper rules for writing LIT tests."""

load("@bazel_skylib//lib:paths.bzl", "paths")
load("@tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured")
load("//xla/tsl:tsl.bzl", "if_cuda_tools", "if_google", "if_oss")

def enforce_glob(files, **kwargs):
Expand Down Expand Up @@ -210,11 +209,7 @@ def lit_test(
srcs = tools,
bin_dir = bin_dir,
lib_dir = lib_dir,
deps = if_cuda_is_configured(
[
"//xla/stream_executor/cuda:all_runtime",
],
),
deps = ["//xla/stream_executor/cuda:all_runtime"],
visibility = ["//visibility:private"],
**kwargs
)
Expand Down
1 change: 1 addition & 0 deletions xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ xla_test(
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_compiler",
"//xla/service:hlo_parser",
"//xla/stream_executor/cuda:cublas_plugin",
"//xla/tests:literal_test_util",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest",
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/fusions/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ xla_test(
"//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cublas_plugin",
"//xla/tests:filecheck",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
Expand Down Expand Up @@ -268,6 +269,7 @@ xla_test(
"//xla/service/gpu/model:tiled_hlo_instruction_or_computation",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cublas_plugin",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"//xla/tsl/lib/core:status_test_util",
Expand Down Expand Up @@ -399,6 +401,7 @@ xla_test(
"//xla/hlo/ir:hlo",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"//xla/stream_executor/cuda:cublas_plugin",
"//xla/tests:xla_internal_test_main", # fixdeps: keep
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ xla_test(
"//xla/stream_executor",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor/cuda:cuda_platform",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/random",
"@com_google_absl//absl/strings",
Expand Down
30 changes: 30 additions & 0 deletions xla/stream_executor/build_defs.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Configurations for StreamExecutor builds"""

load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load(
"@local_config_rocm//rocm:build_defs.bzl",
_if_cuda_or_rocm = "if_cuda_or_rocm",
Expand Down Expand Up @@ -63,5 +64,34 @@ def gpu_only_cc_library(name, tags = [], **kwargs):
target_compatible_with = kwargs.get("target_compatible_with"),
)

def cuda_only_cc_library(name, tags = [], **kwargs):
"""A library that only gets compiled when CUDA is configured, otherwise it's an empty target.
Args:
name: Name of the target
tags: Tags being applied to the implementation target
**kwargs: Accepts all arguments that a `cc_library` would also accept
"""
if not native.package_name().startswith("xla/stream_executor"):
fail("cuda_only_cc_library may only be used in `xla/stream_executor/...`.")

cc_library(
name = "%s_non_cuda" % name,
tags = ["manual"],
)
cc_library(
name = "%s_cuda_only" % name,
tags = tags + ["manual", "cuda-only"],
**kwargs
)
native.alias(
name = name,
actual = if_cuda_is_configured(":%s_cuda_only" % name, ":%s_non_cuda" % name),
visibility = kwargs.get("visibility"),
compatible_with = kwargs.get("compatible_with"),
restricted_to = kwargs.get("restricted_to"),
target_compatible_with = kwargs.get("target_compatible_with"),
)

def stream_executor_build_defs_bzl_deps():
return []
Loading

0 comments on commit 6c3194e

Please sign in to comment.