Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming should allow to differentiate what to resume (steps/opti/weights) #5339

Open
thoglu opened this issue Jan 3, 2021 · 25 comments
Open
Labels
feature Is an improvement or enhancement help wanted Open to be worked on priority: 1 Medium priority task
Milestone

Comments

@thoglu
Copy link

thoglu commented Jan 3, 2021

Currently it is possible to either resume only the full training state (epoch/global steps / optimizer / scheduler options / and weights), or only the weights.

I would like to be able to switch the optimizer at some point, i.e. skip restoring optimizer/scheduler, but still load the epoch/global steps. I only see a way to do this with hacks at the moment. Any other way? Could this be a feature, to specify in the trainer.init function specifically what to restore/what not to restore?

@thoglu thoglu added feature Is an improvement or enhancement help wanted Open to be worked on labels Jan 3, 2021
@rohitgr7
Copy link
Contributor

rohitgr7 commented Jan 3, 2021

one way this can be done is by updating the error to just a warning and make relevant changes below:
https://github.com/PyTorchLightning/pytorch-lightning/blob/d20fd8e5ab1a52747fee2cd53290a679d8b726d0/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L120-L124

cc: @PyTorchLightning/core-contributors thoughts?

@tchaton
Copy link
Contributor

tchaton commented Jan 4, 2021

Hey @rohitgr7,

I am not sure resuming and changing optimizer / scheduler is the best option.

Possibly, we could extend LightningModule configure_optimizers to take current_epoch.

def configure_optimizers(self, current_epoch: int = None):


     return [optimizers], [lr_schedulers], should_update: bool

Internally, we inspect the model function. If the user requested current_epoch, we call this function on every_epoch and change its optimizers, lr_schedulers only if should_update = True

Example:

def configure_optimizers(self, current_epoch: int = None):
    milestones = [0, 100, ...]
    optimizers = None
    lr_schedulers = None
    if current_epoch == milestones[0]:
        # init optimizers / schedulers 

    elif current_epoch == milestones[-1]:
        # init new optimizers / schedulers 

    ...

    return [optimizers], [lr_schedulers], current_epoch  in current_epoch milestones

What are your thoughts ?

@rohitgr7
Copy link
Contributor

rohitgr7 commented Jan 4, 2021

@tchaton yeah this is a good suggestion, but what if someone sets save_weights_only in ModelCheckpoint and uses resume_from_checkpoint using that checkpoint? In such a case this is still a problem.

@thoglu
Copy link
Author

thoglu commented Jan 4, 2021

I brought this up because the current options (save_weights_only or not) are a little too restrictive. I have found a hack for now for me but there is I think a more widespread use-case to be able to change the optimizers, as SGD for example has certain properties that Adam has not etc., and one might want to exploit those at different stages of training.

@carmocca
Copy link
Contributor

carmocca commented Jan 4, 2021

What are your thoughts ?

I think that's too complicated and its complexity will grow too large if we wanted to do it in different occasions other than epoch

In general, I think we should let users save any combination of {training, optimizer, model} state and resume training with whichever combination is provided.

@stale stale bot added the won't fix This will not be worked on label Feb 3, 2021
@thoglu
Copy link
Author

thoglu commented Feb 4, 2021

Is there any progress how/when such a combinatorial choice for the users will be implemented?

@stale stale bot removed the won't fix This will not be worked on label Feb 4, 2021
@carmocca
Copy link
Contributor

carmocca commented Feb 5, 2021

No current progress/plans that I know of. Implementing this will require attention

@carmocca carmocca added this to the 1.3 milestone Feb 5, 2021
@carmocca carmocca added the priority: 2 Low priority task label Feb 5, 2021
@edenlightning edenlightning removed this from the 1.3 milestone Feb 22, 2021
@stale stale bot added the won't fix This will not be worked on label Mar 25, 2021
@Lightning-AI Lightning-AI deleted a comment from stale bot Mar 25, 2021
@stale stale bot removed the won't fix This will not be worked on label Mar 25, 2021
@Lightning-AI Lightning-AI deleted a comment from stale bot Mar 25, 2021
@carmocca
Copy link
Contributor

@ananthsub what do you think about this? could be included in our plans for better fault-tolerant training

@Alexfinkelshtein
Copy link

@carmocca , related but not addressing the exact same matter, what about exposing the strict flag of the torch.nn.Module.load_state_dict?
I have a use-case where my model has some buffers (non-parametric) that sometimes I need to save (to have a standalone version) and sometimes I don't, in favor of lean checkpoints. Most of the buffers are initialized on model init.
alternatively introducing a restore_model_state hook (instead of the on_load_checkpoint) can also enable this when I can explicitly have a custom model load..
what do you think?

@carmocca
Copy link
Contributor

@carmocca , related but not addressing the exact same matter, what about exposing the strict flag of the torch.nn.Module.load_state_dict?

You can already set the strict flag in load_from_checkpoint:

https://github.com/PyTorchLightning/pytorch-lightning/blob/d12c6cf2b358c989d0d8bc17018049def99d6129/pytorch_lightning/core/saving.py#L62

which is passed to:

https://github.com/PyTorchLightning/pytorch-lightning/blob/d12c6cf2b358c989d0d8bc17018049def99d6129/pytorch_lightning/core/saving.py#L205

@Alexfinkelshtein
Copy link

Alexfinkelshtein commented Apr 20, 2021

@carmocca, thanks for the comment.
I am familiar with this, however I would want to have access to the benefits of resuming the training state.. which I don't using the above..

