diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8f10f67..26ef2da 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -18,30 +18,30 @@ Before submitting any issue, please perform a thorough search to see if your pro If you like the project and wish to contribute, you can start by looking at issues labeled `good first issue` (should only require a few lines of code) or `help wanted` (more involved). If you found a bug and want to fix it, please create an issue reporting the bug before creating a pull request. Similarly, if you want to add a new feature, first create a feature request issue. This allows to separate the discussions related to the bug/feature, from the discussions related to the fix/implementation. -### Code conventions +### Testing -We mostly follow the [PEP 8](https://peps.python.org/pep-0008/) style guide for Python code. It is recommended that you format your code with the opinionated [Black](https://github.com/psf/black) formatter. For example, if you created or modified a file `path/to/filename.py`, you can reformat it with +We use [pytest](https://docs.pytest.org) to test our code base. If your contribution introduces new components, you should write new tests to make sure your code doesn't crash under normal circumstances. After installing `pytest`, add the tests to the [tests/](tests) directory and run them with ``` -black -S path/to/filename.py +pytest tests ``` -Additionally, please follow these rules: - -* Use single quotes for strings (`'single-quoted'`) but double quotes (`"double-quoted"`) for text such as error messages. -* Use informative but concise variable names. Single-letter names are fine if the context is clear. -* Avoid explaining code with comments. If something is hard to understand, simplify or decompose it. -* If Black's output [takes too much vertical space](https://github.com/psf/black/issues/1811), ignore its modifications. +It is also recommended to ensure your code works as expected within toy experiments similar to the [tutorials](docs/tutorials). When you submit a pull request, tests are automatically (upon approval) executed for several versions of Python and PyTorch. -### Testing +### Code conventions -We use [pytest](https://docs.pytest.org) to test our code base. If your contribution introduces new components, you should write new tests to make sure your code doesn't crash under normal circumstances. After installing `pytest`, add the tests to the [tests/](tests) directory and run them with +We use [Ruff](https://github.com/astral-sh/ruff) to lint and format all Python code. After installing `ruff`, you can check if your code follows our conventions with ``` -pytest tests +ruff check . +ruff format --check . ``` -It is also recommended to ensure your code works as expected within toy experiments similar to the [tutorials](docs/tutorials). When you submit a pull request, tests are automatically (upon approval) executed for several versions of Python and PyTorch. +Additionally, please follow these rules: + +* Use single quotes for strings (`'single-quoted'`) but double quotes (`"double-quoted"`) for text such as error messages. +* Use informative but concise variable names. Single-letter names are fine if the context is clear. +* Avoid explaining code with comments. If something is hard to understand, simplify or decompose it. ### Documentation diff --git a/docs/conf.py b/docs/conf.py index ee82d21..bf55105 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,8 +1,8 @@ # Configuration file for the Sphinx documentation builder import glob -import inspect import importlib +import inspect import lampe import re import subprocess @@ -59,22 +59,22 @@ def linkcode_resolve(domain: str, info: dict) -> str: for name in fullname.split('.'): objct = getattr(objct, name) + if hasattr(objct, '__wrapped__'): + objct = objct.__wrapped__ + try: file = inspect.getsourcefile(objct) file = file[file.rindex(package) :] lines, start = inspect.getsourcelines(objct) end = start + len(lines) - 1 - except Exception as e: + except Exception: return None else: return f'{repository}/blob/{commit}/{file}#L{start}-L{end}' -napoleon_custom_sections = [ - ('Shapes', 'params_style'), - 'Wikipedia', -] +napoleon_custom_sections = ['Wikipedia'] nb_execution_mode = 'off' myst_enable_extensions = ['dollarmath'] @@ -129,6 +129,7 @@ def linkcode_resolve(domain: str, info: dict) -> str: ## Edit HTML + def edit_html(app, exception): if exception: raise exception @@ -137,12 +138,15 @@ def edit_html(app, exception): with open(file, 'r') as f: text = f.read() + # fmt: off text = text.replace('@pradyunsg\'s', '') text = text.replace('[source]', '') text = re.sub(r'()()', r'\2\1', text) + # fmt: on with open(file, 'w') as f: f.write(text) + def setup(app): app.connect('build-finished', edit_html) diff --git a/docs/tutorials/coverage.ipynb b/docs/tutorials/coverage.ipynb index 4dd2049..98bb001 100644 --- a/docs/tutorials/coverage.ipynb +++ b/docs/tutorials/coverage.ipynb @@ -15,21 +15,18 @@ "metadata": {}, "outputs": [], "source": [ - "%matplotlib inline\n", - "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import zuko\n", "\n", - "from tqdm import tqdm\n", - "\n", - "from lampe.data import JointLoader, H5Dataset\n", + "from lampe.data import H5Dataset, JointLoader\n", "from lampe.diagnostics import expected_coverage_mc, expected_coverage_ni\n", - "from lampe.inference import NPE, NRE, FMPE, NPELoss, NRELoss, BNRELoss, FMPELoss\n", - "from lampe.plots import nice_rc, coverage_plot\n", - "from lampe.utils import GDStep" + "from lampe.inference import FMPE, NPE, NRE, BNRELoss, FMPELoss, NPELoss, NRELoss\n", + "from lampe.plots import coverage_plot, nice_rc\n", + "from lampe.utils import GDStep\n", + "from tqdm import trange" ] }, { @@ -147,20 +144,20 @@ "def train(loss: nn.Module, epochs: int = 256) -> None:\n", " loss.cuda().train()\n", "\n", - " optimizer = optim.AdamW(loss.parameters(), lr=1e-3)\n", + " optimizer = optim.Adam(loss.parameters(), lr=1e-3)\n", " step = GDStep(optimizer, clip=1.0)\n", "\n", - " with tqdm(range(epochs), unit='epoch') as tq:\n", - " for epoch in tq:\n", - " losses = torch.stack([\n", - " step(loss(theta.cuda(), x.cuda()))\n", - " for theta, x in trainset\n", - " ])\n", + " for epoch in (bar := trange(epochs, unit='epoch')):\n", + " losses = []\n", "\n", - " tq.set_postfix(loss=losses.mean().item())\n", + " for theta, x in trainset:\n", + " losses.append(step(loss(theta.cuda(), x.cuda())))\n", + "\n", + " bar.set_postfix(loss=torch.stack(losses).mean().item())\n", "\n", " loss.eval()\n", "\n", + "\n", "train(NPELoss(npe))\n", "train(NRELoss(nre))\n", "train(BNRELoss(bnre))\n", @@ -199,7 +196,7 @@ "\n", "npe_levels, npe_coverages = expected_coverage_mc(npe.flow, testset, device='cuda')\n", "\n", - "log_p = lambda theta, x: nre(theta, x) + prior.log_prob(theta) # log p(theta | x) = log r(theta, x) + log p(theta)\n", + "log_p = lambda theta, x: nre(theta, x) + prior.log_prob(theta)\n", "nre_levels, nre_coverages = expected_coverage_ni(log_p, testset, (LOWER, UPPER), device='cuda')\n", "\n", "log_p = lambda theta, x: bnre(theta, x) + prior.log_prob(theta)\n", diff --git a/docs/tutorials/embedding.ipynb b/docs/tutorials/embedding.ipynb index ad7787d..788beae 100644 --- a/docs/tutorials/embedding.ipynb +++ b/docs/tutorials/embedding.ipynb @@ -15,8 +15,6 @@ "metadata": {}, "outputs": [], "source": [ - "%matplotlib inline\n", - "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", @@ -24,13 +22,12 @@ "import torch.optim as optim\n", "import zuko\n", "\n", - "from tqdm import tqdm\n", - "\n", - "from lampe.data import JointLoader, H5Dataset\n", + "from lampe.data import H5Dataset, JointLoader\n", "from lampe.diagnostics import expected_coverage_mc\n", "from lampe.inference import NPE, NPELoss\n", - "from lampe.plots import nice_rc, corner, mark_point, coverage_plot\n", - "from lampe.utils import GDStep" + "from lampe.plots import corner, coverage_plot, mark_point, nice_rc\n", + "from lampe.utils import GDStep\n", + "from tqdm import trange" ] }, { @@ -73,6 +70,7 @@ "\n", "prior = zuko.distributions.BoxUniform(LOWER, UPPER)\n", "\n", + "\n", "class WorleyNoise:\n", " def __init__(self):\n", " self.domain = torch.cartesian_prod(\n", @@ -97,6 +95,7 @@ "\n", " return x.reshape(64, 64)\n", "\n", + "\n", "simulator = WorleyNoise()\n", "\n", "theta = prior.sample()\n", @@ -191,6 +190,7 @@ " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " return x + super().forward(x)\n", "\n", + "\n", "class NPEWithEmbedding(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", @@ -247,27 +247,22 @@ "source": [ "estimator = NPEWithEmbedding().cuda()\n", "loss = NPELoss(estimator)\n", - "optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)\n", + "optimizer = optim.Adam(estimator.parameters(), lr=1e-3)\n", "step = GDStep(optimizer, clip=1.0)\n", "\n", - "with tqdm(range(32), unit='epoch') as tq:\n", - " for epoch in tq:\n", - " estimator.train()\n", - "\n", - " losses = torch.stack([\n", - " step(loss(preprocess(theta).cuda(), x.cuda()))\n", - " for theta, x in trainset\n", - " ])\n", + "for epoch in (bar := trange(32, unit='epoch')):\n", + " losses, val_losses = [], []\n", "\n", - " estimator.eval()\n", + " for theta, x in trainset:\n", + " losses.append(step(loss(preprocess(theta).cuda(), x.cuda())))\n", "\n", - " with torch.no_grad():\n", - " val_losses = torch.stack([\n", - " loss(preprocess(theta).cuda(), x.cuda())\n", - " for theta, x in validset\n", - " ])\n", + " for theta, x in validset:\n", + " val_losses.append(loss(preprocess(theta).cuda(), x.cuda()))\n", "\n", - " tq.set_postfix(loss=losses.mean().item(), val_loss=val_losses.mean().item())" + " bar.set_postfix(\n", + " loss=torch.stack(losses).mean().item(),\n", + " val_loss=torch.stack(val_losses).mean().item(),\n", + " )" ] }, { diff --git a/docs/tutorials/fmpe.ipynb b/docs/tutorials/fmpe.ipynb index 355b053..d93ee02 100644 --- a/docs/tutorials/fmpe.ipynb +++ b/docs/tutorials/fmpe.ipynb @@ -15,8 +15,6 @@ "metadata": {}, "outputs": [], "source": [ - "%matplotlib inline\n", - "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", @@ -24,12 +22,11 @@ "import zuko\n", "\n", "from itertools import islice\n", - "from tqdm import tqdm\n", - "\n", "from lampe.data import JointLoader\n", "from lampe.inference import FMPE, FMPELoss\n", - "from lampe.plots import nice_rc, corner, mark_point\n", - "from lampe.utils import GDStep" + "from lampe.plots import corner, mark_point, nice_rc\n", + "from lampe.utils import GDStep\n", + "from tqdm import trange" ] }, { @@ -149,22 +146,19 @@ ], "source": [ "loss = FMPELoss(estimator)\n", - "optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)\n", + "optimizer = optim.Adam(estimator.parameters(), lr=1e-3)\n", "scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 128)\n", "step = GDStep(optimizer, clip=1.0) # gradient descent step with gradient clipping\n", "\n", "estimator.train()\n", "\n", - "with tqdm(range(128), unit='epoch') as tq:\n", - " for epoch in tq:\n", - " losses = torch.stack([\n", - " step(loss(theta, x))\n", - " for theta, x in islice(loader, 256) # 256 batches per epoch\n", - " ])\n", + "for epoch in (bar := trange(128, unit='epoch')):\n", + " losses = []\n", "\n", - " tq.set_postfix(loss=losses.mean().item())\n", + " for theta, x in islice(loader, 256): # 256 batches per epoch\n", + " losses.append(step(loss(theta, x)))\n", "\n", - " scheduler.step()" + " bar.set_postfix(loss=torch.stack(losses).mean().item())" ] }, { diff --git a/docs/tutorials/npe.ipynb b/docs/tutorials/npe.ipynb index 64b83cd..7012848 100644 --- a/docs/tutorials/npe.ipynb +++ b/docs/tutorials/npe.ipynb @@ -15,21 +15,17 @@ "metadata": {}, "outputs": [], "source": [ - "%matplotlib inline\n", - "\n", "import matplotlib.pyplot as plt\n", "import torch\n", - "import torch.nn as nn\n", "import torch.optim as optim\n", "import zuko\n", "\n", "from itertools import islice\n", - "from tqdm import tqdm\n", - "\n", "from lampe.data import JointLoader\n", "from lampe.inference import NPE, NPELoss\n", - "from lampe.plots import nice_rc, corner, mark_point\n", - "from lampe.utils import GDStep" + "from lampe.plots import corner, mark_point, nice_rc\n", + "from lampe.utils import GDStep\n", + "from tqdm import trange" ] }, { @@ -162,7 +158,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In the case of NPE, the loss to minimize is the expected log-likelihood of the data $ \\mathbb{E}_{p(\\theta, x)} \\big[ -\\log p_\\phi(\\theta | x) \\big] $, which is easy to implement ourselves." + "In the case of NPE, the loss to minimize is the expected negative log-likelihood of the data $ \\mathbb{E}_{p(\\theta, x)} \\big[ -\\log p_\\phi(\\theta | x) \\big] $, which is easy to implement ourselves." ] }, { @@ -171,7 +167,7 @@ "metadata": {}, "outputs": [], "source": [ - "def loss(theta, x):\n", + "def nll(theta, x):\n", " log_p = estimator(theta, x) # log p(theta | x)\n", " return -log_p.mean()" ] @@ -213,19 +209,18 @@ } ], "source": [ - "optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)\n", + "optimizer = optim.Adam(estimator.parameters(), lr=1e-3)\n", "step = GDStep(optimizer, clip=1.0) # gradient descent step with gradient clipping\n", "\n", "estimator.train()\n", "\n", - "with tqdm(range(64), unit='epoch') as tq:\n", - " for epoch in tq:\n", - " losses = torch.stack([\n", - " step(loss(theta, x))\n", - " for theta, x in islice(loader, 256) # 256 batches per epoch\n", - " ])\n", + "for epoch in (bar := trange(64, unit='epoch')):\n", + " losses = []\n", + "\n", + " for theta, x in islice(loader, 256): # 256 batches per epoch\n", + " losses.append(step(loss(theta, x)))\n", "\n", - " tq.set_postfix(loss=losses.mean().item())" + " bar.set_postfix(loss=torch.stack(losses).mean().item())" ] }, { diff --git a/docs/tutorials/nre.ipynb b/docs/tutorials/nre.ipynb index 9835281..f4ae050 100644 --- a/docs/tutorials/nre.ipynb +++ b/docs/tutorials/nre.ipynb @@ -15,8 +15,6 @@ "metadata": {}, "outputs": [], "source": [ - "%matplotlib inline\n", - "\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", @@ -24,12 +22,11 @@ "import zuko\n", "\n", "from itertools import islice\n", - "from tqdm import tqdm\n", - "\n", "from lampe.data import JointLoader\n", - "from lampe.inference import NRE, NRELoss, MetropolisHastings\n", - "from lampe.plots import nice_rc, corner, mark_point\n", - "from lampe.utils import GDStep" + "from lampe.inference import NRE, MetropolisHastings, NRELoss\n", + "from lampe.plots import corner, mark_point, nice_rc\n", + "from lampe.utils import GDStep\n", + "from tqdm import trange" ] }, { @@ -149,19 +146,18 @@ ], "source": [ "loss = NRELoss(estimator)\n", - "optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)\n", + "optimizer = optim.Adam(estimator.parameters(), lr=1e-3)\n", "step = GDStep(optimizer, clip=1.0) # gradient descent step with gradient clipping\n", "\n", "estimator.train()\n", "\n", - "with tqdm(range(128), unit='epoch') as tq:\n", - " for epoch in tq:\n", - " losses = torch.stack([\n", - " step(loss(theta, x))\n", - " for theta, x in islice(loader, 256) # 256 batches per epoch\n", - " ])\n", + "for epoch in (bar := trange(128, unit='epoch')):\n", + " losses = []\n", + "\n", + " for theta, x in islice(loader, 256): # 256 batches per epoch\n", + " losses.append(step(loss(theta, x)))\n", "\n", - " tq.set_postfix(loss=losses.mean().item())" + " bar.set_postfix(loss=torch.stack(losses).mean().item())" ] }, { @@ -185,8 +181,10 @@ "estimator.eval()\n", "\n", "with torch.no_grad():\n", - " theta_0 = prior.sample((1024,)) # 1024 concurrent Markov chains\n", - " log_p = lambda theta: estimator(theta, x_star) + prior.log_prob(theta) # p(theta | x) = r(theta, x) p(theta)\n", + " # 1024 concurrent Markov chains\n", + " theta_0 = prior.sample((1024,))\n", + " # p(theta | x) = r(theta, x) p(theta)\n", + " log_p = lambda theta: estimator(theta, x_star) + prior.log_prob(theta)\n", "\n", " sampler = MetropolisHastings(theta_0, log_f=log_p, sigma=0.5)\n", " samples = torch.cat(list(sampler(2048, burn=1024, step=4)))" diff --git a/docs/tutorials/simulators.ipynb b/docs/tutorials/simulators.ipynb index 907255e..f89ee75 100644 --- a/docs/tutorials/simulators.ipynb +++ b/docs/tutorials/simulators.ipynb @@ -17,8 +17,6 @@ "metadata": {}, "outputs": [], "source": [ - "%matplotlib inline\n", - "\n", "import lampe\n", "import torch\n", "import zuko\n", diff --git a/lampe/__init__.py b/lampe/__init__.py index 0e54525..1bbd3bc 100644 --- a/lampe/__init__.py +++ b/lampe/__init__.py @@ -2,7 +2,7 @@ __version__ = '0.8.2' -from . import data -from . import inference -from . import nn -from . import utils +from . import data as data +from . import inference as inference +from . import nn as nn +from . import utils as utils diff --git a/lampe/data.py b/lampe/data.py index ff0df86..8431e3d 100644 --- a/lampe/data.py +++ b/lampe/data.py @@ -8,7 +8,7 @@ from numpy import ndarray as Array from pathlib import Path -from torch import Tensor, Size +from torch import Size, Tensor from torch.distributions import Distribution from torch.utils.data import DataLoader, Dataset, IterableDataset from tqdm import tqdm @@ -216,8 +216,7 @@ def __getitem__(self, i: Union[int, slice]) -> Tuple[Tensor, Tensor]: def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]: chunks = torch.tensor([ - (i, i + self.chunk_size) - for i in range(0, len(self), self.chunk_size) + (i, i + self.chunk_size) for i in range(0, len(self), self.chunk_size) ]) if self.shuffle: diff --git a/lampe/diagnostics.py b/lampe/diagnostics.py index c033c13..2529c0a 100644 --- a/lampe/diagnostics.py +++ b/lampe/diagnostics.py @@ -7,6 +7,7 @@ from tqdm import tqdm from typing import * +# isort: local from .utils import gridapply diff --git a/lampe/inference/__init__.py b/lampe/inference/__init__.py index e3e96fa..cbfb628 100644 --- a/lampe/inference/__init__.py +++ b/lampe/inference/__init__.py @@ -1,11 +1,9 @@ r"""Inference components such as estimators, training losses and MCMC samplers.""" -from .mcmc import * - -from .nre import * from .amnre import * from .bnre import * from .cnre import * - -from .npe import * from .fmpe import * +from .mcmc import * +from .npe import * +from .nre import * diff --git a/lampe/inference/amnre.py b/lampe/inference/amnre.py index 80f0add..46525db 100644 --- a/lampe/inference/amnre.py +++ b/lampe/inference/amnre.py @@ -35,12 +35,12 @@ import torch.nn as nn import torch.nn.functional as F -from torch import Tensor, BoolTensor +from torch import BoolTensor, Tensor from torch.distributions import Distribution from typing import * - from zuko.utils import broadcast +# isort: local from .nre import NRE diff --git a/lampe/inference/fmpe.py b/lampe/inference/fmpe.py index a398a12..9e32f6a 100644 --- a/lampe/inference/fmpe.py +++ b/lampe/inference/fmpe.py @@ -29,11 +29,11 @@ from torch import Tensor from torch.distributions import Distribution from typing import * - from zuko.distributions import DiagNormal, NormalizingFlow from zuko.transforms import FreeFormJacobianTransform from zuko.utils import broadcast +# isort: local from ..nn import MLP diff --git a/lampe/inference/mcmc.py b/lampe/inference/mcmc.py index 16056bd..dde947c 100644 --- a/lampe/inference/mcmc.py +++ b/lampe/inference/mcmc.py @@ -5,13 +5,11 @@ ] import torch -import torch.nn as nn from itertools import islice from torch import Tensor from torch.distributions import Distribution from typing import * - from zuko.distributions import DiagNormal @@ -72,9 +70,7 @@ def __init__( self.x_0 = x_0 - assert ( - f is not None or log_f is not None - ), "Either 'f' or 'log_f' has to be provided." + assert f is not None or log_f is not None, "Either 'f' or 'log_f' has to be provided." if f is None: self.log_f = log_f diff --git a/lampe/inference/npe.py b/lampe/inference/npe.py index bef01be..307a803 100644 --- a/lampe/inference/npe.py +++ b/lampe/inference/npe.py @@ -26,13 +26,11 @@ 'NPELoss', ] -import torch import torch.nn as nn from torch import Tensor from typing import * - -from zuko.flows import Flow, MAF +from zuko.flows import MAF, Flow from zuko.utils import broadcast diff --git a/lampe/inference/nre.py b/lampe/inference/nre.py index 7268ed8..dc86c92 100644 --- a/lampe/inference/nre.py +++ b/lampe/inference/nre.py @@ -45,9 +45,9 @@ from torch import Tensor from typing import * - from zuko.utils import broadcast +# isort: local from ..nn import MLP diff --git a/lampe/masks.py b/lampe/masks.py index 0e3b366..ce0f0c8 100644 --- a/lampe/masks.py +++ b/lampe/masks.py @@ -1,10 +1,9 @@ r"""Masking helpers.""" import torch -import torch.nn as nn -from torch import Tensor, BoolTensor, Size -from torch.distributions import Distribution, Bernoulli, Independent +from torch import BoolTensor, Size, Tensor +from torch.distributions import Bernoulli, Distribution, Independent from typing import * diff --git a/lampe/nn.py b/lampe/nn.py index 41f98a8..1e6f90d 100644 --- a/lampe/nn.py +++ b/lampe/nn.py @@ -2,7 +2,6 @@ __all__ = ['MLP', 'ResMLP'] -import torch import torch.nn as nn from torch import Tensor diff --git a/lampe/plots.py b/lampe/plots.py index 04abdcb..3589e0f 100644 --- a/lampe/plots.py +++ b/lampe/plots.py @@ -68,7 +68,7 @@ def __new__( name: str = None, ): if name is None: - if type(color) is str: + if isinstance(color, str): name = f'alpha_{color}' else: name = f'alpha_{hash(color)}' @@ -184,7 +184,7 @@ def corner( data = np.asarray(data) D = data.shape[-1] - if type(bins) is int: + if isinstance(bins, int): bins = [bins] * D if domain is None: @@ -193,8 +193,7 @@ def corner( lower, upper = map(np.asarray, domain) bins = [ - np.histogram_bin_edges(data, bins[i], range=(lower[i], upper[i])) - for i in range(D) + np.histogram_bin_edges(data, bins[i], range=(lower[i], upper[i])) for i in range(D) ] hists = np.ndarray((D, D), dtype=object) diff --git a/lampe/utils.py b/lampe/utils.py index 0fdc1de..b693d42 100644 --- a/lampe/utils.py +++ b/lampe/utils.py @@ -28,13 +28,8 @@ class GDStep(object): """ def __init__(self, optimizer: Optimizer, clip: float = None): - self.optimizer = optimizer - self.parameters = [ - p - for group in optimizer.param_groups - for p in group['params'] - ] + self.parameters = [p for group in optimizer.param_groups for p in group['params']] self.clip = clip def __call__(self, loss: Tensor) -> Tensor: @@ -86,7 +81,7 @@ def gridapply( # Shape dims = len(lower) - if type(bins) is int: + if isinstance(bins, int): bins = [bins] * dims # Create grid diff --git a/pyproject.toml b/pyproject.toml index 821feef..43b396f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,11 +30,36 @@ keywords = [ readme = "README.md" requires-python = ">=3.8" +[project.optional-dependencies] +dev = [ + "pytest", + "ruff", +] + [project.urls] documentation = "https://lampe.readthedocs.io" source = "https://github.com/probabilists/lampe" tracker = "https://github.com/probabilists/lampe/issues" +[tool.ruff] +extend-include = ["*.ipynb"] +extend-select = ["I"] +line-length = 99 + +[tool.ruff.lint] +ignore = ["F403", "F405", "E731", "E741"] +ignore-init-module-imports = true + +[tool.ruff.lint.isort] +lines-between-types = 1 +no-sections = true +relative-imports-order = "closest-to-furthest" + +[tool.ruff.format] +exclude = ["*.ipynb"] +preview = true +quote-style = "preserve" + [tool.setuptools.dynamic] dependencies = {file = "requirements.txt"} version = {attr = "lampe.__version__"} diff --git a/tests/test_data.py b/tests/test_data.py index abc08ec..fa91c8c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,6 +1,5 @@ r"""Tests for the lampe.data module.""" -import h5py import numpy as np import pytest import torch diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py index d805e15..7a38030 100644 --- a/tests/test_diagnostics.py +++ b/tests/test_diagnostics.py @@ -1,6 +1,5 @@ r"""Tests for the lampe.diagnostics module.""" -import pytest import torch from lampe.diagnostics import * @@ -46,17 +45,13 @@ def test_expected_coverage_ni(): assert torch.allclose(levels, coverages, atol=1e-1) # Conservative - estimator = lambda theta, x: Independent( - Truncated(Normal(0, 2 + x**2), -3, 3), 1 - ).log_prob(theta) + estimator = lambda theta, x: Truncated(Normal(0, 2 + x**2), -3, 3).log_prob(theta).sum(-1) levels, coverages = expected_coverage_ni(estimator, pairs, domain, bins=128) assert (coverages > levels).float().mean() > 0.9 # Overconfident - estimator = lambda theta, x: Independent( - Truncated(Normal(0, 0.5 + x**2), -3, 3), 1 - ).log_prob(theta) + estimator = lambda theta, x: Truncated(Normal(0, 0.5 + x**2), -3, 3).log_prob(theta).sum(-1) levels, coverages = expected_coverage_ni(estimator, pairs, domain, bins=128) assert (coverages < levels).float().mean() > 0.9 diff --git a/tests/test_inference.py b/tests/test_inference.py index 31a6fe4..97e626d 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -1,6 +1,5 @@ r"""Tests for the lampe.inference module.""" -import pytest import torch from lampe.inference import * diff --git a/tests/test_masks.py b/tests/test_masks.py index 686734c..4d0b3b8 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -1,6 +1,5 @@ r"""Tests for the lampe.masks module.""" -import pytest import torch from lampe.masks import * diff --git a/tests/test_nn.py b/tests/test_nn.py index e402d08..924069a 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,8 +1,5 @@ r"""Tests for the lampe.nn module.""" -import pytest -import torch - from lampe.nn import * from torch import randn diff --git a/tests/test_plots.py b/tests/test_plots.py index 2a20d34..8972fc8 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -1,13 +1,12 @@ r"""Tests for the lampe.plots module.""" -import pytest import numpy as np from lampe.plots import * def test_nice_rc(): - assert type(nice_rc()) is dict + assert isinstance(nice_rc(), dict) def test_corner(): @@ -38,4 +37,4 @@ def test_corner(): def test_coverage_plot(): levels = np.random.rand(512) ** 2 coverages = np.linspace(0, 1, 512) - figure = coverage_plot(levels, coverages) + figure = coverage_plot(levels, coverages) # noqa: F841 diff --git a/tests/test_utils.py b/tests/test_utils.py index 6f9b404..8e08bbd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,5 @@ r"""Tests for the lampe.utils module.""" -import math import pytest import torch