From 0b44032de40d98b08392c5f5f954191063d80a13 Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Tue, 23 Jan 2024 10:24:07 +0000 Subject: [PATCH] [ROCm] Fix build after removal of cuda_kernels library --- xla/stream_executor/gpu/BUILD | 4 +++ xla/stream_executor/gpu/gpu_command_buffer.cc | 4 +++ xla/stream_executor/rocm/BUILD | 6 ++++ .../rocm/hip_noop_kernel.cu.cc | 28 +++++++++++++++++++ 4 files changed, 42 insertions(+) create mode 100644 xla/stream_executor/rocm/hip_noop_kernel.cu.cc diff --git a/xla/stream_executor/gpu/BUILD b/xla/stream_executor/gpu/BUILD index 8e8e6e88b6fec..ba70a62836b29 100644 --- a/xla/stream_executor/gpu/BUILD +++ b/xla/stream_executor/gpu/BUILD @@ -147,6 +147,9 @@ cc_library( name = "gpu_command_buffer", srcs = if_gpu_is_configured(["gpu_command_buffer.cc"]), hdrs = if_gpu_is_configured(["gpu_command_buffer.h"]), + local_defines = if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ ":gpu_driver_header", ":gpu_executor_header", @@ -174,6 +177,7 @@ cc_library( "//xla/stream_executor/cuda:cuda_conditional_kernels", ]) + if_rocm_is_configured([ "//xla/stream_executor/rocm:hip_conditional_kernels", + "//xla/stream_executor/rocm:hip_noop_kernel", ]), ) diff --git a/xla/stream_executor/gpu/gpu_command_buffer.cc b/xla/stream_executor/gpu/gpu_command_buffer.cc index 46047226ffc98..1663fbf6e77d5 100644 --- a/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -197,7 +197,11 @@ absl::StatusOr GpuCommandBuffer::GetNoOpKernel( auto noop_kernel = std::make_unique(executor); MultiKernelLoaderSpec spec(/*arity=*/0); +#if !defined(TENSORFLOW_USE_ROCM) spec.AddCudaPtxInMemory(gpu::kNoOpKernel, "noop"); +#else + spec.AddInProcessSymbol(gpu::GetNoOpKernel(), "noop"); +#endif // TENSORFLOW_USE_ROCM TF_RETURN_IF_ERROR(executor->GetKernel(spec, noop_kernel.get())); noop_kernel_ = std::move(noop_kernel); diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 866e67976ea3e..0130053b92755 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -176,6 +176,12 @@ rocm_library( deps = if_rocm_is_configured(["@local_config_rocm//rocm:rocm_headers"]), ) +rocm_library( + name = "hip_noop_kernel", + srcs = if_rocm_is_configured(["hip_noop_kernel.cu.cc"]), + deps = if_rocm_is_configured(["@local_config_rocm//rocm:rocm_headers"]), +) + cc_library( name = "rocm_platform", srcs = if_rocm_is_configured(["rocm_platform.cc"]), diff --git a/xla/stream_executor/rocm/hip_noop_kernel.cu.cc b/xla/stream_executor/rocm/hip_noop_kernel.cu.cc new file mode 100644 index 0000000000000..4af3776c81b89 --- /dev/null +++ b/xla/stream_executor/rocm/hip_noop_kernel.cu.cc @@ -0,0 +1,28 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +namespace stream_executor { +namespace rocm { +namespace { + +__global__ void NoOp() {} + +} // namespace +} // namespace rocm + +namespace gpu { +void* GetNoOpKernel() { return reinterpret_cast(&rocm::NoOp); } +} // namespace gpu + +} // namespace stream_executor \ No newline at end of file