Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] <Part 2> Support metric logging and checkpointing for LightningTrainer #33183

Merged

Conversation

woshiyyya
Copy link
Member

Why are these changes needed?

There will be a list of PRs for PyTorch Lightning Integration. This is the second one.

  • LightningTrainer + Test
  • Logging and Checkpointing + Test
  • Predictor + Batch Prediction + Test
  • Integrate with Tune + Test

Content for this PR:

  • Implemented RayModelCheckpoint callback to report latest metrics and checkpoint to trainer session
  • Implemented LightningCheckpoint, a subclass of TorchCheckpoint.
  • CI Tests

Related issue number

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@woshiyyya woshiyyya marked this pull request as ready for review March 9, 2023 22:54
@woshiyyya woshiyyya requested review from amogkam, justinvyu and Yard1 March 9, 2023 22:54
@woshiyyya woshiyyya added train Ray Train Related Issue air labels Mar 9, 2023
woshiyyya and others added 14 commits March 9, 2023 17:17
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com>
Co-authored-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Copy link
Contributor

@justinvyu justinvyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I have a few questions:

Also, do we need to implement a custom restore method, or is using DataParallelTrainer's default one ok? The only PTL-specific that you may need to re-specify is lightning_config. I.e. re-specifying the lightning module class. Can talk offline about this.

python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/lightning_trainer.py Outdated Show resolved Hide resolved
python/ray/train/lightning/lightning_trainer.py Outdated Show resolved Hide resolved
python/ray/train/lightning/lightning_trainer.py Outdated Show resolved Hide resolved
python/ray/train/lightning/lightning_trainer.py Outdated Show resolved Hide resolved
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty nice. a few questions.

python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
python/ray/train/lightning/_lightning_utils.py Outdated Show resolved Hide resolved
super().setup(*args, **kwargs)
self.is_checkpoint_step = False

def _session_report(self, trainer: "pl.Trainer", stage: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is metrics reporting part of the checkpoint class?
what if I want to report data / iteration, but don't want to create checkpoints?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context here is, checkpointing and logging are separate logics in Lightning. Checkpoint class can access metrics and checkpoint, but Logger can only access metrics. In order to report checkpoint and metrics together, we implement reporting in checkpoint class.

For logging, we recommend the users keep using lightning's native Loggers(e.g. wandb, mlflow, tensorboard loggers). They can control the logging frequency by themselves and retrieve logs as usual, which is less intrusive and aligns better with user habits.

# Report latest logged metrics
metrics = {"report_on": stage}
for k, v in self._monitor_candidates(trainer).items():
if k == "report_on":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't feel particularly safe about this. why do we have this keyword, and it's not even __ prefixed ...
also we may be logging this warning msg every time this is called.
is there a list of such keywords defined somewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I tried but didn't relevant keys in ray.air.constants. I removed this warning msg and changed the key name to __report_on now.

@@ -308,6 +309,17 @@ def __init__(
preprocessor: Optional[Preprocessor] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
):
run_config = run_config or RunConfig()
lightning_config = lightning_config or LightningConfigBuilder().build()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, do we want people to pass in a dict or the raw Builder()?
the type annotation says it's a raw dict already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect them to pass a dict, but they can also choose not to and specify it later in Tuner's param_space. We will check if lightning_config is provided in the training_loop_per_worker.

For example:

lightning_config = (
    LightningConfigBuilder()
    .module(MNISTClassifier, feature_dim=128, lr=tune.grid_search([0.1, 0.01, 0.001]))
    .trainer(max_epochs=3, accelerator="cpu")
    .fit_params(datamodule=datamodule)
    .build()
)

trainer = LightningTrainer(scaling_config=scaling_config)

tuner = tune.Tuner(
    trainer,
    param_space={"lightning_config": lightning_config}
    ...
)

)

# Disable strict checking to allow metric reporting at different frequencies
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😓
any chance we can disable the csv writer so we don't have to do this hack here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh actually this metrics check is not in csv writer, but in trial_runner and is triggered every time we call session.report. I agree this is hacky, but it seems to be the only way we can skip this check now. 😅

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
@woshiyyya woshiyyya force-pushed the air/lightning_log_and_checkpoint branch from fafc218 to 55918b9 Compare March 17, 2023 22:05
Copy link
Contributor

@richardliaw richardliaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp for docs

@gjoliver gjoliver merged commit aa02993 into ray-project:master Mar 21, 2023
@ollie-iterators
Copy link

ollie-iterators commented Mar 21, 2023

This code is causing documentation to fail

edoakes pushed a commit to edoakes/ray that referenced this pull request Mar 22, 2023
…Trainer (ray-project#33183)

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
clarng pushed a commit to clarng/ray that referenced this pull request Mar 23, 2023
…Trainer (ray-project#33183)

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
elliottower pushed a commit to elliottower/ray that referenced this pull request Apr 22, 2023
…Trainer (ray-project#33183)

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: elliottower <elliot@elliottower.com>
ProjectsByJackHe pushed a commit to ProjectsByJackHe/ray that referenced this pull request May 4, 2023
…Trainer (ray-project#33183)

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Jack He <jackhe2345@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
train Ray Train Related Issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants