Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiGPU training + changes to Checkpointing logic #218

Merged
merged 68 commits into from
May 25, 2023
Merged

Conversation

prabhuteja12
Copy link
Contributor

This is WIP to rework checkpointing and the multiGPU training in Renate.

Description of changes:

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@@ -16,7 +16,7 @@ def __init__(self, num_hidden: int) -> None:
# Model hyperparameters as well as the loss function need to registered via RenateModule's
# constructor, see documentation. Otherwise, this is a standard torch model.
super().__init__(
constructor_arguments={"num_hidden": num_hidden}, loss_fn=torch.nn.CrossEntropyLoss()
constructor_arguments={"num_hidden": num_hidden}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should fit on single line?

examples/nlp_finetuning/renate_config.py Outdated Show resolved Hide resolved
requirements.txt Outdated Show resolved Hide resolved
src/renate/benchmark/models/base.py Outdated Show resolved Hide resolved
src/renate/cli/parsing_functions.py Outdated Show resolved Hide resolved


def int_or_str(x: str) -> Union[str, int]:
"""Function to cast to int or str. This is used to tackle precision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single first line of doc string.

src/renate/utils/misc.py Outdated Show resolved Hide resolved
@@ -254,6 +270,11 @@ def get_renate_module_mlp(
)


@pytest.helpers.register
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems overkill for a single line ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is reused subsequently.

test/conftest.py Show resolved Hide resolved
doc/getting_started/how_to_renate_config.rst Show resolved Hide resolved
examples/getting_started/renate_config.py Outdated Show resolved Hide resolved
src/renate/updaters/experimental/er.py Outdated Show resolved Hide resolved
src/renate/updaters/learner.py Outdated Show resolved Hide resolved
"""Returns the state of the learner."""
return {
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
learner_state_dict = {
"learner_class_name": self.__class__.__name__,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this is not used anymore. I don't mind leaving it in, but we could remove it. Your call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it isnt. But I left it in as a sanity check.

@@ -411,32 +379,42 @@ def __init__(
**kwargs,
) -> None:
super().__init__(seed=seed, **kwargs)
self.save_hyperparameters(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this. Based on playing with a toy example, I think the save_hyperparameters call in the base class might be enough. Let's talk about it offline.

pl_module.load_state_dict(self._model, torch.load(learner_state_path)["state_dict"])
loaded_state = trainer.strategy.load_checkpoint(learner_state_path)
pl_module.on_load_checkpoint(loaded_state)
# This loads the state dict only if its not Deepspeed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure what the comment refers to. Can you be more specific about which part is needed for Deepspeed?

# Finalize model update.
pl_module.on_model_update_end()
# Save permanently.
pl_module.save(self._output_state_folder)
# Overwrite checkpoint.
self._save_checkpoint(trainer, learner_state_path)

def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A doc string would be helpful here. I don't exactly follow what we do here. Is the goal to have a separate file that just contains the model weights?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. added a detailed description

@lballes
Copy link
Contributor

lballes commented May 24, 2023

We need to check the doc strings. I saw that the RenateModule doc string still contains the loss_fn argument. Can you make sure that all doc strings reflect the change? I.e., remove it from RenateModule and its subclasses and add it to Learner and its subclasses.

lballes
lballes previously approved these changes May 25, 2023
Signed-off-by: Prabhu Teja S <prabhuteja12@gmail.com>
lballes
lballes previously approved these changes May 25, 2023
@wistuba wistuba merged commit 19a2271 into dev May 25, 2023
@wistuba wistuba deleted the checkpointing branch May 25, 2023 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants