-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpoint connector bugfixes #10647
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,8 @@ | |
|
||
import pytorch_lightning as pl | ||
from filelock import FileLock, Timeout | ||
from pytorch_lightning.trainer.states import TrainerFn | ||
|
||
|
||
# Dynamically inherit from the correct Path subclass based on the operating system. | ||
if os.name == 'nt': | ||
|
@@ -139,26 +141,25 @@ | |
Args: | ||
model (pl.LightningModule): The model to be set up. | ||
trainer (Optional[pl.Trainer]): The trainer to be used, if not provided a new one will be created. | ||
|
||
Returns | ||
------- | ||
pl.Trainer: The trainer configured with the model and strategy. | ||
""" | ||
from nemo.lightning import MegatronStrategy, Trainer | ||
from nemo.lightning._strategy_lib import megatron_lazy_init_context | ||
from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib | ||
|
||
_trainer = trainer or Trainer( | ||
devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False) | ||
devices=1, | ||
accelerator="cpu", | ||
strategy=MegatronStrategy(ckpt_save_optimizer=False, always_save_context=True), | ||
) | ||
|
||
_trainer.state.fn = TrainerFn.FITTING # needed for proper save. | ||
_trainer.strategy.connect(model) | ||
_trainer.strategy.setup_environment() | ||
|
||
if not model.state_dict(): | ||
_trainer.strategy.lazy_init = True | ||
with _trainer.init_module(), megatron_lazy_init_context(model.config): | ||
with _trainer.init_module(), _strategy_lib.megatron_lazy_init_context(model.config): | ||
model.configure_model() | ||
|
||
return _trainer | ||
|
||
def nemo_save(self, output_path: Path, trainer: pl.Trainer, dump_io: bool = True) -> None: | ||
|
@@ -170,16 +171,20 @@ | |
trainer (pl.Trainer): The trainer with the strategy to save the model. | ||
dump_io (bool): If True, the IO configuration will be saved to the output path. | ||
""" | ||
# Import here to avoid circular import | ||
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint | ||
Check notice Code scanning / CodeQL Cyclic import Note
Import of module
nemo.lightning.pytorch.callbacks.model_checkpoint Error loading related location Loading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jstjohn any way to avoid this? |
||
|
||
trainer.strategy._setup_optimizers = False | ||
trainer.strategy._init_model_parallel = False | ||
trainer.strategy.setup(trainer) | ||
trainer.save_checkpoint(output_path) | ||
output_path = Path(output_path) | ||
trainer.save_checkpoint(output_path / ModelCheckpoint.WEIGHTS_PATH) | ||
|
||
from nemo.lightning.io.pl import TrainerContext | ||
from nemo.utils.get_rank import is_global_rank_zero | ||
|
||
if is_global_rank_zero() and dump_io: | ||
TrainerContext.from_trainer(trainer).io_dump(output_path) | ||
TrainerContext.from_trainer(trainer).io_dump(output_path / ModelCheckpoint.CONTEXT_PATH) | ||
|
||
def nemo_load( | ||
self, path: Path, trainer: Optional[pl.Trainer] = None, cpu: bool = True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,8 +57,9 @@ class ModelCheckpoint(PTLModelCheckpoint): | |
async_save: Whether to enable asynchronous checkpointing. | ||
""" | ||
|
||
UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" | ||
WEIGHTS_PATH = "weights" | ||
UNFINISHED_CHECKPOINT_SUFFIX: str = "-unfinished" | ||
WEIGHTS_PATH: str = "weights" | ||
CONTEXT_PATH: str = "context" | ||
|
||
def __init__( | ||
self, | ||
|
@@ -277,7 +278,7 @@ def on_train_end(self, trainer, pl_module): | |
else: | ||
super()._save_last_checkpoint(trainer, monitor_candidates) | ||
if self.save_context_on_train_end and not self.always_save_context and is_global_rank_zero(): | ||
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / "context") | ||
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(self.last_model_path) / self.CONTEXT_PATH) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jstjohn ideally I'd like to avoid changing the checkpoint structure but if we have to do it, let's add a comment giving an example for the use-case and cherry-pick this PR to make it to the 24.09 release. |
||
# Call parent on_train_end() to save the -last checkpoint | ||
super().on_train_end(trainer, pl_module) | ||
|
||
|
@@ -449,7 +450,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) | |
trainer.save_checkpoint(ckpt_filepath, save_weights_only, storage_options=storage_options) | ||
|
||
if self.always_save_context and is_global_rank_zero(): | ||
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / "context") | ||
TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath) / self.CONTEXT_PATH) | ||
|
||
if self.async_save: | ||
logging.info(f'Scheduled async checkpoint save for {filepath}') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jstjohn please add what do you mean here, what would be missing if fn was not set to fitting?