Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

post hoc ema from https://arxiv.org/pdf/2312.02696.pdf #17

Merged
merged 8 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,50 @@ ema_output = ema(data)
# however, if you wish to access the copy of your model with EMA, then it will live at ema.ema_model
```

## Todo
In order to use the post-hoc synthesized EMA, proposed by Karras et al. in <a href="https://arxiv.org/abs/2312.02696">a recent paper</a>, follow the example below

- [ ] address the issue of annealing EMA to 1 near the end of training for BYOL https://github.com/lucidrains/byol-pytorch/issues/82
```python
import torch
from ema_pytorch import PostHocEMA

# your neural network as a pytorch module

net = torch.nn.Linear(512, 512)

# wrap your neural network, specify the decay (beta)

ema = PostHocEMA(
net,
sigma_rels = (0.05, 0.3), # a tuple with the hyperparameter for the multiple EMAs. you need at least 2 here to synthesize a new one
update_every = 10, # how often to actually update, to save on compute (updates every 10th .update() call)
checkpoint_every_num_steps = 10
)

net.train()

for _ in range(1000):
# mutate your network, with SGD or otherwise

with torch.no_grad():
net.weight.copy_(torch.randn_like(net.weight))
net.bias.copy_(torch.randn_like(net.bias))

# you will call the update function on your moving average wrapper

ema.update()

# now that you have a few checkpoints
# you can synthesize an EMA model with a different sigma_rel (say 0.15)

synthesized_ema_model = ema.synthesize_ema_model(sigma_rel = 0.15)

# output with synthesized EMA

data = torch.randn(1, 512)

synthesized_ema_output = synthesized_ema_model(data)

```

## Citations

Expand Down
5 changes: 5 additions & 0 deletions ema_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from ema_pytorch.ema_pytorch import EMA

from ema_pytorch.post_hoc_ema import (
KarrasEMA,
PostHocEMA
)
10 changes: 1 addition & 9 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ def __init__(
allow_different_devices = False # if the EMA model is on a different device (say CPU), automatically move the tensor
):
super().__init__()
self._beta = beta
self.karras_beta = karras_beta
self.beta = beta

self.is_frozen = beta == 1.

Expand Down Expand Up @@ -131,13 +130,6 @@ def __init__(
@property
def model(self):
return self.online_model if self.include_online_model else self.online_model[0]

@property
def beta(self):
if self.karras_beta:
return (1 - 1 / (self.step + 1)) ** (1 + self.power)

return self._beta

def eval(self):
return self.ema_model.eval()
Expand Down
Loading