diff --git a/ptychonn/_infer/__main__.py b/ptychonn/_infer/__main__.py index a28b73c..eb51f2a 100644 --- a/ptychonn/_infer/__main__.py +++ b/ptychonn/_infer/__main__.py @@ -202,24 +202,25 @@ def __init__( model: typing.Optional[torch.nn.Module] = None, model_params_path: typing.Optional[pathlib.Path] = None, ): - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") - print(f"Let's use {torch.cuda.device_count()} GPUs!") + self.model = ptychonn.model.ReconSmallPhaseModel( + ) if model is None else model - if model is None or model_params_path is None: - self.model = ptychonn.model.ReconSmallPhaseModel() + if model_params_path is None: with importlib.resources.path( 'ptychonn._infer', 'weights.pth', ) as model_params_path: - self.model.load_state_dict( - torch.load(model_params_path, map_location=self.device)) + self.model.load_state_dict(torch.load(model_params_path)) else: - self.model = model - self.model.load_state_dict( - torch.load(model_params_path, map_location=self.device)) + self.model.load_state_dict(torch.load(model_params_path)) + + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu") + print(f"Let's use {torch.cuda.device_count()} GPUs!") + + if torch.cuda.device_count() > 1: + self.model = torch.nn.DataParallel(self.model) - self.model = torch.nn.DataParallel(self.model) #Default all devices self.model = self.model.to(self.device) def setTestData(self, X_test: np.ndarray, batch_size: int): diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index e4e1792..0a56583 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -279,13 +279,10 @@ def initModel(self, model_params_path: pathlib.Path | None = None): self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") + print(f"Let's use {torch.cuda.device_count()} GPUs!") + if torch.cuda.device_count() > 1: - logger.info("Let's use %d GPUs!", torch.cuda.device_count()) - # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs - self.model = torch.nn.DataParallel( - self.model, - device_ids=None, # Default all devices - ) + self.model = torch.nn.DataParallel(self.model) self.model = self.model.to(self.device) @@ -414,7 +411,13 @@ def updateSavedModel( fname = directory / f'best_model{ suffix }.pth' logger.info("Saving best model as %s", fname) os.makedirs(directory, exist_ok=True) - torch.save(model.state_dict(), fname) + if isinstance(model, ( + torch.nn.DataParallel, + torch.nn.parallel.DistributedDataParallel, + )): + torch.save(model.module.state_dict(), fname) + else: + torch.save(model.state_dict(), fname) def getSavedModelPath(self) -> pathlib.Path | None: """Return the path where `validate` will save the model weights"""