Skip to content

Commit

Permalink
TPUSpawn + IterableDataset error message (#6875)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
ethanwharris and carmocca authored Apr 8, 2021
1 parent 87f0aea commit 1c2ecbf
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 3 deletions.
46 changes: 43 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
import os
import re
import time
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING

import torch
import torch.multiprocessing as mp

from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
Expand All @@ -40,14 +42,51 @@
from omegaconf import DictConfig, ListConfig, OmegaConf


if TYPE_CHECKING:
from torch.nn import Module
from torch.utils.data import DataLoader


class TPUSpawnPlugin(DDPSpawnPlugin):

def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
self.tpu_local_core_rank = 0
self.start_method = None

def setup(self, model: torch.nn.Module) -> torch.nn.Module:
@staticmethod
def _validate_dataloader(dataloaders: Union[List['DataLoader'], 'DataLoader']):
if not isinstance(dataloaders, list):
dataloaders = [dataloaders]

for dataloader in dataloaders:
if not has_len(dataloader):
raise MisconfigurationException(
"TPUs do not currently support IterableDataset objects, the dataset must implement `__len__`."
" HINT: You can mock the length on your dataset to bypass this MisconfigurationException."
)

@staticmethod
def _validate_patched_dataloaders(model: 'Module') -> None:
"""Validate and fail fast if the dataloaders were passed directly to fit.
"""
if hasattr(model, 'train_dataloader') and isinstance(model.train_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.train_dataloader.dataloader)

if hasattr(model, 'val_dataloader') and isinstance(model.val_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.val_dataloader.dataloader)

if hasattr(model, 'test_dataloader') and isinstance(model.test_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.test_dataloader.dataloader)

if hasattr(model, 'predict_dataloader') and isinstance(model.predict_dataloader, _PatchDataLoader):
TPUSpawnPlugin._validate_dataloader(model.predict_dataloader.dataloader)

def connect(self, model: 'Module') -> None:
TPUSpawnPlugin._validate_patched_dataloaders(model)
return super().connect(model)

def setup(self, model: 'Module') -> 'Module':
self.create_mp_queue()
return self.model

Expand All @@ -64,7 +103,8 @@ def distributed_sampler_kwargs(self) -> dict:
def is_distributed(self):
return self.world_size != 1

def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
def process_dataloader(self, dataloader: 'DataLoader') -> MpDeviceLoader:
TPUSpawnPlugin._validate_dataloader(dataloader)
device = xm.xla_device()
dataloader = MpDeviceLoader(dataloader, device)
return dataloader
Expand Down
74 changes: 74 additions & 0 deletions tests/plugins/test_tpu_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock

import pytest
from torch.utils.data import DataLoader

from pytorch_lightning.plugins.training_type import TPUSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader


class BoringModelNoDataloaders(BoringModel):
def train_dataloader(self):
raise NotImplementedError

def val_dataloader(self):
raise NotImplementedError

def test_dataloader(self):
raise NotImplementedError

def predict_dataloader(self):
raise NotImplementedError


_loader = DataLoader(RandomDataset(32, 64))
_loader_no_len = CustomNotImplementedErrorDataloader(_loader)


@pytest.mark.parametrize(
"train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders",
[
(_loader_no_len, None, None, None),
(None, _loader_no_len, None, None),
(None, None, _loader_no_len, None),
(None, None, None, _loader_no_len),
(None, [_loader, _loader_no_len], None, None),
],
)
def test_error_patched_iterable_dataloaders(
tmpdir, train_dataloader, val_dataloaders, test_dataloaders, predict_dataloaders
):
model = BoringModelNoDataloaders()
connector = DataConnector(MagicMock())

connector.attach_dataloaders(
model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
predict_dataloaders=predict_dataloaders,
)

with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
TPUSpawnPlugin(MagicMock()).connect(model)


def test_error_process_iterable_dataloader(tmpdir):
with pytest.raises(MisconfigurationException, match="TPUs do not currently support"):
TPUSpawnPlugin(MagicMock()).process_dataloader(_loader_no_len)

0 comments on commit 1c2ecbf

Please sign in to comment.