Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 21, 2023
1 parent 7ac5190 commit e99bb9d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 28 deletions.
2 changes: 1 addition & 1 deletion configs/callbacks/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ stochastic_weight_averaging:
_target_: pvnet_summation.callbacks.StochasticWeightAveraging
swa_lrs: 0.0000001
swa_epoch_start: 0.8
annealing_epochs: 5
annealing_epochs: 5
70 changes: 47 additions & 23 deletions pvnet_summation/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,19 @@
# limitations under the License.
r"""Stochastic Weight Averaging Callback ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"""
from copy import deepcopy
from typing import Any, Callable, cast, Dict, List, Optional, Union

import torch
from torch import nn, Tensor
from torch.optim.swa_utils import SWALR
from typing import Any, Callable, Dict, List, Optional, Union, cast

import lightning.pytorch as pl
import torch
from lightning.fabric.utilities.types import LRScheduler
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.callbacks import StochasticWeightAveraging
from lightning.pytorch.strategies import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.types import LRSchedulerConfig
from lightning.pytorch.callbacks import StochasticWeightAveraging
from torch import Tensor, nn
from torch.optim.swa_utils import SWALR

_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor]

Expand Down Expand Up @@ -64,7 +62,6 @@ def __init__(
See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>`
Arguments:
swa_lrs: The SWA learning rate to use:
- ``float``. Use this value for all parameter groups of the optimizer.
Expand Down Expand Up @@ -101,15 +98,21 @@ def __init__(

wrong_type = not isinstance(swa_lrs, (float, list))
wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0
wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
wrong_list = isinstance(swa_lrs, list) and not all(
lr > 0 and isinstance(lr, float) for lr in swa_lrs
)
if wrong_type or wrong_float or wrong_list:
raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats")
raise MisconfigurationException(
"The `swa_lrs` should a positive float, or a list of positive floats"
)

if avg_fn is not None and not callable(avg_fn):
raise MisconfigurationException("The `avg_fn` should be callable.")

if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")
raise MisconfigurationException(
f"device is expected to be a torch.device or a str. Found {device}"
)

self.n_averaged: Optional[Tensor] = None
self._swa_epoch_start = swa_epoch_start
Expand Down Expand Up @@ -140,7 +143,9 @@ def swa_end(self) -> int:

@staticmethod
def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool:
return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())
return any(
isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules()
)

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
if isinstance(trainer.strategy, (FSDPStrategy, DeepSpeedStrategy)):
Expand All @@ -154,7 +159,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")

if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
raise MisconfigurationException(
"SWA currently not supported for more than 1 `lr_scheduler`."
)

assert trainer.max_epochs is not None
if isinstance(self._swa_epoch_start, float):
Expand Down Expand Up @@ -216,7 +223,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
if trainer.lr_scheduler_configs:
scheduler_cfg = trainer.lr_scheduler_configs[0]
if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1:
rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")
rank_zero_warn(
f"SWA is currently only supported every epoch. Found {scheduler_cfg}"
)
rank_zero_info(
f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`"
f" for `{self._swa_scheduler.__class__.__name__}`"
Expand All @@ -226,7 +235,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
trainer.lr_scheduler_configs.append(default_scheduler_cfg)

if self.n_averaged is None:
self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device)
self.n_averaged = torch.tensor(
self._init_n_averaged, dtype=torch.long, device=pl_module.device
)

if (self.swa_start <= trainer.current_epoch <= self.swa_end) and (
trainer.current_epoch > self._latest_update_epoch
Expand Down Expand Up @@ -254,7 +265,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
trainer.accumulate_grad_batches = self._train_batches

def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None:
if trainer.current_epoch==0:
if trainer.current_epoch == 0:
self._train_batches = trainer.global_step
trainer.fit_loop._skip_backward = False

Expand All @@ -273,7 +284,9 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
self.transfer_weights(self._average_model, pl_module)

@staticmethod
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None:
def transfer_weights(
src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"
) -> None:
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
dst_param.detach().copy_(src_param.to(dst_param.device))

Expand Down Expand Up @@ -307,7 +320,10 @@ def reset_momenta(self) -> None:

@staticmethod
def update_parameters(
average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN
average_model: "pl.LightningModule",
model: "pl.LightningModule",
n_averaged: Tensor,
avg_fn: _AVG_FN,
) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112."""
for p_swa, p_model in zip(average_model.parameters(), model.parameters()):
Expand All @@ -319,16 +335,24 @@ def update_parameters(
n_averaged += 1

@staticmethod
def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor:
def avg_fn(
averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor
) -> Tensor:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (
num_averaged + 1
)

def state_dict(self) -> Dict[str, Any]:
return {
"n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(),
"latest_update_epoch": self._latest_update_epoch,
"scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(),
"average_model_state": None if self._average_model is None else self._average_model.state_dict(),
"scheduler_state": None
if self._swa_scheduler is None
else self._swa_scheduler.state_dict(),
"average_model_state": None
if self._average_model is None
else self._average_model.state_dict(),
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -353,4 +377,4 @@ def _clear_schedulers(trainer: "pl.Trainer") -> None:
def _load_average_model_state(self, model_state: Any) -> None:
if self._average_model is None:
return
self._average_model.load_state_dict(model_state)
self._average_model.load_state_dict(model_state)
6 changes: 3 additions & 3 deletions pvnet_summation/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def forward(self, x):
eff_cap = x["effective_capacity"].unsqueeze(-1)
else:
eff_cap = x["effective_capacity"]
# Multiply by (effective capacity / 100) since the capacities are roughly of magnitude

# Multiply by (effective capacity / 100) since the capacities are roughly of magnitude
# of 100 MW. We still want the inputs to the network to be order of magnitude 1.
x_in = x["pvnet_outputs"] * (eff_cap/100)
x_in = x["pvnet_outputs"] * (eff_cap / 100)
else:
x_in = x["pvnet_outputs"]

Expand Down
2 changes: 1 addition & 1 deletion pvnet_summation/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def train(config: DictConfig) -> Optional[float]:

# Train the model completely
trainer.fit(model=model, datamodule=datamodule)

# Validate after end - useful if using stochastic weight averaging
trainer.validate(model=model, datamodule=datamodule)

Expand Down

0 comments on commit e99bb9d

Please sign in to comment.