From 2f696ed14d60ffff4c3c57f98a3a0298e653ea48 Mon Sep 17 00:00:00 2001 From: Hanbin Hu Date: Sat, 10 Apr 2021 20:40:53 -0700 Subject: [PATCH] Improve neighbor allreduce (#78) * Fixed the self_weight under emtpy receiving case * Enable empty send neighbors and fix HalfTensor for recv_size==0 * Fixed the self_weight under emtpy receiving case * Enable empty send neighbors and fix HalfTensor for recv_size==0 * Rename neighbor_weights to src_weights, and send_neighbors to dst_weights for neighbor_allreduce * A script to test existing examples * Accept dst_weights as Dict, and reorganize DoNeighborAllreduce * Reorganize CheckNeighborSendRecvPattern * Fix timeline_ptr for NCCL * Fix timeline_ptr for NCCL * Put dst_weights information into TensorTableEntry * First Version of neighbor_allreduce dst_weight, existing problem: Fusion Not Implemented, CUDA data_weight problem * Add some delay after data_weight as a temporary solution * CPU Fusion for dst_weighted added * Add ReadyEvent for dst_weight for single entry neighbor_allreduce * Remove const identifier for tensor dtype as it is meaningless * Add cuda source for scalebuffer * Scale buffer to modify itself * Add .o file to .gitignore * dst_weight using CUDA for fused entry & compile flow in Python setup.py * make clean *.o files generated by nvcc * Add fix for NCCL single entry * Make setup.py more robust * Add timeout and cuda check * Move test example * Fix NCCL side dst_weight fusion bug * Add agg to make matplotlib more stable * Address comments for setup.py * Simpler logic for dst_weighting_enabled and weighted_average_computation * Better consideration for weight buffer size * Make src_weights as std::map, and simplify logic for PerformNeighborAllreduceCallback * Add TODO #80 and #81, and simplify the logic for dst_weight * Wrap CheckNeighborSendRecvPattern again * Add two more TODOs * Address review comments * Add condition variable to control the loop (#88) * Add condition variable to control the loop * Minor update on topology_setting in global_state * Add missing header * Change cv.wait to cv.wait_for 10 seconds * Address comment and remove adjusting resetVersionWinMem in ibfrun Co-authored-by: ybc --- .gitignore | 1 + Makefile | 10 +- bluefog/common/common.h | 9 +- bluefog/common/cuda/cuda_kernels.cu | 120 +++++++++ bluefog/common/cuda/cuda_kernels.h | 33 +++ bluefog/common/global_state.h | 14 +- bluefog/common/mpi_context.cc | 11 +- bluefog/common/mpi_context.h | 15 +- bluefog/common/mpi_controller.cc | 334 +++++++++++++++++++------- bluefog/common/mpi_controller.h | 15 +- bluefog/common/nccl_controller.cc | 138 +++++++++-- bluefog/common/nccl_controller.h | 8 + bluefog/common/operations.cc | 170 ++++++------- bluefog/common/operations.h | 2 + bluefog/common/tensor_queue.cc | 30 +++ bluefog/common/tensor_queue.h | 44 +++- bluefog/run/interactive_run.py | 3 +- bluefog/tensorflow/adapter.cc | 4 +- bluefog/tensorflow/adapter.h | 4 +- bluefog/torch/adapter.cc | 12 +- bluefog/torch/adapter.h | 5 +- bluefog/torch/mpi_ops.cc | 297 +++++++++-------------- bluefog/torch/mpi_ops.h | 13 +- bluefog/torch/mpi_ops.py | 117 ++++----- bluefog/torch/optimizers.py | 16 +- examples/pytorch_average_consensus.py | 4 +- examples/pytorch_benchmark.py | 4 +- examples/pytorch_mnist.py | 4 +- examples/pytorch_optimization.py | 19 +- examples/pytorch_resnet.py | 4 +- setup.py | 42 +++- test/test_all_example.sh | 115 +++++++++ test/torch_ops_test.py | 120 ++++++++- test/torch_optimizer_test.py | 4 +- 34 files changed, 1233 insertions(+), 508 deletions(-) create mode 100644 bluefog/common/cuda/cuda_kernels.cu create mode 100644 bluefog/common/cuda/cuda_kernels.h create mode 100755 test/test_all_example.sh diff --git a/.gitignore b/.gitignore index 56029e30..bb14086e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ # C extensions *.so +*.o # Distribution / packaging .Python diff --git a/Makefile b/Makefile index 5b3ebdfb..dadbdf88 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ test_torch: test_torch_basic test_torch_ops test_torch_win_ops test_torch_optimi test_tensorflow: test_tensorflow_basic test_tensorflow_ops test_all: test_torch test_tensorflow -clean: clean_build clean_so +clean: clean_build clean_so clean_o .PHONY: test_torch_basic test_torch_basic: @@ -51,8 +51,12 @@ test_tensorflow_ops: .PHONY: clean_build clean_build: - rm -R build + rm -fR build .PHONY: clean_so clean_so: - rm ./bluefog/torch/mpi_lib.*.so + rm -f ./bluefog/torch/mpi_lib.*.so + +.PHONY: clean_o +clean_o: + rm -f ./bluefog/common/cuda/*.o \ No newline at end of file diff --git a/bluefog/common/common.h b/bluefog/common/common.h index 8843761c..ab78dac9 100644 --- a/bluefog/common/common.h +++ b/bluefog/common/common.h @@ -209,10 +209,10 @@ class TensorShape { class Tensor { public: - virtual const DataType dtype() const = 0; + virtual DataType dtype() const = 0; virtual const TensorShape shape() const = 0; virtual const void* data() const = 0; - virtual std::shared_ptr data_weight(float weight) = 0; + virtual std::unique_ptr data_weight(float weight) = 0; virtual int64_t size() const = 0; virtual ~Tensor() = default; }; @@ -241,6 +241,7 @@ class OpContext { std::shared_ptr* tensor) = 0; virtual Status AllocateZeros(int64_t num_elements, DataType dtype, std::shared_ptr* tensor) = 0; + virtual std::shared_ptr RecordReadyEvent(int device) = 0; virtual Framework framework() const = 0; virtual ~OpContext() = default; }; @@ -279,10 +280,14 @@ struct TensorTableEntry { // Neighbors for dynamic neighbor_allreduce. std::shared_ptr> send_neighbors; std::shared_ptr> recv_neighbors; + std::shared_ptr> send_weights; // Boolean value if dynamic neighbor is enabled. bool dynamic_neighbors_enabled = false; + // Boolean value for enabling destination(send) weighting operation or not. + bool dst_weighting_enabled = false; + // Boolean value for enabling topology check. bool enable_topo_check = false; diff --git a/bluefog/common/cuda/cuda_kernels.cu b/bluefog/common/cuda/cuda_kernels.cu new file mode 100644 index 00000000..144b21e8 --- /dev/null +++ b/bluefog/common/cuda/cuda_kernels.cu @@ -0,0 +1,120 @@ +// Copyright (C) 2020 NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "cuda_kernels.h" + +#include +#include + +namespace bluefog { +namespace common { + +template +__global__ void scale_buffer_k(T* buffer, int64_t num_elements, const TS scale_factor) { + + const size_t idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + buffer[i] *= scale_factor; + } +} + +// Specialization for half2 +__global__ void scale_buffer_half2_k(__half* buffer, int64_t num_elements, const __half scale_factor) { + + const size_t idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + +#if __CUDA_ARCH__ > 530 + __half2* buffer_h2 = reinterpret_cast<__half2 *>(buffer); + const __half2 scale_factor_h2 = __halves2half2(scale_factor, scale_factor); + + for (size_t i = idx; i < num_elements / 2; i += gridDim.x * blockDim.x) { + buffer_h2[i] = __hmul2(scale_factor_h2, buffer_h2[i]); + } + + // Deal with last element if num_elements is odd + if (idx == 0 && num_elements % 2) { + buffer[num_elements - 1] = __hmul(scale_factor, buffer[num_elements - 1]); + } +#else + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + buffer[i] = __float2half(__half2float(scale_factor) * __half2float(buffer[i])); + } +#endif +} + +// Specialization for architectures without __half compute +template<> +__global__ void scale_buffer_k(__half* buffer, int64_t num_elements, const __half scale_factor) { + + const size_t idx = static_cast(blockDim.x) * blockIdx.x + threadIdx.x; + +#if __CUDA_ARCH__ > 530 + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + buffer[i] *= scale_factor; + } +#else + for (size_t i = idx; i < num_elements; i += gridDim.x * blockDim.x) { + buffer[i] = __float2half(__half2float(scale_factor) * __half2float(buffer[i])); + } +#endif +} + +#define NTHREADS_SCALE_BUFFER_KERNEL 512 +void ScaleBufferCudaImpl(double scale_factor, void* buffer_data, const int64_t num_elements, + DataType dtype, cudaStream_t stream) { + const int64_t blocks = (num_elements + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL; + const int threads = NTHREADS_SCALE_BUFFER_KERNEL; + switch (dtype) { + case DataType::BLUEFOG_UINT8: + scale_buffer_k<<>>((uint8_t*) buffer_data, num_elements, scale_factor); + break; + case DataType::BLUEFOG_INT8: + scale_buffer_k<<>>((int8_t*) buffer_data, num_elements, scale_factor); + break; + case DataType::BLUEFOG_INT32: + scale_buffer_k<<>>((int32_t*) buffer_data, num_elements, scale_factor); + break; + case DataType::BLUEFOG_INT64: + scale_buffer_k<<>>((int64_t*) buffer_data, num_elements, scale_factor); + break; + case DataType::BLUEFOG_FLOAT16: + { + __half scale_factor_half = __float2half((float) scale_factor); + if ((size_t) buffer_data % 4 == 0) { + // If alignment allows, use half2 specialized kernel + int64_t num_elements_h2 = (num_elements + 1) / 2; + int64_t blocks_h2 = (num_elements_h2 + NTHREADS_SCALE_BUFFER_KERNEL - 1) / NTHREADS_SCALE_BUFFER_KERNEL; + scale_buffer_half2_k<<>>((__half*) buffer_data, num_elements, scale_factor_half); + } else { + scale_buffer_k<<>>((__half*) buffer_data, num_elements, scale_factor_half); + } + break; + } + case DataType::BLUEFOG_FLOAT32: + scale_buffer_k<<>>((float*) buffer_data, num_elements, (float) scale_factor); + break; + case DataType::BLUEFOG_FLOAT64: + scale_buffer_k<<>>((double*) buffer_data, num_elements, scale_factor); + break; + default: + throw std::logic_error("Type " + DataType_Name(dtype) + + " not supported by ScaleBufferCudaImpl."); + } +} + +} // namespace common +} // namespace bluefog + diff --git a/bluefog/common/cuda/cuda_kernels.h b/bluefog/common/cuda/cuda_kernels.h new file mode 100644 index 00000000..2ba27a09 --- /dev/null +++ b/bluefog/common/cuda/cuda_kernels.h @@ -0,0 +1,33 @@ +// Copyright (C) 2020 NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef CUDA_KERNELS_H +#define CUDA_KERNELS_H + +#include + +#include "../common.h" + +namespace bluefog { +namespace common { + +// Scales buffer by scalar +void ScaleBufferCudaImpl(double scale_factor, void* buffer_data, const int64_t num_elements, + DataType dtype, cudaStream_t stream); + +} // namespace common +} // namespace bluefog + +#endif // CUDA_KERNELS_H \ No newline at end of file diff --git a/bluefog/common/global_state.h b/bluefog/common/global_state.h index d9addcac..2683ad14 100644 --- a/bluefog/common/global_state.h +++ b/bluefog/common/global_state.h @@ -18,6 +18,7 @@ #define BLUEFOG_COMMON_GLOBAL_STATE_H #include +#include #include #include #include @@ -54,6 +55,14 @@ struct BluefogGlobalState { // Whether collective context has been completed on the background thread. std::atomic_bool initialization_done{false}; + // Condition variable and its mutex for main loop in communication thread. + std::condition_variable loop_cv; + std::mutex loop_mutex; + + // Under negotiation, the entries sends to master first and wait until it + // returns ok to run. This variable keeps the records of that. + std::atomic_int unfinished_enqueued_entries{0}; + // Timeline writer. Timeline timeline; @@ -80,13 +89,12 @@ struct BluefogGlobalState { // Threshold for Tensor Fusion. All tensors that occupy memory beyond this // threshold will be fused. int64_t tensor_fusion_threshold = 8 * 1024 * 1024; + int64_t tensor_fusion_threshold_for_dst_weight = tensor_fusion_threshold; FusionBufferManager fusion_buffer; // Because setting topology happens in the main thread instead of communication - // thread. Following three variables are to sync between them. + // thread. Not really used since the condition variable refactor. std::atomic_bool setting_topology{false}; - std::atomic_bool setting_topology_done{false}; - std::atomic_bool ready_to_setting_topology{false}; // Only exists on the coordinator node (rank zero). Maintains a vector of // requests to allreduce every tensor (keyed by tensor name). diff --git a/bluefog/common/mpi_context.cc b/bluefog/common/mpi_context.cc index 5d6cefc8..e5ef0e6c 100644 --- a/bluefog/common/mpi_context.cc +++ b/bluefog/common/mpi_context.cc @@ -75,7 +75,7 @@ bool WindowManager::InitializeMutexWin(const MPI_Comm& mpi_comm) { std::vector WindowManager::GetVersionMemoryCopy() { return version_mem_; } void WindowManager::resetVersionWinMem(int initialValue /*=0*/) { - for (int i = 0; i < version_mem_.size(); i++) { + for (size_t i = 0; i < version_mem_.size(); i++) { version_mem_[i] = initialValue; } } @@ -222,7 +222,7 @@ MPI_Op MPIContext::GetMPISumOp(DataType dtype) { return dtype == DataType::BLUEFOG_FLOAT16 ? mpi_float16_sum : MPI_SUM; } -MPI_Comm MPIContext::GetMPICommunicator(Communicator comm) { +MPI_Comm MPIContext::GetMPICommunicator(Communicator comm) const { switch (comm) { case Communicator::GLOBAL: return mpi_comm; @@ -332,6 +332,13 @@ void MPIContext::Initialize(const std::vector& ranks, // Create custom MPI float16 summation op. MPI_Op_create(&float16_sum, 1, &mpi_float16_sum); + +#if HAVE_CUDA + int greatest_priority; + CUDACHECK(cudaDeviceGetStreamPriorityRange(NULL, &greatest_priority)); + CUDACHECK(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, + greatest_priority)); +#endif } void MPIContext::Finalize(MPIContextManager& ctx_manager) { diff --git a/bluefog/common/mpi_context.h b/bluefog/common/mpi_context.h index 6ee2a25c..c220d497 100644 --- a/bluefog/common/mpi_context.h +++ b/bluefog/common/mpi_context.h @@ -22,6 +22,10 @@ #include #include +#if HAVE_CUDA +#include "cuda_runtime.h" +#endif + #include "common.h" #include "mpi.h" @@ -144,7 +148,7 @@ class MPIContext { MPI_Op GetMPISumOp(DataType dtype); - MPI_Comm GetMPICommunicator(Communicator comm); + MPI_Comm GetMPICommunicator(Communicator comm) const; int GetMPITypeSize(DataType dtype); @@ -232,8 +236,17 @@ class MPIContext { // MPI Custom data type for float16. MPI_Datatype mpi_float16_t; MPI_Op mpi_float16_sum; + + // TODO(hhb): #80 We should use a common context for MPI and NCCL controller for CUDA usage. +#if HAVE_CUDA + // CUDA Stream + cudaStream_t stream; +#endif }; +std::string GenerateNeighborExchangeErrorMessage(const std::vector& statuses, + int nsend, int nrecv); + } // namespace common } // namespace bluefog diff --git a/bluefog/common/mpi_controller.cc b/bluefog/common/mpi_controller.cc index c3996407..cd632049 100644 --- a/bluefog/common/mpi_controller.cc +++ b/bluefog/common/mpi_controller.cc @@ -29,9 +29,40 @@ #include "operations.h" #include "timeline.h" +#if HAVE_CUDA +#include "cuda/cuda_kernels.h" +#endif + namespace bluefog { namespace common { +namespace { + +template +void ScaleBufferCPUImpl(T* buffer, int64_t num_elements, TS scale_factor) { + for (int64_t i = 0; i < num_elements; ++i) { + buffer[i] = buffer[i] * scale_factor; + } +} + +void ScaleCPUBuffer(double scale_factor, void* weighted_fused_input_data, + int64_t num_elements, DataType dtype) { + switch (dtype) { + // TODO(hhb): FLOAT16 support + case DataType::BLUEFOG_FLOAT32: + ScaleBufferCPUImpl((float*) weighted_fused_input_data, num_elements, (float) scale_factor); + break; + case DataType::BLUEFOG_FLOAT64: + ScaleBufferCPUImpl((double*) weighted_fused_input_data, num_elements, scale_factor); + break; + default: + throw std::logic_error("Type " + DataType_Name(dtype) + + " not supported by ScaleBufferCPUImpl."); + } +} + +} + // It may be because the win_create is called at different // threads from the win_put, win_get, etc. After moving win_create into // communicaiton thread, it resolved. (works in Openmpi=4.0.2 and MPICH). @@ -263,15 +294,13 @@ void MPIController::NeighborAllgather(TensorTableEntry& entry) { displcmnts = new int[mpi_ctx_.neighbor_indgree_]; status = mpi_ctx_.AllocateOutput(entry, recvcounts, Communicator::GRAPH); } else { - bool is_topo_check_fail = CheckNeighborSendRecvPattern( - mpi_ctx_.size_, entry, timeline_ptr, - mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL)); - + bool is_topo_check_fail = CheckNeighborSendRecvPatternForEntry(entry, mpi_ctx_, timeline_ptr); if (is_topo_check_fail) { entry.callback(Status::InvalidArgument( "Src and dst neighbor ranks do not match")); return; } + recvcounts = new int[entry.recv_neighbors->size()]; displcmnts = new int[entry.recv_neighbors->size()]; status = mpi_ctx_.AllocateOutput(entry, recvcounts, Communicator::DYNAMIC, @@ -332,55 +361,59 @@ void MPIController::NeighborAllgather(TensorTableEntry& entry) { } // Function to check if the sending and receiving neighbors match in the topology. -bool CheckNeighborSendRecvPattern(int size, const TensorTableEntry& entry, - Timeline* timeline_ptr, const MPI_Comm& comm) { +bool CheckNeighborSendRecvPattern( + const std::vector& send_neighbors, const std::vector& recv_neighbors, + const std::string& tensor_name, int size, Timeline* timeline_ptr, const MPI_Comm& comm) { bool res = false; - // enabled the check if enable_topo_check is true and partial - // neighbor_allreduce is activated. + timeline_ptr->ActivityStart(tensor_name, "NEGOTIATION"); + // Put all the send and recv neighbors in a single vector, and obtain a send + // matrix and a recv matrix through MPI_Allgather. + bool* send_check_buf = new bool[2 * size]; + std::fill_n(send_check_buf, 2 * size, false); + bool* recv_check_buf = new bool[2 * size * size]; + for (int send_rank : send_neighbors) + send_check_buf[send_rank] = true; + for (int recv_rank : recv_neighbors) + send_check_buf[size + recv_rank] = true; + MPICHECK(MPI_Allgather(send_check_buf, size * 2, MPI_C_BOOL, + recv_check_buf, size * 2, MPI_C_BOOL, comm)); + // This checks that send matrix and transposed recv matrix should be the + // same. If same, the topology is good to go. If not, there is mismatch edge + // to be fixed. + auto GetSendIndex = [size](int i, int j) -> int { return 2*size*i+j; }; + auto GetRecvIndex = [size](int i, int j) -> int { return 2*size*i+j+size; }; + for (int i = 0; i < size; ++i) { + if (res) break; + for (int j = 0; j < size; ++j) { + if (recv_check_buf[GetSendIndex(i, j)] != + recv_check_buf[GetRecvIndex(j, i)]) { + res = true; + break; + } + } + } + delete [] send_check_buf; + delete [] recv_check_buf; + timeline_ptr->ActivityEnd(tensor_name); + return res; +} + +bool CheckNeighborSendRecvPatternForEntry(const TensorTableEntry& entry, + const MPIContext& mpi_ctx, + Timeline* timeline_ptr) { + bool is_topo_check_fail = false; if (entry.enable_topo_check && entry.dynamic_neighbors_enabled) { if (entry.is_hierarchical) { // TODO: support check. - BFLOG(INFO) << "Request to check topology for hierarchical neighbor " - << "allreduce ops but it is not supported yet."; - return res; - } - timeline_ptr->ActivityStart(entry.tensor_name, "NEGOTIATION"); - // Put all the send and recv neighbors in a single vector, and obtain a send - // matrix and a recv matrix through MPI_Allgather. - bool* send_check_buf = new bool[2 * size]; - std::fill_n(send_check_buf, 2 * size, false); - bool* recv_check_buf = new bool[2 * size * size]; - for (int send_rank : *(entry.send_neighbors)) - send_check_buf[send_rank] = true; - for (int recv_rank : *(entry.recv_neighbors)) - send_check_buf[size + recv_rank] = true; - int ret_code = MPI_Allgather(send_check_buf, size * 2, MPI_C_BOOL, - recv_check_buf, size * 2, MPI_C_BOOL, comm); - if (ret_code != MPI_SUCCESS) { - throw std::runtime_error( - "MPI_Allgather (for dynamic neighbor_allreduce negotiation) failed, " - "see MPI output for details."); - } - // This checks that send matrix and transposed recv matrix should be the - // same. If same, the topology is good to go. If not, there is mismatch edge - // to be fixed. - auto GetSendIndex = [size](int i, int j) -> int { return 2*size*i+j; }; - auto GetRecvIndex = [size](int i, int j) -> int { return 2*size*i+j+size; }; - for (int i = 0; i < size; ++i) { - if (res) break; - for (int j = 0; j < size; ++j) { - if (recv_check_buf[GetSendIndex(i, j)] != - recv_check_buf[GetRecvIndex(j, i)]) { - res = true; - break; - } - } + BFLOG(WARNING) << "Request to check topology for hierarchical neighbor " + << "allreduce ops but it is not supported yet."; + } else { + is_topo_check_fail = CheckNeighborSendRecvPattern( + *entry.send_neighbors, *entry.recv_neighbors, entry.tensor_name, + mpi_ctx.size_, timeline_ptr, mpi_ctx.GetMPICommunicator(Communicator::GLOBAL)); } - delete [] send_check_buf; - delete [] recv_check_buf; - timeline_ptr->ActivityEnd(entry.tensor_name); } - return res; + return is_topo_check_fail; } void MPIController::NeighborAllreduce(TensorTableEntry& entry) { @@ -402,10 +435,7 @@ void MPIController::NeighborAllreduce(TensorTableEntry& entry) { // If only partial sending is enabled, the following code block checks whether the sending // and recieving neighbors match each other when enable_topo_check is set to be True. - bool is_topo_check_fail = CheckNeighborSendRecvPattern( - mpi_ctx_.size_, entry, timeline_ptr, - mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL)); - + bool is_topo_check_fail = CheckNeighborSendRecvPatternForEntry(entry, mpi_ctx_, timeline_ptr); if (is_topo_check_fail) { entry.callback(Status::InvalidArgument( "Src(recv from) and dst(send to) neighbor ranks do not match")); @@ -419,25 +449,62 @@ void MPIController::NeighborAllreduce(TensorTableEntry& entry) { // including itself is more intuitive. std::string error_message = ""; - if (!entry.is_hierarchical) { - if (!entry.dynamic_neighbors_enabled) { - int ret_code = MPI_Neighbor_allgather( + if (!entry.is_hierarchical) { // neighbor allreduce without hierarchy + if (!entry.dynamic_neighbors_enabled) { // static topology + MPICHECK(MPI_Neighbor_allgather( sendbuf, num_elements, mpi_ctx_.GetMPIDataType(entry.tensor), buffer_data, num_elements, mpi_ctx_.GetMPIDataType(entry.output), - mpi_ctx_.GetMPICommunicator(Communicator::GRAPH)); - if (ret_code != MPI_SUCCESS) { - throw std::runtime_error( - "MPI_Neighbor_allreduce (through neighbor_allgather) failed, see " - "MPI " - "output for details."); + mpi_ctx_.GetMPICommunicator(Communicator::GRAPH))); + } else { // dynamic topology + int nsend = entry.send_neighbors->size(); + int nrecv = entry.recv_neighbors->size(); + + // Ensure the lifecycle of the weighted tensors are alive after communication. + std::vector> weighted_tensors; + if (entry.dst_weighting_enabled) { + for (int i = 0; i < nsend; ++i) { + auto weighted_tensor_ptr = entry.tensor->data_weight(entry.send_weights->at(i)); + weighted_tensors.push_back(std::move(weighted_tensor_ptr)); + } } - } else { - error_message = mpi_ctx_.NeighborValueExchangeWithConstantElements( - sendbuf, (void *)entry.output->data(), num_elements, entry.output->dtype(), - entry.send_neighbors.get(), entry.recv_neighbors.get() - ); + // TODO(ybc) #83 Better design pattern for data_weight synchronization + // This ready event makes sure the data_weight computation is done before communication, as + // Pytorch CUDA stream is not synchronized with our CUDA stream, and it does nothing when + // it is running on CPU. + std::shared_ptr ready_event = + entry.context->RecordReadyEvent(entry.device); + + std::vector requests(nsend + nrecv); + std::vector statuses(nsend + nrecv); + int element_size = mpi_ctx_.GetMPITypeSize(entry.output->dtype()); + for (int i = 0; i < nrecv; ++i) { + void* recvbuf = (void*)(static_cast(entry.output->data()) + + num_elements * i * element_size); + MPICHECK(MPI_Irecv(recvbuf, num_elements, + mpi_ctx_.GetMPIDataType(entry.output), entry.recv_neighbors->at(i), + /*tag=*/mpi_ctx_.rank_ + entry.recv_neighbors->at(i), + mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL), &requests[i + nsend])); + } + + if (entry.dst_weighting_enabled) { + while ((ready_event != nullptr) && !ready_event->Ready()) { + std::this_thread::sleep_for(std::chrono::nanoseconds(100)); + } + } + for (int i = 0; i < nsend; ++i) { + const void* buffer_send = (entry.dst_weighting_enabled) + ? weighted_tensors[i]->data() + : sendbuf; + MPICHECK(MPI_Isend(buffer_send, num_elements, + mpi_ctx_.GetMPIDataType(entry.tensor), entry.send_neighbors->at(i), + /*tag=*/mpi_ctx_.rank_ + entry.send_neighbors->at(i), + mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL), &requests[i])); + } + MPI_Waitall(nsend + nrecv, requests.data(), statuses.data()); + error_message = + GenerateNeighborExchangeErrorMessage(statuses, nsend, nrecv); } - } else { + } else { // hierarchical neighbor allreduce if (entry.send_neighbors->empty()) { throw std::runtime_error( "Under hierarchical neighbor_allreduce, argument " @@ -448,6 +515,11 @@ void MPIController::NeighborAllreduce(TensorTableEntry& entry) { "Local size is smaller than 2, in this case, you should use " "neighbor_allreduce instead of hierarchical_neighbor_allreduce."); } + if (entry.dst_weighting_enabled) { + throw std::runtime_error( + "Under hierarchical neighbor_allreduce, argument " + "dst_weight should not be enabled for now."); + } // 1. In-place allreduce MPI_Allreduce(MPI_IN_PLACE, (void*)sendbuf, num_elements, mpi_ctx_.GetMPIDataType(entry.tensor), MPI_SUM, @@ -537,10 +609,8 @@ void MPIController::NeighborAllreduce(std::vector& entries) { // If only partial sending is enabled, the following code block checks whether // the sending and recieving neighbors match each other when enable_topo_check // is set to be True. - bool is_topo_check_fail = CheckNeighborSendRecvPattern( - mpi_ctx_.size_, first_entry, timeline_ptr, - mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL)); - + bool is_topo_check_fail = CheckNeighborSendRecvPatternForEntry(first_entry, mpi_ctx_, + timeline_ptr); if (is_topo_check_fail) { for (auto& entry : entries) { entry.callback(Status::InvalidArgument( @@ -555,6 +625,15 @@ void MPIController::NeighborAllreduce(std::vector& entries) { timeline_ptr->ActivityEndAll(entries); const void* fused_input_data = buffer_data; + const void* weighted_fused_input_data = nullptr; + if (first_entry.dst_weighting_enabled) { + // Generate weighted data fusion for sending + timeline_ptr->ActivityStartAll(entries, "MEMCPY_IN_WEIGHT_FUSION_BUFFER"); + weighted_fused_input_data = GenerateWeightedFusedInputData(fused_input_data, first_entry, + num_elements, element_size); + timeline_ptr->ActivityEndAll(entries); + } + // Unlike allreduce, the storage for neighbor_allreduce in fusion buffer // is like [t_1, t_2 | t_1_n1, t_2_n1, t_1_n2, t_2_n2]. // Here t_1 and t_2 means self tensor 1 and 2 and _n1 and _n2 means the @@ -569,24 +648,45 @@ void MPIController::NeighborAllreduce(std::vector& entries) { // including itself is more intuitive. std::string error_message = ""; - if (!first_entry.is_hierarchical) { - if (!first_entry.dynamic_neighbors_enabled) { - int ret_code = MPI_Neighbor_allgather( + if (!first_entry.is_hierarchical) { // neighbor allreduce without hierarchy + if (!first_entry.dynamic_neighbors_enabled) { // static topology + MPICHECK(MPI_Neighbor_allgather( fused_input_data, num_elements, mpi_ctx_.GetMPIDataType(first_entry.tensor), buffer_data, num_elements, mpi_ctx_.GetMPIDataType(first_entry.output), - mpi_ctx_.GetMPICommunicator(Communicator::GRAPH)); - if (ret_code != MPI_SUCCESS) { - throw std::runtime_error( - "MPI_Neighbor_allreduce (through neighbor_allgather) failed, see MPI " - "output for details."); + mpi_ctx_.GetMPICommunicator(Communicator::GRAPH))); + } else { // dynamic topology + int nsend = first_entry.send_neighbors->size(); + int nrecv = first_entry.recv_neighbors->size(); + std::vector requests(nsend + nrecv); + std::vector statuses(nsend + nrecv); + for (int i = 0; i < nrecv; ++i) { + void* recvbuf = + (void*)((uint8_t*)buffer_data + num_elements * i * element_size); + MPICHECK(MPI_Irecv(recvbuf, num_elements, + mpi_ctx_.GetMPIDataType(first_entry.output), first_entry.recv_neighbors->at(i), + /*tag=*/mpi_ctx_.rank_ + first_entry.recv_neighbors->at(i), + mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL), &requests[i + nsend])); } - } else { - error_message = mpi_ctx_.NeighborValueExchangeWithConstantElements( - fused_input_data, buffer_data, num_elements, first_entry.output->dtype(), - first_entry.send_neighbors.get(), first_entry.recv_neighbors.get() - ); +#if HAVE_CUDA + if (first_entry.dst_weighting_enabled && first_entry.device != CPU_DEVICE_ID) { + cudaStreamSynchronize(mpi_ctx_.stream); + } +#endif + for (int i = 0; i < nsend; ++i) { + const void* sendbuf = + (first_entry.dst_weighting_enabled) + ? (void*)((uint8_t*)weighted_fused_input_data + num_elements * i * element_size) + : fused_input_data; + MPICHECK(MPI_Isend(sendbuf, num_elements, + mpi_ctx_.GetMPIDataType(first_entry.tensor), first_entry.send_neighbors->at(i), + /*tag=*/mpi_ctx_.rank_ + first_entry.send_neighbors->at(i), + mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL), &requests[i])); + } + MPI_Waitall(nsend + nrecv, requests.data(), statuses.data()); + error_message = + GenerateNeighborExchangeErrorMessage(statuses, nsend, nrecv); } - } else { + } else { // hierarchical neighbor allreduce if (first_entry.send_neighbors->empty()) { throw std::runtime_error( "Under hierarchical neighbor_allreduce, argument " @@ -597,6 +697,11 @@ void MPIController::NeighborAllreduce(std::vector& entries) { "Local size is smaller than 2, in this case, you should use " "neighbor_allreduce instead of hierarchical_neighbor_allreduce."); } + if (first_entry.dst_weighting_enabled) { + throw std::runtime_error( + "Under hierarchical neighbor_allreduce, argument " + "dst_weight should not be enabled for now."); + } // 1. In-place allreduce MPI_Allreduce(MPI_IN_PLACE, (void*)fused_input_data, num_elements, mpi_ctx_.GetMPIDataType(first_entry.tensor), MPI_SUM, @@ -1286,6 +1391,67 @@ Status MPIController::GetWindowVersionValue(const std::string& name, return Status::OK(); } +void MPIController::MemcpyInWeightFusionBuffer( + void*& weight_buffer_data, size_t num_dst, + const void* buffer_data, int64_t num_elements, int element_size, + std::shared_ptr context, int device) { + // Access the fusion buffer. + FusionBufferManager* buffer_manager; + auto fusion_status = GetBluefogFusionBuffer(buffer_manager); + if (!fusion_status.ok()){ + throw std::runtime_error(fusion_status.reason()); + } + std::shared_ptr buffer = + buffer_manager->GetWeightBuffer(device); + weight_buffer_data = const_cast(buffer->AccessData(context)); + size_t data_size = num_elements * element_size; + + int64_t offset = 0; + for (size_t i = 0; i < num_dst; ++i) { + void* weight_buffer_data_at_offset = (uint8_t*)weight_buffer_data + offset; +#if HAVE_CUDA + if (device != CPU_DEVICE_ID) { + CUDACHECK(cudaMemcpy(weight_buffer_data_at_offset, buffer_data, data_size, + cudaMemcpyDeviceToDevice)); + } else { +#endif + std::memcpy(weight_buffer_data_at_offset, buffer_data, data_size); +#if HAVE_CUDA + } +#endif + offset += data_size; + } +} + +const void* MPIController::GenerateWeightedFusedInputData(const void* fused_input_data, + const TensorTableEntry& entry, + int64_t num_elements, int element_size) { + // Given a fused_input_data like [t_1, t_2], the storage for neighbor_allreduce in + // weighted fused input data is like [t_1_w1, t_2_w1 | t_1_w2, t_2_w2 | t_1_w3, t_2_w3]. + // Here t_1 and t_2 means self tensor 1 and 2 and _w1, _w2, and _w3 means the + // destination weights to destination 1, 2, and 3. + void* weight_buffer_data; + MemcpyInWeightFusionBuffer(weight_buffer_data, entry.send_neighbors->size(), + fused_input_data, num_elements, element_size, + entry.context, entry.device); + int64_t offset = 0; + for (size_t i = 0; i < entry.send_neighbors->size(); ++i) { + double dst_weight = entry.send_weights->at(i); + void* weight_buffer_data_offset = (uint8_t*)weight_buffer_data + offset; + if (entry.device == CPU_DEVICE_ID) { + ScaleCPUBuffer(dst_weight, weight_buffer_data_offset, num_elements, + entry.tensor->dtype()); + } else { +#if HAVE_CUDA + ScaleBufferCudaImpl(dst_weight, weight_buffer_data_offset, num_elements, + entry.tensor->dtype(), mpi_ctx_.stream); +#endif + } + offset += num_elements * element_size; + } + return weight_buffer_data; +} + void MPIController::MemcpyInFusionBuffer( const std::vector& entries, void*& buffer_data, size_t& buffer_len) { diff --git a/bluefog/common/mpi_controller.h b/bluefog/common/mpi_controller.h index f42fc182..b3eda810 100644 --- a/bluefog/common/mpi_controller.h +++ b/bluefog/common/mpi_controller.h @@ -27,8 +27,11 @@ namespace common { // Function to check if the sending and receiving neighbors match in the // topology. -bool CheckNeighborSendRecvPattern(int size, const TensorTableEntry& entry, - Timeline* timeline_ptr, const MPI_Comm& comm); +bool CheckNeighborSendRecvPattern( + const std::vector& send_neighbors, const std::vector& recv_neighbors, + const std::string& tensor_name, int size, Timeline* timeline_ptr, const MPI_Comm& comm); +bool CheckNeighborSendRecvPatternForEntry( + const TensorTableEntry& entry, const MPIContext& mpi_ctx, Timeline* timeline_ptr); class MPIController { public: @@ -120,6 +123,14 @@ class MPIController { const int num_recv_neighbors, const int64_t fused_data_size); + void MemcpyInWeightFusionBuffer(void*& weight_buffer_data, size_t num_dst, + const void* buffer_data, int64_t num_elements, int element_size, + std::shared_ptr context, int device); + + const void* GenerateWeightedFusedInputData(const void* fused_input_data, + const TensorTableEntry& entry, + int64_t num_elements, int element_size); + void MemcpyOutFusionBufferForInputs(const void* fused_input_data, std::vector& entries); diff --git a/bluefog/common/nccl_controller.cc b/bluefog/common/nccl_controller.cc index 5475b167..f5fbd1df 100644 --- a/bluefog/common/nccl_controller.cc +++ b/bluefog/common/nccl_controller.cc @@ -27,6 +27,8 @@ #include "operations.h" #include "timeline.h" +#include "cuda/cuda_kernels.h" + namespace bluefog { namespace common { @@ -577,15 +579,13 @@ void NCCLController::NeighborAllgather(TensorTableEntry& entry) { displcmnts = new int[mpi_ctx_.neighbor_indgree_]; status = mpi_ctx_.AllocateOutput(entry, recvcounts, Communicator::GRAPH); } else { - bool is_topo_check_fail = CheckNeighborSendRecvPattern( - mpi_ctx_.size_, entry, timeline_ptr_, - mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL)); - + bool is_topo_check_fail = CheckNeighborSendRecvPatternForEntry(entry, mpi_ctx_, timeline_ptr_); if (is_topo_check_fail) { entry.callback(Status::InvalidArgument( "Src and dst neighbor ranks do not match")); return; } + recvcounts = new int[entry.recv_neighbors->size()]; displcmnts = new int[entry.recv_neighbors->size()]; status = mpi_ctx_.AllocateOutput(entry, recvcounts, Communicator::DYNAMIC, @@ -727,9 +727,7 @@ void NCCLController::NeighborAllreduce(TensorTableEntry& entry) { // If only partial sending is enabled, the following code block checks whether // the sending and recieving neighbors match each other when enable_topo_check // is set to be True. - bool is_topo_check_fail = CheckNeighborSendRecvPattern( - mpi_ctx_.size_, entry, timeline_ptr_, - mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL)); + bool is_topo_check_fail = CheckNeighborSendRecvPatternForEntry(entry, mpi_ctx_, timeline_ptr_); if (is_topo_check_fail) { entry.callback(Status::InvalidArgument( "Send and recv neighbors dont' match in neighbor " @@ -737,10 +735,13 @@ void NCCLController::NeighborAllreduce(TensorTableEntry& entry) { return; } + std::vector> weighted_tensors; + // Ensure the lifecycle of the weighted tensors are alive after communication. + #if NCCL_MINOR > 6 - if (!entry.is_hierarchical) { + if (!entry.is_hierarchical) { // neighbor allreduce without hierarchy ncclGroupStart(); - if (!entry.dynamic_neighbors_enabled) { + if (!entry.dynamic_neighbors_enabled) { // static topology for (int i = 0; i < mpi_ctx_.neighbor_indgree_; i++) { int recv_rank = mpi_ctx_.neighbor_in_ranks_[i]; void* recvbuf = (void*)(static_cast(entry.output->data()) + @@ -752,7 +753,19 @@ void NCCLController::NeighborAllreduce(TensorTableEntry& entry) { NCCLCHECK(ncclSend(sendbuf, num_elements, GetNCCLDataType(entry.tensor), send_rank, nccl_ctx_.nccl_comm, nccl_ctx_.stream)); } - } else { + } else { // dynamic topology + if(entry.dst_weighting_enabled) { + for (size_t i = 0; i < entry.send_neighbors->size(); ++i) { + auto weighted_tensor_ptr = entry.tensor->data_weight(entry.send_weights->at(i)); + weighted_tensors.push_back(std::move(weighted_tensor_ptr)); + } + } + // TODO(ybc) #83 Better design pattern for data_weight synchronization + // This ready event makes sure the data_weight computation is done before communication, as + // Pytorch CUDA stream is not synchronized with our CUDA stream, and it does nothing when + // it is running on CPU. + std::shared_ptr ready_event = + entry.context->RecordReadyEvent(entry.device); for (size_t i = 0; i < entry.recv_neighbors->size(); ++i) { int recv_rank = entry.recv_neighbors->at(i); void* recvbuf = (void*)(static_cast(entry.output->data()) + @@ -760,19 +773,32 @@ void NCCLController::NeighborAllreduce(TensorTableEntry& entry) { NCCLCHECK(ncclRecv(recvbuf, num_elements, GetNCCLDataType(entry.tensor), recv_rank, nccl_ctx_.nccl_comm, nccl_ctx_.stream)); } - for (int send_rank : *entry.send_neighbors) { - NCCLCHECK(ncclSend(sendbuf, num_elements, GetNCCLDataType(entry.tensor), - send_rank, nccl_ctx_.nccl_comm, nccl_ctx_.stream)); + if(entry.dst_weighting_enabled) { + while ((ready_event != nullptr) && !ready_event->Ready()) { + std::this_thread::sleep_for(std::chrono::nanoseconds(100)); + } + } + for (size_t i = 0; i < entry.send_neighbors->size(); ++i) { + const void* buffer_send = (entry.dst_weighting_enabled) + ? weighted_tensors[i]->data() + : sendbuf; + NCCLCHECK(ncclSend(buffer_send, num_elements, GetNCCLDataType(entry.tensor), + entry.send_neighbors->at(i), nccl_ctx_.nccl_comm, nccl_ctx_.stream)); } } ncclGroupEnd(); - } else { + } else { // hierarchical neighbor allreduce if (entry.send_neighbors->empty()) { throw std::runtime_error( "Under hierarchical neighbor_allreduce, argument " "send_machine_neighbors should " "not be empty."); } + if (entry.dst_weighting_enabled) { + throw std::runtime_error( + "Under hierarchical neighbor_allreduce, argument " + "dst_weight should not be enabled for now."); + } // 1. In-place allreduce for all local ranks. Note it is sum, so we need to // divided by local size at call back stage. NCCLCHECK(ncclAllReduce(sendbuf, (void*)sendbuf, num_elements, @@ -812,7 +838,7 @@ void NCCLController::NeighborAllreduce(TensorTableEntry& entry) { auto tid = std::this_thread::get_id(); nccl_ctx_.finalizer_thread_pool.execute( - [this, entry, event_queue, tid]() mutable { + [this, entry, event_queue, tid, weighted_tensors]() mutable { with_device device_guard(entry.device); WaitForEvents(event_queue, {entry}, this->timeline_ptr_, tid); @@ -973,9 +999,8 @@ void NCCLController::NeighborAllreduce(std::vector& entries) { // If only partial sending is enabled, the following code block checks whether // the sending and recieving neighbors match each other when enable_topo_check // is set to be True. - bool is_topo_check_fail = CheckNeighborSendRecvPattern( - mpi_ctx_.size_, first_entry, timeline_ptr_, - mpi_ctx_.GetMPICommunicator(Communicator::GLOBAL)); + bool is_topo_check_fail = CheckNeighborSendRecvPatternForEntry(first_entry, mpi_ctx_, + timeline_ptr_); if (is_topo_check_fail) { for (auto& entry : entries) { entry.callback(Status::InvalidArgument( @@ -993,6 +1018,12 @@ void NCCLController::NeighborAllreduce(std::vector& entries) { RecordEvent(event_queue, "MEM_CPY_IN"); } + const void* weighted_fused_input_data = + (first_entry.dst_weighting_enabled) + ? GenerateWeightedFusedInputData(fused_input_data, first_entry, + num_elements, element_size) + : nullptr; + // Unlike allreduce, the storage for neighbor_allreduce in fusion buffer // is like [t_1, t_2 | t_1_n1, t_2_n1, t_1_n2, t_2_n2]. // Here t_1 and t_2 means self tensor 1 and 2 and _n1 and _n2 means the @@ -1000,9 +1031,9 @@ void NCCLController::NeighborAllreduce(std::vector& entries) { // Hence, we need to offset the buffer data to location for neighbors. buffer_data = (uint8_t*)buffer_data + num_elements * element_size; - if (!first_entry.is_hierarchical) { + if (!first_entry.is_hierarchical) { // neighbor allreduce without hierarchy ncclGroupStart(); - if (!first_entry.dynamic_neighbors_enabled) { + if (!first_entry.dynamic_neighbors_enabled) { // static topology for (int i = 0; i < mpi_ctx_.neighbor_indgree_; i++) { int recv_rank = mpi_ctx_.neighbor_in_ranks_[i]; void* recvbuf = @@ -1016,7 +1047,7 @@ void NCCLController::NeighborAllreduce(std::vector& entries) { GetNCCLDataType(first_entry.tensor), send_rank, nccl_ctx_.nccl_comm, nccl_ctx_.stream)); } - } else { + } else { // dynamic topology for (size_t i = 0; i < first_entry.recv_neighbors->size(); ++i) { int recv_rank = first_entry.recv_neighbors->at(i); void* recvbuf = @@ -1025,14 +1056,18 @@ void NCCLController::NeighborAllreduce(std::vector& entries) { GetNCCLDataType(first_entry.tensor), recv_rank, nccl_ctx_.nccl_comm, nccl_ctx_.stream)); } - for (int send_rank : *first_entry.send_neighbors) { - NCCLCHECK(ncclSend(fused_input_data, num_elements, - GetNCCLDataType(first_entry.tensor), send_rank, + for (size_t i = 0; i < first_entry.send_neighbors->size(); ++i) { + const void* sendbuf = + (first_entry.dst_weighting_enabled) + ? (void*)((uint8_t*)weighted_fused_input_data + num_elements * i * element_size) + : fused_input_data; + NCCLCHECK(ncclSend(sendbuf, num_elements, GetNCCLDataType(first_entry.tensor), + first_entry.send_neighbors->at(i), nccl_ctx_.nccl_comm, nccl_ctx_.stream)); } } ncclGroupEnd(); - } else { + } else { // hierarchical neighbor allreduce if (first_entry.send_neighbors->empty()) { throw std::runtime_error( "Under hierarchical neighbor_allreduce, argument " @@ -1044,6 +1079,11 @@ void NCCLController::NeighborAllreduce(std::vector& entries) { "neighbor_allreduce instead of hierarchical_neighbor_allreduce." ); } + if (first_entry.dst_weighting_enabled) { + throw std::runtime_error( + "Under hierarchical neighbor_allreduce, argument " + "dst_weight should not be enabled for now."); + } // 1. In-place allreduce for all local ranks. Note it is sum, so we need to // divided by local size at call back stage. @@ -1103,7 +1143,7 @@ void NCCLController::NeighborAllreduce(std::vector& entries) { auto tid = std::this_thread::get_id(); nccl_ctx_.finalizer_thread_pool.execute( - [this, entries, event_queue, tid, buffer_data]() mutable { + [this, entries, event_queue, tid, buffer_data, weighted_fused_input_data]() mutable { auto& first_entry = entries[0]; with_device device_guard(first_entry.device); WaitForEvents(event_queue, entries, this->timeline_ptr_, tid); @@ -1853,6 +1893,52 @@ void NCCLController::MemcpyInFusionBuffer( buffer_len = (size_t)offset; } +void NCCLController::MemcpyInWeightFusionBuffer( + void*& weight_buffer_data, size_t num_dst, + const void* buffer_data, int64_t num_elements, int element_size, + std::shared_ptr context, int device) { + // Access the fusion buffer. + FusionBufferManager* buffer_manager; + auto fusion_status = GetBluefogFusionBuffer(buffer_manager); + if (!fusion_status.ok()){ + throw std::runtime_error(fusion_status.reason()); + } + std::shared_ptr buffer = + buffer_manager->GetWeightBuffer(device); + weight_buffer_data = const_cast(buffer->AccessData(context)); + size_t data_size = num_elements * element_size; + + int64_t offset = 0; + for (size_t i = 0; i < num_dst; ++i) { + void* weight_buffer_data_at_offset = (uint8_t*)weight_buffer_data + offset; + CUDACHECK(cudaMemcpyAsync(weight_buffer_data_at_offset, buffer_data, data_size, + cudaMemcpyDeviceToDevice, nccl_ctx_.stream)); + offset += data_size; + } +} + +const void* NCCLController::GenerateWeightedFusedInputData(const void* fused_input_data, + const TensorTableEntry& entry, + int64_t num_elements, int element_size) { + // Given a fused_input_data like [t_1, t_2], the storage for neighbor_allreduce in + // weighted fused input data is like [t_1_w1, t_2_w1 | t_1_w2, t_2_w2 | t_1_w3, t_2_w3]. + // Here t_1 and t_2 means self tensor 1 and 2 and _w1, _w2, and _w3 means the + // destination weights to destination 1, 2, and 3. + void* weight_buffer_data = nullptr; + MemcpyInWeightFusionBuffer(weight_buffer_data, entry.send_neighbors->size(), + fused_input_data, num_elements, element_size, + entry.context, entry.device); + int64_t offset = 0; + for (size_t i = 0; i < entry.send_neighbors->size(); ++i) { + double dst_weight = entry.send_weights->at(i); + void* weight_buffer_data_offset = (uint8_t*)weight_buffer_data + offset; + ScaleBufferCudaImpl(dst_weight, weight_buffer_data_offset, num_elements, + entry.tensor->dtype(), nccl_ctx_.stream); + offset += num_elements * element_size; + } + return weight_buffer_data; +} + void NCCLController::MemcpyOutFusionBuffer( const void* buffer_data, std::vector& entries) { int64_t offset = 0; diff --git a/bluefog/common/nccl_controller.h b/bluefog/common/nccl_controller.h index a646ddae..952cad5e 100644 --- a/bluefog/common/nccl_controller.h +++ b/bluefog/common/nccl_controller.h @@ -195,6 +195,14 @@ class NCCLController { const int num_recv_neighbors, const int64_t fused_data_size); + void MemcpyInWeightFusionBuffer(void*& weight_buffer_data, size_t num_dst, + const void* buffer_data, int64_t num_elements, int element_size, + std::shared_ptr context, int device); + + const void* GenerateWeightedFusedInputData(const void* fused_input_data, + const TensorTableEntry& entry, + int64_t num_elements, int element_size); + void MemcpyOutFusionBufferForInputs(const void* fused_input_data, std::vector& entries); diff --git a/bluefog/common/operations.cc b/bluefog/common/operations.cc index d505ad7e..578291e0 100644 --- a/bluefog/common/operations.cc +++ b/bluefog/common/operations.cc @@ -432,6 +432,20 @@ void CheckForStalledTensors(BluefogGlobalState& state) { } } +template +bool IsSameList (std::shared_ptr> n1, std::shared_ptr> n2) { + if (n1 == nullptr && n2 == nullptr) return true; + if (n1 == nullptr || n2 == nullptr) return false; + if (n1->size() != n2->size()) return false; + // The order matters as well. + for (size_t i = 0; i < n1->size(); i++) { + if (n1->at(i) != n2->at(i)) { + return false; + } + } + return true; +} + } // namespace bool RunLoopOnce(BluefogGlobalState& state); @@ -468,6 +482,8 @@ void BackgroundThreadLoop(BluefogGlobalState& state) { if (bluefog_fusion_threshold != nullptr) { state.tensor_fusion_threshold = std::strtol(bluefog_fusion_threshold, nullptr, 10); + state.tensor_fusion_threshold_for_dst_weight = + state.tensor_fusion_threshold; } // Initialize the tensor count table. No tensors are available yet. @@ -765,7 +781,19 @@ void PerformOperationWithFusion(std::vector& entries) { first_entry.context, [&]() { timeline.ActivityStartAll(entries, "INIT_FUSION_BUFFER"); }, [&]() { timeline.ActivityEndAll(entries); }); - if (!status.ok()) { + + // As the dst_weight requires extra memory to scale the tensor for each destination, therefore, + // extra memory is required. + Status status_dst_weight = Status::OK(); + if (first_entry.dst_weighting_enabled) { + status_dst_weight = bluefog_global.fusion_buffer.InitializeWeightBuffer( + bluefog_global.tensor_fusion_threshold_for_dst_weight, mpi_context.size_, + first_entry.device, first_entry.context, + [&]() { timeline.ActivityStartAll(entries, "INIT_WEIGHT_FUSION_BUFFER"); }, + [&]() { timeline.ActivityEndAll(entries); }); + } + + if (!status.ok() || !status_dst_weight.ok()) { for (auto& e : entries) { e.callback(status); } @@ -825,7 +853,9 @@ void PerformOperationWithFusion(std::vector& entries) { void NegotiateOfRequestOfMaster(BluefogGlobalState& state, std::deque& message_queue_buffer, bool& should_change_topo, - bool& should_shut_down) { + bool& should_shut_down) { + state.unfinished_enqueued_entries.fetch_add(message_queue_buffer.size()); + std::vector ready_to_reduce; RequestList message_list; message_list.set_shutdown(should_shut_down); @@ -948,20 +978,6 @@ void NegotiateOfRequestOfMaster(BluefogGlobalState& state, // Attempt to add more responses to this fused response. const TensorTableEntry& entry = state.tensor_queue.GetTensorEntry(response.tensor_names()[0]); - auto IsSameNeighborList = - [](std::shared_ptr> n1, - std::shared_ptr> n2) -> bool { - if (n1 == nullptr && n2 == nullptr) return true; - if (n1 == nullptr || n2 == nullptr) return false; - if (n1->size() != n2->size()) return false; - // The order matters as well. - for (int i = 0; i < n1->size(); i++) { - if (n1->at(i) != n2->at(i)) { - return false; - } - } - return true; - }; // Recall that send_neighbors is empty or not determines we use partial // neighbor allreduce or not. int num_recv_neighbors = !entry.dynamic_neighbors_enabled @@ -984,11 +1000,11 @@ void NegotiateOfRequestOfMaster(BluefogGlobalState& state, response.devices() == new_response.devices() && entry.tensor->dtype() == new_entry.tensor->dtype() && entry.dynamic_neighbors_enabled == new_entry.dynamic_neighbors_enabled && + entry.dst_weighting_enabled == new_entry.dst_weighting_enabled && entry.is_hierarchical == new_entry.is_hierarchical && - IsSameNeighborList(entry.send_neighbors, - new_entry.send_neighbors) && - IsSameNeighborList(entry.recv_neighbors, - new_entry.recv_neighbors) && + IsSameList(entry.send_neighbors, new_entry.send_neighbors) && + IsSameList(entry.send_weights, new_entry.send_weights) && + IsSameList(entry.recv_neighbors, new_entry.recv_neighbors) && tensor_size + new_tensor_size <= state.tensor_fusion_threshold) { // These tensors will fuse together well. tensor_size += new_tensor_size; @@ -1021,6 +1037,7 @@ void NegotiateOfRequestOfMaster(BluefogGlobalState& state, } else { PerformOperation(nego_entries); } + state.unfinished_enqueued_entries.fetch_sub(nego_entries.size()); } // Check for stalled tensors. @@ -1035,6 +1052,7 @@ void NegotiateOfRequestOfSlave(BluefogGlobalState& state, std::deque& message_queue_buffer, bool& should_change_topo, bool& should_shut_down) { + state.unfinished_enqueued_entries.fetch_add(message_queue_buffer.size()); std::string encoded_message; RequestList message_list; message_list.set_shutdown(state.shut_down); @@ -1043,6 +1061,7 @@ void NegotiateOfRequestOfSlave(BluefogGlobalState& state, message_list.add_request(message_queue_buffer.front()); message_queue_buffer.pop_front(); } + RequestList::SerializeToString(message_list, encoded_message); int encoded_message_length = (int)encoded_message.length() + 1; MPI_Gather(&encoded_message_length, 1, MPI_INT, nullptr, 1, MPI_INT, @@ -1070,6 +1089,7 @@ void NegotiateOfRequestOfSlave(BluefogGlobalState& state, } else { PerformOperation(nego_entries); } + state.unfinished_enqueued_entries.fetch_sub(nego_entries.size()); } if (response_list.shutdown()) { @@ -1083,6 +1103,8 @@ void NegotiateOfRequestOfSlave(BluefogGlobalState& state, void NegotiationOfRequest(BluefogGlobalState& state, std::deque& message_queue_buffer, bool& should_change_topo, bool& should_shut_down) { + // TODO(ybc) should_change_topo has no effect after condition variable refactor. + // Just keep it for a while. will remove. if (bluefog_rank() == COORDINATE_RANK) { NegotiateOfRequestOfMaster(state, message_queue_buffer, should_change_topo, should_shut_down); @@ -1097,6 +1119,15 @@ bool RunLoopOnce(BluefogGlobalState& state) { bool should_shut_down = state.shut_down; bool should_change_topo = state.setting_topology; + std::unique_lock lk(state.loop_mutex); + // The real mutex for queue is the on under TensorQueue. + state.loop_cv.wait_for(lk, std::chrono::seconds(10), [&state] { + // When we requesting shut_down, or any unfinished entries waiting in the + // negotiation we should not wait. + return state.shut_down || (state.unfinished_enqueued_entries > 0) || + (state.tensor_queue.size() > 0); + }); + // This delay determines thread frequency and MPI message latency auto sleep_duration = state.last_cycle_start + @@ -1117,7 +1148,7 @@ bool RunLoopOnce(BluefogGlobalState& state) { std::vector entries; auto IsRequestConvertToEntryDirectly = [](const Request& request) -> bool { return global_skip_negotiate_stage || - (request.request_type() != Request::ALLREDUCE && + (request.request_type() != Request::ALLREDUCE && request.request_type() != Request::ALLGATHER && request.request_type() != Request::BROADCAST && request.request_type() != Request::NEIGHBOR_ALLREDUCE && @@ -1136,7 +1167,8 @@ bool RunLoopOnce(BluefogGlobalState& state) { std::remove_if(message_queue_buffer.begin(), message_queue_buffer.end(), IsRequestConvertToEntryDirectly), message_queue_buffer.end()); - + + lk.unlock(); // Never hold the mutex when there is remote function. PerformOperation(entries); // For the rest requests, they needs to coordinate and neogiate. @@ -1149,20 +1181,6 @@ bool RunLoopOnce(BluefogGlobalState& state) { NegotiationOfRequest(state, message_queue_buffer, should_change_topo, should_shut_down); } - // Seperate the setting topology and negotiate communnication. - // TODO(ybc) Use conditional variable and mutex to re-implement this. - if (should_change_topo) { - bluefog_global.ready_to_setting_topology = true; - while (!bluefog_global.setting_topology_done) { - std::this_thread::sleep_for(SUSPEND_BACKGROUND_WAITTING_DURATION); - } - bluefog_global.ready_to_setting_topology = false; - // Wait for main thread reset. - while (bluefog_global.setting_topology_done) { - std::this_thread::sleep_for(SUSPEND_BACKGROUND_WAITTING_DURATION); - } - } - return !should_shut_down; } @@ -1201,6 +1219,7 @@ void bluefog_init() { InitializeBluefogOnce(); } void bluefog_shutdown() { if (bluefog_global.background_thread.joinable()) { bluefog_global.shut_down = true; + bluefog_global.loop_cv.notify_all(); bluefog_global.background_thread.join(); // Reset the initialization flag to allow restarting with bluefog_init(...) //bluefog_global.initialize_flag.clear(); @@ -1276,36 +1295,21 @@ int bluefog_set_topology(int indegree, const int* sources, int outdegree, return -1; } #endif - bluefog_global.setting_topology = true; - while (!bluefog_global.ready_to_setting_topology.load()) { - std::this_thread::sleep_for(SUSPEND_BACKGROUND_WAITTING_DURATION); - } - bluefog_global.tensor_queue.LockTensorQueue(); - if (bluefog_global.tensor_queue.size() > 0) { - BFLOG(ERROR) - << "Cannot set the topology because there are unfinished MPI ops."; - bluefog_global.tensor_queue.UnlockTensorQueue(); - return -1; - } - - bool mpi_result = bluefog_global.controller->SetTopology( - indegree, sources, outdegree, destinations); + bool mpi_result; + // When we change the topology, there should be no entries being processed at + // same time. + { + std::lock_guard lk(bluefog_global.loop_mutex); + mpi_result = bluefog_global.controller->SetTopology( + indegree, sources, outdegree, destinations); #if HAVE_NCCL && NCCL_MINOR < 7 - if (mpi_result && nccl_context.is_initialized) { - bluefog_global.nccl_controller->DestroyPeerCommunicators(); - bluefog_global.nccl_controller->InitPeerCommunicators(); - } + if (mpi_result && nccl_context.is_initialized) { + bluefog_global.nccl_controller->DestroyPeerCommunicators(); + bluefog_global.nccl_controller->InitPeerCommunicators(); + } #endif - bluefog_global.tensor_queue.UnlockTensorQueue(); - - bluefog_global.setting_topology = false; - bluefog_global.setting_topology_done = true; - // Wait for the background thread receive the setting_topology_done and - // close the ready_to_setting_topology epoch. - while (bluefog_global.ready_to_setting_topology) { - std::this_thread::sleep_for(SUSPEND_BACKGROUND_WAITTING_DURATION); } - bluefog_global.setting_topology_done = false; + bluefog_global.loop_cv.notify_one(); return mpi_result; } @@ -1433,6 +1437,7 @@ Status EnqueueTensorAllreduce(std::shared_ptr tensor, return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1468,6 +1473,7 @@ Status EnqueueTensorBroadcast(std::shared_ptr tensor, return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1502,6 +1508,7 @@ Status EnqueueTensorAllgather(std::shared_ptr tensor, return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1546,6 +1553,7 @@ Status EnqueueTensorNeighborAllgather(std::shared_ptr tensor, return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1555,7 +1563,9 @@ Status EnqueueTensorNeighborAllreduce(std::shared_ptr tensor, std::shared_ptr ready_event, std::shared_ptr> recv_neighbors, std::shared_ptr> send_neighbors, + std::shared_ptr> send_weights, bool dynamic_neighbors_enabled, + bool dst_weighting_enabled, bool is_hierarchical, bool enable_topo_check, const std::string& name, const int device, @@ -1579,7 +1589,9 @@ Status EnqueueTensorNeighborAllreduce(std::shared_ptr tensor, e.ready_event = ready_event; e.recv_neighbors = recv_neighbors; e.send_neighbors = send_neighbors; + e.send_weights = send_weights; e.dynamic_neighbors_enabled = dynamic_neighbors_enabled; + e.dst_weighting_enabled = dst_weighting_enabled; e.is_hierarchical = is_hierarchical; e.enable_topo_check = enable_topo_check; e.device = device; @@ -1593,6 +1605,7 @@ Status EnqueueTensorNeighborAllreduce(std::shared_ptr tensor, return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1633,6 +1646,7 @@ Status EnqueueTensorPairGossip(std::shared_ptr tensor, return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1665,6 +1679,7 @@ Status EnqueueTensorWindowCreate( return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1689,6 +1704,7 @@ Status EnqueueTensorWindowFree(const std::string& name, int device, return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1722,6 +1738,7 @@ Status EnqueueTensorWindowPut(std::shared_ptr tensor, return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1753,6 +1770,7 @@ Status EnqueueTensorWindowAccumulate( return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -1780,6 +1798,7 @@ Status EnqueueTensorWindowGet(const std::string& name, return SUSPEND_ERROR; } Status status = bluefog_global.tensor_queue.AddToTensorQueue(e, message); + bluefog_global.loop_cv.notify_one(); return status; } @@ -2022,28 +2041,15 @@ void SetSkipNegotiateStageState(bool value) { if (value == global_skip_negotiate_stage) { return; } - if (value) { - // From running negotiate to skipping negotiate, we need to properly turn - // off negotiate stage. Otherwise, it may hang the processes. Use setting - // topology flag to suspend the negotiate stage then skip it. - bluefog_global.setting_topology = true; - while (!bluefog_global.ready_to_setting_topology.load()) { - std::this_thread::sleep_for(SUSPEND_BACKGROUND_WAITTING_DURATION); - } - global_skip_negotiate_stage = value; - - bluefog_global.setting_topology = false; - bluefog_global.setting_topology_done = true; - // Wait for the background thread receive the setting_topology_done and - // close the ready_to_setting_topology epoch. - while (bluefog_global.ready_to_setting_topology) { - std::this_thread::sleep_for(SUSPEND_BACKGROUND_WAITTING_DURATION); - } - bluefog_global.setting_topology_done = false; - } else { + // From running negotiate to skipping negotiate, we need to properly turn + // off negotiate stage. Otherwise, it may hang the processes. Use setting + // topology flag to suspend the negotiate stage then skip it. + { + std::lock_guard lk(bluefog_global.loop_mutex); global_skip_negotiate_stage = value; } + bluefog_global.loop_cv.notify_one(); } bool GetSkipNegotiateStageState() { diff --git a/bluefog/common/operations.h b/bluefog/common/operations.h index 94ac99de..e31dff0f 100644 --- a/bluefog/common/operations.h +++ b/bluefog/common/operations.h @@ -148,7 +148,9 @@ Status EnqueueTensorNeighborAllreduce(std::shared_ptr tensor, std::shared_ptr ready_event, std::shared_ptr> recv_neighbors, std::shared_ptr> send_neighbors, + std::shared_ptr> send_weights, bool dynamic_neighbors_enabled, + bool dst_weighting_enabled, bool is_hierarchical, bool enable_topo_check, const std::string& name, const int device, diff --git a/bluefog/common/tensor_queue.cc b/bluefog/common/tensor_queue.cc index 513459f1..0737db91 100644 --- a/bluefog/common/tensor_queue.cc +++ b/bluefog/common/tensor_queue.cc @@ -154,5 +154,35 @@ std::shared_ptr FusionBufferManager::GetBuffer(int device) { return tensor_fusion_buffers_[device].first; } +Status FusionBufferManager::InitializeWeightBuffer( + int64_t threshold, int world_size, int device, std::shared_ptr context, + std::function on_start_init, std::function on_end_init) { + auto& elem = weight_tensor_fusion_buffers_[device]; + auto& buffer = elem.first; + int64_t& size = elem.second; + // threshold * (world_size-1) is the upper bound for buffer + if (size != threshold*(world_size-1)) { + buffer.reset(); + size = 0; + } + + if (buffer == nullptr) { + on_start_init(); + size = threshold*(world_size-1); + + // Lazily allocate persistent buffer for Tensor Fusion and keep it + // forever per device. + Status status = context->AllocatePersistent(size, &buffer); + on_end_init(); + + return status; + } + + return Status::OK(); +} + +std::shared_ptr FusionBufferManager::GetWeightBuffer(int device) { + return weight_tensor_fusion_buffers_[device].first; +} } // namespace common } // namespace bluefog \ No newline at end of file diff --git a/bluefog/common/tensor_queue.h b/bluefog/common/tensor_queue.h index 50740164..de476743 100644 --- a/bluefog/common/tensor_queue.h +++ b/bluefog/common/tensor_queue.h @@ -46,15 +46,20 @@ class TensorQueue { void PushMessageToQueue(Request& message); - // Used when setting Topology, which require the tensor queue should be empty always. + // Used when setting Topology, which require the tensor queue should be empty + // always. inline void LockTensorQueue() { mutex_.lock(); } inline void UnlockTensorQueue() { mutex_.unlock(); } - inline size_t size() { return message_queue_.size(); } + inline size_t size() { + std::lock_guard guard(mutex_); + return message_queue_.size(); + } protected: // Tensors waiting to be processed. - // Key is based upon the message name since tensor_name in table entry for win ops - // is for window and we need to add "win_put."/"win_create." before it in message. + // Key is based upon the message name since tensor_name in table entry for win + // ops is for window and we need to add "win_put."/"win_create." before it in + // message. std::unordered_map tensor_table_; // Queue of MPI requests waiting to be sent to the coordinator node. @@ -65,8 +70,8 @@ class TensorQueue { mutable std::mutex mutex_; }; -// Encapsulates the process of creating and destroying fusion buffers as the requested -// threshold is changed. +// Encapsulates the process of creating and destroying fusion buffers as the +// requested threshold is changed. class FusionBufferManager { public: // Initializes a buffer of the given threshold size if not already cached. @@ -82,13 +87,40 @@ class FusionBufferManager { std::function on_start_init, std::function on_end_init); + // Initializes a buffer of the given threshold size times MPI size if not + // already cached. There is one constraint to noticed here. We need + // WeightBuffer is always larger than (size-1)*fusion Buffer since we don't + // want to tensor being able to put into the fusion buffer but not able to put + // into weightbuffer. + // + // Args: + // threshold: Size of the buffer in bytes. + // world_size: Size of MPI nodes. + // device: Device ID to associate the buffer. + // context: Framework used to create the buffer and associate it. + // on_start_init: Callback on starting buffer initialization. + // on_end_init: Callback on completing buffer initialization. + Status InitializeWeightBuffer(int64_t threshold, int world_size, int device, + std::shared_ptr context, + std::function on_start_init, + std::function on_end_init); + // Returns the buffer associated with the given device and framework, or null. std::shared_ptr GetBuffer(int device); + // Returns the weight buffer associated with the given device and framework, + // or null. + std::shared_ptr GetWeightBuffer(int device); + private: // Memory buffers for Tensor Fusion. They are keyed by device ID. std::unordered_map, int64_t>> tensor_fusion_buffers_; + + // Memory buffers for Tensor Fusion with dst weight. They are keyed by device + // ID. + std::unordered_map, int64_t>> + weight_tensor_fusion_buffers_; }; } // namespace common diff --git a/bluefog/run/interactive_run.py b/bluefog/run/interactive_run.py index ccfc0c79..6682c8e6 100644 --- a/bluefog/run/interactive_run.py +++ b/bluefog/run/interactive_run.py @@ -406,7 +406,8 @@ def handler(signum, frame): signal.signal(signal.SIGINT, handler) env = os.environ.copy() - env['BLUEFOG_CYCLE_TIME'] = str(20) # Increase the cycle time + # No longer needed after using condition variable. + # env['BLUEFOG_CYCLE_TIME'] = str(20) # Increase the cycle time # action of stop if args.action == "stop": diff --git a/bluefog/tensorflow/adapter.cc b/bluefog/tensorflow/adapter.cc index f7864cbc..278cc426 100644 --- a/bluefog/tensorflow/adapter.cc +++ b/bluefog/tensorflow/adapter.cc @@ -89,7 +89,7 @@ const void* TFPersistentBuffer::AccessData( TFTensor::TFTensor(::tensorflow::Tensor& tensor) : tensor_(tensor) {} -const common::DataType TFTensor::dtype() const { +common::DataType TFTensor::dtype() const { switch (tensor_.dtype()) { case ::tensorflow::DT_UINT8: return common::DataType::BLUEFOG_UINT8; @@ -128,7 +128,7 @@ const void* TFTensor::data() const { return (const void*)tensor_.tensor_data().data(); } -std::shared_ptr TFTensor::data_weight(float weight) { +std::unique_ptr TFTensor::data_weight(float weight) { throw std::runtime_error("Tensorflow with weight is not implemented yet."); }; diff --git a/bluefog/tensorflow/adapter.h b/bluefog/tensorflow/adapter.h index c61c88ad..428a0a3e 100644 --- a/bluefog/tensorflow/adapter.h +++ b/bluefog/tensorflow/adapter.h @@ -47,10 +47,10 @@ class TFPersistentBuffer : public common::PersistentBuffer { class TFTensor : public common::Tensor { public: TFTensor(::tensorflow::Tensor& tensor); - virtual const common::DataType dtype() const override; + virtual common::DataType dtype() const override; virtual const common::TensorShape shape() const override; virtual const void* data() const override; - virtual std::shared_ptr data_weight(float weight) override; + virtual std::unique_ptr data_weight(float weight) override; virtual int64_t size() const override; protected: diff --git a/bluefog/torch/adapter.cc b/bluefog/torch/adapter.cc index 3cb38aab..8dd8560c 100644 --- a/bluefog/torch/adapter.cc +++ b/bluefog/torch/adapter.cc @@ -43,7 +43,7 @@ using ::bluefog::common::with_device; TorchTensor::TorchTensor(::torch::Tensor tensor) : tensor_(tensor) {} -const DataType TorchTensor::dtype() const { +DataType TorchTensor::dtype() const { switch (tensor_.scalar_type()) { case ::torch::kByte: return DataType::BLUEFOG_UINT8; @@ -76,15 +76,15 @@ const common::TensorShape TorchTensor::shape() const { const void* TorchTensor::data() const { return tensor_.data_ptr(); } -std::shared_ptr TorchTensor::data_weight(float weight) { +std::unique_ptr TorchTensor::data_weight(float weight) { if (weight == 1.0) { - return std::make_shared(tensor_); + return std::make_unique(tensor_); } else { int device = tensor_.device().is_cuda() ? tensor_.device().index() : CPU_DEVICE_ID; with_device device_context(device); // Note we call mul instead of mul_ - return std::make_shared(tensor_.mul(weight)); + return std::make_unique(tensor_.mul(weight)); } } @@ -178,6 +178,10 @@ Status TorchOpContext::AllocateZeros(int64_t num_elements, DataType dtype, Framework TorchOpContext::framework() const { return Framework::PYTORCH; } +std::shared_ptr TorchOpContext::RecordReadyEvent(int device) { + return torch::RecordReadyEvent(device); +} + #if HAVE_CUDA struct ReadyEventRegistry { std::unordered_map> cuda_events; diff --git a/bluefog/torch/adapter.h b/bluefog/torch/adapter.h index 4e192ed2..85f07994 100644 --- a/bluefog/torch/adapter.h +++ b/bluefog/torch/adapter.h @@ -32,10 +32,10 @@ namespace torch { class TorchTensor : public common::Tensor { public: TorchTensor(::torch::Tensor tensor); - virtual const common::DataType dtype() const override; + virtual common::DataType dtype() const override; virtual const common::TensorShape shape() const override; virtual const void* data() const override; - virtual std::shared_ptr data_weight(float weight) override; + virtual std::unique_ptr data_weight(float weight) override; virtual int64_t size() const override; // TODO(ybc) Figure out a better encapsulated way to do it. @@ -70,6 +70,7 @@ class TorchOpContext : public common::OpContext { virtual common::Status AllocateZeros( int64_t num_elements, common::DataType dtype, std::shared_ptr* tensor) override; + virtual std::shared_ptr RecordReadyEvent(int device) override; virtual common::Framework framework() const override; private: diff --git a/bluefog/torch/mpi_ops.cc b/bluefog/torch/mpi_ops.cc index 5f011ac1..9a980d54 100644 --- a/bluefog/torch/mpi_ops.cc +++ b/bluefog/torch/mpi_ops.cc @@ -82,8 +82,6 @@ void MaybeCopyBufferBack(::torch::Tensor tensor, ::torch::Tensor buffer) { if (IsCPUHalfTensor(tensor)) tensor.copy_(buffer.to(::torch::kFloat16)); } -} // namespace - std::function(std::function)> GetCallbackWrapper(int handle, Timeline* timeline_ptr, const std::string& op_name, std::thread::id tid) { @@ -98,6 +96,75 @@ std::function(std::function)> }; } +void PerformNeighborAllreduceCallback(::torch::Tensor tensor, ::torch::Tensor output, + double self_weight, + const std::map& src_weights, + bool avg_computation, + bool dynamic_neighbors_enabled, + bool is_hierarchical) { + int src_size = bluefog_neighbor_size(); + if (dynamic_neighbors_enabled) src_size = src_weights.size(); + if (src_size > 0) { + ::torch::Tensor output_buffer = MaybeCopyToTensorBuffer(output); + ::torch::Tensor tensor_buffer = MaybeCopyToTensorBuffer(tensor); + + int first_dim = output_buffer.size(0) / src_size; + std::vector shape_vector; + shape_vector.push_back(first_dim); + for (int idx = 1; idx < tensor_buffer.dim(); ++idx) { + shape_vector.push_back(tensor_buffer.size(idx)); + } + + // if avg_computation is set to be False, sum computation will be taken place. + if (avg_computation) { + auto output_reduced = output_buffer.slice(0, 0, first_dim); + int i = 0; + for (auto kv : src_weights) { + double weight = kv.second; + if (i == 0) { + output_reduced.mul_(weight); + } else { + output_reduced.add_( + output_buffer.slice(0, i * first_dim, (i + 1) * first_dim), weight); + } + ++i; + } + output_buffer.resize_(shape_vector); + output_buffer.add_(tensor_buffer, self_weight); + if (is_hierarchical){ + // Because there is ncclAllreduce just take sum. + output_buffer.div_(bluefog_local_size()); + } + } else { // avg_computation is False, using sum operation + if (src_size > 1) { + auto output_reduced = output_buffer.slice(0, 0, first_dim); + for (int i = 1; i < src_size; i++) { + output_reduced.add_( + output_buffer.slice(0, i * first_dim, (i + 1) * first_dim)); + } + output_buffer.resize_(shape_vector); + } + // Include self data as well. + output_buffer.add_(tensor_buffer); + if (is_hierarchical){ + // Because there is ncclAllreduce just take sum. + output_buffer.div_(bluefog_local_size() * (src_size + 1)); + } else { + output_buffer.div_(src_size + 1); + } + } + output.resize_(shape_vector); + MaybeCopyBufferBack(output, output_buffer); + } else { // recv_size == 0 + output.set_(tensor); + ::torch::Tensor output_buffer = MaybeCopyToTensorBuffer(output); + output_buffer.mul_(self_weight); + MaybeCopyBufferBack(output, output_buffer); + } +} + +} // namespace + int DoAllreduce(::torch::Tensor tensor, ::torch::Tensor output, int average, bool is_hierarchical_local, const std::string& name) { ThrowIfError(common::CheckInitialized()); @@ -320,8 +387,10 @@ int DoNeighborAllgather(::torch::Tensor tensor, ::torch::Tensor output, } int DoNeighborAllreduce(::torch::Tensor tensor, ::torch::Tensor output, - double self_weight, const std::unordered_map& neighbor_weights, - const std::vector& send_neighbors, bool dynamic_neighbors_enabled, + double self_weight, + const std::unordered_map& src_weights, + const std::unordered_map& dst_weights, + bool dynamic_neighbors_enabled, bool dst_weighting_enabled, bool enable_topo_check, bool avg_computation, bool is_hierarchical, const std::string& name) { ThrowIfError(common::CheckInitialized()); @@ -338,113 +407,49 @@ int DoNeighborAllreduce(::torch::Tensor tensor, ::torch::Tensor output, auto callback_wrapper = GetCallbackWrapper(handle, timeline_ptr, op_name, tid); - std::vector recv_neighbors; - for (auto kv : neighbor_weights) - recv_neighbors.push_back(kv.first); - std::sort(recv_neighbors.begin(), recv_neighbors.end()); - + // src_neighbors, dst_neighbors --> list of ranks only used in Enqueue + // src_weights_ordered_map --> used in callback only + // dst_weights_vec --> used in Enqueue for sending (same order of dst_neighbors) + std::map src_weights_ordered_map; + for (auto kv : src_weights) + src_weights_ordered_map.insert(kv); + std::vector src_neighbors; + for (auto kv : src_weights_ordered_map) + src_neighbors.push_back(kv.first); + + std::vector dst_neighbors; + for (auto kv : dst_weights) + dst_neighbors.push_back(kv.first); + std::sort(dst_neighbors.begin(), dst_neighbors.end()); + std::vector dst_weights_vec; + for (int rank : dst_neighbors) + dst_weights_vec.push_back(dst_weights.at(rank)); + + auto bf_src_neighbors = std::make_shared>(src_neighbors); + auto bf_dst_neighbors = std::make_shared>(dst_neighbors); + auto bf_dst_weights_vec = std::make_shared>(dst_weights_vec); + auto ready_event = RecordReadyEvent(device); if (OPS_ON_CPU && tensor.device().is_cuda()) { ::torch::Tensor cpu_buffer = tensor.to(::torch::Device(::torch::kCPU), /*non_blocking=*/false); ::torch::Tensor cpu_output = output.to(::torch::Device(::torch::kCPU), /*non_blocking=*/false); auto bf_tensor = std::make_shared(cpu_buffer); - auto bf_context = - std::make_shared(CPU_DEVICE_ID, cpu_output); + auto bf_context = std::make_shared(CPU_DEVICE_ID, cpu_output); auto bf_output = std::make_shared(cpu_output); - auto bf_recv_neighbors = std::make_shared>(recv_neighbors); - auto bf_send_neighbors = std::make_shared>(send_neighbors); - auto ready_event = RecordReadyEvent(device); + auto enqueue_result = EnqueueTensorNeighborAllreduce( - bf_tensor, bf_output, bf_context, ready_event, bf_recv_neighbors, - bf_send_neighbors, dynamic_neighbors_enabled, is_hierarchical, + bf_tensor, bf_output, bf_context, ready_event, + bf_src_neighbors, bf_dst_neighbors, bf_dst_weights_vec, + dynamic_neighbors_enabled, dst_weighting_enabled, is_hierarchical, enable_topo_check, op_name, CPU_DEVICE_ID, - callback_wrapper([self_weight, neighbor_weights, avg_computation, - cpu_output, tensor, recv_neighbors, send_neighbors, - dynamic_neighbors_enabled, is_hierarchical, output, - device]() mutable { + callback_wrapper([self_weight, src_weights_ordered_map, avg_computation, cpu_output, tensor, + dynamic_neighbors_enabled, is_hierarchical, output, device]() mutable { with_device device_guard(device); output.copy_(cpu_output); - int recv_size = bluefog_neighbor_size(); - if (dynamic_neighbors_enabled) recv_size = recv_neighbors.size(); - if (recv_size > 0) { - ::torch::Tensor output_buffer = MaybeCopyToTensorBuffer(output); - ::torch::Tensor tensor_buffer = MaybeCopyToTensorBuffer(tensor); - - int first_dim = output_buffer.size(0) / recv_size; - std::vector shape_vector; - shape_vector.push_back(first_dim); - for (int idx = 1; idx < tensor_buffer.dim(); ++idx) { - shape_vector.push_back(tensor_buffer.size(idx)); - } - - // if avg_computation is set to be False, sum computation will be taken place. - if (avg_computation) { - // 1) For a distributed graph topology, created with - // MPI_Dist_graph_create, the sequence of neighbors in the send and - // receive buffers at each process is defined as the sequence returned - // by MPI_Dist_graph_neighbors for destinations and sources, - // respectively. 2) MPI_Dist_graph_neighbors: If the communicator was - // created with MPI_Dist_graph_create_adjacent then the order of the - // values in sources and destinations is identical to the input that - // was used by the process with the same rank in comm_old in the - // creation call. - int indgree = 0; - int outdegree = 0; - int* sources_ptr = nullptr; - int* destinations_ptr = nullptr; - bluefog_load_topology(&indgree, sources_ptr, &outdegree, - destinations_ptr); - - auto output_reduced = output_buffer.slice(0, 0, first_dim); - if (dynamic_neighbors_enabled) indgree = recv_neighbors.size(); - for (int i = 0; i < indgree; i++) { - double weight = 0.0; - int recv_rank; - if (!dynamic_neighbors_enabled) recv_rank = *(sources_ptr+i); - else recv_rank = recv_neighbors[i]; - auto it = neighbor_weights.find(recv_rank); - if (it != neighbor_weights.end()) { - weight = it->second; - } - - if (i == 0) { - output_reduced.mul_(weight); - } else { - output_reduced.add_( - output_buffer.slice(0, i * first_dim, (i + 1) * first_dim), weight); - } - } - output_buffer.resize_(shape_vector); - output_buffer.add_(tensor_buffer, self_weight); - if (is_hierarchical){ - // Because there is ncclAllreduce just take sum. - output_buffer.div_(bluefog_local_size()); - } - } else { - int neighbor_size = !dynamic_neighbors_enabled - ? bluefog_neighbor_size() - : recv_neighbors.size(); - if (neighbor_size > 1) { - auto output_reduced = output_buffer.slice(0, 0, first_dim); - for (int i = 1; i < neighbor_size; i++) { - output_reduced.add_( - output_buffer.slice(0, i * first_dim, (i + 1) * first_dim)); - } - output_buffer.resize_(shape_vector); - } - // Include self data as well. - output_buffer.add_(tensor_buffer); - if (is_hierarchical){ - // Because there is ncclAllreduce just take sum. - output_buffer.div_(bluefog_local_size() * (neighbor_size + 1)); - } else { - output_buffer.div_(neighbor_size + 1); - } - } - output.resize_(shape_vector); - MaybeCopyBufferBack(output, output_buffer); - } + PerformNeighborAllreduceCallback(tensor, output, self_weight, src_weights_ordered_map, + avg_computation, dynamic_neighbors_enabled, + is_hierarchical); })); ThrowIfError(enqueue_result); @@ -452,91 +457,17 @@ int DoNeighborAllreduce(::torch::Tensor tensor, ::torch::Tensor output, auto bf_tensor = std::make_shared(tensor); auto bf_context = std::make_shared(device, output); auto bf_output = std::make_shared(output); - auto bf_recv_neighbors = std::make_shared>(recv_neighbors); - auto bf_send_neighbors = std::make_shared>(send_neighbors); - auto ready_event = RecordReadyEvent(device); auto enqueue_result = EnqueueTensorNeighborAllreduce( - bf_tensor, bf_output, bf_context, ready_event, bf_recv_neighbors, - bf_send_neighbors, dynamic_neighbors_enabled, is_hierarchical, - enable_topo_check, op_name, device, - callback_wrapper([self_weight, neighbor_weights, avg_computation, - recv_neighbors, send_neighbors, dynamic_neighbors_enabled, - is_hierarchical, tensor, output]() mutable { - int recv_size = bluefog_neighbor_size(); - if (dynamic_neighbors_enabled) recv_size = recv_neighbors.size(); - if (recv_size > 0) { - ::torch::Tensor output_buffer = MaybeCopyToTensorBuffer(output); - ::torch::Tensor tensor_buffer = MaybeCopyToTensorBuffer(tensor); - - int first_dim = output_buffer.size(0) / recv_size; - std::vector shape_vector; - shape_vector.push_back(first_dim); - for (int idx = 1; idx < tensor_buffer.dim(); ++idx) { - shape_vector.push_back(tensor_buffer.size(idx)); - } - // if avg_computation is set to be True, average computation will be taken place. - if (avg_computation) { - int indgree = 0; - int outdegree = 0; - int* sources_ptr = nullptr; - int* destinations_ptr = nullptr; - auto output_reduced = output_buffer.slice(0, 0, first_dim); - if (!dynamic_neighbors_enabled) { - bluefog_load_topology(&indgree, sources_ptr, &outdegree, - destinations_ptr); - } else { - indgree = recv_neighbors.size(); - } - for (int i = 0; i < indgree; i++) { - double weight = 0.0; - int recv_rank; - if (!dynamic_neighbors_enabled) recv_rank = *(sources_ptr+i); - else recv_rank = recv_neighbors[i]; - auto it = neighbor_weights.find(recv_rank); - if (it != neighbor_weights.end()) { - weight = it->second; - } - - if (i == 0) { - output_reduced.mul_(weight); - } else { - output_reduced.add_( - output_buffer.slice(0, i * first_dim, (i + 1) * first_dim), weight); - } - } - output_buffer.resize_(shape_vector); - output_buffer.add_(tensor_buffer, self_weight); - if (is_hierarchical){ - // Because there is ncclAllreduce just take sum. - output_buffer.div_(bluefog_local_size()); - } - } else { - int neighbor_size = !dynamic_neighbors_enabled - ? bluefog_neighbor_size() - : recv_neighbors.size(); - if (neighbor_size > 1) { - auto output_reduced = output_buffer.slice(0, 0, first_dim); - for (int i = 1; i < neighbor_size; i++) { - output_reduced.add_( - output.slice(0, i * first_dim, (i + 1) * first_dim)); - } - } - output_buffer.resize_(shape_vector); - // Include self data as well. - output_buffer.add_(tensor_buffer); - if (is_hierarchical){ - // Because there is ncclAllreduce just take sum. - output_buffer.div_(bluefog_local_size() * (neighbor_size + 1)); - } else { - output_buffer.div_(neighbor_size + 1); - } - } - output.resize_(shape_vector); - MaybeCopyBufferBack(output, output_buffer); - } else { - output.set_(tensor); - } + bf_tensor, bf_output, bf_context, ready_event, + bf_src_neighbors, bf_dst_neighbors, bf_dst_weights_vec, + dynamic_neighbors_enabled, dst_weighting_enabled, + is_hierarchical, enable_topo_check, op_name, device, + callback_wrapper([self_weight, src_weights_ordered_map, avg_computation, + dynamic_neighbors_enabled, is_hierarchical, tensor, output]() mutable { + PerformNeighborAllreduceCallback(tensor, output, self_weight, src_weights_ordered_map, + avg_computation, dynamic_neighbors_enabled, + is_hierarchical); })); ThrowIfError(enqueue_result); } diff --git a/bluefog/torch/mpi_ops.h b/bluefog/torch/mpi_ops.h index 728652a2..788f3ccc 100644 --- a/bluefog/torch/mpi_ops.h +++ b/bluefog/torch/mpi_ops.h @@ -113,12 +113,13 @@ NEIGHBOR_ALLGATHER_H(torch_cuda_FloatTensor, THCudaTensor) NEIGHBOR_ALLGATHER_H(torch_cuda_DoubleTensor, THCudaDoubleTensor) #endif -#define NEIGHBOR_ALLREDUCE_H(torch_Tensor, THTensor) \ - extern "C" int bluefog_torch_neighbor_allreduce_nonblocking_##torch_Tensor( \ - THTensor* tensor, THTensor* output, double self_weight, \ - const std::unordered_map& neighbor_weights, \ - const std::vector& send_neighbors, bool dynamic_neighbors_enabled, \ - bool enable_topo_check, bool avg_computation, bool is_hierarchical, \ +#define NEIGHBOR_ALLREDUCE_H(torch_Tensor, THTensor) \ + extern "C" int bluefog_torch_neighbor_allreduce_nonblocking_##torch_Tensor( \ + THTensor* tensor, THTensor* output, double self_weight, \ + const std::unordered_map& src_weights, \ + const std::unordered_map& dst_weights, \ + bool dynamic_neighbors_enabled, bool dst_weighting_enabled, \ + bool enable_topo_check, bool avg_computation, bool is_hierarchical, \ char* name); NEIGHBOR_ALLREDUCE_H(torch_HalfTensor, THHalfTensor) diff --git a/bluefog/torch/mpi_ops.py b/bluefog/torch/mpi_ops.py index 328a7a89..4c6b1dbc 100644 --- a/bluefog/torch/mpi_ops.py +++ b/bluefog/torch/mpi_ops.py @@ -15,8 +15,9 @@ # ============================================================================== from contextlib import contextmanager -from typing import List, Dict, Optional +from typing import List, Dict, Union, Optional +import numpy as np import torch from bluefog.torch import mpi_lib # C library @@ -474,68 +475,64 @@ def _neighbor_allreduce_function_factory(tensor): return 'bluefog_torch_neighbor_allreduce_nonblocking_' + tensor.type().replace('.', '_') -def _neighbor_allreduce_nonblocking(tensor, output, self_weight, neighbor_weights, - send_neighbors, enable_topo_check, name): +def _neighbor_allreduce_nonblocking(tensor, output, self_weight, src_weights, + dst_weights, enable_topo_check, name): function = _check_function(_neighbor_allreduce_function_factory, tensor) - if send_neighbors is None: - send_neighbors = [] + if dst_weights is None: + dst_weights = {} dynamic_neighbors_enabled = False - elif len(set(send_neighbors)) != len(send_neighbors): - raise ValueError("Argument send_neighbors should only contain the unique ranks.") - elif self_weight is None or neighbor_weights is None: - raise ValueError("Arguments self_weight and neighbor_weights should be presented if " + dst_weighting_enabled = False + elif len(set(dst_weights)) != len(dst_weights): + raise ValueError("Argument dst_weights should only contain the unique ranks.") + elif self_weight is None or src_weights is None: + raise ValueError("Arguments self_weight and src_weights should be presented if " "enabling dynamic topology.") - elif not send_neighbors: - raise ValueError("Argument send_neighbors cannot be empty list but we plan to support " - "it in future.") else: dynamic_neighbors_enabled = True - if self_weight is None and neighbor_weights is None: + if isinstance(dst_weights, list): + dst_weights = {dst:1.0 for dst in dst_weights} + dst_weighting_enabled = not np.allclose(list(dst_weights.values()), 1.0) + if self_weight is None and src_weights is None: # Implying this is static graph. if is_topo_weighted(): topology = load_topology() - self_weight, neighbor_weights = GetRecvWeights(topology, rank()) + self_weight, src_weights = GetRecvWeights(topology, rank()) weighted_average_computation = True else: weight = 1.0/(len(in_neighbor_ranks())+1) self_weight = weight - neighbor_weights = {r: weight for r in in_neighbor_ranks()} + src_weights = {r: weight for r in in_neighbor_ranks()} weighted_average_computation = False - elif self_weight is not None and neighbor_weights is not None: - if not isinstance(neighbor_weights, dict): + elif self_weight is not None and src_weights is not None: + if not isinstance(src_weights, dict): raise ValueError("Argument neighbor_weights has to be a dictionary map from the " "(in-)neighbor rank to the weights.") if not isinstance(self_weight, float): raise ValueError( "Argument self_weight has to be a float for self rank.") if not dynamic_neighbors_enabled and \ - not set(neighbor_weights.keys()).issubset(set(in_neighbor_ranks())): + not set(src_weights.keys()).issubset(set(in_neighbor_ranks())): raise ValueError("The key of weights should only contain the ranks that belong to " " in-neighbors and self rank.") - uniform_weights = 1.0/(len(neighbor_weights)+1) - weighted_average_computation = False - if abs(self_weight - uniform_weights) > 1e-6: - weighted_average_computation = True - for n_weights in neighbor_weights.values(): - if abs(n_weights - uniform_weights) > 1e-6: - weighted_average_computation = True - break + uniform_weights = 1.0/(len(src_weights)+1) + weighted_average_computation = not(np.isclose(self_weight, uniform_weights) and + np.allclose(list(src_weights.values()), uniform_weights)) else: raise ValueError("Arguments self_weight and neighbor_weights have to be presented at " "the same time") is_hierarchical = False - handle = getattr(mpi_lib, function)(tensor, output, self_weight, neighbor_weights, - send_neighbors, dynamic_neighbors_enabled, + handle = getattr(mpi_lib, function)(tensor, output, self_weight, src_weights, dst_weights, + dynamic_neighbors_enabled, dst_weighting_enabled, enable_topo_check, weighted_average_computation, is_hierarchical, name.encode() if name is not None else "") _handle_map[handle] = (tensor, output) return handle -def neighbor_allreduce(tensor: torch.Tensor, +def neighbor_allreduce(tensor: torch.Tensor, *, self_weight: Optional[float] = None, - neighbor_weights: Optional[Dict[int, float]] = None, - send_neighbors: Optional[List[int]] = None, + src_weights: Optional[Dict[int, float]] = None, + dst_weights: Optional[Union[Dict[int, float], List[int]]] = None, enable_topo_check: bool = True, name: Optional[str] = None) -> torch.Tensor: """ @@ -552,17 +549,17 @@ def neighbor_allreduce(tensor: torch.Tensor, Arguments: tensor: A tensor to execute weighted average with neighbors. self_weight: The weight for self node, used with neighbor_weights. - neighbor_weights: The weights for in-neighbor nodes, used with self weight. + src_weights: The weights for in-neighbor nodes, used with self weight. If neighbor_weights is presented, the return tensor will return the weighted average defined by these weights and the self_weight. If not, the return tensor will return the weighted average defined by the topology weights is provided or uniformly average. The data structure of weights should be {rank : weight} and rank has to belong to the (in-)neighbors. - send_neighbors: The list of neighbor nodes to be sent to. If set to be None, assume the + dst_weights: The weights for out-neighbor nodes. If set to be None, assume the the current node sends to all of its (out-)neighbors. If having values, assume only - part of (out-)neighbors will be sent to. In this mode, this node sends its value to - partial neighbors listed in this variable in a dynamic graph, and `self_weight` and - `neighbor_weights` must be present. + part of (out-)neighbors will be sent to. If set to be a list, assume all the weights + are one. In this mode, this node sends its value to partial neighbors listed in this + variable in a dynamic graph, and `self_weight` and `src_weights` must be present. enable_topo_check: When send_neighbors is present, enabling this option checks if the sending and recieving neighbors match with each other. Disabling this check can boost the performance. @@ -573,19 +570,24 @@ def neighbor_allreduce(tensor: torch.Tensor, Note: self_weight and neighbor_weights must be presented at the same time. """ - if (self_weight is None and neighbor_weights is not None) or \ - (self_weight is not None and neighbor_weights is None): - raise ValueError("Arguments self_weight and neighbor_weights have to be presented at " + # TODO(hanbinhu) #82 Symmetrical argument for self_weight, src_weights, dst_weights + if (self_weight is None and src_weights is not None) or \ + (self_weight is not None and src_weights is None): + raise ValueError("Arguments self_weight and src_weights have to be presented at " "the same time") - handle = neighbor_allreduce_nonblocking(tensor, self_weight, neighbor_weights, - send_neighbors, enable_topo_check, name) + handle = neighbor_allreduce_nonblocking(tensor, + self_weight=self_weight, + src_weights=src_weights, + dst_weights=dst_weights, + enable_topo_check=enable_topo_check, + name=name) return synchronize(handle) -def neighbor_allreduce_nonblocking(tensor: torch.Tensor, +def neighbor_allreduce_nonblocking(tensor: torch.Tensor, *, self_weight: Optional[float] = None, - neighbor_weights: Optional[Dict[int, float]] = None, - send_neighbors: Optional[List[int]] = None, + src_weights: Optional[Dict[int, float]] = None, + dst_weights: Optional[Union[Dict[int, float], List[int]]] = None, enable_topo_check: bool = True, name: Optional[str] = None) -> int: """ @@ -602,17 +604,17 @@ def neighbor_allreduce_nonblocking(tensor: torch.Tensor, Arguments: tensor: A tensor to execute weighted average with neighbors. self_weight: The weight for self node, used with neighbor_weights. - neighbor_weights: The weights for in-neighbor nodes, used with self weight. + src_weights: The weights for in-neighbor nodes, used with self weight. If neighbor_weights is presented, the return tensor will return the weighted average defined by these weights and the self_weight. If not, the return tensor will return the weighted average defined by the topology weights is provided or uniformly average. The data structure of weights should be {rank : weight} and rank has to belong to the (in-)neighbors. - send_neighbors: The list of neighbor nodes to be sent to. If set to be None, assume the + dst_weights: The weights for out-neighbor nodes. If set to be None, assume the the current node sends to all of its (out-)neighbors. If having values, assume only - part of (out-)neighbors will be sent to. In this mode, this node sends its value to - partial neighbors listed in this variable in a dynamic graph, and `self_weight` and - `neighbor_weights` must be present. + part of (out-)neighbors will be sent to. If set to be a list, assume all the weights + are one. In this mode, this node sends its value to partial neighbors listed in this + variable in a dynamic graph, and `self_weight` and `src_weights` must be present. enable_topo_check: When send_neighbors is present, enabling this option checks if the sending and recieving neighbors match with each other. Disabling this check can boost the performance. @@ -624,20 +626,21 @@ def neighbor_allreduce_nonblocking(tensor: torch.Tensor, Note: self_weight and neighbor_weights must be presented at the same time. """ - if (self_weight is None and neighbor_weights is not None) or \ - (self_weight is not None and neighbor_weights is None): - raise ValueError("Arguments self_weight and neighbor_weights have to be presented at " + # TODO(hanbinhu) #82 Symmetrical argument for self_weight, src_weights, dst_weights + if (self_weight is None and src_weights is not None) or \ + (self_weight is not None and src_weights is None): + raise ValueError("Arguments self_weight and src_weights have to be presented at " "the same time") - if send_neighbors is None: + if dst_weights is None: first_dim = tensor.shape[0] * len(in_neighbor_ranks()) else: - first_dim = tensor.shape[0] * len(neighbor_weights) + first_dim = tensor.shape[0] * len(src_weights) new_shape = torch.Size([first_dim] + list(tensor.shape[1:])) output = tensor.new(new_shape) # Pre-allocate the memory for the output. - return _neighbor_allreduce_nonblocking(tensor, output, self_weight, neighbor_weights, - send_neighbors, enable_topo_check, name=name) - + return _neighbor_allreduce_nonblocking(tensor, output, self_weight, src_weights, + dst_weights, enable_topo_check, name=name) +# TODO(hanbinhu) #81 Add dst_weight for hierarchical neighbor allreduce. def hierarchical_neighbor_allreduce(tensor: torch.Tensor, self_weight: float = None, neighbor_machine_weights: Dict[int, float] = None, diff --git a/bluefog/torch/optimizers.py b/bluefog/torch/optimizers.py index ceeb5a94..82fc677c 100644 --- a/bluefog/torch/optimizers.py +++ b/bluefog/torch/optimizers.py @@ -324,8 +324,8 @@ def __init__(self, params, model, communication_type, num_steps_per_communicatio named_parameters, models = _check_named_parameters(self, model) # knobs for neighbor communication behavior self.self_weight = None - self.neighbor_weights = None - self.send_neighbors = None + self.src_weights = None + self.dst_weights = None self.neighbor_machine_weights = None self.send_neighbor_machines = None self.enable_topo_check = False @@ -394,8 +394,8 @@ def hook(model, *unused): def _neighbor_allreduce_data_async(self, p): name = self._parameter_names.get(p) handle = bf.neighbor_allreduce_nonblocking(p.data, name=name, self_weight=self.self_weight, - neighbor_weights=self.neighbor_weights, - send_neighbors=self.send_neighbors, + src_weights=self.src_weights, + dst_weights=self.dst_weights, enable_topo_check=self.enable_topo_check) return handle @@ -489,8 +489,8 @@ def __init__(self, params, model, communication_type, backward_passes_per_step=1 named_parameters, models = _check_named_parameters(self, model) # knobs for neighbor communication behavior self.self_weight = None - self.neighbor_weights = None - self.send_neighbors = None + self.src_weights = None + self.dst_weights = None self.neighbor_machine_weights = None self.send_neighbor_machines = None self.enable_topo_check = False @@ -762,8 +762,8 @@ def _adadelta_step(self, p, grad, param_group): def _neighbor_allreduce_data_async(self, p): name = self._parameter_names.get(p) handle = bf.neighbor_allreduce_nonblocking(p.data, name=name, self_weight=self.self_weight, - neighbor_weights=self.neighbor_weights, - send_neighbors=self.send_neighbors, + src_weights=self.src_weights, + dst_weights=self.dst_weights, enable_topo_check=self.enable_topo_check) return handle diff --git a/examples/pytorch_average_consensus.py b/examples/pytorch_average_consensus.py index 0b1b2f1e..95711596 100644 --- a/examples/pytorch_average_consensus.py +++ b/examples/pytorch_average_consensus.py @@ -99,8 +99,8 @@ self_weight = 1 / (len(recv_neighbors) + 1) x = bf.neighbor_allreduce(x, name='x', self_weight=self_weight, - neighbor_weights=neighbor_weights, - send_neighbors=send_neighbors, enable_topo_check=False) + src_weights=neighbor_weights, + dst_weights=send_neighbors, enable_topo_check=False) mse.append(torch.norm(x-x_bar, p=2) / torch.norm(x_bar, p=2)) else: outdegree = len(bf.out_neighbor_ranks()) diff --git a/examples/pytorch_benchmark.py b/examples/pytorch_benchmark.py index 958e6571..64cf1eb2 100644 --- a/examples/pytorch_benchmark.py +++ b/examples/pytorch_benchmark.py @@ -186,8 +186,8 @@ def dynamic_topology_update(batch_idx): optimizer.dst_weights = {sent_neighbor: 1.0} elif args.dist_optimizer == 'neighbor_allreduce': send_neighbors, recv_neighbors = next(dynamic_neighbor_allreduce_gen) - optimizer.send_neighbors = send_neighbors - optimizer.neighbor_weights = { + optimizer.dst_weights = send_neighbors + optimizer.src_weights = { r: 1/(len(recv_neighbors) + 1) for r in recv_neighbors} optimizer.self_weight = 1 / (len(recv_neighbors) + 1) optimizer.enable_topo_check = False diff --git a/examples/pytorch_mnist.py b/examples/pytorch_mnist.py index 8315dffe..7643369c 100644 --- a/examples/pytorch_mnist.py +++ b/examples/pytorch_mnist.py @@ -215,8 +215,8 @@ def dynamic_topology_update(epoch, batch_idx): optimizer.dst_weights = {sent_neighbor: 1.0} elif args.dist_optimizer == 'neighbor_allreduce': send_neighbors, recv_neighbors = next(dynamic_neighbor_allreduce_gen) - optimizer.send_neighbors = send_neighbors - optimizer.neighbor_weights = {r: 1/(len(recv_neighbors) + 1) for r in recv_neighbors} + optimizer.dst_weights = send_neighbors + optimizer.src_weights = {r: 1/(len(recv_neighbors) + 1) for r in recv_neighbors} optimizer.self_weight = 1 / (len(recv_neighbors) + 1) optimizer.enable_topo_check = False elif args.dist_optimizer == 'hierarchical_neighbor_allreduce': diff --git a/examples/pytorch_optimization.py b/examples/pytorch_optimization.py index 058a1ca2..658cd485 100644 --- a/examples/pytorch_optimization.py +++ b/examples/pytorch_optimization.py @@ -15,6 +15,8 @@ import os import torch +import matplotlib +matplotlib.use('agg') # Make matplotlib more robust when interface plotting is impossible. import matplotlib.pyplot as plt import argparse @@ -199,11 +201,13 @@ def diffusion(X, y, w_opt, loss, maxite=2000, alpha=1e-1, **kwargs): loss_step(X, y, w, tensor_name='neighbor.allreduce.local_variable', loss=loss, rho=rho) + # diffusion with torch.no_grad(): - # diffusion phi = w - alpha * w.grad.data - w.data = bf.neighbor_allreduce( - phi, self_weight, neighbor_weights, name='local variable') + w.data = bf.neighbor_allreduce(phi, + self_weight=self_weight, + src_weights=neighbor_weights, + name='local variable') w.grad.data.zero_() # record convergence @@ -271,8 +275,10 @@ def exact_diffusion(X, y, w_opt, loss, maxite=2000, alpha=1e-1, use_Abar=True, * with torch.no_grad(): psi = w - alpha * w.grad.data phi = psi + w.data - psi_prev - w.data = bf.neighbor_allreduce( - phi, self_weight, neighbor_weights, name='local variable') + w.data = bf.neighbor_allreduce(phi, + self_weight=self_weight, + src_weights=neighbor_weights, + name='local variable') psi_prev = psi.clone() w.grad.data.zero_() @@ -330,8 +336,7 @@ def gradient_tracking(X, y, w_opt, loss, maxite=2000, alpha=1e-1, **kwargs): # q^{k+1} = neighbor_allreduce(q^k) + grad(w^{k+1}) - grad(w^k) # Notice the communication of neighbor_allreduce can overlap with gradient computation. - w_handle = bf.neighbor_allreduce_nonblocking( - w.data, name='Grad.Tracking.w') + w_handle = bf.neighbor_allreduce_nonblocking(w.data, name='Grad.Tracking.w') q_handle = bf.neighbor_allreduce_nonblocking(q, name='Grad.Tracking.q') w.data = bf.synchronize(w_handle) - alpha * q # calculate local gradient diff --git a/examples/pytorch_resnet.py b/examples/pytorch_resnet.py index b0ab4ce4..62f3eb01 100644 --- a/examples/pytorch_resnet.py +++ b/examples/pytorch_resnet.py @@ -361,8 +361,8 @@ def dynamic_topology_update(epoch, batch_idx): optimizer.dst_weights = {sent_neighbor: 1.0} elif args.dist_optimizer == 'neighbor_allreduce': send_neighbors, recv_neighbors = next(dynamic_neighbor_allreduce_gen) - optimizer.send_neighbors = send_neighbors - optimizer.neighbor_weights = {r: 1/(len(recv_neighbors) + 1) for r in recv_neighbors} + optimizer.dst_weights = send_neighbors + optimizer.src_weights = {r: 1/(len(recv_neighbors) + 1) for r in recv_neighbors} optimizer.self_weight = 1 / (len(recv_neighbors) + 1) optimizer.enable_topo_check = False elif args.dist_optimizer == 'hierarchical_neighbor_allreduce': diff --git a/setup.py b/setup.py index d06cb445..fa736de5 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,9 @@ import subprocess import sys import textwrap +import shlex import traceback +from typing import List from distutils.errors import CompileError, DistutilsError, \ @@ -437,16 +439,52 @@ def is_torch_cuda(build_ext, include_dirs, extra_compile_args): print('INFO: Above error indicates that this PyTorch installation does not support CUDA.') return False +def get_nvcc_cmd() -> str: + from shutil import which + nvcc_cmd = which('nvcc') + if nvcc_cmd is None: + raise DistutilsPlatformError('Unable to find NVCC compiler') + return nvcc_cmd + +def build_nvcc_extra_objects(nvcc_cmd: str, cxx11_abi: bool) -> List[str]: + # nvcc --compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI=0' -rdc=true -c cuda_kernels.cu + # nvcc --compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI=0' -dlink -o cuda_kernels_link.o \ + # cuda_kernels.o -lcudart + nvcc_flags = f'-fPIC -D_GLIBCXX_USE_CXX11_ABI={int(cxx11_abi)}' + + extra_object_dir = 'bluefog/common/cuda/' + source = extra_object_dir+'cuda_kernels.cu' + object_file = extra_object_dir+'cuda_kernels.o' + object_link = extra_object_dir+'cuda_kernels_link.o' + + command_object = [nvcc_cmd, '--compiler-options', nvcc_flags, + '-rdc=true', '-c', source, '-o', object_file] + command_link = [nvcc_cmd, '--compiler-options', nvcc_flags, + '-dlink', object_file, '-lcudart', '-o', object_link] + + command_object_str = ' '.join(shlex.quote(par) for par in command_object) + command_link_str = ' '.join(shlex.quote(par) for par in command_link) + + subprocess.check_call(command_object_str, shell=True) + subprocess.check_call(command_link_str, shell=True) + return [object_file, object_link] def build_torch_extension(build_ext, global_options, torch_version): # Backup the options, preventing other plugins access libs that # compiled with compiler of this plugin + import torch + is_cxx11_abi = torch.compiled_with_cxx11_abi() + options = copy.deepcopy(global_options) have_cuda = is_torch_cuda(build_ext, include_dirs=options['INCLUDES'], extra_compile_args=options['COMPILE_FLAGS']) if have_cuda: cuda_include_dirs, cuda_lib_dirs = get_cuda_dirs( build_ext, options['COMPILE_FLAGS']) + nvcc_cmd = get_nvcc_cmd() + cuda_extra_objects = build_nvcc_extra_objects(nvcc_cmd, is_cxx11_abi) + options['EXTRA_OBJECTS'] += cuda_extra_objects + options['INCLUDES'] += cuda_include_dirs options['LIBRARY_DIRS'] += cuda_lib_dirs options['LIBRARIES'] += ['cudart'] @@ -478,9 +516,8 @@ def build_torch_extension(build_ext, global_options, torch_version): updated_macros, 'TORCH_VERSION', str(torch_version)) # Always set _GLIBCXX_USE_CXX11_ABI, since PyTorch can only detect whether it was set to 1. - import torch updated_macros = set_macro(updated_macros, '_GLIBCXX_USE_CXX11_ABI', - str(int(torch.compiled_with_cxx11_abi()))) + str(int(is_cxx11_abi))) # PyTorch requires -DTORCH_API_INCLUDE_EXTENSION_H updated_macros = set_macro( @@ -504,6 +541,7 @@ def build_torch_extension(build_ext, global_options, torch_version): extra_compile_args=options['COMPILE_FLAGS'], extra_link_args=options['LINK_FLAGS'], library_dirs=options['LIBRARY_DIRS'], + extra_objects=options['EXTRA_OBJECTS'], libraries=options['LIBRARIES']) # Patch an existing bluefog_torch_mpi_lib extension object. diff --git a/test/test_all_example.sh b/test/test_all_example.sh new file mode 100755 index 00000000..1e2c92d9 --- /dev/null +++ b/test/test_all_example.sh @@ -0,0 +1,115 @@ +#!/bin/bash +NUM_PROC=4 +BFRUN="bfrun -np ${NUM_PROC}" +RUN_DIR="$( pwd )" +EXAMPLE_DIR="$( cd "$( dirname "$0" )" && pwd )/../examples" + +die() { echo >&2 -e "\nERROR: $@\n"; exit 1; } +check() { + timeout 2m "$@" >>/dev/null 2>&1; + local exit_code=$?; + [ $exit_code -eq 0 ] \ + && echo "Command [$*] succeed" \ + || die "Command [$*] failed with error code $exit_code"; +} + +# check GPU exists +nvidia-smi >>/dev/null 2>&1 +gpu_exit_code=$? +if [[ $gpu_exit_code -eq 0 ]]; then + isgpu=1 +else + isgpu=0 +fi + +if [[ $isgpu -eq 1 ]]; then + echo "GPU Detected" +else + echo "No GPU Detected." +fi + +# PyTorch Average Concensus Cases +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_average_consensus.py +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_average_consensus.py --enable-dynamic-topology +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_average_consensus.py --asynchronous-mode +if [[ $isgpu -eq 1 ]]; then + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_average_consensus.py --no-cuda + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_average_consensus.py --no-cuda --enable-dynamic-topology + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_average_consensus.py --no-cuda --asynchronous-mode +fi + +# PyTorch Optimization Cases +[ -f "${RUN_DIR}/plot.png" ] && EXIST_PLOT_PNG=1 || EXIST_PLOT_PNG=0 +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_optimization.py --method=diffusion +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_optimization.py --method=exact_diffusion +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_optimization.py --method=gradient_tracking +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_optimization.py --method=push_diging +[ "${EXIST_PLOT_PNG}" == 0 ] && rm -f ${RUN_DIR}/plot.png + +# PyTorch MNIST Cases +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --dist-optimizer=gradient_allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --dist-optimizer=allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --dist-optimizer=allreduce --atc-style +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --dist-optimizer=neighbor_allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --dist-optimizer=neighbor_allreduce --disable-dynamic-topology +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --dist-optimizer=neighbor_allreduce --atc-style +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --dist-optimizer=neighbor_allreduce --atc-style --disable-dynamic-topology +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --dist-optimizer=win_put +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --dist-optimizer=win_put --disable-dynamic-topology +if [[ $isgpu -eq 1 ]]; then + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --no-cuda --dist-optimizer=gradient_allreduce + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --no-cuda --dist-optimizer=allreduce + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --no-cuda --dist-optimizer=allreduce --atc-style + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --no-cuda --dist-optimizer=neighbor_allreduce + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --no-cuda --dist-optimizer=neighbor_allreduce --disable-dynamic-topology + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --no-cuda --dist-optimizer=neighbor_allreduce --atc-style + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --no-cuda --dist-optimizer=neighbor_allreduce --atc-style --disable-dynamic-topology + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --no-cuda --dist-optimizer=win_put + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_mnist.py --epochs=1 --no-cuda --dist-optimizer=win_put --disable-dynamic-topology +fi + +# PyTorch Benchmark Cases +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --dist-optimizer=gradient_allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --dist-optimizer=allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --dist-optimizer=allreduce --atc-style +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --dist-optimizer=neighbor_allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --dist-optimizer=neighbor_allreduce --disable-dynamic-topology +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --dist-optimizer=neighbor_allreduce --atc-style +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --dist-optimizer=neighbor_allreduce --atc-style --disable-dynamic-topology +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --dist-optimizer=win_put +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --dist-optimizer=win_put --disable-dynamic-topology +if [[ $isgpu -eq 1 ]]; then + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --no-cuda --dist-optimizer=gradient_allreduce + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --no-cuda --dist-optimizer=allreduce + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --no-cuda --dist-optimizer=allreduce --atc-style + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --no-cuda --dist-optimizer=neighbor_allreduce + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --no-cuda --dist-optimizer=neighbor_allreduce --disable-dynamic-topology + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --no-cuda --dist-optimizer=neighbor_allreduce --atc-style + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --no-cuda --dist-optimizer=neighbor_allreduce --atc-style --disable-dynamic-topology + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --no-cuda --dist-optimizer=win_put + check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_benchmark.py --model=lenet --num-iters=1 --no-cuda --dist-optimizer=win_put --disable-dynamic-topology +fi + +# PyTorch ResNet Cases +[ -d "${RUN_DIR}/logs" ] && EXIST_LOGS_DIR=1 || EXIST_LOGS_DIR=0 +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --dist-optimizer=gradient_allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --dist-optimizer=allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --dist-optimizer=allreduce --atc-style +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --dist-optimizer=neighbor_allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --dist-optimizer=neighbor_allreduce --disable-dynamic-topology +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --dist-optimizer=neighbor_allreduce --atc-style +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --dist-optimizer=neighbor_allreduce --atc-style --disable-dynamic-topology +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --dist-optimizer=win_put +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --dist-optimizer=win_put --disable-dynamic-topology +if [[ $isgpu -eq 1 ]]; then +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --no-cuda --dist-optimizer=gradient_allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --no-cuda --dist-optimizer=allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --no-cuda --dist-optimizer=allreduce --atc-style +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --no-cuda --dist-optimizer=neighbor_allreduce +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --no-cuda --dist-optimizer=neighbor_allreduce --disable-dynamic-topology +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --no-cuda --dist-optimizer=neighbor_allreduce --atc-style +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --no-cuda --dist-optimizer=neighbor_allreduce --atc-style --disable-dynamic-topology +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --no-cuda --dist-optimizer=win_put +check ${BFRUN} python ${EXAMPLE_DIR}/pytorch_resnet.py --model=squeezenet1_0 --epochs=1 --no-cuda --dist-optimizer=win_put --disable-dynamic-topology +fi +[ "${EXIST_LOGS_DIR}" == 0 ] && rm -rf ${RUN_DIR}/logs \ No newline at end of file diff --git a/test/torch_ops_test.py b/test/torch_ops_test.py index 9be0502f..c8c8ca5a 100644 --- a/test/torch_ops_test.py +++ b/test/torch_ops_test.py @@ -171,6 +171,7 @@ def test_allreduce_sum(self): dims = [1, 2, 3] for dtype, dim in itertools.product(dtypes, dims): + torch.manual_seed(123456) tensor = torch.FloatTensor(*([23] * dim)).random_(-100, 100) tensor = self.cast_and_place(tensor, dtype) name = "allreduce_tensor_{}_{}".format(dim, dtype) @@ -387,7 +388,7 @@ def test_neighbor_allreduce_sum_precision(self): name = "neighbor_allreduce_{}_{}".format(dim, dtype) nw = {i: 1.0 for i in neighbor_ranks} reduced_tensor = bf.neighbor_allreduce(tensor, self_weight=1.0, - neighbor_weights=nw, name=name) + src_weights=nw, name=name) assert ( list(reduced_tensor.shape) == [23] * dim ), "bf.neighbor_allreduce (avg) produces incorrect reduced shape" @@ -450,7 +451,7 @@ def test_neighbor_allreduce_dynamic_topo_check(self): name = "neighbor_allreduce_{}_{}".format(dim, dtype) with pytest.raises(ValueError): bf.neighbor_allreduce(tensor, name=name, self_weight=self_weight, - neighbor_weights=neighbor_weights, send_neighbors=send_ranks) + src_weights=neighbor_weights, dst_weights=send_ranks) def test_neighbor_allreduce_dynamic_topo_outside_static_topo_move(self): """Test that the neighbor all reduce (move) 1D, 2D, 3D tensors correctly @@ -477,7 +478,7 @@ def test_neighbor_allreduce_dynamic_topo_outside_static_topo_move(self): name = "neighbor_allreduce_{}_{}".format(dim, dtype) reduced_tensor = bf.neighbor_allreduce( tensor, name=name, self_weight=self_weight, - neighbor_weights=neighbor_weights, send_neighbors=send_ranks) + src_weights=neighbor_weights, dst_weights=send_ranks) eps = EPSILON if tensor.dtype != torch.float16 else LOOSE_EPSILON tensor, reduced_tensor = self.convert_cpu_fp16_to_fp32(tensor, reduced_tensor) assert ( @@ -511,7 +512,7 @@ def test_neighbor_allreduce_dynamic_topo_move(self): name = "neighbor_allreduce_{}_{}".format(dim, dtype) reduced_tensor = bf.neighbor_allreduce( tensor, name=name, self_weight=self_weight, - neighbor_weights=neighbor_weights, send_neighbors=send_ranks) + src_weights=neighbor_weights, dst_weights=send_ranks) eps = EPSILON if tensor.dtype != torch.float16 else LOOSE_EPSILON tensor, reduced_tensor = self.convert_cpu_fp16_to_fp32(tensor, reduced_tensor) assert ( @@ -521,7 +522,7 @@ def test_neighbor_allreduce_dynamic_topo_move(self): (reduced_tensor.data - (rank-1) % size).abs().max() < eps ), "bf.neighbor_allreduce (move) produces incorrect reduced tensor" - @unittest.skip("Haven't fully clear on the usage due to sync issues. Temporarily disabled") + #@unittest.skip("Haven't fully clear on the usage due to sync issues. Temporarily disabled") def test_neighbor_allreduce_dynamic_topo_with_empty_send_neighbors(self): """Test that the neighbor all reduce (avg) 1D, 2D, 3D tensors correctly with empty send_neighbors.""" @@ -553,7 +554,7 @@ def test_neighbor_allreduce_dynamic_topo_with_empty_send_neighbors(self): name = "neighbor_allreduce_{}_{}".format(dim, dtype) reduced_tensor = bf.neighbor_allreduce( tensor, name=name, self_weight=self_weight, - neighbor_weights=neighbor_weights, send_neighbors=send_ranks) + src_weights=neighbor_weights, dst_weights=send_ranks) eps = EPSILON if tensor.dtype != torch.float16 else LOOSE_EPSILON tensor, reduced_tensor = self.convert_cpu_fp16_to_fp32(tensor, reduced_tensor) assert ( @@ -591,7 +592,7 @@ def test_neighbor_allreduce_dynamic_topo_avg(self): name = "neighbor_allreduce_{}_{}".format(dim, dtype) reduced_tensor = bf.neighbor_allreduce( tensor, name=name, self_weight=self_weight, - neighbor_weights=neighbor_weights, send_neighbors=send_ranks) + src_weights=neighbor_weights, dst_weights=send_ranks) eps = EPSILON if tensor.dtype != torch.float16 else LOOSE_EPSILON tensor, reduced_tensor = self.convert_cpu_fp16_to_fp32(tensor, reduced_tensor) assert ( @@ -818,7 +819,7 @@ def test_neighbor_allreduce_sum(self): tensor = self.cast_and_place(tensor, dtype) nw = {i: 1.0 for i in neighbor_ranks} reduced_tensor = bf.neighbor_allreduce(tensor, self_weight=1.0, - neighbor_weights=nw) + src_weights=nw) tensor, reduced_tensor = self.convert_cpu_fp16_to_fp32(tensor, reduced_tensor) assert ( list(reduced_tensor.shape) == [23] * dim @@ -830,6 +831,45 @@ def test_neighbor_allreduce_sum(self): reduced_tensor.data.max() == sum_value ), "bf.neighbor_allreduce (sum) produces incorrect reduced tensor" + def test_neighbor_allreduce_dst_weight(self): + """Test that the neighbor allreduce with destination weights works correctly.""" + size = bf.size() + rank = bf.rank() + if size <= 1: + fname = inspect.currentframe().f_code.co_name + warnings.warn("Skip {} due to size 1".format(fname)) + return + + dtypes = [torch.FloatTensor, torch.DoubleTensor] + if TEST_ON_GPU: + dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + + # By default, we use exponential two ring topology. + num_indegree = int(np.ceil(np.log2(size))) + neighbor_ranks = [(rank - 2**i) % size for i in range(num_indegree)] + dst_weight = 0.5 + expected_value = (np.sum(neighbor_ranks) + rank)*dst_weight + + self_weight = 0.5 + src_weights = {i: 1.0 for i in neighbor_ranks} + dst_weights = {(rank + 2**i) % size : dst_weight for i in range(num_indegree)} + + dims = [1,2,3] + for dtype, dim in itertools.product(dtypes, dims): + tensor = torch.FloatTensor(*([23] * dim)).fill_(1).mul_(rank) + tensor = self.cast_and_place(tensor, dtype) + name = "neighbor_allreduce_{}_{}".format(dim, dtype) + reduced_tensor = bf.neighbor_allreduce( + tensor, name=name, self_weight=self_weight, + src_weights=src_weights, dst_weights=dst_weights) + eps = EPSILON if tensor.dtype != torch.float16 else LOOSE_EPSILON + assert ( + list(reduced_tensor.shape) == [23] * dim + ), "bf.neighbor_allreduce (avg) produces incorrect reduced shape" + assert ( + (reduced_tensor - expected_value).abs().max() < eps + ), "bf.neighbor_allreduce (avg) produces incorrect reduced tensor" + def test_neighbor_allreduce_weighted_avg(self): """Test that the neighbor all reduce (avg) 1D, 2D, 3D tensors correctly.""" size = bf.size() @@ -994,10 +1034,10 @@ def test_neighbor_allreduce_dynamic_topo_fusion(self): handle_1 = bf.neighbor_allreduce_nonblocking( tensor_1, name=name_1, self_weight=self_weight, - neighbor_weights=neighbor_weights, send_neighbors=send_ranks) + src_weights=neighbor_weights, dst_weights=send_ranks) handle_2 = bf.neighbor_allreduce_nonblocking( tensor_2, name=name_2, self_weight=self_weight, - neighbor_weights=neighbor_weights, send_neighbors=send_ranks) + src_weights=neighbor_weights, dst_weights=send_ranks) output_1 = bf.synchronize(handle_1) output_2 = bf.synchronize(handle_2) @@ -1020,6 +1060,59 @@ def test_neighbor_allreduce_dynamic_topo_fusion(self): sum_value).abs().max() < eps ), "bf.neighbor_allreduce_2 (fusion) produces incorrect reduced tensor" + def test_neighbor_allreduce_dst_weight_fusion(self): + """Test neighbor allreduce works with destination weights under tensor fusion.""" + size = bf.size() + rank = bf.rank() + K = 50 # number of tensors send in short time + if size <= 1: + fname = inspect.currentframe().f_code.co_name + warnings.warn("Skip {} due to size 1".format(fname)) + return + + dtypes = [torch.FloatTensor, torch.DoubleTensor] + if TEST_ON_GPU: + dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor] + + # By default, we use exponential two ring topology. + num_indegree = int(np.ceil(np.log2(size))) + neighbor_ranks = [(rank - 2**i) % size for i in range(num_indegree)] + sum_value = (np.sum(neighbor_ranks)+rank)*K + + dst_weight = 0.5 + self_weight = dst_weight + src_weights = {i: 1.0 for i in neighbor_ranks} + dst_weights = {(rank + 2**i) % size: dst_weight for i in range(num_indegree)} + + dims = [1, 2, 3] + for dtype, dim in itertools.product(dtypes, dims): + tensor_list, handles, names = [], [], [] + for i in range(K): + tensor = torch.FloatTensor(*([23] * dim)).fill_(i + rank*K) + tensor = self.cast_and_place(tensor, dtype) + tensor_list.append(tensor) + names.append("index{}_{}_{}".format(i, dtype, dim)) + + for i in range(K): + handle = bf.neighbor_allreduce_nonblocking( + tensor_list[i], name=names[i], + self_weight=self_weight, src_weights=src_weights, dst_weights=dst_weights) + handles.append(handle) + + outputs = [] + for i in range(K): + output = bf.synchronize(handles[i]) + outputs.append(output) + + for i in range(K): + assert ( + list(outputs[i].shape) == [23] * dim + ), f"{names[i]} (fusion) produces incorrect reduced shape" + output_normalized = outputs[i]/dst_weight - i*(num_indegree+1) + assert ( + (output_normalized - sum_value).abs().max() < EPSILON + ), f"{names[i]} (fusion) produces incorrect reduced tensor" + def test_neighbor_allgather(self): """Test that the neighbor all gather 1D, 2D, 3D tensors correctly.""" size = bf.size() @@ -1115,7 +1208,7 @@ def test_neighbor_allgather_dynamic_variable_size(self): torch.ByteTensor, torch.CharTensor, torch.ShortTensor, torch.HalfTensor] if TEST_ON_GPU: dtypes += [torch.cuda.FloatTensor, torch.cuda.DoubleTensor] - + # Connect to all other ranks neighbor_ranks = [i for i in range(size) if i != rank] dims = [1, 2, 3] @@ -1130,10 +1223,11 @@ def test_neighbor_allgather_dynamic_variable_size(self): tensor = torch.FloatTensor( *([tensor_sizes[rank]] + [17] * (dim - 1))).fill_(1).mul_(rank) tensor = self.cast_and_place(tensor, dtype) - gathered = bf.neighbor_allgather(tensor, dst_ranks=neighbor_ranks, src_ranks=neighbor_ranks) + gathered = bf.neighbor_allgather( + tensor, dst_ranks=neighbor_ranks, src_ranks=neighbor_ranks) tensor, gathered = self.convert_cpu_fp16_to_fp32(tensor, gathered) - tensor_sizes[rank] = 0 # remove self-size since neighbor_allgather does not include self. + tensor_sizes[rank] = 0 # remove self since neighbor_allgather does not include self expected_size = sum(tensor_sizes) assert list(gathered.shape) == [expected_size] + [17] * (dim - 1) diff --git a/test/torch_optimizer_test.py b/test/torch_optimizer_test.py index 8da25053..7508a4b9 100644 --- a/test/torch_optimizer_test.py +++ b/test/torch_optimizer_test.py @@ -226,8 +226,8 @@ def dynamic_neighbor_allreduce_train(model, optimizer, dataloader, isCUDA, dynam model.train() for data, target in dataloader: send_neighbors, recv_neighbors = next(dynamic_topo_gen) - optimizer.send_neighbors = send_neighbors - optimizer.neighbor_weights = { + optimizer.dst_weights = send_neighbors + optimizer.src_weights = { r: 1/(len(recv_neighbors) + 1) for r in recv_neighbors} optimizer.self_weight = 1 / (len(recv_neighbors) + 1)