Skip to content

Commit

Permalink
Merge pull request #556 from jeffdaily/rccl_stream_hack
Browse files Browse the repository at this point in the history
create nccl stream after compute stream, add to GPUDeviceContext
  • Loading branch information
whchung authored Jul 19, 2019
2 parents a15f429 + ab67c33 commit 593010e
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 25 deletions.
10 changes: 8 additions & 2 deletions tensorflow/core/common_runtime/gpu/gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,11 @@ class BaseGPUDevice::StreamGroupFactory {
VLOG(2) << "Created stream[" << stream_group_within_gpu
<< "] = " << group->compute;

group->nccl = new se::Stream(executor);
group->nccl->Init();
VLOG(2) << "Created nccl_stream[" << stream_group_within_gpu
<< "] = " << group->nccl;

group->host_to_device = new se::Stream(executor);
group->host_to_device->Init();
VLOG(2) << "Created host_to_device_stream[" << stream_group_within_gpu
Expand Down Expand Up @@ -371,8 +376,9 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
streams_.push_back(StreamGroupFactory::Global().GetOrCreate(
tf_gpu_id_, i, executor_, options.config.gpu_options()));
device_contexts_.push_back(new GPUDeviceContext(
i, streams_.back()->compute, streams_.back()->host_to_device,
streams_.back()->device_to_host, streams_.back()->device_to_device));
i, streams_.back()->compute, streams_.back()->nccl,
streams_.back()->host_to_device, streams_.back()->device_to_host,
streams_.back()->device_to_device));
}

