Skip to content

Commit

Permalink
Flash predict step (#6577)
Browse files Browse the repository at this point in the history
* add predict_step

* Update predict_loop.py

* Update trainer.py

* Update trainer.py

* resolve bugs

* update

* update

* update

* resolve bug

* resolve some failing tests

* udpate tests

* update

* resolve tests

* add a test

* remove typo

* add a test for attachement

* update

* changed to on_train_dataloader

* remove __flash_special_attr__

* resolve tests

* update

* update

* update

* update on comments

* Update pytorch_lightning/trainer/data_loading.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 23, 2021
1 parent a74909a commit 0995d30
Show file tree
Hide file tree
Showing 17 changed files with 174 additions and 30 deletions.
6 changes: 3 additions & 3 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -882,8 +882,8 @@ Or maybe we have a model that we use to do generation
generated_imgs = model(z)
To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict`` function
By default, LightningModule ``predict`` calls forward, but it can be overriden to add any processing logic.
To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict_step`` function
By default, LightningModule ``predict_step`` calls forward, but it can be overriden to add any processing logic.

.. code-block:: python
Expand All @@ -893,7 +893,7 @@ By default, LightningModule ``predict`` calls forward, but it can be overriden t
imgs = self.decoder(z)
return imgs
def predict(self, batch, batch_idx: int , dataloader_idx: int = None):
def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None):
return self(batch)
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
with self.precision_plugin.test_step_context(), self.training_type_plugin.test_step_context():
return self.training_type_plugin.test_step(*args)

