Skip to content

Commit

Permalink
add ability to use foreach
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 14, 2024
1 parent 60c3b9e commit a1b3342
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
62 changes: 51 additions & 11 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from copy import deepcopy
from functools import partial

Expand All @@ -6,14 +8,11 @@
from torch.nn import Module

from beartype import beartype
from beartype.typing import Set, Optional
from beartype.typing import Set

def exists(val):
return val is not None

def get_module_device(m: Module):
return next(m.parameters()).device

def inplace_copy(tgt: Tensor, src: Tensor, *, auto_move_device = False):
if auto_move_device:
src = src.to(tgt.device)
Expand Down Expand Up @@ -51,7 +50,7 @@ class EMA(Module):
def __init__(
self,
model: Module,
ema_model: Optional[Module] = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
ema_model: Module | None = None, # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model
beta = 0.9999,
update_after_step = 100,
update_every = 10,
Expand All @@ -62,7 +61,8 @@ def __init__(
ignore_names: Set[str] = set(),
ignore_startswith_names: Set[str] = set(),
include_online_model = True, # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally)
allow_different_devices = False # if the EMA model is on a different device (say CPU), automatically move the tensor
allow_different_devices = False, # if the EMA model is on a different device (say CPU), automatically move the tensor
use_foreach = False
):
super().__init__()
self.beta = beta
Expand Down Expand Up @@ -122,6 +122,13 @@ def __init__(

self.allow_different_devices = allow_different_devices

# whether to use foreach

if use_foreach:
assert hasattr(torch, '_foreach_lerp_') and hasattr(torch, '_foreach_copy_'), 'your version of torch does not have the prerequisite foreach functions'

self.use_foreach = use_foreach

# init and step states

self.register_buffer('initted', torch.tensor(False))
Expand Down Expand Up @@ -199,9 +206,15 @@ def update_moving_average(self, ma_model, current_model):
if self.is_frozen:
return

copy, lerp = self.inplace_copy, self.inplace_lerp
current_decay = self.get_current_decay()

# store all source and target tensors to copy or lerp

tensors_to_copy = []
tensors_to_lerp = []

# loop through parameters

for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), self.get_params_iter(ma_model)):
if name in self.ignore_names:
continue
Expand All @@ -210,10 +223,12 @@ def update_moving_average(self, ma_model, current_model):
continue

if name in self.param_or_buffer_names_no_ema:
copy(ma_params.data, current_params.data)
tensors_to_copy.append((ma_params.data, current_params.data))
continue

lerp(ma_params.data, current_params.data, 1. - current_decay)
tensors_to_lerp.append((ma_params.data, current_params.data))

# loop through buffers

for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), self.get_buffers_iter(ma_model)):
if name in self.ignore_names:
Expand All @@ -223,10 +238,35 @@ def update_moving_average(self, ma_model, current_model):
continue

if name in self.param_or_buffer_names_no_ema:
copy(ma_buffer.data, current_buffer.data)
tensors_to_copy.append((ma_buffer.data, current_buffer.data))
continue

lerp(ma_buffer.data, current_buffer.data, 1. - current_decay)
tensors_to_lerp.append((ma_buffer.data, current_buffer.data))

# execute inplace copy or lerp

if not self.use_foreach:

for tgt, src in tensors_to_copy:
self.inplace_copy(tgt, src)

for tgt, src in tensors_to_lerp:
self.inplace_lerp(tgt, src, 1. - current_decay)

else:
# use foreach if available and specified

if self.allow_different_devices:
tensors_to_copy = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_copy]
tensors_to_lerp = [(tgt, src.to(tgt.device)) for tgt, src in tensors_to_lerp]

if len(tensors_to_copy) > 0:
tgt_copy, src_copy = zip(*tensors_to_copy)
torch._foreach_copy_(tgt_copy, src_copy)

if len(tensors_to_lerp) > 0:
tgt_lerp, src_lerp = zip(*tensors_to_lerp)
torch._foreach_lerp_(tgt_lerp, src_lerp, 1. - current_decay)

def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.8',
version = '0.5.0',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit a1b3342

Please sign in to comment.