From dda1a3bfc250e4a7a700c5a2f2cea934ce9f5aa6 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 21 Dec 2023 19:02:07 -0600 Subject: [PATCH 01/12] REF: Replace Tester/Trainer with Pytorch Lightning --- ptychonn/__init__.py | 4 +- ptychonn/_infer/__main__.py | 93 ++------ ptychonn/_train/__main__.py | 432 ++---------------------------------- ptychonn/model.py | 53 ++++- requirements-dev | 2 +- setup.cfg | 2 +- 6 files changed, 86 insertions(+), 500 deletions(-) diff --git a/ptychonn/__init__.py b/ptychonn/__init__.py index 9476043..ff8b0f7 100644 --- a/ptychonn/__init__.py +++ b/ptychonn/__init__.py @@ -6,7 +6,7 @@ # package is not installed pass -from ptychonn._infer.__main__ import infer, stitch_from_inference, Tester -from ptychonn._train.__main__ import Trainer +from ptychonn._infer.__main__ import infer, stitch_from_inference +from ptychonn._train.__main__ import train from ptychonn.model import * from ptychonn.plot import * diff --git a/ptychonn/_infer/__main__.py b/ptychonn/_infer/__main__.py index c9200e6..4b4bc53 100644 --- a/ptychonn/_infer/__main__.py +++ b/ptychonn/_infer/__main__.py @@ -10,6 +10,7 @@ import scipy.interpolate import torch import tqdm +import lightning import ptychonn.model @@ -202,82 +203,18 @@ def infer( inferences : (POSITION, 2, WIDTH, HEIGHT) The reconstructed patches inferred by the model. ''' - tester = Tester( - model=ptychonn.model.ReconSmallModel(), - model_params_path=model_params_path, + model = ptychonn.model.LitReconSmallModel.load_from_checkpoint( + model_params_path, ) - tester.setTestData( - data, - batch_size=max(torch.cuda.device_count(), 1) * 64, - ) - return tester.predictTestData(npz_save_path=inferences_out_file) - - -class Tester(): - - def __init__( - self, - *, - model: torch.nn.Module, - model_params_path: pathlib.Path, - ): - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") - print(f"Let's use {torch.cuda.device_count()} GPUs!") - - self.model = model - - params = torch.load( - model_params_path, - map_location=self.device, - ) - self.model.load_state_dict(params) - - self.model = torch.nn.DataParallel(self.model) - - self.model.to(self.device) - - self.model.eval() - - def setTestData(self, X_test: np.ndarray, batch_size: int): - self.X_test = torch.tensor(X_test[:, None, ...], dtype=torch.float32) - self.test_data = TensorDataset(self.X_test) - self.testloader = DataLoader( - self.test_data, - batch_size=batch_size, - shuffle=False, - ) - - def predictTestData(self, npz_save_path: str = None): - - phs_eval = [] - with torch.inference_mode(): - for (ft_images, ) in self.testloader: - ph_eval = self.model(ft_images.to(self.device)) - phs_eval.append(ph_eval.detach().cpu().numpy()) - - self.phs_eval = np.concatenate(phs_eval, axis=0) - - if npz_save_path is not None: - np.savez_compressed(npz_save_path, ph=self.phs_eval) - print(f'Finished the inference stage and saved at {npz_save_path}') - - return self.phs_eval - - def calcErrors(self, phs_true: np.ndarray, npz_save_path: str = None): - from skimage.metrics import mean_squared_error as mse - - self.phs_true = phs_true - self.errors = [] - for i, (p1, p2) in enumerate(zip(self.phs_eval, self.phs_true)): - err2 = mse(p1, p2) - self.errors.append([err2]) - - self.errors = np.array(self.errors) - print("Mean errors in phase") - print(np.mean(self.errors, axis=0)) - - if npz_save_path is not None: - np.savez_compressed(npz_save_path, phs_err=self.errors[:, 0]) - - return self.errors + model.eval() + result = list() + with torch.no_grad(): + for batch in data: + result.append( + model(torch.from_numpy(batch[None, None, :, :]).to("cuda")) + .cpu() + .numpy() + ) + + result = np.concatenate(result, axis=0) + return result diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index 7138a49..a521852 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -116,420 +116,34 @@ def train( assert np.all(np.isfinite(Y_train)) assert X_train.dtype == np.float32 assert np.all(np.isfinite(X_train)) - logger.info("Creating the training model...") + assert X_train.ndim == 4 + assert Y_train.ndim == 4 - trainer = Trainer( - model=ptychonn.model.ReconSmallModel(), - batch_size=batch_size * torch.cuda.device_count(), - output_path=out_dir, + traindata = torch.utils.data.TensorDataset( + torch.from_numpy(X_train), + torch.from_numpy(Y_train), ) - trainer.setTrainingData( - X_train, - Y_train, - valid_data_ratio=0.1, - ) - trainer.setOptimizationParams( - epochs_per_half_cycle=6, - max_lr=1e-3, - min_lr=1e-4, - ) - trainer.initModel(model_params_path=load_model_path) - trainer.run(epochs) - - if out_dir is not None: - trainer.plotLearningRate( - save_fname=out_dir / 'learning_rate.png', - show_fig=False, - ) - ptychonn.plot.plot_metrics( - trainer.metrics, - save_fname=out_dir / 'metrics.png', - show_fig=False, - ) - - return trainer - - -class Trainer(): - """A object that manages training PtychoNN - - Artifacts - --------- - - When `output_path` is not None, the following artifacts are written to disk. - - ``` - `output_path` - reference - 00000.tiff - 00001.tiff - ... - inference - 00000.tiff - 00001.tiff - ... - metrics`output_suffix`.npz - best_model`output_suffix`.pth - ``` - - """ - - def __init__( - self, - model: ptychonn.model.ReconSmallModel, - batch_size: int, - output_path: pathlib.Path | None = None, - output_suffix: str = '', - ): - logger.info("Initializing the training procedure...") - self.model = model - self.batch_size = batch_size - self.output_path = output_path - self.output_suffix = output_suffix - self.epoch = 0 - - def setTrainingData( - self, - X_train_full: np.ndarray, - Y_ph_train_full: np.ndarray, - valid_data_ratio: float = 0.1, - ): - """ - - Parameters - ---------- - X_train_full : (N, H, W) - The measured intensities at the detector - Y_ph_train_full : (N, C, H, W) - The phase and amplitude patches from the reconstructed object. - Phase in the zeroth channel and amplitude (optionally) in the first - channel - - """ - if (Y_ph_train_full.ndim != 4): - msg = ("Training data example patches must have a channel " - "dimension! i.e. the shape should be (N, C, H, W)") - raise ValueError(msg) - logger.info("Setting training data...") - - self.H, self.W = X_train_full.shape[-2:] - - self.X_train_full = torch.tensor( - X_train_full[:, None, ...], - dtype=torch.float32, - ) - self.Y_ph_train_full = torch.tensor( - Y_ph_train_full, - dtype=torch.float32, - ) - self.ntrain_full = self.X_train_full.shape[0] - - self.valid_data_ratio = valid_data_ratio - self.nvalid = int(self.ntrain_full * self.valid_data_ratio) - self.ntrain = self.ntrain_full - self.nvalid - - self.train_data_full = torch.utils.data.TensorDataset( - self.X_train_full, - self.Y_ph_train_full, - ) - - self.train_data, self.valid_data = torch.utils.data.random_split( - self.train_data_full, - [self.ntrain, self.nvalid], - ) - self.trainloader = torch.utils.data.DataLoader( - self.train_data, - batch_size=self.batch_size, - shuffle=True, - num_workers=4, - drop_last=False, - ) - - self.validloader = torch.utils.data.DataLoader( - self.valid_data, - batch_size=self.batch_size, - shuffle=False, - num_workers=4, - drop_last=False, - ) - - self.iters_per_epoch = self.ntrain // self.batch_size + ( - self.ntrain % self.batch_size > 0) - - def setOptimizationParams( - self, - epochs_per_half_cycle: int = 6, - max_lr: float = 5e-4, - min_lr: float = 1e-4, - ): - logger.info("Setting optimization parameters...") - - # TODO: Move this note about iterations into the documentation string - # after figuring out what it means. Paper recommends 2-10 number of - # iterations - self.epochs_per_half_cycle = epochs_per_half_cycle - self.iters_per_half_cycle = epochs_per_half_cycle * self.iters_per_epoch - - logger.info( - "LR step size is: %d which is every %d epochs", - self.iters_per_half_cycle, - self.iters_per_half_cycle / self.iters_per_epoch, - ) - - self.max_lr = max_lr - self.min_lr = min_lr - - self.criterion = self.customLoss - self.optimizer = torch.optim.Adam( - self.model.parameters(), - lr=self.max_lr, - ) - self.scheduler = torch.optim.lr_scheduler.CyclicLR( - self.optimizer, - max_lr=self.max_lr, - base_lr=self.min_lr, - step_size_up=self.iters_per_half_cycle, - cycle_momentum=False, - mode='triangular2', - ) - - def initModel( - self, - model_params_path: pathlib.Path | None = None, - ): - """Load parameters from the disk then model to the GPU(s).""" - - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") - print(f"Let's use {torch.cuda.device_count()} GPUs!") - - torchinfo.summary(self.model, (1, 1, self.H, self.W), device="cpu") - - self.model_params_path = model_params_path - - if model_params_path is not None: - self.model.load_state_dict( - torch.load( - self.model_params_path, - map_location=self.device, - )) - - self.model = torch.nn.DataParallel(self.model) - - self.model = self.model.to(self.device) - self.scaler = torch.cuda.amp.GradScaler() - - logger.info("Setting up metrics...") - self.metrics = { - 'losses': [], - 'val_losses': [], - 'lrs': [], - 'best_val_loss': np.inf - } - logger.info(self.metrics) - - def train(self): - tot_loss = 0.0 - loss_ph = 0.0 - - for (ft_images, phs) in self.trainloader: - - # Move everything to device - ft_images = ft_images.to(self.device) - phs = phs.to(self.device) - - # Divide cumulative loss by number of batches-- slightly inaccurate - # because last batch is different size - pred_phs = self.model(ft_images) - loss_p = self.criterion(pred_phs, phs, self.ntrain) - # Monitor phase loss but only within support (which may not be same - # as true amp) - loss = loss_p - # Use equiweighted amps and phase - - # Zero current grads and do backprop - self.optimizer.zero_grad() - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - - tot_loss += loss.detach().item() - - loss_ph += loss_p.detach().item() - - # Update the LR according to the schedule -- CyclicLR updates each - # batch - self.scheduler.step() - self.metrics['lrs'].append(self.scheduler.get_last_lr()) - self.scaler.update() - - self.metrics['losses'].append([tot_loss, loss_ph]) - - def validate(self, epoch: int): - tot_val_loss = 0.0 - val_loss_ph = 0.0 - for (ft_images, phs) in self.validloader: - ft_images = ft_images.to(self.device) - phs = phs.to(self.device) - pred_phs = self.model(ft_images) - - val_loss_p = self.criterion(pred_phs, phs, self.nvalid) - val_loss = val_loss_p - - tot_val_loss += val_loss.detach().item() - val_loss_ph += val_loss_p.detach().item() - - self.metrics['val_losses'].append([tot_val_loss, val_loss_ph]) - - if self.output_path is not None: - self.saveMetrics( - self.metrics, - self.output_path, - self.output_suffix, - ) - - # Update saved model if val loss is lower - if (tot_val_loss < self.metrics['best_val_loss']): - logger.info( - "Saving improved model after Val Loss improved from %.5f to %.5f", - self.metrics['best_val_loss'], - tot_val_loss, - ) - self.metrics['best_val_loss'] = tot_val_loss - - if self.output_path is not None: - self.updateSavedModel( - self.model, - self.output_path, - self.output_suffix, - ) - - import matplotlib.pyplot as plt - os.makedirs(self.output_path / 'reference', exist_ok=True) - os.makedirs(self.output_path / 'inference', exist_ok=True) - plt.imsave(self.output_path / f'reference/{epoch:05d}.png', - phs[0, 0].detach().cpu().numpy().astype(np.float32)) - plt.imsave( - self.output_path / f'inference/{epoch:05d}.png', - pred_phs[0, 0].detach().cpu().numpy().astype(np.float32)) - - @staticmethod - def customLoss( - input: torch.tensor, - target: torch.tensor, - scaling: float, - ): - """A loss function which scales according to training set size.""" - assert torch.all(torch.isfinite(input)) - assert torch.all(torch.isfinite(target)) - return torch.sum(torch.mean( - torch.abs(input - target), - axis=(-1, -2), - )) / scaling - - # TODO: Use a callback instead of a static method for saving the model? - - @staticmethod - def updateSavedModel( - model: ptychonn.model.ReconSmallModel, - directory: pathlib.Path, - suffix: str = '', - ): - """Writes `model` parameters to `directory`/best_model`suffix`.pth - - The directory is created if it does not exist. - """ - fname = directory / f'best_model{ suffix }.pth' - logger.info("Saving best model as %s", fname) - os.makedirs(directory, exist_ok=True) - if isinstance(model, ( - torch.nn.DataParallel, - torch.nn.parallel.DistributedDataParallel, - )): - torch.save(model.module.state_dict(), fname) - else: - torch.save(model.state_dict(), fname) - - def getSavedModelPath(self) -> pathlib.Path | None: - """Return the path where `validate` will save the model weights""" - if self.output_path is None: - return None - return self.output_path / f'best_model{ self.output_suffix }.pth' - - @staticmethod - def saveMetrics( - metrics: dict, - directory: pathlib.Path, - suffix: str = '', - ): - """Writes `metrics` to `directory`/metrics`suffix`.npz - - The directory is created if it does not exist. - """ - os.makedirs(directory, exist_ok=True) - np.savez(directory / f'metrics{suffix}.npz', **metrics) - - def run(self, epochs: int, output_frequency: int = 1): - """The main training loop""" - - for epoch in range(epochs): - - #Set model to train mode - self.model.train() - - #Training loop - self.train() - - #Switch model to eval mode - self.model.eval() - - #Validation loop - with torch.inference_mode(): - self.validate(epoch) + trainloader = torch.utils.data.DataLoader( + traindata, + batch_size=batch_size, + shuffle=True, + drop_last=True, + ) - if epoch % output_frequency == 0: - logger.info( - 'Epoch: %d | FT | Train Loss: %1.03e | Val Loss: %1.03e', - epoch, - self.metrics['losses'][-1][0], - self.metrics['val_losses'][-1][0], - ) - logger.info( - 'Epoch: %d | Ph | Train Loss: %1.03e | Val Loss: %1.03e', - epoch, - self.metrics['losses'][-1][1], - self.metrics['val_losses'][-1][1], - ) - logger.info( - 'Epoch: %d | Ending LR: %1.03e', - epoch, - self.metrics['lrs'][-1][0], - ) + trainer = lightning.Trainer( + max_epochs=epochs, + default_root_dir=out_dir, + ) - def plotLearningRate( - self, - save_fname: pathlib.Path | None = None, - show_fig: bool = True, - ): - batches = np.linspace( - 0, - len(self.metrics['lrs']), - len(self.metrics['lrs']) + 1, - ) - epoch_list = batches / self.iters_per_epoch + model = ptychonn.model.LitReconSmallModel() - import matplotlib.pyplot as plt + if load_model_path is not None: + model = ptychonn.model.LitReconSmallModel.load_from_checkpoint(load_model_path) - f = plt.figure() - plt.plot(epoch_list[1:], self.metrics['lrs'], 'C3-') - plt.grid() - plt.ylabel("Learning rate") - plt.xlabel("Epoch") - plt.tight_layout() + trainer.fit( + model=model, + train_dataloaders=trainloader, + ) - if save_fname is not None: - plt.savefig(save_fname) - if show_fig: - plt.show() - else: - plt.close(f) + return trainer diff --git a/ptychonn/model.py b/ptychonn/model.py index 0554108..55039c0 100644 --- a/ptychonn/model.py +++ b/ptychonn/model.py @@ -1,11 +1,12 @@ """Define PtychoNN Pytorch models.""" +import lightning import numpy as np import torch import torch.nn as nn -class ReconSmallModel(nn.Module): +class LitReconSmallModel(lightning.LightningModule): """A small PychoNN model. Parameters @@ -36,6 +37,10 @@ def __init__( self.use_batch_norm = use_batch_norm self.enable_amplitude = enable_amplitude + self.epochs_per_half_cycle: int = 6 + self.max_lr: float = 5e-4 + self.min_lr: float = 1e-4 + # Appears sequential has similar functionality as TF avoiding need for # separate model definition and activ self.encoder = nn.Sequential( @@ -118,12 +123,42 @@ def up_block(self, filters_in: int, filters_out: int, groups: int): ] def forward(self, x): - with torch.cuda.amp.autocast(): - output = self.decoder(self.encoder(x)) - # Restore -pi to pi range - # Using tanh activation (-1 to 1) for phase so multiply by pi - output[..., 0, :, :] = torch.tanh(output[..., 0, :, :]) * np.pi - # Restrict amplitude to (0, 1) range with sigmoid - if self.enable_amplitude: - output[..., 1, :, :] = torch.sigmoid(output[..., 1, :, :]) + output = self.decoder(self.encoder(x)) + # Restore -pi to pi range + # Using tanh activation (-1 to 1) for phase so multiply by pi + output[..., 0, :, :] = torch.tanh(output[..., 0, :, :]) * np.pi + # Restrict amplitude to (0, 1) range with sigmoid + if self.enable_amplitude: + output[..., 1, :, :] = torch.sigmoid(output[..., 1, :, :]) return output + + def training_step(self, batch, batch_idx): + ft_images, object_roi = batch + predicted = self.forward(ft_images) + loss = torch.nn.functional.mse_loss(predicted, object_roi) + self.log("training_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + ft_images, object_roi = batch + predicted = self.forward(ft_images) + loss = torch.nn.functional.mse_loss(predicted, object_roi) + self.log("validation_loss", loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.max_lr, + ) + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + max_lr=self.max_lr, + base_lr=self.min_lr, + cycle_momentum=False, + mode="triangular2", + ) + return dict( + optimizer=optimizer, + lr_scheduler=scheduler, + ) diff --git a/requirements-dev b/requirements-dev index 4a8ce60..46ad4f3 100644 --- a/requirements-dev +++ b/requirements-dev @@ -6,4 +6,4 @@ pytorch >=1.12,<2.1 scikit-image scipy tqdm -torchinfo +lightning >=2.1.3,<3.0 diff --git a/setup.cfg b/setup.cfg index 6183aee..5c03881 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,7 @@ install_requires = scikit-image scipy tqdm - torchinfo + lightning >=2.1.3,<3 [options.entry_points] console_scripts = From 52e53edf28164b4267c7d9ea13bbfc6927cef698 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 22 Dec 2023 10:17:08 -0600 Subject: [PATCH 02/12] REF: Bring back some API functionality --- ptychonn/_infer/__main__.py | 10 +++-- ptychonn/_train/__main__.py | 74 +++++++++++++++++++++++++------------ ptychonn/model.py | 6 ++- 3 files changed, 60 insertions(+), 30 deletions(-) diff --git a/ptychonn/_infer/__main__.py b/ptychonn/_infer/__main__.py index 4b4bc53..417139c 100644 --- a/ptychonn/_infer/__main__.py +++ b/ptychonn/_infer/__main__.py @@ -2,7 +2,6 @@ import typing import glob -from torch.utils.data import TensorDataset, DataLoader import click import matplotlib.pyplot as plt import numpy as np @@ -10,7 +9,6 @@ import scipy.interpolate import torch import tqdm -import lightning import ptychonn.model @@ -178,7 +176,7 @@ def infer_cli( def infer( - data: npt.NDArray, + data: npt.NDArray[np.float32], model_params_path: pathlib.Path, *, inferences_out_file: typing.Optional[pathlib.Path] = None, @@ -207,7 +205,7 @@ def infer( model_params_path, ) model.eval() - result = list() + result = [] with torch.no_grad(): for batch in data: result.append( @@ -217,4 +215,8 @@ def infer( ) result = np.concatenate(result, axis=0) + + if inferences_out_file is not None: + np.save(inferences_out_file, result) + return result diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index a521852..a4fe614 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -4,10 +4,10 @@ import pathlib import click +import lightning import numpy as np import numpy.typing as npt import torch -import torchinfo import ptychonn.model import ptychonn.plot @@ -92,8 +92,8 @@ def train_cli( def train( - X_train: npt.NDArray[float], - Y_train: npt.NDArray[float], + X_train: npt.NDArray[np.float32], + Y_train: npt.NDArray[np.float32], out_dir: pathlib.Path | None, load_model_path: pathlib.Path | None = None, epochs: int = 1, @@ -112,38 +112,64 @@ def train( load_model_path Load a previous model's parameters from this file. """ - assert Y_train.dtype == np.float32 - assert np.all(np.isfinite(Y_train)) - assert X_train.dtype == np.float32 - assert np.all(np.isfinite(X_train)) - assert X_train.ndim == 4 - assert Y_train.ndim == 4 - - traindata = torch.utils.data.TensorDataset( - torch.from_numpy(X_train), - torch.from_numpy(Y_train), - ) - - trainloader = torch.utils.data.DataLoader( - traindata, - batch_size=batch_size, - shuffle=True, - drop_last=True, - ) trainer = lightning.Trainer( max_epochs=epochs, default_root_dir=out_dir, ) - model = ptychonn.model.LitReconSmallModel() - if load_model_path is not None: model = ptychonn.model.LitReconSmallModel.load_from_checkpoint(load_model_path) + else: + model = ptychonn.model.LitReconSmallModel() trainer.fit( model=model, - train_dataloaders=trainloader, + train_dataloaders=create_training_dataloader( + X_train, + Y_train, + batch_size, + ), ) return trainer + + +def create_training_dataloader( + X_train: npt.NDArray[np.float32], + Y_train: npt.NDArray[np.float32], + batch_size: int = 32, +) -> torch.utils.data.DataLoader: + """Create a Pytorch Dataloader from numpy arrays.""" + + assert Y_train.dtype == np.float32 + assert np.all(np.isfinite(Y_train)) + assert X_train.dtype == np.float32 + assert np.all(np.isfinite(X_train)) + + if X_train.ndim != 3: + msg = ( + "X_train must have 3 dimemnsions: (N, WIDTH, HEIGHT); " + f" not {X_train.shape}" + ) + raise ValueError(msg) + if Y_train.ndim != 4: + msg = ( + f"Y_train must have 4 dimensions: (N, [1,2], WIDTH, HEIGHT); " + f"not {Y_train.shape}" + ) + raise ValueError(msg) + + dataset = torch.utils.data.TensorDataset( + torch.from_numpy(X_train[:, None, :, :]), + torch.from_numpy(Y_train), + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + drop_last=True, + ) + + return dataloader diff --git a/ptychonn/model.py b/ptychonn/model.py index 55039c0..e6cbbb8 100644 --- a/ptychonn/model.py +++ b/ptychonn/model.py @@ -31,6 +31,8 @@ def __init__( nconv: int = 16, use_batch_norm: bool = True, enable_amplitude: bool = True, + min_lr: float = 1e-4, + max_lr: float = 5e-4, ): super().__init__() self.nconv = nconv @@ -38,8 +40,8 @@ def __init__( self.enable_amplitude = enable_amplitude self.epochs_per_half_cycle: int = 6 - self.max_lr: float = 5e-4 - self.min_lr: float = 1e-4 + self.max_lr: float = max_lr + self.min_lr: float = min_lr # Appears sequential has similar functionality as TF avoiding need for # separate model definition and activ From a92c7ded652f237b3245636e041efaf09a7bf0f3 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 22 Dec 2023 12:42:45 -0600 Subject: [PATCH 03/12] REF: Remove external deps from API --- ptychonn/_train/__main__.py | 67 ++++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 9 deletions(-) diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index a4fe614..be0de6a 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -2,6 +2,7 @@ import logging import os import pathlib +import typing import click import lightning @@ -85,6 +86,9 @@ def train_cli( train( X_train=data, Y_train=patches, + model=init_or_load_model( + ptychonn.LitReconSmallModel, + ), out_dir=out_dir, epochs=epochs, batch_size=32, @@ -94,13 +98,21 @@ def train_cli( def train( X_train: npt.NDArray[np.float32], Y_train: npt.NDArray[np.float32], + model: lightning.LightningModule, out_dir: pathlib.Path | None, - load_model_path: pathlib.Path | None = None, epochs: int = 1, batch_size: int = 32, ): """Train a PtychoNN model. + Initialize a model for the model parameter using the `init_or_load_model()` + function. + + If out_dir is not None the following artifacts will be created: + - {out_dir}/best_model.ckpt + - {out_dir}/metrics.csv + - {out_dir}/hparams.yaml + Parameters ---------- X_train (N, WIDTH, HEIGHT) @@ -109,20 +121,36 @@ def train( The corresponding reconstructed patches for the diffraction patterns. out_dir A folder where all the training artifacts are saved. - load_model_path - Load a previous model's parameters from this file. + model + An initialized PtychoNN model. + epochs + The maximum number of training epochs + batch_size + The size of one training batch. """ + checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( + dirpath=out_dir, + filename="best_model", + save_top_k=1, + monitor="training_loss", + mode="min", + ) + + logger = lightning.pytorch.loggers.CSVLogger( + save_dir=out_dir, + name="", + version="", + prefix="", + ) + trainer = lightning.Trainer( max_epochs=epochs, default_root_dir=out_dir, + callbacks=[checkpoint_callback], + logger=logger, ) - if load_model_path is not None: - model = ptychonn.model.LitReconSmallModel.load_from_checkpoint(load_model_path) - else: - model = ptychonn.model.LitReconSmallModel() - trainer.fit( model=model, train_dataloaders=create_training_dataloader( @@ -140,7 +168,7 @@ def create_training_dataloader( Y_train: npt.NDArray[np.float32], batch_size: int = 32, ) -> torch.utils.data.DataLoader: - """Create a Pytorch Dataloader from numpy arrays.""" + """Create a Pytorch Dataloader from NumPy arrays.""" assert Y_train.dtype == np.float32 assert np.all(np.isfinite(Y_train)) @@ -173,3 +201,24 @@ def create_training_dataloader( ) return dataloader + + +def init_or_load_model( + model_type: typing.Type[lightning.LightningModule], + model_checkpoint_path: pathlib.Path | None, + model_init_params: dict | None, +): + """Initialize one of the PtychoNN models via params or a checkpoint.""" + if not (model_checkpoint_path is None or model_init_params is None): + msg = ( + "One of model_checkpoint_path OR model_init_params must be None! " + "Both cannot be defined." + ) + raise ValueError(msg) + + if model_checkpoint_path is not None: + return model_type.load_from_checkpoint( + model_checkpoint_path + ) + else: + return model_type(**model_init_params) From 243870e47d5d73ca31bd1f360e7eecae65d0f955 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 22 Dec 2023 14:06:16 -0600 Subject: [PATCH 04/12] REF: Move model init into separate function --- ptychonn/__init__.py | 2 +- ptychonn/_infer/__main__.py | 17 ++++++++++++----- ptychonn/_train/__main__.py | 7 +++++-- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/ptychonn/__init__.py b/ptychonn/__init__.py index ff8b0f7..1895be5 100644 --- a/ptychonn/__init__.py +++ b/ptychonn/__init__.py @@ -7,6 +7,6 @@ pass from ptychonn._infer.__main__ import infer, stitch_from_inference -from ptychonn._train.__main__ import train +from ptychonn._train.__main__ import train, init_or_load_model, create_training_dataloader from ptychonn.model import * from ptychonn.plot import * diff --git a/ptychonn/_infer/__main__.py b/ptychonn/_infer/__main__.py index 417139c..6d3a624 100644 --- a/ptychonn/_infer/__main__.py +++ b/ptychonn/_infer/__main__.py @@ -3,6 +3,7 @@ import glob import click +import lightning import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt @@ -127,7 +128,11 @@ def infer_cli( inferences = infer( data=data, - model_params_path=params_path, + model=ptychonn.init_or_load_model( + ptychonn.model.LitReconSmallModel, + model_checkpoint_path=params_path, + model_init_params=None, + ) ) pstitched = stitch_from_inference( @@ -177,7 +182,7 @@ def infer_cli( def infer( data: npt.NDArray[np.float32], - model_params_path: pathlib.Path, + model: lightning.LightningModule, *, inferences_out_file: typing.Optional[pathlib.Path] = None, ) -> npt.NDArray: @@ -189,10 +194,15 @@ def infer( Set the CUDA_VISIBLE_DEVICES environment variable to control which GPUs will be used. + Initialize a model for the model parameter using the `init_or_load_model()` + function. + Parameters ---------- data : (POSITION, WIDTH, HEIGHT) Diffraction patterns to be reconstructed. + model + An initialized PtychoNN model. inferences_out_file : pathlib.Path Optional file to save reconstructed patches. @@ -201,9 +211,6 @@ def infer( inferences : (POSITION, 2, WIDTH, HEIGHT) The reconstructed patches inferred by the model. ''' - model = ptychonn.model.LitReconSmallModel.load_from_checkpoint( - model_params_path, - ) model.eval() result = [] with torch.no_grad(): diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index be0de6a..2618402 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -88,6 +88,8 @@ def train_cli( Y_train=patches, model=init_or_load_model( ptychonn.LitReconSmallModel, + model_checkpoint_path=None, + model_init_params=dict(), ), out_dir=out_dir, epochs=epochs, @@ -115,9 +117,9 @@ def train( Parameters ---------- - X_train (N, WIDTH, HEIGHT) + X_train : (N, WIDTH, HEIGHT) The diffraction patterns. - Y_train (N, 2, WIDTH, HEIGHT) + Y_train : (N, 2, WIDTH, HEIGHT) The corresponding reconstructed patches for the diffraction patterns. out_dir A folder where all the training artifacts are saved. @@ -205,6 +207,7 @@ def create_training_dataloader( def init_or_load_model( model_type: typing.Type[lightning.LightningModule], + *, model_checkpoint_path: pathlib.Path | None, model_init_params: dict | None, ): From 19a9a6c3062aa405fe928b8e84c8be963121957d Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 22 Dec 2023 14:33:08 -0600 Subject: [PATCH 05/12] Plot as soon as done stitching --- ptychonn/_infer/__main__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ptychonn/_infer/__main__.py b/ptychonn/_infer/__main__.py index 6d3a624..1d23b95 100644 --- a/ptychonn/_infer/__main__.py +++ b/ptychonn/_infer/__main__.py @@ -135,20 +135,14 @@ def infer_cli( ) ) + # Plotting some summary images + pstitched = stitch_from_inference( inferences[:, 0], scan, stitched_pixel_width=1, inference_pixel_width=1, ) - astitched = stitch_from_inference( - inferences[:, 1], - scan, - stitched_pixel_width=1, - inference_pixel_width=1, - ) - - # Plotting some summary images plt.figure(1, figsize=[8.5, 7]) plt.imshow(pstitched) plt.colorbar() @@ -156,6 +150,12 @@ def infer_cli( plt.title('stitched_phases') plt.savefig(out_dir / 'pstitched.png', bbox_inches='tight') + astitched = stitch_from_inference( + inferences[:, 1], + scan, + stitched_pixel_width=1, + inference_pixel_width=1, + ) plt.figure(2, figsize=[8.5, 7]) plt.imshow(astitched) plt.colorbar() From 8dceda1f2dc7dcee4f3b9ab67e44df446420a296 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 22 Dec 2023 16:22:43 -0600 Subject: [PATCH 06/12] REF: Reimplement loss plots --- ptychonn/_train/__main__.py | 12 ++++++++++++ ptychonn/model.py | 30 ++++++++++++++++++++++++++-- ptychonn/plot.py | 39 ++++++++++++++++++++----------------- 3 files changed, 61 insertions(+), 20 deletions(-) diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index 2618402..bda1e06 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -162,6 +162,18 @@ def train( ), ) + with open(out_dir / "metrics.csv") as f: + headers = f.readline().strip('\n').split(",") + numbers = np.genfromtxt(out_dir / "metrics.csv", delimiter=",", skip_header=1,) + metrics = dict() + for col, header in enumerate(headers): + metrics[header] = numbers[:, col] + + ptychonn.plot.plot_metrics( + metrics=metrics, + save_fname=out_dir / "metrics.png", + ) + return trainer diff --git a/ptychonn/model.py b/ptychonn/model.py index e6cbbb8..53cbae9 100644 --- a/ptychonn/model.py +++ b/ptychonn/model.py @@ -137,14 +137,40 @@ def forward(self, x): def training_step(self, batch, batch_idx): ft_images, object_roi = batch predicted = self.forward(ft_images) - loss = torch.nn.functional.mse_loss(predicted, object_roi) + loss_phase = torch.nn.functional.mse_loss( + predicted[..., 0, :, :], + object_roi[..., 0, :, :], + ) + self.log("training_loss_phase", loss_phase) + if self.enable_amplitude: + loss_amp = torch.nn.functional.mse_loss( + predicted[..., 1, :, :], + object_roi[..., 1, :, :], + ) + loss = loss_phase + loss_amp + self.log("training_loss_amplitude", loss_amp) + else: + loss = loss_phase self.log("training_loss", loss) return loss def validation_step(self, batch, batch_idx): ft_images, object_roi = batch predicted = self.forward(ft_images) - loss = torch.nn.functional.mse_loss(predicted, object_roi) + loss_phase = torch.nn.functional.mse_loss( + predicted[..., 0, :, :], + object_roi[..., 0, :, :], + ) + self.log("validation_loss_phase", loss_phase) + if self.enable_amplitude: + loss_amp = torch.nn.functional.mse_loss( + predicted[..., 1, :, :], + object_roi[..., 1, :, :], + ) + loss = loss_phase + loss_amp + self.log("validation_loss_amplitude", loss_amp) + else: + loss = loss_phase self.log("validation_loss", loss) return loss diff --git a/ptychonn/plot.py b/ptychonn/plot.py index a54b930..f56c5d6 100644 --- a/ptychonn/plot.py +++ b/ptychonn/plot.py @@ -6,36 +6,39 @@ def plot_metrics(metrics: dict, save_fname: str = None, show_fig: bool = False): - losses_arr = np.array(metrics['losses']) - val_losses_arr = np.array(metrics['val_losses']) - print("Shape of losses array is ", losses_arr.shape) - fig, ax = plt.subplots(3, sharex=True, figsize=(15, 8)) - ax[0].plot(losses_arr[1:, 0], 'C3o-', label="Train") - ax[0].plot(val_losses_arr[1:, 0], 'C0o-', label="Val") + fig, ax = plt.subplots(3, sharex=True, figsize=(16, 9)) + if 'training_loss' in metrics: + ax[0].plot(metrics['step'], metrics['training_loss'], 'C3o-', label="Train") + if 'validation_loss' in metrics: + ax[0].plot(metrics['step'], metrics['validation_loss'], 'C0o-', label="Val") ax[0].set(ylabel='Loss') ax[0].set_yscale('log') ax[0].grid() - ax[0].legend(loc='center right') + ax[0].legend() ax[0].set_title('Total loss') - ax[1].plot(losses_arr[1:, 1], 'C3o-', label="Train Amp loss") - ax[1].plot(val_losses_arr[1:, 1], 'C0o-', label="Val Amp loss") + if 'training_loss_amplitude' in metrics: + ax[1].plot(metrics['step'], metrics['training_loss_amplitude'], 'C3o-', label="Train Amp loss") + if 'validation_loss_amplitude' in metrics: + ax[1].plot(metrics['step'], metrics['validation_loss_amplitude'], 'C0o-', label="Val Amp loss") ax[1].set(ylabel='Loss') ax[1].set_yscale('log') ax[1].grid() - ax[1].legend(loc='center right', bbox_to_anchor=(1.5, 0.5)) + ax[1].legend() ax[1].set_title('Phase loss') - # ax[2].plot(losses_arr[1:, 2], 'C3o-', label="Train Ph loss") - # ax[2].plot(val_losses_arr[1:, 2], 'C0o-', label="Val Ph loss") - # ax[2].set(ylabel='Loss') - # ax[2].grid() - # ax[2].legend(loc='center right', bbox_to_anchor=(1.5, 0.5)) - # ax[2].set_yscale('log') - # ax[2].set_title('Mag los') + if 'training_loss_phase' in metrics: + ax[2].plot(metrics['step'], metrics['training_loss_phase'], 'C3o-', label="Train Ph loss") + if 'validation_loss_phase' in metrics: + ax[2].plot(metrics['step'], metrics['validation_loss_phase'], 'C0o-', label="Val Ph loss") + ax[2].set(ylabel='Loss') + ax[2].grid() + ax[2].legend() + ax[2].set_yscale('log') + ax[2].set_title('Mag los') plt.tight_layout() - plt.xlabel("Epochs") + plt.xlabel("Steps") if save_fname is not None: plt.savefig(save_fname) From 63c9b05614d2e09e0fc0c67214bd8f2f167b5a7d Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 22 Dec 2023 16:23:10 -0600 Subject: [PATCH 07/12] Reimplement lr scheduler step rates --- ptychonn/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ptychonn/model.py b/ptychonn/model.py index 53cbae9..f9f2737 100644 --- a/ptychonn/model.py +++ b/ptychonn/model.py @@ -179,10 +179,13 @@ def configure_optimizers(self): self.parameters(), lr=self.max_lr, ) + iters_per_epoch = self.trainer.estimated_stepping_batches / self.trainer.max_epochs + iters_per_half_cycle = self.epochs_per_half_cycle * iters_per_epoch scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer, max_lr=self.max_lr, base_lr=self.min_lr, + step_size_up=iters_per_half_cycle, cycle_momentum=False, mode="triangular2", ) From 434de05c5945a8cbb42d1f0e655320f923dbcdaf Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 2 Jan 2024 16:28:02 -0600 Subject: [PATCH 08/12] BUG: Allow training out_dir to be None --- ptychonn/_train/__main__.py | 57 ++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index bda1e06..35611a0 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -114,6 +114,7 @@ def train( - {out_dir}/best_model.ckpt - {out_dir}/metrics.csv - {out_dir}/hparams.yaml + - {out_dir}/metrics.png Parameters ---------- @@ -130,27 +131,29 @@ def train( batch_size The size of one training batch. """ + if out_dir is not None: + + checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( + dirpath=out_dir, + filename="best_model", + save_top_k=1, + monitor="training_loss", + mode="min", + ) - checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( - dirpath=out_dir, - filename="best_model", - save_top_k=1, - monitor="training_loss", - mode="min", - ) - - logger = lightning.pytorch.loggers.CSVLogger( - save_dir=out_dir, - name="", - version="", - prefix="", - ) + logger = lightning.pytorch.loggers.CSVLogger( + save_dir=out_dir, + name="", + version="", + prefix="", + ) trainer = lightning.Trainer( max_epochs=epochs, default_root_dir=out_dir, - callbacks=[checkpoint_callback], - logger=logger, + callbacks=None if out_dir is None else [checkpoint_callback], + logger=False if out_dir is None else logger, + enable_checkpointing=False if out_dir is None else True, ) trainer.fit( @@ -162,17 +165,19 @@ def train( ), ) - with open(out_dir / "metrics.csv") as f: - headers = f.readline().strip('\n').split(",") - numbers = np.genfromtxt(out_dir / "metrics.csv", delimiter=",", skip_header=1,) - metrics = dict() - for col, header in enumerate(headers): - metrics[header] = numbers[:, col] + if out_dir is not None: - ptychonn.plot.plot_metrics( - metrics=metrics, - save_fname=out_dir / "metrics.png", - ) + with open(out_dir / "metrics.csv") as f: + headers = f.readline().strip('\n').split(",") + numbers = np.genfromtxt(out_dir / "metrics.csv", delimiter=",", skip_header=1,) + metrics = dict() + for col, header in enumerate(headers): + metrics[header] = numbers[:, col] + + ptychonn.plot.plot_metrics( + metrics=metrics, + save_fname=out_dir / "metrics.png", + ) return trainer From 0eca71616c23585c9dbb73a5ab3c292e004681fb Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 3 Jan 2024 14:34:18 -0600 Subject: [PATCH 09/12] NEW: Implement an in-memory logger --- ptychonn/__init__.py | 2 +- ptychonn/_train/__main__.py | 70 +++++++++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 11 deletions(-) diff --git a/ptychonn/__init__.py b/ptychonn/__init__.py index 1895be5..7d858ca 100644 --- a/ptychonn/__init__.py +++ b/ptychonn/__init__.py @@ -7,6 +7,6 @@ pass from ptychonn._infer.__main__ import infer, stitch_from_inference -from ptychonn._train.__main__ import train, init_or_load_model, create_training_dataloader +from ptychonn._train.__main__ import train, init_or_load_model, create_training_dataloader, ListLogger from ptychonn.model import * from ptychonn.plot import * diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index 35611a0..c82c754 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -1,3 +1,4 @@ +import argparse import glob import logging import os @@ -97,6 +98,52 @@ def train_cli( ) +class ListLogger(lightning.pytorch.loggers.logger.Logger): + """An in-memory logger that saves logged parameters to a List + + Parameters + ---------- + logs : + Each entry of this list is a dictionary with parameter name value + pairs. Each entry of the list represents the parameters during a single + step. + hyperparameters : + Some hyperparameters that were logged? + + """ + def __init__(self): + super().__init__() + self.logs: typing.List[typing.Dict] = [] + self.hyperparameters: argparse.Namespace = argparse.Namespace() + + @lightning.pytorch.utilities.rank_zero_only + def log_metrics(self, metrics, step=None): + metrics["step"] = step + self.logs.append(metrics) + + @lightning.pytorch.utilities.rank_zero_only + def log_hyperparams(self, params): + self.hyperparameters = params + + @lightning.pytorch.utilities.rank_zero_only + def save(self): + # No need to save anything for this logger + pass + + @lightning.pytorch.utilities.rank_zero_only + def finalize(self, status): + # Finalize the logger + pass + + @property + def name(self): + return "ListLogger" + + @property + def version(self): + return "0.1.0" + + def train( X_train: npt.NDArray[np.float32], Y_train: npt.NDArray[np.float32], @@ -104,7 +151,7 @@ def train( out_dir: pathlib.Path | None, epochs: int = 1, batch_size: int = 32, -): +) -> typing.Tuple[lightning.Trainer, lightning.pytorch.loggers.CSVLogger | ListLogger]: """Train a PtychoNN model. Initialize a model for the model parameter using the `init_or_load_model()` @@ -132,7 +179,6 @@ def train( The size of one training batch. """ if out_dir is not None: - checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( dirpath=out_dir, filename="best_model", @@ -148,11 +194,14 @@ def train( prefix="", ) + else: + logger = ListLogger() + trainer = lightning.Trainer( max_epochs=epochs, default_root_dir=out_dir, callbacks=None if out_dir is None else [checkpoint_callback], - logger=False if out_dir is None else logger, + logger=logger, enable_checkpointing=False if out_dir is None else True, ) @@ -166,10 +215,13 @@ def train( ) if out_dir is not None: - with open(out_dir / "metrics.csv") as f: - headers = f.readline().strip('\n').split(",") - numbers = np.genfromtxt(out_dir / "metrics.csv", delimiter=",", skip_header=1,) + headers = f.readline().strip("\n").split(",") + numbers = np.genfromtxt( + out_dir / "metrics.csv", + delimiter=",", + skip_header=1, + ) metrics = dict() for col, header in enumerate(headers): metrics[header] = numbers[:, col] @@ -179,7 +231,7 @@ def train( save_fname=out_dir / "metrics.png", ) - return trainer + return trainer, logger def create_training_dataloader( @@ -237,8 +289,6 @@ def init_or_load_model( raise ValueError(msg) if model_checkpoint_path is not None: - return model_type.load_from_checkpoint( - model_checkpoint_path - ) + return model_type.load_from_checkpoint(model_checkpoint_path) else: return model_type(**model_init_params) From 0aeef4e3540fe88692560caefc7bf27c71f972ec Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 3 Jan 2024 18:34:17 -0600 Subject: [PATCH 10/12] NEW: Add a function to save a model checkpoint --- ptychonn/__init__.py | 2 +- ptychonn/_train/__main__.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/ptychonn/__init__.py b/ptychonn/__init__.py index 7d858ca..c3dc963 100644 --- a/ptychonn/__init__.py +++ b/ptychonn/__init__.py @@ -7,6 +7,6 @@ pass from ptychonn._infer.__main__ import infer, stitch_from_inference -from ptychonn._train.__main__ import train, init_or_load_model, create_training_dataloader, ListLogger +from ptychonn._train.__main__ import train, init_or_load_model, create_training_dataloader, ListLogger, create_model_checkpoint from ptychonn.model import * from ptychonn.plot import * diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index c82c754..6cfbcc7 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -292,3 +292,11 @@ def init_or_load_model( return model_type.load_from_checkpoint(model_checkpoint_path) else: return model_type(**model_init_params) + +def create_model_checkpoint( + trainer: lightning.Trainer, + model_checkpoint_path: pathlib.Path, +): + trainer.save_checkpoint( + model_checkpoint_path, + ) From 4e68a7748ec471a878a25a9e18bba7dbcd3d21be Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 16 Jan 2024 14:34:11 -0600 Subject: [PATCH 11/12] REF: Create both test and validate dataloaders --- ptychonn/_train/__main__.py | 50 ++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 9 deletions(-) diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index 6cfbcc7..ccf5708 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -106,11 +106,12 @@ class ListLogger(lightning.pytorch.loggers.logger.Logger): logs : Each entry of this list is a dictionary with parameter name value pairs. Each entry of the list represents the parameters during a single - step. + step. Not every parameter is logged for each step. hyperparameters : Some hyperparameters that were logged? """ + def __init__(self): super().__init__() self.logs: typing.List[typing.Dict] = [] @@ -205,13 +206,16 @@ def train( enable_checkpointing=False if out_dir is None else True, ) + train_dataloader, val_dataloader = create_training_dataloader( + X_train, + Y_train, + batch_size, + ) + trainer.fit( model=model, - train_dataloaders=create_training_dataloader( - X_train, - Y_train, - batch_size, - ), + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, ) if out_dir is not None: @@ -238,9 +242,14 @@ def create_training_dataloader( X_train: npt.NDArray[np.float32], Y_train: npt.NDArray[np.float32], batch_size: int = 32, -) -> torch.utils.data.DataLoader: + training_fraction: float = 0.8, +) -> typing.Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader | None]: """Create a Pytorch Dataloader from NumPy arrays.""" + if training_fraction > 1.0 or training_fraction <= 0.0: + msg = f"training_fraction must be >0,<=1, not {training_fraction}!" + raise ValueError(msg) + assert Y_train.dtype == np.float32 assert np.all(np.isfinite(Y_train)) assert X_train.dtype == np.float32 @@ -264,14 +273,36 @@ def create_training_dataloader( torch.from_numpy(Y_train), ) - dataloader = torch.utils.data.DataLoader( + if training_fraction == 1.0: + trainingloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + drop_last=True, + ) + + return trainingloader, None + + training, validation = torch.utils.data.random_split( dataset, + [training_fraction, 1.0 - training_fraction], + ) + + trainingloader = torch.utils.data.DataLoader( + training, batch_size=batch_size, shuffle=True, drop_last=True, ) - return dataloader + validationloader = torch.utils.data.DataLoader( + validation, + batch_size=batch_size, + shuffle=False, + drop_last=True, + ) + + return trainingloader, validationloader def init_or_load_model( @@ -293,6 +324,7 @@ def init_or_load_model( else: return model_type(**model_init_params) + def create_model_checkpoint( trainer: lightning.Trainer, model_checkpoint_path: pathlib.Path, From 051f4dcbc84555b2c658ce50bad33f88b7410b73 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 16 Jan 2024 15:07:22 -0600 Subject: [PATCH 12/12] BUG: Filter nans out of metrics plot --- ptychonn/plot.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/ptychonn/plot.py b/ptychonn/plot.py index f56c5d6..9fa380f 100644 --- a/ptychonn/plot.py +++ b/ptychonn/plot.py @@ -2,15 +2,19 @@ import numpy as np +def _filter_nan(x, y): + mask = np.logical_and(np.isfinite(x), np.isfinite(y)) + return x[mask], y[mask] + def plot_metrics(metrics: dict, save_fname: str = None, show_fig: bool = False): fig, ax = plt.subplots(3, sharex=True, figsize=(16, 9)) if 'training_loss' in metrics: - ax[0].plot(metrics['step'], metrics['training_loss'], 'C3o-', label="Train") + ax[0].plot(*_filter_nan(metrics['step'], metrics['training_loss']), 'C3o-', label="Train") if 'validation_loss' in metrics: - ax[0].plot(metrics['step'], metrics['validation_loss'], 'C0o-', label="Val") + ax[0].plot(*_filter_nan(metrics['step'], metrics['validation_loss']), 'C0o-', label="Val") ax[0].set(ylabel='Loss') ax[0].set_yscale('log') ax[0].grid() @@ -18,9 +22,9 @@ def plot_metrics(metrics: dict, ax[0].set_title('Total loss') if 'training_loss_amplitude' in metrics: - ax[1].plot(metrics['step'], metrics['training_loss_amplitude'], 'C3o-', label="Train Amp loss") + ax[1].plot(*_filter_nan(metrics['step'], metrics['training_loss_amplitude']), 'C3o-', label="Train Amp loss") if 'validation_loss_amplitude' in metrics: - ax[1].plot(metrics['step'], metrics['validation_loss_amplitude'], 'C0o-', label="Val Amp loss") + ax[1].plot(*_filter_nan(metrics['step'], metrics['validation_loss_amplitude']), 'C0o-', label="Val Amp loss") ax[1].set(ylabel='Loss') ax[1].set_yscale('log') ax[1].grid() @@ -28,9 +32,9 @@ def plot_metrics(metrics: dict, ax[1].set_title('Phase loss') if 'training_loss_phase' in metrics: - ax[2].plot(metrics['step'], metrics['training_loss_phase'], 'C3o-', label="Train Ph loss") + ax[2].plot(*_filter_nan(metrics['step'], metrics['training_loss_phase']), 'C3o-', label="Train Ph loss") if 'validation_loss_phase' in metrics: - ax[2].plot(metrics['step'], metrics['validation_loss_phase'], 'C0o-', label="Val Ph loss") + ax[2].plot(*_filter_nan(metrics['step'], metrics['validation_loss_phase']), 'C0o-', label="Val Ph loss") ax[2].set(ylabel='Loss') ax[2].grid() ax[2].legend()