@Alexfinkelshtein
Copy link

@carmocca , thumbs up means you agree exposing it will be a good move?

@carmocca
Copy link
Contributor

carmocca commented Apr 21, 2021

I agree there should be a way to set strict=False with Trainer(resume_from_checkpoint)

The relevant piece of code is here:

https://github.com/PyTorchLightning/pytorch-lightning/blob/e4f3a8d3dd534d4ec2fe094280272513e652fba9/pytorch_lightning/plugins/training_type/training_type_plugin.py#L226

One idea would be to have a property in the LightningModule to change this. Similar to how automatic_optimization works. To me, whether we need to load strict or non-strict usually depends on the model.

Another option would be to save whether to load strict or not inside the checkpoint itself and do:

model.load_state_dict(checkpoint['state_dict'], strict=checkpoint.get('strict', True))

Any thoughts?

@Alexfinkelshtein
Copy link

@carmocca , somehow it seems like my comment disappeared.. I'll write it again
I find this suggestion sub-optimal, and the main reason is that when a checkpoint is saved you don't necessarily know the mode in which you would want to load it. It is true that you can then manually use the on_load_checkpoint hook and adjust that but this is rather hacky.
I believe a cleaner solution will either expose the flag through the trainer directly or enable a new hook that enables custom model loading.

@carmocca
Copy link
Contributor

I believe a cleaner solution will either expose the flag through the trainer directly

I personally don't like flags which only work if others are active, so I'd rather avoid this solution

enable a new hook that enables custom model loading.

Are you talking about a hook that encapsulates the load_state_dict call?

@Alexfinkelshtein
Copy link

@carmocca ,
"Are you talking about a hook that encapsulates the load_state_dict call?"
yes exactly

@stale stale bot added the won't fix This will not be worked on label Jun 18, 2021
@awaelchli awaelchli added this to the v1.4 milestone Jun 18, 2021
@awaelchli awaelchli added priority: 1 Medium priority task and removed priority: 2 Low priority task labels Aug 19, 2021
@ananthsub
Copy link
Contributor

ananthsub commented Sep 9, 2021

@awaelchli I think this is can be viewed as an extention to #9405

One half-baked idea is to specify a dataclass around what parts of the checkpoint should be loaded

@dataclass
CheckpointLoadOptions:
    ckpt_path: _PATH
    load_callbacks: bool = True
    load_optimizer_states: bool = True
    load_loops: bool = True

then trainer.fit/validate/test/predict accept this dataclass instead of ckpt_path directly
However, this makes instantiation & configuration more difficult

@imanuelroz
Copy link

@tchaton yeah this is a good suggestion, but what if someone sets save_weights_only in ModelCheckpoint and uses resume_from_checkpoint using that checkpoint? In such a case this is still a problem.

@tchaton yeah this is a good suggestion, but what if someone sets save_weights_only in ModelCheckpoint and uses resume_from_checkpoint using that checkpoint? In such a case this is still a problem.

I am having this problem, I set save_weights_only=True and now if I try to do resume_from_checkpoint= '/path/file.ckpt' it raise the following error: 'Trying to restore training state but checkpoint contains only the model. This is probably due to ModelCheckpoint.save_weights_only being set to True.
What should I do?

@Ir1d
Copy link
Contributor

Ir1d commented Jan 14, 2022

load_from_checkpoint can be set strict=False, it'll be really great if resume_from_checkpoint can also have strict=False

@arsedler9
Copy link

@carmocca I'm training models with PTL using ray.tune's implementation of Population-based Training. Checkpointing is done with trainer.save_checkpoint via the TuneReportCheckpointCallback. With the current setup, I can't tune the learning rate with PBT because calling trainer.fit(..., ckpt_path=ckpt_path) overwrites any new optimizer state with the state from the previous model's checkpoint. Restoring the model weights only via model = ModelClass.load_from_checkpoint(ckpt_path) isn't a good solution because current_epoch is reset to zero when I don't pass ckpt_path to fit, which disrupts logging. These limitations seem to prohibit training PTL models with PBT. Do you have any suggestions?

@carmocca
Copy link
Contributor

One hacky way to do this currently would be to override the optimizer_states key from the checkpoint so this piece of code does not run

https://github.com/PyTorchLightning/pytorch-lightning/blob/dbd69b9a09f3954297d78996389207501e96cd1b/pytorch_lightning/strategies/strategy.py#L322-L323

@turian
Copy link
Contributor

turian commented Sep 11, 2022

Here's a model doing linear warmup for 5 hours but then the cosine annealing base_lr is too high so it diverges. I wish I could have played with that base_lr rather than retraining from scratch:

image

Here's a model that warmsdown (from LR 10 to 3.2) for 10 hours, and then does cosine annealing with base_lr 1.0 and some cosine cycle schedule.

image

Here is a model that does warmup but then the cosine cycle appears too big:

image

@samvanstroud
Copy link

Any update on this? When using the OneCycleLR scheduler it is not possible to resume a training since the number of steps is exceeded. It would be great to be able to restart the scheduler but keep the epoch and step info. As a workaround I am just loading the weights following the approach in #16014 (comment)

@IngLP
Copy link

IngLP commented Sep 7, 2023

Any updates? I have the same problem, cannot change Optimizer/LR on a model when I resume training.

@lukasschmit
Copy link

Any updates on this? splitting up opt-state/lr-schedule/global-step/weights is necessary for serious training setups where e.g. schedule is tuned on the fly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on priority: 1 Medium priority task
Projects
None yet
Development

No branches or pull requests