def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
def predict_step(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
"""The actual predict step.
Args:
Expand All @@ -235,7 +235,7 @@ def predict(self, args: List[Union[Any, int]]) -> _STEP_OUTPUT_TYPE:
args[0] = batch

with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
return self.training_type_plugin.predict(*args)
return self.training_type_plugin.predict_step(*args)

def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the training step
Expand Down Expand Up @@ -359,7 +359,12 @@ def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None:

def to_device(self, batch: Any) -> Any:
"""Pushes the batch to the root device"""
return self.batch_to_device(batch, self.root_device)
# Todo (tchaton) Better fix
is_dict = isinstance(batch, dict)
if is_dict:
batch = [batch]
batch = self.batch_to_device(batch, self.root_device)
return batch[0] if is_dict else batch

@property
def amp_backend(self) -> Optional[LightningEnum]:
Expand Down
24 changes: 24 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@ def on_test_end(self) -> None:
"""
# do something at the end of testing

def on_predict_start(self) -> None:
"""
Called at the beginning of predicting.
"""
# do something at the start of predicting

def on_predict_end(self) -> None:
"""
Called at the end of predicting.
"""
# do something at the end of predicting

def on_before_zero_grad(self, optimizer: Optimizer) -> None:
"""
Called after optimizer.step() and before optimizer.zero_grad().
Expand Down Expand Up @@ -594,6 +606,18 @@ def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
will have an argument ``dataloader_idx`` which matches the order here.
"""

def on_train_dataloader(self) -> None:
"""Called before requesting the train dataloader."""

def on_val_dataloader(self) -> None:
"""Called before requesting the val dataloader."""

def on_test_dataloader(self) -> None:
"""Called before requesting the test dataloader."""

def on_predict_dataloader(self) -> None:
"""Called before requesting the predict dataloader."""

def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ def test_epoch_end(self, outputs):
self.log('final_metric', final_value)
"""

def predict(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None):
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None):
"""
Use this function with trainer.predict(...). Override if you need to add any processing logic.
"""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(self, *inputs, **kwargs):
elif trainer and (trainer.sanity_checking or trainer.validating):
output = self.module.validation_step(*inputs, **kwargs)
elif trainer and trainer.predicting:
output = self.module.predict(*inputs, **kwargs)
output = self.module.predict_step(*inputs, **kwargs)
else:
output = self.module(*inputs, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def validation_step(self, *args, **kwargs):
def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def predict(self, *args, **kwargs):
def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def post_training_step(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def validation_step(self, *args, **kwargs):
def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def predict(self, *args, **kwargs):
def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def post_training_step(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def validation_step(self, *args, **kwargs):
def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def predict(self, *args, **kwargs):
def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def training_step_end(self, output):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ def validation_step(self, *args, **kwargs):
def test_step(self, *args, **kwargs):
return self.lightning_module.test_step(*args, **kwargs)

def predict(self, *args, **kwargs):
return self.lightning_module.predict(*args, **kwargs)
def predict_step(self, *args, **kwargs):
return self.lightning_module.predict_step(*args, **kwargs)

def save_checkpoint(self, filepath, weights_only: bool = False):
"""Save model/training states as a checkpoint file through state-dump and file-write.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def validation_step(self, *args, **kwargs):
def test_step(self, *args, **kwargs):
return self.lightning_module.test_step(*args, **kwargs)

def predict(self, *args, **kwargs):
return self.lightning_module.predict(*args, **kwargs)
def predict_step(self, *args, **kwargs):
return self.lightning_module.predict_step(*args, **kwargs)

def training_step_end(self, output):
return output
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N
self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer

# experimental feature for Flash
if hasattr(datamodule, "data_pipeline"):
model.data_pipeline = datamodule.data_pipeline


class _PatchDataLoader(object):
r"""
Expand Down
15 changes: 8 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import platform
from abc import ABC
from copy import deepcopy
from typing import Callable, Iterable, List, Tuple, Union
from typing import Iterable, List, Tuple, Union

from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
Expand Down Expand Up @@ -191,7 +191,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
Args:
model: The current `LightningModule`
"""
self.train_dataloader = self.request_dataloader(model.train_dataloader)
self.train_dataloader = self.request_dataloader(model, "train")

if self.overfit_batches > 0:
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
Expand Down Expand Up @@ -271,7 +271,7 @@ def _reset_eval_dataloader(
"""
# always get the loaders first so we can count how many there are
loader_name = f'{mode}_dataloader'
dataloaders = self.request_dataloader(getattr(model, loader_name))
dataloaders = self.request_dataloader(model, mode)

if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
Expand All @@ -280,7 +280,7 @@ def _reset_eval_dataloader(
# duplicate it the numb of times needed to match the train loaders
if self.overfit_batches > 0:
num_loaders = len(dataloaders)
train_dataloader = self.request_dataloader(getattr(model, 'train_dataloader'))
train_dataloader = self.request_dataloader(model, 'train')
dataloaders = [deepcopy(train_dataloader) for _ in range(num_loaders)]

self.dev_debugger.track_load_dataloader_call(loader_name, dataloaders=dataloaders)
Expand Down Expand Up @@ -380,7 +380,7 @@ def reset_predict_dataloader(self, model) -> None:
if has_loader:
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(model, 'predict')

def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader:
"""Handles downloading data in the GPU or TPU case.
Args:
Expand All @@ -389,9 +389,10 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
Returns:
The dataloader
"""
dataloader = dataloader_fx()
if model.trainer is not None:
model.trainer.call_hook(f"on_{stage}_dataloader")
dataloader: DataLoader = getattr(model, f'{stage}_dataloader')()
dataloader = self._flatten_dl_only(dataloader)

self.accelerator.barrier('get_dataloaders')
return dataloader

Expand Down
12 changes: 10 additions & 2 deletions pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _get_num_dataloaders(self, dataloaders):
length = len(dataloaders[0])
return length

def predict(self, batch, batch_idx, dataloader_idx):
def predict_step(self, batch, batch_idx, dataloader_idx):
# configure args
args = [batch, batch_idx]
if self.num_dataloaders:
Expand All @@ -74,7 +74,7 @@ def predict(self, batch, batch_idx, dataloader_idx):
model_ref = self.trainer.lightning_module

model_ref._current_fx_name = "predict"
predictions = self.trainer.accelerator.predict(args)
predictions = self.trainer.accelerator.predict_step(args)

if predictions is None:
self.warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")
Expand All @@ -99,3 +99,11 @@ def _convert_to_numpy(v):
return results[0]

return results

def on_predict_start(self):
# hook
self.trainer.call_hook("on_predict_start")

def on_predict_end(self):
# hook
self.trainer.call_hook("on_predict_end")
8 changes: 5 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,8 @@ def run_evaluate(self):
return eval_loop_results

def run_predict(self):
self.predict_loop.on_predict_start()

# prepare dataloaders
dataloaders, max_batches = self.predict_loop.get_predict_dataloaders()

Expand All @@ -784,7 +786,6 @@ def run_predict(self):
for dataloader_idx, dataloader in enumerate(dataloaders):
dataloader = self.accelerator.process_dataloader(dataloader)
dl_max_batches = self.predict_loop.max_batches[dataloader_idx]

for batch_idx, batch in enumerate(dataloader):
if batch is None:
continue
Expand All @@ -794,10 +795,11 @@ def run_predict(self):
break

# lightning module methods
with self.profiler.profile("predict"):
self.predict_loop.predict(batch, batch_idx, dataloader_idx)
with self.profiler.profile("predict_step"):
self.predict_loop.predict_step(batch, batch_idx, dataloader_idx)

results = self.predict_loop.on_predict_epoch_end()
self.predict_loop.on_predict_end()
return results

def run_sanity_check(self, ref_model):
Expand Down
2 changes: 1 addition & 1 deletion tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
("training", "training_step"),
("testing", "test_step"),
("validating", "validation_step"),
("predicting", "predict"),
("predicting", "predict_step"),
]
)
def test_lightning_wrapper_module_methods(wrapper_class, stage):
Expand Down
68 changes: 68 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,3 +1159,71 @@ def test_replace_sampler_with_multiprocessing_context(tmpdir):

new_data_loader = trainer.replace_sampler(train, SequentialSampler(train.dataset))
assert (new_data_loader.multiprocessing_context == train.multiprocessing_context)


def test_request_dataloader(tmpdir):
"""
This test asserts dataloader can be modified and properly set to the trainer.
"""

class DataLoaderWrapper:

def __init__(self, loader):
self.loader = loader
self._iter = iter(self.loader)

def __iter__(self):
self._iter = iter(self.loader)
return self._iter

def __next__(self):
return next(self._iter)

class DataLoaderFunc:

def __init__(self, loader):
self.loader = loader

def __call__(self):
return self.loader

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.on_train_dataloader_called = False
self.on_train_batch_start_called = False
self.on_val_dataloader_called = False
self.on_val_batch_start_called = False

def on_train_dataloader(self) -> None:
loader = self.train_dataloader()
self.train_dataloader = DataLoaderFunc(DataLoaderWrapper(loader))
self.on_train_dataloader_called = True

def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None:
assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper)
self.on_train_batch_start_called = True

def on_val_dataloader(self) -> None:
loader = self.val_dataloader()
self.val_dataloader = DataLoaderFunc(DataLoaderWrapper(loader))
self.on_val_dataloader_called = True

def on_validation_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None:
assert isinstance(self.trainer.val_dataloaders[0], DataLoaderWrapper)
self.on_val_batch_start_called = True

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
)
model = TestModel()
trainer.fit(model)
trainer.test(model)
assert model.on_train_dataloader_called
assert model.on_train_batch_start_called
assert model.on_val_dataloader_called
assert model.on_val_batch_start_called
Loading

0 comments on commit 0995d30

Please sign in to comment.