Skip to content

Commit

Permalink
format docs with 120 (#1057)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] committed Jul 12, 2023
1 parent 011f209 commit c6f6d3b
Show file tree
Hide file tree
Showing 99 changed files with 395 additions and 108 deletions.
8 changes: 2 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ repos:
rev: v1.7.3
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
additional_dependencies: [tomli]
args: ["--in-place"]

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
Expand All @@ -44,11 +45,6 @@ repos:
- mdformat_frontmatter
exclude: CHANGELOG.md

#- repo: https://github.com/PyCQA/isort
# rev: 5.12.0
# hooks:
# - id: isort

- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ relative_files = true
line-length = 120
exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)"

[tool.isort]
known_first_party = ["pl_bolts", "tests", "notebooks"]
skip_glob = []
profile = "black"
line_length = 120
[tool.docformatter]
recursive = true
wrap-summaries = 120
wrap-descriptions = 120
blank = true


[tool.ruff]
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: bool = True
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow", unfreeze=True)
'arrow'
"""
# filer all comments
if comment_char in ln:
Expand Down Expand Up @@ -61,6 +62,7 @@ def _load_requirements(path_dir: str, file_name: str, unfreeze: bool = not _FREE
>>> path_req = os.path.join(_PATH_ROOT, "requirements")
>>> _load_requirements(path_req, "docs.txt") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['sphinx>=4.0', ...]
"""
with open(os.path.join(path_dir, file_name)) as file:
lines = [ln.strip() for ln in file.readlines()]
Expand All @@ -77,6 +79,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:
>>> _load_readme_description(_PATH_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'<div align="center">...'
"""
path_readme = os.path.join(path_dir, "README.md")
with open(path_readme, encoding="utf-8") as fo:
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class BYOLMAWeightUpdate(Callback):
model.target_network = ...
trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])
"""

def __init__(self, initial_tau: float = 0.996) -> None:
Expand Down
18 changes: 11 additions & 7 deletions src/pl_bolts/callbacks/data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ class DataMonitorBase(Callback):

def __init__(self, log_every_n_steps: int = None) -> None:
"""Base class for monitoring data histograms in a LightningModule. This requires a logger configured in the
Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data
gets collected.
Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data gets
collected.
Args:
log_every_n_steps: The interval at which histograms should be logged. This defaults to the
interval defined in the Trainer. Use this to override the Trainer default.
"""
super().__init__()
self._log_every_n_steps: Optional[int] = log_every_n_steps
Expand Down Expand Up @@ -84,12 +85,13 @@ def log_histograms(self, batch: Any, group: str = "") -> None:
self.log_histogram(tensor, name)

def log_histogram(self, tensor: Tensor, name: str) -> None:
"""Override this method to customize the logging of histograms. Detaches the tensor from the graph and
moves it to the CPU for logging.
"""Override this method to customize the logging of histograms. Detaches the tensor from the graph and moves it
to the CPU for logging.
Args:
tensor: The tensor for which to log a histogram
name: The name of the tensor as determined by the callback. Example: ``ìnput/0/[64, 1, 28, 28]``
"""
logger = self._trainer.logger
tensor = tensor.detach().cpu()
Expand Down Expand Up @@ -234,9 +236,9 @@ def on_train_batch_start(


def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: str = "input") -> None:
"""Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. Data
in dictionaries get named by their corresponding keys and otherwise they get indexed by an increasing integer.
The shape of the tensor gets appended to the name as well.
"""Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. Data in
dictionaries get named by their corresponding keys and otherwise they get indexed by an increasing integer. The
shape of the tensor gets appended to the name as well.
Args:
data: A collection of data (potentially nested).
Expand All @@ -249,6 +251,7 @@ def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name:
>>> collect_and_name_tensors(data, output)
>>> output # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
{'input/x/[2, 3]': ..., 'input/y/z/[5]': ...}
"""
assert isinstance(output, dict)
if isinstance(data, Tensor):
Expand All @@ -273,5 +276,6 @@ def shape2str(tensor: Tensor) -> str:
'[1, 2, 3]'
>>> shape2str(torch.rand(4))
'[4]'
"""
return "[" + ", ".join(map(str, tensor.shape)) + "]"
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class PrintTableMetricsCallback(Callback):
# loss│train_loss│val_loss│epoch
# ──────────────────────────────
# 2.2541470527648926│2.2541470527648926│2.2158432006835938│0
"""

def __init__(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/sparseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SparseMLCallback(Callback):
Args:
recipe_path: Path to a SparseML compatible yaml recipe.
More information at https://docs.neuralmagic.com/sparseml/source/recipes.html
"""

def __init__(self, recipe_path: str) -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/pl_bolts/callbacks/ssl_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SSLOnlineEvaluator(Callback): # pragma: no cover
online_eval = SSLOnlineEvaluator(
z_dim=model.z_dim
)
"""

def __init__(
Expand Down Expand Up @@ -182,6 +183,7 @@ def set_training(module: nn.Module, mode: bool):
Args:
module: module to set training mode
mode: whether to set training mode (True) or evaluation mode (False).
"""
original_mode = module.training

Expand Down
5 changes: 3 additions & 2 deletions src/pl_bolts/callbacks/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

@under_review()
class LatentDimInterpolator(Callback):
"""Interpolates the latent space for a model by setting all dims to zero and stepping through the first two
dims increasing one unit at a time.
"""Interpolates the latent space for a model by setting all dims to zero and stepping through the first two dims
increasing one unit at a time.
Default interpolates between [-5, 5] (-5, -4, -3, ..., 3, 4, 5)
Expand All @@ -28,6 +28,7 @@ class LatentDimInterpolator(Callback):
from pl_bolts.callbacks import LatentDimInterpolator
Trainer(callbacks=[LatentDimInterpolator()])
"""

def __init__(
Expand Down
8 changes: 6 additions & 2 deletions src/pl_bolts/callbacks/verification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class VerificationBase:
All verifications should run with any
:class: `torch.nn.Module` unless otherwise stated.
"""

def __init__(self, model: nn.Module) -> None:
Expand All @@ -39,14 +40,16 @@ def check(self, *args: Any, **kwargs: Any) -> bool:
`True` if the test passes, and `False` otherwise. Some verifications can only be performed
with a heuristic accuracy, thus the return value may not always reflect the true state of
the system in these cases.
"""

def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any:
"""Returns a deep copy of the example input array in cases where it is expected that the input changes
during the verification process.
"""Returns a deep copy of the example input array in cases where it is expected that the input changes during
the verification process.
Arguments:
input_array: The input to clone.
"""
if input_array is None and isinstance(self.model, LightningModule):
input_array = self.model.example_input_array
Expand Down Expand Up @@ -89,6 +92,7 @@ class VerificationCallbackBase(Callback):
This type of verification is expected to only work with
:class:`~pytorch_lightning.core.lightning.LightningModule` and will take the input array
from :attr:`~pytorch_lightning.core.lightning.LightningModule.example_input_array` if needed.
"""

def __init__(self, warn: bool = True, error: bool = False) -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/pl_bolts/callbacks/verification/batch_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class BatchGradientVerification(VerificationBase):
This can happen if reshape- and/or permutation operations are carried out in the wrong order or on the wrong tensor
dimensions.
"""

NORM_LAYER_CLASSES = (
Expand Down Expand Up @@ -57,6 +58,7 @@ def check(
Returns:
``True`` if the data in the batch does not mix during the forward pass, and ``False`` otherwise.
"""
input_mapping = input_mapping or default_input_mapping
output_mapping = output_mapping or default_output_mapping
Expand Down Expand Up @@ -151,6 +153,7 @@ def default_input_mapping(data: Any) -> List[Tensor]:
torch.Size([3, 1])
>>> result[1].shape
torch.Size([3, 2])
"""
tensors = collect_tensors(data)
batches: List[Tensor] = []
Expand Down Expand Up @@ -181,6 +184,7 @@ def default_output_mapping(data: Any) -> Tensor:
>>> result = default_output_mapping(data)
>>> result.shape
torch.Size([3, 7])
"""
if isinstance(data, Tensor):
return data
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/vision/confused_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ConfusedLogitCallback(Callback): # pragma: no cover
Authored by:
- Alfredo Canziani
"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datamodules/async_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class AsynchronousLoader:
if set and DataLoader has a __len__. Otherwise it can be left as None
**kwargs: Any additional arguments to pass to the dataloader if we're
constructing one here
"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
dm = CIFAR10DataModule(PATH)
model = LitModel(datamodule=dm)
"""

dataset_cls = TrialCIFAR10
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datamodules/emnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def num_classes(self) -> int:
"""Returns the number of classes.
See the table above.
"""
return len(self.dataset_cls.classes_split_dict[self.split])

Expand Down
18 changes: 13 additions & 5 deletions src/pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ExperienceSourceDataset(IterableDataset):
Takes a generate_batch function that returns an iterator. The logic for the experience source and how the batch is
generated is defined the Lightning model itself
"""

def __init__(self, generate_batch: Callable) -> None:
Expand Down Expand Up @@ -95,6 +96,7 @@ def runner(self, device: torch.device) -> Tuple[Experience]:
Returns:
Tuple of Experiences
"""
while True:
# get actions for all envs
Expand All @@ -116,14 +118,15 @@ def runner(self, device: torch.device) -> Tuple[Experience]:
self.iter_idx += 1

def update_history_queue(self, env_idx, exp, history) -> None:
"""Updates the experience history queue with the lastest experiences. In the event of an experience step is
in the done state, the history will be incrementally appended to the queue, removing the tail of the
history each time.
"""Updates the experience history queue with the lastest experiences. In the event of an experience step is in
the done state, the history will be incrementally appended to the queue, removing the tail of the history each
time.
Args:
env_idx: index of the environment
exp: the current experience
history: history of experience steps for this environment
"""
# If there is a full history of step, append history to queue
if len(history) == self.n_steps:
Expand Down Expand Up @@ -184,6 +187,7 @@ def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience:
Returns:
Experience tuple
"""
next_state, r, is_done, _ = env.step(action[0])

Expand All @@ -198,6 +202,7 @@ def update_env_stats(self, env_idx: int) -> None:
Args:
env_idx: index of the environment used to update stats
"""
self._total_rewards.append(self.cur_rewards[env_idx])
self.total_steps.append(self.cur_steps[env_idx])
Expand Down Expand Up @@ -248,6 +253,7 @@ def runner(self, device: torch.device) -> Experience:
Yields:
Discounted Experience
"""
for experiences in super().runner(device):
last_exp_state, tail_experiences = self.split_head_tail_exp(experiences)
Expand All @@ -263,14 +269,15 @@ def runner(self, device: torch.device) -> Experience:
)

def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tuple[Experience]]:
"""Takes in a tuple of experiences and returns the last state and tail experiences based on if the last
state is the end of an episode.
"""Takes in a tuple of experiences and returns the last state and tail experiences based on if the last state is
the end of an episode.
Args:
experiences: Tuple of N Experience
Returns:
last state (Array or None) and remaining Experience
"""
if experiences[-1].done and len(experiences) <= self.steps:
last_exp_state = experiences[-1].new_state
Expand All @@ -288,6 +295,7 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float:
Returns:
total discounted reward
"""
total_reward = 0.0
for exp in reversed(experiences):
Expand Down
3 changes: 3 additions & 0 deletions src/pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def prepare_data(self) -> None:
"""This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin.
.. warning:: Please download imagenet on your own first.
"""
self._verify_splits(self.data_dir, "train")
self._verify_splits(self.data_dir, "val")
Expand Down Expand Up @@ -223,6 +224,7 @@ def train_transform(self) -> Callable:
std=[0.229, 0.224, 0.225]
),
])
"""
return transform_lib.Compose(
[
Expand All @@ -247,6 +249,7 @@ def val_transform(self) -> Callable:
std=[0.229, 0.224, 0.225]
),
])
"""

return transform_lib.Compose(
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
Expand Down
Loading

0 comments on commit c6f6d3b

Please sign in to comment.