From 661405dac034e87f14410e90d14bf6b69c06880b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 23 Feb 2020 02:24:59 +0100 Subject: [PATCH 01/32] first pass for LightningModule typehints --- pytorch_lightning/core/lightning.py | 110 ++++++++++++++++------------ 1 file changed, 65 insertions(+), 45 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1c1373a0b6c61..7d24343b4317b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -5,11 +5,14 @@ import warnings from abc import ABC, abstractmethod from argparse import Namespace -from typing import Optional, Union, Dict, Callable +from typing import Any, Union, Tuple, List, Optional, Callable, Dict import torch import torch.distributed as dist -from torch.optim import Adam +from torch import Tensor +from torch.nn import Module +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning.core.decorators import data_loader from pytorch_lightning.core.grads import GradInformation @@ -89,7 +92,7 @@ def forward(self, x): log.info(*args, **kwargs) @abstractmethod - def forward(self, *args, **kwargs): + def forward(self, *args: Any, **kwargs: Any) -> Any: r""" Same as torch.nn.Module.forward(), however in Lightning you want this to define the operations you want to use for prediction (ie: on a server or as a feature extractor). @@ -138,7 +141,7 @@ def forward(self, batch): """ - def training_step(self, *args, **kwargs): + def training_step(self, *args, **kwargs) -> dict: r"""return loss, dict with metrics for tqdm Args: @@ -219,7 +222,7 @@ def training_step(self, batch, batch_idx, hiddens): if you want to break out of the current training epoch early. """ - def training_end(self, *args, **kwargs): + def training_end(self, outputs: dict) -> dict: """return loss, dict with metrics for tqdm :param outputs: What you return in `training_step`. @@ -292,7 +295,7 @@ def training_step(self, batch, batch_idx, hiddens): break out of the current training epoch early. """ - def validation_step(self, *args, **kwargs): + def validation_step(self, *args, **kwargs) -> dict: r""" This is the validation loop. It is called for each batch of the validation set. @@ -364,7 +367,7 @@ def validation_step(self, batch, batch_idx, dataset_idx): have been disabled. At the end of validation, model goes back to training mode and gradients are enabled. """ - def test_step(self, *args, **kwargs): + def test_step(self, *args, **kwargs) -> dict: """return whatever outputs will need to be aggregated in test_end :param batch: The output of your dataloader. A tensor, tuple or list :param int batch_idx: Integer displaying which batch this is @@ -433,7 +436,7 @@ def test_step(self, batch, batch_idx, dataset_idx): The `dataset_idx` corresponds to the order of datasets returned in `test_dataloader`. """ - def validation_end(self, outputs): + def validation_end(self, outputs: list) -> dict: """Outputs has the appended output after each validation step. :param outputs: List of outputs you defined in validation_step, or if there are multiple dataloaders, @@ -505,7 +508,7 @@ def validation_end(self, outputs): """ - def test_end(self, outputs): + def test_end(self, outputs: list) -> dict: """Outputs has the appended output after each test step. :param outputs: List of outputs you defined in test_step, or if there are multiple dataloaders, @@ -569,7 +572,7 @@ def test_end(self, outputs): """ - def configure_ddp(self, model, device_ids): + def configure_ddp(self, model: 'LightningModule', device_ids: list) -> Module: r""" Override to init DDP in your own way or with your own wrapper. @@ -580,8 +583,8 @@ def configure_ddp(self, model, device_ids): 3. On a testing batch, the call goes to model.test_step Args: - model (:class:`.LightningModule`): the LightningModule currently being optimized - device_ids (list): the list of GPU ids + model: the LightningModule currently being optimized + device_ids: the list of GPU ids Return: DDP wrapped model @@ -609,7 +612,7 @@ def configure_ddp(self, model, device_ids): ) return model - def init_ddp_connection(self, proc_rank, world_size): + def init_ddp_connection(self, proc_rank: int, world_size: int): r""" Override to define your custom way of setting up a distributed environment. @@ -617,8 +620,8 @@ def init_ddp_connection(self, proc_rank, world_size): Lightning's implementation uses env:// init by default and sets the first node as root. Args: - proc_rank (int): The current process rank within the node. - world_size (int): Number of GPUs being use across all nodes. (num_nodes*nb_gpu_nodes). + proc_rank: The current process rank within the node. + world_size: Number of GPUs being use across all nodes. (num_nodes*nb_gpu_nodes). Example ------- .. code-block:: python @@ -687,16 +690,22 @@ def init_ddp_connection(self): os.environ['MASTER_ADDR'] = root_node dist.init_process_group('nccl', rank=proc_rank, world_size=world_size) - def configure_apex(self, amp, model, optimizers, amp_level): + def configure_apex( + self, + amp: object, + model: 'LightningModule', + optimizers: List[Optimizer], + amp_level: str + ): r""" Override to init AMP your own way Must return a model and list of optimizers Args: - amp (object): pointer to amp library object - model (:class:`.LightningModule`): pointer to current lightningModule - optimizers (list): list of optimizers passed in configure_optimizers() - amp_level (str): AMP mode chosen ('O1', 'O2', etc...) + amp: pointer to amp library object + model: pointer to current lightningModule + optimizers: list of optimizers passed in configure_optimizers() + amp_level: AMP mode chosen ('O1', 'O2', etc...) Return: Apex wrapped model and optimizers @@ -719,7 +728,9 @@ def configure_apex(self, amp, model, optimizers, amp_level): return model, optimizers - def configure_optimizers(self): + def configure_optimizers(self) -> Union[ + Optimizer, List[Optimizer], Tuple[Optimizer, ...], Tuple[List[Optimizer], list] + ]: r""" This is where you choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple. @@ -773,18 +784,25 @@ def configure_optimizers(self): """ return Adam(self.parameters(), lr=1e-3) - def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None): + def optimizer_step( + self, + epoch: int, + batch_idx: int, + optimizer: Optimizer, + optimizer_idx: int, + second_order_closure: Optional[Callable] = None, + ): r""" Override this method to adjust the default way the Trainer calls each optimizer. By default, Lightning calls .step() and zero_grad() as shown in the example once per optimizer. Args: - epoch (int): Current epoch - batch_idx (int): Index of current batch - optimizer (torch.nn.Optimizer): A PyTorch optimizer - optimizer_idx (int): If you used multiple optimizers this indexes into that list - second_order_closure (int): closure for second order methods + epoch: Current epoch + batch_idx: Index of current batch + optimizer: A PyTorch optimizer + optimizer_idx: If you used multiple optimizers this indexes into that list + second_order_closure: closure for second order methods Example ------- @@ -840,15 +858,15 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, sec # clear gradients optimizer.zero_grad() - def tbptt_split_batch(self, batch, split_size): + def tbptt_split_batch(self, batch: Tensor, split_size: int) -> List[Tensor]: r""" When using truncated backpropagation through time, each batch must be split along the time dimension. Lightning handles this by default, but for custom behavior override this function. Args: - batch (torch.nn.Tensor): Current batch - split_size (int): How big the split is + batch: Current batch + split_size: How big the split is Return: list of batch splits. Each split will be passed to forward_step to enable truncated @@ -929,7 +947,7 @@ def prepare_data(self): """ return None - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """Implement a PyTorch DataLoader :return: PyTorch DataLoader @@ -962,7 +980,6 @@ def train_dataloader(self): return loader """ - return None @data_loader def tng_dataloader(self): # todo: remove in v0.8.0 @@ -975,7 +992,7 @@ def tng_dataloader(self): # todo: remove in v0.8.0 " and this method will be removed in v0.8.0", DeprecationWarning) return output - def test_dataloader(self): + def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: r""" Return a dataloader. It will not be called every epoch unless you set @@ -993,7 +1010,7 @@ def test_dataloader(self): .. note:: Lightning adds the correct sampler for distributed and arbitrary hardware. No need to set yourself. Return: - PyTorch DataLoader + Single or multiple PyTorch DataLoader Example ------- @@ -1016,9 +1033,8 @@ def test_dataloader(self): .. note:: If you want to change the data during every epoch DON'T use the data_loader decorator. """ - return None - def val_dataloader(self): + def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]: r""" Return a dataloader. It will not be called every epoch unless you set @@ -1035,7 +1051,7 @@ def val_dataloader(self): .. note:: Lightning adds the correct sampler for distributed and arbitrary hardware No need to set yourself. Return: - PyTorch DataLoader + Single or multiple PyTorch DataLoader Example ------- @@ -1086,10 +1102,14 @@ def val_dataloader(self): .. note:: In the case where you return multiple `val_dataloaders`, the `validation_step` will have an argument `dataset_idx` which matches the order here. """ - return None @classmethod - def load_from_metrics(cls, weights_path, tags_csv, map_location=None): + def load_from_metrics( + cls, + weights_path: str, + tags_csv: str, + map_location: Optional[Dict[str, str]] = None + ) -> 'LightningModule': r""" Warning: Deprecated in version 0.7.0. @@ -1226,7 +1246,7 @@ def _load_model_state(cls, checkpoint): return model - def summarize(self, mode): + def summarize(self, mode: str): model_summary = ModelSummary(self, mode=mode) log.info('\n' + model_summary.__str__()) @@ -1261,13 +1281,13 @@ def unfreeze(self): self.train() - def on_load_checkpoint(self, checkpoint): + def on_load_checkpoint(self, checkpoint: dict): r""" Called by lightning to restore your model. If you saved something with **on_save_checkpoint** this is your chance to restore this. Args: - checkpoint (dict): Loaded checkpoint + checkpoint: Loaded checkpoint Example @@ -1283,14 +1303,14 @@ def on_load_checkpoint(self, checkpoint): No need for you to restore anything regarding training. """ - def on_save_checkpoint(self, checkpoint): + def on_save_checkpoint(self, checkpoint: dict): r""" Called by lightning when saving a checkpoint to give you a chance to store anything else you might want to save Args: - checkpoint (dic): Checkpoint to be saved + checkpoint: Checkpoint to be saved Example ------- @@ -1306,7 +1326,7 @@ def on_save_checkpoint(self, checkpoint): """ - def get_tqdm_dict(self): + def get_tqdm_dict(self) -> dict: r""" Additional items to be displayed in the progress bar. From 9b5e7840288cbabb0397537358fee456ed736600 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 23 Feb 2020 19:01:19 +0100 Subject: [PATCH 02/32] fix return types --- pytorch_lightning/core/lightning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7d24343b4317b..7272e3dc4359e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -696,7 +696,7 @@ def configure_apex( model: 'LightningModule', optimizers: List[Optimizer], amp_level: str - ): + ) -> Tuple['LightningModule', List[Optimizer]]: r""" Override to init AMP your own way Must return a model and list of optimizers @@ -858,7 +858,7 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx, sec # clear gradients optimizer.zero_grad() - def tbptt_split_batch(self, batch: Tensor, split_size: int) -> List[Tensor]: + def tbptt_split_batch(self, batch: Tensor, split_size: int) -> list: r""" When using truncated backpropagation through time, each batch must be split along the time dimension. From dc89b90c48321500c0a20269370cbfd98ee6f377 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Feb 2020 21:27:59 +0100 Subject: [PATCH 03/32] add missing types --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7272e3dc4359e..5ddac94251231 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1213,7 +1213,7 @@ def __init__(self, hparams): return model @classmethod - def _load_model_state(cls, checkpoint): + def _load_model_state(cls, checkpoint: dict) -> 'LightningModule': cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters ckpt_hparams = checkpoint.get('hparams') From d3460cb955d33f662872bb50e60447023462284a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Feb 2020 21:32:56 +0100 Subject: [PATCH 04/32] add type annotations to grads.py --- pytorch_lightning/core/grads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py index d98a6efeb85ca..08d99b78ec9c1 100644 --- a/pytorch_lightning/core/grads.py +++ b/pytorch_lightning/core/grads.py @@ -7,7 +7,7 @@ class GradInformation(nn.Module): - def grad_norm(self, norm_type): + def grad_norm(self, norm_type: float) -> dict: results = {} total_norm = 0 for name, p in self.named_parameters(): From cdfd1d10385b2a0b0847c01b6eeb1a2aae7aae56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Feb 2020 21:35:57 +0100 Subject: [PATCH 05/32] add type annotations to hooks.py --- pytorch_lightning/core/hooks.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 500ba247f1414..9e97503e55559 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -14,10 +14,13 @@ 3. Add the correct place in the :py:mod:`pytorch_lightning.models.trainer` where it should be called. """ - +from typing import Any import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer +from pytorch_lightning import Trainer try: from apex import amp @@ -49,7 +52,7 @@ def on_train_end(self): """ # do something at the end of training - def on_batch_start(self, batch): + def on_batch_start(self, batch: Any): """Called in the training loop before anything happens for that batch. :param batch: @@ -77,7 +80,7 @@ def on_post_performance_check(self): """Called at the very end of the validation loop.""" # do something before validation end - def on_before_zero_grad(self, optimizer): + def on_before_zero_grad(self, optimizer: Optimizer): """Called after optimizer.step() and before optimizer.zero_grad() Called in the training loop after taking an optimizer step and before zeroing grads. @@ -116,7 +119,7 @@ def on_after_backward(self): """ - def backward(self, trainer, loss, optimizer, optimizer_idx): + def backward(self, trainer: Trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int): """Override backward with your own implementation if you need to :param trainer: Pointer to the trainer From 3526a77b7a16afc2cec4c30fc3924cc4abadbe79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Feb 2020 21:49:01 +0100 Subject: [PATCH 06/32] add type annotation to memory.py --- pytorch_lightning/core/memory.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 967710d11d594..58c10af67201f 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -7,14 +7,17 @@ import os import subprocess from subprocess import PIPE +from typing import Tuple, Dict, Union import numpy as np import torch +from pytorch_lightning import LightningModule + class ModelSummary(object): - def __init__(self, model, mode='full'): + def __init__(self, model: LightningModule, mode: str = 'full'): ''' Generates summaries of model layers and dimensions. ''' @@ -31,7 +34,7 @@ def __str__(self): def __repr__(self): return self.summary.__str__() - def named_modules(self): + def named_modules(self) -> list: if self.mode == 'full': mods = self.model.named_modules() mods = list(mods)[1:] # do not include root module (LightningModule) @@ -159,7 +162,7 @@ def summarize(self): self.make_summary() -def _format_summary_table(*cols): +def _format_summary_table(*cols) -> str: ''' Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big @@ -213,7 +216,7 @@ def print_mem_stack(): # pragma: no cover pass -def count_mem_items(): # pragma: no cover +def count_mem_items() -> Tuple[int, int]: # pragma: no cover num_params = 0 num_tensors = 0 for obj in gc.get_objects(): @@ -230,7 +233,7 @@ def count_mem_items(): # pragma: no cover return num_params, num_tensors -def get_memory_profile(mode): +def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: """ 'all' means return memory for all gpus 'min_max' means return memory for max and min @@ -248,7 +251,7 @@ def get_memory_profile(mode): return memory_map -def get_gpu_memory_map(): +def get_gpu_memory_map() -> Dict[str, int]: """Get the current gpu usage. Returns @@ -273,7 +276,7 @@ def get_gpu_memory_map(): return gpu_memory_map -def get_human_readable_count(number): +def get_human_readable_count(number: int) -> str: """ Abbreviates an integer number with K, M, B, T for thousands, millions, billions and trillions, respectively. From 614256c0db320becfb081df0cb104d8d12a9ea50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Feb 2020 21:49:27 +0100 Subject: [PATCH 07/32] proper docstring quotation marks --- pytorch_lightning/core/memory.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 58c10af67201f..c3e9e3cd703eb 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -18,9 +18,7 @@ class ModelSummary(object): def __init__(self, model: LightningModule, mode: str = 'full'): - ''' - Generates summaries of model layers and dimensions. - ''' + """ Generates summaries of model layers and dimensions. """ self.model = model self.mode = mode self.in_sizes = [] @@ -46,7 +44,7 @@ def named_modules(self) -> list: return list(mods) def get_variable_sizes(self): - '''Run sample input through each layer to get output sizes''' + """ Run sample input through each layer to get output sizes """ mods = self.named_modules() in_sizes = [] out_sizes = [] @@ -102,7 +100,7 @@ def get_variable_sizes(self): assert len(in_sizes) == len(out_sizes) def get_layer_names(self): - '''Collect Layer Names''' + """ Collect Layer Names """ mods = self.named_modules() names = [] layers = [] @@ -116,7 +114,7 @@ def get_layer_names(self): self.layer_types = layer_types def get_parameter_sizes(self): - '''Get sizes of all parameters in `model`''' + """ Get sizes of all parameters in `model` """ mods = self.named_modules() sizes = [] for _, m in mods: @@ -127,7 +125,7 @@ def get_parameter_sizes(self): self.param_sizes = sizes def get_parameter_nums(self): - '''Get number of parameters in each layer''' + """ Get number of parameters in each layer """ param_nums = [] for mod in self.param_sizes: all_params = 0 @@ -137,11 +135,11 @@ def get_parameter_nums(self): self.param_nums = param_nums def make_summary(self): - ''' + """ Makes a summary listing with: Layer Name, Layer Type, Input Size, Output Size, Number of Parameters - ''' + """ arrays = [['Name', self.layer_names], ['Type', self.layer_types], ['Params', list(map(get_human_readable_count, self.param_nums))]] @@ -150,7 +148,6 @@ def make_summary(self): arrays.append(['Out sizes', self.out_sizes]) self.summary = _format_summary_table(*arrays) - return def summarize(self): self.get_layer_names() @@ -163,11 +160,11 @@ def summarize(self): def _format_summary_table(*cols) -> str: - ''' + """ Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big string defining the summary table that are nicely formatted. - ''' + """ n_rows = len(cols[0][1]) n_cols = 1 + len(cols) From e36f42adb82f32c155b982fa9f5e6612a871b18c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Feb 2020 21:53:03 +0100 Subject: [PATCH 08/32] add type annotations to saving.py --- pytorch_lightning/core/saving.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index e8a1fac72dd84..2e7cf498354b9 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -2,11 +2,12 @@ import csv import logging as log from argparse import Namespace +from typing import Union class ModelIO(object): - def on_load_checkpoint(self, checkpoint): + def on_load_checkpoint(self, checkpoint: dict): """ Do something with the checkpoint Gives model a chance to load something before state_dict is restored @@ -14,7 +15,7 @@ def on_load_checkpoint(self, checkpoint): :return: """ - def on_save_checkpoint(self, checkpoint): + def on_save_checkpoint(self, checkpoint: dict): """ Give the model a chance to add something to the checkpoint. state_dict is already there @@ -23,20 +24,20 @@ def on_save_checkpoint(self, checkpoint): # ------------------------- # OPTIONAL HOOKS # ------------------------- - def on_hpc_save(self, checkpoint): + def on_hpc_save(self, checkpoint: dict): """ Hook to do whatever you need right before Slurm manager saves the model :return: """ - def on_hpc_load(self, checkpoint): + def on_hpc_load(self, checkpoint: dict): """ Hook to do whatever you need right before Slurm manager loads the model :return: """ -def load_hparams_from_tags_csv(tags_csv): +def load_hparams_from_tags_csv(tags_csv: str) -> Namespace: if not os.path.isfile(tags_csv): log.warning(f'Missing Tags: {tags_csv}.') return Namespace() @@ -50,7 +51,7 @@ def load_hparams_from_tags_csv(tags_csv): return ns -def convert(val): +def convert(val: str) -> Union[int, float, bool, str]: constructors = [int, float, str] if isinstance(val, str): From 0c86afebdee710df870f0ccf24ebd55abfdeee89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Feb 2020 22:13:39 +0100 Subject: [PATCH 09/32] fix cyclic import problem --- pytorch_lightning/core/hooks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9e97503e55559..ccf327d1ff858 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -20,7 +20,7 @@ from torch import Tensor from torch.optim.optimizer import Optimizer -from pytorch_lightning import Trainer +import pytorch_lightning as pl try: from apex import amp @@ -119,7 +119,7 @@ def on_after_backward(self): """ - def backward(self, trainer: Trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int): + def backward(self, trainer: 'pl.Trainer', loss: Tensor, optimizer: Optimizer, optimizer_idx: int): """Override backward with your own implementation if you need to :param trainer: Pointer to the trainer From 280a0d0dad5904107e49ffafe19dd47b395c366a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Feb 2020 22:16:30 +0100 Subject: [PATCH 10/32] fix cyclic import problem --- pytorch_lightning/core/memory.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index c3e9e3cd703eb..ecbe68aa12a38 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -12,12 +12,11 @@ import numpy as np import torch -from pytorch_lightning import LightningModule - +import pytorch_lightning as pl class ModelSummary(object): - def __init__(self, model: LightningModule, mode: str = 'full'): + def __init__(self, model: 'pl.LightningModule', mode: str = 'full'): """ Generates summaries of model layers and dimensions. """ self.model = model self.mode = mode From e433490e97e10a735ecd55a39eb43a135f48bef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Feb 2020 22:22:16 +0100 Subject: [PATCH 11/32] add missing whitespace --- pytorch_lightning/core/memory.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index ecbe68aa12a38..95bdb032d0562 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -14,6 +14,7 @@ import pytorch_lightning as pl + class ModelSummary(object): def __init__(self, model: 'pl.LightningModule', mode: str = 'full'): From 4d24fdf6e5b1dc8049b4cbc504d9f4f3bcf02c67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 06:05:29 +0100 Subject: [PATCH 12/32] finish type hints for load_from_ methods --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 5ddac94251231..4367d65fdcc56 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1108,7 +1108,7 @@ def load_from_metrics( cls, weights_path: str, tags_csv: str, - map_location: Optional[Dict[str, str]] = None + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None ) -> 'LightningModule': r""" Warning: From c2e1cecf520cf04cd9a75cec4980dbf534568cfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 06:08:12 +0100 Subject: [PATCH 13/32] docs: prepare_data does not return anything --- pytorch_lightning/core/lightning.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4367d65fdcc56..cf2bc8923e4d5 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -924,8 +924,6 @@ def prepare_data(self): """Use this to download and prepare data. In distributed (GPU, TPU), this will only be called once - :return: PyTorch DataLoader - This is called before requesting the dataloaders .. code-block:: python From a878adb17e006ec5e61d6d925d3d571870f679d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 06:17:19 +0100 Subject: [PATCH 14/32] fix auto types in docs --- pytorch_lightning/core/lightning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index cf2bc8923e4d5..d61dd436d3803 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -226,7 +226,7 @@ def training_end(self, outputs: dict) -> dict: """return loss, dict with metrics for tqdm :param outputs: What you return in `training_step`. - :return dict: dictionary with loss key and optional log, progress keys: + :return: Dictionary with loss key and optional log, progress keys: - loss -> tensor scalar [REQUIRED] - progress_bar -> Dict for progress bar display. Must have only tensors - log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) @@ -441,7 +441,7 @@ def validation_end(self, outputs: list) -> dict: :param outputs: List of outputs you defined in validation_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader - :return dict: Dictionary or OrderedDict with optional: + :return: Dictionary or OrderedDict with optional: progress_bar -> Dict for progress bar display. Must have only tensors log -> Dict of metrics to add to logger. Must have only tensors (no images, etc) @@ -513,7 +513,7 @@ def test_end(self, outputs: list) -> dict: :param outputs: List of outputs you defined in test_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader - :return dict: Dict of OrderedDict with metrics to display in progress bar + :return: Dict of OrderedDict with metrics to display in progress bar If you didn't define a test_step, this won't be called. Called at the end of the test step with the output of each test_step. From 35eb67c12832fdeab23c0f8d47ab1514d7e164be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 06:30:15 +0100 Subject: [PATCH 15/32] revert typehint for trainer in hook --- pytorch_lightning/core/hooks.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index ccf327d1ff858..cf8656136e6e1 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -20,8 +20,6 @@ from torch import Tensor from torch.optim.optimizer import Optimizer -import pytorch_lightning as pl - try: from apex import amp @@ -119,7 +117,7 @@ def on_after_backward(self): """ - def backward(self, trainer: 'pl.Trainer', loss: Tensor, optimizer: Optimizer, optimizer_idx: int): + def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int): """Override backward with your own implementation if you need to :param trainer: Pointer to the trainer From 0272309748ae3e5ca5fe979277d17fc56bb6a9e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 06:37:10 +0100 Subject: [PATCH 16/32] remove unnecessary return docs --- pytorch_lightning/core/hooks.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index cf8656136e6e1..e65538c94b753 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -39,14 +39,12 @@ def on_sanity_check_start(self): def on_train_start(self): """Called at the beginning of training before sanity check - :return: """ # do something at the start of training def on_train_end(self): """ Called at the end of training before logger experiment is closed - :return: """ # do something at the end of training @@ -54,7 +52,6 @@ def on_batch_start(self, batch: Any): """Called in the training loop before anything happens for that batch. :param batch: - :return: """ # do something when the batch starts @@ -90,17 +87,13 @@ def on_before_zero_grad(self, optimizer: Optimizer): model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad - :param optimizer: - :return: + :param optimizer: The optimizer optimizer for which grads should be zeroed. """ # do something with the optimizer or inspect it. def on_after_backward(self): - """Called after loss.backward() and before optimizers do anything. + """Called in the training loop after loss.backward() and before optimizers do anything. - :return: - - Called in the training loop after model.backward() This is the ideal place to inspect or log gradient information .. code-block:: python @@ -124,7 +117,6 @@ def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: i :param loss: Loss is already scaled by accumulated grads :param optimizer: Current optimizer being used :param optimizer_idx: Index of the current optimizer being used - :return: Called to perform backward step. Feel free to override as needed. From 34f7c7dbf10794a4299e709c85a2686546dda545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 07:00:25 +0100 Subject: [PATCH 17/32] some fixes for memory docs --- pytorch_lightning/core/memory.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 95bdb032d0562..39b128142c933 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -231,10 +231,11 @@ def count_mem_items() -> Tuple[int, int]: # pragma: no cover def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: - """ - 'all' means return memory for all gpus - 'min_max' means return memory for max and min - :param mode: + """ Get a profile of the current memory usage. + + :param mode: There are two modes: + - 'all' means return memory for all gpus + - 'min_max' means return memory for max and min :return: """ memory_map = get_gpu_memory_map() @@ -251,11 +252,9 @@ def get_memory_profile(mode: str) -> Union[Dict[str, int], Dict[int, int]]: def get_gpu_memory_map() -> Dict[str, int]: """Get the current gpu usage. - Returns - ------- - usage: dict - Keys are device ids as integers. - Values are memory usage as integers in MB. + Return: + A dictionary in which the keys are device ids as integers and + values are memory usage as integers in MB. """ result = subprocess.run( [ From 2fa525f89d446d2b767840b9476dd672f7af7bbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 19:30:52 +0100 Subject: [PATCH 18/32] revert typing for args kwargs --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d61dd436d3803..9e154c8dcfa12 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -92,7 +92,7 @@ def forward(self, x): log.info(*args, **kwargs) @abstractmethod - def forward(self, *args: Any, **kwargs: Any) -> Any: + def forward(self, *args, **kwargs): r""" Same as torch.nn.Module.forward(), however in Lightning you want this to define the operations you want to use for prediction (ie: on a server or as a feature extractor). From 1187ef763099c6c9892b321a8632728956b90c88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 19:42:28 +0100 Subject: [PATCH 19/32] added all missing None return types --- pytorch_lightning/core/hooks.py | 24 ++++++++++++------------ pytorch_lightning/core/lightning.py | 19 +++++++++---------- pytorch_lightning/core/memory.py | 18 +++++++++--------- pytorch_lightning/core/saving.py | 10 ++++------ 4 files changed, 34 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e65538c94b753..d903de21d887a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -37,45 +37,45 @@ def on_sanity_check_start(self): :return: """ - def on_train_start(self): + def on_train_start(self) -> None: """Called at the beginning of training before sanity check """ # do something at the start of training - def on_train_end(self): + def on_train_end(self) -> None: """ Called at the end of training before logger experiment is closed """ # do something at the end of training - def on_batch_start(self, batch: Any): + def on_batch_start(self, batch: Any) -> None: """Called in the training loop before anything happens for that batch. :param batch: """ # do something when the batch starts - def on_batch_end(self): + def on_batch_end(self) -> None: """Called in the training loop after the batch.""" # do something when the batch ends - def on_epoch_start(self): + def on_epoch_start(self) -> None: """Called in the training loop at the very beginning of the epoch.""" # do something when the epoch starts - def on_epoch_end(self): + def on_epoch_end(self) -> None: """Called in the training loop at the very end of the epoch.""" # do something when the epoch ends - def on_pre_performance_check(self): + def on_pre_performance_check(self) -> None: """Called at the very beginning of the validation loop.""" # do something before validation starts - def on_post_performance_check(self): + def on_post_performance_check(self) -> None: """Called at the very end of the validation loop.""" # do something before validation end - def on_before_zero_grad(self, optimizer: Optimizer): + def on_before_zero_grad(self, optimizer: Optimizer) -> None: """Called after optimizer.step() and before optimizer.zero_grad() Called in the training loop after taking an optimizer step and before zeroing grads. @@ -87,11 +87,11 @@ def on_before_zero_grad(self, optimizer: Optimizer): model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad - :param optimizer: The optimizer optimizer for which grads should be zeroed. + :param optimizer: The optimizer for which grads should be zeroed. """ # do something with the optimizer or inspect it. - def on_after_backward(self): + def on_after_backward(self) -> None: """Called in the training loop after loss.backward() and before optimizers do anything. This is the ideal place to inspect or log gradient information @@ -110,7 +110,7 @@ def on_after_backward(self): """ - def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int): + def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None: """Override backward with your own implementation if you need to :param trainer: Pointer to the trainer diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9e154c8dcfa12..23c3a4ef4aa40 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -71,7 +71,7 @@ def __init__(self, *args, **kwargs): #: True if using amp self.use_amp = False - def print(self, *args, **kwargs): + def print(self, *args, **kwargs) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once @@ -612,7 +612,7 @@ def configure_ddp(self, model, device_ids): ) return model - def init_ddp_connection(self, proc_rank: int, world_size: int): + def init_ddp_connection(self, proc_rank: int, world_size: int) -> None: r""" Override to define your custom way of setting up a distributed environment. @@ -791,7 +791,7 @@ def optimizer_step( optimizer: Optimizer, optimizer_idx: int, second_order_closure: Optional[Callable] = None, - ): + ) -> None: r""" Override this method to adjust the default way the Trainer calls each optimizer. By default, Lightning @@ -920,7 +920,7 @@ def tbptt_split_batch(self, batch, split_size): return splits - def prepare_data(self): + def prepare_data(self) -> None: """Use this to download and prepare data. In distributed (GPU, TPU), this will only be called once @@ -943,7 +943,6 @@ def prepare_data(self): clean_imagenet() cache_imagenet() """ - return None def train_dataloader(self) -> DataLoader: """Implement a PyTorch DataLoader @@ -1244,11 +1243,11 @@ def _load_model_state(cls, checkpoint: dict) -> 'LightningModule': return model - def summarize(self, mode: str): + def summarize(self, mode: str) -> None: model_summary = ModelSummary(self, mode=mode) log.info('\n' + model_summary.__str__()) - def freeze(self): + def freeze(self) -> None: r""" Freeze all params for inference @@ -1265,7 +1264,7 @@ def freeze(self): self.eval() - def unfreeze(self): + def unfreeze(self) -> None: """Unfreeze all params for inference. .. code-block:: python @@ -1279,7 +1278,7 @@ def unfreeze(self): self.train() - def on_load_checkpoint(self, checkpoint: dict): + def on_load_checkpoint(self, checkpoint: dict) -> None: r""" Called by lightning to restore your model. If you saved something with **on_save_checkpoint** this is your chance to restore this. @@ -1301,7 +1300,7 @@ def on_load_checkpoint(self, checkpoint): No need for you to restore anything regarding training. """ - def on_save_checkpoint(self, checkpoint: dict): + def on_save_checkpoint(self, checkpoint: dict) -> None: r""" Called by lightning when saving a checkpoint to give you a chance to store anything else you diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 39b128142c933..5631e11b4afd8 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -1,6 +1,6 @@ -''' +""" Generates a summary of a model's layers and dimensionality -''' +""" import gc import logging as log @@ -43,7 +43,7 @@ def named_modules(self) -> list: mods = [] return list(mods) - def get_variable_sizes(self): + def get_variable_sizes(self) -> None: """ Run sample input through each layer to get output sizes """ mods = self.named_modules() in_sizes = [] @@ -99,7 +99,7 @@ def get_variable_sizes(self): self.out_sizes = out_sizes assert len(in_sizes) == len(out_sizes) - def get_layer_names(self): + def get_layer_names(self) -> None: """ Collect Layer Names """ mods = self.named_modules() names = [] @@ -113,7 +113,7 @@ def get_layer_names(self): self.layer_names = names self.layer_types = layer_types - def get_parameter_sizes(self): + def get_parameter_sizes(self) -> None: """ Get sizes of all parameters in `model` """ mods = self.named_modules() sizes = [] @@ -124,7 +124,7 @@ def get_parameter_sizes(self): self.param_sizes = sizes - def get_parameter_nums(self): + def get_parameter_nums(self) -> None: """ Get number of parameters in each layer """ param_nums = [] for mod in self.param_sizes: @@ -134,7 +134,7 @@ def get_parameter_nums(self): param_nums.append(all_params) self.param_nums = param_nums - def make_summary(self): + def make_summary(self) -> None: """ Makes a summary listing with: @@ -149,7 +149,7 @@ def make_summary(self): self.summary = _format_summary_table(*arrays) - def summarize(self): + def summarize(self) -> None: self.get_layer_names() self.get_parameter_sizes() self.get_parameter_nums() @@ -204,7 +204,7 @@ def _format_summary_table(*cols) -> str: return summary -def print_mem_stack(): # pragma: no cover +def print_mem_stack() -> None: # pragma: no cover for obj in gc.get_objects(): try: if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 2e7cf498354b9..919bf9fca6300 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -7,7 +7,7 @@ class ModelIO(object): - def on_load_checkpoint(self, checkpoint: dict): + def on_load_checkpoint(self, checkpoint: dict) -> None: """ Do something with the checkpoint Gives model a chance to load something before state_dict is restored @@ -15,7 +15,7 @@ def on_load_checkpoint(self, checkpoint: dict): :return: """ - def on_save_checkpoint(self, checkpoint: dict): + def on_save_checkpoint(self, checkpoint: dict) -> None: """ Give the model a chance to add something to the checkpoint. state_dict is already there @@ -24,16 +24,14 @@ def on_save_checkpoint(self, checkpoint: dict): # ------------------------- # OPTIONAL HOOKS # ------------------------- - def on_hpc_save(self, checkpoint: dict): + def on_hpc_save(self, checkpoint: dict) -> None: """ Hook to do whatever you need right before Slurm manager saves the model - :return: """ - def on_hpc_load(self, checkpoint: dict): + def on_hpc_load(self, checkpoint: dict) -> None: """ Hook to do whatever you need right before Slurm manager loads the model - :return: """ From fe19508f23c88dab83b178e05b9a37e8e5ca4a0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 19:50:15 +0100 Subject: [PATCH 20/32] remove unused import --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 23c3a4ef4aa40..4194be6e41ade 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -5,7 +5,7 @@ import warnings from abc import ABC, abstractmethod from argparse import Namespace -from typing import Any, Union, Tuple, List, Optional, Callable, Dict +from typing import Union, Tuple, List, Optional, Callable, Dict import torch import torch.distributed as dist From 02e50c35379249ec2fcc108ce05cbc0bf91edcc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 26 Feb 2020 23:13:37 +0100 Subject: [PATCH 21/32] add more details to dict/list return types --- pytorch_lightning/core/grads.py | 3 ++- pytorch_lightning/core/lightning.py | 34 +++++++++++++++-------------- pytorch_lightning/core/memory.py | 5 +++-- pytorch_lightning/core/saving.py | 10 ++++----- 4 files changed, 28 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py index 08d99b78ec9c1..b5d2d5616a60f 100644 --- a/pytorch_lightning/core/grads.py +++ b/pytorch_lightning/core/grads.py @@ -1,13 +1,14 @@ """ Module to describe gradients """ +from typing import Dict from torch import nn class GradInformation(nn.Module): - def grad_norm(self, norm_type: float) -> dict: + def grad_norm(self, norm_type: float) -> Dict[str, int]: results = {} total_norm = 0 for name, p in self.named_parameters(): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4194be6e41ade..dcd5103d446a2 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -5,12 +5,12 @@ import warnings from abc import ABC, abstractmethod from argparse import Namespace -from typing import Union, Tuple, List, Optional, Callable, Dict +from typing import Union, Tuple, List, Optional, Callable, Dict, Any import torch import torch.distributed as dist from torch import Tensor -from torch.nn import Module +from torch.nn.parallel import DistributedDataParallel from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -141,7 +141,7 @@ def forward(self, batch): """ - def training_step(self, *args, **kwargs) -> dict: + def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, Dict[str, Tensor]]]]: r"""return loss, dict with metrics for tqdm Args: @@ -151,8 +151,6 @@ def training_step(self, *args, **kwargs) -> dict: optimizer_idx (int): If using multiple optimizers, this argument will also be present. hiddens(:`Tensor `_): Passed in if truncated_bptt_steps > 0. - :param - :return: dict with loss key and optional log, progress keys if implementing training_step, return whatever you need in that step: @@ -222,7 +220,7 @@ def training_step(self, batch, batch_idx, hiddens): if you want to break out of the current training epoch early. """ - def training_end(self, outputs: dict) -> dict: + def training_end(self, outputs: dict) -> Union[int, Dict[str, Union[Tensor, Dict[str, Tensor]]]]: """return loss, dict with metrics for tqdm :param outputs: What you return in `training_step`. @@ -295,7 +293,7 @@ def training_step(self, batch, batch_idx, hiddens): break out of the current training epoch early. """ - def validation_step(self, *args, **kwargs) -> dict: + def validation_step(self, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]: r""" This is the validation loop. It is called for each batch of the validation set. @@ -367,12 +365,12 @@ def validation_step(self, batch, batch_idx, dataset_idx): have been disabled. At the end of validation, model goes back to training mode and gradients are enabled. """ - def test_step(self, *args, **kwargs) -> dict: + def test_step(self, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]: """return whatever outputs will need to be aggregated in test_end :param batch: The output of your dataloader. A tensor, tuple or list :param int batch_idx: Integer displaying which batch this is :param int dataloader_idx: Integer displaying which dataloader this is (only if multiple test datasets used) - :return dict: Dict or OrderedDict with metrics to display in progress bar. All keys must be tensors. + :return: Single tensor or dict with metrics to display in progress bar. All values must be tensors. .. code-block:: python @@ -436,7 +434,9 @@ def test_step(self, batch, batch_idx, dataset_idx): The `dataset_idx` corresponds to the order of datasets returned in `test_dataloader`. """ - def validation_end(self, outputs: list) -> dict: + def validation_end(self, outputs: list) -> Dict[ + str, Union[Dict[str, Union[float, Tensor]], float, Tensor] + ]: """Outputs has the appended output after each validation step. :param outputs: List of outputs you defined in validation_step, or if there are multiple dataloaders, @@ -508,7 +508,9 @@ def validation_end(self, outputs): """ - def test_end(self, outputs: list) -> dict: + def test_end(self, outputs: list) -> Dict[ + str, Union[Dict[str, Union[float, Tensor]], float, Tensor] + ]: """Outputs has the appended output after each test step. :param outputs: List of outputs you defined in test_step, or if there are multiple dataloaders, @@ -572,7 +574,7 @@ def test_end(self, outputs): """ - def configure_ddp(self, model: 'LightningModule', device_ids: list) -> Module: + def configure_ddp(self, model: 'LightningModule', device_ids: List[int]) -> DistributedDataParallel: r""" Override to init DDP in your own way or with your own wrapper. @@ -1210,7 +1212,7 @@ def __init__(self, hparams): return model @classmethod - def _load_model_state(cls, checkpoint: dict) -> 'LightningModule': + def _load_model_state(cls, checkpoint: Dict[str, Any]) -> 'LightningModule': cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters ckpt_hparams = checkpoint.get('hparams') @@ -1278,7 +1280,7 @@ def unfreeze(self) -> None: self.train() - def on_load_checkpoint(self, checkpoint: dict) -> None: + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r""" Called by lightning to restore your model. If you saved something with **on_save_checkpoint** this is your chance to restore this. @@ -1300,7 +1302,7 @@ def on_load_checkpoint(self, checkpoint): No need for you to restore anything regarding training. """ - def on_save_checkpoint(self, checkpoint: dict) -> None: + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: r""" Called by lightning when saving a checkpoint to give you a chance to store anything else you @@ -1323,7 +1325,7 @@ def on_save_checkpoint(self, checkpoint): """ - def get_tqdm_dict(self) -> dict: + def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: r""" Additional items to be displayed in the progress bar. diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 5631e11b4afd8..926786268d90f 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -7,10 +7,11 @@ import os import subprocess from subprocess import PIPE -from typing import Tuple, Dict, Union +from typing import Tuple, Dict, Union, List import numpy as np import torch +from torch.nn import Module import pytorch_lightning as pl @@ -32,7 +33,7 @@ def __str__(self): def __repr__(self): return self.summary.__str__() - def named_modules(self) -> list: + def named_modules(self) -> List[Tuple[str, Module]]: if self.mode == 'full': mods = self.model.named_modules() mods = list(mods)[1:] # do not include root module (LightningModule) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 919bf9fca6300..97ffbc495d833 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -2,12 +2,12 @@ import csv import logging as log from argparse import Namespace -from typing import Union +from typing import Union, Dict, Any class ModelIO(object): - def on_load_checkpoint(self, checkpoint: dict) -> None: + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ Do something with the checkpoint Gives model a chance to load something before state_dict is restored @@ -15,7 +15,7 @@ def on_load_checkpoint(self, checkpoint: dict) -> None: :return: """ - def on_save_checkpoint(self, checkpoint: dict) -> None: + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """ Give the model a chance to add something to the checkpoint. state_dict is already there @@ -24,12 +24,12 @@ def on_save_checkpoint(self, checkpoint: dict) -> None: # ------------------------- # OPTIONAL HOOKS # ------------------------- - def on_hpc_save(self, checkpoint: dict) -> None: + def on_hpc_save(self, checkpoint: Dict[str, Any]) -> None: """ Hook to do whatever you need right before Slurm manager saves the model """ - def on_hpc_load(self, checkpoint: dict) -> None: + def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: """ Hook to do whatever you need right before Slurm manager loads the model """ From d11f747c7032706ef867b70c5589bf1feb0ff8e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 27 Feb 2020 23:13:07 +0100 Subject: [PATCH 22/32] fix line too long --- pytorch_lightning/core/lightning.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dcd5103d446a2..851c3bfda0153 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -141,7 +141,9 @@ def forward(self, batch): """ - def training_step(self, *args, **kwargs) -> Union[int, Dict[str, Union[Tensor, Dict[str, Tensor]]]]: + def training_step(self, *args, **kwargs) -> Union[ + int, Dict[str, Union[Tensor, Dict[str, Tensor]]] + ]: r"""return loss, dict with metrics for tqdm Args: @@ -220,7 +222,9 @@ def training_step(self, batch, batch_idx, hiddens): if you want to break out of the current training epoch early. """ - def training_end(self, outputs: dict) -> Union[int, Dict[str, Union[Tensor, Dict[str, Tensor]]]]: + def training_end(self, outputs: dict) -> Union[ + int, Dict[str, Union[Tensor, Dict[str, Tensor]]] + ]: """return loss, dict with metrics for tqdm :param outputs: What you return in `training_step`. @@ -369,8 +373,10 @@ def test_step(self, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]: """return whatever outputs will need to be aggregated in test_end :param batch: The output of your dataloader. A tensor, tuple or list :param int batch_idx: Integer displaying which batch this is - :param int dataloader_idx: Integer displaying which dataloader this is (only if multiple test datasets used) - :return: Single tensor or dict with metrics to display in progress bar. All values must be tensors. + :param int dataloader_idx: Integer displaying which dataloader this is + (only if multiple test datasets used) + :return: Single tensor or dict with metrics to display in progress bar. + All values must be tensors. .. code-block:: python @@ -574,7 +580,11 @@ def test_end(self, outputs): """ - def configure_ddp(self, model: 'LightningModule', device_ids: List[int]) -> DistributedDataParallel: + def configure_ddp( + self, + model: 'LightningModule', + device_ids: List[int] + ) -> DistributedDataParallel: r""" Override to init DDP in your own way or with your own wrapper. From bee44d020f503e88eab510cb841a1b867f91f408 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 3 Mar 2020 13:23:26 +0100 Subject: [PATCH 23/32] optimize imports --- pytorch_lightning/core/lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 851c3bfda0153..cd46eaffdd9fc 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -11,14 +11,15 @@ import torch.distributed as dist from torch import Tensor from torch.nn.parallel import DistributedDataParallel +from torch.optim import Adam from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from pytorch_lightning.core.decorators import data_loader from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import ModelHooks -from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities.debugging import MisconfigurationException From 85559611e84e312bce64f4e73b638d4999a8439e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 5 Mar 2020 14:54:50 +0100 Subject: [PATCH 24/32] linted --- pytorch_lightning/core/saving.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 3058a2ba1f91c..c41a3205ba05f 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -34,7 +34,6 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: Hook to do whatever you need right before Slurm manager loads the model """ - def load_hparams_from_tags_csv(tags_csv: str) -> Namespace: if not os.path.isfile(tags_csv): log.warning(f'Missing Tags: {tags_csv}.') From 7482cb8482afbb1f0675eca020148c013df51a5c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 5 Mar 2020 14:54:50 +0100 Subject: [PATCH 25/32] Revert "linted" This reverts commit 85559611e84e312bce64f4e73b638d4999a8439e. --- pytorch_lightning/core/saving.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index c41a3205ba05f..3058a2ba1f91c 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -34,6 +34,7 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: Hook to do whatever you need right before Slurm manager loads the model """ + def load_hparams_from_tags_csv(tags_csv: str) -> Namespace: if not os.path.isfile(tags_csv): log.warning(f'Missing Tags: {tags_csv}.') From cad32d8b50d7c0655c32bb1e32222729baf8e69a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 5 Mar 2020 17:05:25 +0100 Subject: [PATCH 26/32] remove whitespace --- pytorch_lightning/core/saving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 3058a2ba1f91c..a7a0125680e3e 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -34,7 +34,7 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: Hook to do whatever you need right before Slurm manager loads the model """ - + def load_hparams_from_tags_csv(tags_csv: str) -> Namespace: if not os.path.isfile(tags_csv): log.warning(f'Missing Tags: {tags_csv}.') From b3f0ba5b45ef4ed1826a7b14fa34711174d759ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 7 Mar 2020 22:49:45 +0100 Subject: [PATCH 27/32] update --- pytorch_lightning/core/lightning.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 535bd888263ab..331eac81e5337 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -221,15 +221,15 @@ def training_step(self, batch, batch_idx, hiddens): if you want to break out of the current training epoch early. """ - def training_end(self, outputs: dict) -> Union[ - int, Dict[str, Union[Tensor, Dict[str, Tensor]]] - ]: + def training_end(self, outputs: dict): """ Warnings: Deprecated in v0.7.0. use training_step_end instead """ - def training_step_end(self, *args, **kwargs): + def training_step_end(self, *args, **kwargs) -> Dict[ + str, Union[Tensor, Dict[str, Tensor]] + ]: """ Use this when training with dp or ddp2 because training_step will operate on only part of the batch. However, this is still optional @@ -290,7 +290,7 @@ def training_step_end(self, outputs): .. seealso:: see the `multi-gpu guide for more details `_. """ - def validation_step(self, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]: + def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]: r""" Operate on a single batch of data from the validation set In this step you'd might generate examples or calculate anything of interest like accuracy. @@ -656,8 +656,9 @@ def test_end(self, outputs: list) -> Dict[ ]: """ Warnings: - """ + Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0 + """ def test_epoch_end(self, outputs): """ From 15516dec2d519f2111386f1b1f6eb78011958a73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 7 Mar 2020 22:58:42 +0100 Subject: [PATCH 28/32] update --- pytorch_lightning/core/lightning.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 82cdd2dae8e5a..b13ce8e05b745 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -377,7 +377,7 @@ def validation_step(self, batch, batch_idx, dataset_idx): the model goes back to training mode and gradients are enabled. """ - def validation_step_end(self, *args, **kwargs): + def validation_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: """ Use this when validating with dp or ddp2 because validation_step will operate on only part of the batch. However, this is still optional @@ -441,7 +441,9 @@ def validation_end(self, outputs): Deprecated in v0.7.0. use validation_epoch_end instead. Will be removed 1.0.0 """ - def validation_epoch_end(self, outputs: list): + def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[ + str, Dict[str, Tensor] + ]: """ Called at end of validation epoch with the output of all validation_steps @@ -515,7 +517,7 @@ def validation_epoch_end(self, outputs): return results """ - def test_step(self, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]: + def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: r""" Operate on a single batch of data from the test set In this step you'd normally generate examples or calculate anything of interest @@ -596,7 +598,7 @@ def test_step(self, batch, batch_idx, dataset_idx): to training mode and gradients are enabled. """ - def test_step_end(self, *args, **kwargs): + def test_step_end(self, *args, **kwargs) -> Dict[str, Tensor]: """ Use this when testing with dp or ddp2 because test_step will operate on only part of the batch. However, this is still optional @@ -654,16 +656,16 @@ def test_step_end(self, outputs): .. seealso:: see the `multi-gpu guide for more details `_. """ - def test_end(self, outputs: list) -> Dict[ - str, Union[Dict[str, Union[float, Tensor]], float, Tensor] - ]: + def test_end(self, outputs): """ Warnings: Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0 """ - def test_epoch_end(self, outputs): + def test_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[ + str, Dict[str, Tensor] + ]: """ Called at end of test epoch with the output of all test_steps. @@ -894,7 +896,7 @@ def configure_apex(self, amp, model, optimizers, amp_level): return model, optimizers def configure_optimizers(self) -> Union[ - Optimizer, List[Optimizer], Tuple[Optimizer, ...], Tuple[List[Optimizer], list] + Optimizer, List[Optimizer], Tuple[Optimizer, ...], Tuple[List[Optimizer], List] ]: r""" Choose what optimizers and learning-rate schedulers to use in your optimization. @@ -1284,12 +1286,7 @@ def val_dataloader(self): """ @classmethod - def load_from_metrics( - cls, - weights_path: str, - tags_csv: str, - map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None - ) -> 'LightningModule': + def load_from_metrics(cls, weights_path, tags_csv, map_location=None): r""" Warning: Deprecated in version 0.7.0. You should use `load_from_checkpoint` instead. From 9c026ce2b9e490be1d7ec34aaa9ef1b51f482ee6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 7 Mar 2020 23:03:25 +0100 Subject: [PATCH 29/32] update --- pytorch_lightning/core/lightning.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b13ce8e05b745..4c98e8c026884 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -71,6 +71,8 @@ def __init__(self, *args, **kwargs): #: True if using amp self.use_amp = False + self.hparams = None + def print(self, *args, **kwargs) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once @@ -220,7 +222,7 @@ def training_step(self, batch, batch_idx, hiddens): if you want to break out of the current training epoch early. """ - def training_end(self, outputs: dict): + def training_end(self, *args, **kwargs): """ Warnings: Deprecated in v0.7.0. use training_step_end instead @@ -659,9 +661,8 @@ def test_step_end(self, outputs): def test_end(self, outputs): """ Warnings: - Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0 - """ + """ def test_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[ str, Dict[str, Tensor] @@ -1442,7 +1443,7 @@ def freeze(self) -> None: self.eval() def unfreeze(self) -> None: - """Unfreeze all params for inference. + """Unfreeze all params for training. .. code-block:: python From 2915e6b4dffc3207ffb329567e8387f8aebb4df7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 7 Mar 2020 23:05:28 +0100 Subject: [PATCH 30/32] update --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4c98e8c026884..17b2cfe303911 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -5,7 +5,7 @@ import warnings from abc import ABC, abstractmethod from argparse import Namespace -from typing import Union, Tuple, List, Optional, Callable, Dict, Any +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist From 4bd0437ba3c5530d524dfe3ceb64e42a6a56d5af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 7 Mar 2020 23:10:40 +0100 Subject: [PATCH 31/32] update --- pytorch_lightning/core/lightning.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 17b2cfe303911..b531c1542ee6d 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -443,9 +443,10 @@ def validation_end(self, outputs): Deprecated in v0.7.0. use validation_epoch_end instead. Will be removed 1.0.0 """ - def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[ - str, Dict[str, Tensor] - ]: + def validation_epoch_end( + self, + outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] + ) -> Dict[str, Dict[str, Tensor]]: """ Called at end of validation epoch with the output of all validation_steps @@ -664,9 +665,10 @@ def test_end(self, outputs): Deprecated in v0.7.0. use test_epoch_end instead. Will be removed 1.0.0 """ - def test_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[ - str, Dict[str, Tensor] - ]: + def test_epoch_end( + self, + outputs: Union[List[Dict[str, Tensor]], List[List[Dict[str, Tensor]]]] + ) -> Dict[str, Dict[str, Tensor]]: """ Called at end of test epoch with the output of all test_steps. @@ -681,7 +683,7 @@ def test_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[ test_epoch_end(test_outs) Args: - outputs (list): List of outputs you defined in test_step, or if there are multiple + outputs: List of outputs you defined in test_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader Return: From 4f68fc1637c398cfdbbe926385373b79c402c80f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 10 Mar 2020 20:35:45 +0100 Subject: [PATCH 32/32] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 985accb595374..54f7950c4a6c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946)) ### Changed