Skip to content

Commit

Permalink
Remove unused definitions and modules; rename config_dict to configs
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Mar 21, 2024
1 parent fdc7ae7 commit bc80a68
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 201 deletions.
102 changes: 51 additions & 51 deletions ptychonn/position/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def estimate(self):


class PtychoNNProbePositionCorrector:
def __init__(self, config_dict: InferenceConfig):
self.config_dict = config_dict
def __init__(self, configs: InferenceConfig):
self.configs = configs
self.ptycho_reconstructor = VirtualReconstructor(InferenceConfig())
self.dp_data_fhdl = None
self.orig_probe_positions = None
Expand All @@ -114,14 +114,14 @@ def __init__(self, config_dict: InferenceConfig):
self.unregistered_indices = []

def build(self):
if self.config_dict.random_seed is not None:
if self.configs.random_seed is not None:
logging.debug(
"Random seed is set to {}.".format(self.config_dict.random_seed)
"Random seed is set to {}.".format(self.configs.random_seed)
)
np.random.seed(self.config_dict.random_seed)
np.random.seed(self.configs.random_seed)

assert isinstance(self.ptycho_reconstructor, VirtualReconstructor)
recons = tifffile.imread(self.config_dict.reconstruction_image_path)
recons = tifffile.imread(self.configs.reconstruction_image_path)
self.ptycho_reconstructor.set_object_image_array(recons)
self.ptycho_reconstructor.build()

Expand All @@ -135,25 +135,25 @@ def build(self):
self.orig_probe_positions = ProbePositionList(
position_list=np.zeros([self.n_dps, 2])
)
if not self.config_dict.probe_position_list:
if self.config_dict.probe_position_data_path:
if not self.configs.probe_position_list:
if self.configs.probe_position_data_path:
self.orig_probe_positions = ProbePositionList(
file_path=self.config_dict.probe_position_data_path,
unit=self.config_dict.probe_position_data_unit,
psize_nm=self.config_dict.pixel_size_nm,
file_path=self.configs.probe_position_data_path,
unit=self.configs.probe_position_data_unit,
psize_nm=self.configs.pixel_size_nm,
)
else:
self.orig_probe_positions = self.config_dict.probe_position_list
self.orig_probe_positions = self.configs.probe_position_list
self.new_probe_positions = self.orig_probe_positions.copy_with_zeros()

self.registrator = Registrator(
self.config_dict.registration_params,
random_seed=self.config_dict.random_seed,
self.configs.registration_params,
random_seed=self.configs.random_seed,
)

if (
self.config_dict.rectangular_grid
and self.config_dict.registration_params.use_baseline_offsets_for_points_on_same_row
self.configs.rectangular_grid
and self.configs.registration_params.use_baseline_offsets_for_points_on_same_row
):
self.build_row_index_list()

Expand All @@ -174,14 +174,14 @@ def build_row_index_list(self):
self.row_index_list = np.array(row_inds)

def run(self):
if self.config_dict.method == "serial":
if self.configs.method == "serial":
self.run_probe_position_correction_serial()
elif self.config_dict.method == "collective":
elif self.configs.method == "collective":
self.run_probe_position_correction_collective()
else:
raise ValueError(
"Correction method {} is not supported. ".format(
self.config_dict.method
self.configs.method
)
)

Expand All @@ -198,7 +198,7 @@ def reconstruct_dp(self, ind=None, dp=None):
obj_amp, obj_ph = self.ptycho_reconstructor.batch_infer(
dp[np.newaxis, :, :]
)
# if self.config_dict.debug:
# if self.configs.debug:
# fig, ax = plt.subplots(1, 3)
# ax[0].imshow(dp)
# ax[0].set_title('DP')
Expand All @@ -211,9 +211,9 @@ def reconstruct_dp(self, ind=None, dp=None):
return obj_amp, obj_ph

