Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] <Part 2> Support metric logging and checkpointing for LightningTrainer #33183

Merged
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4c8edc1
init Lightning Trainer with ci tests
woshiyyya Mar 9, 2023
e0fd5d9
add test for trainer with categorical ray dataset
woshiyyya Mar 9, 2023
ebea347
fix lightiningEnvironment import error
woshiyyya Mar 9, 2023
01f9461
fix test_lightning_trainer not found error
woshiyyya Mar 9, 2023
22725b6
format code
woshiyyya Mar 9, 2023
b3ab873
check linter
woshiyyya Mar 9, 2023
6ba0577
init LightningCheckpoint
woshiyyya Mar 9, 2023
a9d8ab6
fix linting issues
woshiyyya Mar 10, 2023
dde861d
fix linting issues
woshiyyya Mar 10, 2023
55b7916
fix lint again
woshiyyya Mar 10, 2023
a75fec6
Change lightning_config to dict type to aligh with Ray Tune
woshiyyya Mar 10, 2023
e1b9d81
Change lightning_config to dict type to aligh with Ray Tune
woshiyyya Mar 10, 2023
2e6500f
fix non_monitored_checkpoint saving error
woshiyyya Mar 10, 2023
895ab43
Apply suggestions from code review
woshiyyya Mar 13, 2023
f76f48d
replace LightningConfigBuilder setter with semantic meanings
woshiyyya Mar 13, 2023
c1448db
add fixture and builder_tests
woshiyyya Mar 13, 2023
66ab606
add example code for config builder, add api index
woshiyyya Mar 13, 2023
2842b14
fix lint, change test tag to large
woshiyyya Mar 13, 2023
38cd101
remove resume from ckpt for next PR
woshiyyya Mar 13, 2023
5210cb7
rename trainer.rst to avoid cross reference
woshiyyya Mar 13, 2023
300ba71
Merge remote-tracking branch 'upstream/master' into air/lightning_bas…
woshiyyya Mar 13, 2023
1935473
address comments
woshiyyya Mar 14, 2023
55a7b32
fix document issue
woshiyyya Mar 14, 2023
6598894
Merge remote-tracking branch 'upstream/master' into air/lightning_bas…
woshiyyya Mar 14, 2023
171a5e8
Merge remote-tracking branch 'origin/air/lightning_base_trainer' into…
woshiyyya Mar 14, 2023
247ba48
align report and checkpointing frequency
woshiyyya Mar 15, 2023
71b24fa
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 15, 2023
6eea366
convert air ckpt config to lightning config
woshiyyya Mar 15, 2023
8fe5b05
add checks for checkpoint config
woshiyyya Mar 15, 2023
cc8d3e8
check tuning target metric
woshiyyya Mar 15, 2023
e606682
fix tuning metric logic
woshiyyya Mar 15, 2023
5cc5a42
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 15, 2023
7bac4e7
finalize logging logic
woshiyyya Mar 16, 2023
b009e0f
remove tuning metric setup
woshiyyya Mar 16, 2023
ff89639
disable metric checking by default
woshiyyya Mar 16, 2023
40cc5dd
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 16, 2023
9403bf5
address the review comments
woshiyyya Mar 17, 2023
a88e57a
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 17, 2023
c57c310
add doc API reference
woshiyyya Mar 17, 2023
55918b9
fixing CI
woshiyyya Mar 17, 2023
d4be393
add comments on strict check env var
woshiyyya Mar 20, 2023
9bb4136
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/train/api/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ PyTorch Lightning

~train.lightning.LightningTrainer
~train.lightning.LightningConfigBuilder

~train.lightning.LightningCheckpoint

Tensorflow/Keras
~~~~~~~~~~~~~~~~
Expand Down
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_lightning_checkpoint",
size = "medium",
srcs = ["tests/test_lightning_checkpoint.py"],
tags = ["team:ml", "exclusive", "ray_air", "gpu"],
deps = [":train_lib"]
)

py_test(
name = "test_lightning_trainer",
size = "large",
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
)
# isort: on

from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.lightning.lightning_trainer import (
LightningTrainer,
LightningConfigBuilder,
)


