From 1e73e758c38f7f1673f00082da997f7d3d86cefb Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Fri, 25 Jun 2021 16:16:20 +0000 Subject: [PATCH 01/12] remove norm(), avoid memcpy after allgather 1) Removing the norm computation in debug printing 2) Changing _all_gather to be sync op in fetch_sub_module Reason: the async version is not async at all, because each all_gather calls torch.cuda.synchronize() to guarantee previous communication op to be completed 3) Adding new function _allgather_params_split_launch the existing _allgather_params has explicit memcpy after the all-gather op. We can avoid the explicit memory copy at python side, to improve the performance. Known issue: the `torch.distributed.all_gather` will do implicit memcpy at the end of each `ncclAllgather`. --- .../runtime/zero/partition_parameters.py | 59 ++++++++++++++++++- deepspeed/runtime/zero/stage3.py | 4 +- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 04a82d177611..6b5caeef0a8e 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -634,7 +634,9 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): all_gather_list.append(param) if not async_op: - ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + # ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + ret_value = self._allgather_params_split_launch(all_gather_list, hierarchy=hierarchy) + for param in all_gather_list: param.ds_status = ZeroParamStatus.AVAILABLE return ret_value @@ -836,6 +838,61 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) param.data = replicated_tensor.data return handle + + def _allgather_params_split_launch(self, param_list, hierarchy=0): + """ blocking call + avoid explicit memory copy in _allgather_params + """ + if len(param_list) == 0: + return + # collect local tensors and partition sizes + partition_sizes = [] + local_tensors = [] + for param in param_list: + partition_sizes.append(param.ds_tensor.ds_numel) + local_tensors.append(param.ds_tensor) + + # allocate memory for allgather params + allgather_params = [] + for psize in partition_sizes: + tensor_size = psize * self.world_size + flat_tensor = torch.empty(tensor_size, + dtype=param_list[0].dtype, + device=self.local_device).view(-1) + flat_tensor.requres_grad = False + allgather_params.append(flat_tensor) + + # launch + launch_handles = [] + # backend = get_backend(self.ds_process_group) + # with _batch_p2p_manager(backend): + for param_idx, param in enumerate(param_list): + output_list = [] + for i in range(self.world_size): + psize = partition_sizes[param_idx] + partition = allgather_params[param_idx].narrow(0, i * psize, psize) + output_list.append(partition) + + input_tensor = local_tensors[param_idx].view(-1) + h = torch.distributed.all_gather(output_list, + input_tensor, + group=self.ds_process_group, + async_op=True) + launch_handles.append(h) + + # Wait ensures the operation is enqueued, but not necessarily complete. + launch_handles[-1].wait() + + # assign to param.data (not copy) + for i, param in enumerate(param_list): + gathered_tensor = allgather_params[i] + param.data = gathered_tensor.narrow( + 0, 0, param.ds_numel).view(param.ds_shape).data + + # guarantee the communication to be completed + torch.cuda.synchronize() + + return None def _allgather_params(self, param_list, hierarchy=0): if len(param_list) == 0: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 548c38a072c3..4ad8fcbe7ee5 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -432,7 +432,7 @@ def fetch_sub_module(self, sub_module): self.hierarchy += 1 # parameters are partitioned and need to be allgathered - self._all_gather(partitioned_params, async_op=True) + self._all_gather(partitioned_params, async_op=False) # parameters are inflight and communication needs to be completed if partitioned_params or params_in_flight: @@ -441,7 +441,7 @@ def fetch_sub_module(self, sub_module): for _, param in sub_module.named_parameters(recurse=False): param.ds_status = ZeroParamStatus.AVAILABLE print_rank_0( - f"Param id {param.ds_id}, Shape {param.shape}, device {param.device} norm {param.norm()}", + f"Param id {param.ds_id}, Shape {param.shape}, device {param.device}", force=False) #print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}") From 67b3db3e43265994bb1c408ea82d168911c4ba02 Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Wed, 30 Jun 2021 19:36:20 +0000 Subject: [PATCH 02/12] WIP: wrapped ncclAllgather as customized op in DS micro benchmark shows the improvement of allgather a transformer layer with 9834560 elements in half precision is about 1.1ms on aws-p4d instance. --- csrc/communication/collective_comm.cpp | 215 +++++++++++++++++++++++++ op_builder/__init__.py | 4 +- op_builder/communication.py | 27 ++++ tests/benchmarks/allgather_bench.py | 200 +++++++++++++++++++++++ 4 files changed, 445 insertions(+), 1 deletion(-) create mode 100644 csrc/communication/collective_comm.cpp create mode 100644 op_builder/communication.py create mode 100644 tests/benchmarks/allgather_bench.py diff --git a/csrc/communication/collective_comm.cpp b/csrc/communication/collective_comm.cpp new file mode 100644 index 000000000000..505e8235fb4a --- /dev/null +++ b/csrc/communication/collective_comm.cpp @@ -0,0 +1,215 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int debug_flag = std::getenv("DS_DEBUG")? std::stoi(std::getenv("DS_DEBUG")): 0; + +// recording created ncclComm_t +// using processGroup Name as key +std::unordered_map group_communicators; + +// NCCL type typing +// copied from pytorch source code +std::map ncclDataType = { + {at::kChar, ncclInt8}, + {at::kByte, ncclUint8}, + {at::kFloat, ncclFloat}, + {at::kDouble, ncclDouble}, + {at::kInt, ncclInt32}, + {at::kLong, ncclInt64}, + {at::kHalf, ncclHalf}, + {at::kBool, ncclUint8}, +#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 301 + {at::kBFloat16, ncclBfloat16}, +#endif +}; + +// Helper function that gets the data type and issues error if not supported +// from pytorch source code +ncclDataType_t getNcclDataType(at::ScalarType type) { + auto it = ncclDataType.find(type); + TORCH_CHECK( + it != ncclDataType.end(), + "Input tensor data type is not supported for NCCL process group: ", + type); + return it->second; +} + +void check_tensors(std::vector& output_tensors, + std::vector& input_tensors, + int world_size) { + if (input_tensors.size() == 0 || output_tensors.size() == 0) { + TORCH_CHECK(false, "output/input tensor list must be nonempty"); + } + if (output_tensors.size() != input_tensors.size()) { + TORCH_CHECK(false, "output and input tensors must have same size"); + } + + for (size_t i = 0; i < input_tensors.size(); ++i) { + auto out = output_tensors[i]; + auto in = input_tensors[i]; + if (out.numel() != in.numel() * world_size) { + std::stringstream ss; + ss << "output tensor numel != input tensor numel * world_size at" << i ; + TORCH_CHECK(false, ss.str()); + } + } + +} + +// rank0 create the ncclUniqueId +// broadcast using old ProcessGroupNCCL +// ncclCommInitRank with ncclUniqueId and same rank and world size from current +// ProcessGroupNCCL +// +// Note: reason for creating new ncclComm_t, ::c10d::ProcessGroupNCCL didn't expose +// APIs for getting communicator +ncclComm_t create_communicator(std::vector& input_tensors, + std::string& pg_name, + ::c10d::ProcessGroupNCCL& pg) { + int rank = pg.getRank(); + int world_size = pg.getSize(); + at::Tensor& first_tensor = input_tensors[0]; + auto device_idx = first_tensor.get_device(); + if (debug_flag) + printf("creating new communicator at device %ld\n", device_idx); + + // + ncclUniqueId nccl_id; + ncclComm_t nccl_comm; + + auto id_tensor_option = + torch::TensorOptions() + .dtype(torch::kUInt8) + .layout(torch::kStrided) // dense tensor + .requires_grad(false); + + std::vector bcast_tensor; + if (rank == 0) { + auto _result = ncclGetUniqueId(&nccl_id); + if (_result != ncclSuccess) { + TORCH_CHECK(false, "Getting nccl unique id failed"); + // it suppose to exit + } + id_tensor_option.device(torch::kCPU); + at::Tensor cpu_tensor = torch::empty(sizeof(ncclUniqueId), id_tensor_option).zero_(); + memcpy(cpu_tensor.data_ptr(), &nccl_id, sizeof(ncclUniqueId)); + + at::Tensor id_tensor = cpu_tensor.to(first_tensor.device()); + bcast_tensor.push_back(std::move(id_tensor)); + } else { + at::Tensor id_tensor = + torch::empty(sizeof(ncclUniqueId), id_tensor_option).zero_().to(first_tensor.device()); + bcast_tensor.push_back(std::move(id_tensor)); + } + if (debug_flag) + printf("rank %d, created tensor holder, device %ld, is_cuda %d \n", + rank, + device_idx, + bcast_tensor[0].is_cuda()); + + // bcast + { + at::cuda::CUDAGuard gpuGuard(device_idx); + // make sure the allocated tensors are ready + AT_CUDA_CHECK(cudaDeviceSynchronize()); + auto work = pg.broadcast(bcast_tensor); + // make sure the broadcast finished + AT_CUDA_CHECK(cudaDeviceSynchronize()); + } + + // if rank != 0 + // then need to copy ncclUniqueId from bcast_tensor + if (rank != 0) { + auto cpu_tensor = bcast_tensor[0].to(at::kCPU); + std::memcpy(&nccl_id, cpu_tensor.data_ptr(), cpu_tensor.nbytes()); + } + + { + at::cuda::CUDAGuard gpuGuard(device_idx); + // init communicator and save + ncclCommInitRank(&nccl_comm, world_size, nccl_id, rank); + group_communicators[pg_name] = nccl_comm; + + if (debug_flag) printf("nccl_comm initialized at rank %d, device %ld\n", rank, device_idx); + } + + return nccl_comm; +} + +// get communicator from global map +// if not found, create a new one +ncclComm_t get_communicator(std::vector& input_tensors, + std::string& pg_name, ::c10d::ProcessGroupNCCL& pg) { + auto found = group_communicators.find(pg_name); + if (found == group_communicators.end()) { + return create_communicator(input_tensors, pg_name, pg); + } else { + return found->second; + } +} + +int launch_nccl_allgather(std::vector& output_tensors, + std::vector& input_tensors, + ncclComm_t comm) { + auto& first_input = input_tensors[0]; + auto device_idx = first_input.get_device(); + if (debug_flag) + printf("launching allgather op with number of tensors %lu, at device %ld \n", + input_tensors.size(), + device_idx); + + // this suppose to get the cuda stream specified by `with torch.cuda.stream(comm_stream): ...` + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_idx); + + ncclGroupStart(); + for (size_t i = 0; i < input_tensors.size(); ++i) { + at::Tensor& input = input_tensors[i]; + at::Tensor& output = output_tensors[i]; + ncclAllGather(input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + } + ncclGroupEnd(); + + return 0; +} + +int inplaceAllgather(std::vector& output_tensors, + std::vector& input_tensors, + ::c10d::ProcessGroupNCCL& pg, + std::string pg_name + ) { + // ::c10d::ProcessGroup& p_pg = pg; + if (debug_flag) + printf("inplaceAllgather:: process group rank %d, size %d, pg_name %s \n", + pg.getRank(), + pg.getSize(), + pg_name.c_str()); + + check_tensors(output_tensors, input_tensors, pg.getSize()); + + auto nccl_comm = get_communicator(input_tensors, pg_name, pg); + + int res = launch_nccl_allgather(output_tensors, input_tensors, nccl_comm); + + return res; +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("inplace_allgather", &inplaceAllgather, "inplace all-gather (without memcpy)"); +} diff --git a/op_builder/__init__.py b/op_builder/__init__.py index f19ed916c332..2a854ed663e2 100755 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -12,6 +12,7 @@ from .builder import get_default_compute_capatabilities from .transformer_inference import InferenceBuilder from .quantizer import QuantizerBuilder +from .communication import CommunicationBuilder # TODO: infer this list instead of hard coded # List of all available ops @@ -25,6 +26,7 @@ AsyncIOBuilder(), InferenceBuilder(), UtilsBuilder(), - QuantizerBuilder() + QuantizerBuilder(), + CommunicationBuilder() ] ALL_OPS = {op.name: op for op in __op_builders__} diff --git a/op_builder/communication.py b/op_builder/communication.py new file mode 100644 index 000000000000..aa5da88242ea --- /dev/null +++ b/op_builder/communication.py @@ -0,0 +1,27 @@ +""" +Customized wrap for collective communications +""" + +from .builder import OpBuilder +import os + +class CommunicationBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_COLL_COMM" + NAME = "communication" + + def __init__(self): + super().__init__(name=self.NAME) + + def absolute_name(self): + return f'deepspeed.ops.{self.NAME}_op' + + def include_paths(self): + NCCL_HOME = os.getenv('NCCL_HOME') + if not NCCL_HOME: + NCCL_HOME='/usr/local/cuda' + + inc_path = os.path.join(NCCL_HOME, 'include') + return [inc_path] + + def sources(self): + return ['csrc/communication/collective_comm.cpp'] \ No newline at end of file diff --git a/tests/benchmarks/allgather_bench.py b/tests/benchmarks/allgather_bench.py new file mode 100644 index 000000000000..b7a24246804d --- /dev/null +++ b/tests/benchmarks/allgather_bench.py @@ -0,0 +1,200 @@ +""" +Test command: +python3 -m torch.distributed.launch --nnodes=1 --nproc_per_node=2 +""" + +import argparse + +import torch +import math +from torch._C import device +from torch.autograd.grad_mode import F +import torch.distributed as dist +from torch.distributed.distributed_c10d import _get_default_group, _pg_names +from deepspeed.ops.op_builder import CommunicationBuilder +import time +import numpy as np + +ds_coll_comm = CommunicationBuilder().load() + + +def sizeof_dtype(dtype): + if dtype == torch.half: + return 2 + elif dtype == torch.float: + return 4 + else: + return None + + +def prepare_tensor(partition_sizes, world_size, device, dtype=torch.half): + # transformer layer structure with partitioned onto 8 GPUs + + output_tensors = [] + input_tensors = [] + + for _size in partition_sizes: + std = 1 / math.sqrt(_size) + input_t = torch.empty(_size, + dtype=dtype, + device=device).view(-1).uniform_(-std, + std) + output_t = torch.empty(_size * world_size, + dtype=dtype, + device=device).view(-1).uniform_(-std, + std) + input_tensors.append(input_t) + output_tensors.append(output_t) + return output_tensors, input_tensors + + +def _torch_allgather_once(output_tensors, + input_tensors, + partition_sizes, + rank, + world_size): + """""" + s = torch.cuda.Stream() + handles = [] + for part_idx, part_size in enumerate(partition_sizes): + output_t = output_tensors[part_idx] + input_t = input_tensors[part_idx] + + output_list = [] + for i in range(world_size): + out_tensor = output_t.narrow(0, i * part_size, part_size) + output_list.append(out_tensor) + + h = dist.all_gather(output_list, input_t, async_op=True) + handles.append(h) + + torch.cuda.synchronize() + + +def print_bw_rank0(partition_sizes, time_costs, dtype): + if dist.get_rank() == 0: + elem_size = sizeof_dtype(dtype) # in bytes + assert elem_size != None + + numel = sum(partition_sizes) + + avg_t = np.mean(time_costs) + bw = numel * elem_size * (dist.get_world_size() - 1) / 1e9 / avg_t + print(f'avg time {avg_t * 1e3} ms, bw {bw} GB/s') + + +def bench_torch_allgather(output_tensors, + input_tensors, + partition_sizes, + rank, + world_size, + warm_up=5, + repeat=10): + ts = [] + for i in range(warm_up + repeat): + s = time.time() + _torch_allgather_once(output_tensors, + input_tensors, + partition_sizes, + rank, + world_size) + e = time.time() + + if i >= warm_up: + ts.append(e - s) + + print_bw_rank0(partition_sizes, ts, input_tensors[0].dtype) + +def _custom_allgather_once(output_tensors, input_tensors, comm_stream): + default_pg = _get_default_group() + pg_name = _pg_names[default_pg] + + with torch.cuda.stream(comm_stream): + res = ds_coll_comm.inplace_allgather(output_tensors, + input_tensors, + default_pg, + pg_name) + comm_stream.synchronize() + return res + + +def bench_custom_allgather(output_tensors, + input_tensors, + partition_sizes, + rank, + world_size, + warm_up=5, + repeat=10): + """""" + comm_stream = torch.cuda.Stream(rank % torch.cuda.device_count()) + + ts = [] + for i in range(warm_up+repeat): + s = time.time() + _custom_allgather_once(output_tensors, input_tensors, comm_stream) + e = time.time() + + if i >= warm_up: + ts.append(e - s) + + print_bw_rank0(partition_sizes, ts, input_tensors[0].dtype) + +def print_rank0(msg): + if (dist.get_rank() == 0): + print(msg) + +def main(): + """""" + dist.init_process_group(backend='nccl') + + rank = dist.get_rank() + local_size = torch.cuda.device_count() + device_id = rank % local_size + world_size = dist.get_world_size() + torch.cuda.set_device(device_id) + + partition_sizes = [ + 2457600, + 960, + 819200, + 320, + 320, + 320, + 3276800, + 1280, + 3276800, + 320, + 320, + 320 + ] + output_tensors, input_tensors = prepare_tensor(partition_sizes, + dist.get_world_size(), f'cuda:{device_id}', + torch.half) + print_rank0('Using torch.distributed.allgather') + bench_torch_allgather(output_tensors, + input_tensors, + partition_sizes, + rank, + world_size) + + combined_tensor_torch = torch.cat(output_tensors) + out_sum_torch = combined_tensor_torch.sum() + print_rank0(f'output tensors sum {out_sum_torch}') + + # clean output tensors + for t in output_tensors: + t.zero_() + print_rank0('Using customized allgather') + bench_custom_allgather(output_tensors, + input_tensors, + partition_sizes, + rank, + world_size) + combined_tensor_custom = torch.cat(output_tensors) + out_sum_custom = combined_tensor_custom.sum() + print_rank0(f'output tensor sum {out_sum_custom}') + + print_rank0(f'allgather results are close {combined_tensor_custom.allclose(combined_tensor_torch)}') + +if __name__ == '__main__': + main() \ No newline at end of file From 70e681f0c30e0aaef025df2e542cb76f5c69efaa Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Wed, 30 Jun 2021 23:08:11 +0000 Subject: [PATCH 03/12] WIP: integrated into partition_parameters Performance improvement of 5.1B bert on aws-p4d: fwd: 300ms -> 200ms bwd: 680ms -> 610ms --- csrc/communication/collective_comm.cpp | 2 +- deepspeed/ops/communication/__init__.py | 1 + deepspeed/ops/communication/communication.py | 59 ++++++++++++++++++ .../runtime/zero/partition_parameters.py | 62 ++++++++++++++++--- tests/benchmarks/allgather_bench.py | 2 +- 5 files changed, 117 insertions(+), 9 deletions(-) create mode 100644 deepspeed/ops/communication/__init__.py create mode 100644 deepspeed/ops/communication/communication.py diff --git a/csrc/communication/collective_comm.cpp b/csrc/communication/collective_comm.cpp index 505e8235fb4a..6e78250c3ef0 100644 --- a/csrc/communication/collective_comm.cpp +++ b/csrc/communication/collective_comm.cpp @@ -211,5 +211,5 @@ int inplaceAllgather(std::vector& output_tensors, PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("inplace_allgather", &inplaceAllgather, "inplace all-gather (without memcpy)"); + m.def("_inplace_allgather", &inplaceAllgather, "inplace all-gather (without memcpy)"); } diff --git a/deepspeed/ops/communication/__init__.py b/deepspeed/ops/communication/__init__.py new file mode 100644 index 000000000000..ddf050037733 --- /dev/null +++ b/deepspeed/ops/communication/__init__.py @@ -0,0 +1 @@ +from .communication import inplace_allgather \ No newline at end of file diff --git a/deepspeed/ops/communication/communication.py b/deepspeed/ops/communication/communication.py new file mode 100644 index 000000000000..331d874ccec3 --- /dev/null +++ b/deepspeed/ops/communication/communication.py @@ -0,0 +1,59 @@ +import torch + +from torch.distributed.distributed_c10d import _pg_names, _get_default_group +import torch.distributed as dist + +from ..op_builder import CommunicationBuilder + +coll_comm_module = None + + +class CommunicationHandle: + def __init__(self, + start_event: torch.cuda.Event, + end_event: torch.cuda.Event, + timing: bool) -> None: + + self.start = start_event + self.end = end_event + self.timing = timing + + def is_completed(self, ): + return self.end.query() + + def wait(self, stream): + self.end.wait(stream) + + def synchronize(self): + self.end.synchronize() + +def map_process_group(group): + # print(f'rank {dist.get_rank(group=group)}, _pg_names {_pg_names}, _pg_group_ranks {_pg_group_ranks}') + if group == dist.group.WORLD: + return _get_default_group() + else: + return group + + +def inplace_allgather(output_tensors, input_tensors, group, comm_stream, timing=False): + """""" + global coll_comm_module + if coll_comm_module is None: + coll_comm_module = CommunicationBuilder().load() + + group = map_process_group(group) + process_group_name = _pg_names[group] + + start_event = torch.cuda.Event(enable_timing=timing) + end_event = torch.cuda.Event(enable_timing=timing) + if timing: + start_event.record(comm_stream) + with torch.cuda.stream(comm_stream): + coll_comm_module._inplace_allgather(output_tensors, + input_tensors, + group, + process_group_name) + + end_event.record(comm_stream) + + return CommunicationHandle(start_event, end_event, timing) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 6b5caeef0a8e..5579309c8544 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -10,6 +10,8 @@ import functools import itertools +from deepspeed.ops.communication import communication as ds_comm + import torch from torch.distributed.distributed_c10d import _get_global_rank @@ -635,7 +637,8 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): if not async_op: # ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) - ret_value = self._allgather_params_split_launch(all_gather_list, hierarchy=hierarchy) + # ret_value = self._allgather_params_split_launch(all_gather_list, hierarchy=hierarchy) + ret_value = self._allgather_params_with_custom_op(all_gather_list, hierarchy) for param in all_gather_list: param.ds_status = ZeroParamStatus.AVAILABLE @@ -838,7 +841,52 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) param.data = replicated_tensor.data return handle - + + def _allgather_params_with_custom_op(self, param_list, hierarchy=0): + """ using customized allgather op to avoid redundant cudaMemcpy + Note: the torch.distributed.allgather has extra copy: + https://github.com/pytorch/pytorch/blob/v1.9.0/torch/lib/c10d/ProcessGroupNCCL.cpp#L1469 + """ + if len(param_list) == 0: + return + # collect local tensors and partition sizes + partition_sizes = [] + local_tensors = [] + for param in param_list: + partition_sizes.append(param.ds_tensor.ds_numel) + local_tensors.append(param.ds_tensor) + + # allocate memory for allgather params + allgather_output_params = [] + for psize in partition_sizes: + tensor_size = psize * self.world_size + flat_tensor = torch.empty(tensor_size, + dtype=param_list[0].dtype, + device=self.local_device).view(-1) + flat_tensor.requres_grad = False + allgather_output_params.append(flat_tensor) + + # suppose to set the communication stream outside of this function + comm_stream = torch.cuda.current_stream() + + # the handle is a wrapper of the start and the end events + comm_handle = ds_comm.inplace_allgather(allgather_output_params, + local_tensors, + self.ds_process_group, + comm_stream) + + # assign to param.data (not copy) + for i, param in enumerate(param_list): + gathered_tensor = allgather_output_params[i] + param.data = gathered_tensor.narrow( + 0, 0, param.ds_numel).view(param.ds_shape).data + + # this synchronize on cuda.Event + comm_handle.synchronize() + + return None + + def _allgather_params_split_launch(self, param_list, hierarchy=0): """ blocking call avoid explicit memory copy in _allgather_params @@ -851,7 +899,7 @@ def _allgather_params_split_launch(self, param_list, hierarchy=0): for param in param_list: partition_sizes.append(param.ds_tensor.ds_numel) local_tensors.append(param.ds_tensor) - + # allocate memory for allgather params allgather_params = [] for psize in partition_sizes: @@ -861,8 +909,8 @@ def _allgather_params_split_launch(self, param_list, hierarchy=0): device=self.local_device).view(-1) flat_tensor.requres_grad = False allgather_params.append(flat_tensor) - - # launch + + # launch launch_handles = [] # backend = get_backend(self.ds_process_group) # with _batch_p2p_manager(backend): @@ -874,12 +922,12 @@ def _allgather_params_split_launch(self, param_list, hierarchy=0): output_list.append(partition) input_tensor = local_tensors[param_idx].view(-1) - h = torch.distributed.all_gather(output_list, + h = torch.distributed.all_gather(output_list, input_tensor, group=self.ds_process_group, async_op=True) launch_handles.append(h) - + # Wait ensures the operation is enqueued, but not necessarily complete. launch_handles[-1].wait() diff --git a/tests/benchmarks/allgather_bench.py b/tests/benchmarks/allgather_bench.py index b7a24246804d..48a7b380d836 100644 --- a/tests/benchmarks/allgather_bench.py +++ b/tests/benchmarks/allgather_bench.py @@ -110,7 +110,7 @@ def _custom_allgather_once(output_tensors, input_tensors, comm_stream): pg_name = _pg_names[default_pg] with torch.cuda.stream(comm_stream): - res = ds_coll_comm.inplace_allgather(output_tensors, + res = ds_coll_comm._inplace_allgather(output_tensors, input_tensors, default_pg, pg_name) From 81b4fc4a42a3a8fe7119cad425838f500d05fcc7 Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Thu, 1 Jul 2021 17:35:41 +0000 Subject: [PATCH 04/12] Fix format --- csrc/communication/collective_comm.cpp | 301 +++++++++--------- deepspeed/ops/communication/__init__.py | 2 +- deepspeed/ops/communication/communication.py | 3 +- .../runtime/zero/partition_parameters.py | 27 +- op_builder/communication.py | 7 +- tests/benchmarks/allgather_bench.py | 22 +- 6 files changed, 186 insertions(+), 176 deletions(-) diff --git a/csrc/communication/collective_comm.cpp b/csrc/communication/collective_comm.cpp index 6e78250c3ef0..5423533fb89c 100644 --- a/csrc/communication/collective_comm.cpp +++ b/csrc/communication/collective_comm.cpp @@ -1,23 +1,23 @@ #include -#include +#include +#include #include -#include +#include +#include #include #include -#include +#include #include #include -#include -#include -int debug_flag = std::getenv("DS_DEBUG")? std::stoi(std::getenv("DS_DEBUG")): 0; +int debug_flag = std::getenv("DS_DEBUG") ? std::stoi(std::getenv("DS_DEBUG")) : 0; // recording created ncclComm_t // using processGroup Name as key std::unordered_map group_communicators; -// NCCL type typing +// NCCL type typing // copied from pytorch source code std::map ncclDataType = { {at::kChar, ncclInt8}, @@ -35,180 +35,181 @@ std::map ncclDataType = { // Helper function that gets the data type and issues error if not supported // from pytorch source code -ncclDataType_t getNcclDataType(at::ScalarType type) { - auto it = ncclDataType.find(type); - TORCH_CHECK( - it != ncclDataType.end(), - "Input tensor data type is not supported for NCCL process group: ", - type); - return it->second; +ncclDataType_t getNcclDataType(at::ScalarType type) +{ + auto it = ncclDataType.find(type); + TORCH_CHECK(it != ncclDataType.end(), + "Input tensor data type is not supported for NCCL process group: ", + type); + return it->second; } void check_tensors(std::vector& output_tensors, - std::vector& input_tensors, - int world_size) { - if (input_tensors.size() == 0 || output_tensors.size() == 0) { - TORCH_CHECK(false, "output/input tensor list must be nonempty"); - } - if (output_tensors.size() != input_tensors.size()) { - TORCH_CHECK(false, "output and input tensors must have same size"); - } - - for (size_t i = 0; i < input_tensors.size(); ++i) { - auto out = output_tensors[i]; - auto in = input_tensors[i]; - if (out.numel() != in.numel() * world_size) { - std::stringstream ss; - ss << "output tensor numel != input tensor numel * world_size at" << i ; - TORCH_CHECK(false, ss.str()); + std::vector& input_tensors, + int world_size) +{ + if (input_tensors.size() == 0 || output_tensors.size() == 0) { + TORCH_CHECK(false, "output/input tensor list must be nonempty"); + } + if (output_tensors.size() != input_tensors.size()) { + TORCH_CHECK(false, "output and input tensors must have same size"); } - } + for (size_t i = 0; i < input_tensors.size(); ++i) { + auto out = output_tensors[i]; + auto in = input_tensors[i]; + if (out.numel() != in.numel() * world_size) { + std::stringstream ss; + ss << "output tensor numel != input tensor numel * world_size at" << i; + TORCH_CHECK(false, ss.str()); + } + } } -// rank0 create the ncclUniqueId +// rank0 create the ncclUniqueId // broadcast using old ProcessGroupNCCL // ncclCommInitRank with ncclUniqueId and same rank and world size from current // ProcessGroupNCCL -// +// // Note: reason for creating new ncclComm_t, ::c10d::ProcessGroupNCCL didn't expose // APIs for getting communicator ncclComm_t create_communicator(std::vector& input_tensors, - std::string& pg_name, - ::c10d::ProcessGroupNCCL& pg) { - int rank = pg.getRank(); - int world_size = pg.getSize(); - at::Tensor& first_tensor = input_tensors[0]; - auto device_idx = first_tensor.get_device(); - if (debug_flag) - printf("creating new communicator at device %ld\n", device_idx); - - // - ncclUniqueId nccl_id; - ncclComm_t nccl_comm; - - auto id_tensor_option = - torch::TensorOptions() - .dtype(torch::kUInt8) - .layout(torch::kStrided) // dense tensor - .requires_grad(false); - - std::vector bcast_tensor; - if (rank == 0) { - auto _result = ncclGetUniqueId(&nccl_id); - if (_result != ncclSuccess) { - TORCH_CHECK(false, "Getting nccl unique id failed"); - // it suppose to exit - } - id_tensor_option.device(torch::kCPU); - at::Tensor cpu_tensor = torch::empty(sizeof(ncclUniqueId), id_tensor_option).zero_(); - memcpy(cpu_tensor.data_ptr(), &nccl_id, sizeof(ncclUniqueId)); - - at::Tensor id_tensor = cpu_tensor.to(first_tensor.device()); - bcast_tensor.push_back(std::move(id_tensor)); - } else { - at::Tensor id_tensor = - torch::empty(sizeof(ncclUniqueId), id_tensor_option).zero_().to(first_tensor.device()); - bcast_tensor.push_back(std::move(id_tensor)); - } - if (debug_flag) - printf("rank %d, created tensor holder, device %ld, is_cuda %d \n", - rank, - device_idx, - bcast_tensor[0].is_cuda()); - - // bcast - { - at::cuda::CUDAGuard gpuGuard(device_idx); - // make sure the allocated tensors are ready - AT_CUDA_CHECK(cudaDeviceSynchronize()); - auto work = pg.broadcast(bcast_tensor); - // make sure the broadcast finished - AT_CUDA_CHECK(cudaDeviceSynchronize()); - } - - // if rank != 0 - // then need to copy ncclUniqueId from bcast_tensor - if (rank != 0) { - auto cpu_tensor = bcast_tensor[0].to(at::kCPU); - std::memcpy(&nccl_id, cpu_tensor.data_ptr(), cpu_tensor.nbytes()); - } - - { - at::cuda::CUDAGuard gpuGuard(device_idx); - // init communicator and save - ncclCommInitRank(&nccl_comm, world_size, nccl_id, rank); - group_communicators[pg_name] = nccl_comm; - - if (debug_flag) printf("nccl_comm initialized at rank %d, device %ld\n", rank, device_idx); - } - - return nccl_comm; + std::string& pg_name, + ::c10d::ProcessGroupNCCL& pg) +{ + int rank = pg.getRank(); + int world_size = pg.getSize(); + at::Tensor& first_tensor = input_tensors[0]; + auto device_idx = first_tensor.get_device(); + if (debug_flag) printf("creating new communicator at device %ld\n", device_idx); + + // + ncclUniqueId nccl_id; + ncclComm_t nccl_comm; + + auto id_tensor_option = torch::TensorOptions() + .dtype(torch::kUInt8) + .layout(torch::kStrided) // dense tensor + .requires_grad(false); + + std::vector bcast_tensor; + if (rank == 0) { + auto _result = ncclGetUniqueId(&nccl_id); + if (_result != ncclSuccess) { + TORCH_CHECK(false, "Getting nccl unique id failed"); + // it suppose to exit + } + id_tensor_option.device(torch::kCPU); + at::Tensor cpu_tensor = torch::empty(sizeof(ncclUniqueId), id_tensor_option).zero_(); + memcpy(cpu_tensor.data_ptr(), &nccl_id, sizeof(ncclUniqueId)); + + at::Tensor id_tensor = cpu_tensor.to(first_tensor.device()); + bcast_tensor.push_back(std::move(id_tensor)); + } else { + at::Tensor id_tensor = + torch::empty(sizeof(ncclUniqueId), id_tensor_option).zero_().to(first_tensor.device()); + bcast_tensor.push_back(std::move(id_tensor)); + } + if (debug_flag) + printf("rank %d, created tensor holder, device %ld, is_cuda %d \n", + rank, + device_idx, + bcast_tensor[0].is_cuda()); + + // bcast + { + at::cuda::CUDAGuard gpuGuard(device_idx); + // make sure the allocated tensors are ready + AT_CUDA_CHECK(cudaDeviceSynchronize()); + auto work = pg.broadcast(bcast_tensor); + // make sure the broadcast finished + AT_CUDA_CHECK(cudaDeviceSynchronize()); + } + + // if rank != 0 + // then need to copy ncclUniqueId from bcast_tensor + if (rank != 0) { + auto cpu_tensor = bcast_tensor[0].to(at::kCPU); + std::memcpy(&nccl_id, cpu_tensor.data_ptr(), cpu_tensor.nbytes()); + } + + { + at::cuda::CUDAGuard gpuGuard(device_idx); + // init communicator and save + ncclCommInitRank(&nccl_comm, world_size, nccl_id, rank); + group_communicators[pg_name] = nccl_comm; + + if (debug_flag) printf("nccl_comm initialized at rank %d, device %ld\n", rank, device_idx); + } + + return nccl_comm; } // get communicator from global map // if not found, create a new one -ncclComm_t get_communicator(std::vector& input_tensors, - std::string& pg_name, ::c10d::ProcessGroupNCCL& pg) { - auto found = group_communicators.find(pg_name); - if (found == group_communicators.end()) { - return create_communicator(input_tensors, pg_name, pg); - } else { - return found->second; - } +ncclComm_t get_communicator(std::vector& input_tensors, + std::string& pg_name, + ::c10d::ProcessGroupNCCL& pg) +{ + auto found = group_communicators.find(pg_name); + if (found == group_communicators.end()) { + return create_communicator(input_tensors, pg_name, pg); + } else { + return found->second; + } } int launch_nccl_allgather(std::vector& output_tensors, - std::vector& input_tensors, - ncclComm_t comm) { - auto& first_input = input_tensors[0]; - auto device_idx = first_input.get_device(); - if (debug_flag) - printf("launching allgather op with number of tensors %lu, at device %ld \n", - input_tensors.size(), - device_idx); - - // this suppose to get the cuda stream specified by `with torch.cuda.stream(comm_stream): ...` - at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_idx); - - ncclGroupStart(); - for (size_t i = 0; i < input_tensors.size(); ++i) { - at::Tensor& input = input_tensors[i]; - at::Tensor& output = output_tensors[i]; - ncclAllGather(input.data_ptr(), - output.data_ptr(), - input.numel(), - getNcclDataType(input.scalar_type()), - comm, - stream.stream()); - } - ncclGroupEnd(); - - return 0; + std::vector& input_tensors, + ncclComm_t comm) +{ + auto& first_input = input_tensors[0]; + auto device_idx = first_input.get_device(); + if (debug_flag) + printf("launching allgather op with number of tensors %lu, at device %ld \n", + input_tensors.size(), + device_idx); + + // this suppose to get the cuda stream specified by `with torch.cuda.stream(comm_stream): ...` + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_idx); + + ncclGroupStart(); + for (size_t i = 0; i < input_tensors.size(); ++i) { + at::Tensor& input = input_tensors[i]; + at::Tensor& output = output_tensors[i]; + ncclAllGather(input.data_ptr(), + output.data_ptr(), + input.numel(), + getNcclDataType(input.scalar_type()), + comm, + stream.stream()); + } + ncclGroupEnd(); + + return 0; } int inplaceAllgather(std::vector& output_tensors, std::vector& input_tensors, ::c10d::ProcessGroupNCCL& pg, - std::string pg_name - ) { - // ::c10d::ProcessGroup& p_pg = pg; - if (debug_flag) - printf("inplaceAllgather:: process group rank %d, size %d, pg_name %s \n", - pg.getRank(), - pg.getSize(), - pg_name.c_str()); + std::string pg_name) +{ + // ::c10d::ProcessGroup& p_pg = pg; + if (debug_flag) + printf("inplaceAllgather:: process group rank %d, size %d, pg_name %s \n", + pg.getRank(), + pg.getSize(), + pg_name.c_str()); - check_tensors(output_tensors, input_tensors, pg.getSize()); + check_tensors(output_tensors, input_tensors, pg.getSize()); - auto nccl_comm = get_communicator(input_tensors, pg_name, pg); + auto nccl_comm = get_communicator(input_tensors, pg_name, pg); - int res = launch_nccl_allgather(output_tensors, input_tensors, nccl_comm); + int res = launch_nccl_allgather(output_tensors, input_tensors, nccl_comm); - return res; + return res; } - PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("_inplace_allgather", &inplaceAllgather, "inplace all-gather (without memcpy)"); diff --git a/deepspeed/ops/communication/__init__.py b/deepspeed/ops/communication/__init__.py index ddf050037733..1e03579753f9 100644 --- a/deepspeed/ops/communication/__init__.py +++ b/deepspeed/ops/communication/__init__.py @@ -1 +1 @@ -from .communication import inplace_allgather \ No newline at end of file +from .communication import inplace_allgather diff --git a/deepspeed/ops/communication/communication.py b/deepspeed/ops/communication/communication.py index 331d874ccec3..dfaa785a2b68 100644 --- a/deepspeed/ops/communication/communication.py +++ b/deepspeed/ops/communication/communication.py @@ -27,6 +27,7 @@ def wait(self, stream): def synchronize(self): self.end.synchronize() + def map_process_group(group): # print(f'rank {dist.get_rank(group=group)}, _pg_names {_pg_names}, _pg_group_ranks {_pg_group_ranks}') if group == dist.group.WORLD: @@ -41,7 +42,7 @@ def inplace_allgather(output_tensors, input_tensors, group, comm_stream, timing= if coll_comm_module is None: coll_comm_module = CommunicationBuilder().load() - group = map_process_group(group) + group = map_process_group(group) process_group_name = _pg_names[group] start_event = torch.cuda.Event(enable_timing=timing) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 5579309c8544..68371dce4f5b 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -844,7 +844,7 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): def _allgather_params_with_custom_op(self, param_list, hierarchy=0): """ using customized allgather op to avoid redundant cudaMemcpy - Note: the torch.distributed.allgather has extra copy: + Note: the torch.distributed.allgather has extra copy: https://github.com/pytorch/pytorch/blob/v1.9.0/torch/lib/c10d/ProcessGroupNCCL.cpp#L1469 """ if len(param_list) == 0: @@ -861,8 +861,8 @@ def _allgather_params_with_custom_op(self, param_list, hierarchy=0): for psize in partition_sizes: tensor_size = psize * self.world_size flat_tensor = torch.empty(tensor_size, - dtype=param_list[0].dtype, - device=self.local_device).view(-1) + dtype=param_list[0].dtype, + device=self.local_device).view(-1) flat_tensor.requres_grad = False allgather_output_params.append(flat_tensor) @@ -878,15 +878,15 @@ def _allgather_params_with_custom_op(self, param_list, hierarchy=0): # assign to param.data (not copy) for i, param in enumerate(param_list): gathered_tensor = allgather_output_params[i] - param.data = gathered_tensor.narrow( - 0, 0, param.ds_numel).view(param.ds_shape).data + param.data = gathered_tensor.narrow(0, + 0, + param.ds_numel).view(param.ds_shape).data # this synchronize on cuda.Event comm_handle.synchronize() return None - def _allgather_params_split_launch(self, param_list, hierarchy=0): """ blocking call avoid explicit memory copy in _allgather_params @@ -905,8 +905,8 @@ def _allgather_params_split_launch(self, param_list, hierarchy=0): for psize in partition_sizes: tensor_size = psize * self.world_size flat_tensor = torch.empty(tensor_size, - dtype=param_list[0].dtype, - device=self.local_device).view(-1) + dtype=param_list[0].dtype, + device=self.local_device).view(-1) flat_tensor.requres_grad = False allgather_params.append(flat_tensor) @@ -923,9 +923,9 @@ def _allgather_params_split_launch(self, param_list, hierarchy=0): input_tensor = local_tensors[param_idx].view(-1) h = torch.distributed.all_gather(output_list, - input_tensor, - group=self.ds_process_group, - async_op=True) + input_tensor, + group=self.ds_process_group, + async_op=True) launch_handles.append(h) # Wait ensures the operation is enqueued, but not necessarily complete. @@ -934,8 +934,9 @@ def _allgather_params_split_launch(self, param_list, hierarchy=0): # assign to param.data (not copy) for i, param in enumerate(param_list): gathered_tensor = allgather_params[i] - param.data = gathered_tensor.narrow( - 0, 0, param.ds_numel).view(param.ds_shape).data + param.data = gathered_tensor.narrow(0, + 0, + param.ds_numel).view(param.ds_shape).data # guarantee the communication to be completed torch.cuda.synchronize() diff --git a/op_builder/communication.py b/op_builder/communication.py index aa5da88242ea..81fdb28d6dc4 100644 --- a/op_builder/communication.py +++ b/op_builder/communication.py @@ -5,10 +5,11 @@ from .builder import OpBuilder import os + class CommunicationBuilder(OpBuilder): BUILD_VAR = "DS_BUILD_COLL_COMM" NAME = "communication" - + def __init__(self): super().__init__(name=self.NAME) @@ -18,10 +19,10 @@ def absolute_name(self): def include_paths(self): NCCL_HOME = os.getenv('NCCL_HOME') if not NCCL_HOME: - NCCL_HOME='/usr/local/cuda' + NCCL_HOME = '/usr/local/cuda' inc_path = os.path.join(NCCL_HOME, 'include') return [inc_path] def sources(self): - return ['csrc/communication/collective_comm.cpp'] \ No newline at end of file + return ['csrc/communication/collective_comm.cpp'] diff --git a/tests/benchmarks/allgather_bench.py b/tests/benchmarks/allgather_bench.py index 48a7b380d836..744da9961bad 100644 --- a/tests/benchmarks/allgather_bench.py +++ b/tests/benchmarks/allgather_bench.py @@ -1,5 +1,5 @@ """ -Test command: +Test command: python3 -m torch.distributed.launch --nnodes=1 --nproc_per_node=2 """ @@ -105,15 +105,16 @@ def bench_torch_allgather(output_tensors, print_bw_rank0(partition_sizes, ts, input_tensors[0].dtype) + def _custom_allgather_once(output_tensors, input_tensors, comm_stream): default_pg = _get_default_group() pg_name = _pg_names[default_pg] with torch.cuda.stream(comm_stream): res = ds_coll_comm._inplace_allgather(output_tensors, - input_tensors, - default_pg, - pg_name) + input_tensors, + default_pg, + pg_name) comm_stream.synchronize() return res @@ -129,20 +130,22 @@ def bench_custom_allgather(output_tensors, comm_stream = torch.cuda.Stream(rank % torch.cuda.device_count()) ts = [] - for i in range(warm_up+repeat): + for i in range(warm_up + repeat): s = time.time() _custom_allgather_once(output_tensors, input_tensors, comm_stream) e = time.time() if i >= warm_up: ts.append(e - s) - + print_bw_rank0(partition_sizes, ts, input_tensors[0].dtype) + def print_rank0(msg): if (dist.get_rank() == 0): print(msg) + def main(): """""" dist.init_process_group(backend='nccl') @@ -194,7 +197,10 @@ def main(): out_sum_custom = combined_tensor_custom.sum() print_rank0(f'output tensor sum {out_sum_custom}') - print_rank0(f'allgather results are close {combined_tensor_custom.allclose(combined_tensor_torch)}') + print_rank0( + f'allgather results are close {combined_tensor_custom.allclose(combined_tensor_torch)}' + ) + if __name__ == '__main__': - main() \ No newline at end of file + main() From 32c8fa72189d760828d95411104f3d54edfa377a Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Tue, 6 Jul 2021 16:37:10 +0000 Subject: [PATCH 05/12] cleaned dead code, modified unit test --- .../runtime/zero/partition_parameters.py | 57 ------------------- tests/benchmarks/allgather_bench.py | 14 +---- 2 files changed, 3 insertions(+), 68 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 6d07dbe3844e..d92ec08d94c3 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -644,7 +644,6 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): if not async_op: # ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) - # ret_value = self._allgather_params_split_launch(all_gather_list, hierarchy=hierarchy) ret_value = self._allgather_params_with_custom_op(all_gather_list, hierarchy) for param in all_gather_list: @@ -894,62 +893,6 @@ def _allgather_params_with_custom_op(self, param_list, hierarchy=0): return None - def _allgather_params_split_launch(self, param_list, hierarchy=0): - """ blocking call - avoid explicit memory copy in _allgather_params - """ - if len(param_list) == 0: - return - # collect local tensors and partition sizes - partition_sizes = [] - local_tensors = [] - for param in param_list: - partition_sizes.append(param.ds_tensor.ds_numel) - local_tensors.append(param.ds_tensor) - - # allocate memory for allgather params - allgather_params = [] - for psize in partition_sizes: - tensor_size = psize * self.world_size - flat_tensor = torch.empty(tensor_size, - dtype=param_list[0].dtype, - device=self.local_device).view(-1) - flat_tensor.requres_grad = False - allgather_params.append(flat_tensor) - - # launch - launch_handles = [] - # backend = get_backend(self.ds_process_group) - # with _batch_p2p_manager(backend): - for param_idx, param in enumerate(param_list): - output_list = [] - for i in range(self.world_size): - psize = partition_sizes[param_idx] - partition = allgather_params[param_idx].narrow(0, i * psize, psize) - output_list.append(partition) - - input_tensor = local_tensors[param_idx].view(-1) - h = torch.distributed.all_gather(output_list, - input_tensor, - group=self.ds_process_group, - async_op=True) - launch_handles.append(h) - - # Wait ensures the operation is enqueued, but not necessarily complete. - launch_handles[-1].wait() - - # assign to param.data (not copy) - for i, param in enumerate(param_list): - gathered_tensor = allgather_params[i] - param.data = gathered_tensor.narrow(0, - 0, - param.ds_numel).view(param.ds_shape).data - - # guarantee the communication to be completed - torch.cuda.synchronize() - - return None - def _allgather_params(self, param_list, hierarchy=0): if len(param_list) == 0: return diff --git a/tests/benchmarks/allgather_bench.py b/tests/benchmarks/allgather_bench.py index 744da9961bad..52927d97aadc 100644 --- a/tests/benchmarks/allgather_bench.py +++ b/tests/benchmarks/allgather_bench.py @@ -11,12 +11,10 @@ from torch.autograd.grad_mode import F import torch.distributed as dist from torch.distributed.distributed_c10d import _get_default_group, _pg_names -from deepspeed.ops.op_builder import CommunicationBuilder +from deepspeed.ops.communication import inplace_allgather import time import numpy as np -ds_coll_comm = CommunicationBuilder().load() - def sizeof_dtype(dtype): if dtype == torch.half: @@ -108,13 +106,7 @@ def bench_torch_allgather(output_tensors, def _custom_allgather_once(output_tensors, input_tensors, comm_stream): default_pg = _get_default_group() - pg_name = _pg_names[default_pg] - - with torch.cuda.stream(comm_stream): - res = ds_coll_comm._inplace_allgather(output_tensors, - input_tensors, - default_pg, - pg_name) + res = inplace_allgather(output_tensors, input_tensors, default_pg, comm_stream) comm_stream.synchronize() return res @@ -198,7 +190,7 @@ def main(): print_rank0(f'output tensor sum {out_sum_custom}') print_rank0( - f'allgather results are close {combined_tensor_custom.allclose(combined_tensor_torch)}' + f'allgather results of torch API and customized op are close {combined_tensor_custom.allclose(combined_tensor_torch)}' ) From 52085080c2007ce0e8e27f17f80c334ced2af388 Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Fri, 23 Jul 2021 19:24:29 +0000 Subject: [PATCH 06/12] removed customized c++ extension revert back to use torch distributed API --- csrc/communication/collective_comm.cpp | 216 ------------------ deepspeed/ops/communication/__init__.py | 1 - deepspeed/ops/communication/communication.py | 60 ----- .../runtime/zero/partition_parameters.py | 95 +++++--- op_builder/__init__.py | 4 +- op_builder/communication.py | 28 --- tests/benchmarks/allgather_bench.py | 198 ---------------- 7 files changed, 65 insertions(+), 537 deletions(-) delete mode 100644 csrc/communication/collective_comm.cpp delete mode 100644 deepspeed/ops/communication/__init__.py delete mode 100644 deepspeed/ops/communication/communication.py delete mode 100644 op_builder/communication.py delete mode 100644 tests/benchmarks/allgather_bench.py diff --git a/csrc/communication/collective_comm.cpp b/csrc/communication/collective_comm.cpp deleted file mode 100644 index 5423533fb89c..000000000000 --- a/csrc/communication/collective_comm.cpp +++ /dev/null @@ -1,216 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -int debug_flag = std::getenv("DS_DEBUG") ? std::stoi(std::getenv("DS_DEBUG")) : 0; - -// recording created ncclComm_t -// using processGroup Name as key -std::unordered_map group_communicators; - -// NCCL type typing -// copied from pytorch source code -std::map ncclDataType = { - {at::kChar, ncclInt8}, - {at::kByte, ncclUint8}, - {at::kFloat, ncclFloat}, - {at::kDouble, ncclDouble}, - {at::kInt, ncclInt32}, - {at::kLong, ncclInt64}, - {at::kHalf, ncclHalf}, - {at::kBool, ncclUint8}, -#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 301 - {at::kBFloat16, ncclBfloat16}, -#endif -}; - -// Helper function that gets the data type and issues error if not supported -// from pytorch source code -ncclDataType_t getNcclDataType(at::ScalarType type) -{ - auto it = ncclDataType.find(type); - TORCH_CHECK(it != ncclDataType.end(), - "Input tensor data type is not supported for NCCL process group: ", - type); - return it->second; -} - -void check_tensors(std::vector& output_tensors, - std::vector& input_tensors, - int world_size) -{ - if (input_tensors.size() == 0 || output_tensors.size() == 0) { - TORCH_CHECK(false, "output/input tensor list must be nonempty"); - } - if (output_tensors.size() != input_tensors.size()) { - TORCH_CHECK(false, "output and input tensors must have same size"); - } - - for (size_t i = 0; i < input_tensors.size(); ++i) { - auto out = output_tensors[i]; - auto in = input_tensors[i]; - if (out.numel() != in.numel() * world_size) { - std::stringstream ss; - ss << "output tensor numel != input tensor numel * world_size at" << i; - TORCH_CHECK(false, ss.str()); - } - } -} - -// rank0 create the ncclUniqueId -// broadcast using old ProcessGroupNCCL -// ncclCommInitRank with ncclUniqueId and same rank and world size from current -// ProcessGroupNCCL -// -// Note: reason for creating new ncclComm_t, ::c10d::ProcessGroupNCCL didn't expose -// APIs for getting communicator -ncclComm_t create_communicator(std::vector& input_tensors, - std::string& pg_name, - ::c10d::ProcessGroupNCCL& pg) -{ - int rank = pg.getRank(); - int world_size = pg.getSize(); - at::Tensor& first_tensor = input_tensors[0]; - auto device_idx = first_tensor.get_device(); - if (debug_flag) printf("creating new communicator at device %ld\n", device_idx); - - // - ncclUniqueId nccl_id; - ncclComm_t nccl_comm; - - auto id_tensor_option = torch::TensorOptions() - .dtype(torch::kUInt8) - .layout(torch::kStrided) // dense tensor - .requires_grad(false); - - std::vector bcast_tensor; - if (rank == 0) { - auto _result = ncclGetUniqueId(&nccl_id); - if (_result != ncclSuccess) { - TORCH_CHECK(false, "Getting nccl unique id failed"); - // it suppose to exit - } - id_tensor_option.device(torch::kCPU); - at::Tensor cpu_tensor = torch::empty(sizeof(ncclUniqueId), id_tensor_option).zero_(); - memcpy(cpu_tensor.data_ptr(), &nccl_id, sizeof(ncclUniqueId)); - - at::Tensor id_tensor = cpu_tensor.to(first_tensor.device()); - bcast_tensor.push_back(std::move(id_tensor)); - } else { - at::Tensor id_tensor = - torch::empty(sizeof(ncclUniqueId), id_tensor_option).zero_().to(first_tensor.device()); - bcast_tensor.push_back(std::move(id_tensor)); - } - if (debug_flag) - printf("rank %d, created tensor holder, device %ld, is_cuda %d \n", - rank, - device_idx, - bcast_tensor[0].is_cuda()); - - // bcast - { - at::cuda::CUDAGuard gpuGuard(device_idx); - // make sure the allocated tensors are ready - AT_CUDA_CHECK(cudaDeviceSynchronize()); - auto work = pg.broadcast(bcast_tensor); - // make sure the broadcast finished - AT_CUDA_CHECK(cudaDeviceSynchronize()); - } - - // if rank != 0 - // then need to copy ncclUniqueId from bcast_tensor - if (rank != 0) { - auto cpu_tensor = bcast_tensor[0].to(at::kCPU); - std::memcpy(&nccl_id, cpu_tensor.data_ptr(), cpu_tensor.nbytes()); - } - - { - at::cuda::CUDAGuard gpuGuard(device_idx); - // init communicator and save - ncclCommInitRank(&nccl_comm, world_size, nccl_id, rank); - group_communicators[pg_name] = nccl_comm; - - if (debug_flag) printf("nccl_comm initialized at rank %d, device %ld\n", rank, device_idx); - } - - return nccl_comm; -} - -// get communicator from global map -// if not found, create a new one -ncclComm_t get_communicator(std::vector& input_tensors, - std::string& pg_name, - ::c10d::ProcessGroupNCCL& pg) -{ - auto found = group_communicators.find(pg_name); - if (found == group_communicators.end()) { - return create_communicator(input_tensors, pg_name, pg); - } else { - return found->second; - } -} - -int launch_nccl_allgather(std::vector& output_tensors, - std::vector& input_tensors, - ncclComm_t comm) -{ - auto& first_input = input_tensors[0]; - auto device_idx = first_input.get_device(); - if (debug_flag) - printf("launching allgather op with number of tensors %lu, at device %ld \n", - input_tensors.size(), - device_idx); - - // this suppose to get the cuda stream specified by `with torch.cuda.stream(comm_stream): ...` - at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_idx); - - ncclGroupStart(); - for (size_t i = 0; i < input_tensors.size(); ++i) { - at::Tensor& input = input_tensors[i]; - at::Tensor& output = output_tensors[i]; - ncclAllGather(input.data_ptr(), - output.data_ptr(), - input.numel(), - getNcclDataType(input.scalar_type()), - comm, - stream.stream()); - } - ncclGroupEnd(); - - return 0; -} - -int inplaceAllgather(std::vector& output_tensors, - std::vector& input_tensors, - ::c10d::ProcessGroupNCCL& pg, - std::string pg_name) -{ - // ::c10d::ProcessGroup& p_pg = pg; - if (debug_flag) - printf("inplaceAllgather:: process group rank %d, size %d, pg_name %s \n", - pg.getRank(), - pg.getSize(), - pg_name.c_str()); - - check_tensors(output_tensors, input_tensors, pg.getSize()); - - auto nccl_comm = get_communicator(input_tensors, pg_name, pg); - - int res = launch_nccl_allgather(output_tensors, input_tensors, nccl_comm); - - return res; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("_inplace_allgather", &inplaceAllgather, "inplace all-gather (without memcpy)"); -} diff --git a/deepspeed/ops/communication/__init__.py b/deepspeed/ops/communication/__init__.py deleted file mode 100644 index 1e03579753f9..000000000000 --- a/deepspeed/ops/communication/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .communication import inplace_allgather diff --git a/deepspeed/ops/communication/communication.py b/deepspeed/ops/communication/communication.py deleted file mode 100644 index dfaa785a2b68..000000000000 --- a/deepspeed/ops/communication/communication.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch - -from torch.distributed.distributed_c10d import _pg_names, _get_default_group -import torch.distributed as dist - -from ..op_builder import CommunicationBuilder - -coll_comm_module = None - - -class CommunicationHandle: - def __init__(self, - start_event: torch.cuda.Event, - end_event: torch.cuda.Event, - timing: bool) -> None: - - self.start = start_event - self.end = end_event - self.timing = timing - - def is_completed(self, ): - return self.end.query() - - def wait(self, stream): - self.end.wait(stream) - - def synchronize(self): - self.end.synchronize() - - -def map_process_group(group): - # print(f'rank {dist.get_rank(group=group)}, _pg_names {_pg_names}, _pg_group_ranks {_pg_group_ranks}') - if group == dist.group.WORLD: - return _get_default_group() - else: - return group - - -def inplace_allgather(output_tensors, input_tensors, group, comm_stream, timing=False): - """""" - global coll_comm_module - if coll_comm_module is None: - coll_comm_module = CommunicationBuilder().load() - - group = map_process_group(group) - process_group_name = _pg_names[group] - - start_event = torch.cuda.Event(enable_timing=timing) - end_event = torch.cuda.Event(enable_timing=timing) - if timing: - start_event.record(comm_stream) - with torch.cuda.stream(comm_stream): - coll_comm_module._inplace_allgather(output_tensors, - input_tensors, - group, - process_group_name) - - end_event.record(comm_stream) - - return CommunicationHandle(start_event, end_event, timing) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index ce220a40bdab..904b19895e7f 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -10,10 +10,13 @@ import functools import itertools -from deepspeed.ops.communication import communication as ds_comm - import torch -from torch.distributed.distributed_c10d import _get_global_rank +from torch.distributed.distributed_c10d import _get_global_rank, group + +try: + from torch.distributed.distributed_c10d import _all_gather_base as all_gather +except: + from torch.distributed.distributed_c10d import all_gather from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3 from .offload_constants import * @@ -669,7 +672,7 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): if not async_op: # ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) - ret_value = self._allgather_params_with_custom_op(all_gather_list, hierarchy) + ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy) for param in all_gather_list: param.ds_status = ZeroParamStatus.AVAILABLE @@ -857,26 +860,35 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): # param.ds_numel).view(param.ds_shape) # param.data = replicated_tensor.data # return None - partitions = [] - for i in range(self.world_size): - partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) + try: + # try the _all_gather_base on PyTorch master branch + handle = all_gather(flat_tensor, + param.ds_tensor, + group=self.ds_process_group, + async_op=async_op) + except: + partitions = [] + for i in range(self.world_size): + partitions.append( + flat_tensor.narrow(0, + partition_size * i, + partition_size)) - if i == torch.distributed.get_rank(group=self.ds_process_group): - partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) + if i == torch.distributed.get_rank(group=self.ds_process_group): + partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) - handle = torch.distributed.all_gather(partitions, - partitions[self.rank], - group=self.ds_process_group, - async_op=async_op) + handle = torch.distributed.all_gather(partitions, + partitions[self.rank], + group=self.ds_process_group, + async_op=async_op) replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) param.data = replicated_tensor.data return handle - def _allgather_params_with_custom_op(self, param_list, hierarchy=0): - """ using customized allgather op to avoid redundant cudaMemcpy - Note: the torch.distributed.allgather has extra copy: - https://github.com/pytorch/pytorch/blob/v1.9.0/torch/lib/c10d/ProcessGroupNCCL.cpp#L1469 + def _allgather_params_coalesced(self, param_list, hierarchy=0): + """ blocking call + avoid explicit memory copy in _allgather_params """ if len(param_list) == 0: return @@ -888,33 +900,54 @@ def _allgather_params_with_custom_op(self, param_list, hierarchy=0): local_tensors.append(param.ds_tensor) # allocate memory for allgather params - allgather_output_params = [] + allgather_params = [] for psize in partition_sizes: tensor_size = psize * self.world_size flat_tensor = torch.empty(tensor_size, dtype=param_list[0].dtype, device=self.local_device).view(-1) flat_tensor.requres_grad = False - allgather_output_params.append(flat_tensor) - - # suppose to set the communication stream outside of this function - comm_stream = torch.cuda.current_stream() - - # the handle is a wrapper of the start and the end events - comm_handle = ds_comm.inplace_allgather(allgather_output_params, - local_tensors, - self.ds_process_group, - comm_stream) + allgather_params.append(flat_tensor) + + # launch + launch_handles = [] + # backend = get_backend(self.ds_process_group) + # with _batch_p2p_manager(backend): + for param_idx, param in enumerate(param_list): + input_tensor = local_tensors[param_idx].view(-1) + + try: + # try the _all_gather_base from Pytorch master + h = all_gather(allgather_params[param_idx], + input_tensor, + group=self.ds_process_group, + async_op=True) + except: + output_list = [] + for i in range(self.world_size): + psize = partition_sizes[param_idx] + partition = allgather_params[param_idx].narrow(0, i * psize, psize) + output_list.append(partition) + + # back to old all_gather function signature + h = all_gather(output_list, + input_tensor, + group=self.ds_process_group, + async_op=True) + launch_handles.append(h) + + # Wait ensures the operation is enqueued, but not necessarily complete. + launch_handles[-1].wait() # assign to param.data (not copy) for i, param in enumerate(param_list): - gathered_tensor = allgather_output_params[i] + gathered_tensor = allgather_params[i] param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data - # this synchronize on cuda.Event - comm_handle.synchronize() + # guarantee the communication to be completed + torch.cuda.synchronize() return None diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 2a854ed663e2..f19ed916c332 100755 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -12,7 +12,6 @@ from .builder import get_default_compute_capatabilities from .transformer_inference import InferenceBuilder from .quantizer import QuantizerBuilder -from .communication import CommunicationBuilder # TODO: infer this list instead of hard coded # List of all available ops @@ -26,7 +25,6 @@ AsyncIOBuilder(), InferenceBuilder(), UtilsBuilder(), - QuantizerBuilder(), - CommunicationBuilder() + QuantizerBuilder() ] ALL_OPS = {op.name: op for op in __op_builders__} diff --git a/op_builder/communication.py b/op_builder/communication.py deleted file mode 100644 index 81fdb28d6dc4..000000000000 --- a/op_builder/communication.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -Customized wrap for collective communications -""" - -from .builder import OpBuilder -import os - - -class CommunicationBuilder(OpBuilder): - BUILD_VAR = "DS_BUILD_COLL_COMM" - NAME = "communication" - - def __init__(self): - super().__init__(name=self.NAME) - - def absolute_name(self): - return f'deepspeed.ops.{self.NAME}_op' - - def include_paths(self): - NCCL_HOME = os.getenv('NCCL_HOME') - if not NCCL_HOME: - NCCL_HOME = '/usr/local/cuda' - - inc_path = os.path.join(NCCL_HOME, 'include') - return [inc_path] - - def sources(self): - return ['csrc/communication/collective_comm.cpp'] diff --git a/tests/benchmarks/allgather_bench.py b/tests/benchmarks/allgather_bench.py deleted file mode 100644 index 52927d97aadc..000000000000 --- a/tests/benchmarks/allgather_bench.py +++ /dev/null @@ -1,198 +0,0 @@ -""" -Test command: -python3 -m torch.distributed.launch --nnodes=1 --nproc_per_node=2 -""" - -import argparse - -import torch -import math -from torch._C import device -from torch.autograd.grad_mode import F -import torch.distributed as dist -from torch.distributed.distributed_c10d import _get_default_group, _pg_names -from deepspeed.ops.communication import inplace_allgather -import time -import numpy as np - - -def sizeof_dtype(dtype): - if dtype == torch.half: - return 2 - elif dtype == torch.float: - return 4 - else: - return None - - -def prepare_tensor(partition_sizes, world_size, device, dtype=torch.half): - # transformer layer structure with partitioned onto 8 GPUs - - output_tensors = [] - input_tensors = [] - - for _size in partition_sizes: - std = 1 / math.sqrt(_size) - input_t = torch.empty(_size, - dtype=dtype, - device=device).view(-1).uniform_(-std, - std) - output_t = torch.empty(_size * world_size, - dtype=dtype, - device=device).view(-1).uniform_(-std, - std) - input_tensors.append(input_t) - output_tensors.append(output_t) - return output_tensors, input_tensors - - -def _torch_allgather_once(output_tensors, - input_tensors, - partition_sizes, - rank, - world_size): - """""" - s = torch.cuda.Stream() - handles = [] - for part_idx, part_size in enumerate(partition_sizes): - output_t = output_tensors[part_idx] - input_t = input_tensors[part_idx] - - output_list = [] - for i in range(world_size): - out_tensor = output_t.narrow(0, i * part_size, part_size) - output_list.append(out_tensor) - - h = dist.all_gather(output_list, input_t, async_op=True) - handles.append(h) - - torch.cuda.synchronize() - - -def print_bw_rank0(partition_sizes, time_costs, dtype): - if dist.get_rank() == 0: - elem_size = sizeof_dtype(dtype) # in bytes - assert elem_size != None - - numel = sum(partition_sizes) - - avg_t = np.mean(time_costs) - bw = numel * elem_size * (dist.get_world_size() - 1) / 1e9 / avg_t - print(f'avg time {avg_t * 1e3} ms, bw {bw} GB/s') - - -def bench_torch_allgather(output_tensors, - input_tensors, - partition_sizes, - rank, - world_size, - warm_up=5, - repeat=10): - ts = [] - for i in range(warm_up + repeat): - s = time.time() - _torch_allgather_once(output_tensors, - input_tensors, - partition_sizes, - rank, - world_size) - e = time.time() - - if i >= warm_up: - ts.append(e - s) - - print_bw_rank0(partition_sizes, ts, input_tensors[0].dtype) - - -def _custom_allgather_once(output_tensors, input_tensors, comm_stream): - default_pg = _get_default_group() - res = inplace_allgather(output_tensors, input_tensors, default_pg, comm_stream) - comm_stream.synchronize() - return res - - -def bench_custom_allgather(output_tensors, - input_tensors, - partition_sizes, - rank, - world_size, - warm_up=5, - repeat=10): - """""" - comm_stream = torch.cuda.Stream(rank % torch.cuda.device_count()) - - ts = [] - for i in range(warm_up + repeat): - s = time.time() - _custom_allgather_once(output_tensors, input_tensors, comm_stream) - e = time.time() - - if i >= warm_up: - ts.append(e - s) - - print_bw_rank0(partition_sizes, ts, input_tensors[0].dtype) - - -def print_rank0(msg): - if (dist.get_rank() == 0): - print(msg) - - -def main(): - """""" - dist.init_process_group(backend='nccl') - - rank = dist.get_rank() - local_size = torch.cuda.device_count() - device_id = rank % local_size - world_size = dist.get_world_size() - torch.cuda.set_device(device_id) - - partition_sizes = [ - 2457600, - 960, - 819200, - 320, - 320, - 320, - 3276800, - 1280, - 3276800, - 320, - 320, - 320 - ] - output_tensors, input_tensors = prepare_tensor(partition_sizes, - dist.get_world_size(), f'cuda:{device_id}', - torch.half) - print_rank0('Using torch.distributed.allgather') - bench_torch_allgather(output_tensors, - input_tensors, - partition_sizes, - rank, - world_size) - - combined_tensor_torch = torch.cat(output_tensors) - out_sum_torch = combined_tensor_torch.sum() - print_rank0(f'output tensors sum {out_sum_torch}') - - # clean output tensors - for t in output_tensors: - t.zero_() - print_rank0('Using customized allgather') - bench_custom_allgather(output_tensors, - input_tensors, - partition_sizes, - rank, - world_size) - combined_tensor_custom = torch.cat(output_tensors) - out_sum_custom = combined_tensor_custom.sum() - print_rank0(f'output tensor sum {out_sum_custom}') - - print_rank0( - f'allgather results of torch API and customized op are close {combined_tensor_custom.allclose(combined_tensor_torch)}' - ) - - -if __name__ == '__main__': - main() From 1ed96ce2f70d289a38b6dea24e9743b8506e5efd Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Fri, 23 Jul 2021 20:06:10 +0000 Subject: [PATCH 07/12] change torch.ones to torch empty --- deepspeed/runtime/zero/partition_parameters.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 904b19895e7f..bba10bfe26b7 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -720,7 +720,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): f'Before partitioning param {param.ds_id} {param.shape}', force=False) #param.data does not store anything meaningful in partitioned state - param.data = torch.ones(1, dtype=self.dtype).to(param.device) + param.data = torch.empty(1, dtype=self.dtype, device=param.device) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -741,7 +741,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): numel=partition_size): final_location = OFFLOAD_NVME_DEVICE buffer = self.param_swapper.get_buffer(param, partition_size) - partitioned_tensor = torch.zeros(1, + partitioned_tensor = torch.empty(1, dtype=param.dtype, device=buffer.device) partitioned_tensor.data = buffer.data @@ -750,7 +750,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): ) else: - partitioned_tensor = torch.zeros( + partitioned_tensor = torch.empty( partition_size, dtype=param.dtype, device=OFFLOAD_CPU_DEVICE From 0e6d8e0c58a988d973555c1f50daf9691478bccc Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Mon, 9 Aug 2021 23:28:48 -0400 Subject: [PATCH 08/12] typo --- deepspeed/runtime/zero/partition_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index bba10bfe26b7..f9a8a51a14cd 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -906,7 +906,7 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): flat_tensor = torch.empty(tensor_size, dtype=param_list[0].dtype, device=self.local_device).view(-1) - flat_tensor.requres_grad = False + flat_tensor.requires_grad = False allgather_params.append(flat_tensor) # launch From 50a9215df64438bce62b6d89e3c7d0071100188b Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Fri, 15 Oct 2021 03:20:23 +0000 Subject: [PATCH 09/12] warn if not cuda tensor for allgather --- deepspeed/runtime/zero/partition_parameters.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index e172f51e606a..baf4c516e33d 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -942,6 +942,8 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): psize = partition_sizes[param_idx] partition = allgather_params[param_idx].narrow(0, i * psize, psize) output_list.append(partition) + if not partition.is_cuda: + logger.warning(f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}') # back to old all_gather function signature h = all_gather(output_list, From 813cb22701a5cd0fca83d7770c30f3cbe1b0515f Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Thu, 21 Oct 2021 19:08:45 +0000 Subject: [PATCH 10/12] fix formatting --- deepspeed/runtime/zero/partition_parameters.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index baf4c516e33d..2b79fd628ef3 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -943,7 +943,9 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): partition = allgather_params[param_idx].narrow(0, i * psize, psize) output_list.append(partition) if not partition.is_cuda: - logger.warning(f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}') + logger.warning( + f'param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}' + ) # back to old all_gather function signature h = all_gather(output_list, From c092b789ea2ba847bc5c905cf0c8a52819ebc254 Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Fri, 22 Oct 2021 19:48:16 +0000 Subject: [PATCH 11/12] fix: move ds_tensor to cuda device but it is strange that the ds_tensor haven't been moved to cuda --- deepspeed/runtime/zero/partition_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index d00945f9f94d..705c84ed0867 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -919,7 +919,7 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): local_tensors = [] for param in param_list: partition_sizes.append(param.ds_tensor.ds_numel) - local_tensors.append(param.ds_tensor) + local_tensors.append(param.ds_tensor.cuda()) # allocate memory for allgather params allgather_params = [] From 7a8017213bdbb7eab8dda32b5258b6c9a73a0049 Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Sat, 30 Oct 2021 00:52:25 +0000 Subject: [PATCH 12/12] remove try clause on the path for fetching params --- .../runtime/zero/partition_parameters.py | 56 ++++++++++--------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 705c84ed0867..4e1bd22b5d8c 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -12,11 +12,7 @@ import torch from torch.distributed.distributed_c10d import _get_global_rank, group - -try: - from torch.distributed.distributed_c10d import _all_gather_base as all_gather -except: - from torch.distributed.distributed_c10d import all_gather +import torch.distributed as dist from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3 from .offload_constants import * @@ -501,6 +497,14 @@ def get_model(): assert isinstance(module, torch.nn.Module) self._convert_to_zero_parameters(module.parameters(recurse=True)) + self.use_all_gather_base = False + try: + from torch.distributed.distributed_c10d import _all_gather_base as all_gather + self.use_all_gather_base = True + except: + logger.info( + f"_all_gather_base API is not available in torch {torch.__version__}") + def _convert_to_zero_parameters(self, param_list): for param in param_list: if is_zero_param(param): @@ -882,13 +886,13 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): # param.ds_numel).view(param.ds_shape) # param.data = replicated_tensor.data # return None - try: + if self.use_all_gather_base: # try the _all_gather_base on PyTorch master branch - handle = all_gather(flat_tensor, - param.ds_tensor, - group=self.ds_process_group, - async_op=async_op) - except: + handle = dist._all_gather_base(flat_tensor, + param.ds_tensor, + group=self.ds_process_group, + async_op=async_op) + else: partitions = [] for i in range(self.world_size): partitions.append( @@ -896,13 +900,13 @@ def _allgather_param(self, param, async_op=False, hierarchy=0): partition_size * i, partition_size)) - if i == torch.distributed.get_rank(group=self.ds_process_group): + if i == dist.get_rank(group=self.ds_process_group): partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) - handle = torch.distributed.all_gather(partitions, - partitions[self.rank], - group=self.ds_process_group, - async_op=async_op) + handle = dist.all_gather(partitions, + partitions[self.rank], + group=self.ds_process_group, + async_op=async_op) replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) param.data = replicated_tensor.data @@ -938,13 +942,13 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): for param_idx, param in enumerate(param_list): input_tensor = local_tensors[param_idx].view(-1) - try: + if self.use_all_gather_base: # try the _all_gather_base from Pytorch master - h = all_gather(allgather_params[param_idx], - input_tensor, - group=self.ds_process_group, - async_op=True) - except: + h = dist._all_gather_base(allgather_params[param_idx], + input_tensor, + group=self.ds_process_group, + async_op=True) + else: output_list = [] for i in range(self.world_size): psize = partition_sizes[param_idx] @@ -956,10 +960,10 @@ def _allgather_params_coalesced(self, param_list, hierarchy=0): ) # back to old all_gather function signature - h = all_gather(output_list, - input_tensor, - group=self.ds_process_group, - async_op=True) + h = dist.all_gather(output_list, + input_tensor, + group=self.ds_process_group, + async_op=True) launch_handles.append(h) # Wait ensures the operation is enqueued, but not necessarily complete.