Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify IRISPipeline output. #12

Merged
merged 3 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions colab/GettingStarted.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
"cell_type": "code",
"execution_count": null,
"id": "b377bfbb-d1a7-4df8-bcc7-ce13725229f8",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"!pip install open-iris"
Expand Down Expand Up @@ -221,9 +223,9 @@
"id": "3b2896e5-5069-48db-b855-cd80ea04cd6e",
"metadata": {},
"source": [
"The `iris_template` value contains generated by the `IRISPipeline` iris code for an iris texture visible in the input image. The `output[\"iris_template\"]` value is a `dict` containing two keys: `[\"iris_codes\", \"mask_codes\"]`. \n",
"The `iris_template` value contains generated by the `IRISPipeline` iris code for an iris texture visible in the input image. The `output[\"iris_template\"]` value is a `IrisTemplate` object containing two fields: `[\"iris_codes: List[np.ndarray]\", \"mask_codes: List[np.ndarray]\"]`. \n",
"\n",
"Each code available in `output[\"iris_template\"]` dictionary is a `numpy.ndarray` of shape `(16, 256, 2, 2)`. The output shape of iris code is determined by `IRISPipeline` filter bank parameters. The iris/mask code shape's dimmensions correspond to the following `(iris_code_height, iris_code_width, num_filters, 2)`. Values `iris_code_height` and `iris_code_width` are determined by `ProbeSchema`s defined for `ConvFilterBank` object and `num_filters` is determined by number of filters specified for `ConvFilterBank` object. The last `2` value of the iris/mask code dimmension corresponds to real and complex parts of each complex filter response.\n",
"Each code available in `output[\"iris_template\"]` object is a `numpy.ndarray` of shape `(16, 256, 2)`. The length of arrays containing iris codes and mask codes is determined by `IRISPipeline` filter bank parameters. The iris/mask code shape's dimmensions correspond to the following `(iris_code_height, iris_code_width, 2)`. Values `iris_code_height` and `iris_code_width` are determined by `ProbeSchema`'s definition for `ConvFilterBank` object and `num_filters` is determined by number of filters specified for `ConvFilterBank` object. The last `2` value of the iris/mask code dimmension corresponds to the real and imaginary parts of each complex filter response.\n",
"\n",
"_NOTE_: More about how to specify those parameters and configuring custom `IRISPipeline` can be found in the _Configuring custom pipeline_ tutorial."
]
Expand All @@ -235,27 +237,20 @@
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Available keys in `output[\"iris_template\"]` are: \"\"\" + str(output[\"iris_template\"].keys())"
"\"\"\"Available fields in `output[\"iris_template\"]` are: \"\"\" + str(output[\"iris_template\"].__fields__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bbfbc26-6479-43f2-b548-a57070d66092",
"id": "302b0b34-04d9-46dd-8f4c-038955731b66",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"`output[\"iris_template\"]` value types are: \"\"\" + type(output[\"iris_template\"][\"iris_codes\"]).__name__ + \", \" + type(output[\"iris_template\"][\"mask_codes\"]).__name__"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c87e03af-db75-4086-91a9-177251560c83",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"`output[\"iris_template\"]` value shapes are: \"\"\" + str(output[\"iris_template\"][\"iris_codes\"].shape) + \", \" + str(output[\"iris_template\"][\"mask_codes\"].shape)"
"num_codes = len(output[\"iris_template\"].iris_codes)\n",
"code_shape = output[\"iris_template\"].iris_codes[0].shape\n",
"\n",
"f\"\"\"Number of returned iris codes is equal to {num_codes} and each code shape is {code_shape}\"\"\""
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ If ``output["error"]`` value is ``None``, ``IRISPipeline`` finished inference ca
'traceback': 'Very long exception traceback'
}