__all__ = ["LightningTrainer", "LightningConfigBuilder"]
__all__ = ["LightningTrainer", "LightningConfigBuilder", "LightningCheckpoint"]
81 changes: 76 additions & 5 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import logging
import shutil
import torch
from typing import Any, Dict, Optional

import tempfile
import pytorch_lightning as pl

from typing import Any, Dict, Optional
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.plugins.environments import LightningEnvironment

import ray
from ray.air import session

from ray.air.constants import MODEL_KEY
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from torch.utils.data import IterableDataset, DataLoader
from ray.data.dataset import DatasetIterator

Expand Down Expand Up @@ -46,11 +50,9 @@ def node_rank(self) -> int:
return session.get_node_rank()

def set_world_size(self, size: int) -> None:
logger.warning("world_size setter is disabled in AIR LightningTrainer.")
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
pass

def set_global_rank(self, rank: int) -> None:
logger.warning("global_rank setter is disabled in AIR LightningTrainer.")
pass

def teardown(self):
Expand Down Expand Up @@ -94,3 +96,72 @@ def _val_dataloader() -> DataLoader:
# setting, we only override this method when `val_dataset` is not `None`.
if val_dataset:
self.val_dataloader = _val_dataloader


class RayModelCheckpoint(ModelCheckpoint):
"""
AIR customized ModelCheckpoint callback.

A subclass of ``pytorch_lightning.callbacks.ModelCheckpoint``.
This callback function reports the latest metrics to the AIR session and
creates an AIR checkpoint whenever a lightning checkpoint is saved.
"""

def setup(self, *args, **kwargs) -> None:
super().setup(*args, **kwargs)
self.is_checkpoint_step = False

