diff --git a/src/pyro/core/utils.py b/src/pyro/core/utils.py index 02c8e9b..9526963 100644 --- a/src/pyro/core/utils.py +++ b/src/pyro/core/utils.py @@ -295,6 +295,10 @@ def get(self, name: str, suffix: str | None = None) -> typing.Any: raise KeyError( f"No object called '{name}' found in the '{self._name}' registrar" ) + elif ret is None: + raise KeyError( + f"No object called '{name}' found in the '{self._name}' registrar" + ) return ret def __contains__(self, name: str) -> bool: diff --git a/src/pyro/model/networks.py b/src/pyro/model/networks.py index 02fefb4..770d186 100644 --- a/src/pyro/model/networks.py +++ b/src/pyro/model/networks.py @@ -88,7 +88,7 @@ def _update(self, net: nn.Module, update_fn: typing.Callable) -> None: net.state_dict().values(), strict=True, ): - if self.device: + if self._device: net_v = net_v.to(device=self.device) ema_v.copy_(update_fn(ema_v, net_v))