diff --git a/Continual Learning/Inference/README b/Continual Learning/Inference/README deleted file mode 100644 index 2618203..0000000 --- a/Continual Learning/Inference/README +++ /dev/null @@ -1,7 +0,0 @@ -infer_workflow.py - does the inference on trained model using PyTorch -It also stitches the individual inferences using the positions in the optimized route.npz. -The pixel size in the stitiching can be changed accordingly. -Typical choices for pixelsize is 8 or 10 nm. - - -The base path in the code has to be changed in order to run the code. diff --git a/Continual Learning/Inference/inferences_0_to_4_scan506.png b/Continual Learning/Inference/inferences_0_to_4_scan506.png deleted file mode 100644 index 2c60af1..0000000 Binary files a/Continual Learning/Inference/inferences_0_to_4_scan506.png and /dev/null differ diff --git a/Continual Learning/Inference/out/README b/Continual Learning/Inference/out/README deleted file mode 100644 index 8efe7da..0000000 --- a/Continual Learning/Inference/out/README +++ /dev/null @@ -1 +0,0 @@ -The output of the inference is dumped as a npz file here diff --git a/Continual Learning/Inference/src/README b/Continual Learning/Inference/src/README deleted file mode 100644 index 9055a34..0000000 --- a/Continual Learning/Inference/src/README +++ /dev/null @@ -1 +0,0 @@ -Has the trained model, position file and sample test data diff --git a/Continual Learning/Inference/src/optimized_route.npz b/Continual Learning/Inference/src/optimized_route.npz deleted file mode 100644 index 6eb5afb..0000000 Binary files a/Continual Learning/Inference/src/optimized_route.npz and /dev/null differ diff --git a/Continual Learning/Inference/src/scan_506_000793.h5 b/Continual Learning/Inference/src/scan_506_000793.h5 deleted file mode 100644 index b6334f1..0000000 Binary files a/Continual Learning/Inference/src/scan_506_000793.h5 and /dev/null differ diff --git a/Continual Learning/Inference/stitched_506.png b/Continual Learning/Inference/stitched_506.png deleted file mode 100644 index 465dde6..0000000 Binary files a/Continual Learning/Inference/stitched_506.png and /dev/null differ diff --git a/Continual Learning/Training/README b/Continual Learning/Training/README deleted file mode 100644 index 3d7925a..0000000 --- a/Continual Learning/Training/README +++ /dev/null @@ -1,4 +0,0 @@ -This folder has the continual learning code for PtychoNN. - -small_1_ptychonn_model.py -- main code for training -dataPrepPtychoNN_tiff.py -- script for prepping the data for PtychoNN diff --git a/Continual Learning/Training/agx_update.py b/Continual Learning/Training/agx_update.py deleted file mode 100644 index 2cba601..0000000 --- a/Continual Learning/Training/agx_update.py +++ /dev/null @@ -1,12 +0,0 @@ -import paramiko -import os - -def agx_update(best_model_path): - ssh=paramiko.SSHClient() - ssh.load_host_keys(os.path.expanduser(os.path.join("~",".ssh","known_hosts"))) - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect("## IP of edge device", username=" ", password=" ", look_for_keys=False, allow_agent=False) - sftp=ssh.open_sftp() - - sftp.put(best_model_path, "# path to which the model has to be pushed") - diff --git a/Continual Learning/Training/custom_logger.py b/Continual Learning/Training/custom_logger.py deleted file mode 100644 index 407420c..0000000 --- a/Continual Learning/Training/custom_logger.py +++ /dev/null @@ -1,24 +0,0 @@ -import sys -import logging - -class LoggerWriter(object): - def __init__(self, writer): - self._writer = writer - self._msg = '' - - def write(self, message): - self._msg = self._msg + message - while '\n' in self._msg: - pos = self._msg.find('\n') - self._writer(self._msg[:pos]) - self._msg = self._msg[pos+1:] - - def flush(self): - if self._msg != '': - self._writer(self._msg) - self._msg = '' - -def setupLogging(out_path): - logging.basicConfig(filename=f'{out_path}/log', filemode='w', level=logging.DEBUG, format='%(message)s') - log = logging.getLogger('logger') - sys.stdout = LoggerWriter(log.info) \ No newline at end of file diff --git a/Continual Learning/Training/dataPrepPtychoNN_tiff.py b/Continual Learning/Training/dataPrepPtychoNN_tiff.py deleted file mode 100644 index e7e7109..0000000 --- a/Continual Learning/Training/dataPrepPtychoNN_tiff.py +++ /dev/null @@ -1,125 +0,0 @@ -import numpy as np -from scipy import interpolate -import sys -import os -import fabio -from datetime import datetime - -import time -from matplotlib import pyplot, colors -import glob -from PIL import Image -import hdf5plugin -import h5py - -## for denoising -from skimage.restoration import (denoise_tv_chambolle, denoise_bilateral, - denoise_wavelet, estimate_sigma, denoise_nl_means) - -def diff_data(reciprocalpath, scannum): - - dir_name = reciprocal_path+str(scannum)+"/"+str(scannum)+"/" - list_of_files = sorted(filter(os.path.isfile, - glob.glob(dir_name + '*') ) ) - - scan_arr = np.zeros((len(list_of_files), 512, 512)) - for i, file_path in enumerate(list_of_files): - imarray = np.asarray(Image.open(file_path)) - imarray = imarray[2:-2,2:-2] - scan_arr[i] = imarray[:,::-1] - - num_tiff = len(list_of_files) - return scan_arr, num_tiff - - -def Prep_PtychoNN(recon_path, reciprocal_path, scannum, search_path): - """ - recon_path: path to the Tike recon output - reciprocal: input diffraction patterns - scannum: scan ID as a command line argument - search_path: folder location to dump the preprocessed data for ptychoNN - Training code will monitor this folder for new files for training - - """ - # reading the diff patterns from h5 file - reciprocal_path = reciprocal_path + str(scannum) - for df in glob.glob(reciprocal_path+'/scan*.h5'): - with h5py.File(df) as f: - data = f['entry/data/data'][()] - data[data<0]=0 - - - ##post_message("preparing data for ptychoNN") - - angle_path = recon_path+ str(scannum) + '/scan-{0}_object_angle.tiff'.format(scannum) - obj_ph = np.asarray(Image.open(angle_path)) ## read from a tiff file - - ## do image denoising - #obj_ph = denoise_nl_means(obj_ph, preserve_range=True) - - real_path = recon_path+ str(scannum) + '/scan-{0}_object_amp.tiff'.format(scannum) - ampl = np.asarray(Image.open(real_path)) - - - - pixelsize = 11.176e-9 - - - - - - - amp = ampl ## - - pha = obj_ph - - - pha_mean = pha[int(pha.shape[0]/3.):int(pha.shape[0]/3.*2),int(pha.shape[1]/3.):int(pha.shape[1]/3.*2)].mean() - pha -= pha_mean - - - - - - - pos = np.genfromtxt(reciprocal_path +'/positions.csv', delimiter=',') - - - - x = np.arange(obj_ph.shape[1])*pixelsize - y = np.arange(obj_ph.shape[0])*pixelsize - x -= x.mean() - y -= y.mean() - - fint_pha = interpolate.interp2d(x, y, pha, kind='cubic') - fint_amp = interpolate.interp2d(x, y, amp, kind='cubic') - - - real = np.zeros((pos.shape[0], 128, 128), dtype=np.complex64) - xx = np.arange(128)*10e-9 # this is to predict on a 1 um area per point - yy = np.arange(128)*10e-9 - xx -= xx.mean() - yy -= yy.mean() - - for i in range(pos.shape[0]): - real[i] = fint_amp(xx+pos[i,1], yy+pos[i,0])*np.exp(1j*fint_pha(xx+pos[i,1], yy+pos[i,0])) - - - np.savez_compressed(search_path+"scan{0}.npz".format(scannum), real=real, reciprocal=data.astype('float32'), position=pos, pixelsize=pixelsize) - - -if __name__=="__main__": - - #bot_token = "xoxb-679835710832-2052497567909-oB0WeYpoEChiXF3FL0XYm1tb" - #webclient = WebClient(token=bot_token) - - scanID = int(sys.argv[1]) - recon_path ="/grand/hp-ptycho/bicer/202206_run00_workflow-Tao/output/" - search_path = "/grand/hp-ptycho/anakha/S26-beamtime/Training/" - reciprocal_path = "/grand/hp-ptycho/bicer/202206_run00_workflow-Tao/input/" - - #data, num_tiff = diff_data(reciprocal_path, scanID) - #for scanID in range(411, 422): - Prep_PtychoNN(recon_path, reciprocal_path, scanID, search_path) - - diff --git a/Continual Learning/Training/helper_small_model.py b/Continual Learning/Training/helper_small_model.py deleted file mode 100644 index 2140949..0000000 --- a/Continual Learning/Training/helper_small_model.py +++ /dev/null @@ -1,551 +0,0 @@ -import torch, torchvision -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from torchinfo import summary -from torch.utils.data import TensorDataset, DataLoader -import os - -import numpy as np - -import matplotlib.pyplot as plt -import matplotlib -import skimage - - - -def plot3(data: list, titles: list = [], save_fname: str =None): - if(len(titles)<3): - titles=["Plot1", "Plot2", "Plot3"] - fig,ax = plt.subplots(1,3, figsize=(20,12)) - im=ax[0].imshow(data[0]) - ax[0].set_title(titles[0]) - ax[0].axis('off') - plt.colorbar(im,ax=ax[0], fraction=0.046, pad=0.04) - im=ax[1].imshow(data[1]) - ax[1].set_title(titles[1]) - ax[1].axis('off') - plt.colorbar(im,ax=ax[1], fraction=0.046, pad=0.04) - im=ax[2].imshow(data[2]) - ax[2].set_title(titles[2]) - ax[2].axis('off') - plt.colorbar(im,ax=ax[2], fraction=0.046, pad=0.04) - - plt.tight_layout() - if save_fname is not None: - plt.savefig(save_fname) - plt.show() - - -def getIntPositions(positions_all: np.ndarray, # Set as (y,x). - columnwise: bool = False, # whether the images are column-wise or row-wise. True corresponds to columnwise - pixel_size:float =8e-9, - downsampling:float =1 - ): - """ Use downsampling=1 if the positions and pixel size have already been downsampled. Use downsampling < 1 if we want to upsample - the images before interpolation (to better account for subpixel shifts).""" - - raise NotImplementedError("I am not sure it is working correctly just yet.") - if columnwise: - pos_x = np.array(positions_all[:,0]) - pos_y = np.array(positions_all[:, 1]) - else: - pos_x = np.array(positions_all[:,1]) - pos_y = np.array(positions_all[:,0]) - - pos_row = (pos_x-np.min(pos_x)) / (pixel_size ) / downsampling - pos_col = (pos_y-np.min(pos_y)) / (pixel_size ) / downsampling - - # integer position - pos_int_row = pos_row.astype(np.int32) - pos_int_col = pos_col.astype(np.int32) - - pos_subpixel_row = pos_row - pos_int_row - pos_subpixel_col = pos_col - pos_int_col - return pos_int_row, pos_int_col, pos_subpixel_row, pos_subpixel_col - -def stitch(slices: np.ndarray, - positions_row: np.ndarray, - positions_col: np.ndarray, - columnwise: bool = False, # whether the images are column-wise or row-wise - upsample_factor: int = 1, # Use > 1 if we want to - ): - - """The 'columnwise' part of this function and getIntPositions can be signified, but I can do that later. - - Use upsampling_factor > 1 if we want to upsample images before interpolation (to better account for subpixel shifts). - Assumes that the positions supplied are floats and accurately represent the current (not upsampled) scan position. - """ - raise NotImplementedError("I am not sure it is working correctly just yet.") - pos_int_row = (positions_row * upsample_factor).astype(np.int32) - pos_int_col = (positions_col * upsample_factor).astype(np.int32) - - size = slices[0].shape[0] - weights = None - size_h = size // 2 - composite = np.zeros((np.max(pos_int_row) + size, np.max(pos_int_col) + size), slices.dtype) - #print('Composite shape before trimming', composite.shape) - ctr = np.zeros(composite.shape) - if weights is None: - weights = np.ones((np.array(slices[0].shape) * upsample_factor).astype('int32'), dtype='float32') - - for i in range(pos_int_row.shape[0]): - - slice_to_add = slices[i] if columnwise else slices[i].T - if upsample_factor > 1: - if slice_to_add.dtype in [np.complex64, np.complex128]: - mag_slice_to_add = skimage.transform.rescale(np.abs(slice_to_add), upsample_factor, preserve_range=True) - ph_slice_to_add = skimage.transform.rescale(np.angle(slice_to_add), upsample_factor, preserve_range=True) - slice_to_add = mag_slice_to_add * np.exp(1j * ph_slice_to_add) - else: - slice_to_add = skimage.transform.rescale(slice_to_add, upsample_factor, preserve_range=True) - - composite[pos_int_row[i]: pos_int_row[i] + size, pos_int_col[i]: pos_int_col[i] + size] += slice_to_add * weights - - ctr[pos_int_row[i]:pos_int_row[i] + size, pos_int_col[i]:pos_int_col[i] + size] += weights#pb_weight - - - composite = composite[size_h:-size_h,size_h:-size_h] - ctr = ctr[size_h:-size_h, size_h:-size_h] - - composite /= (ctr + 1e-8) - return composite - -class ReconSmallPhaseModel(nn.Module): - def __init__(self, nconv: int = 16): - super(ReconSmallPhaseModel, self).__init__() - self.nconv = nconv - - self.encoder = nn.Sequential( # Appears sequential has similar functionality as TF avoiding need for separate model definition and activ - *self.down_block(1, self.nconv), - *self.down_block(self.nconv, self.nconv * 2), - *self.down_block(self.nconv * 2, self.nconv * 4), - *self.down_block(self.nconv * 4, self.nconv * 8) - ) - - # amplitude model - #self.decoder1 = nn.Sequential( - # *self.up_block(self.nconv * 8, self.nconv * 8), - # *self.up_block(self.nconv * 8, self.nconv * 4), - # *self.up_block(self.nconv * 4, self.nconv * 2), - # *self.up_block(self.nconv * 2, self.nconv * 1), - # nn.Conv2d(self.nconv * 1, 1, 3, stride=1, padding=(1,1)), - #) - - # phase model - self.decoder2 = nn.Sequential( - *self.up_block(self.nconv * 8, self.nconv * 8), - *self.up_block(self.nconv * 8, self.nconv * 4), - *self.up_block(self.nconv * 4, self.nconv * 2), - *self.up_block(self.nconv * 2, self.nconv * 1), - nn.Conv2d(self.nconv * 1, 1, 3, stride=1, padding=(1,1)), - nn.Tanh() - ) - - def down_block(self, filters_in, filters_out): - block = [ - nn.Conv2d(in_channels=filters_in, out_channels=filters_out, kernel_size=3, stride=1, padding=(1,1)), - nn.ReLU(), - nn.Conv2d(filters_out, filters_out, 3, stride=1, padding=(1,1)), - nn.ReLU(), - nn.MaxPool2d((2,2)) - ] - return block - - - def up_block(self, filters_in, filters_out): - block = [ - nn.Conv2d(filters_in, filters_out, 3, stride=1, padding=(1,1)), - nn.ReLU(), - nn.Conv2d(filters_out, filters_out, 3, stride=1, padding=(1,1)), - nn.ReLU(), - nn.Upsample(scale_factor=2, mode='bilinear') - ] - return block - - - def forward(self,x): - with torch.cuda.amp.autocast(): - x1 = self.encoder(x) - #amp = self.decoder1(x1) - ph = self.decoder2(x1) - - #Restore -pi to pi range - ph = ph*np.pi #Using tanh activation (-1 to 1) for phase so multiply by pi - - return ph - - -def plot_metrics(metrics: dict, save_fname: str = None, show_fig: bool = False): - - losses_arr = np.array(metrics['losses']) - val_losses_arr = np.array(metrics['val_losses']) - print("Shape of losses array is ", losses_arr.shape) - fig, ax = plt.subplots(3,sharex=True, figsize=(15, 8)) - ax[0].plot(losses_arr[1:,0], 'C3o-', label = "Train") - ax[0].plot(val_losses_arr[1:,0], 'C0o-', label = "Val") - ax[0].set(ylabel='Loss') - ax[0].set_yscale('log') - ax[0].grid() - ax[0].legend(loc='center right') - ax[0].set_title('Total loss') - - #ax[1].plot(losses_arr[1:,1], 'C3o-', label = "Train Amp loss") - #ax[1].plot(val_losses_arr[1:,1], 'C0o-', label = "Val Amp loss") - #ax[1].set(ylabel='Loss') - #ax[1].set_yscale('log') - #ax[1].grid() - #ax[1].legend(loc='center right', bbox_to_anchor=(1.5, 0.5)) - #ax[1].set_title('Phase loss') - - ax[2].plot(losses_arr[1:,2], 'C3o-', label = "Train Ph loss") - ax[2].plot(val_losses_arr[1:,2], 'C0o-', label = "Val Ph loss") - ax[2].set(ylabel='Loss') - ax[2].grid() - #ax[2].legend(loc='center right', bbox_to_anchor=(1.5, 0.5)) - ax[2].set_yscale('log') - ax[2].set_title('Mag los') - - plt.tight_layout() - plt.xlabel("Epochs") - - if save_fname is not None: - plt.savefig(save_fname) - if show_fig: - plt.show() - else: - plt.close() - - -def plot_test_data(selected_diffs: np.ndarray, selected_phs_true: np.ndarray, - selected_phs_eval: np.ndarray, - save_fname: str = None, show_fig: bool = True): - - n = selected_diffs_eval.shape[0] - - plt.viridis() - - f,ax=plt.subplots(7, n, figsize=(n * 4, 15)) - plt.gcf().text(0.02, 0.85, "Input", fontsize=20) - plt.gcf().text(0.02, 0.72, "True I", fontsize=20) - plt.gcf().text(0.02, 0.6, "Predicted I", fontsize=20) - plt.gcf().text(0.02, 0.5, "Difference I", fontsize=20) - plt.gcf().text(0.02, 0.4, "True Phi", fontsize=20) - plt.gcf().text(0.02, 0.27, "Predicted Phi", fontsize=20) - plt.gcf().text(0.02, 0.17, "Difference Phi", fontsize=20) - - for i in range(0,n): - - # display FT - - im=ax[0,i].imshow(np.log10(selected_diffs[i])) - plt.colorbar(im, ax=ax[0,i], format='%.2f') - ax[0,i].get_xaxis().set_visible(False) - ax[0,i].get_yaxis().set_visible(False) - - - # display predicted intens - #im=ax[2,i].imshow(selected_amps_eval[i]) - #plt.colorbar(im, ax=ax[2,i], format='%.2f') - #ax[2,i].get_xaxis().set_visible(False) - #ax[2,i].get_yaxis().set_visible(False) - - # display original phase - im=ax[4,i].imshow(selected_phs_true[i]) - plt.colorbar(im, ax=ax[4,i], format='%.2f') - ax[4,i].get_xaxis().set_visible(False) - ax[4,i].get_yaxis().set_visible(False) - - # display predicted phase - im=ax[5,i].imshow(selected_phs_eval[i]) - plt.colorbar(im, ax=ax[5,i], format='%.2f') - ax[5,i].get_xaxis().set_visible(False) - ax[5,i].get_yaxis().set_visible(False) - - - # Difference in phase - im=ax[6,i].imshow(selected_phs_true[i] - selected_phs_eval[i]) - plt.colorbar(im, ax=ax[6,i], format='%.2f') - ax[6,i].get_xaxis().set_visible(False) - ax[6,i].get_yaxis().set_visible(False) - - if save_fname is not None: - plt.savefig(save_fname) - if show_fig: - plt.show() - else: - plt.close() - -class Trainer(): - def __init__(self, model: ReconSmallPhaseModel, batch_size: int, output_path: str, output_suffix: str): - self.model = model - self.batch_size = batch_size - self.output_path = output_path - self.output_suffix = output_suffix - self.epoch = 0 - - def setTrainingData(self, X_train_full: np.ndarray, Y_ph_train_full: np.ndarray, - valid_data_ratio: float = 0.1): - - self.H, self.W = X_train_full.shape[-2:] - - self.X_train_full = torch.tensor(X_train_full[:, None, ...].astype('float32')) - self.Y_ph_train_full = torch.tensor(Y_ph_train_full[:, None, ...].astype('float32')) - self.ntrain_full = self.X_train_full.shape[0] - - self.valid_data_ratio = valid_data_ratio - self.nvalid = int(self.ntrain_full * self.valid_data_ratio) - self.ntrain = self.ntrain_full - self.nvalid - - - self.train_data_full = TensorDataset(self.X_train_full, self.Y_ph_train_full) - - self.train_data, self.valid_data = torch.utils.data.random_split(self.train_data_full, [self.ntrain, self.nvalid]) - self.trainloader = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=4) - - self.validloader = DataLoader(self.valid_data, batch_size=self.batch_size, shuffle=True, num_workers=4) - - self.iters_per_epoch = int(np.floor((self.ntrain) / self.batch_size) + 1) #Final batch will be less than batch size - - - - - - def setOptimizationParams(self, epochs_per_half_cycle: int = 6, max_lr: float=5e-4, min_lr: float=1e-4): - #Optimizer details - - self.epochs_per_half_cycle = epochs_per_half_cycle - self.iters_per_half_cycle = epochs_per_half_cycle * self.iters_per_epoch #Paper recommends 2-10 number of iterations - - print("LR step size is:", self.iters_per_half_cycle, - "which is every %d epochs" %(self.iters_per_half_cycle / self.iters_per_epoch)) - - self.max_lr = max_lr - self.min_lr = min_lr - - #criterion = lambda t1, t2: nn.L1Loss() - self.criterion = self.customLoss - self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.max_lr) - self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, max_lr = self.max_lr, base_lr= self.min_lr, - step_size_up = self.iters_per_half_cycle, - cycle_momentum=False, mode='triangular2') - - def testForwardSingleBatch(self): - for ft_images, phs in self.trainloader: - print("batch size:", ft_images.shape) - ph_train = self.model(ft_images) - print("Phase batch shape: ", ph_train.shape) - print("Phase batch dtype", ph_train.dtype) - - loss_ph = self.criterion(ph_train, phs, self.ntrain) - print("Phase loss", loss_ph) - break - - def initModel(self, model_params_path: str = None): - - self.model_params_path = model_params_path - if model_params_path is not None: - self.model.load_state_dict(torch.load(self.model_params_path)) - summary(self.model, (1, 1, self.H, self.W), device="cpu") - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if torch.cuda.device_count() > 1: - print("Let's use", torch.cuda.device_count(), "GPUs!") - # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs - self.model = nn.DataParallel(self.model) #Default all devices - - self.model = self.model.to(self.device) - - print("Setting up mixed precision gradient calculation...") - self.scaler = torch.cuda.amp.GradScaler() - - print("Setting up metrics...") - self.metrics = {'losses':[],'val_losses':[], 'lrs':[], 'best_val_loss' : np.inf} - print(self.metrics) - - - def train(self): - tot_loss = 0.0 - loss_ph = 0.0 - - for i, (ft_images, phs) in enumerate(self.trainloader): - ft_images = ft_images.to(self.device) #Move everything to device - phs = phs.to(self.device) - - pred_phs = self.model(ft_images) #Forward pass - - #Compute losses - loss_p = self.criterion(pred_phs, phs, self.ntrain) #Monitor phase loss but only within support (which may not be same as true amp) - loss = loss_p #Use equiweighted amps and phase - - - #Zero current grads and do backprop - self.optimizer.zero_grad() - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - - tot_loss += loss.detach().item() - - loss_ph += loss_p.detach().item() - - #Update the LR according to the schedule -- CyclicLR updates each batch - self.scheduler.step() - self.metrics['lrs'].append(self.scheduler.get_last_lr()) - self.scaler.update() - - - #Divide cumulative loss by number of batches-- sli inaccurate because last batch is different size - self.metrics['losses'].append([tot_loss, loss_ph]) - - - def validate(self): - tot_val_loss = 0.0 - val_loss_ph = 0.0 - for j, (ft_images, phs) in enumerate(self.validloader): - ft_images = ft_images.to(self.device) - phs = phs.to(self.device) - pred_phs = self.model(ft_images) #Forward pass - - - val_loss_p = self.criterion(pred_phs,phs, self.nvalid) - val_loss = val_loss_p - - #try complex valued diff - #diff_real = pred_amps * torch.cos(pred_phs) - amps * torch.cos(phs) - #diff_imag = pred_amps * torch.sin(pred_phs) - amps * torch.sin(phs) - #val_loss = torch.mean(torch.abs(diff_real + diff_imag)) - - tot_val_loss += val_loss.detach().item() - val_loss_ph += val_loss_p.detach().item() - - self.metrics['val_losses'].append([tot_val_loss, val_loss_ph]) - - - self.saveMetrics(self.metrics, self.output_path, self.output_suffix) - #Update saved model if val loss is lower - - if(tot_val_loss < self.metrics['best_val_loss']): - print("Saving improved model after Val Loss improved from %.5f to %.5f" %(self.metrics['best_val_loss'],tot_val_loss)) - self.metrics['best_val_loss'] = tot_val_loss - self.updateSavedModel(self.model, self.output_path, self.output_suffix) - - @staticmethod - def customLoss(t1, t2, scaling): - return torch.sum(torch.mean(torch.abs(t1 - t2), axis=(-1, -2))) / scaling - - - @staticmethod - #Function to update saved model if validation loss is minimum - def updateSavedModel(model: ReconSmallPhaseModel, path: str, output_suffix: str=''): - if not os.path.isdir(path): - os.mkdir(path) - fname = path + '/best_model' + output_suffix + '.pth' - print("Saving best model as ", fname) - torch.save(model.module.state_dict(), fname) - - @staticmethod - def saveMetrics(metrics: dict, path: str, output_suffix: str=''): - np.savez(path + '/metrics' + output_suffix + '.npz', **metrics) - - - def run(self, epochs: int, output_frequency: int = 1 ): - for epoch in range (epochs): - - #Set model to train mode - self.model.train() - - #Training loop - self.train() - - #Switch model to eval mode - self.model.eval() - - #Validation loop - self.validate() - if epoch % output_frequency == 0: - print('Epoch: %d | FT | Train Loss: %.5f | Val Loss: %.5f' - %(epoch, self.metrics['losses'][-1][0], self.metrics['val_losses'][-1][0])) - print('Epoch: %d | Ph | Train Loss: %.3f | Val Loss: %.3f' - %(epoch, self.metrics['losses'][-1][1], self.metrics['val_losses'][-1][1])) - print('Epoch: %d | Ending LR: %.6f ' %(epoch, self.metrics['lrs'][-1][0])) - print() - - - def plotLearningRate(self, save_fname: str = None, show_fig: bool = True): - batches = np.linspace(0, len(self.metrics['lrs']), len(self.metrics['lrs'])+1) - epoch_list = batches / self.iters_per_epoch - - plt.plot(epoch_list[1:], self.metrics['lrs'], 'C3-') - plt.grid() - plt.ylabel("Learning rate") - plt.xlabel("Epoch") - - plt.tight_layout() - if save_fname is not None: - plt.savefig(save_fname) - if show_fig: - plt.show() - else: - plt.close() - -class Tester(): - def __init__(self, model: ReconSmallPhaseModel, batch_size: int, model_params_path: str): - - self.model = model - self.batch_size = batch_size - self.model_params_path = model_params_path - - self.model.load_state_dict(torch.load(self.model_params_path)) - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if torch.cuda.device_count() > 1: - print("Let's use", torch.cuda.device_count(), "GPUs!") - self.model = nn.DataParallel(self.model) #Default all devices - - self.model = self.model.to(self.device) - - - def setTestData(self, X_test: np.ndarray): - self.X_test = torch.tensor(X_test[:,None,...].astype('float32')) - self.test_data = TensorDataset(self.X_test) - - self.testloader = DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, num_workers=4) - - - def predictTestData(self, npz_save_path: str=None): - self.model.eval() - phs_eval = [] - for i, ft_images in enumerate(self.testloader): - ft_images = ft_images[0].to(self.device) - ph_eval = self.model(ft_images) - for j in range(ft_images.shape[0]): - phs_eval.append(ph_eval[j].detach().to("cpu").numpy()) - self.phs_eval = np.array(phs_eval).squeeze().astype('float32') - if npz_save_path is not None: - np.savez_compressed(npz_save_path, ph=self.phs_eval)#mag=self.amps_eval, ph=self.phs_eval) - #return self.amps_eval, self.phs_eval - return self.phs_eval - - def calcErrors(self, phs_true: np.ndarray, npz_save_path: str = None): - from skimage.metrics import mean_squared_error as mse - - - self.phs_true = phs_true - self.errors = [] - for i, (p1, p2) in enumerate(zip(self.phs_eval, self.phs_true)): - err2 = mse(p1, p2) - self.errors.append([err2]) - - self.errors = np.array(self.errors) - print("Mean errors in phase") - print(np.mean(self.errors, axis=0)) - - if npz_save_path is not None: - np.savez_compressed(npz_save_path, phs_err=self.errors[:,0]) - - return self.errors - - - - - \ No newline at end of file diff --git a/Continual Learning/Training/slack_update.py b/Continual Learning/Training/slack_update.py deleted file mode 100644 index 18030c5..0000000 --- a/Continual Learning/Training/slack_update.py +++ /dev/null @@ -1,13 +0,0 @@ - -from slack_sdk import WebClient - - -slack_bot_token = " " -slack_webclient = WebClient(token=slack_bot_token) - - -def post_message(msg): - slack_webclient.chat_postMessage(channel="#automated", text=msg) - -def post_figure(filename): - slack_webclient.files_upload(channels="#automated", file=filename) \ No newline at end of file diff --git a/Continual Learning/Training/small_1_ptychonn_model.py b/Continual Learning/Training/small_1_ptychonn_model.py deleted file mode 100644 index 7aa9d0a..0000000 --- a/Continual Learning/Training/small_1_ptychonn_model.py +++ /dev/null @@ -1,291 +0,0 @@ -import glob -import numpy as np -import torch - -import os -#os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3, 4, 5, 6, 7' -import sys - -#os.environ['EPICS_CA_ADDR_LIST'] = '164.54.128.27' - -import shutil -import custom_logger - -import helper_small_model as helper -from agx_update import agx_update -from slack_update import post_message -import time - -from scipy.stats import circmean - -import ssl -import epics - -ssl._create_default_https_context = ssl._create_unverified_context - -diffraction_downscaling = 1 - - - -# Basic paths -base_path = " " ## change here -out_path = f"{base_path}/workflow" ## with 4 probe modes -search_path = f"{base_path}/Training" ## this is the path to the preprocessed data for training -overall_path = out_path + f"/overall/small_downscaled_{diffraction_downscaling}" - -best_model_params_path = overall_path + f"/best_model.pth" - - - -# Default datasets for prediction (that the training does not access) -default_test_scan_indices=[] -default_test_datasets =[] -default_test_slice_nums =[] - - -""" default_test_scan_indices = ['scan247.npz']#'scan571.npz', 'scan572.npz'] -default_test_datasets = [f'{base_path}/Training/{scan}' for scan in default_test_scan_indices] -default_test_slice_nums = [200, 500] # this is random """ - - -# Basic training parameters -EPOCHS = 50 -NGPUS = torch.cuda.device_count() -BATCH_SIZE = NGPUS * 64 -LR = 1e-3 - - -def setupOutputDirectoryAndLogging(): - outer_train_iterations = len(glob.glob(f'{out_path}/small_downscaled_{diffraction_downscaling}_iteration_*')) - outer_train_iteration_this = outer_train_iterations + 1 - iteration_out_path = f'{out_path}/small_downscaled_{diffraction_downscaling}_iteration_{outer_train_iteration_this}' - - if not os.path.isdir(overall_path): - os.mkdir(overall_path) - if (not os.path.isdir(iteration_out_path)): - os.mkdir(iteration_out_path) - - - custom_logger.setupLogging(iteration_out_path) - return iteration_out_path - -def printBasicParams(): - print("base path", base_path) - print("out path", out_path) - print("search path", search_path) - print("overall results path", overall_path) - print("Default testing datasets (tested after each training phase)", default_test_scan_indices) - print("GPUs:", NGPUS, "Batch size:", BATCH_SIZE, "Learning rate:", LR) - - -def searchAndCheckDatafiles(): - datafiles_all = glob.glob(search_path + "/*.npz") - - # Read list of datafiles already incorporated into training - if os.path.isfile(f'{overall_path}/data_incorporated_list.txt'): - with open(f'{overall_path}/data_incorporated_list.txt', 'r') as f: - datafiles_incorporated = f.read().splitlines() - else: - datafiles_incorporated = [] - - if len(datafiles_incorporated) > 0: - print("Datafiles already included in model:") - for df in datafiles_incorporated: - print(df) - - datafiles_new = [] - for df in datafiles_all: - if df not in datafiles_incorporated: - continue_without_adding = False - # Excluding default test scans from training - if len(default_test_scan_indices) > 0: - for scan in default_test_scan_indices: - if scan in df: - continue_without_adding = True - - if not continue_without_adding: - datafiles_new.append(df) - if len(datafiles_new) == 0: - print("No new data detected.") - exit(0) - else: - print("Datafiles not previously included:") - for df in datafiles_new: - print(df) - return datafiles_incorporated, datafiles_new - -def testPredictionQualityForNewDatafiles(datafiles_new): - recon_model = helper.ReconSmallPhaseModel() - tester = helper.Tester(model=recon_model, batch_size=BATCH_SIZE, model_params_path=best_model_params_path) - for df in datafiles_new: - with np.load(df) as f: - X_test = np.array(f['reciprocal']) - positions_test = f['position'] - - realspace = np.array(f['real']) - phases = np.angle(realspace) - phase_mean = circmean(phases, low=-np.pi, high=np.pi) - Y_test = realspace * np.exp(-1j * phase_mean) - - print('Predicting for data in ', df) - - fname_prefix = df.split('/')[-1].removesuffix('.npz') - tester.setTestData(X_test) - - - phs_eval = tester.predictTestData(npz_save_path=iteration_out_path + '/preds_' + fname_prefix + '.npz') - - Y_ph_test = np.angle(Y_test) - - tester.calcErrors(Y_ph_test, npz_save_path=iteration_out_path + '/errs_' + fname_prefix + '.npz') - - n_plot = 5 - selected = np.random.randint(X_test.shape[0], size=5) - helper.plot_test_data(X_test[selected], Y_ph_test[selected], phs_eval[selected], - save_fname=iteration_out_path + '/test_imgs_' + fname_prefix + '.png', show_fig=False) - print() - - - - -def trainWithAdditionalData(datafiles_incorporated: list, datafiles_new: list, iteration_out_path: str, load_model_path: str=None): - print("Combining the training and test data for new training session.") - datafiles_train = datafiles_incorporated + datafiles_new - - X_train = [] - Y_train = [] - positions_train = [] - for df in datafiles_train: - print(df) - with np.load(df) as f: - try: - X_train.append(np.array(f['reciprocal'])) - except Exception as e: - print(e) - else: - positions_train.append(f['position']) - - realspace = np.array(f['real']) - phases = np.angle(realspace) - phase_mean = circmean(phases, low=-np.pi, high=np.pi) - Y_train.append(realspace * np.exp(-1j * phase_mean)) - - print("Shape of new training data is", np.shape(X_train)) - - shape12 = np.array(X_train).shape[-2:] - X_train = np.reshape(X_train, [-1, *shape12]) - Y_train = np.reshape(Y_train, [-1, *shape12]) - Y_ph_train = np.angle(Y_train) - - - print("After concatenating, shape of new training data is", np.shape(X_train)) - - print("Before downscaling, max of X_train is", np.max(X_train)) - X_train = np.floor(X_train / diffraction_downscaling) - print("After downscaling, max of X_train is", np.max(X_train)) - - # The actual training part - - print("Creating the training model...") - recon_model = helper.ReconSmallPhaseModel() - if load_model_path is not None: - print("Loading previous best model to initialize the training model.") - recon_model.load_state_dict(torch.load(best_model_params_path)) - - print("Initializing the training procedure...") - trainer = helper.Trainer(recon_model, batch_size=BATCH_SIZE, output_path=iteration_out_path, output_suffix='') - print("Setting training data...") - trainer.setTrainingData(X_train, Y_ph_train) - print("Setting optimization parameters...") - trainer.setOptimizationParams() - trainer.initModel() - - train_time = trainer.run(EPOCHS) - - #trainer.plotLearningRate(save_fname=iteration_out_path + '/learning_rate.png', show_fig=False) - #helper.plot_metrics(trainer.metrics, save_fname=iteration_out_path + '/metrics.png', show_fig=False) - - return datafiles_train, train_time - -def updateIncorporatedDataList(datafiles_train): - with open(f'{overall_path}/data_incorporated_list.txt', 'w') as f: - for df in datafiles_train: - print(df, file=f) - - -def runDefaultTests(iteration_out_path): - recon_model = helper.ReconSmallPhaseModel() - tester = helper.Tester(model=recon_model, batch_size=BATCH_SIZE, model_params_path=iteration_out_path + '/best_model.pth') - default_mean_pred_errors = [] - for df in default_test_datasets: - with np.load(df) as f: - X_test = np.array(f['reciprocal']) - positions_test = f['position'] - - realspace = np.array(f['real']) - phases = np.angle(realspace) - phase_mean = circmean(phases, low=-np.pi, high=np.pi) - Y_test = realspace * np.exp(-1j * phase_mean) - - print("Before downscaling, max of X_test is", np.max(X_test)) - X_test = np.floor(X_test / diffraction_downscaling) - print("After downscaling, max of X_test is", np.max(X_test)) - - print('Predicting for data in ', df) - - fname_prefix = df.split('/')[-1].removesuffix('.npz') - tester.setTestData(X_test) - - - phs_eval = tester.predictTestData(npz_save_path=iteration_out_path + '/preds_' + fname_prefix + '.npz') - - Y_ph_test = np.angle(Y_test) - - errors = tester.calcErrors(Y_ph_test, npz_save_path=iteration_out_path + '/errs_' + fname_prefix + '.npz') - - selected = default_test_slice_nums - #helper.plot_test_data(X_test[selected], Y_ph_test[selected], phs_eval[selected], - # save_fname=iteration_out_path + '/test_imgs_' + fname_prefix + '.png', show_fig=False) - print() - - default_mean_pred_errors.append(np.mean(errors, axis=0)) - return default_mean_pred_errors - - - - - -if __name__ == '__main__': - - iteration_out_path = setupOutputDirectoryAndLogging() - printBasicParams() - - datafiles_incorporated, datafiles_new = searchAndCheckDatafiles() - - print("Posting message to slack") - post_message(f"PtychoNN: Starting new training run for diffraction data ") - - epics.caput("26idbPBS:sft01:ph01:ao14.VAL", 0) - if os.path.isfile(best_model_params_path): - #default_mean_pred_errors = testPredictionQualityForNewDatafiles(datafiles_new) - datafiles_train, train_time = trainWithAdditionalData(datafiles_incorporated, datafiles_new, iteration_out_path, load_model_path=best_model_params_path) - else: - datafiles_train, train_time= trainWithAdditionalData(datafiles_incorporated, datafiles_new, iteration_out_path, load_model_path=None) - - runDefaultTests(iteration_out_path) - - updateIncorporatedDataList(datafiles_train) - - #print("Total Training Time %.3f hr" %(int(train_time)/(60*60))) - print("Copying new best model to ", overall_path) - shutil.copy(iteration_out_path + '/best_model.pth', best_model_params_path) - - print("Updating agx") - epics.caput("26idbPBS:sft01:ph01:ao14.VAL", 1) - agx_update(best_model_params_path) - - epics.caput("26idbPBS:sft01:ph01:ao14.VAL", 2) - print("Posting message to slack") - post_message("PtychoNN: Pushed new model to AGX.") - post_message(f"PtychoNN: Completed training run ") - \ No newline at end of file diff --git a/ptychonn/_infer/__main__.py b/ptychonn/_infer/__main__.py index eb51f2a..c9200e6 100644 --- a/ptychonn/_infer/__main__.py +++ b/ptychonn/_infer/__main__.py @@ -1,10 +1,9 @@ -import importlib.resources import pathlib import typing +import glob from torch.utils.data import TensorDataset, DataLoader import click -import h5py import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt @@ -43,8 +42,8 @@ def stitch_from_inference( stitched : (COMBINED_WIDTH, COMBINED_HEIGHT) np.array The stitched together image. ''' - pos_x = scan[..., 0] - pos_y = scan[..., 1] + pos_x = scan[..., 1] + pos_y = scan[..., 0] # The global axes of the stitched image in meters x = np.arange(pos_x.min(), @@ -85,14 +84,14 @@ def stitch_from_inference( @click.command(name='infer') @click.argument( - 'data_path', + 'data_dir', type=click.Path( exists=True, path_type=pathlib.Path, ), ) @click.argument( - 'scan_path', + 'params_path', type=click.Path( exists=True, path_type=pathlib.Path, @@ -106,8 +105,8 @@ def stitch_from_inference( ), ) def infer_cli( - data_path: pathlib.Path, - scan_path: pathlib.Path, + data_dir: pathlib.Path, + params_path: pathlib.Path, out_dir: pathlib.Path, ): '''Infer a reconstructed image from diffraction patterns at DATA_PATH and @@ -115,52 +114,71 @@ def infer_cli( OUT_DIR. ''' - inferences_out_file = out_dir / 'inferences_506.npz' - click.echo(f'Does data path exist? {data_path.exists()}') + dataslist = [] + scanlist = [] - with h5py.File(data_path) as f: - data = f['entry/data/data'][()] + for name in glob.glob(str(data_dir / '*.npz')): + print(name) + with np.load(name) as f: + dataslist.append(f['reciprocal']) + scanlist.append(f['scan']) + + data = np.concatenate(dataslist, axis=0) + scan = np.concatenate(scanlist, axis=0) inferences = infer( data=data, - inferences_out_file=inferences_out_file, + model_params_path=params_path, ) - ## parameters required for stitching individual inferences - spiral_step = 0.05 - step = spiral_step * -1e-6 - spiral_traj = np.load(scan_path) - scan = np.stack((spiral_traj['x'], spiral_traj['y']), axis=-1) * step - stitched = stitch_from_inference( - inferences, + pstitched = stitch_from_inference( + inferences[:, 0], + scan, + stitched_pixel_width=1, + inference_pixel_width=1, + ) + astitched = stitch_from_inference( + inferences[:, 1], scan, - stitched_pixel_width=10e-9, - inference_pixel_width=10e-9, + stitched_pixel_width=1, + inference_pixel_width=1, ) # Plotting some summary images plt.figure(1, figsize=[8.5, 7]) - plt.pcolormesh(stitched) + plt.imshow(pstitched) + plt.colorbar() + plt.tight_layout() + plt.title('stitched_phases') + plt.savefig(out_dir / 'pstitched.png', bbox_inches='tight') + + plt.figure(2, figsize=[8.5, 7]) + plt.imshow(astitched) plt.colorbar() plt.tight_layout() - plt.title('stitched_inferences') - plt.savefig(out_dir / 'stitched_506.png', bbox_inches='tight') + plt.title('stitched_amplitudes') + plt.savefig(out_dir / 'astitched.png', bbox_inches='tight') test_inferences = [0, 1, 2, 3] fig, axs = plt.subplots(1, 4, figsize=[13, 3]) for ix, inf in enumerate(test_inferences): - plt.subplot(1, 4, ix + 1) - plt.pcolormesh(inferences[inf]) + plt.subplot(2, 4, ix + 1) + plt.pcolormesh(inferences[inf, 0]) + plt.colorbar() + plt.title('Inference at position {0}'.format(inf)) + plt.subplot(2, 4, 4 + ix + 1) + plt.pcolormesh(inferences[inf, 1]) plt.colorbar() plt.title('Inference at position {0}'.format(inf)) plt.tight_layout() - plt.savefig(out_dir / 'inferences_0_to_4_scan506.png', bbox_inches='tight') + plt.savefig(out_dir / 'inferences.png', bbox_inches='tight') return 0 def infer( data: npt.NDArray, + model_params_path: pathlib.Path, *, inferences_out_file: typing.Optional[pathlib.Path] = None, ) -> npt.NDArray: @@ -181,10 +199,13 @@ def infer( Returns ------- - inferences : (POSITION, WIDTH, HEIGHT) + inferences : (POSITION, 2, WIDTH, HEIGHT) The reconstructed patches inferred by the model. ''' - tester = Tester() + tester = Tester( + model=ptychonn.model.ReconSmallModel(), + model_params_path=model_params_path, + ) tester.setTestData( data, batch_size=max(torch.cuda.device_count(), 1) * 64, @@ -193,56 +214,54 @@ def infer( class Tester(): - ''' - ''' def __init__( self, *, - model: typing.Optional[torch.nn.Module] = None, - model_params_path: typing.Optional[pathlib.Path] = None, + model: torch.nn.Module, + model_params_path: pathlib.Path, ): - self.model = ptychonn.model.ReconSmallPhaseModel( - ) if model is None else model - - if model_params_path is None: - with importlib.resources.path( - 'ptychonn._infer', - 'weights.pth', - ) as model_params_path: - self.model.load_state_dict(torch.load(model_params_path)) - else: - self.model.load_state_dict(torch.load(model_params_path)) - self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") print(f"Let's use {torch.cuda.device_count()} GPUs!") - if torch.cuda.device_count() > 1: - self.model = torch.nn.DataParallel(self.model) + self.model = model + + params = torch.load( + model_params_path, + map_location=self.device, + ) + self.model.load_state_dict(params) + + self.model = torch.nn.DataParallel(self.model) + + self.model.to(self.device) - self.model = self.model.to(self.device) + self.model.eval() def setTestData(self, X_test: np.ndarray, batch_size: int): - self.X_test = torch.tensor(X_test[:, None, ...].astype('float32')) + self.X_test = torch.tensor(X_test[:, None, ...], dtype=torch.float32) self.test_data = TensorDataset(self.X_test) - self.testloader = DataLoader(self.test_data, - batch_size=batch_size, - shuffle=False, - num_workers=4) + self.testloader = DataLoader( + self.test_data, + batch_size=batch_size, + shuffle=False, + ) def predictTestData(self, npz_save_path: str = None): - self.model.eval() + phs_eval = [] - for i, ft_images in enumerate(self.testloader): - ft_images = ft_images[0].to(self.device) - ph_eval = self.model(ft_images) - for j in range(ft_images.shape[0]): - phs_eval.append(ph_eval[j].detach().to("cpu").numpy()) - self.phs_eval = np.array(phs_eval).squeeze().astype('float32') + with torch.inference_mode(): + for (ft_images, ) in self.testloader: + ph_eval = self.model(ft_images.to(self.device)) + phs_eval.append(ph_eval.detach().cpu().numpy()) + + self.phs_eval = np.concatenate(phs_eval, axis=0) + if npz_save_path is not None: np.savez_compressed(npz_save_path, ph=self.phs_eval) print(f'Finished the inference stage and saved at {npz_save_path}') + return self.phs_eval def calcErrors(self, phs_true: np.ndarray, npz_save_path: str = None): diff --git a/ptychonn/_infer/weights.pth b/ptychonn/_infer/weights.pth deleted file mode 100644 index 69847e6..0000000 Binary files a/ptychonn/_infer/weights.pth and /dev/null differ diff --git a/ptychonn/_train/__main__.py b/ptychonn/_train/__main__.py index 0a56583..7138a49 100644 --- a/ptychonn/_train/__main__.py +++ b/ptychonn/_train/__main__.py @@ -36,9 +36,15 @@ path_type=pathlib.Path, ), ) +@click.option( + '--epochs', + type=click.INT, + default=100, +) def train_cli( data_dir: pathlib.Path, out_dir: pathlib.Path, + epochs: int, ): """Train a model from diffraction patterns and reconstructed patches. @@ -67,10 +73,12 @@ def train_cli( # centering of the phase in the center 3rd of the reconstructed patches. # The diffraction patterns are converted to float32 and otherwise # unaltered. - patches = np.angle(patches).astype('float32') - patches -= np.mean( - patches[..., patches.shape[-2] // 3:-patches.shape[-2] // 3, - patches.shape[-1] // 3:-patches.shape[-1] // 3], ) + phase = np.angle(patches).astype('float32') + phase -= np.mean( + phase[..., phase.shape[-2] // 3:-phase.shape[-2] // 3, + phase.shape[-1] // 3:-phase.shape[-1] // 3], ) + amplitude = np.abs(patches).astype('float32') + patches = np.stack((phase, amplitude), axis=1) os.makedirs(out_dir, exist_ok=True) @@ -78,8 +86,8 @@ def train_cli( X_train=data, Y_train=patches, out_dir=out_dir, - epochs=50, - batch_size=64, + epochs=epochs, + batch_size=32, ) @@ -89,7 +97,7 @@ def train( out_dir: pathlib.Path | None, load_model_path: pathlib.Path | None = None, epochs: int = 1, - batch_size: int = 64, + batch_size: int = 32, ): """Train a PtychoNN model. @@ -97,7 +105,7 @@ def train( ---------- X_train (N, WIDTH, HEIGHT) The diffraction patterns. - Y_train (N, WIDTH, HEIGHT) + Y_train (N, 2, WIDTH, HEIGHT) The corresponding reconstructed patches for the diffraction patterns. out_dir A folder where all the training artifacts are saved. @@ -111,7 +119,7 @@ def train( logger.info("Creating the training model...") trainer = Trainer( - model=ptychonn.model.ReconSmallPhaseModel(), + model=ptychonn.model.ReconSmallModel(), batch_size=batch_size * torch.cuda.device_count(), output_path=out_dir, ) @@ -130,12 +138,12 @@ def train( if out_dir is not None: trainer.plotLearningRate( - save_fname=out_dir / 'learning_rate.svg', + save_fname=out_dir / 'learning_rate.png', show_fig=False, ) ptychonn.plot.plot_metrics( trainer.metrics, - save_fname=out_dir / 'metrics.svg', + save_fname=out_dir / 'metrics.png', show_fig=False, ) @@ -168,7 +176,7 @@ class Trainer(): def __init__( self, - model: ptychonn.model.ReconSmallPhaseModel, + model: ptychonn.model.ReconSmallModel, batch_size: int, output_path: pathlib.Path | None = None, output_suffix: str = '', @@ -186,6 +194,22 @@ def setTrainingData( Y_ph_train_full: np.ndarray, valid_data_ratio: float = 0.1, ): + """ + + Parameters + ---------- + X_train_full : (N, H, W) + The measured intensities at the detector + Y_ph_train_full : (N, C, H, W) + The phase and amplitude patches from the reconstructed object. + Phase in the zeroth channel and amplitude (optionally) in the first + channel + + """ + if (Y_ph_train_full.ndim != 4): + msg = ("Training data example patches must have a channel " + "dimension! i.e. the shape should be (N, C, H, W)") + raise ValueError(msg) logger.info("Setting training data...") self.H, self.W = X_train_full.shape[-2:] @@ -195,7 +219,7 @@ def setTrainingData( dtype=torch.float32, ) self.Y_ph_train_full = torch.tensor( - Y_ph_train_full[:, None, ...], + Y_ph_train_full, dtype=torch.float32, ) self.ntrain_full = self.X_train_full.shape[0] @@ -269,20 +293,28 @@ def setOptimizationParams( mode='triangular2', ) - def initModel(self, model_params_path: pathlib.Path | None = None): + def initModel( + self, + model_params_path: pathlib.Path | None = None, + ): """Load parameters from the disk then model to the GPU(s).""" - self.model_params_path = model_params_path - if model_params_path is not None: - self.model.load_state_dict(torch.load(self.model_params_path)) - torchinfo.summary(self.model, (1, 1, self.H, self.W), device="cpu") - self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") print(f"Let's use {torch.cuda.device_count()} GPUs!") - if torch.cuda.device_count() > 1: - self.model = torch.nn.DataParallel(self.model) + torchinfo.summary(self.model, (1, 1, self.H, self.W), device="cpu") + + self.model_params_path = model_params_path + + if model_params_path is not None: + self.model.load_state_dict( + torch.load( + self.model_params_path, + map_location=self.device, + )) + + self.model = torch.nn.DataParallel(self.model) self.model = self.model.to(self.device) @@ -372,14 +404,13 @@ def validate(self, epoch: int): self.output_suffix, ) - import tifffile + import matplotlib.pyplot as plt os.makedirs(self.output_path / 'reference', exist_ok=True) os.makedirs(self.output_path / 'inference', exist_ok=True) - tifffile.imwrite( - self.output_path / f'reference/{epoch:05d}.tiff', - phs[0, 0].detach().cpu().numpy().astype(np.float32)) - tifffile.imwrite( - self.output_path / f'inference/{epoch:05d}.tiff', + plt.imsave(self.output_path / f'reference/{epoch:05d}.png', + phs[0, 0].detach().cpu().numpy().astype(np.float32)) + plt.imsave( + self.output_path / f'inference/{epoch:05d}.png', pred_phs[0, 0].detach().cpu().numpy().astype(np.float32)) @staticmethod @@ -400,7 +431,7 @@ def customLoss( @staticmethod def updateSavedModel( - model: ptychonn.model.ReconSmallPhaseModel, + model: ptychonn.model.ReconSmallModel, directory: pathlib.Path, suffix: str = '', ): @@ -453,7 +484,8 @@ def run(self, epochs: int, output_frequency: int = 1): self.model.eval() #Validation loop - self.validate(epoch) + with torch.inference_mode(): + self.validate(epoch) if epoch % output_frequency == 0: logger.info( diff --git a/ptychonn/model.py b/ptychonn/model.py index e5a808f..0554108 100644 --- a/ptychonn/model.py +++ b/ptychonn/model.py @@ -1,40 +1,68 @@ +"""Define PtychoNN Pytorch models.""" + import numpy as np import torch import torch.nn as nn -class ReconSmallPhaseModel(nn.Module): +class ReconSmallModel(nn.Module): + """A small PychoNN model. + + Parameters + ---------- + nconv : + The number of convolution kernels at the smallest level + use_batch_norm : + Whether to use batch normalization after convolution layers + enable_amplitude : + Whether the amplitude branch is included in the model - def __init__(self, nconv: int = 16, use_batch_norm=False): - super(ReconSmallPhaseModel, self).__init__() + Shapes + ------ + input : (N, 1, H, W) + The measured intensity of the diffraction patterns + output : (N, C, H, W) + The phase (and amplitude if C is 2) in the patch of the object + """ + + def __init__( + self, + nconv: int = 16, + use_batch_norm: bool = True, + enable_amplitude: bool = True, + ): + super().__init__() self.nconv = nconv self.use_batch_norm = use_batch_norm + self.enable_amplitude = enable_amplitude - self.encoder = nn.Sequential( # Appears sequential has similar functionality as TF avoiding need for separate model definition and activ + # Appears sequential has similar functionality as TF avoiding need for + # separate model definition and activ + self.encoder = nn.Sequential( *self.down_block(1, self.nconv), *self.down_block(self.nconv, self.nconv * 2), *self.down_block(self.nconv * 2, self.nconv * 4), *self.down_block(self.nconv * 4, self.nconv * 8), ) - # amplitude model - #self.decoder1 = nn.Sequential( - # *self.up_block(self.nconv * 8, self.nconv * 8), - # *self.up_block(self.nconv * 8, self.nconv * 4), - # *self.up_block(self.nconv * 4, self.nconv * 2), - # *self.up_block(self.nconv * 2, self.nconv * 1), - # nn.Conv2d(self.nconv * 1, 1, 3, stride=1, padding=(1,1)), - #) - - # phase model - self.decoder2 = nn.Sequential( - *self.up_block(self.nconv * 8, self.nconv * 8), - *self.up_block(self.nconv * 8, self.nconv * 4), - *self.up_block(self.nconv * 4, self.nconv * 2), - *self.up_block(self.nconv * 2, self.nconv * 1), - nn.Conv2d(self.nconv * 1, 1, 3, stride=1, padding=(1, 1)), - *((nn.BatchNorm2d(1),) if self.use_batch_norm else ()), - nn.Tanh(), + # Double the number of channels when doing both phase and amplitude, + # but keep them separate with grouping + c = 2 if self.enable_amplitude else 1 + self.decoder = nn.Sequential( + *self.up_block(self.nconv * 8 * 1, self.nconv * 8 * c, groups=c), + *self.up_block(self.nconv * 8 * c, self.nconv * 4 * c, groups=c), + *self.up_block(self.nconv * 4 * c, self.nconv * 2 * c, groups=c), + *self.up_block(self.nconv * 2 * c, self.nconv * 1 * c, groups=c), + nn.Conv2d( + in_channels=self.nconv * 1 * c, + out_channels=c, + kernel_size=3, + stride=1, + padding=(1, 1), + bias=(not self.use_batch_norm), + groups=c, + ), + *((nn.BatchNorm2d(c), ) if self.use_batch_norm else ()), ) def down_block(self, filters_in, filters_out): @@ -45,39 +73,57 @@ def down_block(self, filters_in, filters_out): kernel_size=3, stride=1, padding=(1, 1), - ), - *((nn.BatchNorm2d(filters_out),) if self.use_batch_norm else ()), + bias=(not self.use_batch_norm), + ), *((nn.BatchNorm2d(filters_out), ) if self.use_batch_norm else + ()), nn.ReLU(), - nn.Conv2d(filters_out, filters_out, 3, stride=1, padding=(1, 1)), - *((nn.BatchNorm2d(filters_out),) if self.use_batch_norm else ()), + nn.Conv2d( + in_channels=filters_out, + out_channels=filters_out, + kernel_size=3, + stride=1, + padding=(1, 1), + bias=(not self.use_batch_norm), + ), *((nn.BatchNorm2d(filters_out), ) if self.use_batch_norm else + ()), nn.ReLU(), nn.MaxPool2d((2, 2)) ] - def up_block(self, filters_in, filters_out): + def up_block(self, filters_in: int, filters_out: int, groups: int): return [ nn.Conv2d( - filters_in, - filters_out, - 3, + in_channels=filters_in, + out_channels=filters_out, + kernel_size=3, stride=1, padding=(1, 1), - ), - *((nn.BatchNorm2d(filters_out),) if self.use_batch_norm else ()), + bias=(not self.use_batch_norm), + groups=groups, + ), *((nn.BatchNorm2d(filters_out), ) if self.use_batch_norm else + ()), nn.ReLU(), - nn.Conv2d(filters_out, filters_out, 3, stride=1, padding=(1, 1)), - *((nn.BatchNorm2d(filters_out),) if self.use_batch_norm else ()), + nn.Conv2d( + in_channels=filters_out, + out_channels=filters_out, + kernel_size=3, + stride=1, + padding=(1, 1), + bias=(not self.use_batch_norm), + groups=groups, + ), *((nn.BatchNorm2d(filters_out), ) if self.use_batch_norm else + ()), nn.ReLU(), nn.Upsample(scale_factor=2, mode='bilinear') ] def forward(self, x): with torch.cuda.amp.autocast(): - x1 = self.encoder(x) - #amplitude = self.decoder1(x1) - phase = self.decoder2(x1) - - #Restore -pi to pi range - phase = phase * np.pi #Using tanh activation (-1 to 1) for phase so multiply by pi - - return phase + output = self.decoder(self.encoder(x)) + # Restore -pi to pi range + # Using tanh activation (-1 to 1) for phase so multiply by pi + output[..., 0, :, :] = torch.tanh(output[..., 0, :, :]) * np.pi + # Restrict amplitude to (0, 1) range with sigmoid + if self.enable_amplitude: + output[..., 1, :, :] = torch.sigmoid(output[..., 1, :, :]) + return output diff --git a/requirements-dev b/requirements-dev index ed77e7c..4a8ce60 100644 --- a/requirements-dev +++ b/requirements-dev @@ -1,12 +1,9 @@ click -h5py importlib-metadata -importlib-resources matplotlib numpy ~=1.21 -pytorch ~=1.12 +pytorch >=1.12,<2.1 scikit-image scipy -tifffile tqdm torchinfo diff --git a/setup.cfg b/setup.cfg index 2678280..6183aee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,15 +6,12 @@ include_package_data = True packages = find: install_requires = click - h5py importlib-metadata; python_version < "3.8" - importlib-resources; python_version < "3.7" matplotlib numpy ~=1.21 - torch ~=1.12 + torch >=1.12,<2.1 scikit-image scipy - tifffile tqdm torchinfo diff --git a/tests/construct_test_data.py b/tests/construct_test_data.py index c8073cd..53dfa88 100644 --- a/tests/construct_test_data.py +++ b/tests/construct_test_data.py @@ -1,18 +1,18 @@ import libimage import numpy as np -import tifffile +import matplotlib.pyplot as plt import tike.ptycho import tike.ptycho.learn -def test_construct_simulated_training_set(W=2048, N=1024, S=128): +def test_construct_simulated_training_set(W=2048, N=2048, S=128): phase = libimage.load('coins', W) - 0.5 amplitude = 1 - libimage.load('earring', W) fov = (amplitude * np.exp(1j * phase * np.pi)).astype('complex64') - tifffile.imwrite('phase.tiff', np.angle(fov)) - tifffile.imwrite('amplitude.tiff', np.abs(fov)) - assert np.abs(fov).max() <= 1.0 - assert np.abs(fov).min() >= 0.0 + plt.imsave('phase.png', np.angle(fov)) + plt.imsave('amplitude.png', np.abs(fov)) + assert np.abs(fov).max() <= 1.0 + 1e-4, np.abs(fov).max() + assert np.abs(fov).min() >= 0.0, np.abs(fov).min() scan = np.random.uniform(1, W - S - 1, (N, 2)) @@ -36,7 +36,7 @@ def test_construct_simulated_training_set(W=2048, N=1024, S=128): axes=(-2, -1), ).astype('float32') print(diffraction.dtype, diffraction.shape) - tifffile.imwrite('diffraction.tiff', diffraction[N // 2]) + plt.imsave('diffraction.png', np.log10(diffraction[N // 2])) print(f'Training params = {np.prod(diffraction.shape)}') @@ -44,6 +44,7 @@ def test_construct_simulated_training_set(W=2048, N=1024, S=128): 'simulated_data.npz', reciprocal=diffraction, real=patches, + scan=scan, ) if __name__ == '__main__': diff --git a/tests/test_ptychonn.py b/tests/test_ptychonn.py index 79a3b9c..4e04b17 100644 --- a/tests/test_ptychonn.py +++ b/tests/test_ptychonn.py @@ -3,25 +3,36 @@ import pathlib from click.testing import CliRunner -import torch +import ptychonn import ptychonn.model from ptychonn import __main__ as cli _test_dir = pathlib.Path(__file__).resolve().parents[0] + def test_command_line_interface(): """Test the CLI.""" runner = CliRunner() result = runner.invoke(cli.main) assert result.exit_code == 0 - assert 'ptychonn.cli.main' in result.output help_result = runner.invoke(cli.main, ['--help']) assert help_result.exit_code == 0 assert 'Show this message and exit.' in help_result.output def test_load_weights(): - model = ptychonn.model.ReconSmallPhaseModel(16, False) - model.load_state_dict(torch.load(_test_dir / 'weights.pth')) - print(model) + model = ptychonn.ReconSmallModel() + trainer = ptychonn.Trainer( + model=model, + batch_size=32, + ) + trainer.updateSavedModel( + trainer.model, + _test_dir, + suffix='', + ) + tester = ptychonn.Tester( + model=model, + model_params_path=_test_dir / 'best_model.pth', + ) diff --git a/tests/weights.pth b/tests/weights.pth deleted file mode 100644 index 69847e6..0000000 Binary files a/tests/weights.pth and /dev/null differ