Skip to content

Commit

Permalink
Loss goes down on fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn committed Jul 8, 2024
1 parent ee43c65 commit ab4e73c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
18 changes: 8 additions & 10 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,7 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat
def init_model_parallel(self):
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from megatron.core import parallel_state

if self.convert_module_fn:
self.apply_convert_module_fn()

self.init_ddp()


for model_module in self:
if not self._cpu:
model_module.cuda(torch.cuda.current_device())
Expand Down Expand Up @@ -460,10 +455,13 @@ def init_model_parallel(self):
logging.info(msg)

if num_params != num_trainable_params:
logging.info(
f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)"
)

logging.info(f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)")

if self.convert_module_fn:
self.apply_convert_module_fn()

self.init_ddp()

def apply_convert_module_fn(self):
for i in range(len(self)):
self[i] = self.convert_module_fn(self[i])
Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(

self._ddp = ddp
if ddp == "megatron":
self.ddp_config = DistributedDataParallelConfig()
self.ddp_config = DistributedDataParallelConfig(check_for_nan_in_grad=True)
elif isinstance(ddp, DistributedDataParallelConfig):
self.ddp_config = ddp
elif ddp == "pytorch":
Expand Down

0 comments on commit ab4e73c

Please sign in to comment.