diff --git a/ptychonn/_infer/__main__.py b/ptychonn/_infer/__main__.py index 1d23b95..8127495 100644 --- a/ptychonn/_infer/__main__.py +++ b/ptychonn/_infer/__main__.py @@ -211,7 +211,7 @@ def infer( inferences : (POSITION, 2, WIDTH, HEIGHT) The reconstructed patches inferred by the model. ''' - model.eval() + model = model.eval().to("cuda") result = [] with torch.no_grad(): for batch in data: diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index c0e0b70..92d884d 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -153,6 +153,7 @@ def train( epochs: int = 1, batch_size: int = 32, training_fraction: float = 0.8, + log_frequency: int = 50, ) -> typing.Tuple[lightning.Trainer, lightning.pytorch.loggers.CSVLogger | ListLogger]: """Train a PtychoNN model. @@ -181,7 +182,17 @@ def train( The size of one training batch. training_fraction The proprotion of X_train and Y_train that is used for training. + log_frequency + Write to the logs every this number of steps """ + if batch_size <= 0: + msg = f"Number of batches must be positive not f{batch_size}" + raise ValueError(msg) + if epochs <= 0: + msg = f"Number of epochs must be positive or zero not f{epochs}" + raise ValueError(msg) + # X_train, Y_train and training_fraction checked in create_training_dataloader + if out_dir is not None: checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( dirpath=out_dir, @@ -207,6 +218,7 @@ def train( callbacks=None if out_dir is None else [checkpoint_callback], logger=logger, enable_checkpointing=False if out_dir is None else True, + log_every_n_steps=log_frequency, ) train_dataloader, val_dataloader = create_training_dataloader( @@ -259,13 +271,13 @@ def create_training_dataloader( assert X_train.dtype == np.float32 assert np.all(np.isfinite(X_train)) - if X_train.ndim != 3: + if X_train.ndim != 3 or X_train.shape[0] < 1: msg = ( - "X_train must have 3 dimemnsions: (N, WIDTH, HEIGHT); " + "X_train must have 3 dimensions: (N, WIDTH, HEIGHT); " f" not {X_train.shape}" ) raise ValueError(msg) - if Y_train.ndim != 4: + if Y_train.ndim != 4 or Y_train.shape[0] < 1 or Y_train.shape[1] not in [1, 2]: msg = ( f"Y_train must have 4 dimensions: (N, [1,2], WIDTH, HEIGHT); " f"not {Y_train.shape}" @@ -292,6 +304,15 @@ def create_training_dataloader( [training_fraction, 1.0 - training_fraction], ) + if len(training) // batch_size <= 0: + msg = ("The training dataset is smaller than one batch. " + "Adjust the batch size so there is training data.") + raise ValueError(msg) + if len(validation) // batch_size <= 0: + msg = ("The validation dataset is smaller than one batch. " + "Adjust the batch_size so there is validation data.") + raise ValueError(msg) + trainingloader = torch.utils.data.DataLoader( training, batch_size=batch_size,