From 067a9234a4f1de1f971a9c5f1ed3eec13b34da44 Mon Sep 17 00:00:00 2001 From: akoumpa Date: Fri, 27 Sep 2024 16:37:49 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: akoumpa --- .../pytorch/strategies/megatron_strategy.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 107957f7f6fa..07cda991fd5d 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -641,18 +641,26 @@ def optimizer_sharded_state_dict(self, is_loading=False): @override def save_checkpoint( - self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None, allow_meta_tensors = False, + self, + checkpoint: Dict[str, Any], + filepath: Union[str, Path], + storage_options: Optional[Any] = None, + allow_meta_tensors=False, ) -> None: checkpoint["state_dict"] = OrderedDict([]) # remove device state_dict # retrieve `sharded_state_dict` if it has not already been configured in `on_save_checkpoint` if "sharded_state_dict" not in checkpoint: checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict() - meta_tensors = list(filter(lambda x: isinstance(x[1], torch.Tensor) and x[1].device.type == 'meta', checkpoint['sharded_state_dict'].items())) - for (name, tensor) in meta_tensors: - logging.warning(f"Got device=meta for {name}") + meta_tensors = list( + filter( + lambda x: isinstance(x[1], torch.Tensor) and x[1].device.type == 'meta', + checkpoint['sharded_state_dict'].items(), + ) + ) + for name, tensor in meta_tensors: + logging.warning(f"Got device=meta for {name}") if not allow_meta_tensors: - assert len(meta_tensors) == 0, meta_tensors - + assert len(meta_tensors) == 0, meta_tensors ## replace unsharded optimizer_states with sharded dict. ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called,