Skip to content

Commit

Permalink
[Data][Train] Fix remaining issues on DatasetConfig->DataConfig migra…
Browse files Browse the repository at this point in the history
…tion (#37215) (#37352)

- Change all examples to use DataConfig.
- Update function signature of all Trainer classes.
- Add a link in deprecation warning.

---------

Signed-off-by: Hao Chen <chenh1024@gmail.com>
  • Loading branch information
raulchen authored Jul 12, 2023
1 parent 066645f commit bf1b735
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 37 deletions.
14 changes: 6 additions & 8 deletions doc/source/ray-air/doc_code/computer_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def train_torch_model(dataset, preprocessor, per_epoch_preprocessor):

from ray import train
from ray.air import session
from ray.air.config import DatasetConfig, ScalingConfig
from ray.air.config import ScalingConfig
from ray.train.torch import TorchCheckpoint, TorchTrainer

def train_one_epoch(model, *, criterion, optimizer, batch_size, epoch):
Expand Down Expand Up @@ -237,13 +237,11 @@ def train_loop_per_worker(config):
# __torch_training_loop_stop__

# __torch_trainer_start__
dataset = per_epoch_preprocessor.transform(dataset)
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"batch_size": 32, "lr": 0.02, "epochs": 1},
datasets={"train": dataset},
dataset_config={
"train": DatasetConfig(per_epoch_preprocessor=per_epoch_preprocessor)
},
scaling_config=ScalingConfig(num_workers=2),
preprocessor=preprocessor,
)
Expand Down Expand Up @@ -288,16 +286,16 @@ def train_loop_per_worker(config):
# __tensorflow_training_loop_stop__

# __tensorflow_trainer_start__
from ray.air import DatasetConfig, ScalingConfig
from ray.air import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

