Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing in pl.callbacks.lr_monitor #10802

Merged
merged 19 commits into from
Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ warn_no_return = "False"
module = [
"pytorch_lightning.accelerators.gpu",
"pytorch_lightning.callbacks.finetuning",
"pytorch_lightning.callbacks.lr_monitor",
"pytorch_lightning.callbacks.model_checkpoint",
"pytorch_lightning.callbacks.progress.base",
"pytorch_lightning.callbacks.progress.progress",
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/lr_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def configure_optimizer(self):

"""

def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False):
def __init__(self, logging_interval: Optional[str] = None, log_momentum: bool = False) -> None:
if logging_interval not in (None, "step", "epoch"):
raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.")

Expand Down Expand Up @@ -146,6 +146,7 @@ def _check_no_key(key: str) -> bool:
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}

def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if not trainer.logger_connector.should_update_logs:
return

Expand All @@ -157,6 +158,7 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)

def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
assert trainer.logger is not None
if self.logging_interval != "step":
interval = "epoch" if self.logging_interval is None else "any"
latest_stat = self._extract_stats(trainer, interval)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

import pytorch_lightning as pl
Expand All @@ -32,7 +31,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT

TBroadcast = TypeVar("TBroadcast")

Expand All @@ -52,7 +51,7 @@ def __init__(
self.checkpoint_io = checkpoint_io
self.precision_plugin = precision_plugin
self.optimizers: List[Optimizer] = []
self.lr_schedulers: List[_LRScheduler] = []
self.lr_schedulers: List[LRSchedulerConfig] = []
self.optimizer_frequencies: List[int] = []
if is_overridden("post_dispatch", self, parent=Strategy):
rank_zero_deprecation(
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import LRSchedulerConfig


class TrainerOptimizersMixin(ABC):
Expand Down Expand Up @@ -122,7 +123,7 @@ def _convert_to_lightning_optimizer(trainer, optimizer):
@staticmethod
def _configure_schedulers(
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
) -> List[Dict[str, Any]]:
) -> List[LRSchedulerConfig]:
"""Convert each scheduler into dict structure with relevant information."""
lr_schedulers = []
default_config = _get_default_scheduler_config()
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
_PATH,
_PREDICT_OUTPUT,
EVAL_DATALOADERS,
LRSchedulerTypeUnion,
LRSchedulerConfig,
STEP_OUTPUT,
TRAIN_DATALOADERS,
)
Expand Down Expand Up @@ -1839,11 +1839,11 @@ def optimizers(self, new_optims: Optional[List[Optimizer]]) -> None:
self.strategy.optimizers = new_optims

@property
def lr_schedulers(self) -> List[LRSchedulerTypeUnion]:
def lr_schedulers(self) -> List[LRSchedulerConfig]:
return self.strategy.lr_schedulers

@lr_schedulers.setter
def lr_schedulers(self, new_schedulers: List[LRSchedulerTypeUnion]) -> None:
def lr_schedulers(self, new_schedulers: List[LRSchedulerConfig]) -> None:
self.strategy.lr_schedulers = new_schedulers

@property
Expand Down
81 changes: 75 additions & 6 deletions pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
"""
Convention:
- Do not include any `_TYPE` suffix
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`)
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`)
"""
from pathlib import Path
from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchmetrics import Metric
from typing_extensions import TypedDict

_NUMBER = Union[int, float]
_METRIC = Union[Metric, torch.Tensor, _NUMBER]
Expand All @@ -43,7 +44,75 @@
Dict[str, Sequence[DataLoader]],
]
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]


# Copied from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
class _LRScheduler:
optimizer: Optimizer
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, optimizer: Optimizer, last_epoch: int = ...) -> None:
...

def state_dict(self) -> dict:
...

def load_state_dict(self, state_dict: dict) -> None:
...

def get_last_lr(self) -> List[float]:
...

def get_lr(self) -> float:
...

def step(self, epoch: Optional[int] = ...) -> None:
...


# Copied from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
class ReduceLROnPlateau:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
in_cooldown: bool
optimizer: Optimizer

def __init__(
self,
optimizer: Optimizer,
mode: str = ...,
factor: float = ...,
patience: int = ...,
verbose: bool = ...,
threshold: float = ...,
threshold_mode: str = ...,
cooldown: int = ...,
min_lr: float = ...,
eps: float = ...,
) -> None:
...

def step(self, metrics: Any, epoch: Optional[int] = ...) -> None:
...

def state_dict(self) -> dict:
...

def load_state_dict(self, state_dict: dict) -> None:
...


# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (_LRScheduler, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[_LRScheduler, ReduceLROnPlateau]
LRSchedulerType = Union[Type[_LRScheduler], Type[ReduceLROnPlateau]]
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved


class LRSchedulerConfig(TypedDict):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
scheduler: Union[_LRScheduler, ReduceLROnPlateau]
name: Optional[str]
interval: str
frequency: int
reduce_on_plateau: bool
monitor: Optional[str]
strict: bool
opt_idx: Optional[int]