em_ = EventMgrFactory::Singleton()->GetEventMgr(executor_,
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/common_runtime/gpu/gpu_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class BaseGPUDevice : public LocalDevice {
friend class GPUDeviceTestHelper;
struct StreamGroup {
se::Stream* compute = nullptr;
se::Stream* nccl = nullptr;
se::Stream* host_to_device = nullptr;
se::Stream* device_to_host = nullptr;
gtl::InlinedVector<se::Stream*, 4> device_to_device;
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/core/common_runtime/gpu_device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,21 @@ class GPUDeviceContext : public DeviceContext {
public:
// Does not take ownership of streams.
GPUDeviceContext(int stream_id, se::Stream* stream,
se::Stream* nccl_stream,
se::Stream* host_to_device_stream,
se::Stream* device_to_host_stream,
gtl::InlinedVector<se::Stream*, 4> device_to_device_stream)
: stream_id_(stream_id),
stream_(stream),
nccl_stream_(nccl_stream),
host_to_device_stream_(host_to_device_stream),
device_to_host_stream_(device_to_host_stream),
device_to_device_stream_(device_to_device_stream) {}

~GPUDeviceContext() override {}

se::Stream* stream() const override { return stream_; }
se::Stream* nccl_stream() const { return nccl_stream_; }
se::Stream* host_to_device_stream() const { return host_to_device_stream_; }
se::Stream* device_to_host_stream() const { return device_to_host_stream_; }
se::Stream* device_to_device_stream(int index) const {
Expand Down Expand Up @@ -72,6 +75,8 @@ class GPUDeviceContext : public DeviceContext {
// The default primary stream to use for this context.
// All the memory belongs to this stream.
se::Stream* stream_;
// The stream to use for nccl operations.
se::Stream* nccl_stream_;
// The stream to use for copying data from host into GPU.
se::Stream* host_to_device_stream_;
// The stream to use for copying data from GPU to host.
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ tf_kernel_library(
prefix = "collective_ops",
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
Expand Down Expand Up @@ -382,6 +383,7 @@ tf_kernel_library(
"//tensorflow/core/nccl:nccl_lib",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:gpu_runtime",
] + nccl_config()
),
)
Expand Down
10 changes: 7 additions & 3 deletions tensorflow/core/kernels/collective_nccl_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include "tensorflow/core/common_runtime/collective_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/nccl/nccl_manager.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/profiler/lib/traceme.h"
Expand Down Expand Up @@ -122,6 +123,9 @@ void NcclReducer::Run(StatusCallback done) {
Notification nccl_done;
Status nccl_status;
auto* compute_stream = col_ctx_->op_ctx->op_device_context()->stream();
auto* nccl_stream =
static_cast<GPUDeviceContext*>(col_ctx_->op_ctx->op_device_context())
->nccl_stream();
auto* gpu_info = col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
// `AddToAllReduce` performs consistency checks for the NCCL call and enqueues
// the `Participant` struct locally. When all local participants with this
Expand All @@ -142,9 +146,9 @@ void NcclReducer::Run(StatusCallback done) {
nccl_done.Notify();
};
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info->event_mgr,
gpu_info->gpu_id, col_ctx_->input, col_ctx_->output,
col_params_->default_rank, std::move(done_callback));
compute_stream->parent(), compute_stream, nccl_stream,
gpu_info->event_mgr, gpu_info->gpu_id,col_ctx_->input,
col_ctx_->output, col_params_->default_rank, std::move(done_callback));
VLOG(1) << "NcclReducer calling NcclManager::AddToAllReduce num_tasks "
<< col_params_->group.num_tasks << " current task "
<< col_params_->instance.task_names[col_params_->default_rank]
Expand Down
46 changes: 31 additions & 15 deletions tensorflow/core/kernels/nccl_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#if TENSORFLOW_USE_ROCM
#include "rocm/include/rccl/rccl.h"
#endif
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/nccl/nccl_manager.h"

Expand Down Expand Up @@ -107,11 +108,14 @@ class NcclAllReduceOpKernel : public NcclReduceOpBase {
};

auto* compute_stream = c->op_device_context()->stream();
auto* nccl_stream =
static_cast<const GPUDeviceContext*>(c->op_device_context())
->nccl_stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info->event_mgr,
gpu_info->gpu_id, input, output, /*global_rank=*/-1,
std::move(actual_done));
compute_stream->parent(), compute_stream, nccl_stream,
gpu_info->event_mgr, gpu_info->gpu_id, input, output,
/*global_rank=*/-1, std::move(actual_done));
NcclManager::instance()->AddToAllReduce(
std::move(participant),
{GetCollectiveKey(c),
Expand Down Expand Up @@ -139,11 +143,14 @@ class NcclReduceSendKernel : public NcclReduceOpBase {
};

auto* compute_stream = c->op_device_context()->stream();
auto* nccl_stream =
static_cast<const GPUDeviceContext*>(c->op_device_context())
->nccl_stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info->event_mgr,
gpu_info->gpu_id, &c->input(0), /*output=*/nullptr, /*global_rank=*/-1,
std::move(actual_done));
compute_stream->parent(), compute_stream, nccl_stream,
gpu_info->event_mgr, gpu_info->gpu_id, &c->input(0),
/*output=*/nullptr, /*global_rank=*/-1, std::move(actual_done));
NcclManager::instance()->AddReduceSend(
std::move(participant),
{GetCollectiveKey(c),
Expand Down Expand Up @@ -176,11 +183,14 @@ class NcclReduceRecvKernel : public NcclReduceOpBase {
};

auto* compute_stream = c->op_device_context()->stream();
auto* nccl_stream =
static_cast<const GPUDeviceContext*>(c->op_device_context())
->nccl_stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info->event_mgr,
gpu_info->gpu_id, input, output, /*global_rank=*/-1,
std::move(actual_done));
compute_stream->parent(), compute_stream, nccl_stream,
gpu_info->event_mgr, gpu_info->gpu_id, input, output,
/*global_rank=*/-1, std::move(actual_done));
NcclManager::instance()->AddReduceRecv(
std::move(participant),
{GetCollectiveKey(c),
Expand Down Expand Up @@ -211,11 +221,14 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase {
};

auto* compute_stream = c->op_device_context()->stream();
auto* nccl_stream =
static_cast<const GPUDeviceContext*>(c->op_device_context())
->nccl_stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info->event_mgr,
gpu_info->gpu_id, &c->input(0), /*output=*/nullptr, /*global_rank=*/-1,
std::move(actual_done));
compute_stream->parent(), compute_stream, nccl_stream,
gpu_info->event_mgr, gpu_info->gpu_id, &c->input(0),
/*output=*/nullptr, /*global_rank=*/-1, std::move(actual_done));
NcclManager::instance()->AddBroadcastSend(
std::move(participant), {GetCollectiveKey(c),
/*num_local_devices=*/num_devices(),
Expand Down Expand Up @@ -248,11 +261,14 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
};

auto* compute_stream = c->op_device_context()->stream();
auto* nccl_stream =
static_cast<const GPUDeviceContext*>(c->op_device_context())
->nccl_stream();
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
auto participant = absl::make_unique<NcclManager::Participant>(
compute_stream->parent(), compute_stream, gpu_info->event_mgr,
gpu_info->gpu_id, /*input=*/nullptr, output, /*global_rank=*/-1,
std::move(actual_done));
compute_stream->parent(), compute_stream, nccl_stream,
gpu_info->event_mgr, gpu_info->gpu_id, /*input=*/nullptr,
output, /*global_rank=*/-1, std::move(actual_done));
NcclManager::instance()->AddBroadcastRecv(
std::move(participant), {GetCollectiveKey(c),
/*num_local_devices=*/num_devices(),
Expand Down
8 changes: 4 additions & 4 deletions tensorflow/core/nccl/nccl_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct NcclManager::NcclStream : public core::RefCounted {

// The stream on which to run the nccl collective.
// This is a different stream than the tensorflow compute stream.
std::unique_ptr<se::Stream> stream;
se::Stream* stream = nullptr;

// `mu` protects access to `pending_launches_`, which is the list of
// collectives ready but whose kernels are yet to be launched. When the
Expand Down Expand Up @@ -297,6 +297,7 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective,
std::vector<int> devices(collective->num_local_devices);
for (int i = 0; i < collective->num_local_devices; ++i) {
auto* executor = collective->participants[i]->executor;
auto* borrowed_nccl_stream = collective->participants[i]->nccl_stream;

// Find a communication stream to use for the device.
auto& streams = device_to_comm_streams_[executor];
Expand All @@ -310,8 +311,7 @@ Status NcclManager::GetCommunicator(NcclManager::Collective* collective,
if (nccl_stream == nullptr) {
nccl_stream = new NcclStream();
nccl_stream->executor = executor;
nccl_stream->stream.reset(new se::Stream(executor));
nccl_stream->stream->Init();
nccl_stream->stream = borrowed_nccl_stream;

streams.emplace_back(nccl_stream);
used_streams.insert(nccl_stream);
Expand Down Expand Up @@ -578,7 +578,7 @@ void NcclManager::RunCollective(Collective* collective) {
}

void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
se::Stream* comm_stream = nccl_stream->stream.get();
se::Stream* comm_stream = nccl_stream->stream;
ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
comm_stream->implementation()->GpuStreamMemberHack());
Expand Down
8 changes: 7 additions & 1 deletion tensorflow/core/nccl/nccl_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ class NcclManager {

// A participant in a Collective.
struct Participant {
Participant(se::StreamExecutor* executor, se::Stream* tensor_stream,
Participant(se::StreamExecutor* executor,
se::Stream* tensor_stream, se::Stream* nccl_stream,
EventMgr* event_mgr, int gpu_device_id, const Tensor* input,
Tensor* output, int global_rank, DoneCallback done_callback)
: executor(executor),
tensor_stream(tensor_stream),
nccl_stream(nccl_stream),
event_mgr(event_mgr),
gpu_device_id(gpu_device_id),
input(input),
Expand All @@ -75,6 +77,7 @@ class NcclManager {
DCHECK(executor != nullptr);
DCHECK(event_mgr != nullptr);
DCHECK(tensor_stream != nullptr);
DCHECK(nccl_stream != nullptr);
}

// StreamExecutor for the device. Expected to be live for process lifetime.
Expand All @@ -88,6 +91,9 @@ class NcclManager {
// `done_callback` is called.
se::Stream* const tensor_stream;

// `nccl_stream` is the stream for all nccl operations
se::Stream* const nccl_stream;

// EventMgr which polls on executor.
// Owned by the caller, who must keep it live until `done_callback` is
// called.
Expand Down

0 comments on commit 593010e

Please sign in to comment.