You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We should allow Callback objects to optionally persist state that can be reloaded from checkpoints.
Motivation
We already manually save the state for early stopping and model checkpoint callbacks. This refactor would eliminate callback-specific code in the Trainer and extend the ability to user-written callbacks.
Pitch
This callback would just return a state_dict which the Trainer could store. The only thing that I am unclear how we should handle is for other callbacks how we want to reinitialize the state. If we can expect that the same exact callbacks will be passed to the Trainer then it should be trivial. Or we could expect that you only pass in a single instance of each callback class (eg. callbacks=[CustomerLogger(), EarlyStopping(), ModelCheckpoint()] and not callbacks=[CustomerLogger(params_a), CustomerLogger(params_b), EarlyStopping(), ModelCheckpoint()] and just keep a mapping of callback class to state dicts. However, if the user passed multiple callback instances of the same class I'm not sure how we would want to handle that.
I would recommend that we document the following constraints:
All objects in the dictionary must be pickle-able.
You cannot persist multiple instances of the same callback class.
The text was updated successfully, but these errors were encountered:
I think it is reasonable to assume there is only one instance of these special callbacks (and should raise error otherwise, e.g. see progress bar callback). Note that currently the logger is not a callback.
Also, the documentation for callbacks should probably let the user know that the the order of the list input to the Trainer is not preserved (e.g. the Trainer should reorder it so that the earlystopping callback comes after checkpoint, right?)
🚀 Feature
We should allow Callback objects to optionally persist state that can be reloaded from checkpoints.
Motivation
We already manually save the state for early stopping and model checkpoint callbacks. This refactor would eliminate callback-specific code in the Trainer and extend the ability to user-written callbacks.
Pitch
This callback would just return a state_dict which the Trainer could store. The only thing that I am unclear how we should handle is for other callbacks how we want to reinitialize the state. If we can expect that the same exact callbacks will be passed to the Trainer then it should be trivial. Or we could expect that you only pass in a single instance of each callback class (eg. callbacks=[CustomerLogger(), EarlyStopping(), ModelCheckpoint()] and not callbacks=[CustomerLogger(params_a), CustomerLogger(params_b), EarlyStopping(), ModelCheckpoint()] and just keep a mapping of callback class to state dicts. However, if the user passed multiple callback instances of the same class I'm not sure how we would want to handle that.
I would recommend that we document the following constraints:
The text was updated successfully, but these errors were encountered: