Skip to content

Commit

Permalink
only update float types
Browse files Browse the repository at this point in the history
0.0.7
  • Loading branch information
lucidrains committed Jun 22, 2022
1 parent b0d4346 commit 784f744
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
25 changes: 20 additions & 5 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
def exists(val):
return val is not None

def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])

def clamp(value, min_value = None, max_value = None):
assert exists(min_value) or exists(max_value)
if exists(min_value):
Expand Down Expand Up @@ -75,11 +78,17 @@ def restore_ema_model_device(self):
self.ema_model.to(device)

def copy_params_from_model_to_ema(self):
for ma_param, current_param in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
ma_param.data.copy_(current_param.data)
for ma_params, current_params in zip(list(self.ema_model.parameters()), list(self.online_model.parameters())):
if not is_float_dtype(current_params.dtype):
continue

ma_params.data.copy_(current_params.data)

for ma_buffers, current_buffers in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
if not is_float_dtype(current_buffers.dtype):
continue

for ma_buffer, current_buffer in zip(list(self.ema_model.buffers()), list(self.online_model.buffers())):
ma_buffer.data.copy_(current_buffer.data)
ma_buffers.data.copy_(current_buffers.data)

def get_current_decay(self):
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value = 0)
Expand Down Expand Up @@ -112,15 +121,21 @@ def update_moving_average(self, ma_model, current_model):
current_decay = self.get_current_decay()

for (name, current_params), (_, ma_params) in zip(list(current_model.named_parameters()), list(ma_model.named_parameters())):
if not is_float_dtype(current_params.dtype):
continue

if name in self.param_or_buffer_names_no_ema:
ma_param.data.copy_(current_param.data)
ma_params.data.copy_(current_params.data)
continue

difference = ma_params.data - current_params.data
difference.mul_(1.0 - current_decay)
ma_params.sub_(difference)

for (name, current_buffer), (_, ma_buffer) in zip(list(current_model.named_buffers()), list(ma_model.named_buffers())):
if not is_float_dtype(current_buffer.dtype):
continue

if name in self.param_or_buffer_names_no_ema:
ma_buffer.data.copy_(current_buffer.data)
continue
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.7',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit 784f744

Please sign in to comment.