diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index 8ff96ec..d0eb2ec 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -66,8 +66,8 @@ def __init__( self.ema_model.requires_grad_(False) - self.parameter_names = {name for name, param in self.ema_model.named_parameters() if is_float_dtype(param)} - self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if is_float_dtype(buffer)} + self.parameter_names = {name for name, param in self.ema_model.named_parameters() if is_float_dtype(param.dtype)} + self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if is_float_dtype(buffer.dtype)} self.update_every = update_every self.update_after_step = update_after_step @@ -101,10 +101,10 @@ def get_buffers_iter(self, model): yield name, buffer def copy_params_from_model_to_ema(self): - for ma_params, current_params in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.online_model)): + for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), self.get_params_iter(self.online_model)): ma_params.data.copy_(current_params.data) - for ma_buffers, current_buffers in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.online_model)): + for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), self.get_buffers_iter(self.online_model)): ma_buffers.data.copy_(current_buffers.data) def get_current_decay(self): diff --git a/setup.py b/setup.py index cbb4eee..ebcc7a1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.1.0', + version = '0.1.1', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',