Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa committed Sep 27, 2024
1 parent 59f9873 commit 067a923
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,18 +641,26 @@ def optimizer_sharded_state_dict(self, is_loading=False):

@override
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None, allow_meta_tensors = False,
self,
checkpoint: Dict[str, Any],
filepath: Union[str, Path],
storage_options: Optional[Any] = None,
allow_meta_tensors=False,
) -> None:
checkpoint["state_dict"] = OrderedDict([]) # remove device state_dict
# retrieve `sharded_state_dict` if it has not already been configured in `on_save_checkpoint`
if "sharded_state_dict" not in checkpoint:
checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict()
meta_tensors = list(filter(lambda x: isinstance(x[1], torch.Tensor) and x[1].device.type == 'meta', checkpoint['sharded_state_dict'].items()))
for (name, tensor) in meta_tensors:
logging.warning(f"Got device=meta for {name}")
meta_tensors = list(
filter(
lambda x: isinstance(x[1], torch.Tensor) and x[1].device.type == 'meta',
checkpoint['sharded_state_dict'].items(),
)
)
for name, tensor in meta_tensors:
logging.warning(f"Got device=meta for {name}")
if not allow_meta_tensors:
assert len(meta_tensors) == 0, meta_tensors

assert len(meta_tensors) == 0, meta_tensors

## replace unsharded optimizer_states with sharded dict.
## note that if trainer.save_checkpoint(path, save_weights_only=True) is called,
Expand Down

0 comments on commit 067a923

Please sign in to comment.