diff --git a/cities/modeling/model_interactions.py b/cities/modeling/model_interactions.py index 586e2d7a..3fc1712b 100644 --- a/cities/modeling/model_interactions.py +++ b/cities/modeling/model_interactions.py @@ -3,12 +3,17 @@ from typing import Optional import dill + import torch import pyro import pyro.distributions as dist -from cities.modeling.modeling_utils import (prep_wide_data_for_inference, - train_interactions_model) +import torch + +from cities.modeling.modeling_utils import ( + prep_wide_data_for_inference, + train_interactions_model, +) from cities.utils.data_grabber import DataGrabber, find_repo_root