Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] <Part 3> Add LightningPredictor to support batch prediction #33196

Merged
merged 21 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_lightning_predictor",
size = "medium",
srcs = ["tests/test_lightning_predictor.py"],
tags = ["team:ml", "exclusive", "ray_air", "gpu"],
deps = [":train_lib"]
)

py_test(
name = "test_minimal",
size = "small",
Expand Down
8 changes: 7 additions & 1 deletion python/ray/train/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@
# isort: on

from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.lightning.lightning_predictor import LightningPredictor
from ray.train.lightning.lightning_trainer import (
LightningTrainer,
LightningConfigBuilder,
)

__all__ = ["LightningTrainer", "LightningConfigBuilder", "LightningCheckpoint"]
__all__ = [
"LightningTrainer",
"LightningConfigBuilder",
"LightningCheckpoint",
"LightningPredictor",
]
4 changes: 2 additions & 2 deletions python/ray/train/lightning/lightning_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def get_model(

Args:
model_class: A subclass of ``pytorch_lightning.LightningModule`` that
defines your model and training logic.
defines your model and training logic.
load_from_checkpoint_kwargs: Arguments to pass into
``pl.Trainer.load_from_checkpoint``
``pl.LightningModule.load_from_checkpoint``

Returns:
pl.LightningModule: An instance of the loaded model.
Expand Down
111 changes: 111 additions & 0 deletions python/ray/train/lightning/lightning_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import logging
from typing import Optional, Type

from ray.data.preprocessor import Preprocessor
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.torch.torch_predictor import TorchPredictor
from ray.util.annotations import PublicAPI
import pytorch_lightning as pl

logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
class LightningPredictor(TorchPredictor):
"""A predictor for PyTorch Lightning modules.

Example:
.. testcode:: python

import torch
import numpy as np
import pytorch_lightning as pl
from ray.train.lightning import LightningPredictor


class MyModel(pl.LightningModule):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)

def forward(self, x):
out = self.linear(x)
return out

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = torch.nn.functional.mse_loss(y_hat, y)
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer


batch_size, input_dim, output_dim = 10, 3, 5
model = MyModel(input_dim=input_dim, output_dim=output_dim)
predictor = LightningPredictor(model=model, use_gpu=False)
batch = np.random.rand(batch_size, input_dim).astype(np.float32)

# Internally, LightningPredictor.predict() invokes the forward() method
# of the model to generate predictions
output = predictor.predict(batch)

assert output["predictions"].shape == (batch_size, output_dim)


.. testoutput::
:hide:
:options: +ELLIPSIS

Args:
model: The PyTorch Lightning module to use for predictions.
preprocessor: A preprocessor used to transform data batches prior
to prediction.
use_gpu: If set, the model will be moved to GPU on instantiation and
prediction happens on GPU.
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved


"""

def __init__(
self,
model: pl.LightningModule,
preprocessor: Optional["Preprocessor"] = None,
use_gpu: bool = False,
):
super(LightningPredictor, self).__init__(
model=model, preprocessor=preprocessor, use_gpu=use_gpu
)

@classmethod
def from_checkpoint(
cls,
checkpoint: LightningCheckpoint,
model: Type[pl.LightningModule],
*,
preprocessor: Optional[Preprocessor] = None,
use_gpu: bool = False,
**load_from_checkpoint_kwargs,
) -> "LightningPredictor":
"""Instantiate the LightningPredictor from a Checkpoint.

The checkpoint is expected to be a result of ``LightningTrainer``.

Args:
checkpoint: The checkpoint to load the model and preprocessor from.
It is expected to be from the result of a ``LightningTrainer`` run.
model: A subclass of ``pytorch_lightning.LightningModule`` that
defines your model and training logic. Note that this is a class type
instead of a model instance.
preprocessor: A preprocessor used to transform data batches prior
to prediction.
use_gpu: If set, the model will be moved to GPU on instantiation and
prediction happens on GPU.
**load_from_checkpoint_kwargs: Arguments to pass into
``pl.LightningModule.load_from_checkpoint``
"""
model = checkpoint.get_model(model, **load_from_checkpoint_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe you don't need to unpack then pack. just pass it like:

model = checkpoint.get_model(
    model, load_from_checkpoint_kwargs=load_from_checkpoint_kwargs
)

will do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Updated!

return cls(model=model, preprocessor=preprocessor, use_gpu=use_gpu)
59 changes: 57 additions & 2 deletions python/ray/train/tests/lightning_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytorch_lightning as pl
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchmetrics import Accuracy


class LinearModule(pl.LightningModule):
Expand Down Expand Up @@ -78,3 +80,56 @@ def train_dataloader(self):

def val_dataloader(self):
return DataLoader(self.val_data, batch_size=self.batch_size)


class LightningMNISTClassifier(pl.LightningModule):
def __init__(self, lr: float, layer_1: int, layer_2: int):
super(LightningMNISTClassifier, self).__init__()

self.lr = lr
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
self.layer_3 = torch.nn.Linear(layer_2, 10)
self.accuracy = Accuracy()

def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)

def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", acc)
return loss

def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": acc}

def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)

def predict_step(self, batch, batch_idx, dataloader_idx=None):
x = batch
logits = self.forward(x)
return torch.argmax(logits, dim=-1)
140 changes: 140 additions & 0 deletions python/ray/train/tests/test_lightning_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import re
import pytest
import torch

import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader

import ray
from ray.air.constants import MAX_REPR_LENGTH, MODEL_KEY
from ray.air.util.data_batch_conversion import convert_batch_type_to_pandas
from ray.train.tests.conftest import * # noqa
from ray.train.batch_predictor import BatchPredictor
from ray.train.lightning import LightningCheckpoint, LightningPredictor
from ray.train.tests.dummy_preprocessor import DummyPreprocessor
from ray.train.tests.lightning_test_utils import LightningMNISTClassifier


def test_repr():
model = pl.LightningModule()
predictor = LightningPredictor(model)

representation = repr(predictor)

assert len(representation) < MAX_REPR_LENGTH
pattern = re.compile("^LightningPredictor\\((.*)\\)$")
assert pattern.match(representation)


def save_checkpoint(model: pl.LightningModule, ckpt_path: str):
trainer = pl.Trainer(max_epochs=0)
trainer.fit(model, train_dataloaders=DataLoader(torch.randn(1)))
trainer.save_checkpoint(ckpt_path)


@pytest.mark.parametrize(
"checkpoint_source", ["from_path", "from_uri", "from_directory"]
)
@pytest.mark.parametrize("use_gpu", [True, False])
@pytest.mark.parametrize("use_preprocessor", [True, False])
def test_predictor(
mock_s3_bucket_uri,
tmpdir,
checkpoint_source: str,
use_preprocessor: bool,
use_gpu: bool,
):
woshiyyya marked this conversation as resolved.
Show resolved Hide resolved
"""Test LightningPredictor instantiation and prediction step."""
model_config = {
"layer_1": 32,
"layer_2": 64,
"lr": 1e-4,
}
model = LightningMNISTClassifier(**model_config)

ckpt_path = str(tmpdir / MODEL_KEY)
save_checkpoint(model, ckpt_path)

# Test load checkpoint from local dir or remote path
checkpoint = LightningCheckpoint.from_path(ckpt_path)
if checkpoint_source == "from_uri":
checkpoint.to_uri(mock_s3_bucket_uri)
checkpoint = LightningCheckpoint.from_uri(mock_s3_bucket_uri)
if checkpoint_source == "from_directory":
checkpoint = LightningCheckpoint.from_directory(tmpdir)

preprocessor = DummyPreprocessor() if use_preprocessor else None

# Instantiate a predictor from checkpoint
predictor = LightningPredictor.from_checkpoint(
checkpoint=checkpoint,
model=LightningMNISTClassifier,
use_gpu=use_gpu,
preprocessor=preprocessor,
**model_config,
)

# Create synthetic input data
batch_size = 10
batch = np.random.rand(batch_size, 1, 28, 28).astype(np.float32)

output = predictor.predict(batch)

assert len(output["predictions"]) == batch_size
if preprocessor:
assert predictor.get_preprocessor().has_preprocessed


@pytest.mark.parametrize("use_gpu", [True, False])
def test_batch_predictor(tmpdir, use_gpu: bool):
"""Test batch prediction with a LightningPredictor."""
batch_size = 32
synthetic_data = convert_batch_type_to_pandas(
{
"image": np.random.rand(batch_size, 1, 28, 28).astype(np.float32),
"label": np.random.randint(0, 10, (batch_size,)),
}
)
ds = ray.data.from_pandas(synthetic_data)

# Create a PTL native checkpoint
ckpt_path = str(tmpdir / MODEL_KEY)
model_config = {
"layer_1": 32,
"layer_2": 64,
"lr": 1e-4,
}
model = LightningMNISTClassifier(**model_config)
save_checkpoint(model, ckpt_path)

# Create a LightningCheckpoint from the native checkpoint
checkpoint = LightningCheckpoint.from_path(ckpt_path)

batch_predictor = BatchPredictor(
checkpoint=checkpoint,
predictor_cls=LightningPredictor,
use_gpu=use_gpu,
model=LightningMNISTClassifier,
**model_config,
)

predictions = batch_predictor.predict(
ds,
feature_columns=["image"],
keep_columns=["label"],
batch_size=8,
min_scoring_workers=2,
max_scoring_workers=2,
num_gpus_per_worker=1 if use_gpu else 0,
)

assert predictions.count() == batch_size


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", "-x", __file__]))