Skip to content

Commit

Permalink
better datasplit generator (#272)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Jul 22, 2024
2 parents 5f6b7ea + 8ca78c4 commit 33c6f52
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 34 deletions.
1 change: 0 additions & 1 deletion dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def start_worker_fn(
device = compute_context.device

# retrieving run
logger.error(f"run_name: {run_name} {type(run_name)}")
if isinstance(run_name, Run):
run = run_name
run_name = run.name
Expand Down
122 changes: 121 additions & 1 deletion dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import numpy as np

import yaml
import dacapo
import click
import logging
Expand All @@ -17,6 +17,8 @@
)
from dacapo.store.local_array_store import LocalArrayIdentifier
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.options import DaCapoConfig
import os


@click.group()
Expand Down Expand Up @@ -686,6 +688,124 @@ def segment_blockwise(
)


def prompt_with_choices(prompt_text, choices, default_index=0):
"""
Prompts the user with a list of choices and returns the selected choice.
Args:
prompt_text (str): The prompt text to display to the user.
choices (list): The list of choices to present.
default_index (int): The index of the default choice (0-based).
Returns:
str: The selected choice.
"""
while True:
click.echo(prompt_text)
for i, choice in enumerate(choices, 1):
click.echo(f"{i} - {choice}")

# If the default_index is out of range, set to 0
default_index = max(0, min(default_index, len(choices) - 1))

try:
# Prompt the user for input
choice_num = click.prompt(
f"Enter your choice (default: {choices[default_index]})",
default=default_index + 1,
type=int,
)

# Check if the provided number is valid
if 1 <= choice_num <= len(choices):
return choices[choice_num - 1]
else:
click.echo("Invalid choice number. Please try again.")
except click.BadParameter:
click.echo("Invalid input. Please enter a number.")


@cli.command()
def config():
if os.path.exists("dacapo.yaml"):
overwrite = click.confirm(
"dacapo.yaml already exists. Do you want to overwrite it?", default=False
)
if not overwrite:
click.echo("Aborting configuration creation.")
return
runs_base_dir = click.prompt("Enter the base directory for runs", type=str)
storage_type = prompt_with_choices("Enter the type of storage:", ["files", "mongo"])
mongo_db_name = None
mongo_db_host = None
if storage_type == "mongo":
mongo_db_name = click.prompt("Enter the name of the MongoDB database", type=str)
mongo_db_host = click.prompt("Enter the MongoDB host URI", type=str)

compute_type = prompt_with_choices(
"Enter the type of compute context:", ["LocalTorch", "Bsub"]
)
if compute_type == "Bsub":
queue = click.prompt("Enter the queue for compute context", type=str)
num_gpus = click.prompt("Enter the number of GPUs", type=int)
num_cpus = click.prompt("Enter the number of CPUs", type=int)
billing = click.prompt("Enter the billing account", type=str)
compute_context = {
"type": compute_type,
"config": {
"queue": queue,
"num_gpus": num_gpus,
"num_cpus": num_cpus,
"billing": billing,
},
}
else:
compute_context = {"type": compute_type}

try:
generate_config(
runs_base_dir,
storage_type,
compute_type,
compute_context,
mongo_db_name,
mongo_db_host,
)
except ValueError as e:
logger.error(str(e))


def generate_dacapo_yaml(config):
with open("dacapo.yaml", "w") as f:
yaml.dump(config.serialize(), f, default_flow_style=False)
print("dacapo.yaml has been created.")


def generate_config(
runs_base_dir,
storage_type,
compute_type,
compute_context,
mongo_db_name=None,
mongo_db_host=None,
):
config = DaCapoConfig(
type=storage_type,
runs_base_dir=Path(runs_base_dir).expanduser(),
compute_context=compute_context,
)

if storage_type == "mongo":
if not mongo_db_name or not mongo_db_host:
raise ValueError(
"--mongo_db_name and --mongo_db_host are required when type is 'mongo'"
)
config.mongo_db_name = mongo_db_name
config.mongo_db_host = mongo_db_host

generate_dacapo_yaml(config)


def unpack_ctx(ctx):
"""
Unpacks the context object and returns a dictionary of keyword arguments.
Expand Down
14 changes: 11 additions & 3 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def resize_if_needed(

raw_upsample = raw_voxel_size / target_resolution
raw_downsample = target_resolution / raw_voxel_size
assert len(target_resolution) == zarr_array.dims, (
f"Target resolution {target_resolution} and raw voxel size {raw_voxel_size} "
f"have different dimensions {zarr_array.dims}"
)
if any([u > 1 or d > 1 for u, d in zip(raw_upsample, raw_downsample)]):
return ResampledArrayConfig(
name=f"{extra_str}_{array_config.name}_{array_config.dataset}_resampled",
Expand Down Expand Up @@ -653,7 +657,7 @@ def check_class_name(self, class_name):
"""
datasets, classes = format_class_name(
class_name, self.classes_separator_caracter
class_name, self.classes_separator_caracter, self.targets
)
if self.class_name is None:
self.class_name = classes
Expand Down Expand Up @@ -893,7 +897,7 @@ def generate_from_csv(
)


def format_class_name(class_name, separator_character="&"):
def format_class_name(class_name, separator_character="&", targets=None):
"""
Format the class name.
Expand All @@ -919,4 +923,8 @@ def format_class_name(class_name, separator_character="&"):
base_class_name = class_name.split("[")[0]
return [f"{base_class_name}{c}" for c in classes], classes
else:
raise ValueError(f"Invalid class name {class_name} missing '[' and ']'")
if targets is None:
raise ValueError(f"Invalid class name {class_name} missing '[' and ']'")
if len(targets) > 1:
raise ValueError(f"Invalid class name {class_name} missing '[' and ']'")
return [class_name], [targets[0]]
11 changes: 8 additions & 3 deletions dacapo/experiments/training_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

from typing import List
import attr
import logging

logger = logging.getLogger(__name__)


@attr.s
Expand Down Expand Up @@ -62,9 +65,11 @@ def add_iteration_stats(self, iteration_stats: TrainingIterationStats) -> None:
- The inner list contains the stats for each training iteration.
"""
if len(self.iteration_stats) > 0:
assert (
iteration_stats.iteration == self.iteration_stats[-1].iteration + 1
), f"Expected iteration {self.iteration_stats[-1].iteration + 1}, got {iteration_stats.iteration}"
if iteration_stats.iteration <= self.iteration_stats[-1].iteration:
logger.error(
f"Expected iteration {self.iteration_stats[-1].iteration + 1}, got {iteration_stats.iteration}. will remove stats after {iteration_stats.iteration-1}"
)
self.delete_after(iteration_stats.iteration - 1)

self.iteration_stats.append(iteration_stats)

Expand Down
3 changes: 2 additions & 1 deletion dacapo/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def serialize(self):
{'type': 'files', 'runs_base_dir': '/home/user/dacapo', 'compute_context': {'type': 'LocalTorch', 'config': {}}, 'mongo_db_host': None, 'mongo_db_name': None}
"""
converter = Converter()
return converter.unstructure(self)
data = converter.unstructure(self)
return {k: v for k, v in data.items() if v is not None}


class Options:
Expand Down
3 changes: 2 additions & 1 deletion dacapo/store/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,11 @@ class from unstructured data.
cls = cls_fn(obj_data["__type__"])
structure_fn = make_dict_structure_fn(cls, self)
return structure_fn(obj_data, cls)
except:
except Exception as e:
print(
f"Could not structure object of type {obj_data}. will try unstructured data. attr __type__ can be missing because of old version of the data."
)
print(e)
return obj_data


Expand Down
49 changes: 25 additions & 24 deletions dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def train(run_name: str):
return train_run(run)


def train_run(run: Run):
def train_run(run: Run, do_validate=True):
"""
Train a run
Expand Down Expand Up @@ -184,30 +184,31 @@ def train_run(run: Run):

stats_store.store_training_stats(run.name, run.training_stats)
weights_store.store_weights(run, iteration_stats.iteration + 1)
try:
# launch validation in a separate thread to avoid blocking training
if compute_context.distribute_workers:
validate_thread = threading.Thread(
target=validate,
args=(run, iteration_stats.iteration + 1),
name=f"validate_{run.name}_{iteration_stats.iteration + 1}",
daemon=True,
if do_validate:
try:
# launch validation in a separate thread to avoid blocking training
if compute_context.distribute_workers:
validate_thread = threading.Thread(
target=validate,
args=(run, iteration_stats.iteration + 1),
name=f"validate_{run.name}_{iteration_stats.iteration + 1}",
daemon=True,
)
validate_thread.start()
else:
validate(
run,
iteration_stats.iteration + 1,
)

stats_store.store_validation_iteration_scores(
run.name, run.validation_scores
)
validate_thread.start()
else:
validate(
run,
iteration_stats.iteration + 1,
except Exception as e:
logger.error(
f"Validation failed for run {run.name} at iteration "
f"{iteration_stats.iteration + 1}.",
exc_info=e,
)

stats_store.store_validation_iteration_scores(
run.name, run.validation_scores
)
except Exception as e:
logger.error(
f"Validation failed for run {run.name} at iteration "
f"{iteration_stats.iteration + 1}.",
exc_info=e,
)

print(f"Trained until {trained_until}. Finished.")

0 comments on commit 33c6f52

Please sign in to comment.