Skip to content

Commit

Permalink
Merge pull request #24 from mdw771/package
Browse files Browse the repository at this point in the history
Add position prediction
  • Loading branch information
carterbox authored Mar 27, 2024
2 parents 5bc29a3 + bc80a68 commit c8935ad
Show file tree
Hide file tree
Showing 13 changed files with 3,332 additions and 0 deletions.
98 changes: 98 additions & 0 deletions ptychonn/position/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Ptychography probe position prediction with PtychoNN
A ptychography probe position prediction algorithm making use of PtychoNN [1],
a single-shot phase retrieval network.
## How it works
A trained PtychoNN model is able to predict the local phase around a scan point
given the diffraction pattern alone, not needing any overlapping diffraction
patterns or position information. With predictions made around each scan point,
this algorithm finds out their pairwise offsets using common (but customized)
image registration methods. It then finds a least-squares solution of the
positions of all points in the same coordinates system by solving a linear
equation.
## How to use it with PtychoNN
### Prediction
Run prediction on diffraction patterns in PtychoNN and save all the images as a
single 3D tiff file.
### Create run configurations
Create an `InferenceConfig` object and set `reconstruction_image_path` and
other parameters like `num_neighbors_collective`, `method`, etc. Use the
default parameters as a starting point.
Image registration parameters are supplied in a `RegistrationConfig` object
to `registration_params` of `InferenceConfig`, for example:
```python
configs = InferenceConfig(
...
registration_params=RegistrationConfig(
registration_method='error_map',
...
)
)
```
If the settings for `RegistrationConfig` are stored in and read from a JSON
or TOML file, just put these parameters at the **same level** as other
parameters. Don't create nested structures in config files.
To start with an initial position set, create a `ProbePositionList` object with
the initial positions, and pass this object to the config obejct:
```python
configs = InferenceConfig(
...
probe_position_list=ProbePositionList(position_list=arr)
)
```
where arr is a `(N, 2)` array of probe positions in pixel.
Using the `ProbePositionCorrectorChain` class allows one to run position
prediction for multiple iterations with varied settings for certain parameters
for each iteration. To do this, create keys in the config object named as
`<name_of_existing_key>_multiiter`, and set a list or tuple of values to it.
Each element of the list is the value for that iteration. For example, setting
`configs.__dict__['method_multiiter'] = ["serial", "collective", "collective"]`
would tell the corrector chain to run 3 iterations, with `method` set to
`"serial"`, `"collective"`, and `"collective"` respectively.
> **Note:** do not pass "_multiiter" keys to the config object's constructor as it will not be recognized.
> Instead, either create new keys in the config object's dictionary container (`configs.__dict__`) after
> the object is instantiated,
> or keep these settings in a JSON or TOML file and read them afterwards.
### Run prediction
Just run the following:
```python
corrector_chain = ProbePositionCorrectorChain(configs)
corrector_chain.build()
corrector_chain.run()
```
Predicted positions can be obtained from
`corrector_chain.corrector_list[-1].new_probe_positions.array`.
### Examples
`tests/test_multiiter_pos_calculation.py` shows an example of a 3-iteration
position prediction run with images already predicted by PtychoNN. The script
demonstrates a case without any initial position input; however, if an initial
position set is desired, one can provide that through the `position_list` key
of the config object. See comments in the config object constructor inside the
script.
**References**
1. M. J. Cherukara, T. Zhou, Y. Nashed, P. Enfedaque, A. Hexemer, R. J. Harder, M. V. Holt, AI-enabled high-resolution scanning coherent diffraction imaging. Appl Phys Lett 117, 044103 (2020).
"""

from .configs import InferenceConfig, RegistrationConfig
from .core import ProbePositionCorrectorChain
from .position_list import ProbePositionList
277 changes: 277 additions & 0 deletions ptychonn/position/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
import json
import dataclasses
import warnings
from typing import Any, Optional, Union
from collections.abc import Sequence

from ptychonn.position.position_list import ProbePositionList
from ptychonn.position.io import DataFileHandle

try:
import tomli
except ImportError:
warnings.warn("Unable to import tomli, which is needed to load a TOML config file.")


@dataclasses.dataclass
class Config:
def __str__(self, *args, **kwargs):
s = ""
for key in self.__dict__.keys():
s += "{}: {}\n".format(key, self.__dict__[key])
return s

@staticmethod
def is_jsonable(x):
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False

def get_serializable_dict(self):
d = {}
for key in self.__dict__.keys():
v = self.__dict__[key]
if not self.__class__.is_jsonable(v):
if isinstance(v, (tuple, list)):
v = "_".join([str(x) for x in v])
else:
v = str(v)
d[key] = v
return d

@staticmethod
def recursive_query(config_obj, key):
for k in config_obj.__dict__.keys():
if k == key:
return config_obj.__dict__[k]
for k, item in config_obj.__dict__.items():
if isinstance(item, Config):
return config_obj.recursive_query(item, key)
return

def query(self, key):
return self.recursive_query(self, key)

@staticmethod
def overwrite_value_to_key(config_obj, key, value):
"""
Recursively search a Config and any of its keys that are also objects of Config for `key`.
Replace its value with `value` if found.
"""
is_multiiter_key = key.endswith("_multiiter")
key_basename = key if not is_multiiter_key else key[: -len("_multiiter")]
for k in config_obj.__dict__.keys():
if k == key_basename:
config_obj.__dict__[key] = value
return
for k, item in config_obj.__dict__.items():
if isinstance(item, Config):
config_obj.overwrite_value_to_key(item, key, value)
return

def dump_to_json(self, filename):
f = open(filename, "w")
d = self.get_serializable_dict()
json.dump(d, f)
f.close()

def load_from_json(self, filename):
"""
This function only overwrites entries contained in the JSON file. Unspecified entries are unaffected.
"""
f = open(filename, "r")
d = json.load(f)
for key in d.keys():
self.overwrite_value_to_key(self, key, d[key])
f.close()

def load_from_toml(self, filename):
"""
This function only overwrites entries contained in the TOML file. Unspecified entries are unaffected.
"""
f = open(filename, "rb")
d = tomli.load(f)
for key in d.keys():
self.overwrite_value_to_key(self, key, d[key])
f.close()

@staticmethod
def from_toml(filename):
obj = InferenceConfig()
obj.load_from_toml(filename)
return obj

@staticmethod
def from_json(filename):
obj = InferenceConfig()
obj.load_from_json(filename)
return obj


@dataclasses.dataclass
class RegistrationConfig(Config):
registration_method: str = "error_map"
"""Registration method. Can be "error_map", "sift", "hybrid"."""

max_shift: int = 7
"""The maximum x/y shift allowed in error map."""

do_subpixel: bool = True
"""If True, error map algorithm will attempt to get subpixel precision through quadratic fitting."""

subpixel_fitting_window_size: int = 5
"""Window size for subpixel fitting."""

subpixel_diff_tolerance: float = 2.0
"""
If the x or y distance between the subpixel offset found and the integer offset is beyond this value, subpixel
result will be rejected and integer offset will be used instead.
"""

subpixel_fitting_check_coefficients: bool = True
"""
If True, coefficients of the fitted quadratic function are checked and the result will be marked questionable
if the quadratic function looks too smooth.
"""

sift_outlier_removal_method: str = "kmeans"
"""Method for detecting outlier matches for SIFT. Can be "trial_error", "kmeans", "isoforest", "ransac"."""

sift_border_exclusion_length: int = 16
"""
The length of the near-boundary region of the image. When doing SIFT registration, if a matching pair of
keypoints involve points in this region, it will be discarded. However, if all matches (after outlier removal)
are near-boundary, they are used as they are. This operation is less aggressive than `central_crop`.
"""

registration_downsample: int = 1
"""Image downsampling before registration."""

hybrid_registration_algs: Sequence[str] = (
"error_map_multilevel",
"error_map_expandable",
"sift",
)
"""Hybrid registration algorithms"""

hybrid_registration_tols: Sequence[float] = (0.15, 0.3, 0.3)
"""Hybrid registration tolerances. This value is disregarded unless registration method is hybrid."""

nonhybrid_registration_tol: float = None
"""Error tolerance for non-hybrid registration. This value is disregarded if registration method is hybrid."""

registration_tol_schedule: Optional[Sequence[Sequence[int, float], ...]] = None
"""
The schedule of error tolerance for registration algorithms. This should be a (N, 2) list. In each sub-list,
the first value is the index of point, and the second value is the new tolerance value to be used at and
after that point.
"""

min_roi_stddev: float = 0.2
"""
The minimum standard deviation required in the region where registration errors are calculated. If the standard
deviation is below this value, registration result will be rejected as the area for error check might be too
flat to be conclusive.
"""

use_baseline_offsets_for_points_on_same_row: bool = False
"""
If True, baseline offset's x-component will be used for the horizontal offsets of all points on the same
row if they are arranged in a rectangular grid.
"""

use_baseline_offsets_for_unregistered_points: bool = False
"""
If True, if a point is not successfully registered with any neighbor in collective mode, it will fill
the linear system with the offsets of the two adjacently indexed points to that point from baseline positions.
"""

use_baseline_offsets_for_uncertain_pairs: bool = False
"""
If True, if an image pair looks too empty to provide reliable registration result, it will fill
the linear system with the offsets of the two adjacently indexed points to that point from baseline positions.
"""

use_fast_errormap: bool = False
"""
Use fast error map algorithm, where errors are calculated between a sliced region of image 1 and a cropped
version of image 2, instead of rolling image 2 to impose shift.
"""

errormap_error_check_tol: float = 0.3
"""Error map result will be marked questionable if the lowest error is beyond this value."""


@dataclasses.dataclass
class InferenceConfig(Config):
# ===== General configs =====
registration_params: RegistrationConfig = dataclasses.field(
default_factory=RegistrationConfig
)
"""Registration parameters."""

reconstruction_image_path: str = ''
"""
Path to the reconstructed images to be used for position prediction.
"""

probe_position_list: Optional[ProbePositionList] = None
"""
A ProbePositionList object used for finding nearest neighbors in collective mode.
If None, `probe_position_data_path` must be provided.
"""

probe_position_data_path: Optional[str] = None
"""
Path to the data file containing probe positions, which should be a CSV file with each line containing the
positions in y and x. Ignored if `probe_position_list` is provided.
"""

probe_position_data_unit: Optional[str] = None
"""Unit of provided probe position. Can be 'nm', 'm', or 'pixel'. Ignored if `probe_position_list` is provided."""

pixel_size_nm: Optional[float] = None
"""Pixel size of input positions. Ignored if `probe_position_list` is provided."""

baseline_position_list: Optional[ProbePositionList] = None
"""Baseline positions. Used by ProbePositionCorrectorChain when the serial mode result is bad."""

central_crop: Optional[Sequence[int, int]] = None
"""
List or tuple of int. Patch size used for image registration. If smaller than the reconstructed object size,
a patch will be cropped from the center.
"""

method: str = "collective"
"""Method for correction. Can be 'serial' or 'collective'"""

num_neighbors_collective: int = 3
"""Number of neighbors in collective registration"""

offset_estimator_order: int = 1
"""
Order of momentum used in the offset estimator. The estimator is used only in serial mode and when the
registration result is not reliable.
"""

offset_estimator_beta: float = 0.5
"""Weight of past offsets when updating the running average of offsets in the offset estimator."""

smooth_constraint_weight: float = 1e-2
"""
Weight of the smoothness constraint when solving for global-frame probe positions. This is the lambda_2
in the equation in the paper.
"""

rectangular_grid: bool = False
"""
Whether the scan grid is a rectangular grid. Some parameters including
`use_baseline_offsets_for_points_on_same_row` won't take effect unless this is set to True.
"""

random_seed: Optional[int] = 123
"""Random seed."""

debug: bool = False
Loading

0 comments on commit c8935ad

Please sign in to comment.