diff --git a/dacapo/blockwise/predict_worker.py b/dacapo/blockwise/predict_worker.py index 9ec06d298..bdaa8fc53 100644 --- a/dacapo/blockwise/predict_worker.py +++ b/dacapo/blockwise/predict_worker.py @@ -113,7 +113,6 @@ def start_worker_fn( device = compute_context.device # retrieving run - logger.error(f"run_name: {run_name} {type(run_name)}") if isinstance(run_name, Run): run = run_name run_name = run.name diff --git a/dacapo/cli.py b/dacapo/cli.py index f4a2a6c41..f05c32012 100644 --- a/dacapo/cli.py +++ b/dacapo/cli.py @@ -2,7 +2,7 @@ from typing import Optional import numpy as np - +import yaml import dacapo import click import logging @@ -17,6 +17,8 @@ ) from dacapo.store.local_array_store import LocalArrayIdentifier from dacapo.experiments.datasplits.datasets.arrays import ZarrArray +from dacapo.options import DaCapoConfig +import os @click.group() @@ -686,6 +688,124 @@ def segment_blockwise( ) +def prompt_with_choices(prompt_text, choices, default_index=0): + """ + Prompts the user with a list of choices and returns the selected choice. + + Args: + prompt_text (str): The prompt text to display to the user. + choices (list): The list of choices to present. + default_index (int): The index of the default choice (0-based). + + Returns: + str: The selected choice. + """ + while True: + click.echo(prompt_text) + for i, choice in enumerate(choices, 1): + click.echo(f"{i} - {choice}") + + # If the default_index is out of range, set to 0 + default_index = max(0, min(default_index, len(choices) - 1)) + + try: + # Prompt the user for input + choice_num = click.prompt( + f"Enter your choice (default: {choices[default_index]})", + default=default_index + 1, + type=int, + ) + + # Check if the provided number is valid + if 1 <= choice_num <= len(choices): + return choices[choice_num - 1] + else: + click.echo("Invalid choice number. Please try again.") + except click.BadParameter: + click.echo("Invalid input. Please enter a number.") + + +@cli.command() +def config(): + if os.path.exists("dacapo.yaml"): + overwrite = click.confirm( + "dacapo.yaml already exists. Do you want to overwrite it?", default=False + ) + if not overwrite: + click.echo("Aborting configuration creation.") + return + runs_base_dir = click.prompt("Enter the base directory for runs", type=str) + storage_type = prompt_with_choices("Enter the type of storage:", ["files", "mongo"]) + mongo_db_name = None + mongo_db_host = None + if storage_type == "mongo": + mongo_db_name = click.prompt("Enter the name of the MongoDB database", type=str) + mongo_db_host = click.prompt("Enter the MongoDB host URI", type=str) + + compute_type = prompt_with_choices( + "Enter the type of compute context:", ["LocalTorch", "Bsub"] + ) + if compute_type == "Bsub": + queue = click.prompt("Enter the queue for compute context", type=str) + num_gpus = click.prompt("Enter the number of GPUs", type=int) + num_cpus = click.prompt("Enter the number of CPUs", type=int) + billing = click.prompt("Enter the billing account", type=str) + compute_context = { + "type": compute_type, + "config": { + "queue": queue, + "num_gpus": num_gpus, + "num_cpus": num_cpus, + "billing": billing, + }, + } + else: + compute_context = {"type": compute_type} + + try: + generate_config( + runs_base_dir, + storage_type, + compute_type, + compute_context, + mongo_db_name, + mongo_db_host, + ) + except ValueError as e: + logger.error(str(e)) + + +def generate_dacapo_yaml(config): + with open("dacapo.yaml", "w") as f: + yaml.dump(config.serialize(), f, default_flow_style=False) + print("dacapo.yaml has been created.") + + +def generate_config( + runs_base_dir, + storage_type, + compute_type, + compute_context, + mongo_db_name=None, + mongo_db_host=None, +): + config = DaCapoConfig( + type=storage_type, + runs_base_dir=Path(runs_base_dir).expanduser(), + compute_context=compute_context, + ) + + if storage_type == "mongo": + if not mongo_db_name or not mongo_db_host: + raise ValueError( + "--mongo_db_name and --mongo_db_host are required when type is 'mongo'" + ) + config.mongo_db_name = mongo_db_name + config.mongo_db_host = mongo_db_host + + generate_dacapo_yaml(config) + + def unpack_ctx(ctx): """ Unpacks the context object and returns a dictionary of keyword arguments. diff --git a/dacapo/experiments/datasplits/datasplit_generator.py b/dacapo/experiments/datasplits/datasplit_generator.py index a1a0e45c3..3f0b908ef 100644 --- a/dacapo/experiments/datasplits/datasplit_generator.py +++ b/dacapo/experiments/datasplits/datasplit_generator.py @@ -75,6 +75,10 @@ def resize_if_needed( raw_upsample = raw_voxel_size / target_resolution raw_downsample = target_resolution / raw_voxel_size + assert len(target_resolution) == zarr_array.dims, ( + f"Target resolution {target_resolution} and raw voxel size {raw_voxel_size} " + f"have different dimensions {zarr_array.dims}" + ) if any([u > 1 or d > 1 for u, d in zip(raw_upsample, raw_downsample)]): return ResampledArrayConfig( name=f"{extra_str}_{array_config.name}_{array_config.dataset}_resampled", @@ -653,7 +657,7 @@ def check_class_name(self, class_name): """ datasets, classes = format_class_name( - class_name, self.classes_separator_caracter + class_name, self.classes_separator_caracter, self.targets ) if self.class_name is None: self.class_name = classes @@ -893,7 +897,7 @@ def generate_from_csv( ) -def format_class_name(class_name, separator_character="&"): +def format_class_name(class_name, separator_character="&", targets=None): """ Format the class name. @@ -919,4 +923,8 @@ def format_class_name(class_name, separator_character="&"): base_class_name = class_name.split("[")[0] return [f"{base_class_name}{c}" for c in classes], classes else: - raise ValueError(f"Invalid class name {class_name} missing '[' and ']'") + if targets is None: + raise ValueError(f"Invalid class name {class_name} missing '[' and ']'") + if len(targets) > 1: + raise ValueError(f"Invalid class name {class_name} missing '[' and ']'") + return [class_name], [targets[0]] diff --git a/dacapo/experiments/training_stats.py b/dacapo/experiments/training_stats.py index eef5f2c97..1b84a7424 100644 --- a/dacapo/experiments/training_stats.py +++ b/dacapo/experiments/training_stats.py @@ -5,6 +5,9 @@ from typing import List import attr +import logging + +logger = logging.getLogger(__name__) @attr.s @@ -62,9 +65,11 @@ def add_iteration_stats(self, iteration_stats: TrainingIterationStats) -> None: - The inner list contains the stats for each training iteration. """ if len(self.iteration_stats) > 0: - assert ( - iteration_stats.iteration == self.iteration_stats[-1].iteration + 1 - ), f"Expected iteration {self.iteration_stats[-1].iteration + 1}, got {iteration_stats.iteration}" + if iteration_stats.iteration <= self.iteration_stats[-1].iteration: + logger.error( + f"Expected iteration {self.iteration_stats[-1].iteration + 1}, got {iteration_stats.iteration}. will remove stats after {iteration_stats.iteration-1}" + ) + self.delete_after(iteration_stats.iteration - 1) self.iteration_stats.append(iteration_stats) diff --git a/dacapo/options.py b/dacapo/options.py index fb40b94b2..e52a0c57a 100644 --- a/dacapo/options.py +++ b/dacapo/options.py @@ -73,7 +73,8 @@ def serialize(self): {'type': 'files', 'runs_base_dir': '/home/user/dacapo', 'compute_context': {'type': 'LocalTorch', 'config': {}}, 'mongo_db_host': None, 'mongo_db_name': None} """ converter = Converter() - return converter.unstructure(self) + data = converter.unstructure(self) + return {k: v for k, v in data.items() if v is not None} class Options: diff --git a/dacapo/store/converter.py b/dacapo/store/converter.py index 62bb2f4df..7e5451b3c 100644 --- a/dacapo/store/converter.py +++ b/dacapo/store/converter.py @@ -121,10 +121,11 @@ class from unstructured data. cls = cls_fn(obj_data["__type__"]) structure_fn = make_dict_structure_fn(cls, self) return structure_fn(obj_data, cls) - except: + except Exception as e: print( f"Could not structure object of type {obj_data}. will try unstructured data. attr __type__ can be missing because of old version of the data." ) + print(e) return obj_data diff --git a/dacapo/train.py b/dacapo/train.py index 9a2b5474c..e5c2ae4cc 100644 --- a/dacapo/train.py +++ b/dacapo/train.py @@ -47,7 +47,7 @@ def train(run_name: str): return train_run(run) -def train_run(run: Run): +def train_run(run: Run, do_validate=True): """ Train a run @@ -184,30 +184,31 @@ def train_run(run: Run): stats_store.store_training_stats(run.name, run.training_stats) weights_store.store_weights(run, iteration_stats.iteration + 1) - try: - # launch validation in a separate thread to avoid blocking training - if compute_context.distribute_workers: - validate_thread = threading.Thread( - target=validate, - args=(run, iteration_stats.iteration + 1), - name=f"validate_{run.name}_{iteration_stats.iteration + 1}", - daemon=True, + if do_validate: + try: + # launch validation in a separate thread to avoid blocking training + if compute_context.distribute_workers: + validate_thread = threading.Thread( + target=validate, + args=(run, iteration_stats.iteration + 1), + name=f"validate_{run.name}_{iteration_stats.iteration + 1}", + daemon=True, + ) + validate_thread.start() + else: + validate( + run, + iteration_stats.iteration + 1, + ) + + stats_store.store_validation_iteration_scores( + run.name, run.validation_scores ) - validate_thread.start() - else: - validate( - run, - iteration_stats.iteration + 1, + except Exception as e: + logger.error( + f"Validation failed for run {run.name} at iteration " + f"{iteration_stats.iteration + 1}.", + exc_info=e, ) - stats_store.store_validation_iteration_scores( - run.name, run.validation_scores - ) - except Exception as e: - logger.error( - f"Validation failed for run {run.name} at iteration " - f"{iteration_stats.iteration + 1}.", - exc_info=e, - ) - print(f"Trained until {trained_until}. Finished.")