diff --git a/ema_pytorch/ema_pytorch.py b/ema_pytorch/ema_pytorch.py index 7138770..35d68be 100644 --- a/ema_pytorch/ema_pytorch.py +++ b/ema_pytorch/ema_pytorch.py @@ -101,7 +101,7 @@ def __init__( # handle callable returning ema module - if callable(ema_model): + if not isinstance(ema_model, Module) and callable(ema_model): ema_model = ema_model() # ema model diff --git a/ema_pytorch/post_hoc_ema.py b/ema_pytorch/post_hoc_ema.py index 09b6d19..41e4d2d 100644 --- a/ema_pytorch/post_hoc_ema.py +++ b/ema_pytorch/post_hoc_ema.py @@ -77,7 +77,7 @@ def __init__( # handle callable returning ema module - if callable(ema_model): + if not isinstance(ema_model, Module) and callable(ema_model): ema_model = ema_model() # ema model diff --git a/setup.py b/setup.py index eb7832d..ddfed7f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'ema-pytorch', packages = find_packages(exclude=[]), - version = '0.7.4', + version = '0.7.5', license='MIT', description = 'Easy way to keep track of exponential moving average version of your pytorch module', author = 'Phil Wang',