Skip to content

Commit

Permalink
🎨 Lint and format code with Ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jan 29, 2024
1 parent 338046f commit 07ad235
Show file tree
Hide file tree
Showing 29 changed files with 143 additions and 166 deletions.
26 changes: 13 additions & 13 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@ Before submitting any issue, please perform a thorough search to see if your pro

If you like the project and wish to contribute, you can start by looking at issues labeled `good first issue` (should only require a few lines of code) or `help wanted` (more involved). If you found a bug and want to fix it, please create an issue reporting the bug before creating a pull request. Similarly, if you want to add a new feature, first create a feature request issue. This allows to separate the discussions related to the bug/feature, from the discussions related to the fix/implementation.

### Code conventions
### Testing

We mostly follow the [PEP 8](https://peps.python.org/pep-0008/) style guide for Python code. It is recommended that you format your code with the opinionated [Black](https://github.com/psf/black) formatter. For example, if you created or modified a file `path/to/filename.py`, you can reformat it with
We use [pytest](https://docs.pytest.org) to test our code base. If your contribution introduces new components, you should write new tests to make sure your code doesn't crash under normal circumstances. After installing `pytest`, add the tests to the [tests/](tests) directory and run them with

```
black -S path/to/filename.py
pytest tests
```

Additionally, please follow these rules:

* Use single quotes for strings (`'single-quoted'`) but double quotes (`"double-quoted"`) for text such as error messages.
* Use informative but concise variable names. Single-letter names are fine if the context is clear.
* Avoid explaining code with comments. If something is hard to understand, simplify or decompose it.
* If Black's output [takes too much vertical space](https://github.com/psf/black/issues/1811), ignore its modifications.
It is also recommended to ensure your code works as expected within toy experiments similar to the [tutorials](docs/tutorials). When you submit a pull request, tests are automatically (upon approval) executed for several versions of Python and PyTorch.

### Testing
### Code conventions

We use [pytest](https://docs.pytest.org) to test our code base. If your contribution introduces new components, you should write new tests to make sure your code doesn't crash under normal circumstances. After installing `pytest`, add the tests to the [tests/](tests) directory and run them with
We use [Ruff](https://github.com/astral-sh/ruff) to lint and format all Python code. After installing `ruff`, you can check if your code follows our conventions with

```
pytest tests
ruff check .
ruff format --check .
```

It is also recommended to ensure your code works as expected within toy experiments similar to the [tutorials](docs/tutorials). When you submit a pull request, tests are automatically (upon approval) executed for several versions of Python and PyTorch.
Additionally, please follow these rules:

* Use single quotes for strings (`'single-quoted'`) but double quotes (`"double-quoted"`) for text such as error messages.
* Use informative but concise variable names. Single-letter names are fine if the context is clear.
* Avoid explaining code with comments. If something is hard to understand, simplify or decompose it.

### Documentation

Expand Down
16 changes: 10 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Configuration file for the Sphinx documentation builder

import glob
import inspect
import importlib
import inspect
import lampe
import re
import subprocess
Expand Down Expand Up @@ -59,22 +59,22 @@ def linkcode_resolve(domain: str, info: dict) -> str:
for name in fullname.split('.'):
objct = getattr(objct, name)

if hasattr(objct, '__wrapped__'):
objct = objct.__wrapped__

try:
file = inspect.getsourcefile(objct)
file = file[file.rindex(package) :]

lines, start = inspect.getsourcelines(objct)
end = start + len(lines) - 1
except Exception as e:
except Exception:
return None
else:
return f'{repository}/blob/{commit}/{file}#L{start}-L{end}'


napoleon_custom_sections = [
('Shapes', 'params_style'),
'Wikipedia',
]
napoleon_custom_sections = ['Wikipedia']

nb_execution_mode = 'off'
myst_enable_extensions = ['dollarmath']
Expand Down Expand Up @@ -129,6 +129,7 @@ def linkcode_resolve(domain: str, info: dict) -> str:

## Edit HTML


def edit_html(app, exception):
if exception:
raise exception
Expand All @@ -137,12 +138,15 @@ def edit_html(app, exception):
with open(file, 'r') as f:
text = f.read()

# fmt: off
text = text.replace('<a class="muted-link" href="https://pradyunsg.me">@pradyunsg</a>\'s', '')
text = text.replace('<span class="pre">[source]</span>', '<i class="fa-solid fa-code"></i>')
text = re.sub(r'(<a class="reference external".*</a>)(<a class="headerlink".*</a>)', r'\2\1', text)
# fmt: on

with open(file, 'w') as f:
f.write(text)


def setup(app):
app.connect('build-finished', edit_html)
31 changes: 14 additions & 17 deletions docs/tutorials/coverage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import zuko\n",
"\n",
"from tqdm import tqdm\n",
"\n",
"from lampe.data import JointLoader, H5Dataset\n",
"from lampe.data import H5Dataset, JointLoader\n",
"from lampe.diagnostics import expected_coverage_mc, expected_coverage_ni\n",
"from lampe.inference import NPE, NRE, FMPE, NPELoss, NRELoss, BNRELoss, FMPELoss\n",
"from lampe.plots import nice_rc, coverage_plot\n",
"from lampe.utils import GDStep"
"from lampe.inference import FMPE, NPE, NRE, BNRELoss, FMPELoss, NPELoss, NRELoss\n",
"from lampe.plots import coverage_plot, nice_rc\n",
"from lampe.utils import GDStep\n",
"from tqdm import trange"
]
},
{
Expand Down Expand Up @@ -147,20 +144,20 @@
"def train(loss: nn.Module, epochs: int = 256) -> None:\n",
" loss.cuda().train()\n",
"\n",
" optimizer = optim.AdamW(loss.parameters(), lr=1e-3)\n",
" optimizer = optim.Adam(loss.parameters(), lr=1e-3)\n",
" step = GDStep(optimizer, clip=1.0)\n",
"\n",
" with tqdm(range(epochs), unit='epoch') as tq:\n",
" for epoch in tq:\n",
" losses = torch.stack([\n",
" step(loss(theta.cuda(), x.cuda()))\n",
" for theta, x in trainset\n",
" ])\n",
" for epoch in (bar := trange(epochs, unit='epoch')):\n",
" losses = []\n",
"\n",
" tq.set_postfix(loss=losses.mean().item())\n",
" for theta, x in trainset:\n",
" losses.append(step(loss(theta.cuda(), x.cuda())))\n",
"\n",
" bar.set_postfix(loss=torch.stack(losses).mean().item())\n",
"\n",
" loss.eval()\n",
"\n",
"\n",
"train(NPELoss(npe))\n",
"train(NRELoss(nre))\n",
"train(BNRELoss(bnre))\n",
Expand Down Expand Up @@ -199,7 +196,7 @@
"\n",
"npe_levels, npe_coverages = expected_coverage_mc(npe.flow, testset, device='cuda')\n",
"\n",
"log_p = lambda theta, x: nre(theta, x) + prior.log_prob(theta) # log p(theta | x) = log r(theta, x) + log p(theta)\n",
"log_p = lambda theta, x: nre(theta, x) + prior.log_prob(theta)\n",
"nre_levels, nre_coverages = expected_coverage_ni(log_p, testset, (LOWER, UPPER), device='cuda')\n",
"\n",
"log_p = lambda theta, x: bnre(theta, x) + prior.log_prob(theta)\n",
Expand Down
41 changes: 18 additions & 23 deletions docs/tutorials/embedding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,19 @@
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import zuko\n",
"\n",
"from tqdm import tqdm\n",
"\n",
"from lampe.data import JointLoader, H5Dataset\n",
"from lampe.data import H5Dataset, JointLoader\n",
"from lampe.diagnostics import expected_coverage_mc\n",
"from lampe.inference import NPE, NPELoss\n",
"from lampe.plots import nice_rc, corner, mark_point, coverage_plot\n",
"from lampe.utils import GDStep"
"from lampe.plots import corner, coverage_plot, mark_point, nice_rc\n",
"from lampe.utils import GDStep\n",
"from tqdm import trange"
]
},
{
Expand Down Expand Up @@ -73,6 +70,7 @@
"\n",
"prior = zuko.distributions.BoxUniform(LOWER, UPPER)\n",
"\n",
"\n",
"class WorleyNoise:\n",
" def __init__(self):\n",
" self.domain = torch.cartesian_prod(\n",
Expand All @@ -97,6 +95,7 @@
"\n",
" return x.reshape(64, 64)\n",
"\n",
"\n",
"simulator = WorleyNoise()\n",
"\n",
"theta = prior.sample()\n",
Expand Down Expand Up @@ -191,6 +190,7 @@
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return x + super().forward(x)\n",
"\n",
"\n",
"class NPEWithEmbedding(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
Expand Down Expand Up @@ -247,27 +247,22 @@
"source": [
"estimator = NPEWithEmbedding().cuda()\n",
"loss = NPELoss(estimator)\n",
"optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)\n",
"optimizer = optim.Adam(estimator.parameters(), lr=1e-3)\n",
"step = GDStep(optimizer, clip=1.0)\n",
"\n",
"with tqdm(range(32), unit='epoch') as tq:\n",
" for epoch in tq:\n",
" estimator.train()\n",
"\n",
" losses = torch.stack([\n",
" step(loss(preprocess(theta).cuda(), x.cuda()))\n",
" for theta, x in trainset\n",
" ])\n",
"for epoch in (bar := trange(32, unit='epoch')):\n",
" losses, val_losses = [], []\n",
"\n",
" estimator.eval()\n",
" for theta, x in trainset:\n",
" losses.append(step(loss(preprocess(theta).cuda(), x.cuda())))\n",
"\n",
" with torch.no_grad():\n",
" val_losses = torch.stack([\n",
" loss(preprocess(theta).cuda(), x.cuda())\n",
" for theta, x in validset\n",
" ])\n",
" for theta, x in validset:\n",
" val_losses.append(loss(preprocess(theta).cuda(), x.cuda()))\n",
"\n",
" tq.set_postfix(loss=losses.mean().item(), val_loss=val_losses.mean().item())"
" bar.set_postfix(\n",
" loss=torch.stack(losses).mean().item(),\n",
" val_loss=torch.stack(val_losses).mean().item(),\n",
" )"
]
},
{
Expand Down
24 changes: 9 additions & 15 deletions docs/tutorials/fmpe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import zuko\n",
"\n",
"from itertools import islice\n",
"from tqdm import tqdm\n",
"\n",
"from lampe.data import JointLoader\n",
"from lampe.inference import FMPE, FMPELoss\n",
"from lampe.plots import nice_rc, corner, mark_point\n",
"from lampe.utils import GDStep"
"from lampe.plots import corner, mark_point, nice_rc\n",
"from lampe.utils import GDStep\n",
"from tqdm import trange"
]
},
{
Expand Down Expand Up @@ -149,22 +146,19 @@
],
"source": [
"loss = FMPELoss(estimator)\n",
"optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)\n",
"optimizer = optim.Adam(estimator.parameters(), lr=1e-3)\n",
"scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 128)\n",
"step = GDStep(optimizer, clip=1.0) # gradient descent step with gradient clipping\n",
"\n",
"estimator.train()\n",
"\n",
"with tqdm(range(128), unit='epoch') as tq:\n",
" for epoch in tq:\n",
" losses = torch.stack([\n",
" step(loss(theta, x))\n",
" for theta, x in islice(loader, 256) # 256 batches per epoch\n",
" ])\n",
"for epoch in (bar := trange(128, unit='epoch')):\n",
" losses = []\n",
"\n",
" tq.set_postfix(loss=losses.mean().item())\n",
" for theta, x in islice(loader, 256): # 256 batches per epoch\n",
" losses.append(step(loss(theta, x)))\n",
"\n",
" scheduler.step()"
" bar.set_postfix(loss=torch.stack(losses).mean().item())"
]
},
{
Expand Down
29 changes: 12 additions & 17 deletions docs/tutorials/npe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import zuko\n",
"\n",
"from itertools import islice\n",
"from tqdm import tqdm\n",
"\n",
"from lampe.data import JointLoader\n",
"from lampe.inference import NPE, NPELoss\n",
"from lampe.plots import nice_rc, corner, mark_point\n",
"from lampe.utils import GDStep"
"from lampe.plots import corner, mark_point, nice_rc\n",
"from lampe.utils import GDStep\n",
"from tqdm import trange"
]
},
{
Expand Down Expand Up @@ -162,7 +158,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In the case of NPE, the loss to minimize is the expected log-likelihood of the data $ \\mathbb{E}_{p(\\theta, x)} \\big[ -\\log p_\\phi(\\theta | x) \\big] $, which is easy to implement ourselves."
"In the case of NPE, the loss to minimize is the expected negative log-likelihood of the data $ \\mathbb{E}_{p(\\theta, x)} \\big[ -\\log p_\\phi(\\theta | x) \\big] $, which is easy to implement ourselves."
]
},
{
Expand All @@ -171,7 +167,7 @@
"metadata": {},
"outputs": [],
"source": [
"def loss(theta, x):\n",
"def nll(theta, x):\n",
" log_p = estimator(theta, x) # log p(theta | x)\n",
" return -log_p.mean()"
]
Expand Down Expand Up @@ -213,19 +209,18 @@
}
],
"source": [
"optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)\n",
"optimizer = optim.Adam(estimator.parameters(), lr=1e-3)\n",
"step = GDStep(optimizer, clip=1.0) # gradient descent step with gradient clipping\n",
"\n",
"estimator.train()\n",
"\n",
"with tqdm(range(64), unit='epoch') as tq:\n",
" for epoch in tq:\n",
" losses = torch.stack([\n",
" step(loss(theta, x))\n",
" for theta, x in islice(loader, 256) # 256 batches per epoch\n",
" ])\n",
"for epoch in (bar := trange(64, unit='epoch')):\n",
" losses = []\n",
"\n",
" for theta, x in islice(loader, 256): # 256 batches per epoch\n",
" losses.append(step(loss(theta, x)))\n",
"\n",
" tq.set_postfix(loss=losses.mean().item())"
" bar.set_postfix(loss=torch.stack(losses).mean().item())"
]
},
{
Expand Down
Loading

0 comments on commit 07ad235

Please sign in to comment.