From 847b4b8083241c3ecb38d86ba4381c8c4b112af2 Mon Sep 17 00:00:00 2001 From: "Mauricio A. Rovira Galvez" <8482308+marovira@users.noreply.github.com> Date: Tue, 9 Apr 2024 17:25:06 -0700 Subject: [PATCH] [brief] Few fixes. [detailed] - When looking up a key in the registry, if the key is not found it should raise an exception even if the suffix is none. - Fixes a typo in EMA. --- src/pyro/core/utils.py | 4 ++++ src/pyro/model/networks.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pyro/core/utils.py b/src/pyro/core/utils.py index 02c8e9b..9526963 100644 --- a/src/pyro/core/utils.py +++ b/src/pyro/core/utils.py @@ -295,6 +295,10 @@ def get(self, name: str, suffix: str | None = None) -> typing.Any: raise KeyError( f"No object called '{name}' found in the '{self._name}' registrar" ) + elif ret is None: + raise KeyError( + f"No object called '{name}' found in the '{self._name}' registrar" + ) return ret def __contains__(self, name: str) -> bool: diff --git a/src/pyro/model/networks.py b/src/pyro/model/networks.py index 02fefb4..770d186 100644 --- a/src/pyro/model/networks.py +++ b/src/pyro/model/networks.py @@ -88,7 +88,7 @@ def _update(self, net: nn.Module, update_fn: typing.Callable) -> None: net.state_dict().values(), strict=True, ): - if self.device: + if self._device: net_v = net_v.to(device=self.device) ema_v.copy_(update_fn(ema_v, net_v))