Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] <Part 2> Support metric logging and checkpointing for LightningTrainer #33183

Merged
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
4c8edc1
init Lightning Trainer with ci tests
woshiyyya Mar 9, 2023
e0fd5d9
add test for trainer with categorical ray dataset
woshiyyya Mar 9, 2023
ebea347
fix lightiningEnvironment import error
woshiyyya Mar 9, 2023
01f9461
fix test_lightning_trainer not found error
woshiyyya Mar 9, 2023
22725b6
format code
woshiyyya Mar 9, 2023
b3ab873
check linter
woshiyyya Mar 9, 2023
6ba0577
init LightningCheckpoint
woshiyyya Mar 9, 2023
a9d8ab6
fix linting issues
woshiyyya Mar 10, 2023
dde861d
fix linting issues
woshiyyya Mar 10, 2023
55b7916
fix lint again
woshiyyya Mar 10, 2023
a75fec6
Change lightning_config to dict type to aligh with Ray Tune
woshiyyya Mar 10, 2023
e1b9d81
Change lightning_config to dict type to aligh with Ray Tune
woshiyyya Mar 10, 2023
2e6500f
fix non_monitored_checkpoint saving error
woshiyyya Mar 10, 2023
895ab43
Apply suggestions from code review
woshiyyya Mar 13, 2023
f76f48d
replace LightningConfigBuilder setter with semantic meanings
woshiyyya Mar 13, 2023
c1448db
add fixture and builder_tests
woshiyyya Mar 13, 2023
66ab606
add example code for config builder, add api index
woshiyyya Mar 13, 2023
2842b14
fix lint, change test tag to large
woshiyyya Mar 13, 2023
38cd101
remove resume from ckpt for next PR
woshiyyya Mar 13, 2023
5210cb7
rename trainer.rst to avoid cross reference
woshiyyya Mar 13, 2023
300ba71
Merge remote-tracking branch 'upstream/master' into air/lightning_bas…
woshiyyya Mar 13, 2023
1935473
address comments
woshiyyya Mar 14, 2023
55a7b32
fix document issue
woshiyyya Mar 14, 2023
6598894
Merge remote-tracking branch 'upstream/master' into air/lightning_bas…
woshiyyya Mar 14, 2023
171a5e8
Merge remote-tracking branch 'origin/air/lightning_base_trainer' into…
woshiyyya Mar 14, 2023
247ba48
align report and checkpointing frequency
woshiyyya Mar 15, 2023
71b24fa
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 15, 2023
6eea366
convert air ckpt config to lightning config
woshiyyya Mar 15, 2023
8fe5b05
add checks for checkpoint config
woshiyyya Mar 15, 2023
cc8d3e8
check tuning target metric
woshiyyya Mar 15, 2023
e606682
fix tuning metric logic
woshiyyya Mar 15, 2023
5cc5a42
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 15, 2023
7bac4e7
finalize logging logic
woshiyyya Mar 16, 2023
b009e0f
remove tuning metric setup
woshiyyya Mar 16, 2023
ff89639
disable metric checking by default
woshiyyya Mar 16, 2023
40cc5dd
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 16, 2023
9403bf5
address the review comments
woshiyyya Mar 17, 2023
a88e57a
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 17, 2023
c57c310
add doc API reference
woshiyyya Mar 17, 2023
55918b9
fixing CI
woshiyyya Mar 17, 2023
d4be393
add comments on strict check env var
woshiyyya Mar 20, 2023
9bb4136
Merge remote-tracking branch 'upstream/master' into air/lightning_log…
woshiyyya Mar 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_lightning_checkpoint",
size = "medium",
srcs = ["tests/test_lightning_checkpoint.py"],
tags = ["team:ml", "exclusive", "ray_air", "gpu"],
deps = [":train_lib"]
)

py_test(
name = "test_lightning_trainer",
size = "large",
Expand Down
4 changes: 2 additions & 2 deletions python/ray/train/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
)
# isort: on

from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.lightning.lightning_trainer import (
LightningTrainer,
LightningConfigBuilder,
)


