-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Train] Unifying Lightning and AIR CheckpointConfig #36368
[Train] Unifying Lightning and AIR CheckpointConfig #36368
Conversation
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments.
One other thing: can we document what should be passed in through LightningTrainer(datasets)
? I can't find any mentions of the special "val"
key that needs to be passed in to enable evaluation.
Also, we have this EVALUATION_DATASET_KEY
that TransformersTrainer
uses. Does it make sense to use that instead of "val"
?
@@ -427,6 +447,17 @@ def _check_checkpoint_configs( | |||
"through `LightningConfigBuilder.checkpointing()`." | |||
) | |||
|
|||
# Auto-fill the AIR CheckpointConfig if the user didn't specify it. | |||
if air_ckpt_config == CheckpointConfig(): | |||
save_top_k = ptl_ckpt_config.get("save_top_k", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now default AIR to save 1 checkpoint when using it with lightning trainer -- is this intentional? PTL only saves 1 checkpoint by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, in PTL, the default value of save_top_k
is 1. Also, by default monitor
is None which saves a checkpoint only for the last epoch. We are applying default behavior of PTL to AIR now.
@@ -427,6 +447,17 @@ def _check_checkpoint_configs( | |||
"through `LightningConfigBuilder.checkpointing()`." | |||
) | |||
|
|||
# Auto-fill the AIR CheckpointConfig if the user didn't specify it. | |||
if air_ckpt_config == CheckpointConfig(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to be aware of is that if the user does explicitly pass in a default CheckpointConfig
, then this will also get overridden. Is that expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will be overridden, and this is painful because there's no way to determine whether the users explicitly passed a CheckpointConfig or not. But I think few people will provide an empty CheckpointConfig
?
But yeah, I'll think of a workaround for this corner case.
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one nit. Want to avoid deepcopy since run config can take in a bunch of user objects like callbacks and stoppers.
Also, any thoughts on the eval dataset thing? We can leave that for a separate PR.
@justinvyu We introduced the In my opinion, we should keep both for now, and throw an api change warning if they are using |
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com> Co-authored-by: matthewdeng <matthew.j.deng@gmail.com> Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
Why are these changes needed?
Currently the users have to specify checkpoint configs twice with same parameters, one in
LightningConfigBuilder.checkpointing()
, one in AIRCheckpointConfig
.If the users only provides configures in
LightningConfigBuilder
, by default AIR will save all checkpoints, which is not desired. This PR is intended to automatically generate an AIR CheckpointConfig for the user if they forget to provide one.(Another minor fix, raised a RuntimeError when users forget to provide a
datasets_iter_config
withdatasets
.)Related issue number
Closes #35920
Closes #36509
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.