Skip to content

Commit

Permalink
RegistrationConfigDict
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Mar 12, 2024
1 parent cc1a609 commit 41faf63
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 171 deletions.
15 changes: 15 additions & 0 deletions ptychonn/pospred/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@
Create an `InferenceConfigDict` object and set `reconstruction_image_path` and other parameters
like `num_neighbors_collective`, `method`, etc. Use the default parameters as a starting point.
Image registration parameters are supplied in a `RegistrationConfigDict` object to `registration_params`
of `InferenceConfigDict`, for example:
```
configs = InferenceConfigDict(
...
registration_params=RegistrationConfigDict(
registration_method='error_map',
...
)
)
```
If the settings for `RegistrationConfigDict` are stored in and read from a JSON or TOML file,
just put these parameters at the **same level** as other parameters. Don't create nested structures
in config files.
To start with an initial position set, create a `ProbePositionList` object with the initial positions,
and pass this object to the config obejct:
```
Expand Down
170 changes: 118 additions & 52 deletions ptychonn/pospred/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,36 @@ def get_serializable_dict(self):
d[key] = v
return d

@staticmethod
def recursive_query(config_obj, key):
for k in config_obj.__dict__.keys():
if k == key:
return config_obj.__dict__[k]
for k, item in config_obj.__dict__.items():
if isinstance(item, ConfigDict):
return config_obj.recursive_query(item, key)
return

def query(self, key):
return self.recursive_query(self, key)

@staticmethod
def overwrite_value_to_key(config_obj, key, value):
"""
Recursively search a ConfigDict and any of its keys that are also objects of ConfigDict for `key`.
Replace its value with `value` if found.
"""
is_multiiter_key = key.endswith('_multiiter')
key_basename = key if not is_multiiter_key else key[:-len('_multiiter')]
for k in config_obj.__dict__.keys():
if k == key_basename:
config_obj.__dict__[key] = value
return
for k, item in config_obj.__dict__.items():
if isinstance(item, ConfigDict):
config_obj.overwrite_value_to_key(item, key, value)
return

def dump_to_json(self, filename):
try:
f = open(filename, 'w')
Expand All @@ -57,7 +87,7 @@ def load_from_json(self, filename):
f = open(filename, 'r')
d = json.load(f)
for key in d.keys():
self.__dict__[key] = d[key]
self.overwrite_value_to_key(self, key, d[key])
f.close()

def load_from_toml(self, filename):
Expand All @@ -67,57 +97,37 @@ def load_from_toml(self, filename):
f = open(filename, 'rb')
d = tomli.load(f)
for key in d.keys():
self.__dict__[key] = d[key]
self.overwrite_value_to_key(self, key, d[key])
f.close()


@dataclasses.dataclass
class InferenceConfigDict(ConfigDict):

# ===== PtychoNN configs =====
batch_size: int = 1
"""Inference batch size."""

model_path: str = None
"""Path to a trained PtychoNN model."""
class RegistrationConfigDict(ConfigDict):

model: Any = None
"""
registration_method: str = 'error_map'
"""Registration method. Can be "error_map", "sift", "hybrid"."""

The model. Should be a tuple(nn.Module, kwargs): the first element of the tuple is the class handle of a
model class, and the second is a dictionary of keyword arguments. The model will be instantiated using these.
This value is used to instantiate a model object, whose weights are overwritten with those read from
max_shift: int = 7
"""The maximum x/y shift allowed in error map."""

do_subpixel: bool = True
"""If True, error map algorithm will attempt to get subpixel precision through quadratic fitting."""

`model_path`. The provided model class and arguments must match the model being loaded.
"""
subpixel_fitting_window_size: int = 5
"""Window size for subpixel fitting."""

ptycho_reconstructor: Any = None
subpixel_diff_tolerance: float = 2.0
"""
Should be either None or a Reconstructor object. If None, PyTorchReconstructor is used by default.
If the x or y distance between the subpixel offset found and the integer offset is beyond this value, subpixel
result will be rejected and integer offset will be used instead.
"""

dp_data_path: str = None
subpixel_fitting_check_coefficients: bool = True
"""
The path to the diffraction data file. When using a VirtualReconstrutor that uses already-reconstructed images,
keep this as None.
If True, coefficients of the fitted quadratic function are checked and the result will be marked questionable
if the quadratic function looks too smooth.
"""

prediction_output_path: str = None
"""Path to save PtychoNN prediction results."""

cpu_only: bool = False

onnx_mdl: Any = None
"""ONNX file when using ONNXReconstructor."""

# ===== Image registration configs =====
registration_method: str = 'error_map'

do_subpixel: bool = True

use_fast_errormap: bool = False

sift_outlier_removal_method: str = 'kmeans'
"""Method for detecting outlier matches for SIFT. Can be "trial_error", "kmeans", "isoforest", "ransac"."""

Expand Down Expand Up @@ -148,16 +158,84 @@ class InferenceConfigDict(ConfigDict):
"""

min_roi_stddev: float = 0.2
"""
The minimum standard deviation required in the region where registration errors are calculated. If the standard
deviation is below this value, registration result will be rejected as the area for error check might be too
flat to be conclusive.
"""

subpixel_fitting_window_size: int = 5
use_baseline_offsets_for_points_on_same_row: bool = False
"""
If True, baseline offset's x-component will be used for the horizontal offsets of all points on the same
row if they are arranged in a rectangular grid.
"""

subpixel_diff_tolerance: float = 2.0
use_baseline_offsets_for_unregistered_points: bool = False
"""
If True, if a point is not successfully registered with any neighbor in collective mode, it will fill
the linear system with the offsets of the two adjacently indexed points to that point from baseline positions.
"""

subpixel_fitting_check_coefficients: bool = True
use_baseline_offsets_for_uncertain_pairs: bool = False
"""
If True, if an image pair looks too empty to provide reliable registration result, it will fill
the linear system with the offsets of the two adjacently indexed points to that point from baseline positions.
"""

use_fast_errormap: bool = False
"""
Use fast error map algorithm, where errors are calculated between a sliced region of image 1 and a cropped
version of image 2, instead of rolling image 2 to impose shift.
"""

errormap_error_check_tol: float = 0.3
"""Error map result will be marked questionable if the lowest error is beyond this value."""


@dataclasses.dataclass
class InferenceConfigDict(ConfigDict):

# ===== PtychoNN configs =====
batch_size: int = 1
"""Inference batch size."""

model_path: str = None
"""Path to a trained PtychoNN model."""

model: Any = None
"""
The model. Should be a tuple(nn.Module, kwargs): the first element of the tuple is the class handle of a
model class, and the second is a dictionary of keyword arguments. The model will be instantiated using these.
This value is used to instantiate a model object, whose weights are overwritten with those read from
`model_path`. The provided model class and arguments must match the model being loaded.
"""

ptycho_reconstructor: Any = None
"""
Should be either None or a Reconstructor object. If None, PyTorchReconstructor is used by default.
"""

dp_data_path: str = None
"""
The path to the diffraction data file. When using a VirtualReconstrutor that uses already-reconstructed images,
keep this as None.
"""

prediction_output_path: str = None
"""Path to save PtychoNN prediction results."""

cpu_only: bool = False

onnx_mdl: Any = None
"""ONNX file when using ONNXReconstructor."""

# ===== General configs =====
registration_params: RegistrationConfigDict = dataclasses.field(default_factory=RegistrationConfigDict)
"""Registration parameters."""

reconstruction_image_path: Any = None
"""
Path to the reconstructed images to be used for position prediction. If None, PtychoNNProbePositionCorrector
Expand Down Expand Up @@ -196,8 +274,6 @@ class InferenceConfigDict(ConfigDict):
method: str = 'collective'
"""Method for correction. Can be 'serial' or 'collective'"""

max_shift: int = 7

num_neighbors_collective: int = 3
"""Number of neighbors in collective registration"""

Expand All @@ -207,18 +283,8 @@ class InferenceConfigDict(ConfigDict):

smooth_constraint_weight: float = 1e-2

use_baseline_offsets_for_uncertain_pairs: bool = False

rectangular_grid: bool = False

use_baseline_offsets_for_points_on_same_row: bool = False

use_baseline_offsets_for_unregistered_points: bool = False
"""
If True, if a point is not successfully registered with any neighbor in collective mode, it will fill
the linear system with the offsets of the two adjacently indexed points to that point from baseline positions.
"""

stitching_downsampling: int = 1

random_seed: Any = 123
Expand Down
Loading

0 comments on commit 41faf63

Please sign in to comment.