From 24cceb0198f1d3c359f8702f28c719641d91d206 Mon Sep 17 00:00:00 2001 From: Yunxuan Xiao Date: Wed, 28 Jun 2023 14:16:07 -0700 Subject: [PATCH] [Train] Unifying Lightning and AIR CheckpointConfig (#36368) Signed-off-by: woshiyyya Signed-off-by: Yunxuan Xiao Co-authored-by: matthewdeng Signed-off-by: e428265 --- .../ray/train/lightning/lightning_trainer.py | 79 ++++++++++++++----- .../ray/train/tests/test_lightning_trainer.py | 49 ++++++++++++ 2 files changed, 109 insertions(+), 19 deletions(-) diff --git a/python/ray/train/lightning/lightning_trainer.py b/python/ray/train/lightning/lightning_trainer.py index 630f8a3eacd2..fb47f5f1eaf4 100644 --- a/python/ray/train/lightning/lightning_trainer.py +++ b/python/ray/train/lightning/lightning_trainer.py @@ -1,6 +1,7 @@ import os import pytorch_lightning as pl +from copy import copy from inspect import isclass from typing import Any, Dict, Optional, Type from pytorch_lightning.plugins.environments import ClusterEnvironment @@ -102,10 +103,14 @@ class definition instead of a class instance. def trainer(self, **kwargs) -> "LightningConfigBuilder": """Set up the configurations of ``pytorch_lightning.Trainer``. - Note that you don't have to specify the `strategy` argument here since the - ``LightningTrainer`` creates a PyTorch Lightning Strategy object with the - configurations specified in the `.strategy()` method. If no configuration - is specified, it creates a DDPStrategy by default. + Note that you don't have to specify the ``strategy``, ``device`` and + ``num_nodes`` arguments here, since the ``LightningTrainer`` creates + a PyTorch Lightning Strategy object with the configurations specified + in the `.strategy()` method. The ``device`` and ``num_nodes`` are also + configured automatically by the LightningTrainer. If no configuration + is specified, it creates a ``DDPStrategy`` by default. + + For ``accelerator``, currently only ``"cpu"`` and ``"gpu"`` are supported. Args: kwargs: The initialization arguments for ``pytorch_lightning.Trainer`` @@ -223,7 +228,7 @@ class LightningTrainer(TorchTrainer): ``pytorch_lightning.LightningModule`` using the arguments provided in ``LightningConfigBuilder.module()``. - For data ingestion, the LightningTrainer will then either convert the Dataset + For data ingestion, the LightningTrainer will then either convert the Ray Dataset shards to a ``pytorch_lightning.LightningDataModule``, or directly use the datamodule or dataloaders if provided by users. @@ -339,19 +344,27 @@ def configure_optimizers(self): scaling_config: Configuration for how to scale data parallel training. dataset_config: Configuration for dataset ingest. run_config: Configuration for the execution of the training run. - datasets: A dictionary of Datasets to use for training. + datasets: A dictionary of Ray Datasets to use for training. Use the key "train" to denote which dataset is the training dataset and (optionally) key "val" to denote the validation - dataset. If a ``preprocessor`` is provided and has not already - been fit, it will be fit on the training dataset. All datasets will be - transformed by the ``preprocessor`` if one is provided. - datasets_iter_config: Configurations for iterating over input Datasets. - This configuration is only valid when `datasets` argument is provided to - the LightningTrainer. Otherwise, LightningTrainer will use datamodule - or dataloaders specified in ``LightningConfig.trainer_init_config``. - For valid arguments to pass, please refer to: + dataset. Internally, LightningTrainer shards the training dataset + across all workers, and creates a PyTorch Dataloader for each shard. + + The datasets will be transformed by ``preprocessor`` if it is provided. + If the ``preprocessor`` has not already been fit, it will be fit on the + training dataset. + + If ``datasets`` is not specified, ``LightningTrainer`` will use datamodule + or dataloaders specified in ``LightningConfigBuilder.fit_params`` instead. + datasets_iter_config: Configuration for iterating over the input ray datasets. + You can configure the per-device batch size, prefetch batch size, collate + function, and more. For valid arguments to pass, please refer to: :py:meth:`Dataset.iter_torch_batches ` + + Note that if you provide a ``datasets`` parameter, you must always specify + ``datasets_iter_config`` for it. + preprocessor: A ray.data.Preprocessor to preprocess the provided datasets. resume_from_checkpoint: A checkpoint to resume training from. @@ -370,10 +383,17 @@ def __init__( preprocessor: Optional[Preprocessor] = None, resume_from_checkpoint: Optional[Checkpoint] = None, ): - run_config = run_config or RunConfig() + run_config = copy(run_config) or RunConfig() lightning_config = lightning_config or LightningConfigBuilder().build() - self._check_checkpoint_configs( + if datasets and not datasets_iter_config: + raise RuntimeError( + "No `datasets_iter_config` provided for the input `datasets`!" + "Please refer to the API of `ray.data.DataIterator.iter_torch_batches`" + "for all valid arguments." + ) + + run_config.checkpoint_config = self._unify_checkpoint_configs( ptl_ckpt_config=lightning_config["_model_checkpoint_config"], air_ckpt_config=run_config.checkpoint_config, ) @@ -402,10 +422,11 @@ def __init__( resume_from_checkpoint=resume_from_checkpoint, ) - def _check_checkpoint_configs( + def _unify_checkpoint_configs( self, ptl_ckpt_config: Dict, air_ckpt_config: CheckpointConfig - ): - """Check if configs are set correctly""" + ) -> CheckpointConfig: + """Unify the Lightning checkpointing config and the AIR CheckpointConfig.""" + ptl_ckpt_metric = ptl_ckpt_config.get("monitor", None) air_ckpt_metric = air_ckpt_config.checkpoint_score_attribute @@ -427,6 +448,26 @@ def _check_checkpoint_configs( "through `LightningConfigBuilder.checkpointing()`." ) + # Auto-fill the AIR CheckpointConfig if the user didn't specify it. + if air_ckpt_config == CheckpointConfig(): + # save_tok_k = 1 -> num_to_keep = 1 : Lightning saves 1 ckpt by default + # save_top_k = 0 -> num_to_keep = 1 : AIR saves at least 1 ckpt + # save_top_k = -1 -> num_to_keep = None : Save all ckpts + + save_top_k = ptl_ckpt_config.get("save_top_k", 1) + if save_top_k == -1: + num_to_keep = None + else: + num_to_keep = max(save_top_k, 1) + + return CheckpointConfig( + num_to_keep=num_to_keep, + checkpoint_score_attribute=ptl_ckpt_config.get("monitor", None), + checkpoint_score_order=ptl_ckpt_config.get("mode", "min"), + ) + else: + return air_ckpt_config + @PublicAPI(stability="alpha") @classmethod def restore( diff --git a/python/ray/train/tests/test_lightning_trainer.py b/python/ray/train/tests/test_lightning_trainer.py index 767ce2fcc10c..5693dcd15b0d 100644 --- a/python/ray/train/tests/test_lightning_trainer.py +++ b/python/ray/train/tests/test_lightning_trainer.py @@ -198,6 +198,55 @@ def test_trainer_with_categorical_ray_data(ray_start_6_cpus_2_gpus, accelerator) assert results.checkpoint +def test_trainer_checkpoint_configs(): + num_epochs = 1 + batch_size = 8 + input_dim = 32 + output_dim = 4 + dataset_size = 256 + + datamodule = DummyDataModule(batch_size, dataset_size) + + config_builder = ( + LightningConfigBuilder() + .module(LinearModule, input_dim=input_dim, output_dim=output_dim) + .trainer(max_epochs=num_epochs, accelerator="gpu") + .strategy("fsdp") + .checkpointing(monitor="metric_a", mode="min", save_top_k=3, save_last=True) + .fit_params(datamodule=datamodule) + ) + + scaling_config = ray.air.ScalingConfig(num_workers=2, use_gpu=True) + + trainer = LightningTrainer( + lightning_config=config_builder.build(), scaling_config=scaling_config + ) + + # Test checkpoint configs + air_ckpt_config = trainer.run_config.checkpoint_config + assert air_ckpt_config.checkpoint_score_attribute == "metric_a" + assert air_ckpt_config.checkpoint_score_order == "min" + assert air_ckpt_config.num_to_keep == 3 + + config_builder.checkpointing(save_top_k=-1, monitor=None) + trainer = LightningTrainer( + lightning_config=config_builder.build(), scaling_config=scaling_config + ) + air_ckpt_config = trainer.run_config.checkpoint_config + assert air_ckpt_config.checkpoint_score_attribute is None + assert air_ckpt_config.checkpoint_score_order == "min" + assert air_ckpt_config.num_to_keep is None + + +def test_trainer_dataset_iter_config(): + # Test missing datasets_iter_config + with pytest.raises(RuntimeError, match="No `datasets_iter_config` provided"): + LightningTrainer( + scaling_config=ray.air.ScalingConfig(num_workers=2), + datasets={"train": ray.data.range(100)}, + ) + + if __name__ == "__main__": import sys