diff --git a/.gitignore b/.gitignore index b50398827..fd7633975 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ dist build dacapo.yaml __pycache__ +scratch/ # vscode stuff .vscode diff --git a/dacapo/apply.py b/dacapo/apply.py index 3d1c78974..bfdb2c182 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -12,7 +12,6 @@ import dacapo.experiments.tasks.post_processors as post_processors from dacapo.store.array_store import LocalArrayIdentifier from dacapo.predict import predict -from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store.create_store import ( create_config_store, @@ -34,9 +33,8 @@ def apply( iteration: Optional[int] = None, parameters: Optional[PostProcessorParameters | str] = None, roi: Optional[Roi | str] = None, - num_workers: int = 30, + num_workers: int = 12, output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore - compute_context: ComputeContext = LocalTorch(), overwrite: bool = True, file_format: str = "zarr", ): @@ -144,7 +142,7 @@ def apply( ) output_container = Path( output_path, - "".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}", + Path(input_container).stem + f".{file_format}", ) prediction_array_identifier = LocalArrayIdentifier( output_container, f"prediction_{run_name}_{iteration}" @@ -160,7 +158,7 @@ def apply( Path(input_container, input_dataset), ) return apply_run( - run.name, + run, iteration, parameters, input_array_identifier, @@ -169,22 +167,20 @@ def apply( roi, num_workers, output_dtype, - compute_context, overwrite, ) def apply_run( - run_name: str, + run: Run, iteration: int, parameters: PostProcessorParameters, input_array_identifier: "LocalArrayIdentifier", prediction_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", roi: Optional[Roi] = None, - num_workers: int = 30, - output_dtype: Optional[np.dtype] = np.uint8, # type: ignore - compute_context: ComputeContext = LocalTorch(), + num_workers: int = 12, + output_dtype: np.dtype | str = np.uint8, # type: ignore overwrite: bool = True, ): """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" @@ -192,15 +188,14 @@ def apply_run( # render prediction dataset logger.info("Predicting on dataset %s", prediction_array_identifier) predict( - run_name, + run.name, iteration, input_container=input_array_identifier.container, input_dataset=input_array_identifier.dataset, - output_path=prediction_array_identifier.container, + output_path=prediction_array_identifier, output_roi=roi, num_workers=num_workers, output_dtype=output_dtype, - compute_context=compute_context, overwrite=overwrite, ) @@ -208,7 +203,7 @@ def apply_run( logger.info("Post-processing output to dataset %s", output_array_identifier) post_processor = run.task.post_processor post_processor.set_prediction(prediction_array_identifier) - post_processor.process(parameters, output_array_identifier) + post_processor.process(parameters, output_array_identifier, num_workers=num_workers) logger.info("Done") return diff --git a/dacapo/blockwise/argmax_worker.py b/dacapo/blockwise/argmax_worker.py index ac6ad044e..e42dd0299 100644 --- a/dacapo/blockwise/argmax_worker.py +++ b/dacapo/blockwise/argmax_worker.py @@ -1,7 +1,7 @@ from pathlib import Path from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context import daisy @@ -14,6 +14,7 @@ read_write_conflict: bool = False fit: str = "valid" +path = __file__ @click.group() @@ -74,7 +75,6 @@ def start_worker( def spawn_worker( input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. @@ -82,12 +82,13 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--input_container", input_array_identifier.container, diff --git a/dacapo/blockwise/blockwise_task.py b/dacapo/blockwise/blockwise_task.py index 3b8bf9f9d..54e1b7347 100644 --- a/dacapo/blockwise/blockwise_task.py +++ b/dacapo/blockwise/blockwise_task.py @@ -2,15 +2,12 @@ from importlib.machinery import SourceFileLoader from pathlib import Path from daisy import Task, Roi -from dacapo.compute_context import ComputeContext -import dacapo.compute_context class DaCapoBlockwiseTask(Task): def __init__( self, worker_file: str | Path, - compute_context: ComputeContext | str, total_roi: Roi, read_roi: Roi, write_roi: Roi, @@ -21,9 +18,6 @@ def __init__( *args, **kwargs, ): - if isinstance(compute_context, str): - compute_context = getattr(dacapo.compute_context, compute_context)() - # Load worker functions worker_name = Path(worker_file).stem worker = SourceFileLoader(worker_name, str(worker_file)).load_module() @@ -32,9 +26,7 @@ def __init__( timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") task_id = worker_name + timestamp - process_function = worker.spawn_worker( - *args, **kwargs, compute_context=compute_context - ) + process_function = worker.spawn_worker(*args, **kwargs) if hasattr(worker, "check_function"): check_function = worker.check_function else: diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 40856f191..6b47bf76c 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -1,17 +1,16 @@ from pathlib import Path import torch -from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray -from dacapo.gp.dacapo_array_source import DaCapoArraySource +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray +from dacapo.gp import DaCapoArraySource from dacapo.store.array_store import LocalArrayIdentifier from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.experiments import Run -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context import gunpowder as gp import gunpowder.torch as gp_torch -import daisy -from daisy import Coordinate +from funlib.geometry import Coordinate import numpy as np import click @@ -22,6 +21,7 @@ read_write_conflict: bool = False fit: str = "valid" +path = __file__ @click.group() @@ -66,7 +66,7 @@ def start_worker( input_dataset: str, output_container: Path | str, output_dataset: str, - device: str = "cuda", + device: str | torch.device = "cuda", ): # retrieving run config_store = create_config_store() @@ -92,7 +92,7 @@ def start_worker( torch.backends.cudnn.benchmark = True # get the model's input and output size - model = run.model.eval() + model = run.model.eval().to(device) input_voxel_size = Coordinate(raw_array.voxel_size) output_voxel_size = model.scale(input_voxel_size) input_shape = Coordinate(model.eval_input_shape) @@ -102,6 +102,7 @@ def start_worker( logger.info( "Predicting with input size %s, output size %s", input_size, output_size ) + # create gunpowder keys raw = gp.ArrayKey("RAW") @@ -112,11 +113,13 @@ def start_worker( # prepare data source pipeline = DaCapoArraySource(raw_array, raw) # raw: (c, d, h, w) - pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) + pipeline += gp.Pad(raw, None) # raw: (c, d, h, w) pipeline += gp.Unsqueeze([raw]) # raw: (1, c, d, h, w) + pipeline += gp.Normalize(raw) + # predict pipeline += gp_torch.Predict( model=model, @@ -146,30 +149,31 @@ def start_worker( ) # assumes float32 is [0,1] pipeline += gp.AsType(prediction, output_array.dtype) - # wait for blocks to run pipeline - client = daisy.Client() - - while True: - print("getting block") - with client.acquire_block() as block: - if block is None: - break - - ref_request = gp.BatchRequest() - ref_request[raw] = gp.ArraySpec( - roi=block.read_roi, voxel_size=input_voxel_size, dtype=raw_array.dtype - ) - ref_request[prediction] = gp.ArraySpec( - roi=block.write_roi, - voxel_size=output_voxel_size, - dtype=output_array.dtype, - ) + # write to output array + pipeline += gp.ZarrWrite( + { + prediction: output_array_identifier.dataset, + }, + store=str(output_array_identifier.container), + ) - with gp.build(pipeline): - batch = pipeline.request_batch(ref_request) + # make reference batch request + request = gp.BatchRequest() + request.add(raw, input_size, voxel_size=input_voxel_size) + request.add( + prediction, + output_size, + voxel_size=output_voxel_size, + ) + # use daisy requests to run pipeline + pipeline += gp.DaisyRequestBlocks( + reference=request, + roi_map={raw: "read_roi", prediction: "write_roi"}, + num_workers=1, + ) - # write to output array - output_array[block.write_roi] = batch.arrays[prediction].data + with gp.build(pipeline): + batch = pipeline.request_batch(gp.BatchRequest()) def spawn_worker( @@ -177,7 +181,6 @@ def spawn_worker( iteration: int, raw_array_identifier: "LocalArrayIdentifier", prediction_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. @@ -185,12 +188,13 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--run-name", run_name, diff --git a/dacapo/blockwise/relabel_worker.py b/dacapo/blockwise/relabel_worker.py index dc45fb53c..1b8580c28 100644 --- a/dacapo/blockwise/relabel_worker.py +++ b/dacapo/blockwise/relabel_worker.py @@ -1,7 +1,7 @@ from glob import glob import os import daisy -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context from dacapo.store.array_store import LocalArrayIdentifier from scipy.cluster.hierarchy import DisjointSet from funlib.persistence import open_ds @@ -27,6 +27,7 @@ def cli(log_level): fit = "shrink" read_write_conflict = False +path = __file__ @cli.command() @@ -88,7 +89,6 @@ def read_cross_block_merges(tmpdir): def spawn_worker( output_array_identifier: LocalArrayIdentifier, tmpdir: str, - compute_context: ComputeContext = LocalTorch(), *args, **kwargs, ): @@ -97,12 +97,13 @@ def spawn_worker( Args: output_array_identifier (LocalArrayIdentifier): The output array identifier tmpdir (str): The temporary directory - compute_context (ComputeContext, optional): The compute context. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--output_container", output_array_identifier.container, diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index 675ca52fe..b9ee72765 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -2,16 +2,15 @@ import tempfile import time import daisy +import dacapo.blockwise from funlib.geometry import Roi, Coordinate import yaml -from dacapo.compute_context import ComputeContext from dacapo.blockwise import DaCapoBlockwiseTask def run_blockwise( worker_file: str | Path, - compute_context: ComputeContext | str, total_roi: Roi, read_roi: Roi, write_roi: Roi, @@ -51,10 +50,6 @@ def run_blockwise( (either due to failed post check or application crashes or network failure) - compute_context (``ComputeContext``): - - The compute context to use for parallelization. - *args: Additional positional arguments to pass to ``worker_function``. @@ -72,7 +67,6 @@ def run_blockwise( # Make the task task = DaCapoBlockwiseTask( worker_file, - compute_context, total_roi, read_roi, write_roi, @@ -88,8 +82,7 @@ def run_blockwise( def segment_blockwise( - segment_function_file: str or Path, - compute_context: ComputeContext | str, + segment_function_file: str | Path, context: Coordinate, total_roi: Roi, read_roi: Roi, @@ -110,8 +103,7 @@ def segment_blockwise( # Make the task task = DaCapoBlockwiseTask( - str(Path(Path(__file__).parent, "segment_worker.py")), - compute_context, + str(Path(Path(dacapo.blockwise.__file__).parent, "segment_worker.py")), total_roi.grow(context, context), read_roi, write_roi, @@ -135,8 +127,7 @@ def segment_blockwise( # Make the task task = DaCapoBlockwiseTask( - str(Path(Path(__file__).parent, "relabel_worker.py")), - compute_context, + str(Path(Path(dacapo.blockwise.__file__).parent, "relabel_worker.py")), total_roi, read_roi, write_roi, diff --git a/dacapo/blockwise/segment_worker.py b/dacapo/blockwise/segment_worker.py index bd15320d7..32c86cacb 100644 --- a/dacapo/blockwise/segment_worker.py +++ b/dacapo/blockwise/segment_worker.py @@ -8,7 +8,7 @@ import numpy as np import yaml -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier @@ -28,6 +28,7 @@ def cli(log_level): fit = "shrink" read_write_conflict = True +path = __file__ @cli.command() @@ -157,7 +158,6 @@ def spawn_worker( output_array_identifier: LocalArrayIdentifier, tmpdir: str, function_path: str, - compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. @@ -165,12 +165,13 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--input_container", input_array_identifier.container, diff --git a/dacapo/blockwise/threshold_worker.py b/dacapo/blockwise/threshold_worker.py index d8d645c2b..60fa0198e 100644 --- a/dacapo/blockwise/threshold_worker.py +++ b/dacapo/blockwise/threshold_worker.py @@ -1,7 +1,7 @@ from pathlib import Path from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context import daisy @@ -14,6 +14,7 @@ read_write_conflict: bool = False fit: str = "valid" +path = __file__ @click.group() @@ -76,7 +77,6 @@ def spawn_worker( input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", threshold: float = 0.0, - compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. @@ -84,12 +84,13 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--input_container", input_array_identifier.container, diff --git a/dacapo/cli.py b/dacapo/cli.py index 8c064aadc..cda08e40f 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -11,7 +11,6 @@ from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( PostProcessorParameters, ) -from dacapo.compute_context import ComputeContext, LocalTorch @click.group() @@ -45,6 +44,9 @@ def train(run_name): type=int, help="The iteration at which to validate the run.", ) +@click.option("-w", "--num_workers", type=int, default=30) +@click.option("-dt", "--output_dtype", type=str, default="uint8") +@click.option("-ow", "--overwrite", is_flag=True) def validate(run_name, iteration): dacapo.validate(run_name, iteration) @@ -75,7 +77,6 @@ def validate(run_name, iteration): @click.option("-w", "--num_workers", type=int, default=30) @click.option("-dt", "--output_dtype", type=str, default="uint8") @click.option("-ow", "--overwrite", is_flag=True) -@click.option("-cc", "--compute_context", type=str, default="LocalTorch") def apply( run_name: str, input_container: Path | str, @@ -89,11 +90,7 @@ def apply( num_workers: int = 30, output_dtype: Optional[np.dtype | str] = "uint8", overwrite: bool = True, - compute_context: Optional[ComputeContext | str] = LocalTorch(), ): - if isinstance(compute_context, str): - compute_context = getattr(compute_context, compute_context)() - dacapo.apply( run_name, input_container, @@ -107,7 +104,6 @@ def apply( num_workers, output_dtype, overwrite=overwrite, - compute_context=compute_context, # type: ignore ) @@ -139,13 +135,6 @@ def apply( ) @click.option("-w", "--num_workers", type=int, default=30) @click.option("-dt", "--output_dtype", type=str, default="uint8") -@click.option( - "-cc", - "--compute_context", - type=str, - default="LocalTorch", - help="The compute context to use for prediction. Must be the name of a subclass of ComputeContext.", -) @click.option("-ow", "--overwrite", is_flag=True) def predict( run_name: str, @@ -156,7 +145,6 @@ def predict( output_roi: Optional[str | Roi] = None, num_workers: int = 30, output_dtype: np.dtype | str = np.uint8, # type: ignore - compute_context: ComputeContext | str = LocalTorch(), overwrite: bool = True, ): dacapo.predict( @@ -168,6 +156,5 @@ def predict( output_roi, num_workers, output_dtype, - compute_context, overwrite, ) diff --git a/dacapo/compute_context/__init__.py b/dacapo/compute_context/__init__.py index c1d859c50..e2cc14722 100644 --- a/dacapo/compute_context/__init__.py +++ b/dacapo/compute_context/__init__.py @@ -1,3 +1,3 @@ -from .compute_context import ComputeContext # noqa +from .compute_context import ComputeContext, create_compute_context # noqa from .local_torch import LocalTorch # noqa from .bsub import Bsub # noqa diff --git a/dacapo/compute_context/bsub.py b/dacapo/compute_context/bsub.py index 54d3dadda..9f168e752 100644 --- a/dacapo/compute_context/bsub.py +++ b/dacapo/compute_context/bsub.py @@ -2,7 +2,6 @@ import attr -import subprocess from typing import Optional diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index 1cf660188..a57de1a09 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod import subprocess +from dacapo import Options, compute_context + class ComputeContext(ABC): @property @@ -18,4 +20,17 @@ def wrap_command(self, command): def execute(self, command): # A helper method to run a command in the context specific way. - return subprocess.run(self.wrap_command(command)) + subprocess.run(self.wrap_command(command)) + + +def create_compute_context(): + """Create a compute context based on the global DaCapo options.""" + + options = Options.instance() + + if hasattr(compute_context, options.compute_context_config["type"]): + return getattr(compute_context, options.compute_context_config["type"])( + **options.compute_context_config["config"] + ) + else: + raise ValueError(f"Unknown store type {options.type}") diff --git a/dacapo/experiments/datasplits/datasplit.py b/dacapo/experiments/datasplits/datasplit.py index 17c7e3ac1..eb5a55023 100644 --- a/dacapo/experiments/datasplits/datasplit.py +++ b/dacapo/experiments/datasplits/datasplit.py @@ -1,4 +1,4 @@ -from dacapo.experiments.datasplits.datasets import Dataset +from .datasets import Dataset import neuroglancer diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 129f947ab..8aef6eb1d 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -1,7 +1,7 @@ -from .datasplits.datasplit import DataSplit +from .datasplits import DataSplit from .tasks.task import Task -from .architectures.architecture import Architecture -from .trainers.trainer import Trainer +from .architectures import Architecture +from .trainers import Trainer from .training_stats import TrainingStats from .validation_scores import ValidationScores from .starts import Start diff --git a/dacapo/experiments/tasks/evaluators/dummy_evaluator.py b/dacapo/experiments/tasks/evaluators/dummy_evaluator.py index 3e2e27b94..abedb4459 100644 --- a/dacapo/experiments/tasks/evaluators/dummy_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/dummy_evaluator.py @@ -7,7 +7,17 @@ class DummyEvaluator(Evaluator): criteria = ["frizz_level", "blipp_score"] - def evaluate(self, output_array, evaluation_dataset): + def evaluate(self, output_array_identifier, evaluation_dataset): + """ + Evaluate the given output array and dataset and returns the scores based on predefined criteria. + + Args: + output_array_identifier : The output array to be evaluated. + evaluation_dataset : The dataset to be used for evaluation. + + Returns: + DummyEvaluationScore: An object of DummyEvaluationScores class, with the evaluation scores. + """ return DummyEvaluationScores( frizz_level=random.random(), blipp_score=random.random() ) diff --git a/dacapo/experiments/tasks/evaluators/evaluator.py b/dacapo/experiments/tasks/evaluators/evaluator.py index 9d5cbbda0..64d7873ca 100644 --- a/dacapo/experiments/tasks/evaluators/evaluator.py +++ b/dacapo/experiments/tasks/evaluators/evaluator.py @@ -9,6 +9,7 @@ from dacapo.experiments.tasks.evaluators.evaluation_scores import EvaluationScores from dacapo.experiments.datasplits.datasets import Dataset from dacapo.experiments.datasplits.datasets.arrays import Array + from dacapo.store.local_array_store import LocalArrayIdentifier from dacapo.experiments.tasks.post_processors import PostProcessorParameters from dacapo.experiments.validation_scores import ValidationScores @@ -28,9 +29,23 @@ class Evaluator(ABC): @abstractmethod def evaluate( - self, output_array: "Array", eval_array: "Array" + self, output_array_identifier: "LocalArrayIdentifier", evaluation_array: "Array" ) -> "EvaluationScores": - """Compare an `output_array` against ground-truth `eval_array`""" + """ + Compares and evaluates the output array against the evaluation array. + + Parameters + ---------- + output_array_identifier : Array + The output data array to evaluate + evaluation_array : Array + The evaluation data array to compare with the output + + Returns + ------- + EvaluationScores + The detailed evaluation scores after the comparison. + """ pass @property diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 02f8b1202..42863b56d 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -1,6 +1,6 @@ from pathlib import Path -from dacapo.blockwise.scheduler import run_blockwise -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.blockwise import run_blockwise +import dacapo.blockwise from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier from .argmax_post_processor_parameters import ArgmaxPostProcessorParameters @@ -20,6 +20,7 @@ def enumerate_parameters(self): yield ArgmaxPostProcessorParameters(id=1) def set_prediction(self, prediction_array_identifier): + self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) @@ -27,8 +28,7 @@ def set_prediction(self, prediction_array_identifier): def process( self, parameters, - output_array_identifier, - compute_context: ComputeContext | str = LocalTorch(), + output_array_identifier: "LocalArrayIdentifier", num_workers: int = 16, block_size: Coordinate = Coordinate((64, 64, 64)), ): @@ -42,12 +42,11 @@ def process( ) read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size) - # run blockwise prediction + # run blockwise post-processing run_blockwise( worker_file=str( - Path(Path(__file__).parent, "blockwise", "predict_worker.py") + Path(Path(dacapo.blockwise.__file__).parent, "argmax_worker.py") ), - compute_context=compute_context, total_roi=self.prediction_array.roi, read_roi=read_roi, write_roi=read_roi, @@ -55,9 +54,8 @@ def process( max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option ###### - input_array_identifier=LocalArrayIdentifier( - self.prediction_array.file_name, self.prediction_array.dataset - ), + input_array_identifier=self.prediction_array_identifier, output_array_identifier=output_array_identifier, ) + return output_array diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py index 5a2c7810a..4a992ced2 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py @@ -18,10 +18,10 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]: for i, min_size in enumerate(range(1, 11)): yield DummyPostProcessorParameters(id=i, min_size=min_size) - def set_prediction(self, prediction_array): + def set_prediction(self, prediction_array_identifier): pass - def process(self, parameters, output_array_identifier): + def process(self, parameters, output_array_identifier, *args, **kwargs): # store some dummy data f = zarr.open(str(output_array_identifier.container), "a") f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size diff --git a/dacapo/experiments/tasks/post_processors/post_processor.py b/dacapo/experiments/tasks/post_processors/post_processor.py index 585063828..f0a991c51 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor.py +++ b/dacapo/experiments/tasks/post_processors/post_processor.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from dacapo.compute_context import ComputeContext, LocalTorch from funlib.geometry import Coordinate from typing import Iterable, TYPE_CHECKING @@ -35,7 +34,6 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext | str = LocalTorch(), num_workers: int = 16, chunk_size: Coordinate = Coordinate((64, 64, 64)), ) -> "Array": diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index bbdc76aa1..5d3b45220 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -1,19 +1,14 @@ from pathlib import Path from dacapo.blockwise.scheduler import run_blockwise -from dacapo.compute_context import ComputeContext, LocalTorch from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from .threshold_post_processor_parameters import ThresholdPostProcessorParameters +from dacapo.store.array_store import LocalArrayIdentifier from .post_processor import PostProcessor +import dacapo.blockwise import numpy as np from daisy import Roi, Coordinate -from typing import TYPE_CHECKING, Iterable - -if TYPE_CHECKING: - from dacapo.store.local_array_store import LocalArrayIdentifier - from dacapo.experiments.tasks.post_processors import ( - ThresholdPostProcessorParameters, - ) +from typing import Iterable class ThresholdPostProcessor(PostProcessor): @@ -25,7 +20,8 @@ def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]: for i, threshold in enumerate([-0.1, 0.0, 0.1]): yield ThresholdPostProcessorParameters(id=i, threshold=threshold) - def set_prediction(self, prediction_array_identifier: "LocalArrayIdentifier"): + def set_prediction(self, prediction_array_identifier): + self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) @@ -34,7 +30,6 @@ def process( self, parameters: "ThresholdPostProcessorParameters", # type: ignore[override] output_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext | str = LocalTorch(), num_workers: int = 16, block_size: Coordinate = Coordinate((64, 64, 64)), ) -> ZarrArray: @@ -59,12 +54,11 @@ def process( ) read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size) - # run blockwise prediction + # run blockwise post-processing run_blockwise( worker_file=str( - Path(Path(__file__).parent, "blockwise", "predict_worker.py") + Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py") ), - compute_context=compute_context, total_roi=self.prediction_array.roi, read_roi=read_roi, write_roi=read_roi, @@ -72,9 +66,7 @@ def process( max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option ###### - input_array_identifier=LocalArrayIdentifier( - self.prediction_array.file_name, self.prediction_array.dataset - ), + input_array_identifier=self.prediction_array_identifier, output_array_identifier=output_array_identifier, threshold=parameters.threshold, ) diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 64bec66e8..fa9d10a47 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -1,12 +1,11 @@ from pathlib import Path +import dacapo.blockwise from dacapo.blockwise.scheduler import segment_blockwise -from dacapo.compute_context import ComputeContext, LocalTorch from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier from .watershed_post_processor_parameters import WatershedPostProcessorParameters from .post_processor import PostProcessor -from dacapo.compute_context import ComputeContext, LocalTorch from funlib.geometry import Coordinate, Roi @@ -28,6 +27,7 @@ def enumerate_parameters(self): yield WatershedPostProcessorParameters(id=i, bias=bias) def set_prediction(self, prediction_array_identifier): + self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) @@ -36,7 +36,6 @@ def process( self, parameters: WatershedPostProcessorParameters, # type: ignore[override] output_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext | str = LocalTorch(), num_workers: int = 16, block_size: Coordinate = Coordinate((64, 64, 64)), ): @@ -58,9 +57,8 @@ def process( } segment_blockwise( segment_function_file=str( - Path(Path(__file__).parent, "blockwise", "watershed_function.py") + Path(Path(dacapo.blockwise.__file__).parent, "watershed_function.py") ), - compute_context=compute_context, context=parameters.context, total_roi=self.prediction_array.roi, read_roi=read_roi.grow(parameters.context, parameters.context), @@ -69,9 +67,7 @@ def process( max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option ###### - input_array_identifier=LocalArrayIdentifier( - self.prediction_array.file_name, self.prediction_array.dataset - ), + input_array_identifier=self.prediction_array_identifier, output_array_identifier=output_array_identifier, parameters=pars, ) diff --git a/dacapo/experiments/trainers/dummy_trainer.py b/dacapo/experiments/trainers/dummy_trainer.py index 85c7c1ee8..7e826979a 100644 --- a/dacapo/experiments/trainers/dummy_trainer.py +++ b/dacapo/experiments/trainers/dummy_trainer.py @@ -15,28 +15,35 @@ def __init__(self, trainer_config): self.mirror_augment = trainer_config.mirror_augment def create_optimizer(self, model): - return torch.optim.Adam(lr=self.learning_rate, params=model.parameters()) + return torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) def iterate(self, num_iterations: int, model: Model, optimizer, device): target_iteration = self.iteration + num_iterations - for self.iteration in range(self.iteration, target_iteration): + for iteration in range(self.iteration, target_iteration): optimizer.zero_grad() - raw = torch.from_numpy( - np.random.randn(1, model.num_in_channels, *model.input_shape) - ).float() - target = torch.from_numpy( - np.zeros((1, model.num_out_channels, *model.output_shape)) - ).float() + raw = ( + torch.from_numpy( + np.random.randn(1, model.num_in_channels, *model.input_shape) + ) + .float() + .to(device) + ) + target = ( + torch.from_numpy( + np.zeros((1, model.num_out_channels, *model.output_shape)) + ) + .float() + .to(device) + ) pred = model.forward(raw) loss = self._loss.compute(pred, target) loss.backward() optimizer.step() yield TrainingIterationStats( - loss=1.0 / (self.iteration + 1), iteration=self.iteration, time=0.1 + loss=1.0 / (iteration + 1), iteration=iteration, time=0.1 ) - - self.iteration += 1 + self.iteration += 1 def build_batch_provider(self, datasplit, architecture, task, snapshot_container): self._loss = task.loss diff --git a/dacapo/options.py b/dacapo/options.py index cea11b38b..d2e2f14c1 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -1,69 +1,98 @@ +import os import yaml import logging from os.path import expanduser from pathlib import Path -logger = logging.getLogger(__name__) +import attr +from cattr import Converter -# options files in order of precedence (highest first) -options_files = [ - Path("./dacapo.yaml"), - Path(expanduser("~/.config/dacapo/dacapo.yaml")), -] +from typing import Optional +logger = logging.getLogger(__name__) -def parse_options(): - for path in options_files: - if not path.exists(): - continue - with path.open("r") as f: - return yaml.safe_load(f) +@attr.s +class DaCapoConfig: + type: str = attr.ib( + default="files", + metadata={ + "help_text": "The type of store to use for storing configurations and statistics. " + "Currently, only 'files' and 'mongo' are supported with files being the default." + }, + ) + runs_base_dir: Path = attr.ib( + default=Path(expanduser("~/.dacapo")), + metadata={ + "help_text": "The path at DaCapo will use for reading and writing any necessary data." + }, + ) + compute_context_config: dict = attr.ib( + default={"type": "LocalTorch", "config": {"device": None}}, + metadata={ + "help_text": "The configuration for the compute context to use. " + "This is a dictionary with the keys being the names of the compute context and the values being the configuration for that context." + }, + ) + mongo_db_host: Optional[str] = attr.ib( + default=None, + metadata={ + "help_text": "The host of the MongoDB instance to use for storing configurations and statistics." + }, + ) + mongo_db_name: Optional[str] = attr.ib( + default=None, + metadata={ + "help_text": "The name of the MongoDB database to use for storing configurations and statistics." + }, + ) + + def serialize(self): + converter = Converter() + return converter.unstructure(self) class Options: - _instance = None - def __init__(self): raise RuntimeError("Singleton: Use Options.instance()") @classmethod - def instance(cls, **kwargs): - if cls._instance is None: - cls._instance = cls.__new__(cls) - cls._instance.__parse_options(**kwargs) - - return cls._instance - - def __getattr__(self, name): - try: - return self.__options[name] - except KeyError: - raise RuntimeError( - f"Configuration file {self.filename} does not contain an " - f"entry for option {name}" - ) - - def __parse_options(self, **kwargs): - if len(kwargs) > 0: - self.__options = kwargs - self.filename = "kwargs" - return + def instance(cls, **kwargs) -> DaCapoConfig: + config = cls.__parse_options(**kwargs) + return config + + @classmethod + def config_file(cls) -> Optional[Path]: + env_dict = dict(os.environ) + if "OPTIONS_FILE" in env_dict: + options_files = [Path(env_dict["OPTIONS_FILE"])] + else: + options_files = [] + + # options files in order of precedence (highest first) + options_files += [ + Path("./dacapo.yaml"), + Path(expanduser("~/.config/dacapo/dacapo.yaml")), + ] for path in options_files: - if not path.exists(): - continue + if path.exists(): + return path + return None - with path.open("r") as f: - self.__options = yaml.safe_load(f) - self.filename = path + @classmethod + def __parse_options_from_file(cls): + if (config_file := cls.config_file()) is not None: + with config_file.open("r") as f: + return yaml.safe_load(f) + else: + return {} - return + @classmethod + def __parse_options(cls, **kwargs): + options = cls.__parse_options_from_file() + options.update(kwargs) - logger.error( - "No options file found. Please create any of the following " "files:" - ) - for path in options_files: - logger.error("\t%s", path.absolute()) + converter = Converter() - raise RuntimeError("Could not find a DaCapo options file.") + return converter.structure(options, DaCapoConfig) diff --git a/dacapo/predict.py b/dacapo/predict.py index 4ce3f98bf..ee0dcaa2b 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -1,13 +1,11 @@ from pathlib import Path -import click from dacapo.blockwise import run_blockwise +import dacapo.blockwise from dacapo.experiments import Run -from dacapo.store.create_store import create_config_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.store.local_array_store import LocalArrayIdentifier -from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.experiments.datasplits.datasets.arrays import ZarrArray -from dacapo.cli import cli from funlib.geometry import Coordinate, Roi import numpy as np @@ -24,45 +22,71 @@ def predict( iteration: int, input_container: Path | str, input_dataset: str, - output_path: Path | str, + output_path: LocalArrayIdentifier | str, output_roi: Optional[Roi | str] = None, - num_workers: int = 30, + num_workers: int = 12, output_dtype: np.dtype | str = np.uint8, # type: ignore - compute_context: ComputeContext | str = LocalTorch(), overwrite: bool = True, ): - """_summary_ + """Predict with a trained model. Args: - run_name (str): _description_ - iteration (int): _description_ - input_container (Path | str): _description_ - input_dataset (str): _description_ - output_path (Path | str): _description_ - output_roi (Optional[str], optional): Defaults to None. If output roi is None, - it will be set to the raw roi. - num_workers (int, optional): _description_. Defaults to 30. - output_dtype (np.dtype | str, optional): _description_. Defaults to np.uint8. - overwrite (bool, optional): _description_. Defaults to True. + run_name (str): The name of the run to predict with. + iteration (int): The training iteration of the model to use for prediction. + input_container (Path | str): The container of the input array. + input_dataset (str): The dataset name of the input array. + output_path (LocalArrayIdentifier | str): The path where the prediction array will be stored, or a LocalArryIdentifier for the prediction array. + output_roi (Optional[Roi | str], optional): The ROI of the output array. If None, the ROI of the input array will be used. Defaults to None. + num_workers (int, optional): The number of workers to use for blockwise prediction. Defaults to 30. + output_dtype (np.dtype | str, optional): The dtype of the output array. Defaults to np.uint8. + overwrite (bool, optional): If True, the output array will be overwritten if it already exists. Defaults to True. """ # retrieving run config_store = create_config_store() run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) + # check to see if we can load the weights + weights_store = create_weights_store() + try: + weights_store.retrieve_weights(run_name, iteration) + except FileNotFoundError: + raise ValueError( + f"No weights found for run {run_name} at iteration {iteration}." + ) + # get arrays raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) - output_container = Path( - output_path, - "".join(Path(input_container).name.split(".")[:-1]) + ".zarr", - ) # TODO: zarr hardcoded - prediction_array_identifier = LocalArrayIdentifier( - output_container, f"prediction_{run_name}_{iteration}" - ) + if isinstance(output_path, LocalArrayIdentifier): + prediction_array_identifier = output_path + else: + if ".zarr" in str(output_path) or ".n5" in str(output_path): + output_container = Path(output_path) + else: + output_container = Path( + output_path, + Path(input_container).stem + ".zarr", + ) # TODO: zarr hardcoded + prediction_array_identifier = LocalArrayIdentifier( + output_container, f"prediction_{run_name}_{iteration}" + ) + + # get the model's input and output size + model = run.model.eval() + + input_voxel_size = Coordinate(raw_array.voxel_size) + output_voxel_size = model.scale(input_voxel_size) + input_shape = Coordinate(model.eval_input_shape) + input_size = input_voxel_size * input_shape + output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + # calculate input and output rois + + context = (input_size - output_size) // 2 if output_roi is None: - output_roi = raw_array.roi + output_roi = raw_array.roi.grow(-context, -context) elif isinstance(output_roi, str): start, end = zip( *[ @@ -76,30 +100,16 @@ def predict( ) output_roi = output_roi.snap_to_grid( raw_array.voxel_size, mode="grow" - ).intersect(raw_array.roi) + ).intersect(raw_array.roi.grow(-context, -context)) + _input_roi = output_roi.grow(context, context) if isinstance(output_dtype, str): output_dtype = np.dtype(output_dtype) - model = run.model.eval() - - # get the model's input and output size - - input_voxel_size = Coordinate(raw_array.voxel_size) - output_voxel_size = model.scale(input_voxel_size) - input_shape = Coordinate(model.eval_input_shape) - input_size = input_voxel_size * input_shape - output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] - logger.info( "Predicting with input size %s, output size %s", input_size, output_size ) - # calculate input and output rois - - context = (input_size - output_size) / 2 - _input_roi = output_roi.grow(context, context) - logger.info("Total input ROI: %s, output ROI: %s", _input_roi, output_roi) # prepare prediction dataset @@ -115,12 +125,13 @@ def predict( ) # run blockwise prediction + worker_file = str(Path(Path(dacapo.blockwise.__file__).parent, "predict_worker.py")) + logger.info("Running blockwise prediction with worker_file: ", worker_file) run_blockwise( - worker_file=str(Path(Path(__file__).parent, "blockwise", "predict_worker.py")), - compute_context=compute_context, + worker_file=worker_file, total_roi=_input_roi, read_roi=Roi((0, 0, 0), input_size), - write_roi=Roi((0, 0, 0), output_size), + write_roi=Roi(context, output_size), num_workers=num_workers, max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option diff --git a/dacapo/store/config_store.py b/dacapo/store/config_store.py index 8c91fd036..87fa6edb0 100644 --- a/dacapo/store/config_store.py +++ b/dacapo/store/config_store.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, TYPE_CHECKING +from typing import Any, List, TYPE_CHECKING if TYPE_CHECKING: from dacapo.experiments.run_config import RunConfig @@ -17,40 +17,13 @@ class DuplicateNameError(Exception): class ConfigStore(ABC): """Base class for configuration stores.""" - @property - @abstractmethod - def runs(self): - pass - - @property - @abstractmethod - def datasplits(self): - pass - - @property - @abstractmethod - def datasets(self): - pass - - @property - @abstractmethod - def arrays(self): - pass - - @property - @abstractmethod - def tasks(self): - pass - - @property - @abstractmethod - def trainers(self): - pass - - @property - @abstractmethod - def architectures(self): - pass + runs: Any + datasplits: Any + datasets: Any + arrays: Any + tasks: Any + trainers: Any + architectures: Any @abstractmethod def delete_config(self, database, config_name: str) -> None: diff --git a/dacapo/store/create_store.py b/dacapo/store/create_store.py index 47e92626f..0fcc43ed2 100644 --- a/dacapo/store/create_store.py +++ b/dacapo/store/create_store.py @@ -14,19 +14,15 @@ def create_config_store(): options = Options.instance() - try: - store_type = options.type - except RuntimeError: - store_type = "files" - if store_type == "mongo": + if options.type == "mongo": db_host = options.mongo_db_host db_name = options.mongo_db_name return MongoConfigStore(db_host, db_name) - elif store_type == "files": + elif options.type == "files": store_path = Path(options.runs_base_dir).expanduser() return FileConfigStore(store_path / "configs") else: - raise ValueError(f"Unknown store type {store_type}") + raise ValueError(f"Unknown store type {options.type}") def create_stats_store(): @@ -34,17 +30,15 @@ def create_stats_store(): options = Options.instance() - try: - store_type = options.type - except RuntimeError: - store_type = "mongo" - if store_type == "mongo": + if options.type == "mongo": db_host = options.mongo_db_host db_name = options.mongo_db_name return MongoStatsStore(db_host, db_name) - elif store_type == "files": + elif options.type == "files": store_path = Path(options.runs_base_dir).expanduser() return FileStatsStore(store_path / "stats") + else: + raise ValueError(f"Unknown store type {options.type}") def create_weights_store(): diff --git a/dacapo/train.py b/dacapo/train.py index abf5ad48c..620f08e55 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -1,3 +1,4 @@ +from dacapo.compute_context import create_compute_context from dacapo.store.create_store import ( create_array_store, create_config_store, @@ -5,7 +6,6 @@ create_weights_store, ) from dacapo.experiments import Run -from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.validate import validate_run import torch @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -def train(run_name: str, compute_context: ComputeContext = LocalTorch()): +def train(run_name: str): """Train a run""" # check config store to see if run is already being trained TODO @@ -34,13 +34,10 @@ def train(run_name: str, compute_context: ComputeContext = LocalTorch()): run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - return train_run(run, compute_context=compute_context) + return train_run(run) -def train_run( - run: Run, - compute_context: ComputeContext = LocalTorch(), -): +def train_run(run: Run): logger.info("Starting/resuming training for run %s...", run) # create run @@ -117,6 +114,7 @@ def train_run( # loading weights directly from a checkpoint into cuda # can allocate twice the memory of loading to cpu before # moving to cuda. + compute_context = create_compute_context() run.model = run.model.to(compute_context.device) run.move_optimizer(compute_context.device) @@ -155,11 +153,20 @@ def train_run( trained_until = run.training_stats.trained_until() # If this is not a validation iteration or final iteration, skip validation + # also skip for test cases where total iterations is less than validation interval no_its = iteration_stats is None # No training steps run validation_it = ( iteration_stats.iteration + 1 ) % run.validation_interval == 0 final_it = trained_until >= run.train_until + if final_it and (trained_until < run.validation_interval): + # Special case for tests - skip validation, but store weights + stats_store.store_training_stats(run.name, run.training_stats) + weights_store.store_weights(run, iteration_stats.iteration + 1) + run.move_optimizer(compute_context.device) + run.model.train() + continue + if no_its or (not validation_it and not final_it): stats_store.store_training_stats(run.name, run.training_stats) continue @@ -175,7 +182,6 @@ def train_run( validate_run( run, iteration_stats.iteration + 1, - compute_context=compute_context, ) stats_store.store_validation_iteration_scores( run.name, run.validation_scores diff --git a/dacapo/validate.py b/dacapo/validate.py index 65fcb03d8..0b87bff91 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -1,5 +1,4 @@ from .predict import predict -from .compute_context import LocalTorch, ComputeContext from .experiments import Run, ValidationIterationScores from .experiments.datasplits.datasets.arrays import ZarrArray from .store.create_store import ( @@ -18,7 +17,6 @@ def validate( run_name: str, iteration: int, - compute_context: ComputeContext = LocalTorch(), num_workers: int = 30, output_dtype: str = "uint8", overwrite: bool = True, @@ -45,12 +43,11 @@ def validate( # create weights store and read weights weights_store = create_weights_store() - weights_store.retrieve_weights(run, iteration) + weights_store.retrieve_weights(run.name, iteration) return validate_run( run, iteration, - compute_context=compute_context, num_workers=num_workers, output_dtype=output_dtype, overwrite=overwrite, @@ -60,7 +57,6 @@ def validate( def validate_run( run: Run, iteration: int, - compute_context: ComputeContext = LocalTorch(), num_workers: int = 30, output_dtype: str = "uint8", overwrite: bool = True, @@ -151,7 +147,7 @@ def validate_run( logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration, validation_dataset + run.name, iteration, validation_dataset.name ) logger.info("Predicting on dataset %s", validation_dataset.name) predict( @@ -159,11 +155,10 @@ def validate_run( iteration, input_container=input_raw_array_identifier.container, input_dataset=input_raw_array_identifier.dataset, - output_path=prediction_array_identifier.container, + output_path=prediction_array_identifier, output_roi=validation_dataset.gt.roi, num_workers=num_workers, output_dtype=output_dtype, - compute_context=compute_context, overwrite=overwrite, ) @@ -175,7 +170,7 @@ def validate_run( for parameters in post_processor.enumerate_parameters(): output_array_identifier = array_store.validation_output_array( - run.name, iteration, parameters, validation_dataset + run.name, iteration, str(parameters), validation_dataset.name ) post_processed_array = post_processor.process( @@ -211,10 +206,11 @@ def validate_run( "iteration": iteration, criterion: getattr(scores, criterion), "parameters_id": parameters.id, + "parameters": str(parameters), } ) weights_store.store_best( - run, iteration, validation_dataset.name, criterion + run.name, iteration, validation_dataset.name, criterion ) # delete current output. We only keep the best outputs as determined by diff --git a/pyproject.toml b/pyproject.toml index 325a77b35..a3f0b0015 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ dependencies = [ "attrs", "bokeh", "numpy-indexed>=0.3.7", - "daisy>=1.0", + #"daisy>=1.0", + "daisy @ git+https://github.com/funkelab/daisy.git", "funlib.math>=0.1", "funlib.geometry>=0.2", "mwatershed>=0.1", @@ -83,9 +84,10 @@ examples = [ "ipykernel", "jupyter", ] +all = ["dacapo-ml[test,dev,docs,examples]"] [project.urls] -homepage = "https://github.com/janelia-cellmap/dacapo" +homepage = "https://github.io/janelia-cellmap/dacapo" repository = "https://github.com/janelia-cellmap/dacapo" # https://hatch.pypa.io/latest/config/metadata/ @@ -139,7 +141,7 @@ filterwarnings = [ "error", "ignore::DeprecationWarning", ] - + # https://mypy.readthedocs.io/en/stable/config_file.html [tool.mypy] files = "dacapo/**/" diff --git a/tests/components/test_options.py b/tests/components/test_options.py new file mode 100644 index 000000000..7ac7e1488 --- /dev/null +++ b/tests/components/test_options.py @@ -0,0 +1,64 @@ +from dacapo import Options + + +from pathlib import Path +import textwrap + + +def test_no_config(): + # Make sure the config file does not exist + config_file = Path("dacapo.yaml") + if config_file.exists(): + config_file.unlink() + + # Parse the options + options = Options.instance() + + # Check the options + assert isinstance(options.runs_base_dir, Path) + assert options.mongo_db_host is None + assert options.mongo_db_name is None + + # Parse the options + options = Options.instance( + runs_base_dir="tmp", mongo_db_host="localhost", mongo_db_name="dacapo" + ) + + # Check the options + assert options.runs_base_dir == Path("tmp") + assert options.mongo_db_host == "localhost" + assert options.mongo_db_name == "dacapo" + + +# we need to change the working directory because +# dacapo looks for the config file in the working directory +def test_local_config_file(): + # Create a config file + config_file = Path("dacapo.yaml") + config_file.write_text( + textwrap.dedent( + """ + runs_base_dir: /tmp + mongo_db_host: localhost + mongo_db_name: dacapo + """ + ) + ) + + # Parse the options + options = Options.instance() + + # Check the options + assert options.runs_base_dir == Path("/tmp") + assert options.mongo_db_host == "localhost" + assert options.mongo_db_name == "dacapo" + assert Options.config_file() == config_file + + # Parse the options + options = Options.instance(runs_base_dir="/tmp2") + + # Check the options + assert options.runs_base_dir == Path("/tmp2") + assert options.mongo_db_host == "localhost" + assert options.mongo_db_name == "dacapo" + assert Options.config_file() == config_file diff --git a/tests/fixtures/arrays.py b/tests/fixtures/arrays.py index c2f45483e..8af4e90f2 100644 --- a/tests/fixtures/arrays.py +++ b/tests/fixtures/arrays.py @@ -22,13 +22,13 @@ def zarr_array(tmp_path): file_name=tmp_path / "zarr_array.zarr", dataset="volumes/test", ) - zarr_container = zarr.open(str(tmp_path / "zarr_array.zarr")) + zarr_container = zarr.open(str(zarr_array_config.file_name)) dataset = zarr_container.create_dataset( - "volumes/test", data=np.zeros((100, 50, 25)) + zarr_array_config.dataset, data=np.zeros((100, 50, 25), dtype=np.float32) ) dataset.attrs["offset"] = (12, 12, 12) dataset.attrs["resolution"] = (1, 2, 4) - dataset.attrs["axes"] = "zyx" + dataset.attrs["axes"] = ["zyx"] yield zarr_array_config @@ -39,9 +39,10 @@ def cellmap_array(tmp_path): file_name=tmp_path / "zarr_array.zarr", dataset="volumes/test", ) - zarr_container = zarr.open(str(tmp_path / "zarr_array.zarr")) + zarr_container = zarr.open(str(zarr_array_config.file_name)) dataset = zarr_container.create_dataset( - "volumes/test", data=np.arange(0, 100).reshape(10, 5, 2) + zarr_array_config.dataset, + data=np.arange(0, 100, dtype=np.uint8).reshape(10, 5, 2), ) dataset.attrs["offset"] = (12, 12, 12) dataset.attrs["resolution"] = (1, 2, 4) diff --git a/tests/fixtures/datasplits.py b/tests/fixtures/datasplits.py index 5af5a4036..7bb5672c6 100644 --- a/tests/fixtures/datasplits.py +++ b/tests/fixtures/datasplits.py @@ -67,7 +67,7 @@ def twelve_class_datasplit(tmp_path): gt_dataset = twelve_class_zarr.create_dataset( gt.dataset, shape=(40, 20, 20), dtype=np.uint8 ) - random_data = np.random.randn(40, 20, 20) + random_data = np.random.rand(40, 20, 20) # as intensities increase so does the class for i in list(np.linspace(random_data.min(), random_data.max(), 13))[1:]: gt_dataset[:] += random_data > i @@ -178,7 +178,7 @@ def six_class_datasplit(tmp_path): gt_dataset = twelve_class_zarr.create_dataset( gt.dataset, shape=(40, 20, 20), dtype=np.uint8 ) - random_data = np.random.randn(40, 20, 20) + random_data = np.random.rand(40, 20, 20) # as intensities increase so does the class for i in list(np.linspace(random_data.min(), random_data.max(), 13))[1:]: gt_dataset[:] += random_data > i diff --git a/tests/fixtures/db.py b/tests/fixtures/db.py index 533950496..9429f5507 100644 --- a/tests/fixtures/db.py +++ b/tests/fixtures/db.py @@ -3,18 +3,18 @@ import pymongo import pytest +import os +from pathlib import Path +import yaml + def mongo_db_available(): - try: - options = Options.instance() - client = pymongo.MongoClient( - host=options.mongo_db_host, serverSelectionTimeoutMS=1000 - ) - Options._instance = None - except RuntimeError: - # cannot find a dacapo config file, mongodb is not available - Options._instance = None - return False + options = Options.instance() + client = pymongo.MongoClient( + host=options.mongo_db_host, + serverSelectionTimeoutMS=1000, + socketTimeoutMS=1000, + ) try: client.admin.command("ping") return True @@ -34,23 +34,31 @@ def mongo_db_available(): ) ) def options(request, tmp_path): - # TODO: Clean up this fixture. Its a bit clunky to use. - # Maybe just write the dacapo.yaml file instead of assigning to Options._instance - kwargs_from_file = {} - if request.param == "mongo": - options_from_file = Options.instance() - kwargs_from_file.update( - { - "mongo_db_host": options_from_file.mongo_db_host, - "mongo_db_name": "dacapo_tests", - } - ) - Options._instance = None + # read the options from the users config file locally options = Options.instance( - type=request.param, runs_base_dir=f"{tmp_path}", **kwargs_from_file + type=request.param, runs_base_dir="tests", mongo_db_name="dacapo_tests" ) + + # change to a temporary directory for this test only + old_dir = os.getcwd() + os.chdir(tmp_path) + + # write the dacapo config in the current temporary directory. Now options + # will be read from this file instead of the users config file letting + # us test different configurations + config_file = Path("dacapo.yaml") + with open(config_file, "w") as f: + yaml.safe_dump(options.serialize(), f) + # config_file.write_text(options.serialize() + # ) + + # yield the options yield options + + # cleanup if request.param == "mongo": client = pymongo.MongoClient(host=options.mongo_db_host) client.drop_database("dacapo_tests") - Options._instance = None + + # reset working directory + os.chdir(old_dir) diff --git a/tests/operations/test_apply.py b/tests/operations/test_apply.py index 53ca30b7f..64390d0c3 100644 --- a/tests/operations/test_apply.py +++ b/tests/operations/test_apply.py @@ -1,7 +1,6 @@ from ..fixtures import * from dacapo.experiments import Run -from dacapo.compute_context import LocalTorch from dacapo.store.create_store import create_config_store, create_weights_store from dacapo import apply @@ -21,14 +20,7 @@ lazy_fixture("onehot_run"), ], ) -def test_apply( - options, - run_config, -): - # TODO: test the apply function - return # remove this line to run the test - compute_context = LocalTorch(device="cpu") - +def test_apply(options, run_config, zarr_array, tmp_path): # create a store store = create_config_store() @@ -44,13 +36,38 @@ def test_apply( # ------------------------------------- # apply + parameters = list(run.task.post_processor.enumerate_parameters())[0] # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - apply(run_config.name, 0, compute_context=compute_context) + apply( + run_config.name, + zarr_array.file_name, + zarr_array.dataset, + output_path=tmp_path, + iteration=0, + parameters=parameters, + num_workers=4, + ) weights_store.store_weights(run, 1) - apply(run_config.name, 1, compute_context=compute_context) + apply( + run_config.name, + zarr_array.file_name, + zarr_array.dataset, + output_path=tmp_path, + iteration=1, + parameters=parameters, + num_workers=4, + ) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - apply(run_config.name, 2, compute_context=compute_context) + apply( + run_config.name, + zarr_array.file_name, + zarr_array.dataset, + output_path=tmp_path, + iteration=2, + parameters=parameters, + num_workers=4, + ) diff --git a/tests/operations/test_predict.py b/tests/operations/test_predict.py new file mode 100644 index 000000000..f537c8fd9 --- /dev/null +++ b/tests/operations/test_predict.py @@ -0,0 +1,70 @@ +from ..fixtures import * + +from dacapo.experiments import Run +from dacapo.store.create_store import create_config_store, create_weights_store +from dacapo import predict + +import pytest +from pytest_lazyfixture import lazy_fixture + +import logging + +logging.basicConfig(level=logging.INFO) + + +@pytest.mark.parametrize( + "run_config", + [ + lazy_fixture("distance_run"), + lazy_fixture("dummy_run"), + lazy_fixture("onehot_run"), + ], +) +def test_predict(options, run_config, zarr_array, tmp_path): + # os.environ["PYDEVD_UNBLOCK_THREADS_TIMEOUT"] = "2.0" + + # create a store + + store = create_config_store() + weights_store = create_weights_store() + + # store the configs + + store.store_run_config(run_config) + + run_config = store.retrieve_run_config(run_config.name) + run = Run(run_config) + + # ------------------------------------- + + # predict + # test predicting with iterations for which we know there are weights + weights_store.store_weights(run, 0) + predict( + run_config.name, + iteration=0, + input_container=zarr_array.file_name, + input_dataset=zarr_array.dataset, + output_path=tmp_path, + num_workers=4, + ) + weights_store.store_weights(run, 1) + predict( + run_config.name, + iteration=1, + input_container=zarr_array.file_name, + input_dataset=zarr_array.dataset, + output_path=tmp_path, + num_workers=4, + ) + + # test predicting with iterations for which we know there are no weights + with pytest.raises(ValueError): + predict( + run_config.name, + iteration=2, + input_container=zarr_array.file_name, + input_dataset=zarr_array.dataset, + output_path=tmp_path, + num_workers=4, + ) diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index ebcd4f1ad..dcfcd12c0 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -2,7 +2,6 @@ from ..fixtures import * from dacapo.experiments import Run -from dacapo.compute_context import LocalTorch from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.train import train_run @@ -29,8 +28,6 @@ def test_train( options, run_config, ): - compute_context = LocalTorch(device="cpu") - # create a store store = create_config_store() @@ -47,7 +44,7 @@ def test_train( # train weights_store.store_weights(run, 0) - train_run(run, compute_context=compute_context) + train_run(run) init_weights = weights_store.retrieve_weights(run.name, 0) final_weights = weights_store.retrieve_weights(run.name, run.train_until) diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index 54d6dc5e4..7489da40d 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -1,7 +1,6 @@ from ..fixtures import * from dacapo.experiments import Run -from dacapo.compute_context import LocalTorch from dacapo.store.create_store import create_config_store, create_weights_store from dacapo import validate @@ -25,8 +24,6 @@ def test_validate( options, run_config, ): - compute_context = LocalTorch(device="cpu") - # create a store store = create_config_store() @@ -45,10 +42,10 @@ def test_validate( # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - validate(run_config.name, 0, compute_context=compute_context) + validate(run_config.name, 0, num_workers=4) weights_store.store_weights(run, 1) - validate(run_config.name, 1, compute_context=compute_context) + validate(run_config.name, 1, num_workers=4) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - validate(run_config.name, 2, compute_context=compute_context) + validate(run_config.name, 2, num_workers=4)