diff --git a/ptychonn/__init__.py b/ptychonn/__init__.py index 9476043..c3dc963 100644 --- a/ptychonn/__init__.py +++ b/ptychonn/__init__.py @@ -6,7 +6,7 @@ # package is not installed pass -from ptychonn._infer.__main__ import infer, stitch_from_inference, Tester -from ptychonn._train.__main__ import Trainer +from ptychonn._infer.__main__ import infer, stitch_from_inference +from ptychonn._train.__main__ import train, init_or_load_model, create_training_dataloader, ListLogger, create_model_checkpoint from ptychonn.model import * from ptychonn.plot import * diff --git a/ptychonn/_infer/__main__.py b/ptychonn/_infer/__main__.py index c9200e6..1d23b95 100644 --- a/ptychonn/_infer/__main__.py +++ b/ptychonn/_infer/__main__.py @@ -2,8 +2,8 @@ import typing import glob -from torch.utils.data import TensorDataset, DataLoader import click +import lightning import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt @@ -128,23 +128,21 @@ def infer_cli( inferences = infer( data=data, - model_params_path=params_path, + model=ptychonn.init_or_load_model( + ptychonn.model.LitReconSmallModel, + model_checkpoint_path=params_path, + model_init_params=None, + ) ) + # Plotting some summary images + pstitched = stitch_from_inference( inferences[:, 0], scan, stitched_pixel_width=1, inference_pixel_width=1, ) - astitched = stitch_from_inference( - inferences[:, 1], - scan, - stitched_pixel_width=1, - inference_pixel_width=1, - ) - - # Plotting some summary images plt.figure(1, figsize=[8.5, 7]) plt.imshow(pstitched) plt.colorbar() @@ -152,6 +150,12 @@ def infer_cli( plt.title('stitched_phases') plt.savefig(out_dir / 'pstitched.png', bbox_inches='tight') + astitched = stitch_from_inference( + inferences[:, 1], + scan, + stitched_pixel_width=1, + inference_pixel_width=1, + ) plt.figure(2, figsize=[8.5, 7]) plt.imshow(astitched) plt.colorbar() @@ -177,8 +181,8 @@ def infer_cli( def infer( - data: npt.NDArray, - model_params_path: pathlib.Path, + data: npt.NDArray[np.float32], + model: lightning.LightningModule, *, inferences_out_file: typing.Optional[pathlib.Path] = None, ) -> npt.NDArray: @@ -190,10 +194,15 @@ def infer( Set the CUDA_VISIBLE_DEVICES environment variable to control which GPUs will be used. + Initialize a model for the model parameter using the `init_or_load_model()` + function. + Parameters ---------- data : (POSITION, WIDTH, HEIGHT) Diffraction patterns to be reconstructed. + model + An initialized PtychoNN model. inferences_out_file : pathlib.Path Optional file to save reconstructed patches. @@ -202,82 +211,19 @@ def infer( inferences : (POSITION, 2, WIDTH, HEIGHT) The reconstructed patches inferred by the model. ''' - tester = Tester( - model=ptychonn.model.ReconSmallModel(), - model_params_path=model_params_path, - ) - tester.setTestData( - data, - batch_size=max(torch.cuda.device_count(), 1) * 64, - ) - return tester.predictTestData(npz_save_path=inferences_out_file) - - -class Tester(): - - def __init__( - self, - *, - model: torch.nn.Module, - model_params_path: pathlib.Path, - ): - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") - print(f"Let's use {torch.cuda.device_count()} GPUs!") - - self.model = model - - params = torch.load( - model_params_path, - map_location=self.device, - ) - self.model.load_state_dict(params) - - self.model = torch.nn.DataParallel(self.model) - - self.model.to(self.device) - - self.model.eval() - - def setTestData(self, X_test: np.ndarray, batch_size: int): - self.X_test = torch.tensor(X_test[:, None, ...], dtype=torch.float32) - self.test_data = TensorDataset(self.X_test) - self.testloader = DataLoader( - self.test_data, - batch_size=batch_size, - shuffle=False, - ) - - def predictTestData(self, npz_save_path: str = None): - - phs_eval = [] - with torch.inference_mode(): - for (ft_images, ) in self.testloader: - ph_eval = self.model(ft_images.to(self.device)) - phs_eval.append(ph_eval.detach().cpu().numpy()) - - self.phs_eval = np.concatenate(phs_eval, axis=0) - - if npz_save_path is not None: - np.savez_compressed(npz_save_path, ph=self.phs_eval) - print(f'Finished the inference stage and saved at {npz_save_path}') - - return self.phs_eval - - def calcErrors(self, phs_true: np.ndarray, npz_save_path: str = None): - from skimage.metrics import mean_squared_error as mse - - self.phs_true = phs_true - self.errors = [] - for i, (p1, p2) in enumerate(zip(self.phs_eval, self.phs_true)): - err2 = mse(p1, p2) - self.errors.append([err2]) - - self.errors = np.array(self.errors) - print("Mean errors in phase") - print(np.mean(self.errors, axis=0)) - - if npz_save_path is not None: - np.savez_compressed(npz_save_path, phs_err=self.errors[:, 0]) - - return self.errors + model.eval() + result = [] + with torch.no_grad(): + for batch in data: + result.append( + model(torch.from_numpy(batch[None, None, :, :]).to("cuda")) + .cpu() + .numpy() + ) + + result = np.concatenate(result, axis=0) + + if inferences_out_file is not None: + np.save(inferences_out_file, result) + + return result diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index 7138a49..ccf5708 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -1,13 +1,15 @@ +import argparse import glob import logging import os import pathlib +import typing import click +import lightning import numpy as np import numpy.typing as npt import torch -import torchinfo import ptychonn.model import ptychonn.plot @@ -85,451 +87,248 @@ def train_cli( train( X_train=data, Y_train=patches, + model=init_or_load_model( + ptychonn.LitReconSmallModel, + model_checkpoint_path=None, + model_init_params=dict(), + ), out_dir=out_dir, epochs=epochs, batch_size=32, ) +class ListLogger(lightning.pytorch.loggers.logger.Logger): + """An in-memory logger that saves logged parameters to a List + + Parameters + ---------- + logs : + Each entry of this list is a dictionary with parameter name value + pairs. Each entry of the list represents the parameters during a single + step. Not every parameter is logged for each step. + hyperparameters : + Some hyperparameters that were logged? + + """ + + def __init__(self): + super().__init__() + self.logs: typing.List[typing.Dict] = [] + self.hyperparameters: argparse.Namespace = argparse.Namespace() + + @lightning.pytorch.utilities.rank_zero_only + def log_metrics(self, metrics, step=None): + metrics["step"] = step + self.logs.append(metrics) + + @lightning.pytorch.utilities.rank_zero_only + def log_hyperparams(self, params): + self.hyperparameters = params + + @lightning.pytorch.utilities.rank_zero_only + def save(self): + # No need to save anything for this logger + pass + + @lightning.pytorch.utilities.rank_zero_only + def finalize(self, status): + # Finalize the logger + pass + + @property + def name(self): + return "ListLogger" + + @property + def version(self): + return "0.1.0" + + def train( - X_train: npt.NDArray[float], - Y_train: npt.NDArray[float], + X_train: npt.NDArray[np.float32], + Y_train: npt.NDArray[np.float32], + model: lightning.LightningModule, out_dir: pathlib.Path | None, - load_model_path: pathlib.Path | None = None, epochs: int = 1, batch_size: int = 32, -): +) -> typing.Tuple[lightning.Trainer, lightning.pytorch.loggers.CSVLogger | ListLogger]: """Train a PtychoNN model. + Initialize a model for the model parameter using the `init_or_load_model()` + function. + + If out_dir is not None the following artifacts will be created: + - {out_dir}/best_model.ckpt + - {out_dir}/metrics.csv + - {out_dir}/hparams.yaml + - {out_dir}/metrics.png + Parameters ---------- - X_train (N, WIDTH, HEIGHT) + X_train : (N, WIDTH, HEIGHT) The diffraction patterns. - Y_train (N, 2, WIDTH, HEIGHT) + Y_train : (N, 2, WIDTH, HEIGHT) The corresponding reconstructed patches for the diffraction patterns. out_dir A folder where all the training artifacts are saved. - load_model_path - Load a previous model's parameters from this file. + model + An initialized PtychoNN model. + epochs + The maximum number of training epochs + batch_size + The size of one training batch. """ - assert Y_train.dtype == np.float32 - assert np.all(np.isfinite(Y_train)) - assert X_train.dtype == np.float32 - assert np.all(np.isfinite(X_train)) - logger.info("Creating the training model...") + if out_dir is not None: + checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( + dirpath=out_dir, + filename="best_model", + save_top_k=1, + monitor="training_loss", + mode="min", + ) + + logger = lightning.pytorch.loggers.CSVLogger( + save_dir=out_dir, + name="", + version="", + prefix="", + ) - trainer = Trainer( - model=ptychonn.model.ReconSmallModel(), - batch_size=batch_size * torch.cuda.device_count(), - output_path=out_dir, + else: + logger = ListLogger() + + trainer = lightning.Trainer( + max_epochs=epochs, + default_root_dir=out_dir, + callbacks=None if out_dir is None else [checkpoint_callback], + logger=logger, + enable_checkpointing=False if out_dir is None else True, ) - trainer.setTrainingData( + + train_dataloader, val_dataloader = create_training_dataloader( X_train, Y_train, - valid_data_ratio=0.1, + batch_size, ) - trainer.setOptimizationParams( - epochs_per_half_cycle=6, - max_lr=1e-3, - min_lr=1e-4, + + trainer.fit( + model=model, + train_dataloaders=train_dataloader, + val_dataloaders=val_dataloader, ) - trainer.initModel(model_params_path=load_model_path) - trainer.run(epochs) if out_dir is not None: - trainer.plotLearningRate( - save_fname=out_dir / 'learning_rate.png', - show_fig=False, + with open(out_dir / "metrics.csv") as f: + headers = f.readline().strip("\n").split(",") + numbers = np.genfromtxt( + out_dir / "metrics.csv", + delimiter=",", + skip_header=1, ) + metrics = dict() + for col, header in enumerate(headers): + metrics[header] = numbers[:, col] + ptychonn.plot.plot_metrics( - trainer.metrics, - save_fname=out_dir / 'metrics.png', - show_fig=False, + metrics=metrics, + save_fname=out_dir / "metrics.png", ) - return trainer - + return trainer, logger -class Trainer(): - """A object that manages training PtychoNN - Artifacts - --------- - - When `output_path` is not None, the following artifacts are written to disk. +def create_training_dataloader( + X_train: npt.NDArray[np.float32], + Y_train: npt.NDArray[np.float32], + batch_size: int = 32, + training_fraction: float = 0.8, +) -> typing.Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader | None]: + """Create a Pytorch Dataloader from NumPy arrays.""" - ``` - `output_path` - reference - 00000.tiff - 00001.tiff - ... - inference - 00000.tiff - 00001.tiff - ... - metrics`output_suffix`.npz - best_model`output_suffix`.pth - ``` + if training_fraction > 1.0 or training_fraction <= 0.0: + msg = f"training_fraction must be >0,<=1, not {training_fraction}!" + raise ValueError(msg) - """ + assert Y_train.dtype == np.float32 + assert np.all(np.isfinite(Y_train)) + assert X_train.dtype == np.float32 + assert np.all(np.isfinite(X_train)) - def __init__( - self, - model: ptychonn.model.ReconSmallModel, - batch_size: int, - output_path: pathlib.Path | None = None, - output_suffix: str = '', - ): - logger.info("Initializing the training procedure...") - self.model = model - self.batch_size = batch_size - self.output_path = output_path - self.output_suffix = output_suffix - self.epoch = 0 - - def setTrainingData( - self, - X_train_full: np.ndarray, - Y_ph_train_full: np.ndarray, - valid_data_ratio: float = 0.1, - ): - """ - - Parameters - ---------- - X_train_full : (N, H, W) - The measured intensities at the detector - Y_ph_train_full : (N, C, H, W) - The phase and amplitude patches from the reconstructed object. - Phase in the zeroth channel and amplitude (optionally) in the first - channel - - """ - if (Y_ph_train_full.ndim != 4): - msg = ("Training data example patches must have a channel " - "dimension! i.e. the shape should be (N, C, H, W)") - raise ValueError(msg) - logger.info("Setting training data...") - - self.H, self.W = X_train_full.shape[-2:] - - self.X_train_full = torch.tensor( - X_train_full[:, None, ...], - dtype=torch.float32, + if X_train.ndim != 3: + msg = ( + "X_train must have 3 dimemnsions: (N, WIDTH, HEIGHT); " + f" not {X_train.shape}" ) - self.Y_ph_train_full = torch.tensor( - Y_ph_train_full, - dtype=torch.float32, + raise ValueError(msg) + if Y_train.ndim != 4: + msg = ( + f"Y_train must have 4 dimensions: (N, [1,2], WIDTH, HEIGHT); " + f"not {Y_train.shape}" ) - self.ntrain_full = self.X_train_full.shape[0] + raise ValueError(msg) - self.valid_data_ratio = valid_data_ratio - self.nvalid = int(self.ntrain_full * self.valid_data_ratio) - self.ntrain = self.ntrain_full - self.nvalid - - self.train_data_full = torch.utils.data.TensorDataset( - self.X_train_full, - self.Y_ph_train_full, - ) + dataset = torch.utils.data.TensorDataset( + torch.from_numpy(X_train[:, None, :, :]), + torch.from_numpy(Y_train), + ) - self.train_data, self.valid_data = torch.utils.data.random_split( - self.train_data_full, - [self.ntrain, self.nvalid], - ) - self.trainloader = torch.utils.data.DataLoader( - self.train_data, - batch_size=self.batch_size, + if training_fraction == 1.0: + trainingloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, shuffle=True, - num_workers=4, - drop_last=False, + drop_last=True, ) - self.validloader = torch.utils.data.DataLoader( - self.valid_data, - batch_size=self.batch_size, - shuffle=False, - num_workers=4, - drop_last=False, - ) + return trainingloader, None - self.iters_per_epoch = self.ntrain // self.batch_size + ( - self.ntrain % self.batch_size > 0) - - def setOptimizationParams( - self, - epochs_per_half_cycle: int = 6, - max_lr: float = 5e-4, - min_lr: float = 1e-4, - ): - logger.info("Setting optimization parameters...") - - # TODO: Move this note about iterations into the documentation string - # after figuring out what it means. Paper recommends 2-10 number of - # iterations - self.epochs_per_half_cycle = epochs_per_half_cycle - self.iters_per_half_cycle = epochs_per_half_cycle * self.iters_per_epoch - - logger.info( - "LR step size is: %d which is every %d epochs", - self.iters_per_half_cycle, - self.iters_per_half_cycle / self.iters_per_epoch, - ) + training, validation = torch.utils.data.random_split( + dataset, + [training_fraction, 1.0 - training_fraction], + ) - self.max_lr = max_lr - self.min_lr = min_lr + trainingloader = torch.utils.data.DataLoader( + training, + batch_size=batch_size, + shuffle=True, + drop_last=True, + ) - self.criterion = self.customLoss - self.optimizer = torch.optim.Adam( - self.model.parameters(), - lr=self.max_lr, - ) - self.scheduler = torch.optim.lr_scheduler.CyclicLR( - self.optimizer, - max_lr=self.max_lr, - base_lr=self.min_lr, - step_size_up=self.iters_per_half_cycle, - cycle_momentum=False, - mode='triangular2', - ) + validationloader = torch.utils.data.DataLoader( + validation, + batch_size=batch_size, + shuffle=False, + drop_last=True, + ) + + return trainingloader, validationloader - def initModel( - self, - model_params_path: pathlib.Path | None = None, - ): - """Load parameters from the disk then model to the GPU(s).""" - - self.device = torch.device( - "cuda" if torch.cuda.is_available() else "cpu") - print(f"Let's use {torch.cuda.device_count()} GPUs!") - - torchinfo.summary(self.model, (1, 1, self.H, self.W), device="cpu") - - self.model_params_path = model_params_path - - if model_params_path is not None: - self.model.load_state_dict( - torch.load( - self.model_params_path, - map_location=self.device, - )) - - self.model = torch.nn.DataParallel(self.model) - - self.model = self.model.to(self.device) - - self.scaler = torch.cuda.amp.GradScaler() - - logger.info("Setting up metrics...") - self.metrics = { - 'losses': [], - 'val_losses': [], - 'lrs': [], - 'best_val_loss': np.inf - } - logger.info(self.metrics) - - def train(self): - tot_loss = 0.0 - loss_ph = 0.0 - - for (ft_images, phs) in self.trainloader: - - # Move everything to device - ft_images = ft_images.to(self.device) - phs = phs.to(self.device) - - # Divide cumulative loss by number of batches-- slightly inaccurate - # because last batch is different size - pred_phs = self.model(ft_images) - loss_p = self.criterion(pred_phs, phs, self.ntrain) - # Monitor phase loss but only within support (which may not be same - # as true amp) - loss = loss_p - # Use equiweighted amps and phase - - # Zero current grads and do backprop - self.optimizer.zero_grad() - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - - tot_loss += loss.detach().item() - - loss_ph += loss_p.detach().item() - - # Update the LR according to the schedule -- CyclicLR updates each - # batch - self.scheduler.step() - self.metrics['lrs'].append(self.scheduler.get_last_lr()) - self.scaler.update() - - self.metrics['losses'].append([tot_loss, loss_ph]) - - def validate(self, epoch: int): - tot_val_loss = 0.0 - val_loss_ph = 0.0 - for (ft_images, phs) in self.validloader: - ft_images = ft_images.to(self.device) - phs = phs.to(self.device) - pred_phs = self.model(ft_images) - - val_loss_p = self.criterion(pred_phs, phs, self.nvalid) - val_loss = val_loss_p - - tot_val_loss += val_loss.detach().item() - val_loss_ph += val_loss_p.detach().item() - - self.metrics['val_losses'].append([tot_val_loss, val_loss_ph]) - - if self.output_path is not None: - self.saveMetrics( - self.metrics, - self.output_path, - self.output_suffix, - ) - - # Update saved model if val loss is lower - if (tot_val_loss < self.metrics['best_val_loss']): - logger.info( - "Saving improved model after Val Loss improved from %.5f to %.5f", - self.metrics['best_val_loss'], - tot_val_loss, - ) - self.metrics['best_val_loss'] = tot_val_loss - - if self.output_path is not None: - self.updateSavedModel( - self.model, - self.output_path, - self.output_suffix, - ) - - import matplotlib.pyplot as plt - os.makedirs(self.output_path / 'reference', exist_ok=True) - os.makedirs(self.output_path / 'inference', exist_ok=True) - plt.imsave(self.output_path / f'reference/{epoch:05d}.png', - phs[0, 0].detach().cpu().numpy().astype(np.float32)) - plt.imsave( - self.output_path / f'inference/{epoch:05d}.png', - pred_phs[0, 0].detach().cpu().numpy().astype(np.float32)) - - @staticmethod - def customLoss( - input: torch.tensor, - target: torch.tensor, - scaling: float, - ): - """A loss function which scales according to training set size.""" - assert torch.all(torch.isfinite(input)) - assert torch.all(torch.isfinite(target)) - return torch.sum(torch.mean( - torch.abs(input - target), - axis=(-1, -2), - )) / scaling - - # TODO: Use a callback instead of a static method for saving the model? - - @staticmethod - def updateSavedModel( - model: ptychonn.model.ReconSmallModel, - directory: pathlib.Path, - suffix: str = '', - ): - """Writes `model` parameters to `directory`/best_model`suffix`.pth - - The directory is created if it does not exist. - """ - fname = directory / f'best_model{ suffix }.pth' - logger.info("Saving best model as %s", fname) - os.makedirs(directory, exist_ok=True) - if isinstance(model, ( - torch.nn.DataParallel, - torch.nn.parallel.DistributedDataParallel, - )): - torch.save(model.module.state_dict(), fname) - else: - torch.save(model.state_dict(), fname) - - def getSavedModelPath(self) -> pathlib.Path | None: - """Return the path where `validate` will save the model weights""" - if self.output_path is None: - return None - return self.output_path / f'best_model{ self.output_suffix }.pth' - - @staticmethod - def saveMetrics( - metrics: dict, - directory: pathlib.Path, - suffix: str = '', - ): - """Writes `metrics` to `directory`/metrics`suffix`.npz - - The directory is created if it does not exist. - """ - os.makedirs(directory, exist_ok=True) - np.savez(directory / f'metrics{suffix}.npz', **metrics) - - def run(self, epochs: int, output_frequency: int = 1): - """The main training loop""" - - for epoch in range(epochs): - - #Set model to train mode - self.model.train() - - #Training loop - self.train() - - #Switch model to eval mode - self.model.eval() - - #Validation loop - with torch.inference_mode(): - self.validate(epoch) - - if epoch % output_frequency == 0: - logger.info( - 'Epoch: %d | FT | Train Loss: %1.03e | Val Loss: %1.03e', - epoch, - self.metrics['losses'][-1][0], - self.metrics['val_losses'][-1][0], - ) - logger.info( - 'Epoch: %d | Ph | Train Loss: %1.03e | Val Loss: %1.03e', - epoch, - self.metrics['losses'][-1][1], - self.metrics['val_losses'][-1][1], - ) - logger.info( - 'Epoch: %d | Ending LR: %1.03e', - epoch, - self.metrics['lrs'][-1][0], - ) - - def plotLearningRate( - self, - save_fname: pathlib.Path | None = None, - show_fig: bool = True, - ): - batches = np.linspace( - 0, - len(self.metrics['lrs']), - len(self.metrics['lrs']) + 1, + +def init_or_load_model( + model_type: typing.Type[lightning.LightningModule], + *, + model_checkpoint_path: pathlib.Path | None, + model_init_params: dict | None, +): + """Initialize one of the PtychoNN models via params or a checkpoint.""" + if not (model_checkpoint_path is None or model_init_params is None): + msg = ( + "One of model_checkpoint_path OR model_init_params must be None! " + "Both cannot be defined." ) - epoch_list = batches / self.iters_per_epoch - - import matplotlib.pyplot as plt - - f = plt.figure() - plt.plot(epoch_list[1:], self.metrics['lrs'], 'C3-') - plt.grid() - plt.ylabel("Learning rate") - plt.xlabel("Epoch") - plt.tight_layout() - - if save_fname is not None: - plt.savefig(save_fname) - if show_fig: - plt.show() - else: - plt.close(f) + raise ValueError(msg) + + if model_checkpoint_path is not None: + return model_type.load_from_checkpoint(model_checkpoint_path) + else: + return model_type(**model_init_params) + + +def create_model_checkpoint( + trainer: lightning.Trainer, + model_checkpoint_path: pathlib.Path, +): + trainer.save_checkpoint( + model_checkpoint_path, + ) diff --git a/ptychonn/model.py b/ptychonn/model.py index 0554108..f9f2737 100644 --- a/ptychonn/model.py +++ b/ptychonn/model.py @@ -1,11 +1,12 @@ """Define PtychoNN Pytorch models.""" +import lightning import numpy as np import torch import torch.nn as nn -class ReconSmallModel(nn.Module): +class LitReconSmallModel(lightning.LightningModule): """A small PychoNN model. Parameters @@ -30,12 +31,18 @@ def __init__( nconv: int = 16, use_batch_norm: bool = True, enable_amplitude: bool = True, + min_lr: float = 1e-4, + max_lr: float = 5e-4, ): super().__init__() self.nconv = nconv self.use_batch_norm = use_batch_norm self.enable_amplitude = enable_amplitude + self.epochs_per_half_cycle: int = 6 + self.max_lr: float = max_lr + self.min_lr: float = min_lr + # Appears sequential has similar functionality as TF avoiding need for # separate model definition and activ self.encoder = nn.Sequential( @@ -118,12 +125,71 @@ def up_block(self, filters_in: int, filters_out: int, groups: int): ] def forward(self, x): - with torch.cuda.amp.autocast(): - output = self.decoder(self.encoder(x)) - # Restore -pi to pi range - # Using tanh activation (-1 to 1) for phase so multiply by pi - output[..., 0, :, :] = torch.tanh(output[..., 0, :, :]) * np.pi - # Restrict amplitude to (0, 1) range with sigmoid - if self.enable_amplitude: - output[..., 1, :, :] = torch.sigmoid(output[..., 1, :, :]) + output = self.decoder(self.encoder(x)) + # Restore -pi to pi range + # Using tanh activation (-1 to 1) for phase so multiply by pi + output[..., 0, :, :] = torch.tanh(output[..., 0, :, :]) * np.pi + # Restrict amplitude to (0, 1) range with sigmoid + if self.enable_amplitude: + output[..., 1, :, :] = torch.sigmoid(output[..., 1, :, :]) return output + + def training_step(self, batch, batch_idx): + ft_images, object_roi = batch + predicted = self.forward(ft_images) + loss_phase = torch.nn.functional.mse_loss( + predicted[..., 0, :, :], + object_roi[..., 0, :, :], + ) + self.log("training_loss_phase", loss_phase) + if self.enable_amplitude: + loss_amp = torch.nn.functional.mse_loss( + predicted[..., 1, :, :], + object_roi[..., 1, :, :], + ) + loss = loss_phase + loss_amp + self.log("training_loss_amplitude", loss_amp) + else: + loss = loss_phase + self.log("training_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + ft_images, object_roi = batch + predicted = self.forward(ft_images) + loss_phase = torch.nn.functional.mse_loss( + predicted[..., 0, :, :], + object_roi[..., 0, :, :], + ) + self.log("validation_loss_phase", loss_phase) + if self.enable_amplitude: + loss_amp = torch.nn.functional.mse_loss( + predicted[..., 1, :, :], + object_roi[..., 1, :, :], + ) + loss = loss_phase + loss_amp + self.log("validation_loss_amplitude", loss_amp) + else: + loss = loss_phase + self.log("validation_loss", loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.max_lr, + ) + iters_per_epoch = self.trainer.estimated_stepping_batches / self.trainer.max_epochs + iters_per_half_cycle = self.epochs_per_half_cycle * iters_per_epoch + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + max_lr=self.max_lr, + base_lr=self.min_lr, + step_size_up=iters_per_half_cycle, + cycle_momentum=False, + mode="triangular2", + ) + return dict( + optimizer=optimizer, + lr_scheduler=scheduler, + ) diff --git a/ptychonn/plot.py b/ptychonn/plot.py index a54b930..9fa380f 100644 --- a/ptychonn/plot.py +++ b/ptychonn/plot.py @@ -2,40 +2,47 @@ import numpy as np +def _filter_nan(x, y): + mask = np.logical_and(np.isfinite(x), np.isfinite(y)) + return x[mask], y[mask] + def plot_metrics(metrics: dict, save_fname: str = None, show_fig: bool = False): - losses_arr = np.array(metrics['losses']) - val_losses_arr = np.array(metrics['val_losses']) - print("Shape of losses array is ", losses_arr.shape) - fig, ax = plt.subplots(3, sharex=True, figsize=(15, 8)) - ax[0].plot(losses_arr[1:, 0], 'C3o-', label="Train") - ax[0].plot(val_losses_arr[1:, 0], 'C0o-', label="Val") + fig, ax = plt.subplots(3, sharex=True, figsize=(16, 9)) + if 'training_loss' in metrics: + ax[0].plot(*_filter_nan(metrics['step'], metrics['training_loss']), 'C3o-', label="Train") + if 'validation_loss' in metrics: + ax[0].plot(*_filter_nan(metrics['step'], metrics['validation_loss']), 'C0o-', label="Val") ax[0].set(ylabel='Loss') ax[0].set_yscale('log') ax[0].grid() - ax[0].legend(loc='center right') + ax[0].legend() ax[0].set_title('Total loss') - ax[1].plot(losses_arr[1:, 1], 'C3o-', label="Train Amp loss") - ax[1].plot(val_losses_arr[1:, 1], 'C0o-', label="Val Amp loss") + if 'training_loss_amplitude' in metrics: + ax[1].plot(*_filter_nan(metrics['step'], metrics['training_loss_amplitude']), 'C3o-', label="Train Amp loss") + if 'validation_loss_amplitude' in metrics: + ax[1].plot(*_filter_nan(metrics['step'], metrics['validation_loss_amplitude']), 'C0o-', label="Val Amp loss") ax[1].set(ylabel='Loss') ax[1].set_yscale('log') ax[1].grid() - ax[1].legend(loc='center right', bbox_to_anchor=(1.5, 0.5)) + ax[1].legend() ax[1].set_title('Phase loss') - # ax[2].plot(losses_arr[1:, 2], 'C3o-', label="Train Ph loss") - # ax[2].plot(val_losses_arr[1:, 2], 'C0o-', label="Val Ph loss") - # ax[2].set(ylabel='Loss') - # ax[2].grid() - # ax[2].legend(loc='center right', bbox_to_anchor=(1.5, 0.5)) - # ax[2].set_yscale('log') - # ax[2].set_title('Mag los') + if 'training_loss_phase' in metrics: + ax[2].plot(*_filter_nan(metrics['step'], metrics['training_loss_phase']), 'C3o-', label="Train Ph loss") + if 'validation_loss_phase' in metrics: + ax[2].plot(*_filter_nan(metrics['step'], metrics['validation_loss_phase']), 'C0o-', label="Val Ph loss") + ax[2].set(ylabel='Loss') + ax[2].grid() + ax[2].legend() + ax[2].set_yscale('log') + ax[2].set_title('Mag los') plt.tight_layout() - plt.xlabel("Epochs") + plt.xlabel("Steps") if save_fname is not None: plt.savefig(save_fname) diff --git a/requirements-dev b/requirements-dev index 4a8ce60..46ad4f3 100644 --- a/requirements-dev +++ b/requirements-dev @@ -6,4 +6,4 @@ pytorch >=1.12,<2.1 scikit-image scipy tqdm -torchinfo +lightning >=2.1.3,<3.0 diff --git a/setup.cfg b/setup.cfg index 6183aee..5c03881 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,7 +13,7 @@ install_requires = scikit-image scipy tqdm - torchinfo + lightning >=2.1.3,<3 [options.entry_points] console_scripts =