The ``iris_template`` value contains generated by the ``IRISPipeline`` iris code for an iris texture visible in the input image. The ``output["iris_template"]`` value is a ``dict`` containing two keys: ``["iris_codes", "mask_codes"]``.
The ``iris_template`` value contains generated by the ``IRISPipeline`` iris code for an iris texture visible in the input image. The ``output["iris_template"]`` value is a ``IrisTemplate`` object containing two fields: ``["iris_codes: List[np.ndarray]", "mask_codes: List[np.ndarray]"]``.

Each code available in ``output["iris_template"]`` dictionary is a ``numpy.ndarray`` of shape ``(16, 256, 2, 2)``. The output shape of iris code is determined by ``IRISPipeline`` filter bank parameters. The iris/mask code shape's dimensions correspond to the following ``(iris_code_height, iris_code_width, num_filters, 2)``. Values ``iris_code_height`` and ``iris_code_width`` are determined by ``ProbeSchema``s defined for ``ConvFilterBank`` object and ``num_filters`` is determined by number of filters specified for ``ConvFilterBank`` object. The last ``2`` value of the iris/mask code dimension corresponds to real and complex parts of each complex filter response.
Each code available in ``output["iris_template"]`` dictionary is a ``numpy.ndarray`` of shape ``(16, 256, 2)``. The length of arrays containing iris codes and mask codes is determined by ``IRISPipeline`` filter bank parameters. The iris/mask code shape's dimensions correspond to the following ``(iris_code_height, iris_code_width, 2)``. Values ``iris_code_height`` and ``iris_code_width`` are determined by ``ProbeSchema``'s definition for ``ConvFilterBank`` object and ``num_filters`` is determined by number of filters specified for ``ConvFilterBank`` object. The last ``2`` value of the iris/mask code dimension corresponds to the real and imaginary parts of each complex filter response.

*NOTE*: More about how to specify those parameters and configuring custom ``IRISPipeline`` can be found in the *Configuring custom pipeline* tutorial.

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ exclude_lines = ["if __name__ == .__main__.:"]

[tool.ruff]
exclude = ["__init__.py"]
select = ["E", "F", "PLC", "PLE", "PLR", "PLW"]
ignore = ["E501", "F722", "F821", "PLR2004", "PLR0915", "PLR0913", "PLC0414", "PLR0402", "PLR5501", "PLR0911", "PLR0912", "PLW0603", "PLW2901"]
lint.select = ["E", "F", "PLC", "PLE", "PLR", "PLW"]
lint.ignore = ["E501", "F722", "F821", "PLR2004", "PLR0915", "PLR0913", "PLC0414", "PLR0402", "PLR5501", "PLR0911", "PLR0912", "PLW0603", "PLW2901"]

[tool.isort]
profile = "black"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,7 @@ def compute_kernel_values(self) -> np.ndarray:

# calculate envelope and orientation
envelope = np.exp(
-0.5
* np.log2(radius * self.params.lambda_rho / self.params.kernel_size[1]) ** 2
/ self.params.sigma_rho**2
-0.5 * np.log2(radius * self.params.lambda_rho / self.params.kernel_size[1]) ** 2 / self.params.sigma_rho**2
)
envelope[ksize_rho_half][ksize_phi_half] = 0
orientation = np.exp(-0.5 * dtheta**2 / self.params.sigma_phi**2)
Expand Down
2 changes: 1 addition & 1 deletion src/iris/nodes/normalization/nonlinear_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, res_in_r: int = 128, oversat_threshold: int = 254) -> None:

Args:
res_in_r (int): Normalized image r resolution. Defaults to 128.
oversat_threshold (int, optional): threshold for masking over-satuated pixels. Defaults to 254.
oversat_threshold (int, optional): threshold for masking over-satuated pixels. Defaults to 254.
"""
intermediate_radiuses = np.array([getgrids(max(0, res_in_r), p2i_ratio) for p2i_ratio in range(100)])
super().__init__(
Expand Down
37 changes: 37 additions & 0 deletions src/iris/orchestration/output_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,43 @@
from iris.io.dataclasses import ImmutableModel


def build_simple_output(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]:
"""Build the output for the Orb.

