diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c1b21085a3bb44..7190401322b1b2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -17,7 +17,6 @@ load( ) load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl") load("//tensorflow:tensorflow.bzl", "if_nccl") -load("//tensorflow:tensorflow.bzl", "nccl_config") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests") load( @@ -38,8 +37,8 @@ load( "if_mkl_ml", "mkl_deps", ) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda", "if_cuda_is_configured") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm", "if_rocm_is_configured") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") # Description: # Op kernel implementations for TensorFlow. @@ -193,6 +192,16 @@ tf_cc_test( ], ) +# virtual targets since nested select statements not possible +tf_kernel_library( + name = "virtual_nccl", + deps = if_cuda(["@local_config_nccl//:nccl"]), +) +tf_kernel_library( + name = "virtual_rccl", + deps = if_rocm(["@local_config_rocm//rocm:rccl"]), +) + tf_kernel_library( name = "collective_ops", srcs = if_nccl([ @@ -207,9 +216,10 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/profiler/lib:traceme", ] + if_nccl([ + ":virtual_nccl", + ":virtual_rccl", "//tensorflow/core/nccl:nccl_lib", - ] + nccl_config() - ), + ]), ) tf_cuda_cc_test( @@ -376,16 +386,19 @@ cc_library( tf_kernel_library( name = "nccl_kernels", - srcs = if_nccl([ + srcs = if_cuda_or_rocm([ "nccl_ops.cc", ]), - deps = if_nccl([ + deps = if_cuda([ + "@local_config_nccl//:nccl", + ]) + if_rocm([ + "@local_config_rocm//rocm:rccl", + ]) + if_cuda_or_rocm([ "//tensorflow/core/nccl:nccl_lib", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:gpu_runtime", - ] + nccl_config() - ), + ]), ) cc_library( @@ -3258,7 +3271,7 @@ cc_library( ":cholesky_grad", ":cholesky_op", ":determinant_op", - ":einsum_op", # disabling to move past build errors (no_rocm, no_cuda) + ":einsum_op", ":lu_op", ":matrix_exponential_op", ":matrix_inverse_op", diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index 095e6c5ca9e47a..e6419c9dfcff3c 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -4,10 +4,9 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_copts") -load("//tensorflow:tensorflow.bzl", "if_nccl") -load("//tensorflow:tensorflow.bzl", "nccl_config") -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") -load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") +load("//tensorflow:tensorflow.bzl", "if_cuda_or_rocm") load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", @@ -22,23 +21,26 @@ exports_files(["LICENSE"]) cc_library( name = "nccl_lib", - srcs = if_nccl([ + srcs = if_cuda_or_rocm([ "nccl_manager.cc", "nccl_rewrite.cc", ]), - hdrs = if_nccl([ + hdrs = if_cuda_or_rocm([ "nccl_manager.h", ]), copts = tf_copts(), - deps = if_nccl([ + deps = if_cuda([ + "@local_config_nccl//:nccl" + ]) + if_rocm([ + "@local_config_rocm//rocm:rccl" + ]) + if_cuda_or_rocm([ "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", "//tensorflow/core:stream_executor", - ] + nccl_config() - ), + ]), alwayslink = 1, ) @@ -53,12 +55,13 @@ tf_cuda_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - ] + if_nccl([ + ] + if_cuda_or_rocm([ ":nccl_lib", - ] + nccl_config() - ) + if_cuda_is_configured([ + ]) + if_cuda([ + "@local_config_nccl//:nccl", "//tensorflow/core:cuda", - ]) + if_rocm_is_configured([ + ]) + if_rocm([ + "@local_config_rocm//rocm:rccl", "//tensorflow/core:rocm", ]), ) diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index e206b9f830b177..69e1de54d63602 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -245,11 +245,6 @@ def if_nccl(a): "//conditions:default": a, }) -def nccl_config(): - return (if_rocm_is_configured(["@local_config_rocm//rocm:rccl"]) - or if_cuda_is_configured(["@local_config_nccl//:nccl"]) - or []) - def get_win_copts(is_external = False): WINDOWS_COPTS = [ "/DPLATFORM_WINDOWS",