Skip to content

Commit

Permalink
Merge pull request #572 from jeffdaily/rccl_fix_build_files
Browse files Browse the repository at this point in the history
rework rccl BUILD rules to match more closely with upstream
  • Loading branch information
whchung authored Jul 22, 2019
2 parents 593010e + 2244aee commit aa50da8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 28 deletions.
33 changes: 23 additions & 10 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ load(
)
load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
load("//tensorflow:tensorflow.bzl", "if_nccl")
load("//tensorflow:tensorflow.bzl", "nccl_config")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
load(
Expand All @@ -38,8 +37,8 @@ load(
"if_mkl_ml",
"mkl_deps",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda", "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm", "if_rocm_is_configured")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")

# Description:
# Op kernel implementations for TensorFlow.
Expand Down Expand Up @@ -193,6 +192,16 @@ tf_cc_test(
],
)

# virtual targets since nested select statements not possible
tf_kernel_library(
name = "virtual_nccl",
deps = if_cuda(["@local_config_nccl//:nccl"]),
)
tf_kernel_library(
name = "virtual_rccl",
deps = if_rocm(["@local_config_rocm//rocm:rccl"]),
)

tf_kernel_library(
name = "collective_ops",
srcs = if_nccl([
Expand All @@ -207,9 +216,10 @@ tf_kernel_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
] + if_nccl([
":virtual_nccl",
":virtual_rccl",
"//tensorflow/core/nccl:nccl_lib",
] + nccl_config()
),
]),
)

tf_cuda_cc_test(
Expand Down Expand Up @@ -376,16 +386,19 @@ cc_library(

tf_kernel_library(
name = "nccl_kernels",
srcs = if_nccl([
srcs = if_cuda_or_rocm([
"nccl_ops.cc",
]),
deps = if_nccl([
deps = if_cuda([
"@local_config_nccl//:nccl",
]) + if_rocm([
"@local_config_rocm//rocm:rccl",
]) + if_cuda_or_rocm([
"//tensorflow/core/nccl:nccl_lib",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:gpu_runtime",
] + nccl_config()
),
]),
)

cc_library(
Expand Down Expand Up @@ -3258,7 +3271,7 @@ cc_library(
":cholesky_grad",
":cholesky_op",
":determinant_op",
":einsum_op", # disabling to move past build errors (no_rocm, no_cuda)
":einsum_op",
":lu_op",
":matrix_exponential_op",
":matrix_inverse_op",
Expand Down
29 changes: 16 additions & 13 deletions tensorflow/core/nccl/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "if_nccl")
load("//tensorflow:tensorflow.bzl", "nccl_config")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("//tensorflow:tensorflow.bzl", "if_cuda_or_rocm")
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"tf_cuda_tests_tags",
Expand All @@ -22,23 +21,26 @@ exports_files(["LICENSE"])

cc_library(
name = "nccl_lib",
srcs = if_nccl([
srcs = if_cuda_or_rocm([
"nccl_manager.cc",
"nccl_rewrite.cc",
]),
hdrs = if_nccl([
hdrs = if_cuda_or_rocm([
"nccl_manager.h",
]),
copts = tf_copts(),
deps = if_nccl([
deps = if_cuda([
"@local_config_nccl//:nccl"
]) + if_rocm([
"@local_config_rocm//rocm:rccl"
]) + if_cuda_or_rocm([
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor",
] + nccl_config()
),
]),
alwayslink = 1,
)

Expand All @@ -53,12 +55,13 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
] + if_nccl([
] + if_cuda_or_rocm([
":nccl_lib",
] + nccl_config()
) + if_cuda_is_configured([
]) + if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core:cuda",
]) + if_rocm_is_configured([
]) + if_rocm([
"@local_config_rocm//rocm:rccl",
"//tensorflow/core:rocm",
]),
)
5 changes: 0 additions & 5 deletions tensorflow/tensorflow.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,6 @@ def if_nccl(a):
"//conditions:default": a,
})

def nccl_config():
return (if_rocm_is_configured(["@local_config_rocm//rocm:rccl"])
or if_cuda_is_configured(["@local_config_nccl//:nccl"])
or [])

def get_win_copts(is_external = False):
WINDOWS_COPTS = [
"/DPLATFORM_WINDOWS",
Expand Down

0 comments on commit aa50da8

Please sign in to comment.