-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AIR] <Part 2> Support metric logging and checkpointing for LightningTrainer #33183
[AIR] <Part 2> Support metric logging and checkpointing for LightningTrainer #33183
Conversation
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Co-authored-by: Amog Kamsetty <amogkam@users.noreply.github.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I have a few questions:
Also, do we need to implement a custom restore
method, or is using DataParallelTrainer
's default one ok? The only PTL-specific that you may need to re-specify is lightning_config
. I.e. re-specifying the lightning module class. Can talk offline about this.
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretty nice. a few questions.
super().setup(*args, **kwargs) | ||
self.is_checkpoint_step = False | ||
|
||
def _session_report(self, trainer: "pl.Trainer", stage: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is metrics reporting part of the checkpoint class?
what if I want to report data / iteration, but don't want to create checkpoints?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The context here is, checkpointing and logging are separate logics in Lightning. Checkpoint class can access metrics and checkpoint, but Logger can only access metrics. In order to report checkpoint and metrics together, we implement reporting in checkpoint class.
For logging, we recommend the users keep using lightning's native Loggers(e.g. wandb, mlflow, tensorboard loggers). They can control the logging frequency by themselves and retrieve logs as usual, which is less intrusive and aligns better with user habits.
# Report latest logged metrics | ||
metrics = {"report_on": stage} | ||
for k, v in self._monitor_candidates(trainer).items(): | ||
if k == "report_on": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't feel particularly safe about this. why do we have this keyword, and it's not even __ prefixed ...
also we may be logging this warning msg every time this is called.
is there a list of such keywords defined somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I tried but didn't relevant keys in ray.air.constants
. I removed this warning msg and changed the key name to __report_on
now.
@@ -308,6 +309,17 @@ def __init__( | |||
preprocessor: Optional[Preprocessor] = None, | |||
resume_from_checkpoint: Optional[Checkpoint] = None, | |||
): | |||
run_config = run_config or RunConfig() | |||
lightning_config = lightning_config or LightningConfigBuilder().build() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait, do we want people to pass in a dict or the raw Builder()?
the type annotation says it's a raw dict already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We expect them to pass a dict, but they can also choose not to and specify it later in Tuner's param_space
. We will check if lightning_config is provided in the training_loop_per_worker.
For example:
lightning_config = (
LightningConfigBuilder()
.module(MNISTClassifier, feature_dim=128, lr=tune.grid_search([0.1, 0.01, 0.001]))
.trainer(max_epochs=3, accelerator="cpu")
.fit_params(datamodule=datamodule)
.build()
)
trainer = LightningTrainer(scaling_config=scaling_config)
tuner = tune.Tuner(
trainer,
param_space={"lightning_config": lightning_config}
...
)
) | ||
|
||
# Disable strict checking to allow metric reporting at different frequencies | ||
os.environ["TUNE_DISABLE_STRICT_METRIC_CHECKING"] = "1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😓
any chance we can disable the csv writer so we don't have to do this hack here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh actually this metrics check is not in csv writer, but in trial_runner and is triggered every time we call session.report
. I agree this is hacky, but it seems to be the only way we can skip this check now. 😅
fafc218
to
55918b9
Compare
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp for docs
This code is causing documentation to fail |
…Trainer (ray-project#33183) Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
…Trainer (ray-project#33183) Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
…Trainer (ray-project#33183) Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> Signed-off-by: elliottower <elliot@elliottower.com>
…Trainer (ray-project#33183) Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com> Signed-off-by: Jack He <jackhe2345@gmail.com>
Why are these changes needed?
There will be a list of PRs for PyTorch Lightning Integration. This is the second one.
Content for this PR:
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.