Args:
call_trace (PipelineCallTraceStorage): Pipeline call results storage.

Returns:
Dict[str, Any]: {
"iris_template": (Optional[IrisTemplate]) the iris template object if the pipeline succeeded,
"error": (Optional[Dict]) the error dict if the pipeline returned an error,
"metadata": (Dict) the metadata dict,
}.
"""
metadata = __get_metadata(call_trace=call_trace)
error = __get_error(call_trace=call_trace)
iris_template = None

exception = call_trace.get_error()
if exception is None:
iris_template = call_trace["encoder"]
error = None
elif isinstance(exception, Exception):
error = {
"error_type": type(exception).__name__,
"message": str(exception),
"traceback": "".join(traceback.format_tb(exception.__traceback__)),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wiktorlazarski This is already dealt with in __get_error, we can simplify this function even more 🙂

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I completely missed that. Thanks for the hint. That'll allow us to simplify also build_orb_output function.


output = {
"error": error,
"iris_template": iris_template,
"metadata": metadata,
}

return output


def build_orb_output(call_trace: PipelineCallTraceStorage) -> Dict[str, Any]:
"""Build the output for the Orb.

Expand Down
12 changes: 9 additions & 3 deletions src/iris/pipelines/iris_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from iris.io.errors import IRISPipelineError
from iris.orchestration.environment import Environment
from iris.orchestration.error_managers import store_error_manager
from iris.orchestration.output_builders import build_debugging_output, build_orb_output
from iris.orchestration.output_builders import build_debugging_output, build_orb_output, build_simple_output
from iris.orchestration.pipeline_dataclasses import PipelineClass, PipelineMetadata, PipelineNode
from iris.orchestration.validators import pipeline_config_duplicate_node_name_check

Expand All @@ -41,6 +41,12 @@ class IRISPipeline(Algorithm):
call_trace_initialiser=PipelineCallTraceStorage.initialise,
)

ORB_ENVIRONMENT = Environment(
pipeline_output_builder=build_orb_output,
error_manager=store_error_manager,
call_trace_initialiser=PipelineCallTraceStorage.initialise,
)

class Parameters(Algorithm.Parameters):
"""IRISPipeline parameters, all derived from the input `config`."""

Expand All @@ -57,7 +63,7 @@ def __init__(
self,
config: Union[Dict[str, Any], Optional[str]] = None,
env: Environment = Environment(
pipeline_output_builder=build_orb_output,
pipeline_output_builder=build_simple_output,
error_manager=store_error_manager,
call_trace_initialiser=PipelineCallTraceStorage.initialise,
),
Expand All @@ -66,7 +72,7 @@ def __init__(

Args:
config (Union[Dict[str, Any], Optional[str]]): Input configuration, as a YAML-formatted string or dictionary specifying all nodes configuration. Defaults to None, which loads the default config.
env (Environment, optional): Environment properties. Defaults to Environment(output_builder=build_orb_output, error_manager=store_error_manager, call_trace_initialiser=PipelineCallTraceStorage).
env (Environment, optional): Environment properties. Defaults to Environment(output_builder=build_simple_output, error_manager=store_error_manager, call_trace_initialiser=PipelineCallTraceStorage).
"""
deserialized_config = self.load_config(config) if isinstance(config, str) or config is None else config
super().__init__(**deserialized_config)
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e_tests/pipelines/test_e2e_iris_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def expected_debug_pipeline_output() -> Dict[str, Any]:

def test_e2e_iris_pipeline(ir_image: np.ndarray, expected_iris_pipeline_output: Dict[str, Any]) -> None:
"""End-to-end test of the IRISPipeline in the Orb setup"""
iris_pipeline = IRISPipeline()
iris_pipeline = IRISPipeline(env=IRISPipeline.ORB_ENVIRONMENT)
computed_pipeline_output = iris_pipeline(img_data=ir_image, eye_side="right")

compare_iris_pipeline_outputs(computed_pipeline_output, expected_iris_pipeline_output)
Expand Down
Loading