Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>
  • Loading branch information
marcromeyn committed Jul 8, 2024
1 parent a28150f commit b93a551
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 18 deletions.
2 changes: 1 addition & 1 deletion nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,5 +512,5 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri
_state_dict[key[len(to_remove) :]] = value
else:
_state_dict[key] = value

module.load_state_dict(_state_dict, strict=strict)
27 changes: 15 additions & 12 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,12 +424,12 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat
def init_model_parallel(self):
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from megatron.core import parallel_state

if self.convert_module_fn:
self.apply_convert_module_fn()

self.init_ddp()

for model_module in self:
if not self._cpu:
model_module.cuda(torch.cuda.current_device())
Expand All @@ -448,7 +448,7 @@ def init_model_parallel(self):
# Print number of parameters.
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
from nemo.utils import logging

num_params = _calc_number_of_params(list(self))
num_trainable_params = _calc_number_of_trainable_params(list(self))

Expand All @@ -458,14 +458,16 @@ def init_model_parallel(self):
f"{num_params}"
)
logging.info(msg)

if num_params != num_trainable_params:
logging.info(f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)")

logging.info(
f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)"
)

def apply_convert_module_fn(self):
for i in range(len(self)):
self[i] = self.convert_module_fn(self[i])

def init_ddp(self):
if not isinstance(self.ddp_config, DistributedDataParallelConfig):
return
Expand Down Expand Up @@ -591,8 +593,10 @@ def __getattr__(self, item: Any) -> Any:
except AttributeError:
# If not found in superclass, check if we have any modules
if len(self) == 0:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}' and contains no modules")

raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{item}' and contains no modules"
)

# Try to get it from the first module
try:
return getattr(self._modules[self._get_abs_string_index(0)], item)
Expand Down Expand Up @@ -939,8 +943,7 @@ def _calc_number_of_params(model: List[nn.Module]) -> int:
def _calc_number_of_trainable_params(model: List[nn.Module]) -> int:
assert isinstance(model, list)

return sum([sum([p.numel() for p in model_module.parameters() if p.requires_grad])
for model_module in model])
return sum([sum([p.numel() for p in model_module.parameters() if p.requires_grad]) for model_module in model])


def is_list_of_iterators(var) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/pytorch/callbacks/model_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._maybe_apply_transform(trainer)

def _maybe_apply_transform(self, trainer):
if self._needs_to_call:
self.apply_transform(trainer)

def apply_transform(self, trainer):
self.model_transform(trainer.model)

Expand Down
2 changes: 1 addition & 1 deletion nemo/lightning/pytorch/callbacks/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def apply_transform(self, trainer):

logging.info("Setting up optimizers")
trainer.strategy.setup_optimizers(trainer)

if self.wrapped_io.adapter_ckpt_path is not None:
logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}")
adapter_state = self.wrapped_io.load_checkpoint(self.wrapped_io.adapter_ckpt_path)
Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
ddp_config=self.ddp_config,
convert_module_fn=convert_module_fn,
)

if self._init_model_parallel:
self.init_model_parallel()

self.megatron_parallel.trainer = trainer

# check signature-def of self.model.configure_optimizers to check if there's an optional arg: megatron_parallel
Expand Down

0 comments on commit b93a551

Please sign in to comment.