diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5604ad317..892c7333a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -26,3 +26,11 @@ This will also be run automatically when a PR is made to master and a codecov re - For a completely new feature, make a branch off of the `dev/main` branch of CellMap's fork of DaCapo with a name describing the feature. If you are collaborating on a feature that already has a branch, you can branch off that feature branch. - Currently, you should make your PRs into the `dev/main` branch of CellMap's fork, or the feature branch you branched off of. PRs currently require one maintainer's approval before merging. Once the PR is merged, the feature branch should be deleted. - `dev/main` will be regularly merged to `main` when new features are fully implemented and all tests are passing. + + +## Documentation +Documentation is built using Sphinx. To build the documentation locally, run +```bash +sphinx-build -M html docs/source docs/build +``` +This will generate the html files in the `docs/build/html` directory. \ No newline at end of file diff --git a/dacapo/apply.py b/dacapo/apply.py index f9a1af28f..0bbb66ea6 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -18,7 +18,7 @@ create_weights_store, ) -from pathlib import Path +from upath import UPath as Path logger = logging.getLogger(__name__) @@ -38,7 +38,40 @@ def apply( overwrite: bool = True, file_format: str = "zarr", ): - """Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.""" + """ + Load weights and apply a trained model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used. + + Args: + run_name (str): Name of the run to apply. + input_container (Path | str): Path to the input container. + input_dataset (str): Name of the input dataset. + output_path (Path | str): Path to the output container. + validation_dataset (Optional[Dataset | str], optional): Validation dataset to use for finding the best parameters. Defaults to None. + criterion (str, optional): Criterion to use for finding the best parameters. Defaults to "voi". + iteration (Optional[int], optional): Iteration to use. If None, the best iteration is used. Defaults to None. + parameters (Optional[PostProcessorParameters | str], optional): Post-processor parameters to use. If None, the best parameters are found. Defaults to None. + roi (Optional[Roi | str], optional): Region of interest to use. If None, the whole input dataset is used. Defaults to None. + num_workers (int, optional): Number of workers to use. Defaults to 12. + output_dtype (np.dtype | str, optional): Output dtype. Defaults to np.uint8. + overwrite (bool, optional): Overwrite existing output. Defaults to True. + file_format (str, optional): File format to use. Defaults to "zarr". + Raises: + ValueError: If validation_dataset is None and criterion is not None. + ValueError: If parameters is a string that cannot be parsed to PostProcessorParameters. + ValueError: If parameters is not a PostProcessorParameters object. + Examples: + >>> apply( + ... run_name="run_1", + ... input_container="data.zarr", + ... input_dataset="raw", + ... output_path="output.zarr", + ... validation_dataset="validate", + ... criterion="voi", + ... num_workers=12, + ... output_dtype=np.uint8, + ... overwrite=True, + ... ) + """ if isinstance(output_dtype, str): output_dtype = np.dtype(output_dtype) @@ -178,8 +211,36 @@ def apply_run( output_dtype: np.dtype | str = np.uint8, # type: ignore overwrite: bool = True, ): - """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" + """ + Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded. + Args: + run (Run): The run object containing the task and post-processor. + iteration (int): The iteration number. + parameters (PostProcessorParameters): The post-processor parameters. + input_array_identifier (LocalArrayIdentifier): The identifier for the input array. + prediction_array_identifier (LocalArrayIdentifier): The identifier for the prediction array. + output_array_identifier (LocalArrayIdentifier): The identifier for the output array. + roi (Optional[Roi], optional): The region of interest. Defaults to None. + num_workers (int, optional): The number of workers for parallel processing. Defaults to 12. + output_dtype (np.dtype | str, optional): The output data type. Defaults to np.uint8. + overwrite (bool, optional): Whether to overwrite existing output. Defaults to True. + Raises: + ValueError: If the input array is not a ZarrArray. + Examples: + >>> apply_run( + ... run=run, + ... iteration=1, + ... parameters=parameters, + ... input_array_identifier=LocalArrayIdentifier(Path("data.zarr"), "raw"), + ... prediction_array_identifier=LocalArrayIdentifier(Path("output.zarr"), "prediction_run_1_1"), + ... output_array_identifier=LocalArrayIdentifier(Path("output.zarr"), "output_run_1_1"), + ... roi=None, + ... num_workers=12, + ... output_dtype=np.uint8, + ... overwrite=True, + ... ) + """ # render prediction dataset print(f"Predicting on dataset {prediction_array_identifier}") predict( diff --git a/dacapo/blockwise/argmax_worker.py b/dacapo/blockwise/argmax_worker.py index 59a17d752..d9f893452 100644 --- a/dacapo/blockwise/argmax_worker.py +++ b/dacapo/blockwise/argmax_worker.py @@ -1,4 +1,4 @@ -from pathlib import Path +from upath import UPath as Path import sys from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier @@ -27,6 +27,12 @@ default="INFO", ) def cli(log_level): + """ + CLI for running the threshold worker. + + Args: + log_level (str): The log level to use. + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -47,7 +53,17 @@ def start_worker( input_dataset: str, output_container: Path | str, output_dataset: str, + return_io_loop: bool = False, ): + """ + Start the threshold worker. + + Args: + input_container (Path | str): The input container. + input_dataset (str): The input dataset. + output_container (Path | str): The output container. + output_dataset (str): The output dataset. + """ # get arrays input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) input_array = ZarrArray.open_from_array_identifier(input_array_identifier) @@ -57,34 +73,51 @@ def start_worker( ) output_array = ZarrArray.open_from_array_identifier(output_array_identifier) - # wait for blocks to run pipeline - client = daisy.Client() + def io_loop(): + # wait for blocks to run pipeline + client = daisy.Client() - while True: - print("getting block") - with client.acquire_block() as block: - if block is None: - break + while True: + print("getting block") + with client.acquire_block() as block: + if block is None: + break - # write to output array - output_array[block.write_roi] = np.argmax( - input_array[block.write_roi], - axis=input_array.axes.index("c"), - ) + # write to output array + output_array[block.write_roi] = np.argmax( + input_array[block.write_roi], + axis=input_array.axes.index("c"), + ) + + if return_io_loop: + return io_loop + else: + io_loop() def spawn_worker( input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + Returns: + Callable: The function to run the worker. """ compute_context = create_compute_context() + if not compute_context.distribute_workers: + return start_worker( + input_array_identifier.container, + input_array_identifier.dataset, + output_array_identifier.container, + output_array_identifier.dataset, + return_io_loop=True, + ) # Make the command for the worker to run command = [ @@ -103,7 +136,9 @@ def spawn_worker( ] def run_worker(): - # Run the worker in the given compute context + """ + Run the worker in the given compute context. + """ compute_context.execute(command) return run_worker diff --git a/dacapo/blockwise/blockwise_task.py b/dacapo/blockwise/blockwise_task.py index cbae73b8b..090c9614d 100644 --- a/dacapo/blockwise/blockwise_task.py +++ b/dacapo/blockwise/blockwise_task.py @@ -1,10 +1,32 @@ from datetime import datetime from importlib.machinery import SourceFileLoader -from pathlib import Path +from upath import UPath as Path from daisy import Task, Roi class DaCapoBlockwiseTask(Task): + """ + A task to run a blockwise worker function. This task is used to run a + blockwise worker function on a given ROI. + + Attributes: + worker_file (str | Path): The path to the worker file. + total_roi (Roi): The ROI to process. + read_roi (Roi): The ROI to read from for a block. + write_roi (Roi): The ROI to write to for a block. + num_workers (int): The number of workers to use. + max_retries (int): The maximum number of times a task will be retried if failed + (either due to failed post check or application crashes or network + failure) + timeout: The timeout for the task. + upstream_tasks: The upstream tasks. + *args: Additional positional arguments to pass to ``worker_function``. + **kwargs: Additional keyword arguments to pass to ``worker_function``. + Methods: + __init__: + Initialize the task. + """ + def __init__( self, worker_file: str | Path, @@ -18,6 +40,23 @@ def __init__( *args, **kwargs, ): + """ + Initialize the task. + + Args: + worker_file (str | Path): The path to the worker file. + total_roi (Roi): The ROI to process. + read_roi (Roi): The ROI to read from for a block. + write_roi (Roi): The ROI to write to for a block. + num_workers (int): The number of workers to use. + max_retries (int): The maximum number of times a task will be retried if failed + (either due to failed post check or application crashes or network + failure) + timeout: The timeout for the task. + upstream_tasks: The upstream tasks. + *args: Additional positional arguments to pass to ``worker_function``. + **kwargs: Additional keyword arguments to pass to ``worker_function``. + """ # Load worker functions worker_name = Path(worker_file).stem worker = SourceFileLoader(worker_name, str(worker_file)).load_module() diff --git a/dacapo/blockwise/empanada_function.py b/dacapo/blockwise/empanada_function.py index 4175f8577..301a282e8 100644 --- a/dacapo/blockwise/empanada_function.py +++ b/dacapo/blockwise/empanada_function.py @@ -45,6 +45,35 @@ def segment_function(input_array, block, **parameters): + """ + Segment a 3D block using the empanada-napari library. + + Args: + input_array (np.ndarray): The 3D array to segment. + block (dask.array.core.Block): The block object. + **parameters: Parameters for the empanada-napari segmenter. + Returns: + np.ndarray: The segmented 3D array. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> import numpy as np + >>> from dask import array as da + >>> from dacapo.blockwise.empanada_function import segment_function + >>> input_array = np.random.rand(64, 64, 64) + >>> block = da.from_array(input_array, chunks=(32, 32, 32)) + >>> segmented_array = segment_function(block, model_config="MitoNet_v1") + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + Reference: + - doi: 10.1016/j.cels.2022.12.006 + """ vols, class_names = [], [] for vol, class_name, _ in empanada_segmenter( input_array[block.read_roi], **parameters @@ -60,12 +89,66 @@ def segment_function(input_array, block, **parameters): def stack_inference(engine, volume, axis_name): + """ + Perform inference on a single axis of a 3D volume. + + Args: + engine (Engine3d): The engine object. + volume (np.ndarray): The 3D volume to segment. + axis_name (str): The axis name to segment. + Returns: + tuple: The stack, axis name, and trackers dictionary. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> import numpy as np + >>> from empanada_napari.inference import Engine3d + >>> from dacapo.blockwise.empanada_function import stack_inference + >>> model_config = "MitoNet_v1" + >>> use_gpu = True + >>> use_quantized = False + >>> engine = Engine3d(model_config, use_gpu=use_gpu, use_quantized=use_quantized) + >>> volume = np.random.rand(64, 64, 64) + >>> axis_name = "xy" + >>> stack, axis_name, trackers_dict = stack_inference(engine, volume, axis_name) + Note: + The `axis_name` parameter should be one of the following: + """ stack, trackers = engine.infer_on_axis(volume, axis_name) trackers_dict = {axis_name: trackers} return stack, axis_name, trackers_dict def orthoplane_inference(engine, volume): + """ + Perform inference on the orthogonal planes of a 3D volume. + + Args: + engine (Engine3d): The engine object. + volume (np.ndarray): The 3D volume to segment. + Returns: + dict: The trackers dictionary. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> import numpy as np + >>> from empanada_napari.inference import Engine3d + >>> from dacapo.blockwise.empanada_function import orthoplane_inference + >>> model_config = "MitoNet_v1" + >>> use_gpu = True + >>> use_quantized = False + >>> engine = Engine3d(model_config, use_gpu=use_gpu, use_quantized=use_quantized) + >>> volume = np.random.rand(64, 64, 64) + >>> trackers_dict = orthoplane_inference(engine, volume) + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + """ trackers_dict = {} for axis_name in ["xy", "xz", "yz"]: stack, trackers = engine.infer_on_axis(volume, axis_name) @@ -103,6 +186,93 @@ def empanada_segmenter( pixel_vote_thr=1, allow_one_view=False, ): + """ + Segment a 3D volume using the empanada-napari library. + + Args: + image (np.ndarray): The 3D volume to segment. + model_config (str): The model configuration to use. + use_gpu (bool): Whether to use the GPU. + use_quantized (bool): Whether to use quantized inference. + multigpu (bool): Whether to use multiple GPUs. + downsampling (int): The downsampling factor. + confidence_thr (float): The confidence threshold. + center_confidence_thr (float): The center confidence threshold. + min_distance_object_centers (int): The minimum distance between object centers. + fine_boundaries (bool): Whether to use fine boundaries. + semantic_only (bool): Whether to use semantic segmentation only. + median_slices (int): The number of median slices. + min_size (int): The minimum size of objects. + min_extent (int): The minimum extent. + maximum_objects_per_class (int): The maximum number of objects per class. + inference_plane (str): The inference plane. + orthoplane (bool): Whether to use orthoplane inference. + return_panoptic (bool): Whether to return the panoptic segmentation. + pixel_vote_thr (int): The pixel vote threshold. + allow_one_view (bool): Whether to allow one view. + Returns: + tuple: The volume, class name, and tracker. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> import numpy as np + >>> from empanada_napari.inference import Engine3d + >>> from dacapo.blockwise.empanada_function import empanada_segmenter + >>> image = np.random.rand(64, 64, 64) + >>> model_config = "MitoNet_v1" + >>> use_gpu = True + >>> use_quantized = False + >>> multigpu = False + >>> downsampling = 1 + >>> confidence_thr = 0.5 + >>> center_confidence_thr = 0.1 + >>> min_distance_object_centers = 21 + >>> fine_boundaries = True + >>> semantic_only = False + >>> median_slices = 11 + >>> min_size = 10000 + >>> min_extent = 50 + >>> maximum_objects_per_class = 1000000 + >>> inference_plane = "xy" + >>> orthoplane = True + >>> return_panoptic = False + >>> pixel_vote_thr = 1 + >>> allow_one_view = False + >>> for vol, class_name, tracker in empanada_segmenter( + ... image, + ... model_config=model_config, + ... use_gpu=use_gpu, + ... use_quantized=use_quantized, + ... multigpu=multigpu, + ... downsampling=downsampling, + ... confidence_thr=confidence_thr, + ... center_confidence_thr=center_confidence_thr, + ... min_distance_object_centers=min_distance_object_centers, + ... fine_boundaries=fine_boundaries, + ... semantic_only=semantic_only, + ... median_slices=median_slices, + ... min_size=min_size, + ... min_extent=min_extent, + ... maximum_objects_per_class=maximum_objects_per_class, + ... inference_plane=inference_plane, + ... orthoplane=orthoplane, + ... return_panoptic=return_panoptic, + ... pixel_vote_thr=pixel_vote_thr, + ... allow_one_view=allow_one_view + ... ): + ... print(vol.shape, class_name, tracker) + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + Reference: + - doi: 10.1016/j.cels.2022.12.006 + + """ # load the model config model_config = read_yaml(model_configs[model_config]) min_size = int(min_size) @@ -144,6 +314,22 @@ def empanada_segmenter( ) def start_postprocess_worker(*args): + """ + Start the postprocessing worker. + + Args: + *args: The arguments to pass to the worker. + Returns: + generator: The generator object. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> for vol, class_name, tracker in start_postprocess_worker(*args): + ... print(vol.shape, class_name, tracker) + Note: + The `args` parameter should be a tuple of arguments. + + """ trackers_dict = args[0][2] for vol, class_name, tracker in stack_postprocessing( trackers_dict, @@ -157,6 +343,21 @@ def start_postprocess_worker(*args): yield vol, class_name, tracker def start_consensus_worker(trackers_dict): + """ + Start the consensus worker. + + Args: + trackers_dict (dict): The trackers dictionary. + Returns: + generator: The generator object. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> for vol, class_name, tracker in start_consensus_worker(trackers_dict): + ... print(vol.shape, class_name, tracker) + Note: + The `trackers_dict` parameter should be a dictionary of trackers. + """ for vol, class_name, tracker in tracker_consensus( trackers_dict, model_config, @@ -202,8 +403,34 @@ def stack_postprocessing( min_extent=4, dtype=np.uint32, ): - r"""Relabels and filters each class defined in trackers. Yields a numpy + """ + Relabels and filters each class defined in trackers. Yields a numpy or zarr volume along with the name of the class that is segmented. + + Args: + trackers (dict): The trackers dictionary. + model_config (str): The model configuration to use. + label_divisor (int): The label divisor. + min_size (int): The minimum size of objects. + min_extent (int): The minimum extent of objects. + dtype (type): The data type. + Returns: + generator: The generator object. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> for vol, class_name, tracker in stack_postprocessing(trackers, model_config): + ... print(vol.shape, class_name, tracker) + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + Reference: + - doi: 10.1016/j.cels.2022.12.006 """ thing_list = model_config["thing_list"] class_names = model_config["class_names"] @@ -244,8 +471,36 @@ def tracker_consensus( min_extent=4, dtype=np.uint32, ): - r"""Calculate the orthoplane consensus from trackers. Yields a numpy + """ + Calculate the orthoplane consensus from trackers. Yields a numpy or zarr volume along with the name of the class that is segmented. + + Args: + trackers (dict): The trackers dictionary. + model_config (str): The model configuration to use. + pixel_vote_thr (int): The pixel vote threshold. + cluster_iou_thr (float): The cluster IoU threshold. + allow_one_view (bool): Whether to allow one view. + min_size (int): The minimum size of objects. + min_extent (int): The minimum extent of objects. + dtype (type): The data type. + Returns: + generator: The generator object. + Raises: + ImportError: If empanada-napari is not installed. + Examples: + >>> for vol, class_name, tracker in tracker_consensus(trackers, model_config): + ... print(vol.shape, class_name, tracker) + Note: + The `model_config` parameter should be one of the following: + - MitoNet_v1 + - MitoNet_v2 + - MitoNet_v3 + - MitoNet_v4 + - MitoNet_v5 + - MitoNet_v6 + Reference: + - doi: 10.1016/j.cels.2022.12.006 """ labels = model_config["labels"] thing_list = model_config["thing_list"] diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 739e8699a..c8b666734 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -1,5 +1,5 @@ import sys -from pathlib import Path +from upath import UPath as Path from typing import Optional import torch @@ -36,6 +36,14 @@ default="INFO", ) def cli(log_level): + """ + CLI for running the predict worker. + + The predict worker is used to apply a trained model to a dataset. + + Args: + log_level (str): The log level to use for logging. + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -68,7 +76,19 @@ def start_worker( input_dataset: str, output_container: Path | str, output_dataset: str, + return_io_loop: Optional[bool] = False, ): + """ + Start a worker to apply a trained model to a dataset. + + Args: + run_name (str): The name of the run to apply. + iteration (int or None): The training iteration of the model to use for prediction. + input_container (Path | str): The input container. + input_dataset (str): The input dataset. + output_container (Path | str): The output container. + output_dataset (str): The output dataset. + """ compute_context = create_compute_context() device = compute_context.device @@ -150,34 +170,40 @@ def start_worker( voxel_size=output_voxel_size, ) - daisy_client = daisy.Client() + def io_loop(): + daisy_client = daisy.Client() - while True: - with daisy_client.acquire_block() as block: - if block is None: - return + while True: + with daisy_client.acquire_block() as block: + if block is None: + return - print(f"Processing block {block}") + print(f"Processing block {block}") - chunk_request = request.copy() - chunk_request[raw].roi = block.read_roi - chunk_request[prediction].roi = block.write_roi + chunk_request = request.copy() + chunk_request[raw].roi = block.read_roi + chunk_request[prediction].roi = block.write_roi - with gp.build(pipeline): - batch = pipeline.request_batch(chunk_request) - # prediction: (1, [c,] d, h, w) - output = batch.arrays[prediction].data.squeeze() + with gp.build(pipeline): + batch = pipeline.request_batch(chunk_request) + # prediction: (1, [c,] d, h, w) + output = batch.arrays[prediction].data.squeeze() - # convert to uint8 if necessary: - if output_array.dtype == np.uint8: - if "sigmoid" not in str(model.eval_activation).lower(): - # assume output is in [-1, 1] - output += 1 - output /= 2 - output *= 255 - output = output.clip(0, 255) - output = output.astype(np.uint8) - output_array[block.write_roi] = output + # convert to uint8 if necessary: + if output_array.dtype == np.uint8: + if "sigmoid" not in str(model.eval_activation).lower(): + # assume output is in [-1, 1] + output += 1 + output /= 2 + output *= 255 + output = output.clip(0, 255) + output = output.astype(np.uint8) + output_array[block.write_roi] = output + + if return_io_loop: + return io_loop + else: + io_loop() def spawn_worker( @@ -186,15 +212,28 @@ def spawn_worker( input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: run_name (str): The name of the run to apply. iteration (int or None): The training iteration of the model to use for prediction. input_array_identifier (LocalArrayIdentifier): The raw data to predict on. output_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + Returns: + Callable: The function to run the worker. """ compute_context = create_compute_context() + if not compute_context.distribute_workers: + return start_worker( + run_name, + iteration, + input_array_identifier.container, + input_array_identifier.dataset, + output_array_identifier.container, + output_array_identifier.dataset, + return_io_loop=True, + ) # Make the command for the worker to run command = [ @@ -219,7 +258,9 @@ def spawn_worker( print("Defining worker with command: ", compute_context.wrap_command(command)) def run_worker(): - # Run the worker in the given compute context + """ + Run the worker in the given compute context. + """ print("Running worker with command: ", command) compute_context.execute(command) diff --git a/dacapo/blockwise/relabel_worker.py b/dacapo/blockwise/relabel_worker.py index b374f7120..423a878f7 100644 --- a/dacapo/blockwise/relabel_worker.py +++ b/dacapo/blockwise/relabel_worker.py @@ -24,6 +24,12 @@ default="INFO", ) def cli(log_level): + """ + CLI for running the relabel worker. + + Args: + log_level (str): The log level to use. + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -40,9 +46,18 @@ def start_worker( output_container, output_dataset, tmpdir, + return_io_loop=False, *args, **kwargs, ): + """ + Start the relabel worker. + + Args: + output_container (str): The output container + output_dataset (str): The output dataset + tmpdir (str): The temporary directory + """ client = daisy.Client() array_out = open_ds(output_container, output_dataset, mode="a") @@ -50,22 +65,38 @@ def start_worker( components = find_components(nodes, edges) - while True: - with client.acquire_block() as block: - if block is None: - break + def io_loop(): + client = daisy.Client() + while True: + with client.acquire_block() as block: + if block is None: + break + + try: + relabel_in_block(array_out, nodes, components, block) + except OSError as e: + logging.error( + f"Failed to relabel block {block.write_roi}: {e}. Trying again." + ) + sleep(1) + relabel_in_block(array_out, nodes, components, block) - try: - relabel_in_block(array_out, nodes, components, block) - except OSError as e: - logging.error( - f"Failed to relabel block {block.write_roi}: {e}. Trying again." - ) - sleep(1) - relabel_in_block(array_out, nodes, components, block) + if return_io_loop: + return io_loop + else: + io_loop() def relabel_in_block(array_out, old_values, new_values, block): + """ + Relabel the array in the given block. + + Args: + array_out (np.ndarray): The output array + old_values (np.ndarray): The old values + new_values (np.ndarray): The new values + block (daisy.Block): The block + """ a = array_out.to_ndarray(block.write_roi) # DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input if old_values.size > 0: @@ -74,6 +105,15 @@ def relabel_in_block(array_out, old_values, new_values, block): def find_components(nodes, edges): + """ + Find the components. + + Args: + nodes (np.ndarray): The nodes + edges (np.ndarray): The edges + Returns: + List[int]: The components + """ # scipy disjoint_set = DisjointSet(nodes) for edge in edges: @@ -82,6 +122,14 @@ def find_components(nodes, edges): def read_cross_block_merges(tmpdir): + """ + Read the cross block merges. + + Args: + tmpdir (str): The temporary directory + Returns: + Tuple[np.ndarray, np.ndarray]: The nodes and edges + """ block_files = glob(os.path.join(tmpdir, "block_*.npz")) nodes = [] @@ -100,14 +148,25 @@ def spawn_worker( *args, **kwargs, ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: output_array_identifier (LocalArrayIdentifier): The output array identifier tmpdir (str): The temporary directory + Returns: + Callable: The function to run the worker """ compute_context = create_compute_context() + if not compute_context.distribute_workers: + return start_worker( + output_array_identifier.container, + output_array_identifier.dataset, + tmpdir, + return_io_loop=True, + ) + # Make the command for the worker to run command = [ # "python", @@ -123,7 +182,9 @@ def spawn_worker( ] def run_worker(): - # Run the worker in the given compute context + """ + Run the worker in the given compute context. + """ compute_context.execute(command) return run_worker diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index ddea38280..b2a015a75 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -1,4 +1,4 @@ -from pathlib import Path +from upath import UPath as Path import shutil import tempfile import time @@ -28,46 +28,34 @@ def run_blockwise( *args, **kwargs, ): - """Run a function in parallel over a large volume. + """ + Run a function in parallel over a large volume. Args: - worker_file (``str`` or ``Path``): - The path to the file containing the necessary worker functions: ``spawn_worker`` and ``start_worker``. Optionally, the file can also contain a ``check_function`` and an ``init_callback_fn``. - total_roi (``Roi``): The ROI to process. - read_roi (``Roi``): The ROI to read from for a block. - write_roi (``Roi``): The ROI to write to for a block. - num_workers (``int``): - - The number of workers to use. - + The number of workers to use. max_retries (``int``): - - The maximum number of times a task will be retried if failed - (either due to failed post check or application crashes or network - failure) - + The maximum number of times a task will be retried if failed + (either due to failed post check or application crashes or network + failure) *args: - Additional positional arguments to pass to ``worker_function``. - **kwargs: - Additional keyword arguments to pass to ``worker_function``. - Returns: - - ``Bool``. + ``Bool``. + Examples: + >>> run_blockwise(worker_file, total_roi, read_roi, write_roi, num_workers, max_retries, timeout, upstream_tasks) """ @@ -100,61 +88,46 @@ def segment_blockwise( max_retries: int = 2, timeout=None, upstream_tasks=None, + keep_tmpdir=False, *args, **kwargs, ): - """Run a segmentation function in parallel over a large volume. + """ + Run a segmentation function in parallel over a large volume. Args: - - segment_function_file (``str`` or ``Path``): - - The path to the file containing the necessary worker functions: - ``spawn_worker`` and ``start_worker``. - Optionally, the file can also contain a ``check_function`` and an ``init_callback_fn``. - - context (``Coordinate``): - - The context to add to the read and write ROI. - - total_roi (``Roi``): - The ROI to process. - - read_roi (``Roi``): - The ROI to read from for a block. - - write_roi (``Roi``): - The ROI to write to for a block. - - num_workers (``int``): - - The number of workers to use. - - max_retries (``int``): - - The maximum number of times a task will be retried if failed - (either due to failed post check or application crashes or network - failure) - - timeout (``int``): - - The maximum time in seconds to wait for a worker to complete a task. - - upstream_tasks (``List``): - - List of upstream tasks. - - *args: - - Additional positional arguments to pass to ``worker_function``. - - **kwargs: - - Additional keyword arguments to pass to ``worker_function``. - + segment_function_file (``str`` or ``Path``): + The path to the file containing the necessary worker functions: + ``spawn_worker`` and ``start_worker``. + Optionally, the file can also contain a ``check_function`` and an ``init_callback_fn``. + context (``Coordinate``): + The context to add to the read and write ROI. + total_roi (``Roi``): + The ROI to process. + read_roi (``Roi``): + The ROI to read from for a block. + write_roi (``Roi``): + The ROI to write to for a block. + num_workers (``int``): + The number of workers to use. + max_retries (``int``): + The maximum number of times a task will be retried if failed + (either due to failed post check or application crashes or network + failure) + timeout (``int``): + The maximum time in seconds to wait for a worker to complete a task. + upstream_tasks (``List``): + List of upstream tasks. + keep_tmpdir (``bool``): + Whether to keep the temporary directory. + *args: + Additional positional arguments to pass to ``worker_function``. + **kwargs: + Additional keyword arguments to pass to ``worker_function``. Returns: - ``Bool``. + Examples: + >>> segment_blockwise(segment_function_file, context, total_roi, read_roi, write_roi, num_workers, max_retries, timeout, upstream_tasks) """ options = Options.instance() if not options.runs_base_dir.exists(): @@ -219,5 +192,50 @@ def segment_blockwise( success = success and daisy.run_blockwise([task]) - shutil.rmtree(tmpdir, ignore_errors=True) + if success and not keep_tmpdir: + shutil.rmtree(tmpdir, ignore_errors=True) + else: + # Write a relabel script to tmpdir + output_container = kwargs["output_array_identifier"].container + output_dataset = kwargs["output_array_identifier"].dataset + out_string = "from dacapo.blockwise import DaCapoBlockwiseTask\n" + out_string += ( + "from dacapo.store.local_array_store import LocalArrayIdentifier\n" + ) + out_string += "import daisy\n" + out_string += "from funlib.geometry import Roi, Coordinate\n" + out_string += "from upath import UPath as Path\n" + out_string += f"output_array_identifier = LocalArrayIdentifier(Path({output_container}), {output_dataset})\n" + out_string += ( + f"total_roi = Roi({total_roi.get_begin()}, {total_roi.get_shape()})\n" + ) + out_string += ( + f"read_roi = Roi({read_roi.get_begin()}, {read_roi.get_shape()})\n" + ) + out_string += ( + f"write_roi = Roi({write_roi.get_begin()}, {write_roi.get_shape()})\n" + ) + out_string += "task = DaCapoBlockwiseTask(\n" + out_string += f' "{str(Path(Path(dacapo.blockwise.__file__).parent, "relabel_worker.py"))}"),\n' + out_string += " total_roi,\n" + out_string += " read_roi,\n" + out_string += " write_roi,\n" + out_string += f" {num_workers},\n" + out_string += f" {max_retries},\n" + out_string += f" {timeout},\n" + out_string += f" tmpdir={tmpdir},\n" + out_string += f" output_array_identifier=output_array_identifier,\n" + out_string += ")\n" + out_string += "success = daisy.run_blockwise([task])\n" + out_string += "if success:\n" + out_string += f" shutil.rmtree({tmpdir}, ignore_errors=True)\n" + out_string += "else:\n" + out_string += ' print("Relabeling failed")\n' + with open(Path(tmpdir, "relabel.py"), "w") as f: + f.write(out_string) + raise RuntimeError( + f"Blockwise segmentation failed. Can rerun with merge files stored at:\n\t{tmpdir}" + f"Use read_roi: {read_roi} and write_roi: {write_roi} to rerun." + f"Or simply run the script at {Path(tmpdir, 'relabel.py')}" + ) return success diff --git a/dacapo/blockwise/segment_worker.py b/dacapo/blockwise/segment_worker.py index da1e0c098..0ecec2e19 100644 --- a/dacapo/blockwise/segment_worker.py +++ b/dacapo/blockwise/segment_worker.py @@ -1,7 +1,7 @@ from importlib.machinery import SourceFileLoader import logging import os -from pathlib import Path +from upath import UPath as Path import sys import click import daisy @@ -26,6 +26,12 @@ default="INFO", ) def cli(log_level): + """ + CLI for running the segment worker. + + Args: + log_level (str): The log level to use. + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -48,8 +54,10 @@ def start_worker( output_dataset: str, tmpdir: str, function_path: str, + return_io_loop: bool = False, ): - """Start a worker to run a segment function on a given dataset. + """ + Start a worker to run a segment function on a given dataset. Args: input_container (str): The input container. @@ -58,6 +66,7 @@ def start_worker( output_dataset (str): The output dataset. tmpdir (str): The temporary directory. function_path (str): The path to the segment function. + return_io_loop (bool): Whether to return the io loop or run it. """ print("Starting worker") @@ -91,91 +100,97 @@ def start_worker( parameters.update(yaml.safe_load(f)) # wait for blocks to run pipeline - client = daisy.Client() - num_voxels_in_block = None - - while True: - with client.acquire_block() as block: - if block is None: - break - if num_voxels_in_block is None: - num_voxels_in_block = np.prod(block.write_roi.size) - - segmentation = segment_function(input_array, block, **parameters) - - assert ( - segmentation.dtype == np.uint64 - ), "Instance segmentations returned by segment_function is expected to be uint64" - - id_bump = block.block_id[1] * num_voxels_in_block - segmentation += id_bump - segmentation[segmentation == id_bump] = 0 - - # wrap segmentation into daisy array - segmentation = Array( - segmentation, roi=block.read_roi, voxel_size=input_array.voxel_size - ) - - # store segmentation in out array - output_array[block.write_roi] = segmentation[block.write_roi] - - neighbor_roi = block.write_roi.grow( - input_array.voxel_size, input_array.voxel_size - ) - - # clip segmentation to 1-voxel context - segmentation = segmentation.to_ndarray(roi=neighbor_roi, fill_value=0) - neighbors = output_array._daisy_array.to_ndarray( - roi=neighbor_roi, fill_value=0 - ) - - unique_pairs = [] - - for d in range(3): - slices_neg = tuple( - slice(None) if dd != d else slice(0, 1) for dd in range(3) - ) - slices_pos = tuple( - slice(None) if dd != d else slice(-1, None) for dd in range(3) + def io_loop(): + client = daisy.Client() + num_voxels_in_block = None + + while True: + with client.acquire_block() as block: + if block is None: + break + if num_voxels_in_block is None: + num_voxels_in_block = np.prod(block.write_roi.size) + + segmentation = segment_function(input_array, block, **parameters) + + assert ( + segmentation.dtype == np.uint64 + ), "Instance segmentations returned by segment_function is expected to be uint64" + + id_bump = block.block_id[1] * num_voxels_in_block + segmentation += id_bump + segmentation[segmentation == id_bump] = 0 + + # wrap segmentation into daisy array + segmentation = Array( + segmentation, roi=block.read_roi, voxel_size=input_array.voxel_size ) - pairs_neg = np.array( - [ - segmentation[slices_neg].flatten(), - neighbors[slices_neg].flatten(), - ] - ) - pairs_neg = pairs_neg.transpose() + # store segmentation in out array + output_array[block.write_roi] = segmentation[block.write_roi] - pairs_pos = np.array( - [ - segmentation[slices_pos].flatten(), - neighbors[slices_pos].flatten(), - ] + neighbor_roi = block.write_roi.grow( + input_array.voxel_size, input_array.voxel_size ) - pairs_pos = pairs_pos.transpose() - unique_pairs.append( - np.unique(np.concatenate([pairs_neg, pairs_pos]), axis=0) + # clip segmentation to 1-voxel context + segmentation = segmentation.to_ndarray(roi=neighbor_roi, fill_value=0) + neighbors = output_array._daisy_array.to_ndarray( + roi=neighbor_roi, fill_value=0 ) - unique_pairs = np.concatenate(unique_pairs) - zero_u = unique_pairs[:, 0] == 0 # type: ignore - zero_v = unique_pairs[:, 1] == 0 # type: ignore - non_zero_filter = np.logical_not(np.logical_or(zero_u, zero_v)) - - edges = unique_pairs[non_zero_filter] - nodes = np.unique(edges) - - assert os.path.exists(tmpdir) - path = os.path.join(tmpdir, f"block_{block.block_id[1]}.npz") - print(f"Writing ids to {path}") - with open(path, "wb") as f: - np.savez_compressed( - f, - nodes=nodes, - edges=edges, - ) + unique_pairs = [] + + for d in range(3): + slices_neg = tuple( + slice(None) if dd != d else slice(0, 1) for dd in range(3) + ) + slices_pos = tuple( + slice(None) if dd != d else slice(-1, None) for dd in range(3) + ) + + pairs_neg = np.array( + [ + segmentation[slices_neg].flatten(), + neighbors[slices_neg].flatten(), + ] + ) + pairs_neg = pairs_neg.transpose() + + pairs_pos = np.array( + [ + segmentation[slices_pos].flatten(), + neighbors[slices_pos].flatten(), + ] + ) + pairs_pos = pairs_pos.transpose() + + unique_pairs.append( + np.unique(np.concatenate([pairs_neg, pairs_pos]), axis=0) + ) + + unique_pairs = np.concatenate(unique_pairs) + zero_u = unique_pairs[:, 0] == 0 # type: ignore + zero_v = unique_pairs[:, 1] == 0 # type: ignore + non_zero_filter = np.logical_not(np.logical_or(zero_u, zero_v)) + + edges = unique_pairs[non_zero_filter] + nodes = np.unique(edges) + + assert os.path.exists(tmpdir) + path = os.path.join(tmpdir, f"block_{block.block_id[1]}.npz") + print(f"Writing ids to {path}") + with open(path, "wb") as f: + np.savez_compressed( + f, + nodes=nodes, + edges=edges, + ) + + if return_io_loop: + return io_loop + else: + io_loop() def spawn_worker( @@ -184,14 +199,27 @@ def spawn_worker( tmpdir: str, function_path: str, ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + Returns: + Callable: The function to run the worker. """ compute_context = create_compute_context() + if not compute_context.distribute_workers: + return start_worker( + input_array_identifier.container, + input_array_identifier.dataset, + output_array_identifier.container, + output_array_identifier.dataset, + tmpdir, + function_path, + return_io_loop=True, + ) # Make the command for the worker to run command = [ @@ -214,7 +242,9 @@ def spawn_worker( ] def run_worker(): - # Run the worker in the given compute context + """ + Run the worker in the given compute context. + """ compute_context.execute(command) return run_worker diff --git a/dacapo/blockwise/threshold_worker.py b/dacapo/blockwise/threshold_worker.py index 3ff08c1e6..3e05f13cc 100644 --- a/dacapo/blockwise/threshold_worker.py +++ b/dacapo/blockwise/threshold_worker.py @@ -1,4 +1,4 @@ -from pathlib import Path +from upath import UPath as Path import sys from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier @@ -49,7 +49,19 @@ def start_worker( output_container: Path | str, output_dataset: str, threshold: float = 0.0, + return_io_loop: bool = False, ): + """ + Start the threshold worker. + + Args: + input_container (Path | str): The input container. + input_dataset (str): The input dataset. + output_container (Path | str): The output container. + output_dataset (str): The output dataset. + threshold (float): The threshold. + + """ # get arrays input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) input_array = ZarrArray.open_from_array_identifier(input_array_identifier) @@ -59,19 +71,25 @@ def start_worker( ) output_array = ZarrArray.open_from_array_identifier(output_array_identifier) - # wait for blocks to run pipeline - client = daisy.Client() + def io_loop(): + # wait for blocks to run pipeline + client = daisy.Client() + + while True: + print("getting block") + with client.acquire_block() as block: + if block is None: + break - while True: - print("getting block") - with client.acquire_block() as block: - if block is None: - break + # write to output array + output_array[block.write_roi] = ( + input_array[block.write_roi] > threshold + ).astype(np.uint8) - # write to output array - output_array[block.write_roi] = ( - input_array[block.write_roi] > threshold - ).astype(np.uint8) + if return_io_loop: + return io_loop + else: + io_loop() def spawn_worker( @@ -79,18 +97,28 @@ def spawn_worker( output_array_identifier: "LocalArrayIdentifier", threshold: float = 0.0, ): - """Spawn a worker to predict on a given dataset. + """ + Spawn a worker to predict on a given dataset. Args: - model (Model): The model to use for prediction. - raw_array (Array): The raw data to predict on. - prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + input_array_identifier (LocalArrayIdentifier): The raw data to predict on. + output_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + threshold (float): The threshold. + Returns: + Callable: The function to run the worker. """ compute_context = create_compute_context() + if not compute_context.distribute_workers: + return start_worker( + input_array_identifier.container, + input_array_identifier.dataset, + output_array_identifier.container, + output_array_identifier.dataset, + return_io_loop=True, + ) # Make the command for the worker to run command = [ - # "python", sys.executable, path, "start-worker", @@ -107,7 +135,9 @@ def spawn_worker( ] def run_worker(): - # Run the worker in the given compute context + """ + Run the worker in the given compute context. + """ compute_context.execute(command) return run_worker diff --git a/dacapo/blockwise/watershed_function.py b/dacapo/blockwise/watershed_function.py index 0c5deae6f..05440e228 100644 --- a/dacapo/blockwise/watershed_function.py +++ b/dacapo/blockwise/watershed_function.py @@ -5,6 +5,30 @@ def segment_function(input_array, block, offsets, bias): + """ + Segment the input array using the multicut watershed algorithm. + + Args: + input_array (np.ndarray): The input array. + block (daisy.Block): The block to be processed. + offsets (List[Tuple[int]]): The offsets. + bias (float): The bias. + Returns: + np.ndarray: The segmented array. + Examples: + >>> input_array = np.random.rand(128, 128, 128) + >>> total_roi = daisy.Roi((0, 0, 0), (128, 128, 128)) + >>> read_roi = daisy.Roi((0, 0, 0), (64, 64, 64)) + >>> write_roi = daisy.Roi((0, 0, 0), (32, 32, 32)) + >>> block_id = 0 + >>> task_id = "task_id" + >>> block = daisy.Block(total_roi, read_roi, write_roi, block_id, task_id) + >>> offsets = [(0, 1, 0), (1, 0, 0), (0, 0, 1)] + >>> bias = 0.1 + >>> segmentation = segment_function(input_array, block, offsets, bias) + Note: + DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input + """ # if a previous segmentation is provided, it must have a "grid graph" # in its metadata. pred_data = input_array[block.read_roi] diff --git a/dacapo/cli.py b/dacapo/cli.py index a54fcd058..f4a2a6c41 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -1,4 +1,4 @@ -from pathlib import Path +from upath import UPath as Path from typing import Optional import numpy as np @@ -7,7 +7,6 @@ import click import logging from funlib.geometry import Roi, Coordinate -from funlib.persistence import open_ds from dacapo.experiments.datasplits.datasets.dataset import Dataset from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( PostProcessorParameters, @@ -29,6 +28,42 @@ default="INFO", ) def cli(log_level): + """ + Command-line interface for the DACAPO application. + + Args: + log_level (str): The desired log level for the application. + Examples: + To train a model, run: + ``` + dacapo train --run-name my_run + ``` + + To validate a model, run: + ``` + dacapo validate --run-name my_run --iteration 100 + ``` + + To apply a model, run: + ``` + dacapo apply --run-name my_run --input-container /path/to/input --input-dataset my_dataset --output-path /path/to/output + ``` + + To predict with a model, run: + ``` + dacapo predict --run-name my_run --iteration 100 --input-container /path/to/input --input-dataset my_dataset --output-path /path/to/output + ``` + + To run a blockwise operation, run: + ``` + dacapo run-blockwise --input-container /path/to/input --input-dataset my_dataset --output-container /path/to/output --output-dataset my_output --worker-file /path/to/worker.py --total-roi [0:100,0:100,0:100] --read-roi-size [10,10,10] --write-roi-size [10,10,10] --num-workers 16 + ``` + + To segment blockwise, run: + ``` + dacapo segment-blockwise --input-container /path/to/input --input-dataset my_dataset --output-container /path/to/output --output-dataset my_output --segment-function-file /path/to/segment_function.py --total-roi [0:100,0:100,0:100] --read-roi-size [10,10,10] --write-roi-size [10,10,10] --num-workers 16 + ``` + """ logging.basicConfig(level=getattr(logging, log_level.upper())) @@ -70,13 +105,49 @@ def validate(run_name, iteration): "--input_container", required=True, type=click.Path(exists=True, file_okay=False), + help="The path to the input container.", +) +@click.option( + "-id", + "--input_dataset", + required=True, + type=str, + help="The name of the input dataset.", +) +@click.option( + "-op", + "--output_path", + required=True, + type=click.Path(file_okay=False), + help="The path to the output directory.", +) +@click.option( + "-vd", + "--validation_dataset", + type=str, + default=None, + help="The name of the validation dataset.", +) +@click.option( + "-c", + "--criterion", + default="voi", + help="The criterion to use for applying the run.", +) +@click.option( + "-i", + "--iteration", + type=int, + default=None, + help="The iteration of the model to use for prediction.", +) +@click.option( + "-p", + "--parameters", + type=str, + default=None, + help="The parameters for the post-processor.", ) -@click.option("-id", "--input_dataset", required=True, type=str) -@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) -@click.option("-vd", "--validation_dataset", type=str, default=None) -@click.option("-c", "--criterion", default="voi") -@click.option("-i", "--iteration", type=int, default=None) -@click.option("-p", "--parameters", type=str, default=None) @click.option( "-roi", "--roi", @@ -84,9 +155,26 @@ def validate(run_name, iteration): required=False, help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", ) -@click.option("-w", "--num_workers", type=int, default=30) -@click.option("-dt", "--output_dtype", type=str, default="uint8") -@click.option("-ow", "--overwrite", is_flag=True) +@click.option( + "-w", + "--num_workers", + type=int, + default=30, + help="The number of workers to use for prediction.", +) +@click.option( + "-dt", + "--output_dtype", + type=str, + default="uint8", + help="The output data type.", +) +@click.option( + "-ow", + "--overwrite", + is_flag=True, + help="Whether to overwrite existing output files.", +) def apply( run_name: str, input_container: Path | str, @@ -101,6 +189,30 @@ def apply( output_dtype: np.dtype | str = "uint8", overwrite: bool = True, ): + """ + Apply a trained run to an input dataset. + + Args: + run_name (str): The name of the run to apply. + input_container (Path | str): The path to the input container. + input_dataset (str): The name of the input dataset. + output_path (Path | str): The path to the output directory. + validation_dataset (Dataset | str, optional): The name of the validation dataset. Defaults to None. + criterion (str, optional): The criterion to use for applying the run. Defaults to "voi". + iteration (int, optional): The iteration of the model to use for prediction. Defaults to None. + parameters (PostProcessorParameters | str, optional): The parameters for the post-processor. Defaults to None. + roi (Roi | str, optional): The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]. Defaults to None. + num_workers (int, optional): The number of workers to use for prediction. Defaults to 30. + output_dtype (np.dtype | str, optional): The output data type. Defaults to "uint8". + overwrite (bool, optional): Whether to overwrite existing output files. Defaults to True. + Raises: + ValueError: If the run_name is not valid. + Examples: + To apply a trained run to an input dataset, run: + ``` + dacapo apply --run-name my_run --input-container /path/to/input --input-dataset my_dataset --output-path /path/to/output + ``` + """ dacapo.apply( run_name, input_container, @@ -133,9 +245,22 @@ def apply( "--input_container", required=True, type=click.Path(exists=True, file_okay=False), + help="The path to the input container.", +) +@click.option( + "-id", + "--input_dataset", + required=True, + type=str, + help="The name of the input dataset.", +) +@click.option( + "-op", + "--output_path", + required=True, + type=click.Path(file_okay=False), + help="The path to the output directory.", ) -@click.option("-id", "--input_dataset", required=True, type=str) -@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False)) @click.option( "-roi", "--output_roi", @@ -143,9 +268,22 @@ def apply( required=False, help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]", ) -@click.option("-w", "--num_workers", type=int, default=30) -@click.option("-dt", "--output_dtype", type=str, default="uint8") -@click.option("-ow", "--overwrite", is_flag=True) +@click.option( + "-w", + "--num_workers", + type=int, + default=30, + help="The number of workers to use for prediction.", +) +@click.option( + "-dt", "--output_dtype", type=str, default="uint8", help="The output data type." +) +@click.option( + "-ow", + "--overwrite", + is_flag=True, + help="Whether to overwrite existing output files.", +) def predict( run_name: str, iteration: int, @@ -157,6 +295,27 @@ def predict( output_dtype: np.dtype | str = np.uint8, # type: ignore overwrite: bool = True, ): + """ + Apply a trained model to predict on a dataset. + + Args: + run_name (str): The name of the run to apply. + iteration (int): The training iteration of the model to use for prediction. + input_container (Path | str): The path to the input container. + input_dataset (str): The name of the input dataset. + output_path (Path | str): The path to the output directory. + output_roi (Optional[str | Roi], optional): The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]. Defaults to None. + num_workers (int, optional): The number of workers to use for prediction. Defaults to 30. + output_dtype (np.dtype | str, optional): The output data type. Defaults to np.uint8. + overwrite (bool, optional): Whether to overwrite existing output files. Defaults to True. + Raises: + ValueError: If the run_name is not valid. + Examples: + To predict with a model, run: + ``` + dacapo predict --run-name my_run --iteration 100 --input-container /path/to/input --input-dataset my_dataset --output-path /path/to/output + ``` + """ dacapo.predict( run_name, iteration, @@ -181,12 +340,29 @@ def predict( "--input_container", required=True, type=click.Path(exists=True, file_okay=False), + help="The path to the input container.", +) +@click.option( + "-id", + "--input_dataset", + required=True, + type=str, + help="The name of the input dataset.", +) +@click.option( + "-oc", + "--output_container", + required=True, + type=click.Path(file_okay=False), + help="The path to the output container.", ) -@click.option("-id", "--input_dataset", required=True, type=str) @click.option( - "-oc", "--output_container", required=True, type=click.Path(file_okay=False) + "-od", + "--output_dataset", + required=True, + type=str, + help="The name of the output dataset.", ) -@click.option("-od", "--output_dataset", required=True, type=str) @click.option( "-w", "--worker_file", required=True, type=str, help="The path to the worker file." ) @@ -196,7 +372,6 @@ def predict( required=True, type=str, help="The total roi to be processed. Format is [start:end, start:end, ... ] in voxels. Defaults to the roi of the input dataset. Do not use spaces in CLI argument.", - default=None, ) @click.option( "-rr", @@ -212,12 +387,30 @@ def predict( type=str, help="The size of the roi to be written for each block, in the format of [z,y,x] in voxels.", ) -@click.option("-nw", "--num_workers", type=int, default=16) -@click.option("-mr", "--max_retries", type=int, default=2) -@click.option("-t", "--timeout", type=int, default=None) -@click.option("-ow", "--overwrite", is_flag=True, default=True) -@click.option("-co", "-channels_out", type=int, default=None) -@click.option("-dt", "--output_dtype", type=str, default="uint8") +@click.option( + "-nw", "--num_workers", type=int, default=16, help="The number of workers to use." +) +@click.option( + "-mr", "--max_retries", type=int, default=2, help="The maximum number of retries." +) +@click.option("-t", "--timeout", type=int, default=None, help="The timeout in seconds.") +@click.option( + "-ow", + "--overwrite", + is_flag=True, + default=True, + help="Whether to overwrite existing output files.", +) +@click.option( + "-co", + "-channels_out", + type=int, + default=None, + help="The number of output channels.", +) +@click.option( + "-dt", "--output_dtype", type=str, default="uint8", help="The output data type." +) @click.pass_context def run_blockwise( ctx, @@ -238,6 +431,32 @@ def run_blockwise( *args, **kwargs, ): + """ + Run blockwise processing on a dataset. + + Args: + input_container: The path to the input container. + input_dataset: The name of the input dataset. + output_container: The path to the output container. + output_dataset: The name of the output dataset. + worker_file: The path to the worker file. + total_roi: The total roi to be processed. Format is [start:end, start:end, ... ] in voxels. Defaults to the roi of the input dataset. Do not use spaces in CLI argument. + read_roi_size: The size of the roi to be read for each block, in the format of [z,y,x] in voxels. + write_roi_size: The size of the roi to be written for each block, in the format of [z,y,x] in voxels. + num_workers: The number of workers to use. + max_retries: The maximum number of retries. + timeout: The timeout in seconds. + overwrite: Whether to overwrite existing output files. + channels_out: The number of output channels. + output_dtype: The output data type. + Raises: + ValueError: If the run_name is not valid. + Examples: + To run a blockwise operation, run: + ``` + dacapo run-blockwise --input-container /path/to/input --input-dataset my_dataset --output-container /path/to/output --output-dataset my_output --worker-file /path/to/worker.py --total-roi [0:100,0:100,0:100] --read-roi-size [10,10,10] --write-roi-size [10,10,10] --num-workers 16 + ``` + """ # get arbitrary args and kwargs parameters = unpack_ctx(ctx) @@ -291,13 +510,36 @@ def run_blockwise( "--input_container", required=True, type=click.Path(exists=True, file_okay=False), + help="The path to the input container.", +) +@click.option( + "-id", + "--input_dataset", + required=True, + type=str, + help="The name of the input dataset.", +) +@click.option( + "-oc", + "--output_container", + required=True, + type=click.Path(file_okay=False), + help="The path to the output container.", +) +@click.option( + "-od", + "--output_dataset", + required=True, + type=str, + help="The name of the output dataset.", ) -@click.option("-id", "--input_dataset", required=True, type=str) @click.option( - "-oc", "--output_container", required=True, type=click.Path(file_okay=False) + "-sf", + "--segment_function_file", + required=True, + type=click.Path(), + help="The path to the segment function file.", ) -@click.option("-od", "--output_dataset", required=True, type=str) -@click.option("-sf", "--segment_function_file", required=True, type=click.Path()) @click.option( "-tr", "--total_roi", @@ -326,11 +568,27 @@ def run_blockwise( help="The context to be used, in the format of [z,y,x] in voxels. Defaults to the difference between the read and write rois.", default=None, ) -@click.option("-nw", "--num_workers", type=int, default=16) -@click.option("-mr", "--max_retries", type=int, default=2) -@click.option("-t", "--timeout", type=int, default=None) -@click.option("-ow", "--overwrite", is_flag=True, default=True) -@click.option("-co", "--channels_out", type=int, default=None) +@click.option( + "-nw", "--num_workers", type=int, default=16, help="The number of workers to use." +) +@click.option( + "-mr", "--max_retries", type=int, default=2, help="The maximum number of retries." +) +@click.option("-t", "--timeout", type=int, default=None, help="The timeout in seconds.") +@click.option( + "-ow", + "--overwrite", + is_flag=True, + default=True, + help="Whether to overwrite existing output files.", +) +@click.option( + "-co", + "--channels_out", + type=int, + default=None, + help="The number of output channels.", +) @click.pass_context def segment_blockwise( ctx, @@ -351,6 +609,32 @@ def segment_blockwise( *args, **kwargs, ): + """ + Segment the input dataset blockwise using a segment function file. + + Args: + input_container (str): The path to the input container. + input_dataset (str): The name of the input dataset. + output_container (str): The path to the output container. + output_dataset (str): The name of the output dataset. + segment_function_file (str): The path to the segment function file. + total_roi (str): The total roi to be processed. Format is [start:end,start:end,...] in voxels. Defaults to the roi of the input dataset. Do not use spaces in CLI argument. + read_roi_size (str): The size of the roi to be read for each block, in the format of [z,y,x] in voxels. + write_roi_size (str): The size of the roi to be written for each block, in the format of [z,y,x] in voxels. + context (str, optional): The context to be used, in the format of [z,y,x] in voxels. Defaults to the difference between the read and write rois. + num_workers (int, optional): The number of workers to use. Defaults to 16. + max_retries (int, optional): The maximum number of retries. Defaults to 2. + timeout (int, optional): The timeout in seconds. Defaults to None. + overwrite (bool, optional): Whether to overwrite existing output files. Defaults to True. + channels_out (int, optional): The number of output channels. Defaults to None. + Raises: + ValueError: If the run_name is not valid. + Examples: + To segment blockwise, run: + ``` + dacapo segment-blockwise --input-container /path/to/input --input-dataset my_dataset --output-container /path/to/output --output-dataset my_output --segment-function-file /path/to/segment_function.py --total-roi [0:100,0:100,0:100] --read-roi-size [10,10,10] --write-roi-size [10,10,10] --num-workers 16 + ``` + """ # get arbitrary args and kwargs parameters = unpack_ctx(ctx) @@ -403,7 +687,21 @@ def segment_blockwise( def unpack_ctx(ctx): - # print(ctx.args) + """ + Unpacks the context object and returns a dictionary of keyword arguments. + + Args: + ctx (object): The context object containing the arguments. + Returns: + dict: A dictionary of keyword arguments. + Raises: + ValueError: If the run_name is not valid. + Example: + >>> ctx = ... + >>> kwargs = unpack_ctx(ctx) + >>> print(kwargs) + {'arg1': value1, 'arg2': value2, ...} + """ kwargs = { ctx.args[i].lstrip("-"): ctx.args[i + 1] for i in range(0, len(ctx.args), 2) } @@ -413,11 +711,25 @@ def unpack_ctx(ctx): elif v.replace(".", "").isnumeric(): kwargs[k] = float(v) print(f"{k}: {kwargs[k]}") - # print(f"{type(k)}: {k} --> {type(kwargs[k])} {kwargs[k]}") return kwargs def get_rois(total_roi, read_roi_size, write_roi_size, input_array): + """ + Get the ROIs for processing. + + Args: + total_roi (str): The total ROI to be processed. + read_roi_size (str): The size of the ROI to be read for each block. + write_roi_size (str): The size of the ROI to be written for each block. + input_array (ZarrArray): The input array. + Returns: + tuple: A tuple containing the total ROI, read ROI, write ROI, and context. + Raises: + ValueError: If the run_name is not valid. + Example: + >>> total_roi, read_roi, write_roi, context = get_rois(total_roi, read_roi_size, write_roi_size, input_array) + """ if total_roi is not None: # parse the string into a Roi start, end = zip( diff --git a/dacapo/compute_context/bsub.py b/dacapo/compute_context/bsub.py index a3fb6aac5..3ea5dfdba 100644 --- a/dacapo/compute_context/bsub.py +++ b/dacapo/compute_context/bsub.py @@ -1,5 +1,5 @@ import os -from pathlib import Path +from upath import UPath as Path from .compute_context import ComputeContext import daisy @@ -10,6 +10,29 @@ @attr.s class Bsub(ComputeContext): + distribute_workers: Optional[bool] = attr.ib( + default=True, + metadata={ + "help_text": "Whether to distribute the workers across multiple nodes or processes." + }, + ) + """ + The Bsub class is a subclass of the ComputeContext class. It is used to specify the + context in which computations are to be done. Bsub is used to specify that + computations are to be done on a cluster using LSF. + + Attributes: + queue (str): The queue to run on. + num_gpus (int): The number of gpus to train on. Currently only 1 gpu can be used. + num_cpus (int): The number of cpus to use to generate training data. + billing (Optional[str]): Project name that will be paying for this Job. + Methods: + device(): Returns the device on which computations are to be done. + _wrap_command(command): Wraps a command in the context specific command. + Note: + The class is a subclass of the ComputeContext class. + + """ queue: str = attr.ib(default="local", metadata={"help_text": "The queue to run on"}) num_gpus: int = attr.ib( default=1, @@ -33,12 +56,35 @@ class Bsub(ComputeContext): @property def device(self): + """ + A property method that returns the device on which computations are to be done. + + A device can be a CPU, GPU, TPU, etc. It is used to specify the context in which computations are to be done. + + Returns: + str: The device on which computations are to be done. + Examples: + >>> context = Bsub() + >>> device = context.device + """ if self.num_gpus > 0: return "cuda" else: return "cpu" def _wrap_command(self, command): + """ + A helper method to wrap a command in the context specific command. + + Args: + command (List[str]): The command to be wrapped. + Returns: + List[str]: The wrapped command. + Examples: + >>> context = Bsub() + >>> command = ["python", "script.py"] + >>> wrapped_command = context._wrap_command(command) + """ try: client = daisy.Client() basename = str( @@ -53,8 +99,6 @@ def _wrap_command(self, command): f"{self.queue}", "-n", f"{self.num_cpus}", - "-gpu", - f"num={self.num_gpus}", "-J", "dacapo", "-o", @@ -62,6 +106,14 @@ def _wrap_command(self, command): "-e", f"{basename}.err", ] + + ( + [ + "-gpu", + f"num={self.num_gpus}", + ] + if self.num_gpus > 0 + else [] + ) + ( [ "-P", diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index 57b4c4064..4d88f8376 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -1,26 +1,109 @@ from abc import ABC, abstractmethod -import os +from typing import Optional +import attr import subprocess -import sys from dacapo import Options, compute_context class ComputeContext(ABC): + distribute_workers: Optional[bool] = attr.ib( + default=False, + metadata={ + "help_text": "Whether to distribute the workers across multiple nodes or processes." + }, + ) + """ + The ComputeContext class is an abstract base class for defining the context in which computations are to be done. + It is inherited from the built-in class `ABC` (Abstract Base Classes). Other classes can inherit this class to define + their own specific variations of the context. It requires to implement several property methods, and also includes + additional methods related to the context design. + + Attributes: + device: The device on which computations are to be done. + Methods: + _wrap_command(command): Wraps a command in the context specific command. + wrap_command(command): Wraps a command in the context specific command and returns it. + execute(command): Runs a command in the context specific way. + Note: + The class is abstract and requires to implement the abstract methods. + """ + @property @abstractmethod def device(self): + """ + Abstract property method to define the device on which computations are to be done. + + A device can be a CPU, GPU, TPU, etc. It is used to specify the context in which computations are to be done. + + Returns: + str: The device on which computations are to be done. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = ComputeContext() + >>> device = context.device + Note: + The method should be implemented in the derived class. + """ pass def _wrap_command(self, command): + """ + A helper method to wrap a command in the context specific command. + + Args: + command (List[str]): The command to be wrapped. + Returns: + List[str]: The wrapped command. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = ComputeContext() + >>> command = ["python", "script.py"] + >>> wrapped_command = context._wrap_command(command) + Note: + The method should be implemented in the derived class. + """ # A helper method to wrap a command in the context specific command. return command def wrap_command(self, command): + """ + A method to wrap a command in the context specific command and return it. + + Args: + command (List[str]): The command to be wrapped. + Returns: + List[str]: The wrapped command. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = ComputeContext() + >>> command = ["python", "script.py"] + >>> wrapped_command = context.wrap_command(command) + Note: + The method should be implemented in the derived class. + """ command = [str(com) for com in self._wrap_command(command)] return command def execute(self, command): + """ + A method to run a command in the context specific way. + + Args: + command (List[str]): The command to be executed. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> context = ComputeContext() + >>> command = ["python", "script.py"] + >>> context.execute(command) + Note: + The method should be implemented in the derived class. + """ # A helper method to run a command in the context specific way. # add pythonpath to the environment @@ -31,7 +114,18 @@ def execute(self, command): def create_compute_context() -> ComputeContext: - """Create a compute context based on the global DaCapo options.""" + """ + Create a compute context based on the global DaCapo options. + + Returns: + ComputeContext: The compute context object. + Raises: + ValueError: If the store type is unknown. + Examples: + >>> context = create_compute_context() + Note: + The method is implemented in the module. + """ options = Options.instance() diff --git a/dacapo/compute_context/local_torch.py b/dacapo/compute_context/local_torch.py index 330e1899a..2cda40582 100644 --- a/dacapo/compute_context/local_torch.py +++ b/dacapo/compute_context/local_torch.py @@ -15,6 +15,12 @@ class LocalTorch(ComputeContext): Attributes: _device (Optional[str]): This stores the type of device on which torch computations are to be done. It can take "cuda" for GPU or "cpu" for CPU. None value results in automatic detection of device type. + oom_limit (Optional[float | int]): The out of GPU memory to leave free in GB. If the free memory is below + this limit, we will fall back on CPU. + Methods: + device(): Returns the torch device object. + Note: + The class is a subclass of the ComputeContext class. """ _device: Optional[str] = attr.ib( @@ -37,6 +43,9 @@ def device(self): """ A property method that returns the torch device object. It automatically detects and uses "cuda" (GPU) if available, else it falls back on using "cpu". + + Returns: + torch.device: The torch device object. """ if self._device is None: if torch.cuda.is_available(): diff --git a/dacapo/experiments/architectures/architecture.py b/dacapo/experiments/architectures/architecture.py index 888030adb..0f188560e 100644 --- a/dacapo/experiments/architectures/architecture.py +++ b/dacapo/experiments/architectures/architecture.py @@ -11,6 +11,17 @@ class Architecture(torch.nn.Module, ABC): It is inherited from PyTorch's Module and built-in class `ABC` (Abstract Base Classes). Other classes can inherit this class to define their own specific variations of architecture. It requires to implement several property methods, and also includes additional methods related to the architecture design. + + Attributes: + input_shape (Coordinate): The spatial input shape for the neural network architecture. + eval_shape_increase (Coordinate): The amount to increase the input shape during prediction. + num_in_channels (int): The number of input channels required by the architecture. + num_out_channels (int): The number of output channels provided by the architecture. + Methods: + dims: Returns the number of dimensions of the input shape. + scale: Scales the input voxel size as required by the architecture. + Note: + The class is abstract and requires to implement the abstract methods. """ @property @@ -22,6 +33,14 @@ def input_shape(self) -> Coordinate: Returns: Coordinate: The spatial input shape. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> input_shape = Coordinate((128, 128, 128)) + >>> model = MyModel(input_shape) + Note: + The method should be implemented in the derived class. + """ pass @@ -32,6 +51,13 @@ def eval_shape_increase(self) -> Coordinate: Returns: Coordinate: An instance representing the amount to increase in each dimension of the input shape. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> eval_shape_increase = Coordinate((0, 0, 0)) + >>> model = MyModel(input_shape, eval_shape_increase) + Note: + The method is optional and can be overridden in the derived class. """ return Coordinate((0,) * self.input_shape.dims) @@ -43,6 +69,13 @@ def num_in_channels(self) -> int: Returns: int: Required number of input channels. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> num_in_channels = 1 + >>> model = MyModel(input_shape, num_in_channels) + Note: + The method should be implemented in the derived class. """ pass @@ -54,6 +87,14 @@ def num_out_channels(self) -> int: Returns: int: Number of output channels. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> num_out_channels = 1 + >>> model = MyModel(input_shape, num_out_channels) + Note: + The method should be implemented in the derived class. + """ pass @@ -64,6 +105,15 @@ def dims(self) -> int: Returns: int: The number of dimensions. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> input_shape = Coordinate((128, 128, 128)) + >>> model = MyModel(input_shape) + >>> model.dims + 3 + Note: + The method is optional and can be overridden in the derived class. """ return self.input_shape.dims @@ -73,8 +123,16 @@ def scale(self, input_voxel_size: Coordinate) -> Coordinate: Args: input_voxel_size (Coordinate): The original size of the input voxel. - Returns: Coordinate: The scaled voxel size. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> input_voxel_size = Coordinate((1, 1, 1)) + >>> model = MyModel(input_shape) + >>> model.scale(input_voxel_size) + Coordinate((1, 1, 1)) + Note: + The method is optional and can be overridden in the derived class. """ return input_voxel_size diff --git a/dacapo/experiments/architectures/architecture_config.py b/dacapo/experiments/architectures/architecture_config.py index 09455ce55..67ea080a2 100644 --- a/dacapo/experiments/architectures/architecture_config.py +++ b/dacapo/experiments/architectures/architecture_config.py @@ -5,18 +5,16 @@ @attr.s class ArchitectureConfig: """ - A class to represent the base configurations of any architecture. - - Attributes - ---------- - name : str - a unique name for the architecture. - - Methods - ------- - verify() - validates the given architecture. - + A class to represent the base configurations of any architecture. It is used to define the architecture of a neural network model. + + Attributes: + name : str + a unique name for the architecture. + Methods: + verify() + validates the given architecture. + Note: + The class is abstract and requires to implement the abstract methods. """ name: str = attr.ib( @@ -31,11 +29,15 @@ def verify(self) -> Tuple[bool, str]: """ A method to validate an architecture configuration. - Returns - ------- - bool - A flag indicating whether the config is valid or not. - str - A description of the architecture. + Returns: + Tuple[bool, str]: A tuple of a boolean indicating if the architecture is valid and a message. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> config = ArchitectureConfig("MyModel") + >>> is_valid, message = config.verify() + >>> print(is_valid, message) + Note: + The method should be implemented in the derived class. """ return True, "No validation for this Architecture" diff --git a/dacapo/experiments/architectures/cnnectome_unet.py b/dacapo/experiments/architectures/cnnectome_unet.py index bb2be3586..acb345c88 100644 --- a/dacapo/experiments/architectures/cnnectome_unet.py +++ b/dacapo/experiments/architectures/cnnectome_unet.py @@ -7,7 +7,151 @@ class CNNectomeUNet(Architecture): + """ + A U-Net architecture for 3D or 4D data. The U-Net expects 3D or 4D tensors + shaped like:: + + ``(batch=1, channels, [length,] depth, height, width)``. + + This U-Net performs only "valid" convolutions, i.e., sizes of the feature + maps decrease after each convolution. It will perfrom 4D convolutions as + long as ``length`` is greater than 1. As soon as ``length`` is 1 due to a + valid convolution, the time dimension will be dropped and tensors with + ``(b, c, z, y, x)`` will be use (and returned) from there on. + + Attributes: + fmaps_in: + The number of input channels. + fmaps_out: + The number of feature maps in the output layer. This is also the + number of output feature maps. Stored in the ``channels`` dimension. + num_fmaps: + The number of feature maps in the first layer. This is also the + number of output feature maps. Stored in the ``channels`` dimension. + fmap_inc_factor: + By how much to multiply the number of feature maps between layers. + If layer 0 has ``k`` feature maps, layer ``l`` will have + ``k*fmap_inc_factor**l``. + downsample_factors: + List of tuples ``(z, y, x)`` to use to down- and up-sample the + feature maps between layers. + kernel_size_down (optional): + List of lists of kernel sizes. The number of sizes in a list + determines the number of convolutional layers in the corresponding + level of the build on the left side. Kernel sizes can be given as + tuples or integer. If not given, each convolutional pass will + consist of two 3x3x3 convolutions. + kernel_size_up (optional): + List of lists of kernel sizes. The number of sizes in a list + determines the number of convolutional layers in the corresponding + level of the build on the right side. Within one of the lists going + from left to right. Kernel sizes can be given as tuples or integer. + If not given, each convolutional pass will consist of two 3x3x3 + convolutions. + activation + Which activation to use after a convolution. Accepts the name of + any tensorflow activation function (e.g., ``ReLU`` for + ``torch.nn.ReLU``). + fov (optional): + Initial field of view in physical units + voxel_size (optional): + Size of a voxel in the input data, in physical units + num_heads (optional): + Number of decoders. The resulting U-Net has one single encoder + path and num_heads decoder paths. This is useful in a multi-task + learning context. + constant_upsample (optional): + If set to true, perform a constant upsampling instead of a + transposed convolution in the upsampling layers. + padding (optional): + How to pad convolutions. Either 'same' or 'valid' (default). + upsample_channel_contraction: + When performing the ConvTranspose, whether to reduce the number + of channels by the fmap_increment_factor. can be either bool or + list of bools to apply independently per layer. + activation_on_upsample: + Whether or not to add an activation after the upsample operation. + use_attention: + Whether or not to use an attention block in the U-Net. + Methods: + forward(x): + Forward pass of the U-Net. + scale(voxel_size): + Scale the voxel size according to the upsampling factors. + input_shape: + Return the input shape of the U-Net. + num_in_channels: + Return the number of input channels. + num_out_channels: + Return the number of output channels. + eval_shape_increase: + Return the increase in shape due to the U-Net. + Note: + This class is a wrapper around the ``CNNectomeUNetModule`` class. + The ``CNNectomeUNetModule`` class is the actual implementation of the + U-Net architecture. + """ + def __init__(self, architecture_config): + """ + Initialize the U-Net architecture. + + Args: + architecture_config (dict): A dictionary containing the configuration + of the U-Net architecture. The dictionary should contain the following + keys: + - input_shape: The shape of the input data. + - fmaps_out: The number of output feature maps. + - fmaps_in: The number of input feature maps. + - num_fmaps: The number of feature maps in the first layer. + - fmap_inc_factor: The factor by which the number of feature maps + increases between layers. + - downsample_factors: List of tuples ``(z, y, x)`` to use to down- + and up-sample the feature maps between layers. + - kernel_size_down (optional): List of lists of kernel sizes. The + number of sizes in a list determines the number of convolutional + layers in the corresponding level of the build on the left side. + Kernel sizes can be given as tuples or integer. If not given, each + convolutional pass will consist of two 3x3x3 convolutions. + - kernel_size_up (optional): List of lists of kernel sizes. The + number of sizes in a list determines the number of convolutional + layers in the corresponding level of the build on the right side. + Within one of the lists going from left to right. Kernel sizes can + be given as tuples or integer. If not given, each convolutional + pass will consist of two 3x3x3 convolutions. + - constant_upsample (optional): If set to true, perform a constant + upsampling instead of a transposed convolution in the upsampling + layers. + - padding (optional): How to pad convolutions. Either 'same' or + 'valid' (default). + - upsample_factors (optional): List of tuples ``(z, y, x)`` to use + to upsample the feature maps between layers. + - activation_on_upsample (optional): Whether or not to add an + activation after the upsample operation. + - use_attention (optional): Whether or not to use an attention block + in the U-Net. + Raises: + ValueError: If the input shape is not given. + Examples: + >>> architecture_config = { + ... "input_shape": (1, 1, 128, 128, 128), + ... "fmaps_out": 1, + ... "fmaps_in": 1, + ... "num_fmaps": 24, + ... "fmap_inc_factor": 2, + ... "downsample_factors": [(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... "kernel_size_down": [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... "kernel_size_up": [[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... "constant_upsample": False, + ... "padding": "valid", + ... "upsample_factors": [(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... "activation_on_upsample": True, + ... "use_attention": False + ... } + >>> unet = CNNectomeUNet(architecture_config) + Note: + The input shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``. + """ super().__init__() self._input_shape = architecture_config.input_shape @@ -31,11 +175,62 @@ def __init__(self, architecture_config): @property def eval_shape_increase(self): + """ + The increase in shape due to the U-Net. + + Returns: + The increase in shape due to the U-Net. + Raises: + AttributeError: If the increase in shape is not given. + Examples: + >>> unet.eval_shape_increase + (1, 1, 128, 128, 128) + Note: + The increase in shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``. + """ if self._eval_shape_increase is None: return super().eval_shape_increase return self._eval_shape_increase def module(self): + """ + Create the U-Net module. + + Returns: + The U-Net module. + Raises: + AttributeError: If the number of input channels is not given. + AttributeError: If the number of output channels is not given. + AttributeError: If the number of feature maps in the first layer is not given. + AttributeError: If the factor by which the number of feature maps increases between layers is not given. + AttributeError: If the downsample factors are not given. + AttributeError: If the kernel sizes for the down pass are not given. + AttributeError: If the kernel sizes for the up pass are not given. + AttributeError: If the constant upsample flag is not given. + AttributeError: If the padding is not given. + AttributeError: If the upsample factors are not given. + AttributeError: If the activation on upsample flag is not given. + AttributeError: If the use attention flag is not given. + Examples: + >>> unet.module() + CNNectomeUNetModule( + in_channels=1, + num_fmaps=24, + num_fmaps_out=1, + fmap_inc_factor=2, + kernel_size_down=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + kernel_size_up=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + constant_upsample=False, + padding='valid', + activation_on_upsample=True, + upsample_channel_contraction=[False, True, True], + use_attention=False + ) + Note: + The U-Net module is an instance of the ``CNNectomeUNetModule`` class. + + """ fmaps_in = self.fmaps_in levels = len(self.downsample_factors) + 1 dims = len(self.downsample_factors[0]) @@ -91,27 +286,159 @@ def module(self): return unet def scale(self, voxel_size): + """ + Scale the voxel size according to the upsampling factors. + + Args: + voxel_size (tuple): The size of a voxel in the input data. + Returns: + The scaled voxel size. + Raises: + ValueError: If the voxel size is not given. + Examples: + >>> unet.scale((1, 1, 1)) + (1, 1, 1) + Note: + The voxel size should be given as a tuple ``(z, y, x)``. + """ for upsample_factor in self.upsample_factors: voxel_size = voxel_size / upsample_factor return voxel_size @property def input_shape(self): + """ + Return the input shape of the U-Net. + + Returns: + The input shape of the U-Net. + Raises: + AttributeError: If the input shape is not given. + Examples: + >>> unet.input_shape + (1, 1, 128, 128, 128) + Note: + The input shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``. + """ return self._input_shape @property def num_in_channels(self) -> int: + """ + Return the number of input channels. + + Returns: + The number of input channels. + Raises: + AttributeError: If the number of input channels is not given. + Examples: + >>> unet.num_in_channels + 1 + Note: + The number of input channels should be given as an integer. + """ return self.fmaps_in @property def num_out_channels(self) -> int: + """ + Return the number of output channels. + + Returns: + The number of output channels. + Raises: + AttributeError: If the number of output channels is not given. + Examples: + >>> unet.num_out_channels + 1 + Note: + The number of output channels should be given as an integer. + """ return self.fmaps_out def forward(self, x): + """ + Forward pass of the U-Net. + + Args: + x (Tensor): The input tensor. + Returns: + The output tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> unet = CNNectomeUNet(architecture_config) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> unet(x) + Note: + The input tensor should be given as a 5D tensor. + """ return self.unet(x) class CNNectomeUNetModule(torch.nn.Module): + """ + A U-Net module for 3D or 4D data. The U-Net expects 3D or 4D tensors shaped + like:: + + ``(batch=1, channels, [length,] depth, height, width)``. + + This U-Net performs only "valid" convolutions, i.e., sizes of the feature maps + decrease after each convolution. It will perfrom 4D convolutions as long as + ``length`` is greater than 1. As soon as ``length`` is 1 due to a valid + convolution, the time dimension will be dropped and tensors with ``(b, c, z, y, x)`` + will be use (and returned) from there on. + + Attributes: + num_levels: + The number of levels in the U-Net. + num_heads: + The number of decoders. + in_channels: + The number of input channels. + out_channels: + The number of output channels. + dims: + The number of dimensions. + use_attention: + Whether or not to use an attention block in the U-Net. + l_conv: + The left convolutional passes. + l_down: + The left downsample layers. + r_up: + The right up/crop/concatenate layers. + r_conv: + The right convolutional passes. + kernel_size_down: + The kernel sizes for the down pass. + kernel_size_up: + The kernel sizes for the up pass. + fmap_inc_factor: + The factor by which the number of feature maps increases between layers. + downsample_factors: + The downsample factors. + constant_upsample: + Whether to perform a constant upsampling instead of a transposed convolution. + padding: + How to pad convolutions. + upsample_channel_contraction: + Whether to reduce the number of channels by the fmap_increment_factor. + activation_on_upsample: + Whether or not to add an activation after the upsample operation. + use_attention: + Whether or not to use an attention block in the U-Net. + attention: + The attention blocks. + Methods: + rec_forward(level, f_in): + Recursive forward pass of the U-Net. + forward(x): + Forward pass of the U-Net. + Note: + The input tensor should be given as a 5D tensor. + """ + def __init__( self, in_channels, @@ -129,7 +456,8 @@ def __init__( activation_on_upsample=False, use_attention=False, ): - """Create a U-Net:: + """ + Create a U-Net:: f_in --> f_left --------------------------->> f_right--> f_out | ^ @@ -155,83 +483,80 @@ def __init__( from there on. Args: - in_channels: - The number of input channels. - num_fmaps: - The number of feature maps in the first layer. This is also the number of output feature maps. Stored in the ``channels`` dimension. - fmap_inc_factor: - By how much to multiply the number of feature maps between layers. If layer 0 has ``k`` feature maps, layer ``l`` will have ``k*fmap_inc_factor**l``. - downsample_factors: - List of tuples ``(z, y, x)`` to use to down- and up-sample the feature maps between layers. - kernel_size_down (optional): - List of lists of kernel sizes. The number of sizes in a list determines the number of convolutional layers in the corresponding level of the build on the left side. Kernel sizes can be given as tuples or integer. If not given, each convolutional pass will consist of two 3x3x3 convolutions. - kernel_size_up (optional): - List of lists of kernel sizes. The number of sizes in a list determines the number of convolutional layers in the corresponding level of the build on the right side. Within one of the lists going from left to right. Kernel sizes can be given as tuples or integer. If not given, each convolutional pass will consist of two 3x3x3 convolutions. - activation: - Which activation to use after a convolution. Accepts the name of any tensorflow activation function (e.g., ``ReLU`` for ``torch.nn.ReLU``). - fov (optional): - Initial field of view in physical units - voxel_size (optional): - Size of a voxel in the input data, in physical units - num_heads (optional): - Number of decoders. The resulting U-Net has one single encoder path and num_heads decoder paths. This is useful in a multi-task learning context. - constant_upsample (optional): - If set to true, perform a constant upsampling instead of a transposed convolution in the upsampling layers. - padding (optional): - How to pad convolutions. Either 'same' or 'valid' (default). - upsample_channel_contraction: - When performing the ConvTranspose, whether to reduce the number of channels by the fmap_increment_factor. can be either bool or list of bools to apply independently per layer. - activation_on_upsample: - Whether or not to add an activation after the upsample operation. + use_attention: + Whether or not to use an attention block in the U-Net. + attention: + The attention blocks. + Returns: + The U-Net module. + Raises: + ValueError: If the number of input channels is not given. + Examples: + >>> unet = CNNectomeUNetModule( + ... in_channels=1, + ... num_fmaps=24, + ... num_fmaps_out=1, + ... fmap_inc_factor=2, + ... kernel_size_down=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... kernel_size_up=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... constant_upsample=False, + ... padding='valid', + ... activation_on_upsample=True, + ... upsample_channel_contraction=[False, True, True], + ... use_attention=False + ... ) + Note: + The input tensor should be given as a 5D tensor. """ super().__init__() @@ -378,6 +703,36 @@ def __init__( ) def rec_forward(self, level, f_in): + """ + Recursive forward pass of the U-Net. + + Args: + level (int): The level of the U-Net. + f_in (Tensor): The input tensor. + Returns: + The output tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> unet = CNNectomeUNetModule( + ... in_channels=1, + ... num_fmaps=24, + ... num_fmaps_out=1, + ... fmap_inc_factor=2, + ... kernel_size_down=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... kernel_size_up=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... constant_upsample=False, + ... padding='valid', + ... activation_on_upsample=True, + ... upsample_channel_contraction=[False, True, True], + ... use_attention=False + ... ) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> unet.rec_forward(2, x) + Note: + The input tensor should be given as a 5D tensor. + """ # index of level in layer arrays i = self.num_levels - level - 1 @@ -415,6 +770,35 @@ def rec_forward(self, level, f_in): return fs_out def forward(self, x): + """ + Forward pass of the U-Net. + + Args: + x (Tensor): The input tensor. + Returns: + The output tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> unet = CNNectomeUNetModule( + ... in_channels=1, + ... num_fmaps=24, + ... num_fmaps_out=1, + ... fmap_inc_factor=2, + ... kernel_size_down=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... kernel_size_up=[[(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)], [(3, 3, 3), (3, 3, 3)]], + ... downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)], + ... constant_upsample=False, + ... padding='valid', + ... activation_on_upsample=True, + ... upsample_channel_contraction=[False, True, True], + ... use_attention=False + ... ) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> unet(x) + Note: + The input tensor should be given as a 5D tensor. + """ y = self.rec_forward(self.num_levels - 1, x) if self.num_heads == 1: @@ -424,9 +808,45 @@ def forward(self, x): class ConvPass(torch.nn.Module): + """ + Convolutional pass module. This module performs a series of convolutional + layers followed by an activation function. The module can also pad the + feature maps to ensure translation equivariance. The module can perform + 2D or 3D convolutions. + + Attributes: + dims: + The number of dimensions. + conv_pass: + The convolutional pass module. + Methods: + forward(x): + Forward pass of the Conv + Note: + The input tensor should be given as a 5D tensor. + """ + def __init__( self, in_channels, out_channels, kernel_sizes, activation, padding="valid" ): + """ + Convolutional pass module. This module performs a series of + convolutional layers followed by an activation function. + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + kernel_sizes (list): The kernel sizes for the convolutional layers. + activation (str): The activation function to use. + padding (optional): How to pad convolutions. Either 'same' or 'valid'. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> conv_pass = ConvPass(1, 1, [(3, 3, 3), (3, 3, 3)], "ReLU") + Note: + The input tensor should be given as a 5D tensor. + + """ super(ConvPass, self).__init__() if activation is not None: @@ -460,11 +880,61 @@ def __init__( self.conv_pass = torch.nn.Sequential(*layers) def forward(self, x): + """ + Forward pass of the ConvPass module. + + Args: + x (Tensor): The input tensor. + Returns: + The output tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> conv_pass = ConvPass(1, 1, [(3, 3, 3), (3, 3, 3)], "ReLU") + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> conv_pass(x) + Note: + The input tensor should be given as a 5D tensor. + """ return self.conv_pass(x) class Downsample(torch.nn.Module): + """ + Downsample module. This module performs downsampling of the input tensor + using either max-pooling or average pooling. The module can also crop the + feature maps to ensure translation equivariance with a stride of the + downsampling factor. + + Attributes: + dims: + The number of dimensions. + downsample_factor: + The downsampling factor. + down: + The downsampling layer. + Methods: + forward(x): + Downsample the input tensor. + Note: + The input tensor should be given as a 5D tensor. + + """ + def __init__(self, downsample_factor): + """ + Downsample module. This module performs downsampling of the input tensor + using either max-pooling or average pooling. + + Args: + downsample_factor (tuple): The downsampling factor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> downsample = Downsample((2, 2, 2)) + Note: + The input tensor should be given as a 5D tensor. + """ super(Downsample, self).__init__() self.dims = len(downsample_factor) @@ -479,6 +949,22 @@ def __init__(self, downsample_factor): self.down = pool(downsample_factor, stride=downsample_factor) def forward(self, x): + """ + Downsample the input tensor. + + Args: + x (Tensor): The input tensor. + Returns: + The downsampled tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> downsample = Downsample((2, 2, 2)) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> downsample(x) + Note: + The input tensor should be given as a 5D tensor. + """ for d in range(1, self.dims + 1): if x.size()[-d] % self.downsample_factor[-d] != 0: raise RuntimeError( @@ -491,6 +977,34 @@ def forward(self, x): class Upsample(torch.nn.Module): + """ + Upsample module. This module performs upsampling of the input tensor using + either transposed convolutions or nearest neighbor interpolation. The + module can also crop the feature maps to ensure translation equivariance + with a stride of the upsampling factor. + + Attributes: + crop_factor: + The crop factor. + next_conv_kernel_sizes: + The kernel sizes for the convolutional layers. + dims: + The number of dimensions. + up: + The upsampling layer. + Methods: + crop_to_factor(x, factor, kernel_sizes): + Crop feature maps to ensure translation equivariance with stride of + upsampling factor. + crop(x, shape): + Center-crop x to match spatial dimensions given by shape. + forward(g_out, f_left=None): + Forward pass of the Upsample module. + Note: + The input tensor should be given as a 5D tensor. + + """ + def __init__( self, scale_factor, @@ -501,6 +1015,27 @@ def __init__( next_conv_kernel_sizes=None, activation=None, ): + """ + Upsample module. This module performs upsampling of the input tensor + + Args: + scale_factor (tuple): The upsampling factor. + mode (optional): The upsampling mode. Either 'transposed_conv' or + 'nearest + in_channels (optional): The number of input channels. + out_channels (optional): The number of output channels. + crop_factor (optional): The crop factor. + next_conv_kernel_sizes (optional): The kernel sizes for the convolutional layers. + activation (optional): The activation function to use. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1) + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1, activation="ReLU") + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1, crop_factor=(2, 2, 2), next_conv_kernel_sizes=[(3, 3, 3), (3, 3, 3)]) + Note: + The input tensor should be given as a 5D tensor. + """ super(Upsample, self).__init__() if activation is not None: @@ -548,12 +1083,50 @@ def __init__( self.up = layers[0] def crop_to_factor(self, x, factor, kernel_sizes): - """Crop feature maps to ensure translation equivariance with stride of + """ + Crop feature maps to ensure translation equivariance with stride of upsampling factor. This should be done right after upsampling, before application of the convolutions with the given kernel sizes. The crop could be done after the convolutions, but it is more efficient to do that before (feature maps will be smaller). + + We need to ensure that the feature map is large enough to ensure that + the translation equivariance is maintained. This is done by cropping + the feature map to the largest size that is a multiple of the factor + and that is large enough to ensure that the translation equivariance + is maintained. + + We need (spatial_shape - convolution_crop) to be a multiple of factor, + i.e.: + (s - c) = n*k + + where s is the spatial size of the feature map, c is the crop due to + the convolutions, n is the number of strides of the upsampling factor, + and k is the upsampling factor. + + We want to find the largest n for which s' = n*k + c <= s + + n = floor((s - c)/k) + + This gives us the target shape s' + + s' = n*k + c + + Args: + x (Tensor): The input tensor. + factor (tuple): The upsampling factor. + kernel_sizes (list): The kernel sizes for the convolutional layers. + Returns: + The cropped tensor. + Raises: + RuntimeError: If the feature map is too small to ensure translation equivariance. + Examples: + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> upsample.crop_to_factor(x, (2, 2, 2), [(3, 3, 3), (3, 3, 3)]) + Note: + The input tensor should be given as a 5D tensor. """ shape = x.size() @@ -599,7 +1172,23 @@ def crop_to_factor(self, x, factor, kernel_sizes): return x def crop(self, x, shape): - """Center-crop x to match spatial dimensions given by shape.""" + """ + Center-crop x to match spatial dimensions given by shape. + + Args: + x (Tensor): The input tensor. + shape (tuple): The target shape. + Returns: + The center-cropped tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1) + >>> x = torch.randn(1, 1, 64, 64, 64) + >>> upsample.crop(x, (32, 32, 32)) + Note: + The input tensor should be given as a 5D tensor. + """ x_target_size = x.size()[: -self.dims] + shape @@ -610,6 +1199,24 @@ def crop(self, x, shape): return x[slices] def forward(self, g_out, f_left=None): + """ + Forward pass of the Upsample module. + + Args: + g_out (Tensor): The gating signal tensor. + f_left (Tensor): The input feature tensor. + Returns: + The output feature tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> upsample = Upsample(scale_factor=(2, 2, 2), in_channels=1, out_channels=1) + >>> g_out = torch.randn(1, 1, 64, 64, 64) + >>> f_left = torch.randn(1, 1, 32, 32, 32) + >>> upsample(g_out, f_left) + Note: + The gating signal and input feature tensors should be given as 5D tensors. + """ g_up = self.up(g_out) if self.next_conv_kernel_sizes is not None: @@ -628,41 +1235,72 @@ def forward(self, g_out, f_left=None): class AttentionBlockModule(nn.Module): - def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): - """Attention Block Module:: - - The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). - - [g] --> W_g --\ /--> psi --> * --> [output] - \ / - [x] --> W_x --> [+] --> relu -- + """ + Attention Block Module: + + The AttentionBlock uses two separate pathways to process 'g' and 'x', + combines them, and applies a sigmoid activation to generate an attention map. + This map is then used to scale the input features 'x', resulting in an output + that focuses on important features as dictated by the gating signal 'g'. + + The attention block takes two inputs: 'g' (gating signal) and 'x' (input features). + + [g] --> W_g --\ /--> psi --> * --> [output] + \ / + [x] --> W_x --> [+] --> relu -- + + Where: + - W_g and W_x are 1x1 Convolution followed by Batch Normalization + - [+] indicates element-wise addition + - relu is the Rectified Linear Unit activation function + - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation + - * indicates element-wise multiplication between the output of psi and input feature 'x' + - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + + Attributes: + dims: + The number of dimensions of the input tensors. + kernel_sizes: + The kernel sizes for the convolutional layers. + upsample_factor: + The factor by which to upsample the attention map. + W_g: + The 1x1 Convolutional layer for the gating signal. + W_x: + The 1x1 Convolutional layer for the input features. + psi: + The 1x1 Convolutional layer followed by Sigmoid activation. + up: + The upsampling layer to match the dimensions of the input features. + relu: + The Rectified Linear Unit activation function. + Methods: + calculate_and_apply_padding(smaller_tensor, larger_tensor): + Calculate and apply symmetric padding to the smaller tensor to match the dimensions of the larger tensor. + forward(g, x): + Forward pass of the Attention Block. + Note: + The AttentionBlockModule is an instance of the ``torch.nn.Module`` class. + """ - Where: - - W_g and W_x are 1x1 Convolution followed by Batch Normalization - - [+] indicates element-wise addition - - relu is the Rectified Linear Unit activation function - - psi is a sequence of 1x1 Convolution, Batch Normalization, and Sigmoid activation - - * indicates element-wise multiplication between the output of psi and input feature 'x' - - [output] has the same dimensions as input 'x', selectively emphasized by attention weights + def __init__(self, F_g, F_l, F_int, dims, upsample_factor=None): + """ + Initialize the Attention Block Module. Args: - F_g (int): The number of feature channels in the gating signal (g). - This is the input channel dimension for the W_g convolutional layer. - - F_l (int): The number of feature channels in the input features (x). - This is the input channel dimension for the W_x convolutional layer. - - F_int (int): The number of intermediate feature channels. - This represents the output channel dimension of the W_g and W_x convolutional layers - and the input channel dimension for the psi layer. Typically, F_int is smaller - than F_g and F_l, as it serves to compress the feature representations before - applying the attention mechanism. - - The AttentionBlock uses two separate pathways to process 'g' and 'x', combines them, - and applies a sigmoid activation to generate an attention map. This map is then used - to scale the input features 'x', resulting in an output that focuses on important - features as dictated by the gating signal 'g'. - + F_g (int): The number of feature maps in the gating signal tensor. + F_l (int): The number of feature maps in the input feature tensor. + F_int (int): The number of feature maps in the intermediate tensor. + dims (int): The number of dimensions of the input tensors. + upsample_factor (optional): The factor by which to upsample the attention map. + Returns: + The Attention Block Module. + Raises: + RuntimeError: If the gating signal and input feature tensors have different dimensions. + Examples: + >>> attention_block = AttentionBlockModule(F_g=1, F_l=1, F_int=1, dims=3) + Note: + The number of feature maps should be given as an integer. """ super(AttentionBlockModule, self).__init__() @@ -709,11 +1347,19 @@ def calculate_and_apply_padding(self, smaller_tensor, larger_tensor): Calculate and apply symmetric padding to the smaller tensor to match the dimensions of the larger tensor. Args: - smaller_tensor (Tensor): The tensor to be padded. - larger_tensor (Tensor): The tensor whose dimensions the smaller tensor needs to match. - + smaller_tensor (Tensor): The tensor to be padded. + larger_tensor (Tensor): The tensor whose dimensions the smaller tensor needs to match. Returns: - Tensor: The padded smaller tensor with the same dimensions as the larger tensor. + Tensor: The padded smaller tensor with the same dimensions as the larger tensor. + Raises: + RuntimeError: If the tensors have different dimensions. + Examples: + >>> larger_tensor = torch.randn(1, 1, 128, 128, 128) + >>> smaller_tensor = torch.randn(1, 1, 64, 64, 64) + >>> attention_block = AttentionBlockModule(F_g=1, F_l=1, F_int=1, dims=3) + >>> padded_tensor = attention_block.calculate_and_apply_padding(smaller_tensor, larger_tensor) + Note: + The tensors should have the same dimensions. """ padding = [] for i in range(2, 2 + self.dims): @@ -727,6 +1373,24 @@ def calculate_and_apply_padding(self, smaller_tensor, larger_tensor): return nn.functional.pad(smaller_tensor, padding, mode="constant", value=0) def forward(self, g, x): + """ + Forward pass of the Attention Block. + + Args: + g (Tensor): The gating signal tensor. + x (Tensor): The input feature tensor. + Returns: + Tensor: The output tensor with the same dimensions as the input feature tensor. + Raises: + RuntimeError: If the gating signal and input feature tensors have different dimensions. + Examples: + >>> g = torch.randn(1, 1, 128, 128, 128) + >>> x = torch.randn(1, 1, 128, 128, 128) + >>> attention_block = AttentionBlockModule(F_g=1, F_l=1, F_int=1, dims=3) + >>> output = attention_block(g, x) + Note: + The gating signal and input feature tensors should have the same dimensions. + """ g1 = self.W_g(g) x1 = self.W_x(x) g1 = self.calculate_and_apply_padding(g1, x1) diff --git a/dacapo/experiments/architectures/cnnectome_unet_config.py b/dacapo/experiments/architectures/cnnectome_unet_config.py index c0e9e5b9d..77905d79c 100644 --- a/dacapo/experiments/architectures/cnnectome_unet_config.py +++ b/dacapo/experiments/architectures/cnnectome_unet_config.py @@ -10,10 +10,49 @@ @attr.s class CNNectomeUNetConfig(ArchitectureConfig): - """This class configures the CNNectomeUNet based on + """ + This class configures the CNNectomeUNet based on https://github.com/saalfeldlab/CNNectome/blob/master/CNNectome/networks/unet_class.py Includes support for super resolution via the upsampling factors. + + Attributes: + input_shape: Coordinate + The shape of the data passed into the network during training. + fmaps_out: int + The number of channels produced by your architecture. + fmaps_in: int + The number of channels expected from the raw data. + num_fmaps: int + The number of feature maps in the top level of the UNet. + fmap_inc_factor: int + The multiplication factor for the number of feature maps for each level of the UNet. + downsample_factors: List[Coordinate] + The factors to downsample the feature maps along each axis per layer. + kernel_size_down: Optional[List[Coordinate]] + The size of the convolutional kernels used before downsampling in each layer. + kernel_size_up: Optional[List[Coordinate]] + The size of the convolutional kernels used before upsampling in each layer. + _eval_shape_increase: Optional[Coordinate] + The amount by which to increase the input size when just prediction rather than training. + It is generally possible to significantly increase the input size since we don't have the memory + constraints of the gradients, the optimizer and the batch size. + upsample_factors: Optional[List[Coordinate]] + The amount by which to upsample the output of the UNet. + constant_upsample: bool + Whether to use a transpose convolution or simply copy voxels to upsample. + padding: str + The padding to use in convolution operations. + use_attention: bool + Whether to use attention blocks in the UNet. This is supported for 2D and 3D. + Methods: + architecture_type() + Returns the architecture type. + Note: + The architecture_type attribute is set to CNNectomeUNet. + References: + Saalfeld, S., Fetter, R., Cardona, A., & Tomancak, P. (2012). + """ architecture_type = CNNectomeUNet @@ -45,13 +84,13 @@ class CNNectomeUNetConfig(ArchitectureConfig): "help_text": "The factors to downsample the feature maps along each axis per layer." } ) - kernel_size_down: Optional[List[Coordinate]] = attr.ib( + kernel_size_down: Optional[List[List[Coordinate]]] = attr.ib( default=None, metadata={ "help_text": "The size of the convolutional kernels used before downsampling in each layer." }, ) - kernel_size_up: Optional[List[Coordinate]] = attr.ib( + kernel_size_up: Optional[List[List[Coordinate]]] = attr.ib( default=None, metadata={ "help_text": "The size of the convolutional kernels used before upsampling in each layer." diff --git a/dacapo/experiments/architectures/dummy_architecture.py b/dacapo/experiments/architectures/dummy_architecture.py index 70a0d5d3e..fa5a889e7 100644 --- a/dacapo/experiments/architectures/dummy_architecture.py +++ b/dacapo/experiments/architectures/dummy_architecture.py @@ -12,15 +12,27 @@ class DummyArchitecture(Architecture): channels_out: An integer representing the number of output channels. conv: A 3D convolution object. input_shape: A coordinate object representing the shape of the input. - Methods: forward(x): Performs the forward pass of the network. + num_in_channels(): Returns the number of input channels for this architecture. + num_out_channels(): Returns the number of output channels for this architecture. + Note: + This class is used to represent a dummy architecture layer for a 3D CNN. """ def __init__(self, architecture_config): """ + Constructor for the DummyArchitecture class. Initializes the 3D convolution object. + Args: - architecture_config: An object containing the configuration settings for the architecture. + architecture_config: An architecture configuration object. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> architecture_config = ArchitectureConfig(num_in_channels=1, num_out_channels=1) + >>> dummy_architecture = DummyArchitecture(architecture_config) + Note: + This method is used to initialize the DummyArchitecture class. """ super().__init__() @@ -36,6 +48,13 @@ def input_shape(self): Returns: Coordinate: Input shape of the architecture. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture.input_shape + Coordinate(x=40, y=20, z=20) + Note: + This method is used to return the input shape for this architecture. """ return Coordinate(40, 20, 20) @@ -46,6 +65,13 @@ def num_in_channels(self): Returns: int: Number of input channels. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture.num_in_channels + 1 + Note: + This method is used to return the number of input channels for this architecture. """ return self.channels_in @@ -56,6 +82,13 @@ def num_out_channels(self): Returns: int: Number of output channels. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture.num_out_channels + 1 + Note: + This method is used to return the number of output channels for this architecture. """ return self.channels_out @@ -65,8 +98,15 @@ def forward(self, x): Args: x: Input tensor. - Returns: Tensor: Output tensor after the forward pass. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture = DummyArchitecture(architecture_config) + >>> x = torch.randn(1, 1, 40, 20, 20) + >>> dummy_architecture.forward(x) + Note: + This method is used to perform the forward pass of the network. """ return self.conv(x) diff --git a/dacapo/experiments/architectures/dummy_architecture_config.py b/dacapo/experiments/architectures/dummy_architecture_config.py index eaf9b7027..695d8bc41 100644 --- a/dacapo/experiments/architectures/dummy_architecture_config.py +++ b/dacapo/experiments/architectures/dummy_architecture_config.py @@ -8,7 +8,8 @@ @attr.s class DummyArchitectureConfig(ArchitectureConfig): - """A dummy architecture configuration class used for testing purposes. + """ + A dummy architecture configuration class used for testing purposes. It extends the base class "ArchitectureConfig". This class contains dummy attributes and always returns that the configuration is invalid when verified. @@ -20,6 +21,10 @@ class DummyArchitectureConfig(ArchitectureConfig): functionality or meaning. num_out_channels (int): The number of output channels. This is also a dummy attribute and has no real functionality or meaning. + Methods: + verify(self) -> Tuple[bool, str]: This method is used to check whether this is a valid architecture configuration. + Note: + This class is used to represent a DummyArchitectureConfig object in the system. """ architecture_type = DummyArchitecture @@ -29,13 +34,22 @@ class DummyArchitectureConfig(ArchitectureConfig): num_out_channels: int = attr.ib(metadata={"help_text": "Dummy attribute."}) def verify(self) -> Tuple[bool, str]: - """Verifies the configuration validity. + """ + Verifies the configuration validity. Since this is a dummy configuration for testing purposes, this method always returns False indicating that the configuration is invalid. Returns: tuple: A tuple containing a boolean validity flag and a reason message string. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> dummy_architecture_config = DummyArchitectureConfig(num_in_channels=1, num_out_channels=1) + >>> dummy_architecture_config.verify() + (False, "This is a DummyArchitectureConfig and is never valid") + Note: + This method is used to check whether this is a valid architecture configuration. """ return False, "This is a DummyArchitectureConfig and is never valid" diff --git a/dacapo/experiments/arraytypes/annotations.py b/dacapo/experiments/arraytypes/annotations.py index f7fc2f9b1..f90d9dd09 100644 --- a/dacapo/experiments/arraytypes/annotations.py +++ b/dacapo/experiments/arraytypes/annotations.py @@ -8,7 +8,15 @@ class AnnotationArray(ArrayType): """ An AnnotationArray is a uint8, uint16, uint32 or uint64 Array where each - voxel has a value associated with its class. + voxel has a value associated with its class. The class of each voxel can be + determined by simply taking the value. + + Attributes: + classes (Dict[int, str]): A mapping from class label to class name. + Methods: + interpolatable(self) -> bool: It is a method that returns False. + Note: + This class is used to create an AnnotationArray object which is used to represent an array of class labels. """ classes: Dict[int, str] = attr.ib( @@ -20,4 +28,20 @@ class AnnotationArray(ArrayType): @property def interpolatable(self): + """ + Method to return False. + + Returns: + bool + Returns a boolean value of False representing that the values are not interpolatable. + Raises: + NotImplementedError + This method is not implemented in this class. + Examples: + >>> annotation_array = AnnotationArray(classes={1: "mitochondria", 2: "membrane"}) + >>> annotation_array.interpolatable + False + Note: + This method is used to check if the array is interpolatable. + """ return False diff --git a/dacapo/experiments/arraytypes/arraytype.py b/dacapo/experiments/arraytypes/arraytype.py index 0dce23ec0..c4ec2f050 100644 --- a/dacapo/experiments/arraytypes/arraytype.py +++ b/dacapo/experiments/arraytypes/arraytype.py @@ -8,7 +8,16 @@ class ArrayType(ABC): track of the semantic meaning of an Array. Additionally the ArrayType keeps track of metadata that is specific to this datatype such as num_classes for an annotated volume or channel names for intensity - arrays. + arrays. The ArrayType class is an abstract class and should be subclassed + to represent different types of arrays. + + Attributes: + num_classes (int): The number of classes in the array. + channel_names (List[str]): The names of the channels in the array. + Methods: + interpolatable: This is an abstract method which should be overridden in each of the subclasses to determine if an array is interpolatable or not. + Note: + This class is used to create an ArrayType object which is used to represent the type of data provided by an array. """ @property @@ -20,5 +29,13 @@ def interpolatable(self) -> bool: Returns: bool: True if the array is interpolatable, False otherwise. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> array_type = ArrayType() + >>> array_type.interpolatable + NotImplementedError + Note: + This method is used to check if the array is interpolatable. """ pass diff --git a/dacapo/experiments/arraytypes/binary.py b/dacapo/experiments/arraytypes/binary.py index e6c57faeb..dcc95b109 100644 --- a/dacapo/experiments/arraytypes/binary.py +++ b/dacapo/experiments/arraytypes/binary.py @@ -9,16 +9,14 @@ class BinaryArray(ArrayType): """ A subclass of ArrayType representing BinaryArray. The BinaryArray object is created with two attributes; channels. - Each voxel in this array is either 1 or 0. + Each voxel in this array is either 1 or 0. The class of each voxel can be determined by simply taking the argmax. Attributes: channels (Dict[int, str]): A dictionary attribute representing channel mapping with its binary classification. - - Args: - channels (Dict[int, str]): A dictionary input where keys are channel numbers and values are their corresponding class for binary classification. - Methods: interpolatable: Returns False as binary array type is not interpolatable. + Note: + This class is used to represent a BinaryArray object in the system. """ channels: Dict[int, str] = attr.ib( @@ -34,5 +32,13 @@ def interpolatable(self) -> bool: Returns: bool: Always returns False because interpolation is not possible. + Raises: + NotImplementedError: This method is not implemented in this class. + Examples: + >>> binary_array = BinaryArray(channels={1: "class1"}) + >>> binary_array.interpolatable + False + Note: + This method is used to check if the array is interpolatable. """ return False diff --git a/dacapo/experiments/arraytypes/distances.py b/dacapo/experiments/arraytypes/distances.py index 057f8f1b2..cd9754cbf 100644 --- a/dacapo/experiments/arraytypes/distances.py +++ b/dacapo/experiments/arraytypes/distances.py @@ -9,7 +9,17 @@ class DistanceArray(ArrayType): """ An array containing signed distances to the nearest boundary voxel for a particular label class. - Distances should be positive outside an object and negative inside an object. + Distances should be positive outside an object and negative inside an object. The distance should be 0 on the boundary. + The class of each voxel can be determined by simply taking the argmin. The distance should be in the range [-max, max]. + + Attributes: + classes (Dict[int, str]): A mapping from channel to class on which distances were calculated. + max (float): The maximum possible distance value of your distances. + Methods: + interpolatable(self) -> bool: It is a method that returns True. + Note: + This class is used to create a DistanceArray object which is used to represent an array containing signed distances to the nearest boundary voxel for a particular label class. + The class of each voxel can be determined by simply taking the argmin. """ classes: Dict[int, str] = attr.ib( @@ -20,4 +30,18 @@ class DistanceArray(ArrayType): @property def interpolatable(self) -> bool: + """ + Checks if the array is interpolatable. Returns True for this class. + + Returns: + bool: True indicating that the data can be interpolated. + Raises: + NotImplementedError: This method is not implemented in this class + Examples: + >>> distance_array = DistanceArray(classes={1: "class1"}) + >>> distance_array.interpolatable + True + Note: + This method is used to check if the array is interpolatable. + """ return True diff --git a/dacapo/experiments/arraytypes/embedding.py b/dacapo/experiments/arraytypes/embedding.py index 81fcadce3..ed751ca59 100644 --- a/dacapo/experiments/arraytypes/embedding.py +++ b/dacapo/experiments/arraytypes/embedding.py @@ -7,7 +7,16 @@ class EmbeddingArray(ArrayType): """ A generic output of a model that could represent almost anything. Assumed to be - float, interpolatable, and have sum number of channels. + float, interpolatable, and have sum number of channels. The channels are not + specified, and the array can be of any shape. + + Attributes: + embedding_dims (int): The dimension of your embedding. + Methods: + interpolatable(): + It is a method that returns True. + Note: + This class is used to represent an EmbeddingArray object in the system. """ embedding_dims: int = attr.ib( @@ -16,4 +25,20 @@ class EmbeddingArray(ArrayType): @property def interpolatable(self) -> bool: + """ + Method to return True. + + Returns: + bool + Returns a boolean value of True representing that the values are interpolatable. + Raises: + NotImplementedError + This method is not implemented in this class. + Examples: + >>> embedding_array = EmbeddingArray(embedding_dims=10) + >>> embedding_array.interpolatable + True + Note: + This method is used to check if the array is interpolatable. + """ return True diff --git a/dacapo/experiments/arraytypes/intensities.py b/dacapo/experiments/arraytypes/intensities.py index 84cf9227d..6cc74e96c 100644 --- a/dacapo/experiments/arraytypes/intensities.py +++ b/dacapo/experiments/arraytypes/intensities.py @@ -9,7 +9,17 @@ @attr.s class IntensitiesArray(ArrayType): """ - An IntensitiesArray is an Array of measured intensities. + An IntensitiesArray is an Array of measured intensities. Each voxel has a value in the range [min, max]. + + Attributes: + channels (Dict[int, str]): A mapping from channel to a name describing that channel. + min (float): The minimum possible value of your intensities. + max (float): The maximum possible value of your intensities. + Methods: + __attrs_post_init__(self): This method is called after the instance has been initialized by the constructor. + interpolatable(self) -> bool: It is a method that returns True. + Note: + This class is used to create an IntensitiesArray object which is used to represent an array of measured intensities. """ channels: Dict[int, str] = attr.ib( @@ -26,4 +36,20 @@ class IntensitiesArray(ArrayType): @property def interpolatable(self) -> bool: + """ + Method to return True. + + Returns: + bool + Returns a boolean value of True representing that the values are interpolatable. + Raises: + NotImplementedError + This method is not implemented in this class. + Examples: + >>> intensities_array = IntensitiesArray(channels={1: "channel1"}, min=0, max=1) + >>> intensities_array.interpolatable + True + Note: + This method is used to check if the array is interpolatable. + """ return True diff --git a/dacapo/experiments/arraytypes/mask.py b/dacapo/experiments/arraytypes/mask.py index 7f188ca73..cf2a04eaf 100644 --- a/dacapo/experiments/arraytypes/mask.py +++ b/dacapo/experiments/arraytypes/mask.py @@ -8,10 +8,11 @@ class Mask(ArrayType): """ A class that inherits the ArrayType class. This is a representation of a Mask in the system. - Methods - ------- - interpolatable(): - It is a method that returns False. + Methods: + interpolatable(): + It is a method that returns False. + Note: + This class is used to represent a Mask object in the system. """ @property @@ -19,9 +20,17 @@ def interpolatable(self) -> bool: """ Method to return False. - Returns - ------ - bool - Returns a boolean value of False representing that the values are not interpolatable. + Returns: + bool + Returns a boolean value of False representing that the values are not interpolatable. + Raises: + NotImplementedError + This method is not implemented in this class. + Examples: + >>> mask = Mask() + >>> mask.interpolatable + False + Note: + This method is used to check if the array is interpolatable. """ return False diff --git a/dacapo/experiments/arraytypes/probabilities.py b/dacapo/experiments/arraytypes/probabilities.py index e6510190f..d237aa601 100644 --- a/dacapo/experiments/arraytypes/probabilities.py +++ b/dacapo/experiments/arraytypes/probabilities.py @@ -14,6 +14,9 @@ class ProbabilityArray(ArrayType): Attributes: classes (List[str]): A mapping from channel to class on which distances were calculated. + Note: + This class is used to create a ProbabilityArray object which is used to represent an array containing probability distributions for each voxel pointed by its coordinate. + The class of each voxel can be determined by simply taking the argmax. """ classes: List[str] = attr.ib( @@ -29,5 +32,13 @@ def interpolatable(self) -> bool: Returns: bool: True indicating that the data can be interpolated. + Raises: + NotImplementedError: This method is not implemented in this class + Examples: + >>> probability_array = ProbabilityArray(classes=["class1", "class2"]) + >>> probability_array.interpolatable + True + Note: + This method is used to check if the array is interpolatable. """ return True diff --git a/dacapo/experiments/datasplits/datasets/arrays/array.py b/dacapo/experiments/datasplits/datasets/arrays/array.py index 37479e6af..da040067c 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/array.py @@ -7,48 +7,169 @@ class Array(ABC): + """ + An Array is a multi-dimensional array of data that can be read from and written to. It is + defined by a region of interest (ROI) in world units, a voxel size, and a number of spatial + dimensions. The data is stored in a numpy array, and can be accessed using numpy-like slicing + syntax. + + The Array class is an abstract base class that defines the interface for all Array + implementations. It provides a number of properties that must be implemented by subclasses, + such as the ROI, voxel size, and data type of the array. It also provides a method for fetching + data from the array, which is implemented by slicing the numpy array. + + The Array class also provides a method for checking if the array can be visualized in + Neuroglancer, and a method for generating a Neuroglancer layer for the array. These methods are + implemented by subclasses that support visualization in Neuroglancer. + + Attributes: + attrs (Dict[str, Any]): A dictionary of metadata attributes stored on this array. + axes (List[str]): The axes of this dataset as a string of characters, as they are indexed. + Permitted characters are: + * ``zyx`` for spatial dimensions + * ``c`` for channels + * ``s`` for samples + dims (int): The number of spatial dimensions. + voxel_size (Coordinate): The size of a voxel in physical units. + roi (Roi): The total ROI of this array, in world units. + dtype (Any): The dtype of this array, in numpy dtypes + num_channels (Optional[int]): The number of channels provided by this dataset. Should return + None if the channel dimension doesn't exist. + data (np.ndarray): A numpy-like readable and writable view into this array. + writable (bool): Can we write to this Array? + Methods: + __getitem__(self, roi: Roi) -> np.ndarray: Get a numpy like readable and writable view into + this array. + _can_neuroglance(self) -> bool: Check if this array can be visualized in Neuroglancer. + _neuroglancer_layer(self): Generate a Neuroglancer layer for this array. + _slices(self, roi: Roi) -> Iterable[slice]: Generate a list of slices for the given ROI. + Note: + This class is used to define the interface for all Array implementations. It provides a + number of properties that must be implemented by subclasses, such as the ROI, voxel size, and + data type of the array. It also provides a method for fetching data from the array, which is + implemented by slicing the numpy array. The Array class also provides a method for checking + if the array can be visualized in Neuroglancer, and a method for generating a Neuroglancer + layer for the array. These methods are implemented by subclasses that support visualization + in Neuroglancer. + """ + @property @abstractmethod def attrs(self) -> Dict[str, Any]: """ Return a dictionary of metadata attributes stored on this array. + + Returns: + Dict[str, Any]: A dictionary of metadata attributes stored on this array. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.attrs + {} + Note: + This method must be implemented by the subclass. """ pass @property @abstractmethod def axes(self) -> List[str]: - """Returns the axes of this dataset as a string of charactes, as they + """ + Returns the axes of this dataset as a string of charactes, as they are indexed. Permitted characters are: * ``zyx`` for spatial dimensions * ``c`` for channels * ``s`` for samples + + Returns: + List[str]: The axes of this dataset as a string of characters, as they are indexed. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.axes + ['z', 'y', 'x'] + Note: + This method must be implemented by the subclass. """ pass @property @abstractmethod def dims(self) -> int: - """Returns the number of spatial dimensions.""" + """ + Returns the number of spatial dimensions. + + Returns: + int: The number of spatial dimensions. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.dims + 3 + Note: + This method must be implemented by the subclass. + """ pass @property @abstractmethod def voxel_size(self) -> Coordinate: - """The size of a voxel in physical units.""" + """ + The size of a voxel in physical units. + + Returns: + Coordinate: The size of a voxel in physical units. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.voxel_size + Coordinate((1, 1, 1)) + Note: + This method must be implemented by the subclass. + """ pass @property @abstractmethod def roi(self) -> Roi: - """The total ROI of this array, in world units.""" + """ + The total ROI of this array, in world units. + + Returns: + Roi: The total ROI of this array, in world units. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.roi + Roi(offset=Coordinate((0, 0, 0)), shape=Coordinate((100, 100, 100))) + Note: + This method must be implemented by the subclass. + """ pass @property @abstractmethod def dtype(self) -> Any: - """The dtype of this array, in numpy dtypes""" + """ + The dtype of this array, in numpy dtypes + + Returns: + Any: The dtype of this array, in numpy dtypes. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.dtype + np.dtype('uint8') + Note: + This method must be implemented by the subclass. + """ pass @property @@ -57,6 +178,17 @@ def num_channels(self) -> Optional[int]: """ The number of channels provided by this dataset. Should return None if the channel dimension doesn't exist. + + Returns: + Optional[int]: The number of channels provided by this dataset. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.num_channels + 1 + Note: + This method must be implemented by the subclass. """ pass @@ -65,6 +197,17 @@ def num_channels(self) -> Optional[int]: def data(self) -> np.ndarray: """ Get a numpy like readable and writable view into this array. + + Returns: + np.ndarray: A numpy like readable and writable view into this array. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.data + np.ndarray + Note: + This method must be implemented by the subclass. """ pass @@ -73,10 +216,38 @@ def data(self) -> np.ndarray: def writable(self) -> bool: """ Can we write to this Array? + + Returns: + bool: Can we write to this Array? + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array.writable + False + Note: + This method must be implemented by the subclass. """ pass def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Get a numpy like readable and writable view into this array. + + Args: + roi (Roi): The region of interest to fetch data from. + Returns: + np.ndarray: A numpy like readable and writable view into this array. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> roi = Roi(offset=Coordinate((0, 0, 0)), shape=Coordinate((100, 100, 100))) + >>> array[roi] + np.ndarray + Note: + This method must be implemented by the subclass. + """ if not self.roi.contains(roi): raise ValueError(f"Cannot fetch data from outside my roi: {self.roi}!") @@ -92,12 +263,53 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return self.data[slices] def _can_neuroglance(self) -> bool: + """ + Check if this array can be visualized in Neuroglancer. + + Returns: + bool: Whether this array can be visualized in Neuroglancer. + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array._can_neuroglance() + False + Note: + This method must be implemented by the subclass. + """ return False def _neuroglancer_layer(self): + """ + Generate a Neuroglancer layer for this array. + + Raises: + NotImplementedError: This method must be implemented by the subclass. + Examples: + >>> array = Array() + >>> array._neuroglancer_layer() + NotImplementedError + Note: + This method must be implemented by the subclass. + """ pass def _slices(self, roi: Roi) -> Iterable[slice]: + """ + Generate a list of slices for the given ROI. + + Args: + roi (Roi): The region of interest to generate slices for. + Returns: + Iterable[slice]: A list of slices for the given ROI. + Examples: + >>> array = Array() + >>> roi = Roi(offset=Coordinate((0, 0, 0)), shape=Coordinate((100, 100, 100))) + >>> array._slices(roi) + [slice(None, None, None), slice(None, None, None), slice(None, None, None)] + Note: + This method must be implemented by the subclass. + """ offset = (roi.offset - self.roi.offset) / self.voxel_size shape = roi.shape / self.voxel_size spatial_slices: Dict[str, slice] = { diff --git a/dacapo/experiments/datasplits/datasets/arrays/array_config.py b/dacapo/experiments/datasplits/datasets/arrays/array_config.py index 0642cbb52..a8e51dfd2 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/array_config.py @@ -5,9 +5,22 @@ @attr.s class ArrayConfig: - """Base class for array configurations. Each subclass of an + """ + Base class for array configurations. Each subclass of an `Array` should have a corresponding config class derived from - `ArrayConfig`. + `ArrayConfig`. This class should be used to store the configuration + of the array. + + Attributes: + name (str): A unique name for this array. This will be saved so you + and others can find and reuse this array. Keep it short + and avoid special characters. + Methods: + verify(self) -> Tuple[bool, str]: This method is used to check whether this is a valid Array. + Note: + This class is used to create a base class for array configurations. Each subclass of an + `Array` should have a corresponding config class derived from `ArrayConfig`. + This class should be used to store the configuration of the array. """ name: str = attr.ib( @@ -21,5 +34,18 @@ class ArrayConfig: def verify(self) -> Tuple[bool, str]: """ Check whether this is a valid Array + + Returns: + Tuple[bool, str]: A tuple with the first element being a boolean + indicating whether the array is valid and the second element being + a string with a message explaining why the array is invalid + Raises: + NotImplementedError: This method is not implemented in this class + Examples: + >>> array_config = ArrayConfig(name="array_config") + >>> array_config.verify() + (True, "No validation for this Array") + Note: + This method is used to check whether this is a valid Array. """ return True, "No validation for this Array" diff --git a/dacapo/experiments/datasplits/datasets/arrays/binarize_array.py b/dacapo/experiments/datasplits/datasets/arrays/binarize_array.py index 791c1051c..dc79fcae5 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/binarize_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/binarize_array.py @@ -13,19 +13,50 @@ class BinarizeArray(Array): Because we often want to predict classes that are a combination of a set of labels we wrap a ZarrArray with the BinarizeArray and provide something like `groupings=[("mito", [3,4,5])]` - where 4 corresponds to mito_membrane, 5 is mito_ribos, and - 3 is everything else that is part of a mitochondria. The BinarizeArray - will simply combine labels 3,4,5 into a single binary channel for th - class of "mito". + where 4 corresponds to mito_mem (mitochondria membrane), 5 is mito_ribo + (mitochondria ribosomes), and 3 is everything else that is part of a + mitochondria. The BinarizeArray will simply combine labels 3,4,5 into + a single binary channel for the class of "mito". + We use a single channel per class because some classes may overlap. For example if you had `groupings=[("mito", [3,4,5]), ("membrane", [4, 8, 1])]` - where 4 is mito_membrane, 8 is er_membrane, and 1 is plasma_membrane. + where 4 is mito_mem, 8 is er_mem (ER membrane), and 1 is pm (plasma membrane). Now you can have a binary classification for membrane or not which in some cases overlaps with the channel for mitochondria which includes the mito membrane. + + Attributes: + name (str): The name of the array. + source_array (Array): The source array to binarize. + background (int): The label to treat as background. + groupings (List[Tuple[str, List[int]]]): A list of tuples where the first + element is the name of the class and the second element is a list of + labels that should be combined into a single binary channel. + Methods: + __init__(self, array_config): This method initializes the BinarizeArray object. + __attrs_post_init__(self): This method is called after the instance has been initialized by the constructor. It is used to set the default_config to an instance of ArrayConfig if it is None. + __getitem__(self, roi: Roi) -> np.ndarray: This method returns the binary channels for the given region of interest. + _can_neuroglance(self): This method returns True if the source array can be visualized in neuroglance. + _neuroglancer_source(self): This method returns the source array for neuroglancer. + _neuroglancer_layer(self): This method returns the neuroglancer layer for the source array. + _source_name(self): This method returns the name of the source array. + Note: + This class is used to create a BinarizeArray object which is a wrapper around a ZarrArray containing uint annotations. """ def __init__(self, array_config): + """ + This method initializes the BinarizeArray object. + + Args: + array_config (ArrayConfig): The array configuration. + Raises: + AssertionError: If the source array has channels. + Examples: + >>> binarize_array = BinarizeArray(array_config) + Note: + This method is used to initialize the BinarizeArray object. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -40,38 +71,147 @@ def __init__(self, array_config): @property def attrs(self): + """ + This method returns the attributes of the source array. + + Returns: + Dict: The attributes of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.attrs + Note: + This method is used to return the attributes of the source array. + """ return self._source_array.attrs @property def axes(self): + """ + This method returns the axes of the source array. + + Returns: + List[str]: The axes of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.axes + Note: + This method is used to return the axes of the source array. + """ return ["c"] + self._source_array.axes @property def dims(self) -> int: + """ + This method returns the dimensions of the source array. + + Returns: + int: The dimensions of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.dims + Note: + This method is used to return the dimensions of the source array. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + This method returns the voxel size of the source array. + + Returns: + Coordinate: The voxel size of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.voxel_size + Note: + This method is used to return the voxel size of the source array. + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + This method returns the region of interest of the source array. + + Returns: + Roi: The region of interest of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.roi + Note: + This method is used to return the region of interest of the source array. + """ return self._source_array.roi @property def writable(self) -> bool: + """ + This method returns True if the source array is writable. + + Returns: + bool: True if the source array is writable. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.writable + Note: + This method is used to return True if the source array is writable. + """ return False @property def dtype(self): + """ + This method returns the data type of the source array. + + Returns: + np.dtype: The data type of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.dtype + Note: + This method is used to return the data type of the source array. + """ return np.uint8 @property def num_channels(self) -> int: + """ + This method returns the number of channels in the source array. + + Returns: + int: The number of channels in the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.num_channels + Note: + This method is used to return the number of channels in the source array. + + """ return len(self._groupings) @property def data(self): + """ + This method returns the data of the source array. + + Returns: + np.ndarray: The data of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.data + Note: + This method is used to return the data of the source array. + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -79,9 +219,35 @@ def data(self): @property def channels(self): + """ + This method returns the channel names of the source array. + + Returns: + Iterator[str]: The channel names of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array.channels + Note: + This method is used to return the channel names of the source array. + """ return (name for name, _ in self._groupings) def __getitem__(self, roi: Roi) -> np.ndarray: + """ + This method returns the binary channels for the given region of interest. + + Args: + roi (Roi): The region of interest. + Returns: + np.ndarray: The binary channels for the given region of interest. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array[roi] + Note: + This method is used to return the binary channels for the given region of interest. + """ labels = self._source_array[roi] grouped = np.zeros((len(self._groupings), *labels.shape), dtype=np.uint8) for i, (_, ids) in enumerate(self._groupings): @@ -92,14 +258,62 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return grouped def _can_neuroglance(self): + """ + This method returns True if the source array can be visualized in neuroglance. + + Returns: + bool: True if the source array can be visualized in neuroglance. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array._can_neuroglance() + Note: + This method is used to return True if the source array can be visualized in neuroglance. + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + This method returns the source array for neuroglancer. + + Returns: + neuroglancer.LocalVolume: The source array for neuroglancer. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array._neuroglancer_source() + Note: + This method is used to return the source array for neuroglancer. + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + This method returns the neuroglancer layer for the source array. + + Returns: + neuroglancer.SegmentationLayer: The neuroglancer layer for the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array._neuroglancer_layer() + Note: + This method is used to return the neuroglancer layer for the source array. + """ layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) return layer def _source_name(self): + """ + This method returns the name of the source array. + + Returns: + str: The name of the source array. + Raises: + ValueError: If the source array is not writable. + Examples: + >>> binarize_array._source_name() + Note: + This method is used to return the name of the source array. + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py index 62f4c4da6..195c9eb16 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/binarize_array_config.py @@ -8,8 +8,21 @@ @attr.s class BinarizeArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for turning an Annotated dataset into a - multi class binary classification problem""" + """ + This config class provides the necessary configuration for turning an Annotated dataset into a + multi class binary classification problem. Each class will be binarized into a separate channel. + + Attributes: + source_array_config (ArrayConfig): The Array from which to pull annotated data. Is expected to contain a volume with uint64 voxels and no channel dimension + groupings (List[Tuple[str, List[int]]]): List of id groups with a symantic name. Each id group is a List of ids. + Group i found in groupings[i] will be binarized and placed in channel i. + An empty group will binarize all non background labels. + background (int): The id considered background. Will never be binarized to 1, defaults to 0. + Note: + This class is used to create a BinarizeArray object which is used to turn an Annotated dataset into a multi class binary classification problem. + Each class will be binarized into a separate channel. + + """ array_type = BinarizeArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py index 37cf650f6..2cea77a00 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array.py @@ -11,10 +11,53 @@ class ConcatArray(Array): - """This is a wrapper around other `source_arrays` that concatenates - them along the channel dimension.""" + """ + This is a wrapper around other `source_arrays` that concatenates + them along the channel dimension. The `source_arrays` are expected + to have the same shape and ROI, but can have different data types. + + Attributes: + name: The name of the array. + channels: The list of channel names. + source_arrays: A dictionary mapping channel names to source arrays. + default_array: An optional default array to use for channels that are + not present in `source_arrays`. + Methods: + from_toml(cls, toml_path: str) -> ConcatArrayConfig: + Load the ConcatArrayConfig from a TOML file + to_toml(self, toml_path: str) -> None: + Save the ConcatArrayConfig to a TOML file + create_array(self) -> ConcatArray: + Create the ConcatArray from the config + Note: + This class is a subclass of Array and inherits all its attributes + and methods. The only difference is that the array_type is ConcatArray. + + """ def __init__(self, array_config): + """ + Initialize the ConcatArray from a ConcatArrayConfig. + + Args: + array_config (ConcatArrayConfig): The config to create the ConcatArray from. + Raises: + AssertionError: If the source arrays have different shapes or ROIs. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + Note: + The `source_arrays` are expected to have the same shape and ROI, + but can have different data types. + """ self.name = array_config.name self.channels = array_config.channels self.source_arrays = { @@ -29,14 +72,82 @@ def __init__(self, array_config): @property def attrs(self): + """ + Return the attributes of the ConcatArray as a dictionary. + + Returns: + Dict[str, Any]: The attributes of the ConcatArray. + Raises: + AssertionError: If the source arrays have different attributes. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.attrs + {'axes': 'cxyz', 'roi': Roi(...), 'voxel_size': (1, 1, 1)} + Note: + The `source_arrays` are expected to have the same attributes. + """ return dict() @property def source_arrays(self) -> Dict[str, Array]: + """ + Return the source arrays of the ConcatArray. + + Returns: + Dict[str, Array]: The source arrays of the ConcatArray. + Raises: + AssertionError: If the source arrays are empty. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.source_arrays + {'A': Array(...), 'B': Array(...)} + Note: + The `source_arrays` are expected to have the same shape and ROI. + """ return self._source_arrays @source_arrays.setter def source_arrays(self, value: Dict[str, Array]): + """ + Set the source arrays of the ConcatArray. + + Args: + value (Dict[str, Array]): The source arrays to set. + Raises: + AssertionError: If the source arrays are empty. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.source_arrays = {'A': Array(...), 'B': Array(...)} + Note: + The `source_arrays` are expected to have the same shape and ROI. + """ assert len(value) > 0, "Source arrays is empty!" self._source_arrays = value attrs: Dict[str, Any] = {} @@ -58,10 +169,56 @@ def source_arrays(self, value: Dict[str, Array]): @property def source_array(self) -> Array: + """ + Return the source array of the ConcatArray. + + Returns: + Array: The source array of the ConcatArray. + Raises: + AssertionError: If the source array is None. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.source_array + Array(...) + Note: + The `source_array` is expected to have the same shape and ROI. + """ return self._source_array @property def axes(self): + """ + Return the axes of the ConcatArray. + + Returns: + str: The axes of the ConcatArray. + Raises: + AssertionError: If the source arrays have different axes. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.axes + 'cxyz' + Note: + The `source_arrays` are expected to have the same axes. + """ source_axes = self.source_array.axes if "c" not in source_axes: source_axes = ["c"] + source_axes @@ -69,33 +226,210 @@ def axes(self): @property def dims(self): + """ + Return the dimensions of the ConcatArray. + + Returns: + Tuple[int]: The dimensions of the ConcatArray. + Raises: + AssertionError: If the source arrays have different dimensions. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.dims + (2, 100, 100, 100) + Note: + The `source_arrays` are expected to have the same dimensions. + """ return self.source_array.dims @property def voxel_size(self): + """ + Return the voxel size of the ConcatArray. + + Returns: + Tuple[float]: The voxel size of the ConcatArray. + Raises: + AssertionError: If the source arrays have different voxel sizes. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.voxel_size + (1, 1, 1) + Note: + The `source_arrays` are expected to have the same voxel size. + """ return self.source_array.voxel_size @property def roi(self): + """ + Return the ROI of the ConcatArray. + + Returns: + Roi: The ROI of the ConcatArray. + Raises: + AssertionError: If the source arrays have different ROIs. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.roi + Roi(...) + Note: + The `source_arrays` are expected to have the same ROI. + """ return self.source_array.roi @property def writable(self) -> bool: + """ + Return whether the ConcatArray is writable. + + Returns: + bool: Whether the ConcatArray is writable. + Raises: + AssertionError: If the ConcatArray is writable. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.writable + False + Note: + The ConcatArray is not writable. + """ return False @property def data(self): + """ + Return the data of the ConcatArray. + + Returns: + np.ndarray: The data of the ConcatArray. + Raises: + RuntimeError: If the ConcatArray is not writable. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.data + np.ndarray(...) + Note: + The ConcatArray is not writable. + """ raise RuntimeError("Cannot get writable version of this data!") @property def dtype(self): + """ + Return the data type of the ConcatArray. + + Returns: + np.dtype: The data type of the ConcatArray. + Raises: + AssertionError: If the source arrays have different data types. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.dtype + np.float32 + Note: + The `source_arrays` are expected to have the same data type. + """ return self.source_array.dtype @property def num_channels(self): + """ + Return the number of channels of the ConcatArray. + + Returns: + int: The number of channels of the ConcatArray. + Raises: + AssertionError: If the source arrays have different numbers of channels. + Examples: + >>> config = ConcatArrayConfig( + ... name="my_concat_array", + ... channels=["A", "B"], + ... source_array_configs={ + ... "A": ArrayConfig(...), + ... "B": ArrayConfig(...), + ... }, + ... default_config=ArrayConfig(...), + ... ) + >>> array = ConcatArray(config) + >>> array.num_channels + 2 + Note: + The `source_arrays` are expected to have the same number of channels. + """ return len(self.channels) def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Return the data of the ConcatArray for a given ROI. + + Args: + roi (Roi): The ROI to get the data for. + Returns: + np.ndarray: The data of the ConcatArray for the given ROI. + Raises: + AssertionError: If the source arrays have different shapes or ROIs. + Examples: + >>> roi = Roi(...) + >>> array[roi] + np.ndarray(...) + Note: + The `source_arrays` are expected to have the same shape and ROI. + """ default = ( np.zeros_like(self.source_array[roi]) if self.default_array is None diff --git a/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py index ca76c167b..cc734f70b 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/concat_array_config.py @@ -8,7 +8,20 @@ @attr.s class ConcatArrayConfig(ArrayConfig): - """This array read data from the source array and then return a np.ones_like() version.""" + """ + This array read data from the source array and then return a np.ones_like() version of the data. + + Attributes: + channels (List[str]): An ordering for the source_arrays. + source_array_configs (Dict[str, ArrayConfig]): A mapping from channels to array_configs. If a channel has no ArrayConfig it will be filled with zeros + default_config (Optional[ArrayConfig]): An optional array providing the default array per channel. If not provided, missing channels will simply be filled with 0s + Methods: + __attrs_post_init__(self): This method is called after the instance has been initialized by the constructor. It is used to set the default_config to an instance of ArrayConfig if it is None. + get_array(self, source_arrays: Dict[str, np.ndarray]) -> np.ndarray: This method reads data from the source array and then return a np.ones_like() version of the data. + Note: + This class is used to create a ConcatArray object which is used to read data from the source array and then return a np.ones_like() version of the data. + The source array is a dictionary with the key being the channel and the value being the array. + """ array_type = ConcatArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/crop_array.py b/dacapo/experiments/datasplits/datasets/arrays/crop_array.py index 04b163513..96bdad0fd 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/crop_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/crop_array.py @@ -7,10 +7,64 @@ class CropArray(Array): """ - Used to crop a larger array to a smaller array. + Used to crop a larger array to a smaller array. This is useful when you + want to work with a subset of a larger array, but don't want to copy the + data. The crop is done on demand, so the data is not copied until you + actually access it. + + Attributes: + name: The name of the array. + source_array: The array to crop. + crop_roi: The region of interest to crop to. + Methods: + attrs: Returns the attributes of the source array. + axes: Returns the axes of the source array. + dims: Returns the number of dimensions of the source array. + voxel_size: Returns the voxel size of the source array. + roi: Returns the region of interest of the source array. + writable: Returns whether the array is writable. + dtype: Returns the data type of the source array. + num_channels: Returns the number of channels of the source array. + data: Returns the data of the source array. + channels: Returns the channels of the source array. + __getitem__(roi): Returns the data of the source array within the + region of interest. + _can_neuroglance(): Returns whether the source array can be viewed in + Neuroglancer. + _neuroglancer_source(): Returns the source of the source array for + Neuroglancer. + _neuroglancer_layer(): Returns the layer of the source array for + Neuroglancer. + _source_name(): Returns the name of the source array. + Note: + This class is a subclass of Array. + + """ def __init__(self, array_config): + """ + Initializes the CropArray. + + Args: + array_config: The configuration of the array to crop. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + Note: + The source array configuration must be an instance of ArrayConfig. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -19,38 +73,265 @@ def __init__(self, array_config): @property def attrs(self): + """ + Returns the attributes of the source array. + + Returns: + The attributes of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.attrs + {} + Note: + The attributes are empty because the source array is not modified. + """ return self._source_array.attrs @property def axes(self): + """ + Returns the axes of the source array. + + Returns: + The axes of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.axes + 'zyx' + Note: + The axes are 'zyx' because the source array is not modified. + """ return self._source_array.axes @property def dims(self) -> int: + """ + Returns the number of dimensions of the source array. + + Returns: + The number of dimensions of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.dims + 3 + Note: + The number of dimensions is 3 because the source array is not + modified. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the source array. + + Returns: + The voxel size of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Note: + The voxel size is (1.0, 1.0, 1.0) because the source array is not + modified. + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Returns the region of interest of the source array. + + Returns: + The region of interest of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.roi + Roi(offset=(0, 0, 0), shape=(10, 10, 10)) + Note: + The region of interest is (0, 0, 0) with shape (10, 10, 10) + because the source array is not modified. + """ return self.crop_roi.intersect(self._source_array.roi) @property def writable(self) -> bool: + """ + Returns whether the array is writable. + + Returns: + False + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.writable + False + Note: + The array is not writable because it is a virtual array created by + modifying another array on demand. + """ return False @property def dtype(self): + """ + Returns the data type of the source array. + + Returns: + The data type of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.dtype + numpy.dtype('uint8') + Note: + The data type is uint8 because the source array is not modified. + """ return self._source_array.dtype @property def num_channels(self) -> int: + """ + Returns the number of channels of the source array. + + Returns: + The number of channels of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.num_channels + 1 + Note: + The number of channels is 1 because the source array is not + modified. + """ return self._source_array.num_channels @property def data(self): + """ + Returns the data of the source array. + + Returns: + The data of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.data + array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0, 0, 0, 0, 0 + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -58,20 +339,170 @@ def data(self): @property def channels(self): + """ + Returns the channels of the source array. + + Returns: + The channels of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array.channels + 1 + Note: + The channels is 1 because the source array is not modified. + """ return self._source_array.channels def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns the data of the source array within the region of interest. + + Args: + roi: The region of interest. + Returns: + The data of the source array within the region of interest. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> array_config = ArrayConfig( + ... name='array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array[Roi((0, 0, 0), (5, 5, 5))] + array([[[ + Note: + The data is the same as the source array because the source array + is not modified. + """ assert self.roi.contains(roi) return self._source_array[roi] def _can_neuroglance(self): + """ + Returns whether the source array can be viewed in Neuroglancer. + + Returns: + Whether the source array can be viewed in Neuroglancer. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array_config = ArrayConfig( + ... name='source_array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array._can_neuroglance() + False + Note: + The source array cannot be viewed in Neuroglancer because the + source array is not modified. + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Returns the source of the source array for Neuroglancer. + + Returns: + The source of the source array for Neuroglancer. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array_config = ArrayConfig( + ... name='source_array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array._neuroglancer_source() + {'source': 'source_array'} + Note: + The source is the source array because the source array is not + modified. + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Returns the layer of the source array for Neuroglancer. + + Returns: + The layer of the source array for Neuroglancer. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array_config = ArrayConfig( + ... name='source_array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array._neuroglancer_layer() + {'source': 'source_array', 'type': 'image'} + Note: + The layer is an image because the source array is not modified. + """ return self._source_array._neuroglancer_layer() def _source_name(self): + """ + Returns the name of the source array. + + Returns: + The name of the source array. + Raises: + ValueError: If the region of interest to crop to is not within the + region of interest of the source array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import CropArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array_config = ArrayConfig( + ... name='source_array', + ... source_array_config=source_array_config, + ... roi=Roi((0, 0, 0), (10, 10, 10)) + ... ) + >>> crop_array = CropArray(array_config) + >>> crop_array._source_name() + 'source_array' + Note: + The name is the source array because the source array is not + modified. + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/crop_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/crop_array_config.py index 0a8d885fd..899120e90 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/crop_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/crop_array_config.py @@ -8,9 +8,26 @@ @attr.s class CropArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for cropping an + """ + This config class provides the necessary configuration for cropping an Array to a smaller ROI. Especially useful for validation volumes that may - be too large for quick evaluation""" + be too large for quick evaluation. The ROI is specified in the config. The + cropped Array will have the same dtype as the source Array. + + Attributes: + source_array_config (ArrayConfig): The Array to crop + roi (Roi): The ROI for cropping + Methods: + from_toml(cls, toml_path: str) -> CropArrayConfig: + Load the CropArrayConfig from a TOML file + to_toml(self, toml_path: str) -> None: + Save the CropArrayConfig to a TOML file + create_array(self) -> CropArray: + Create the CropArray from the config + Note: + This class is a subclass of ArrayConfig and inherits all its attributes + and methods. The only difference is that the array_type is CropArray. + """ array_type = CropArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/dummy_array.py b/dacapo/experiments/datasplits/datasets/arrays/dummy_array.py index 8e3ce3daa..3d23ebf05 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dummy_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dummy_array.py @@ -6,44 +6,188 @@ class DummyArray(Array): - """This is just a dummy array for testing.""" + """ + This is just a dummy array for testing. It has a shape of (100, 50, 50) and is filled with zeros. + + Attributes: + array_config (ArrayConfig): The config object for the array + Methods: + __getitem__: Returns the intensities normalized to the range (0, 1) + Notes: + The array_config must be an ArrayConfig object. + The min and max values are used to normalize the intensities. + All intensities are converted to float32. + + """ def __init__(self, array_config): + """ + Initializes the IntensitiesArray object + + Args: + array_config (ArrayConfig): The config object for the array + Raises: + ValueError: If the array_config is not an ArrayConfig object + Examples: + >>> array_config = ArrayConfig(...) + >>> intensities_array = IntensitiesArray(array_config) + Notes: + The array_config must be an ArrayConfig object. + """ super().__init__() self._data = np.zeros((100, 50, 50)) @property def attrs(self): + """ + Returns the attributes of the source array + + Returns: + dict: The attributes of the source array + Raises: + ValueError: If the attributes is not a dictionary + Examples: + >>> intensities_array.attrs + {'resolution': (1.0, 1.0, 1.0), 'unit': 'micrometer'} + """ return dict() @property def axes(self): + """ + Returns the axes of the source array + + Returns: + str: The axes of the source array + Raises: + ValueError: If the axes is not a string + Examples: + >>> intensities_array.axes + 'zyx' + Notes: + The axes are the same as the source array + """ return ["z", "y", "x"] @property def dims(self): + """ + Returns the number of dimensions of the source array + + Returns: + int: The number of dimensions of the source array + Raises: + ValueError: If the dims is not an integer + Examples: + >>> intensities_array.dims + 3 + Notes: + The dims are the same as the source array + """ return 3 @property def voxel_size(self): + """ + Returns the voxel size of the source array + + Returns: + Coordinate: The voxel size of the source array + Raises: + ValueError: If the voxel size is not a Coordinate object + Examples: + >>> intensities_array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Notes: + The voxel size is the same as the source array + """ return Coordinate(1, 2, 2) @property def roi(self): + """ + Returns the region of interest of the source array + + Returns: + Roi: The region of interest of the source array + Raises: + ValueError: If the roi is not a Roi object + Examples: + >>> intensities_array.roi + Roi(offset=(0, 0, 0), shape=(100, 100, 100)) + Notes: + The roi is the same as the source array + """ return Roi((0, 0, 0), (100, 100, 100)) @property def writable(self) -> bool: + """ + Returns whether the array is writable + + Returns: + bool: Whether the array is writable + Examples: + >>> intensities_array.writable + True + Notes: + The array is always writable + """ return True @property def data(self): + """ + Returns the data of the source array + + Returns: + np.ndarray: The data of the source array + Raises: + ValueError: If the data is not a numpy array + Examples: + >>> intensities_array.data + array([[[0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.], + ..., + [0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.], + [0., 0., 0., ..., 0., 0., 0.]], + Notes: + The data is the same as the source array + """ return self._data @property def dtype(self): + """ + Returns the data type of the array + + Returns: + type: The data type of the array + Raises: + ValueError: If the data type is not a type + Examples: + >>> intensities_array.dtype + numpy.float32 + Notes: + The data type is the same as the source array + """ return self._data.dtype @property def num_channels(self): + """ + Returns the number of channels in the source array + + Returns: + int: The number of channels in the source array + Raises: + ValueError: If the number of channels is not an integer + Examples: + >>> intensities_array.num_channels + 1 + Notes: + The number of channels is the same as the source array + """ return None diff --git a/dacapo/experiments/datasplits/datasets/arrays/dummy_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/dummy_array_config.py index fba67ec51..44632ae2b 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dummy_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dummy_array_config.py @@ -8,10 +8,34 @@ @attr.s class DummyArrayConfig(ArrayConfig): - """This is just a dummy array config used for testing. None of the - attributes have any particular meaning.""" + """ + This is just a dummy array config used for testing. None of the + attributes have any particular meaning. It is used to test the + ArrayConfig class. + + Methods: + to_array: Returns the DummyArray object + verify: Returns whether the DummyArrayConfig is valid + Notes: + The source_array_config must be an ArrayConfig object. + + """ array_type = DummyArray def verify(self) -> Tuple[bool, str]: + """ + Check whether this is a valid Array + + Returns: + Tuple[bool, str]: Whether the Array is valid and a message + Raises: + ValueError: If the source is not a tuple of strings + Examples: + >>> dummy_array_config = DummyArrayConfig(...) + >>> dummy_array_config.verify() + (False, "This is a DummyArrayConfig and is never valid") + Notes: + The source must be a tuple of strings. + """ return False, "This is a DummyArrayConfig and is never valid" diff --git a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py index e08ffe562..b6abc29e1 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dvid_array.py @@ -22,41 +22,180 @@ class DVIDArray(Array): - """This is a DVID array""" + """ + This is a DVID array. It is a wrapper around a DVID array that provides + the necessary methods to interact with the array. It is used to fetch data + from a DVID server. The source is a tuple of three strings: the server, the UUID, + and the data name. + + DVID: data management system for terabyte-sized 3D images + + Attributes: + name (str): The name of the array + source (tuple[str, str, str]): The source of the array + Methods: + __getitem__: Returns the data from the array for a given region of interest + Notes: + The source is a tuple of three strings: the server, the UUID, and the data name. + """ def __init__(self, array_config): + """ + Initializes the DVIDArray object + + Args: + array_config (ArrayConfig): The config object for the array + Returns: + DVIDArray: The DVIDArray object + Raises: + ValueError: If the array_config is not an ArrayConfig object + Examples: + >>> array_config = ArrayConfig(...) + >>> dvid_array = DVIDArray(array_config) + Notes: + The array_config must be an ArrayConfig object. + + """ super().__init__() self.name: str = array_config.name self.source: tuple[str, str, str] = array_config.source def __str__(self): + """ + Returns the string representation of the DVIDArray object + + Returns: + str: The string representation of the DVIDArray object + Raises: + ValueError: If the source is not a tuple of three strings + Examples: + >>> str(dvid_array) + DVIDArray(('server', 'UUID', 'data_name')) + Notes: + The string representation is the source of the array + """ return f"DVIDArray({self.source})" def __repr__(self): + """ + Returns the string representation of the DVIDArray object + + Returns: + str: The string representation of the DVIDArray object + Raises: + ValueError: If the source is not a tuple of three strings + Examples: + >>> repr(dvid_array) + DVIDArray(('server', 'UUID', 'data_name')) + Notes: + The string representation is the source of the array + """ return f"DVIDArray({self.source})" @lazy_property.LazyProperty def attrs(self): + """ + Returns the attributes of the DVID array + + Returns: + dict: The attributes of the DVID array + Raises: + ValueError: If the attributes is not a dictionary + Examples: + >>> dvid_array.attrs + {'Extended': {'VoxelSize': (1.0, 1.0, 1.0), 'Values': [{'DataType': 'uint64'}]}, 'Extents': {'MinPoint': (0, 0, 0), 'MaxPoint': (100, 100, 100)}} + Notes: + The attributes are the same as the source array + """ return fetch_info(*self.source) @property def axes(self): + """ + Returns the axes of the DVID array + + Returns: + str: The axes of the DVID array + Raises: + ValueError: If the axes is not a string + Examples: + >>> dvid_array.axes + 'zyx' + Notes: + The axes are the same as the source array + """ return ["c", "z", "y", "x"][-self.dims :] @property def dims(self) -> int: + """ + Returns the dimensions of the DVID array + + Returns: + int: The dimensions of the DVID array + Raises: + ValueError: If the dimensions is not an integer + Examples: + >>> dvid_array.dims + 3 + Notes: + The dimensions are the same as the source array + """ return self.voxel_size.dims @lazy_property.LazyProperty def _daisy_array(self) -> funlib.persistence.Array: + """ + Returns the DVID array as a Daisy array + + Returns: + funlib.persistence.Array: The DVID array as a Daisy array + Raises: + ValueError: If the DVID array is not a Daisy array + Examples: + >>> dvid_array._daisy_array + Array(...) + Notes: + The DVID array is a Daisy array + """ raise NotImplementedError() @lazy_property.LazyProperty def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the DVID array + + Returns: + Coordinate: The voxel size of the DVID array + Raises: + ValueError: If the voxel size is not a Coordinate object + Examples: + >>> dvid_array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Notes: + The voxel size is the same as the source array + """ return Coordinate(self.attrs["Extended"]["VoxelSize"]) @lazy_property.LazyProperty def roi(self) -> Roi: + """ + Returns the region of interest of the DVID array + + Returns: + Roi: The region of interest of the DVID array + Raises: + ValueError: If the region of interest is not a Roi object + Examples: + >>> dvid_array.roi + Roi(...) + Notes: + The region of interest is the same as the source array + """ + return Roi( + Coordinate(self.attrs["Extents"]["MinPoint"]) * self.voxel_size, + Coordinate(self.attrs["Extents"]["MaxPoint"]) * self.voxel_size, + ) return Roi( Coordinate(self.attrs["Extents"]["MinPoint"]) * self.voxel_size, Coordinate(self.attrs["Extents"]["MaxPoint"]) * self.voxel_size, @@ -64,25 +203,105 @@ def roi(self) -> Roi: @property def writable(self) -> bool: + """ + Returns whether the DVID array is writable + + Returns: + bool: Whether the DVID array is writable + Raises: + ValueError: If the writable is not a boolean + Examples: + >>> dvid_array.writable + False + Notes: + The writable is the same as the source array + """ return False @property def dtype(self) -> Any: + """ + Returns the data type of the DVID array + + Returns: + type: The data type of the DVID array + Raises: + ValueError: If the data type is not a type + Examples: + >>> dvid_array.dtype + numpy.uint64 + Notes: + The data type is the same as the source array + """ return np.dtype(self.attrs["Extended"]["Values"][0]["DataType"]) @property def num_channels(self) -> Optional[int]: + """ + Returns the number of channels of the DVID array + + Returns: + int: The number of channels of the DVID array + Raises: + ValueError: If the number of channels is not an integer + Examples: + >>> dvid_array.num_channels + 1 + Notes: + The number of channels is the same as the source array + """ return None @property def spatial_axes(self) -> List[str]: + """ + Returns the spatial axes of the DVID array + + Returns: + List[str]: The spatial axes of the DVID array + Raises: + ValueError: If the spatial axes is not a list + Examples: + >>> dvid_array.spatial_axes + ['z', 'y', 'x'] + Notes: + The spatial axes are the same as the source array + """ return [ax for ax in self.axes if ax not in set(["c", "b"])] @property def data(self) -> Any: + """ + Returns the number of channels of the DVID array + + Returns: + int: The number of channels of the DVID array + Raises: + ValueError: If the number of channels is not an integer + Examples: + >>> dvid_array.num_channels + 1 + Notes: + The number of channels is the same as the source array + """ raise NotImplementedError() def __getitem__(self, roi: Roi) -> np.ndarray[Any, Any]: + """ + Returns the data of the DVID array for a given region of interest + + Args: + roi (Roi): The region of interest for which to get the data + Returns: + np.ndarray: The data of the DVID array for the region of interest + Raises: + ValueError: If the data is not a numpy array + Examples: + >>> dvid_array[roi] + array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]]) + Notes: + The data is the same as the source array + """ box = np.array( (roi.offset / self.voxel_size, (roi.offset + roi.shape) / self.voxel_size) ) @@ -95,22 +314,114 @@ def __getitem__(self, roi: Roi) -> np.ndarray[Any, Any]: return data def _can_neuroglance(self) -> bool: + """ + Returns whether the DVID array can be used with neuroglance + + Returns: + bool: Whether the DVID array can be used with neuroglance + Raises: + ValueError: If the DVID array cannot be used with neuroglance + Examples: + >>> dvid_array._can_neuroglance() + True + Notes: + The DVID array can be used with neuroglance + """ return True def _neuroglancer_source(self): + """ + Returns the neuroglancer source of the DVID array + + Returns: + Tuple[str, str, str]: The neuroglancer source of the DVID array + Raises: + ValueError: If the neuroglancer source is not a tuple of three strings + Examples: + >>> dvid_array._neuroglancer_source() + ('server', 'UUID', 'data_name') + Notes: + The neuroglancer source is the same as the source array + """ raise NotImplementedError() def _neuroglancer_layer(self) -> Tuple[neuroglancer.ImageLayer, Dict[str, Any]]: + """ + Returns the neuroglancer layer of the DVID array + + Returns: + Tuple[neuroglancer.ImageLayer, dict]: The neuroglancer layer of the DVID array + Raises: + ValueError: If the neuroglancer layer is not a tuple of an ImageLayer and a dictionary + Examples: + >>> dvid_array._neuroglancer_layer() + (ImageLayer(...), {}) + Notes: + The neuroglancer layer is the same as the source array + """ raise NotImplementedError() def _transform_matrix(self): + """ + Returns the transformation matrix of the DVID array + + Returns: + np.ndarray: The transformation matrix of the DVID array + Raises: + ValueError: If the transformation matrix is not a numpy array + Examples: + >>> dvid_array._transform_matrix() + array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + Notes: + The transformation matrix is the same as the source array + """ raise NotImplementedError() def _output_dimensions(self) -> Dict[str, Tuple[float, str]]: + """ + Returns the output dimensions of the DVID array + + Returns: + dict: The output dimensions of the DVID array + Raises: + ValueError: If the output dimensions is not a dictionary + Examples: + >>> dvid_array._output_dimensions() + {'z': (100, 'nm'), 'y': (100, 'nm'), 'x': (100, 'nm')} + Notes: + The output dimensions are the same as the source array + """ raise NotImplementedError() def _source_name(self) -> str: + """ + Returns the source name of the DVID array + + Returns: + str: The source name of the DVID array + Raises: + ValueError: If the source name is not a string + Examples: + >>> dvid_array._source_name() + 'data_name' + Notes: + The source name is the same as the source array + """ raise NotImplementedError() def add_metadata(self, metadata: Dict[str, Any]) -> None: + """ + Adds metadata to the DVID array + + Args: + metadata (dict): The metadata to add to the DVID array + Returns: + None + Raises: + ValueError: If the metadata is not a dictionary + Examples: + >>> dvid_array.add_metadata({'description': 'This is a DVID array'}) + Notes: + The metadata is added to the source array + """ raise NotImplementedError() diff --git a/dacapo/experiments/datasplits/datasets/arrays/dvid_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/dvid_array_config.py index d9c5071c0..db63e2750 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/dvid_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/dvid_array_config.py @@ -9,7 +9,17 @@ @attr.s class DVIDArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for a DVID array""" + """ + This config class provides the necessary configuration for a DVID array. It takes a source string and returns the DVIDArray object. + + Attributes: + source (Tuple[str, str, str]): The source strings + Methods: + to_array: Returns the DVIDArray object + Notes: + The source must be a tuple of strings. + + """ array_type = DVIDArray @@ -20,5 +30,16 @@ class DVIDArrayConfig(ArrayConfig): def verify(self) -> Tuple[bool, str]: """ Check whether this is a valid Array + + Returns: + Tuple[bool, str]: Whether the Array is valid and a message + Raises: + ValueError: If the source is not a tuple of strings + Examples: + >>> dvid_array_config = DVIDArrayConfig(...) + >>> dvid_array_config.verify() + (True, "No validation for this Array") + Notes: + The source must be a tuple of strings. """ return True, "No validation for this Array" diff --git a/dacapo/experiments/datasplits/datasets/arrays/intensity_array.py b/dacapo/experiments/datasplits/datasets/arrays/intensity_array.py index 9840cddd9..7c1365106 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/intensity_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/intensity_array.py @@ -11,9 +11,33 @@ class IntensitiesArray(Array): the range (0, 1) and convert to float32. Use this if you have your intensities stored as uint8 or similar and want your model to have floats as input. + + Attributes: + array_config (ArrayConfig): The config object for the array + min (float): The minimum intensity value in the array + max (float): The maximum intensity value in the array + Methods: + __getitem__: Returns the intensities normalized to the range (0, 1) + Notes: + The array_config must be an ArrayConfig object. + The min and max values are used to normalize the intensities. + All intensities are converted to float32. """ def __init__(self, array_config): + """ + Initializes the IntensitiesArray object + + Args: + array_config (ArrayConfig): The config object for the array + Raises: + ValueError: If the array_config is not an ArrayConfig object + Examples: + >>> array_config = ArrayConfig(...) + >>> intensities_array = IntensitiesArray(array_config) + Notes: + The array_config must be an ArrayConfig object. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -24,44 +48,176 @@ def __init__(self, array_config): @property def attrs(self): + """ + Returns the attributes of the source array + + Returns: + dict: The attributes of the source array + Raises: + ValueError: If the attributes is not a dictionary + Examples: + >>> intensities_array.attrs + {'resolution': (1.0, 1.0, 1.0), 'unit': 'micrometer'} + Notes: + The attributes are the same as the source array + """ return self._source_array.attrs @property def axes(self): + """ + Returns the axes of the source array + + Returns: + str: The axes of the source array + Raises: + ValueError: If the axes is not a string + Examples: + >>> intensities_array.axes + 'zyx' + Notes: + The axes are the same as the source array + """ return self._source_array.axes @property def dims(self) -> int: + """ + Returns the dimensions of the source array + + Returns: + int: The dimensions of the source array + Raises: + ValueError: If the dimensions is not an integer + Examples: + >>> intensities_array.dims + 3 + Notes: + The dimensions are the same as the source array + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the source array + + Returns: + Coordinate: The voxel size of the source array + Raises: + ValueError: If the voxel size is not a Coordinate object + Examples: + >>> intensities_array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Notes: + The voxel size is the same as the source array + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Returns the region of interest of the source array + + Returns: + Roi: The region of interest of the source array + Raises: + ValueError: If the region of interest is not a Roi object + Examples: + >>> intensities_array.roi + Roi(offset=(0, 0, 0), shape=(10, 20, 30)) + Notes: + The region of interest is the same as the source array + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Returns whether the array is writable + + Returns: + bool: Whether the array is writable + Raises: + ValueError: If the array is not writable + Examples: + >>> intensities_array.writable + False + Notes: + The array is not writable because it is a virtual array created by modifying another array on demand. + """ return False @property def dtype(self): + """ + Returns the data type of the array + + Returns: + type: The data type of the array + Raises: + ValueError: If the data type is not a type + Examples: + >>> intensities_array.dtype + numpy.float32 + Notes: + The data type is always float32 + """ return np.float32 @property def num_channels(self) -> int: + """ + Returns the number of channels in the source array + + Returns: + int: The number of channels in the source array + Raises: + ValueError: If the number of channels is not an integer + Examples: + >>> intensities_array.num_channels + 3 + Notes: + The number of channels is the same as the source array + """ return self._source_array.num_channels @property def data(self): + """ + Returns the data of the source array + + Returns: + np.ndarray: The data of the source array + Raises: + ValueError: If the data is not a numpy array + Examples: + >>> intensities_array.data + array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]]) + Notes: + The data is the same as the source array + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." ) def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns the intensities normalized to the range (0, 1) + + Args: + roi (Roi): The region of interest to get the intensities from + Returns: + np.ndarray: The intensities normalized to the range (0, 1) + Raises: + ValueError: If the intensities are not in the range (0, 1) + Examples: + >>> intensities_array[roi] + array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], [[0.7, 0.8, 0.9], [1.0, 1.1, 1.2]]]) + Notes: + The intensities are normalized to the range (0, 1) + """ intensities = self._source_array[roi] normalized = (intensities.astype(np.float32) - self._min) / ( self._max - self._min @@ -69,13 +225,66 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return normalized def _can_neuroglance(self): + """ + Returns whether the array can be visualized with Neuroglancer + + Returns: + bool: Whether the array can be visualized with Neuroglancer + Raises: + ValueError: If the array cannot be visualized with Neuroglancer + Examples: + >>> intensities_array._can_neuroglance() + True + Notes: + The array can be visualized with Neuroglancer if the source array can be visualized with Neuroglancer + + """ return self._source_array._can_neuroglance() def _neuroglancer_layer(self): + """ + Returns the Neuroglancer layer of the source array + + Returns: + dict: The Neuroglancer layer of the source array + Raises: + ValueError: If the Neuroglancer layer is not a dictionary + Examples: + >>> intensities_array._neuroglancer_layer() + {'type': 'image', 'source': 'precomputed://https://mybucket.s3.amazonaws.com/mydata'} + Notes: + The Neuroglancer layer is the same as the source array + """ return self._source_array._neuroglancer_layer() def _source_name(self): + """ + Returns the name of the source array + + Returns: + str: The name of the source array + Raises: + ValueError: If the name is not a string + Examples: + >>> intensities_array._source_name() + 'mydata' + Notes: + The name is the same as the source array + """ return self._source_array._source_name() def _neuroglancer_source(self): + """ + Returns the Neuroglancer source of the source array + + Returns: + str: The Neuroglancer source of the source array + Raises: + ValueError: If the Neuroglancer source is not a string + Examples: + >>> intensities_array._neuroglancer_source() + 'precomputed://https://mybucket.s3.amazonaws.com/mydata' + Notes: + The Neuroglancer source is the same as the source array + """ return self._source_array._neuroglancer_source() diff --git a/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py index 87281f69f..7ea13385c 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/intensity_array_config.py @@ -6,8 +6,20 @@ @attr.s class IntensitiesArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for turning an Annotated dataset into a - multi class binary classification problem""" + """ + This config class provides the necessary configuration for turning an Annotated dataset into a + multi class binary classification problem. It takes a source array and normalizes the intensities + between 0 and 1. The source array is expected to contain a volume with uint64 voxels and no channel dimension. + + Attributes: + source_array_config (ArrayConfig): The Array from which to pull annotated data + min (float): The minimum intensity in your data + max (float): The maximum intensity in your data + Methods: + to_array: Returns the IntensitiesArray object + Notes: + The source_array_config must be an ArrayConfig object. + """ array_type = IntensitiesArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array.py b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array.py index 995f27d05..212d933ac 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array.py @@ -9,9 +9,93 @@ class LogicalOrArray(Array): - """ """ + """ + Array that computes the logical OR of the instances in a list of source arrays. + + Attributes: + name: str + The name of the array + source_array: Array + The source array from which to take the logical OR + Methods: + axes: () -> List[str] + Get the axes of the array + dims: () -> int + Get the number of dimensions of the array + voxel_size: () -> Coordinate + Get the voxel size of the array + roi: () -> Roi + Get the region of interest of the array + writable: () -> bool + Get whether the array is writable + dtype: () -> type + Get the data type of the array + num_channels: () -> int + Get the number of channels in the array + data: () -> np.ndarray + Get the data of the array + attrs: () -> dict + Get the attributes of the array + __getitem__: (roi: Roi) -> np.ndarray + Get the data of the array in the region of interest + _can_neuroglance: () -> bool + Get whether the array can be visualized in neuroglance + _neuroglancer_source: () -> dict + Get the neuroglancer source of the array + _neuroglancer_layer: () -> Tuple[neuroglancer.Layer, dict] + Get the neuroglancer layer of the array + _source_name: () -> str + Get the name of the source array + Notes: + The LogicalOrArray class is used to create a LogicalOrArray. The LogicalOrArray + class is a subclass of the Array class. + """ def __init__(self, array_config): + """ + Create a LogicalOrArray instance from a configuration + Args: + array_config: MergeInstancesArrayConfig + The configuration for the array + Returns: + LogicalOrArray + The LogicalOrArray instance created from the configuration + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.name + 'logical_or' + >>> array.source_array.name + 'mask1' + >>> array.source_array.mask_id + 1 + Notes: + The create_array method is used to create a LogicalOrArray instance from a + configuration. The LogicalOrArray instance is created by taking the logical OR + of the instances in the source arrays. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -19,34 +103,330 @@ def __init__(self, array_config): @property def axes(self): + """ + Get the axes of the array + + Returns: + List[str]: The axes of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.axes + ['x', 'y', 'z'] + Notes: + The axes method is used to get the axes of the array. The axes are the dimensions + of the array. + """ return [x for x in self._source_array.axes if x != "c"] @property def dims(self) -> int: + """ + Get the number of dimensions of the array + + Returns: + int: The number of dimensions of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.dims + 3 + Notes: + The dims method is used to get the number of dimensions of the array. The number + of dimensions is the number of axes of the array. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Get the voxel size of the array + + Returns: + Coordinate: The voxel size of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.voxel_size + Coordinate(x=1.0, y=1.0, z=1.0) + Notes: + The voxel_size method is used to get the voxel size of the array. The voxel size + is the size of a voxel in the array. + + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Get the region of interest of the array + + Returns: + Roi: The region of interest of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.roi + Roi(offset=(0, 0, 0), shape=(10, 10, 10)) + Notes: + The roi method is used to get the region of interest of the array. The region of + interest is the shape and offset of the array. + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Get whether the array is writable + + Returns: + bool: Whether the array is writable + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.writable + False + Notes: + The writable method is used to get whether the array is writable. An array is + writable if it can be modified. + """ return False @property def dtype(self): + """ + Get the data type of the array + + Returns: + type: The data type of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.dtype + + Notes: + The dtype method is used to get the data type of the array. The data type is the + type of the data in the array. + """ return np.uint8 @property def num_channels(self): + """ + Get the number of channels in the array + + Returns: + int: The number of channels in the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.num_channels + 1 + Notes: + The num_channels method is used to get the number of channels in the array. The + number of channels is the number of channels in the array. + """ return None @property def data(self): + """ + Get the data of the array + + Returns: + np.ndarray: The data of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.data + array([[[1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + ..., + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1]]], dtype=uint8) + Notes: + The data method is used to get the data of the array. The data is the content of + the array. + + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -54,21 +434,211 @@ def data(self): @property def attrs(self): + """ + Get the attributes of the array + + Returns: + dict: The attributes of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array.attrs + {'name': 'logical_or'} + Notes: + The attrs method is used to get the attributes of the array. The attributes are + the metadata of the array. + """ return self._source_array.attrs def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Get the data of the array in the region of interest + + Args: + roi: Roi + The region of interest of the array + Returns: + np.ndarray: The data of the array in the region of interest + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> roi = Roi((0, 0, 0), (10, 10, 10)) + >>> array[roi] + array([[[1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + ..., + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1], + [1, 1, 1, ..., 1, 1, 1]]], dtype=uint8) + Notes: + The __getitem__ method is used to get the data of the array in the region of interest. + The data is the content of the array in the region of interest. + """ mask = self._source_array[roi] if "c" in self._source_array.axes: mask = np.max(mask, axis=self._source_array.axes.index("c")) return mask def _can_neuroglance(self): + """ + Get whether the array can be visualized in neuroglance + + Returns: + bool: Whether the array can be visualized in neuroglance + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array._can_neuroglance() + True + Notes: + The _can_neuroglance method is used to get whether the array can be visualized + in neuroglance. + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Get the neuroglancer source of the array + + Returns: + dict: The neuroglancer source of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array._neuroglancer_source() + {'source': 'precomputed://https://mybucket.storage.googleapis.com/path/to/logical_or'} + Notes: + The _neuroglancer_source method is used to get the neuroglancer source of the array. + The neuroglancer source is the source that is displayed in the neuroglancer viewer. + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Get the neuroglancer layer of the array + + Returns: + Tuple[neuroglancer.Layer, dict]: The neuroglancer layer of the array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array._neuroglancer_layer() + (SegmentationLayer(source='precomputed://https://mybucket.storage.googleapis.com/path/to/logical_or'), {'visible': False}) + Notes: + The _neuroglancer_layer method is used to get the neuroglancer layer of the array. + The neuroglancer layer is the layer that is displayed in the neuroglancer viewer. + """ # Generates an Segmentation layer layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) @@ -78,4 +648,40 @@ def _neuroglancer_layer(self): return layer, kwargs def _source_name(self): + """ + Get the name of the source array + + Returns: + str: The name of the source array + Raises: + ValueError: If the array is not writable + Examples: + >>> array_config = MergeInstancesArrayConfig( + ... name="logical_or", + ... source_array_configs=[ + ... ArrayConfig( + ... name="mask1", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask1", + ... mask_id=1, + ... ), + ... ), + ... ArrayConfig( + ... name="mask2", + ... array_type=MaskArray, + ... source_array_config=MaskArrayConfig( + ... name="mask2", + ... mask_id=2, + ... ), + ... ), + ... ], + ... ) + >>> array = array_config.create_array() + >>> array._source_name() + 'mask1' + Notes: + The _source_name method is used to get the name of the source array. The name + of the source array is the name of the array that is being modified. + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py index d0a211a8a..a22591405 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/logical_or_array_config.py @@ -6,8 +6,17 @@ @attr.s class LogicalOrArrayConfig(ArrayConfig): - """This config class takes a source array and performs a logical or over the channels. - Good for union multiple masks.""" + """ + This config class takes a source array and performs a logical or over the channels. + Good for union multiple masks. + + Attributes: + source_array_config (ArrayConfig): The Array of masks from which to take the union + Methods: + to_array: Returns the LogicalOrArray object + Notes: + The source_array_config must be an ArrayConfig object. + """ array_type = LogicalOrArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array.py b/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array.py index 944c69b69..4a36efc29 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array.py @@ -9,9 +9,68 @@ class MergeInstancesArray(Array): - """ """ + """ + This array merges multiple source arrays into a single array by summing them. This is useful for merging + instance segmentation arrays into a single array. NeuoGlancer will display each instance as a different color. + + Attributes: + name : str + The name of the array + source_array_configs : List[ArrayConfig] + A list of source arrays to merge + Methods: + __getitem__(roi: Roi) -> np.ndarray + Returns a numpy array with the requested region of interest + _can_neuroglance() -> bool + Returns True if the array can be visualized in neuroglancer + _neuroglancer_source() -> str + Returns the source name for the array in neuroglancer + _neuroglancer_layer() -> Tuple[neuroglancer.SegmentationLayer, Dict[str, Any]] + Returns a neuroglancer layer and its configuration + _source_name() -> str + Returns the source name for the array + Note: + This array is not writable + Source arrays must have the same shape. + + """ def __init__(self, array_config): + """ + Constructor for MergeInstancesArray + + Args: + array_config : MergeInstancesArrayConfig + The configuration for the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + ``` + Note: + This example shows how to create a MergeInstancesArray object + """ self.name = array_config.name self._source_arrays = [ source_config.array_type(source_config) @@ -21,34 +80,317 @@ def __init__(self, array_config): @property def axes(self): + """ + Returns the axes of the array + + Returns: + List[str]: The axes of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + axes = array.axes + ``` + Note: + This example shows how to get the axes of the array + + """ return [x for x in self._source_array.axes if x != "c"] @property def dims(self) -> int: + """ + Returns the number of dimensions of the array + + Returns: + int: The number of dimensions of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + dims = array.dims + ``` + Note: + This example shows how to get the number of dimensions of the array + + + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the array + + Returns: + Coordinate: The voxel size of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + voxel_size = array.voxel_size + ``` + Note: + This example shows how to get the voxel size of the array + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Returns the region of interest of the array + + Returns: + Roi: The region of interest of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + roi = array.roi + ``` + Note: + This example shows how to get the region of interest of the array + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Returns True if the array is writable, False otherwise + + Returns: + bool: True if the array is writable, False otherwise + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + writable = array.writable + ``` + Note: + This example shows how to check if the array is writable + """ return False @property def dtype(self): + """ + Returns the data type of the array + + Returns: + np.dtype: The data type of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + dtype = array.dtype + ``` + Note: + This example shows how to get the data type of the array + """ return np.uint8 @property def num_channels(self): + """ + Returns the number of channels of the array + + Returns: + int: The number of channels of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + num_channels = array.num_channels + ``` + Note: + This example shows how to get the number of channels of the array + """ return None @property def data(self): + """ + Returns the data of the array + + Returns: + np.ndarray: The data of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + data = array.data + ``` + Note: + This example shows how to get the data of the array + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -56,9 +398,83 @@ def data(self): @property def attrs(self): + """ + Returns the attributes of the array + + Returns: + Dict[str, Any]: The attributes of the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + attributes = array.attrs + ``` + Note: + This example shows how to get the attributes of the array + """ return self._source_array.attrs def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns a numpy array with the requested region of interest + + Args: + roi : Roi + The region of interest to get + Returns: + np.ndarray: A numpy array with the requested region of interest + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + roi = Roi((0, 0, 0), (100, 100, 100)) + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + array_data = array[roi] + ``` + Note: + This example shows how to get a numpy array with the requested region of interest + """ arrays = [source_array[roi] for source_array in self._source_arrays] offset = 0 for array in arrays: @@ -67,12 +483,117 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return np.sum(arrays, axis=0) def _can_neuroglance(self): + """ + Returns True if the array can be visualized in neuroglancer, False otherwise + + Returns: + bool: True if the array can be visualized in neuroglancer, False otherwise + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + can_neuroglance = array._can_neuroglance() + ``` + Note: + This example shows how to check if the array can be visualized in neuroglancer + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Returns the source name for the array in neuroglancer + + Returns: + str: The source name for the array in neuroglancer + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + source = array._neuroglancer_source() + ``` + Note: + This example shows how to get the source name for the array in neuroglancer + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Returns a neuroglancer layer and its configuration + + Returns: + Tuple[neuroglancer.SegmentationLayer, Dict[str, Any]]: A neuroglancer layer and its configuration + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + layer, kwargs = array._neuroglancer_layer() + ``` + Note: + This example shows how to get a neuroglancer layer and its configuration + """ # Generates an Segmentation layer layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) @@ -82,4 +603,39 @@ def _neuroglancer_layer(self): return layer, kwargs def _source_name(self): + """ + Returns the source name for the array + + Returns: + str: The source name for the array + Raises: + ValueError: If the source arrays have different shapes + Example: + ```python + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArray + from dacapo.experiments.datasplits.datasets.arrays import MergeInstancesArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + from dacapo.experiments.datasplits.datasets.arrays import ArrayType + from funlib.geometry import Coordinate, Roi + array_config = MergeInstancesArrayConfig( + name="array", + source_array_configs=[ + ArrayConfig( + name="array1", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array1.h5", + ), + ArrayConfig( + name="array2", + array_type=ArrayType.INSTANCE_SEGMENTATION, + path="path/to/array2.h5", + ), + ], + ) + array = MergeInstancesArray(array_config) + source_name = array._source_name() + ``` + Note: + This example shows how to get the source name for the array + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array_config.py index 31c6e5acd..d7a523215 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/merge_instances_array_config.py @@ -8,6 +8,21 @@ @attr.s class MergeInstancesArrayConfig(ArrayConfig): + """ + Configuration for an array that merges instances from multiple arrays + into a single array. The instances are merged by taking the union of the + instances in the source arrays. + + Attributes: + source_array_configs: List[ArrayConfig] + The Array of masks from which to take the union + Methods: + create_array: () -> MergeInstancesArray + Create a MergeInstancesArray instance from the configuration + Notes: + The MergeInstancesArrayConfig class is used to create a MergeInstancesArray + """ + array_type = MergeInstancesArray source_array_configs: List[ArrayConfig] = attr.ib( diff --git a/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask.py b/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask.py index 3d1a86b93..aaf59cb69 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask.py +++ b/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask.py @@ -20,9 +20,32 @@ class MissingAnnotationsMask(Array): See package fibsem_tools for appropriate metadata format for indicating presence of labels in your ground truth. "https://github.com/janelia-cosem/fibsem-tools" + + Attributes: + array_config: A BinarizeArrayConfig object + Methods: + __getitem__(roi: Roi) -> np.ndarray: Returns a binary mask of the + annotations that are present but not annotated. + Note: + This class is not meant to be used directly. It is used by the + BinarizeArray class to mask out annotations that are present but + not annotated. """ def __init__(self, array_config): + """ + Initializes the MissingAnnotationsMask class + + Args: + array_config (BinarizeArrayConfig): A BinarizeArrayConfig object + Raises: + AssertionError: If the source array has channels + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> missing_annotations_mask = MissingAnnotationsMask(MissingAnnotationsMaskConfig(source_array, groupings)) + Notes: + This is a helper function for the BinarizeArray class + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -36,34 +59,152 @@ def __init__(self, array_config): @property def axes(self): + """ + Returns the axes of the source array + + Returns: + list: Axes of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.axes + ['x', 'y', 'z'] + Notes: + This is a helper function for the BinarizeArray class + """ return ["c"] + self._source_array.axes @property def dims(self) -> int: + """ + Returns the number of dimensions of the source array + + Returns: + int: Number of dimensions of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.dims + 3 + Notes: + This is a helper function for the BinarizeArray class + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the source array + + Returns: + Coordinate: Voxel size of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.voxel_size + Coordinate(x=4, y=4, z=40) + Notes: + This is a helper function for the BinarizeArray class + + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + Returns the region of interest of the source array + + Returns: + Roi: Region of interest of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.roi + Roi(offset=(0, 0, 0), shape=(100, 100, 100)) + Notes: + This is a helper function for the BinarizeArray class + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Returns whether the source array is writable + + Returns: + bool: Whether the source array is writable + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.writable + False + Notes: + This is a helper function for the BinarizeArray class + + """ return False @property def dtype(self): + """ + Returns the data type of the source array + + Returns: + np.dtype: Data type of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.dtype + np.uint8 + Notes: + This is a helper function for the BinarizeArray class + + """ return np.uint8 @property def num_channels(self) -> int: + """ + Returns the number of channels + + Returns: + int: Number of channels + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.num_channels + 2 + Notes: + This is a helper function for the BinarizeArray class + + + """ return len(self._groupings) @property def data(self): + """ + Returns the data of the source array + + Returns: + np.ndarray: Data of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.data + np.ndarray(...) + Notes: + This is a helper function for the BinarizeArray class + + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -71,13 +212,62 @@ def data(self): @property def attrs(self): + """ + Returns the attributes of the source array + + Returns: + dict: Attributes of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.attrs + {'name': 'source_array', 'resolution': [4, 4, 40]} + Notes: + This is a helper function for the BinarizeArray class + """ return self._source_array.attrs @property def channels(self): + """ + Returns the names of the channels + + Returns: + Generator[str]: Names of the channels + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array.channels + Generator['channel1', 'channel2', ...] + Notes: + This is a helper function for the BinarizeArray class + """ return (name for name, _ in self._groupings) def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns a binary mask of the annotations that are present but not annotated. + + Args: + roi (Roi): Region of interest to get the mask for + Returns: + np.ndarray: Binary mask of the annotations that are present but not annotated + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> missing_annotations_mask = MissingAnnotationsMask(MissingAnnotationsMaskConfig(source_array, groupings)) + >>> roi = Roi(...) + >>> missing_annotations_mask[roi] + np.ndarray(...) + Notes: + - This is a helper function for the BinarizeArray class + - Number of channels in the mask is equal to the number of groupings + - Nuclues is a special case where we mask out the whole channel if any of the + sub-organelles are present but not annotated + """ labels = self._source_array[roi] grouped = np.ones((len(self._groupings), *labels.shape), dtype=bool) grouped[:] = labels > 0 @@ -93,32 +283,63 @@ def __getitem__(self, roi: Roi) -> np.ndarray: ) for i, (_, ids) in enumerate(self._groupings): if any([id in present_not_annotated for id in ids]): - # specially handle id 37 - # TODO: find more general solution - if 37 in ids and 37 not in present_not_annotated: - # 37 marks any kind of nucleus voxel. There many be nucleus sub - # organelles marked as "present not annotated", but we can safely - # train any channel that includes those organelles as long as - # 37 is annotated. - pass - else: - # mask out this whole channel - grouped[i] = 0 - - # for id in ids: - # grouped[i][labels == id] = 0 + grouped[i] = 0 except KeyError: pass return grouped def _can_neuroglance(self): + """ + Returns whether the array can be visualized in neuroglancer + + Returns: + bool: Whether the array can be visualized in neuroglancer + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array._can_neuroglance() + True + Notes: + This is a helper function for the neuroglancer layer + + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Returns a neuroglancer source for the array + + Returns: + neuroglancer.LocalVolume: Neuroglancer source for the array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array._neuroglancer_source() + neuroglancer.LocalVolume(...) + Notes: + This is a helper function for the neuroglancer layer + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Returns a neuroglancer Segmentation layer for the array + + Returns: + neuroglancer.SegmentationLayer: Segmentation layer for the array + dict: Keyword arguments for the layer + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array._neuroglancer_layer() + (neuroglancer.SegmentationLayer, dict) + Notes: + This is a helper function for the neuroglancer layer + """ # Generates an Segmentation layer layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) @@ -128,4 +349,18 @@ def _neuroglancer_layer(self): return layer, kwargs def _source_name(self): + """ + Returns the name of the source array + + Returns: + str: Name of the source array + Raises: + ValueError: If the source array does not have a name + Examples: + >>> source_array = ZarrArray(ZarrArrayConfig(...)) + >>> source_array._source_name() + 'source_array' + Notes: + This is a helper function for the neuroglancer layer name + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask_config.py b/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask_config.py index 6fae4d51d..08faece08 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/missing_annotations_mask_config.py @@ -8,8 +8,20 @@ @attr.s class MissingAnnotationsMaskConfig(ArrayConfig): - """This config class provides the necessary configuration for turning an Annotated dataset into a - multi class binary classification problem""" + """ + This config class provides the necessary configuration for turning an Annotated dataset into a + multi class binary classification problem + + Attributes: + source_array_config : ArrayConfig + The Array from which to pull annotated data. Is expected to contain a volume with uint64 voxels and no channel dimension + groupings : List[Tuple[str, List[int]]] + List of id groups with a symantic name. Each id group is a List of ids. + Group i found in groupings[i] will be binarized and placed in channel i. + Note: + The output array will have a channel dimension equal to the number of groups. + Each channel will be a binary mask of the ids in the groupings list. + """ array_type = MissingAnnotationsMask diff --git a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py index 5f2bc0483..63c73e228 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/numpy_array.py @@ -9,7 +9,21 @@ class NumpyArray(Array): - """This is just a wrapper for a numpy array to make it fit the DaCapo Array interface.""" + """ + This is just a wrapper for a numpy array to make it fit the DaCapo Array interface. + + Attributes: + data: The numpy array. + dtype: The data type of the numpy array. + roi: The region of interest of the numpy array. + voxel_size: The voxel size of the numpy array. + axes: The axes of the numpy array. + Methods: + from_gp_array: Create a NumpyArray from a Gunpowder Array. + from_np_array: Create a NumpyArray from a numpy array. + Note: + This class is a subclass of Array. + """ _data: np.ndarray _dtype: np.dtype @@ -18,14 +32,73 @@ class NumpyArray(Array): _axes: List[str] def __init__(self, array_config): + """ + Create a NumpyArray from an array config. + + Args: + array_config: The array config. + Returns: + NumpyArray: The NumpyArray. + Raises: + ValueError: If the array does not have a data type. + Examples: + >>> array = NumpyArray(OnesArrayConfig(source_array_config=ArrayConfig())) + >>> array.data + array([[[1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.]], + + [[1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.]]]) + Note: + This method creates a NumpyArray from an array config. + """ raise RuntimeError("Numpy Array cannot be built from a config file") @property def attrs(self): + """ + Returns the attributes of the array. + + Returns: + dict: The attributes of the array. + Raises: + ValueError: If the array does not have attributes. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.attrs + {} + Note: + This method is a property. It returns the attributes of the array. + """ return dict() @classmethod def from_gp_array(cls, array: gp.Array): + """ + Create a NumpyArray from a Gunpowder Array. + + Args: + array (gp.Array): The Gunpowder Array. + Returns: + NumpyArray: The NumpyArray. + Raises: + ValueError: If the array does not have a data type. + Examples: + >>> array = gp.Array(data=np.zeros((2, 3, 4)), spec=gp.ArraySpec(roi=Roi((0, 0, 0), (2, 3, 4)), voxel_size=Coordinate((1, 1, 1)))) + >>> array = NumpyArray.from_gp_array(array) + >>> array.data + array([[[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]], + + [[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]) + Note: + This method creates a NumpyArray from a Gunpowder Array. + """ instance = cls.__new__(cls) instance._data = array.data instance._dtype = array.data.dtype @@ -45,6 +118,32 @@ def from_gp_array(cls, array: gp.Array): @classmethod def from_np_array(cls, array: np.ndarray, roi, voxel_size, axes): + """ + Create a NumpyArray from a numpy array. + + Args: + array (np.ndarray): The numpy array. + roi (Roi): The region of interest of the array. + voxel_size (Coordinate): The voxel size of the array. + axes (List[str]): The axes of the array. + Returns: + NumpyArray: The NumpyArray. + Raises: + ValueError: If the array does not have a data type. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.data + array([[[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]], + + [[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]) + Note: + This method creates a NumpyArray from a numpy array. + + """ instance = cls.__new__(cls) instance._data = array instance._dtype = array.dtype @@ -55,34 +154,151 @@ def from_np_array(cls, array: np.ndarray, roi, voxel_size, axes): @property def axes(self): + """ + Returns the axes of the array. + + Returns: + List[str]: The axes of the array. + Raises: + ValueError: If the array does not have axes. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.axes + ['z', 'y', 'x'] + Note: + This method is a property. It returns the axes of the array. + """ return self._axes @property def dims(self): + """ + Returns the number of dimensions of the array. + + Returns: + int: The number of dimensions of the array. + Raises: + ValueError: If the array does not have a dimension. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.dims + 3 + Note: + This method is a property. It returns the number of dimensions of the array. + """ return self._roi.dims @property def voxel_size(self): + """ + Returns the voxel size of the array. + + Returns: + Coordinate: The voxel size of the array. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.voxel_size + Coordinate((1, 1, 1)) + Note: + This method is a property. It returns the voxel size of the array. + """ return self._voxel_size @property def roi(self): + """ + Returns the region of interest of the array. + + Returns: + Roi: The region of interest of the array. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.roi + Roi((0, 0, 0), (2, 3, 4)) + Note: + This method is a property. It returns the region of interest of the array. + """ return self._roi @property def writable(self) -> bool: + """ + Returns whether the array is writable. + + Returns: + bool: Whether the array is writable. + Raises: + ValueError: If the array is not writable. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.writable + True + Note: + This method is a property. It returns whether the array is writable. + """ return True @property def data(self): + """ + Returns the numpy array. + + Returns: + np.ndarray: The numpy array. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.data + array([[[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]], + + [[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]) + Note: + This method is a property. It returns the numpy array. + """ return self._data @property def dtype(self): + """ + Returns the data type of the array. + + Returns: + np.dtype: The data type of the array. + Raises: + ValueError: If the array does not have a data type. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.dtype + dtype('float64') + Note: + This method is a property. It returns the data type of the array. + """ return self.data.dtype @property def num_channels(self): + """ + Returns the number of channels in the array. + + Returns: + int: The number of channels in the array. + Raises: + ValueError: If the array does not have a channel dimension. + Examples: + >>> array = NumpyArray.from_np_array(np.zeros((1, 2, 3, 4)), Roi((0, 0, 0), (1, 2, 3)), Coordinate((1, 1, 1)), ["b", "c", "z", "y", "x"]) + >>> array.num_channels + 1 + >>> array = NumpyArray.from_np_array(np.zeros((2, 3, 4)), Roi((0, 0, 0), (2, 3, 4)), Coordinate((1, 1, 1)), ["z", "y", "x"]) + >>> array.num_channels + Traceback (most recent call last): + ... + ValueError: Array does not have a channel dimension. + Note: + This method is a property. It returns the number of channels in the array. + """ try: channel_dim = self.axes.index("c") return self.data.shape[channel_dim] diff --git a/dacapo/experiments/datasplits/datasets/arrays/ones_array.py b/dacapo/experiments/datasplits/datasets/arrays/ones_array.py index 4fe0aaca1..6fd5c4faf 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/ones_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/ones_array.py @@ -6,59 +6,400 @@ class OnesArray(Array): - """This is a wrapper around another `source_array` that simply provides ones - with the same metadata as the `source_array`.""" + """ + This is a wrapper around another `source_array` that simply provides ones + with the same metadata as the `source_array`. + + This is useful for creating a mask array that is the same size as the + original array, but with all values set to 1. + + Attributes: + source_array: The source array that this array is based on. + Methods: + like: Create a new OnesArray with the same metadata as another array. + attrs: Get the attributes of the array. + axes: Get the axes of the array. + dims: Get the dimensions of the array. + voxel_size: Get the voxel size of the array. + roi: Get the region of interest of the array. + writable: Check if the array is writable. + data: Get the data of the array. + dtype: Get the data type of the array. + num_channels: Get the number of channels of the array. + __getitem__: Get a subarray of the array. + Note: + This class is not meant to be instantiated directly. Instead, use the + `like` method to create a new OnesArray with the same metadata as + another array. + """ def __init__(self, array_config): + """ + Initialize the OnesArray with the given array configuration. + + Args: + array_config: The configuration of the source array. + Raises: + RuntimeError: If the source array is not specified in the + configuration. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import ArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> source_array_config = ArrayConfig(source_array) + >>> ones_array = OnesArray(source_array_config) + >>> ones_array.source_array + NumpyArray(data=array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]), voxel_size=(1.0, 1.0, 1.0), roi=Roi((0, 0, 0), (10, 10, 10)), num_channels=1) + Notes: + This class is not meant to be instantiated directly. Instead, use the + `like` method to create a new OnesArray with the same metadata as + another array. + """ self._source_array = array_config.source_array_config.array_type( array_config.source_array_config ) @classmethod def like(cls, array: Array): + """ + Create a new OnesArray with the same metadata as another array. + + Args: + array: The source array. + Returns: + The new OnesArray with the same metadata as the source array. + Raises: + RuntimeError: If the source array is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray.like(source_array) + >>> ones_array.source_array + NumpyArray(data=array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]), voxel_size=(1.0, 1.0, 1.0), roi=Roi((0, 0, 0), (10, 10, 10)), num_channels=1) + Notes: + This class is not meant to be instantiated directly. Instead, use the + `like` method to create a new OnesArray with the same metadata as + another array. + + """ instance = cls.__new__(cls) instance._source_array = array return instance @property def attrs(self): + """ + Get the attributes of the array. + + Returns: + An empty dictionary. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.attrs + {} + Notes: + This method is used to get the attributes of the array. The attributes + are stored as key-value pairs in a dictionary. This method returns an + empty dictionary because the OnesArray does not have any attributes. + """ return dict() @property def source_array(self) -> Array: + """ + Get the source array that this array is based on. + + Returns: + The source array. + Raises: + RuntimeError: If the source array is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.source_array + NumpyArray(data=array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]), voxel_size=(1.0, 1.0, 1.0), roi=Roi((0, 0, 0), (10, 10, 10)), num_channels=1) + Notes: + This method is used to get the source array that this array is based on. + The source array is the array that the OnesArray is created from. This + method returns the source array that was specified when the OnesArray + was created. + """ return self._source_array @property def axes(self): + """ + Get the axes of the array. + + Returns: + The axes of the array. + Raises: + RuntimeError: If the axes are not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.axes + 'zyx' + Notes: + This method is used to get the axes of the array. The axes are the + order of the dimensions of the array. This method returns the axes of + the array that was specified when the OnesArray was created. + """ return self.source_array.axes @property def dims(self): + """ + Get the dimensions of the array. + + Returns: + The dimensions of the array. + Raises: + RuntimeError: If the dimensions are not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.dims + (10, 10, 10) + Notes: + This method is used to get the dimensions of the array. The dimensions + are the size of the array along each axis. This method returns the + dimensions of the array that was specified when the OnesArray was created. + """ return self.source_array.dims @property def voxel_size(self): + """ + Get the voxel size of the array. + + Returns: + The voxel size of the array. + Raises: + RuntimeError: If the voxel size is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.voxel_size + (1.0, 1.0, 1.0) + Notes: + This method is used to get the voxel size of the array. The voxel size + is the size of each voxel in the array. This method returns the voxel + size of the array that was specified when the OnesArray was created. + """ return self.source_array.voxel_size @property def roi(self): + """ + Get the region of interest of the array. + + Returns: + The region of interest of the array. + Raises: + RuntimeError: If the region of interest is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.roi + Roi((0, 0, 0), (10, 10, 10)) + Notes: + This method is used to get the region of interest of the array. The + region of interest is the region of the array that contains the data. + This method returns the region of interest of the array that was specified + when the OnesArray was created. + """ return self.source_array.roi @property def writable(self) -> bool: + """ + Check if the array is writable. + + Returns: + False. + Raises: + RuntimeError: If the writability of the array is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.writable + False + Notes: + This method is used to check if the array is writable. An array is + writable if it can be modified in place. This method returns False + because the OnesArray is read-only and cannot be modified. + """ return False @property def data(self): + """ + Get the data of the array. + + Returns: + The data of the array. + Raises: + RuntimeError: If the data is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.data + array([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]) + Notes: + This method is used to get the data of the array. The data is the + values that are stored in the array. This method returns a subarray + of the array with all values set to 1. + """ raise RuntimeError("Cannot get writable version of this data!") @property def dtype(self): + """ + Get the data type of the array. + + Returns: + The data type of the array. + Raises: + RuntimeError: If the data type is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.dtype + + Notes: + This method is used to get the data type of the array. The data type + is the type of the values that are stored in the array. This method + returns the data type of the array that was specified when the OnesArray + was created. + """ return bool @property def num_channels(self): + """ + Get the number of channels of the array. + + Returns: + The number of channels of the array. + Raises: + RuntimeError: If the number of channels is not specified. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> ones_array.num_channels + 1 + Notes: + This method is used to get the number of channels of the array. The + number of channels is the number of values that are stored at each + voxel in the array. This method returns the number of channels of the + array that was specified when the OnesArray was created. + """ return self.source_array.num_channels def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Get a subarray of the array. + + Args: + roi: The region of interest. + Returns: + A subarray of the array with all values set to 1. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import OnesArray + >>> from dacapo.experiments.datasplits.datasets.arrays import NumpyArray + >>> from funlib.geometry import Roi + >>> import numpy as np + >>> source_array = NumpyArray(np.zeros((10, 10, 10))) + >>> ones_array = OnesArray(source_array) + >>> roi = Roi((0, 0, 0), (10, 10, 10)) + >>> ones_array[roi] + array([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]) + Notes: + This method is used to get a subarray of the array. The subarray is + specified by the region of interest. This method returns a subarray + of the array with all values set to 1. + """ return np.ones_like(self.source_array.__getitem__(roi), dtype=bool) diff --git a/dacapo/experiments/datasplits/datasets/arrays/ones_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/ones_array_config.py index 649aaa390..152b357c2 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/ones_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/ones_array_config.py @@ -6,7 +6,20 @@ @attr.s class OnesArrayConfig(ArrayConfig): - """This array read data from the source array and then return a np.ones_like() version.""" + """ + This array read data from the source array and then return a np.ones_like() version. + + This is useful for creating a mask array from a source array. For example, if you have a + 2D array of data and you want to create a mask array that is the same shape as the data + array, you can use this class to create the mask array. + + Attributes: + source_array_config: The source array that you want to copy and fill with ones. + Methods: + create_array: Create the array. + Note: + This class is a subclass of ArrayConfig. + """ array_type = OnesArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py index f74d5bf1d..5c60a5df4 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array.py @@ -8,9 +8,76 @@ class ResampledArray(Array): - """This is a zarr array""" + """ + This is a zarr array that is a resampled version of another array. + + Resampling is done by rescaling the source array with the given + upsample and downsample factors. The voxel size of the resampled array + is the voxel size of the source array divided by the downsample factor + and multiplied by the upsample factor. + + Attributes: + name: str + The name of the array + source_array: Array + The source array + upsample: Coordinate + The upsample factor for each dimension + downsample: Coordinate + The downsample factor for each dimension + interp_order: int + The order of the interpolation used for resampling + Methods: + attrs: Dict + Returns the attributes of the source array + axes: str + Returns the axes of the source array + dims: int + Returns the number of dimensions of the source array + voxel_size: Coordinate + Returns the voxel size of the resampled array + roi: Roi + Returns the region of interest of the resampled array + writable: bool + Returns whether the resampled array is writable + dtype: np.dtype + Returns the data type of the resampled array + num_channels: int + Returns the number of channels of the resampled array + data: np.ndarray + Returns the data of the resampled array + scale: Tuple[float] + Returns the scale of the resampled array + __getitem__(roi: Roi) -> np.ndarray + Returns the data of the resampled array within the given region of interest + _can_neuroglance() -> bool + Returns whether the source array can be visualized with neuroglance + _neuroglancer_layer() -> Dict + Returns the neuroglancer layer of the source array + _neuroglancer_source() -> Dict + Returns the neuroglancer source of the source array + _source_name() -> str + Returns the name of the source array + Note: + This class is a subclass of Array. + + + """ def __init__(self, array_config): + """ + Constructor of the ResampledArray class. + + Args: + array_config: ArrayConfig + The configuration of the array + Raises: + AssertionError: If the voxel size of the resampled array is not equal to the voxel size of the source array divided by the downsample factor and multiplied by the upsample factor + Examples: + >>> resampled_array = ResampledArray(array_config) + Note: + This constructor resamples the source array with the given upsample and downsample factors. + """ self.name = array_config.name self._source_array = array_config.source_array_config.array_type( array_config.source_array_config @@ -26,38 +93,149 @@ def __init__(self, array_config): @property def attrs(self): + """ + Returns the attributes of the source array. + + Returns: + Dict: The attributes of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.attrs + Note: + This method returns the attributes of the source array. + + """ return self._source_array.attrs @property def axes(self): + """ + Returns the axes of the source array. + + Returns: + str: The axes of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.axes + Note: + This method returns the axes of the source array. + """ return self._source_array.axes @property def dims(self) -> int: + """ + Returns the number of dimensions of the source array. + + Returns: + int: The number of dimensions of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.dims + Note: + This method returns the number of dimensions of the source array. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the resampled array. + + Returns: + Coordinate: The voxel size of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.voxel_size + Note: + This method returns the voxel size of the resampled array. + """ return (self._source_array.voxel_size * self.downsample) / self.upsample @property def roi(self) -> Roi: + """ + Returns the region of interest of the resampled array. + + Returns: + Roi: The region of interest of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.roi + Note: + This method returns the region of interest of the resampled array. + + """ return self._source_array.roi.snap_to_grid(self.voxel_size, mode="shrink") @property def writable(self) -> bool: + """ + Returns whether the resampled array is writable. + + Returns: + bool: True if the resampled array is writable, False otherwise + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.writable + Note: + This method returns whether the resampled array is writable. + + """ return False @property def dtype(self): + """ + Returns the data type of the resampled array. + + Returns: + np.dtype: The data type of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.dtype + Note: + This method returns the data type of the resampled array. + """ return self._source_array.dtype @property def num_channels(self) -> int: + """ + Returns the number of channels of the resampled array. + + Returns: + int: The number of channels of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.num_channels + Note: + This method returns the number of channels of the resampled array. + """ return self._source_array.num_channels @property def data(self): + """ + Returns the data of the resampled array. + + Returns: + np.ndarray: The data of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.data + Note: + This method returns the data of the resampled array. + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -65,6 +243,19 @@ def data(self): @property def scale(self): + """ + Returns the scale of the resampled array. + + Returns: + Tuple[float]: The scale of the resampled array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array.scale + Note: + This method returns the scale of the resampled array. + + """ spatial_scales = tuple(u / d for d, u in zip(self.downsample, self.upsample)) if "c" in self.axes: scales = list(spatial_scales) @@ -74,6 +265,21 @@ def scale(self): return spatial_scales def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns the data of the resampled array within the given region of interest. + + Args: + roi: Roi + The region of interest + Returns: + np.ndarray: The data of the resampled array within the given region of interest + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array[roi] + Note: + This method returns the data of the resampled array within the given region of interest. + """ snapped_roi = roi.snap_to_grid(self._source_array.voxel_size, mode="grow") resampled_array = funlib.persistence.Array( rescale( @@ -88,13 +294,61 @@ def __getitem__(self, roi: Roi) -> np.ndarray: return resampled_array.to_ndarray(roi) def _can_neuroglance(self): + """ + Returns whether the source array can be visualized with neuroglance. + + Returns: + bool: True if the source array can be visualized with neuroglance, False otherwise + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array._can_neuroglance() + Note: + This method returns whether the source array can be visualized with neuroglance. + """ return self._source_array._can_neuroglance() def _neuroglancer_layer(self): + """ + Returns the neuroglancer layer of the source array. + + Returns: + Dict: The neuroglancer layer of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array._neuroglancer_layer() + Note: + This method returns the neuroglancer layer of the source array. + """ return self._source_array._neuroglancer_layer() def _neuroglancer_source(self): + """ + Returns the neuroglancer source of the source array. + + Returns: + Dict: The neuroglancer source of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array._neuroglancer_source() + Note: + This method returns the neuroglancer source of the source array. + """ return self._source_array._neuroglancer_source() def _source_name(self): + """ + Returns the name of the source array. + + Returns: + str: The name of the source array + Raises: + ValueError: If the resampled array is not writable + Examples: + >>> resampled_array._source_name() + Note: + This method returns the name of the source array. + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py index e080b8304..c4c5a1c54 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/resampled_array_config.py @@ -8,7 +8,20 @@ @attr.s class ResampledArrayConfig(ArrayConfig): - """This array will up or down sample an array into the desired voxel size.""" + """ + A configuration for a ResampledArray. This array will up or down sample an array into the desired voxel size. + + Attributes: + source_array_config (ArrayConfig): The Array that you want to upsample or downsample. + upsample (Coordinate): The amount by which to upsample! + downsample (Coordinate): The amount by which to downsample! + interp_order (bool): The order of the interpolation! + Methods: + create_array: Creates a ResampledArray from the configuration. + Note: + This class is meant to be used with the ArrayDataset class. + + """ array_type = ResampledArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/sum_array.py b/dacapo/experiments/datasplits/datasets/arrays/sum_array.py index 845b69810..ce1dcd087 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/sum_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/sum_array.py @@ -9,9 +9,56 @@ class SumArray(Array): - """ """ + """ + This class provides a sum array. This array is a virtual array that is created by summing + multiple source arrays. The source arrays must have the same shape and ROI. + + Attributes: + name: str + The name of the array. + _source_arrays: List[Array] + The source arrays to sum. + _source_array: Array + The first source array. + Methods: + __getitem__(roi: Roi) -> np.ndarray + Get the data for the given region of interest. + _can_neuroglance() -> bool + Check if neuroglance can be used. + _neuroglancer_source() -> Dict + Return the source for neuroglance. + _neuroglancer_layer() -> Tuple[neuroglancer.SegmentationLayer, Dict] + Return the neuroglancer layer. + _source_name() -> str + Return the source name. + Note: + This class is a subclass of Array. + """ def __init__(self, array_config): + """ + Initialize the SumArray. + + Args: + array_config: SumArrayConfig + The configuration for the sum array. + Returns: + SumArray: The sum array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays.sum_array import SumArray + >>> from dacapo.experiments.datasplits.datasets.arrays.sum_array_config import SumArrayConfig + >>> from dacapo.experiments.datasplits.datasets.arrays.tiff_array import TiffArray + >>> from dacapo.experiments.datasplits.datasets.arrays.tiff_array_config import TiffArrayConfig + >>> from funlib.geometry import Coordinate + >>> from pathlib import Path + >>> sum_array = SumArray(SumArrayConfig(name="sum", source_array_configs=[TiffArrayConfig(file_name=Path("data.tiff"), offset=Coordinate([0, 0, 0]), voxel_size=Coordinate([1, 1, 1]), axes=["x", "y", "z"])])) + Note: + This class is a subclass of Array. + + """ self.name = array_config.name self._source_arrays = [ source_config.array_type(source_config) @@ -21,34 +68,163 @@ def __init__(self, array_config): @property def axes(self): + """ + The axes of the array. + + Returns: + List[str]: The axes of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.axes + ['x', 'y', 'z'] + Note: + This class is a subclass of Array. + """ return [x for x in self._source_array.axes if x != "c"] @property def dims(self) -> int: + """ + The number of dimensions of the array. + + Returns: + int: The number of dimensions of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.dims + 3 + Note: + This class is a subclass of Array. + """ return self._source_array.dims @property def voxel_size(self) -> Coordinate: + """ + The size of each voxel in each dimension. + + Returns: + Coordinate: The size of each voxel in each dimension. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.voxel_size + Coordinate([1, 1, 1]) + Note: + This class is a subclass of Array. + """ return self._source_array.voxel_size @property def roi(self) -> Roi: + """ + The region of interest of the array. + + Args: + roi: Roi + The region of interest. + Returns: + Roi: The region of interest. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.roi + Roi(Coordinate([0, 0, 0]), Coordinate([100, 100, 100])) + Note: + This class is a subclass of Array. + """ return self._source_array.roi @property def writable(self) -> bool: + """ + Check if the array is writable. + + Args: + writable: bool + Check if the array is writable. + Returns: + bool: True if the array is writable, otherwise False. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.writable + False + Note: + This class is a subclass of Array. + """ return False @property def dtype(self): + """ + The data type of the array. + + Args: + dtype: np.uint8 + The data type of the array. + Returns: + np.uint8: The data type of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.dtype + np.uint8 + Note: + This class is a subclass of Array. + + """ return np.uint8 @property def num_channels(self): + """ + The number of channels in the array. + + Args: + num_channels: Optional[int] + The number of channels in the array. + Returns: + Optional[int]: The number of channels in the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.num_channels + None + Note: + This class is a subclass of Array. + + """ return None @property def data(self): + """ + Get the data of the array. + + Args: + data: np.ndarray + The data of the array. + Returns: + np.ndarray: The data of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.data + np.array([[[0, 0], [0, 0]], [[0, 0], [0, 0]]]) + Note: + This class is a subclass of Array. + """ raise ValueError( "Cannot get a writable view of this array because it is a virtual " "array created by modifying another array on demand." @@ -56,20 +232,107 @@ def data(self): @property def attrs(self): + """ + Return the attributes of the array. + + Args: + attrs: Dict + The attributes of the array. + Returns: + Dict: The attributes of the array. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array.attrs + {} + Note: + This class is a subclass of Array. + """ return self._source_array.attrs def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Get the data for the given region of interest. + + Args: + roi: Roi + The region of interest. + Returns: + np.ndarray: The data for the given region of interest. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array[roi] + np.array([[[0, 0], [0, 0]], [[0, 0], [0, 0]]]) + Note: + This class is a subclass of Array. + """ return np.sum( [source_array[roi] for source_array in self._source_arrays], axis=0 ) def _can_neuroglance(self): + """ + Check if neuroglance can be used. + + Args: + can_neuroglance: bool + Check if neuroglance can be used. + Returns: + bool: True if neuroglance can be used, otherwise False. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array._can_neuroglance() + False + Note: + This class is a subclass of Array. + """ return self._source_array._can_neuroglance() def _neuroglancer_source(self): + """ + Return the source for neuroglance. + + Args: + source: Dict + The source for neuroglance. + Returns: + Dict: The source for neuroglance. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array._neuroglancer_source() + {'source': 'precomputed://https://mybucket/segmentation', 'type': 'segmentation', 'voxel_size': [1, 1, 1]} + Note: + This class is a subclass of Array. + + """ return self._source_array._neuroglancer_source() def _neuroglancer_layer(self): + """ + Return the neuroglancer layer. + + Args: + layer: Tuple[neuroglancer.SegmentationLayer, Dict] + The neuroglancer layer. + Returns: + Tuple[neuroglancer.SegmentationLayer, Dict]: The neuroglancer layer. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array._neuroglancer_layer() + (SegmentationLayer(source={'source': 'precomputed://https://mybucket/segmentation', 'type': 'segmentation', 'voxel_size': [1, 1, 1]}, visible=False), {}) + Note: + This class is a subclass of Array. + + """ # Generates an Segmentation layer layer = neuroglancer.SegmentationLayer(source=self._neuroglancer_source()) @@ -79,4 +342,22 @@ def _neuroglancer_layer(self): return layer, kwargs def _source_name(self): + """ + Return the source name. + + Args: + source_name: str + The source name. + Returns: + str: The source name. + Raises: + ValueError: + Cannot get a writable view of this array because it is a virtual array created by modifying another array on demand. + Examples: + >>> sum_array._source_name() + 'data.tiff' + Note: + This class is a subclass of Array. + + """ return self._source_array._source_name() diff --git a/dacapo/experiments/datasplits/datasets/arrays/sum_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/sum_array_config.py index 4cc12ddd7..0c2912140 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/sum_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/sum_array_config.py @@ -8,6 +8,17 @@ @attr.s class SumArrayConfig(ArrayConfig): + """ + This config class provides the necessary configuration for a sum + array. + + Attributes: + source_array_configs: List[ArrayConfig] + The Array of masks from which to take the union + Note: + This class is a subclass of ArrayConfig. + """ + array_type = SumArray source_array_configs: List[ArrayConfig] = attr.ib( diff --git a/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py b/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py index ccdf50376..34e582b4e 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/tiff_array.py @@ -6,14 +6,32 @@ import tifffile import logging -from pathlib import Path +from upath import UPath as Path from typing import List, Optional logger = logging.getLogger(__name__) class TiffArray(Array): - """This is a tiff array""" + """ + This class provides the necessary configuration for a tiff array. + + Attributes: + _offset: Coordinate + The offset of the array. + _file_name: Path + The file name of the tiff. + _voxel_size: Coordinate + The voxel size of the array. + _axes: List[str] + The axes of the array. + Methods: + attrs() -> Dict + Return the attributes of the tiff. + Note: + This class is a subclass of Array. + + """ _offset: Coordinate _file_name: Path @@ -21,6 +39,24 @@ class TiffArray(Array): _axes: List[str] def __init__(self, array_config): + """ + Initialize the TiffArray. + + Args: + array_config: TiffArrayConfig + The configuration for the tiff array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays.tiff_array import TiffArray + >>> from dacapo.experiments.datasplits.datasets.arrays.tiff_array_config import TiffArrayConfig + >>> from funlib.geometry import Coordinate + >>> from pathlib import Path + >>> tiff_array = TiffArray(TiffArrayConfig(file_name=Path("data.tiff"), offset=Coordinate([0, 0, 0]), voxel_size=Coordinate([1, 1, 1]), axes=["x", "y", "z"])) + Note: + This class is a subclass of Array. + """ super().__init__() self._file_name = array_config.file_name @@ -30,20 +66,76 @@ def __init__(self, array_config): @property def attrs(self): + """ + Return the attributes of the tiff. + + Returns: + Dict: The attributes of the tiff. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.attrs + {'axes': ['x', 'y', 'z'], 'offset': [0, 0, 0], 'voxel_size': [1, 1, 1]} + Note: + Tiffs have tons of different locations for metadata. + """ raise NotImplementedError( "Tiffs have tons of different locations for metadata." ) @property def axes(self) -> List[str]: + """ + Return the axes of the array. + + Returns: + List[str]: The axes of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.axes + ['x', 'y', 'z'] + Note: + Tiffs have tons of different locations for metadata. + """ return self._axes @property def dims(self) -> int: + """ + Return the number of dimensions of the array. + + Returns: + int: The number of dimensions of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.dims + 3 + Note: + Tiffs have tons of different locations for metadata. + """ return self.voxel_size.dims @lazy_property.LazyProperty def shape(self) -> Coordinate: + """ + Return the shape of the array. + + Returns: + Coordinate: The shape of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.shape + Coordinate([100, 100, 100]) + Note: + Tiffs have tons of different locations for metadata. + """ data_shape = self.data.shape spatial_shape = Coordinate( [data_shape[self.axes.index(axis)] for axis in self.spatial_axes] @@ -52,22 +144,94 @@ def shape(self) -> Coordinate: @lazy_property.LazyProperty def voxel_size(self) -> Coordinate: + """ + Return the voxel size of the array. + + Returns: + Coordinate: The voxel size of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.voxel_size + Coordinate([1, 1, 1]) + Note: + Tiffs have tons of different locations for metadata. + """ return self._voxel_size @lazy_property.LazyProperty def roi(self) -> Roi: + """ + Return the region of interest of the array. + + Returns: + Roi: The region of interest of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.roi + Roi([0, 0, 0], [100, 100, 100]) + Note: + Tiffs have tons of different locations for metadata. + """ return Roi(self._offset, self.shape) @property def writable(self) -> bool: + """ + Return whether the array is writable. + + Returns: + bool: Whether the array is writable. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.writable + False + Note: + Tiffs have tons of different locations for metadata. + """ return False @property def dtype(self): + """ + Return the data type of the array. + + Returns: + np.dtype: The data type of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.dtype + np.float32 + Note: + Tiffs have tons of different locations for metadata. + + """ return self.data.dtype @property def num_channels(self) -> Optional[int]: + """ + Return the number of channels of the array. + + Returns: + Optional[int]: The number of channels of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.num_channels + 1 + Note: + Tiffs have tons of different locations for metadata. + + """ if "c" in self.axes: return self.data.shape[self.axes.index("c")] else: @@ -75,8 +239,36 @@ def num_channels(self) -> Optional[int]: @property def spatial_axes(self) -> List[str]: + """ + Return the spatial axes of the array. + + Returns: + List[str]: The spatial axes of the array. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.spatial_axes + ['x', 'y', 'z'] + Note: + Tiffs have tons of different locations for metadata. + """ return [c for c in self.axes if c != "c"] @lazy_property.LazyProperty def data(self): + """ + Return the data of the tiff. + + Returns: + np.ndarray: The data of the tiff. + Raises: + NotImplementedError: + Tiffs have tons of different locations for metadata. + Examples: + >>> tiff_array.data + np.ndarray + Note: + Tiffs have tons of different locations for metadata. + """ return tifffile.TiffFile(self._file_name).values diff --git a/dacapo/experiments/datasplits/datasets/arrays/tiff_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/tiff_array_config.py index d1930e55a..27b4e623a 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/tiff_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/tiff_array_config.py @@ -5,13 +5,27 @@ from funlib.geometry import Coordinate -from pathlib import Path +from upath import UPath as Path from typing import List @attr.s class ZarrArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for a tiff array""" + """ + This config class provides the necessary configuration for a tiff array + + Attributes: + file_name: Path + The file name of the tiff. + offset: Coordinate + The offset of the array. + voxel_size: Coordinate + The voxel size of the array. + axes: List[str] + The axes of the array. + Note: + This class is a subclass of ArrayConfig. + """ array_type = TiffArray diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py index 51046fd2e..f61bf0cd4 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array.py @@ -9,10 +9,11 @@ import lazy_property import numpy as np import zarr +from zarr.n5 import N5FSStore from collections import OrderedDict import logging -from pathlib import Path +from upath import UPath as Path import json from typing import Dict, Tuple, Any, Optional, List @@ -20,30 +21,172 @@ class ZarrArray(Array): - """This is a zarr array""" + """ + This is a zarr array. + + Attributes: + name (str): The name of the array. + file_name (Path): The file name of the array. + dataset (str): The dataset name. + _axes (Optional[List[str]]): The axes of the array. + snap_to_grid (Optional[Coordinate]): The snap to grid. + Methods: + __init__(array_config): + Initializes the array type 'raw' and name for the DummyDataset instance. + __str__(): + Returns the string representation of the ZarrArray. + __repr__(): + Returns the string representation of the ZarrArray. + attrs(): + Returns the attributes of the array. + axes(): + Returns the axes of the array. + dims(): + Returns the dimensions of the array. + _daisy_array(): + Returns the daisy array. + voxel_size(): + Returns the voxel size of the array. + roi(): + Returns the region of interest of the array. + writable(): + Returns the boolean value of the array. + dtype(): + Returns the data type of the array. + num_channels(): + Returns the number of channels of the array. + spatial_axes(): + Returns the spatial axes of the array. + data(): + Returns the data of the array. + __getitem__(roi): + Returns the data of the array for the given region of interest. + __setitem__(roi, value): + Sets the data of the array for the given region of interest. + create_from_array_identifier(array_identifier, axes, roi, num_channels, voxel_size, dtype, write_size=None, name=None, overwrite=False): + Creates a new ZarrArray given an array identifier. + open_from_array_identifier(array_identifier, name=""): + Opens a new ZarrArray given an array identifier. + _can_neuroglance(): + Returns the boolean value of the array. + _neuroglancer_source(): + Returns the neuroglancer source of the array. + _neuroglancer_layer(): + Returns the neuroglancer layer of the array. + _transform_matrix(): + Returns the transform matrix of the array. + _output_dimensions(): + Returns the output dimensions of the array. + _source_name(): + Returns the source name of the array. + add_metadata(metadata): + Adds metadata to the array. + Notes: + This class is used to create a zarr array. + """ def __init__(self, array_config): + """ + Initializes the array type 'raw' and name for the DummyDataset instance. + + Args: + array_config (object): an instance of a configuration class that includes the name and + raw configuration of the data. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> dataset = DummyDataset(dataset_config) + Notes: + This method is used to initialize the dataset. + """ super().__init__() self.name = array_config.name self.file_name = array_config.file_name self.dataset = array_config.dataset - + self._mode = array_config.mode self._attributes = self.data.attrs self._axes = array_config._axes self.snap_to_grid = array_config.snap_to_grid def __str__(self): + """ + Returns the string representation of the ZarrArray. + + Args: + ZarrArray (str): The string representation of the ZarrArray. + Returns: + str: The string representation of the ZarrArray. + Raises: + NotImplementedError + Examples: + >>> print(ZarrArray) + Notes: + This method is used to return the string representation of the ZarrArray. + """ return f"ZarrArray({self.file_name}, {self.dataset})" def __repr__(self): + """ + Returns the string representation of the ZarrArray. + + Args: + ZarrArray (str): The string representation of the ZarrArray. + Returns: + str: The string representation of the ZarrArray. + Raises: + NotImplementedError + Examples: + >>> print(ZarrArray) + Notes: + This method is used to return the string representation of the ZarrArray. + + """ return f"ZarrArray({self.file_name}, {self.dataset})" + @property + def mode(self): + if not hasattr(self, "_mode"): + self._mode = "a" + if self._mode not in ["r", "w", "a"]: + raise ValueError(f"Mode {self._mode} not in ['r', 'w', 'a']") + return self._mode + @property def attrs(self): + """ + Returns the attributes of the array. + + Args: + attrs (Any): The attributes of the array. + Returns: + Any: The attributes of the array. + Raises: + NotImplementedError + Examples: + >>> attrs() + Notes: + This method is used to return the attributes of the array. + + """ return self.data.attrs @property def axes(self): + """ + Returns the axes of the array. + + Args: + axes (List[str]): The axes of the array. + Returns: + List[str]: The axes of the array. + Raises: + NotImplementedError + Examples: + >>> axes() + Notes: + This method is used to return the axes of the array. + """ if self._axes is not None: return self._axes try: @@ -58,18 +201,77 @@ def axes(self): @property def dims(self) -> int: + """ + Returns the dimensions of the array. + + Args: + dims (int): The dimensions of the array. + Returns: + int: The dimensions of the array. + Raises: + NotImplementedError + Examples: + >>> dims() + Notes: + This method is used to return the dimensions of the array. + + """ return self.voxel_size.dims @lazy_property.LazyProperty def _daisy_array(self) -> funlib.persistence.Array: + """ + Returns the daisy array. + + Args: + voxel_size (Coordinate): The voxel size. + Returns: + funlib.persistence.Array: The daisy array. + Raises: + NotImplementedError + Examples: + >>> _daisy_array() + Notes: + This method is used to return the daisy array. + + """ return funlib.persistence.open_ds(f"{self.file_name}", self.dataset) @lazy_property.LazyProperty def voxel_size(self) -> Coordinate: + """ + Returns the voxel size of the array. + + Args: + voxel_size (Coordinate): The voxel size. + Returns: + Coordinate: The voxel size of the array. + Raises: + NotImplementedError + Examples: + >>> voxel_size() + Notes: + This method is used to return the voxel size of the array. + + """ return self._daisy_array.voxel_size @lazy_property.LazyProperty def roi(self) -> Roi: + """ + Returns the region of interest of the array. + + Args: + roi (Roi): The region of interest. + Returns: + Roi: The region of interest of the array. + Raises: + NotImplementedError + Examples: + >>> roi() + Notes: + This method is used to return the region of interest of the array. + """ if self.snap_to_grid is not None: return self._daisy_array.roi.snap_to_grid(self.snap_to_grid, mode="shrink") else: @@ -77,32 +279,136 @@ def roi(self) -> Roi: @property def writable(self) -> bool: + """ + Returns the boolean value of the array. + + Args: + writable (bool): The boolean value of the array. + Returns: + bool: The boolean value of the array. + Raises: + NotImplementedError + Examples: + >>> writable() + Notes: + This method is used to return the boolean value of the array. + """ return True @property def dtype(self) -> Any: + """ + Returns the data type of the array. + + Args: + dtype (Any): The data type of the array. + Returns: + Any: The data type of the array. + Raises: + NotImplementedError + Examples: + >>> dtype() + Notes: + This method is used to return the data type of the array. + """ return self.data.dtype @property def num_channels(self) -> Optional[int]: + """ + Returns the number of channels of the array. + + Args: + num_channels (Optional[int]): The number of channels of the array. + Returns: + Optional[int]: The number of channels of the array. + Raises: + NotImplementedError + Examples: + >>> num_channels() + Notes: + This method is used to return the number of channels of the array. + + """ return None if "c" not in self.axes else self.data.shape[self.axes.index("c")] @property def spatial_axes(self) -> List[str]: + """ + Returns the spatial axes of the array. + + Args: + spatial_axes (List[str]): The spatial axes of the array. + Returns: + List[str]: The spatial axes of the array. + Raises: + NotImplementedError + Examples: + >>> spatial_axes() + Notes: + This method is used to return the spatial axes of the array. + + """ return [ax for ax in self.axes if ax not in set(["c", "b"])] @property def data(self) -> Any: - zarr_container = zarr.open(str(self.file_name)) + """ + Returns the data of the array. + + Args: + data (Any): The data of the array. + Returns: + Any: The data of the array. + Raises: + NotImplementedError + Examples: + >>> data() + Notes: + This method is used to return the data of the array. + """ + file_name = str(self.file_name) + # Zarr library does not detect the store for N5 datasets + if file_name.endswith(".n5"): + zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode) + else: + zarr_container = zarr.open(str(file_name), mode=self.mode) return zarr_container[self.dataset] def __getitem__(self, roi: Roi) -> np.ndarray: + """ + Returns the data of the array for the given region of interest. + + Args: + roi (Roi): The region of interest. + Returns: + np.ndarray: The data of the array for the given region of interest. + Raises: + NotImplementedError + Examples: + >>> __getitem__(roi) + Notes: + This method is used to return the data of the array for the given region of interest. + """ data: np.ndarray = funlib.persistence.Array( self.data, self.roi, self.voxel_size ).to_ndarray(roi=roi) return data def __setitem__(self, roi: Roi, value: np.ndarray): + """ + Sets the data of the array for the given region of interest. + + Args: + roi (Roi): The region of interest. + value (np.ndarray): The value to set. + Raises: + NotImplementedError + Examples: + >>> __setitem__(roi, value) + Notes: + This method is used to set the data of the array for the given region of interest. + """ funlib.persistence.Array(self.data, self.roi, self.voxel_size)[roi] = value @classmethod @@ -114,13 +420,33 @@ def create_from_array_identifier( num_channels, voxel_size, dtype, + mode="a", write_size=None, name=None, overwrite=False, ): """ Create a new ZarrArray given an array identifier. It is assumed that - this array_identifier points to a dataset that does not yet exist + this array_identifier points to a dataset that does not yet exist. + + Args: + array_identifier (ArrayIdentifier): The array identifier. + axes (List[str]): The axes of the array. + roi (Roi): The region of interest. + num_channels (int): The number of channels. + voxel_size (Coordinate): The voxel size. + dtype (Any): The data type. + write_size (Optional[Coordinate]): The write size. + name (Optional[str]): The name of the array. + overwrite (bool): The boolean value to overwrite the array. + Returns: + ZarrArray: The ZarrArray. + Raises: + NotImplementedError + Examples: + >>> create_from_array_identifier(array_identifier, axes, roi, num_channels, voxel_size, dtype, write_size=None, name=None, overwrite=False) + Notes: + This method is used to create a new ZarrArray given an array identifier. """ if write_size is None: # total storage per block is approx c*x*y*z*dtype_size @@ -214,6 +540,21 @@ def create_from_array_identifier( @classmethod def open_from_array_identifier(cls, array_identifier, name=""): + """ + Opens a new ZarrArray given an array identifier. + + Args: + array_identifier (ArrayIdentifier): The array identifier. + name (str): The name of the array. + Returns: + ZarrArray: The ZarrArray. + Raises: + NotImplementedError + Examples: + >>> open_from_array_identifier(array_identifier, name="") + Notes: + This method is used to open a new ZarrArray given an array identifier. + """ zarr_array = cls.__new__(cls) zarr_array.name = name zarr_array.file_name = array_identifier.container @@ -224,9 +565,38 @@ def open_from_array_identifier(cls, array_identifier, name=""): return zarr_array def _can_neuroglance(self) -> bool: + """ + Returns the boolean value of the array. + + Args: + can_neuroglance (bool): The boolean value of the array. + Returns: + bool: The boolean value of the array. + Raises: + NotImplementedError + Examples: + >>> can_neuroglance() + Notes: + This method is used to return the boolean value of the array. + """ return True def _neuroglancer_source(self): + """ + Returns the neuroglancer source of the array. + + Args: + neuroglancer.LocalVolume: The neuroglancer source of the array. + Returns: + neuroglancer.LocalVolume: The neuroglancer source of the array. + Raises: + NotImplementedError + Examples: + >>> neuroglancer_source() + Notes: + This method is used to return the neuroglancer source of the array. + + """ d = open_ds(str(self.file_name), self.dataset) return neuroglancer.LocalVolume( data=d.data, @@ -239,10 +609,38 @@ def _neuroglancer_source(self): ) def _neuroglancer_layer(self) -> Tuple[neuroglancer.ImageLayer, Dict[str, Any]]: + """ + Returns the neuroglancer layer of the array. + + Args: + layer (neuroglancer.ImageLayer): The neuroglancer layer of the array. + Returns: + Tuple[neuroglancer.ImageLayer, Dict[str, Any]]: The neuroglancer layer of the array. + Raises: + NotImplementedError + Examples: + >>> neuroglancer_layer() + Notes: + This method is used to return the neuroglancer layer of the array. + """ layer = neuroglancer.ImageLayer(source=self._neuroglancer_source()) return layer def _transform_matrix(self): + """ + Returns the transform matrix of the array. + + Args: + transform_matrix (List[List[float]]): The transform matrix of the array. + Returns: + List[List[float]]: The transform matrix of the array. + Raises: + NotImplementedError + Examples: + >>> transform_matrix() + Notes: + This method is used to return the transform matrix of the array. + """ is_zarr = self.file_name.name.endswith(".zarr") if is_zarr: offset = self.roi.offset @@ -267,6 +665,20 @@ def _transform_matrix(self): return [[0] * i + [1] + [0] * (self.dims - i) for i in range(self.dims)] def _output_dimensions(self) -> Dict[str, Tuple[float, str]]: + """ + Returns the output dimensions of the array. + + Args: + output_dimensions (Dict[str, Tuple[float, str]]): The output dimensions of the array. + Returns: + Dict[str, Tuple[float, str]]: The output dimensions of the array. + Raises: + NotImplementedError + Examples: + >>> output_dimensions() + Notes: + This method is used to return the output dimensions of the array. + """ is_zarr = self.file_name.name.endswith(".zarr") if is_zarr: spatial_dimensions = OrderedDict() @@ -282,9 +694,37 @@ def _output_dimensions(self) -> Dict[str, Tuple[float, str]]: } def _source_name(self) -> str: + """ + Returns the source name of the array. + + Args: + source_name (str): The source name of the array. + Returns: + str: The source name of the array. + Raises: + NotImplementedError + Examples: + >>> source_name() + Notes: + This method is used to return the source name of the array. + + """ return self.name def add_metadata(self, metadata: Dict[str, Any]) -> None: + """ + Adds metadata to the array. + + Args: + metadata (Dict[str, Any]): The metadata to add to the array. + Raises: + NotImplementedError + Examples: + >>> add_metadata(metadata) + Notes: + This method is used to add metadata to the array. + + """ dataset = zarr.open(self.file_name, mode="a")[self.dataset] for k, v in metadata.items(): dataset.attrs[k] = v diff --git a/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py b/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py index 69bce2378..b67717647 100644 --- a/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py +++ b/dacapo/experiments/datasplits/datasets/arrays/zarr_array_config.py @@ -5,14 +5,36 @@ from funlib.geometry import Coordinate -from pathlib import Path +from upath import UPath as Path from typing import Optional, List, Tuple @attr.s class ZarrArrayConfig(ArrayConfig): - """This config class provides the necessary configuration for a zarr array""" + """ + This config class provides the necessary configuration for a zarr array. + + A zarr array is a container for large, multi-dimensional arrays. It is similar to HDF5, but is designed to work + with large arrays that do not fit into memory. Zarr arrays can be stored on disk or in the cloud + and can be accessed concurrently by multiple processes. Zarr arrays can be compressed and + support chunked, N-dimensional arrays. + + Attributes: + file_name: Path + The file name of the zarr container. + dataset: str + The name of your dataset. May include '/' characters for nested heirarchies + snap_to_grid: Optional[Coordinate] + If you need to make sure your ROI's align with a specific voxel_size + _axes: Optional[List[str]] + The axes of your data! + Methods: + verify() -> Tuple[bool, str] + Check whether this is a valid Array + Note: + This class is a subclass of ArrayConfig. + """ array_type = ZarrArray @@ -33,10 +55,30 @@ class ZarrArrayConfig(ArrayConfig): _axes: Optional[List[str]] = attr.ib( default=None, metadata={"help_text": "The axes of your data!"} ) + mode: Optional[str] = attr.ib( + default="a", metadata={"help_text": "The access mode!"} + ) def verify(self) -> Tuple[bool, str]: """ Check whether this is a valid Array + + Returns: + Tuple[bool, str]: A tuple of a boolean and a string. The boolean indicates whether the Array is valid or not. + The string provides a reason why the Array is not valid. + Raises: + NotImplementedError: This method is not implemented for this Array + Examples: + >>> zarr_array_config = ZarrArrayConfig( + ... file_name=Path("data.zarr"), + ... dataset="data", + ... snap_to_grid=Coordinate(1, 1, 1), + ... _axes=["x", "y", "z"] + ... ) + >>> zarr_array_config.verify() + (True, 'No validation for this Array') + Note: + This method is not implemented for this Array """ if not self.file_name.exists(): return False, f"{self.file_name} does not exist!" diff --git a/dacapo/experiments/datasplits/datasets/dataset.py b/dacapo/experiments/datasplits/datasets/dataset.py index 663805227..ced4f58d6 100644 --- a/dacapo/experiments/datasplits/datasets/dataset.py +++ b/dacapo/experiments/datasplits/datasets/dataset.py @@ -15,6 +15,20 @@ class Dataset(ABC): mask (Array, optional): The mask for the data. weight (int, optional): The weight of the dataset. sample_points (list[Coordinate], optional): The list of sample points in the dataset. + Methods: + __eq__(other): + Overloaded equality operator for dataset objects. + __hash__(): + Calculates a hash for the dataset. + __repr__(): + Returns the official string representation of the dataset object. + __str__(): + Returns the string representation of the dataset object. + _neuroglancer_layers(prefix="", exclude_layers=None): + Generates neuroglancer layers for raw, gt and mask if they can be viewed by neuroglance, excluding those in + the exclude_layers. + Notes: + This class is a base class and should not be instantiated. """ name: str @@ -30,9 +44,17 @@ def __eq__(self, other: Any) -> bool: Args: other (Any): The object to compare with the dataset. - Returns: bool: True if the object is also a dataset and they have the same name, False otherwise. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset1 = Dataset("dataset1") + >>> dataset2 = Dataset("dataset2") + >>> dataset1 == dataset2 + False + Notes: + This method is used to compare two dataset objects. """ return isinstance(other, type(self)) and self.name == other.name @@ -42,6 +64,14 @@ def __hash__(self) -> int: Returns: int: The hash of the dataset name. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset = Dataset("dataset") + >>> hash(dataset) + 123456 + Notes: + This method is used to calculate a hash for the dataset. """ return hash(self.name) @@ -51,6 +81,14 @@ def __repr__(self) -> str: Returns: str: String representation of the dataset. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset = Dataset("dataset") + >>> dataset + Dataset(dataset) + Notes: + This method is used to return the official string representation of the dataset object. """ return f"Dataset({self.name})" @@ -58,8 +96,18 @@ def __str__(self) -> str: """ Returns the string representation of the dataset object. + Args: + self (Dataset): The dataset object. Returns: str: String representation of the dataset. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset = Dataset("dataset") + >>> print(dataset) + Dataset(dataset) + Notes: + This method is used to return the string representation of the dataset object. """ return f"Dataset({self.name})" @@ -71,9 +119,16 @@ def _neuroglancer_layers(self, prefix="", exclude_layers=None): Args: prefix (str, optional): A prefix to be added to the layer names. exclude_layers (set, optional): A set of layer names to exclude. - Returns: dict: A dictionary containing layer names as keys and corresponding neuroglancer layer as values. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset = Dataset("dataset") + >>> dataset._neuroglancer_layers() + {"raw": neuroglancer_layer} + Notes: + This method is used to generate neuroglancer layers for raw, gt and mask if they can be viewed by neuroglance. """ layers = {} exclude_layers = exclude_layers if exclude_layers is not None else set() diff --git a/dacapo/experiments/datasplits/datasets/dataset_config.py b/dacapo/experiments/datasplits/datasets/dataset_config.py index c860d600e..4217eb00e 100644 --- a/dacapo/experiments/datasplits/datasets/dataset_config.py +++ b/dacapo/experiments/datasplits/datasets/dataset_config.py @@ -5,7 +5,8 @@ @attr.s class DatasetConfig: - """A class used to define configuration for datasets. This provides the + """ + A class used to define configuration for datasets. This provides the framework to create a Dataset instance. Attributes: @@ -18,11 +19,12 @@ class DatasetConfig: A numeric value that indicates how frequently this dataset should be sampled in comparison to others. Higher the weight, more frequently it gets sampled. - Methods: verify: Checks and validates the dataset configuration. The specific rules for validation need to be defined by the user. + Notes: + This class is used to create a configuration object for datasets. """ name: str = attr.ib( @@ -51,5 +53,13 @@ def verify(self) -> Tuple[bool, str]: Returns: tuple: A tuple of boolean value indicating the check (True or False) and message specifying result of validation. + Raises: + NotImplementedError: If the method is not implemented in the derived class. + Examples: + >>> dataset_config = DatasetConfig(name="sample_dataset") + >>> dataset_config.verify() + (True, "No validation for this DataSet") + Notes: + This method is used to validate the configuration of the dataset. """ return True, "No validation for this DataSet" diff --git a/dacapo/experiments/datasplits/datasets/dummy_dataset.py b/dacapo/experiments/datasplits/datasets/dummy_dataset.py index cec9e05b4..4fc34e84b 100644 --- a/dacapo/experiments/datasplits/datasets/dummy_dataset.py +++ b/dacapo/experiments/datasplits/datasets/dummy_dataset.py @@ -3,20 +3,35 @@ class DummyDataset(Dataset): - """DummyDataset is a child class of the Dataset. This class has property 'raw' of Array type and a name. + """ + DummyDataset is a child class of the Dataset. This class has property 'raw' of Array type and a name. - Args: - dataset_config (object): an instance of a configuration class. + Attributes: + raw: Array + The raw data. + Methods: + __init__(dataset_config): + Initializes the array type 'raw' and name for the DummyDataset instance. + Notes: + This class is used to create a dataset with raw data. """ raw: Array def __init__(self, dataset_config): - """Initializes the array type 'raw' and name for the DummyDataset instance. + """ + Initializes the array type 'raw' and name for the DummyDataset instance. Args: dataset_config (object): an instance of a configuration class that includes the name and raw configuration of the data. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> dataset = DummyDataset(dataset_config) + Notes: + This method is used to initialize the dataset. """ super().__init__() self.name = dataset_config.name diff --git a/dacapo/experiments/datasplits/datasets/dummy_dataset_config.py b/dacapo/experiments/datasplits/datasets/dummy_dataset_config.py index ecdf3e36e..6aaefc98a 100644 --- a/dacapo/experiments/datasplits/datasets/dummy_dataset_config.py +++ b/dacapo/experiments/datasplits/datasets/dummy_dataset_config.py @@ -15,9 +15,10 @@ class DummyDatasetConfig(DatasetConfig): Attributes: dataset_type : Clearly mentions the type of dataset raw_config : This attribute holds the configurations related to dataset arrays. - Methods: verify: A dummy verification method for testing purposes, always returns False and a message. + Notes: + This class is used to create a configuration object for the dummy dataset. """ dataset_type = DummyDataset @@ -25,10 +26,20 @@ class DummyDatasetConfig(DatasetConfig): raw_config: ArrayConfig = attr.ib(DummyArrayConfig(name="dummy_array")) def verify(self) -> Tuple[bool, str]: - """A dummy method that always indicates the dataset config is not valid. + """ + A dummy method that always indicates the dataset config is not valid. Returns: A tuple of False and a message indicating the invalidity. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> dataset_config = DummyDatasetConfig(raw_config=DummyArrayConfig(name="dummy_array")) + >>> dataset_config.verify() + (False, "This is a DummyDatasetConfig and is never valid") + Notes: + This method is used to validate the configuration of the dataset. """ return False, "This is a DummyDatasetConfig and is never valid" diff --git a/dacapo/experiments/datasplits/datasets/graphstores/graph_source_config.py b/dacapo/experiments/datasplits/datasets/graphstores/graph_source_config.py index d7d587d78..1a2a7745f 100644 --- a/dacapo/experiments/datasplits/datasets/graphstores/graph_source_config.py +++ b/dacapo/experiments/datasplits/datasets/graphstores/graph_source_config.py @@ -3,9 +3,17 @@ @attr.s class GraphStoreConfig: - """Base class for graph store configurations. Each subclass of a + """ + Base class for graph store configurations. Each subclass of a `GraphStore` should have a corresponding config class derived from `GraphStoreConfig`. + + Attributes: + store_type (class): The type of graph store that is being configured. + Methods: + verify: A method to verify the validity of the configuration. + Notes: + This class is used to create a configuration object for the graph store. """ pass diff --git a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py index 040c5baa3..8539e8339 100644 --- a/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py +++ b/dacapo/experiments/datasplits/datasets/raw_gt_dataset.py @@ -7,12 +7,47 @@ class RawGTDataset(Dataset): + """ + A dataset that contains raw and ground truth data. Optionally, it can also contain a mask. + + Attributes: + raw: Array + The raw data. + gt: Array + The ground truth data. + mask: Optional[Array] + The mask data. + sample_points: Optional[List[Coordinate]] + The sample points in the graph. + weight: Optional[float] + The weight of the dataset. + Methods: + __init__(dataset_config): + Initialize the dataset. + Notes: + This class is a base class and should not be instantiated. + """ + raw: Array gt: Array mask: Optional[Array] sample_points: Optional[List[Coordinate]] def __init__(self, dataset_config): + """ + Initialize the dataset. + + Args: + dataset_config: DataSplitConfig + The configuration of the dataset. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> dataset = RawGTDataset(dataset_config) + Notes: + This method is used to initialize the dataset. + """ self.name = dataset_config.name self.raw = dataset_config.raw_config.array_type(dataset_config.raw_config) self.gt = dataset_config.gt_config.array_type(dataset_config.gt_config) diff --git a/dacapo/experiments/datasplits/datasets/raw_gt_dataset_config.py b/dacapo/experiments/datasplits/datasets/raw_gt_dataset_config.py index 705bcb467..e967b83d6 100644 --- a/dacapo/experiments/datasplits/datasets/raw_gt_dataset_config.py +++ b/dacapo/experiments/datasplits/datasets/raw_gt_dataset_config.py @@ -27,6 +27,10 @@ class RawGTDatasetConfig(DatasetConfig): equal to zero on voxels where the mask is 1. sample_points (Optional[List[Coordinate]]): An optional list of points around which training samples will be extracted. + Methods: + verify: A method to verify the validity of the configuration. + Notes: + This class is used to create a configuration object for the standard dataset with both raw and GT Array. """ dataset_type = RawGTDataset diff --git a/dacapo/experiments/datasplits/datasplit.py b/dacapo/experiments/datasplits/datasplit.py index 62eaa4b27..f6f34e5d5 100644 --- a/dacapo/experiments/datasplits/datasplit.py +++ b/dacapo/experiments/datasplits/datasplit.py @@ -7,10 +7,46 @@ class DataSplit(ABC): + """ + A class for creating a simple train dataset and no validation dataset. It is derived from `DataSplit` class. + It is used to split the data into training and validation datasets. The training and validation datasets are + used to train and validate the model respectively. + + Attributes: + train : list + The list containing training datasets. In this class, it contains only one dataset for training. + validate : list + The list containing validation datasets. In this class, it is an empty list as no validation dataset is set. + Methods: + __init__(self, datasplit_config): + The constructor for DummyDataSplit class. It initialises a list with training datasets according to the input configuration. + Notes: + This class is used to split the data into training and validation datasets. + """ + train: List[Dataset] validate: Optional[List[Dataset]] def _neuroglancer(self, embedded=False): + """ + A method to visualize the data in Neuroglancer. It creates a Neuroglancer viewer and adds the layers of the training and validation datasets to it. + + Args: + embedded : bool + A boolean flag to indicate if the Neuroglancer viewer is to be embedded in the notebook. + Returns: + viewer : obj + The Neuroglancer viewer object. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which is logged and handled by training the model without head matching. + Examples: + >>> viewer = datasplit._neuroglancer(embedded=True) + Notes: + This function is called by the DataSplit class to visualize the data in Neuroglancer. + It creates a Neuroglancer viewer and adds the layers of the training and validation datasets to it. + Neuroglancer is a powerful tool for visualizing large-scale volumetric data. + """ neuroglancer.set_server_bind_address("0.0.0.0") viewer = neuroglancer.Viewer() with viewer.txn() as s: diff --git a/dacapo/experiments/datasplits/datasplit_config.py b/dacapo/experiments/datasplits/datasplit_config.py index f00069960..992113d47 100644 --- a/dacapo/experiments/datasplits/datasplit_config.py +++ b/dacapo/experiments/datasplits/datasplit_config.py @@ -8,17 +8,16 @@ class DataSplitConfig: """ A class used to create a DataSplit configuration object. - Attributes - ---------- - name : str - A name for the datasplit. This name will be saved so it can be found - and reused easily. It is recommended to keep it short and avoid special - characters. - - Methods - ------- - verify() -> Tuple[bool, str]: - Validates if it is a valid data split configuration. + Attributes: + name : str + A name for the datasplit. This name will be saved so it can be found + and reused easily. It is recommended to keep it short and avoid special + characters. + Methods: + verify() -> Tuple[bool, str]: + Validates if it is a valid data split configuration. + Notes: + This class is used to create a DataSplit configuration object. """ name: str = attr.ib( @@ -33,10 +32,18 @@ def verify(self) -> Tuple[bool, str]: """ Validates if the current configuration is a valid data split configuration. - Returns - ------- - Tuple[bool, str] - True if the configuration is valid, - False otherwise along with respective validation error message. + Returns: + Tuple[bool, str] + True if the configuration is valid, + False otherwise along with respective validation error message. + Raises: + NotImplementedError + If the method is not implemented in the derived class. + Examples: + >>> datasplit_config = DataSplitConfig(name="datasplit") + >>> datasplit_config.verify() + (True, "No validation for this DataSplit") + Notes: + This method is used to validate the configuration of DataSplit. """ return True, "No validation for this DataSplit" diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index 8f177e187..a69fd633f 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -1,11 +1,12 @@ from dacapo.experiments.tasks import TaskConfig -from pathlib import Path +from upath import UPath as Path from typing import List from enum import Enum, EnumMeta from funlib.geometry import Coordinate from typing import Union, Optional import zarr +from zarr.n5 import N5FSStore from dacapo.experiments.datasplits.datasets.arrays import ( ZarrArrayConfig, ZarrArray, @@ -21,14 +22,55 @@ logger = logging.getLogger(__name__) -def is_zarr_group(file_name: str, dataset: str): - zarr_file = zarr.open(str(file_name)) +def is_zarr_group(file_name: Path, dataset: str): + """ + Check if the dataset is a Zarr group. If the dataset is a Zarr group, it will return True, otherwise False. + + Args: + file_name : str + The name of the file. + dataset : str + The name of the dataset. + Returns: + bool : True if the dataset is a Zarr group, otherwise False. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> is_zarr_group(file_name, dataset) + Notes: + This function is used to check if the dataset is a Zarr group. + """ + if file_name.suffix == ".n5": + zarr_file = zarr.open(N5FSStore(str(file_name)), mode="r") + else: + zarr_file = zarr.open(str(file_name), mode="r") return isinstance(zarr_file[dataset], zarr.hierarchy.Group) def resize_if_needed( array_config: ZarrArrayConfig, target_resolution: Coordinate, extra_str="" ): + """ + Resize the array if needed. If the array needs to be resized, it will return the resized array, otherwise it will return the original array. + + Args: + array_config : obj + The configuration of the array. + target_resolution : obj + The target resolution. + extra_str : str + An extra string. + Returns: + obj : The resized array if needed, otherwise the original array. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> resize_if_needed(array_config, target_resolution, extra_str) + Notes: + This function is used to resize the array if needed. + """ zarr_array = ZarrArray(array_config) raw_voxel_size = zarr_array.voxel_size @@ -49,6 +91,28 @@ def resize_if_needed( def get_right_resolution_array_config( container: Path, dataset, target_resolution, extra_str="" ): + """ + Get the right resolution array configuration. It will return the right resolution array configuration. + + Args: + container : obj + The container. + dataset : str + The dataset. + target_resolution : obj + The target resolution. + extra_str : str + An extra string. + Returns: + obj : The right resolution array configuration. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> get_right_resolution_array_config(container, dataset, target_resolution, extra_str) + Notes: + This function is used to get the right resolution array configuration. + """ level = 0 current_dataset_path = Path(dataset, f"s{level}") if not (container / current_dataset_path).exists(): @@ -61,6 +125,7 @@ def get_right_resolution_array_config( file_name=container, dataset=str(current_dataset_path), snap_to_grid=target_resolution, + mode="r", ) zarr_array = ZarrArray(zarr_config) while ( @@ -73,6 +138,7 @@ def get_right_resolution_array_config( file_name=container, dataset=str(Path(dataset, f"s{level}")), snap_to_grid=target_resolution, + mode="r", ) zarr_array = ZarrArray(zarr_config) @@ -80,7 +146,36 @@ def get_right_resolution_array_config( class CustomEnumMeta(EnumMeta): + """ + Custom Enum Meta class to raise KeyError when an invalid option is passed. + + Attributes: + _member_names_ : list + The list of member names. + Methods: + __getitem__(self, item) + A method to get the item. + Notes: + This class is used to raise KeyError when an invalid option is passed. + """ + def __getitem__(self, item): + """ + Get the item. + + Args: + item : obj + The item. + Returns: + obj : The item. + Raises: + KeyError + If the item is not a valid option, a KeyError is raised. + Examples: + >>> __getitem__(item) + Notes: + This function is used to get the item. + """ if item not in self._member_names_: raise KeyError( f"{item} is not a valid option of {self.__name__}, the valid options are {self._member_names_}" @@ -89,21 +184,103 @@ def __getitem__(self, item): class CustomEnum(Enum, metaclass=CustomEnumMeta): + """ + A custom Enum class to raise KeyError when an invalid option is passed. + + Attributes: + __str__ : str + The string representation of the class. + Methods: + __str__(self) + A method to get the string representation of the class. + Notes: + This class is used to raise KeyError when an invalid option is passed. + """ + def __str__(self) -> str: + """ + Get the string representation of the class. + + Args: + self : obj + The object. + Returns: + str : The string representation of the class. + Raises: + KeyError + If the item is not a valid option, a KeyError is raised. + Examples: + >>> __str__() + Notes: + This function is used to get the string representation of the class. + """ return self.name class DatasetType(CustomEnum): + """ + An Enum class to represent the dataset type. It is derived from `CustomEnum` class. + + Attributes: + val : int + The value of the dataset type. + train : int + The training dataset type. + Methods: + __str__(self) + A method to get the string representation of the class. + Notes: + This class is used to represent the dataset type. + """ + val = 1 train = 2 class SegmentationType(CustomEnum): + """ + An Enum class to represent the segmentation type. It is derived from `CustomEnum` class. + + Attributes: + semantic : int + The semantic segmentation type. + instance : int + The instance segmentation type. + Methods: + __str__(self) + A method to get the string representation of the class. + Notes: + This class is used to represent the segmentation type. + """ + semantic = 1 instance = 2 class DatasetSpec: + """ + A class for dataset specification. It is used to specify the dataset. + + Attributes: + dataset_type : obj + The dataset type. + raw_container : obj + The raw container. + raw_dataset : str + The raw dataset. + gt_container : obj + The ground truth container. + gt_dataset : str + The ground truth dataset. + Methods: + __init__(dataset_type, raw_container, raw_dataset, gt_container, gt_dataset) + Initializes the DatasetSpec class with the specified dataset type, raw container, raw dataset, ground truth container, and ground truth dataset. + __str__(self) + A method to get the string representation of the class. + Notes: + This class is used to specify the dataset. + """ + def __init__( self, dataset_type: Union[str, DatasetType], @@ -112,6 +289,28 @@ def __init__( gt_container: Union[str, Path], gt_dataset: str, ): + """ + Initializes the DatasetSpec class with the specified dataset type, raw container, raw dataset, ground truth container, and ground truth dataset. + + Args: + dataset_type : obj + The dataset type. + raw_container : obj + The raw container. + raw_dataset : str + The raw dataset. + gt_container : obj + The ground truth container. + gt_dataset : str + The ground truth dataset. + Raises: + KeyError + If the item is not a valid option, a KeyError is raised. + Methods: + __init__(dataset_type, raw_container, raw_dataset, gt_container, gt_dataset) + Notes: + This function is used to initialize the DatasetSpec class with the specified dataset type, raw container, raw dataset, ground truth container, and ground truth dataset. + """ if isinstance(dataset_type, str): dataset_type = DatasetType[dataset_type.lower()] @@ -128,10 +327,42 @@ def __init__( self.gt_dataset = gt_dataset def __str__(self) -> str: + """ + Get the string representation of the class. + + Args: + self : obj + The object. + Returns: + str : The string representation of the class. + Raises: + KeyError + If the item is not a valid option, a KeyError is raised. + Examples: + >>> __str__() + Notes: + This function is used to get the string representation of the class. + """ return f"{self.raw_container.stem}_{self.gt_dataset}" def generate_dataspec_from_csv(csv_path: Path): + """ + Generate the dataset specification from the CSV file. It will return the dataset specification. + + Args: + csv_path : obj + The CSV file path. + Returns: + list : The dataset specification. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> generate_dataspec_from_csv(csv_path) + Notes: + This function is used to generate the dataset specification from the CSV file. + """ datasets = [] if not csv_path.exists(): raise FileNotFoundError(f"CSV file {csv_path} does not exist.") @@ -158,14 +389,73 @@ def generate_dataspec_from_csv(csv_path: Path): class DataSplitGenerator: - """Generates DataSplitConfig for a given task config and datasets. - class names in gt_dataset shoulb be within [] e.g. [mito&peroxisome&er] for mutiple classes or [mito] for one class - Currently only supports: - - semantic segmentation. - Supports: + """ + Generates DataSplitConfig for a given task config and datasets. A csv file can be generated + from the DataSplitConfig and used to generate the DataSplitConfig again. + + Currently only supports semantic segmentation. + Supports: - 2D and 3D datasets. - Zarr, N5 and OME-Zarr datasets. - Multi class targets. + - Different resolutions for raw and ground truth datasets. + - Different resolutions for training and validation datasets. + + Attributes: + name : str + The name of the data split generator. + datasets : list + The list of dataset specifications. + input_resolution : obj + The input resolution. + output_resolution : obj + The output resolution. + targets : list + The list of targets. + segmentation_type : obj + The segmentation type. + max_gt_downsample : int + The maximum ground truth downsample. + max_gt_upsample : int + The maximum ground truth upsample. + max_raw_training_downsample : int + The maximum raw training downsample. + max_raw_training_upsample : int + The maximum raw training upsample. + max_raw_validation_downsample : int + The maximum raw validation downsample. + max_raw_validation_upsample : int + The maximum raw validation upsample. + min_training_volume_size : int + The minimum training volume size. + raw_min : int + The minimum raw value. + raw_max : int + The maximum raw value. + classes_separator_caracter : str + The classes separator character. + Methods: + __init__(name, datasets, input_resolution, output_resolution, targets, segmentation_type, max_gt_downsample, max_gt_upsample, max_raw_training_downsample, max_raw_training_upsample, max_raw_validation_downsample, max_raw_validation_upsample, min_training_volume_size, raw_min, raw_max, classes_separator_caracter) + Initializes the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character. + __str__(self) + A method to get the string representation of the class. + class_name(self) + A method to get the class name. + check_class_name(self, class_name) + A method to check the class name. + compute(self) + A method to compute the data split. + __generate_semantic_seg_datasplit(self) + A method to generate the semantic segmentation data split. + __generate_semantic_seg_dataset_crop(self, dataset) + A method to generate the semantic segmentation dataset crop. + generate_csv(datasets, csv_path) + A method to generate the CSV file. + generate_from_csv(csv_path, input_resolution, output_resolution, name, **kwargs) + A method to generate the data split from the CSV file. + Notes: + - This class is used to generate the DataSplitConfig for a given task config and datasets. + - Class names in gt_dataset shoulb be within [] e.g. [mito&peroxisome&er] for mutiple classes or [mito] for one class """ def __init__( @@ -187,6 +477,69 @@ def __init__( raw_max=255, classes_separator_caracter="&", ): + """ + Initializes the DataSplitGenerator class with the specified: + - name + - datasets + - input resolution + - output resolution + - targets + - segmentation type + - maximum ground truth downsample + - maximum ground truth upsample + - maximum raw training downsample + - maximum raw training upsample + - maximum raw validation downsample + - maximum raw validation upsample + - minimum training volume size + - minimum raw value + - maximum raw value + - classes separator character + + Args: + name : str + The name of the data split generator. + datasets : list + The list of dataset specifications. + input_resolution : obj + The input resolution. + output_resolution : obj + The output resolution. + targets : list + The list of targets. + segmentation_type : obj + The segmentation type. + max_gt_downsample : int + The maximum ground truth downsample. + max_gt_upsample : int + The maximum ground truth upsample. + max_raw_training_downsample : int + The maximum raw training downsample. + max_raw_training_upsample : int + The maximum raw training upsample. + max_raw_validation_downsample : int + The maximum raw validation downsample. + max_raw_validation_upsample : int + The maximum raw validation upsample. + min_training_volume_size : int + The minimum training volume size. + raw_min : int + The minimum raw value. + raw_max : int + The maximum raw value. + classes_separator_caracter : str + The classes separator character. + Returns: + obj : The DataSplitGenerator class. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> DataSplitGenerator(name, datasets, input_resolution, output_resolution, targets, segmentation_type, max_gt_downsample, max_gt_upsample, max_raw_training_downsample, max_raw_training_upsample, max_raw_validation_downsample, max_raw_validation_upsample, min_training_volume_size, raw_min, raw_max, classes_separator_caracter) + Notes: + This function is used to initialize the DataSplitGenerator class with the specified name, datasets, input resolution, output resolution, targets, segmentation type, maximum ground truth downsample, maximum ground truth upsample, maximum raw training downsample, maximum raw training upsample, maximum raw validation downsample, maximum raw validation upsample, minimum training volume size, minimum raw value, maximum raw value, and classes separator character. + + """ self.name = name self.datasets = datasets self.input_resolution = input_resolution @@ -210,15 +563,65 @@ def __init__( self.classes_separator_caracter = classes_separator_caracter def __str__(self) -> str: + """ + Get the string representation of the class. + + Args: + self : obj + The object. + Returns: + str : The string representation of the class. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> __str__() + Notes: + This function is used to get the string representation of the class. + """ return f"DataSplitGenerator:{self.name}_{self.segmentation_type}_{self.class_name}_{self.output_resolution[0]}nm" @property def class_name(self): + """ + Get the class name. + + Args: + self : obj + The object. + Returns: + obj : The class name. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> class_name + Notes: + This function is used to get the class name. + """ return self._class_name # Goal is to force class_name to be set only once, so we have the same classes for all datasets @class_name.setter def class_name(self, class_name): + """ + Set the class name. + + Args: + self : obj + The object. + class_name : obj + The class name. + Returns: + obj : The class name. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> class_name + Notes: + This function is used to set the class name. + """ if self._class_name is not None: raise ValueError( f"Class name already set. Current class name is {self.class_name} and new class name is {class_name}" @@ -226,6 +629,25 @@ def class_name(self, class_name): self._class_name = class_name def check_class_name(self, class_name): + """ + Check the class name. + + Args: + self : obj + The object. + class_name : obj + The class name. + Returns: + obj : The class name. + Raises: + ValueError + If the class name is already set, a ValueError is raised. + Examples: + >>> check_class_name(class_name) + Notes: + This function is used to check the class name. + + """ datasets, classes = format_class_name( class_name, self.classes_separator_caracter ) @@ -242,6 +664,22 @@ def check_class_name(self, class_name): return datasets, classes def compute(self): + """ + Compute the data split. + + Args: + self : obj + The object. + Returns: + obj : The data split. + Raises: + NotImplementedError + If the segmentation type is not implemented, a NotImplementedError is raised. + Examples: + >>> compute() + Notes: + This function is used to compute the data split. + """ if self.segmentation_type == SegmentationType.semantic: return self.__generate_semantic_seg_datasplit() else: @@ -250,6 +688,23 @@ def compute(self): ) def __generate_semantic_seg_datasplit(self): + """ + Generate the semantic segmentation data split. + + Args: + self : obj + The object. + Returns: + obj : The data split. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> __generate_semantic_seg_datasplit() + Notes: + This function is used to generate the semantic segmentation data split. + + """ train_dataset_configs = [] validation_dataset_configs = [] for dataset in self.datasets: @@ -281,6 +736,24 @@ def __generate_semantic_seg_datasplit(self): ) def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): + """ + Generate the semantic segmentation dataset crop. + + Args: + self : obj + The object. + dataset : obj + The dataset. + Returns: + obj : The dataset crop. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> __generate_semantic_seg_dataset_crop(dataset) + Notes: + This function is used to generate the semantic segmentation dataset crop. + """ raw_container = dataset.raw_container raw_dataset = dataset.raw_dataset gt_path = dataset.gt_container @@ -295,7 +768,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): # f"Processing raw_container:{raw_container} raw_dataset:{raw_dataset} gt_path:{gt_path} gt_dataset:{gt_dataset}" # ) - if is_zarr_group(str(raw_container), raw_dataset): + if is_zarr_group(raw_container, raw_dataset): raw_config = get_right_resolution_array_config( raw_container, raw_dataset, self.input_resolution, "raw" ) @@ -305,6 +778,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): name=f"raw_{raw_container.stem}_uint8", file_name=raw_container, dataset=raw_dataset, + mode="r", ), self.input_resolution, "raw", @@ -322,7 +796,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): raise FileNotFoundError( f"GT path {gt_path/current_class_dataset} does not exist." ) - if is_zarr_group(str(gt_path), current_class_dataset): + if is_zarr_group(gt_path, current_class_dataset): gt_config = get_right_resolution_array_config( gt_path, current_class_dataset, self.output_resolution, "gt" ) @@ -332,6 +806,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec): name=f"gt_{gt_path.stem}_{current_class_dataset}_uint8", file_name=gt_path, dataset=current_class_dataset, + mode="r", ), self.output_resolution, "gt", @@ -374,6 +849,31 @@ def generate_from_csv( name: Optional[str] = None, **kwargs, ): + """ + Generate the data split from the CSV file. + + Args: + csv_path : obj + The CSV file path. + input_resolution : obj + The input resolution. + output_resolution : obj + The output resolution. + name : str + The name. + **kwargs : dict + The keyword arguments. + Returns: + obj : The data split. + Raises: + FileNotFoundError + If the file does not exist, a FileNotFoundError is raised. + Examples: + >>> generate_from_csv(csv_path, input_resolution, output_resolution, name, **kwargs) + Notes: + This function is used to generate the data split from the CSV file. + + """ if isinstance(csv_path, str): csv_path = Path(csv_path) @@ -390,6 +890,24 @@ def generate_from_csv( def format_class_name(class_name, separator_character="&"): + """ + Format the class name. + + Args: + class_name : obj + The class name. + separator_character : str + The separator character. + Returns: + obj : The class name. + Raises: + ValueError + If the class name is invalid, a ValueError is raised. + Examples: + >>> format_class_name(class_name, separator_character) + Notes: + This function is used to format the class name. + """ if "[" in class_name: if "]" not in class_name: raise ValueError(f"Invalid class name {class_name} missing ']'") diff --git a/dacapo/experiments/datasplits/dummy_datasplit.py b/dacapo/experiments/datasplits/dummy_datasplit.py index 6a5476ef0..b8bde7327 100644 --- a/dacapo/experiments/datasplits/dummy_datasplit.py +++ b/dacapo/experiments/datasplits/dummy_datasplit.py @@ -5,34 +5,40 @@ class DummyDataSplit(DataSplit): - """A class for creating a simple train dataset and no validation dataset. - - It is derived from `DataSplit` class. - - ... - Attributes - ---------- - train : list - The list containing training datasets. In this class, it contains only one dataset for training. - validate : list - The list containing validation datasets. In this class, it is an empty list as no validation dataset is set. - - Methods - ------- - __init__(self, datasplit_config): - The constructor for DummyDataSplit class. It initialises a list with training datasets according to the input configuration. + """ + A class for creating a simple train dataset and no validation dataset. It is derived from `DataSplit` class. + It is used to split the data into training and validation datasets. The training and validation datasets are + used to train and validate the model respectively. + + Attributes: + train : list + The list containing training datasets. In this class, it contains only one dataset for training. + validate : list + The list containing validation datasets. In this class, it is an empty list as no validation dataset is set. + Methods: + __init__(self, datasplit_config): + The constructor for DummyDataSplit class. It initialises a list with training datasets according to the input configuration. + Notes: + This class is used to split the data into training and validation datasets. """ train: List[Dataset] validate: List[Dataset] def __init__(self, datasplit_config): - """Constructor method for initializing the instance of `DummyDataSplit` class. It sets up the list of training datasets based on the passed configuration. - - Parameters - ---------- - datasplit_config : DatasplitConfig - The configuration setup for processing the datasets into the training sets. + """ + Constructor method for initializing the instance of `DummyDataSplit` class. It sets up the list of training datasets based on the passed configuration. + + Args: + datasplit_config : obj + The configuration to initialize the DummyDataSplit class. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which is logged and handled by training the model without head matching. + Examples: + >>> dummy_datasplit = DummyDataSplit(datasplit_config) + Notes: + This function is called by the DummyDataSplit class to initialize the DummyDataSplit class with specified config to split the data into training and validation datasets. """ super().__init__() diff --git a/dacapo/experiments/datasplits/dummy_datasplit_config.py b/dacapo/experiments/datasplits/dummy_datasplit_config.py index d320df949..fc343909a 100644 --- a/dacapo/experiments/datasplits/dummy_datasplit_config.py +++ b/dacapo/experiments/datasplits/dummy_datasplit_config.py @@ -9,7 +9,8 @@ @attr.s class DummyDataSplitConfig(DataSplitConfig): - """A simple class representing config for Dummy DataSplit. + """ + A simple class representing config for Dummy DataSplit. This class is derived from 'DataSplitConfig' and is initialized with 'DatasetConfig' for training dataset. @@ -17,6 +18,12 @@ class DummyDataSplitConfig(DataSplitConfig): Attributes: datasplit_type: Class of dummy data split functionality. train_config: Config for the training dataset. Defaults to DummyDatasetConfig. + Methods: + verify() + A method for verification. This method always return 'False' plus + a string indicating the condition. + Notes: + This class is used to represent the configuration for Dummy DataSplit. """ @@ -25,10 +32,17 @@ class DummyDataSplitConfig(DataSplitConfig): train_config: DatasetConfig = attr.ib(DummyDatasetConfig(name="dummy_dataset")) def verify(self) -> Tuple[bool, str]: - """A method for verification. This method always return 'False' plus + """ + A method for verification. This method always return 'False' plus a string indicating the condition. Returns: Tuple[bool, str]: A tuple contains a boolean 'False' and a string. + Examples: + >>> dummy_datasplit_config = DummyDataSplitConfig(train_config) + >>> dummy_datasplit_config.verify() + (False, "This is a DummyDataSplit and is never valid") + Notes: + This method is used to verify the configuration of DummyDataSplit. """ return False, "This is a DummyDataSplit and is never valid" diff --git a/dacapo/experiments/datasplits/keys/keys.py b/dacapo/experiments/datasplits/keys/keys.py index 7da64dd78..531e43d49 100644 --- a/dacapo/experiments/datasplits/keys/keys.py +++ b/dacapo/experiments/datasplits/keys/keys.py @@ -2,7 +2,26 @@ class DataKey(Enum): - """Represent a base class for various types of keys in Dacapo library.""" + """ + Represent a base class for various types of keys in Dacapo library. + + Attributes: + RAW: str + The raw data key. + GT: str + The ground truth data key. + MASK: str + The data mask key. + NON_EMPTY: str + The data key for non-empty mask. + SPECIFIED_LOCATIONS: str + The key for specified locations in the graph. + Methods: + __str__(): + Return the string representation of the key. + Notes: + This class is a base class and should not be instantiated. + """ pass @@ -12,16 +31,20 @@ class ArrayKey(DataKey): """ A unique enumeration representing different types of array keys - Attributes - ---------- - RAW: str - The raw data key. - GT: str - The ground truth data key. - MASK: str - The data mask key. - NON_EMPTY: str - The data key for non-empty mask. + Attributes: + RAW: str + The raw data key. + GT: str + The ground truth data key. + MASK: str + The data mask key. + NON_EMPTY: str + The data key for non-empty mask. + Methods: + __str__(): + Return the string representation of the key. + Notes: + This class is a base class and should not be instantiated. """ RAW = "raw" @@ -35,10 +58,14 @@ class GraphKey(DataKey): """ A unique enumeration representing different types of graph keys - Attributes - ---------- - SPECIFIED_LOCATIONS: str - The key for specified locations in the graph. + Attributes: + SPECIFIED_LOCATIONS: str + The key for specified locations in the graph. + Methods: + __str__(): + Return the string representation of the key. + Notes: + This class is a base class and should not be instantiated. """ SPECIFIED_LOCATIONS = "specified_locations" diff --git a/dacapo/experiments/datasplits/train_validate_datasplit.py b/dacapo/experiments/datasplits/train_validate_datasplit.py index 3fdfe6c41..0b93663a3 100644 --- a/dacapo/experiments/datasplits/train_validate_datasplit.py +++ b/dacapo/experiments/datasplits/train_validate_datasplit.py @@ -5,10 +5,47 @@ class TrainValidateDataSplit(DataSplit): + """ + A DataSplit that contains a list of training and validation datasets. This + class is used to split the data into training and validation datasets. The + training and validation datasets are used to train and validate the model + respectively. + + Attributes: + train : list + The list of training datasets. + validate : list + The list of validation datasets. + Methods: + __init__(datasplit_config) + Initializes the TrainValidateDataSplit class with specified config to + split the data into training and validation datasets. + Notes: + This class is used to split the data into training and validation datasets. + """ + train: List[Dataset] validate: List[Dataset] def __init__(self, datasplit_config): + """ + Initializes the TrainValidateDataSplit class with specified config to + split the data into training and validation datasets. + + Args: + datasplit_config : obj + The configuration to initialize the TrainValidateDataSplit class. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which + is logged and handled by training the model without head matching. + Examples: + >>> train_validate_datasplit = TrainValidateDataSplit(datasplit_config) + Notes: + This function is called by the TrainValidateDataSplit class to initialize + the TrainValidateDataSplit class with specified config to split the data + into training and validation datasets. + """ super().__init__() self.train = [ diff --git a/dacapo/experiments/datasplits/train_validate_datasplit_config.py b/dacapo/experiments/datasplits/train_validate_datasplit_config.py index 9970250a6..3cb7f9364 100644 --- a/dacapo/experiments/datasplits/train_validate_datasplit_config.py +++ b/dacapo/experiments/datasplits/train_validate_datasplit_config.py @@ -10,7 +10,23 @@ @attr.s class TrainValidateDataSplitConfig(DataSplitConfig): """ - This is the standard Train/Validate DataSplit config. + This is the standard Train/Validate DataSplit config. It contains a list of + training and validation datasets. This class is used to split the data into + training and validation datasets. The training and validation datasets are + used to train and validate the model respectively. + + Attributes: + train_configs : list + The list of training datasets. + validate_configs : list + The list of validation datasets. + Methods: + __init__(datasplit_config) + Initializes the TrainValidateDataSplitConfig class with specified config to + split the data into training and validation datasets. + Notes: + This class is used to split the data into training and validation datasets. + """ datasplit_type = TrainValidateDataSplit diff --git a/dacapo/experiments/model.py b/dacapo/experiments/model.py index 75777cd81..7df576cb3 100644 --- a/dacapo/experiments/model.py +++ b/dacapo/experiments/model.py @@ -8,13 +8,34 @@ class Model(torch.nn.Module): - """A trainable DaCapo model. Consists of an ``Architecture`` and a + """ + A trainable DaCapo model. Consists of an ``Architecture`` and a prediction head. Models are generated by ``Predictor``s. May include an optional eval_activation that is only executed when the model is in eval mode. This is particularly useful if you want to train with something like BCELossWithLogits, since you want to avoid applying softmax while training, but apply it during evaluation. + + Attributes: + architecture (Architecture): The architecture of the model. + prediction_head (torch.nn.Module): The prediction head of the model. + chain (torch.nn.Sequential): The architecture followed by the prediction head. + num_in_channels (int): The number of input channels. + input_shape (Coordinate): The shape of the input tensor. + eval_input_shape (Coordinate): The shape of the input tensor during evaluation. + num_out_channels (int): The number of output channels. + output_shape (Coordinate): The shape of the output + eval_activation (torch.nn.Module | None): The activation function to apply during evaluation. + Methods: + forward(x: torch.Tensor) -> torch.Tensor: + Forward pass of the model. + compute_output_shape(input_shape: Coordinate) -> Tuple[int, Coordinate]: + Compute the spatial shape of this model, when fed a tensor of the given spatial shape as input. + scale(voxel_size: Coordinate) -> Coordinate: + Scale the model by the given voxel size. + Note: + The output shape is the spatial shape of the model, i.e., not accounting for channels and batch dimensions. """ num_out_channels: int @@ -26,6 +47,42 @@ def __init__( prediction_head: torch.nn.Module, eval_activation: torch.nn.Module | None = None, ): + """ + Initializes a Model object. + + Args: + architecture (Architecture): The architecture of the model. + prediction_head (torch.nn.Module): The prediction head of the model. + eval_activation (torch.nn.Module | None): The activation function to apply during evaluation. + Raises: + AssertionError: If the architecture is not an instance of Architecture. + Examples: + >>> model = Model(architecture, prediction_head) + >>> model + Model object + >>> model.architecture + Architecture object + >>> model.prediction_head + Prediction head object + >>> model.chain + Sequential object + >>> model.num_in_channels + 1 + >>> model.input_shape + Coordinate(1, 1, 1) + >>> model.eval_input_shape + Coordinate(1, 1, 1) + >>> model.num_out_channels + 1 + >>> model.output_shape + Coordinate(1, 1, 1) + >>> model.eval_activation + None + Note: + The output shape is the spatial shape of the model, i.e., not accounting for channels and batch dimensions. Update the weight initialization to use Kaiming. + The eval_activation is only applied during evaluation. This is particularly useful if you want to train with something like BCELossWithLogits, since you want to avoid applying softmax while training, but apply it during evaluation. + To Do: Put this somewhere better, there might be conv layers that aren't follwed by relus. + """ super().__init__() self.architecture = architecture @@ -48,21 +105,71 @@ def __init__( torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu") def forward(self, x): + """ + Forward pass of the model. + + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor. + Examples: + >>> model = Model(architecture, prediction_head) + >>> model.forward(x) + torch.Tensor + Note: + The eval_activation is only applied during evaluation. This is particularly useful if you want to train with something like BCELossWithLogits, since you want to avoid applying softmax while training, but apply it during evaluation. + + """ result = self.chain(x) if not self.training and self.eval_activation is not None: result = self.eval_activation(result) return result def compute_output_shape(self, input_shape: Coordinate) -> Tuple[int, Coordinate]: - """Compute the spatial shape (i.e., not accounting for channels and + """ + Compute the spatial shape (i.e., not accounting for channels and batch dimensions) of this model, when fed a tensor of the given spatial - shape as input.""" + shape as input. + + Args: + input_shape (Coordinate): The shape of the input tensor. + Returns: + Tuple[int, Coordinate]: The number of output channels and the spatial shape of the output. + Raises: + AssertionError: If the input_shape is not a Coordinate. + Examples: + >>> model = Model(architecture, prediction_head) + >>> model.compute_output_shape(input_shape) + (1, Coordinate(1, 1, 1)) + Note: + The output shape is the spatial shape of the model, i.e., not accounting for channels and batch dimensions. + """ return self.__get_output_shape(input_shape, self.num_in_channels) def __get_output_shape( self, input_shape: Coordinate, in_channels: int ) -> Tuple[int, Coordinate]: + """ + Compute the spatial shape (i.e., not accounting for channels and + batch dimensions) of this model, when fed a tensor of the given spatial + shape as input. + + Args: + input_shape (Coordinate): The shape of the input tensor. + in_channels (int): The number of input channels. + Returns: + Tuple[int, Coordinate]: The number of output channels and the spatial shape of the output. + Raises: + AssertionError: If the input_shape is not a Coordinate. + Examples: + >>> model = Model(architecture, prediction_head) + >>> model.__get_output_shape(input_shape, in_channels) + (1, Coordinate(1, 1, 1)) + Note: + The output shape is the spatial shape of the model, i.e., not accounting for channels and batch dimensions. + + """ device = torch.device("cpu") for parameter in self.parameters(): device = parameter.device @@ -74,4 +181,20 @@ def __get_output_shape( return out.shape[1], Coordinate(out.shape[2:]) def scale(self, voxel_size: Coordinate) -> Coordinate: + """ + Scale the model by the given voxel size. + + Args: + voxel_size (Coordinate): The voxel size to scale the model by. + Returns: + Coordinate: The scaled model. + Raises: + AssertionError: If the voxel_size is not a Coordinate. + Examples: + >>> model = Model(architecture, prediction_head) + >>> model.scale(voxel_size) + Coordinate(1, 1, 1) + Note: + The output shape is the spatial shape of the model, i.e., not accounting for channels and batch dimensions. + """ return self.architecture.scale(voxel_size) diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index f425876e3..e7afb03ad 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -11,6 +11,36 @@ class Run: + """ + Class representing a run in the experiment. A run is a combination of a task, architecture, trainer, datasplit, + model, optimizer, training stats, and validation scores. It also contains the name of the run, the number of + iterations to train for, and the interval at which to validate. It also contains a start object that can be used to + initialize the model with preloaded weights. The run object can be used to move the optimizer to a specified device. + + Attributes: + name (str): The name of the run. + train_until (int): The number of iterations to train for. + validation_interval (int): The interval at which to validate. + task (Task): The task object. + architecture (Architecture): The architecture object. + trainer (Trainer): The trainer object. + datasplit (DataSplit): The datasplit object. + model (Model): The model object. + optimizer (torch.optim.Optimizer): The optimizer object. + training_stats (TrainingStats): The training stats object. + validation_scores (ValidationScores): The validation scores object. + start (Start): The start object. + Methods: + move_optimizer(device: torch.device, empty_cuda_cache: bool) -> None: + Moves the optimizer to the specified device. + get_validation_scores(run_config) -> ValidationScores: + Static method to get the validation scores without initializing model, optimizer, trainer, etc. + Note: + The iteration stats list is structured as follows: + - The outer list contains the stats for each iteration. + - The inner list contains the stats for each training iteration. + """ + name: str train_until: int validation_interval: int @@ -26,8 +56,44 @@ class Run: training_stats: TrainingStats validation_scores: ValidationScores - def __init__(self, run_config): + def __init__(self, run_config, load_starter_model: bool = True): + """ + Initializes a Run object. + + Args: + run_config: The configuration for the run. + Raises: + AssertionError: If the task, architecture, trainer, or datasplit types are not specified in the run_config. + Examples: + >>> run = Run(run_config) + >>> run.name + 'run_name' + >>> run.train_until + 100 + >>> run.validation_interval + 10 + >>> run.task + Task object + >>> run.architecture + Architecture object + >>> run.trainer + Trainer object + >>> run.datasplit + DataSplit object + >>> run.model + Model object + >>> run.optimizer + Optimizer object + >>> run.training_stats + TrainingStats object + >>> run.validation_scores + ValidationScores object + >>> run.start + Start object + + """ self.name = run_config.name + self._config = run_config self.train_until = run_config.num_iterations self.validation_interval = run_config.validation_interval @@ -41,7 +107,9 @@ def __init__(self, run_config): self.task = task_type(run_config.task_config) self.architecture = architecture_type(run_config.architecture_config) self.trainer = trainer_type(run_config.trainer_config) - self.datasplit = datasplit_type(run_config.datasplit_config) + + # lazy load datasplit + self._datasplit = None # combined pieces self.model = self.task.create_model(self.architecture) @@ -49,9 +117,11 @@ def __init__(self, run_config): # tracking self.training_stats = TrainingStats() - self.validation_scores = ValidationScores( - self.task.parameters, self.datasplit.validate, self.task.evaluation_scores - ) + self._validation_scores = None + + if not load_starter_model: + self.start = None + return # preloaded weights from previous run self.start = ( @@ -73,10 +143,40 @@ def __init__(self, run_config): self.start.initialize_weights(self.model, new_head=new_head) + @property + def datasplit(self): + if self._datasplit is None: + self._datasplit = self._config.datasplit_config.datasplit_type( + self._config.datasplit_config + ) + return self._datasplit + + @property + def validation_scores(self): + if self._validation_scores is None: + self._validation_scores = ValidationScores( + self.task.parameters, + self.datasplit.validate, + self.task.evaluation_scores, + ) + return self._validation_scores + @staticmethod def get_validation_scores(run_config) -> ValidationScores: """ - Static method to avoid having to initialize model, optimizer, trainer, etc. + Static method to get the validation scores without initializing model, optimizer, trainer, etc. + + Args: + run_config: The configuration for the run. + Returns: + The validation scores. + Raises: + AssertionError: If the task or datasplit types are not specified in the run_config. + Examples: + >>> validation_scores = Run.get_validation_scores(run_config) + >>> validation_scores + ValidationScores object + """ task_type = run_config.task_config.task_type datasplit_type = run_config.datasplit_config.datasplit_type @@ -91,6 +191,20 @@ def get_validation_scores(run_config) -> ValidationScores: def move_optimizer( self, device: torch.device, empty_cuda_cache: bool = False ) -> None: + """ + Moves the optimizer to the specified device. + + Args: + device: The device to move the optimizer to. + empty_cuda_cache: Whether to empty the CUDA cache after moving the optimizer. + Raises: + AssertionError: If the optimizer state is not a dictionary. + Examples: + >>> run.move_optimizer(device) + >>> run.optimizer + Optimizer object + + """ for state in self.optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): diff --git a/dacapo/experiments/starts/cosem_start.py b/dacapo/experiments/starts/cosem_start.py index fb943b45a..6e5c513d0 100644 --- a/dacapo/experiments/starts/cosem_start.py +++ b/dacapo/experiments/starts/cosem_start.py @@ -8,6 +8,31 @@ def get_model_setup(run): + """ + Loads the model setup from the dacapo store for the specified run. The + model setup includes the classes_channels, voxel_size_input and + voxel_size_output. + + Args: + run : str + The run for which the model setup is to be loaded. + Returns: + classes_channels : list + The classes_channels of the model. + voxel_size_input : list + The voxel_size_input of the model. + voxel_size_output : list + The voxel_size_output of the model. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which + is logged and handled by training the model without head matching. + Examples: + >>> classes_channels, voxel_size_input, voxel_size_output = get_model_setup(run) + Notes: + This function is called by the CosemStart class to load the model setup + from the dacapo store for the specified run. + """ try: model = cosem.load_model(run) if hasattr(model, "classes_channels"): @@ -31,7 +56,56 @@ def get_model_setup(run): class CosemStart(Start): + """ + A class to represent the starting point for tasks. This class inherits + from the Start class and is used to load the weights of the starter model + used for finetuning. The weights are loaded from the dacapo store for the + specified run and criterion. + + Attributes: + run : str + The run to be used as a starting point for tasks. + criterion : str + The criterion to be used for choosing weights from run. + name : str + The name of the run and criterion. + channels : list + The classes_channels of the model. + Methods: + __init__(start_config) + Initializes the CosemStart class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + check() + Checks if the checkpoint for the specified run and criterion exists. + initialize_weights(model, new_head=None) + Retrieves the weights from the dacapo store and load them into + the model. + Notes: + This class is used to represent the starting point for tasks. The weights + of the starter model used for finetuning are loaded from the dacapo store. + """ + def __init__(self, start_config): + """ + Initializes the CosemStart class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + + Args: + start_config : obj + The configuration to initialize the CosemStart class. + Raises: + Exception + If the model setup cannot be loaded, an Exception is thrown which + is logged and handled by training the model without head matching. + Examples: + >>> start = CosemStart(start_config) + Notes: + This function is called by the CosemStart class to initialize the + CosemStart class with specified config to run the initialization of + weights for a model associated with a specific criterion. + """ self.run = start_config.run self.criterion = start_config.criterion self.name = f"{self.run}/{self.criterion}" @@ -43,6 +117,19 @@ def __init__(self, start_config): self.channels = channels def check(self): + """ + Checks if the checkpoint for the specified run and criterion exists. + + Raises: + Exception + If the checkpoint does not exist, an Exception is thrown which + is logged and handled by training the model without head matching. + Examples: + >>> check() + Notes: + This function is called by the CosemStart class to check if the + checkpoint for the specified run and criterion exists. + """ from dacapo.store.create_store import create_weights_store weights_store = create_weights_store() @@ -56,6 +143,29 @@ def check(self): logger.info(f"Checkpoint for {self.name} exists.") def initialize_weights(self, model, new_head=None): + """ + Retrieves the weights from the dacapo store and load them into + the model. + + Args: + model : obj + The model to which the weights are to be loaded. + new_head : list + The labels of the new head. + Returns: + model : obj + The model with the weights loaded from the dacapo store. + Raises: + RuntimeError + If weights of a non-existing or mismatched layer are being + loaded, a RuntimeError exception is thrown which is logged + and handled by loading only the common layers from weights. + Examples: + >>> model = initialize_weights(model, new_head) + Notes: + This function is called by the CosemStart class to retrieve the weights + from the dacapo store and load them into the model. + """ self.check() from dacapo.store.create_store import create_weights_store diff --git a/dacapo/experiments/starts/cosem_start_config.py b/dacapo/experiments/starts/cosem_start_config.py index de16477b1..6d841b121 100644 --- a/dacapo/experiments/starts/cosem_start_config.py +++ b/dacapo/experiments/starts/cosem_start_config.py @@ -5,9 +5,30 @@ @attr.s class CosemStartConfig(StartConfig): - """Starter for COSEM pretained models. This is a subclass of `StartConfig` and + """ + Starter for COSEM pretained models. This is a subclass of `StartConfig` and should be used to initialize the model with pretrained weights from a previous run. + + The weights are loaded from the dacapo store for the specified run. The + configuration is used to initialize the weights for the model associated with + a specific criterion. + + Attributes: + run : str + The run to be used as a starting point for tasks. + criterion : str + The criterion to be used for choosing weights from run. + Methods: + __init__(start_config) + Initializes the CosemStartConfig class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + Examples: + >>> start_config = CosemStartConfig(run="run_1", criterion="best") + Notes: + This class is used to represent the configuration for running tasks. + """ start_type = CosemStart diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index 8b667b56b..6afabbdcf 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -12,6 +12,37 @@ def match_heads(model, head_weights, old_head, new_head): + """ + Matches the head of the model to the new head by copying the weights + of the old head to the new head. The weights of the old head are + copied to the new head by matching the labels of the old head to the + labels of the new head. + + Args: + model : obj + The model to which the weights are to be loaded. + head_weights : dict + The weights of the old head. + old_head : list + The labels of the old head. + new_head : list + The labels of the new head. + Returns: + model : obj + The model with the weights of the old head copied to the new + head. + Raises: + RuntimeError + If the old head is not found in the new head, a RuntimeError + exception is thrown which is logged and handled by loading + only the common layers from weights. + Examples: + >>> model = match_heads(model, head_weights, old_head, new_head) + Notes: + This function is called by the Start class to match the head of + the model to the new head by copying the weights of the old head + to the new head. + """ for label in new_head: if label in old_head: logger.warning(f"matching head for {label}.") @@ -25,6 +56,45 @@ def match_heads(model, head_weights, old_head, new_head): def _set_weights(model, weights, run, criterion, old_head=None, new_head=None): + """ + Loads the weights of the model from the dacapo store into the model. If + the old head and new head are provided, the weights of the old head are + copied to the new head by matching the labels of the old head to the labels + of the new head. If the old head is not found in the new head, a RuntimeError + exception is thrown which is logged and handled by loading only the common + layers from weights. + + Args: + model : obj + The model to which the weights are to be loaded. + weights : obj + The weights of the model retrieved from the dacapo store. + run : str + The specified run to retrieve weights for the model. + criterion : str + The policy that was used to decide when to store the weights. + old_head : list + The labels of the old head. + new_head : list + The labels of the new head. + Returns: + model : obj + The model with the weights loaded from the dacapo store. + Raises: + RuntimeError + If weights of a non-existing or mismatched layer are being + loaded, a RuntimeError exception is thrown which is logged + and handled by loading only the common layers from weights. + Examples: + >>> model = _set_weights(model, weights, run, criterion, old_head, new_head) + Notes: + This function is called by the Start class to load the weights of the + model from the dacapo store into the model. If the old head and new head + are provided, the weights of the old head are copied to the new head by + matching the labels of the old head to the labels of the new head. If the + old head is not found in the new head, a RuntimeError exception is thrown + which is logged and handled by loading only the common layers from weights. + """ logger.warning( f"loading weights from run {run}, criterion: {criterion}, old_head {old_head}, new_head: {new_head}" ) @@ -79,12 +149,24 @@ class Start(ABC): This class interfaces with the dacapo store to retrieve and load the weights of the starter model used for finetuning. - Attributes - ---------- - run : str - The specified run to retrieve weights for the model. - criterion : str - The policy that was used to decide when to store the weights. + Attributes: + run : str + The specified run to retrieve weights for the model. + criterion : str + The policy that was used to decide when to store the weights. + channels : int + The number of channels in the input data. + Methods: + __init__(start_config) + Initializes the Start class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + initialize_weights(model, new_head=None) + Retrieves the weights from the dacapo store and load them into + the model. + Notes: + This class is used to retrieve and load the weights of the starter + model used for finetuning from the dacapo store. """ def __init__(self, start_config): @@ -93,14 +175,23 @@ def __init__(self, start_config): initialization of weights for a model associated with a specific criterion. - Parameters - ---------- - start_config : obj - An object containing configuration details for the model - initialization. + Args: + start_config : obj + The configuration to initialize the Start class. + Examples: + >>> start = Start(start_config) + Notes: + This function is called by the Start class to initialize the + Start class with specified config to run the initialization of + weights for a model associated with a specific criterion. """ - self.run = start_config.run - self.criterion = start_config.criterion + # Old version return a dict, new version return an object, this line is to support both + if isinstance(start_config, dict): + self.run = start_config["run"] + self.criterion = start_config["criterion"] + else: + self.run = start_config.run + self.criterion = start_config.criterion self.channels = None @@ -112,16 +203,25 @@ def initialize_weights(self, model, new_head=None): """ Retrieves the weights from the dacapo store and load them into the model. - Parameters - ---------- - model : obj - The model to which the weights are to be loaded. - Raises - ------ - RuntimeError - If weights of a non-existing or mismatched layer are being - loaded, a RuntimeError exception is thrown which is logged - and handled by loading only the common layers from weights. + + Args: + model : obj + The model to which the weights are to be loaded. + new_head : list + The labels of the new head. + Returns: + model : obj + The model with the weights loaded from the dacapo store. + Raises: + RuntimeError + If weights of a non-existing or mismatched layer are being + loaded, a RuntimeError exception is thrown which is logged + and handled by loading only the common layers from weights. + Examples: + >>> model = start.initialize_weights(model, new_head) + Notes: + This function is called by the Start class to retrieve the weights + from the dacapo store and load them into the model. """ from dacapo.store.create_store import create_weights_store diff --git a/dacapo/experiments/starts/start_config.py b/dacapo/experiments/starts/start_config.py index 60ae35ff9..0c961f250 100644 --- a/dacapo/experiments/starts/start_config.py +++ b/dacapo/experiments/starts/start_config.py @@ -5,16 +5,22 @@ @attr.s class StartConfig: """ - A class to represent the configuration for running tasks. - - Attributes - ---------- - run : str - The run to be used as a starting point for tasks. - - criterion : str - The criterion to be used for choosing weights from run. + A class to represent the configuration for running tasks. This class + interfaces with the dacapo store to retrieve and load the weights of the + starter model used for finetuning. + Attributes: + run : str + The run to be used as a starting point for tasks. + criterion : str + The criterion to be used for choosing weights from run. + Methods: + __init__(start_config) + Initializes the StartConfig class with specified config to run the + initialization of weights for a model associated with a specific + criterion. + Notes: + This class is used to represent the configuration for running tasks. """ start_type = Start diff --git a/dacapo/experiments/tasks/affinities_task.py b/dacapo/experiments/tasks/affinities_task.py index 08cbe7888..a355288d2 100644 --- a/dacapo/experiments/tasks/affinities_task.py +++ b/dacapo/experiments/tasks/affinities_task.py @@ -6,10 +6,36 @@ class AffinitiesTask(Task): - """This is a task for generating voxel affinities.""" + """ + This is a task for generating voxel affinities. It uses an `AffinitiesPredictor` for prediction, + an `AffinitiesLoss` for loss calculation, a `WatershedPostProcessor` for post-processing, and an + `InstanceEvaluator` for evaluation. + + Attributes: + predictor: AffinitiesPredictor object + loss: AffinitiesLoss object + post_processor: WatershedPostProcessor object + evaluator: InstanceEvaluator object + Methods: + __init__(self, task_config): Initializes all components for the affinities task. + Notes: + This is a subclass of Task. + + """ def __init__(self, task_config): - """Create a `DummyTask` from a `DummyTaskConfig`.""" + """ + Create a `DummyTask` from a `DummyTaskConfig`. + + Args: + task_config: The configuration for the task. + Returns: + A `DummyTask` object. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> task = AffinitiesTask(task_config) + """ self.predictor = AffinitiesPredictor( neighborhood=task_config.neighborhood, diff --git a/dacapo/experiments/tasks/affinities_task_config.py b/dacapo/experiments/tasks/affinities_task_config.py index 0bbb8f4bc..5e22f2a0d 100644 --- a/dacapo/experiments/tasks/affinities_task_config.py +++ b/dacapo/experiments/tasks/affinities_task_config.py @@ -10,8 +10,23 @@ @attr.s class AffinitiesTaskConfig(TaskConfig): - """This is a Affinities task config used for generating and + """ + This is a Affinities task config used for generating and evaluating voxel affinities for instance segmentations. + + Attributes: + neighborhood: A list of Coordinate objects. + lsds: Whether or not to train lsds along with your affinities. + lsds_to_affs_weight_ratio: If training with lsds, set how much they should be weighted compared to affs. + affs_weight_clipmin: The minimum value for affinities weights. + affs_weight_clipmax: The maximum value for affinities weights. + lsd_weight_clipmin: The minimum value for lsds weights. + lsd_weight_clipmax: The maximum value for lsds weights. + background_as_object: Whether to treat the background as a separate object. + Methods: + verify(self) -> Tuple[bool, str]: This method verifies the AffinitiesTaskConfig + Notes: + This is a subclass of TaskConfig. """ task_type = AffinitiesTask diff --git a/dacapo/experiments/tasks/distance_task.py b/dacapo/experiments/tasks/distance_task.py index e31976b37..a7c747cdd 100644 --- a/dacapo/experiments/tasks/distance_task.py +++ b/dacapo/experiments/tasks/distance_task.py @@ -19,6 +19,10 @@ class DistanceTask(Task): loss: MSELoss object post_processor: ThresholdPostProcessor object evaluator: BinarySegmentationEvaluator object + Methods: + __init__(self, task_config): Initializes attributes of DistanceTask + Notes: + This is a subclass of Task. """ def __init__(self, task_config): @@ -29,6 +33,10 @@ def __init__(self, task_config): Args: task_config: Object of task configuration + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> task = DistanceTask(task_config) """ self.predictor = DistancePredictor( diff --git a/dacapo/experiments/tasks/distance_task_config.py b/dacapo/experiments/tasks/distance_task_config.py index 9a2ab9cbe..bbd3585ec 100644 --- a/dacapo/experiments/tasks/distance_task_config.py +++ b/dacapo/experiments/tasks/distance_task_config.py @@ -16,6 +16,21 @@ class DistanceTaskConfig(TaskConfig): affinities is you can get a denser signal, i.e. 1 misclassified pixel in an affinity prediction could merge 2 otherwise very distinct objects, this cannot happen with distances. + + Attributes: + channels: A list of channel names. + clip_distance: Maximum distance to consider for false positive/negatives. + tol_distance: Tolerance distance for counting false positives/negatives + scale_factor: The amount by which to scale distances before applying a tanh normalization. + mask_distances: Whether or not to mask out regions where the true distance to + object boundary cannot be known. This is anywhere that the distance to crop boundary + is less than the distance to object boundary. + clipmin: The minimum value for distance weights. + clipmax: The maximum value for distance weights. + Methods: + verify(self) -> Tuple[bool, str]: This method verifies the DistanceTaskConfig object. + Notes: + This is a subclass of TaskConfig. """ task_type = DistanceTask diff --git a/dacapo/experiments/tasks/dummy_task.py b/dacapo/experiments/tasks/dummy_task.py index ebdb51206..bfaacffe1 100644 --- a/dacapo/experiments/tasks/dummy_task.py +++ b/dacapo/experiments/tasks/dummy_task.py @@ -11,26 +11,38 @@ class DummyTask(Task): post-processing, and evaluator) for the dummy task. Primarily used for testing purposes. Inherits from the Task class. - Attributes - ---------- - predictor : Object - Instance of DummyPredictor class. - loss : Object - Instance of DummyLoss class. - post_processor : Object - Instance of DummyPostProcessor class. - evaluator : Object - Instance of DummyEvaluator class. + Attributes: + predictor : Object + Instance of DummyPredictor class. + loss : Object + Instance of DummyLoss class. + post_processor : Object + Instance of DummyPostProcessor class. + evaluator : Object + Instance of DummyEvaluator class. + Methods: + __init__(self, task_config) + Initializes all components for the dummy task. + Notes: + This is a subclass of Task. """ def __init__(self, task_config): """ Initializes dummy task with predictor, loss function, post processor and evaluator. - Parameters - ---------- - task_config : Object - Configurations for the task, contains `embedding_dims` and `detection_threshold` + Parameters: + task_config : Object + Configurations for the task, contains `embedding_dims` and `detection_threshold` + Args: + task_config : TaskConfig + The configuration of the task. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> task = DummyTask(task_config) + Notes: + This is a base class for all tasks that use dummy components. """ self.predictor = DummyPredictor(task_config.embedding_dims) diff --git a/dacapo/experiments/tasks/dummy_task_config.py b/dacapo/experiments/tasks/dummy_task_config.py index 1cca4b31a..769bc1b6b 100644 --- a/dacapo/experiments/tasks/dummy_task_config.py +++ b/dacapo/experiments/tasks/dummy_task_config.py @@ -18,6 +18,10 @@ class DummyTaskConfig(TaskConfig): task_type (cls): The type of task. Here, set to DummyTask. embedding_dims (int): A dummy attribute represented as an integer. detection_threshold (float): Another dummy attribute represented as a float. + Methods: + verify(self) -> Tuple[bool, str]: This method verifies the DummyTaskConfig object. + Note: + This is a subclass of TaskConfig. """ @@ -35,5 +39,9 @@ def verify(self) -> Tuple[bool, str]: Returns: tuple: A tuple containing a boolean status and a string message. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> valid, reason = task_config.verify() """ return False, "This is a DummyTaskConfig and is never valid" diff --git a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluation_scores.py index a8eb68dce..6e75205fd 100644 --- a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluation_scores.py @@ -6,6 +6,60 @@ @attr.s class BinarySegmentationEvaluationScores(EvaluationScores): + """ + Class representing evaluation scores for binary segmentation tasks. + + The metrics include: + - Dice coefficient: 2 * |A ∩ B| / |A| + |B| ; where A and B are the binary segmentations + - Jaccard coefficient: |A ∩ B| / |A ∪ B| ; where A and B are the binary segmentations + - Hausdorff distance: max(h(A, B), h(B, A)) ; where h(A, B) is the Hausdorff distance between A and B + - False negative rate: |A - B| / |A| ; where A and B are the binary segmentations + - False positive rate: |B - A| / |B| ; where A and B are the binary segmentations + - False discovery rate: |B - A| / |A| ; where A and B are the binary segmentations + - VOI: Variation of Information; split and merge errors combined into a single measure of segmentation quality + - Mean false distance: 0.5 * (mean false positive distance + mean false negative distance) + - Mean false negative distance: mean distance of false negatives + - Mean false positive distance: mean distance of false positives + - Mean false distance clipped: 0.5 * (mean false positive distance clipped + mean false negative distance clipped) ; clipped to a maximum distance + - Mean false negative distance clipped: mean distance of false negatives clipped ; clipped to a maximum distance + - Mean false positive distance clipped: mean distance of false positives clipped ; clipped to a maximum distance + - Precision with tolerance: TP / (TP + FP) ; where TP and FP are the true and false positives within a tolerance distance + - Recall with tolerance: TP / (TP + FN) ; where TP and FN are the true and false positives within a tolerance distance + - F1 score with tolerance: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives within a tolerance distance + - Precision: TP / (TP + FP) ; where TP and FP are the true and false positives + - Recall: TP / (TP + FN) ; where TP and FN are the true and false positives + - F1 score: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives + + Attributes: + dice (float): The Dice coefficient. + jaccard (float): The Jaccard index. + hausdorff (float): The Hausdorff distance. + false_negative_rate (float): The false negative rate. + false_negative_rate_with_tolerance (float): The false negative rate with tolerance. + false_positive_rate (float): The false positive rate. + false_discovery_rate (float): The false discovery rate. + false_positive_rate_with_tolerance (float): The false positive rate with tolerance. + voi (float): The variation of information. + mean_false_distance (float): The mean false distance. + mean_false_negative_distance (float): The mean false negative distance. + mean_false_positive_distance (float): The mean false positive distance. + mean_false_distance_clipped (float): The mean false distance clipped. + mean_false_negative_distance_clipped (float): The mean false negative distance clipped. + mean_false_positive_distance_clipped (float): The mean false positive distance clipped. + precision_with_tolerance (float): The precision with tolerance. + recall_with_tolerance (float): The recall with tolerance. + f1_score_with_tolerance (float): The F1 score with tolerance. + precision (float): The precision. + recall (float): The recall. + f1_score (float): The F1 score. + Methods: + store_best(criterion: str) -> bool: Whether or not to store the best weights/validation blocks for this criterion. + higher_is_better(criterion: str) -> bool: Determines whether a higher value is better for a given criterion. + bounds(criterion: str) -> Tuple[Union[int, float, None], Union[int, float, None]]: Determines the bounds for a given criterion. + Notes: + The evaluation scores are stored as attributes of the class. The class also contains methods to determine whether a higher value is better for a given criterion, whether or not to store the best weights/validation blocks for a given criterion, and the bounds for a given criterion. + """ + dice: float = attr.ib(default=float("nan")) jaccard: float = attr.ib(default=float("nan")) hausdorff: float = attr.ib(default=float("nan")) @@ -54,6 +108,24 @@ class BinarySegmentationEvaluationScores(EvaluationScores): @staticmethod def store_best(criterion: str) -> bool: + """ + Determines whether or not to store the best weights/validation blocks for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + bool: True if the best weights/validation blocks should be stored, False otherwise. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> BinarySegmentationEvaluationScores.store_best("dice") + False + >>> BinarySegmentationEvaluationScores.store_best("f1_score") + True + Notes: + The method returns True if the criterion is recognized and False otherwise. Whether or not to store the best weights/validation blocks for a given criterion is determined by the mapping dictionary. + + """ # Whether or not to store the best weights/validation blocks for this # criterion. mapping = { @@ -83,6 +155,23 @@ def store_best(criterion: str) -> bool: @staticmethod def higher_is_better(criterion: str) -> bool: + """ + Determines whether a higher value is better for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + bool: True if a higher value is better, False otherwise. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> BinarySegmentationEvaluationScores.higher_is_better("dice") + True + >>> BinarySegmentationEvaluationScores.higher_is_better("f1_score") + True + Notes: + The method returns True if the criterion is recognized and False otherwise. Whether a higher value is better for a given criterion is determined by the mapping dictionary. + """ mapping = { "dice": True, "jaccard": True, @@ -112,6 +201,23 @@ def higher_is_better(criterion: str) -> bool: def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + Determines the bounds for a given criterion. The bounds are used to determine the best value for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + Tuple[Union[int, float, None], Union[int, float, None]]: The lower and upper bounds for the criterion. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> BinarySegmentationEvaluationScores.bounds("dice") + (0, 1) + >>> BinarySegmentationEvaluationScores.bounds("hausdorff") + (0, nan) + Notes: + The method returns the lower and upper bounds for the criterion. The bounds are determined by the mapping dictionary. + """ mapping = { "dice": (0, 1), "jaccard": (0, 1), @@ -140,15 +246,53 @@ def bounds( @attr.s class MultiChannelBinarySegmentationEvaluationScores(EvaluationScores): + """ + Class representing evaluation scores for multi-channel binary segmentation tasks. + + Attributes: + channel_scores (List[Tuple[str, BinarySegmentationEvaluationScores]]): The list of channel scores. + Methods: + higher_is_better(criterion: str) -> bool: Determines whether a higher value is better for a given criterion. + store_best(criterion: str) -> bool: Whether or not to store the best weights/validation blocks for this criterion. + bounds(criterion: str) -> Tuple[Union[int, float, None], Union[int, float, None]]: Determines the bounds for a given criterion. + Notes: + The evaluation scores are stored as attributes of the class. The class also contains methods to determine whether a higher value is better for a given criterion, whether or not to store the best weights/validation blocks for a given criterion, and the bounds for a given criterion. + """ + channel_scores: List[Tuple[str, BinarySegmentationEvaluationScores]] = attr.ib() def __attrs_post_init__(self): + """ + Post-initialization method to set attributes for each channel and criterion. + + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> channel_scores = [("channel1", BinarySegmentationEvaluationScores()), ("channel2", BinarySegmentationEvaluationScores())] + >>> MultiChannelBinarySegmentationEvaluationScores(channel_scores) + Notes: + The method sets attributes for each channel and criterion. The attributes are stored as attributes of the class. + """ for channel, scores in self.channel_scores: for criteria in BinarySegmentationEvaluationScores.criteria: setattr(self, f"{channel}__{criteria}", getattr(scores, criteria)) @property def criteria(self): + """ + Returns a list of all criteria for all channels. + + Returns: + List[str]: The list of criteria. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> channel_scores = [("channel1", BinarySegmentationEvaluationScores()), ("channel2", BinarySegmentationEvaluationScores())] + >>> MultiChannelBinarySegmentationEvaluationScores(channel_scores).criteria + Notes: + The method returns a list of all criteria for all channels. The criteria are stored as attributes of the class. + """ + return [ f"{channel}__{criteria}" for channel, _ in self.channel_scores @@ -157,11 +301,45 @@ def criteria(self): @staticmethod def higher_is_better(criterion: str) -> bool: + """ + Determines whether a higher value is better for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + bool: True if a higher value is better, False otherwise. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> MultiChannelBinarySegmentationEvaluationScores.higher_is_better("channel1__dice") + True + >>> MultiChannelBinarySegmentationEvaluationScores.higher_is_better("channel1__f1_score") + True + Notes: + The method returns True if the criterion is recognized and False otherwise. Whether a higher value is better for a given criterion is determined by the mapping dictionary. + """ _, criterion = criterion.split("__") return BinarySegmentationEvaluationScores.higher_is_better(criterion) @staticmethod def store_best(criterion: str) -> bool: + """ + Determines whether or not to store the best weights/validation blocks for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + bool: True if the best weights/validation blocks should be stored, False otherwise. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> MultiChannelBinarySegmentationEvaluationScores.store_best("channel1__dice") + False + >>> MultiChannelBinarySegmentationEvaluationScores.store_best("channel1__f1_score") + True + Notes: + The method returns True if the criterion is recognized and False otherwise. Whether or not to store the best weights/validation blocks for a given criterion is determined by the mapping dictionary. + """ _, criterion = criterion.split("__") return BinarySegmentationEvaluationScores.store_best(criterion) @@ -169,5 +347,22 @@ def store_best(criterion: str) -> bool: def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + Determines the bounds for a given criterion. The bounds are used to determine the best value for a given criterion. + + Args: + criterion (str): The evaluation criterion. + Returns: + Tuple[Union[int, float, None], Union[int, float, None]]: The lower and upper bounds for the criterion. + Raises: + ValueError: If the criterion is not recognized. + Examples: + >>> MultiChannelBinarySegmentationEvaluationScores.bounds("channel1__dice") + (0, 1) + >>> MultiChannelBinarySegmentationEvaluationScores.bounds("channel1__hausdorff") + (0, nan) + Notes: + The method returns the lower and upper bounds for the criterion. The bounds are determined by the mapping dictionary. + """ _, criterion = criterion.split("__") return BinarySegmentationEvaluationScores.bounds(criterion) diff --git a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py index 542083c4d..d6ade542e 100644 --- a/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/binary_segmentation_evaluator.py @@ -23,12 +23,78 @@ class BinarySegmentationEvaluator(Evaluator): """ - Given a binary segmentation, compute various metrics to determine their similarity. + Given a binary segmentation, compute various metrics to determine their similarity. The metrics include: + - Dice coefficient: 2 * |A ∩ B| / |A| + |B| ; where A and B are the binary segmentations + - Jaccard coefficient: |A ∩ B| / |A ∪ B| ; where A and B are the binary segmentations + - Hausdorff distance: max(h(A, B), h(B, A)) ; where h(A, B) is the Hausdorff distance between A and B + - False negative rate: |A - B| / |A| ; where A and B are the binary segmentations + - False positive rate: |B - A| / |B| ; where A and B are the binary segmentations + - False discovery rate: |B - A| / |A| ; where A and B are the binary segmentations + - VOI: Variation of Information; split and merge errors combined into a single measure of segmentation quality + - Mean false distance: 0.5 * (mean false positive distance + mean false negative distance) + - Mean false negative distance: mean distance of false negatives + - Mean false positive distance: mean distance of false positives + - Mean false distance clipped: 0.5 * (mean false positive distance clipped + mean false negative distance clipped) ; clipped to a maximum distance + - Mean false negative distance clipped: mean distance of false negatives clipped ; clipped to a maximum distance + - Mean false positive distance clipped: mean distance of false positives clipped ; clipped to a maximum distance + - Precision with tolerance: TP / (TP + FP) ; where TP and FP are the true and false positives within a tolerance distance + - Recall with tolerance: TP / (TP + FN) ; where TP and FN are the true and false positives within a tolerance distance + - F1 score with tolerance: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives within a tolerance distance + - Precision: TP / (TP + FP) ; where TP and FP are the true and false positives + - Recall: TP / (TP + FN) ; where TP and FN are the true and false positives + - F1 score: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives + + Attributes: + clip_distance : float + the clip distance + tol_distance : float + the tolerance distance + channels : List[str] + the channels + criteria : List[str] + the evaluation criteria + Methods: + evaluate(output_array_identifier, evaluation_array) + Evaluate the output array against the evaluation array. + score + Return the evaluation scores. + Note: + The BinarySegmentationEvaluator class is used to evaluate the performance of a binary segmentation task. + The class provides methods to evaluate the output array against the evaluation array and return the evaluation scores. + All evaluation scores should inherit from this class. + + Clip distance is the maximum distance between the ground truth and the predicted segmentation for a pixel to be considered a false positive. + Tolerance distance is the maximum distance between the ground truth and the predicted segmentation for a pixel to be considered a true positive. + Channels are the channels of the binary segmentation. + Criteria are the evaluation criteria. + """ criteria = ["jaccard", "voi"] def __init__(self, clip_distance: float, tol_distance: float, channels: List[str]): + """ + Initialize the binary segmentation evaluator. + + Args: + clip_distance : float + the clip distance + tol_distance : float + the tolerance distance + channels : List[str] + the channels + Raises: + ValueError: if the clip distance is not valid + Examples: + >>> binary_segmentation_evaluator = BinarySegmentationEvaluator(clip_distance=200, tol_distance=40, channels=["channel1", "channel2"]) + Note: + This function is used to initialize the binary segmentation evaluator. + + Clip distance is the maximum distance between the ground truth and the predicted segmentation for a pixel to be considered a false positive. + Tolerance distance is the maximum distance between the ground truth and the predicted segmentation for a pixel to be considered a true positive. + Channels are the channels of the binary segmentation. + Criteria are the evaluation criteria. + """ self.clip_distance = clip_distance self.tol_distance = tol_distance self.channels = channels @@ -38,6 +104,28 @@ def __init__(self, clip_distance: float, tol_distance: float, channels: List[str ] def evaluate(self, output_array_identifier, evaluation_array): + """ + Evaluate the output array against the evaluation array. + + Args: + output_array_identifier : str + the identifier of the output array + evaluation_array : ZarrArray + the evaluation array + Returns: + BinarySegmentationEvaluationScores or MultiChannelBinarySegmentationEvaluationScores + the evaluation scores + Raises: + ValueError: if the output array identifier is not valid + Examples: + >>> binary_segmentation_evaluator = BinarySegmentationEvaluator(clip_distance=200, tol_distance=40, channels=["channel1", "channel2"]) + >>> output_array_identifier = "output_array" + >>> evaluation_array = ZarrArray.open_from_array_identifier("evaluation_array") + >>> binary_segmentation_evaluator.evaluate(output_array_identifier, evaluation_array) + BinarySegmentationEvaluationScores(dice=0.0, jaccard=0.0, hausdorff=0.0, false_negative_rate=0.0, false_positive_rate=0.0, false_discovery_rate=0.0, voi=0.0, mean_false_distance=0.0, mean_false_negative_distance=0.0, mean_false_positive_distance=0.0, mean_false_distance_clipped=0.0, mean_false_negative_distance_clipped=0.0, mean_false_positive_distance_clipped=0.0, precision_with_tolerance=0.0, recall_with_tolerance=0.0, f1_score_with_tolerance=0.0, precision=0.0, recall=0.0, f1_score=0.0) + Note: + This function is used to evaluate the output array against the evaluation array. + """ output_array = ZarrArray.open_from_array_identifier(output_array_identifier) evaluation_data = evaluation_array[evaluation_array.roi].squeeze() output_data = output_array[output_array.roi].squeeze() @@ -135,12 +223,50 @@ def evaluate(self, output_array_identifier, evaluation_array): @property def score(self): + """ + Return the evaluation scores. + + Returns: + BinarySegmentationEvaluationScores or MultiChannelBinarySegmentationEvaluationScores + the evaluation scores + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> binary_segmentation_evaluator = BinarySegmentationEvaluator(clip_distance=200, tol_distance=40, channels=["channel1", "channel2"]) + >>> binary_segmentation_evaluator.score + BinarySegmentationEvaluationScores(dice=0.0, jaccard=0.0, hausdorff=0.0, false_negative_rate=0.0, false_positive_rate=0.0, false_discovery_rate=0.0, voi=0.0, mean_false_distance=0.0, mean_false_negative_distance=0.0, mean_false_positive_distance=0.0, mean_false_distance_clipped=0.0, mean_false_negative_distance_clipped=0.0, mean_false_positive_distance_clipped=0.0, precision_with_tolerance=0.0, recall_with_tolerance=0.0, f1_score_with_tolerance=0.0, precision=0.0, recall=0.0, f1_score=0.0) + Note: + This function is used to return the evaluation scores. + """ channel_scores = [] for channel in self.channels: channel_scores.append((channel, BinarySegmentationEvaluationScores())) return MultiChannelBinarySegmentationEvaluationScores(channel_scores) def _evaluate(self, output_data, evaluation_data, voxel_size): + """ + Evaluate the output array against the evaluation array. + + Args: + output_data : np.ndarray + the output data + evaluation_data : np.ndarray + the evaluation data + voxel_size : Tuple[float, float, float] + the voxel size + Returns: + BinarySegmentationEvaluationScores + the evaluation scores + Examples: + >>> binary_segmentation_evaluator = BinarySegmentationEvaluator(clip_distance=200, tol_distance=40, channels=["channel1", "channel2"]) + >>> output_data = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> evaluation_data = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> voxel_size = (1, 1, 1) + >>> binary_segmentation_evaluator._evaluate(output_data, evaluation_data, voxel_size) + BinarySegmentationEvaluationScores(dice=0.0, jaccard=0.0, hausdorff=0.0, false_negative_rate=0.0, false_positive_rate=0.0, false_discovery_rate=0.0, voi=0.0, mean_false_distance=0.0, mean_false_negative_distance=0.0, mean_false_positive_distance=0.0, mean_false_distance_clipped=0.0, mean_false_negative_distance_clipped=0.0, mean_false_positive_distance_clipped=0.0, precision_with_tolerance=0.0, recall_with_tolerance=0.0, f1_score_with_tolerance=0.0, precision=0.0, recall=0.0, f1_score=0.0) + Note: + This function is used to evaluate the output array against the evaluation array. + """ evaluator = ArrayEvaluator( evaluation_data, output_data, @@ -178,6 +304,90 @@ def _evaluate(self, output_data, evaluation_data, voxel_size): class ArrayEvaluator: + """ + Given a binary segmentation, compute various metrics to determine their similarity. The metrics include: + - Dice coefficient: 2 * |A ∩ B| / |A| + |B| ; where A and B are the binary segmentations + - Jaccard coefficient: |A ∩ B| / |A ∪ B| ; where A and B are the binary segmentations + - Hausdorff distance: max(h(A, B), h(B, A)) ; where h(A, B) is the Hausdorff distance between A and B + - False negative rate: |A - B| / |A| ; where A and B are the binary segmentations + - False positive rate: |B - A| / |B| ; where A and B are the binary segmentations + - False discovery rate: |B - A| / |A| ; where A and B are the binary segmentations + - VOI: Variation of Information; split and merge errors combined into a single measure of segmentation quality + - Mean false distance: 0.5 * (mean false positive distance + mean false negative distance) + - Mean false negative distance: mean distance of false negatives + - Mean false positive distance: mean distance of false positives + - Mean false distance clipped: 0.5 * (mean false positive distance clipped + mean false negative distance clipped) ; clipped to a maximum distance + - Mean false negative distance clipped: mean distance of false negatives clipped ; clipped to a maximum distance + - Mean false positive distance clipped: mean distance of false positives clipped ; clipped to a maximum distance + - Precision with tolerance: TP / (TP + FP) ; where TP and FP are the true and false positives within a tolerance distance + - Recall with tolerance: TP / (TP + FN) ; where TP and FN are the true and false positives within a tolerance distance + - F1 score with tolerance: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives within a tolerance distance + - Precision: TP / (TP + FP) ; where TP and FP are the true and false positives + - Recall: TP / (TP + FN) ; where TP and FN are the true and false positives + - F1 score: 2 * (Recall * Precision) / (Recall + Precision) ; where Recall and Precision are the true and false positives + + Attributes: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + truth_empty : bool + whether the truth binary segmentation is empty + test_empty : bool + whether the test binary segmentation is empty + cremieval : CremiEvaluator + the cremi evaluator + resolution : Tuple[float, float, float] + the resolution + Methods: + dice + Return the Dice coefficient. + jaccard + Return the Jaccard coefficient. + hausdorff + Return the Hausdorff distance. + false_negative_rate + Return the false negative rate. + false_positive_rate + Return the false positive rate. + false_discovery_rate + Return the false discovery rate. + precision + Return the precision. + recall + Return the recall. + f1_score + Return the F1 score. + voi + Return the VOI. + mean_false_distance + Return the mean false distance. + mean_false_negative_distance + Return the mean false negative distance. + mean_false_positive_distance + Return the mean false positive distance. + mean_false_distance_clipped + Return the mean false distance clipped. + mean_false_negative_distance_clipped + Return the mean false negative distance clipped. + mean_false_positive_distance_clipped + Return the mean false positive distance clipped. + false_positive_rate_with_tolerance + Return the false positive rate with tolerance. + false_negative_rate_with_tolerance + Return the false negative rate with tolerance. + precision_with_tolerance + Return the precision with tolerance. + recall_with_tolerance + Return the recall with tolerance. + f1_score_with_tolerance + Return the F1 score with tolerance. + Note: + The ArrayEvaluator class is used to evaluate the performance of a binary segmentation task. + The class provides methods to evaluate the truth binary segmentation against the test binary segmentation. + All evaluation scores should inherit from this class. + """ + def __init__( self, truth_binary, @@ -187,6 +397,38 @@ def __init__( metric_params, resolution, ): + """ + Initialize the array evaluator. + + Args: + truth_binary : np.ndarray + the truth binary segmentation + test_binary : np.ndarray + the test binary segmentation + truth_empty : bool + whether the truth binary segmentation is empty + test_empty : bool + whether the test binary segmentation is empty + metric_params : Dict[str, float] + the metric parameters + resolution : Tuple[float, float, float] + the resolution + Returns: + ArrayEvaluator + the array evaluator + Raises: + ValueError: if the truth binary segmentation is not valid + Examples: + >>> truth_binary = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> test_binary = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> truth_empty = False + >>> test_empty = False + >>> metric_params = {"clip_distance": 200, "tol_distance": 40} + >>> resolution = (1, 1, 1) + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + Note: + This function is used to initialize the array evaluator. + """ self.truth = truth_binary.astype(np.uint8) self.test = test_binary.astype(np.uint8) self.truth_empty = truth_empty @@ -202,35 +444,148 @@ def __init__( @lazy_property.LazyProperty def truth_itk(self): + """ + A SimpleITK image of the truth binary segmentation. + + Returns: + sitk.Image + the truth binary segmentation as a SimpleITK image + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.truth_itk + ::value_type *' at 0x7f8b1c0b3f30> > + Note: + This function is used to return the truth binary segmentation as a SimpleITK image. + """ res = sitk.GetImageFromArray(self.truth) res.SetSpacing(self.resolution) return res @lazy_property.LazyProperty def test_itk(self): + """ + A SimpleITK image of the test binary segmentation. + + Args: + test : np.ndarray + the test binary segmentation + resolution : Tuple[float, float, float] + the resolution + Returns: + sitk.Image + the test binary segmentation as a SimpleITK image + Raises: + ValueError: if the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.test_itk + ::value_type *' at 0x7f8b1c0b3f30> > + Note: + This function is used to return the test binary segmentation as a SimpleITK image. + """ res = sitk.GetImageFromArray(self.test) res.SetSpacing(self.resolution) return res @lazy_property.LazyProperty def overlap_measures_filter(self): + """ + A SimpleITK filter to compute overlap measures. + + Args: + truth_itk : sitk.Image + the truth binary segmentation as a SimpleITK image + test_itk : sitk.Image + the test binary segmentation as a SimpleITK image + Returns: + sitk.LabelOverlapMeasuresImageFilter + the overlap measures filter + Raises: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.overlap_measures_filter + > + Note: + This function is used to return the overlap measures filter. + """ overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter() overlap_measures_filter.Execute(self.test_itk, self.truth_itk) return overlap_measures_filter def dice(self): + """ + The Dice coefficient. + + Args: + truth_itk : sitk.Image + the truth binary segmentation as a SimpleITK image + test_itk : sitk.Image + the test binary segmentation as a SimpleITK image + Returns: + float + the Dice coefficient + Raises: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.dice() + 0.0 + Note: + This function is used to return the Dice coefficient. + """ if (not self.truth_empty) or (not self.test_empty): return self.overlap_measures_filter.GetDiceCoefficient() else: return np.nan def jaccard(self): + """ + The Jaccard coefficient. + + Args: + truth_itk : sitk.Image + the truth binary segmentation as a SimpleITK image + test_itk : sitk.Image + the test binary segmentation as a SimpleITK image + Returns: + float + the Jaccard coefficient + Raises: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.jaccard() + 0.0 + Note: + This function is used to return the Jaccard coefficient. + + """ if (not self.truth_empty) or (not self.test_empty): return self.overlap_measures_filter.GetJaccardCoefficient() else: return np.nan def hausdorff(self): + """ + The Hausdorff distance. + + Args: + None + Returns: + float: the Hausdorff distance + Raises: + None + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.hausdorff() + 0.0 + Note: + This function is used to return the Hausdorff distance between the truth binary segmentation and the test binary segmentation. + + If either the truth or test binary segmentation is empty, the function returns 0. + Otherwise, it calculates the Hausdorff distance using the HausdorffDistanceImageFilter from the SimpleITK library. + """ if self.truth_empty and self.test_empty: return 0 elif not self.truth_empty and not self.test_empty: @@ -241,12 +596,47 @@ def hausdorff(self): return np.nan def false_negative_rate(self): + """ + The false negative rate. + + Returns: + float + the false negative rate + Returns: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.false_negative_rate() + 0.0 + Note: + This function is used to return the false negative rate. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.overlap_measures_filter.GetFalseNegativeError() def false_positive_rate(self): + """ + The false positive rate. + + Args: + truth_itk : sitk.Image + the truth binary segmentation as a SimpleITK image + test_itk : sitk.Image + the test binary segmentation as a SimpleITK image + Returns: + float + the false positive rate + Raises: + ValueError: if the truth binary segmentation or the test binary segmentation is not valid + Examples: + >>> array_evaluator = ArrayEvaluator(truth_binary, test_binary, truth_empty, test_empty, metric_params, resolution) + >>> array_evaluator.false_positive_rate() + 0.0 + Note: + This function is used to return the false positive rate. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -255,12 +645,45 @@ def false_positive_rate(self): ) def false_discovery_rate(self): + """ + Calculate the false discovery rate (FDR) for the binary segmentation evaluation. + + Returns: + float: The false discovery rate. + Raises: + None + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_discovery_rate() + 0.25 + Note: + The false discovery rate is a measure of the proportion of false positives among the predicted positive samples. + It is calculated as the ratio of false positives to the sum of true positives and false positives. + If either the ground truth or the predicted segmentation is empty, the FDR is set to NaN. + """ if (not self.truth_empty) or (not self.test_empty): return self.overlap_measures_filter.GetFalsePositiveError() else: return np.nan def precision(self): + """ + Calculate the precision of the binary segmentation evaluation. + + Returns: + float: The precision value. + Raises: + None. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.precision() + 0.75 + Note: + Precision is a measure of the accuracy of the positive predictions made by the model. + It is calculated as the ratio of true positives to the total number of positive predictions. + If either the ground truth or the predicted values are empty, the precision value will be NaN. + """ + if self.truth_empty or self.test_empty: return np.nan else: @@ -269,6 +692,21 @@ def precision(self): return float(np.float32(tp) / np.float32(pred_pos)) def recall(self): + """ + Calculate the recall metric for binary segmentation evaluation. + + Returns: + float: The recall value. + Raises: + None + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.recall() + 0.75 + Note: + Recall is a measure of the ability of a binary classifier to identify all positive samples. + It is calculated as the ratio of true positives to the total number of actual positives. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -277,6 +715,23 @@ def recall(self): return float(np.float32(tp) / np.float32(cond_pos)) def f1_score(self): + """ + Calculate the F1 score for binary segmentation evaluation. + + Returns: + float: The F1 score value. + Raises: + None. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.f1_score() + 0.75 + Note: + The F1 score is the harmonic mean of precision and recall. + It is a measure of the balance between precision and recall, providing a single metric to evaluate the model's performance. + + If either the ground truth or the predicted values are empty, the F1 score will be NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -288,6 +743,25 @@ def f1_score(self): return 2 * (rec * prec) / (rec + prec) def voi(self): + """ + Calculate the Variation of Information (VOI) for binary segmentation evaluation. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The VOI value. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.voi() + 0.75 + Note: + The VOI is a measure of the similarity between two segmentations. + It combines the split and merge errors into a single measure of segmentation quality. + If either the ground truth or the predicted values are empty, the VOI will be NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -297,66 +771,273 @@ def voi(self): return voi_split + voi_merge def mean_false_distance(self): + """ + Calculate the mean false distance between the ground truth and the test results. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false distance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_distance() + 0.25 + Note: + - This method returns np.nan if either the ground truth or the test results are empty. + - The mean false distance is a measure of the average distance between the false positive pixels in the test results and the nearest true positive pixels in the ground truth. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_distance def mean_false_negative_distance(self): + """ + Calculate the mean false negative distance between the ground truth and the test results. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false negative distance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_negative_distance() + 0.25 + Note: + This method returns np.nan if either the ground truth or the test results are empty. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_negative_distance def mean_false_positive_distance(self): + """ + Calculate the mean false positive distance. + + This method calculates the mean false positive distance between the ground truth and the test results. + If either the ground truth or the test results are empty, the method returns NaN. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false positive distance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_positive_distance() + 0.5 + Note: + The mean false positive distance is a measure of the average distance between false positive pixels in the + test results and the corresponding ground truth pixels. It is commonly used to evaluate the performance of + binary segmentation algorithms. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_positive_distance def mean_false_distance_clipped(self): + """ + Calculate the mean false distance (clipped) between the ground truth and the test results. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false distance (clipped) value. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_distance_clipped() + 0.123 + Note: + This method returns np.nan if either the ground truth or the test results are empty. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_distance_clipped def mean_false_negative_distance_clipped(self): + """ + Calculate the mean false negative distance, with clipping. + + This method calculates the mean false negative distance between the ground truth and the test results. + The distance is clipped to avoid extreme values. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false negative distance with clipping. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_negative_distance_clipped() + 0.123 + Note: + - The mean false negative distance is a measure of the average distance between the false negative pixels in the ground truth and the test results. + - Clipping the distance helps to avoid extreme values that may skew the overall evaluation. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_negative_distances_clipped def mean_false_positive_distance_clipped(self): + """ + Calculate the mean false positive distance, with clipping. + + This method calculates the mean false positive distance between the ground truth and the test results, + taking into account any clipping that may have been applied. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The mean false positive distance with clipping. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_positive_distance_clipped() + 0.25 + Note: + - The mean false positive distance is a measure of the average distance between false positive pixels + in the test results and the corresponding ground truth pixels. + - If either the ground truth or the test results are empty, the method returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.mean_false_positive_distances_clipped def false_positive_rate_with_tolerance(self): + """ + Calculate the false positive rate with tolerance. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The false positive rate with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positive_rate_with_tolerance() + 0.25 + Note: + This method calculates the false positive rate with tolerance by comparing the truth and test data. + If either the truth or test data is empty, it returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.false_positive_rate_with_tolerance def false_negative_rate_with_tolerance(self): + """ + Calculate the false negative rate with tolerance. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + The false negative rate with tolerance as a floating-point number. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_rate_with_tolerance() + 0.25 + Note: + This method calculates the false negative rate with tolerance, which is a measure of the proportion of false negatives in a binary segmentation evaluation. If either the ground truth or the test data is empty, the method returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.false_negative_rate_with_tolerance def precision_with_tolerance(self): + """ + Calculate the precision with tolerance. + + This method calculates the precision with tolerance by comparing the truth and test data. + Precision is the ratio of true positives to the sum of true positives and false positives. + Tolerance is a distance threshold within which two pixels are considered to be a match. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The precision with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.precision_with_tolerance() + 0.75 + Note: + - Precision is a measure of the accuracy of the positive predictions. + - If either the ground truth or the test data is empty, the method returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.precision_with_tolerance def recall_with_tolerance(self): + """ + Calculate the recall with tolerance for the binary segmentation evaluator. + + Returns: + float: The recall with tolerance value. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.recall_with_tolerance() + 0.75 + Note: + This method calculates the recall with tolerance, which is a measure of how well the binary segmentation evaluator performs. It returns the recall with tolerance value as a float. If either the truth or test data is empty, it returns NaN. + """ if self.truth_empty or self.test_empty: return np.nan else: return self.cremieval.recall_with_tolerance def f1_score_with_tolerance(self): + """ + Calculate the F1 score with tolerance. + + This method calculates the F1 score with tolerance between the ground truth and the test results. + If either the ground truth or the test results are empty, the function returns NaN. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + Returns: + float: The F1 score with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.f1_score_with_tolerance() + 0.85 + Note: + The F1 score is a measure of a test's accuracy. It considers both the precision and recall of the test to compute the score. + The tolerance parameter allows for a certain degree of variation between the ground truth and the test results. + """ if self.truth_empty or self.test_empty: return np.nan else: @@ -364,9 +1045,95 @@ def f1_score_with_tolerance(self): class CremiEvaluator: + """ + Evaluate the performance of a binary segmentation task using the CREMI score. + The CREMI score is a measure of the similarity between two binary segmentations. + + Attributes: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + sampling : Tuple[float, float, float] + the sampling resolution + clip_distance : float + the maximum distance to clip + tol_distance : float + the tolerance distance + Methods: + false_positive_distances + Return the false positive distances. + false_positives_with_tolerance + Return the false positives with tolerance. + false_positive_rate_with_tolerance + Return the false positive rate with tolerance. + false_negatives_with_tolerance + Return the false negatives with tolerance. + false_negative_rate_with_tolerance + Return the false negative rate with tolerance. + true_positives_with_tolerance + Return the true positives with tolerance. + precision_with_tolerance + Return the precision with tolerance. + recall_with_tolerance + Return the recall with tolerance. + f1_score_with_tolerance + Return the F1 score with tolerance. + mean_false_positive_distances_clipped + Return the mean false positive distances clipped. + mean_false_negative_distances_clipped + Return the mean false negative distances clipped. + mean_false_positive_distance + Return the mean false positive distance. + false_negative_distances + Return the false negative distances. + mean_false_negative_distance + Return the mean false negative distance. + mean_false_distance + Return the mean false distance. + mean_false_distance_clipped + Return the mean false distance clipped. + Note: + - The CremiEvaluator class is used to evaluate the performance of a binary segmentation task using the CREMI score. + - True and test binary segmentations are compared to calculate various evaluation metrics. + - The class provides methods to evaluate the performance of the binary segmentation task. + - Toleration distance is used to determine the tolerance level for the evaluation. + - Clip distance is used to clip the distance values to avoid extreme values. + - All evaluation scores should inherit from this class. + """ + def __init__( self, truth, test, sampling=(1, 1, 1), clip_distance=200, tol_distance=40 ): + """ + Initialize the Cremi evaluator. + + Args: + truth : np.ndarray + the truth binary segmentation + test : np.ndarray + the test binary segmentation + sampling : Tuple[float, float, float] + the sampling resolution + clip_distance : float + the maximum distance to clip + tol_distance : float + the tolerance distance + Returns: + CremiEvaluator + the Cremi evaluator + Raises: + ValueError: if the truth binary segmentation is not valid + Examples: + >>> truth = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> test = np.array([[[0, 1], [1, 0]], [[0, 1], [1, 0]]]) + >>> sampling = (1, 1, 1) + >>> clip_distance = 200 + >>> tol_distance = 40 + >>> cremi_evaluator = CremiEvaluator(truth, test, sampling, clip_distance, tol_distance) + Note: + This function is used to initialize the Cremi evaluator. + """ self.test = test self.truth = truth self.sampling = sampling @@ -375,37 +1142,176 @@ def __init__( @lazy_property.LazyProperty def test_mask(self): + """ + Generate a binary mask for the test data. + + Args: + test : np.ndarray + the test binary segmentation + Returns: + test_mask (ndarray): A binary mask indicating the regions of interest in the test data. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.test = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]]) + >>> evaluator.test_mask() + array([[False, True, False], + [ True, True, True], + [False, True, False]]) + Note: + This method assumes that the background class is represented by the constant `BG`. + """ # todo: more involved masking test_mask = self.test == BG return test_mask @lazy_property.LazyProperty def truth_mask(self): + """ + Returns a binary mask indicating the truth values. + + Args: + truth : np.ndarray + the truth binary segmentation + Returns: + truth_mask (ndarray): A binary mask where True indicates the truth values and False indicates other values. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> mask = evaluator.truth_mask() + >>> print(mask) + [[ True True False] + [False True False] + [ True False False]] + Note: + The truth mask is computed by comparing the truth values with a predefined background value (BG). + """ truth_mask = self.truth == BG return truth_mask @lazy_property.LazyProperty def test_edt(self): + """ + Calculate the Euclidean Distance Transform (EDT) of the test mask. + + Args: + self.test_mask (ndarray): The binary test mask. + self.sampling (float or sequence of floats): The pixel spacing or sampling along each dimension. + Returns: + ndarray: The Euclidean Distance Transform of the test mask. + Examples: + # Example 1: + test_mask = np.array([[0, 0, 1], + [1, 1, 1], + [0, 0, 0]]) + sampling = 1.0 + result = test_edt(test_mask, sampling) + # Output: array([[1. , 1. , 0. ], + # [0. , 0. , 0. ], + # [1. , 1. , 1.41421356]]) + + # Example 2: + test_mask = np.array([[0, 1, 0], + [1, 0, 1], + [0, 1, 0]]) + sampling = 0.5 + result = test_edt(test_mask, sampling) + # Output: array([[0.5 , 0. , 0.5 ], + # [0. , 0.70710678, 0. ], + # [0.5 , 0. , 0.5 ]]) + + Note: + The Euclidean Distance Transform (EDT) calculates the distance from each pixel in the binary mask to the nearest boundary pixel. It is commonly used in image processing and computer vision tasks, such as edge detection and shape analysis. + """ test_edt = scipy.ndimage.distance_transform_edt(self.test_mask, self.sampling) return test_edt @lazy_property.LazyProperty def truth_edt(self): + """ + Calculate the Euclidean Distance Transform (EDT) of the ground truth mask. + + Args: + self.truth_mask (ndarray): The binary ground truth mask. + self.sampling (float or sequence of floats): The pixel spacing or sampling along each dimension. + Returns: + ndarray: The Euclidean Distance Transform of the ground truth mask. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> edt = evaluator.truth_edt() + Note: + The Euclidean Distance Transform (EDT) calculates the distance from each pixel in the binary mask to the nearest boundary pixel. It is commonly used in image processing and computer vision tasks. + """ truth_edt = scipy.ndimage.distance_transform_edt(self.truth_mask, self.sampling) return truth_edt @lazy_property.LazyProperty def false_positive_distances(self): + """ + Calculate the distances of false positive pixels from the ground truth segmentation. + + Args: + self.test_mask (ndarray): The binary test mask. + self.truth_edt (ndarray): The Euclidean Distance Transform of the ground truth segmentation. + Returns: + numpy.ndarray: An array containing the distances of false positive pixels from the ground truth segmentation. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> distances = evaluator.false_positive_distances() + >>> print(distances) + [1.2, 0.8, 2.5, 1.0] + Note: + This method assumes that the ground truth segmentation and the test mask have been initialized. + The ground truth segmentation is stored in the `truth_edt` attribute, and the test mask is obtained by inverting the `test_mask` attribute. + """ test_bin = np.invert(self.test_mask) false_positive_distances = self.truth_edt[test_bin] return false_positive_distances @lazy_property.LazyProperty def false_positives_with_tolerance(self): + """ + Calculate the number of false positives with a given tolerance distance. + + Args: + self.false_positive_distances (ndarray): The distances of false positive pixels from the ground truth segmentation. + self.tol_distance (float): The tolerance distance. + Returns: + int: The number of false positives with a distance greater than the tolerance distance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positive_distances = [1, 2, 3] + >>> evaluator.tol_distance = 2 + >>> false_positives = evaluator.false_positives_with_tolerance() + >>> print(false_positives) + 1 + Note: + The `false_positive_distances` attribute should be initialized before calling this method. + + """ return np.sum(self.false_positive_distances > self.tol_distance) @lazy_property.LazyProperty def false_positive_rate_with_tolerance(self): + """ + Calculate the false positive rate with tolerance. + + This method calculates the false positive rate by dividing the number of false positives with tolerance + by the number of condition negatives. + + Args: + self.false_positives_with_tolerance (int): The number of false positives with tolerance. + self.truth_mask (ndarray): The binary ground truth mask. + Returns: + float: The false positive rate with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positives_with_tolerance = 10 + >>> evaluator.truth_mask = np.array([0, 1, 0, 1, 0]) + >>> evaluator.false_positive_rate_with_tolerance() + 0.5 + Note: + The false positive rate with tolerance is a measure of the proportion of false positive predictions + with respect to the total number of condition negatives. It is commonly used in binary segmentation tasks. + """ condition_negative = np.sum(self.truth_mask) return float( np.float32(self.false_positives_with_tolerance) @@ -414,10 +1320,51 @@ def false_positive_rate_with_tolerance(self): @lazy_property.LazyProperty def false_negatives_with_tolerance(self): + """ + Calculate the number of false negatives with tolerance. + + Args: + self.false_negative_distances (ndarray): The distances of false negative pixels from the ground truth segmentation. + self.tol_distance (float): The tolerance distance. + Returns: + int: The number of false negatives with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_distances = [1, 2, 3] + >>> evaluator.tol_distance = 2 + >>> false_negatives = evaluator.false_negatives_with_tolerance() + >>> print(false_negatives) + 1 + Note: + False negatives are cases where the model incorrectly predicts the absence of a positive class. + The tolerance distance is used to determine whether a false negative is within an acceptable range. + + """ return np.sum(self.false_negative_distances > self.tol_distance) @lazy_property.LazyProperty def false_negative_rate_with_tolerance(self): + """ + Calculate the false negative rate with tolerance. + + This method calculates the false negative rate by dividing the number of false negatives + with tolerance by the number of condition positives. + + Args: + self.false_negatives_with_tolerance (int): The number of false negatives with tolerance. + self.false_negative_distances (ndarray): The distances of false negative pixels from the ground truth segmentation. + Returns: + float: The false negative rate with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_distances = [1, 2, 3] + >>> evaluator.false_negatives_with_tolerance = 2 + >>> evaluator.false_negative_rate_with_tolerance() + 0.6666666666666666 + Note: + The false negative rate with tolerance is a measure of the proportion of condition positives + that are incorrectly classified as negatives, considering a certain tolerance level. + """ condition_positive = len(self.false_negative_distances) return float( np.float32(self.false_negatives_with_tolerance) @@ -426,6 +1373,29 @@ def false_negative_rate_with_tolerance(self): @lazy_property.LazyProperty def true_positives_with_tolerance(self): + """ + Calculate the number of true positives with tolerance. + + Args: + self.test_mask (ndarray): The test binary segmentation mask. + self.truth_mask (ndarray): The ground truth binary segmentation mask. + self.false_negatives_with_tolerance (int): The number of false negatives with tolerance. + self.false_positives_with_tolerance (int): The number of false positives with tolerance. + Returns: + int: The number of true positives with tolerance. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.test_mask = np.array([[0, 1], [1, 0]]) + >>> evaluator.truth_mask = np.array([[0, 1], [1, 0]]) + >>> evaluator.false_negatives_with_tolerance = 1 + >>> evaluator.false_positives_with_tolerance = 1 + >>> true_positives = evaluator.true_positives_with_tolerance() + >>> print(true_positives) + 2 + Note: + True positives are cases where the model correctly predicts the presence of a positive class. + The tolerance distance is used to determine whether a true positive is within an acceptable range. + """ all_pos = np.sum(np.invert(self.test_mask & self.truth_mask)) return ( all_pos @@ -435,6 +1405,31 @@ def true_positives_with_tolerance(self): @lazy_property.LazyProperty def precision_with_tolerance(self): + """ + Calculate the precision with tolerance. + + This method calculates the precision with tolerance by dividing the number of true positives + with tolerance by the sum of true positives with tolerance and false positives with tolerance. + + Args: + self.true_positives_with_tolerance (int): The number of true positives with tolerance. + self.false_positives_with_tolerance (int): The number of false positives with tolerance. + Returns: + float: The precision with tolerance. + Raises: + ZeroDivisionError: If the sum of true positives with tolerance and false positives with tolerance is zero. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.true_positives_with_tolerance = 10 + >>> evaluator.false_positives_with_tolerance = 5 + >>> evaluator.precision_with_tolerance() + 0.6666666666666666 + Note: + The precision with tolerance is a measure of the proportion of true positives with tolerance + out of the total number of predicted positives with tolerance. + It indicates how well the binary segmentation evaluator performs in terms of correctly identifying positive samples. + If the sum of true positives with tolerance and false positives with tolerance is zero, the precision with tolerance is undefined and a ZeroDivisionError is raised. + """ return float( np.float32(self.true_positives_with_tolerance) / np.float32( @@ -444,6 +1439,23 @@ def precision_with_tolerance(self): @lazy_property.LazyProperty def recall_with_tolerance(self): + """ + A measure of the ability of a binary classifier to identify all positive samples. + + Args: + self.true_positives_with_tolerance (int): The number of true positives with tolerance. + self.false_negatives_with_tolerance (int): The number of false negatives with tolerance. + Returns: + float: The recall with tolerance value. + Raises: + ZeroDivisionError: If the sum of true positives with tolerance and false negatives with tolerance is zero. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.recall_with_tolerance() + 0.75 + Note: + This method calculates the recall with tolerance, which is a measure of how well the binary segmentation evaluator performs. It returns the recall with tolerance value as a float. If either the truth or test data is empty, it returns NaN. + """ return float( np.float32(self.true_positives_with_tolerance) / np.float32( @@ -453,6 +1465,28 @@ def recall_with_tolerance(self): @lazy_property.LazyProperty def f1_score_with_tolerance(self): + """ + Calculate the F1 score with tolerance. + + Args: + self.recall_with_tolerance (float): The recall with tolerance value. + self.precision_with_tolerance (float): The precision with tolerance value. + Returns: + float: The F1 score with tolerance. + Raises: + ZeroDivisionError: If both the recall with tolerance and precision with tolerance are zero. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.recall_with_tolerance = 0.8 + >>> evaluator.precision_with_tolerance = 0.9 + >>> evaluator.f1_score_with_tolerance() + 0.8571428571428571 + Note: + The F1 score is a measure of a test's accuracy. It considers both the precision and recall of the test to compute the score. + The F1 score with tolerance is calculated using the formula: + F1 = 2 * (recall_with_tolerance * precision_with_tolerance) / (recall_with_tolerance + precision_with_tolerance) + If both recall_with_tolerance and precision_with_tolerance are 0, the F1 score with tolerance will be NaN. + """ if self.recall_with_tolerance == 0 and self.precision_with_tolerance == 0: return np.nan else: @@ -464,6 +1498,26 @@ def f1_score_with_tolerance(self): @lazy_property.LazyProperty def mean_false_positive_distances_clipped(self): + """ + Calculate the mean of the false positive distances, clipped to a maximum distance. + + Args: + self.false_positive_distances (ndarray): The distances of false positive pixels from the ground truth segmentation. + self.clip_distance (float): The maximum distance to clip. + Returns: + float: The mean of the false positive distances, clipped to a maximum distance. + Raises: + ValueError: If the clip distance is not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positive_distances = [1, 2, 3, 4, 5] + >>> evaluator.clip_distance = 3 + >>> evaluator.mean_false_positive_distances_clipped() + 2.5 + Note: + + This method calculates the mean of the false positive distances, where the distances are clipped to a maximum distance. The `false_positive_distances` attribute should be set before calling this method. The `clip_distance` attribute determines the maximum distance to which the distances are clipped. + """ mean_false_positive_distance_clipped = np.mean( np.clip(self.false_positive_distances, None, self.clip_distance) ) @@ -471,6 +1525,25 @@ def mean_false_positive_distances_clipped(self): @lazy_property.LazyProperty def mean_false_negative_distances_clipped(self): + """ + Calculate the mean of the false negative distances, clipped to a maximum distance. + + Args: + self.false_negative_distances (ndarray): The distances of false negative pixels from the ground truth segmentation. + self.clip_distance (float): The maximum distance to clip. + Returns: + float: The mean of the false negative distances, clipped to a maximum distance. + Raises: + ValueError: If the clip distance is not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_distances = [1, 2, 3, 4, 5] + >>> evaluator.clip_distance = 3 + >>> evaluator.mean_false_negative_distances_clipped() + 2.5 + Note: + This method calculates the mean of the false negative distances, where the distances are clipped to a maximum distance. The `false_negative_distances` attribute should be set before calling this method. The `clip_distance` attribute determines the maximum distance to which the distances are clipped. + """ mean_false_negative_distance_clipped = np.mean( np.clip(self.false_negative_distances, None, self.clip_distance) ) @@ -478,22 +1551,98 @@ def mean_false_negative_distances_clipped(self): @lazy_property.LazyProperty def mean_false_positive_distance(self): + """ + Calculate the mean false positive distance. + + This method calculates the mean distance between the false positive points and the ground truth points. + + Args: + self.false_positive_distances (ndarray): The distances of false positive pixels from the ground truth mask. + Returns: + float: The mean false positive distance. + Raises: + ValueError: If the false positive distances are not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_positive_distances = [1.2, 3.4, 2.1] + >>> evaluator.mean_false_positive_distance() + 2.2333333333333334 + Note: + The false positive distances should be set before calling this method using the `false_positive_distances` attribute. + """ mean_false_positive_distance = np.mean(self.false_positive_distances) return mean_false_positive_distance @lazy_property.LazyProperty def false_negative_distances(self): + """ + Calculate the distances of false negative pixels from the ground truth mask. + + Args: + self.truth_mask (ndarray): The binary ground truth mask. + Returns: + numpy.ndarray: An array containing the distances of false negative pixels. + Raises: + ValueError: If the truth mask is not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> distances = evaluator.false_negative_distances() + >>> print(distances) + [0.5, 1.0, 1.5, 2.0] + Note: + This method assumes that the ground truth mask and the test mask have already been set. + """ truth_bin = np.invert(self.truth_mask) false_negative_distances = self.test_edt[truth_bin] return false_negative_distances @lazy_property.LazyProperty def mean_false_negative_distance(self): + """ + Calculate the mean false negative distance. + + Args: + self.false_negative_distances (ndarray): The distances of false negative pixels from the ground truth mask. + Returns: + float: The mean false negative distance. + Raises: + ValueError: If the false negative distances are not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.false_negative_distances = [1.2, 3.4, 2.1] + >>> evaluator.mean_false_negative_distance() + 2.2333333333333334 + Note: + The mean false negative distance is calculated as the average of all false negative distances. + + """ mean_false_negative_distance = np.mean(self.false_negative_distances) return mean_false_negative_distance @lazy_property.LazyProperty def mean_false_distance(self): + """ + Calculate the mean false distance. + + This method calculates the mean false distance by taking the average of the mean false positive distance + and the mean false negative distance. + + Args: + self.mean_false_positive_distance (float): The mean false positive distance. + self.mean_false_negative_distance (float): The mean false negative distance. + Returns: + float: The calculated mean false distance. + Raises: + ValueError: If the mean false positive distance or the mean false negative distance is not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_distance() + 5.0 + Note: + The mean false distance is a metric used to evaluate the performance of a binary segmentation model. + It provides a measure of the average distance between false positive and false negative predictions. + + """ mean_false_distance = 0.5 * ( self.mean_false_positive_distance + self.mean_false_negative_distance ) @@ -501,6 +1650,28 @@ def mean_false_distance(self): @lazy_property.LazyProperty def mean_false_distance_clipped(self): + """ + Calculates the mean false distance clipped. + + This method calculates the mean false distance clipped by taking the average of the mean false positive distances + clipped and the mean false negative distances clipped. + + Args: + self.mean_false_positive_distances_clipped (float): The mean false positive distances clipped. + self.mean_false_negative_distances_clipped (float): The mean false negative distances clipped. + Returns: + float: The calculated mean false distance clipped. + Raises: + ValueError: If the mean false positive distances clipped or the mean false negative distances clipped are not set. + Examples: + >>> evaluator = BinarySegmentationEvaluator() + >>> evaluator.mean_false_distance_clipped() + 2.5 + Note: + The mean false distance clipped is calculated as 0.5 * (mean_false_positive_distances_clipped + + mean_false_negative_distances_clipped). + + """ mean_false_distance_clipped = 0.5 * ( self.mean_false_positive_distances_clipped + self.mean_false_negative_distances_clipped diff --git a/dacapo/experiments/tasks/evaluators/dummy_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/dummy_evaluation_scores.py index eb7879cbc..ffca2fc74 100644 --- a/dacapo/experiments/tasks/evaluators/dummy_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/dummy_evaluation_scores.py @@ -6,6 +6,25 @@ @attr.s class DummyEvaluationScores(EvaluationScores): + """ + The evaluation scores for the dummy task. The scores include the frizz level and blipp score. A higher frizz level indicates more frizz, while a higher blipp score indicates better performance. + + Attributes: + frizz_level : float + the frizz level + blipp_score : float + the blipp score + Methods: + higher_is_better(criterion) + Return whether higher is better for the given criterion. + bounds(criterion) + Return the bounds for the given criterion. + store_best(criterion) + Return whether to store the best score for the given criterion. + Note: + The DummyEvaluationScores class is used to store the evaluation scores for the dummy task. The class also provides methods to determine whether higher is better for a given criterion, the bounds for a given criterion, and whether to store the best score for a given criterion. + """ + criteria = ["frizz_level", "blipp_score"] frizz_level: float = attr.ib(default=float("nan")) @@ -13,6 +32,23 @@ class DummyEvaluationScores(EvaluationScores): @staticmethod def higher_is_better(criterion: str) -> bool: + """ + Return whether higher is better for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether higher is better for this criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> DummyEvaluationScores.higher_is_better("frizz_level") + True + Note: + This function is used to determine whether higher is better for the given criterion. + """ mapping = { "frizz_level": True, "blipp_score": False, @@ -23,6 +59,23 @@ def higher_is_better(criterion: str) -> bool: def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + Return the bounds for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + Tuple[Union[int, float, None], Union[int, float, None]] + the bounds for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> DummyEvaluationScores.bounds("frizz_level") + (0.0, 1.0) + Note: + This function is used to return the bounds for the given criterion. + """ mapping = { "frizz_level": (0.0, 1.0), "blipp_score": (0.0, 1.0), @@ -31,4 +84,21 @@ def bounds( @staticmethod def store_best(criterion: str) -> bool: + """ + Return whether to store the best score for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether to store the best score for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> DummyEvaluationScores.store_best("frizz_level") + True + Note: + This function is used to determine whether to store the best score for the given criterion. + """ return True diff --git a/dacapo/experiments/tasks/evaluators/dummy_evaluator.py b/dacapo/experiments/tasks/evaluators/dummy_evaluator.py index f9a4dc1ea..db2d68ac5 100644 --- a/dacapo/experiments/tasks/evaluators/dummy_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/dummy_evaluator.py @@ -5,6 +5,21 @@ class DummyEvaluator(Evaluator): + """ + A class representing a dummy evaluator. This evaluator is used for testing purposes. + + Attributes: + criteria : List[str] + the evaluation criteria + Methods: + evaluate(output_array_identifier, evaluation_dataset) + Evaluate the output array against the evaluation dataset. + score + Return the evaluation scores. + Note: + The DummyEvaluator class is used to evaluate the performance of a dummy task. + """ + criteria = ["frizz_level", "blipp_score"] def evaluate(self, output_array_identifier, evaluation_dataset): @@ -14,9 +29,18 @@ def evaluate(self, output_array_identifier, evaluation_dataset): Args: output_array_identifier : The output array to be evaluated. evaluation_dataset : The dataset to be used for evaluation. - Returns: DummyEvaluationScore: An object of DummyEvaluationScores class, with the evaluation scores. + Raises: + ValueError: if the output array identifier is not valid + Examples: + >>> dummy_evaluator = DummyEvaluator() + >>> output_array_identifier = "output_array" + >>> evaluation_dataset = "evaluation_dataset" + >>> dummy_evaluator.evaluate(output_array_identifier, evaluation_dataset) + DummyEvaluationScores(frizz_level=0.0, blipp_score=0.0) + Note: + This function is used to evaluate the output array against the evaluation dataset. """ return DummyEvaluationScores( frizz_level=random.random(), blipp_score=random.random() @@ -24,4 +48,16 @@ def evaluate(self, output_array_identifier, evaluation_dataset): @property def score(self) -> DummyEvaluationScores: + """ + Return the evaluation scores. + + Returns: + DummyEvaluationScores: An object of DummyEvaluationScores class, with the evaluation scores. + Examples: + >>> dummy_evaluator = DummyEvaluator() + >>> dummy_evaluator.score + DummyEvaluationScores(frizz_level=0.0, blipp_score=0.0) + Note: + This function is used to return the evaluation scores. + """ return DummyEvaluationScores() diff --git a/dacapo/experiments/tasks/evaluators/evaluation_scores.py b/dacapo/experiments/tasks/evaluators/evaluation_scores.py index 3733b9133..7cf341709 100644 --- a/dacapo/experiments/tasks/evaluators/evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/evaluation_scores.py @@ -6,11 +6,44 @@ @attr.s class EvaluationScores: - """Base class for evaluation scores.""" + """ + Base class for evaluation scores. This class is used to store the evaluation scores for a task. + The scores include the evaluation criteria. The class also provides methods to determine whether higher is better for a given criterion, + the bounds for a given criterion, and whether to store the best score for a given criterion. + + Attributes: + criteria : List[str] + the evaluation criteria + Methods: + higher_is_better(criterion) + Return whether higher is better for the given criterion. + bounds(criterion) + Return the bounds for the given criterion. + store_best(criterion) + Return whether to store the best score for the given criterion. + Note: + The EvaluationScores class is used to store the evaluation scores for a task. All evaluation scores should inherit from this class. + + """ @property @abstractmethod def criteria(self) -> List[str]: + """ + The evaluation criteria. + + Returns: + List[str] + the evaluation criteria + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluation_scores = EvaluationScores() + >>> evaluation_scores.criteria + ["criterion1", "criterion2"] + Note: + This function is used to return the evaluation criteria. + """ pass @staticmethod @@ -18,6 +51,23 @@ def criteria(self) -> List[str]: def higher_is_better(criterion: str) -> bool: """ Wether or not higher is better for this criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether higher is better for this criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluation_scores = EvaluationScores() + >>> criterion = "criterion1" + >>> evaluation_scores.higher_is_better(criterion) + True + Note: + This function is used to determine whether higher is better for a given criterion. + """ pass @@ -27,7 +77,24 @@ def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: """ - The bounds for this criterion + The bounds for this criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + Tuple[Union[int, float, None], Union[int, float, None]] + the bounds for this criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluation_scores = EvaluationScores() + >>> criterion = "criterion1" + >>> evaluation_scores.bounds(criterion) + (0, 1) + Note: + This function is used to return the bounds for the given criterion. + """ pass @@ -37,5 +104,21 @@ def store_best(criterion: str) -> bool: """ Whether or not to save the best validation block and model weights for this criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether to store the best score for this criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluation_scores = EvaluationScores() + >>> criterion = "criterion1" + >>> evaluation_scores.store_best(criterion) + True + Note: + This function is used to return whether to store the best score for the given criterion. """ pass diff --git a/dacapo/experiments/tasks/evaluators/evaluator.py b/dacapo/experiments/tasks/evaluators/evaluator.py index 83e4763b3..beccc57c5 100644 --- a/dacapo/experiments/tasks/evaluators/evaluator.py +++ b/dacapo/experiments/tasks/evaluators/evaluator.py @@ -21,10 +21,38 @@ class Evaluator(ABC): - """Base class of all evaluators. + """ + Base class of all evaluators: An abstract class representing an evaluator that compares and evaluates the output array against the evaluation array. An evaluator takes a post-processor's output and compares it against - ground-truth. + ground-truth. It then returns a set of scores that can be used to + determine the quality of the post-processor's output. + + Attributes: + best_scores : Dict[OutputIdentifier, BestScore] + the best scores for each dataset/post-processing parameter/criterion combination + Methods: + evaluate(output_array_identifier, evaluation_array) + Compare and evaluate the output array against the evaluation array. + is_best(dataset, parameter, criterion, score) + Check if the provided score is the best for this dataset/parameter/criterion combo. + get_overall_best(dataset, criterion) + Return the best score for the given dataset and criterion. + get_overall_best_parameters(dataset, criterion) + Return the best parameters for the given dataset and criterion. + compare(score_1, score_2, criterion) + Compare two scores for the given criterion. + set_best(validation_scores) + Find the best iteration for each dataset/post_processing_parameter/criterion. + higher_is_better(criterion) + Return whether higher is better for the given criterion. + bounds(criterion) + Return the bounds for the given criterion. + store_best(criterion) + Return whether to store the best score for the given criterion. + Note: + The Evaluator class is used to compare and evaluate the output array against the evaluation array. + """ @abstractmethod @@ -34,17 +62,24 @@ def evaluate( """ Compares and evaluates the output array against the evaluation array. - Parameters - ---------- - output_array_identifier : Array - The output data array to evaluate - evaluation_array : Array - The evaluation data array to compare with the output - - Returns - ------- - EvaluationScores - The detailed evaluation scores after the comparison. + Args: + output_array_identifier : LocalArrayIdentifier + The identifier of the output array. + evaluation_array : Array + The evaluation array. + Returns: + EvaluationScores + The evaluation scores. + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> output_array_identifier = LocalArrayIdentifier("output_array") + >>> evaluation_array = Array() + >>> evaluator.evaluate(output_array_identifier, evaluation_array) + EvaluationScores() + Note: + This function is used to compare and evaluate the output array against the evaluation array. """ pass @@ -52,6 +87,21 @@ def evaluate( def best_scores( self, ) -> Dict[OutputIdentifier, BestScore]: + """ + The best scores for each dataset/post-processing parameter/criterion combination. + + Returns: + Dict[OutputIdentifier, BestScore] + the best scores for each dataset/post-processing parameter/criterion combination + Raises: + AttributeError: if the best scores are not set + Examples: + >>> evaluator = Evaluator() + >>> evaluator.best_scores + {} + Note: + This function is used to return the best scores for each dataset/post-processing parameter/criterion combination. + """ if not hasattr(self, "_best_scores"): self._best_scores: Dict[OutputIdentifier, BestScore] = {} return self._best_scores @@ -64,7 +114,32 @@ def is_best( score: "EvaluationScores", ) -> bool: """ - Check if the provided score is the best for this dataset/parameter/criterion combo + Check if the provided score is the best for this dataset/parameter/criterion combo. + + Args: + dataset : Dataset + the dataset + parameter : PostProcessorParameters + the post-processor parameters + criterion : str + the criterion + score : EvaluationScores + the evaluation scores + Returns: + bool + whether the provided score is the best for this dataset/parameter/criterion combo + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> dataset = Dataset() + >>> parameter = PostProcessorParameters() + >>> criterion = "criterion" + >>> score = EvaluationScores() + >>> evaluator.is_best(dataset, parameter, criterion, score) + False + Note: + This function is used to check if the provided score is the best for this dataset/parameter/criterion combo. """ if not self.store_best(criterion) or math.isnan(getattr(score, criterion)): return False @@ -78,6 +153,28 @@ def is_best( ) def get_overall_best(self, dataset: "Dataset", criterion: str): + """ + Return the best score for the given dataset and criterion. + + Args: + dataset : Dataset + the dataset + criterion : str + the criterion + Returns: + Optional[float] + the best score for the given dataset and criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> dataset = Dataset() + >>> criterion = "criterion" + >>> evaluator.get_overall_best(dataset, criterion) + None + Note: + This function is used to return the best score for the given dataset and criterion. + """ overall_best = None if self.best_scores: for _, parameter, _ in self.best_scores.keys(): @@ -99,6 +196,28 @@ def get_overall_best(self, dataset: "Dataset", criterion: str): return overall_best def get_overall_best_parameters(self, dataset: "Dataset", criterion: str): + """ + Return the best parameters for the given dataset and criterion. + + Args: + dataset : Dataset + the dataset + criterion : str + the criterion + Returns: + Optional[PostProcessorParameters] + the best parameters for the given dataset and criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> dataset = Dataset() + >>> criterion = "criterion" + >>> evaluator.get_overall_best_parameters(dataset, criterion) + None + Note: + This function is used to return the best parameters for the given dataset and criterion. + """ overall_best = None overall_best_parameters = None if self.best_scores: @@ -121,6 +240,31 @@ def get_overall_best_parameters(self, dataset: "Dataset", criterion: str): return overall_best_parameters def compare(self, score_1, score_2, criterion): + """ + Compare two scores for the given criterion. + + Args: + score_1 : float + the first score + score_2 : float + the second score + criterion : str + the criterion + Returns: + bool + whether the first score is better than the second score for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> score_1 = 0.0 + >>> score_2 = 0.0 + >>> criterion = "criterion" + >>> evaluator.compare(score_1, score_2, criterion) + False + Note: + This function is used to compare two scores for the given criterion. + """ if self.higher_is_better(criterion): return score_1 > score_2 else: @@ -128,7 +272,21 @@ def compare(self, score_1, score_2, criterion): def set_best(self, validation_scores: "ValidationScores") -> None: """ - Find the best iteration for each dataset/post_processing_parameter/criterion + Find the best iteration for each dataset/post_processing_parameter/criterion. + + Args: + validation_scores : ValidationScores + the validation scores + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> validation_scores = ValidationScores() + >>> evaluator.set_best(validation_scores) + None + Note: + This function is used to find the best iteration for each dataset/post_processing_parameter/criterion. + Typically, this function is called after the validation scores have been computed. """ scores = validation_scores.to_xarray() @@ -185,12 +343,40 @@ def criteria(self) -> List[str]: criteria might be "precision", "recall", and "jaccard". It is unlikely that the best iteration/post processing parameters will be the same for all 3 of these criteria + + Returns: + List[str] + the evaluation criteria + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> evaluator.criteria + [] + Note: + This function is used to return the evaluation criteria. """ pass def higher_is_better(self, criterion: str) -> bool: """ Wether or not higher is better for this criterion. + + Args: + criterion : str + the criterion + Returns: + bool + whether higher is better for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> criterion = "criterion" + >>> evaluator.higher_is_better(criterion) + False + Note: + This function is used to determine whether higher is better for the given criterion. """ return self.score.higher_is_better(criterion) @@ -199,16 +385,63 @@ def bounds( ) -> Tuple[Union[int, float, None], Union[int, float, None]]: """ The bounds for this criterion + + Args: + criterion : str + the criterion + Returns: + Tuple[Union[int, float, None], Union[int, float, None]] + the bounds for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> criterion = "criterion" + >>> evaluator.bounds(criterion) + (0, 1) + Note: + This function is used to return the bounds for the given criterion. """ return self.score.bounds(criterion) def store_best(self, criterion: str) -> bool: """ The bounds for this criterion + + Args: + criterion : str + the criterion + Returns: + bool + whether to store the best score for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> criterion = "criterion" + >>> evaluator.store_best(criterion) + False + Note: + This function is used to return whether to store the best score for the given criterion. """ return self.score.store_best(criterion) @property @abstractmethod def score(self) -> "EvaluationScores": + """ + The evaluation scores. + + Returns: + EvaluationScores + the evaluation scores + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> evaluator = Evaluator() + >>> evaluator.score + EvaluationScores() + Note: + This function is used to return the evaluation scores. + """ pass diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py index 4e4df9cca..9e3e54e69 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluation_scores.py @@ -6,6 +6,27 @@ @attr.s class InstanceEvaluationScores(EvaluationScores): + """ + The evaluation scores for the instance segmentation task. The scores include the variation of information (VOI) split, VOI merge, and VOI. + + Attributes: + voi_split : float + the variation of information (VOI) split + voi_merge : float + the variation of information (VOI) merge + voi : float + the variation of information (VOI) + Methods: + higher_is_better(criterion) + Return whether higher is better for the given criterion. + bounds(criterion) + Return the bounds for the given criterion. + store_best(criterion) + Return whether to store the best score for the given criterion. + Note: + The InstanceEvaluationScores class is used to store the evaluation scores for the instance segmentation task. + """ + criteria = ["voi_split", "voi_merge", "voi"] voi_split: float = attr.ib(default=float("nan")) @@ -13,10 +34,42 @@ class InstanceEvaluationScores(EvaluationScores): @property def voi(self): + """ + Return the average of the VOI split and VOI merge. + + Returns: + float + the average of the VOI split and VOI merge + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> instance_evaluation_scores = InstanceEvaluationScores(voi_split=0.1, voi_merge=0.2) + >>> instance_evaluation_scores.voi + 0.15 + Note: + This function is used to calculate the average of the VOI split and VOI merge. + """ return (self.voi_split + self.voi_merge) / 2 @staticmethod def higher_is_better(criterion: str) -> bool: + """ + Return whether higher is better for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether higher is better for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> InstanceEvaluationScores.higher_is_better("voi_split") + False + Note: + This function is used to determine whether higher is better for the given criterion. + """ mapping = { "voi_split": False, "voi_merge": False, @@ -28,6 +81,24 @@ def higher_is_better(criterion: str) -> bool: def bounds( criterion: str, ) -> Tuple[Union[int, float, None], Union[int, float, None]]: + """ + Return the bounds for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + Tuple[Union[int, float, None], Union[int, float, None]] + the bounds for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> InstanceEvaluationScores.bounds("voi_split") + (0, 1) + Note: + This function is used to return the bounds for the given criterion. + + """ mapping = { "voi_split": (0, 1), "voi_merge": (0, 1), @@ -37,4 +108,21 @@ def bounds( @staticmethod def store_best(criterion: str) -> bool: + """ + Return whether to store the best score for the given criterion. + + Args: + criterion : str + the evaluation criterion + Returns: + bool + whether to store the best score for the given criterion + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> InstanceEvaluationScores.store_best("voi_split") + True + Note: + This function is used to determine whether to store the best score for the given criterion. + """ return True diff --git a/dacapo/experiments/tasks/evaluators/instance_evaluator.py b/dacapo/experiments/tasks/evaluators/instance_evaluator.py index 30707b369..d2e179eaa 100644 --- a/dacapo/experiments/tasks/evaluators/instance_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/instance_evaluator.py @@ -14,30 +14,34 @@ def relabel(array, return_backwards_map=False, inplace=False): - """Relabel array, such that IDs are consecutive. Excludes 0. + """ + Relabel array, such that IDs are consecutive. Excludes 0. Args: - array (ndarray): - The array to relabel. - return_backwards_map (``bool``, optional): - If ``True``, return an ndarray that maps new labels (indices in the array) to old labels. - inplace (``bool``, optional): - Perform the replacement in-place on ``array``. - Returns: - A tuple ``(relabelled, n)``, where ``relabelled`` is the relabelled array and ``n`` the number of unique labels found. - If ``return_backwards_map`` is ``True``, returns ``(relabelled, n, backwards_map)``. + Raises: + ValueError: + If ``array`` is not of type ``np.ndarray``. + Examples: + >>> array = np.array([[1, 2, 0], [0, 2, 1]]) + >>> relabel(array) + (array([[1, 2, 0], [0, 2, 1]]), 2) + >>> relabel(array, return_backwards_map=True) + (array([[1, 2, 0], [0, 2, 1]]), 2, [0, 1, 2]) + Note: + This function is used to relabel an array, such that IDs are consecutive. Excludes 0. + """ if array.size == 0: @@ -71,9 +75,48 @@ def relabel(array, return_backwards_map=False, inplace=False): class InstanceEvaluator(Evaluator): + """ + A class representing an evaluator for instance segmentation tasks. + + Attributes: + criteria : List[str] + the evaluation criteria + Methods: + evaluate(output_array_identifier, evaluation_array) + Evaluate the output array against the evaluation array. + score + Return the evaluation scores. + Note: + The InstanceEvaluator class is used to evaluate the performance of an instance segmentation task. + + """ + criteria: List[str] = ["voi_merge", "voi_split", "voi"] def evaluate(self, output_array_identifier, evaluation_array): + """ + Evaluate the output array against the evaluation array. + + Args: + output_array_identifier : str + the identifier of the output array + evaluation_array : ZarrArray + the evaluation array + Returns: + InstanceEvaluationScores + the evaluation scores + Raises: + ValueError: if the output array identifier is not valid + Examples: + >>> instance_evaluator = InstanceEvaluator() + >>> output_array_identifier = "output_array" + >>> evaluation_array = ZarrArray.open_from_array_identifier("evaluation_array") + >>> instance_evaluator.evaluate(output_array_identifier, evaluation_array) + InstanceEvaluationScores(voi_merge=0.0, voi_split=0.0) + Note: + This function is used to evaluate the output array against the evaluation array. + + """ output_array = ZarrArray.open_from_array_identifier(output_array_identifier) evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64) output_data = output_array[output_array.roi].astype(np.uint64) @@ -86,9 +129,47 @@ def evaluate(self, output_array_identifier, evaluation_array): @property def score(self) -> InstanceEvaluationScores: + """ + Return the evaluation scores. + + Returns: + InstanceEvaluationScores + the evaluation scores + Raises: + NotImplementedError: if the function is not implemented + Examples: + >>> instance_evaluator = InstanceEvaluator() + >>> instance_evaluator.score + InstanceEvaluationScores(voi_merge=0.0, voi_split=0.0) + Note: + This function is used to return the evaluation scores. + + """ return InstanceEvaluationScores() def voi(truth, test): + """ + Calculate the variation of information (VOI) between two segmentations. + + Args: + truth : ndarray + the ground truth segmentation + test : ndarray + the test segmentation + Returns: + dict + the variation of information (VOI) scores + Raises: + ValueError: if the truth and test arrays are not of type np.ndarray + Examples: + >>> truth = np.array([[1, 1, 0], [0, 2, 2]]) + >>> test = np.array([[1, 1, 0], [0, 2, 2]]) + >>> voi(truth, test) + {'voi_split': 0.0, 'voi_merge': 0.0} + Note: + This function is used to calculate the variation of information (VOI) between two segmentations. + + """ voi_split, voi_merge = _voi(test + 1, truth + 1, ignore_groundtruth=[]) return {"voi_split": voi_split, "voi_merge": voi_merge} diff --git a/dacapo/experiments/tasks/hot_distance_task.py b/dacapo/experiments/tasks/hot_distance_task.py index d9e8fbca3..630e58ed5 100644 --- a/dacapo/experiments/tasks/hot_distance_task.py +++ b/dacapo/experiments/tasks/hot_distance_task.py @@ -16,6 +16,10 @@ class HotDistanceTask(Task): loss: HotDistanceLoss object. post_processor: ThresholdPostProcessor object. evaluator: BinarySegmentationEvaluator object. + Methods: + __init__(self, task_config): Constructs all the necessary attributes for the HotDistanceTask object. + Notes: + This is a subclass of Task. """ def __init__(self, task_config): @@ -24,6 +28,10 @@ def __init__(self, task_config): Args: task_config : The task configuration parameters. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> task = HotDistanceTask(task_config) """ self.predictor = HotDistancePredictor( diff --git a/dacapo/experiments/tasks/hot_distance_task_config.py b/dacapo/experiments/tasks/hot_distance_task_config.py index ccb036ae4..18cab91b3 100644 --- a/dacapo/experiments/tasks/hot_distance_task_config.py +++ b/dacapo/experiments/tasks/hot_distance_task_config.py @@ -8,7 +8,8 @@ @attr.s class HotDistanceTaskConfig(TaskConfig): - """Class for generating TaskConfigs for the HotDistanceTask, which predicts one hot encodings of classes, as well as signed distance transforms of those classes. + """ + Class for generating TaskConfigs for the HotDistanceTask, which predicts one hot encodings of classes, as well as signed distance transforms of those classes. Attributes: task_type: A reference to the Hot Distance Task class. @@ -19,7 +20,8 @@ class HotDistanceTaskConfig(TaskConfig): a tanh normalization. Defaults to 1. mask_distances (bool): Whether or not to mask out regions where the true distance to object boundary cannot be known. Defaults to False - + Methods: + verify(self) -> Tuple[bool, str]: This method verifies the HotDistanceTaskConfig object. Note: Generating distance transforms over regular affinities provides you with a denser signal, i.e., one misclassified pixel in an affinity prediction can merge 2 diff --git a/dacapo/experiments/tasks/inner_distance_task.py b/dacapo/experiments/tasks/inner_distance_task.py index 25af2158d..45315a672 100644 --- a/dacapo/experiments/tasks/inner_distance_task.py +++ b/dacapo/experiments/tasks/inner_distance_task.py @@ -15,6 +15,10 @@ class InnerDistanceTask(Task): loss: Used for calculating the mean square error loss. post_processor: Used for applying threshold post-processing. evaluator: Used for evaluating the results using binary segmentation. + Methods: + __init__(self, task_config): Initializes an instance of InnerDistanceTask. + Notes: + This is a subclass of Task. """ def __init__(self, task_config): @@ -24,6 +28,11 @@ def __init__(self, task_config): Args: task_config: The configuration for the task including channel and scale factor for prediction, and clip distance, tolerance distance, and channels for evaluation. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> task = InnerDistanceTask(task_config) + """ self.predictor = InnerDistancePredictor( diff --git a/dacapo/experiments/tasks/inner_distance_task_config.py b/dacapo/experiments/tasks/inner_distance_task_config.py index 1a66cc47d..e193a1ae2 100644 --- a/dacapo/experiments/tasks/inner_distance_task_config.py +++ b/dacapo/experiments/tasks/inner_distance_task_config.py @@ -8,7 +8,8 @@ @attr.s class InnerDistanceTaskConfig(TaskConfig): - """This is a Distance task config used for generating and + """ + This is a Distance task config used for generating and evaluating signed distance transforms as a way of generating segmentations. @@ -16,6 +17,15 @@ class InnerDistanceTaskConfig(TaskConfig): affinities is you can get a denser signal, i.e. 1 misclassified pixel in an affinity prediction could merge 2 otherwise very distinct objects, this cannot happen with distances. + + Attributes: + channels: A list of channel names. + clip_distance: Maximum distance to consider for false positive/negatives. + tol_distance: Tolerance distance for counting false positives/negatives + scale_factor: The amount by which to scale distances before applying a tanh normalization. + Notes: + This is a subclass of TaskConfig. + """ task_type = InnerDistanceTask diff --git a/dacapo/experiments/tasks/losses/affinities_loss.py b/dacapo/experiments/tasks/losses/affinities_loss.py index 74fc7fe67..1bc9aded5 100644 --- a/dacapo/experiments/tasks/losses/affinities_loss.py +++ b/dacapo/experiments/tasks/losses/affinities_loss.py @@ -3,11 +3,71 @@ class AffinitiesLoss(Loss): + """ + A class representing a loss function that calculates the loss between affinities and local shape descriptors (LSDs). + + Attributes: + num_affinities : int + the number of affinities + lsds_to_affs_weight_ratio : float + the ratio of the weight of the loss between affinities and LSDs + Methods: + compute(prediction, target, weight=None) + Calculate the total loss between prediction and target. + Note: + The AffinitiesLoss class is used to calculate the loss between affinities and local shape descriptors (LSDs). + + """ + def __init__(self, num_affinities: int, lsds_to_affs_weight_ratio: float): + """ + Initialize the AffinitiesLoss class with the number of affinities and the ratio of the weight of the loss between affinities and LSDs. + + Args: + num_affinities : int + the number of affinities + lsds_to_affs_weight_ratio : float + the ratio of the weight of the loss between affinities and LSDs + Examples: + >>> affinities_loss = AffinitiesLoss(3, 0.5) + >>> affinities_loss.num_affinities + 3 + >>> affinities_loss.lsds_to_affs_weight_ratio + 0.5 + Note: + The AffinitiesLoss class is used to calculate the loss between affinities and local shape descriptors (LSDs). + + """ self.num_affinities = num_affinities self.lsds_to_affs_weight_ratio = lsds_to_affs_weight_ratio def compute(self, prediction, target, weight): + """ + Method to calculate the total loss between affinities and LSDs. + + Args: + prediction : torch.Tensor + the model's prediction + target : torch.Tensor + the target values + weight : torch.Tensor + the weight to apply to the loss + Returns: + torch.Tensor + the total loss between affinities and LSDs + Raises: + ValueError: if the number of affinities in the prediction and target does not match + Examples: + >>> affinities_loss = AffinitiesLoss(3, 0.5) + >>> prediction = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + >>> target = torch.tensor([[9, 10, 11, 12], [13, 14, 15, 16]]) + >>> weight = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 1]]) + >>> affinities_loss.compute(prediction, target, weight) + tensor(0.5) + Note: + The AffinitiesLoss class is used to calculate the loss between affinities and local shape descriptors (LSDs). + + """ affs, affs_target, affs_weight = ( prediction[:, 0 : self.num_affinities, ...], target[:, 0 : self.num_affinities, ...], diff --git a/dacapo/experiments/tasks/losses/dummy_loss.py b/dacapo/experiments/tasks/losses/dummy_loss.py index f68206d01..b9543a5b9 100644 --- a/dacapo/experiments/tasks/losses/dummy_loss.py +++ b/dacapo/experiments/tasks/losses/dummy_loss.py @@ -7,29 +7,41 @@ class DummyLoss(Loss): Inherits the Loss class. - Methods - ------- - compute(prediction, target, weight=None) - Calculate the total loss between prediction and target. + Attributes: + name : str + name of the loss function + Methods: + compute(prediction, target, weight=None) + Calculate the total loss between prediction and target. + Note: + The dummy loss is used to test the training loop and the loss calculation. It is not a real loss function. + It is used to test the training loop and the loss calculation. + """ def compute(self, prediction, target, weight=None): """ Method to calculate the total dummy loss. - Parameters - ---------- - prediction : float or int - predicted output - target : float or int - true output - weight : float or int, optional - weight parameter for the loss, by default None - - Returns - ------- - float or int - Total loss calculated as the sum of absolute differences between prediction and target. + Args: + prediction : torch.Tensor + the model's prediction + target : torch.Tensor + the target values + weight : torch.Tensor + the weight to apply to the loss + Returns: + torch.Tensor + the total loss between prediction and target + Examples: + >>> dummy_loss = DummyLoss() + >>> prediction = torch.tensor([1, 2, 3]) + >>> target = torch.tensor([4, 5, 6]) + >>> dummy_loss.compute(prediction, target) + tensor(9) + Note: + The dummy loss is used to test the training loop and the loss calculation. It is not a real loss function. + It is used to test the training loop and the loss calculation. """ return abs(prediction - target).sum() diff --git a/dacapo/experiments/tasks/losses/hot_distance_loss.py b/dacapo/experiments/tasks/losses/hot_distance_loss.py index 784176bd0..8974ee3fa 100644 --- a/dacapo/experiments/tasks/losses/hot_distance_loss.py +++ b/dacapo/experiments/tasks/losses/hot_distance_loss.py @@ -7,7 +7,52 @@ # The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps. # Model should predict twice the number of channels as the target. class HotDistanceLoss(Loss): + """ + A class used to represent the Hot Distance Loss function. This class inherits from the Loss class. The Hot Distance Loss + function is used for predicting hot and distance maps at the same time. The first half of the channels are the hot maps, + the second half are the distance maps. The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance + maps. The model should predict twice the number of channels as the target. + + Attributes: + hot_loss: The Binary Cross Entropy Loss function. + distance_loss: The Mean Square Error Loss function. + Methods: + compute(prediction, target, weight) -> torch.Tensor + Function to compute the Hot Distance Loss for the provided prediction and target, with respect to the weight. + split(x) -> Tuple[torch.Tensor, torch.Tensor] + Function to split the input tensor into two tensors. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes + cannot be changed. + + """ + def compute(self, prediction, target, weight): + """ + Function to compute the Hot Distance Loss for the provided prediction and target, with respect to the weight. + + Args: + prediction : torch.Tensor + The predicted tensor. + target : torch.Tensor + The target tensor. + weight : torch.Tensor + The weight tensor. + Returns: + torch.Tensor + The computed Hot Distance Loss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = HotDistanceLoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.compute(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the computed Hot Distance Loss tensor. + """ target_hot, target_distance = self.split(target) prediction_hot, prediction_distance = self.split(prediction) weight_hot, weight_distance = self.split(weight) @@ -16,14 +61,83 @@ def compute(self, prediction, target, weight): ) + self.distance_loss(prediction_distance, target_distance, weight_distance) def hot_loss(self, prediction, target, weight): + """ + The Binary Cross Entropy Loss function. This function computes the BCELoss for the hot maps. + + Args: + prediction : torch.Tensor + The predicted tensor. + target : torch.Tensor + The target tensor. + weight : torch.Tensor + The weight tensor. + Returns: + torch.Tensor + The computed BCELoss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = HotDistanceLoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.hot_loss(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the computed BCELoss tensor. + """ loss = torch.nn.BCEWithLogitsLoss(reduction="none") return torch.mean(loss(prediction, target) * weight) def distance_loss(self, prediction, target, weight): + """ + The Mean Square Error Loss function. This function computes the MSELoss for the distance maps. + + Args: + prediction : torch.Tensor + The predicted tensor. + target : torch.Tensor + The target tensor. + weight : torch.Tensor + The weight tensor. + Returns: + torch.Tensor + The computed MSELoss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = HotDistanceLoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.distance_loss(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the computed MSELoss tensor. + """ loss = torch.nn.MSELoss() return loss(prediction * weight, target * weight) def split(self, x): + """ + Function to split the input tensor into two tensors. + + Args: + x : torch.Tensor + The input tensor. + Returns: + Tuple[torch.Tensor, torch.Tensor] + The two split tensors. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = HotDistanceLoss() + >>> x = torch.tensor([1.0, 2.0, 3.0]) + >>> loss.split(x) + (tensor([1.0]), tensor([2.0])) + Note: + This method must be implemented in the subclass. It should return the two split tensors. + """ # Shape[0] is the batch size and Shape[1] is the number of channels. assert ( x.shape[1] % 2 == 0 diff --git a/dacapo/experiments/tasks/losses/loss.py b/dacapo/experiments/tasks/losses/loss.py index 20824d6ab..34b3e9aa8 100644 --- a/dacapo/experiments/tasks/losses/loss.py +++ b/dacapo/experiments/tasks/losses/loss.py @@ -5,6 +5,18 @@ class Loss(ABC): + """ + A class used to represent a loss function. This class is an abstract class + that should be inherited by any loss function class. + + Methods: + compute(prediction, target, weight) -> torch.Tensor + Function to compute the loss for the provided prediction and target, with respect to the weight. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes + cannot be changed. + """ + @abstractmethod def compute( self, @@ -12,10 +24,31 @@ def compute( target: torch.Tensor, weight: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Compute the loss for the given prediction and target. Optionally, if + """ + Compute the loss for the given prediction and target. Optionally, if given, a loss weight should be considered. All arguments are ``torch`` tensors. The return type should be a ``torch`` scalar that can be used with an optimizer, just as usual when - training with ``torch``.""" + training with ``torch``. + + Args: + prediction: The predicted tensor. + target: The target tensor. + weight: The weight tensor. + Returns: + The computed loss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = MSELoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.compute(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the + computed loss tensor. + """ pass diff --git a/dacapo/experiments/tasks/losses/mse_loss.py b/dacapo/experiments/tasks/losses/mse_loss.py index 348042c11..e3b0dac0a 100644 --- a/dacapo/experiments/tasks/losses/mse_loss.py +++ b/dacapo/experiments/tasks/losses/mse_loss.py @@ -4,34 +4,40 @@ class MSELoss(Loss): """ - A class used to represent the Mean Square Error Loss function (MSELoss). + A class used to represent the Mean Square Error Loss function (MSELoss). This class inherits from the Loss class. - Attributes - ---------- - None - - Methods - ------- - compute(prediction, target, weight): - Computes the MSELoss with the given weight for the predictiom and target. + Methods: + compute(prediction, target, weight) -> torch.Tensor + Function to compute the MSELoss for the provided prediction and target, with respect to the weight. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once created, the values of its attributes + cannot be changed. """ def compute(self, prediction, target, weight): """ Function to compute the MSELoss for the provided prediction and target, with respect to the weight. - Parameters: - ---------- - prediction : torch.Tensor - The prediction tensor for which loss needs to be calculated. - target : torch.Tensor - The target tensor with respect to which loss is calculated. - weight : torch.Tensor - The weight tensor used to weigh the prediction in the loss calculation. - + Args: + prediction : torch.Tensor + The predicted tensor. + target : torch.Tensor + The target tensor. + weight : torch.Tensor + The weight tensor. Returns: - ------- - torch.Tensor - The computed MSELoss tensor. + torch.Tensor + The computed MSELoss tensor. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> loss = MSELoss() + >>> prediction = torch.tensor([1.0, 2.0, 3.0]) + >>> target = torch.tensor([1.0, 2.0, 3.0]) + >>> weight = torch.tensor([1.0, 1.0, 1.0]) + >>> loss.compute(prediction, target, weight) + tensor(0.) + Note: + This method must be implemented in the subclass. It should return the computed MSELoss tensor. """ return torch.nn.MSELoss().forward(prediction * weight, target * weight) diff --git a/dacapo/experiments/tasks/one_hot_task.py b/dacapo/experiments/tasks/one_hot_task.py index 7abc27fda..870140f50 100644 --- a/dacapo/experiments/tasks/one_hot_task.py +++ b/dacapo/experiments/tasks/one_hot_task.py @@ -6,7 +6,30 @@ class OneHotTask(Task): + """ + A task that uses a one-hot predictor. The model is loaded from a file + and the weights are loaded from a file. The loss is a dummy loss and the + post processor is an argmax post processor. The evaluator is a dummy evaluator. + + Attributes: + weights (Path): The path to the weights file. + Methods: + create_model(self, architecture) -> Model: This method creates a model using the given architecture. + Notes: + This is a base class for all tasks that use one-hot predictors. + """ + def __init__(self, task_config): + """ + Initialize the OneHotTask object. + + Args: + task_config (TaskConfig): The configuration of the task. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> task = OneHotTask(task_config) + """ self.predictor = OneHotPredictor(classes=task_config.classes) self.loss = DummyLoss() self.post_processor = ArgmaxPostProcessor() diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index bfd4584d9..5bdfa767f 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -1,4 +1,4 @@ -from pathlib import Path +from upath import UPath as Path from dacapo.blockwise import run_blockwise import dacapo.blockwise from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray @@ -10,16 +10,75 @@ class ArgmaxPostProcessor(PostProcessor): + """ + Post-processor that takes the argmax of the input array along the channel + axis. The output is a binary array where the value is 1 if the argmax is + greater than the threshold, and 0 otherwise. + + Attributes: + prediction_array: The array containing the model's prediction. + Methods: + enumerate_parameters: Enumerate all possible parameters of this post-processor. + set_prediction: Set the prediction array identifier. + process: Convert predictions into the final output. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once + created, the values of its attributes cannot be changed. + """ + def __init__(self): + """ + Initialize the post-processor. + + Args: + detection_threshold: The detection threshold. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = ArgmaxPostProcessor() + Note: + This method must be implemented in the subclass. It should set the + `detection_threshold` attribute. + """ pass def enumerate_parameters(self): - """Enumerate all possible parameters of this post-processor. Should - return instances of ``PostProcessorParameters``.""" + """ + Enumerate all possible parameters of this post-processor. Should + return instances of ``PostProcessorParameters``. + + Returns: + An iterable of `PostProcessorParameters` instances. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = ArgmaxPostProcessor() + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + ArgmaxPostProcessorParameters(id=0) + Note: + This method must be implemented in the subclass. It should return an + iterable of `PostProcessorParameters` instances. + """ yield ArgmaxPostProcessorParameters(id=1) def set_prediction(self, prediction_array_identifier): + """ + Set the prediction array identifier. + + Args: + prediction_array_identifier: The identifier of the array containing + the model's prediction. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = ArgmaxPostProcessor() + >>> post_processor.set_prediction("prediction") + Note: + This method must be implemented in the subclass. It should set the + `prediction_array_identifier` attribute. + """ self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier @@ -32,6 +91,26 @@ def process( num_workers: int = 16, block_size: Coordinate = Coordinate((256, 256, 256)), ): + """ + Convert predictions into the final output. + + Args: + parameters: The parameters of the post-processor. + output_array_identifier: The identifier of the output array. + num_workers: The number of workers to use. + block_size: The size of the blocks to process. + Returns: + The output array. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = ArgmaxPostProcessor() + >>> post_processor.set_prediction("prediction") + >>> post_processor.process(parameters, "output") + Note: + This method must be implemented in the subclass. It should process the + predictions and return the output array. + """ if self.prediction_array._daisy_array.chunk_shape is not None: block_size = Coordinate( self.prediction_array._daisy_array.chunk_shape[ diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor_parameters.py index 331faf5e6..308a470a6 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor_parameters.py @@ -4,4 +4,15 @@ @attr.s(frozen=True) class ArgmaxPostProcessorParameters(PostProcessorParameters): + """ + Parameters for the argmax post-processor. The argmax post-processor will set + the output to the index of the maximum value in the input array. + + Methods: + parameter_names: Get the names of the parameters. + Note: + This class is immutable. Once created, the values of its attributes + cannot be changed. + """ + pass diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py index 4a992ced2..7d54f0714 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py @@ -8,20 +8,108 @@ class DummyPostProcessor(PostProcessor): + """ + Dummy post-processor that stores some dummy data. The dummy data is a 10x10x10 + array filled with the value of the min_size parameter. The min_size parameter + is specified in the parameters of the post-processor. The post-processor has + a detection threshold that is used to determine if an object is detected. + + Attributes: + detection_threshold: The detection threshold. + Methods: + enumerate_parameters: Enumerate all possible parameters of this post-processor. + set_prediction: Set the prediction array identifier. + process: Convert predictions into the final output. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once + created, the values of its attributes cannot be changed. + """ + def __init__(self, detection_threshold: float): + """ + Initialize the post-processor. + + Args: + detection_threshold: The detection threshold. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = DummyPostProcessor(0.5) + Note: + This method must be implemented in the subclass. It should set the + `detection_threshold` attribute. + """ self.detection_threshold = detection_threshold def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]: - """Enumerate all possible parameters of this post-processor. Should - return instances of ``PostProcessorParameters``.""" + """ + Enumerate all possible parameters of this post-processor. Should + return instances of ``PostProcessorParameters``. + + Returns: + An iterable of `PostProcessorParameters` instances. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = DummyPostProcessor() + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + DummyPostProcessorParameters(id=0, min_size=1) + DummyPostProcessorParameters(id=1, min_size=2) + DummyPostProcessorParameters(id=2, min_size=3) + DummyPostProcessorParameters(id=3, min_size=4) + DummyPostProcessorParameters(id=4, min_size=5) + DummyPostProcessorParameters(id=5, min_size=6) + DummyPostProcessorParameters(id=6, min_size=7) + DummyPostProcessorParameters(id=7, min_size=8) + DummyPostProcessorParameters(id=8, min_size=9) + DummyPostProcessorParameters(id=9, min_size=10) + Note: + This method must be implemented in the subclass. It should return an + iterable of `PostProcessorParameters` instances. + """ for i, min_size in enumerate(range(1, 11)): yield DummyPostProcessorParameters(id=i, min_size=min_size) def set_prediction(self, prediction_array_identifier): + """ + Set the prediction array identifier. + + Args: + prediction_array_identifier: The identifier of the array containing + the model's prediction. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = DummyPostProcessor() + >>> post_processor.set_prediction("prediction") + Note: + This method must be implemented in the subclass. It should set the + `prediction_array_identifier` attribute. + """ pass def process(self, parameters, output_array_identifier, *args, **kwargs): + """ + Convert predictions into the final output. + + Args: + parameters: The parameters of the post-processor. + output_array_identifier: The identifier of the output array. + num_workers: The number of workers to use. + chunk_size: The size of the chunks to process. + Returns: + The output array. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = DummyPostProcessor() + >>> post_processor.process(parameters, "output") + Note: + This method must be implemented in the subclass. It should process the + predictions and store the output in the output array. + """ # store some dummy data f = zarr.open(str(output_array_identifier.container), "a") f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor_parameters.py index bfa09e583..6321c27d9 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor_parameters.py @@ -4,4 +4,19 @@ @attr.s(frozen=True) class DummyPostProcessorParameters(PostProcessorParameters): + """ + Parameters for the dummy post-processor. The dummy post-processor will set + the output to 1 if the input is greater than the minimum size, and 0 + otherwise. + + Attributes: + min_size: The minimum size. If the input is greater than this value, the + output will be set to 1. Otherwise, the output will be set to 0. + Methods: + parameter_names: Get the names of the parameters. + Note: + This class is immutable. Once created, the values of its attributes + cannot be changed. + """ + min_size: int = attr.ib() diff --git a/dacapo/experiments/tasks/post_processors/post_processor.py b/dacapo/experiments/tasks/post_processors/post_processor.py index f0a991c51..2b63b15c0 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor.py +++ b/dacapo/experiments/tasks/post_processors/post_processor.py @@ -12,21 +12,71 @@ class PostProcessor(ABC): - """Base class of all post-processors. + """ + Base class of all post-processors. A post-processor takes a model's prediction and converts it into the final - output (e.g., per-voxel class probabilities into a semantic segmentation). + output (e.g., per-voxel class probabilities into a semantic segmentation). A + post-processor can have multiple parameters, which can be enumerated using + the `enumerate_parameters` method. The `process` method takes a set of + parameters and applies the post-processing to the prediction. + + Attributes: + prediction_array_identifier: The identifier of the array containing the + model's prediction. + Methods: + enumerate_parameters: Enumerate all possible parameters of this + post-processor. + set_prediction: Set the prediction array identifier. + process: Convert predictions into the final output. + Note: + This class is abstract. Subclasses must implement the abstract methods. Once + created, the values of its attributes cannot be changed. """ @abstractmethod def enumerate_parameters(self) -> Iterable["PostProcessorParameters"]: - """Enumerate all possible parameters of this post-processor.""" + """ + Enumerate all possible parameters of this post-processor. + + Returns: + An iterable of `PostProcessorParameters` instances. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = MyPostProcessor() + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + MyPostProcessorParameters(param1=0.0, param2=0.0) + MyPostProcessorParameters(param1=0.0, param2=1.0) + MyPostProcessorParameters(param1=1.0, param2=0.0) + MyPostProcessorParameters(param1=1.0, param2=1.0) + Note: + This method must be implemented in the subclass. It should return an + iterable of `PostProcessorParameters` instances. + + """ pass @abstractmethod def set_prediction( self, prediction_array_identifier: "LocalArrayIdentifier" ) -> None: + """ + Set the prediction array identifier. + + Args: + prediction_array_identifier: The identifier of the array containing + the model's prediction. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = MyPostProcessor() + >>> post_processor.set_prediction("prediction") + Note: + This method must be implemented in the subclass. It should set the + `prediction_array_identifier` attribute. + """ pass @abstractmethod @@ -37,5 +87,26 @@ def process( num_workers: int = 16, chunk_size: Coordinate = Coordinate((64, 64, 64)), ) -> "Array": - """Convert predictions into the final output.""" + """ + Convert predictions into the final output. + + Args: + parameters: The parameters of the post-processor. + output_array_identifier: The identifier of the array to store the + output. + num_workers: The number of workers to use. + chunk_size: The size of the chunks to process. + Returns: + The output array. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> post_processor = MyPostProcessor() + >>> post_processor.set_prediction("prediction") + >>> parameters = MyPostProcessorParameters(param1=0.0, param2=0.0) + >>> output = post_processor.process(parameters, "output") + Note: + This method must be implemented in the subclass. It should convert the + model's prediction into the final output. + """ pass diff --git a/dacapo/experiments/tasks/post_processors/post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/post_processor_parameters.py index dd08ab41c..271b5127e 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/post_processor_parameters.py @@ -5,12 +5,40 @@ @attr.s(frozen=True) class PostProcessorParameters: - """Base class for post-processor parameters.""" + """ + Base class for post-processor parameters. Post-processor parameters are + immutable objects that define the parameters of a post-processor. The + parameters are used to configure the post-processor. + + Attributes: + id: The identifier of the post-processor parameter. + Methods: + parameter_names: Get the names of the parameters. + Note: + This class is immutable. Once created, the values of its attributes + cannot be changed. + + """ id: int = attr.ib() @property def parameter_names(self) -> List[str]: + """ + Get the names of the parameters. + + Returns: + A list of parameter names. + Raises: + NotImplementedError: If the method is not implemented in the subclass. + Examples: + >>> parameters = PostProcessorParameters(0) + >>> parameters.parameter_names + ["id"] + Note: + This method must be implemented in the subclass. It should return a + list of parameter names. + """ return ["id"] diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 2ea537a6c..c0e10418c 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -1,4 +1,4 @@ -from pathlib import Path +from upath import UPath as Path from dacapo.blockwise.scheduler import run_blockwise from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from .threshold_post_processor_parameters import ThresholdPostProcessorParameters @@ -12,15 +12,53 @@ class ThresholdPostProcessor(PostProcessor): + """ + A post-processor that applies a threshold to the prediction. + + Attributes: + prediction_array_identifier: The identifier of the prediction array. + prediction_array: The prediction array. + Methods: + enumerate_parameters: Enumerate all possible parameters of this post-processor. + set_prediction: Set the prediction array. + process: Process the prediction with the given parameters. + Note: + This post-processor applies a threshold to the prediction. The threshold is used to define the segmentation. The prediction array is set using the `set_prediction` method. + """ + def __init__(self): pass def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]: - """Enumerate all possible parameters of this post-processor.""" + """ + Enumerate all possible parameters of this post-processor. + + Returns: + Generator[ThresholdPostProcessorParameters]: A generator of parameters. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + Note: + This method should return a generator of instances of ``ThresholdPostProcessorParameters``. + """ for i, threshold in enumerate([100, 127, 150]): yield ThresholdPostProcessorParameters(id=i, threshold=threshold) def set_prediction(self, prediction_array_identifier): + """ + Set the prediction array. + + Args: + prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> post_processor.set_prediction(prediction_array_identifier) + Note: + This method should set the prediction array using the given identifier. + """ self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier @@ -33,6 +71,24 @@ def process( num_workers: int = 16, block_size: Coordinate = Coordinate((256, 256, 256)), ) -> ZarrArray: + """ + Process the prediction with the given parameters. + + Args: + parameters (ThresholdPostProcessorParameters): The parameters to use for processing. + output_array_identifier (LocalArrayIdentifier): The identifier of the output array. + num_workers (int): The number of workers to use for processing. + block_size (Coordinate): The block size to use for processing. + Returns: + ZarrArray: The output array. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> post_processor.process(parameters, output_array_identifier) + Note: + This method should process the prediction with the given parameters and return the output array. The method uses the `run_blockwise` function from the `dacapo.blockwise.scheduler` module to run the blockwise post-processing. + The output array is created using the `ZarrArray.create_from_array_identifier` function from the `dacapo.experiments.datasplits.datasets.arrays` module. + """ # TODO: Investigate Liskov substitution princple and whether it is a problem here # OOP theory states the super class should always be replaceable with its subclasses # meaning the input arguments to methods on the subclass can only be more loosely @@ -54,11 +110,6 @@ def process( self.prediction_array.voxel_size, ) ] - if ( - self.prediction_array.num_channels is not None - and self.prediction_array.num_channels > 1 - ): - write_size = [self.prediction_array.num_channels] + write_size output_array = ZarrArray.create_from_array_identifier( output_array_identifier, self.prediction_array.axes, diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py index 9a28ba970..014fc1ec2 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor_parameters.py @@ -4,4 +4,18 @@ @attr.s(frozen=True) class ThresholdPostProcessorParameters(PostProcessorParameters): + """ + Parameters for the threshold post-processor. The threshold post-processor + will set the output to 1 if the input is greater than the threshold, and 0 + otherwise. + + Attributes: + threshold: The threshold value. If the input is greater than this + value, the output will be set to 1. Otherwise, the output will be + set to 0. + Note: + This class is immutable. Once created, the values of its attributes + cannot be changed. + """ + threshold: float = attr.ib(default=0.0) diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 7a3467daa..fd22b426a 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -1,4 +1,4 @@ -from pathlib import Path +from upath import UPath as Path import dacapo.blockwise from dacapo.blockwise.scheduler import segment_blockwise from dacapo.experiments.datasplits.datasets.arrays import ZarrArray @@ -16,12 +16,51 @@ class WatershedPostProcessor(PostProcessor): + """ + A post-processor that applies a watershed transformation to the + prediction. + + Attributes: + offsets: List of offsets for the watershed transformation. + Methods: + enumerate_parameters: Enumerate all possible parameters of this post-processor. + set_prediction: Set the prediction array. + process: Process the prediction with the given parameters. + Note: + This post-processor uses the `watershed_function.py` script to apply the watershed transformation. The offsets are used to define the neighborhood for the watershed transformation. + + """ + def __init__(self, offsets: List[Coordinate]): + """ + A post-processor that applies a watershed transformation to the + prediction. + + Args: + offsets (List[Coordinate]): List of offsets for the watershed transformation. + Examples: + >>> WatershedPostProcessor(offsets=[(0, 0, 1), (0, 1, 0), (1, 0, 0)]) + Note: + This post-processor uses the `watershed_function.py` script to apply the watershed transformation. The offsets are used to define the neighborhood for the watershed transformation. + """ self.offsets = offsets def enumerate_parameters(self): - """Enumerate all possible parameters of this post-processor. Should - return instances of ``PostProcessorParameters``.""" + """ + Enumerate all possible parameters of this post-processor. Should + return instances of ``PostProcessorParameters``. + + Returns: + Generator[WatershedPostProcessorParameters]: A generator of parameters. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> for parameters in post_processor.enumerate_parameters(): + ... print(parameters) + Note: + This method should be implemented by the subclass. It should return a generator of instances of ``WatershedPostProcessorParameters``. + + """ for i, bias in enumerate([0.1, 0.25, 0.5, 0.75, 0.9]): yield WatershedPostProcessorParameters(id=i, bias=bias) @@ -31,6 +70,18 @@ def set_prediction(self, prediction_array_identifier): self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) + """ + Set the prediction array. + + Args: + prediction_array_identifier (LocalArrayIdentifier): The prediction array identifier. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> post_processor.set_prediction(prediction_array_identifier) + Note: + This method should be implemented by the subclass. To set the prediction array, the method uses the `ZarrArray.open_from_array_identifier` function from the `dacapo.experiments.datasplits.datasets.arrays` module. + """ def process( self, @@ -39,6 +90,23 @@ def process( num_workers: int = 16, block_size: Coordinate = Coordinate((256, 256, 256)), ): + """ + Process the prediction with the given parameters. + + Args: + parameters (WatershedPostProcessorParameters): The parameters to use for processing. + output_array_identifier (LocalArrayIdentifier): The output array identifier. + num_workers (int): The number of workers to use for processing. + block_size (Coordinate): The block size to use for processing. + Returns: + LocalArrayIdentifier: The output array identifier. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> post_processor.process(parameters, output_array_identifier) + Note: + This method should be implemented by the subclass. To run the watershed transformation, the method uses the `segment_blockwise` function from the `dacapo.blockwise.scheduler` module. + """ if self.prediction_array._daisy_array.chunk_shape is not None: block_size = Coordinate( self.prediction_array._daisy_array.chunk_shape[ diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor_parameters.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor_parameters.py index 6a3a1e271..70dadaca8 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor_parameters.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor_parameters.py @@ -5,5 +5,21 @@ @attr.s(frozen=True) class WatershedPostProcessorParameters(PostProcessorParameters): + """ + Parameters for the watershed post-processor. + + Attributes: + offsets: List of offsets for the watershed transformation. + threshold: Threshold for the watershed transformation. + sigma: Sigma for the watershed transformation. + min_size: Minimum size of the segments. + bias: Bias for the watershed transformation. + context: Context for the watershed transformation. + Examples: + >>> WatershedPostProcessorParameters(offsets=[(0, 0, 1), (0, 1, 0), (1, 0, 0)], threshold=0.5, sigma=1.0, min_size=100, bias=0.5, context=(32, 32, 32)) + Note: + This class is used by the ``WatershedPostProcessor`` to define the parameters for the watershed transformation. The offsets are used to define the neighborhood for the watershed transformation. + """ + bias: float = attr.ib(default=0.5) context: Coordinate = attr.ib(default=Coordinate((32, 32, 32))) diff --git a/dacapo/experiments/tasks/predictors/affinities_predictor.py b/dacapo/experiments/tasks/predictors/affinities_predictor.py index d68541349..59f0cfa60 100644 --- a/dacapo/experiments/tasks/predictors/affinities_predictor.py +++ b/dacapo/experiments/tasks/predictors/affinities_predictor.py @@ -14,9 +14,65 @@ import itertools from typing import List +from .predictor import Predictor +from dacapo.experiments import Model +from dacapo.experiments.arraytypes import EmbeddingArray +from dacapo.experiments.datasplits.datasets.arrays import NumpyArray +from dacapo.utils.affinities import seg_to_affgraph, padding as aff_padding +from dacapo.utils.balance_weights import balance_weights +from funlib.geometry import Coordinate +from lsd.train import LsdExtractor +from scipy import ndimage +import numpy as np +import torch +import itertools +from typing import List class AffinitiesPredictor(Predictor): + """ + A predictor for generating affinity predictions from input data. + + Attributes: + neighborhood (List[Coordinate]): The neighborhood. + lsds (bool): Whether to compute local shape descriptors. + num_voxels (int): The number of voxels. + downsample_lsds (int): The downsample rate for LSDs. + grow_boundary_iterations (int): The number of iterations to grow the boundary. + affs_weight_clipmin (float): The minimum weight for affinities. + affs_weight_clipmax (float): The maximum weight for affinities. + lsd_weight_clipmin (float): The minimum weight for LSDs. + lsd_weight_clipmax (float): The maximum weight for LSDs. + background_as_object (bool): Whether to treat the background as an object. + Methods: + __init__( + self, + neighborhood: List[Coordinate], + lsds: bool = True, + num_voxels: int = 20, + downsample_lsds: int = 1, + grow_boundary_iterations: int = 0, + affs_weight_clipmin: float = 0.05, + affs_weight_clipmax: float = 0.95, + lsd_weight_clipmin: float = 0.05, + lsd_weight_clipmax: float = 0.95, + background_as_object: bool = False + ): Initializes the AffinitiesPredictor. + extractor(self, voxel_size): Get the LSD extractor. + dims: Get the number of dimensions. + sigma(self, voxel_size): Compute the sigma value for LSD computation. + lsd_pad(self, voxel_size): Compute the padding for LSD computation. + num_channels: Get the number of channels. + create_model(self, architecture): Create the model. + create_target(self, gt): Create the target data. + _grow_boundaries(self, mask, slab): Grow the boundaries of the mask. + create_weight(self, gt, target, mask, moving_class_counts=None): Create the weight data. + gt_region_for_roi(self, target_spec): Get the ground truth region for the target region of interest (ROI). + output_array_type: Get the output array type. + Notes: + This is a subclass of Predictor. + """ + def __init__( self, neighborhood: List[Coordinate], @@ -30,6 +86,16 @@ def __init__( lsd_weight_clipmax: float = 0.95, background_as_object: bool = False, ): + """ + Initializes the AffinitiesPredictor. + + Args: + neighborhood (List[Coordinate]): The neighborhood. + Raises: + ValueError: If the number of dimensions is not 2 or 3. + Examples: + >>> neighborhood = [Coordinate((0, 1)), Coordinate((1, 0))] + """ self.neighborhood = neighborhood self.lsds = lsds self.num_voxels = num_voxels @@ -55,6 +121,18 @@ def __init__( self.background_as_object = background_as_object def extractor(self, voxel_size): + """ + Get the LSD extractor. + + Args: + voxel_size (Coordinate): The voxel size. + Returns: + LsdExtractor: The LSD extractor. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> extractor = predictor.extractor(voxel_size) + """ if self._extractor is None: self._extractor = LsdExtractor( self.sigma(voxel_size), downsample=self.downsample_lsds @@ -64,23 +142,80 @@ def extractor(self, voxel_size): @property def dims(self): + """ + Get the number of dimensions. + + Returns: + int: The number of dimensions. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.dims + + """ return self.neighborhood[0].dims def sigma(self, voxel_size): + """ + Compute the sigma value for LSD computation. + + Args: + voxel_size (Coordinate): The voxel size. + Returns: + Coordinate: The sigma value. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.sigma(voxel_size) + """ voxel_dist = max(voxel_size) # arbitrarily chosen sigma = voxel_dist * self.num_voxels # arbitrarily chosen return Coordinate((sigma,) * self.dims) def lsd_pad(self, voxel_size): + """ + Compute the padding for LSD computation. + + Args: + voxel_size (Coordinate): The voxel size. + Returns: + Coordinate: The padding value. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.lsd_pad(voxel_size) + """ multiplier = 3 # from AddLocalShapeDescriptor Node in funlib.lsd padding = Coordinate(self.sigma(voxel_size) * multiplier) return padding @property def num_channels(self): + """ + Get the number of channels. + + Returns: + int: The number of channels. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.num_channels + """ return len(self.neighborhood) + self.num_lsds def create_model(self, architecture): + """ + Create the model. + + Args: + architecture: The architecture for the model. + Returns: + Model: The created model. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> model = predictor.create_model(architecture) + """ if self.dims == 2: head = torch.nn.Conv2d( architecture.num_out_channels, self.num_channels, kernel_size=1 @@ -97,6 +232,19 @@ def create_model(self, architecture): return Model(architecture, head, eval_activation=torch.nn.Sigmoid()) def create_target(self, gt): + """ + Create the target data. + + Args: + gt: The ground truth data. + Returns: + NumpyArray: The created target data. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_target(gt) + + """ # zeros assert gt.num_channels is None or gt.num_channels == 1, ( "Cannot create affinities from ground truth with multiple channels.\n" @@ -130,6 +278,19 @@ def create_target(self, gt): ) def _grow_boundaries(self, mask, slab): + """ + Grow the boundaries of the mask. + + Args: + mask: The mask data. + slab: The slab definition. + Returns: + np.ndarray: The mask with grown boundaries. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor._grow_boundaries(mask, slab) + """ # get all foreground voxels by erosion of each component foreground = np.zeros(shape=mask.shape, dtype=bool) @@ -153,6 +314,21 @@ def _grow_boundaries(self, mask, slab): return mask def create_weight(self, gt, target, mask, moving_class_counts=None): + """ + Create the weight data. + + Args: + gt: The ground truth data. + target: The target data. + mask: The mask data. + moving_class_counts: The moving class counts. + Returns: + Tuple[NumpyArray, Tuple]: The created weight data and moving class counts. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_weight(gt, target, mask, moving_class_counts) + """ (moving_class_counts, moving_lsd_class_counts) = ( moving_class_counts if moving_class_counts is not None else (None, None) ) @@ -198,6 +374,17 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): ), (moving_class_counts, moving_lsd_class_counts) def gt_region_for_roi(self, target_spec): + """ + Get the ground truth region for the target region of interest (ROI). + + Args: + target_spec: The target region of interest (ROI) specification. + Returns: + The ground truth region specification. + Raises: + NotImplementedError: This method is not implemented. + + """ gt_spec = target_spec.copy() pad_neg, pad_pos = aff_padding(self.neighborhood, target_spec.voxel_size) if self.lsds: @@ -221,4 +408,14 @@ def gt_region_for_roi(self, target_spec): @property def output_array_type(self): + """ + Get the output array type. + + Returns: + EmbeddingArray: The output array type. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.output_array_type + """ return EmbeddingArray(self.dims) diff --git a/dacapo/experiments/tasks/predictors/distance_predictor.py b/dacapo/experiments/tasks/predictors/distance_predictor.py index eb19cd9e1..403565b00 100644 --- a/dacapo/experiments/tasks/predictors/distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/distance_predictor.py @@ -25,6 +25,33 @@ class DistancePredictor(Predictor): Multiple classes can be predicted via multiple distance channels. The names of each class that is being segmented can be passed in as a list of strings in the channels argument. + + Attributes: + channels (List[str]): The list of class labels. + scale_factor (float): The scale factor for the distance transform. + mask_distances (bool): Whether to mask distances. + clipmin (float): The minimum value to clip the weights to. + clipmax (float): The maximum value to clip the weights to. + Methods: + __init__(self, channels: List[str], scale_factor: float, mask_distances: bool, clipmin: float = 0.05, clipmax: float = 0.95): Initializes the DistancePredictor. + create_model(self, architecture): Create the model for the predictor. + create_target(self, gt): Create the target array for training. + create_weight(self, gt, target, mask, moving_class_counts=None): Create the weight array for training. + output_array_type: Get the output array type. + create_distance_mask(self, distances, mask, voxel_size, normalize=None, normalize_args=None): Create the distance mask. + process(self, labels, voxel_size, normalize=None, normalize_args=None): Process the labels array. + gt_region_for_roi(self, target_spec): Get the ground-truth region for the ROI. + padding(self, gt_voxel_size: Coordinate) -> Coordinate: Get the padding needed for the ground-truth array. + Notes: + The DistancePredictor is used to predict signed distances for a binary segmentation task. + The distances are calculated using the distance_transform_edt function from scipy.ndimage.morphology. + The distances are then passed through a tanh function to saturate the distances at +-1. + The distances are calculated for each class that is being segmented and are stored in separate channels. + The names of each class that is being segmented can be passed in as a list of strings in the channels argument. + The scale_factor argument is used to scale the distances. + The mask_distances argument is used to determine whether to mask distances. + The clipmin argument is used to determine the minimum value to clip the weights to. + The clipmax argument is used to determine the maximum value to clip the weights to. """ def __init__( @@ -35,6 +62,20 @@ def __init__( clipmin: float = 0.05, clipmax: float = 0.95, ): + """ + Initialize the DistancePredictor object. + + Args: + channels (List[str]): List of channel names. + scale_factor (float): Scale factor for distance calculation. + mask_distances (bool): Flag indicating whether to mask distances. + clipmin (float, optional): Minimum clipping value. Defaults to 0.05. + clipmax (float, optional): Maximum clipping value. Defaults to 0.95. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor = DistancePredictor(channels, scale_factor, mask_distances, clipmin, clipmax) + """ self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor @@ -48,9 +89,31 @@ def __init__( @property def embedding_dims(self): + """ + Get the number of embedding dimensions. + + Returns: + int: The number of embedding dimensions. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> embedding_dims = predictor.embedding_dims + """ return len(self.channels) def create_model(self, architecture): + """ + Create the model for the predictor. + + Args: + architecture: The architecture for the model. + Returns: + Model: The created model. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> model = predictor.create_model(architecture) + """ if architecture.dims == 2: head = torch.nn.Conv2d( architecture.num_out_channels, self.embedding_dims, kernel_size=1 @@ -63,6 +126,19 @@ def create_model(self, architecture): return Model(architecture, head) def create_target(self, gt): + """ + Create the target array for training. + + Args: + gt: The ground-truth array. + Returns: + NumpyArray: The created target array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_target(gt) + + """ distances = self.process( gt.data, gt.voxel_size, self.norm, self.dt_scale_factor ) @@ -74,6 +150,22 @@ def create_target(self, gt): ) def create_weight(self, gt, target, mask, moving_class_counts=None): + """ + Create the weight array for training, given a ground-truth and + associated target array. + + Args: + gt: The ground-truth array. + target: The target array. + mask: The mask array. + moving_class_counts: The moving class counts. + Returns: + The weight array and the moving class counts. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_weight(gt, target, mask, moving_class_counts) + """ # balance weights independently for each channel if self.mask_distances: distance_mask = self.create_distance_mask( @@ -107,6 +199,16 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): @property def output_array_type(self): + """ + Get the output array type. + + Returns: + DistanceArray: The output array type. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.output_array_type + """ return DistanceArray(self.embedding_dims) def create_distance_mask( @@ -117,6 +219,23 @@ def create_distance_mask( normalize=None, normalize_args=None, ): + """ + Create a distance mask. + + Args: + distances (np.ndarray): The distances array. + mask (np.ndarray): The mask array. + voxel_size (Coordinate): The voxel size. + normalize (str): The normalization method. + normalize_args: The normalization arguments. + Returns: + np.ndarray: The distance mask. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_distance_mask(distances, mask, voxel_size, normalize, normalize_args) + + """ mask_output = mask.copy() for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)): tmp = np.zeros( @@ -176,6 +295,22 @@ def process( normalize=None, normalize_args=None, ): + """ + Process the labels array and convert it to one-hot encoding. + + Args: + labels (np.ndarray): The labels array. + voxel_size (Coordinate): The voxel size. + normalize (str): The normalization method. + normalize_args: The normalization arguments. + Returns: + np.ndarray: The distances array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.process(labels, voxel_size, normalize, normalize_args) + + """ all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 for ii, channel in enumerate(labels): boundaries = self.__find_boundaries(channel) @@ -213,6 +348,19 @@ def process( return all_distances def __find_boundaries(self, labels): + """ + Find the boundaries in the labels. + + Args: + labels: The labels. + Returns: + The boundaries. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.__find_boundaries(labels) + + """ # labels: 1 1 1 1 0 0 2 2 2 2 3 3 n # shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1 # diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1 @@ -251,6 +399,21 @@ def __find_boundaries(self, labels): return boundaries def __normalize(self, distances, norm, normalize_args): + """ + Normalize the distances. + + Args: + distances: The distances to normalize. + norm: The normalization method. + normalize_args: The normalization arguments. + Returns: + The normalized distances. + Raises: + ValueError: If the normalization method is not supported. + Examples: + >>> predictor.__normalize(distances, norm, normalize_args) + + """ if norm == "tanh": scale = normalize_args return np.tanh(distances / scale) @@ -258,6 +421,20 @@ def __normalize(self, distances, norm, normalize_args): raise ValueError("Only tanh is supported for normalization") def gt_region_for_roi(self, target_spec): + """ + Report how much spatial context this predictor needs to generate a + target for the given ROI. By default, uses the same ROI. + + Args: + target_spec: The ROI for which the target is requested. + Returns: + The ROI for which the ground-truth is requested. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.gt_region_for_roi(target_spec) + + """ if self.mask_distances: gt_spec = target_spec.copy() gt_spec.roi = gt_spec.roi.grow( @@ -269,4 +446,17 @@ def gt_region_for_roi(self, target_spec): return gt_spec def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + """ + Return the padding needed for the ground-truth array. + + Args: + gt_voxel_size (Coordinate): The voxel size of the ground-truth array. + Returns: + Coordinate: The padding needed for the ground-truth array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.padding(gt_voxel_size) + + """ return Coordinate((self.max_distance,) * gt_voxel_size.dims) diff --git a/dacapo/experiments/tasks/predictors/dummy_predictor.py b/dacapo/experiments/tasks/predictors/dummy_predictor.py index 5e7ba8b6c..3fb64b9ac 100644 --- a/dacapo/experiments/tasks/predictors/dummy_predictor.py +++ b/dacapo/experiments/tasks/predictors/dummy_predictor.py @@ -8,10 +8,47 @@ class DummyPredictor(Predictor): + """ + A dummy predictor class that inherits from the base Predictor class. + + Attributes: + embedding_dims (int): The number of embedding dimensions. + Methods: + __init__(self, embedding_dims: int): Initializes a new instance of the DummyPredictor class. + create_model(self, architecture): Creates a model using the given architecture. + create_target(self, gt): Creates a target based on the ground truth. + create_weight(self, gt, target, mask, moving_class_counts=None): Creates a weight based on the ground truth, target, and mask. + output_array_type: Gets the output array type. + Notes: + This is a subclass of Predictor. + """ + def __init__(self, embedding_dims): + """ + Initializes a new instance of the DummyPredictor class. + + Args: + embedding_dims (int): The number of embedding dimensions. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor = DummyPredictor(embedding_dims) + """ self.embedding_dims = embedding_dims def create_model(self, architecture): + """ + Creates a model using the given architecture. + + Args: + architecture: The architecture to use for creating the model. + Returns: + Model: The created model. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> model = predictor.create_model(architecture) + """ head = torch.nn.Conv3d( architecture.num_out_channels, self.embedding_dims, kernel_size=3 ) @@ -19,6 +56,18 @@ def create_model(self, architecture): return Model(architecture, head) def create_target(self, gt): + """ + Creates a target based on the ground truth. + + Args: + gt: The ground truth. + Returns: + NumpyArray: The created target. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_target(gt) + """ # zeros return NumpyArray.from_np_array( np.zeros((self.embedding_dims,) + gt.data.shape[-gt.dims :]), @@ -28,6 +77,21 @@ def create_target(self, gt): ) def create_weight(self, gt, target, mask, moving_class_counts=None): + """ + Creates a weight based on the ground truth, target, and mask. + + Args: + gt: The ground truth. + target: The target. + mask: The mask. + moving_class_counts: The moving class counts. + Returns: + Tuple[NumpyArray, None]: The created weight and None. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_weight(gt, target, mask, moving_class_counts) + """ # ones return ( NumpyArray.from_np_array( @@ -41,4 +105,14 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): @property def output_array_type(self): + """ + Gets the output array type. + + Returns: + EmbeddingArray: The output array type. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.output_array_type + """ return EmbeddingArray(self.embedding_dims) diff --git a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py index 7aab7c21f..c25df23ec 100644 --- a/dacapo/experiments/tasks/predictors/hot_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/hot_distance_predictor.py @@ -26,9 +26,44 @@ class HotDistancePredictor(Predictor): Multiple classes can be predicted via multiple distance channels. The names of each class that is being segmented can be passed in as a list of strings in the channels argument. + + Attributes: + channels: List of strings, each string is the name of a class that is being segmented. + scale_factor: The scale factor for the distance transform. + mask_distances: Whether to mask distances based on the distance to the boundary. + norm: The normalization function to use for the distance transform. + dt_scale_factor: The scale factor for the distance transform. + max_distance: The maximum distance to consider for the distance transform. + epsilon: The epsilon value to use for the distance transform. + threshold: The threshold value to use for the distance transform. + Methods: + __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): Initializes the HotDistancePredictor. + create_model(self, architecture): Create the model for the predictor. + create_target(self, gt): Create the target array for training. + create_weight(self, gt, target, mask, moving_class_counts=None): Create the weight array for training. + create_distance_mask(self, distances, mask, voxel_size, normalize=None, normalize_args=None): Create the distance mask for training. + process(self, labels, voxel_size, normalize=None, normalize_args=None): Process the labels array and convert it to one-hot encoding. + gt_region_for_roi(self, target_spec): Report how much spatial context this predictor needs to generate a target for the given ROI. + padding(self, gt_voxel_size): Return the padding needed for the ground-truth + Notes: + This is a subclass of Predictor. """ def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool): + """ + Initializes the HotDistancePredictor. + + Args: + channels (List[str]): The list of class labels. + scale_factor (float): The scale factor for the distance transform. + mask_distances (bool): Whether to mask distances based on the distance to the boundary. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor = HotDistancePredictor(channels, scale_factor, mask_distances) + Note: + The channels argument is a list of strings, each string is the name of a class that is being segmented. + """ self.channels = ( channels * 2 ) # one hot + distance (TODO: add hot/distance to channel names) @@ -42,13 +77,46 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo @property def embedding_dims(self): + """ + Get the number of embedding dimensions. + + Returns: + int: The number of embedding dimensions. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> embedding_dims = predictor.embedding_dims + """ return len(self.channels) @property def classes(self): + """ + Get the number of classes. + + Returns: + int: The number of classes. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> classes = predictor.classes + + """ return len(self.channels) // 2 def create_model(self, architecture): + """ + Create the model for the predictor. + + Args: + architecture: The architecture for the model. + Returns: + Model: The created model. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> model = predictor.create_model(architecture) + """ if architecture.dims == 2: head = torch.nn.Conv2d( architecture.num_out_channels, self.embedding_dims, kernel_size=3 @@ -61,6 +129,18 @@ def create_model(self, architecture): return Model(architecture, head) def create_target(self, gt): + """ + Create the target array for training. + + Args: + gt: The ground truth array. + Returns: + NumpyArray: The created target array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> target = predictor.create_target(gt) + """ target = self.process(gt.data, gt.voxel_size, self.norm, self.dt_scale_factor) return NumpyArray.from_np_array( target, @@ -70,6 +150,22 @@ def create_target(self, gt): ) def create_weight(self, gt, target, mask, moving_class_counts=None): + """ + Create the weight array for training, given a ground-truth and + associated target array. + + Args: + gt: The ground-truth array. + target: The target array. + mask: The mask array. + moving_class_counts: The moving class counts. + Returns: + Tuple[NumpyArray, np.ndarray]: The weight array and the moving class counts. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_weight(gt, target, mask, moving_class_counts) + """ # balance weights independently for each channel one_hot_weights, one_hot_moving_class_counts = balance_weights( gt[target.roi], @@ -122,6 +218,18 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): @property def output_array_type(self): + """ + Get the output array type. + + Returns: + ProbabilityArray: The output array type. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> output_array_type = predictor.output_array_type + Notes: + Technically this is a probability array + distance array, but it is only ever referenced for interpolatability (which is true for both). + """ # technically this is a probability array + distance array, but it is only ever referenced for interpolatability (which is true for both) (TODO) return ProbabilityArray(self.embedding_dims) @@ -133,6 +241,22 @@ def create_distance_mask( normalize=None, normalize_args=None, ): + """ + Create the distance mask for training. + + Args: + distances: The distances array. + mask: The mask array. + voxel_size: The voxel size. + normalize: The normalization function to use. + normalize_args: The normalization arguments. + Returns: + np.ndarray: The distance mask. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> distance_mask = self.create_distance_mask(distances, mask, voxel_size, normalize, normalize_args) + """ mask_output = mask.copy() for i, (channel_distance, channel_mask) in enumerate(zip(distances, mask)): tmp = np.zeros( @@ -192,6 +316,22 @@ def process( normalize=None, normalize_args=None, ): + """ + Process the labels array and convert it to one-hot encoding. + + Args: + labels: The labels array. + voxel_size: The voxel size. + normalize: The normalization function to use. + normalize_args: The normalization arguments. + Returns: + np.ndarray: The one-hot encoded array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.process(labels, voxel_size, normalize, normalize_args) + + """ all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 for ii, channel in enumerate(labels): boundaries = self.__find_boundaries(channel) @@ -229,6 +369,20 @@ def process( return np.concatenate((labels, all_distances)) def __find_boundaries(self, labels): + """ + Find the boundaries in the labels array. + + Args: + labels: The labels array. + Returns: + The boundaries array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> boundaries = self.__find_boundaries(labels) + Notes: + Assumes labels has a singleton channel dim and channel dim is first. + """ # labels: 1 1 1 1 0 0 2 2 2 2 3 3 n # shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1 # diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1 @@ -267,6 +421,20 @@ def __find_boundaries(self, labels): return boundaries def __normalize(self, distances, norm, normalize_args): + """ + Normalize the distances. + + Args: + distances: The distances to normalize. + norm: The normalization function to use. + normalize_args: The normalization arguments. + Returns: + The normalized distances. + Raises: + ValueError: Only tanh is supported for normalization. + Examples: + >>> normalized_distances = self.__normalize(distances, norm, normalize_args) + """ if norm == "tanh": scale = normalize_args return np.tanh(distances / scale) @@ -274,6 +442,20 @@ def __normalize(self, distances, norm, normalize_args): raise ValueError("Only tanh is supported for normalization") def gt_region_for_roi(self, target_spec): + """ + Report how much spatial context this predictor needs to generate a + target for the given ROI. By default, uses the same ROI. + + Args: + target_spec: The ROI for which the target is requested. + Returns: + The ROI for which the ground-truth is requested. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.gt_region_for_roi(target_spec) + + """ if self.mask_distances: gt_spec = target_spec.copy() gt_spec.roi = gt_spec.roi.grow( @@ -285,4 +467,16 @@ def gt_region_for_roi(self, target_spec): return gt_spec def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + """ + Return the padding needed for the ground-truth array. + + Args: + gt_voxel_size: The voxel size of the ground-truth array. + Returns: + Coordinate: The padding needed for the ground-truth array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.padding(gt_voxel_size) + """ return Coordinate((self.max_distance,) * gt_voxel_size.dims) diff --git a/dacapo/experiments/tasks/predictors/inner_distance_predictor.py b/dacapo/experiments/tasks/predictors/inner_distance_predictor.py index b91e790d4..7e168c0e5 100644 --- a/dacapo/experiments/tasks/predictors/inner_distance_predictor.py +++ b/dacapo/experiments/tasks/predictors/inner_distance_predictor.py @@ -26,9 +26,37 @@ class InnerDistancePredictor(Predictor): Multiple classes can be predicted via multiple distance channels. The names of each class that is being segmented can be passed in as a list of strings in the channels argument. + + Attributes: + channels (List[str]): The list of channel names. + scale_factor (float): The amount by which to scale distances before applying a tanh normalization. + Methods: + __init__(self, channels: List[str], scale_factor: float): Initializes the InnerDistancePredictor. + create_model(self, architecture): Create the model for the predictor. + create_target(self, gt): Create the target array for training. + create_weight(self, gt, target, mask, moving_class_counts=None): Create the weight array for training. + output_array_type: Get the output array type. + process(self, labels: np.ndarray, voxel_size: Coordinate, normalize=None, normalize_args=None): Process the labels array and convert it to signed distances. + __find_boundaries(self, labels): Find the boundaries in a labels array. + __normalize(self, distances, norm, normalize_args): Normalize the distances. + gt_region_for_roi(self, target_spec): Get the ground-truth region for the given ROI. + padding(self, gt_voxel_size: Coordinate) -> Coordinate: Get the padding needed for the ground-truth array. + Notes: + This is a subclass of Predictor. """ def __init__(self, channels: List[str], scale_factor: float): + """ + Initialize the InnerDistancePredictor. + + Args: + channels (List[str]): The list of channel names. + scale_factor (float): The amount by which to scale distances before applying a tanh normalization. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor = InnerDistancePredictor(channels, scale_factor) + """ self.channels = channels self.norm = "tanh" self.dt_scale_factor = scale_factor @@ -39,9 +67,29 @@ def __init__(self, channels: List[str], scale_factor: float): @property def embedding_dims(self): + """ + Get the number of embedding dimensions. + + Returns: + int: The number of embedding dimensions. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> embedding_dims = predictor.embedding_dims + """ return len(self.channels) def create_model(self, architecture): + """ + Create the model for the predictor. + + Args: + architecture: The architecture for the model. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> model = predictor.create_model(architecture) + """ if architecture.dims == 2: head = torch.nn.Conv2d( architecture.num_out_channels, self.embedding_dims, kernel_size=1 @@ -54,6 +102,19 @@ def create_model(self, architecture): return Model(architecture, head) def create_target(self, gt): + """ + Create the target array for training. + + Args: + gt: The ground-truth array. + Returns: + The DistanceArray. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_target(gt) + + """ distances = self.process( gt.data, gt.voxel_size, self.norm, self.dt_scale_factor ) @@ -65,6 +126,23 @@ def create_target(self, gt): ) def create_weight(self, gt, target, mask, moving_class_counts=None): + """ + Create the weight array for training, given a ground-truth and + associated target array. + + Args: + gt: The ground-truth array. + target: The target array. + mask: The mask array. + moving_class_counts: The moving class counts. + Returns: + The weight array and the moving class counts. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_weight(gt, target, mask, moving_class_counts) + + """ # balance weights independently for each channel weights, moving_class_counts = balance_weights( @@ -86,6 +164,17 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): @property def output_array_type(self): + """ + Get the output array type. + + Returns: + The DistanceArray. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.output_array_type + + """ return DistanceArray(self.embedding_dims) def process( @@ -95,6 +184,22 @@ def process( normalize=None, normalize_args=None, ): + """ + Process the labels array and convert it to signed distances. + + Args: + labels: The labels array. + voxel_size: The voxel size. + normalize: The normalization method. + normalize_args: The normalization arguments. + Returns: + The signed distances. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.process(labels, voxel_size, normalize, normalize_args) + + """ all_distances = np.zeros(labels.shape, dtype=np.float32) - 1 for ii, channel in enumerate(labels): boundaries = self.__find_boundaries(channel) @@ -132,6 +237,19 @@ def process( return all_distances * labels def __find_boundaries(self, labels): + """ + Find boundaries in a labels array. + + Args: + labels: The labels array. + Returns: + The boundaries array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.__find_boundaries(labels) + + """ # labels: 1 1 1 1 0 0 2 2 2 2 3 3 n # shift : 1 1 1 1 0 0 2 2 2 2 3 n - 1 # diff : 0 0 0 1 0 1 0 0 0 1 0 n - 1 @@ -170,6 +288,20 @@ def __find_boundaries(self, labels): return boundaries def __normalize(self, distances, norm, normalize_args): + """ + Normalize distances. + + Args: + distances: The distances to normalize. + norm: The normalization method. + normalize_args: The normalization arguments. + Returns: + The normalized distances. + Raises: + ValueError: If the normalization method is not supported. + Examples: + >>> predictor.__normalize(distances, norm, normalize_args) + """ if norm == "tanh": scale = normalize_args return np.tanh(distances / scale) @@ -177,6 +309,20 @@ def __normalize(self, distances, norm, normalize_args): raise ValueError("Only tanh is supported for normalization") def gt_region_for_roi(self, target_spec): + """ + Report how much spatial context this predictor needs to generate a + target for the given ROI. By default, uses the same ROI. + + Args: + target_spec: The ROI for which the target is requested. + Returns: + The ROI for which the ground-truth is requested. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.gt_region_for_roi(target_spec) + + """ if self.mask_distances: gt_spec = target_spec.copy() gt_spec.roi = gt_spec.roi.grow( @@ -188,4 +334,16 @@ def gt_region_for_roi(self, target_spec): return gt_spec def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + """ + Return the padding needed for the ground-truth array. + + Args: + gt_voxel_size: The voxel size of the ground-truth array. + Returns: + The padding needed for the ground-truth array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.padding(gt_voxel_size) + """ return Coordinate((self.max_distance,) * gt_voxel_size.dims) diff --git a/dacapo/experiments/tasks/predictors/one_hot_predictor.py b/dacapo/experiments/tasks/predictors/one_hot_predictor.py index 7aa55936a..abf90be7e 100644 --- a/dacapo/experiments/tasks/predictors/one_hot_predictor.py +++ b/dacapo/experiments/tasks/predictors/one_hot_predictor.py @@ -13,14 +13,62 @@ class OneHotPredictor(Predictor): + """ + A predictor that uses one-hot encoding for classification tasks. + + Attributes: + classes (List[str]): The list of class labels. + Methods: + __init__(self, classes: List[str]): Initializes the OneHotPredictor. + create_model(self, architecture): Create the model for the predictor. + create_target(self, gt): Create the target array for training. + create_weight(self, gt, target, mask, moving_class_counts=None): Create the weight array for training. + output_array_type: Get the output array type. + process(self, labels: np.ndarray): Process the labels array and convert it to one-hot encoding. + Notes: + This is a subclass of Predictor. + """ + def __init__(self, classes: List[str]): + """ + Initialize the OneHotPredictor. + + Args: + classes (List[str]): The list of class labels. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor = OneHotPredictor(classes) + """ self.classes = classes @property def embedding_dims(self): + """ + Get the number of embedding dimensions. + + Returns: + int: The number of embedding dimensions. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> embedding_dims = predictor.embedding_dims + """ return len(self.classes) def create_model(self, architecture): + """ + Create the model for the predictor. + + Args: + architecture: The architecture for the model. + Returns: + Model: The created model. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> model = predictor.create_model(architecture) + """ head = torch.nn.Conv3d( architecture.num_out_channels, self.embedding_dims, kernel_size=3 ) @@ -28,6 +76,19 @@ def create_model(self, architecture): return Model(architecture, head) def create_target(self, gt): + """ + Create the target array for training. + + Args: + gt: The ground truth array. + Returns: + NumpyArray: The created target array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> target = predictor.create_target(gt) + + """ one_hots = self.process(gt.data) return NumpyArray.from_np_array( one_hots, @@ -37,6 +98,22 @@ def create_target(self, gt): ) def create_weight(self, gt, target, mask, moving_class_counts=None): + """ + Create the weight array for training. + + Args: + gt: The ground truth array. + target: The target array. + mask: The mask array. + moving_class_counts: The moving class counts. + Returns: + Tuple[NumpyArray, None]: The created weight array and None. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_weight(gt, target, mask, moving_class_counts) + + """ return ( NumpyArray.from_np_array( np.ones(target.data.shape), @@ -49,12 +126,36 @@ def create_weight(self, gt, target, mask, moving_class_counts=None): @property def output_array_type(self): + """ + Get the output array type. + + Returns: + ProbabilityArray: The output array type. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> output_array_type = predictor.output_array_type + """ return ProbabilityArray(self.classes) def process( self, labels: np.ndarray, ): + """ + Process the labels array and convert it to one-hot encoding. + + Args: + labels (np.ndarray): The labels array. + Returns: + np.ndarray: The one-hot encoded array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> one_hots = predictor.process(labels) + Notes: + Assumes labels has a singleton channel dim and channel dim is first. + """ # TODO: Assumes labels has a singleton channel dim and channel dim is first one_hots = np.zeros((self.embedding_dims,) + labels.shape[1:], dtype=np.uint8) for i, _ in enumerate(self.classes): diff --git a/dacapo/experiments/tasks/predictors/predictor.py b/dacapo/experiments/tasks/predictors/predictor.py index 166156f31..8c1dce00d 100644 --- a/dacapo/experiments/tasks/predictors/predictor.py +++ b/dacapo/experiments/tasks/predictors/predictor.py @@ -10,17 +10,47 @@ class Predictor(ABC): + """ + A predictor is a class that defines how to train a model to predict a + certain output from a certain input. + + A predictor is responsible for creating the model, the target, the weight, + and the output array type for a given training architecture. + + Methods: + create_model(self, architecture: "Architecture") -> "Model": Given a training architecture, create a model for this predictor. + create_target(self, gt: "Array") -> "Array": Create the target array for training, given a ground-truth array. + create_weight(self, gt: "Array", target: "Array", mask: "Array", moving_class_counts: Any) -> Tuple["Array", Any]: Create the weight array for training, given a ground-truth and associated target array. + gt_region_for_roi(self, target_spec): Report how much spatial context this predictor needs to generate a target for the given ROI. + padding(self, gt_voxel_size: Coordinate) -> Coordinate: Return the padding needed for the ground-truth array. + Notes: + This is a subclass of ABC. + """ + @abstractmethod def create_model(self, architecture: "Architecture") -> "Model": - """Given a training architecture, create a model for this predictor. + """ + Given a training architecture, create a model for this predictor. This is usually done by appending extra layers to the output of the architecture to get the output tensor of the architecture into the - right shape for this predictor.""" + right shape for this predictor. + + Args: + architecture: The training architecture. + Returns: + The model. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_model(architecture) + + """ pass @abstractmethod def create_target(self, gt: "Array") -> "Array": - """Create the target array for training, given a ground-truth array. + """ + Create the target array for training, given a ground-truth array. In general, the target is different from the ground-truth. @@ -37,6 +67,16 @@ def create_target(self, gt: "Array") -> "Array": (e.g., because it predicts the distance to a boundary, up to a certain threshold), you can request a larger ground-truth region. See method ``gt_region_for_roi``. + + Args: + gt: The ground-truth array. + Returns: + The target array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_target(gt) + """ pass @@ -48,8 +88,22 @@ def create_weight( mask: "Array", moving_class_counts: Any, ) -> Tuple["Array", Any]: - """Create the weight array for training, given a ground-truth and + """ + Create the weight array for training, given a ground-truth and associated target array. + + Args: + gt: The ground-truth array. + target: The target array. + mask: The mask array. + moving_class_counts: The moving class counts. + Returns: + The weight array and the moving class counts. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.create_weight(gt, target, mask, moving_class_counts) + """ pass @@ -59,12 +113,37 @@ def output_array_type(self): pass def gt_region_for_roi(self, target_spec): - """Report how much spatial context this predictor needs to generate a + """ + Report how much spatial context this predictor needs to generate a target for the given ROI. By default, uses the same ROI. Overwrite this method to request ground-truth in a larger ROI, as - needed.""" + needed. + + Args: + target_spec: The ROI for which the target is requested. + Returns: + The ROI for which the ground-truth is requested. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.gt_region_for_roi(target_spec) + + + """ return target_spec def padding(self, gt_voxel_size: Coordinate) -> Coordinate: + """ + Return the padding needed for the ground-truth array. + + Args: + gt_voxel_size: The voxel size of the ground-truth array. + Returns: + The padding needed for the ground-truth array. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> predictor.padding(gt_voxel_size) + """ return Coordinate((0,) * gt_voxel_size.dims) diff --git a/dacapo/experiments/tasks/pretrained_task.py b/dacapo/experiments/tasks/pretrained_task.py index 1be9b57c0..b8ae83fd5 100644 --- a/dacapo/experiments/tasks/pretrained_task.py +++ b/dacapo/experiments/tasks/pretrained_task.py @@ -4,7 +4,30 @@ class PretrainedTask(Task): + """ + A task that uses a pretrained model. The model is loaded from a file + and the weights are loaded from a file. + + Attributes: + weights (Path): The path to the weights file. + Methods: + create_model(self, architecture) -> Model: This method creates a model using the given architecture. + Notes: + This is a base class for all tasks that use pretrained models. + + """ + def __init__(self, task_config): + """ + Initialize the PretrainedTask object. + + Args: + task_config (TaskConfig): The configuration of the task. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> task = PretrainedTask(task_config) + """ sub_task = task_config.sub_task_config.task_type(task_config.sub_task_config) self.weights = task_config.weights @@ -14,6 +37,19 @@ def __init__(self, task_config): self.evaluator = sub_task.evaluator def create_model(self, architecture): + """ + Create a model using the given architecture. + + Args: + architecture (str): The architecture of the model. + Returns: + Model: The model created using the given architecture. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> model = task.create_model(architecture) + + """ model = self.predictor.create_model(architecture) saved_state_dict = torch.load(str(self.weights)) diff --git a/dacapo/experiments/tasks/pretrained_task_config.py b/dacapo/experiments/tasks/pretrained_task_config.py index 6f7263a21..195ccd742 100644 --- a/dacapo/experiments/tasks/pretrained_task_config.py +++ b/dacapo/experiments/tasks/pretrained_task_config.py @@ -3,12 +3,23 @@ from .pretrained_task import PretrainedTask from .task_config import TaskConfig -from pathlib import Path +from upath import UPath as Path @attr.s class PretrainedTaskConfig(TaskConfig): - """ """ + """ + Configuration for a task that uses a pretrained model. The model is loaded from a file + and the weights are loaded from a file. + + Attributes: + sub_task_config (TaskConfig): The task to run starting with the provided pretrained weights. + weights (Path): A checkpoint containing pretrained model weights. + Methods: + verify(self) -> Tuple[bool, str]: This method verifies the PretrainedTaskConfig object. + Notes: + This is a subclass of TaskConfig. + """ task_type = PretrainedTask diff --git a/dacapo/experiments/tasks/task_config.py b/dacapo/experiments/tasks/task_config.py index bdfbe8579..4afef4ec1 100644 --- a/dacapo/experiments/tasks/task_config.py +++ b/dacapo/experiments/tasks/task_config.py @@ -5,8 +5,19 @@ @attr.s class TaskConfig: - """Base class for task configurations. Each subclass of a `Task` should + """ + Base class for task configurations. Each subclass of a `Task` should have a corresponding config class derived from `TaskConfig`. + + Attributes: + name: A unique name for this task. This will be saved so you and + others can find and reuse this task. Keep it short and avoid + special characters. + Methods: + verify(self) -> Tuple[bool, str]: This method verifies the TaskConfig object. + Notes: + This is a base class for all task configurations. It is not meant to be + used directly. """ name: str = attr.ib( @@ -20,5 +31,13 @@ class TaskConfig: def verify(self) -> Tuple[bool, str]: """ Check whether this is a valid Task + + Returns: + Tuple[bool, str]: A tuple containing a boolean value indicating whether the TaskConfig object is valid + and a string containing the reason why the object is invalid. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> valid, reason = task_config.verify() """ return True, "No validation for this Task" diff --git a/dacapo/experiments/trainers/dummy_trainer.py b/dacapo/experiments/trainers/dummy_trainer.py index 7e826979a..6c35424d0 100644 --- a/dacapo/experiments/trainers/dummy_trainer.py +++ b/dacapo/experiments/trainers/dummy_trainer.py @@ -7,17 +7,77 @@ class DummyTrainer(Trainer): + """ + This class is used to train a model using dummy data and is used for testing purposes. It contains attributes + related to learning rate, batch size, and mirror augment. It also contains methods to create an optimizer, iterate + over the training data, build a batch provider, and check if the trainer can train on the given data split. This class + contains methods to enter and exit the context manager. The iterate method yields training iteration statistics. + + Attributes: + learning_rate (float): The learning rate to use. + batch_size (int): The batch size to use. + mirror_augment (bool): A boolean value indicating whether to use mirror augmentation or not. + Methods: + __init__(self, trainer_config): This method initializes the DummyTrainer object. + create_optimizer(self, model): This method creates an optimizer for the given model. + iterate(self, num_iterations: int, model, optimizer, device): This method iterates over the training data for the specified number of iterations. + build_batch_provider(self, datasplit, architecture, task, snapshot_container): This method builds a batch provider for the given data split, architecture, task, and snapshot container. + can_train(self, datasplit): This method checks if the trainer can train on the given data split. + __enter__(self): This method enters the context manager. + __exit__(self, exc_type, exc_val, exc_tb): This method exits the context manager. + Note: + The iterate method yields TrainingIterationStats. + """ + iteration = 0 def __init__(self, trainer_config): + """ + Initialize the DummyTrainer object. + + Args: + trainer_config (TrainerConfig): The configuration object for the trainer. + Returns: + DummyTrainer: The DummyTrainer object. + Examples: + >>> trainer = DummyTrainer(trainer_config) + + """ self.learning_rate = trainer_config.learning_rate self.batch_size = trainer_config.batch_size self.mirror_augment = trainer_config.mirror_augment def create_optimizer(self, model): + """ + Create an optimizer for the given model. + + Args: + model (Model): The model to optimize. + Returns: + torch.optim.Optimizer: The optimizer object. + Examples: + >>> optimizer = create_optimizer(model) + + """ return torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) def iterate(self, num_iterations: int, model: Model, optimizer, device): + """ + Iterate over the training data for the specified number of iterations. + + Args: + num_iterations (int): The number of iterations to perform. + model (Model): The model to train. + optimizer (torch.optim.Optimizer): The optimizer to use. + device (torch.device): The device to perform the computations on. + Yields: + TrainingIterationStats: The training iteration statistics. + Raises: + ValueError: If the number of iterations is less than or equal to zero. + Examples: + >>> for stats in iterate(num_iterations, model, optimizer, device): + >>> print(stats) + """ target_iteration = self.iteration + num_iterations for iteration in range(self.iteration, target_iteration): @@ -46,13 +106,66 @@ def iterate(self, num_iterations: int, model: Model, optimizer, device): self.iteration += 1 def build_batch_provider(self, datasplit, architecture, task, snapshot_container): + """ + Build a batch provider for the given data split, architecture, task, and snapshot container. + + Args: + datasplit (DataSplit): The data split to use. + architecture (Architecture): The architecture to use. + task (Task): The task to perform. + snapshot_container (SnapshotContainer): The snapshot container to use. + Returns: + BatchProvider: The batch provider object. + Raises: + ValueError: If the task loss is not set. + Examples: + >>> batch_provider = build_batch_provider(datasplit, architecture, task, snapshot_container) + + """ self._loss = task.loss def can_train(self, datasplit): + """ + Check if the trainer can train on the given data split. + + Args: + datasplit (DataSplit): The data split to check. + Returns: + bool: True if the trainer can train on the data split, False otherwise. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> can_train(datasplit) + + """ return True def __enter__(self): + """ + Enter the context manager. + + Returns: + DummyTrainer: The trainer object. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> with trainer as t: + >>> print(t) + """ return self def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit the context manager. + + Args: + exc_type: The type of the exception. + exc_val: The exception value. + exc_tb: The exception traceback. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> with trainer as t: + >>> print(t) + """ pass diff --git a/dacapo/experiments/trainers/dummy_trainer_config.py b/dacapo/experiments/trainers/dummy_trainer_config.py index b6b64412f..5a6a92029 100644 --- a/dacapo/experiments/trainers/dummy_trainer_config.py +++ b/dacapo/experiments/trainers/dummy_trainer_config.py @@ -8,12 +8,31 @@ @attr.s class DummyTrainerConfig(TrainerConfig): - """This is just a dummy trainer config used for testing. None of the - attributes have any particular meaning.""" + """ + This is just a dummy trainer config used for testing. None of the + attributes have any particular meaning. This is just to test the trainer + and the trainer config. + + Attributes: + mirror_augment (bool): A boolean value indicating whether to use mirror + augmentation or not. + Methods: + verify(self) -> Tuple[bool, str]: This method verifies the DummyTrainerConfig object. + + """ trainer_type = DummyTrainer mirror_augment: bool = attr.ib(metadata={"help_text": "Dummy attribute."}) def verify(self) -> Tuple[bool, str]: + """ + Verify the DummyTrainerConfig object. + + Returns: + Tuple[bool, str]: A tuple containing a boolean value indicating whether the DummyTrainerConfig object is valid + and a string containing the reason why the object is invalid. + Examples: + >>> valid, reason = trainer_config.verify() + """ return False, "This is a DummyTrainerConfig and is never valid" diff --git a/dacapo/experiments/trainers/gp_augments/augment_config.py b/dacapo/experiments/trainers/gp_augments/augment_config.py index c46e2a1ee..dcca4142b 100644 --- a/dacapo/experiments/trainers/gp_augments/augment_config.py +++ b/dacapo/experiments/trainers/gp_augments/augment_config.py @@ -10,6 +10,14 @@ class AugmentConfig(ABC): """ Base class for gunpowder augment configurations. Each subclass of a `Augment` should have a corresponding config class derived from `AugmentConfig`. + + Attributes: + _raw_key: Key for raw data. Not used in this implementation. Defaults to None. + _gt_key: Key for ground truth data. Not used in this implementation. Defaults to None. + _mask_key: Key for mask data. Not used in this implementation. Defaults to None. + Methods: + node(_raw_key=None, _gt_key=None, _mask_key=None): Get a gp.Augment node. + """ @abstractmethod @@ -17,6 +25,18 @@ def node( self, raw_key: gp.ArrayKey, gt_key: gp.ArrayKey, mask_key: gp.ArrayKey ) -> gp.BatchFilter: """ - return a gunpowder node that performs this augmentation + Get a gunpowder augment node. + + Args: + raw_key (gp.ArrayKey): Key for raw data. + gt_key (gp.ArrayKey): Key for ground truth data. + mask_key (gp.ArrayKey): Key for mask data. + Returns: + gunpowder.BatchFilter : Augmentation node which can be incorporated in the pipeline. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> node = augment_config.node(raw_key, gt_key, mask_key) + """ pass diff --git a/dacapo/experiments/trainers/gp_augments/elastic_config.py b/dacapo/experiments/trainers/gp_augments/elastic_config.py index 55d25a02a..40f40e800 100644 --- a/dacapo/experiments/trainers/gp_augments/elastic_config.py +++ b/dacapo/experiments/trainers/gp_augments/elastic_config.py @@ -22,6 +22,10 @@ class ElasticAugmentConfig(AugmentConfig): on a grid. Default is 1. uniform_3d_rotation (bool): Should 3D rotations be performed uniformly. The 'rotation_interval' will be ignored. Default is False. + Methods: + node(_raw_key=None, _gt_key=None, _mask_key=None): Returns the object of ElasticAugment with the given + configuration details. + """ control_point_spacing: List[int] = attr.ib( @@ -67,11 +71,15 @@ def node(self, _raw_key=None, _gt_key=None, _mask_key=None): _raw_key: Unused variable, kept for future use. _gt_key: Unused variable, kept for future use. _mask_key: Unused variable, kept for future use. - Returns: ElasticAugment: A ElasticAugment object configured with `control_point_spacing`, `control_point_displacement_sigma`, `rotation_interval`, `subsample` and `uniform_3d_rotation`. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> node = elastic_augment_config.node() + """ return ElasticAugment( control_point_spacing=self.control_point_spacing, diff --git a/dacapo/experiments/trainers/gp_augments/gamma_config.py b/dacapo/experiments/trainers/gp_augments/gamma_config.py index 434a2986d..8eed470b5 100644 --- a/dacapo/experiments/trainers/gp_augments/gamma_config.py +++ b/dacapo/experiments/trainers/gp_augments/gamma_config.py @@ -16,7 +16,6 @@ class GammaAugmentConfig(AugmentConfig): Attributes: gamma_range: A tuple of float values represents the min and max range of gamma noise to apply on the raw data. - Methods: node(): Constructs a node in the augmentation pipeline. """ @@ -35,9 +34,12 @@ def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None): raw_key (gp.ArrayKey): Key to an Array (volume) in the pipeline _gt_key (gp.ArrayKey, optional): Ground Truth key, not used in this function. Defaults to None. _mask_key (gp.ArrayKey, optional): Mask Key, not used in this function. Defaults to None. - Returns: GammaAugment instance: The augmentation method to be applied on the source data. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> node = gamma_augment_config.node(raw_key) """ return GammaAugment( [raw_key], gamma_min=self.gamma_range[0], gamma_max=self.gamma_range[1] diff --git a/dacapo/experiments/trainers/gp_augments/intensity_config.py b/dacapo/experiments/trainers/gp_augments/intensity_config.py index 105336be8..fef1b26df 100644 --- a/dacapo/experiments/trainers/gp_augments/intensity_config.py +++ b/dacapo/experiments/trainers/gp_augments/intensity_config.py @@ -9,6 +9,18 @@ @attr.s class IntensityAugmentConfig(AugmentConfig): + """ + This class is an implementation of AugmentConfig that applies intensity augmentations. + + Attributes: + scale (Tuple[float, float]): A range within which to choose a random scale factor. + shift (Tuple[float, float]): A range within which to choose a random additive shift. + clip (bool): Set to False if modified values should not be clipped to [0, 1] + Methods: + node(raw_key, _gt_key=None, _mask_key=None): Get a gp.IntensityAugment node. + + """ + scale: Tuple[float, float] = attr.ib( metadata={"help_text": "A range within which to choose a random scale factor."} ) @@ -25,6 +37,20 @@ class IntensityAugmentConfig(AugmentConfig): ) def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None): + """ + Get a gp.IntensityAugment node. + + Args: + raw_key (gp.ArrayKey): Key for raw data. + _gt_key ([type], optional): Specific key for ground truth data, not used in this implementation. Defaults to None. + _mask_key ([type], optional): Specific key for mask data, not used in this implementation. Defaults to None. + Returns: + gunpowder.IntensityAugment : Intensity augmentation node which can be incorporated in the pipeline. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> node = intensity_augment_config.node(raw_key) + """ return gp.IntensityAugment( raw_key, scale_min=self.scale[0], diff --git a/dacapo/experiments/trainers/gp_augments/intensity_scale_shift_config.py b/dacapo/experiments/trainers/gp_augments/intensity_scale_shift_config.py index 081b15066..4a033406b 100644 --- a/dacapo/experiments/trainers/gp_augments/intensity_scale_shift_config.py +++ b/dacapo/experiments/trainers/gp_augments/intensity_scale_shift_config.py @@ -7,6 +7,18 @@ @attr.s class IntensityScaleShiftAugmentConfig(AugmentConfig): + """ + This class is an implementation of AugmentConfig that applies intensity scaling and shifting. + + Attributes: + scale (float): A constant to scale your intensities. + shift (float): A constant to shift your intensities. + Methods: + node(raw_key, _gt_key=None, _mask_key=None): Get a gp.IntensityScaleShift node. + Note: + This class is a subclass of AugmentConfig. + """ + scale: float = attr.ib( metadata={"help_text": "A constant to scale your intensities."} ) @@ -15,4 +27,18 @@ class IntensityScaleShiftAugmentConfig(AugmentConfig): ) def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None): + """ + Get a gp.IntensityScaleShift node. + + Args: + raw_key (gp.ArrayKey): Key for raw data. + _gt_key ([type], optional): Specific key for ground truth data, not used in this implementation. Defaults to None. + _mask_key ([type], optional): Specific key for mask data, not used in this implementation. Defaults to None. + Returns: + gunpowder.IntensityScaleShift : Intensity scaling and shifting node which can be incorporated in the pipeline. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> node = intensity_scale_shift_augment_config.node(raw_key) + """ return gp.IntensityScaleShift(raw_key, scale=self.scale, shift=self.shift) diff --git a/dacapo/experiments/trainers/gp_augments/simple_config.py b/dacapo/experiments/trainers/gp_augments/simple_config.py index f214883df..77c8e6e5a 100644 --- a/dacapo/experiments/trainers/gp_augments/simple_config.py +++ b/dacapo/experiments/trainers/gp_augments/simple_config.py @@ -14,9 +14,10 @@ class SimpleAugmentConfig(AugmentConfig): _raw_key: Key for raw data. Not used in this implementation. Defaults to None. _gt_key: Key for ground truth data. Not used in this implementation. Defaults to None. _mask_key: Key for mask data. Not used in this implementation. Defaults to None. - - Returns: - Gunpowder SimpleAugment Node: A node that can be included in a pipeline to perform simple data augmentations. + Methods: + node(_raw_key=None, _gt_key=None, _mask_key=None): Get a gp.SimpleAugment node. + Note: + This class is a subclass of AugmentConfig. """ def node(self, _raw_key=None, _gt_key=None, _mask_key=None): @@ -27,8 +28,12 @@ def node(self, _raw_key=None, _gt_key=None, _mask_key=None): _raw_key ([type], optional): Specific key for raw data, not used in this implementation. Defaults to None. _gt_key ([type], optional): Specific key for ground truth data, not used in this implementation. Defaults to None. _mask_key ([type], optional): Specific key for mask data, not used in this implementation. Defaults to None. - Returns: gunpowder.SimpleAugment : Simple augmentation node which can be incorporated in the pipeline. + Raises: + NotImplementedError: This method is not implemented. + Examples: + >>> node = simple_augment_config.node() + """ return gp.SimpleAugment() diff --git a/dacapo/experiments/trainers/gunpowder_trainer.py b/dacapo/experiments/trainers/gunpowder_trainer.py index 74fb9b807..4916e557f 100644 --- a/dacapo/experiments/trainers/gunpowder_trainer.py +++ b/dacapo/experiments/trainers/gunpowder_trainer.py @@ -28,9 +28,59 @@ class GunpowderTrainer(Trainer): + """ + GunpowderTrainer class for training a model using gunpowder. This class is a subclass of the Trainer class. It + implements the abstract methods defined in the Trainer class. The GunpowderTrainer class is used to train a model + using gunpowder, a data loading and augmentation library. It is used to train a model on a dataset using a specific + task. + + Attributes: + learning_rate (float): The learning rate for the optimizer. + batch_size (int): The size of the training batch. + num_data_fetchers (int): The number of data fetchers. + print_profiling (int): The number of iterations after which to print profiling stats. + snapshot_iteration (int): The number of iterations after which to save a snapshot. + min_masked (float): The minimum value of the mask. + augments (List[Augment]): The list of augmentations to apply to the data. + mask_integral_downsample_factor (int): The downsample factor for the mask integral. + clip_raw (bool): Whether to clip the raw data. + scheduler (torch.optim.lr_scheduler.LinearLR): The learning rate scheduler. + Methods: + create_optimizer(model: Model) -> torch.optim.Optimizer: + Creates an optimizer for the model. + build_batch_provider(datasets: List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: + Initializes the training pipeline using various components. + iterate(num_iterations: int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: + Performs a number of training iterations. + __iter__() -> Iterator[None]: + Initializes the training pipeline. + next() -> Tuple[NumpyArray, NumpyArray, NumpyArray, NumpyArray, NumpyArray]: + Fetches the next batch of data. + __enter__() -> GunpowderTrainer: + Enters the context manager. + __exit__(exc_type, exc_val, exc_tb) -> None: + Exits the context manager. + can_train(datasets: List[Dataset]) -> bool: + Checks if the trainer can train with a specific set of datasets. + Note: + The GunpowderTrainer class is a subclass of the Trainer class. It is used to train a model using gunpowder. + + """ + iteration = 0 def __init__(self, trainer_config): + """ + Initializes the GunpowderTrainer object. + + Args: + trainer_config (TrainerConfig): The trainer configuration. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> trainer = GunpowderTrainer(trainer_config) + + """ self.learning_rate = trainer_config.learning_rate self.batch_size = trainer_config.batch_size self.num_data_fetchers = trainer_config.num_data_fetchers @@ -45,7 +95,24 @@ def __init__(self, trainer_config): self.scheduler = None def create_optimizer(self, model): - optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) + """ + Creates an optimizer for the model. + + Args: + model (Model): The model for which the optimizer will be created. + Returns: + torch.optim.Optimizer: The optimizer created for the model. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> optimizer = trainer.create_optimizer(model) + + """ + optimizer = torch.optim.RAdam( + lr=self.learning_rate, + params=model.parameters(), + decoupled_weight_decay=True, + ) self.scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.01, @@ -56,6 +123,20 @@ def create_optimizer(self, model): return optimizer def build_batch_provider(self, datasets, model, task, snapshot_container=None): + """ + Initializes the training pipeline using various components. + + Args: + datasets (List[Dataset]): The list of datasets. + model (Model): The model to be trained. + task (Task): The task to be performed. + snapshot_container (LocalContainerIdentifier): The snapshot container. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> trainer.build_batch_provider(datasets, model, task, snapshot_container) + + """ input_shape = Coordinate(model.input_shape) output_shape = Coordinate(model.output_shape) @@ -194,6 +275,23 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None): self.snapshot_container = snapshot_container def iterate(self, num_iterations, model, optimizer, device): + """ + Performs a number of training iterations. + + Args: + num_iterations (int): The number of training iterations. + model (Model): The model to be trained. + optimizer (torch.optim.Optimizer): The optimizer for the model. + device (torch.device): The device (GPU/CPU) where the model will be trained. + Returns: + Iterator[TrainingIterationStats]: An iterator of the training statistics. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): + >>> print(iteration_stats) + + """ t_start_fetch = time.time() print("Starting iteration!") @@ -296,6 +394,17 @@ def iterate(self, num_iterations, model, optimizer, device): t_start_fetch = time.time() def __iter__(self): + """ + Initializes the training pipeline. + + Returns: + Iterator[None]: An iterator of None. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> for _ in trainer: + >>> pass + """ with gp.build(self._pipeline): teardown = False while not teardown: @@ -305,6 +414,17 @@ def __iter__(self): yield None def next(self): + """ + Fetches the next batch of data. + + Returns: + Tuple[NumpyArray, NumpyArray, NumpyArray, NumpyArray, NumpyArray]: A tuple containing the raw data, ground truth data, target data, weight data, and mask data. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> raw, gt, target, weight, mask = trainer.next() + + """ batch = next(self._iter) self._iter.send(False) return ( @@ -320,10 +440,34 @@ def next(self): ) def __enter__(self): + """ + Enters the context manager. + + Returns: + GunpowderTrainer: The GunpowderTrainer object. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> with trainer: + >>> pass + """ self._iter = iter(self) return self def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exits the context manager. + + Args: + exc_type: The exception type. + exc_val: The exception value. + exc_tb: The exception traceback. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> with trainer: + >>> pass + """ try: self._iter.send(True) except TypeError: @@ -331,4 +475,17 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def can_train(self, datasets) -> bool: + """ + Checks if the trainer can train with a specific set of datasets. + + Args: + datasets (List[Dataset]): The list of datasets. + Returns: + bool: True if the trainer can train with the datasets, False otherwise. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> can_train = trainer.can_train(datasets) + + """ return all([dataset.gt is not None for dataset in datasets]) diff --git a/dacapo/experiments/trainers/trainer.py b/dacapo/experiments/trainers/trainer.py index 5e3bee653..5a8def054 100644 --- a/dacapo/experiments/trainers/trainer.py +++ b/dacapo/experiments/trainers/trainer.py @@ -12,11 +12,28 @@ class Trainer(ABC): - """Trainer Abstract Base Class + """ + Trainer Abstract Base Class This serves as the blueprint for any trainer classes in the dacapo library. It defines essential methods that every subclass must implement for effective training of a neural network model. + + Attributes: + iteration (int): The number of training iterations. + batch_size (int): The size of the training batch. + learning_rate (float): The learning rate for the optimizer. + Methods: + create_optimizer(model: Model) -> torch.optim.Optimizer: + Creates an optimizer for the model. + iterate(num_iterations: int, model: Model, optimizer: torch.optim.Optimizer, device: torch.device) -> Iterator[TrainingIterationStats]: + Performs a number of training iterations. + can_train(datasets: List[Dataset]) -> bool: + Checks if the trainer can train with a specific set of datasets. + build_batch_provider(datasets: List[Dataset], model: Model, task: Task, snapshot_container: LocalContainerIdentifier) -> None: + Initializes the training pipeline using various components. + Note: + The Trainer class is an abstract class that cannot be instantiated directly. It is meant to be subclassed. """ iteration: int @@ -25,13 +42,19 @@ class Trainer(ABC): @abstractmethod def create_optimizer(self, model: "Model") -> torch.optim.Optimizer: - """Creates an optimizer for the model. + """ + Creates an optimizer for the model. Args: model (Model): The model for which the optimizer will be created. - Returns: torch.optim.Optimizer: The optimizer created for the model. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> optimizer = trainer.create_optimizer(model) + Note: + This method must be implemented by the subclass. """ pass @@ -43,30 +66,43 @@ def iterate( optimizer: torch.optim.Optimizer, device: torch.device, ) -> Iterator["TrainingIterationStats"]: - """Performs a number of training iterations. + """ + Performs a number of training iterations. Args: num_iterations (int): Number of training iterations. model (Model): The model to be trained. optimizer (torch.optim.Optimizer): The optimizer for the model. device (torch.device): The device (GPU/CPU) where the model will be trained. - Returns: Iterator[TrainingIterationStats]: An iterator of the training statistics. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> for iteration_stats in trainer.iterate(num_iterations, model, optimizer, device): + >>> print(iteration_stats) + Note: + This method must be implemented by the subclass. """ pass @abstractmethod def can_train(self, datasets: List["Dataset"]) -> bool: - """Checks if the trainer can train with a specific set of datasets. + """ + Checks if the trainer can train with a specific set of datasets. Some trainers may have specific requirements for their training datasets. Args: datasets (List[Dataset]): The training datasets. - Returns: bool: True if the trainer can train on the given datasets, False otherwise. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> can_train = trainer.can_train(datasets) + Note: + This method must be implemented by the subclass. """ pass @@ -88,15 +124,44 @@ def build_batch_provider( model (Model): The model to inform the pipeline of required input/output sizes. task (Task): The task to transform ground truth into target. snapshot_container (LocalContainerIdentifier): Defines where snapshots will be saved. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> trainer.build_batch_provider(datasets, model, task, snapshot_container) + Note: + This method must be implemented by the subclass. """ pass @abstractmethod def __enter__(self): - """Defines the functionality of the '__enter__' method for use in a 'with' statement.""" + """ + Defines the functionality of the '__enter__' method for use in a 'with' statement. + + Returns: + Trainer: The trainer object. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> with trainer as t: + >>> print(t) + + """ return self @abstractmethod def __exit__(self, exc_type, exc_val, exc_tb): - """Defines the functionality of the '__exit__' method for use in a 'with' statement.""" + """ + Defines the functionality of the '__exit__' method for use in a 'with' statement. + + Args: + exc_type: The type of exception raised. + exc_val: The exception value. + exc_tb: The traceback. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> with trainer as t: + >>> print(t) + """ pass diff --git a/dacapo/experiments/trainers/trainer_config.py b/dacapo/experiments/trainers/trainer_config.py index d30781285..8158ace2f 100644 --- a/dacapo/experiments/trainers/trainer_config.py +++ b/dacapo/experiments/trainers/trainer_config.py @@ -15,6 +15,11 @@ class TrainerConfig: name (str): A unique name for this trainer. batch_size (int): The batch size to be used during training. learning_rate (float): The learning rate of the optimizer. + Methods: + verify() -> Tuple[bool, str]: + Verify whether this TrainerConfig is valid or not. + Note: + The TrainerConfig class is an abstract class that cannot be instantiated directly. It is meant to be subclassed. """ name: str = attr.ib( @@ -39,9 +44,20 @@ class TrainerConfig: def verify(self) -> Tuple[bool, str]: """ Verify whether this TrainerConfig is valid or not. + A TrainerConfig is considered valid if it has a valid batch size and learning rate. Returns: tuple: A tuple containing a boolean indicating whether the TrainerConfig is valid and a message explaining why. + Raises: + NotImplementedError: If the method is not implemented by the subclass. + Examples: + >>> valid, message = trainer_config.verify() + >>> valid + True + >>> message + "No validation for this Trainer" + Note: + This method must be implemented by the subclass. """ return True, "No validation for this Trainer" diff --git a/dacapo/experiments/training_iteration_stats.py b/dacapo/experiments/training_iteration_stats.py index 5d0507bbd..4f193a8eb 100644 --- a/dacapo/experiments/training_iteration_stats.py +++ b/dacapo/experiments/training_iteration_stats.py @@ -4,12 +4,16 @@ @attr.s class TrainingIterationStats: """ - A class to represent the training iteration statistics. + A class to represent the training iteration statistics. It contains the loss and time taken for each iteration. Attributes: iteration (int): The iteration that produced these stats. loss (float): The loss value of this iteration. time (float): The time it took to process this iteration. + Note: + The iteration stats list is structured as follows: + - The outer list contains the stats for each iteration. + - The inner list contains the stats for each training iteration. """ diff --git a/dacapo/experiments/training_stats.py b/dacapo/experiments/training_stats.py index b2f30abf5..eef5f2c97 100644 --- a/dacapo/experiments/training_stats.py +++ b/dacapo/experiments/training_stats.py @@ -10,12 +10,14 @@ @attr.s class TrainingStats: """ - A class used to represent Training Statistics. + A class used to represent Training Statistics. It contains a list of training + iteration statistics. It also provides methods to add new iteration stats, + delete stats after a specified iteration, get the number of iterations trained + for, and convert the stats to a xarray data array. Attributes: iteration_stats: List[TrainingIterationStats] an ordered list of training stats. - Methods: add_iteration_stats(iteration_stats: TrainingIterationStats) -> None: Add a new set of iterations stats to the existing list of iteration @@ -26,6 +28,10 @@ class TrainingStats: Gets the number of iterations that the model has been trained for. to_xarray() -> xr.DataArray: Converts the iteration statistics to a xarray data array. + Note: + The iteration stats list is structured as follows: + - The outer list contains the stats for each iteration. + - The inner list contains the stats for each training iteration. """ iteration_stats: List[TrainingIterationStats] = attr.ib( @@ -39,9 +45,21 @@ def add_iteration_stats(self, iteration_stats: TrainingIterationStats) -> None: Args: iteration_stats (TrainingIterationStats): a new iteration stats object. - Raises: assert: if the new iteration stats do not follow the order of existing iteration stats. + Examples: + >>> training_stats = TrainingStats() + >>> training_stats.add_iteration_stats(TrainingIterationStats(0, 0.1)) + >>> training_stats.add_iteration_stats(TrainingIterationStats(1, 0.2)) + >>> training_stats.add_iteration_stats(TrainingIterationStats(2, 0.3)) + >>> training_stats.iteration_stats + [TrainingIterationStats(iteration=0, loss=0.1), + TrainingIterationStats(iteration=1, loss=0.2), + TrainingIterationStats(iteration=2, loss=0.3)] + Note: + The iteration stats list is structured as follows: + - The outer list contains the stats for each iteration. + - The inner list contains the stats for each training iteration. """ if len(self.iteration_stats) > 0: assert ( @@ -56,6 +74,20 @@ def delete_after(self, iteration: int) -> None: Args: iteration (int): the iteration after which the stats are to be deleted. + Raises: + assert: if the iteration number is less than the maximum iteration number. + Examples: + >>> training_stats = TrainingStats() + >>> training_stats.add_iteration_stats(TrainingIterationStats(0, 0.1)) + >>> training_stats.add_iteration_stats(TrainingIterationStats(1, 0.2)) + >>> training_stats.add_iteration_stats(TrainingIterationStats(2, 0.3)) + >>> training_stats.delete_after(1) + >>> training_stats.iteration_stats + [TrainingIterationStats(iteration=0, loss=0.1)] + Note: + The iteration stats list is structured as follows: + - The outer list contains the stats for each iteration. + - The inner list contains the stats for each training iteration. """ self.iteration_stats = [ stats for stats in self.iteration_stats if stats.iteration < iteration @@ -68,6 +100,19 @@ def trained_until(self) -> int: Returns: int: number of iterations that the model has been trained for. + Raises: + assert: if the iteration stats list is empty. + Examples: + >>> training_stats = TrainingStats() + >>> training_stats.add_iteration_stats(TrainingIterationStats(0, 0.1)) + >>> training_stats.add_iteration_stats(TrainingIterationStats(1, 0.2)) + >>> training_stats.add_iteration_stats(TrainingIterationStats(2, 0.3)) + >>> training_stats.trained_until() + 3 + Note: + The iteration stats list is structured as follows: + - The outer list contains the stats for each iteration. + - The inner list contains the stats for each training iteration. """ if not self.iteration_stats: return 0 @@ -79,6 +124,22 @@ def to_xarray(self) -> xr.DataArray: Returns: xr.DataArray: xarray DataArray of iteration losses. + Raises: + assert: if the iteration stats list is empty. + Examples: + >>> training_stats = TrainingStats() + >>> training_stats.add_iteration_stats(TrainingIterationStats(0, 0.1)) + >>> training_stats.add_iteration_stats(TrainingIterationStats(1, 0.2)) + >>> training_stats.add_iteration_stats(TrainingIterationStats(2, 0.3)) + >>> training_stats.to_xarray() + + array([0.1, 0.2, 0.3]) + Coordinates: + * iterations (iterations) int64 0 1 2 + Note: + The iteration stats list is structured as follows: + - The outer list contains the stats for each iteration. + - The inner list contains the stats for each training iteration. """ return xr.DataArray( np.array( diff --git a/dacapo/experiments/validation_iteration_scores.py b/dacapo/experiments/validation_iteration_scores.py index 6e2133cce..e4f2932ea 100644 --- a/dacapo/experiments/validation_iteration_scores.py +++ b/dacapo/experiments/validation_iteration_scores.py @@ -11,6 +11,11 @@ class ValidationIterationScores: iteration (int): The iteration associated with these validation scores. scores (List[List[List[float]]]): A list of scores per dataset, post processor parameters, and evaluation criterion. + Note: + The scores list is structured as follows: + - The outer list contains the scores for each dataset. + - The middle list contains the scores for each post processor parameter. + - The inner list contains the scores for each evaluation criterion. """ diff --git a/dacapo/experiments/validation_scores.py b/dacapo/experiments/validation_scores.py index 17727cc22..18e2df029 100644 --- a/dacapo/experiments/validation_scores.py +++ b/dacapo/experiments/validation_scores.py @@ -11,6 +11,31 @@ @attr.s class ValidationScores: + """ + Class representing the validation scores for a set of parameters and datasets. + + Attributes: + parameters (List[PostProcessorParameters]): The list of parameters that are being evaluated. + datasets (List[Dataset]): The datasets that will be evaluated at each iteration. + evaluation_scores (EvaluationScores): The scores that are collected on each iteration per + `PostProcessorParameters` and `Dataset`. + scores (List[ValidationIterationScores]): A list of evaluation scores and their associated + post-processing parameters. + Methods: + subscores(iteration_scores): Create a new ValidationScores object with a subset of the iteration scores. + add_iteration_scores(iteration_scores): Add iteration scores to the list of scores. + delete_after(iteration): Delete scores after a specified iteration. + validated_until(): Get the number of iterations validated for (the maximum iteration plus one). + compare(existing_iteration_scores): Compare iteration stats provided from elsewhere to scores we have saved locally. + criteria(): Get the list of evaluation criteria. + parameter_names(): Get the list of parameter names. + to_xarray(): Convert the validation scores to an xarray DataArray. + get_best(data, dim): Compute the Best scores along dimension "dim" per criterion. + Notes: + The `scores` attribute is a list of `ValidationIterationScores` objects, each of which + contains the scores for a single iteration. + """ + parameters: List[PostProcessorParameters] = attr.ib( metadata={"help_text": "The list of parameters that are being evaluated"} ) @@ -33,6 +58,23 @@ class ValidationScores: def subscores( self, iteration_scores: List[ValidationIterationScores] ) -> "ValidationScores": + """ + Create a new ValidationScores object with a subset of the iteration scores. + + Args: + iteration_scores: The iteration scores to include in the new ValidationScores object. + Returns: + A new ValidationScores object with the specified iteration scores. + Raises: + ValueError: If the iteration scores are not in the list of scores. + Examples: + >>> validation_scores.subscores([validation_scores.scores[0]]) + Note: + This method is used to create a new ValidationScores object with a subset of the + iteration scores. This is useful when you want to create a new ValidationScores object + that only contains the scores up to a certain iteration. + + """ return ValidationScores( self.parameters, self.datasets, @@ -44,15 +86,54 @@ def add_iteration_scores( self, iteration_scores: ValidationIterationScores, ) -> None: + """ + Add iteration scores to the list of scores. + + Args: + iteration_scores: The iteration scores to add. + Raises: + ValueError: If the iteration scores are already in the list of scores. + Examples: + >>> validation_scores.add_iteration_scores(validation_scores.scores[0]) + Note: + This method is used to add iteration scores to the list of scores. This is useful when + you want to add scores for a new iteration to the ValidationScores object. + + """ self.scores.append(iteration_scores) def delete_after(self, iteration: int) -> None: + """ + Delete scores after a specified iteration. + + Args: + iteration: The iteration after which to delete the scores. + Raises: + ValueError: If the iteration scores are not in the list of scores. + Examples: + >>> validation_scores.delete_after(0) + Note: + This method is used to delete scores after a specified iteration. This is useful when + you want to delete scores after a certain iteration. + + """ self.scores = [scores for scores in self.scores if scores.iteration < iteration] def validated_until(self) -> int: - """The number of iterations validated for (the maximum iteration plus - one).""" + """ + Get the number of iterations validated for (the maximum iteration plus one). + Returns: + The number of iterations validated for. + Raises: + ValueError: If there are no scores. + Examples: + >>> validation_scores.validated_until() + Note: + This method is used to get the number of iterations validated for (the maximum iteration + plus one). This is useful when you want to know how many iterations have been validated. + + """ if not self.scores: return 0 return max([score.iteration for score in self.scores]) + 1 @@ -61,11 +142,26 @@ def compare( self, existing_iteration_scores: List[ValidationIterationScores] ) -> Tuple[bool, int]: """ - Compares iteration stats provided from elsewhere to scores we have saved locally. + Compare iteration stats provided from elsewhere to scores we have saved locally. Local scores take priority. If local scores are at a lower iteration than the existing ones, delete the existing ones and replace with local. If local iteration > existing iteration, just update existing scores with the last overhanging local scores. + + Args: + existing_iteration_scores: The existing iteration scores to compare with. + Returns: + A tuple indicating whether the local scores should replace the existing ones + and the existing iteration number. + Raises: + ValueError: If the iteration scores are not in the list of scores. + Examples: + >>> validation_scores.compare([validation_scores.scores[0]]) + Note: + This method is used to compare iteration stats provided from elsewhere to scores we have + saved locally. Local scores take priority. If local scores are at a lower iteration than + the existing ones, delete the existing ones and replace with local. If local iteration > + existing iteration, just update existing scores with the last overhanging local scores. """ if not existing_iteration_scores: return False, 0 @@ -80,13 +176,53 @@ def compare( @property def criteria(self) -> List[str]: + """ + Get the list of evaluation criteria. + + Returns: + The list of evaluation criteria. + Raises: + ValueError: If there are no scores. + Examples: + >>> validation_scores.criteria + Note: + This property is used to get the list of evaluation criteria. This is useful when you + want to know what criteria are being used to evaluate the scores. + """ return self.evaluation_scores.criteria @property def parameter_names(self) -> List[str]: + """ + Get the list of parameter names. + + Returns: + The list of parameter names. + Raises: + ValueError: If there are no scores. + Examples: + >>> validation_scores.parameter_names + Note: + This property is used to get the list of parameter names. This is useful when you want + to know what parameters are being used to evaluate the scores. + """ return self.parameters[0].parameter_names def to_xarray(self) -> xr.DataArray: + """ + Convert the validation scores to an xarray DataArray. + + Returns: + An xarray DataArray representing the validation scores. + Raises: + ValueError: If there are no scores. + Examples: + >>> validation_scores.to_xarray() + Note: + This method is used to convert the validation scores to an xarray DataArray. This is + useful when you want to work with the validation scores as an xarray DataArray. + + """ return xr.DataArray( np.array( [iteration_score.scores for iteration_score in self.scores] @@ -110,7 +246,25 @@ def get_best( """ Compute the Best scores along dimension "dim" per criterion. Returns both the index associated with the best value, and the - best value in two seperate arrays. + best value in two separate arrays. + + Args: + data: The data array to compute the best scores from. + dim: The dimension along which to compute the best scores. + Returns: + A tuple containing the index associated with the best value and the best value + in two separate arrays. + Raises: + ValueError: If the criteria are not in the data array. + Examples: + >>> validation_scores.get_best(data, "iterations") + Note: + This method is used to compute the Best scores along dimension "dim" per criterion. It + returns both the index associated with the best value and the best value in two separate + arrays. This is useful when you want to know the best scores for a given data array. + Fix: The method is currently not able to handle the case where the criteria are not in the data array. + To fix this, we need to add a check to see if the criteria are in the data array and raise an error if they are not. + """ if "criteria" in data.coords.keys(): if len(data.coords["criteria"].shape) > 1: diff --git a/dacapo/ext/__init__.py b/dacapo/ext/__init__.py index e0308e1fb..9efd04751 100644 --- a/dacapo/ext/__init__.py +++ b/dacapo/ext/__init__.py @@ -3,11 +3,43 @@ class NoSuchModule: + """ + This class is used to raise an exception when a module is not found. + + Attributes: + __name (str): The name of the module that was not found. + __traceback_str (str): The traceback string of the exception. + __exception (Exception): The exception raised. + Methods: + __getattr__(item): Raises the exception. + + """ + def __init__(self, name): + """ + Initializes the NoSuchModule object. + + Args: + name (str): The name of the module that was not found. + Examples: + >>> module = NoSuchModule("module") + + """ self.__name = name self.__traceback_str = traceback.format_tb(sys.exc_info()[2]) errtype, value = sys.exc_info()[:2] self.__exception = errtype(value) def __getattr__(self, item): + """ + Raises the exception. + + Args: + item: The item to get. + Raises: + Exception: The exception raised. + Examples: + >>> module.function() + + """ raise self.__exception diff --git a/dacapo/gp/copy.py b/dacapo/gp/copy.py index e0ea6e94c..f40ce29c6 100644 --- a/dacapo/gp/copy.py +++ b/dacapo/gp/copy.py @@ -9,11 +9,13 @@ class CopyMask(gp.BatchFilter): array_key (gp.ArrayKey): Original key of the array from where the mask will be copied. copy_key (gp.ArrayKey): New key where the copied mask will reside. drop_channels (bool): If True, channels will be dropped via a max collapse. - Methods: setup: Sets up the filter by enabling autoskip and providing the copied key. prepare: Prepares the filter by copying the request of copy_key into a dependency. process: Processes the batch by copying the mask from the array_key to the copy_key. + Note: + This class is a subclass of gunpowder.BatchFilter and is used to + copy a mask into a new key with the option to drop channels via max collapse. """ def __init__( @@ -26,6 +28,13 @@ def __init__( array_key (gp.ArrayKey): Original key of the array from where the mask will be copied. copy_key (gp.ArrayKey): New key where the copied mask will reside. drop_channels (bool): If True, channels will be dropped via a max collapse. Default is False. + Raises: + TypeError: If array_key is not of type gp.ArrayKey. + TypeError: If copy_key is not of type gp.ArrayKey. + Examples: + >>> array_key = gp.ArrayKey("ARRAY") + >>> copy_key = gp.ArrayKey("COPY") + >>> copy_mask = CopyMask(array_key, copy_key) """ self.array_key = array_key self.copy_key = copy_key @@ -34,6 +43,12 @@ def __init__( def setup(self): """ Sets up the filter by enabling autoskip and providing the copied key. + + Raises: + RuntimeError: If the key is already provided. + Examples: + >>> copy_mask.setup() + """ self.enable_autoskip() self.provides(self.copy_key, self.spec[self.array_key].copy()) @@ -44,9 +59,14 @@ def prepare(self, request): Args: request: The request to prepare. - Returns: deps: The prepared dependencies. + Raises: + NotImplementedError: If the copy_key is not provided. + Examples: + >>> request = gp.BatchRequest() + >>> request[self.copy_key] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1))) + >>> copy_mask.prepare(request) """ deps = gp.BatchRequest() deps[self.array_key] = request[self.copy_key].copy() @@ -61,9 +81,14 @@ def process(self, batch, request): Args: batch: The batch to process. request: The request for processing. - Returns: outputs: The processed outputs. + Raises: + KeyError: If the requested key is not in the request. + Examples: + >>> request = gp.BatchRequest() + >>> request[gp.ArrayKey("ARRAY")] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1))) + >>> copy_mask.process(batch, request) """ outputs = gp.Batch() diff --git a/dacapo/gp/dacapo_array_source.py b/dacapo/gp/dacapo_array_source.py index c00b2d504..2fb750c8b 100644 --- a/dacapo/gp/dacapo_array_source.py +++ b/dacapo/gp/dacapo_array_source.py @@ -10,20 +10,36 @@ class DaCapoArraySource(gp.BatchProvider): - """A DaCapo Array source node - - Args: - - Array (Array): - - The DaCapo Array to pull data from - - key (``gp.ArrayKey``): - - The key to provide data into + """ + A DaCapo Array source node + + Attributes: + array (Array): The array to be served. + key (gp.ArrayKey): The key of the array to be served. + Methods: + setup(): Set up the provider. + provide(request): Provides the array for the requested ROI. + Note: + This class is a subclass of gunpowder.BatchProvider and is used to + serve array data to gunpowder pipelines. """ def __init__(self, array: Array, key: gp.ArrayKey): + """ + Create a DaCapoArraySource object. + + Args: + array (Array): The array to be served. + key (gp.ArrayKey): The key of the array to be served. + Raises: + TypeError: If key is not of type gp.ArrayKey. + TypeError: If array is not of type Array. + Examples: + >>> from dacapo.experiments.datasplits.datasets.arrays import Array + >>> from gunpowder import ArrayKey + >>> array = Array() + >>> array_source = DaCapoArraySource(array, gp.ArrayKey("ARRAY")) + """ self.array = array self.array_spec = ArraySpec( roi=self.array.roi, voxel_size=self.array.voxel_size @@ -31,9 +47,31 @@ def __init__(self, array: Array, key: gp.ArrayKey): self.key = key def setup(self): + """ + Adds the key and the array spec to the provider. + + Raises: + RuntimeError: If the key is already provided. + Examples: + >>> array_source.setup() + + """ self.provides(self.key, self.array_spec.copy()) def provide(self, request): + """ + Provides data based on the given request. + + Args: + request (gp.BatchRequest): The request for data + Returns: + gp.Batch: The batch containing the provided data + Raises: + ValueError: If the input data contains NaN values + Examples: + >>> array_source.provide(request) + + """ output = gp.Batch() timing_provide = Timing(self, "provide") diff --git a/dacapo/gp/dacapo_create_target.py b/dacapo/gp/dacapo_create_target.py index f136c5c7b..13514cebc 100644 --- a/dacapo/gp/dacapo_create_target.py +++ b/dacapo/gp/dacapo_create_target.py @@ -7,21 +7,27 @@ class DaCapoTargetFilter(gp.BatchFilter): - """A Gunpowder node for generating the target from the ground truth - - Args: + """ + A Gunpowder node for generating the target from the ground truth + Attributes: Predictor (Predictor): - The DaCapo Predictor to use to transform gt into target - gt (``Array``): - The dataset to use for generating the target. - target_key (``gp.ArrayKey``): - The key with which to provide the target. + weights_key (``gp.ArrayKey``): + The key with which to provide the weights. + mask_key (``gp.ArrayKey``): + The key with which to provide the mask. + Methods: + setup(): Set up the provider. + prepare(request): Prepare the request. + process(batch, request): Process the batch. + Note: + This class is a subclass of gunpowder.BatchFilter and is used to + generate the target from the ground truth. """ def __init__( @@ -32,6 +38,32 @@ def __init__( weights_key: Optional[gp.ArrayKey] = None, mask_key: Optional[gp.ArrayKey] = None, ): + """ + Initialize the DacapoCreateTarget object. + + Args: + predictor (Predictor): The predictor object used for prediction. + gt_key (gp.ArrayKey): The ground truth key. + target_key (Optional[gp.ArrayKey]): The target key. Defaults to None. + weights_key (Optional[gp.ArrayKey]): The weights key. Defaults to None. + mask_key (Optional[gp.ArrayKey]): The mask key. Defaults to None. + Raises: + AssertionError: If neither target_key nor weights_key is provided. + Examples: + >>> from dacapo.experiments.tasks.predictors import Predictor + >>> from gunpowder import ArrayKey + >>> from gunpowder import ArrayKey + >>> from gunpowder import ArrayKey + >>> predictor = Predictor() + >>> gt_key = ArrayKey("GT") + >>> target_key = ArrayKey("TARGET") + >>> weights_key = ArrayKey("WEIGHTS") + >>> mask_key = ArrayKey("MASK") + >>> target_filter = DaCapoTargetFilter(predictor, gt_key, target_key, weights_key, mask_key) + Note: + The target filter is used to generate the target from the ground truth. + + """ self.predictor = predictor self.gt_key = gt_key self.target_key = target_key @@ -45,6 +77,16 @@ def __init__( ), "Must provide either target or weights" def setup(self): + """ + Set up the provider. This function sets the provider to provide the + target with the given key. + + Raises: + RuntimeError: If the key is already provided. + Examples: + >>> target_filter.setup() + + """ provided_spec = gp.ArraySpec( roi=self.spec[self.gt_key].roi, voxel_size=self.spec[self.gt_key].voxel_size, @@ -62,6 +104,21 @@ def setup(self): self.provides(self.weights_key, provided_spec) def prepare(self, request): + """ + Prepare the request. + + Args: + request (gp.BatchRequest): The request to prepare. + Returns: + deps (gp.BatchRequest): The dependencies. + Raises: + NotImplementedError: If the target_key is not provided. + Examples: + >>> request = gp.BatchRequest() + >>> request[gp.ArrayKey("GT")] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1))) + >>> target_filter.prepare(request) + + """ deps = gp.BatchRequest() # TODO: Does the gt depend on weights too? request_spec = None @@ -80,6 +137,19 @@ def prepare(self, request): return deps def process(self, batch, request): + """ + Process the batch. + + Args: + batch (gp.Batch): The batch to process. + request (gp.BatchRequest): The request to process. + Returns: + output (gp.Batch): The output batch. + Examples: + >>> request = gp.BatchRequest() + >>> request[gp.ArrayKey("GT")] = gp.ArraySpec(roi=gp.Roi((0, 0, 0), (1, 1, 1))) + >>> target_filter.process(batch, request) + """ output = gp.Batch() gt_array = NumpyArray.from_gp_array(batch[self.gt_key]) diff --git a/dacapo/gp/dacapo_points_source.py b/dacapo/gp/dacapo_points_source.py index 5e599d3fa..5f20d9f25 100644 --- a/dacapo/gp/dacapo_points_source.py +++ b/dacapo/gp/dacapo_points_source.py @@ -13,6 +13,13 @@ class GraphSource(gp.BatchProvider): Attributes: key (gp.GraphKey): The key of the graph to be served. graph (gp.Graph): The graph to be served. + Methods: + setup(): Set up the provider. + provide(request): Provides the graph for the requested ROI. + + Note: + This class is a subclass of gunpowder.BatchProvider and is used to + serve graph data to gunpowder pipelines. """ def __init__(self, key: gp.GraphKey, graph: gp.Graph): @@ -20,6 +27,15 @@ def __init__(self, key: gp.GraphKey, graph: gp.Graph): Args: key (gp.GraphKey): The key of the graph to be served. graph (gp.Graph): The graph to be served. + Raises: + TypeError: If key is not of type gp.GraphKey. + TypeError: If graph is not of type gp.Graph. + Examples: + >>> graph = gp.Graph() + >>> graph.add_node(1, position=[0, 0, 0]) + >>> graph.add_node(2, position=[1, 1, 1]) + >>> graph.add_edge(1, 2) + >>> graph_source = GraphSource(gp.GraphKey("GRAPH"), graph) """ self.key = key self.graph = graph @@ -28,6 +44,12 @@ def setup(self): """ Set up the provider. This function sets the provider to provide the graph with the given key. + + Raises: + RuntimeError: If the key is already provided. + Examples: + >>> graph_source.setup() + """ self.provides(self.key, self.graph.spec) @@ -42,9 +64,14 @@ def provide(self, request): Args: request (gp.BatchRequest): BatchRequest with the same ROI for each requested array and graph. - Returns: outputs (gp.Batch): The graph contained in a Batch. + Raises: + KeyError: If the requested key is not in the request. + Examples: + >>> request = gp.BatchRequest() + >>> request[gp.GraphKey("GRAPH")] = gp.GraphSpec(roi=gp.Roi((0, 0, 0), (1, 1, 1))) + >>> graph_source.provide(request) """ outputs = gp.Batch() if self.key in request: diff --git a/dacapo/gp/elastic_augment_fuse.py b/dacapo/gp/elastic_augment_fuse.py index e8d02e65d..a703559ec 100644 --- a/dacapo/gp/elastic_augment_fuse.py +++ b/dacapo/gp/elastic_augment_fuse.py @@ -16,6 +16,21 @@ def _create_identity_transformation(shape, voxel_size=None, offset=None, subsample=1): + """ + Create an identity transformation grid. + + Args: + shape (tuple): The shape of the transformation grid. + voxel_size (tuple, optional): The voxel size of the grid. Defaults to None. + offset (tuple, optional): The offset of the grid. Defaults to None. + subsample (int, optional): The subsampling factor. Defaults to 1. + Returns: + numpy.ndarray: The identity transformation grid. + Raises: + AssertionError: If the subsample is not an integer. + Examples: + >>> _create_identity_transformation((10, 10, 10)) + """ dims = len(shape) if voxel_size is None: @@ -40,6 +55,23 @@ def _create_identity_transformation(shape, voxel_size=None, offset=None, subsamp def _upscale_transformation( transformation, output_shape, interpolate_order=1, dtype=np.float32 ): + """ + Upscales a given transformation to match the specified output shape. + + Args: + transformation (ndarray): The input transformation to be upscaled. + output_shape (tuple): The desired output shape of the upscaled transformation. + interpolate_order (int, optional): The order of interpolation used during upscaling. Defaults to 1. + dtype (type, optional): The data type of the upscaled transformation. Defaults to np.float32. + Returns: + ndarray: The upscaled transformation with the specified output shape. + Raises: + AssertionError: If the transformation and output shape have different dimensions. + Examples: + >>> _upscale_transformation(transformation, (10, 10, 10)) + + """ + input_shape = transformation.shape[1:] dims = len(output_shape) @@ -59,6 +91,20 @@ def _upscale_transformation( def _rotate(point, angle): + """ + Rotate a point by a given angle. + + Args: + point (list or tuple): The coordinates of the point to rotate. + angle (float): The angle (in radians) by which to rotate the point. + Returns: + numpy.ndarray: The rotated point. + Raises: + AssertionError: If the point is not a list or tuple. + Examples: + >>> _rotate((1, 2), 0.5) + + """ res = np.array(point) res[0] = math.sin(angle) * point[1] + math.cos(angle) * point[0] res[1] = -math.sin(angle) * point[0] + math.cos(angle) * point[1] @@ -67,6 +113,24 @@ def _rotate(point, angle): def _create_rotation_transformation(shape, angle, subsample=1, voxel_size=None): + """ + Create a rotation transformation. + + Args: + shape (tuple): The shape of the input volume. + angle (float): The rotation angle in degrees. + subsample (int, optional): The subsampling factor. Defaults to 1. + voxel_size (tuple, optional): The voxel size of the input volume. Defaults to None. + Returns: + ndarray: The rotation transformation. + Raises: + AssertionError: If the subsample is not an integer. + Examples: + >>> _create_rotation_transformation((10, 10, 10), 0.5) + Notes: + The rotation is performed around the center of the volume. + + """ dims = len(shape) subsample_shape = tuple(max(1, int(s / subsample)) for s in shape) control_points = (2,) * dims @@ -101,6 +165,23 @@ def _create_rotation_transformation(shape, angle, subsample=1, voxel_size=None): def _create_uniform_3d_transformation(shape, rotation, subsample=1, voxel_size=None): + """ + Create a uniform 3D transformation. + + Args: + shape (tuple): The shape of the input volume. + rotation (Rotation): The rotation to be applied to the control points. + subsample (int, optional): The subsampling factor. Defaults to 1. + voxel_size (Coordinate, optional): The voxel size of the input volume. Defaults to None. + Returns: + ndarray: The transformed control point offsets. + Raises: + AssertionError: If the subsample is not an integer. + Examples: + >>> _create_uniform_3d_transformation((10, 10, 10), Rotation.from_euler('xyz', [0.5, 0.5, 0.5])) + Notes: + The rotation is performed around the center of the volume. + """ dims = len(shape) subsample_shape = tuple(max(1, int(s / subsample)) for s in shape) control_points = (2,) * dims @@ -135,6 +216,18 @@ def _create_uniform_3d_transformation(shape, rotation, subsample=1, voxel_size=N def _min_max_mean_std(ndarray, prefix=""): + """ + Calculate the minimum, maximum, mean, and standard deviation of the given ndarray. + + Args: + ndarray (numpy.ndarray): The input ndarray. + Returns: + str: A string containing the calculated statistics with the given prefix. + Raises: + AssertionError: If the input is not a numpy array. + Examples: + >>> _min_max_mean_std(ndarray) + """ return "" @@ -142,36 +235,26 @@ class ElasticAugment(BatchFilter): """ Elasticly deform a batch. Requests larger batches upstream to avoid data loss due to rotation and jitter. - Args: - - control_point_spacing (``tuple`` of ``int``): - - Distance between control points for the elastic deformation, in - voxels per dimension. - - control_point_displacement_sigma (``tuple`` of ``float``): - - Standard deviation of control point displacement distribution, in world coordinates. - - rotation_interval (``tuple`` of two ``floats``): - - Interval to randomly sample rotation angles from (0, 2PI). - - subsample (``int``): - - Instead of creating an elastic transformation on the full - resolution, create one sub-sampled by the given factor, and linearly - interpolate to obtain the full resolution transformation. This can - significantly speed up this node, at the expense of having visible - piecewise linear deformations for large factors. Usually, a factor - of 4 can safely be used without noticeable changes. However, the - default is 1 (i.e., no sub-sampling). - - seed (``int``): - - Set random state for reproducible results (tests only, do not use - in production code!!) + control_point_spacing (tuple of int): Distance between control points for the elastic deformation, in voxels per dimension. + control_point_displacement_sigma (tuple of float): Standard deviation of control point displacement distribution, in world coordinates. + rotation_interval (tuple of two float): Interval to randomly sample rotation angles from (0, 2PI). + subsample (int): Instead of creating an elastic transformation on the full resolution, create one sub-sampled by the given factor, and linearly interpolate to obtain the full resolution transformation. This can significantly speed up this node, at the expense of having visible piecewise linear deformations for large factors. Usually, a factor of 4 can safely be used without noticeable changes. However, the default is 1 (i.e., no sub-sampling). + seed (int): Set random state for reproducible results (tests only, do not use in production code!!) + augmentation_probability (float): Probability to apply the augmentation. + uniform_3d_rotation (bool): Use a uniform 3D rotation instead of a rotation around a random axis. + Provides: + * The arrays in the batch, deformed. + Requests: + * The arrays in the batch, enlarged such that the deformed ROI fits into + the enlarged input ROI. + Method: + setup: Set up the ElasticAugment node. + prepare: Prepare the ElasticAugment node. + process: Process the ElasticAugment node. + Notes: + This node is a port of the ElasticAugment node from the original + `gunpowder < """ def __init__( @@ -184,6 +267,23 @@ def __init__( seed=None, uniform_3d_rotation=False, ): + """ + Initialize the BatchFilter object. + + Args: + control_point_spacing (float): The spacing between control points. + control_point_displacement_sigma (float): The standard deviation of the control point displacements. + rotation_interval (tuple): A tuple containing the start and end angles for rotation. + subsample (int, optional): The subsampling factor. Defaults to 1. + augmentation_probability (float, optional): The probability of applying augmentation. Defaults to 1.0. + seed (int, optional): The seed value for random number generation. Defaults to None. + uniform_3d_rotation (bool, optional): Whether to use uniform 3D rotation. Defaults to False. + Raises: + AssertionError: If the subsample is not an integer. + Examples: + >>> ElasticAugment(control_point_spacing, control_point_displacement_sigma, rotation_interval, subsample=1, augmentation_probability=1.0, seed=None, uniform_3d_rotation=False) + + """ super(BatchFilter, self).__init__() self.control_point_spacing = control_point_spacing self.control_point_displacement_sigma = control_point_displacement_sigma @@ -211,6 +311,20 @@ def __init__( self.target_rois = {} def setup(self): + """ + Set up the object by calculating the voxel size and spatial dimensions. + + This method calculates the voxel size by finding the minimum value for each axis + from the voxel sizes of all array specs. It then sets the `voxel_size` attribute + of the object. The spatial dimensions are also set based on the dimensions of the + voxel size. + + Raises: + AssertionError: If the voxel size is not a Coordinate object. + Examples: + >>> setup() + + """ self.voxel_size = Coordinate( min(axis) for axis in zip( @@ -223,6 +337,27 @@ def setup(self): self.spatial_dims = self.voxel_size.dims def prepare(self, request): + """ + Prepares the request for augmentation. + + Args: + request: The request object containing the data to be augmented. + Raises: + AssertionError: If the key in the request is not an ArrayKey or GraphKey. + Examples: + >>> prepare(request) + Notes: + This method prepares the request for augmentation by performing the following steps: + 1. Logs the preparation details, including the transformation voxel size. + 2. Calculates the master ROI based on the total ROI. + 3. Generates a uniform random sample and determines whether to perform augmentation based on the augmentation probability. + 4. If augmentation is not required, logs the decision and returns. + 5. Snaps the master ROI to the grid based on the voxel size and calculates the master transformation. + 6. Clears the existing transformations and target ROIs. + 7. Iterates over each key in the request and prepares it for augmentation. + 8. Updates the upstream request with the modified ROI. + + """ logger.debug( logger.debug( f"{type(self).__name__} preparing request {request} with transformation voxel size {self.voxel_size}" @@ -328,6 +463,20 @@ def prepare(self, request): ) def process(self, batch, request): + """ + Process the ElasticAugment node. + + Args: + batch: The batch object containing the data to be processed. + request: The request object specifying the data to be processed. + Raises: + AssertionError: If the key in the request is not an ArrayKey or GraphKey. + Examples: + >>> process(batch, request) + Notes: + This method applies the transformation to the data in the batch and restores the original ROIs. + + """ if not self.do_augment: logger.debug( f"Process: Randomly not augmenting at all. (probability to augment: {self.augmentation_probability})" @@ -389,6 +538,20 @@ def process(self, batch, request): array.spec.roi = request[key].roi def _create_transformation(self, target_shape, offset): + """ + Create a transformation matrix for augmenting the input data. + + Args: + target_shape (tuple): The shape of the target data. + offset (tuple): The offset of the target data. + Returns: + np.ndarray: The transformation matrix. + Raises: + AssertionError: If the subsample is not an integer. + Examples: + >>> _create_transformation((10, 10, 10), (0, 0, 0)) + + """ logger.debug( f"creating displacement for shape {target_shape}, subsample {self.subsample}", ) @@ -445,15 +608,27 @@ def _create_transformation(self, target_shape, offset): return transformation def _spatial_roi(self, roi): + """ + Returns a spatial region of interest (ROI) based on the given ROI. + + Args: + roi (Roi): The input ROI. + Returns: + Roi: The spatial ROI. + Raises: + AssertionError: If the ROI is not a Roi object. + Examples: + >>> _spatial_roi(roi) + + """ return Roi( roi.get_begin()[-self.spatial_dims :], roi.get_shape()[-self.spatial_dims :] ) def _affine(self, array, scale, offset, target_roi, dtype=np.float32, order=1): - """taken from the scipy 0.18.1 doc: - https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.ndimage.affine_transform.html#scipy.ndimage.affine_transform - + """ Apply an affine transformation. + The given matrix and offset are used to find for each point in the output the corresponding coordinates in the input by an affine transformation. The value of the input at those coordinates is determined by spline interpolation of the requested order. Points outside the boundaries of the input are filled according to the given mode. @@ -467,6 +642,27 @@ def _affine(self, array, scale, offset, target_roi, dtype=np.float32, order=1): Changed in version 0.18.0: Previously, the exact interpretation of the affine transformation depended on whether the matrix was supplied as a one-dimensional or two-dimensional array. If a one-dimensional array was supplied to the matrix parameter, the output pixel value at index o was determined from the input image at position matrix * (o + offset). + If a two-dimensional array was supplied, the output pixel value at index o was determined from the input image at position + np.dot(matrix, o) + offset. This behavior was inconsistent and error-prone. As of version 0.18.0, the interpretation of + the matrix parameter is consistent, and the offset parameter is always added to the input pixel index vector. + + Args: + array (ndarray): The input array. + scale (float or ndarray): The scale factor(s). + offset (float or ndarray): The offset. + target_roi (Roi): The target region of interest. + dtype (type, optional): The data type of the output array. Defaults to np.float32. + order (int, optional): The order of interpolation. Defaults to 1. + Returns: + ndarray: The transformed array. + Raises: + AssertionError: If the scale is not a scalar or 1-D array. + Examples: + >>> _affine(array, scale, offset, target_roi) + References: + taken from the scipy 0.18.1 doc: + https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.ndimage.affine_transform.html#scipy.ndimage.affine_transform + """ ndim = array.shape[0] output = np.empty((ndim,) + target_roi.get_shape(), dtype=dtype) @@ -488,10 +684,40 @@ def _affine(self, array, scale, offset, target_roi, dtype=np.float32, order=1): return output def _shift_transformation(self, shift, transformation): + """ + Shift the given transformation. + + Args: + shift (tuple): The shift to apply to the transformation. + transformation (ndarray): The transformation to shift. + Returns: + ndarray: The shifted transformation. + Raises: + AssertionError: If the shift is not a tuple. + Examples: + >>> _shift_transformation(shift, transformation) + + """ for d in range(transformation.shape[0]): transformation[d] += shift[d] def _get_source_roi(self, transformation): + """ + Get the source region of interest (ROI) for the given transformation. + + Args: + transformation (ndarray): The transformation. + Returns: + Roi: The source ROI. + Raises: + AssertionError: If the transformation is not an ndarray. + Examples: + >>> transformation = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + >>> _get_source_roi(transformation) + Roi(Coordinate([0, 0, 0]), Coordinate([2, 2, 2])) + Notes: + Create the source ROI sufficiently large to feed the transformation. + """ dims = transformation.shape[0] # get bounding box of needed data for transformation diff --git a/dacapo/gp/gamma_noise.py b/dacapo/gp/gamma_noise.py index 8fc897e14..6a61a3a21 100644 --- a/dacapo/gp/gamma_noise.py +++ b/dacapo/gp/gamma_noise.py @@ -19,6 +19,7 @@ class GammaAugment(BatchFilter): setup(): Method to configure the internal state of the class process(): Method to apply gamma noise to the desired arrays __augment(): Private method to perform the actual augmentation + """ def __init__(self, arrays, gamma_min, gamma_max): @@ -29,6 +30,11 @@ def __init__(self, arrays, gamma_min, gamma_max): arrays : An iterable collection of np arrays to augment gamma_min : A float representing the lower limit of gamma perturbation gamma_max : A float representing the upper limit of gamma perturbation + Raises: + AssertionError: If gamma_max is less than gamma_min + Examples: + >>> GammaAugment(arrays, gamma_min, gamma_max) + GammaAugment(arrays, gamma_min, gamma_max) """ if not isinstance(arrays, Iterable): arrays = [ @@ -42,6 +48,12 @@ def __init__(self, arrays, gamma_min, gamma_max): def setup(self): """ Configuring the internal state by iterating over arrays. + + Raises: + AssertionError: If the array data type is not float32 or float64 + Examples: + >>> setup() + setup() """ for array in self.arrays: self.updates(array, self.spec[array]) @@ -53,6 +65,13 @@ def process(self, batch, request): Args: batch : The input batch to be processed. request : An object which holds the requested output location. + Returns: + The batch with the gamma noise applied. + Raises: + AssertionError: If the array data type is not float32 or float64 + Examples: + >>> process(batch, request) + process(batch, request) """ sample_gamma_min = (max(self.gamma_min, 1.0 / self.gamma_min) - 1) * (-1) ** ( self.gamma_min < 1 @@ -84,6 +103,13 @@ def __augment(self, a, gamma): Args: a: raw array to be augmented gamma: gamma index to be applied + Returns: + The augmented array. + Raises: + AssertionError: If the array data type is not float32 or float64 + Examples: + >>> __augment(a, gamma) + __augment(a, gamma) """ # normalize a a_min = a.min() diff --git a/dacapo/gp/product.py b/dacapo/gp/product.py index 45926bea6..9a3adfe44 100644 --- a/dacapo/gp/product.py +++ b/dacapo/gp/product.py @@ -3,25 +3,91 @@ class Product(gp.BatchFilter): """ - multiplies two arrays + A BatchFilter that multiplies two input arrays and produces an output array. + + Attributes: + x1_key (:class:`ArrayKey`): The key of the first input array. + x2_key (:class:`ArrayKey`): The key of the second input array. + y_key (:class:`ArrayKey`): The key of the output array. + Provides: + y_key (gp.ArrayKey): The key of the output array. + Method: + __init__: Initialize the Product BatchFilter. + setup: Set up the Product BatchFilter. + prepare: Prepare the Product BatchFilter. + process: Process the Product BatchFilter. + """ def __init__(self, x1_key: gp.ArrayKey, x2_key: gp.ArrayKey, y_key: gp.ArrayKey): + """ + Initialize the Product BatchFilter. + + Args: + x1_key (gp.ArrayKey): The key of the first input array. + x2_key (gp.ArrayKey): The key of the second input array. + y_key (gp.ArrayKey): The key of the output array. + Raises: + AssertionError: If the input arrays are not provided. + Examples: + >>> Product(x1_key, x2_key, y_key) + Product(x1_key, x2_key, y_key) + """ self.x1_key = x1_key self.x2_key = x2_key self.y_key = y_key def setup(self): + """ + Set up the Product BatchFilter. + + Enables autoskip and specifies the output array. + + Raises: + AssertionError: If the input arrays are not provided. + Examples: + >>> setup() + setup() + """ self.enable_autoskip() self.provides(self.y_key, self.spec[self.x1_key].copy()) def prepare(self, request): + """ + Prepare the Product BatchFilter. + + Args: + request (gp.BatchRequest): The batch request. + Returns: + gp.BatchRequest: The dependencies. + Raises: + AssertionError: If the input arrays are not provided. + Examples: + >>> prepare(request) + prepare(request) + + """ deps = gp.BatchRequest() deps[self.x1_key] = request[self.y_key].copy() deps[self.x2_key] = request[self.y_key].copy() return deps def process(self, batch, request): + """ + Process the Product BatchFilter. + + Args: + batch (gp.Batch): The input batch. + request (gp.BatchRequest): The batch request. + Returns: + gp.Batch: The output batch. + Raises: + AssertionError: If the input arrays are not provided. + Examples: + >>> process(batch, request) + process(batch, request) + + """ outputs = gp.Batch() outputs[self.y_key] = gp.Array( diff --git a/dacapo/gp/reject_if_empty.py b/dacapo/gp/reject_if_empty.py index a0b1528d8..0e3e662a7 100644 --- a/dacapo/gp/reject_if_empty.py +++ b/dacapo/gp/reject_if_empty.py @@ -8,30 +8,66 @@ class RejectIfEmpty(BatchFilter): - """Reject batches based on the masked-in vs. masked-out ratio. - - Args: - + """ + Reject batches based on the masked-in vs. masked-out ratio. + Attributes: gt (:class:`ArrayKey`, optional): - The gt array to use - p (``float``, optional): - The probability that we reject until gt is nonempty + Method: + setup: Set up the provider. + provide: Provide a batch. + """ def __init__(self, gt=None, p=0.5, background=0): + """ + Initialize the RejectIfEmpty filter. + + Args: + gt (:class:`ArrayKey`, optional): The gt array to use. + p (float, optional): The probability that we reject until gt is nonempty. + background (int, optional): The background value to consider as empty. + Raises: + AssertionError: If only 1 upstream provider is supported. + Examples: + >>> RejectIfEmpty(gt=gt, p=0.5, background=0) + RejectIfEmpty(gt=gt, p=0.5, background=0) + """ self.gt = gt self.p = p self.background = 0 def setup(self): + """ + Set up the provider. + + Raises: + AssertionError: If only 1 upstream provider is supported. + Examples: + >>> setup() + setup() + """ upstream_providers = self.get_upstream_providers() assert len(upstream_providers) == 1, "Only 1 upstream provider supported" self.upstream_provider = upstream_providers[0] def provide(self, request): + """ + Provides a batch of data, rejecting empty ground truth (gt) if requested. + + Args: + request: The request object containing the necessary information. + Returns: + The batch of data. + Raises: + AssertionError: If the requested gt is not present in the request. + Examples: + >>> provide(request) + provide(request) + + """ random.seed(request.random_seed) report_next_timeout = 10 diff --git a/dacapo/options.py b/dacapo/options.py index cb45cc989..589d3d7cd 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -2,7 +2,7 @@ import yaml import logging from os.path import expanduser -from pathlib import Path +from upath import UPath as Path import attr from cattr import Converter @@ -14,6 +14,20 @@ @attr.s class DaCapoConfig: + """ + Configuration class for DaCapo. + + Attributes: + type (str): The type of store to use for storing configurations and statistics. + runs_base_dir (Path): The path at DaCapo will use for reading and writing any necessary data. + compute_context (dict): The configuration for the compute context to use. + mongo_db_host (Optional[str]): The host of the MongoDB instance to use for storing configurations and statistics. + mongo_db_name (Optional[str]): The name of the MongoDB database to use for storing configurations and statistics. + Methods: + serialize: Serialize the DaCapoConfig object. + + """ + type: str = attr.ib( default="files", metadata={ @@ -48,21 +62,75 @@ class DaCapoConfig: ) def serialize(self): + """ + Serialize the DaCapoConfig object. + + Returns: + dict: The serialized representation of the DaCapoConfig object. + Examples: + >>> config = DaCapoConfig() + >>> config.serialize() + {'type': 'files', 'runs_base_dir': '/home/user/dacapo', 'compute_context': {'type': 'LocalTorch', 'config': {}}, 'mongo_db_host': None, 'mongo_db_name': None} + """ converter = Converter() return converter.unstructure(self) class Options: + """ + A class that provides options for configuring DaCapo. + + This class is designed as a singleton and should be accessed using the `instance` method. + + Methods: + instance: Returns an instance of the Options class. + config_file: Returns the path to the configuration file. + __parse_options_from_file: Parses options from the configuration file. + __parse_options: Parses options from the configuration file and updates them with the provided kwargs. + """ + def __init__(self): + """ + Initializes the Options class. + + Raises: + RuntimeError: If the constructor is called directly instead of using Options.instance(). + Examples: + >>> options = Options() + Traceback (most recent call last): + ... + RuntimeError: Singleton: Use Options.instance() + """ raise RuntimeError("Singleton: Use Options.instance()") @classmethod def instance(cls, **kwargs) -> DaCapoConfig: + """ + Returns an instance of the Options class. + + Args: + kwargs: Additional keyword arguments to update the options. + Returns: + An instance of the DaCapoConfig class. + Examples: + >>> options = Options.instance() + >>> options + DaCapoConfig(type='files', runs_base_dir=PosixPath('/home/user/dacapo'), compute_context={'type': 'LocalTorch', 'config': {}}, mongo_db_host=None, mongo_db_name=None) + """ config = cls.__parse_options(**kwargs) return config @classmethod def config_file(cls) -> Optional[Path]: + """ + Returns the path to the configuration file. + + Returns: + The path to the configuration file if found, otherwise None. + Examples: + >>> Options.config_file() + PosixPath('/home/user/.config/dacapo/dacapo.yaml') + """ env_dict = dict(os.environ) if "OPTIONS_FILE" in env_dict: options_files = [Path(env_dict["OPTIONS_FILE"])] @@ -83,6 +151,15 @@ def config_file(cls) -> Optional[Path]: @classmethod def __parse_options_from_file(cls): + """ + Parses options from the configuration file. + + Returns: + A dictionary containing the parsed options. + Examples: + >>> Options.__parse_options_from_file() + {'type': 'files', 'runs_base_dir': '/home/user/dacapo', 'compute_context': {'type': 'LocalTorch', 'config': {}}, 'mongo_db_host': None, 'mongo_db_name': None} + """ if (config_file := cls.config_file()) is not None: with config_file.open("r") as f: return yaml.safe_load(f) @@ -91,6 +168,17 @@ def __parse_options_from_file(cls): @classmethod def __parse_options(cls, **kwargs): + """ + Parses options from the configuration file and updates them with the provided kwargs. + + Args: + kwargs: Additional keyword arguments to update the options. + Returns: + A dictionary containing the parsed and updated options. + Examples: + >>> Options.__parse_options() + {'type': 'files', 'runs_base_dir': '/home/user/dacapo', 'compute_context': {'type': 'LocalTorch', 'config': {}}, 'mongo_db_host': None, 'mongo_db_name': None} + """ options = cls.__parse_options_from_file() options.update(kwargs) diff --git a/dacapo/plot.py b/dacapo/plot.py index fcd1c6ee2..e86f697b3 100644 --- a/dacapo/plot.py +++ b/dacapo/plot.py @@ -29,6 +29,21 @@ def smooth_values(a, n, stride=1): + """ + Smooth values with a moving average. + + Args: + a: values to smooth + n: number of values to average + stride: stride of the smoothing + Returns: + m: smoothed values + s: standard deviation of the smoothed values + Raises: + ValueError: If run_name is not found in config store + Examples: + >>> smooth_values([1,2,3,4,5], 3) + """ a = np.array(a) # mean @@ -59,6 +74,20 @@ def get_runs_info( config_store = create_config_store() stats_store = create_stats_store() runs = [] + """ + Get information about runs for plotting. + + Args: + run_config_names: Names of run configs to plot + validation_score_names: Names of validation scores to plot + plot_losses: Whether to plot losses + Returns: + runs: List of RunInfo objects + Raises: + ValueError: If run_name is not found in config store + Examples: + >>> get_runs_info(["run_name"], ["validation_score_name"], [True]) + """ for run_config_name, validation_score_name, plot_loss in zip( run_config_names, validation_score_names, plot_losses @@ -96,6 +125,23 @@ def plot_runs( plot_losses=None, return_json=False, ): + """ + Plot runs. + Args: + run_config_base_names: Names of run configs to plot + smooth: Smoothing factor + validation_scores: Validation scores to plot + higher_is_betters: Whether higher is better + plot_losses: Whether to plot losses + return_json: Whether to return JSON + Returns: + JSON or HTML plot + Raises: + ValueError: If run_name is not found in config store + Examples: + >>> plot_runs(["run_name"], 100, None, None, [True]) + + """ print("PLOTTING RUNS") runs = get_runs_info(run_config_base_names, validation_scores, plot_losses) print("GOT RUNS INFO") diff --git a/dacapo/predict.py b/dacapo/predict.py index 723a63566..0c09a9f7a 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -1,4 +1,4 @@ -from pathlib import Path +from upath import UPath as Path from dacapo.blockwise import run_blockwise import dacapo.blockwise @@ -40,6 +40,11 @@ def predict( num_workers (int, optional): The number of workers to use for blockwise prediction. Defaults to 1 for local processing, otherwise 12. output_dtype (np.dtype | str, optional): The dtype of the output array. Defaults to np.uint8. overwrite (bool, optional): If True, the output array will be overwritten if it already exists. Defaults to True. + Raises: + ValueError: If run_name is not found in config store + Examples: + >>> predict("run_name", 100, "input.zarr", "raw", "output.zarr", output_roi="[0:100,0:100,0:100]") + """ # retrieving run if isinstance(run_name, Run): diff --git a/dacapo/store/array_store.py b/dacapo/store/array_store.py index 7c44ab7ab..0e48a0882 100644 --- a/dacapo/store/array_store.py +++ b/dacapo/store/array_store.py @@ -7,43 +7,115 @@ from abc import ABC, abstractmethod import itertools import json -from pathlib import Path +from upath import UPath as Path from typing import Optional, Tuple @attr.s class LocalArrayIdentifier: + """ + Represents a local array identifier. + + Attributes: + container (Path): The path to the container. + dataset (str): The dataset name. + Method: + __str__ : Returns the string representation of the identifier. + """ + container: Path = attr.ib() dataset: str = attr.ib() @attr.s class LocalContainerIdentifier: + """ + Represents a local container identifier. + + Attributes: + container (Path): The path to the container. + Method: + array_identifier : Creates a local array identifier for the given dataset. + + """ + container: Path = attr.ib() def array_identifier(self, dataset) -> LocalArrayIdentifier: + """ + Creates a local array identifier for the given dataset. + + Args: + dataset: The dataset for which to create the array identifier. + Returns: + LocalArrayIdentifier: The local array identifier. + Raises: + TypeError: If the dataset is not a string. + Examples: + >>> container = Path('path/to/container') + >>> container.array_identifier('dataset') + LocalArrayIdentifier(container=Path('path/to/container'), dataset='dataset') + """ return LocalArrayIdentifier(self.container, dataset) class ArrayStore(ABC): - """Base class for array stores. + """ + Base class for array stores. Creates identifiers for the caller to create and write arrays. Provides only rudimentary support for IO itself (currently only to remove - arrays).""" + arrays). + + Attributes: + container (Path): The path to the container. + dataset (str): The dataset name. + Method: + __str__ : Returns the string representation of the identifier. + + """ @abstractmethod def validation_prediction_array( self, run_name: str, iteration: int, dataset: str ) -> LocalArrayIdentifier: - """Get the array identifier for a particular validation prediction.""" + """ + Get the array identifier for a particular validation prediction. + + Args: + run_name: The name of the run. + iteration: The iteration number. + dataset: The dataset name. + Returns: + LocalArrayIdentifier: The array identifier. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> validation_prediction_array('run_name', 1, 'dataset') + LocalArrayIdentifier(container=Path('path/to/container'), dataset='dataset') + """ pass @abstractmethod def validation_output_array( self, run_name: str, iteration: int, parameters: str, dataset: str ) -> LocalArrayIdentifier: - """Get the array identifier for a particular validation output.""" + """ + Get the array identifier for a particular validation output. + + Args: + run_name: The name of the run. + iteration: The iteration number. + parameters: The parameters. + dataset: The dataset name. + Returns: + LocalArrayIdentifier: The array identifier. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> validation_output_array('run_name', 1, 'parameters', 'dataset') + LocalArrayIdentifier(container=Path('path/to/container'), dataset='dataset') + """ pass @abstractmethod @@ -58,18 +130,50 @@ def validation_input_arrays( and figure out where to find the inputs for each run. If we write the data then we don't need to search for it. This convenience comes at the cost of some extra memory usage. + + Args: + run_name: The name of the run. + index: The index of the validation input. + Returns: + Tuple[LocalArrayIdentifier, LocalArrayIdentifier]: The array identifiers. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> validation_input_arrays('run_name', 'index') + (LocalArrayIdentifier(container=Path('path/to/container'), dataset='dataset'), LocalArrayIdentifier(container=Path('path/to/container'), dataset='dataset')) + """ pass @abstractmethod def remove(self, array_identifier: "LocalArrayIdentifier") -> None: - """Remove an array by its identifier.""" + """ + Remove an array by its identifier. + + Args: + array_identifier: The array identifier. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> remove(LocalArrayIdentifier(container=Path('path/to/container'), dataset='dataset')) + + """ pass @abstractmethod def snapshot_container(self, run_name: str) -> LocalContainerIdentifier: """ Get a container identifier for storage of a snapshot. + + Args: + run_name: The name of the run. + Returns: + LocalContainerIdentifier: The container identifier. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> snapshot_container('run_name') + LocalContainerIdentifier(container=Path('path/to/container')) """ pass @@ -77,10 +181,34 @@ def snapshot_container(self, run_name: str) -> LocalContainerIdentifier: def validation_container(self, run_name: str) -> LocalContainerIdentifier: """ Get a container identifier for storage of a snapshot. + + Args: + run_name: The name of the run. + Returns: + LocalContainerIdentifier: The container identifier. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> validation_container('run_name') + LocalContainerIdentifier(container=Path('path/to/container')) """ pass def _visualize_training(self, run): + """ + Returns a neuroglancer link to visualize snapshots and validations. + + Args: + run: The run. + Returns: + str: The neuroglancer link. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> _visualize_training(run) + 'http://neuroglancer-demo.appspot.com/#!{}' + + """ # returns a neuroglancer link to visualize snapshots and validations snapshot_container = self.snapshot_container(run.name) validation_container = self.validation_container(run.name) @@ -91,7 +219,35 @@ def _visualize_training(self, run): validations = [] def generate_groups(container): + """ + Generate groups for snapshots and validations. + + Args: + container: The container. + Returns: + function: The add_element function. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> generate_groups(container) + function + + """ + def add_element(name, obj): + """ + Add elements to the container. + + Args: + name: The name of the element. + obj: The object. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> add_element('name', 'obj') + None + + """ if isinstance(obj, zarr.hierarchy.Array): container.append(name) diff --git a/dacapo/store/config_store.py b/dacapo/store/config_store.py index 87fa6edb0..7a06e0081 100644 --- a/dacapo/store/config_store.py +++ b/dacapo/store/config_store.py @@ -11,11 +11,66 @@ class DuplicateNameError(Exception): + """ + Exception raised when trying to store a config with a name that already + exists. + + Attributes: + message (str): The error message. + Methods: + __str__: Return the error message. + + """ + pass class ConfigStore(ABC): - """Base class for configuration stores.""" + """ + Base class for configuration stores. + + Attributes: + runs (Any): The runs store. + datasplits (Any): The datasplits store. + datasets (Any): The datasets store. + arrays (Any): The arrays store. + tasks (Any): The tasks store. + trainers (Any): The trainers store. + architectures (Any): The architectures store. + Methods: + delete_config: Delete a config from a store. + store_run_config: Store a run config. + retrieve_run_config: Retrieve a run config from a run name. + retrieve_run_config_names: Retrieve all run config names. + delete_run_config: Delete a run config. + store_task_config: Store a task config. + retrieve_task_config: Retrieve a task config from a task name. + retrieve_task_config_names: Retrieve all task config names. + delete_task_config: Delete a task config. + store_architecture_config: Store a architecture config. + retrieve_architecture_config: Retrieve a architecture config from a architecture name. + retrieve_architecture_config_names: Retrieve all architecture config names. + delete_architecture_config: Delete a architecture config. + store_trainer_config: Store a trainer config. + retrieve_trainer_config: Retrieve a trainer config from a trainer name. + retrieve_trainer_config_names: Retrieve all trainer config names. + delete_trainer_config: Delete a trainer config. + store_datasplit_config: Store a datasplit config. + retrieve_datasplit_config: Retrieve a datasplit config from a datasplit name. + retrieve_datasplit_config_names: Retrieve all datasplit names. + delete_datasplit_config: Delete a datasplit config. + store_array_config: Store a array config. + retrieve_array_config: Retrieve a array config from a array name. + retrieve_array_config_names: Retrieve all array names. + delete_array_config: Delete a array config. + Note: + This class is an abstract base class for configuration stores. It + defines the interface for storing and retrieving configuration objects + (e.g., run, task, architecture, trainer, datasplit, dataset, array + configs). Concrete implementations of this class should define how + these objects are stored and retrieved (e.g., in a database, in files). + + """ runs: Any datasplits: Any @@ -27,99 +82,322 @@ class ConfigStore(ABC): @abstractmethod def delete_config(self, database, config_name: str) -> None: + """ + Delete a config from a store. + + Args: + database (Any): The store to delete the config from. + config_name (str): The name of the config to delete. + Raises: + KeyError: If the config does not exist. + Examples: + >>> store.delete_config(store.runs, "run1") + + """ pass @abstractmethod def store_run_config(self, run_config: "RunConfig") -> None: - """Store a run config. This should also store the configs that are part + """ + Store a run config. This should also store the configs that are part of the run config (i.e., task, architecture, trainer, and dataset - config).""" + config). + + Args: + run_config (RunConfig): The run config to store. + Raises: + DuplicateNameError: If a run config with the same name already + exists. + Examples: + >>> store.store_run_config(run_config) + + """ pass @abstractmethod def retrieve_run_config(self, run_name: str) -> "RunConfig": - """Retrieve a run config from a run name.""" + """ + Retrieve a run config from a run name. + + Args: + run_name (str): The name of the run config to retrieve. + Returns: + RunConfig: The run config with the given name. + Raises: + KeyError: If the run config does not exist. + Examples: + >>> run_config = store.retrieve_run_config("run1") + + """ pass @abstractmethod def retrieve_run_config_names(self) -> List[str]: - """Retrieve all run config names.""" + """ + Retrieve all run config names. + + Returns: + List[str]: The names of all run configs. + Raises: + KeyError: If no run configs exist. + Examples: + >>> run_names = store.retrieve_run_config_names() + """ pass def delete_run_config(self, run_name: str) -> None: + """ + Delete a run config from the store. + + Args: + run_name (str): The name of the run config to delete. + Raises: + KeyError: If the run config does not exist. + Examples: + >>> store.delete_run_config("run1") + + """ self.delete_config(self.runs, run_name) @abstractmethod def store_task_config(self, task_config: "TaskConfig") -> None: - """Store a task config.""" + """ + Store a task config. + + Args: + task_config (TaskConfig): The task config to store. + Raises: + DuplicateNameError: If a task config with the same name already + exists. + Examples: + >>> store.store_task_config(task_config) + + """ pass @abstractmethod def retrieve_task_config(self, task_name: str) -> "TaskConfig": - """Retrieve a task config from a task name.""" + """ + Retrieve a task config from a task name. + + Args: + task_name (str): The name of the task config to retrieve. + Returns: + TaskConfig: The task config with the given name. + Raises: + KeyError: If the task config does not exist. + Examples: + >>> task_config = store.retrieve_task_config("task1") + + """ pass @abstractmethod def retrieve_task_config_names(self) -> List[str]: - """Retrieve all task config names.""" + """ + Retrieve all task config names. + + Args: + List[str]: The names of all task configs. + Returns: + List[str]: The names of all task configs. + Raises: + KeyError: If no task configs exist. + Examples: + >>> task_names = store.retrieve_task_config_names() + + """ pass def delete_task_config(self, task_name: str) -> None: + """ + Delete a task config from the store. + + Args: + task_name (str): The name of the task config to delete. + Raises: + KeyError: If the task config does not exist. + Examples: + >>> store.delete_task_config("task1") + + """ self.delete_config(self.tasks, task_name) @abstractmethod def store_architecture_config( self, architecture_config: "ArchitectureConfig" ) -> None: - """Store a architecture config.""" + """ + Store a architecture config. + + Args: + architecture_config (ArchitectureConfig): The architecture config + to store. + Raises: + DuplicateNameError: If a architecture config with the same name + already exists. + Examples: + >>> store.store_architecture_config(architecture_config) + """ pass @abstractmethod def retrieve_architecture_config( self, architecture_name: str ) -> "ArchitectureConfig": - """Retrieve a architecture config from a architecture name.""" + """ + Retrieve a architecture config from a architecture name. + + Args: + architecture_name (str): The name of the architecture config to + retrieve. + Returns: + ArchitectureConfig: The architecture config with the given name. + Raises: + KeyError: If the architecture config does not exist. + Examples: + >>> architecture_config = store.retrieve_architecture_config("architecture1") + """ pass @abstractmethod def retrieve_architecture_config_names(self) -> List[str]: - """Retrieve all architecture config names.""" + """ + Retrieve all architecture config names. + + Args: + List[str]: The names of all architecture configs. + Returns: + List[str]: The names of all architecture configs. + Raises: + KeyError: If no architecture configs exist. + Examples: + >>> architecture_names = store.retrieve_architecture_config_names() + """ pass def delete_architecture_config(self, architecture_name: str) -> None: + """ + Delete a architecture config from the store. + + Args: + architecture_name (str): The name of the architecture config to + delete. + Raises: + KeyError: If the architecture config does not exist. + Examples: + >>> store.delete_architecture_config("architecture1") + """ self.delete_config(self.architectures, architecture_name) @abstractmethod def store_trainer_config(self, trainer_config: "TrainerConfig") -> None: - """Store a trainer config.""" + """ + Store a trainer config. + + Args: + trainer_config (TrainerConfig): The trainer config to store. + Raises: + DuplicateNameError: If a trainer config with the same name already + exists. + Examples: + >>> store.store_trainer_config(trainer_config) + """ pass @abstractmethod def retrieve_trainer_config(self, trainer_name: str) -> None: - """Retrieve a trainer config from a trainer name.""" + """ + Retrieve a trainer config from a trainer name. + + Args: + trainer_name (str): The name of the trainer config to retrieve. + Returns: + TrainerConfig: The trainer config with the given name. + Raises: + KeyError: If the trainer config does not exist. + Examples: + >>> trainer_config = store.retrieve_trainer_config("trainer1") + """ pass @abstractmethod def retrieve_trainer_config_names(self) -> List[str]: - """Retrieve all trainer config names.""" + """ + Retrieve all trainer config names. + + Args: + List[str]: The names of all trainer configs. + Returns: + List[str]: The names of all trainer configs. + Raises: + KeyError: If no trainer configs exist. + Examples: + >>> trainer_names = store.retrieve_trainer_config_names() + + """ pass def delete_trainer_config(self, trainer_name: str) -> None: + """ + Delete a trainer config from the store. + + Args: + trainer_name (str): The name of the trainer config to delete. + Raises: + KeyError: If the trainer config does not exist. + Examples: + >>> store.delete_trainer_config("trainer1") + + """ self.delete_config(self.trainers, trainer_name) @abstractmethod def store_datasplit_config(self, datasplit_config: "DataSplitConfig") -> None: - """Store a datasplit config.""" + """ + Store a datasplit config. + + Args: + datasplit_config (DataSplitConfig): The datasplit config to store. + Raises: + DuplicateNameError: If a datasplit config with the same name already + exists. + Examples: + >>> store.store_datasplit_config(datasplit_config) + """ pass @abstractmethod def retrieve_datasplit_config(self, datasplit_name: str) -> "DataSplitConfig": - """Retrieve a datasplit config from a datasplit name.""" + """ + Retrieve a datasplit config from a datasplit name. + + Args: + datasplit_name (str): The name of the datasplit config to retrieve. + Returns: + DataSplitConfig: The datasplit config with the given name. + Raises: + KeyError: If the datasplit config does not exist. + Examples: + >>> datasplit_config = store.retrieve_datasplit_config("datasplit1") + """ pass @abstractmethod def retrieve_datasplit_config_names(self) -> List[str]: - """Retrieve all datasplit names.""" + """ + Retrieve all datasplit names. + + Args: + List[str]: The names of all datasplit configs. + Returns: + List[str]: The names of all datasplit configs. + Raises: + KeyError: If no datasplit configs exist. + Examples: + >>> datasplit_names = store.retrieve_datasplit_config_names() + + """ pass def delete_datasplit_config(self, datasplit_name: str) -> None: @@ -127,18 +405,60 @@ def delete_datasplit_config(self, datasplit_name: str) -> None: @abstractmethod def store_array_config(self, array_config: "ArrayConfig") -> None: - """Store a array config.""" + """ + Store a array config. + + Args: + array_config (ArrayConfig): The array config to store. + Raises: + DuplicateNameError: If a array config with the same name already + exists. + Examples: + >>> store.store_array_config(array_config) + """ pass @abstractmethod def retrieve_array_config(self, array_name: str) -> "ArrayConfig": - """Retrieve a array config from a array name.""" + """ + Retrieve a array config from a array name. + + Args: + array_name (str): The name of the array config to retrieve. + Returns: + ArrayConfig: The array config with the given name. + Raises: + KeyError: If the array config does not exist. + Examples: + >>> array_config = store.retrieve_array_config("array1") + """ pass @abstractmethod def retrieve_array_config_names(self) -> List[str]: - """Retrieve all array names.""" + """ + Retrieve all array names. + + Args: + List[str]: The names of all array configs. + Returns: + List[str]: The names of all array configs. + Raises: + KeyError: If no array configs exist. + Examples: + >>> array_names = store.retrieve_array_config_names() + """ pass def delete_array_config(self, array_name: str) -> None: + """ + Delete a array config from the store. + + Args: + array_name (str): The name of the array config to delete. + Raises: + KeyError: If the array config does not exist. + Examples: + >>> store.delete_array_config("array1") + """ self.delete_config(self.arrays, array_name) diff --git a/dacapo/store/conversion_hooks.py b/dacapo/store/conversion_hooks.py index 934b4e47b..f90dd5f79 100644 --- a/dacapo/store/conversion_hooks.py +++ b/dacapo/store/conversion_hooks.py @@ -14,11 +14,39 @@ from funlib.geometry import Coordinate, Roi -from pathlib import Path +from upath import UPath as Path def register_hierarchy_hooks(converter): - """Central place to register type hierarchies for conversion.""" + """ + Central place to register type hierarchies for conversion. + + Args: + converter (Converter): The converter to register the hooks with. + Raises: + TypeError: If ``cls`` is not a class. + Example: + If class ``A`` is the base of class ``B``, and + ``converter.register_hierarchy(A, lambda typ: eval(typ))`` has been + called, the dictionary ``y = converter.unstructure(x)`` will + contain a ``__type__`` field that is ``'A'`` if ``x = A()`` and + ``B`` if ``x = B()``. + + This ``__type__`` field is then used by ``x = + converter.structure(y, A)`` to recreate the concrete type of ``x``. + Note: + This method is used to register a class hierarchy for typed + structure/unstructure conversion. For each class in the hierarchy + under (including) ``cls``, this will store an additional + ``__type__`` attribute (a string) in the object dictionary. This + ``__type__`` string will be the concrete class of the object, and + will be used to structure the dictionary back into an object of the + correct class. + + For this to work, this function needs to know how to convert a + ``__type__`` string back into a class, for which it used the + provided ``cls_fn``. + """ converter.register_hierarchy(TaskConfig, cls_fun) converter.register_hierarchy(StartConfig, cls_fun) @@ -34,8 +62,36 @@ def register_hierarchy_hooks(converter): def register_hooks(converter): - """Central place to register all conversion hooks with the given - converter.""" + """ + Central place to register all conversion hooks with the given + converter. + + Args: + converter (Converter): The converter to register the hooks with. + Raises: + TypeError: If ``cls`` is not a class. + Example: + If class ``A`` is the base of class ``B``, and + ``converter.register_hierarchy(A, lambda typ: eval(typ))`` has been + called, the dictionary ``y = converter.unstructure(x)`` will + contain a ``__type__`` field that is ``'A'`` if ``x = A()`` and + ``B`` if ``x = B()``. + + This ``__type__`` field is then used by ``x = + converter.structure(y, A)`` to recreate the concrete type of ``x``. + Note: + This method is used to register a class hierarchy for typed + structure/unstructure conversion. For each class in the hierarchy + under (including) ``cls``, this will store an additional + ``__type__`` attribute (a string) in the object dictionary. This + ``__type__`` string will be the concrete class of the object, and + will be used to structure the dictionary back into an object of the + correct class. + + For this to work, this function needs to know how to convert a + ``__type__`` string back into a class, for which it used the + provided ``cls_fn``. + """ ######################### # DaCapo specific hooks # @@ -80,6 +136,21 @@ def register_hooks(converter): def cls_fun(typ): - """Convert a type string into the corresponding class. The class must be - visible to this module (hence the star imports at the top).""" + """ + Convert a type string into the corresponding class. The class must be + visible to this module (hence the star imports at the top). + + Args: + typ (str): The type string to convert. + Returns: + class: The class corresponding to the type string. + Raises: + NameError: If the class is not visible to this module. + Example: + ``cls_fun('TaskConfig')`` will return the class ``TaskConfig``. + Note: + This function is used to convert a type string back into a class. It is + used in conjunction with the ``register_hierarchy`` function to + register a class hierarchy for typed structure/unstructure conversion. + """ return eval(typ) diff --git a/dacapo/store/converter.py b/dacapo/store/converter.py index d50ca0225..62bb2f4df 100644 --- a/dacapo/store/converter.py +++ b/dacapo/store/converter.py @@ -6,10 +6,28 @@ class TypedConverter(Converter): """A converter that stores and retrieves type information for selected class hierarchies. Used to reconstruct a concrete class from unstructured - data.""" + data. + + Attributes: + hooks (Dict[Type, List[Hook]]): A dictionary mapping classes to lists of + hooks that should be applied to them. + Methods: + register_hierarchy: Register a class hierarchy for typed + structure/unstructure conversion. + __typed_unstructure: Unstructure an object, adding a '__type__' field + with the class name. + __typed_structure: Structure an object, using the '__type__' field to + determine the class. + Note: + This class is a subclass of cattr.Converter, and extends it with the + ability to store and retrieve type information for selected class + hierarchies. This is useful for reconstructing a concrete class from + unstructured data. + """ def register_hierarchy(self, cls, cls_fn): - """Register a class hierarchy for typed structure/unstructure + """ + Register a class hierarchy for typed structure/unstructure conversion. For each class in the hierarchy under (including) ``cls``, this will @@ -23,17 +41,16 @@ def register_hierarchy(self, cls, cls_fn): ``cls_fn``. Args: - cls (class): The top-level class of the hierarchy to register. - cls_fn (function): A function mapping type strings to classes. This can be as simple as ``lambda typ: eval(typ)``, if all subclasses of ``cls`` are visible to the module that calls this method. - + Raises: + TypeError: If ``cls`` is not a class. Example: If class ``A`` is the base of class ``B``, and @@ -44,6 +61,18 @@ def register_hierarchy(self, cls, cls_fn): This ``__type__`` field is then used by ``x = converter.structure(y, A)`` to recreate the concrete type of ``x``. + Note: + This method is used to register a class hierarchy for typed + structure/unstructure conversion. For each class in the hierarchy + under (including) ``cls``, this will store an additional + ``__type__`` attribute (a string) in the object dictionary. This + ``__type__`` string will be the concrete class of the object, and + will be used to structure the dictionary back into an object of the + correct class. + + For this to work, this function needs to know how to convert a + ``__type__`` string back into a class, for which it used the + provided ``cls_fn``. """ self.register_unstructure_hook(cls, lambda obj: self.__typed_unstructure(obj)) @@ -53,14 +82,50 @@ def register_hierarchy(self, cls, cls_fn): ) def __typed_unstructure(self, obj): + """ + Unstructure an object, adding a '__type__' field with the class name. + + Args: + obj (object): The object to unstructure. + Returns: + Dict: The unstructured object. + Examples: + >>> converter.__typed_unstructure(A()) + {'__type__': 'A'} + """ cls = type(obj) unstructure_fn = make_dict_unstructure_fn(cls, self) return {"__type__": type(obj).__name__, **unstructure_fn(obj)} def __typed_structure(self, obj_data, cls, cls_fn): - cls = cls_fn(obj_data["__type__"]) - structure_fn = make_dict_structure_fn(cls, self) - return structure_fn(obj_data, cls) + """ + Structure an object, using the '__type__' field to determine the class. + + Args: + obj_data (Dict): The unstructured object. + cls (class): The class + cls_fn (function): A function mapping type strings to classes. + Returns: + object: The structured object. + Raises: + ValueError: If the '__type__' field is missing. + Examples: + >>> converter.__typed_structure({'__type__': 'A'}, A, lambda typ: eval(typ + 'A') + Note: + This method is used to structure an object, using the '__type__' field + to determine the class. This is useful for reconstructing a concrete + class from unstructured data. + """ + try: + cls = cls_fn(obj_data["__type__"]) + structure_fn = make_dict_structure_fn(cls, self) + return structure_fn(obj_data, cls) + except: + print( + f"Could not structure object of type {obj_data}. will try unstructured data. attr __type__ can be missing because of old version of the data." + ) + return obj_data # The global converter object, to be used by stores to convert objects into diff --git a/dacapo/store/create_store.py b/dacapo/store/create_store.py index 0fcc43ed2..e04060e90 100644 --- a/dacapo/store/create_store.py +++ b/dacapo/store/create_store.py @@ -6,11 +6,23 @@ from .file_stats_store import FileStatsStore from dacapo import Options -from pathlib import Path +from upath import UPath as Path def create_config_store(): - """Create a config store based on the global DaCapo options.""" + """ + Create a config store based on the global DaCapo options. + + Returns: + ConfigStore: The created config store. + Raises: + ValueError: If the store type is not supported. + Examples: + >>> create_config_store() + + Note: + Currently, only the FileConfigStore and MongoConfigStore are supported. + """ options = Options.instance() @@ -19,14 +31,29 @@ def create_config_store(): db_name = options.mongo_db_name return MongoConfigStore(db_host, db_name) elif options.type == "files": - store_path = Path(options.runs_base_dir).expanduser() + store_path = Path(options.runs_base_dir) return FileConfigStore(store_path / "configs") else: raise ValueError(f"Unknown store type {options.type}") def create_stats_store(): - """Create a statistics store based on the global DaCapo options.""" + """ + Create a statistics store based on the global DaCapo options. + + Args: + options (Options): The global DaCapo options. + Returns: + StatsStore: The created statistics store. + Raises: + ValueError: If the store type is not supported. + Examples: + >>> create_stats_store() + + Note: + Currently, only the FileStatsStore and MongoStatsStore are supported. + + """ options = Options.instance() @@ -35,24 +62,49 @@ def create_stats_store(): db_name = options.mongo_db_name return MongoStatsStore(db_host, db_name) elif options.type == "files": - store_path = Path(options.runs_base_dir).expanduser() + store_path = Path(options.runs_base_dir) return FileStatsStore(store_path / "stats") else: raise ValueError(f"Unknown store type {options.type}") def create_weights_store(): - """Create a weights store based on the global DaCapo options.""" + """ + Create a weights store based on the global DaCapo options. + + Args: + options (Options): The global DaCapo options. + Returns: + WeightsStore: The created weights store. + Examples: + >>> create_weights_store() + + Note: + Currently, only the LocalWeightsStore is supported. + """ options = Options.instance() - # currently, only the LocalWeightsStore is supported - base_dir = Path(options.runs_base_dir).expanduser() + base_dir = Path(options.runs_base_dir) return LocalWeightsStore(base_dir) def create_array_store(): - """Create an array store based on the global DaCapo options.""" + """ + Create an array store based on the global DaCapo options. + + Args: + options (Options): The global DaCapo options. + Returns: + ArrayStore: The created array store. + Raises: + ValueError: If the store type is not supported. + Examples: + >>> create_array_store() + + Note: + Currently, only the LocalArrayStore is supported. + """ options = Options.instance() diff --git a/dacapo/store/file_config_store.py b/dacapo/store/file_config_store.py index 014fe4fef..55543b462 100644 --- a/dacapo/store/file_config_store.py +++ b/dacapo/store/file_config_store.py @@ -9,17 +9,55 @@ import logging import yaml -from pathlib import Path +from upath import UPath as Path logger = logging.getLogger(__name__) class FileConfigStore(ConfigStore): - """A Local File based store for configurations. Used to store and retrieve + """ + A Local File based store for configurations. Used to store and retrieve configurations for runs, tasks, architectures, trainers, and datasplits. + + Attributes: + path (Path): The path to the file. + Methods: + store_run_config(run_config, ignore=None): Stores the run configuration in the file config store. + retrieve_run_config(run_name): Retrieve the run configuration for a given run name. + retrieve_run_config_names(): Retrieve the names of the run configurations. + store_task_config(task_config, ignore=None): Stores the task configuration in the file config store. + retrieve_task_config(task_name): Retrieve the task configuration for a given task name. + retrieve_task_config_names(): Retrieve the names of the task configurations. + store_architecture_config(architecture_config, ignore=None): Stores the architecture configuration in the file config store. + retrieve_architecture_config(architecture_name): Retrieve the architecture configuration for a given architecture name. + retrieve_architecture_config_names(): Retrieve the names of the architecture configurations. + store_trainer_config(trainer_config, ignore=None): Stores the trainer configuration in the file config store. + retrieve_trainer_config(trainer_name): Retrieve the trainer configuration for a given trainer name. + retrieve_trainer_config_names(): Retrieve the names of the trainer configurations. + store_datasplit_config(datasplit_config, ignore=None): Stores the datasplit configuration in the file config store. + retrieve_datasplit_config(datasplit_name): Retrieve the datasplit configuration for a given datasplit name. + retrieve_datasplit_config_names(): Retrieve the names of the datasplit configurations. + store_array_config(array_config, ignore=None): Stores the array configuration in the file config store. + retrieve_array_config(array_name): Retrieve the array configuration for a given array name. + retrieve_array_config_names(): Retrieve the names of the array configurations. + __save_insert(collection, data, ignore=None): Saves the data to the collection. + __load(collection, name): Loads the data + Notes: + The FileConfigStore is used to store and retrieve configurations for runs, tasks, architectures, trainers, and datasplits. + The FileConfigStore is a local file based store for configurations. """ def __init__(self, path): + """ + Initializes a new instance of the FileConfigStore class. + + Args: + path (str): The path to the file. + Raises: + ValueError: If the path is not a valid directory. + Examples: + >>> store = FileConfigStore("path/to/configs") + """ print(f"Creating FileConfigStore:\n\tpath: {path}") self.path = Path(path) @@ -28,72 +66,295 @@ def __init__(self, path): self.__init_db() def store_run_config(self, run_config, ignore=None): + """ + Stores the run configuration in the file config store. + + Args: + run_config (RunConfig): The run configuration to store. + ignore (list, optional): A list of keys to ignore when comparing the stored configuration with the new configuration. Defaults to None. + Raises: + DuplicateNameError: If a configuration with the same name already exists. + Examples: + >>> store.store_run_config(run_config) + """ run_doc = converter.unstructure(run_config) self.__save_insert(self.runs, run_doc, ignore) def retrieve_run_config(self, run_name): + """ + Retrieve the run configuration for a given run name. + + Args: + run_name (str): The name of the run configuration to retrieve. + Returns: + RunConfig: The run configuration object. + Raises: + KeyError: If the run name does not exist in the store. + Examples: + >>> run_config = store.retrieve_run_config("run1") + + """ run_doc = self.__load(self.runs, run_name) return converter.structure(run_doc, RunConfig) def retrieve_run_config_names(self): + """ + Retrieve the names of the run configurations. + + Returns: + A list of run configuration names. + Raises: + KeyError: If no run configurations are stored. + Examples: + >>> run_names = store.retrieve_run_config_names() + + """ return [f.name[:-5] for f in self.runs.iterdir()] def store_task_config(self, task_config, ignore=None): + """ + Stores the task configuration in the file config store. + + Args: + task_config (TaskConfig): The task configuration to store. + ignore (list, optional): A list of keys to ignore when comparing the stored configuration with the new configuration. Defaults to None. + Raises: + DuplicateNameError: If a configuration with the same name already exists. + Examples: + >>> store.store_task_config(task_config) + + """ task_doc = converter.unstructure(task_config) self.__save_insert(self.tasks, task_doc, ignore) def retrieve_task_config(self, task_name): + """ + Retrieve the task configuration for a given task name. + + Args: + task_name (str): The name of the task configuration to retrieve. + Returns: + TaskConfig: The task configuration object. + Raises: + KeyError: If the task name does not exist in the store. + Examples: + >>> task_config = store.retrieve_task_config("task1") + + """ task_doc = self.__load(self.tasks, task_name) return converter.structure(task_doc, TaskConfig) def retrieve_task_config_names(self): + """ + Retrieve the names of the task configurations. + + Returns: + A list of task configuration names. + Raises: + KeyError: If no task configurations are stored. + Examples: + >>> task_names = store.retrieve_task_config_names() + """ return [f.name[:-5] for f in self.tasks.iterdir()] def store_architecture_config(self, architecture_config, ignore=None): + """ + Stores the architecture configuration in the file config store. + + Args: + architecture_config (ArchitectureConfig): The architecture configuration to store. + ignore (list, optional): A list of keys to ignore when comparing the stored configuration with the new configuration. Defaults to None. + Raises: + DuplicateNameError: If a configuration with the same name already exists. + Examples: + >>> store.store_architecture_config(architecture_config) + """ architecture_doc = converter.unstructure(architecture_config) self.__save_insert(self.architectures, architecture_doc, ignore) def retrieve_architecture_config(self, architecture_name): + """ + Retrieve the architecture configuration for a given architecture name. + + Args: + architecture_name (str): The name of the architecture configuration to retrieve. + Returns: + ArchitectureConfig: The architecture configuration object. + Raises: + KeyError: If the architecture name does not exist in the store. + Examples: + >>> architecture_config = store.retrieve_architecture_config("architecture1") + """ architecture_doc = self.__load(self.architectures, architecture_name) return converter.structure(architecture_doc, ArchitectureConfig) def retrieve_architecture_config_names(self): + """ + Retrieve the names of the architecture configurations. + + Returns: + A list of architecture configuration names. + Raises: + KeyError: If no architecture configurations are stored. + Examples: + >>> architecture_names = store.retrieve_architecture_config_names() + """ return [f.name[:-5] for f in self.architectures.iterdir()] def store_trainer_config(self, trainer_config, ignore=None): + """ + Stores the trainer configuration in the file config store. + + Args: + trainer_config (TrainerConfig): The trainer configuration to store. + ignore (list, optional): A list of keys to ignore when comparing the stored configuration with the new configuration. Defaults to None. + Raises: + DuplicateNameError: If a configuration with the same name already exists. + Examples: + >>> store.store_trainer_config(trainer_config) + """ trainer_doc = converter.unstructure(trainer_config) self.__save_insert(self.trainers, trainer_doc, ignore) def retrieve_trainer_config(self, trainer_name): + """ + Retrieve the trainer configuration for a given trainer name. + + Args: + trainer_name (str): The name of the trainer configuration to retrieve. + Returns: + TrainerConfig: The trainer configuration object. + Raises: + KeyError: If the trainer name does not exist in the store. + Examples: + >>> trainer_config = store.retrieve_trainer_config("trainer1") + + """ trainer_doc = self.__load(self.trainers, trainer_name) return converter.structure(trainer_doc, TrainerConfig) def retrieve_trainer_config_names(self): + """ + Retrieve the names of the trainer configurations. + + Args: + trainer_name (str): The name of the trainer configuration to retrieve. + Returns: + TrainerConfig: The trainer configuration object. + Raises: + KeyError: If the trainer name does not exist in the store. + Examples: + >>> trainer_config = store.retrieve_trainer_config("trainer1") + """ return [f.name[:-5] for f in self.trainers.iterdir()] def store_datasplit_config(self, datasplit_config, ignore=None): + """ + Stores the datasplit configuration in the file config store. + + Args: + datasplit_config (DataSplitConfig): The datasplit configuration to store. + ignore (list, optional): A list of keys to ignore when comparing the stored configuration with the new configuration. Defaults to None. + Raises: + DuplicateNameError: If a configuration with the same name already exists. + Examples: + >>> store.store_datasplit_config(datasplit_config) + """ datasplit_doc = converter.unstructure(datasplit_config) self.__save_insert(self.datasplits, datasplit_doc, ignore) def retrieve_datasplit_config(self, datasplit_name): + """ + Retrieve the datasplit configuration for a given datasplit name. + + Args: + datasplit_name (str): The name of the datasplit configuration to retrieve. + Returns: + DataSplitConfig: The datasplit configuration object. + Raises: + KeyError: If the datasplit name does not exist in the store. + Examples: + >>> datasplit_config = store.retrieve_datasplit_config("datasplit1") + + """ datasplit_doc = self.__load(self.datasplits, datasplit_name) return converter.structure(datasplit_doc, DataSplitConfig) def retrieve_datasplit_config_names(self): + """ + Retrieve the names of the datasplit configurations. + + Args: + datasplit_name (str): The name of the datasplit configuration to retrieve. + Returns: + DataSplitConfig: The datasplit configuration object. + Raises: + KeyError: If the datasplit name does not exist in the store. + Examples: + >>> datasplit_config = store.retrieve_datasplit_config("datasplit1") + + """ return [f.name[:-5] for f in self.datasplits.iterdir()] def store_array_config(self, array_config, ignore=None): + """ + Stores the array configuration in the file config store. + + Args: + array_config (ArrayConfig): The array configuration to store. + ignore (list, optional): A list of keys to ignore when comparing the stored configuration with the new configuration. Defaults to None. + Raises: + DuplicateNameError: If a configuration with the same name already exists. + Examples: + >>> store.store_array_config(array_config) + + """ array_doc = converter.unstructure(array_config) self.__save_insert(self.arrays, array_doc, ignore) def retrieve_array_config(self, array_name): + """ + Retrieve the array configuration for a given array name. + + Args: + array_name (str): The name of the array configuration to retrieve. + Returns: + ArrayConfig: The array configuration object. + Raises: + KeyError: If the array name does not exist in the store. + Examples: + >>> array_config = store.retrieve_array_config("array1") + """ array_doc = self.__load(self.arrays, array_name) return converter.structure(array_doc, ArrayConfig) def retrieve_array_config_names(self): + """ + Retrieve the names of the array configurations. + + Returns: + A list of array configuration names. + Raises: + KeyError: If no array configurations are stored. + Examples: + >>> array_names = store.retrieve_array_config_names() + + """ return [f.name[:-5] for f in self.arrays.iterdir()] def __save_insert(self, collection, data, ignore=None): + """ + Saves the data to the collection. + + Args: + collection (Path): The path to the collection. + data (dict): The data to store. + ignore (list, optional): A list of keys to ignore when comparing the stored configuration with the new configuration. Defaults to None. + Raises: + DuplicateNameError: If a configuration with the same name already exists. + Examples: + >>> store.__save_insert(collection, data) + """ name = data["name"] file_store = collection / f"{name}.yaml" @@ -113,6 +374,19 @@ def __save_insert(self, collection, data, ignore=None): ) def __load(self, collection, name): + """ + Loads the data from the collection. + + Args: + collection (Path): The path to the collection. + name (str): The name of the data to load. + Returns: + The data from the collection. + Raises: + ValueError: If the config with the name does not exist in the collection. + Examples: + >>> store.__load(collection, name) + """ file_store = collection / f"{name}.yaml" if file_store.exists(): with file_store.open("r") as f: @@ -121,6 +395,20 @@ def __load(self, collection, name): raise ValueError(f"No config with name: {name} in collection: {collection}") def __same_doc(self, a, b, ignore=None): + """ + Compares two dictionaries for equality, ignoring certain keys. + + Args: + a (dict): The first dictionary to compare. + b (dict): The second dictionary to compare. + ignore (list, optional): A list of keys to ignore. Defaults to None. + Returns: + bool: True if the dictionaries are equal, False otherwise. + Raises: + KeyError: If the keys do not match. + Examples: + >>> store.__same_doc(a, b) + """ if ignore: a = dict(a) b = dict(b) @@ -133,11 +421,29 @@ def __same_doc(self, a, b, ignore=None): return a == b def __init_db(self): + """ + Initializes the FileConfigStore database. + Adds the collections for the FileConfigStore. + + Raises: + FileNotFoundError: If the collections do not exist. + Examples: + >>> store.__init_db() + + """ # no indexing for filesystem # please only use this config store for debugging pass def __open_collections(self): + """ + Opens the collections for the FileConfigStore. + + Raises: + FileNotFoundError: If the collections do not exist. + Examples: + >>> store.__open_collections() + """ self.users.mkdir(exist_ok=True, parents=True) self.runs.mkdir(exist_ok=True, parents=True) self.tasks.mkdir(exist_ok=True, parents=True) @@ -148,35 +454,136 @@ def __open_collections(self): @property def users(self) -> Path: + """ + Returns the path to the users directory. + + Returns: + Path: The path to the users directory. + Raises: + FileNotFoundError: If the users directory does not exist. + Examples: + >>> store.users + Path("path/to/configs/users") + """ return self.path / "users" @property def runs(self) -> Path: + """ + Returns the path to the runs directory. + + Returns: + Path: The path to the runs directory. + Raises: + FileNotFoundError: If the runs directory does not exist. + Examples: + >>> store.runs + Path("path/to/configs/runs") + """ return self.path / "runs" @property def tasks(self) -> Path: + """ + Returns the path to the tasks directory. + + Returns: + Path: The path to the tasks directory. + Raises: + FileNotFoundError: If the tasks directory does not exist. + Examples: + >>> store.tasks + Path("path/to/configs/tasks") + """ return self.path / "tasks" @property def datasplits(self) -> Path: + """ + Returns the path to the datasplits directory. + + Returns: + Path: The path to the datasplits directory. + Raises: + FileNotFoundError: If the datasplits directory does not exist. + Examples: + >>> store.datasplits + Path("path/to/configs/datasplits") + """ return self.path / "datasplits" @property def arrays(self) -> Path: + """ + Returns the path to the arrays directory. + + Returns: + Path: The path to the arrays directory. + Raises: + FileNotFoundError: If the arrays directory does not exist. + Examples: + >>> store.arrays + Path("path/to/configs/arrays") + """ return self.path / "arrays" @property def architectures(self) -> Path: + """ + Returns the path to the architectures directory. + + Returns: + Path: The path to the architectures directory. + Raises: + FileNotFoundError: If the architectures directory does not exist. + Examples: + >>> store.architectures + Path("path/to/configs/architectures") + """ return self.path / "architectures" @property def trainers(self) -> Path: + """ + Returns the path to the trainers directory. + + Returns: + Path: The path to the trainers directory. + Raises: + FileNotFoundError: If the trainers directory does not exist. + Examples: + >>> store.trainers + Path("path/to/configs/trainers") + """ return self.path / "trainers" @property def datasets(self) -> Path: + """ + Returns the path to the datasets directory. + + Returns: + Path: The path to the datasets directory. + Raises: + FileNotFoundError: If the datasets directory does not exist. + Examples: + >>> store.datasets + Path("path/to/configs/datasets") + + """ return self.path / "datasets" def delete_config(self, database: Path, config_name: str) -> None: + """ + Deletes a configuration file from the specified database. + + Args: + database (Path): The path to the database where the configuration file is stored. + config_name (str): The name of the configuration file to be deleted. + Raises: + FileNotFoundError: If the configuration file does not exist. + Examples: + >>> store.delete_config(Path("path/to/configs"), "run1") + + """ (database / f"{config_name}.yaml").unlink() diff --git a/dacapo/store/file_stats_store.py b/dacapo/store/file_stats_store.py index 0d2c28e08..72cf9df58 100644 --- a/dacapo/store/file_stats_store.py +++ b/dacapo/store/file_stats_store.py @@ -6,17 +6,47 @@ import logging import pickle -from pathlib import Path +from upath import UPath as Path logger = logging.getLogger(__name__) class FileStatsStore(StatsStore): - """A File based store for run statistics. Used to store and retrieve training - statistics and validation scores. + """A File based store for run statistics. Used to store and retrieve training statistics and validation scores. + + The store is organized as follows: + - A directory for training statistics, with a subdirectory for each run. Each run directory contains a pickled list of TrainingIterationStats objects. + - A directory for validation scores, with a subdirectory for each run. Each run directory contains a pickled list of ValidationIterationScores objects. + + Attributes: + - path: The root directory for the store. + - training_stats: The directory for training statistics. + - validation_scores: The directory for validation scores. + + Methods: + - store_training_stats(run_name, stats): Store the training statistics for a run. + - retrieve_training_stats(run_name): Retrieve the training statistics for a run. + - store_validation_iteration_scores(run_name, scores): Store the validation scores for a run. + - retrieve_validation_iteration_scores(run_name): Retrieve the validation scores for a run. + - delete_training_stats(run_name): Delete the training statistics for a run. + + Note: The store does not support concurrent access. It is intended for use in single-threaded applications. + + """ def __init__(self, path): + """ + Initializes a new instance of the FileStatsStore class. + + Args: + path (str): The path to the file. + Raises: + ValueError: If the path is invalid. + Examples: + >>> store = FileStatsStore("store") + + """ print(f"Creating FileStatsStore:\n\tpath : {path}") self.path = Path(path) @@ -25,6 +55,23 @@ def __init__(self, path): self.__init_db() def store_training_stats(self, run_name, stats): + """ + Stores the training statistics for a specific run. + + Args: + run_name (str): The name of the run. + stats (Stats): The training statistics to be stored. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.store_training_stats("run1", stats) + Notes: + - If the training statistics for the given run already exist in the database, the method will compare the + existing statistics with the new statistics and update or overwrite them accordingly. + - If the new statistics go further than the existing statistics, the method will update the statistics from + the last stored iteration. + - If the new statistics are behind the existing statistics, the method will overwrite the existing statistics. + """ existing_stats = self.__read_training_stats(run_name) store_from_iteration = 0 @@ -51,11 +98,37 @@ def store_training_stats(self, run_name, stats): ) def retrieve_training_stats(self, run_name): + """ + Retrieve the training statistics for a specific run. + + Parameters: + run_name (str): The name of the run for which to retrieve the training statistics. + + Returns: + dict: A dictionary containing the training statistics for the specified run. + """ return self.__read_training_stats(run_name) def store_validation_iteration_scores(self, run_name, scores): - existing_iteration_scores = self.__read_validation_iteration_scores(run_name) - store_from_iteration, drop_db = scores.compare(existing_iteration_scores) + """ + Stores the validation scores for a specific run. + + Args: + run_name (str): The name of the run. + scores (Scores): The validation scores to be stored. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.store_validation_iteration_scores("run1", scores) + Notes: + - If the validation scores for the given run already exist in the database, the method will compare the + existing scores with the new scores and update or overwrite them accordingly. + - If the new scores go further than the existing scores, the method will update the scores from + the last stored iteration. + - If the new scores are behind the existing scores, the method will overwrite the existing scores. + """ + existing_scores = self.__read_validation_iteration_scores(run_name) + store_from_iteration, drop_db = scores.compare(existing_scores) if drop_db: # current scores are behind DB--drop DB @@ -72,12 +145,50 @@ def store_validation_iteration_scores(self, run_name, scores): ) def retrieve_validation_iteration_scores(self, run_name): + """ + Retrieve the validation iteration scores for a given run. + + Args: + run_name (str): The name of the run for which to retrieve the validation iteration scores. + Returns: + list: A list of validation iteration scores. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.retrieve_validation_iteration_scores("run1") + + """ return self.__read_validation_iteration_scores(run_name) def delete_training_stats(self, run_name: str) -> None: + """ + Deletes the training stats for a specific run. + + Args: + run_name (str): The name of the run for which to delete the training stats. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.delete_training_stats("run1") + + """ self.__delete_training_stats(run_name) def __store_training_stats(self, stats, begin, end, run_name): + """ + Store the training statistics for a specific run. + + Args: + stats (Stats): The statistics object containing the training stats. + begin (int): The starting index of the iteration stats to store. + end (int): The ending index of the iteration stats to store. + run_name (str): The name of the run. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.__store_training_stats(stats, 0, 100, "run1") + + """ docs = converter.unstructure(stats.iteration_stats[begin:end]) for doc in docs: doc.update({"run_name": run_name}) @@ -88,6 +199,18 @@ def __store_training_stats(self, stats, begin, end, run_name): pickle.dump(docs, fd) def __read_training_stats(self, run_name): + """ + Read the training statistics for a given run. + + Args: + run_name (str): The name of the run for which to read the training statistics. + Returns: + TrainingStats: The training statistics for the run. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.__read_training_stats("run1") + """ file_store = self.training_stats / run_name if file_store.exists(): with file_store.open("rb") as fd: @@ -98,6 +221,16 @@ def __read_training_stats(self, run_name): return stats def __delete_training_stats(self, run_name): + """ + Deletes the training stats file for a given run. + + Args: + run_name (str): The name of the run for which to delete the training stats. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.__delete_training_stats("run1") + """ file_store = self.training_stats / run_name if file_store.exists(): file_store.unlink() @@ -105,6 +238,20 @@ def __delete_training_stats(self, run_name): def __store_validation_iteration_scores( self, validation_scores: ValidationScores, begin: int, end: int, run_name: str ) -> None: + """ + Store the validation iteration scores. + + Args: + validation_scores (ValidationScores): The validation scores object. + begin (int): The starting iteration index. + end (int): The ending iteration index. + run_name (str): The name of the run. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.__store_validation_iteration_scores(validation_scores, 0, 100, "run1") + + """ docs = [ converter.unstructure(scores) for scores in validation_scores.scores @@ -119,6 +266,18 @@ def __store_validation_iteration_scores( pickle.dump(docs, fd) def __read_validation_iteration_scores(self, run_name): + """ + Read the validation iteration scores for a given run. + + Args: + run_name (str): The name of the run for which to read the validation iteration scores. + Returns: + ValidationScores: The validation iteration scores for the run. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.__read_validation_iteration_scores("run1") + """ file_store = self.validation_scores / run_name if file_store.exists(): with file_store.open("rb") as fd: @@ -129,14 +288,46 @@ def __read_validation_iteration_scores(self, run_name): return scores def __delete_validation_iteration_scores(self, run_name): + """ + Delete the validation iteration scores for a given run. + + Args: + run_name (str): The name of the run for which to delete the validation iteration scores. + Raises: + ValueError: If the run name is invalid. + Examples: + >>> store.__delete_validation_iteration_scores("run1") + + """ file_store = self.validation_scores / run_name if file_store.exists(): file_store.unlink() def __init_db(self): + """ + Initialize the database for the file stats store. + + This method creates the necessary directories for storing training statistics and validation scores. + + Raises: + ValueError: If the path is invalid. + Examples: + >>> store.__init_db() + """ + pass def __open_collections(self): + """ + Open the collections for the file stats store. + + This method initializes the directories for storing training statistics and validation scores. + + Raises: + ValueError: If the path is invalid. + Examples: + >>> store.__open_collections() + """ self.training_stats = self.path / "training_stats" self.training_stats.mkdir(exist_ok=True, parents=True) self.validation_scores = self.path / "validation_scores" diff --git a/dacapo/store/local_array_store.py b/dacapo/store/local_array_store.py index b04b4555d..f47119831 100644 --- a/dacapo/store/local_array_store.py +++ b/dacapo/store/local_array_store.py @@ -1,6 +1,6 @@ from .array_store import ArrayStore, LocalArrayIdentifier, LocalContainerIdentifier -from pathlib import Path +from upath import UPath as Path import logging import shutil from typing import Optional, Tuple @@ -9,61 +9,133 @@ class LocalArrayStore(ArrayStore): - """A local array store that uses zarr containers.""" + """ + A local array store that uses zarr containers. + + Attributes: + basedir: The base directory where the store will write data. + Methods: + best_validation_array: Get the array identifier for the best validation array. + validation_prediction_array: Get the array identifier for a particular validation prediction. + validation_output_array: Get the array identifier for a particular validation output. + validation_input_arrays: Get the array identifiers for the validation input raw/gt. + snapshot_container: Get a container identifier for storage of a snapshot. + validation_container: Get a container identifier for storage of a validation. + remove: Remove a dataset from a container. + + """ def __init__(self, basedir): + """ + Initialize the LocalArrayStore. + + Args: + basedir (str): The base directory where the store will write data. + Raises: + ValueError: If the basedir is not a directory. + Examples: + >>> store = LocalArrayStore("/path/to/store") + + """ self.basedir = basedir def best_validation_array( self, run_name: str, criterion: str, index: Optional[str] = None ) -> LocalArrayIdentifier: + """ + Get the array identifier for the best validation array. + + Args: + run_name (str): The name of the run. + criterion (str): The criterion for the validation array. + index (str, optional): The index of the validation array. Defaults to None. + Returns: + LocalArrayIdentifier: The array identifier for the best validation array. + Raises: + ValueError: If the container does not exist. + Examples: + >>> store.best_validation_array("run1", "loss") + + """ + container = self.validation_container(run_name).container if index is None: dataset = f"{criterion}" else: dataset = f"{index}/{criterion}" - return LocalArrayIdentifier(container, dataset) def validation_prediction_array( self, run_name: str, iteration: int, dataset: str ) -> LocalArrayIdentifier: - """Get the array identifier for a particular validation prediction.""" - + """ + Get the array identifier for a particular validation prediction. + + Args: + run_name (str): The name of the run. + iteration (int): The iteration of the validation prediction. + dataset (str): The dataset of the validation prediction. + Returns: + LocalArrayIdentifier: The array identifier for the validation prediction. + Raises: + ValueError: If the container does not exist. + Examples: + >>> store.validation_prediction_array("run1", 0, "train") + """ container = self.validation_container(run_name).container dataset = f"{iteration}/{dataset}/prediction" - return LocalArrayIdentifier(container, dataset) def validation_output_array( self, run_name: str, iteration: int, parameters: str, dataset: str ) -> LocalArrayIdentifier: - """Get the array identifier for a particular validation output.""" - + """ + Get the array identifier for a particular validation output. + + Args: + run_name (str): The name of the run. + iteration (int): The iteration of the validation output. + parameters (str): The parameters of the validation output. + dataset (str): The dataset of the validation output. + Returns: + LocalArrayIdentifier: The array identifier for the validation output. + Raises: + ValueError: If the container does not exist. + Examples: + >>> store.validation_output_array("run1", 0, "params1", "train") + """ container = self.validation_container(run_name).container dataset = f"{iteration}/{dataset}/output/{parameters}" - return LocalArrayIdentifier(container, dataset) def validation_input_arrays( self, run_name: str, index: Optional[str] = None ) -> Tuple[LocalArrayIdentifier, LocalArrayIdentifier]: """ - Get an array identifiers for the validation input raw/gt. + Get the array identifiers for the validation input raw/gt. It would be nice to store raw/gt with the validation predictions/outputs. If we don't store these we would have to look up the datasplit config and figure out where to find the inputs for each run. If we write the data then we don't need to search for it. This convenience comes at the cost of some extra memory usage. - """ + Args: + run_name (str): The name of the run. + index (str, optional): The index of the validation input. Defaults to None. + Returns: + Tuple[LocalArrayIdentifier, LocalArrayIdentifier]: The array identifiers for the validation input raw/gt. + Raises: + ValueError: If the container does not exist. + Examples: + >>> store.validation_input_arrays("run1") + """ container = self.validation_container(run_name).container if index is not None: dataset_prefix = f"inputs/{index}" + else: dataset_prefix = "inputs" - return ( LocalArrayIdentifier(container, f"{dataset_prefix}/raw"), LocalArrayIdentifier(container, f"{dataset_prefix}/gt"), @@ -72,28 +144,61 @@ def validation_input_arrays( def snapshot_container(self, run_name: str) -> LocalContainerIdentifier: """ Get a container identifier for storage of a snapshot. + + Args: + run_name (str): The name of the run. + Returns: + LocalContainerIdentifier: The container identifier for the snapshot. + Raises: + ValueError: If the container does not exist. + Examples: + >>> store.snapshot_container("run1") + """ + return LocalContainerIdentifier( Path(self.__get_run_dir(run_name), "snapshot.zarr") ) def validation_container(self, run_name: str) -> LocalContainerIdentifier: """ - Get a container identifier for storage of a snapshot. + Get a container identifier for storage of a validation. + + Args: + run_name (str): The name of the run. + Returns: + LocalContainerIdentifier: The container identifier for the validation. + Raises: + ValueError: If the container does not exist. + Examples: + >>> store.validation_container("run1") + """ + return LocalContainerIdentifier( Path(self.__get_run_dir(run_name), "validation.zarr") ) def remove(self, array_identifier: "LocalArrayIdentifier") -> None: + """ + Remove a dataset from a container. + + Args: + array_identifier (LocalArrayIdentifier): The array identifier of the dataset to remove. + Raises: + ValueError: If the container path does not end with '.zarr'. + Examples: + >>> store.remove(array_identifier) + + """ container = array_identifier.container + dataset = array_identifier.dataset assert container.suffix == ".zarr", ( f"The container path does not end with '.zarr'. Stopping here to " f"prevent data loss." ) - path = Path(container, dataset) if not path.exists(): @@ -107,9 +212,21 @@ def remove(self, array_identifier: "LocalArrayIdentifier") -> None: f"Asked to remove dataset {dataset} in container {container}, but it is not a directory. Will not delete." ) return - print(f"Removing dataset {dataset} in container {container}") + shutil.rmtree(path) def __get_run_dir(self, run_name: str) -> Path: + """ + Get the directory path for a run. + + Args: + run_name (str): The name of the run. + Returns: + Path: The directory path for the run. + Raises: + ValueError: If the run directory does not exist. + Examples: + >>> store.__get_run_dir("run1") + """ return Path(self.basedir, run_name) diff --git a/dacapo/store/local_weights_store.py b/dacapo/store/local_weights_store.py index 7f2384547..fb375602b 100644 --- a/dacapo/store/local_weights_store.py +++ b/dacapo/store/local_weights_store.py @@ -5,7 +5,7 @@ import torch import json -from pathlib import Path +from upath import UPath as Path import logging from typing import Optional, Union @@ -14,16 +14,76 @@ class LocalWeightsStore(WeightsStore): - """A local store for network weights.""" + """ + A local store for network weights. + + All weights are stored in a directory structure like this: + + ``` + basedir + ├── run1 + │ ├── checkpoints + │ │ ├── iterations + │ │ │ ├── 0 + │ │ │ ├── 1 + │ │ │ ├── ... + │ ├── dataset1 + │ │ ├── criterion1.json + │ ├── dataset2 + │ │ ├── criterion2.json + ├── run2 + │ ├── ... + ``` + + Attributes: + basedir: The base directory where the weights are stored. + Methods: + latest_iteration: Return the latest iteration for which weights are available for the given run. + store_weights: Store the network weights of the given run. + retrieve_weights: Retrieve the network weights of the given run. + remove: Remove the network weights of the given run. + store_best: Store the best weights in a easy to find location. + retrieve_best: Retrieve the best weights of the given run. + Note: + The weights are stored in the format of a Weights object, which is a simple container for the model and optimizer state dicts. + + """ def __init__(self, basedir): + """ + Create a new local weights store. + + Args: + basedir: The base directory where the weights are stored. + Raises: + FileNotFoundError: If the directory does not exist. + Examples: + >>> store = LocalWeightsStore("weights") + Note: + The directory is created if it does not exist. + + """ print(f"Creating local weights store in directory {basedir}") self.basedir = basedir def latest_iteration(self, run: str) -> Optional[int]: - """Return the latest iteration for which weights are available for the - given run.""" + """ + Return the latest iteration for which weights are available for the + given run. + + Args: + run: The name of the run. + Returns: + The latest iteration for which weights are available, or None if no + weights are available. + Raises: + FileNotFoundError: If the run directory does not exist. + Examples: + >>> store.latest_iteration("run1") + Note: + The iteration is determined by the number of the subdirectories in the "iterations" directory. + """ weights_dir = self.__get_weights_dir(run) / "iterations" @@ -35,7 +95,19 @@ def latest_iteration(self, run: str) -> Optional[int]: return iterations[-1] def store_weights(self, run: Run, iteration: int): - """Store the network weights of the given run.""" + """ + Store the network weights of the given run. + + Args: + run: The run object. + iteration: The iteration number. + Raises: + FileNotFoundError: If the run directory does not exist. + Examples: + >>> store.store_weights(run, 0) + Note: + The weights are stored in the format of a Weights object, which is a simple container for the model and optimizer state dicts. + """ logger.warning(f"Storing weights for run {run}, iteration {iteration}") @@ -50,7 +122,21 @@ def store_weights(self, run: Run, iteration: int): torch.save(weights, weights_name) def retrieve_weights(self, run: str, iteration: int) -> Weights: - """Retrieve the network weights of the given run.""" + """ + Retrieve the network weights of the given run. + + Args: + run: The name of the run. + iteration: The iteration number. + Returns: + The network weights. + Raises: + FileNotFoundError: If the weights file does not exist. + Examples: + >>> store.retrieve_weights("run1", 0) + Note: + The weights are stored in the format of a Weights object, which is a simple container for the model and optimizer state dicts. + """ print(f"Retrieving weights for run {run}, iteration {iteration}") @@ -64,6 +150,21 @@ def retrieve_weights(self, run: str, iteration: int) -> Weights: return weights def _retrieve_weights(self, run: str, key: str) -> Weights: + """ + Retrieves the weights for a given run and key. + + Args: + run (str): The name of the run. + key (str): The key of the weights. + Returns: + Weights: The retrieved weights. + Raises: + FileNotFoundError: If the weights file does not exist. + Examples: + >>> store._retrieve_weights("run1", "key1") + Note: + The weights are stored in the format of a Weights object, which is a simple container for the model and optimizer state dicts. + """ weights_name = self.__get_weights_dir(run) / key if not weights_name.exists(): weights_name = self.__get_weights_dir(run) / "iterations" / key @@ -76,6 +177,19 @@ def _retrieve_weights(self, run: str, key: str) -> Weights: return weights def remove(self, run: str, iteration: int): + """ + Remove the weights for a specific run and iteration. + + Args: + run (str): The name of the run. + iteration (int): The iteration number. + Raises: + FileNotFoundError: If the weights file does not exist. + Examples: + >>> store.remove("run1", 0) + Note: + The weights are stored in the format of a Weights object, which is a simple container for the model and optimizer state dicts. + """ weights = self.__get_weights_dir(run) / "iterations" / str(iteration) weights.unlink() @@ -84,6 +198,18 @@ def store_best(self, run: str, iteration: int, dataset: str, criterion: str): Store the best weights in a easy to find location. Symlinks weights from appropriate iteration # TODO: simply store a yaml of dataset/criterion -> iteration/parameter id + + Args: + run (str): The name of the run. + iteration (int): The iteration number. + dataset (str): The name of the dataset. + criterion (str): The criterion for selecting the best weights. + Raises: + FileNotFoundError: If the weights file does not exist. + Examples: + >>> store.store_best("run1", 0, "dataset1", "criterion1") + Note: + The best weights are stored in a json file that contains the iteration number. """ # must exist since we must read run/iteration weights @@ -107,6 +233,22 @@ def store_best(self, run: str, iteration: int, dataset: str, criterion: str): f.write(json.dumps({"iteration": iteration})) def retrieve_best(self, run: str, dataset: str | Dataset, criterion: str) -> int: + """ + Retrieve the best weights for a given run, dataset, and criterion. + + Args: + run (str): The name of the run. + dataset (str | Dataset): The name of the dataset or a Dataset object. + criterion (str): The criterion for selecting the best weights. + Returns: + int: The iteration number of the best weights. + Raises: + FileNotFoundError: If the weights file does not exist. + Examples: + >>> store.retrieve_best("run1", "dataset1", "criterion1") + Note: + The best weights are stored in a json file that contains the iteration number. + """ print(f"Retrieving weights for run {run}, criterion {criterion}") with (self.__get_weights_dir(run) / criterion / f"{dataset}.json").open( @@ -117,6 +259,17 @@ def retrieve_best(self, run: str, dataset: str | Dataset, criterion: str) -> int return weights_info["iteration"] def _load_best(self, run: Run, criterion: str): + """ + Load the best weights for a given run and criterion. + + Args: + run (Run): The run for which to load the weights. + criterion (str): The criterion for which to load the weights. + Examples: + >>> store._load_best(run, "criterion1") + Note: + This method is used internally by the store to load the best weights for a given run and criterion. + """ print(f"Retrieving weights for run {run}, criterion {criterion}") weights_name = self.__get_weights_dir(run) / f"{criterion}" @@ -128,6 +281,20 @@ def _load_best(self, run: Run, criterion: str): run.model.load_state_dict(weights.model) def __get_weights_dir(self, run: Union[str, Run]): + """ + Get the directory path for storing weights checkpoints. + + Args: + run: The name of the run or the run object. + Returns: + Path: The directory path for storing weights checkpoints. + Raises: + FileNotFoundError: If the run directory does not exist. + Examples: + >>> store.__get_weights_dir("run1") + Note: + The directory is created if it does not exist. + """ run = run if isinstance(run, str) else run.name return Path(self.basedir, run, "checkpoints") diff --git a/dacapo/store/mongo_config_store.py b/dacapo/store/mongo_config_store.py index a33aaf710..c5daecd4d 100644 --- a/dacapo/store/mongo_config_store.py +++ b/dacapo/store/mongo_config_store.py @@ -17,11 +17,64 @@ class MongoConfigStore(ConfigStore): - """A MongoDB store for configurations. Used to store and retrieve + """ + A MongoDB store for configurations. Used to store and retrieve configurations for runs, tasks, architectures, trainers, and datasets. + + Attributes: + db_host (str): The host of the MongoDB database. + db_name (str): The name of the MongoDB database. + client (MongoClient): The MongoDB client. + database (Database): The MongoDB database. + users (Collection): The users collection. + runs (Collection): The runs collection. + tasks (Collection): The tasks collection. + datasplits (Collection): The datasplits collection. + datasets (Collection): The datasets collection. + arrays (Collection): The arrays collection. + architectures (Collection): The architectures collection. + trainers (Collection): The trainers collection. + Methods: + store_run_config(run_config, ignore): Store the run configuration. + retrieve_run_config(run_name): Retrieve the run configuration. + delete_run_config(run_name): Delete the run configuration. + retrieve_run_config_names(task_names, datasplit_names, architecture_names, trainer_names): Retrieve the names of the run configurations. + store_task_config(task_config, ignore): Store the task configuration. + retrieve_task_config(task_name): Retrieve the task configuration. + retrieve_task_config_names(): Retrieve the names of the task configurations. + store_architecture_config(architecture_config, ignore): Store the architecture configuration. + retrieve_architecture_config(architecture_name): Retrieve the architecture configuration. + retrieve_architecture_config_names(): Retrieve the names of the architecture configurations. + store_trainer_config(trainer_config, ignore): Store the trainer configuration. + retrieve_trainer_config(trainer_name): Retrieve the trainer configuration. + retrieve_trainer_config_names(): Retrieve the names of the trainer configurations. + store_datasplit_config(datasplit_config, ignore): Store the datasplit configuration. + retrieve_datasplit_config(datasplit_name): Retrieve the datasplit configuration. + retrieve_datasplit_config_names(): Retrieve the names of the datasplit configurations. + store_dataset_config(dataset_config, ignore): Store the dataset configuration. + retrieve_dataset_config(dataset_name): Retrieve the dataset configuration. + retrieve_dataset_config_names(): Retrieve the names of the dataset configurations. + store_array_config(array_config, ignore): Store the array configuration. + retrieve_array_config(array_name): Retrieve the array configuration. + retrieve_array_config_names(): Retrieve the names of the array configurations. + __save_insert(collection, data, ignore): Save or insert a document into a collection. + __same_doc(a, b, ignore): Check if two documents are the same. + __init_db(): Initialize the database. + __open_collections(): Open the collections. + Notes: + The store is initialized with the host and database name. """ def __init__(self, db_host, db_name): + """ + Initialize a MongoConfigStore object. + + Args: + db_host (str): The host address of the MongoDB server. + db_name (str): The name of the database to connect to. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + """ print( f"Creating MongoConfigStore:\n\thost : {db_host}\n\tdatabase: {db_name}" ) @@ -35,13 +88,53 @@ def __init__(self, db_host, db_name): self.__init_db() def delete_config(self, database, config_name: str) -> None: + """ + Deletes a configuration from the database. + + Args: + database: The database object. + config_name: The name of the configuration to delete. + Raises: + ValueError: If the configuration is not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> config_name = 'config_0' + >>> store.delete_config(store.tasks, config_name) + """ database.delete_one({"name": config_name}) def store_run_config(self, run_config, ignore=None): + """ + Stores the run configuration in the MongoDB runs collection. + + Args: + run_config (dict): The run configuration to be stored. + ignore (list, optional): A list of fields to ignore during the storage process. + Raises: + DuplicateNameError: If the run configuration is already stored. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> run_config = {'name': 'run_0'} + >>> store.store_run_config(run_config) + """ run_doc = converter.unstructure(run_config) self.__save_insert(self.runs, run_doc, ignore) def retrieve_run_config(self, run_name): + """ + Retrieve the run configuration for a given run name. + + Args: + run_name (str): The name of the run. + Returns: + RunConfig: The run configuration for the given run name. + Raises: + ValueError: If the run configuration is not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.retrieve_run_config(run_name) + """ run_doc = self.runs.find_one({"name": run_name}, projection={"_id": False}) try: return converter.structure(run_doc, RunConfig) @@ -49,6 +142,18 @@ def retrieve_run_config(self, run_name): raise TypeError(f"Could not structure run: {run_name} as RunConfig!") from e def delete_run_config(self, run_name): + """ + Delete a run configuration from the MongoDB collection. + + Args: + run_name (str): The name of the run configuration to delete. + Raises: + ValueError: If the run configuration is not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.delete_run_config(run_name) + """ self.runs.delete_one({"name": run_name}) def retrieve_run_config_names( @@ -58,6 +163,26 @@ def retrieve_run_config_names( architecture_names=None, trainer_names=None, ): + """ + Retrieve the names of run configurations based on specified filters. + + Args: + task_names (list, optional): List of task names to filter the run configurations. Defaults to None. + datasplit_names (list, optional): List of datasplit names to filter the run configurations. Defaults to None. + architecture_names (list, optional): List of architecture names to filter the run configurations. Defaults to None. + trainer_names (list, optional): List of trainer names to filter the run configurations. Defaults to None. + Returns: + list: A list of run configuration names that match the specified filters. + Raises: + ValueError: If the run configurations are not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> task_names = ['task_0'] + >>> datasplit_names = ['datasplit_0'] + >>> architecture_names = ['architecture_0'] + >>> trainer_names = ['trainer_0'] + >>> store.retrieve_run_config_names(task_names, datasplit_names, architecture_names, trainer_names) + """ filters = {} if task_names is not None: filters["task_config.name"] = {"$in": task_names} @@ -71,90 +196,349 @@ def retrieve_run_config_names( return list([run["name"] for run in runs]) def store_task_config(self, task_config, ignore=None): + """ + Store the task configuration in the MongoDB tasks collection. + + Args: + task_config (TaskConfig): The task configuration to be stored. + ignore (list, optional): A list of fields to ignore during the storage process. + Raises: + DuplicateNameError: If the task configuration is already stored. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> task_config = TaskConfig(name='task_0') + >>> store.store_task_config(task_config) + """ task_doc = converter.unstructure(task_config) self.__save_insert(self.tasks, task_doc, ignore) def retrieve_task_config(self, task_name): + """ + Retrieve the task configuration for a given task name. + + Args: + task_name (str): The name of the task. + Returns: + TaskConfig: The task configuration object. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> task_name = 'task_0' + >>> store.retrieve_task_config(task_name) + + """ task_doc = self.tasks.find_one({"name": task_name}, projection={"_id": False}) return converter.structure(task_doc, TaskConfig) def retrieve_task_config_names(self): + """ + Retrieve the names of all task configurations. + + Returns: + A list of task configuration names. + Raises: + ValueError: If the task configurations are not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> store.retrieve_task_config_names() + + """ tasks = self.tasks.find({}, projection={"_id": False, "name": True}) return list([task["name"] for task in tasks]) def store_architecture_config(self, architecture_config, ignore=None): + """ + Store the architecture configuration in the MongoDB. + + Args: + architecture_config (ArchitectureConfig): The architecture configuration to be stored. + ignore (list, optional): List of fields to ignore during storage. Defaults to None. + Raises: + DuplicateNameError: If the architecture configuration is already stored. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> architecture_config = ArchitectureConfig(name='architecture_0') + >>> store.store_architecture_config(architecture_config) + """ architecture_doc = converter.unstructure(architecture_config) self.__save_insert(self.architectures, architecture_doc, ignore) def retrieve_architecture_config(self, architecture_name): + """ + Retrieve the architecture configuration for the given architecture name. + + Args: + architecture_name (str): The name of the architecture. + Returns: + ArchitectureConfig: The architecture configuration object. + Raises: + ValueError: If the architecture configuration is not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> architecture_name = 'architecture_0' + >>> store.retrieve_architecture_config(architecture_name) + + """ architecture_doc = self.architectures.find_one( {"name": architecture_name}, projection={"_id": False} ) return converter.structure(architecture_doc, ArchitectureConfig) def retrieve_architecture_config_names(self): + """ + Retrieve the names of all architecture configurations. + + Returns: + A list of architecture configuration names. + Raises: + ValueError: If the architecture configurations are not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> store.retrieve_architecture_config_names() + + """ architectures = self.architectures.find( {}, projection={"_id": False, "name": True} ) return list([architecture["name"] for architecture in architectures]) def store_trainer_config(self, trainer_config, ignore=None): + """ + Store the trainer configuration in the MongoDB. + + Args: + trainer_config (TrainerConfig): The trainer configuration to be stored. + ignore (list, optional): List of fields to ignore during storage. Defaults to None. + Returns: + DuplicateNameError: If the trainer configuration is already stored. + Raises: + DuplicateNameError: If the trainer configuration is already stored. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> trainer_config = TrainerConfig(name='trainer_0') + >>> store.store_trainer_config(trainer_config) + """ trainer_doc = converter.unstructure(trainer_config) self.__save_insert(self.trainers, trainer_doc, ignore) def retrieve_trainer_config(self, trainer_name): + """ + Retrieve the trainer configuration for the given trainer name. + + Args: + trainer_name (str): The name of the trainer. + Returns: + TrainerConfig: The trainer configuration object. + Raises: + ValueError: If the trainer configuration is not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> trainer_name = 'trainer_0' + >>> store.retrieve_trainer_config(trainer_name) + """ trainer_doc = self.trainers.find_one( {"name": trainer_name}, projection={"_id": False} ) return converter.structure(trainer_doc, TrainerConfig) def retrieve_trainer_config_names(self): + """ + Retrieve the names of all trainer configurations. + + Args: + trainer_name (str): The name of the trainer. + Returns: + TrainerConfig: The trainer configuration object. + Raises: + ValueError: If the trainer configuration is not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> trainer_name = 'trainer_0' + >>> store.retrieve_trainer_config(trainer_name) + """ trainers = self.trainers.find({}, projection={"_id": False, "name": True}) return list([trainer["name"] for trainer in trainers]) def store_datasplit_config(self, datasplit_config, ignore=None): + """ + Store the datasplit configuration in the MongoDB. + + Args: + datasplit_config (DataSplitConfig): The datasplit configuration to be stored. + ignore (list, optional): List of fields to ignore during storage. Defaults to None. + Returns: + DuplicateNameError: If the datasplit configuration is already stored. + Raises: + DuplicateNameError: If the datasplit configuration is already stored. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> datasplit_config = DataSplitConfig(name='datasplit_0') + >>> store.store_datasplit_config(datasplit_config) + """ datasplit_doc = converter.unstructure(datasplit_config) self.__save_insert(self.datasplits, datasplit_doc, ignore) def retrieve_datasplit_config(self, datasplit_name): + """ + Retrieve the datasplit configuration for the given datasplit name. + + Args: + datasplit_name (str): The name of the datasplit. + Returns: + DataSplitConfig: The datasplit configuration object. + Raises: + ValueError: If the datasplit configuration is not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> datasplit_name = 'datasplit_0' + >>> store.retrieve_datasplit_config(datasplit_name) + """ datasplit_doc = self.datasplits.find_one( {"name": datasplit_name}, projection={"_id": False} ) return converter.structure(datasplit_doc, DataSplitConfig) def retrieve_datasplit_config_names(self): + """ + Retrieve the names of all datasplit configurations. + + Returns: + A list of datasplit configuration names. + Raises: + ValueError: If the datasplit configurations are not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> store.retrieve_datasplit_config_names() + """ datasplits = self.datasplits.find({}, projection={"_id": False, "name": True}) return list([datasplit["name"] for datasplit in datasplits]) def store_dataset_config(self, dataset_config, ignore=None): + """ + Store the dataset configuration in the MongoDB. + + Args: + dataset_config (DatasetConfig): The dataset configuration to be stored. + ignore (list, optional): List of fields to ignore during storage. Defaults to None. + Returns: + DuplicateNameError: If the dataset configuration is already stored. + Raises: + DuplicateNameError: If the dataset configuration is already stored. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> dataset_config = DatasetConfig(name='dataset_0') + >>> store.store_dataset_config(dataset_config) + + """ dataset_doc = converter.unstructure(dataset_config) self.__save_insert(self.datasets, dataset_doc, ignore) def retrieve_dataset_config(self, dataset_name): + """ + Retrieve the dataset configuration for the given dataset name. + + Args: + dataset_name (str): The name of the dataset. + Returns: + DatasetConfig: The dataset configuration object. + Raises: + ValueError: If the dataset configuration is not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> dataset_name = 'dataset_0' + >>> store.retrieve_dataset_config(dataset_name) + """ dataset_doc = self.datasets.find_one( {"name": dataset_name}, projection={"_id": False} ) return converter.structure(dataset_doc, DatasetConfig) def retrieve_dataset_config_names(self): + """ + Retrieve the names of all dataset configurations. + + Returns: + A list of dataset configuration names. + Raises: + ValueError: If the dataset configurations are not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> store.retrieve_dataset_config_names() + + """ datasets = self.datasets.find({}, projection={"_id": False, "name": True}) return list([dataset["name"] for dataset in datasets]) def store_array_config(self, array_config, ignore=None): + """ + Store the array configuration in the MongoDB. + + Args: + array_config (ArrayConfig): The array configuration to be stored. + ignore (list, optional): List of fields to ignore during storage. Defaults to None. + Returns: + DuplicateNameError: If the array configuration is already stored. + Raises: + DuplicateNameError: If the array configuration is already stored. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> array_config = ArrayConfig(name='array_0') + >>> store.store_array_config(array_config) + """ array_doc = converter.unstructure(array_config) self.__save_insert(self.arrays, array_doc, ignore) def retrieve_array_config(self, array_name): + """ + Retrieve the array configuration for the given array name. + + Args: + array_name (str): The name of the array. + Returns: + ArrayConfig: The array configuration object. + Raises: + ValueError: If the array configuration is not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> array_name = 'array_0' + >>> store.retrieve_array_config(array_name) + + """ array_doc = self.arrays.find_one( {"name": array_name}, projection={"_id": False} ) return converter.structure(array_doc, ArrayConfig) def retrieve_array_config_names(self): + """ + Retrieve the names of all array configurations. + + Returns: + A list of array configuration names. + Raises: + ValueError: If the array configurations are not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> store.retrieve_array_config_names() + """ arrays = self.arrays.find({}, projection={"_id": False, "name": True}) return list([array["name"] for array in arrays]) def __save_insert(self, collection, data, ignore=None): + """ + Save and insert data into the specified collection. + + Args: + collection (pymongo.collection.Collection): The collection to insert the data into. + data (dict): The data to be inserted. + ignore (list, optional): A list of keys to ignore when comparing existing and new data. + Raises: + DuplicateNameError: If the data for the given name does not match the already stored entry. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> collection = store.runs + >>> data = {'name': 'run_0'} + >>> store.__save_insert(collection, data) + """ name = data["name"] try: @@ -171,6 +555,24 @@ def __save_insert(self, collection, data, ignore=None): ) def __same_doc(self, a, b, ignore=None): + """ + Check if two documents are the same. + + Args: + a (dict): The first document. + b (dict): The second document. + ignore (list, optional): A list of fields to ignore during the comparison. + Returns: + bool: True if the documents are the same, False otherwise. + Raises: + ValueError: If the documents are not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> a = {'name': 'run_0'} + >>> b = {'name': 'run_0'} + >>> store.__same_doc(a, b) + + """ if ignore: a = dict(a) b = dict(b) @@ -186,6 +588,15 @@ def __same_doc(self, a, b, ignore=None): return bson_a == bson_b def __init_db(self): + """ + Initialize the MongoDB database. + + Raises: + ValueError: If the collections are not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> store.__init_db() + """ self.users.create_index([("username", ASCENDING)], name="username", unique=True) self.runs.create_index( @@ -207,6 +618,18 @@ def __init_db(self): self.trainers.create_index([("name", ASCENDING)], name="name", unique=True) def __open_collections(self): + """ + Opens the collections in the MongoDB database. + + This method initializes the collection attributes for various entities such as users, runs, tasks, datasplits, datasets, + arrays, architectures, and trainers. These attributes can be used to interact with the corresponding collections in the database. + + Raises: + ValueError: If the collections are not available. + Examples: + >>> store = MongoConfigStore('localhost', 'dacapo') + >>> store.__open_collections() + """ self.users = self.database["users"] self.runs = self.database["runs"] self.tasks = self.database["tasks"] diff --git a/dacapo/store/mongo_stats_store.py b/dacapo/store/mongo_stats_store.py index 5de35aea8..e71d3cac7 100644 --- a/dacapo/store/mongo_stats_store.py +++ b/dacapo/store/mongo_stats_store.py @@ -13,9 +13,38 @@ class MongoStatsStore(StatsStore): """A MongoDB store for run statistics. Used to store and retrieve training statistics and validation scores. + + Attributes: + db_host (str): The host of the MongoDB database. + db_name (str): The name of the MongoDB database. + client (MongoClient): The MongoDB client. + database (Database): The MongoDB database. + training_stats (Collection): The collection for training statistics. + validation_scores (Collection): The collection for validation scores. + Methods: + store_training_stats(run_name, stats): Store the training stats of a given run. + retrieve_training_stats(run_name): Retrieve the training stats for a given run. + store_validation_iteration_scores(run_name, scores): Store the validation iteration scores of a given run. + retrieve_validation_iteration_scores(run_name): Retrieve the validation iteration scores for a given run. + delete_training_stats(run_name): Delete the training stats associated with a specific run. + Notes: + The MongoStatsStore uses the 'training_stats' and 'validation_scores' collections to store the statistics. + """ def __init__(self, db_host, db_name): + """ + Initialize the MongoStatsStore with the given host and database name. + + Args: + db_host (str): The host of the MongoDB database. + db_name (str): The name of the MongoDB database. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + Notes: + The MongoStatsStore will connect to the MongoDB database at the given host. + + """ print( f"Creating MongoStatsStore:\n\thost : {db_host}\n\tdatabase: {db_name}" ) @@ -29,6 +58,22 @@ def __init__(self, db_host, db_name): self.__init_db() def store_training_stats(self, run_name: str, stats: TrainingStats): + """ + Store the training statistics for a specific run. + + Args: + run_name (str): The name of the run. + stats (TrainingStats): The training statistics to be stored. + Raises: + ValueError: If the training statistics are already stored. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> stats = TrainingStats() + >>> store.store_training_stats(run_name, stats) + Notes: + The training statistics are stored in the 'training_stats' collection. + """ existing_stats = self.__read_training_stats(run_name) store_from_iteration = 0 @@ -57,11 +102,43 @@ def store_training_stats(self, run_name: str, stats: TrainingStats): def retrieve_training_stats( self, run_name: str, subsample: bool = False ) -> TrainingStats: + """ + Retrieve the training statistics for a given run. + + Args: + run_name (str): The name of the run. + subsample (bool, optional): Whether to subsample the training statistics. Defaults to False. + Returns: + TrainingStats: The training statistics for the specified run. + Raises: + ValueError: If the training statistics are not available. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.retrieve_training_stats(run_name) + Notes: + The training statistics are retrieved from the 'training_stats' collection. + """ return self.__read_training_stats(run_name, subsample=subsample) def store_validation_iteration_scores( self, run_name: str, scores: ValidationScores ): + """ + Stores the validation iteration scores for a given run. + + Args: + run_name (str): The name of the run. + scores (ValidationScores): The validation scores to store. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> scores = ValidationScores() + >>> store.store_validation_iteration_scores(run_name, scores) + Notes: + The validation iteration scores are stored in the 'validation_scores' collection. + """ + existing_iteration_scores = self.__read_validation_iteration_scores(run_name) drop_db, store_from_iteration = scores.compare(existing_iteration_scores) @@ -86,6 +163,24 @@ def retrieve_validation_iteration_scores( subsample: bool = False, validation_interval: Optional[int] = None, ) -> List[ValidationIterationScores]: + """ + Retrieve the validation iteration scores for a given run. + + Args: + run_name (str): The name of the run. + subsample (bool, optional): Whether to subsample the scores. Defaults to False. + validation_interval (int, optional): The interval at which to retrieve the scores. Defaults to None. + Returns: + List[ValidationIterationScores]: A list of validation iteration scores. + Raises: + ValueError: If the validation iteration scores are not available. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.retrieve_validation_iteration_scores(run_name) + Notes: + The validation iteration scores are retrieved from the 'validation_scores' collection. + """ return self.__read_validation_iteration_scores( run_name, subsample=subsample, validation_interval=validation_interval ) @@ -93,6 +188,27 @@ def retrieve_validation_iteration_scores( def __store_training_stats( self, stats: TrainingStats, begin: int, end: int, run_name: str ) -> None: + """ + Store the training statistics in the database. + + Args: + stats (TrainingStats): The training statistics to store. + begin (int): The first iteration to store. + end (int): The last iteration to store. + run_name (str): The name of the run. + Raises: + ValueError: If the training statistics are already stored. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> stats = TrainingStats() + >>> begin = 0 + >>> end = 1000 + >>> run_name = 'run_0' + >>> store.__store_training_stats(stats, begin, end, run_name) + Notes: + The training statistics are stored in the 'training_stats' collection. + + """ docs = converter.unstructure(stats.iteration_stats[begin:end]) for doc in docs: doc.update({"run_name": run_name}) @@ -103,6 +219,23 @@ def __store_training_stats( def __read_training_stats( self, run_name: str, subsample: bool = False ) -> TrainingStats: + """ + Read training statistics from the MongoDB collection. + + Args: + run_name (str): The name of the training run. + subsample (bool, optional): Whether to subsample the statistics to get 1000 iterations. Defaults to False. + Returns: + TrainingStats: The training statistics. + Raises: + ValueError: If the training statistics are not available. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.__read_training_stats(run_name) + Notes: + The training statistics are read from the 'training_stats' collection. + """ filters: Dict[str, Any] = {"run_name": run_name} if subsample: # if possible subsample s.t. we get 1000 iterations @@ -124,6 +257,20 @@ def __read_training_stats( return stats def __delete_training_stats(self, run_name: str) -> None: + """ + Delete training stats for a given run name. + + Args: + run_name (str): The name of the run. + Raises: + ValueError: If the training statistics are not available. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.__delete_training_stats(run_name) + Notes: + The training statistics are deleted from the 'training_stats' collection. + """ self.training_stats.delete_many({"run_name": run_name}) def __store_validation_iteration_scores( @@ -133,6 +280,26 @@ def __store_validation_iteration_scores( end: int, run_name: str, ) -> None: + """ + Store the validation scores for a specific range of iterations. + + Args: + validation_scores (ValidationScores): The validation scores object containing the scores to be stored. + begin (int): The starting iteration (inclusive) for which the scores should be stored. + end (int): The ending iteration (exclusive) for which the scores should be stored. + run_name (str): The name of the run associated with the scores. + Raises: + ValueError: If the validation scores are already stored. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> validation_scores = ValidationScores() + >>> begin = 0 + >>> end = 1000 + >>> run_name = 'run_0' + >>> store.__store_validation_iteration_scores(validation_scores, begin, end, run_name) + Notes: + The validation scores are stored in the 'validation_scores' collection. + """ docs = [ converter.unstructure(scores) for scores in validation_scores.scores @@ -150,6 +317,25 @@ def __read_validation_iteration_scores( subsample: bool = False, validation_interval: Optional[int] = None, ) -> List[ValidationIterationScores]: + """ + Read and retrieve validation iteration scores from the MongoDB collection. + + Args: + run_name (str): The name of the run. + subsample (bool, optional): Whether to subsample the scores. Defaults to False. + validation_interval (int, optional): The interval at which to subsample the scores. + Only applicable if subsample is True. Defaults to None. + Returns: + List[ValidationIterationScores]: A list of validation iteration scores. + Raises: + ValueError: If there is an error in processing the documents. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.__read_validation_iteration_scores(run_name) + Notes: + The validation iteration scores are read from the 'validation_scores' collection. + """ filters: Dict[str, Any] = {"run_name": run_name} if subsample: # if possible subsample s.t. we get 1000 iterations @@ -178,15 +364,78 @@ def __read_validation_iteration_scores( return scores def delete_validation_scores(self, run_name: str) -> None: + """ + Deletes the validation scores for a given run. + + Args: + run_name (str): The name of the run for which validation scores should be deleted. + Raises: + ValueError: If the validation scores are not available. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.delete_validation_scores(run_name) + Notes: + The validation scores are deleted from the 'validation_scores' collection. + """ self.__delete_validation_scores(run_name) def __delete_validation_scores(self, run_name: str) -> None: + """ + Delete validation scores for a given run name. + + Args: + run_name (str): The name of the run. + Raises: + ValueError: If the validation scores are not available. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.__delete_validation_scores(run_name) + Notes: + The validation scores are deleted from the 'validation_scores' collection. + + """ self.validation_scores.delete_many({"run_name": run_name}) def delete_training_stats(self, run_name: str) -> None: + """ + Deletes the training stats for a given run. + + Args: + run_name (str): The name of the run for which training stats should be deleted. + Raises: + ValueError: If the training statistics are not available. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> run_name = 'run_0' + >>> store.delete_training_stats(run_name) + Notes: + The training statistics are deleted from the 'training_stats' collection. + """ self.__delete_training_stats(run_name) def __init_db(self): + """ + Initialize the database by creating indexes for the training_stats and validation_scores collections. + + This method creates indexes on specific fields to improve query performance. + + Indexes created: + - For training_stats collection: + - run_name and iteration (unique index) + - iteration + - For validation_scores collection: + - run_name, iteration, and dataset (unique index) + - iteration + Raises: + ValueError: If the indexes cannot be created. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> store.__init_db() + Notes: + The indexes are created to improve query performance. + """ self.training_stats.create_index( [("run_name", ASCENDING), ("iteration", ASCENDING)], name="run_it", @@ -201,5 +450,19 @@ def __init_db(self): self.validation_scores.create_index([("iteration", ASCENDING)], name="it") def __open_collections(self): + """ + Opens the collections in the MongoDB database. + + This method initializes the `training_stats` and `validation_scores` collections + in the MongoDB database. + + Raises: + ValueError: If the collections are not available. + Examples: + >>> store = MongoStatsStore('localhost', 'dacapo') + >>> store.__open_collections() + Notes: + The collections are used to store training statistics and validation scores. + """ self.training_stats = self.database["training_stats"] self.validation_scores = self.database["validation_scores"] diff --git a/dacapo/store/stats_store.py b/dacapo/store/stats_store.py index 6912ae208..48bf3671d 100644 --- a/dacapo/store/stats_store.py +++ b/dacapo/store/stats_store.py @@ -11,32 +11,105 @@ class StatsStore(ABC): - """Base class for statistics stores.""" + """ + Base class for statistics stores. + + Methods: + store_training_stats(run_name, training_stats): Store the training stats of a given run. + retrieve_training_stats(run_name): Retrieve the training stats for a given run. + store_validation_iteration_scores(run_name, validation_scores): Store the validation iteration scores of a given run. + retrieve_validation_iteration_scores(run_name): Retrieve the validation iteration scores for a given run. + delete_training_stats(run_name): Delete the training stats associated with a specific run. + """ @abstractmethod def store_training_stats(self, run_name: str, training_stats: "TrainingStats"): - """Store training stats of a given run.""" + """ + Store training stats of a given run. + + Args: + run_name (str): The name of the run. + training_stats (TrainingStats): The training stats to store. + Raises: + ValueError: If the training stats are already stored. + Examples: + >>> store = StatsStore() + >>> run_name = 'run_0' + >>> training_stats = TrainingStats() + >>> store.store_training_stats(run_name, training_stats) + """ pass @abstractmethod def retrieve_training_stats(self, run_name: str) -> "TrainingStats": - """Retrieve the training stats for a given run.""" + """ + Retrieve the training stats for a given run. + + Args: + run_name (str): The name of the run. + Returns: + TrainingStats: The training stats for the given run. + Raises: + ValueError: If the training stats are not available. + Examples: + >>> store = StatsStore() + >>> run_name = 'run_0' + >>> store.retrieve_training_stats(run_name) + """ pass @abstractmethod def store_validation_iteration_scores( self, run_name: str, validation_scores: "ValidationScores" ): - """Store the validation iteration scores of a given run.""" + """ + Store the validation iteration scores of a given run. + + Args: + run_name (str): The name of the run. + validation_scores (ValidationScores): The validation scores to store. + Raises: + ValueError: If the validation iteration scores are already stored. + Examples: + >>> store = StatsStore() + >>> run_name = 'run_0' + >>> validation_scores = ValidationScores() + >>> store.store_validation_iteration_scores(run_name, validation_scores) + """ pass @abstractmethod def retrieve_validation_iteration_scores( self, run_name: str ) -> List["ValidationIterationScores"]: - """Retrieve the validation iteration scores for a given run.""" + """ + Retrieve the validation iteration scores for a given run. + + Args: + run_name (str): The name of the run. + Returns: + List[ValidationIterationScores]: The validation iteration scores for the given run. + Raises: + ValueError: If the validation iteration scores are not available. + Examples: + >>> store = StatsStore() + >>> run_name = 'run_0' + >>> store.retrieve_validation_iteration_scores(run_name) + """ pass @abstractmethod def delete_training_stats(self, run_name: str) -> None: + """ + Deletes the training statistics for a given run. + + Args: + run_name (str): The name of the run. + Raises: + ValueError: If the training stats are not available. + Example: + >>> store = StatsStore() + >>> run_name = 'run_0' + >>> store.delete_training_stats(run_name) + """ pass diff --git a/dacapo/store/weights_store.py b/dacapo/store/weights_store.py index 9e4c16d58..d26bdff82 100644 --- a/dacapo/store/weights_store.py +++ b/dacapo/store/weights_store.py @@ -8,20 +8,62 @@ class Weights: + """ + A class representing the weights of a model and optimizer. + + Attributes: + optimizer (OrderedDict[str, torch.Tensor]): The optimizer's state dictionary. + model (OrderedDict[str, torch.Tensor]): The model's state dictionary. + Methods: + __init__(model_state_dict, optimizer_state_dict): Initializes the Weights object with the given model and optimizer state dictionaries. + """ + optimizer: OrderedDict[str, torch.Tensor] model: OrderedDict[str, torch.Tensor] def __init__(self, model_state_dict, optimizer_state_dict): + """ + Initializes the Weights object with the given model and optimizer state dictionaries. + + Args: + model_state_dict (OrderedDict[str, torch.Tensor]): The state dictionary of the model. + optimizer_state_dict (OrderedDict[str, torch.Tensor]): The state dictionary of the optimizer. + Examples: + >>> model = torch.nn.Linear(2, 2) + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + >>> weights = Weights(model.state_dict(), optimizer.state_dict()) + """ self.model = model_state_dict self.optimizer = optimizer_state_dict class WeightsStore(ABC): - """Base class for network weight stores.""" + """ + Base class for network weight stores. + + Methods: + load_weights(run, iteration): Load the weights of the given iteration into the given run. + load_best(run, dataset, criterion): Load the best weights for the given run, dataset, and criterion into the given run. + latest_iteration(run): Return the latest iteration for which weights are available for the given run. + store_weights(run, iteration): Store the network weights of the given run. + retrieve_weights(run, iteration): Retrieve the network weights of the given run. + remove(run, iteration): Delete the weights associated with a specific run/iteration. + retrieve_best(run, dataset, criterion): Retrieve the best weights for the given run, dataset, and criterion. + """ def load_weights(self, run: Run, iteration: int) -> None: """ Load this iterations weights into the given run. + Args: + run (Run): The run to load the weights into. + iteration (int): The iteration to load the weights from. + Raises: + ValueError: If the iteration is not available. + Examples: + >>> store = WeightsStore() + >>> run = Run() + >>> iteration = 0 + >>> store.load_weights(run, iteration) """ weights = self.retrieve_weights(run.name, iteration) run.model.load_state_dict(weights.model) @@ -30,30 +72,97 @@ def load_weights(self, run: Run, iteration: int) -> None: def load_best(self, run: Run, dataset: str, criterion: str) -> None: """ Load the best weights for this Run,dataset,criterion into Run.model + + Args: + run (Run): The run to load the weights into. + dataset (str): The dataset to load the weights from. + criterion (str): The criterion to load the weights from. + Raises: + ValueError: If the best iteration is not available. + Examples: + >>> store = WeightsStore() + >>> run = Run() + >>> dataset = 'mnist' + >>> criterion = 'accuracy' + >>> store.load_best(run, dataset, criterion) + """ best_iteration = self.retrieve_best(run.name, dataset, criterion) self.load_weights(run, best_iteration) @abstractmethod def latest_iteration(self, run: str) -> Optional[int]: - """Return the latest iteration for which weights are available for the - given run.""" + """ + Return the latest iteration for which weights are available for the + given run. + + Args: + run (str): The name of the run. + Returns: + int: The latest iteration for which weights are available. + Raises: + ValueError: If no weights are available for the given run. + Examples: + >>> store = WeightsStore() + >>> run = 'run_0' + >>> store.latest_iteration(run) + """ pass @abstractmethod def store_weights(self, run: Run, iteration: int) -> None: - """Store the network weights of the given run.""" + """ + Store the network weights of the given run. + + Args: + run (Run): The run to store the weights of. + iteration (int): The iteration to store the weights for. + Raises: + ValueError: If the iteration is already stored. + Examples: + >>> store = WeightsStore() + >>> run = Run() + >>> iteration = 0 + >>> store.store_weights(run, iteration) + """ pass @abstractmethod def retrieve_weights(self, run: str, iteration: int) -> Weights: - """Retrieve the network weights of the given run.""" + """ + Retrieve the network weights of the given run. + + Args: + run (str): The name of the run. + iteration (int): The iteration to retrieve the weights for. + Returns: + Weights: The weights of the given run and iteration. + Raises: + ValueError: If the weights are not available. + Examples: + >>> store = WeightsStore() + >>> run = 'run_0' + >>> iteration = 0 + >>> store.retrieve_weights(run, iteration) + """ pass @abstractmethod def remove(self, run: str, iteration: int) -> None: """ Delete the weights associated with a specific run/iteration + + Args: + run (str): The name of the run. + iteration (int): The iteration to delete the weights for. + Raises: + ValueError: If the weights are not available. + Examples: + >>> store = WeightsStore() + >>> run = 'run_0' + >>> iteration = 0 + >>> store.remove(run, iteration) + """ pass @@ -61,5 +170,20 @@ def remove(self, run: str, iteration: int) -> None: def retrieve_best(self, run: str, dataset: str, criterion: str) -> int: """ Retrieve the best weights for this run/dataset/criterion + + Args: + run (str): The name of the run. + dataset (str): The dataset to retrieve the best weights for. + criterion (str): The criterion to retrieve the best weights for. + Returns: + int: The iteration of the best weights. + Raises: + ValueError: If the best weights are not available. + Examples: + >>> store = WeightsStore() + >>> run = 'run_0' + >>> dataset = 'mnist' + >>> criterion = 'accuracy' + >>> store.retrieve_best(run, dataset, criterion) """ pass diff --git a/dacapo/train.py b/dacapo/train.py index 36c42c8a6..f218b3251 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -6,7 +6,7 @@ create_weights_store, ) from dacapo.experiments import Run -from dacapo.validate import validate_run +from dacapo.validate import validate import torch from tqdm import tqdm @@ -18,7 +18,16 @@ def train(run_name: str): - """Train a run""" + """ + Train a run + + Args: + run_name: Name of the run to train + Raises: + ValueError: If run_name is not found in config store + Examples: + >>> train("run_name") + """ # check config store to see if run is already being trained TODO # if ...: @@ -39,6 +48,15 @@ def train(run_name: str): def train_run(run: Run): + """ + Train a run + + Args: + run: Run object to train + Raises: + ValueError: If run_name is not found in config store + + """ print(f"Starting/resuming training for run {run}...") # create run @@ -169,13 +187,13 @@ def train_run(run: Run): try: # launch validation in a separate thread to avoid blocking training validate_thread = threading.Thread( - target=validate_run, + target=validate, args=(run, iteration_stats.iteration + 1), name=f"validate_{run.name}_{iteration_stats.iteration + 1}", daemon=True, ) validate_thread.start() - # validate_run( + # validate( # run, # iteration_stats.iteration + 1, # ) diff --git a/dacapo/utils/affinities.py b/dacapo/utils/affinities.py index 9c2dcec76..e6750e8af 100644 --- a/dacapo/utils/affinities.py +++ b/dacapo/utils/affinities.py @@ -9,6 +9,32 @@ def seg_to_affgraph(seg: np.ndarray, neighborhood: List[Coordinate]) -> np.ndarray: + """ + Constructs an affinity graph from a segmentation. + + Args: + seg (np.ndarray): The segmentation array. + neighborhood (List[Coordinate]): The list of coordinates representing the neighborhood. + Returns: + np.ndarray: The affinity graph. + Raises: + RuntimeError: If the number of dimensions is not 2 or 3. + Examples: + >>> seg = np.array([[1, 1, 2], [1, 1, 2], [3, 3, 4]]) + >>> neighborhood = [Coordinate(1, 0), Coordinate(0, 1)] + >>> seg_to_affgraph(seg, neighborhood) + array([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + + [[1, 1, 0], + [1, 1, 0], + [0, 0, 0]]], dtype=int32) + Notes: + The affinity graph is represented as: + shape = (e, z, y, x) + nhood.shape = (edges, 3) + """ nhood: np.ndarray = np.array(neighborhood) # constructs an affinity graph from a segmentation @@ -100,6 +126,19 @@ def seg_to_affgraph(seg: np.ndarray, neighborhood: List[Coordinate]) -> np.ndarr def padding(neighborhood, voxel_size): """ Get the appropriate padding to make sure all provided affinities are "True" + + Args: + neighborhood (List[Coordinate]): The list of coordinates representing the neighborhood. + voxel_size (Coordinate): The voxel size. + Returns: + Tuple[Coordinate, Coordinate]: The negative and positive padding. + Raises: + RuntimeError: If the number of dimensions is not 2 or 3. + Examples: + >>> neighborhood = [Coordinate(1, 0), Coordinate(0, 1)] + >>> voxel_size = Coordinate(1, 1) + >>> padding(neighborhood, voxel_size) + (Coordinate(0, 0), Coordinate(1, 1)) """ dims = voxel_size.dims padding_neg = ( diff --git a/dacapo/utils/balance_weights.py b/dacapo/utils/balance_weights.py index f5adcffca..e713745c6 100644 --- a/dacapo/utils/balance_weights.py +++ b/dacapo/utils/balance_weights.py @@ -13,6 +13,62 @@ def balance_weights( clipmax: float = 0.95, moving_counts: Optional[List[Dict[int, Tuple[int, int]]]] = None, ): + """ + Balances the weights based on the label data and other parameters. + + Args: + label_data (np.ndarray): The label data. + num_classes (int): The number of classes. + masks (List[np.ndarray], optional): List of masks. Defaults to an empty list. + slab (optional): The slab parameter. Defaults to None. + clipmin (float, optional): The minimum clipping value. Defaults to 0.05. + clipmax (float, optional): The maximum clipping value. Defaults to 0.95. + moving_counts (Optional[List[Dict[int, Tuple[int, int]]]], optional): List of moving counts. Defaults to None. + Returns: + Tuple[np.ndarray, List[Dict[int, Tuple[int, int]]]]: The balanced error scale and moving counts. + Raises: + AssertionError: If the number of unique labels is greater than the number of classes. + AssertionError: If the minimum label is less than 0 or the maximum label is greater than the number of classes. + Examples: + >>> label_data = np.array([[0, 1, 2], [0, 1, 2], [0, 1, 2]]) + >>> num_classes = 3 + >>> masks = [np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])] + >>> balance_weights(label_data, num_classes, masks) + (array([[0.33333334, 0.33333334, 0.33333334], + [0.33333334, 0.33333334, 0.33333334], + [0.33333334, 0.33333334, 0.33333334]], dtype=float32), + [{0: (3, 9), 1: (3, 9), 2: (3, 9)}]) + Notes: + The balanced error scale is computed as: + error_scale = np.ones(label_data.shape, dtype=np.float32) + for mask in masks: + error_scale = error_scale * mask + slab_ranges = (range(0, m, s) for m, s in zip(error_scale.shape, slab)) + for ind, start in enumerate(itertools.product(*slab_ranges)): + slab_counts = moving_counts[ind] + slices = tuple(slice(start[d], start[d] + slab[d]) for d in range(len(slab))) + scale_slab = error_scale[slices] + labels_slab = label_data[slices] + masked_in = scale_slab.sum() + classes, counts = np.unique(labels_slab[np.nonzero(scale_slab)], return_counts=True) + updated_fracs = [] + for key, (num, den) in slab_counts.items(): + slab_counts[key] = (num, den + masked_in) + for class_id, num in zip(classes, counts): + (old_num, den) = slab_counts[class_id] + slab_counts[class_id] = (num + old_num, den) + updated_fracs.append(slab_counts[class_id][0] / slab_counts[class_id][1]) + fracs = np.array(updated_fracs) + if clipmin is not None or clipmax is not None: + np.clip(fracs, clipmin, clipmax, fracs) + total_frac = 1.0 + w_sparse = total_frac / float(num_classes) / fracs + w = np.zeros(num_classes) + w[classes] = w_sparse + labels_slab = labels_slab.astype(np.int64) + scale_slab *= np.take(w, labels_slab) + """ + if moving_counts is None: moving_counts = [] unique_labels = np.unique(label_data) diff --git a/dacapo/utils/pipeline.py b/dacapo/utils/pipeline.py index 8dbc950ba..99f823eb9 100644 --- a/dacapo/utils/pipeline.py +++ b/dacapo/utils/pipeline.py @@ -13,15 +13,45 @@ class CreatePoints(gp.BatchFilter): + """ + A batch filter that creates random points in a 3D label volume. + + Attributes: + labels (str): The key of the label data in the batch. + num_points (tuple): A tuple specifying the range of the number of points to create. + Methods: + process: Create random points in the label volume. + """ + def __init__( self, labels, num_points=(20, 150), ): + """ + Initialize the Pipeline object. + + Args: + labels (list): A list of labels. + num_points (tuple, optional): A tuple representing the range of number of points. Defaults to (20, 150). + Examples: + >>> CreatePoints(labels="LABELS", num_points=(20, 150)) + """ self.labels = labels self.num_points = num_points def process(self, batch, request): + """ + Process the batch by creating random points in the label volume. + + Args: + batch (dict): The input batch containing label data. + request (gp.BatchRequest): The batch request. + Raises: + ValueError: If the number of points is not an integer. + Examples: + >>> CreatePoints.process(batch, request) + """ labels = batch[self.labels].data num_points = random.randint(*self.num_points) @@ -36,32 +66,91 @@ def process(self, batch, request): class MakeRaw(gp.BatchFilter): - def __init__( - self, - raw, - labels, - gaussian_noise_args: Iterable = (0.5, 0.1), - gaussian_noise_lim: float = 0.3, - gaussian_blur_args: Iterable = (0.5, 1.5), - membrane_like=True, - membrane_size=3, - inside_value=0.5, - ): - self.raw = raw - self.labels = labels - self.gaussian_noise_args = gaussian_noise_args - self.gaussian_noise_lim = gaussian_noise_lim - self.gaussian_blur_args = gaussian_blur_args - self.membrane_like = membrane_like - self.membrane_size = membrane_size - self.inside_value = inside_value + """ + A batch filter that generates a raw image from labels. + + Attributes: + raw (str): The key of the raw data in the batch. + labels (str): The key of the label data in the batch. + gaussian_noise_args (tuple): A tuple specifying the mean and standard deviation of the Gaussian noise. + gaussian_noise_lim (float): The limit of the Gaussian noise. + gaussian_blur_args (tuple): A tuple specifying the mean and standard deviation of the Gaussian blur. + membrane_like (bool): A boolean indicating whether to generate a membrane-like structure. + membrane_size (int): The size of the membrane-like structure. + inside_value (float): The value to set inside the membranes of objects. + Methods: + setup: Set up the batch filter by defining the specification of the raw image. + process: Generate the raw image from the labels. + """ + + class Pipeline: + def __init__( + self, + raw, + labels, + gaussian_noise_args: Iterable = (0.5, 0.1), + gaussian_noise_lim: float = 0.3, + gaussian_blur_args: Iterable = (0.5, 1.5), + membrane_like=True, + membrane_size=3, + inside_value=0.5, + ): + """ + Initialize the Pipeline object. + + Args: + raw: The raw data. + labels: The labels data. + gaussian_noise_args: Tuple of two floats representing the mean and standard deviation + of the Gaussian noise to be added to the data. Default is (0.5, 0.1). + gaussian_noise_lim: The limit of the Gaussian noise. Default is 0.3. + gaussian_blur_args: Tuple of two floats representing the mean and standard deviation + of the Gaussian blur to be applied to the data. Default is (0.5, 1.5). + membrane_like: Boolean indicating whether to apply membrane-like transformation to the data. + Default is True. + membrane_size: The size of the membrane. Default is 3. + inside_value: The value to be assigned to the inside of the membrane. Default is 0.5. + Examples: + >>> Pipeline(raw="RAW", labels="LABELS", gaussian_noise_args=(0.5, 0.1), gaussian_noise_lim=0.3, + >>> gaussian_blur_args=(0.5, 1.5), membrane_like=True, membrane_size=3, inside_value=0.5) + """ + self.raw = raw + self.labels = labels + self.gaussian_noise_args = gaussian_noise_args + self.gaussian_noise_lim = gaussian_noise_lim + self.gaussian_blur_args = gaussian_blur_args + self.membrane_like = membrane_like + self.membrane_size = membrane_size + self.inside_value = inside_value def setup(self): + """ + Set up the batch filter by defining the specification of the raw image. + + Raises: + ValueError: If the data type is not float32. + Examples: + >>> MakeRaw.setup() + + """ spec = self.spec[self.labels].copy() # type: ignore spec.dtype = np.float32 self.provides(self.raw, spec) def process(self, batch, request): + """ + Process the batch by generating the raw image from the labels. + + Args: + batch (gp.Batch): The input batch. + request (gp.BatchRequest): The request for the output batch. + Returns: + gp.Batch: The output batch. + Raises: + ValueError: If the data type is not float32. + Examples: + >>> MakeRaw.process(batch, request) + """ labels = batch[self.labels].data raw: np.ndarray = np.zeros_like(labels, dtype=np.float32) raw[labels > 0] = 1 @@ -95,11 +184,45 @@ def process(self, batch, request): class DilatePoints(gp.BatchFilter): + """ + A batch filter that performs dilation on labeled points. + + Attributes: + labels (str): The key of the labels data in the batch. + dilations (list[int]): A list of two integers representing the range of dilations to apply. + Methods: + process: Perform dilation on the labeled points. + """ + def __init__(self, labels, dilations=[2, 8]): + """ + Initialize the DilatePoints batch filter. + + Args: + labels (str): The key of the labels data in the batch. + dilations (list[int]): A list of two integers representing the range of dilations to apply. + Raises: + ValueError: If the dilations are not integers. + Examples: + >>> DilatePoints(labels="LABELS", dilations=[2, 8]) + + """ self.labels = labels self.dilations = dilations def process(self, batch, request): + """ + Process the batch by performing dilation on the labeled points. + + Args: + batch (Batch): The input batch. + request (Request): The request object. + Raises: + ValueError: If the dilations are not integers. + Examples: + >>> DilatePoints.process(batch, request) + + """ labels = batch[self.labels].data dilations = random.randint(*self.dilations) @@ -109,11 +232,34 @@ def process(self, batch, request): class RandomDilateLabels(gp.BatchFilter): + """ + A batch filter that randomly dilates labels in a batch. + + Attributes: + labels (str): The key of the labels in the batch. + dilations (list[int]): A list of two integers representing the range of dilations. + Methods: + process: Randomly dilate the labels in the batch. + + """ + def __init__(self, labels, dilations=[2, 8]): self.labels = labels self.dilations = dilations def process(self, batch, request): + """ + Process the batch by randomly dilating labels. + + Args: + batch (dict): The input batch. + request: The request object. + Raises: + ValueError: If the dilations are not integers. + Examples: + >>> RandomDilateLabels.process(batch, request) + + """ labels = batch[self.labels].data new_labels = np.zeros_like(labels) @@ -122,7 +268,7 @@ def process(self, batch, request): continue dilations = np.random.randint(*self.dilations) - # # make sure we don't overlap existing labels + # make sure we don't overlap existing labels new_labels[ np.logical_or( labels == id, @@ -136,11 +282,48 @@ def process(self, batch, request): class Relabel(gp.BatchFilter): + """ + A batch filter that relabels the labels in a batch. + + Args: + labels (str): The name of the labels data in the batch. + connectivity (int, optional): The connectivity used for relabeling. Defaults to 1. + Methods: + process: Process the batch and relabel the labels. + + """ + def __init__(self, labels, connectivity=1): + """ + Initialize the Pipeline object. + + Args: + labels (str): The name of the labels data in the batch. + connectivity (int, optional): The connectivity used for relabeling. Defaults to 1. + Raises: + ValueError: If the connectivity is not an integer. + Examples: + >>> Relabel(labels="LABELS", connectivity=1) + """ self.labels = labels self.connectivity = connectivity def process(self, batch, request): + """ + Process the batch and relabel the labels. + + Args: + batch (Batch): The input batch. + request (Request): The request for processing. + Returns: + Batch: The output batch. + Raises: + ValueError: If the connectivity is not an integer. + Examples: + >>> Relabel.process(batch, request) + + + """ labels = batch[self.labels].data relabeled = relabel(labels, connectivity=self.connectivity).astype(labels.dtype) # type: ignore @@ -149,11 +332,46 @@ def process(self, batch, request): class ExpandLabels(gp.BatchFilter): + """ + A batch filter that expands labels by assigning the nearest label to each pixel within a specified distance. + + Attributes: + labels (str): The name of the labels data in the batch. + background (int): The label value representing the background. + Methods: + process: Process the batch and expand the labels. + + """ + def __init__(self, labels, background=0): + """ + Initialize the Pipeline object. + + Args: + labels (list): A list of labels. + background (int, optional): The background value. Defaults to 0. + Raises: + ValueError: If the background is not an integer. + Examples: + >>> ExpandLabels(labels="LABELS", background=0) + + """ self.labels = labels self.background = background def process(self, batch, request): + """ + Process the batch by expanding the labels. + + Args: + batch (gp.Batch): The input batch. + request (gp.BatchRequest): The batch request. + Raises: + ValueError: If the background is not an integer. + Examples: + >>> ExpandLabels.process(batch, request) + + """ labels_data = batch[self.labels].data distance = labels_data.shape[0] @@ -177,14 +395,59 @@ def process(self, batch, request): class ZerosSource(gp.BatchProvider): + """ + A batch provider that generates arrays filled with zeros. + + Attributes: + key (str): The key to use for the generated array. + _spec (dict): A dictionary containing the specification of the array. + Methods: + setup: Perform any necessary setup before providing batches. + provide: Provide a batch containing an array filled with zeros. + + """ + def __init__(self, key, spec): + """ + Initialize a Pipeline object. + + Args: + key (str): The key to use for the generated array. + spec (ArraySpec): The specification of the array. + Raises: + ValueError: If the key is not a string. + Examples: + >>> ZerosSource(key="LABELS", spec=ArraySpec(roi=gp.Roi((0, 0, 0), (100, 100, 100)), voxel_size=(8, 8, 8), dtype=np.uint8)) + + """ self.key = key self._spec = {key: spec} def setup(self): + """ + Perform any necessary setup before providing batches. + + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> ZerosSource.setup() + + """ pass def provide(self, request): + """ + Provide a batch containing an array filled with zeros. + + Args: + request (gp.BatchRequest): The request for the batch. + Returns: + gp.Batch: The provided batch. + Raises: + NotImplementedError: If the method is not implemented. + Examples: + >>> ZerosSource.provide(request) + """ batch = gp.Batch() roi = request[self.key].roi @@ -214,7 +477,6 @@ def random_source_pipeline( """Create a random source pipeline and batch request for example training. Args: - voxel_size (tuple of int): The size of a voxel in world units. input_shape (tuple of int): The shape of the input arrays. dtype (numpy.dtype): The dtype of the label arrays. @@ -227,11 +489,15 @@ def random_source_pipeline( membrane_like (bool): Whether to generate a membrane-like structure in the raw array. membrane_size (int): The width of the membrane-like structure on the outside of the objects. inside_value (float): The value to set inside the membranes of objects. - Returns: - gunpowder.Pipeline: The batch generating Gunpowder pipeline. gunpowder.BatchRequest: The batch request for the pipeline. + Raises: + ValueError: If the input_shape is not an integer. + Examples: + >>> random_source_pipeline(voxel_size=(8, 8, 8), input_shape=(148, 148, 148), dtype=np.uint8, expand_labels=False, + >>> relabel_connectivity=1, random_dilate=True, num_points=(20, 150), gaussian_noise_args=(0, 0.1), + >>> gaussian_blur_args=(0.5, 1.5), membrane_like=True, membrane_size=3, inside_value=0.5) """ voxel_size = gp.Coordinate(voxel_size) diff --git a/dacapo/utils/view.py b/dacapo/utils/view.py index f5b6d8e83..203f98cf4 100644 --- a/dacapo/utils/view.py +++ b/dacapo/utils/view.py @@ -18,6 +18,35 @@ def get_viewer( arrays: dict, width: int = 1500, height: int = 600, headless: bool = True ) -> neuroglancer.Viewer | IFrame: + """ + Creates a neuroglancer viewer to visualize arrays. + + Args: + arrays (dict): A dictionary containing arrays to be visualized. + width (int, optional): The width of the viewer window in pixels. Defaults to 1500. + height (int, optional): The height of the viewer window in pixels. Defaults to 600. + headless (bool, optional): If True, returns the viewer object. If False, returns an IFrame object embedding the viewer. Defaults to True. + Returns: + neuroglancer.Viewer | IFrame: The neuroglancer viewer object or an IFrame object embedding the viewer. + Raises: + ValueError: If the array is not a numpy array or a neuroglancer.LocalVolume object. + Examples: + >>> import numpy as np + >>> import neuroglancer + >>> from dacapo.utils.view import get_viewer + >>> arrays = { + ... "raw": { + ... "array": np.random.rand(100, 100, 100) + ... }, + ... "seg": { + ... "array": np.random.randint(0, 10, (100, 100, 100)), + ... "is_seg": True + ... } + ... } + >>> viewer = get_viewer(arrays) + >>> viewer + """ + for name, array_data in arrays.items(): array = array_data["array"] if hasattr(array, "to_ndarray"): @@ -52,12 +81,31 @@ def get_viewer( def add_seg_layer(state, name, data, voxel_size, meshes=False): + """ + Add a segmentation layer to the Neuroglancer state. + + Args: + state (neuroglancer.ViewerState): The Neuroglancer viewer state. + name (str): The name of the segmentation layer. + data (ndarray): The segmentation data. + voxel_size (list): The voxel size in nm. + meshes (bool, optional): Whether to generate meshes for the segments. Defaults to False. + Raises: + ValueError: If the data is not a numpy array. + Examples: + >>> import numpy as np + >>> import neuroglancer + >>> from dacapo.utils.view import add_seg_layer + >>> state = neuroglancer.ViewerState() + >>> data = np.random.randint(0, 10, (100, 100, 100)) + >>> voxel_size = [1, 1, 1] + >>> add_seg_layer(state, "seg", data, voxel_size) + """ if meshes: kwargs = {"segments": np.unique(data[data > 0])} else: kwargs = {} state.layers[name] = neuroglancer.SegmentationLayer( - # segments=[str(i) for i in np.unique(data[data > 0])], # this line will cause all objects to be selected and thus all meshes to be generated...will be slow if lots of high res meshes source=neuroglancer.LocalVolume( data=data, dimensions=neuroglancer.CoordinateSpace( @@ -71,6 +119,25 @@ def add_seg_layer(state, name, data, voxel_size, meshes=False): def add_scalar_layer(state, name, data, voxel_size): + """ + Add a scalar layer to the state. + + Args: + state (neuroglancer.ViewerState): The viewer state to add the layer to. + name (str): The name of the layer. + data (ndarray): The scalar data to be displayed. + voxel_size (list): The voxel size in nm. + Raises: + ValueError: If the data is not a numpy array. + Examples: + >>> import numpy as np + >>> import neuroglancer + >>> from dacapo.utils.view import add_scalar_layer + >>> state = neuroglancer.ViewerState() + >>> data = np.random.rand(100, 100, 100) + >>> voxel_size = [1, 1, 1] + >>> add_scalar_layer(state, "raw", data, voxel_size) + """ state.layers[name] = neuroglancer.ImageLayer( source=neuroglancer.LocalVolume( data=data, @@ -84,7 +151,35 @@ def add_scalar_layer(state, name, data, voxel_size): class BestScore: + """ + Represents the best score achieved during a run. + + Attributes: + run (Run): The run object associated with the best score. + score (float): The best score achieved. + iteration (int): The iteration number at which the best score was achieved. + parameter (Optional[str]): The parameter associated with the best score. + validation_parameters: The validation parameters used during the run. + array_store: The array store object used to store prediction arrays. + stats_store: The stats store object used to store iteration scores. + ds: The dataset object associated with the best score. + Methods: + get_ds(iteration, validation_dataset): Retrieves the dataset object associated with the best score. + does_new_best_exist(): Checks if a new best score exists. + """ + def __init__(self, run: Run): + """ + Initializes a new instance of the BestScore class. + + Args: + run (Run): The run object associated with the best score. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import BestScore + >>> run = Run() + >>> best_score = BestScore(run) + """ self.run: Run = run self.score: float = -1 self.iteration: int = 0 @@ -95,6 +190,23 @@ def __init__(self, run: Run): self.stats_store = create_stats_store() def get_ds(self, iteration, validation_dataset): + """ + Retrieves the dataset object associated with the best score. + + Args: + iteration (int): The iteration number. + validation_dataset: The validation dataset object. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import BestScore + >>> run = Run() + >>> best_score = BestScore(run) + >>> iteration = 0 + >>> validation_dataset = run.datasplit.validate[0] + >>> best_score.get_ds(iteration, validation_dataset) + """ prediction_array_identifier = self.array_store.validation_prediction_array( self.run.name, iteration, validation_dataset.name ) @@ -107,6 +219,20 @@ def get_ds(self, iteration, validation_dataset): self.ds = open_ds(container, dataset) def does_new_best_exist(self): + """ + Checks if a new best score exists. + + Returns: + bool: True if a new best score exists, False otherwise. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import BestScore + >>> run = Run() + >>> best_score = BestScore(run) + >>> new_best_exists = best_score.does_new_best_exist() + """ new_best_exists = False self.validation_scores = self.stats_store.retrieve_validation_iteration_scores( self.run.name @@ -132,12 +258,74 @@ def does_new_best_exist(self): class NeuroglancerRunViewer: + """ + A class for viewing neuroglancer runs. + + Attributes: + run (Run): The run object. + best_score (BestScore): The best score object. + embedded (bool): Whether the viewer is embedded. + viewer: The neuroglancer viewer. + raw: The raw dataset. + gt: The ground truth dataset. + segmentation: The segmentation dataset. + most_recent_iteration: The most recent iteration. + run_thread: The run thread. + array_store: The array store object. + Methods: + updated_neuroglancer_layer(layer_name, ds): Update the neuroglancer layer with the given name and data source. + deprecated_start_neuroglancer(): Deprecated method to start the neuroglancer viewer. + start_neuroglancer(): Start the neuroglancer viewer. + start(): Start the viewer. + open_from_array_identitifier(array_identifier): Open the array from the given identifier. + get_datasets(): Get the datasets for validation. + update_best_info(): Update the best info. + update_neuroglancer(): Update the neuroglancer viewer. + update_best_layer(): Update the best layer. + new_validation_checker(): Start a new validation checker thread. + update_with_new_validation_if_possible(): Update with new validation if possible. + stop(): Stop the viewer. + """ + def __init__(self, run: Run, embedded=False): + """ + Initialize a View object. + + Args: + run (Run): The run object. + embedded (bool, optional): Whether the viewer is embedded. Defaults to False. + Returns: + View: The view object. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + """ self.run: Run = run self.best_score = BestScore(run) self.embedded = embedded def updated_neuroglancer_layer(self, layer_name, ds): + """ + Update the neuroglancer layer with the given name and data source. + + Args: + layer_name (str): The name of the layer. + ds: The data source. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> layer_name = "prediction" + >>> ds = viewer.run.datasplit.validate[0].raw._source_array + >>> viewer.updated_neuroglancer_layer(layer_name, ds) + """ source = neuroglancer.LocalVolume( data=ds.data, dimensions=neuroglancer.CoordinateSpace( @@ -160,10 +348,38 @@ def updated_neuroglancer_layer(self, layer_name, ds): self.viewer.set_state(new_state) def deprecated_start_neuroglancer(self): + """ + Deprecated method to start the neuroglancer viewer. + + Returns: + IFrame: The embedded viewer. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.deprecated_start_neuroglancer() + """ neuroglancer.set_server_bind_address("0.0.0.0") self.viewer = neuroglancer.Viewer() def start_neuroglancer(self): + """ + Start the neuroglancer viewer. + + Returns: + IFrame: The embedded viewer. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.start_neuroglancer() + """ neuroglancer.set_server_bind_address("0.0.0.0") self.viewer = neuroglancer.Viewer() print(f"Neuroglancer viewer: {self.viewer}") @@ -185,6 +401,20 @@ def start_neuroglancer(self): return IFrame(src=self.viewer, width=1800, height=900) def start(self): + """ + Start the viewer. + + Returns: + IFrame: The embedded viewer. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.start() + """ self.run_thread = True self.array_store = create_array_store() self.get_datasets() @@ -192,12 +422,45 @@ def start(self): return self.start_neuroglancer() def open_from_array_identitifier(self, array_identifier): + """ + Open the array from the given identifier. + + Args: + array_identifier: The array identifier. + Returns: + The opened dataset or None if it doesn't exist. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> array_identifier = viewer.run.datasplit.validate[0].raw._source_array + >>> ds = viewer.open_from_array_identitifier(array_identifier) + """ if os.path.exists(array_identifier.container / array_identifier.dataset): return open_ds(str(array_identifier.container), array_identifier.dataset) else: return None def get_datasets(self): + """ + Get the datasets for validation. + + Args: + run (Run): The run object. + Returns: + The raw and ground truth datasets for validation. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.get_datasets() + """ for validation_dataset in self.run.datasplit.validate: raw = validation_dataset.raw._source_array gt = validation_dataset.gt._source_array @@ -205,10 +468,40 @@ def get_datasets(self): self.gt = open_ds(str(gt.file_name), gt.dataset) def update_best_info(self): + """ + Update the best info. + + Args: + run (Run): The run object. + Returns: + IFrame: The embedded viewer. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.update_best_info() + """ self.segmentation = self.best_score.ds self.most_recent_iteration = self.best_score.iteration def update_neuroglancer(self): + """ + Update the neuroglancer viewer. + + Args: + run (Run): The run object. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.update_neuroglancer() + """ self.updated_neuroglancer_layer( f"prediction at iteration {self.best_score.iteration}, f1 score {self.best_score.score}", self.segmentation, @@ -216,10 +509,42 @@ def update_neuroglancer(self): return None def update_best_layer(self): + """ + Update the best layer. + + Args: + run (Run): The run object. + Returns: + IFrame: The embedded viewer. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.update_best_layer() + """ self.update_best_info() self.update_neuroglancer() def new_validation_checker(self): + """ + Start a new validation checker thread. + + Args: + run (Run): The run object. + Returns: + IFrame: The embedded viewer. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.new_validation_checker() + """ self.thread = threading.Thread( target=self.update_with_new_validation_if_possible, daemon=True ) @@ -227,6 +552,22 @@ def new_validation_checker(self): self.thread.start() def update_with_new_validation_if_possible(self): + """ + Update with new validation if possible. + + Args: + run (Run): The run object. + Returns: + IFrame: The embedded viewer. + Raises: + FileNotFoundError: If the dataset object does not exist. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.update_with_new_validation_if_possible() + """ thread = threading.currentThread() # Here we are assuming that we are checking the directory .../valdiation_config/prediction # Ideally we will only have to check for the current best validation @@ -240,4 +581,20 @@ def update_with_new_validation_if_possible(self): self.update_best_layer() def stop(self): + """ + Stop the viewer. + + Args: + run (Run): The run object. + Raises: + FileNotFoundError: If the dataset object does not exist. + Returns: + IFrame: The embedded viewer. + Examples: + >>> from dacapo.experiments.run import Run + >>> from dacapo.utils.view import NeuroglancerRunViewer + >>> run = Run() + >>> viewer = NeuroglancerRunViewer(run) + >>> viewer.stop() + """ self.thread.run_thread = False diff --git a/dacapo/utils/voi.py b/dacapo/utils/voi.py index e5399a443..17f658b0b 100644 --- a/dacapo/utils/voi.py +++ b/dacapo/utils/voi.py @@ -37,6 +37,11 @@ def voi(reconstruction, groundtruth, ignore_reconstruction=[], ignore_groundtrut (split, merge) : float The variation of information split and merge error, i.e., H(X|Y) and H(Y|X) + Raises + ------ + ValueError + If `reconstruction` and `groundtruth` have different shapes. + References ---------- [1] Meila, M. (2007). Comparing clusterings - an information based @@ -108,6 +113,11 @@ def vi_tables(x, y=None, ignore_x=[0], ignore_y=[0]): The proportions of each label in `x` and `y` (`px`, `py`), the per-segment conditional entropies of `x` given `y` and vice-versa, the per-segment conditional probability p log p. + + Raises + ------ + ValueError + If `x` and `y` have different shapes. """ if y is not None: pxy = contingency_table(x, y, ignore_x, ignore_y) @@ -164,6 +174,12 @@ def contingency_table(seg, gt, ignore_seg=[0], ignore_gt=[0], norm=True): A contingency table. `cont[i, j]` will equal the number of voxels labeled `i` in `seg` and `j` in `gt`. (Or the proportion of such voxels if `norm=True`.) + + Raises + ------ + ValueError + If `seg` and `gt` have different shapes. + """ segr = seg.ravel() gtr = gt.ravel() @@ -198,6 +214,11 @@ def divide_columns(matrix, row, in_place=False): ------- out : same type as `matrix` The result of the row-wise division. + + Raises + ------ + ValueError + If `row` contains zeros. """ if in_place: out = matrix @@ -237,6 +258,11 @@ def divide_rows(matrix, column, in_place=False): ------- out : same type as `matrix` The result of the row-wise division. + + Raises + ------ + ValueError + If `column` contains zeros. """ if in_place: out = matrix @@ -276,6 +302,11 @@ def xlogx(x, out=None, in_place=False): ------- y : same type as x Result of x * log_2(x). + + Raises + ------ + ValueError + If x contains negative values. """ if in_place: y = x diff --git a/dacapo/validate.py b/dacapo/validate.py index 8c6461e7d..b49ffe4c1 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -8,59 +8,78 @@ create_weights_store, ) -from pathlib import Path +from upath import UPath as Path import logging +from warnings import warn logger = logging.getLogger(__name__) -def validate( - run_name: str, +def validate_run( + run: Run, iteration: int, num_workers: int = 1, output_dtype: str = "uint8", overwrite: bool = True, ): - """Validate a run at a given iteration. Loads the weights from a previously - stored checkpoint. Returns the best parameters and scores for this - iteration.""" - - print(f"Validating run {run_name} at iteration {iteration}...") - - # create run - - config_store = create_config_store() - run_config = config_store.retrieve_run_config(run_name) - run = Run(run_config) - - # read in previous training/validation stats - stats_store = create_stats_store() - run.training_stats = stats_store.retrieve_training_stats(run_name) - run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( - run_name + """ + validate_run is deprecated and will be removed in a future version. Please use validate instead. + """ + warn( + "validate_run is deprecated and will be removed in a future version. Please use validate instead.", + DeprecationWarning, ) - - return validate_run( - run, - iteration, + return validate( + run_name=run, + iteration=iteration, num_workers=num_workers, output_dtype=output_dtype, overwrite=overwrite, ) -# @reloading # allows us to fix validation bugs without interrupting training -def validate_run( - run: Run, +def validate( + run_name: str | Run, iteration: int, num_workers: int = 1, output_dtype: str = "uint8", overwrite: bool = True, ): - """Validate an already loaded run at the given iteration. This does not - load the weights of that iteration, it is assumed that the model is already - loaded correctly. Returns the best parameters and scores for this - iteration.""" + """ + Validate a run at a given iteration. Loads the weights from a previously + stored checkpoint. Returns the best parameters and scores for this + iteration. + + Args: + run_name: The name of the run to validate. + iteration: The iteration to validate. + num_workers: The number of workers to use for validation. + output_dtype: The dtype to use for the output arrays. + overwrite: Whether to overwrite existing output arrays + Returns: + The best parameters and scores for this iteration + Raises: + ValueError: If the run does not have a validation dataset or the dataset does not have ground truth. + Example: + validate("my_run", 1000) + """ + + print(f"Validating run {run_name} at iteration {iteration}...") + + if isinstance(run_name, Run): + run = run_name + run_name = run.name + else: + config_store = create_config_store() + run_config = config_store.retrieve_run_config(run_name) + run = Run(run_config) + + # read in previous training/validation stats + stats_store = create_stats_store() + run.training_stats = stats_store.retrieve_training_stats(run_name) + run.validation_scores.scores = stats_store.retrieve_validation_iteration_scores( + run_name + ) if ( run.datasplit.validate is None diff --git a/dockerfile b/dockerfile new file mode 100644 index 000000000..b766d30c3 --- /dev/null +++ b/dockerfile @@ -0,0 +1,30 @@ +FROM python:3.11-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y \ + gcc \ + g++ \ + pkg-config \ + make \ + libhdf5-dev \ + libc-dev \ + npm \ + git \ + && rm -rf /var/lib/apt/lists/* + +RUN npm install -g configurable-http-proxy + +RUN pip install --upgrade pip + +RUN pip install h5py +RUN pip install dacapo-ml +RUN pip install notebook + +RUN git clone https://github.com/janelia-cellmap/dacapo.git +RUN mv dacapo/examples examples && rm -rf dacapo + +EXPOSE 8000 + +CMD ["jupyter", "notebook", "--ip=0.0.0.0", "--port=8000", "--NotebookApp.allow_origin='*'", "--allow-root", "--NotebookApp.token=''", "--NotebookApp.password=''", "--NotebookApp.notebook_dir='/app/'"] + diff --git a/docs/source/docker.rst b/docs/source/docker.rst new file mode 100644 index 000000000..7782c350f --- /dev/null +++ b/docs/source/docker.rst @@ -0,0 +1,48 @@ +.. automodule:: dacapo + +.. contents:: + :depth: 1 + :local: + +Docker Configuration for JupyterHub-Dacapo +========================================= + +This document provides instructions on how to build and run the Docker image for the JupyterHub-Dacapo project. + +Requirements +------------ +Before you begin, ensure you have Docker installed on your system. You can download it from `Docker's official website `_. + +Building the Docker Image +------------------------- +To build the Docker image, navigate to the directory containing your Dockerfile and run the following command: + +.. code-block:: bash + + docker build -t jupyterhub-dacapo . + +This command builds a Docker image with the tag `jupyterhub-dacapo` using the Dockerfile in the current directory. + +Running the Docker Container +---------------------------- +Once the image is built, you can run a container from the image with the following command: + +.. code-block:: bash + + docker run -p 8000:8000 jupyterhub-dacapo + +This command starts a container based on the `jupyterhub-dacapo` image. It maps port 8000 of the container to port 8000 on the host, allowing you to access JupyterHub by navigating to `http://localhost:8000` in your web browser. + +Stopping the Docker Container +----------------------------- +To stop the running container, you can use the Docker CLI to stop the container: + +.. code-block:: bash + + docker stop [CONTAINER_ID] + +Replace `[CONTAINER_ID]` with the actual ID of your running container. You can find the container ID by listing all running containers with `docker ps`. + +Further Configuration +--------------------- +For additional configuration options, such as setting environment variables or configuring volumes, refer to the Docker documentation or the specific documentation for the JupyterHub or Dacapo configurations. diff --git a/docs/source/index.rst b/docs/source/index.rst index 703084605..a4390edd3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,6 +10,7 @@ overview install tutorial + docker autoapi/index cli diff --git a/examples/aws/README.md b/examples/aws/README.md new file mode 100644 index 000000000..96f8c9499 --- /dev/null +++ b/examples/aws/README.md @@ -0,0 +1,14 @@ +You can work locally using S3 data by setting the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables. You can also set the `AWS_REGION` environment variable to specify the region to use. If you are using a profile, you can set the `AWS_PROFILE` environment variable to specify the profile to use. + +```bash +aws configure +``` + +In order to store checkpoints and experiments data in S3, you need to modify `dacapo.yaml` to include the following: + +```yaml +runs_base_dir: "s3://dacapotest" +``` + +For configs and stats, you can save them locally or s3 by setting `type: files` or for mongodb by setting `type: mongo` in the `dacapo.yaml` file. + diff --git a/examples/aws/aws_store_check.py b/examples/aws/aws_store_check.py new file mode 100644 index 000000000..f44b261ed --- /dev/null +++ b/examples/aws/aws_store_check.py @@ -0,0 +1,30 @@ +# %% +import dacapo + +# from import create_config_store + +config_store = dacapo.store.create_store.create_config_store() + +# %% +from dacapo import Options + +options = Options.instance() + +# %% +options +# %% +from dacapo.experiments.tasks import DistanceTaskConfig + +task_config = DistanceTaskConfig( + name="cosem_distance_task_4nm", + channels=["mito"], + clip_distance=40.0, + tol_distance=40.0, + scale_factor=80.0, +) + +# %% + +config_store.store_task_config(task_config) + +# %% diff --git a/examples/aws/cloud_csv.csv b/examples/aws/cloud_csv.csv new file mode 100644 index 000000000..99a407a82 --- /dev/null +++ b/examples/aws/cloud_csv.csv @@ -0,0 +1,3 @@ +train,s3://janelia-cosem-datasets/jrc_hela-2/jrc_hela-2.zarr,recon-1/em/fibsem-uint8,s3://janelia-cosem-datasets/jrc_hela-2/jrc_hela-2.zarr,recon-1/labels/groundtruth/crop155/[nuc] +train,s3://janelia-cosem-datasets/jrc_hela-2/jrc_hela-2.zarr,recon-1/em/fibsem-uint8,s3://janelia-cosem-datasets/jrc_hela-2/jrc_hela-2.zarr,recon-1/labels/groundtruth/crop7/[nuc] +val,s3://janelia-cosem-datasets/jrc_hela-2/jrc_hela-2.zarr,recon-1/em/fibsem-uint8,s3://janelia-cosem-datasets/jrc_hela-2/jrc_hela-2.zarr,recon-1/labels/groundtruth/crop6/[nuc] \ No newline at end of file diff --git a/examples/aws/dacapo.yaml b/examples/aws/dacapo.yaml new file mode 100644 index 000000000..960719a6d --- /dev/null +++ b/examples/aws/dacapo.yaml @@ -0,0 +1,3 @@ + +runs_base_dir: "s3://dacapotest" +type: "files" diff --git a/examples/aws/s3_datasplit.py b/examples/aws/s3_datasplit.py new file mode 100644 index 000000000..f5bb72b79 --- /dev/null +++ b/examples/aws/s3_datasplit.py @@ -0,0 +1,16 @@ +# %% +from dacapo.experiments.datasplits import DataSplitGenerator +from funlib.geometry import Coordinate + +input_resolution = Coordinate(8, 8, 8) +output_resolution = Coordinate(4, 4, 4) +datasplit_config = DataSplitGenerator.generate_from_csv( + "cloud_csv.csv", + input_resolution, + output_resolution, +).compute() +# %% +datasplit = datasplit_config.datasplit_type(datasplit_config) +# %% +viewer = datasplit._neuroglancer() +# %% diff --git a/examples/blockwise/dummy_script.py b/examples/blockwise/dummy_script.py new file mode 100644 index 000000000..2a96af9c0 --- /dev/null +++ b/examples/blockwise/dummy_script.py @@ -0,0 +1,28 @@ +from dacapo.blockwise.scheduler import run_blockwise +from funlib.geometry import Roi + +# Make the ROIs +path_to_worker = "dummy_worker.py" +total_roi = Roi(offset=(0, 0, 0), shape=(100, 100, 100)) +read_roi = Roi(offset=(0, 0, 0), shape=(10, 10, 10)) +write_roi = Roi(offset=(0, 0, 0), shape=(1, 1, 1)) +num_workers = 16 + +# Run the script blockwise +success = run_blockwise( + worker_file=path_to_worker, + total_roi=total_roi, + read_roi=read_roi, + write_roi=write_roi, + num_workers=num_workers, + arg="Thing", +) + +# Print the success +if success: + print("Success") +else: + print("Failure") + +# example run command: +# bsub -n 4 python dummy_script.py diff --git a/examples/blockwise/dummy_worker.py b/examples/blockwise/dummy_worker.py new file mode 100644 index 000000000..9bdb33e7e --- /dev/null +++ b/examples/blockwise/dummy_worker.py @@ -0,0 +1,124 @@ +from typing import Any, Optional +import sys +from dacapo.compute_context import create_compute_context + +import daisy + +import click + +import logging + +logger = logging.getLogger(__file__) + +read_write_conflict: bool = False +fit: str = "valid" +path = __file__ + +# OPTIONALLY DEFINE GLOBALS HERE + + +@click.group() +@click.option( + "--log-level", + type=click.Choice( + ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False + ), + default="INFO", +) +def cli(log_level): + """ + CLI for running the threshold worker. + + Args: + log_level (str): The log level to use. + """ + logging.basicConfig(level=getattr(logging, log_level.upper())) + + +@cli.command() +@click.option( + "-a", + "--arg", + required=True, + type=any, + default=None, +) +# ADD MORE CLICK OPTION ARGUMENTS HERE +def start_worker( + arg: Any, + # ADD MORE ARGUMENTS HERE + return_io_loop: Optional[bool] = False, +): + """ + Start the worker. + + Args: + arg (Any): An example argument to use. + + """ + # Do something with the argument + print(arg) + + def io_loop(): + # wait for blocks to run pipeline + client = daisy.Client() + + while True: + print("getting block") + with client.acquire_block() as block: + if block is None: + break + + # Do the blockwise process + print( + f"processing block: {block.id}, with read_roi: {block.read_roi}, using arg: {arg}" + ) + # DO SOMETHING WITH THE BLOCK + + if return_io_loop: + return io_loop + else: + io_loop() + + +def spawn_worker( + arg: Any, + # ADD MORE ARGUMENTS HERE +): + """ + Spawn a worker. + + Args: + arg (Any): An example argument to use. + Returns: + Callable: The function to run the worker. + """ + compute_context = create_compute_context() + if not compute_context.distribute_workers: + return start_worker( + arg=arg, + # ADD MORE ARGUMENTS HERE + return_io_loop=True, + ) + + # Make the command for the worker to run + command = [ + sys.executable, + path, + "start-worker", + "--arg", + str(arg), + # ADD MORE ARGUMENTS HERE, THEY MUST BE STRINGS + ] + + def run_worker(): + """ + Run the worker in the given compute context. + """ + compute_context.execute(command) + + return run_worker + + +if __name__ == "__main__": + cli() diff --git a/examples/postprocessing/README.md b/examples/postprocessing/README.md new file mode 100644 index 000000000..436694ba4 --- /dev/null +++ b/examples/postprocessing/README.md @@ -0,0 +1,10 @@ +# Post processing example scripts for distribute blockwise processing of peroxisome data. + +The goal of the script is to : +- Gaussian filter the data +- Threshold the distance data to get binary data +- Apply watershed to get connected components +- Find the connected components +- Mask False Positives Mitochondria using Mitochondria data +- Merge crops +- Filter the connected components based on size diff --git a/examples/postprocessing/blockwise_postprocess_script.py b/examples/postprocessing/blockwise_postprocess_script.py new file mode 100644 index 000000000..5008e476d --- /dev/null +++ b/examples/postprocessing/blockwise_postprocess_script.py @@ -0,0 +1,50 @@ +from dacapo.blockwise.scheduler import run_blockwise +from funlib.geometry import Roi +from postprocessing.postprocess_worker import open_ds +import daisy +import numpy as np + +# Make the ROIs +path_to_worker = "postprocess_worker.py" +num_workers = 16 +overlap = 20 + +peroxi_container = "/path/to/peroxi_container.zarr" +peroxi_dataset = "peroxisomes" +mito_container = "/path/to/mito_container.zarr" +mito_dataset = "mitochondria" +threshold = "0.5" +gaussian_kernel = 2 + +array_in = open_ds(peroxi_container, peroxi_dataset) +total_roi = array_in.roi + +voxel_size = array_in.voxel_size +block_size = np.array(array_in.data.chunks) * np.array(voxel_size) + +write_size = daisy.Coordinate(block_size) +write_roi = daisy.Roi((0,) * len(write_size), write_size) + +context = np.array(voxel_size) * overlap + +read_roi = write_roi.grow(context, context) +total_roi = array_in.roi.grow(context, context) + + +# Run the script blockwise +success = run_blockwise( + worker_file=path_to_worker, + total_roi=total_roi, + read_roi=read_roi, + write_roi=write_roi, + num_workers=num_workers, +) + +# Print the success +if success: + print("Success") +else: + print("Failure") + +# example run command: +# bsub -n 4 python blockwise_postprocess_script.py diff --git a/examples/postprocessing/postprocess_worker.py b/examples/postprocessing/postprocess_worker.py new file mode 100644 index 000000000..b5183efd4 --- /dev/null +++ b/examples/postprocessing/postprocess_worker.py @@ -0,0 +1,210 @@ +from typing import Any, Optional +import sys +from dacapo.compute_context import create_compute_context + +import daisy + +import click + +import logging + +import skimage.measure +import skimage.filters +import skimage.morphology +from funlib.persistence import open_ds +import numpy as np +from skimage.segmentation import watershed +from scipy import ndimage as ndi + +logger = logging.getLogger(__file__) + +read_write_conflict: bool = False +fit: str = "valid" +path = __file__ + + +# OPTIONALLY DEFINE GLOBALS HERE + + +@click.group() +@click.option( + "--log-level", + type=click.Choice( + ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False + ), + default="INFO", +) +def cli(log_level): + """ + CLI for running the threshold worker. + + Args: + log_level (str): The log level to use. + """ + logging.basicConfig(level=getattr(logging, log_level.upper())) + + +@cli.command() +@click.option( + "-pc", + "--peroxi-container", + required=True, + type=str, + default=None, +) +@click.option( + "-pd", + "--peroxi-dataset", + required=True, + type=str, + default=None, +) +@click.option( + "-mc", + "--mito-container", + required=True, + type=str, + default=None, +) +@click.option( + "-md", + "--mito-dataset", + required=True, + type=str, + default=None, +) +@click.option( + "-t", + "--threshold", + required=False, + type=float, + default=0.5, +) +@click.option( + "-g", + "--gaussian-kernel", + required=False, + type=int, + default=2, +) +def start_worker( + peroxi_container, + peroxi_dataset, + mito_container, + mito_dataset, + threshold, + gaussian_kernel, + return_io_loop: Optional[bool] = False, +): + """ + Start the worker. + + Args: + peroxi_container (str): The container of the peroxisome predictions. + peroxi_dataset (str): The dataset of the peroxisome predictions. + mito_container (str): The container of the mitochondria predictions. + mito_dataset (str): The dataset of the mitochondria predictions. + threshold (float): The threshold to use for the peroxisome predictions. + gaussian_kernel (int): The kernel size to use for the gaussian filter. + + returns: + instance_peroxi (np.ndarray): The instance labels of the peroxisome predictions. + + """ + # Do something with the argument + # print(arg) + + def io_loop(): + # wait for blocks to run pipeline + client = daisy.Client() + peroxi_ds = open_ds(peroxi_container, peroxi_dataset) + mito_ds = open_ds(mito_container, mito_dataset) + + while True: + print("getting block") + with client.acquire_block() as block: + if block is None: + break + + # Do the blockwise process + peroxi = peroxi_ds.to_ndarray(block.read_roi) + mito = mito_ds.to_ndarray(block.read_roi) + + print(f"processing block: {block.id}, with read_roi: {block.read_roi}") + peroxi = skimage.filters.gaussian(peroxi, gaussian_kernel) + # threshold precictions + binary_peroxi = peroxi > threshold + # get instance labels + markers, _ = ndi.label(binary_peroxi) + # Apply Watershed + ws_labels = watershed(-peroxi, markers, mask=peroxi) + instance_peroxi = skimage.measure.label(ws_labels).astype(np.int64) + # relabel background to 0 + instance_peroxi[mito > 0] = 0 + # make mask of unwanted object class overlaps + return instance_peroxi.astype(np.uint64) + + if return_io_loop: + return io_loop + else: + io_loop() + + +def spawn_worker( + peroxi_container, + peroxi_dataset, + mito_container, + mito_dataset, + threshold, + gaussian_kernel, +): + """ + Spawn a worker. + + Args: + arg (Any): An example argument to use. + Returns: + Callable: The function to run the worker. + """ + compute_context = create_compute_context() + if not compute_context.distribute_workers: + return start_worker( + peroxi_container=peroxi_container, + peroxi_dataset=peroxi_dataset, + mito_container=mito_container, + mito_dataset=mito_dataset, + threshold=threshold, + gaussian_kernel=gaussian_kernel, + return_io_loop=True, + ) + + # Make the command for the worker to run + command = [ + sys.executable, + path, + "start-worker", + "--peroxi-container", + peroxi_container, + "--peroxi-dataset", + peroxi_dataset, + "--mito-container", + mito_container, + "--mito-dataset", + mito_dataset, + "--threshold", + str(threshold), + "--gaussian-kernel", + str(gaussian_kernel), + ] + + def run_worker(): + """ + Run the worker in the given compute context. + """ + compute_context.execute(command) + + return run_worker + + +if __name__ == "__main__": + cli() diff --git a/pyproject.toml b/pyproject.toml index acdeb0fa6..0ab64cdff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ ] dynamic = ["version"] dependencies = [ + "universal-pathlib>=0.2.2,<1.0.0", "numpy>=1.22.4", "pyyaml", "zarr", @@ -48,8 +49,6 @@ dependencies = [ "cellmap-models", "funlib.persistence>=0.3.0", "gunpowder>=1.3", - # "lsds>=0.1.3", - # "lsds @ git+https://github.com/funkelab/lsd", "lsds", "xarray", "cattrs", @@ -57,6 +56,8 @@ dependencies = [ "click", "pyyaml", "scipy", + "upath", + "boto3", ] # extras