Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into pr/ssl_generative_pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
Kuray107 committed Aug 6, 2024
2 parents 6f9a9c6 + 1bc5a87 commit ffacac2
Showing 1 changed file with 17 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

try:
from megatron.core import parallel_state
from megatron.core.distributed import finalize_model_grads
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func

HAVE_MEGATRON_CORE = True
Expand Down Expand Up @@ -378,11 +379,27 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None):
)
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters
elif not forward_only and self.use_mcore_dist_optim:
if self.cfg.optim.get("overlap_grad_sync", False):
no_sync_func = [model_chunk.no_sync for model_chunk in self.model]
no_sync_func = no_sync_func[0] if len(self.model) == 1 else no_sync_func

if self.cfg.optim.get("delay_grad_reduce", True):
grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.model]
grad_sync_func = grad_sync_func[0] if len(self.model) == 1 else grad_sync_func
if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("delay_param_gather", False):
param_sync_func = [
lambda x, model_index=model_index: self._optimizer.finish_param_sync(model_index, x)
for model_index in range(len(self.model))
]
param_sync_func = param_sync_func[0] if len(self.model) == 1 else param_sync_func

for module in self.get_model_module_list():
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func
module.config.param_sync_func = param_sync_func
if self.use_mcore_dist_optim:
module.config.finalize_model_grads_func = finalize_model_grads

fwd_bwd_function = get_forward_backward_func()

Expand Down

0 comments on commit ffacac2

Please sign in to comment.