def _session_report(self, trainer: "pl.Trainer", stage: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is metrics reporting part of the checkpoint class?
what if I want to report data / iteration, but don't want to create checkpoints?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context here is, checkpointing and logging are separate logics in Lightning. Checkpoint class can access metrics and checkpoint, but Logger can only access metrics. In order to report checkpoint and metrics together, we implement reporting in checkpoint class.

For logging, we recommend the users keep using lightning's native Loggers(e.g. wandb, mlflow, tensorboard loggers). They can control the logging frequency by themselves and retrieve logs as usual, which is less intrusive and aligns better with user habits.

"""Report latest metrics dict and checkpoint to AIR training session.

This method is called whenever a new checkpoint is created. It creates
a `LightningCheckpoint` and reports it to the AIR session along with
the latest metrics.
"""

# Align the frequency of checkpointing and logging
if not self.is_checkpoint_step:
return

# Report latest logged metrics
metrics = {"report_on": stage}
for k, v in self._monitor_candidates(trainer).items():
if k == "report_on":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't feel particularly safe about this. why do we have this keyword, and it's not even __ prefixed ...
also we may be logging this warning msg every time this is called.
is there a list of such keywords defined somewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I tried but didn't relevant keys in ray.air.constants. I removed this warning msg and changed the key name to __report_on now.

logger.warning(
"'report_on' is a reserved key in AIR report metrics. "
"Original values are overwritten!"
)
continue
if isinstance(v, torch.Tensor):
metrics[k] = v.item()

# Report latest saved checkpoint
# Note that AIR only takes the checkpoint of rank 0.
# Save a dummy checkpoint on the other workers to avoid blocking.
with tempfile.TemporaryDirectory() as tmpdir:
if trainer.global_rank == 0:
shutil.copy(self.last_model_path, f"{tmpdir}/{MODEL_KEY}")
checkpoint = LightningCheckpoint.from_directory(path=tmpdir)
else:
checkpoint = LightningCheckpoint.from_dict(
{"rank": session.get_world_rank()}
)
session.report(metrics=metrics, checkpoint=checkpoint)

self.is_checkpoint_step = False

def _save_last_checkpoint(self, *args, **kwargs) -> None:
super()._save_last_checkpoint(*args, **kwargs)
self.is_checkpoint_step = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we actually need this. Let's just return early in the cases where no checkpoint is found. Better not to mess with this private method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that override this method is not elegant, but it seems to be the only way. The name of the last checkpoint is always "last.ckpt", so we can't determine if there's a new checkpoint based on the file name alone.

To give you more context, the logic of ModelCheckpoint.on_train_batch_end() is like:

def on_train_batch_end():
    #####################################################################
    # Code block that determines whether this is a checkpoint step.
    # return if it's not.
    #####################################################################
    self._save_top_k_checkpoint()
    self._save_last_checkpoint()

Since lightning did not modularize the code block into a self.should_checkpoint() function, we can only track whether _save_top_k_checkpoint() and _save_last_checkpoint() are being called.


def on_train_batch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
super().on_train_batch_end(trainer, *args, **kwargs)
self._session_report(trainer=trainer, stage="train_batch_end")

def on_train_epoch_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
super().on_train_epoch_end(trainer, *args, **kwargs)
self._session_report(trainer=trainer, stage="train_epoch_end")

def on_validation_end(self, trainer: "pl.Trainer", *args, **kwargs) -> None:
super().on_validation_end(trainer, *args, **kwargs)
self._session_report(trainer=trainer, stage="validation_end")
100 changes: 100 additions & 0 deletions python/ray/train/lightning/lightning_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import logging
import pytorch_lightning as pl
import tempfile
import shutil

from inspect import isclass
from typing import Optional, Type

from ray.air.constants import MODEL_KEY
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.data import Preprocessor
from ray.train.torch import TorchCheckpoint
from ray.util.annotations import PublicAPI

logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
class LightningCheckpoint(TorchCheckpoint):
"""A :class:`~ray.air.checkpoint.Checkpoint` with Lightning-specific functionality.

LightningCheckpoint only support file based checkpoint loading.
Create this by calling ``LightningCheckpoint.from_directory(ckpt_dir)``,
``LightningCheckpoint.from_uri(uri)`` or ``LightningCheckpoint.from_path(path)``

LightningCheckpoint loads file named ``model`` under the specified directory.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._cache_dir = None

@classmethod
def from_path(
cls,
path: str,
*,
preprocessor: Optional["Preprocessor"] = None,
) -> "LightningCheckpoint":
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"""Create a ``ray.air.lightning.LightningCheckpoint`` from a checkpoint path.

Args:
path: The file path to the PyTorch Lightning checkpoint.
preprocessor: A fitted preprocessor to be applied before inference.

Returns:
An :py:class:`LightningCheckpoint` containing the model.

Examples:
>>> from ray.train.lightning import LightningCheckpoint
>>>
>>> checkpoint = LightningCheckpoint.from_path("/path/to/checkpoint.ckpt")
"""

assert os.path.exists(path), f"Lightning checkpoint {path} doesn't exists!"

cache_dir = tempfile.mkdtemp()
new_checkpoint_path = os.path.join(cache_dir, MODEL_KEY)
shutil.copy(path, new_checkpoint_path)
if preprocessor:
save_preprocessor_to_dir(preprocessor, cache_dir)
checkpoint = cls.from_directory(cache_dir)
checkpoint._cache_dir = cache_dir
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
return checkpoint

def get_model(
self, model_class: Type[pl.LightningModule], **load_from_checkpoint_kwargs
) -> pl.LightningModule:
"""Retrieve the model stored in this checkpoint.

Args:
model_class: A subclass of ``pytorch_lightning.LightningModule`` that
defines your model and training logic.
load_from_checkpoint_kwargs: Arguments to pass into
``pl.Trainer.load_from_checkpoint``

Returns:
pl.LightningModule: An instance of the loaded model.
"""
if not isclass(model_class):
raise ValueError(
"'model_class' must be a class, not an instantiated Lightning trainer."
)

with self.as_directory() as checkpoint_dir:
ckpt_path = os.path.join(checkpoint_dir, MODEL_KEY)
if not os.path.exists(ckpt_path):
raise RuntimeError(
f"File {ckpt_path} not found under the checkpoint directory."
)

model = model_class.load_from_checkpoint(
ckpt_path, **load_from_checkpoint_kwargs
)
return model

def __del__(self):
if self._cache_dir and os.path.exists(self._cache_dir):
shutil.rmtree(self._cache_dir)
Loading