Skip to content

Commit

Permalink
fixing CI
Browse files Browse the repository at this point in the history
Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
  • Loading branch information
woshiyyya committed Mar 17, 2023
1 parent c57c310 commit 55918b9
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 115 deletions.
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_lightning_trainer_restore",
size = "medium",
srcs = ["tests/test_lightning_trainer_restore.py"],
tags = ["team:ml", "exclusive", "ray_air", "gpu"],
deps = [":train_lib"]
)

py_test(
name = "test_lightning_trainer",
size = "large",
Expand Down
12 changes: 5 additions & 7 deletions python/ray/train/lightning/_lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

logger = logging.getLogger(__name__)

LIGHTNING_REPORT_STAGE_KEY = "__report_on"


class RayDDPStrategy(DDPStrategy):
"""Subclass of DDPStrategy to ensure compatibility with Ray orchestration."""
Expand Down Expand Up @@ -50,9 +52,11 @@ def node_rank(self) -> int:
return session.get_node_rank()

def set_world_size(self, size: int) -> None:
# Disable it since `world_size()` directly returns data from AIR session.
pass

def set_global_rank(self, rank: int) -> None:
# Disable it since `global_rank()` directly returns data from AIR session.
pass

def teardown(self):
Expand Down Expand Up @@ -124,14 +128,8 @@ def _session_report(self, trainer: "pl.Trainer", stage: str):
return

# Report latest logged metrics
metrics = {"report_on": stage}
metrics = {LIGHTNING_REPORT_STAGE_KEY: stage}
for k, v in self._monitor_candidates(trainer).items():
if k == "report_on":
logger.warning(
"'report_on' is a reserved key in AIR report metrics. "
"Original values are overwritten!"
)
continue
if isinstance(v, torch.Tensor):
metrics[k] = v.item()

Expand Down
65 changes: 0 additions & 65 deletions python/ray/train/tests/test_lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import ray
from ray.train.lightning import LightningConfigBuilder, LightningTrainer
from ray.train.constants import MODEL_KEY
from ray.air.util.data_batch_conversion import convert_batch_type_to_pandas
from ray.train.tests.lightning_test_utils import (
LinearModule,
Expand Down Expand Up @@ -188,70 +187,6 @@ def test_trainer_with_categorical_ray_data(ray_start_6_cpus_2_gpus, accelerator)
assert results.checkpoint


def test_resume_from_checkpoint(ray_start_6_cpus):
num_epochs = 2
batch_size = 8
num_workers = 2
dataset_size = 64

# Create simple categorical ray dataset
input_1 = np.random.rand(dataset_size, 32).astype(np.float32)
input_2 = np.random.rand(dataset_size, 32).astype(np.float32)
pd = convert_batch_type_to_pandas({"input_1": input_1, "input_2": input_2})
train_dataset = ray.data.from_pandas(pd)
val_dataset = ray.data.from_pandas(pd)

config_builder = (
LightningConfigBuilder()
.module(
DoubleLinearModule,
input_dim_1=32,
input_dim_2=32,
output_dim=4,
)
.trainer(max_epochs=num_epochs, accelerator="cpu")
)

lightning_config = config_builder.build()

scaling_config = ray.air.ScalingConfig(num_workers=num_workers, use_gpu=False)

trainer = LightningTrainer(
lightning_config=lightning_config,
scaling_config=scaling_config,
datasets={"train": train_dataset, "val": val_dataset},
datasets_iter_config={"batch_size": batch_size},
)
results = trainer.fit()

# Resume training for another 2 epochs
num_epochs += 2
ckpt_dir = results.checkpoint.uri[7:]
ckpt_path = f"{ckpt_dir}/{MODEL_KEY}"

lightning_config = (
config_builder.fit_params(ckpt_path=ckpt_path)
.trainer(max_epochs=num_epochs)
.build()
)

trainer = LightningTrainer(
lightning_config=lightning_config,
scaling_config=scaling_config,
datasets={"train": train_dataset, "val": val_dataset},
datasets_iter_config={"batch_size": batch_size},
)
results = trainer.fit()

assert results.metrics["epoch"] == num_epochs - 1
assert (
results.metrics["step"] == num_epochs * dataset_size / num_workers / batch_size
)
assert "loss" in results.metrics
assert "val_loss" in results.metrics
assert results.checkpoint


if __name__ == "__main__":
import sys

Expand Down
156 changes: 156 additions & 0 deletions python/ray/train/tests/test_lightning_trainer_restore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import numpy as np
import pytest

import ray
from ray.air import RunConfig, CheckpointConfig
from ray.air.util.data_batch_conversion import convert_batch_type_to_pandas
from ray.train.constants import MODEL_KEY
from ray.train.lightning import LightningConfigBuilder, LightningTrainer
from ray.train.tests.lightning_test_utils import (
DoubleLinearModule,
DummyDataModule,
LinearModule,
)
from ray.tune import Callback, TuneError


@pytest.fixture
def ray_start_4_cpus_2_gpus():
address_info = ray.init(num_cpus=4, num_gpus=2)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()


@pytest.fixture
def ray_start_6_cpus():
address_info = ray.init(num_cpus=6)
yield address_info
# The code after the yield will run as teardown code.
ray.shutdown()


class FailureInjectionCallback(Callback):
"""Inject failure at the configured iteration number."""

def __init__(self, num_iters=2):
self.num_iters = num_iters

def on_trial_save(self, iteration, trials, trial, **info):
if trial.last_result["training_iteration"] == self.num_iters:
print(f"Failing after {self.num_iters} iters...")
raise RuntimeError


def test_native_trainer_restore(ray_start_4_cpus_2_gpus):
"""Test restoring trainer in the Lightning's native way."""
num_epochs = 2
batch_size = 8
num_workers = 2
dataset_size = 64

# Create simple categorical ray dataset
input_1 = np.random.rand(dataset_size, 32).astype(np.float32)
input_2 = np.random.rand(dataset_size, 32).astype(np.float32)
pd = convert_batch_type_to_pandas({"input_1": input_1, "input_2": input_2})
train_dataset = ray.data.from_pandas(pd)
val_dataset = ray.data.from_pandas(pd)

config_builder = (
LightningConfigBuilder()
.module(
DoubleLinearModule,
input_dim_1=32,
input_dim_2=32,
output_dim=4,
)
.trainer(max_epochs=num_epochs, accelerator="gpu")
)

lightning_config = config_builder.build()

scaling_config = ray.air.ScalingConfig(num_workers=num_workers, use_gpu=True)

trainer = LightningTrainer(
lightning_config=lightning_config,
scaling_config=scaling_config,
datasets={"train": train_dataset, "val": val_dataset},
datasets_iter_config={"batch_size": batch_size},
)
results = trainer.fit()

# Resume training for another 2 epochs
num_epochs += 2
ckpt_dir = results.checkpoint.uri[7:]
ckpt_path = f"{ckpt_dir}/{MODEL_KEY}"

lightning_config = (
config_builder.fit_params(ckpt_path=ckpt_path)
.trainer(max_epochs=num_epochs)
.build()
)

trainer = LightningTrainer(
lightning_config=lightning_config,
scaling_config=scaling_config,
datasets={"train": train_dataset, "val": val_dataset},
datasets_iter_config={"batch_size": batch_size},
)
results = trainer.fit()

assert results.metrics["epoch"] == num_epochs - 1
assert (
results.metrics["step"] == num_epochs * dataset_size / num_workers / batch_size
)
assert "loss" in results.metrics
assert "val_loss" in results.metrics
assert results.checkpoint


def test_air_trainer_restore(ray_start_6_cpus, tmpdir):
"""Test restore for LightningTrainer from a failed/interrupted trail."""
exp_name = "air_trainer_restore_test"

datamodule = DummyDataModule(8, 256)
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()

lightning_config = (
LightningConfigBuilder()
.module(LinearModule, input_dim=32, output_dim=4)
.trainer(max_epochs=5, accelerator="cpu")
.fit_params(train_dataloaders=train_loader, val_dataloaders=val_loader)
.build()
)

scaling_config = ray.air.ScalingConfig(num_workers=2, use_gpu=False)

trainer = LightningTrainer(
lightning_config=lightning_config,
scaling_config=scaling_config,
run_config=RunConfig(
local_dir=str(tmpdir),
name=exp_name,
checkpoint_config=CheckpointConfig(num_to_keep=1),
callbacks=[FailureInjectionCallback(num_iters=2)],
),
)

with pytest.raises(TuneError):
result = trainer.fit()

trainer = LightningTrainer.restore(str(tmpdir / exp_name))
result = trainer.fit()

assert not result.error
assert result.metrics["training_iteration"] == 5
assert result.metrics["iterations_since_restore"] == 3
assert tmpdir / exp_name in result.log_dir.parents


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", "-x", __file__]))
43 changes: 0 additions & 43 deletions python/ray/train/tests/test_trainer_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from ray.train.torch import TorchTrainer
from ray.train.xgboost import XGBoostTrainer
from ray.train.lightgbm import LightGBMTrainer
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray.train.tests.lightning_test_utils import LinearModule, DummyDataModule
from ray.train.huggingface import HuggingFaceTrainer
from ray.train.rl import RLTrainer
from ray.tune import Callback, TuneError
Expand Down Expand Up @@ -217,47 +215,6 @@ def test_trainer_with_init_fn_restore(ray_start_4_cpus, tmpdir, trainer_cls):
assert tmpdir / exp_name in result.log_dir.parents


def test_lightning_trainer_restore(ray_start_4_cpus, tmpdir):
"""Tests restore for LightningTrainer. Same success criteria as above."""
exp_name = "lightning_trainer_restore_test"

datamodule = DummyDataModule(8, 256)
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()

lightning_config = (
LightningConfigBuilder()
.module(LinearModule, input_dim=32, output_dim=4)
.trainer(max_epochs=5, accelerator="cpu")
.fit_params(train_dataloaders=train_loader, val_dataloaders=val_loader)
.build()
)

scaling_config = ray.air.ScalingConfig(num_workers=2, use_gpu=False)

trainer = LightningTrainer(
lightning_config=lightning_config,
scaling_config=scaling_config,
run_config=RunConfig(
local_dir=str(tmpdir),
name=exp_name,
checkpoint_config=CheckpointConfig(num_to_keep=1),
callbacks=[FailureInjectionCallback(num_iters=2)],
),
)

with pytest.raises(TuneError):
result = trainer.fit()

trainer = LightningTrainer.restore(str(tmpdir / exp_name))
result = trainer.fit()

assert not result.error
assert result.metrics["training_iteration"] == 5
assert result.metrics["iterations_since_restore"] == 3
assert tmpdir / exp_name in result.log_dir.parents


def test_rl_trainer_restore(ray_start_4_cpus, tmpdir):
"""Tests restore for RL trainer. Same success criteria as above."""

Expand Down

0 comments on commit 55918b9

Please sign in to comment.