# The following transform operation is lazy.
# It will be re-run every epoch.
dataset = per_epoch_preprocessor.transform(dataset)
trainer = TensorflowTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"batch_size": 32, "lr": 0.02, "epochs": 1},
datasets={"train": dataset},
dataset_config={
"train": DatasetConfig(per_epoch_preprocessor=per_epoch_preprocessor)
},
scaling_config=ScalingConfig(num_workers=2),
preprocessor=preprocessor,
)
Expand Down
36 changes: 26 additions & 10 deletions doc/source/ray-air/examples/torch_detection.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "da5b9b7e",
"metadata": {},
Expand All @@ -23,6 +24,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "e9a6d043",
"metadata": {},
Expand All @@ -45,6 +47,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "9b3d4302",
"metadata": {},
Expand All @@ -67,6 +70,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "65bf13b8",
"metadata": {},
Expand All @@ -91,6 +95,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5567a6d6",
"metadata": {},
Expand All @@ -112,6 +117,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f821e93d",
"metadata": {},
Expand Down Expand Up @@ -153,6 +159,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b8ab2cf1",
"metadata": {},
Expand Down Expand Up @@ -210,6 +217,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "686f0885",
"metadata": {},
Expand Down Expand Up @@ -293,6 +301,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "10d6ed44",
"metadata": {},
Expand Down Expand Up @@ -332,6 +341,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "db3d0ee6",
"metadata": {},
Expand Down Expand Up @@ -367,6 +377,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5ff0097f",
"metadata": {},
Expand All @@ -375,6 +386,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "87846ae1",
"metadata": {},
Expand Down Expand Up @@ -438,6 +450,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "e7cdc755",
"metadata": {},
Expand All @@ -446,6 +459,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "9cfddd49",
"metadata": {},
Expand All @@ -464,6 +478,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6b68209a",
"metadata": {},
Expand All @@ -472,6 +487,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "9dbea4b4",
"metadata": {},
Expand Down Expand Up @@ -503,6 +519,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "1c647be8",
"metadata": {},
Expand Down Expand Up @@ -616,6 +633,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0d68c97c",
"metadata": {},
Expand All @@ -624,6 +642,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "eef58891",
"metadata": {},
Expand Down Expand Up @@ -806,9 +825,12 @@
}
],
"source": [
"from ray.air.config import DatasetConfig, ScalingConfig\n",
"from ray.air.config import ScalingConfig\n",
"from ray.train.torch import TorchTrainer\n",
"\n",
"# The following transform operation is lazy.\n",
"# It will be re-run every epoch.\n",
"train_dataset = per_epoch_preprocessor.transform(train_dataset)\n",
"\n",
"trainer = TorchTrainer(\n",
" train_loop_per_worker=train_loop_per_worker,\n",
Expand All @@ -823,19 +845,13 @@
" },\n",
" scaling_config=ScalingConfig(num_workers=4, use_gpu=True),\n",
" datasets={\"train\": train_dataset},\n",
" dataset_config={\n",
" # Don't augment test images. Only apply `per_epoch_preprocessor` to the train\n",
" # set.\n",
" \"train\": DatasetConfig(\n",
" per_epoch_preprocessor=per_epoch_preprocessor\n",
" ),\n",
" },\n",
" preprocessor=preprocessor,\n",
")\n",
"results = trainer.fit()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "838101c2",
"metadata": {},
Expand All @@ -854,8 +870,8 @@
},
"language_info": {
"name": "python",
"version": "3.10.9",
"pygments_lexer": "ipython3"
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"vscode": {
"interpreter": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import boto3
import mlflow
import pandas as pd
from ray.air.config import DatasetConfig, ScalingConfig
from ray.air.config import ScalingConfig
from ray.train.data_config import DataConfig
from ray.train.torch.torch_trainer import TorchTrainer
import torch
import torch.nn as nn
Expand Down Expand Up @@ -601,6 +602,10 @@ def to_torch_dataset(torch_batch_iterator):
DROPOUT_EVERY = 5
DROPOUT_PROB = 0.2

# The following random_shuffle operations are lazy.
# They will be re-run every epoch.
train_dataset = train_dataset.random_shuffle()
test_dataset = test_dataset.random_shuffle()
datasets = {"train": train_dataset, "test": test_dataset}

config = {
Expand Down Expand Up @@ -633,7 +638,7 @@ def to_torch_dataset(torch_batch_iterator):
resources_per_worker=resources_per_worker,
),
run_config=RunConfig(callbacks=callbacks),
dataset_config={"train": DatasetConfig(global_shuffle=True)},
dataset_config=DataConfig(datasets_to_split=["train", "test"]),
)
results = trainer.fit()
state_dict = results.checkpoint.to_dict()["model"]
Expand Down
3 changes: 2 additions & 1 deletion python/ray/air/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,8 @@ def from_placement_group_factory(
@dataclass
@Deprecated(
message="Use `ray.train.DataConfig` instead of DatasetConfig to "
"configure data ingest for training."
"configure data ingest for training. "
"See https://docs.ray.io/en/master/ray-air/check-ingest.html for more details."
)
class DatasetConfig:
"""Configuration for ingest of a single Dataset.
Expand Down
5 changes: 3 additions & 2 deletions python/ray/train/horovod/horovod_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Callable, Optional, Union, TYPE_CHECKING

from ray.air.config import ScalingConfig, RunConfig, DatasetConfig
from ray.air.config import ScalingConfig, RunConfig
from ray.train.data_config import DataConfig
from ray.train.trainer import GenDataset
from ray.air.checkpoint import Checkpoint

Expand Down Expand Up @@ -181,7 +182,7 @@ def __init__(
train_loop_config: Optional[Dict] = None,
horovod_config: Optional[HorovodConfig] = None,
scaling_config: Optional[ScalingConfig] = None,
dataset_config: Optional[Dict[str, DatasetConfig]] = None,
dataset_config: Optional[DataConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
Expand Down
5 changes: 3 additions & 2 deletions python/ray/train/huggingface/accelerate/accelerate_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.air.config import RunConfig, ScalingConfig
from ray.train.data_config import DataConfig
from ray.train.torch import TorchConfig
from ray.train.trainer import GenDataset

Expand Down Expand Up @@ -263,7 +264,7 @@ def __init__(
accelerate_config: Optional[Union[dict, str, Path, os.PathLike]] = None,
torch_config: Optional[TorchConfig] = None,
scaling_config: Optional[ScalingConfig] = None,
dataset_config: Optional[Dict[str, DatasetConfig]] = None,
dataset_config: Optional[DataConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.air.config import RunConfig, ScalingConfig
from ray.train.constants import (
EVALUATION_DATASET_KEY,
TRAIN_DATASET_KEY,
)
from ray.train.data_config import DataConfig
from ray.train.data_parallel_trainer import DataParallelTrainer
from ray.train.torch import TorchConfig, TorchTrainer
from ray.train.trainer import GenDataset
Expand Down Expand Up @@ -254,7 +255,7 @@ def __init__(
trainer_init_config: Optional[Dict] = None,
torch_config: Optional[TorchConfig] = None,
scaling_config: Optional[ScalingConfig] = None,
dataset_config: Optional[Dict[str, DatasetConfig]] = None,
dataset_config: Optional[DataConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
preprocessor: Optional["Preprocessor"] = None,
Expand Down
5 changes: 3 additions & 2 deletions python/ray/train/lightning/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from pytorch_lightning.plugins.environments import ClusterEnvironment

from ray.air import session
from ray.air.config import CheckpointConfig, DatasetConfig, RunConfig, ScalingConfig
from ray.air.config import CheckpointConfig, RunConfig, ScalingConfig
from ray.air.constants import MODEL_KEY
from ray.air.checkpoint import Checkpoint
from ray.data.preprocessor import Preprocessor
from ray.train.data_config import DataConfig
from ray.train.trainer import GenDataset
from ray.train.torch import TorchTrainer
from ray.train.torch.config import TorchConfig
Expand Down Expand Up @@ -395,7 +396,7 @@ def __init__(
*,
torch_config: Optional[TorchConfig] = None,
scaling_config: Optional[ScalingConfig] = None,
dataset_config: Optional[Dict[str, DatasetConfig]] = None,
dataset_config: Optional[DataConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
datasets_iter_config: Optional[Dict[str, Any]] = None,
Expand Down
5 changes: 3 additions & 2 deletions python/ray/train/mosaic/mosaic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from composer.loggers.logger_destination import LoggerDestination

from ray.air.checkpoint import Checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig
from ray.air.config import RunConfig, ScalingConfig
from ray.train.data_config import DataConfig
from ray.train.mosaic._mosaic_utils import RayLogger
from ray.train.torch import TorchConfig, TorchTrainer
from ray.train.trainer import GenDataset
Expand Down Expand Up @@ -139,7 +140,7 @@ def __init__(
trainer_init_config: Optional[Dict] = None,
torch_config: Optional[TorchConfig] = None,
scaling_config: Optional[ScalingConfig] = None,
dataset_config: Optional[Dict[str, DatasetConfig]] = None,
dataset_config: Optional[DataConfig] = None,
run_config: Optional[RunConfig] = None,
preprocessor: Optional["Preprocessor"] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
Expand Down
Loading

0 comments on commit bf1b735

Please sign in to comment.