Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 29, 2022
1 parent ba85be0 commit 03e75ca
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions ema_pytorch/ema_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(

self.ema_model.requires_grad_(False)

self.parameter_names = {name for name, param in self.ema_model.named_parameters() if is_float_dtype(param)}
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if is_float_dtype(buffer)}
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if is_float_dtype(param.dtype)}
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if is_float_dtype(buffer.dtype)}

self.update_every = update_every
self.update_after_step = update_after_step
Expand Down Expand Up @@ -101,10 +101,10 @@ def get_buffers_iter(self, model):
yield name, buffer

def copy_params_from_model_to_ema(self):
for ma_params, current_params in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.online_model)):
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.online_model)):
ma_params.data.copy_(current_params.data)

for ma_buffers, current_buffers in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.online_model)):
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.online_model)):
ma_buffers.data.copy_(current_buffers.data)

def get_current_decay(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'ema-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'Easy way to keep track of exponential moving average version of your pytorch module',
author = 'Phil Wang',
Expand Down

0 comments on commit 03e75ca

Please sign in to comment.