From b13902f2d42fec1903fea7f65b7fd6ed3c5c6882 Mon Sep 17 00:00:00 2001 From: William Patton Date: Thu, 15 Feb 2024 16:57:48 -0800 Subject: [PATCH 01/19] overhaul options --- dacapo/options.py | 92 ++++++++++++++++++++------------ dacapo/store/create_store.py | 20 +++---- tests/components/test_options.py | 59 ++++++++++++++++++++ tests/fixtures/db.py | 64 +++++++++++++--------- 4 files changed, 164 insertions(+), 71 deletions(-) create mode 100644 tests/components/test_options.py diff --git a/dacapo/options.py b/dacapo/options.py index cea11b38b..7d41240c1 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -3,8 +3,44 @@ from os.path import expanduser from pathlib import Path +import attr +from cattr import Converter + +from typing import Optional + + logger = logging.getLogger(__name__) + +@attr.s +class DaCapoConfig: + type: str = attr.ib( + default="files", + metadata={ + "help_text": "The type of store to use for storing configurations and statistics. " + "Currently, only 'files' and 'mongo' are supported with files being the default." + }, + ) + runs_base_dir: Path = attr.ib( + default=Path(expanduser("~/.dacapo")), + metadata={ + "help_text": "The path at DaCapo will use for reading and writing any necessary data." + }, + ) + mongo_db_host: Optional[str] = attr.ib( + default=None, + metadata={ + "help_text": "The host of the MongoDB instance to use for storing configurations and statistics." + }, + ) + mongo_db_name: Optional[str] = attr.ib( + default=None, + metadata={ + "help_text": "The name of the MongoDB database to use for storing configurations and statistics." + }, + ) + + # options files in order of precedence (highest first) options_files = [ Path("./dacapo.yaml"), @@ -22,48 +58,36 @@ def parse_options(): class Options: - _instance = None def __init__(self): raise RuntimeError("Singleton: Use Options.instance()") @classmethod - def instance(cls, **kwargs): - if cls._instance is None: - cls._instance = cls.__new__(cls) - cls._instance.__parse_options(**kwargs) - - return cls._instance - - def __getattr__(self, name): - try: - return self.__options[name] - except KeyError: - raise RuntimeError( - f"Configuration file {self.filename} does not contain an " - f"entry for option {name}" - ) - - def __parse_options(self, **kwargs): - if len(kwargs) > 0: - self.__options = kwargs - self.filename = "kwargs" - return + def instance(cls, **kwargs) -> DaCapoConfig: + config = cls.__parse_options(**kwargs) + return config + + @classmethod + def config_file(cls) -> Optional[Path]: for path in options_files: - if not path.exists(): - continue + if path.exists(): + return path + return None - with path.open("r") as f: - self.__options = yaml.safe_load(f) - self.filename = path + @classmethod + def __parse_options_from_file(cls): + if (config_file := cls.config_file()) is not None: + with config_file.open("r") as f: + return yaml.safe_load(f) + else: + return {} - return + @classmethod + def __parse_options(cls, **kwargs): + options = cls.__parse_options_from_file() + options.update(kwargs) - logger.error( - "No options file found. Please create any of the following " "files:" - ) - for path in options_files: - logger.error("\t%s", path.absolute()) + converter = Converter() - raise RuntimeError("Could not find a DaCapo options file.") + return converter.structure(options, DaCapoConfig) diff --git a/dacapo/store/create_store.py b/dacapo/store/create_store.py index 47e92626f..0fcc43ed2 100644 --- a/dacapo/store/create_store.py +++ b/dacapo/store/create_store.py @@ -14,19 +14,15 @@ def create_config_store(): options = Options.instance() - try: - store_type = options.type - except RuntimeError: - store_type = "files" - if store_type == "mongo": + if options.type == "mongo": db_host = options.mongo_db_host db_name = options.mongo_db_name return MongoConfigStore(db_host, db_name) - elif store_type == "files": + elif options.type == "files": store_path = Path(options.runs_base_dir).expanduser() return FileConfigStore(store_path / "configs") else: - raise ValueError(f"Unknown store type {store_type}") + raise ValueError(f"Unknown store type {options.type}") def create_stats_store(): @@ -34,17 +30,15 @@ def create_stats_store(): options = Options.instance() - try: - store_type = options.type - except RuntimeError: - store_type = "mongo" - if store_type == "mongo": + if options.type == "mongo": db_host = options.mongo_db_host db_name = options.mongo_db_name return MongoStatsStore(db_host, db_name) - elif store_type == "files": + elif options.type == "files": store_path = Path(options.runs_base_dir).expanduser() return FileStatsStore(store_path / "stats") + else: + raise ValueError(f"Unknown store type {options.type}") def create_weights_store(): diff --git a/tests/components/test_options.py b/tests/components/test_options.py new file mode 100644 index 000000000..0ca984e1f --- /dev/null +++ b/tests/components/test_options.py @@ -0,0 +1,59 @@ +from dacapo import Options + + +from pathlib import Path +import textwrap + + +def test_no_config(): + # Parse the options + options = Options.instance() + + # Check the options + assert isinstance(options.runs_base_dir, Path) + assert options.mongo_db_host is None + assert options.mongo_db_name is None + + # Parse the options + options = Options.instance( + runs_base_dir="tmp", mongo_db_host="localhost", mongo_db_name="dacapo" + ) + + # Check the options + assert options.runs_base_dir == Path("tmp") + assert options.mongo_db_host == "localhost" + assert options.mongo_db_name == "dacapo" + + +# we need to change the working directory because +# dacapo looks for the config file in the working directory +def test_local_config_file(change_working_directory): + # Create a config file + config_file = Path("dacapo.yaml") + config_file.write_text( + textwrap.dedent( + """ + runs_base_dir: /tmp + mongo_db_host: localhost + mongo_db_name: dacapo + """ + ) + ) + + # Parse the options + options = Options.instance() + + # Check the options + assert options.runs_base_dir == Path("/tmp") + assert options.mongo_db_host == "localhost" + assert options.mongo_db_name == "dacapo" + assert Options.config_file() == config_file + + # Parse the options + options = Options.instance(runs_base_dir="/tmp2") + + # Check the options + assert options.runs_base_dir == Path("/tmp2") + assert options.mongo_db_host == "localhost" + assert options.mongo_db_name == "dacapo" + assert Options.config_file() == config_file diff --git a/tests/fixtures/db.py b/tests/fixtures/db.py index 533950496..9717b8953 100644 --- a/tests/fixtures/db.py +++ b/tests/fixtures/db.py @@ -2,19 +2,20 @@ import pymongo import pytest +import attr + +import os +from pathlib import Path +import yaml def mongo_db_available(): - try: - options = Options.instance() - client = pymongo.MongoClient( - host=options.mongo_db_host, serverSelectionTimeoutMS=1000 - ) - Options._instance = None - except RuntimeError: - # cannot find a dacapo config file, mongodb is not available - Options._instance = None - return False + options = Options.instance() + client = pymongo.MongoClient( + host=options.mongo_db_host, + serverSelectionTimeoutMS=1000, + socketTimeoutMS=1000, + ) try: client.admin.command("ping") return True @@ -34,23 +35,38 @@ def mongo_db_available(): ) ) def options(request, tmp_path): - # TODO: Clean up this fixture. Its a bit clunky to use. - # Maybe just write the dacapo.yaml file instead of assigning to Options._instance - kwargs_from_file = {} - if request.param == "mongo": - options_from_file = Options.instance() - kwargs_from_file.update( - { - "mongo_db_host": options_from_file.mongo_db_host, - "mongo_db_name": "dacapo_tests", - } - ) - Options._instance = None + + # read the options from the users config file locally options = Options.instance( - type=request.param, runs_base_dir=f"{tmp_path}", **kwargs_from_file + type=request.param, runs_base_dir="tests", mongo_db_name="dacapo_tests" ) + + # change to a temporary directory for this test only + old_dir = os.getcwd() + os.chdir(tmp_path) + + # write the dacapo config in the current temporary directory. Now options + # will be read from this file instead of the users config file letting + # us test different configurations + config_file = Path("dacapo.yaml") + config_file.write_text( + yaml.safe_dump( + attr.asdict( + options, + value_serializer=lambda inst, field, value: ( + str(value) if value is not None else None + ), + ) + ) + ) + + # yield the options yield options + + # cleanup if request.param == "mongo": client = pymongo.MongoClient(host=options.mongo_db_host) client.drop_database("dacapo_tests") - Options._instance = None + + # reset working directory + os.chdir(old_dir) From ab58caa73188e8a8dcd534e73fb0426b00bc19e3 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 16 Feb 2024 13:20:12 -0500 Subject: [PATCH 02/19] =?UTF-8?q?fix:=20=E2=9A=A1=EF=B8=8F=20Pass=20comput?= =?UTF-8?q?e=20context=20as=20part=20of=20options,=20and=20pass=20options?= =?UTF-8?q?=20file=20path=20as=20environment=20variable.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/apply.py | 5 ---- dacapo/blockwise/argmax_worker.py | 6 ++-- dacapo/blockwise/blockwise_task.py | 9 +----- dacapo/blockwise/predict_worker.py | 6 ++-- dacapo/blockwise/relabel_worker.py | 6 ++-- dacapo/blockwise/scheduler.py | 10 ------- dacapo/blockwise/segment_worker.py | 6 ++-- dacapo/blockwise/threshold_worker.py | 6 ++-- dacapo/cli.py | 19 ++---------- dacapo/compute_context/__init__.py | 2 +- dacapo/compute_context/compute_context.py | 15 ++++++++++ .../post_processors/argmax_post_processor.py | 3 -- .../tasks/post_processors/post_processor.py | 2 -- .../threshold_post_processor.py | 3 -- .../watershed_post_processor.py | 4 --- dacapo/options.py | 29 +++++++++++++------ dacapo/predict.py | 5 ---- dacapo/store/create_store.py | 12 ++++---- dacapo/train.py | 13 ++++----- dacapo/validate.py | 5 ---- tests/operations/test_apply.py | 10 ++----- tests/operations/test_train.py | 4 +-- tests/operations/test_validate.py | 8 ++--- 23 files changed, 73 insertions(+), 115 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 3d1c78974..f12f38091 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -12,7 +12,6 @@ import dacapo.experiments.tasks.post_processors as post_processors from dacapo.store.array_store import LocalArrayIdentifier from dacapo.predict import predict -from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store.create_store import ( create_config_store, @@ -36,7 +35,6 @@ def apply( roi: Optional[Roi | str] = None, num_workers: int = 30, output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore - compute_context: ComputeContext = LocalTorch(), overwrite: bool = True, file_format: str = "zarr", ): @@ -169,7 +167,6 @@ def apply( roi, num_workers, output_dtype, - compute_context, overwrite, ) @@ -184,7 +181,6 @@ def apply_run( roi: Optional[Roi] = None, num_workers: int = 30, output_dtype: Optional[np.dtype] = np.uint8, # type: ignore - compute_context: ComputeContext = LocalTorch(), overwrite: bool = True, ): """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" @@ -200,7 +196,6 @@ def apply_run( output_roi=roi, num_workers=num_workers, output_dtype=output_dtype, - compute_context=compute_context, overwrite=overwrite, ) diff --git a/dacapo/blockwise/argmax_worker.py b/dacapo/blockwise/argmax_worker.py index ac6ad044e..1398d9629 100644 --- a/dacapo/blockwise/argmax_worker.py +++ b/dacapo/blockwise/argmax_worker.py @@ -1,7 +1,7 @@ from pathlib import Path from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context import daisy @@ -74,7 +74,6 @@ def start_worker( def spawn_worker( input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. @@ -82,8 +81,9 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", diff --git a/dacapo/blockwise/blockwise_task.py b/dacapo/blockwise/blockwise_task.py index 3b8bf9f9d..c2296515e 100644 --- a/dacapo/blockwise/blockwise_task.py +++ b/dacapo/blockwise/blockwise_task.py @@ -2,15 +2,12 @@ from importlib.machinery import SourceFileLoader from pathlib import Path from daisy import Task, Roi -from dacapo.compute_context import ComputeContext -import dacapo.compute_context class DaCapoBlockwiseTask(Task): def __init__( self, worker_file: str | Path, - compute_context: ComputeContext | str, total_roi: Roi, read_roi: Roi, write_roi: Roi, @@ -21,8 +18,6 @@ def __init__( *args, **kwargs, ): - if isinstance(compute_context, str): - compute_context = getattr(dacapo.compute_context, compute_context)() # Load worker functions worker_name = Path(worker_file).stem @@ -32,9 +27,7 @@ def __init__( timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") task_id = worker_name + timestamp - process_function = worker.spawn_worker( - *args, **kwargs, compute_context=compute_context - ) + process_function = worker.spawn_worker(*args, **kwargs) if hasattr(worker, "check_function"): check_function = worker.check_function else: diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 40856f191..41ee29c09 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -6,7 +6,7 @@ from dacapo.store.array_store import LocalArrayIdentifier from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.experiments import Run -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context import gunpowder as gp import gunpowder.torch as gp_torch @@ -177,7 +177,6 @@ def spawn_worker( iteration: int, raw_array_identifier: "LocalArrayIdentifier", prediction_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. @@ -185,8 +184,9 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", diff --git a/dacapo/blockwise/relabel_worker.py b/dacapo/blockwise/relabel_worker.py index dc45fb53c..4ee2dd9af 100644 --- a/dacapo/blockwise/relabel_worker.py +++ b/dacapo/blockwise/relabel_worker.py @@ -1,7 +1,7 @@ from glob import glob import os import daisy -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context from dacapo.store.array_store import LocalArrayIdentifier from scipy.cluster.hierarchy import DisjointSet from funlib.persistence import open_ds @@ -88,7 +88,6 @@ def read_cross_block_merges(tmpdir): def spawn_worker( output_array_identifier: LocalArrayIdentifier, tmpdir: str, - compute_context: ComputeContext = LocalTorch(), *args, **kwargs, ): @@ -97,8 +96,9 @@ def spawn_worker( Args: output_array_identifier (LocalArrayIdentifier): The output array identifier tmpdir (str): The temporary directory - compute_context (ComputeContext, optional): The compute context. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index 675ca52fe..777b2c2bd 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -5,13 +5,11 @@ from funlib.geometry import Roi, Coordinate import yaml -from dacapo.compute_context import ComputeContext from dacapo.blockwise import DaCapoBlockwiseTask def run_blockwise( worker_file: str | Path, - compute_context: ComputeContext | str, total_roi: Roi, read_roi: Roi, write_roi: Roi, @@ -51,10 +49,6 @@ def run_blockwise( (either due to failed post check or application crashes or network failure) - compute_context (``ComputeContext``): - - The compute context to use for parallelization. - *args: Additional positional arguments to pass to ``worker_function``. @@ -72,7 +66,6 @@ def run_blockwise( # Make the task task = DaCapoBlockwiseTask( worker_file, - compute_context, total_roi, read_roi, write_roi, @@ -89,7 +82,6 @@ def run_blockwise( def segment_blockwise( segment_function_file: str or Path, - compute_context: ComputeContext | str, context: Coordinate, total_roi: Roi, read_roi: Roi, @@ -111,7 +103,6 @@ def segment_blockwise( # Make the task task = DaCapoBlockwiseTask( str(Path(Path(__file__).parent, "segment_worker.py")), - compute_context, total_roi.grow(context, context), read_roi, write_roi, @@ -136,7 +127,6 @@ def segment_blockwise( # Make the task task = DaCapoBlockwiseTask( str(Path(Path(__file__).parent, "relabel_worker.py")), - compute_context, total_roi, read_roi, write_roi, diff --git a/dacapo/blockwise/segment_worker.py b/dacapo/blockwise/segment_worker.py index bd15320d7..ebb57cb7d 100644 --- a/dacapo/blockwise/segment_worker.py +++ b/dacapo/blockwise/segment_worker.py @@ -8,7 +8,7 @@ import numpy as np import yaml -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier @@ -157,7 +157,6 @@ def spawn_worker( output_array_identifier: LocalArrayIdentifier, tmpdir: str, function_path: str, - compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. @@ -165,8 +164,9 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", diff --git a/dacapo/blockwise/threshold_worker.py b/dacapo/blockwise/threshold_worker.py index d8d645c2b..e848f3b66 100644 --- a/dacapo/blockwise/threshold_worker.py +++ b/dacapo/blockwise/threshold_worker.py @@ -1,7 +1,7 @@ from pathlib import Path from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier -from dacapo.compute_context import ComputeContext, LocalTorch +from dacapo.compute_context import create_compute_context import daisy @@ -76,7 +76,6 @@ def spawn_worker( input_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", threshold: float = 0.0, - compute_context: ComputeContext = LocalTorch(), ): """Spawn a worker to predict on a given dataset. @@ -84,8 +83,9 @@ def spawn_worker( model (Model): The model to use for prediction. raw_array (Array): The raw data to predict on. prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array. - compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch(). """ + compute_context = create_compute_context() + # Make the command for the worker to run command = [ "python", diff --git a/dacapo/cli.py b/dacapo/cli.py index 8c064aadc..cda08e40f 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -11,7 +11,6 @@ from dacapo.experiments.tasks.post_processors.post_processor_parameters import ( PostProcessorParameters, ) -from dacapo.compute_context import ComputeContext, LocalTorch @click.group() @@ -45,6 +44,9 @@ def train(run_name): type=int, help="The iteration at which to validate the run.", ) +@click.option("-w", "--num_workers", type=int, default=30) +@click.option("-dt", "--output_dtype", type=str, default="uint8") +@click.option("-ow", "--overwrite", is_flag=True) def validate(run_name, iteration): dacapo.validate(run_name, iteration) @@ -75,7 +77,6 @@ def validate(run_name, iteration): @click.option("-w", "--num_workers", type=int, default=30) @click.option("-dt", "--output_dtype", type=str, default="uint8") @click.option("-ow", "--overwrite", is_flag=True) -@click.option("-cc", "--compute_context", type=str, default="LocalTorch") def apply( run_name: str, input_container: Path | str, @@ -89,11 +90,7 @@ def apply( num_workers: int = 30, output_dtype: Optional[np.dtype | str] = "uint8", overwrite: bool = True, - compute_context: Optional[ComputeContext | str] = LocalTorch(), ): - if isinstance(compute_context, str): - compute_context = getattr(compute_context, compute_context)() - dacapo.apply( run_name, input_container, @@ -107,7 +104,6 @@ def apply( num_workers, output_dtype, overwrite=overwrite, - compute_context=compute_context, # type: ignore ) @@ -139,13 +135,6 @@ def apply( ) @click.option("-w", "--num_workers", type=int, default=30) @click.option("-dt", "--output_dtype", type=str, default="uint8") -@click.option( - "-cc", - "--compute_context", - type=str, - default="LocalTorch", - help="The compute context to use for prediction. Must be the name of a subclass of ComputeContext.", -) @click.option("-ow", "--overwrite", is_flag=True) def predict( run_name: str, @@ -156,7 +145,6 @@ def predict( output_roi: Optional[str | Roi] = None, num_workers: int = 30, output_dtype: np.dtype | str = np.uint8, # type: ignore - compute_context: ComputeContext | str = LocalTorch(), overwrite: bool = True, ): dacapo.predict( @@ -168,6 +156,5 @@ def predict( output_roi, num_workers, output_dtype, - compute_context, overwrite, ) diff --git a/dacapo/compute_context/__init__.py b/dacapo/compute_context/__init__.py index c1d859c50..e2cc14722 100644 --- a/dacapo/compute_context/__init__.py +++ b/dacapo/compute_context/__init__.py @@ -1,3 +1,3 @@ -from .compute_context import ComputeContext # noqa +from .compute_context import ComputeContext, create_compute_context # noqa from .local_torch import LocalTorch # noqa from .bsub import Bsub # noqa diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index 1cf660188..da56e854b 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod import subprocess +from dacapo import Options, compute_context + class ComputeContext(ABC): @property @@ -19,3 +21,16 @@ def wrap_command(self, command): def execute(self, command): # A helper method to run a command in the context specific way. return subprocess.run(self.wrap_command(command)) + + +def create_compute_context(): + """Create a compute context based on the global DaCapo options.""" + + options = Options.instance() + + if hasattr(compute_context, options.compute_context_config["type"]): + return getattr(compute_context, options.compute_context_config["type"])( + **options.compute_context_config["config"] + ) + else: + raise ValueError(f"Unknown store type {options.db_type}") diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 02f8b1202..5fa6009ec 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -1,6 +1,5 @@ from pathlib import Path from dacapo.blockwise.scheduler import run_blockwise -from dacapo.compute_context import ComputeContext, LocalTorch from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier from .argmax_post_processor_parameters import ArgmaxPostProcessorParameters @@ -28,7 +27,6 @@ def process( self, parameters, output_array_identifier, - compute_context: ComputeContext | str = LocalTorch(), num_workers: int = 16, block_size: Coordinate = Coordinate((64, 64, 64)), ): @@ -47,7 +45,6 @@ def process( worker_file=str( Path(Path(__file__).parent, "blockwise", "predict_worker.py") ), - compute_context=compute_context, total_roi=self.prediction_array.roi, read_roi=read_roi, write_roi=read_roi, diff --git a/dacapo/experiments/tasks/post_processors/post_processor.py b/dacapo/experiments/tasks/post_processors/post_processor.py index 585063828..f0a991c51 100644 --- a/dacapo/experiments/tasks/post_processors/post_processor.py +++ b/dacapo/experiments/tasks/post_processors/post_processor.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from dacapo.compute_context import ComputeContext, LocalTorch from funlib.geometry import Coordinate from typing import Iterable, TYPE_CHECKING @@ -35,7 +34,6 @@ def process( self, parameters: "PostProcessorParameters", output_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext | str = LocalTorch(), num_workers: int = 16, chunk_size: Coordinate = Coordinate((64, 64, 64)), ) -> "Array": diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index bbdc76aa1..712ef5baf 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -1,6 +1,5 @@ from pathlib import Path from dacapo.blockwise.scheduler import run_blockwise -from dacapo.compute_context import ComputeContext, LocalTorch from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from .threshold_post_processor_parameters import ThresholdPostProcessorParameters from .post_processor import PostProcessor @@ -34,7 +33,6 @@ def process( self, parameters: "ThresholdPostProcessorParameters", # type: ignore[override] output_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext | str = LocalTorch(), num_workers: int = 16, block_size: Coordinate = Coordinate((64, 64, 64)), ) -> ZarrArray: @@ -64,7 +62,6 @@ def process( worker_file=str( Path(Path(__file__).parent, "blockwise", "predict_worker.py") ), - compute_context=compute_context, total_roi=self.prediction_array.roi, read_roi=read_roi, write_roi=read_roi, diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 64bec66e8..6fac51735 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -1,12 +1,10 @@ from pathlib import Path from dacapo.blockwise.scheduler import segment_blockwise -from dacapo.compute_context import ComputeContext, LocalTorch from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier from .watershed_post_processor_parameters import WatershedPostProcessorParameters from .post_processor import PostProcessor -from dacapo.compute_context import ComputeContext, LocalTorch from funlib.geometry import Coordinate, Roi @@ -36,7 +34,6 @@ def process( self, parameters: WatershedPostProcessorParameters, # type: ignore[override] output_array_identifier: "LocalArrayIdentifier", - compute_context: ComputeContext | str = LocalTorch(), num_workers: int = 16, block_size: Coordinate = Coordinate((64, 64, 64)), ): @@ -60,7 +57,6 @@ def process( segment_function_file=str( Path(Path(__file__).parent, "blockwise", "watershed_function.py") ), - compute_context=compute_context, context=parameters.context, total_roi=self.prediction_array.roi, read_roi=read_roi.grow(parameters.context, parameters.context), diff --git a/dacapo/options.py b/dacapo/options.py index 7d41240c1..0c3973b06 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -1,3 +1,4 @@ +import os import yaml import logging from os.path import expanduser @@ -14,7 +15,7 @@ @attr.s class DaCapoConfig: - type: str = attr.ib( + db_type: str = attr.ib( default="files", metadata={ "help_text": "The type of store to use for storing configurations and statistics. " @@ -27,6 +28,13 @@ class DaCapoConfig: "help_text": "The path at DaCapo will use for reading and writing any necessary data." }, ) + compute_context_config: dict = attr.ib( + default={"type": "LocalTorch", "config": {"_device": None}}, + metadata={ + "help_text": "The configuration for the compute context to use. " + "This is a dictionary with the keys being the names of the compute context and the values being the configuration for that context." + }, + ) mongo_db_host: Optional[str] = attr.ib( default=None, metadata={ @@ -41,13 +49,6 @@ class DaCapoConfig: ) -# options files in order of precedence (highest first) -options_files = [ - Path("./dacapo.yaml"), - Path(expanduser("~/.config/dacapo/dacapo.yaml")), -] - - def parse_options(): for path in options_files: if not path.exists(): @@ -58,7 +59,6 @@ def parse_options(): class Options: - def __init__(self): raise RuntimeError("Singleton: Use Options.instance()") @@ -70,6 +70,17 @@ def instance(cls, **kwargs) -> DaCapoConfig: @classmethod def config_file(cls) -> Optional[Path]: + env_dict = dict(os.environ) + if "OPTIONS_FILE" in env_dict: + options_files = [Path(env_dict["OPTIONS_FILE"])] + else: + options_files = [] + + # options files in order of precedence (highest first) + options_files += [ + Path("./dacapo.yaml"), + Path(expanduser("~/.config/dacapo/dacapo.yaml")), + ] for path in options_files: if path.exists(): return path diff --git a/dacapo/predict.py b/dacapo/predict.py index 4ce3f98bf..bfd46cb43 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -1,13 +1,10 @@ from pathlib import Path -import click from dacapo.blockwise import run_blockwise from dacapo.experiments import Run from dacapo.store.create_store import create_config_store from dacapo.store.local_array_store import LocalArrayIdentifier -from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.experiments.datasplits.datasets.arrays import ZarrArray -from dacapo.cli import cli from funlib.geometry import Coordinate, Roi import numpy as np @@ -28,7 +25,6 @@ def predict( output_roi: Optional[Roi | str] = None, num_workers: int = 30, output_dtype: np.dtype | str = np.uint8, # type: ignore - compute_context: ComputeContext | str = LocalTorch(), overwrite: bool = True, ): """_summary_ @@ -117,7 +113,6 @@ def predict( # run blockwise prediction run_blockwise( worker_file=str(Path(Path(__file__).parent, "blockwise", "predict_worker.py")), - compute_context=compute_context, total_roi=_input_roi, read_roi=Roi((0, 0, 0), input_size), write_roi=Roi((0, 0, 0), output_size), diff --git a/dacapo/store/create_store.py b/dacapo/store/create_store.py index 0fcc43ed2..b5442bbad 100644 --- a/dacapo/store/create_store.py +++ b/dacapo/store/create_store.py @@ -14,15 +14,15 @@ def create_config_store(): options = Options.instance() - if options.type == "mongo": + if options.db_type == "mongo": db_host = options.mongo_db_host db_name = options.mongo_db_name return MongoConfigStore(db_host, db_name) - elif options.type == "files": + elif options.db_type == "files": store_path = Path(options.runs_base_dir).expanduser() return FileConfigStore(store_path / "configs") else: - raise ValueError(f"Unknown store type {options.type}") + raise ValueError(f"Unknown store type {options.db_type}") def create_stats_store(): @@ -30,15 +30,15 @@ def create_stats_store(): options = Options.instance() - if options.type == "mongo": + if options.db_type == "mongo": db_host = options.mongo_db_host db_name = options.mongo_db_name return MongoStatsStore(db_host, db_name) - elif options.type == "files": + elif options.db_type == "files": store_path = Path(options.runs_base_dir).expanduser() return FileStatsStore(store_path / "stats") else: - raise ValueError(f"Unknown store type {options.type}") + raise ValueError(f"Unknown store type {options.db_type}") def create_weights_store(): diff --git a/dacapo/train.py b/dacapo/train.py index abf5ad48c..9ab270e5c 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -1,3 +1,4 @@ +from dacapo.compute_context import create_compute_context from dacapo.store.create_store import ( create_array_store, create_config_store, @@ -5,7 +6,6 @@ create_weights_store, ) from dacapo.experiments import Run -from dacapo.compute_context import LocalTorch, ComputeContext from dacapo.validate import validate_run import torch @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -def train(run_name: str, compute_context: ComputeContext = LocalTorch()): +def train(run_name: str): """Train a run""" # check config store to see if run is already being trained TODO @@ -34,13 +34,10 @@ def train(run_name: str, compute_context: ComputeContext = LocalTorch()): run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) - return train_run(run, compute_context=compute_context) + return train_run(run) -def train_run( - run: Run, - compute_context: ComputeContext = LocalTorch(), -): +def train_run(run: Run): logger.info("Starting/resuming training for run %s...", run) # create run @@ -117,6 +114,7 @@ def train_run( # loading weights directly from a checkpoint into cuda # can allocate twice the memory of loading to cpu before # moving to cuda. + compute_context = create_compute_context() run.model = run.model.to(compute_context.device) run.move_optimizer(compute_context.device) @@ -175,7 +173,6 @@ def train_run( validate_run( run, iteration_stats.iteration + 1, - compute_context=compute_context, ) stats_store.store_validation_iteration_scores( run.name, run.validation_scores diff --git a/dacapo/validate.py b/dacapo/validate.py index 65fcb03d8..74ffd9ddb 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -1,5 +1,4 @@ from .predict import predict -from .compute_context import LocalTorch, ComputeContext from .experiments import Run, ValidationIterationScores from .experiments.datasplits.datasets.arrays import ZarrArray from .store.create_store import ( @@ -18,7 +17,6 @@ def validate( run_name: str, iteration: int, - compute_context: ComputeContext = LocalTorch(), num_workers: int = 30, output_dtype: str = "uint8", overwrite: bool = True, @@ -50,7 +48,6 @@ def validate( return validate_run( run, iteration, - compute_context=compute_context, num_workers=num_workers, output_dtype=output_dtype, overwrite=overwrite, @@ -60,7 +57,6 @@ def validate( def validate_run( run: Run, iteration: int, - compute_context: ComputeContext = LocalTorch(), num_workers: int = 30, output_dtype: str = "uint8", overwrite: bool = True, @@ -163,7 +159,6 @@ def validate_run( output_roi=validation_dataset.gt.roi, num_workers=num_workers, output_dtype=output_dtype, - compute_context=compute_context, overwrite=overwrite, ) diff --git a/tests/operations/test_apply.py b/tests/operations/test_apply.py index 53ca30b7f..a59e41e91 100644 --- a/tests/operations/test_apply.py +++ b/tests/operations/test_apply.py @@ -1,7 +1,6 @@ from ..fixtures import * from dacapo.experiments import Run -from dacapo.compute_context import LocalTorch from dacapo.store.create_store import create_config_store, create_weights_store from dacapo import apply @@ -25,9 +24,6 @@ def test_apply( options, run_config, ): - # TODO: test the apply function - return # remove this line to run the test - compute_context = LocalTorch(device="cpu") # create a store @@ -47,10 +43,10 @@ def test_apply( # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - apply(run_config.name, 0, compute_context=compute_context) + apply(run_config.name, 0) weights_store.store_weights(run, 1) - apply(run_config.name, 1, compute_context=compute_context) + apply(run_config.name, 1) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - apply(run_config.name, 2, compute_context=compute_context) + apply(run_config.name, 2) diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index ebcd4f1ad..6585971b6 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -2,7 +2,6 @@ from ..fixtures import * from dacapo.experiments import Run -from dacapo.compute_context import LocalTorch from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.train import train_run @@ -29,7 +28,6 @@ def test_train( options, run_config, ): - compute_context = LocalTorch(device="cpu") # create a store @@ -47,7 +45,7 @@ def test_train( # train weights_store.store_weights(run, 0) - train_run(run, compute_context=compute_context) + train_run(run) init_weights = weights_store.retrieve_weights(run.name, 0) final_weights = weights_store.retrieve_weights(run.name, run.train_until) diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index 54d6dc5e4..b214d64e3 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -1,7 +1,6 @@ from ..fixtures import * from dacapo.experiments import Run -from dacapo.compute_context import LocalTorch from dacapo.store.create_store import create_config_store, create_weights_store from dacapo import validate @@ -25,7 +24,6 @@ def test_validate( options, run_config, ): - compute_context = LocalTorch(device="cpu") # create a store @@ -45,10 +43,10 @@ def test_validate( # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - validate(run_config.name, 0, compute_context=compute_context) + validate(run_config.name, 0) weights_store.store_weights(run, 1) - validate(run_config.name, 1, compute_context=compute_context) + validate(run_config.name, 1) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - validate(run_config.name, 2, compute_context=compute_context) + validate(run_config.name, 2) From 9f1f84dce42403e663844caca5789f45c734ac70 Mon Sep 17 00:00:00 2001 From: Larissa Heinrich Date: Fri, 16 Feb 2024 15:33:40 -0500 Subject: [PATCH 03/19] fix: :poop: making options/config serializable --- dacapo/options.py | 6 +++++- tests/fixtures/db.py | 15 ++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/dacapo/options.py b/dacapo/options.py index 0c3973b06..5d6c7770d 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -29,7 +29,7 @@ class DaCapoConfig: }, ) compute_context_config: dict = attr.ib( - default={"type": "LocalTorch", "config": {"_device": None}}, + default={"type": "LocalTorch", "config": {"device": None}}, metadata={ "help_text": "The configuration for the compute context to use. " "This is a dictionary with the keys being the names of the compute context and the values being the configuration for that context." @@ -48,6 +48,10 @@ class DaCapoConfig: }, ) + def serialize(self): + converter = Converter() + return converter.unstructure(self) + def parse_options(): for path in options_files: diff --git a/tests/fixtures/db.py b/tests/fixtures/db.py index 9717b8953..985b0840f 100644 --- a/tests/fixtures/db.py +++ b/tests/fixtures/db.py @@ -2,7 +2,6 @@ import pymongo import pytest -import attr import os from pathlib import Path @@ -49,16 +48,10 @@ def options(request, tmp_path): # will be read from this file instead of the users config file letting # us test different configurations config_file = Path("dacapo.yaml") - config_file.write_text( - yaml.safe_dump( - attr.asdict( - options, - value_serializer=lambda inst, field, value: ( - str(value) if value is not None else None - ), - ) - ) - ) + with open(config_file, "w") as f: + yaml.safe_dump(options.serialize(), f) + # config_file.write_text(options.serialize() + # ) # yield the options yield options From 7a39d9d86386799116ccc074d4b60c4f7705e8de Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 16 Feb 2024 15:53:14 -0500 Subject: [PATCH 04/19] Remove unused code in options.py --- dacapo/options.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/dacapo/options.py b/dacapo/options.py index 5d6c7770d..41a6fda50 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -53,15 +53,6 @@ def serialize(self): return converter.unstructure(self) -def parse_options(): - for path in options_files: - if not path.exists(): - continue - - with path.open("r") as f: - return yaml.safe_load(f) - - class Options: def __init__(self): raise RuntimeError("Singleton: Use Options.instance()") From f5351611793de141137cdaa1a90ac47439bb66ee Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 16 Feb 2024 21:02:37 +0000 Subject: [PATCH 05/19] :art: Format Python code with psf/black --- dacapo/blockwise/blockwise_task.py | 1 - tests/fixtures/db.py | 1 - tests/operations/test_apply.py | 1 - tests/operations/test_train.py | 1 - tests/operations/test_validate.py | 1 - 5 files changed, 5 deletions(-) diff --git a/dacapo/blockwise/blockwise_task.py b/dacapo/blockwise/blockwise_task.py index c2296515e..54e1b7347 100644 --- a/dacapo/blockwise/blockwise_task.py +++ b/dacapo/blockwise/blockwise_task.py @@ -18,7 +18,6 @@ def __init__( *args, **kwargs, ): - # Load worker functions worker_name = Path(worker_file).stem worker = SourceFileLoader(worker_name, str(worker_file)).load_module() diff --git a/tests/fixtures/db.py b/tests/fixtures/db.py index 985b0840f..9429f5507 100644 --- a/tests/fixtures/db.py +++ b/tests/fixtures/db.py @@ -34,7 +34,6 @@ def mongo_db_available(): ) ) def options(request, tmp_path): - # read the options from the users config file locally options = Options.instance( type=request.param, runs_base_dir="tests", mongo_db_name="dacapo_tests" diff --git a/tests/operations/test_apply.py b/tests/operations/test_apply.py index a59e41e91..7bdf2bf39 100644 --- a/tests/operations/test_apply.py +++ b/tests/operations/test_apply.py @@ -24,7 +24,6 @@ def test_apply( options, run_config, ): - # create a store store = create_config_store() diff --git a/tests/operations/test_train.py b/tests/operations/test_train.py index 6585971b6..dcfcd12c0 100644 --- a/tests/operations/test_train.py +++ b/tests/operations/test_train.py @@ -28,7 +28,6 @@ def test_train( options, run_config, ): - # create a store store = create_config_store() diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index b214d64e3..df4f4774b 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -24,7 +24,6 @@ def test_validate( options, run_config, ): - # create a store store = create_config_store() From 89382b06960799d65c2666ac28cfe09ff9482075 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Fri, 16 Feb 2024 16:44:42 -0500 Subject: [PATCH 06/19] =?UTF-8?q?fix:=20=F0=9F=A7=AA=20Fix=20test=5Fapply?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/fixtures/post_processors.py | 1 + tests/operations/test_apply.py | 33 ++++++++++++++++++++++++------- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/fixtures/post_processors.py b/tests/fixtures/post_processors.py index 61d01e6d8..95f51abbf 100644 --- a/tests/fixtures/post_processors.py +++ b/tests/fixtures/post_processors.py @@ -1,6 +1,7 @@ from dacapo.experiments.tasks.post_processors import ( ArgmaxPostProcessor, ThresholdPostProcessor, + PostProcessorParameters, ) import pytest diff --git a/tests/operations/test_apply.py b/tests/operations/test_apply.py index 7bdf2bf39..8d5b2f781 100644 --- a/tests/operations/test_apply.py +++ b/tests/operations/test_apply.py @@ -20,10 +20,7 @@ lazy_fixture("onehot_run"), ], ) -def test_apply( - options, - run_config, -): +def test_apply(options, run_config, zarr_array, tmp_path): # create a store store = create_config_store() @@ -39,13 +36,35 @@ def test_apply( # ------------------------------------- # apply + parameters = list(run.task.post_processor.enumerate_parameters())[0] # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - apply(run_config.name, 0) + apply( + run_config.name, + zarr_array.file_name, + zarr_array.dataset, + output_path=tmp_path, + iteration=0, + parameters=parameters, + ) weights_store.store_weights(run, 1) - apply(run_config.name, 1) + apply( + run_config.name, + zarr_array.file_name, + zarr_array.dataset, + output_path=tmp_path, + iteration=1, + parameters=parameters, + ) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - apply(run_config.name, 2) + apply( + run_config.name, + zarr_array.file_name, + zarr_array.dataset, + output_path=tmp_path, + iteration=2, + parameters=parameters, + ) From 37fb9339d811ed2a7d09ae024085ae133c16c950 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Sun, 18 Feb 2024 23:18:57 -0500 Subject: [PATCH 07/19] =?UTF-8?q?fix:=20=F0=9F=A7=AA=20Working=20on=20test?= =?UTF-8?q?s.=20All=20"components"=20are=20passing.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/train.py | 5 ++++- tests/components/test_options.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/dacapo/train.py b/dacapo/train.py index 9ab270e5c..7802ded60 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -153,11 +153,14 @@ def train_run(run: Run): trained_until = run.training_stats.trained_until() # If this is not a validation iteration or final iteration, skip validation + # also skip for test cases where total iterations is less than validation interval no_its = iteration_stats is None # No training steps run validation_it = ( iteration_stats.iteration + 1 ) % run.validation_interval == 0 - final_it = trained_until >= run.train_until + final_it = (trained_until >= run.train_until) and ( + run.train_until >= run.validation_interval + ) if no_its or (not validation_it and not final_it): stats_store.store_training_stats(run.name, run.training_stats) continue diff --git a/tests/components/test_options.py b/tests/components/test_options.py index 0ca984e1f..d0f354953 100644 --- a/tests/components/test_options.py +++ b/tests/components/test_options.py @@ -27,7 +27,7 @@ def test_no_config(): # we need to change the working directory because # dacapo looks for the config file in the working directory -def test_local_config_file(change_working_directory): +def test_local_config_file(): # Create a config file config_file = Path("dacapo.yaml") config_file.write_text( From 3775d6d221166058ee35f0dda0c95e2a3c919412 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 19 Feb 2024 14:40:24 -0500 Subject: [PATCH 08/19] =?UTF-8?q?fix:=20=E2=9C=85=20Testing=20train=20pass?= =?UTF-8?q?es.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixed by 1) removing testing validation from tests of training, and 2) fixing the dummy trainer so that it moves arrays to the proper devices during training. --- dacapo/experiments/trainers/dummy_trainer.py | 29 ++++++++++++-------- dacapo/train.py | 12 ++++++-- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/dacapo/experiments/trainers/dummy_trainer.py b/dacapo/experiments/trainers/dummy_trainer.py index 85c7c1ee8..7e826979a 100644 --- a/dacapo/experiments/trainers/dummy_trainer.py +++ b/dacapo/experiments/trainers/dummy_trainer.py @@ -15,28 +15,35 @@ def __init__(self, trainer_config): self.mirror_augment = trainer_config.mirror_augment def create_optimizer(self, model): - return torch.optim.Adam(lr=self.learning_rate, params=model.parameters()) + return torch.optim.RAdam(lr=self.learning_rate, params=model.parameters()) def iterate(self, num_iterations: int, model: Model, optimizer, device): target_iteration = self.iteration + num_iterations - for self.iteration in range(self.iteration, target_iteration): + for iteration in range(self.iteration, target_iteration): optimizer.zero_grad() - raw = torch.from_numpy( - np.random.randn(1, model.num_in_channels, *model.input_shape) - ).float() - target = torch.from_numpy( - np.zeros((1, model.num_out_channels, *model.output_shape)) - ).float() + raw = ( + torch.from_numpy( + np.random.randn(1, model.num_in_channels, *model.input_shape) + ) + .float() + .to(device) + ) + target = ( + torch.from_numpy( + np.zeros((1, model.num_out_channels, *model.output_shape)) + ) + .float() + .to(device) + ) pred = model.forward(raw) loss = self._loss.compute(pred, target) loss.backward() optimizer.step() yield TrainingIterationStats( - loss=1.0 / (self.iteration + 1), iteration=self.iteration, time=0.1 + loss=1.0 / (iteration + 1), iteration=iteration, time=0.1 ) - - self.iteration += 1 + self.iteration += 1 def build_batch_provider(self, datasplit, architecture, task, snapshot_container): self._loss = task.loss diff --git a/dacapo/train.py b/dacapo/train.py index 7802ded60..620f08e55 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -158,9 +158,15 @@ def train_run(run: Run): validation_it = ( iteration_stats.iteration + 1 ) % run.validation_interval == 0 - final_it = (trained_until >= run.train_until) and ( - run.train_until >= run.validation_interval - ) + final_it = trained_until >= run.train_until + if final_it and (trained_until < run.validation_interval): + # Special case for tests - skip validation, but store weights + stats_store.store_training_stats(run.name, run.training_stats) + weights_store.store_weights(run, iteration_stats.iteration + 1) + run.move_optimizer(compute_context.device) + run.model.train() + continue + if no_its or (not validation_it and not final_it): stats_store.store_training_stats(run.name, run.training_stats) continue From e365495bfe30cfa3def23c7fb0177b9db117875a Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 19 Feb 2024 15:10:49 -0500 Subject: [PATCH 09/19] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Fixed=20test=5Fopti?= =?UTF-8?q?ons=20for=20no=5Fconfig?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Made sure there wasn't any config file present to be read. --- tests/components/test_options.py | 5 +++ tests/operations/test_predict.py | 65 ++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 tests/operations/test_predict.py diff --git a/tests/components/test_options.py b/tests/components/test_options.py index d0f354953..7ac7e1488 100644 --- a/tests/components/test_options.py +++ b/tests/components/test_options.py @@ -6,6 +6,11 @@ def test_no_config(): + # Make sure the config file does not exist + config_file = Path("dacapo.yaml") + if config_file.exists(): + config_file.unlink() + # Parse the options options = Options.instance() diff --git a/tests/operations/test_predict.py b/tests/operations/test_predict.py new file mode 100644 index 000000000..8a095be45 --- /dev/null +++ b/tests/operations/test_predict.py @@ -0,0 +1,65 @@ +from ..fixtures import * + +from dacapo.experiments import Run +from dacapo.store.create_store import create_config_store, create_weights_store +from dacapo import predict + +import pytest +from pytest_lazyfixture import lazy_fixture + +import logging + +logging.basicConfig(level=logging.INFO) + + +@pytest.mark.parametrize( + "run_config", + [ + lazy_fixture("distance_run"), + lazy_fixture("dummy_run"), + lazy_fixture("onehot_run"), + ], +) +def test_predict(options, run_config, zarr_array, tmp_path): + # create a store + + store = create_config_store() + weights_store = create_weights_store() + + # store the configs + + store.store_run_config(run_config) + + run_config = store.retrieve_run_config(run_config.name) + run = Run(run_config) + + # ------------------------------------- + + # predict + # test predicting with iterations for which we know there are weights + weights_store.store_weights(run, 0) + predict( + run_config.name, + iteration=0, + input_container=zarr_array.file_name, + input_dataset=zarr_array.dataset, + output_path=tmp_path, + ) + weights_store.store_weights(run, 1) + predict( + run_config.name, + iteration=1, + input_container=zarr_array.file_name, + input_dataset=zarr_array.dataset, + output_path=tmp_path, + ) + + # test predicting with iterations for which we know there are no weights + with pytest.raises(ValueError): + predict( + run_config.name, + iteration=2, + input_container=zarr_array.file_name, + input_dataset=zarr_array.dataset, + output_path=tmp_path, + ) From 645be34d3ba182911cb4baf4e614105e3a716df6 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 19 Feb 2024 17:10:58 -0500 Subject: [PATCH 10/19] =?UTF-8?q?fix:=20=F0=9F=9A=A7=20Work=20to=20fix=20t?= =?UTF-8?q?est=5Fpredict.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/blockwise/argmax_worker.py | 3 ++- dacapo/blockwise/predict_worker.py | 19 ++++++++++++------- dacapo/blockwise/relabel_worker.py | 3 ++- dacapo/blockwise/segment_worker.py | 3 ++- dacapo/blockwise/threshold_worker.py | 3 ++- dacapo/predict.py | 4 +++- 6 files changed, 23 insertions(+), 12 deletions(-) diff --git a/dacapo/blockwise/argmax_worker.py b/dacapo/blockwise/argmax_worker.py index 1398d9629..e42dd0299 100644 --- a/dacapo/blockwise/argmax_worker.py +++ b/dacapo/blockwise/argmax_worker.py @@ -14,6 +14,7 @@ read_write_conflict: bool = False fit: str = "valid" +path = __file__ @click.group() @@ -87,7 +88,7 @@ def spawn_worker( # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--input_container", input_array_identifier.container, diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 41ee29c09..f98321e27 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -22,6 +22,7 @@ read_write_conflict: bool = False fit: str = "valid" +path = __file__ @click.group() @@ -112,7 +113,7 @@ def start_worker( # prepare data source pipeline = DaCapoArraySource(raw_array, raw) # raw: (c, d, h, w) - pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims)) + pipeline += gp.Pad(raw, None) # raw: (c, d, h, w) pipeline += gp.Unsqueeze([raw]) # raw: (1, c, d, h, w) @@ -155,18 +156,22 @@ def start_worker( if block is None: break - ref_request = gp.BatchRequest() - ref_request[raw] = gp.ArraySpec( - roi=block.read_roi, voxel_size=input_voxel_size, dtype=raw_array.dtype + request = gp.BatchRequest() + request[raw] = gp.ArraySpec( + roi=block.read_roi, + voxel_size=input_voxel_size, + dtype=raw_array.dtype, + interpolatable=True, ) - ref_request[prediction] = gp.ArraySpec( + request[prediction] = gp.ArraySpec( roi=block.write_roi, voxel_size=output_voxel_size, dtype=output_array.dtype, + interpolatable=True, ) with gp.build(pipeline): - batch = pipeline.request_batch(ref_request) + batch = pipeline.request_batch(request) # write to output array output_array[block.write_roi] = batch.arrays[prediction].data @@ -190,7 +195,7 @@ def spawn_worker( # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--run-name", run_name, diff --git a/dacapo/blockwise/relabel_worker.py b/dacapo/blockwise/relabel_worker.py index 4ee2dd9af..1b8580c28 100644 --- a/dacapo/blockwise/relabel_worker.py +++ b/dacapo/blockwise/relabel_worker.py @@ -27,6 +27,7 @@ def cli(log_level): fit = "shrink" read_write_conflict = False +path = __file__ @cli.command() @@ -102,7 +103,7 @@ def spawn_worker( # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--output_container", output_array_identifier.container, diff --git a/dacapo/blockwise/segment_worker.py b/dacapo/blockwise/segment_worker.py index ebb57cb7d..32c86cacb 100644 --- a/dacapo/blockwise/segment_worker.py +++ b/dacapo/blockwise/segment_worker.py @@ -28,6 +28,7 @@ def cli(log_level): fit = "shrink" read_write_conflict = True +path = __file__ @cli.command() @@ -170,7 +171,7 @@ def spawn_worker( # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--input_container", input_array_identifier.container, diff --git a/dacapo/blockwise/threshold_worker.py b/dacapo/blockwise/threshold_worker.py index e848f3b66..60fa0198e 100644 --- a/dacapo/blockwise/threshold_worker.py +++ b/dacapo/blockwise/threshold_worker.py @@ -14,6 +14,7 @@ read_write_conflict: bool = False fit: str = "valid" +path = __file__ @click.group() @@ -89,7 +90,7 @@ def spawn_worker( # Make the command for the worker to run command = [ "python", - __file__, + path, "start-worker", "--input_container", input_array_identifier.container, diff --git a/dacapo/predict.py b/dacapo/predict.py index bfd46cb43..fb70dec86 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -111,8 +111,10 @@ def predict( ) # run blockwise prediction + worker_file = str(Path(Path(__file__).parent, "blockwise", "predict_worker.py")) + logger.info("Running blockwise prediction with worker_file: ", worker_file) run_blockwise( - worker_file=str(Path(Path(__file__).parent, "blockwise", "predict_worker.py")), + worker_file=worker_file, total_roi=_input_roi, read_roi=Roi((0, 0, 0), input_size), write_roi=Roi((0, 0, 0), output_size), From 816c19f91f2979883fa38462d37d2c8fa877dcb0 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Mon, 19 Feb 2024 22:40:53 -0500 Subject: [PATCH 11/19] =?UTF-8?q?fix:=20=F0=9F=9A=A7=20Some=20cleanup,=20b?= =?UTF-8?q?ut=20still=20failing=20test=5Fpredict.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/blockwise/predict_worker.py | 57 ++++++++++++----------- dacapo/compute_context/bsub.py | 1 - dacapo/compute_context/compute_context.py | 2 +- dacapo/predict.py | 23 +++++---- tests/operations/test_predict.py | 3 ++ 5 files changed, 44 insertions(+), 42 deletions(-) diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index f98321e27..c5795fbcc 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -10,8 +10,7 @@ import gunpowder as gp import gunpowder.torch as gp_torch -import daisy -from daisy import Coordinate +from funlib.geometry import Coordinate, Roi import numpy as np import click @@ -147,34 +146,36 @@ def start_worker( ) # assumes float32 is [0,1] pipeline += gp.AsType(prediction, output_array.dtype) - # wait for blocks to run pipeline - client = daisy.Client() - - while True: - print("getting block") - with client.acquire_block() as block: - if block is None: - break - - request = gp.BatchRequest() - request[raw] = gp.ArraySpec( - roi=block.read_roi, - voxel_size=input_voxel_size, - dtype=raw_array.dtype, - interpolatable=True, - ) - request[prediction] = gp.ArraySpec( - roi=block.write_roi, - voxel_size=output_voxel_size, - dtype=output_array.dtype, - interpolatable=True, - ) + # write to output array + pipeline += gp.ZarrWrite( + { + prediction: output_array_identifier.dataset, + }, + store=str(output_array_identifier.container), + compression_type="gzip", + ) - with gp.build(pipeline): - batch = pipeline.request_batch(request) + # make reference batch request + request = gp.BatchRequest() + request[raw] = gp.ArraySpec( + roi=Roi((0,) * len(input_voxel_size), input_size), + voxel_size=input_voxel_size, + dtype=raw_array.dtype, + ) + request.add( + prediction, + output_size, + voxel_size=output_voxel_size, + ) + # use daisy requests to run pipeline + pipeline += gp.DaisyRequestBlocks( + reference=request, + roi_map={raw: "read_roi", prediction: "write_roi"}, + num_workers=1, + ) - # write to output array - output_array[block.write_roi] = batch.arrays[prediction].data + with gp.build(pipeline): + batch = pipeline.request_batch(request) def spawn_worker( diff --git a/dacapo/compute_context/bsub.py b/dacapo/compute_context/bsub.py index 54d3dadda..9f168e752 100644 --- a/dacapo/compute_context/bsub.py +++ b/dacapo/compute_context/bsub.py @@ -2,7 +2,6 @@ import attr -import subprocess from typing import Optional diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index da56e854b..b0fc84fe2 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -20,7 +20,7 @@ def wrap_command(self, command): def execute(self, command): # A helper method to run a command in the context specific way. - return subprocess.run(self.wrap_command(command)) + subprocess.run(self.wrap_command(command), shell=True) def create_compute_context(): diff --git a/dacapo/predict.py b/dacapo/predict.py index fb70dec86..7b91b8e1d 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -27,19 +27,18 @@ def predict( output_dtype: np.dtype | str = np.uint8, # type: ignore overwrite: bool = True, ): - """_summary_ + """Predict with a trained model. Args: - run_name (str): _description_ - iteration (int): _description_ - input_container (Path | str): _description_ - input_dataset (str): _description_ - output_path (Path | str): _description_ - output_roi (Optional[str], optional): Defaults to None. If output roi is None, - it will be set to the raw roi. - num_workers (int, optional): _description_. Defaults to 30. - output_dtype (np.dtype | str, optional): _description_. Defaults to np.uint8. - overwrite (bool, optional): _description_. Defaults to True. + run_name (str): The name of the run to predict with. + iteration (int): The training iteration of the model to use for prediction. + input_container (Path | str): The container of the input array. + input_dataset (str): The dataset name of the input array. + output_path (Path | str): The path where the prediction array will be stored. + output_roi (Optional[Roi | str], optional): The ROI of the output array. If None, the ROI of the input array will be used. Defaults to None. + num_workers (int, optional): The number of workers to use for blockwise prediction. Defaults to 30. + output_dtype (np.dtype | str, optional): The dtype of the output array. Defaults to np.uint8. + overwrite (bool, optional): If True, the output array will be overwritten if it already exists. Defaults to True. """ # retrieving run config_store = create_config_store() @@ -51,7 +50,7 @@ def predict( raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) output_container = Path( output_path, - "".join(Path(input_container).name.split(".")[:-1]) + ".zarr", + Path(input_container).stem + ".zarr", ) # TODO: zarr hardcoded prediction_array_identifier = LocalArrayIdentifier( output_container, f"prediction_{run_name}_{iteration}" diff --git a/tests/operations/test_predict.py b/tests/operations/test_predict.py index 8a095be45..e5da6f70b 100644 --- a/tests/operations/test_predict.py +++ b/tests/operations/test_predict.py @@ -1,3 +1,4 @@ +import os from ..fixtures import * from dacapo.experiments import Run @@ -21,6 +22,8 @@ ], ) def test_predict(options, run_config, zarr_array, tmp_path): + # os.environ["PYDEVD_UNBLOCK_THREADS_TIMEOUT"] = "2.0" + # create a store store = create_config_store() From bd171669b996884653681bd2bee0846d1d809b4e Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 20 Feb 2024 12:53:02 -0500 Subject: [PATCH 12/19] =?UTF-8?q?fix:=20=F0=9F=9A=A7=20Debug=20predict=5Fw?= =?UTF-8?q?orker.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit predict_worker.py works outside of pytest context. --- .gitignore | 1 + dacapo/blockwise/predict_worker.py | 44 ++++++++++++++++++----- dacapo/compute_context/compute_context.py | 3 +- dacapo/predict.py | 35 +++++++++--------- tests/fixtures/arrays.py | 6 ++-- tests/operations/test_predict.py | 1 - 6 files changed, 58 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index b50398827..fd7633975 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ dist build dacapo.yaml __pycache__ +scratch/ # vscode stuff .vscode diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index c5795fbcc..e6c2dfdd4 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -10,7 +10,7 @@ import gunpowder as gp import gunpowder.torch as gp_torch -from funlib.geometry import Coordinate, Roi +from funlib.geometry import Coordinate import numpy as np import click @@ -92,7 +92,7 @@ def start_worker( torch.backends.cudnn.benchmark = True # get the model's input and output size - model = run.model.eval() + model = run.model.eval().to(device) input_voxel_size = Coordinate(raw_array.voxel_size) output_voxel_size = model.scale(input_voxel_size) input_shape = Coordinate(model.eval_input_shape) @@ -102,6 +102,35 @@ def start_worker( logger.info( "Predicting with input size %s, output size %s", input_size, output_size ) + + # # simple daisy case + # daisy_client = daisy.Client() + + # while True: + # with daisy_client.acquire_block() as block: + # if block is None: + # return + + # raw_in = raw_array[block.read_roi][None, ...] + # # convert to float32 if necessary: + # if raw_in.dtype != np.float32 and raw_in.dtype != np.float64: + # raw_in = raw_in.astype(np.float32) + # # normalize to [0,1] + # raw_in /= np.iinfo(raw_in.dtype).max + # elif raw_in.dtype == np.float64: + # raw_in = raw_in.astype(np.float32) + # print(raw_in.shape, raw_in.dtype, raw_in.min(), raw_in.max()) + # raw_in = torch.as_tensor(raw_in).to(device) + # with torch.no_grad(): + # prediction_out = model(raw_in) + # # convert to uint8 if necessary: + # if output_array.dtype == np.uint8: + # prediction_out = (prediction_out * 255).to(torch.uint8) + # # move to cpu and numpy + # prediction_out = prediction_out.cpu().numpy() + # # write to output array + # output_array[block.write_roi] = prediction_out + # create gunpowder keys raw = gp.ArrayKey("RAW") @@ -117,6 +146,8 @@ def start_worker( pipeline += gp.Unsqueeze([raw]) # raw: (1, c, d, h, w) + pipeline += gp.Normalize(raw) + # predict pipeline += gp_torch.Predict( model=model, @@ -152,16 +183,11 @@ def start_worker( prediction: output_array_identifier.dataset, }, store=str(output_array_identifier.container), - compression_type="gzip", ) # make reference batch request request = gp.BatchRequest() - request[raw] = gp.ArraySpec( - roi=Roi((0,) * len(input_voxel_size), input_size), - voxel_size=input_voxel_size, - dtype=raw_array.dtype, - ) + request.add(raw, input_size, voxel_size=input_voxel_size) request.add( prediction, output_size, @@ -175,7 +201,7 @@ def start_worker( ) with gp.build(pipeline): - batch = pipeline.request_batch(request) + batch = pipeline.request_batch(gp.BatchRequest()) def spawn_worker( diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index b0fc84fe2..1eab08b0f 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -20,7 +20,8 @@ def wrap_command(self, command): def execute(self, command): # A helper method to run a command in the context specific way. - subprocess.run(self.wrap_command(command), shell=True) + # subprocess.run(self.wrap_command(command), shell=True) + subprocess.run(self.wrap_command(command)) def create_compute_context(): diff --git a/dacapo/predict.py b/dacapo/predict.py index 7b91b8e1d..d19c9b885 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -56,8 +56,21 @@ def predict( output_container, f"prediction_{run_name}_{iteration}" ) + # get the model's input and output size + model = run.model.eval() + + input_voxel_size = Coordinate(raw_array.voxel_size) + output_voxel_size = model.scale(input_voxel_size) + input_shape = Coordinate(model.eval_input_shape) + input_size = input_voxel_size * input_shape + output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] + + # calculate input and output rois + + context = (input_size - output_size) // 2 + if output_roi is None: - output_roi = raw_array.roi + output_roi = raw_array.roi.grow(-context, -context) elif isinstance(output_roi, str): start, end = zip( *[ @@ -71,30 +84,16 @@ def predict( ) output_roi = output_roi.snap_to_grid( raw_array.voxel_size, mode="grow" - ).intersect(raw_array.roi) + ).intersect(raw_array.roi.grow(-context, -context)) + _input_roi = output_roi.grow(context, context) if isinstance(output_dtype, str): output_dtype = np.dtype(output_dtype) - model = run.model.eval() - - # get the model's input and output size - - input_voxel_size = Coordinate(raw_array.voxel_size) - output_voxel_size = model.scale(input_voxel_size) - input_shape = Coordinate(model.eval_input_shape) - input_size = input_voxel_size * input_shape - output_size = output_voxel_size * model.compute_output_shape(input_shape)[1] - logger.info( "Predicting with input size %s, output size %s", input_size, output_size ) - # calculate input and output rois - - context = (input_size - output_size) / 2 - _input_roi = output_roi.grow(context, context) - logger.info("Total input ROI: %s, output ROI: %s", _input_roi, output_roi) # prepare prediction dataset @@ -116,7 +115,7 @@ def predict( worker_file=worker_file, total_roi=_input_roi, read_roi=Roi((0, 0, 0), input_size), - write_roi=Roi((0, 0, 0), output_size), + write_roi=Roi(context, output_size), num_workers=num_workers, max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option diff --git a/tests/fixtures/arrays.py b/tests/fixtures/arrays.py index c2f45483e..cd09fe658 100644 --- a/tests/fixtures/arrays.py +++ b/tests/fixtures/arrays.py @@ -24,11 +24,11 @@ def zarr_array(tmp_path): ) zarr_container = zarr.open(str(tmp_path / "zarr_array.zarr")) dataset = zarr_container.create_dataset( - "volumes/test", data=np.zeros((100, 50, 25)) + "volumes/test", data=np.zeros((100, 50, 25), dtype=np.float32) ) dataset.attrs["offset"] = (12, 12, 12) dataset.attrs["resolution"] = (1, 2, 4) - dataset.attrs["axes"] = "zyx" + dataset.attrs["axes"] = ["zyx"] yield zarr_array_config @@ -41,7 +41,7 @@ def cellmap_array(tmp_path): ) zarr_container = zarr.open(str(tmp_path / "zarr_array.zarr")) dataset = zarr_container.create_dataset( - "volumes/test", data=np.arange(0, 100).reshape(10, 5, 2) + "volumes/test", data=np.arange(0, 100, dtype=np.uint8).reshape(10, 5, 2) ) dataset.attrs["offset"] = (12, 12, 12) dataset.attrs["resolution"] = (1, 2, 4) diff --git a/tests/operations/test_predict.py b/tests/operations/test_predict.py index e5da6f70b..956eadba8 100644 --- a/tests/operations/test_predict.py +++ b/tests/operations/test_predict.py @@ -1,4 +1,3 @@ -import os from ..fixtures import * from dacapo.experiments import Run From e38090799717247e1de3ac364c0ec20eaabc6cc2 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Tue, 20 Feb 2024 12:58:45 -0500 Subject: [PATCH 13/19] Remove commented out code for daisy case in predict_worker.py --- dacapo/blockwise/predict_worker.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index e6c2dfdd4..cb40091b4 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -103,34 +103,6 @@ def start_worker( "Predicting with input size %s, output size %s", input_size, output_size ) - # # simple daisy case - # daisy_client = daisy.Client() - - # while True: - # with daisy_client.acquire_block() as block: - # if block is None: - # return - - # raw_in = raw_array[block.read_roi][None, ...] - # # convert to float32 if necessary: - # if raw_in.dtype != np.float32 and raw_in.dtype != np.float64: - # raw_in = raw_in.astype(np.float32) - # # normalize to [0,1] - # raw_in /= np.iinfo(raw_in.dtype).max - # elif raw_in.dtype == np.float64: - # raw_in = raw_in.astype(np.float32) - # print(raw_in.shape, raw_in.dtype, raw_in.min(), raw_in.max()) - # raw_in = torch.as_tensor(raw_in).to(device) - # with torch.no_grad(): - # prediction_out = model(raw_in) - # # convert to uint8 if necessary: - # if output_array.dtype == np.uint8: - # prediction_out = (prediction_out * 255).to(torch.uint8) - # # move to cpu and numpy - # prediction_out = prediction_out.cpu().numpy() - # # write to output array - # output_array[block.write_roi] = prediction_out - # create gunpowder keys raw = gp.ArrayKey("RAW") From 9b12bbd2ad6647b24527a7ea1cddd309fb6cc8a1 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 21 Feb 2024 00:15:33 -0500 Subject: [PATCH 14/19] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Passed=20test=5Fpre?= =?UTF-8?q?dict.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit also updated daisy dependency, and added "all" option for installing optional dependencies --- dacapo/blockwise/predict_worker.py | 2 +- dacapo/compute_context/compute_context.py | 1 - dacapo/predict.py | 13 +++++++++++-- pyproject.toml | 8 +++++--- tests/operations/test_predict.py | 3 +++ 5 files changed, 20 insertions(+), 7 deletions(-) diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index cb40091b4..5074c3620 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -66,7 +66,7 @@ def start_worker( input_dataset: str, output_container: Path | str, output_dataset: str, - device: str = "cuda", + device: str | torch.device = "cuda", ): # retrieving run config_store = create_config_store() diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index 1eab08b0f..3c30079d1 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -20,7 +20,6 @@ def wrap_command(self, command): def execute(self, command): # A helper method to run a command in the context specific way. - # subprocess.run(self.wrap_command(command), shell=True) subprocess.run(self.wrap_command(command)) diff --git a/dacapo/predict.py b/dacapo/predict.py index d19c9b885..f856cffd0 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -2,7 +2,7 @@ from dacapo.blockwise import run_blockwise from dacapo.experiments import Run -from dacapo.store.create_store import create_config_store +from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.store.local_array_store import LocalArrayIdentifier from dacapo.experiments.datasplits.datasets.arrays import ZarrArray @@ -23,7 +23,7 @@ def predict( input_dataset: str, output_path: Path | str, output_roi: Optional[Roi | str] = None, - num_workers: int = 30, + num_workers: int = 12, output_dtype: np.dtype | str = np.uint8, # type: ignore overwrite: bool = True, ): @@ -45,6 +45,15 @@ def predict( run_config = config_store.retrieve_run_config(run_name) run = Run(run_config) + # check to see if we can load the weights + weights_store = create_weights_store() + try: + weights_store.retrieve_weights(run_name, iteration) + except FileNotFoundError: + raise ValueError( + f"No weights found for run {run_name} at iteration {iteration}." + ) + # get arrays raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) diff --git a/pyproject.toml b/pyproject.toml index 325a77b35..a3f0b0015 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,8 @@ dependencies = [ "attrs", "bokeh", "numpy-indexed>=0.3.7", - "daisy>=1.0", + #"daisy>=1.0", + "daisy @ git+https://github.com/funkelab/daisy.git", "funlib.math>=0.1", "funlib.geometry>=0.2", "mwatershed>=0.1", @@ -83,9 +84,10 @@ examples = [ "ipykernel", "jupyter", ] +all = ["dacapo-ml[test,dev,docs,examples]"] [project.urls] -homepage = "https://github.com/janelia-cellmap/dacapo" +homepage = "https://github.io/janelia-cellmap/dacapo" repository = "https://github.com/janelia-cellmap/dacapo" # https://hatch.pypa.io/latest/config/metadata/ @@ -139,7 +141,7 @@ filterwarnings = [ "error", "ignore::DeprecationWarning", ] - + # https://mypy.readthedocs.io/en/stable/config_file.html [tool.mypy] files = "dacapo/**/" diff --git a/tests/operations/test_predict.py b/tests/operations/test_predict.py index 956eadba8..f537c8fd9 100644 --- a/tests/operations/test_predict.py +++ b/tests/operations/test_predict.py @@ -46,6 +46,7 @@ def test_predict(options, run_config, zarr_array, tmp_path): input_container=zarr_array.file_name, input_dataset=zarr_array.dataset, output_path=tmp_path, + num_workers=4, ) weights_store.store_weights(run, 1) predict( @@ -54,6 +55,7 @@ def test_predict(options, run_config, zarr_array, tmp_path): input_container=zarr_array.file_name, input_dataset=zarr_array.dataset, output_path=tmp_path, + num_workers=4, ) # test predicting with iterations for which we know there are no weights @@ -64,4 +66,5 @@ def test_predict(options, run_config, zarr_array, tmp_path): input_container=zarr_array.file_name, input_dataset=zarr_array.dataset, output_path=tmp_path, + num_workers=4, ) From 0ef1fd2c69ed81b618456d89a2e6f45d587cad11 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 21 Feb 2024 15:35:34 -0500 Subject: [PATCH 15/19] =?UTF-8?q?fix:=20=F0=9F=A7=AA=20Fix=20imports.=20te?= =?UTF-8?q?st=5Fapply.py=20passes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/apply.py | 16 ++++++++-------- dacapo/blockwise/scheduler.py | 5 +++-- .../post_processors/argmax_post_processor.py | 5 +++-- .../post_processors/dummy_post_processor.py | 2 +- .../post_processors/threshold_post_processor.py | 10 ++++------ .../post_processors/watershed_post_processor.py | 3 ++- dacapo/predict.py | 14 +++++++++----- tests/fixtures/arrays.py | 9 +++++---- tests/operations/test_apply.py | 3 +++ 9 files changed, 38 insertions(+), 29 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index f12f38091..939f803c2 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -33,7 +33,7 @@ def apply( iteration: Optional[int] = None, parameters: Optional[PostProcessorParameters | str] = None, roi: Optional[Roi | str] = None, - num_workers: int = 30, + num_workers: int = 12, output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore overwrite: bool = True, file_format: str = "zarr", @@ -142,7 +142,7 @@ def apply( ) output_container = Path( output_path, - "".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}", + Path(input_container).stem + f".{file_format}", ) prediction_array_identifier = LocalArrayIdentifier( output_container, f"prediction_{run_name}_{iteration}" @@ -158,7 +158,7 @@ def apply( Path(input_container, input_dataset), ) return apply_run( - run.name, + run, iteration, parameters, input_array_identifier, @@ -172,15 +172,15 @@ def apply( def apply_run( - run_name: str, + run: Run, iteration: int, parameters: PostProcessorParameters, input_array_identifier: "LocalArrayIdentifier", prediction_array_identifier: "LocalArrayIdentifier", output_array_identifier: "LocalArrayIdentifier", roi: Optional[Roi] = None, - num_workers: int = 30, - output_dtype: Optional[np.dtype] = np.uint8, # type: ignore + num_workers: int = 12, + output_dtype: np.dtype | str = np.uint8, # type: ignore overwrite: bool = True, ): """Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.""" @@ -188,7 +188,7 @@ def apply_run( # render prediction dataset logger.info("Predicting on dataset %s", prediction_array_identifier) predict( - run_name, + run.name, iteration, input_container=input_array_identifier.container, input_dataset=input_array_identifier.dataset, @@ -203,7 +203,7 @@ def apply_run( logger.info("Post-processing output to dataset %s", output_array_identifier) post_processor = run.task.post_processor post_processor.set_prediction(prediction_array_identifier) - post_processor.process(parameters, output_array_identifier) + post_processor.process(parameters, output_array_identifier, num_workers=num_workers) logger.info("Done") return diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index 777b2c2bd..4430362af 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -2,6 +2,7 @@ import tempfile import time import daisy +import dacapo.blockwise from funlib.geometry import Roi, Coordinate import yaml @@ -102,7 +103,7 @@ def segment_blockwise( # Make the task task = DaCapoBlockwiseTask( - str(Path(Path(__file__).parent, "segment_worker.py")), + str(Path(Path(dacapo.blockwise.__file__).parent, "segment_worker.py")), total_roi.grow(context, context), read_roi, write_roi, @@ -126,7 +127,7 @@ def segment_blockwise( # Make the task task = DaCapoBlockwiseTask( - str(Path(Path(__file__).parent, "relabel_worker.py")), + str(Path(Path(dacapo.blockwise.__file__).parent, "relabel_worker.py")), total_roi, read_roi, write_roi, diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 5fa6009ec..2cecde806 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -1,5 +1,6 @@ from pathlib import Path -from dacapo.blockwise.scheduler import run_blockwise +from dacapo.blockwise import run_blockwise +import dacapo.blockwise from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier from .argmax_post_processor_parameters import ArgmaxPostProcessorParameters @@ -43,7 +44,7 @@ def process( # run blockwise prediction run_blockwise( worker_file=str( - Path(Path(__file__).parent, "blockwise", "predict_worker.py") + Path(Path(dacapo.blockwise.__file__).parent, "argmax_worker.py") ), total_roi=self.prediction_array.roi, read_roi=read_roi, diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py index 5a2c7810a..4ab60457a 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py @@ -21,7 +21,7 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]: def set_prediction(self, prediction_array): pass - def process(self, parameters, output_array_identifier): + def process(self, parameters, output_array_identifier, *args, **kwargs): # store some dummy data f = zarr.open(str(output_array_identifier.container), "a") f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 712ef5baf..c20579019 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -3,16 +3,14 @@ from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from .threshold_post_processor_parameters import ThresholdPostProcessorParameters from .post_processor import PostProcessor +import dacapo.blockwise import numpy as np from daisy import Roi, Coordinate from typing import TYPE_CHECKING, Iterable -if TYPE_CHECKING: - from dacapo.store.local_array_store import LocalArrayIdentifier - from dacapo.experiments.tasks.post_processors import ( - ThresholdPostProcessorParameters, - ) +# if TYPE_CHECKING: +from dacapo.store.local_array_store import LocalArrayIdentifier class ThresholdPostProcessor(PostProcessor): @@ -60,7 +58,7 @@ def process( # run blockwise prediction run_blockwise( worker_file=str( - Path(Path(__file__).parent, "blockwise", "predict_worker.py") + Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py") ), total_roi=self.prediction_array.roi, read_roi=read_roi, diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 6fac51735..7ada70a6a 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -1,4 +1,5 @@ from pathlib import Path +import dacapo.blockwise from dacapo.blockwise.scheduler import segment_blockwise from dacapo.experiments.datasplits.datasets.arrays import ZarrArray from dacapo.store.array_store import LocalArrayIdentifier @@ -55,7 +56,7 @@ def process( } segment_blockwise( segment_function_file=str( - Path(Path(__file__).parent, "blockwise", "watershed_function.py") + Path(Path(dacapo.blockwise.__file__).parent, "watershed_function.py") ), context=parameters.context, total_roi=self.prediction_array.roi, diff --git a/dacapo/predict.py b/dacapo/predict.py index f856cffd0..87556e1c1 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -1,6 +1,7 @@ from pathlib import Path from dacapo.blockwise import run_blockwise +import dacapo.blockwise from dacapo.experiments import Run from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.store.local_array_store import LocalArrayIdentifier @@ -57,10 +58,13 @@ def predict( # get arrays raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) - output_container = Path( - output_path, - Path(input_container).stem + ".zarr", - ) # TODO: zarr hardcoded + if ".zarr" in str(output_path) or ".n5" in str(output_path): + output_container = Path(output_path) + else: + output_container = Path( + output_path, + Path(input_container).stem + ".zarr", + ) # TODO: zarr hardcoded prediction_array_identifier = LocalArrayIdentifier( output_container, f"prediction_{run_name}_{iteration}" ) @@ -118,7 +122,7 @@ def predict( ) # run blockwise prediction - worker_file = str(Path(Path(__file__).parent, "blockwise", "predict_worker.py")) + worker_file = str(Path(Path(dacapo.blockwise.__file__).parent, "predict_worker.py")) logger.info("Running blockwise prediction with worker_file: ", worker_file) run_blockwise( worker_file=worker_file, diff --git a/tests/fixtures/arrays.py b/tests/fixtures/arrays.py index cd09fe658..8af4e90f2 100644 --- a/tests/fixtures/arrays.py +++ b/tests/fixtures/arrays.py @@ -22,9 +22,9 @@ def zarr_array(tmp_path): file_name=tmp_path / "zarr_array.zarr", dataset="volumes/test", ) - zarr_container = zarr.open(str(tmp_path / "zarr_array.zarr")) + zarr_container = zarr.open(str(zarr_array_config.file_name)) dataset = zarr_container.create_dataset( - "volumes/test", data=np.zeros((100, 50, 25), dtype=np.float32) + zarr_array_config.dataset, data=np.zeros((100, 50, 25), dtype=np.float32) ) dataset.attrs["offset"] = (12, 12, 12) dataset.attrs["resolution"] = (1, 2, 4) @@ -39,9 +39,10 @@ def cellmap_array(tmp_path): file_name=tmp_path / "zarr_array.zarr", dataset="volumes/test", ) - zarr_container = zarr.open(str(tmp_path / "zarr_array.zarr")) + zarr_container = zarr.open(str(zarr_array_config.file_name)) dataset = zarr_container.create_dataset( - "volumes/test", data=np.arange(0, 100, dtype=np.uint8).reshape(10, 5, 2) + zarr_array_config.dataset, + data=np.arange(0, 100, dtype=np.uint8).reshape(10, 5, 2), ) dataset.attrs["offset"] = (12, 12, 12) dataset.attrs["resolution"] = (1, 2, 4) diff --git a/tests/operations/test_apply.py b/tests/operations/test_apply.py index 8d5b2f781..64390d0c3 100644 --- a/tests/operations/test_apply.py +++ b/tests/operations/test_apply.py @@ -47,6 +47,7 @@ def test_apply(options, run_config, zarr_array, tmp_path): output_path=tmp_path, iteration=0, parameters=parameters, + num_workers=4, ) weights_store.store_weights(run, 1) apply( @@ -56,6 +57,7 @@ def test_apply(options, run_config, zarr_array, tmp_path): output_path=tmp_path, iteration=1, parameters=parameters, + num_workers=4, ) # test validating weights that don't exist @@ -67,4 +69,5 @@ def test_apply(options, run_config, zarr_array, tmp_path): output_path=tmp_path, iteration=2, parameters=parameters, + num_workers=4, ) From 6d0a44482f8a51fc1216b8dd05fae8e466f0db10 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 21 Feb 2024 15:47:58 -0500 Subject: [PATCH 16/19] =?UTF-8?q?fix:=20=F0=9F=90=9B=20Fix=20abstract=20Co?= =?UTF-8?q?nfigStore=20to=20allow=20MongoConfigStore=20definition.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/store/config_store.py | 43 +++++++----------------------------- 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/dacapo/store/config_store.py b/dacapo/store/config_store.py index 8c91fd036..87fa6edb0 100644 --- a/dacapo/store/config_store.py +++ b/dacapo/store/config_store.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, TYPE_CHECKING +from typing import Any, List, TYPE_CHECKING if TYPE_CHECKING: from dacapo.experiments.run_config import RunConfig @@ -17,40 +17,13 @@ class DuplicateNameError(Exception): class ConfigStore(ABC): """Base class for configuration stores.""" - @property - @abstractmethod - def runs(self): - pass - - @property - @abstractmethod - def datasplits(self): - pass - - @property - @abstractmethod - def datasets(self): - pass - - @property - @abstractmethod - def arrays(self): - pass - - @property - @abstractmethod - def tasks(self): - pass - - @property - @abstractmethod - def trainers(self): - pass - - @property - @abstractmethod - def architectures(self): - pass + runs: Any + datasplits: Any + datasets: Any + arrays: Any + tasks: Any + trainers: Any + architectures: Any @abstractmethod def delete_config(self, database, config_name: str) -> None: From c1b7f8a894eb3ca7b9c88e42131caecdaa29a48b Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 21 Feb 2024 21:57:39 -0500 Subject: [PATCH 17/19] =?UTF-8?q?fix:=20=F0=9F=9A=A7=20Extend=20prediction?= =?UTF-8?q?=20output=20declaration,=20refactor,=20debug=20tests,=20and=20d?= =?UTF-8?q?ebug=20validate.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dacapo/apply.py | 2 +- dacapo/compute_context/compute_context.py | 2 +- .../tasks/evaluators/dummy_evaluator.py | 2 +- .../experiments/tasks/evaluators/evaluator.py | 5 ++-- dacapo/options.py | 2 +- dacapo/predict.py | 25 +++++++++++-------- dacapo/store/create_store.py | 12 ++++----- dacapo/validate.py | 11 ++++---- tests/fixtures/datasplits.py | 4 +-- tests/operations/test_validate.py | 6 ++--- 10 files changed, 38 insertions(+), 33 deletions(-) diff --git a/dacapo/apply.py b/dacapo/apply.py index 939f803c2..bfdb2c182 100644 --- a/dacapo/apply.py +++ b/dacapo/apply.py @@ -192,7 +192,7 @@ def apply_run( iteration, input_container=input_array_identifier.container, input_dataset=input_array_identifier.dataset, - output_path=prediction_array_identifier.container, + output_path=prediction_array_identifier, output_roi=roi, num_workers=num_workers, output_dtype=output_dtype, diff --git a/dacapo/compute_context/compute_context.py b/dacapo/compute_context/compute_context.py index 3c30079d1..a57de1a09 100644 --- a/dacapo/compute_context/compute_context.py +++ b/dacapo/compute_context/compute_context.py @@ -33,4 +33,4 @@ def create_compute_context(): **options.compute_context_config["config"] ) else: - raise ValueError(f"Unknown store type {options.db_type}") + raise ValueError(f"Unknown store type {options.type}") diff --git a/dacapo/experiments/tasks/evaluators/dummy_evaluator.py b/dacapo/experiments/tasks/evaluators/dummy_evaluator.py index 3e2e27b94..1c27bb320 100644 --- a/dacapo/experiments/tasks/evaluators/dummy_evaluator.py +++ b/dacapo/experiments/tasks/evaluators/dummy_evaluator.py @@ -7,7 +7,7 @@ class DummyEvaluator(Evaluator): criteria = ["frizz_level", "blipp_score"] - def evaluate(self, output_array, evaluation_dataset): + def evaluate(self, output_array_identifier, evaluation_dataset): return DummyEvaluationScores( frizz_level=random.random(), blipp_score=random.random() ) diff --git a/dacapo/experiments/tasks/evaluators/evaluator.py b/dacapo/experiments/tasks/evaluators/evaluator.py index 9d5cbbda0..4039a20dc 100644 --- a/dacapo/experiments/tasks/evaluators/evaluator.py +++ b/dacapo/experiments/tasks/evaluators/evaluator.py @@ -9,6 +9,7 @@ from dacapo.experiments.tasks.evaluators.evaluation_scores import EvaluationScores from dacapo.experiments.datasplits.datasets import Dataset from dacapo.experiments.datasplits.datasets.arrays import Array + from dacapo.store.local_array_store import LocalArrayIdentifier from dacapo.experiments.tasks.post_processors import PostProcessorParameters from dacapo.experiments.validation_scores import ValidationScores @@ -28,9 +29,9 @@ class Evaluator(ABC): @abstractmethod def evaluate( - self, output_array: "Array", eval_array: "Array" + self, output_array_identifier: "LocalArrayIdentifier", evaluation_array: "Array" ) -> "EvaluationScores": - """Compare an `output_array` against ground-truth `eval_array`""" + """Compare an `output_array_identifier` against ground-truth `evaluation_array`""" pass @property diff --git a/dacapo/options.py b/dacapo/options.py index 41a6fda50..981979690 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -15,7 +15,7 @@ @attr.s class DaCapoConfig: - db_type: str = attr.ib( + type: str = attr.ib( default="files", metadata={ "help_text": "The type of store to use for storing configurations and statistics. " diff --git a/dacapo/predict.py b/dacapo/predict.py index 87556e1c1..ee0dcaa2b 100644 --- a/dacapo/predict.py +++ b/dacapo/predict.py @@ -22,7 +22,7 @@ def predict( iteration: int, input_container: Path | str, input_dataset: str, - output_path: Path | str, + output_path: LocalArrayIdentifier | str, output_roi: Optional[Roi | str] = None, num_workers: int = 12, output_dtype: np.dtype | str = np.uint8, # type: ignore @@ -35,7 +35,7 @@ def predict( iteration (int): The training iteration of the model to use for prediction. input_container (Path | str): The container of the input array. input_dataset (str): The dataset name of the input array. - output_path (Path | str): The path where the prediction array will be stored. + output_path (LocalArrayIdentifier | str): The path where the prediction array will be stored, or a LocalArryIdentifier for the prediction array. output_roi (Optional[Roi | str], optional): The ROI of the output array. If None, the ROI of the input array will be used. Defaults to None. num_workers (int, optional): The number of workers to use for blockwise prediction. Defaults to 30. output_dtype (np.dtype | str, optional): The dtype of the output array. Defaults to np.uint8. @@ -58,16 +58,19 @@ def predict( # get arrays raw_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset) raw_array = ZarrArray.open_from_array_identifier(raw_array_identifier) - if ".zarr" in str(output_path) or ".n5" in str(output_path): - output_container = Path(output_path) + if isinstance(output_path, LocalArrayIdentifier): + prediction_array_identifier = output_path else: - output_container = Path( - output_path, - Path(input_container).stem + ".zarr", - ) # TODO: zarr hardcoded - prediction_array_identifier = LocalArrayIdentifier( - output_container, f"prediction_{run_name}_{iteration}" - ) + if ".zarr" in str(output_path) or ".n5" in str(output_path): + output_container = Path(output_path) + else: + output_container = Path( + output_path, + Path(input_container).stem + ".zarr", + ) # TODO: zarr hardcoded + prediction_array_identifier = LocalArrayIdentifier( + output_container, f"prediction_{run_name}_{iteration}" + ) # get the model's input and output size model = run.model.eval() diff --git a/dacapo/store/create_store.py b/dacapo/store/create_store.py index b5442bbad..0fcc43ed2 100644 --- a/dacapo/store/create_store.py +++ b/dacapo/store/create_store.py @@ -14,15 +14,15 @@ def create_config_store(): options = Options.instance() - if options.db_type == "mongo": + if options.type == "mongo": db_host = options.mongo_db_host db_name = options.mongo_db_name return MongoConfigStore(db_host, db_name) - elif options.db_type == "files": + elif options.type == "files": store_path = Path(options.runs_base_dir).expanduser() return FileConfigStore(store_path / "configs") else: - raise ValueError(f"Unknown store type {options.db_type}") + raise ValueError(f"Unknown store type {options.type}") def create_stats_store(): @@ -30,15 +30,15 @@ def create_stats_store(): options = Options.instance() - if options.db_type == "mongo": + if options.type == "mongo": db_host = options.mongo_db_host db_name = options.mongo_db_name return MongoStatsStore(db_host, db_name) - elif options.db_type == "files": + elif options.type == "files": store_path = Path(options.runs_base_dir).expanduser() return FileStatsStore(store_path / "stats") else: - raise ValueError(f"Unknown store type {options.db_type}") + raise ValueError(f"Unknown store type {options.type}") def create_weights_store(): diff --git a/dacapo/validate.py b/dacapo/validate.py index 74ffd9ddb..0b87bff91 100644 --- a/dacapo/validate.py +++ b/dacapo/validate.py @@ -43,7 +43,7 @@ def validate( # create weights store and read weights weights_store = create_weights_store() - weights_store.retrieve_weights(run, iteration) + weights_store.retrieve_weights(run.name, iteration) return validate_run( run, @@ -147,7 +147,7 @@ def validate_run( logger.info("validation inputs already copied!") prediction_array_identifier = array_store.validation_prediction_array( - run.name, iteration, validation_dataset + run.name, iteration, validation_dataset.name ) logger.info("Predicting on dataset %s", validation_dataset.name) predict( @@ -155,7 +155,7 @@ def validate_run( iteration, input_container=input_raw_array_identifier.container, input_dataset=input_raw_array_identifier.dataset, - output_path=prediction_array_identifier.container, + output_path=prediction_array_identifier, output_roi=validation_dataset.gt.roi, num_workers=num_workers, output_dtype=output_dtype, @@ -170,7 +170,7 @@ def validate_run( for parameters in post_processor.enumerate_parameters(): output_array_identifier = array_store.validation_output_array( - run.name, iteration, parameters, validation_dataset + run.name, iteration, str(parameters), validation_dataset.name ) post_processed_array = post_processor.process( @@ -206,10 +206,11 @@ def validate_run( "iteration": iteration, criterion: getattr(scores, criterion), "parameters_id": parameters.id, + "parameters": str(parameters), } ) weights_store.store_best( - run, iteration, validation_dataset.name, criterion + run.name, iteration, validation_dataset.name, criterion ) # delete current output. We only keep the best outputs as determined by diff --git a/tests/fixtures/datasplits.py b/tests/fixtures/datasplits.py index 5af5a4036..7bb5672c6 100644 --- a/tests/fixtures/datasplits.py +++ b/tests/fixtures/datasplits.py @@ -67,7 +67,7 @@ def twelve_class_datasplit(tmp_path): gt_dataset = twelve_class_zarr.create_dataset( gt.dataset, shape=(40, 20, 20), dtype=np.uint8 ) - random_data = np.random.randn(40, 20, 20) + random_data = np.random.rand(40, 20, 20) # as intensities increase so does the class for i in list(np.linspace(random_data.min(), random_data.max(), 13))[1:]: gt_dataset[:] += random_data > i @@ -178,7 +178,7 @@ def six_class_datasplit(tmp_path): gt_dataset = twelve_class_zarr.create_dataset( gt.dataset, shape=(40, 20, 20), dtype=np.uint8 ) - random_data = np.random.randn(40, 20, 20) + random_data = np.random.rand(40, 20, 20) # as intensities increase so does the class for i in list(np.linspace(random_data.min(), random_data.max(), 13))[1:]: gt_dataset[:] += random_data > i diff --git a/tests/operations/test_validate.py b/tests/operations/test_validate.py index df4f4774b..7489da40d 100644 --- a/tests/operations/test_validate.py +++ b/tests/operations/test_validate.py @@ -42,10 +42,10 @@ def test_validate( # test validating iterations for which we know there are weights weights_store.store_weights(run, 0) - validate(run_config.name, 0) + validate(run_config.name, 0, num_workers=4) weights_store.store_weights(run, 1) - validate(run_config.name, 1) + validate(run_config.name, 1, num_workers=4) # test validating weights that don't exist with pytest.raises(FileNotFoundError): - validate(run_config.name, 2) + validate(run_config.name, 2, num_workers=4) From d5f0e60b63ded288a62c6a4f0ed54a20d116796f Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Wed, 21 Feb 2024 23:17:32 -0500 Subject: [PATCH 18/19] =?UTF-8?q?chore:=20=F0=9F=A7=91=E2=80=8D?= =?UTF-8?q?=F0=9F=92=BB=20Clean=20imports.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WIP pytests hanging, then passing after keyboardinterupt. --- dacapo/blockwise/predict_worker.py | 4 ++-- dacapo/blockwise/scheduler.py | 2 +- dacapo/experiments/datasplits/datasplit.py | 2 +- dacapo/experiments/run.py | 6 +++--- .../tasks/post_processors/argmax_post_processor.py | 3 ++- .../tasks/post_processors/threshold_post_processor.py | 2 +- 6 files changed, 10 insertions(+), 9 deletions(-) diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 5074c3620..6b47bf76c 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -1,8 +1,8 @@ from pathlib import Path import torch -from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray -from dacapo.gp.dacapo_array_source import DaCapoArraySource +from dacapo.experiments.datasplits.datasets.arrays import ZarrArray +from dacapo.gp import DaCapoArraySource from dacapo.store.array_store import LocalArrayIdentifier from dacapo.store.create_store import create_config_store, create_weights_store from dacapo.experiments import Run diff --git a/dacapo/blockwise/scheduler.py b/dacapo/blockwise/scheduler.py index 4430362af..b9ee72765 100644 --- a/dacapo/blockwise/scheduler.py +++ b/dacapo/blockwise/scheduler.py @@ -82,7 +82,7 @@ def run_blockwise( def segment_blockwise( - segment_function_file: str or Path, + segment_function_file: str | Path, context: Coordinate, total_roi: Roi, read_roi: Roi, diff --git a/dacapo/experiments/datasplits/datasplit.py b/dacapo/experiments/datasplits/datasplit.py index 17c7e3ac1..eb5a55023 100644 --- a/dacapo/experiments/datasplits/datasplit.py +++ b/dacapo/experiments/datasplits/datasplit.py @@ -1,4 +1,4 @@ -from dacapo.experiments.datasplits.datasets import Dataset +from .datasets import Dataset import neuroglancer diff --git a/dacapo/experiments/run.py b/dacapo/experiments/run.py index 129f947ab..8aef6eb1d 100644 --- a/dacapo/experiments/run.py +++ b/dacapo/experiments/run.py @@ -1,7 +1,7 @@ -from .datasplits.datasplit import DataSplit +from .datasplits import DataSplit from .tasks.task import Task -from .architectures.architecture import Architecture -from .trainers.trainer import Trainer +from .architectures import Architecture +from .trainers import Trainer from .training_stats import TrainingStats from .validation_scores import ValidationScores from .starts import Start diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 2cecde806..4302b6d08 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -41,7 +41,7 @@ def process( ) read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size) - # run blockwise prediction + # run blockwise post-processing run_blockwise( worker_file=str( Path(Path(dacapo.blockwise.__file__).parent, "argmax_worker.py") @@ -58,4 +58,5 @@ def process( ), output_array_identifier=output_array_identifier, ) + return output_array diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index c20579019..4c337ed62 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -55,7 +55,7 @@ def process( ) read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size) - # run blockwise prediction + # run blockwise post-processing run_blockwise( worker_file=str( Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py") From 9def264e9c5691b5921956097a5c79cd9da67c66 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 22 Feb 2024 00:38:03 -0500 Subject: [PATCH 19/19] =?UTF-8?q?chore:=20=F0=9F=94=A5=20Cleaning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tasks/post_processors/argmax_post_processor.py | 7 +++---- .../tasks/post_processors/dummy_post_processor.py | 2 +- .../post_processors/threshold_post_processor.py | 13 +++++-------- .../post_processors/watershed_post_processor.py | 5 ++--- tests/fixtures/post_processors.py | 1 - 5 files changed, 11 insertions(+), 17 deletions(-) diff --git a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py index 4302b6d08..42863b56d 100644 --- a/dacapo/experiments/tasks/post_processors/argmax_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/argmax_post_processor.py @@ -20,6 +20,7 @@ def enumerate_parameters(self): yield ArgmaxPostProcessorParameters(id=1) def set_prediction(self, prediction_array_identifier): + self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) @@ -27,7 +28,7 @@ def set_prediction(self, prediction_array_identifier): def process( self, parameters, - output_array_identifier, + output_array_identifier: "LocalArrayIdentifier", num_workers: int = 16, block_size: Coordinate = Coordinate((64, 64, 64)), ): @@ -53,9 +54,7 @@ def process( max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option ###### - input_array_identifier=LocalArrayIdentifier( - self.prediction_array.file_name, self.prediction_array.dataset - ), + input_array_identifier=self.prediction_array_identifier, output_array_identifier=output_array_identifier, ) diff --git a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py index 4ab60457a..4a992ced2 100644 --- a/dacapo/experiments/tasks/post_processors/dummy_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/dummy_post_processor.py @@ -18,7 +18,7 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]: for i, min_size in enumerate(range(1, 11)): yield DummyPostProcessorParameters(id=i, min_size=min_size) - def set_prediction(self, prediction_array): + def set_prediction(self, prediction_array_identifier): pass def process(self, parameters, output_array_identifier, *args, **kwargs): diff --git a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py index 4c337ed62..5d3b45220 100644 --- a/dacapo/experiments/tasks/post_processors/threshold_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/threshold_post_processor.py @@ -2,15 +2,13 @@ from dacapo.blockwise.scheduler import run_blockwise from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray from .threshold_post_processor_parameters import ThresholdPostProcessorParameters +from dacapo.store.array_store import LocalArrayIdentifier from .post_processor import PostProcessor import dacapo.blockwise import numpy as np from daisy import Roi, Coordinate -from typing import TYPE_CHECKING, Iterable - -# if TYPE_CHECKING: -from dacapo.store.local_array_store import LocalArrayIdentifier +from typing import Iterable class ThresholdPostProcessor(PostProcessor): @@ -22,7 +20,8 @@ def enumerate_parameters(self) -> Iterable["ThresholdPostProcessorParameters"]: for i, threshold in enumerate([-0.1, 0.0, 0.1]): yield ThresholdPostProcessorParameters(id=i, threshold=threshold) - def set_prediction(self, prediction_array_identifier: "LocalArrayIdentifier"): + def set_prediction(self, prediction_array_identifier): + self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) @@ -67,9 +66,7 @@ def process( max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option ###### - input_array_identifier=LocalArrayIdentifier( - self.prediction_array.file_name, self.prediction_array.dataset - ), + input_array_identifier=self.prediction_array_identifier, output_array_identifier=output_array_identifier, threshold=parameters.threshold, ) diff --git a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py index 7ada70a6a..fa9d10a47 100644 --- a/dacapo/experiments/tasks/post_processors/watershed_post_processor.py +++ b/dacapo/experiments/tasks/post_processors/watershed_post_processor.py @@ -27,6 +27,7 @@ def enumerate_parameters(self): yield WatershedPostProcessorParameters(id=i, bias=bias) def set_prediction(self, prediction_array_identifier): + self.prediction_array_identifier = prediction_array_identifier self.prediction_array = ZarrArray.open_from_array_identifier( prediction_array_identifier ) @@ -66,9 +67,7 @@ def process( max_retries=2, # TODO: make this an option timeout=None, # TODO: make this an option ###### - input_array_identifier=LocalArrayIdentifier( - self.prediction_array.file_name, self.prediction_array.dataset - ), + input_array_identifier=self.prediction_array_identifier, output_array_identifier=output_array_identifier, parameters=pars, ) diff --git a/tests/fixtures/post_processors.py b/tests/fixtures/post_processors.py index 95f51abbf..61d01e6d8 100644 --- a/tests/fixtures/post_processors.py +++ b/tests/fixtures/post_processors.py @@ -1,7 +1,6 @@ from dacapo.experiments.tasks.post_processors import ( ArgmaxPostProcessor, ThresholdPostProcessor, - PostProcessorParameters, ) import pytest