diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index d95c5a3df..2d648c959 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -101,16 +101,30 @@ def _reduce_scatter_along_seq_dim(input_, seq_dim): if get_fp32_allreduce(): input_ = input_.float() - assert input_.shape[seq_dim] % world_size == 0 - tensor_list = list( - torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) - ) - output = torch.empty_like(tensor_list[0]) - torch.distributed.reduce_scatter(output, tensor_list) + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + assert dim_size[seq_dim] % world_size == 0 + + if seq_dim == 0: + dim_size[seq_dim] = dim_size[seq_dim] // world_size + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + tensor_list = list( + torch.split(input_, input_.shape[seq_dim] // world_size, seq_dim) + ) + output = torch.empty_like(tensor_list[0]) + torch.distributed.reduce_scatter(output, tensor_list) # reconvert to original Bf16/Fp16 dtype if get_fp32_allreduce(): - input_ = input_.to(dt) + output = output.to(dt) return output @@ -123,12 +137,28 @@ def _gather_along_seq_dim(input_, seq_dim): if world_size == 1: return input_ - input_ = input_.contiguous() - rank = get_model_parallel_rank() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather(tensor_list, input_, group=get_model_parallel_group()) - output = torch.cat(tensor_list, dim=seq_dim) + dim_size = list(input_.size()) + assert ( + isinstance(seq_dim, int) and seq_dim < len(dim_size) and seq_dim >= 0 + ), "seq_dim must be a valid tensor dim" + dim_size[seq_dim] = dim_size[seq_dim] * world_size + + if seq_dim == 0: + output = torch.empty( + dim_size, dtype=input_.dtype, device=torch.cuda.current_device() + ) + torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=get_model_parallel_group() + ) + else: + input_ = input_.contiguous() + rank = get_model_parallel_rank() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather( + tensor_list, input_, group=get_model_parallel_group() + ) + output = torch.cat(tensor_list, dim=seq_dim) return output