__all__ = ["LightningTrainer", "LightningConfigBuilder"]
__all__ = ["LightningTrainer", "LightningConfigBuilder", "LightningCheckpoint"]
130 changes: 125 additions & 5 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import os
import logging
import torch
from typing import Any, Dict, Optional

import pytorch_lightning as pl

from torch import Tensor
from copy import deepcopy
from typing import Any, Dict, Optional
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.plugins.environments import LightningEnvironment

import ray
from ray.air import session

from ray.air.constants import MODEL_KEY
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from torch.utils.data import IterableDataset, DataLoader
from ray.data.dataset import DatasetIterator

Expand Down Expand Up @@ -46,11 +52,13 @@ def node_rank(self) -> int:
return session.get_node_rank()

def set_world_size(self, size: int) -> None:
logger.warning("world_size setter is disabled in AIR LightningTrainer.")
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
if self.global_rank() == 0:
logger.warning("world_size setter is disabled in AIR LightningTrainer.")
pass

def set_global_rank(self, rank: int) -> None:
logger.warning("global_rank setter is disabled in AIR LightningTrainer.")
if self.global_rank() == 0:
logger.warning("global_rank setter is disabled in AIR LightningTrainer.")
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
pass

def teardown(self):
Expand Down Expand Up @@ -94,3 +102,115 @@ def _val_dataloader() -> DataLoader:
# setting, we only override this method when `val_dataset` is not `None`.
if val_dataset:
self.val_dataloader = _val_dataloader


class RayModelCheckpoint(ModelCheckpoint):
"""
AIR customized ModelCheckpoint callback.

A subclass of ``pytorch_lightning.callbacks.ModelCheckpoint``.
This callback function reports the latest metrics to the AIR session and
creates an AIR checkpoint whenever a lightning checkpoint is saved.
"""

def setup(self, *args, **kwargs) -> None:
super().setup(*args, **kwargs)
self.last_best_k_models = {}
self.last_best_model_path = None
self.is_checkpoint_step = False

def format_checkpoint_name(
self,
metrics: Dict[str, Tensor],
filename: Optional[str] = None,
ver: Optional[int] = None,
) -> str:
"""
Change checkpoint files path to align with AIR checkpoint format.

e.g. './epoch=2-loss=0.12.ckpt' -> './epoch=2-loss=0.12.ckpt/model'
"""
filepath = super().format_checkpoint_name(metrics, filename, ver)
return f"{filepath}/{MODEL_KEY}"
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

def _session_report(self, trainer: "pl.Trainer", stage: str):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"""Report latest metrics dict and checkpoint to AIR training session."""

# Align the frequency of session.report() and checkpointing.
if not self.is_checkpoint_step:
return
self.is_checkpoint_step = False

# Report latest logged metrics
kwargs = {}
metrics = {}
for k, v in self._monitor_candidates(trainer).items():
if k == "_stage":
logger.warning(
"'_stage' is a reserved key in AIR report metrics. "
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"Original values are overwritten!"
)
continue
if isinstance(v, torch.Tensor):
metrics[k] = v.item()

metrics["_stage"] = stage
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
kwargs["metrics"] = metrics

filepath = None
if self.monitor:
# Capture metric-based top-k checkpoint
new_checkpoint = self.best_k_models.keys() - self.last_best_k_models.keys()
if new_checkpoint:
filepath = new_checkpoint.pop()
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
else:
# Capture frequency-based checkpoint
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
if self.last_best_model_path != self.best_model_path:
filepath = self.best_model_path
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

# Report latest saved checkpoint
# Note that AIR only takes the checkpoint of rank 0.
# Save a dummy checkpoint on the other workers to avoid blocking.
if filepath:
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
if trainer.global_rank == 0:
kwargs["checkpoint"] = LightningCheckpoint.from_directory(
path=os.path.dirname(filepath)
)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
else:
kwargs["checkpoint"] = LightningCheckpoint.from_dict(
{"rank": session.get_world_rank()}
)

self.last_best_k_models = deepcopy(self.best_k_models)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
self.last_best_model_path = self.best_model_path

session.report(**kwargs)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