def crop_center(self, image_list):
if self.config_dict.central_crop is None:
if self.configs.central_crop is None:
return image_list
crop_shape = self.config_dict.central_crop
crop_shape = self.configs.central_crop
for i in range(len(image_list)):
orig_shape = image_list[i].shape[1:]
start_point = [(orig_shape[j] - crop_shape[j]) // 2 for j in range(2)]
Expand All @@ -230,15 +230,15 @@ def run_probe_position_correction_serial(self):
Run serial mode probe position correction.
"""
offset_tracker = OffsetEstimator(
beta=self.config_dict.offset_estimator_beta,
order=self.config_dict.offset_estimator_order,
beta=self.configs.offset_estimator_beta,
order=self.configs.offset_estimator_order,
)
previous_obj = self.reconstruct_dp(0)[1][0]
for ind in trange(1, self.n_dps):
current_obj = self.reconstruct_dp(ind)[1][0]
offset = self.registrator.run(previous_obj, current_obj)
if (
self.config_dict.registration_params.use_baseline_offsets_for_uncertain_pairs
self.configs.registration_params.use_baseline_offsets_for_uncertain_pairs
and self.registrator.get_status()
== self.registrator.get_status_code("empty")
):
Expand All @@ -251,7 +251,7 @@ def run_probe_position_correction_serial(self):
):
offset = offset_tracker.estimate()
self.count_bad_offset += 1
if self.config_dict.debug:
if self.configs.debug:
fig, ax = plt.subplots(1, 2)
ax[0].imshow(previous_obj)
ax[1].imshow(current_obj)
Expand All @@ -270,7 +270,7 @@ def run_probe_position_correction_collective(self):
self.build_linear_system_for_collective_correction()
self.solve_linear_system(
mode="residue",
smooth_constraint_weight=self.config_dict.smooth_constraint_weight,
smooth_constraint_weight=self.configs.smooth_constraint_weight,
)
self.postprocess()

Expand All @@ -284,7 +284,7 @@ def get_neightbor_inds(self, i_dp, knn_inds):
else:
this_neighbors_inds = [i_dp - 1, i_dp + 1]
i_knn = 0
while len(this_neighbors_inds) < self.config_dict.num_neighbors_collective:
while len(this_neighbors_inds) < self.configs.num_neighbors_collective:
if knn_inds[i_knn] not in this_neighbors_inds:
this_neighbors_inds.append(knn_inds[i_knn])
i_knn += 1
Expand All @@ -297,14 +297,14 @@ def build_linear_system_for_collective_correction(self):
# registered with O(1) time.
registered_pair_hash_table = collections.defaultdict(lambda: None)
nn_engine = sklearn.neighbors.NearestNeighbors(
n_neighbors=self.config_dict.num_neighbors_collective + 1
n_neighbors=self.configs.num_neighbors_collective + 1
)
nn_engine.fit(self.orig_probe_positions.array)
nn_dists, nn_inds = nn_engine.kneighbors(self.orig_probe_positions.array)
for i_dp, this_orig_pos in enumerate(tqdm(self.orig_probe_positions.array)):
if self.config_dict.registration_params.registration_tol_schedule:
if self.configs.registration_params.registration_tol_schedule:
self.registrator.update_tol_by_tol_schedule(
i_dp, self.config_dict.registration_params.registration_tol_schedule
i_dp, self.configs.registration_params.registration_tol_schedule
)
this_knn_inds = nn_inds[i_dp, 1:]
this_neighbors_inds = self.get_neightbor_inds(i_dp, this_knn_inds)
Expand All @@ -329,7 +329,7 @@ def build_linear_system_for_collective_correction(self):
)[1][0]

if (
self.config_dict.registration_params.use_baseline_offsets_for_points_on_same_row
self.configs.registration_params.use_baseline_offsets_for_points_on_same_row
and self.row_index_list is not None
and self.row_index_list[i_dp] == self.row_index_list[ind_neighbor]
):
Expand All @@ -349,7 +349,7 @@ def build_linear_system_for_collective_correction(self):
offset = self.registrator.run(neighbor_obj, current_obj)
# print('{} - {}: {}'.format(i_dp, ind_neighbor, offset))
if (
self.config_dict.debug
self.configs.debug
and self.registrator.get_status()
!= self.registrator.get_status_code("empty")
):
Expand All @@ -363,7 +363,7 @@ def build_linear_system_for_collective_correction(self):
plt.tight_layout()
print("Offset: {}".format(offset))
if (
self.config_dict.registration_params.use_baseline_offsets_for_uncertain_pairs
self.configs.registration_params.use_baseline_offsets_for_uncertain_pairs
and self.registrator.get_status()
== self.registrator.get_status_code("empty")
):
Expand All @@ -382,7 +382,7 @@ def build_linear_system_for_collective_correction(self):
self.a_mat.append(self._generate_amat_row(i_dp, ind_neighbor))
self.a_mat = np.stack(self.a_mat)
self.b_vec = np.stack(self.b_vec)
if self.config_dict.registration_params.use_baseline_offsets_for_unregistered_points:
if self.configs.registration_params.use_baseline_offsets_for_unregistered_points:
self.fill_gaps_in_linear_system()

def fill_gaps_in_linear_system(self):
Expand Down Expand Up @@ -437,7 +437,7 @@ def solve_linear_system(self, mode="residue", smooth_constraint_weight=1e-3):
self.new_probe_positions.array = np.linalg.pinv(a_mat) @ b_vec

def postprocess(self):
if self.config_dict.registration_params.use_baseline_offsets_for_unregistered_points:
if self.configs.registration_params.use_baseline_offsets_for_unregistered_points:
for ind in self.unregistered_indices:
if ind > 0:
baseline_offset = (
Expand Down Expand Up @@ -487,12 +487,12 @@ def _update_probe_position_list(self, ind, offset):


class ProbePositionCorrectorChain:
def __init__(self, config_dict: InferenceConfig):
self.config_dict = config_dict
def __init__(self, configs: InferenceConfig):
self.configs = configs
self.corrector_list = []
self.multiiter_keys = []
self.n_iters = 1
self.baseline_pos_list = config_dict.baseline_position_list
self.baseline_pos_list = configs.baseline_position_list
self.collective_mode_offset_tol = 150
self.verbose = True
self.redone_with_baseline = False
Expand All @@ -513,10 +513,10 @@ def find_multiiter_keys(self, config_obj):

def build_multiiter_entries(self):
self.has_multiiter_key = False
self.find_multiiter_keys(self.config_dict)
self.find_multiiter_keys(self.configs)
if not self.has_multiiter_key:
raise ValueError(
"With ProbePositionCorrectorChain, there should be at least one entry in the config dict "
"With ProbePositionCorrectorChain, there should be at least one entry in configs.__dict__ "
'that ends with "_multiiter" and is a list whose length equals to the desired number of '
"iterations. "
)
Expand All @@ -527,15 +527,15 @@ def run(self):

def run_correction_iteration(self, iter):
logging.debug("Now running iteration {}.".format(iter))
self.update_config_dict(iter)
self.update_configs(iter)
if self.verbose:
print(self.config_dict)
corrector = PtychoNNProbePositionCorrector(config_dict=self.config_dict)
print(self.configs)
corrector = PtychoNNProbePositionCorrector(configs=self.configs)
corrector.build()
if self.verbose:
corrector.orig_probe_positions.plot()
corrector.run()
if self.config_dict.method == "collective" and (
if self.configs.method == "collective" and (
not self.is_collective_result_good(corrector)
):
# Redo iteration using baseline as initialization if result is bad
Expand All @@ -544,8 +544,8 @@ def run_correction_iteration(self, iter):
"redo this iteration with baseline positions as initialization..."
)
if self.baseline_pos_list:
self.config_dict.probe_position_list = self.baseline_pos_list
corrector = PtychoNNProbePositionCorrector(config_dict=self.config_dict)
self.configs.probe_position_list = self.baseline_pos_list
corrector = PtychoNNProbePositionCorrector(configs=self.configs)
corrector.build()
corrector.run()
self.redone_with_baseline = True
Expand All @@ -567,11 +567,11 @@ def get_ordinary_key_name(self, mikey):
ind = mikey.find("_multiiter")
return mikey[:ind]

def update_config_dict(self, iter, initialize_with_baseline=False):
def update_configs(self, iter, initialize_with_baseline=False):
for mikey in self.multiiter_keys:
key = self.get_ordinary_key_name(mikey)
self.config_dict.overwrite_value_to_key(
self.config_dict, key, self.config_dict.query(mikey)[iter]
self.configs.overwrite_value_to_key(
self.configs, key, self.configs.query(mikey)[iter]
)
if iter > 0:
last_corrector = self.corrector_list[iter - 1]
Expand All @@ -584,7 +584,7 @@ def update_config_dict(self, iter, initialize_with_baseline=False):
raise ValueError(
"Cannot initialize with baseline positions: baseline position list is None."
)
self.config_dict.probe_position_list = probe_pos_list
self.configs.probe_position_list = probe_pos_list
logging.debug(
"Using result from the last iteration to initialize probe position array..."
)
Loading

0 comments on commit bc80a68

Please sign in to comment.