Skip to content

Commit

Permalink
attempt mypy fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sumanthratna committed Jan 28, 2021
1 parent b52fa07 commit fb7ed82
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 30 deletions.
16 changes: 9 additions & 7 deletions tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,33 @@
import os
import sys

from pytorch_lightning import Trainer # noqa: E402
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402
from tests.base import EvalModelTemplate # noqa: E402
from tests.base.develop_pipelines import run_prediction # noqa: E402
from tests.base.develop_utils import (reset_seed, # noqa: E402
set_random_master_port)

# this is needed because Conda does not use `PYTHONPATH` env var while pip and virtualenv do
PYTHONPATH = os.getenv('PYTHONPATH', '')
if ':' in PYTHONPATH:
sys.path = PYTHONPATH.split(':') + sys.path

from pytorch_lightning import Trainer # noqa: E402
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402

if HOROVOD_AVAILABLE:
import horovod.torch as hvd # noqa: E402
else:
print('You requested to import Horovod which is missing or not supported for your OS.')

from tests.base import EvalModelTemplate # noqa: E402
from tests.base.develop_pipelines import run_prediction # noqa: E402
from tests.base.develop_utils import set_random_master_port, reset_seed # noqa: E402


parser = argparse.ArgumentParser()
parser.add_argument('--trainer-options', required=True)
parser.add_argument('--on-gpu', action='store_true', default=False)


def run_test_from_config(trainer_options):
def run_test_from_config(trainer_options) -> None:
"""Trains the default model with the given config."""
set_random_master_port()
reset_seed()
Expand Down
49 changes: 26 additions & 23 deletions tests/trainer/logging/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,25 @@
Tests to ensure that the training loop works with a dict (1.0)
"""
from copy import deepcopy
from typing import Any, Callable, Dict, List, Tuple, TypeVar

import pytest
import torch
from torch.utils.data import DataLoader

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import \
CallbackHookNameValidator
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDataset
from torch.utils.data import DataLoader

F = TypeVar('F', bound=Callable[..., Any])

def decorator_with_arguments(fx_name='', hook_fx_name=None):
def decorator(func):
def wrapper(self, *args, **kwargs):

def decorator_with_arguments(fx_name='', hook_fx_name=None) -> Callable[[F], F]:
def decorator(func: F) -> F:
def wrapper(self, *args, **kwargs) -> Any:
# Set information
self._current_fx_name = fx_name
self._current_hook_fx_name = hook_fx_name
Expand All @@ -47,7 +50,7 @@ def wrapper(self, *args, **kwargs):
return decorator


def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch):
def test__logger_connector__epoch_result_store__train(tmpdir, monkeypatch) -> None:
"""
Tests that LoggerConnector will properly capture logged information
and reduce them
Expand All @@ -59,7 +62,7 @@ class TestModel(BoringModel):
train_losses = []

@decorator_with_arguments(fx_name="training_step")
def training_step(self, batch, batch_idx):
def training_step(self, batch, batch_idx) -> Dict[str, Any]:
output = self.layer(batch)
loss = self.loss(batch, output)

Expand All @@ -69,7 +72,7 @@ def training_step(self, batch, batch_idx):

return {"loss": loss}

def training_step_end(self, *_):
def training_step_end(self, *_) -> None:
self.train_results = deepcopy(self.trainer.logger_connector.cached_results)

model = TestModel()
Expand Down Expand Up @@ -105,7 +108,7 @@ def training_step_end(self, *_):
assert generated == excepted


def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
def test__logger_connector__epoch_result_store__train__ttbt(tmpdir) -> None:
"""
Tests that LoggerConnector will properly capture logged information with ttbt
and reduce them
Expand All @@ -118,23 +121,23 @@ def test__logger_connector__epoch_result_store__train__ttbt(tmpdir):
y_seq_list = torch.rand(batch_size, sequence_size, 1).tolist()

class MockSeq2SeqDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
def __getitem__(self, i) -> Tuple[Any, List[List[List[float]]]]:
return x_seq, y_seq_list

def __len__(self):
def __len__(self) -> int:
return 1

class TestModel(BoringModel):

train_losses = []

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.test_hidden = None
self.layer = torch.nn.Linear(2, 2)

@decorator_with_arguments(fx_name="training_step")
def training_step(self, batch, batch_idx, hiddens):
def training_step(self, batch, batch_idx, hiddens) -> Dict[str, Any]:
self.test_hidden = torch.rand(1)

x_tensor, y_list = batch
Expand All @@ -155,15 +158,15 @@ def training_step(self, batch, batch_idx, hiddens):
def on_train_epoch_start(self) -> None:
self.test_hidden = None

def train_dataloader(self):
def train_dataloader(self) -> Any:
return torch.utils.data.DataLoader(
dataset=MockSeq2SeqDataset(),
batch_size=batch_size,
shuffle=False,
sampler=None,
)

def training_step_end(self, *_):
def training_step_end(self, *_) -> None:
self.train_results = deepcopy(self.trainer.logger_connector.cached_results)

model = TestModel()
Expand Down Expand Up @@ -200,7 +203,7 @@ def training_step_end(self, *_):


@pytest.mark.parametrize('num_dataloaders', [1, 2])
def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, monkeypatch, num_dataloaders):
def test__logger_connector__epoch_result_store__test_multi_dataloaders(tmpdir, monkeypatch, num_dataloaders) -> None:
"""
Tests that LoggerConnector will properly capture logged information in multi_dataloaders scenario
"""
Expand All @@ -211,7 +214,7 @@ class TestModel(BoringModel):
test_losses = {}

@decorator_with_arguments(fx_name="test_step")
def test_step(self, batch, batch_idx, dl_idx=0):
def test_step(self, batch, batch_idx, dl_idx=0) -> Dict[str, Any]:
output = self.layer(batch)
loss = self.loss(batch, output)

Expand All @@ -221,15 +224,15 @@ def test_step(self, batch, batch_idx, dl_idx=0):
self.log("test_loss", loss, on_step=True, on_epoch=True)
return {"test_loss": loss}

def on_test_batch_end(self, *args, **kwargs):
def on_test_batch_end(self, *args, **kwargs) -> None:
# save objects as it will be reset at the end of epoch.
self.batch_results = deepcopy(self.trainer.logger_connector.cached_results)

def on_test_epoch_end(self):
def on_test_epoch_end(self) -> None:
# save objects as it will be reset at the end of epoch.
self.reduce_results = deepcopy(self.trainer.logger_connector.cached_results)

def test_dataloader(self):
def test_dataloader(self) -> List[Any]:
return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)]

model = TestModel()
Expand Down Expand Up @@ -266,7 +269,7 @@ def test_dataloader(self):
assert abs(expected.item() - generated.item()) < 1e-6


def test_call_back_validator(tmpdir):
def test_call_back_validator(tmpdir) -> None:

funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')])

Expand Down Expand Up @@ -368,7 +371,7 @@ def test_call_back_validator(tmpdir):


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires two GPUs")
def test_epoch_results_cache_dp(tmpdir):
def test_epoch_results_cache_dp(tmpdir) -> None:

root_device = torch.device("cuda", 0)

Expand Down

0 comments on commit fb7ed82

Please sign in to comment.