diff --git a/ptychonn/position/core.py b/ptychonn/position/core.py index a8972f1..62a0a59 100644 --- a/ptychonn/position/core.py +++ b/ptychonn/position/core.py @@ -98,8 +98,8 @@ def estimate(self): class PtychoNNProbePositionCorrector: - def __init__(self, config_dict: InferenceConfig): - self.config_dict = config_dict + def __init__(self, configs: InferenceConfig): + self.configs = configs self.ptycho_reconstructor = VirtualReconstructor(InferenceConfig()) self.dp_data_fhdl = None self.orig_probe_positions = None @@ -114,14 +114,14 @@ def __init__(self, config_dict: InferenceConfig): self.unregistered_indices = [] def build(self): - if self.config_dict.random_seed is not None: + if self.configs.random_seed is not None: logging.debug( - "Random seed is set to {}.".format(self.config_dict.random_seed) + "Random seed is set to {}.".format(self.configs.random_seed) ) - np.random.seed(self.config_dict.random_seed) + np.random.seed(self.configs.random_seed) assert isinstance(self.ptycho_reconstructor, VirtualReconstructor) - recons = tifffile.imread(self.config_dict.reconstruction_image_path) + recons = tifffile.imread(self.configs.reconstruction_image_path) self.ptycho_reconstructor.set_object_image_array(recons) self.ptycho_reconstructor.build() @@ -135,25 +135,25 @@ def build(self): self.orig_probe_positions = ProbePositionList( position_list=np.zeros([self.n_dps, 2]) ) - if not self.config_dict.probe_position_list: - if self.config_dict.probe_position_data_path: + if not self.configs.probe_position_list: + if self.configs.probe_position_data_path: self.orig_probe_positions = ProbePositionList( - file_path=self.config_dict.probe_position_data_path, - unit=self.config_dict.probe_position_data_unit, - psize_nm=self.config_dict.pixel_size_nm, + file_path=self.configs.probe_position_data_path, + unit=self.configs.probe_position_data_unit, + psize_nm=self.configs.pixel_size_nm, ) else: - self.orig_probe_positions = self.config_dict.probe_position_list + self.orig_probe_positions = self.configs.probe_position_list self.new_probe_positions = self.orig_probe_positions.copy_with_zeros() self.registrator = Registrator( - self.config_dict.registration_params, - random_seed=self.config_dict.random_seed, + self.configs.registration_params, + random_seed=self.configs.random_seed, ) if ( - self.config_dict.rectangular_grid - and self.config_dict.registration_params.use_baseline_offsets_for_points_on_same_row + self.configs.rectangular_grid + and self.configs.registration_params.use_baseline_offsets_for_points_on_same_row ): self.build_row_index_list() @@ -174,14 +174,14 @@ def build_row_index_list(self): self.row_index_list = np.array(row_inds) def run(self): - if self.config_dict.method == "serial": + if self.configs.method == "serial": self.run_probe_position_correction_serial() - elif self.config_dict.method == "collective": + elif self.configs.method == "collective": self.run_probe_position_correction_collective() else: raise ValueError( "Correction method {} is not supported. ".format( - self.config_dict.method + self.configs.method ) ) @@ -198,7 +198,7 @@ def reconstruct_dp(self, ind=None, dp=None): obj_amp, obj_ph = self.ptycho_reconstructor.batch_infer( dp[np.newaxis, :, :] ) - # if self.config_dict.debug: + # if self.configs.debug: # fig, ax = plt.subplots(1, 3) # ax[0].imshow(dp) # ax[0].set_title('DP') @@ -211,9 +211,9 @@ def reconstruct_dp(self, ind=None, dp=None): return obj_amp, obj_ph def crop_center(self, image_list): - if self.config_dict.central_crop is None: + if self.configs.central_crop is None: return image_list - crop_shape = self.config_dict.central_crop + crop_shape = self.configs.central_crop for i in range(len(image_list)): orig_shape = image_list[i].shape[1:] start_point = [(orig_shape[j] - crop_shape[j]) // 2 for j in range(2)] @@ -230,15 +230,15 @@ def run_probe_position_correction_serial(self): Run serial mode probe position correction. """ offset_tracker = OffsetEstimator( - beta=self.config_dict.offset_estimator_beta, - order=self.config_dict.offset_estimator_order, + beta=self.configs.offset_estimator_beta, + order=self.configs.offset_estimator_order, ) previous_obj = self.reconstruct_dp(0)[1][0] for ind in trange(1, self.n_dps): current_obj = self.reconstruct_dp(ind)[1][0] offset = self.registrator.run(previous_obj, current_obj) if ( - self.config_dict.registration_params.use_baseline_offsets_for_uncertain_pairs + self.configs.registration_params.use_baseline_offsets_for_uncertain_pairs and self.registrator.get_status() == self.registrator.get_status_code("empty") ): @@ -251,7 +251,7 @@ def run_probe_position_correction_serial(self): ): offset = offset_tracker.estimate() self.count_bad_offset += 1 - if self.config_dict.debug: + if self.configs.debug: fig, ax = plt.subplots(1, 2) ax[0].imshow(previous_obj) ax[1].imshow(current_obj) @@ -270,7 +270,7 @@ def run_probe_position_correction_collective(self): self.build_linear_system_for_collective_correction() self.solve_linear_system( mode="residue", - smooth_constraint_weight=self.config_dict.smooth_constraint_weight, + smooth_constraint_weight=self.configs.smooth_constraint_weight, ) self.postprocess() @@ -284,7 +284,7 @@ def get_neightbor_inds(self, i_dp, knn_inds): else: this_neighbors_inds = [i_dp - 1, i_dp + 1] i_knn = 0 - while len(this_neighbors_inds) < self.config_dict.num_neighbors_collective: + while len(this_neighbors_inds) < self.configs.num_neighbors_collective: if knn_inds[i_knn] not in this_neighbors_inds: this_neighbors_inds.append(knn_inds[i_knn]) i_knn += 1 @@ -297,14 +297,14 @@ def build_linear_system_for_collective_correction(self): # registered with O(1) time. registered_pair_hash_table = collections.defaultdict(lambda: None) nn_engine = sklearn.neighbors.NearestNeighbors( - n_neighbors=self.config_dict.num_neighbors_collective + 1 + n_neighbors=self.configs.num_neighbors_collective + 1 ) nn_engine.fit(self.orig_probe_positions.array) nn_dists, nn_inds = nn_engine.kneighbors(self.orig_probe_positions.array) for i_dp, this_orig_pos in enumerate(tqdm(self.orig_probe_positions.array)): - if self.config_dict.registration_params.registration_tol_schedule: + if self.configs.registration_params.registration_tol_schedule: self.registrator.update_tol_by_tol_schedule( - i_dp, self.config_dict.registration_params.registration_tol_schedule + i_dp, self.configs.registration_params.registration_tol_schedule ) this_knn_inds = nn_inds[i_dp, 1:] this_neighbors_inds = self.get_neightbor_inds(i_dp, this_knn_inds) @@ -329,7 +329,7 @@ def build_linear_system_for_collective_correction(self): )[1][0] if ( - self.config_dict.registration_params.use_baseline_offsets_for_points_on_same_row + self.configs.registration_params.use_baseline_offsets_for_points_on_same_row and self.row_index_list is not None and self.row_index_list[i_dp] == self.row_index_list[ind_neighbor] ): @@ -349,7 +349,7 @@ def build_linear_system_for_collective_correction(self): offset = self.registrator.run(neighbor_obj, current_obj) # print('{} - {}: {}'.format(i_dp, ind_neighbor, offset)) if ( - self.config_dict.debug + self.configs.debug and self.registrator.get_status() != self.registrator.get_status_code("empty") ): @@ -363,7 +363,7 @@ def build_linear_system_for_collective_correction(self): plt.tight_layout() print("Offset: {}".format(offset)) if ( - self.config_dict.registration_params.use_baseline_offsets_for_uncertain_pairs + self.configs.registration_params.use_baseline_offsets_for_uncertain_pairs and self.registrator.get_status() == self.registrator.get_status_code("empty") ): @@ -382,7 +382,7 @@ def build_linear_system_for_collective_correction(self): self.a_mat.append(self._generate_amat_row(i_dp, ind_neighbor)) self.a_mat = np.stack(self.a_mat) self.b_vec = np.stack(self.b_vec) - if self.config_dict.registration_params.use_baseline_offsets_for_unregistered_points: + if self.configs.registration_params.use_baseline_offsets_for_unregistered_points: self.fill_gaps_in_linear_system() def fill_gaps_in_linear_system(self): @@ -437,7 +437,7 @@ def solve_linear_system(self, mode="residue", smooth_constraint_weight=1e-3): self.new_probe_positions.array = np.linalg.pinv(a_mat) @ b_vec def postprocess(self): - if self.config_dict.registration_params.use_baseline_offsets_for_unregistered_points: + if self.configs.registration_params.use_baseline_offsets_for_unregistered_points: for ind in self.unregistered_indices: if ind > 0: baseline_offset = ( @@ -487,12 +487,12 @@ def _update_probe_position_list(self, ind, offset): class ProbePositionCorrectorChain: - def __init__(self, config_dict: InferenceConfig): - self.config_dict = config_dict + def __init__(self, configs: InferenceConfig): + self.configs = configs self.corrector_list = [] self.multiiter_keys = [] self.n_iters = 1 - self.baseline_pos_list = config_dict.baseline_position_list + self.baseline_pos_list = configs.baseline_position_list self.collective_mode_offset_tol = 150 self.verbose = True self.redone_with_baseline = False @@ -513,10 +513,10 @@ def find_multiiter_keys(self, config_obj): def build_multiiter_entries(self): self.has_multiiter_key = False - self.find_multiiter_keys(self.config_dict) + self.find_multiiter_keys(self.configs) if not self.has_multiiter_key: raise ValueError( - "With ProbePositionCorrectorChain, there should be at least one entry in the config dict " + "With ProbePositionCorrectorChain, there should be at least one entry in configs.__dict__ " 'that ends with "_multiiter" and is a list whose length equals to the desired number of ' "iterations. " ) @@ -527,15 +527,15 @@ def run(self): def run_correction_iteration(self, iter): logging.debug("Now running iteration {}.".format(iter)) - self.update_config_dict(iter) + self.update_configs(iter) if self.verbose: - print(self.config_dict) - corrector = PtychoNNProbePositionCorrector(config_dict=self.config_dict) + print(self.configs) + corrector = PtychoNNProbePositionCorrector(configs=self.configs) corrector.build() if self.verbose: corrector.orig_probe_positions.plot() corrector.run() - if self.config_dict.method == "collective" and ( + if self.configs.method == "collective" and ( not self.is_collective_result_good(corrector) ): # Redo iteration using baseline as initialization if result is bad @@ -544,8 +544,8 @@ def run_correction_iteration(self, iter): "redo this iteration with baseline positions as initialization..." ) if self.baseline_pos_list: - self.config_dict.probe_position_list = self.baseline_pos_list - corrector = PtychoNNProbePositionCorrector(config_dict=self.config_dict) + self.configs.probe_position_list = self.baseline_pos_list + corrector = PtychoNNProbePositionCorrector(configs=self.configs) corrector.build() corrector.run() self.redone_with_baseline = True @@ -567,11 +567,11 @@ def get_ordinary_key_name(self, mikey): ind = mikey.find("_multiiter") return mikey[:ind] - def update_config_dict(self, iter, initialize_with_baseline=False): + def update_configs(self, iter, initialize_with_baseline=False): for mikey in self.multiiter_keys: key = self.get_ordinary_key_name(mikey) - self.config_dict.overwrite_value_to_key( - self.config_dict, key, self.config_dict.query(mikey)[iter] + self.configs.overwrite_value_to_key( + self.configs, key, self.configs.query(mikey)[iter] ) if iter > 0: last_corrector = self.corrector_list[iter - 1] @@ -584,7 +584,7 @@ def update_config_dict(self, iter, initialize_with_baseline=False): raise ValueError( "Cannot initialize with baseline positions: baseline position list is None." ) - self.config_dict.probe_position_list = probe_pos_list + self.configs.probe_position_list = probe_pos_list logging.debug( "Using result from the last iteration to initialize probe position array..." ) diff --git a/ptychonn/position/helper.py b/ptychonn/position/helper.py deleted file mode 100644 index ad64f60..0000000 --- a/ptychonn/position/helper.py +++ /dev/null @@ -1,133 +0,0 @@ -import logging -import warnings - -try: - import pycuda.driver as cuda - import tensorrt as trt -except ImportError: - warnings.warn( - "Unable to import pycuda and tensorrt. If you do not intend to use the ONNX reconstructor, ignore " - "this message. " - ) -from skimage.transform import resize -import numpy as np - - -def engine_build_from_onnx(onnx_mdl): - EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - TRT_LOGGER = trt.Logger(trt.Logger.ERROR) - builder = trt.Builder(TRT_LOGGER) - config = builder.create_builder_config() - # config.set_flag(trt.BuilderFlag.FP16) - config.set_flag(trt.BuilderFlag.TF32) - config.max_workspace_size = 1 * ( - 1 << 30 - ) # the maximum size that any layer in the network can use - - network = builder.create_network(EXPLICIT_BATCH) - parser = trt.OnnxParser(network, TRT_LOGGER) - # Load the Onnx model and parse it in order to populate the TensorRT network. - success = parser.parse_from_file(onnx_mdl) - for idx in range(parser.num_errors): - print(parser.get_error(idx)) - - if not success: - return None - - return builder.build_engine(network, config) - - -def mem_allocation(engine): - """ - Determine dimensions and create page-locked memory buffers (i.e. won't be swapped to disk) to hold host - inputs/outputs. - """ - logging.debug("Expected input node shape is {}".format(engine.get_binding_shape(0))) - in_sz = trt.volume(engine.get_binding_shape(0)) * engine.max_batch_size - logging.debug("Input size: {}".format(in_sz)) - h_input = cuda.pagelocked_empty(in_sz, dtype="float32") - - out_sz = trt.volume(engine.get_binding_shape(1)) * engine.max_batch_size - h_output = cuda.pagelocked_empty(out_sz, dtype="float32") - - # Allocate device memory for inputs and outputs. - d_input = cuda.mem_alloc(h_input.nbytes) - d_output = cuda.mem_alloc(h_output.nbytes) - - # Create a stream in which to copy inputs/outputs and run inference. - stream = cuda.Stream() - - return h_input, h_output, d_input, d_output, stream - - -def inference(context, h_input, h_output, d_input, d_output, stream): - # Transfer input data to the GPU. - cuda.memcpy_htod_async(d_input, h_input, stream) - - # Run inference. - context.execute_async_v2( - bindings=[int(d_input), int(d_output)], stream_handle=stream.handle - ) - - # Transfer predictions back from the GPU. - cuda.memcpy_dtoh_async(h_output, d_output, stream) - - # Synchronize the stream - stream.synchronize() - # Return the host - return h_output - - -def transform_data_for_ptychonn( - dp, target_shape, discard_len=None, overflow_correction=False -): - """ - Throw away 1/8 of the boundary region, and resize DPs to match label size. - - :param dp: np.ndarray. The data to be transformed. Can be either 3D [N, H, W] or 2D [H, W]. - :param target_shape: list[int]. The target shape. - :param discard_len: tuple[int]. The length to discard on each side. If None, the length is default to 1/8 of the raw - image size. If the numbers are negative, the images will be padded instead. - :param overflow_correction: bool. Whether to correct overflowing pixels, whose values wrap around to the negative - side whn the true values surpass int16 limit. - :return: np.ndarray. - """ - dp = dp.astype(float) - if overflow_correction: - dp = correct_overflow(dp) - if discard_len is None: - discard_len = [dp.shape[i] // 8 for i in (-2, -1)] - for i in (0, 1): - if discard_len[i] > 0: - slicer = [slice(None)] * (len(dp.shape) - 2) - slicer_appendix = [slice(None), slice(None)] - slicer_appendix[i] = slice(discard_len[i], -discard_len[i]) - dp = dp[tuple(slicer + slicer_appendix)] - elif discard_len[i] < 0: - pad_len = [(0, 0)] * (len(dp.shape) - 2) - pad_len_appendix = [(0, 0), (0, 0)] - pad_len_appendix[i] = (-discard_len[i], -discard_len[i]) - dp = np.pad(dp, np.array(pad_len + pad_len_appendix), mode="constant") - target_shape = list(dp.shape[:-2]) + list(target_shape) - if not (target_shape[-1] == dp.shape[-1] and target_shape[-2] == dp.shape[-2]): - dp = resize(dp, target_shape, preserve_range=True, anti_aliasing=True) - return dp - - -def crop_center(img, shape_to_keep=(64, 64)): - slicer = [slice(None)] * (len(img.shape) - 2) - for i in range(-2, 0, 1): - st = (img.shape[i] - shape_to_keep[i]) // 2 - end = st + shape_to_keep[i] - slicer.append(slice(st, end)) - img = img[tuple(slicer)] - return img - - -def correct_overflow(arr): - mask = arr < 0 - vals = arr[mask] - vals = 32768 + (vals - -32768) - arr[mask] = vals - # logging.debug('{} overflowing values corrected.'.format(np.count_nonzero(mask))) - return arr diff --git a/ptychonn/position/reconstructor.py b/ptychonn/position/reconstructor.py index 545bef6..c3a65dc 100644 --- a/ptychonn/position/reconstructor.py +++ b/ptychonn/position/reconstructor.py @@ -5,15 +5,11 @@ from ptychonn.position.io import * -class Reconstructor: - def __init__(self, config_dict: InferenceConfig): - """ - Inference engine for PtychoNN. - - :param config_dict: dict. Configuration dictionary. - """ - self.config_dict = config_dict +class VirtualReconstructor: + def __init__(self, configs: InferenceConfig): + self.configs = configs self.device = None + self.object_image_array = None def build(self): if torch.cuda.is_available(): @@ -21,15 +17,6 @@ def build(self): else: self.device = torch.device("cpu") - def batch_infer(self, x): - pass - - -class VirtualReconstructor(Reconstructor): - def __init__(self, config_dict: InferenceConfig): - super().__init__(config_dict) - self.object_image_array = None - def set_object_image_array(self, arr): self.object_image_array = arr