Skip to content

Commit

Permalink
[Train] Unifying Lightning and AIR CheckpointConfig (ray-project#36368)
Browse files Browse the repository at this point in the history
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Yunxuan Xiao <xiaoyunxuan1998@gmail.com>
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
2 people authored and arvind-chandra committed Aug 31, 2023
1 parent 91ada04 commit 24cceb0
Showing 2 changed files with 109 additions and 19 deletions.
79 changes: 60 additions & 19 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pytorch_lightning as pl

from copy import copy
from inspect import isclass
from typing import Any, Dict, Optional, Type
from pytorch_lightning.plugins.environments import ClusterEnvironment
@@ -102,10 +103,14 @@ class definition instead of a class instance.
def trainer(self, **kwargs) -> "LightningConfigBuilder":
"""Set up the configurations of ``pytorch_lightning.Trainer``.
Note that you don't have to specify the `strategy` argument here since the
``LightningTrainer`` creates a PyTorch Lightning Strategy object with the
configurations specified in the `.strategy()` method. If no configuration
is specified, it creates a DDPStrategy by default.
Note that you don't have to specify the ``strategy``, ``device`` and
``num_nodes`` arguments here, since the ``LightningTrainer`` creates
a PyTorch Lightning Strategy object with the configurations specified
in the `.strategy()` method. The ``device`` and ``num_nodes`` are also
configured automatically by the LightningTrainer. If no configuration
is specified, it creates a ``DDPStrategy`` by default.
For ``accelerator``, currently only ``"cpu"`` and ``"gpu"`` are supported.
Args:
kwargs: The initialization arguments for ``pytorch_lightning.Trainer``
@@ -223,7 +228,7 @@ class LightningTrainer(TorchTrainer):
``pytorch_lightning.LightningModule`` using the arguments provided in
``LightningConfigBuilder.module()``.
For data ingestion, the LightningTrainer will then either convert the Dataset
For data ingestion, the LightningTrainer will then either convert the Ray Dataset
shards to a ``pytorch_lightning.LightningDataModule``, or directly use the
datamodule or dataloaders if provided by users.
@@ -339,19 +344,27 @@ def configure_optimizers(self):
scaling_config: Configuration for how to scale data parallel training.
dataset_config: Configuration for dataset ingest.
run_config: Configuration for the execution of the training run.
datasets: A dictionary of Datasets to use for training.
datasets: A dictionary of Ray Datasets to use for training.
Use the key "train" to denote which dataset is the training
dataset and (optionally) key "val" to denote the validation
dataset. If a ``preprocessor`` is provided and has not already
been fit, it will be fit on the training dataset. All datasets will be
transformed by the ``preprocessor`` if one is provided.
datasets_iter_config: Configurations for iterating over input Datasets.
This configuration is only valid when `datasets` argument is provided to
the LightningTrainer. Otherwise, LightningTrainer will use datamodule
or dataloaders specified in ``LightningConfig.trainer_init_config``.
For valid arguments to pass, please refer to:
dataset. Internally, LightningTrainer shards the training dataset
across all workers, and creates a PyTorch Dataloader for each shard.
The datasets will be transformed by ``preprocessor`` if it is provided.
If the ``preprocessor`` has not already been fit, it will be fit on the
training dataset.
If ``datasets`` is not specified, ``LightningTrainer`` will use datamodule
or dataloaders specified in ``LightningConfigBuilder.fit_params`` instead.
datasets_iter_config: Configuration for iterating over the input ray datasets.
You can configure the per-device batch size, prefetch batch size, collate
function, and more. For valid arguments to pass, please refer to:
:py:meth:`Dataset.iter_torch_batches
<ray.data.Dataset.iter_torch_batches>`
Note that if you provide a ``datasets`` parameter, you must always specify
``datasets_iter_config`` for it.
preprocessor: A ray.data.Preprocessor to preprocess the
provided datasets.
resume_from_checkpoint: A checkpoint to resume training from.
@@ -370,10 +383,17 @@ def __init__(
preprocessor: Optional[Preprocessor] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
):
run_config = run_config or RunConfig()
run_config = copy(run_config) or RunConfig()
lightning_config = lightning_config or LightningConfigBuilder().build()

self._check_checkpoint_configs(
if datasets and not datasets_iter_config:
raise RuntimeError(
"No `datasets_iter_config` provided for the input `datasets`!"
"Please refer to the API of `ray.data.DataIterator.iter_torch_batches`"
"for all valid arguments."
)

run_config.checkpoint_config = self._unify_checkpoint_configs(
ptl_ckpt_config=lightning_config["_model_checkpoint_config"],
air_ckpt_config=run_config.checkpoint_config,
)
@@ -402,10 +422,11 @@ def __init__(
resume_from_checkpoint=resume_from_checkpoint,
)

def _check_checkpoint_configs(
def _unify_checkpoint_configs(
self, ptl_ckpt_config: Dict, air_ckpt_config: CheckpointConfig
):
"""Check if configs are set correctly"""
) -> CheckpointConfig:
"""Unify the Lightning checkpointing config and the AIR CheckpointConfig."""

ptl_ckpt_metric = ptl_ckpt_config.get("monitor", None)
air_ckpt_metric = air_ckpt_config.checkpoint_score_attribute

@@ -427,6 +448,26 @@ def _check_checkpoint_configs(
"through `LightningConfigBuilder.checkpointing()`."
)

# Auto-fill the AIR CheckpointConfig if the user didn't specify it.
if air_ckpt_config == CheckpointConfig():
# save_tok_k = 1 -> num_to_keep = 1 : Lightning saves 1 ckpt by default
# save_top_k = 0 -> num_to_keep = 1 : AIR saves at least 1 ckpt
# save_top_k = -1 -> num_to_keep = None : Save all ckpts

save_top_k = ptl_ckpt_config.get("save_top_k", 1)
if save_top_k == -1:
num_to_keep = None
else:
num_to_keep = max(save_top_k, 1)

return CheckpointConfig(
num_to_keep=num_to_keep,
checkpoint_score_attribute=ptl_ckpt_config.get("monitor", None),
checkpoint_score_order=ptl_ckpt_config.get("mode", "min"),
)
else:
return air_ckpt_config

@PublicAPI(stability="alpha")
@classmethod
def restore(
49 changes: 49 additions & 0 deletions python/ray/train/tests/test_lightning_trainer.py
Original file line number Diff line number Diff line change
@@ -198,6 +198,55 @@ def test_trainer_with_categorical_ray_data(ray_start_6_cpus_2_gpus, accelerator)
assert results.checkpoint


def test_trainer_checkpoint_configs():
num_epochs = 1
batch_size = 8
input_dim = 32
output_dim = 4
dataset_size = 256

datamodule = DummyDataModule(batch_size, dataset_size)

config_builder = (
LightningConfigBuilder()
.module(LinearModule, input_dim=input_dim, output_dim=output_dim)
.trainer(max_epochs=num_epochs, accelerator="gpu")
.strategy("fsdp")
.checkpointing(monitor="metric_a", mode="min", save_top_k=3, save_last=True)
.fit_params(datamodule=datamodule)
)

scaling_config = ray.air.ScalingConfig(num_workers=2, use_gpu=True)

trainer = LightningTrainer(
lightning_config=config_builder.build(), scaling_config=scaling_config
)

# Test checkpoint configs
air_ckpt_config = trainer.run_config.checkpoint_config
assert air_ckpt_config.checkpoint_score_attribute == "metric_a"
assert air_ckpt_config.checkpoint_score_order == "min"
assert air_ckpt_config.num_to_keep == 3

config_builder.checkpointing(save_top_k=-1, monitor=None)
trainer = LightningTrainer(
lightning_config=config_builder.build(), scaling_config=scaling_config
)
air_ckpt_config = trainer.run_config.checkpoint_config
assert air_ckpt_config.checkpoint_score_attribute is None
assert air_ckpt_config.checkpoint_score_order == "min"
assert air_ckpt_config.num_to_keep is None


def test_trainer_dataset_iter_config():
# Test missing datasets_iter_config
with pytest.raises(RuntimeError, match="No `datasets_iter_config` provided"):
LightningTrainer(
scaling_config=ray.air.ScalingConfig(num_workers=2),
datasets={"train": ray.data.range(100)},
)


if __name__ == "__main__":
import sys

0 comments on commit 24cceb0

Please sign in to comment.