Skip to content

Commit

Permalink
[AIR] <Part 3> Add LightningPredictor to support batch prediction (ra…
Browse files Browse the repository at this point in the history
…y-project#33196)

Signed-off-by: woshiyyya <xiaoyunxuan1998@gmail.com>
Signed-off-by: Jack He <jackhe2345@gmail.com>
  • Loading branch information
woshiyyya authored and ProjectsByJackHe committed May 4, 2023
1 parent 2445f47 commit a0b8938
Show file tree
Hide file tree
Showing 6 changed files with 339 additions and 8 deletions.
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_lightning_predictor",
size = "medium",
srcs = ["tests/test_lightning_predictor.py"],
tags = ["team:ml", "exclusive", "ray_air", "gpu"],
deps = [":train_lib"]
)

py_test(
name = "test_minimal",
size = "small",
Expand Down
8 changes: 7 additions & 1 deletion python/ray/train/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,15 @@
# isort: on

from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.lightning.lightning_predictor import LightningPredictor
from ray.train.lightning.lightning_trainer import (
LightningTrainer,
LightningConfigBuilder,
)

__all__ = ["LightningTrainer", "LightningConfigBuilder", "LightningCheckpoint"]
__all__ = [
"LightningTrainer",
"LightningConfigBuilder",
"LightningCheckpoint",
"LightningPredictor",
]
15 changes: 10 additions & 5 deletions python/ray/train/lightning/lightning_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil

from inspect import isclass
from typing import Optional, Type
from typing import Optional, Type, Dict, Any

from ray.air.constants import MODEL_KEY
from ray.air._internal.checkpointing import save_preprocessor_to_dir
Expand Down Expand Up @@ -67,15 +67,17 @@ def from_path(
return checkpoint

def get_model(
self, model_class: Type[pl.LightningModule], **load_from_checkpoint_kwargs
self,
model_class: Type[pl.LightningModule],
load_from_checkpoint_kwargs: Optional[Dict[str, Any]] = None,
) -> pl.LightningModule:
"""Retrieve the model stored in this checkpoint.
Args:
model_class: A subclass of ``pytorch_lightning.LightningModule`` that
defines your model and training logic.
load_from_checkpoint_kwargs: Arguments to pass into
``pl.Trainer.load_from_checkpoint``
defines your model and training logic.
load_from_checkpoint_kwargs: A dictionary of arguments to pass into
``pl.LightningModule.load_from_checkpoint``
Returns:
pl.LightningModule: An instance of the loaded model.
Expand All @@ -85,6 +87,9 @@ def get_model(
"'model_class' must be a class, not an instantiated Lightning trainer."
)

if not load_from_checkpoint_kwargs:
load_from_checkpoint_kwargs = {}

with self.as_directory() as checkpoint_dir:
ckpt_path = os.path.join(checkpoint_dir, MODEL_KEY)
if not os.path.exists(ckpt_path):
Expand Down
117 changes: 117 additions & 0 deletions python/ray/train/lightning/lightning_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
from typing import Optional, Type, Dict, Any

from ray.data.preprocessor import Preprocessor
from ray.train.lightning.lightning_checkpoint import LightningCheckpoint
from ray.train.torch.torch_predictor import TorchPredictor
from ray.util.annotations import PublicAPI
import pytorch_lightning as pl

logger = logging.getLogger(__name__)


@PublicAPI(stability="alpha")
class LightningPredictor(TorchPredictor):
"""A predictor for PyTorch Lightning modules.
Example:
.. testcode:: python
import torch
import numpy as np
import pytorch_lightning as pl
from ray.train.lightning import LightningPredictor
class MyModel(pl.LightningModule):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x):
out = self.linear(x)
return out
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = torch.nn.functional.mse_loss(y_hat, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
batch_size, input_dim, output_dim = 10, 3, 5
model = MyModel(input_dim=input_dim, output_dim=output_dim)
predictor = LightningPredictor(model=model, use_gpu=False)
batch = np.random.rand(batch_size, input_dim).astype(np.float32)
# Internally, LightningPredictor.predict() invokes the forward() method
# of the model to generate predictions
output = predictor.predict(batch)
assert output["predictions"].shape == (batch_size, output_dim)
.. testoutput::
:hide:
:options: +ELLIPSIS
Args:
model: The PyTorch Lightning module to use for predictions.
preprocessor: A preprocessor used to transform data batches prior
to prediction.
use_gpu: If set, the model will be moved to GPU on instantiation and
prediction happens on GPU.
"""

def __init__(
self,
model: pl.LightningModule,
preprocessor: Optional["Preprocessor"] = None,
use_gpu: bool = False,
):
super(LightningPredictor, self).__init__(
model=model, preprocessor=preprocessor, use_gpu=use_gpu
)

@classmethod
def from_checkpoint(
cls,
checkpoint: LightningCheckpoint,
model_class: Type[pl.LightningModule],
*,
preprocessor: Optional[Preprocessor] = None,
use_gpu: bool = False,
load_from_checkpoint_kwargs: Optional[Dict[str, Any]] = None
) -> "LightningPredictor":
"""Instantiate the LightningPredictor from a Checkpoint.
The checkpoint is expected to be a result of ``LightningTrainer``.
Args:
checkpoint: The checkpoint to load the model and preprocessor from.
It is expected to be from the result of a ``LightningTrainer`` run.
model_class: A subclass of ``pytorch_lightning.LightningModule`` that
defines your model and training logic. Note that this is a class type
instead of a model instance.
preprocessor: A preprocessor used to transform data batches prior
to prediction.
use_gpu: If set, the model will be moved to GPU on instantiation and
prediction happens on GPU.
load_from_checkpoint_kwargs: A dictionary of arguments to pass into
``pl.LightningModule.load_from_checkpoint``
"""
if not load_from_checkpoint_kwargs:
load_from_checkpoint_kwargs = {}

model = checkpoint.get_model(
model_class=model_class,
load_from_checkpoint_kwargs=load_from_checkpoint_kwargs,
)
return cls(model=model, preprocessor=preprocessor, use_gpu=use_gpu)
59 changes: 57 additions & 2 deletions python/ray/train/tests/lightning_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytorch_lightning as pl
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchmetrics import Accuracy


class LinearModule(pl.LightningModule):
Expand Down Expand Up @@ -78,3 +80,56 @@ def train_dataloader(self):

def val_dataloader(self):
return DataLoader(self.val_data, batch_size=self.batch_size)


class LightningMNISTClassifier(pl.LightningModule):
def __init__(self, lr: float, layer_1: int, layer_2: int):
super(LightningMNISTClassifier, self).__init__()

self.lr = lr
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
self.layer_3 = torch.nn.Linear(layer_2, 10)
self.accuracy = Accuracy()

def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)

def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", acc)
return loss

def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": acc}

def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)

def predict_step(self, batch, batch_idx, dataloader_idx=None):
x = batch
logits = self.forward(x)
return torch.argmax(logits, dim=-1)
Loading

0 comments on commit a0b8938

Please sign in to comment.