-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MultiGPU training + changes to Checkpointing logic #218
Conversation
@@ -16,7 +16,7 @@ def __init__(self, num_hidden: int) -> None: | |||
# Model hyperparameters as well as the loss function need to registered via RenateModule's | |||
# constructor, see documentation. Otherwise, this is a standard torch model. | |||
super().__init__( | |||
constructor_arguments={"num_hidden": num_hidden}, loss_fn=torch.nn.CrossEntropyLoss() | |||
constructor_arguments={"num_hidden": num_hidden} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should fit on single line?
src/renate/utils/misc.py
Outdated
|
||
|
||
def int_or_str(x: str) -> Union[str, int]: | ||
"""Function to cast to int or str. This is used to tackle precision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Single first line of doc string.
@@ -254,6 +270,11 @@ def get_renate_module_mlp( | |||
) | |||
|
|||
|
|||
@pytest.helpers.register |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems overkill for a single line ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is reused subsequently.
Signed-off-by: Prabhu Teja S <prabhuteja12@gmail.com>
"""Returns the state of the learner.""" | ||
return { | ||
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | ||
learner_state_dict = { | ||
"learner_class_name": self.__class__.__name__, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this is not used anymore. I don't mind leaving it in, but we could remove it. Your call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it isnt. But I left it in as a sanity check.
src/renate/updaters/learner.py
Outdated
@@ -411,32 +379,42 @@ def __init__( | |||
**kwargs, | |||
) -> None: | |||
super().__init__(seed=seed, **kwargs) | |||
self.save_hyperparameters( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about this. Based on playing with a toy example, I think the save_hyperparameters
call in the base class might be enough. Let's talk about it offline.
src/renate/updaters/model_updater.py
Outdated
pl_module.load_state_dict(self._model, torch.load(learner_state_path)["state_dict"]) | ||
loaded_state = trainer.strategy.load_checkpoint(learner_state_path) | ||
pl_module.on_load_checkpoint(loaded_state) | ||
# This loads the state dict only if its not Deepspeed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsure what the comment refers to. Can you be more specific about which part is needed for Deepspeed?
# Finalize model update. | ||
pl_module.on_model_update_end() | ||
# Save permanently. | ||
pl_module.save(self._output_state_folder) | ||
# Overwrite checkpoint. | ||
self._save_checkpoint(trainer, learner_state_path) | ||
|
||
def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A doc string would be helpful here. I don't exactly follow what we do here. Is the goal to have a separate file that just contains the model weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. added a detailed description
We need to check the doc strings. I saw that the |
Signed-off-by: Prabhu Teja S <prabhuteja12@gmail.com>
This is WIP to rework checkpointing and the multiGPU training in Renate.
Description of changes:
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.