def _save_topk_checkpoint(
self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
) -> None:
self.is_checkpoint_step = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we actually need this. Let's just return early in the cases where no checkpoint is found. Better not to mess with this private method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that override this method is not elegant, but it seems to be the only way. The name of the last checkpoint is always "last.ckpt", so we can't determine if there's a new checkpoint based on the file name alone.

To give you more context, the logic of ModelCheckpoint.on_train_batch_end() is like:

def on_train_batch_end():
    #####################################################################
    # Code block that determines whether this is a checkpoint step.
    # return if it's not.
    #####################################################################
    self._save_top_k_checkpoint()
    self._save_last_checkpoint()

Since lightning did not modularize the code block into a self.should_checkpoint() function, we can only track whether _save_top_k_checkpoint() and _save_last_checkpoint() are being called.

return super()._save_topk_checkpoint(trainer, monitor_candidates)

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
self._session_report(trainer=trainer, stage="train_batch_end")

def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
super().on_train_epoch_end(trainer, pl_module)
self._session_report(trainer=trainer, stage="train_epoch_end")

def on_validation_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
super().on_validation_end(trainer, pl_module)
self._session_report(trainer=trainer, stage="validation_end")
97 changes: 97 additions & 0 deletions python/ray/train/lightning/lightning_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import logging
import pytorch_lightning as pl
import tempfile
import shutil

from inspect import isclass
from typing import Optional, Type

from ray.air.constants import MODEL_KEY
from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.data import Preprocessor
from ray.train.torch import TorchCheckpoint
from ray.util.annotations import PublicAPI

logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
class LightningCheckpoint(TorchCheckpoint):
"""A :class:`~ray.air.checkpoint.Checkpoint` with Lightning-specific functionality.

LightningCheckpoint only support file based checkpoint loading.
Create this by calling ``LightningCheckpoint.from_directory(ckpt_dir)``,
``LightningCheckpoint.from_uri(uri)`` or ``LightningCheckpoint.from_path(path)``

LightningCheckpoint loads file named ``model`` under the specified directory.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._cache_dir = None

@classmethod
def from_path(
cls,
path: str,
*,
preprocessor: Optional["Preprocessor"] = None,
) -> "LightningCheckpoint":
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"""Create a ``ray.air.lightning.LightningCheckpoint`` from a checkpoint path.
Args:
path: The file path to the PyTorch Lightning checkpoint.
preprocessor: A fitted preprocessor to be applied before inference.
Returns:
An :py:class:`LightningCheckpoint` containing the model.
Examples:
>>> from ray.train.lightning import LightningCheckpoint
>>>
>>> checkpoint = LightningCheckpoint.from_path("/path/to/checkpoint.ckpt")
"""

assert os.path.exists(path), f"Lightning checkpoint {path} doesn't exists!"

cache_dir = tempfile.mkdtemp()
new_checkpoint_path = os.path.join(cache_dir, MODEL_KEY)
shutil.copy(path, new_checkpoint_path)
if preprocessor:
save_preprocessor_to_dir(preprocessor, cache_dir)
checkpoint = cls.from_directory(cache_dir)
checkpoint._cache_dir = cache_dir
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
return checkpoint

def get_model(
self, model_class: Type[pl.LightningModule], **load_from_checkpoint_kwargs
) -> pl.LightningModule:
"""Retrieve the model stored in this checkpoint.

Args:
model_class: A subclass of ``pytorch_lightning.LightningModule`` that
defines your model and training logic.
load_from_checkpoint_kwargs: Arguments to pass into
``model_cls.load_from_checkpoint``
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

Returns:
pl.LightningModule: An instance of the loaded model.
"""
if not isclass(model_class):
raise ValueError(
"'lightning_module' must be a class, not a class instance."
)
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved

with self.as_directory() as checkpoint_dir:
ckpt_path = os.path.join(checkpoint_dir, MODEL_KEY)
if not os.path.exists(ckpt_path):
raise RuntimeError(
f"File {ckpt_path} not found under the checkpoint directory."
)

model = model_class.load_from_checkpoint(
ckpt_path, **load_from_checkpoint_kwargs
)
return model

def __del__(self):
if self._cache_dir and os.path.exists(self._cache_dir):
shutil.rmtree(self._cache_dir)
Loading