From ab4e73c7f5a950435a225c465c56e4eb4b20cf6f Mon Sep 17 00:00:00 2001 From: Marc Romeijn Date: Mon, 8 Jul 2024 04:59:12 -0700 Subject: [PATCH] Loss goes down on fp32 --- nemo/lightning/megatron_parallel.py | 18 ++++++++---------- nemo/lightning/pytorch/strategies.py | 2 +- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 480a49b6dd09..c3ee4ae96ebe 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -424,12 +424,7 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat def init_model_parallel(self): from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes from megatron.core import parallel_state - - if self.convert_module_fn: - self.apply_convert_module_fn() - - self.init_ddp() - + for model_module in self: if not self._cpu: model_module.cuda(torch.cuda.current_device()) @@ -460,10 +455,13 @@ def init_model_parallel(self): logging.info(msg) if num_params != num_trainable_params: - logging.info( - f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)" - ) - + logging.info(f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)") + + if self.convert_module_fn: + self.apply_convert_module_fn() + + self.init_ddp() + def apply_convert_module_fn(self): for i in range(len(self)): self[i] = self.convert_module_fn(self[i]) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 6b008831985e..d0e502839f2f 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -148,7 +148,7 @@ def __init__( self._ddp = ddp if ddp == "megatron": - self.ddp_config = DistributedDataParallelConfig() + self.ddp_config = DistributedDataParallelConfig(check_for_nan_in_grad=True) elif isinstance(ddp, DistributedDataParallelConfig): self.ddp_config = ddp elif ddp == "pytorch":