Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Train] Unifying Lightning and AIR CheckpointConfig #36368

Merged

Conversation

woshiyyya
Copy link
Member

@woshiyyya woshiyyya commented Jun 13, 2023

Why are these changes needed?

Currently the users have to specify checkpoint configs twice with same parameters, one in LightningConfigBuilder.checkpointing(), one in AIR CheckpointConfig.

If the users only provides configures in LightningConfigBuilder, by default AIR will save all checkpoints, which is not desired. This PR is intended to automatically generate an AIR CheckpointConfig for the user if they forget to provide one.

(Another minor fix, raised a RuntimeError when users forget to provide a datasets_iter_config with datasets.)

Related issue number

Closes #35920
Closes #36509

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
w
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@woshiyyya woshiyyya marked this pull request as ready for review June 23, 2023 17:36
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments.

One other thing: can we document what should be passed in through LightningTrainer(datasets)? I can't find any mentions of the special "val" key that needs to be passed in to enable evaluation.

Also, we have this EVALUATION_DATASET_KEY that TransformersTrainer uses. Does it make sense to use that instead of "val"?

python/ray/train/tests/test_lightning_trainer.py Outdated Show resolved Hide resolved
python/ray/train/tests/test_lightning_trainer.py Outdated Show resolved Hide resolved
python/ray/train/lightning/lightning_trainer.py Outdated Show resolved Hide resolved
@@ -427,6 +447,17 @@ def _check_checkpoint_configs(
"through `LightningConfigBuilder.checkpointing()`."
)

# Auto-fill the AIR CheckpointConfig if the user didn't specify it.
if air_ckpt_config == CheckpointConfig():
save_top_k = ptl_ckpt_config.get("save_top_k", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now default AIR to save 1 checkpoint when using it with lightning trainer -- is this intentional? PTL only saves 1 checkpoint by default?

Copy link
Member Author

@woshiyyya woshiyyya Jun 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in PTL, the default value of save_top_k is 1. Also, by default monitor is None which saves a checkpoint only for the last epoch. We are applying default behavior of PTL to AIR now.

python/ray/train/lightning/lightning_trainer.py Outdated Show resolved Hide resolved
python/ray/train/lightning/lightning_trainer.py Outdated Show resolved Hide resolved
python/ray/train/lightning/lightning_trainer.py Outdated Show resolved Hide resolved
@@ -427,6 +447,17 @@ def _check_checkpoint_configs(
"through `LightningConfigBuilder.checkpointing()`."
)

# Auto-fill the AIR CheckpointConfig if the user didn't specify it.
if air_ckpt_config == CheckpointConfig():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to be aware of is that if the user does explicitly pass in a default CheckpointConfig, then this will also get overridden. Is that expected?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will be overridden, and this is painful because there's no way to determine whether the users explicitly passed a CheckpointConfig or not. But I think few people will provide an empty CheckpointConfig?
But yeah, I'll think of a workaround for this corner case.

woshiyyya and others added 3 commits June 23, 2023 15:17
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one nit. Want to avoid deepcopy since run config can take in a bunch of user objects like callbacks and stoppers.

Also, any thoughts on the eval dataset thing? We can leave that for a separate PR.

python/ray/train/lightning/lightning_trainer.py Outdated Show resolved Hide resolved
@woshiyyya
Copy link
Member Author

woshiyyya commented Jun 27, 2023

@justinvyu We introduced the val key in the LightningTrainer's docstring. I agree to change the eval dataset key from val to evaluation, but we might decide whether to keep both as valid keys, or raise an error for val.

In my opinion, we should keep both for now, and throw an api change warning if they are using val. Created an issue here, #36873, let's do it in a separate PR.

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
arvind-chandra pushed a commit to lmco/ray that referenced this pull request Aug 31, 2023
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants