From 1bc5a87bc1baaaade7ec9e17eeac897d7b92122b Mon Sep 17 00:00:00 2001 From: Michal Futrega Date: Tue, 6 Aug 2024 19:03:40 +0200 Subject: [PATCH] Add support for overlapped gradient and parameter synchronization for GPT SFT model (#10041) * Add support for overlapped gradient and parameter synchronization for GPT SFT model Signed-off-by: Michal Futrega * Add finalize_model_grads * Apply isort and black reformatting Signed-off-by: michal2409 --------- Signed-off-by: Michal Futrega Signed-off-by: michal2409 Co-authored-by: michal2409 --- .../language_modeling/megatron_gpt_sft_model.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 75dda1c4e9c8..9c2372ef38ca 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -40,6 +40,7 @@ try: from megatron.core import parallel_state + from megatron.core.distributed import finalize_model_grads from megatron.core.pipeline_parallel.schedules import get_forward_backward_func HAVE_MEGATRON_CORE = True @@ -378,11 +379,27 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): ) grad_sync_func = self.reduce_overlap_gradients param_sync_func = self.sync_overlap_parameters + elif not forward_only and self.use_mcore_dist_optim: + if self.cfg.optim.get("overlap_grad_sync", False): + no_sync_func = [model_chunk.no_sync for model_chunk in self.model] + no_sync_func = no_sync_func[0] if len(self.model) == 1 else no_sync_func + + if self.cfg.optim.get("delay_grad_reduce", True): + grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.model] + grad_sync_func = grad_sync_func[0] if len(self.model) == 1 else grad_sync_func + if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("delay_param_gather", False): + param_sync_func = [ + lambda x, model_index=model_index: self._optimizer.finish_param_sync(model_index, x) + for model_index in range(len(self.model)) + ] + param_sync_func = param_sync_func[0] if len(self.model) == 1 else param_sync_func for module in self.get_model_module_list(): module.config.no_sync_func = no_sync_func module.config.grad_sync_func = grad_sync_func module.config.param_sync_func = param_sync_func + if self.use_mcore_dist_optim: + module.config.finalize_model_grads_func = finalize_model_grads fwd_bwd_function = get_forward_backward_func()