From b0ba18e70c7375f2439d84012cac686c2c880320 Mon Sep 17 00:00:00 2001 From: Hongwei Date: Mon, 16 Dec 2024 15:41:37 -0800 Subject: [PATCH 1/2] initialize communication backend --- deepspeed/runtime/domino/transformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 8eb95e49c29d..621d445eae78 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -6,13 +6,13 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -import deepspeed -from deepspeed import comm as dist +import deepspeed.comm +from deepspeed.comm.comm import init_distributed from deepspeed.accelerator import get_accelerator def is_rank_0(): - if dist.get_rank() == 0: + if deepspeed.comm.get_rank() == 0: return True @@ -249,6 +249,8 @@ def __init__(self, output_bias=None): super(DominoTransformerLayer, self).__init__() + init_distributed() + self.llama_model = config.llama_model self.layer_number = layer_number self.layer_type = layer_type From 2833db753240b7e826b9d6b313ca43d599f4d2e1 Mon Sep 17 00:00:00 2001 From: Hongwei Date: Mon, 16 Dec 2024 17:46:33 -0800 Subject: [PATCH 2/2] add initialization check --- deepspeed/runtime/domino/transformer.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/deepspeed/runtime/domino/transformer.py b/deepspeed/runtime/domino/transformer.py index 621d445eae78..88c5494c8147 100644 --- a/deepspeed/runtime/domino/transformer.py +++ b/deepspeed/runtime/domino/transformer.py @@ -6,13 +6,12 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -import deepspeed.comm -from deepspeed.comm.comm import init_distributed +import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator def is_rank_0(): - if deepspeed.comm.get_rank() == 0: + if dist.get_rank() == 0: return True @@ -97,7 +96,7 @@ def backward(ctx, grad_output): return grad_output # Async All-reduce. - handle = deepspeed.comm.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) + handle = dist.all_reduce(grad_output, group=ctx.mpu.get_tensor_model_parallel_group(), async_op=True) ctx.handle_dic[ctx.h_id] = handle return None, grad_output, None, None @@ -249,7 +248,9 @@ def __init__(self, output_bias=None): super(DominoTransformerLayer, self).__init__() - init_distributed() + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "deepspeed.comm is not initialized!" self.llama_model = config.llama_model self.layer_number = layer_number @@ -360,18 +361,14 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): layernorm_output0, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle0 = deepspeed.comm.all_reduce(attention_output0, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) attention_output1, attention_bias1 = \ self.self_attention( layernorm_output1, attention_mask, rotary_pos_emb=rotary_pos_emb) - handle1 = deepspeed.comm.all_reduce(attention_output1, - group=self.mpu.get_tensor_model_parallel_group(), - async_op=True) + handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle0.wait() # Residual0 connection. @@ -415,7 +412,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): output0 = output0 + bias_c output0 = self.mlp_activation_func(output0) output0 = torch.matmul(output0, self.weight_r.t()) - handle2 = deepspeed.comm.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) + handle2 = dist.all_reduce(output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True) handle1.wait() @@ -427,7 +424,7 @@ def forward(self, hidden_states, attention_mask, rotary_pos_emb=None): if bias_c is not None: output1 = output1 + bias_c output1 = torch.matmul(output1, self.weight_r.t()) - deepspeed.comm.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) + dist.all_reduce(output1, group=self.mpu.get_tensor_model_parallel_group()) handle2.wait()