Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate patch #278

Merged
merged 23 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dacapo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from . import experiments, utils # noqa
from .apply import apply # noqa
from .train import train # noqa
from .validate import validate # noqa
from .validate import validate, validate_run # noqa
from .predict import predict # noqa
from .blockwise import run_blockwise, segment_blockwise # noqa
1 change: 1 addition & 0 deletions dacapo/blockwise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .blockwise_task import DaCapoBlockwiseTask
from .scheduler import run_blockwise, segment_blockwise
from . import global_vars
1 change: 1 addition & 0 deletions dacapo/blockwise/global_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
current_run = None
185 changes: 101 additions & 84 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import click
from dacapo.blockwise import global_vars

import logging

Expand All @@ -27,6 +28,17 @@
path = __file__


def is_global_run_set(run_name) -> bool:
found = global_vars.current_run is not None
if found:
found = global_vars.current_run.name == run_name
if not found:
logger.error(
f"Found global run {global_vars.current_run.name} but looking for {run_name}"
)
return found


@click.group()
@click.option(
"--log-level",
Expand Down Expand Up @@ -70,7 +82,7 @@ def cli(log_level):
)
@click.option("-od", "--output_dataset", required=True, type=str)
def start_worker(
run_name: str | Run,
run_name: str,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
Expand All @@ -90,7 +102,7 @@ def start_worker(


def start_worker_fn(
run_name: str | Run,
run_name: str,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
Expand All @@ -109,93 +121,97 @@ def start_worker_fn(
output_container (Path | str): The output container.
output_dataset (str): The output dataset.
"""
compute_context = create_compute_context()
device = compute_context.device

# retrieving run
if isinstance(run_name, Run):
run = run_name
run_name = run.name
else:
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)
def io_loop():
daisy_client = daisy.Client()

if iteration is not None:
# create weights store
weights_store = create_weights_store()
compute_context = create_compute_context()
device = compute_context.device

if is_global_run_set(run_name):
logger.warning("Using global run variable")
run = global_vars.current_run
else:
logger.warning("initiating local run in predict_worker")
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

if iteration is not None and compute_context.distribute_workers:
# create weights store
weights_store = create_weights_store()

# load weights
run.model.load_state_dict(
weights_store.retrieve_weights(run_name, iteration).model
)

# load weights
run.model.load_state_dict(
weights_store.retrieve_weights(run_name, iteration).model
# get arrays
input_array_identifier = LocalArrayIdentifier(
Path(input_container), input_dataset
)
raw_array = ZarrArray.open_from_array_identifier(input_array_identifier)

# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
raw_array = ZarrArray.open_from_array_identifier(input_array_identifier)

output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)

# set benchmark flag to True for performance
torch.backends.cudnn.benchmark = True

# get the model's input and output size
model = run.model.eval().to(device)
input_voxel_size = Coordinate(raw_array.voxel_size)
output_voxel_size = model.scale(input_voxel_size)
input_shape = Coordinate(model.eval_input_shape)
input_size = input_voxel_size * input_shape
output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

print(f"Predicting with input size {input_size}, output size {output_size}")

# create gunpowder keys

raw = gp.ArrayKey("RAW")
prediction = gp.ArrayKey("PREDICTION")

# assemble prediction pipeline

# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, None)
# raw: (c, d, h, w)
pipeline += gp.Unsqueeze([raw])
# raw: (1, c, d, h, w)

pipeline += gp.Normalize(raw)

# predict
# model.eval()
pipeline += gp_torch.Predict(
model=model,
inputs={"x": raw},
outputs={0: prediction},
array_specs={
prediction: gp.ArraySpec(
voxel_size=output_voxel_size,
dtype=np.float32, # assumes network output is float32
)
},
spawn_subprocess=False,
device=str(device),
)

# make reference batch request
request = gp.BatchRequest()
request.add(raw, input_size, voxel_size=input_voxel_size)
request.add(
prediction,
output_size,
voxel_size=output_voxel_size,
)
output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)

# set benchmark flag to True for performance
torch.backends.cudnn.benchmark = True

# get the model's input and output size
model = run.model.eval()
# .to(device)
input_voxel_size = Coordinate(raw_array.voxel_size)
output_voxel_size = model.scale(input_voxel_size)
input_shape = Coordinate(model.eval_input_shape)
input_size = input_voxel_size * input_shape
output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

print(f"Predicting with input size {input_size}, output size {output_size}")

# create gunpowder keys

raw = gp.ArrayKey("RAW")
prediction = gp.ArrayKey("PREDICTION")

# assemble prediction pipeline

# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, None)
# raw: (c, d, h, w)
pipeline += gp.Unsqueeze([raw])
# raw: (1, c, d, h, w)

pipeline += gp.Normalize(raw)

# predict
# model.eval()
pipeline += gp_torch.Predict(
model=model,
inputs={"x": raw},
outputs={0: prediction},
array_specs={
prediction: gp.ArraySpec(
voxel_size=output_voxel_size,
dtype=np.float32, # assumes network output is float32
)
},
spawn_subprocess=False,
device=str(device),
)

def io_loop():
daisy_client = daisy.Client()
# make reference batch request
request = gp.BatchRequest()
request.add(raw, input_size, voxel_size=input_voxel_size)
request.add(
prediction,
output_size,
voxel_size=output_voxel_size,
)

while True:
with daisy_client.acquire_block() as block:
Expand Down Expand Up @@ -231,7 +247,7 @@ def io_loop():


def spawn_worker(
run_name: str | Run,
run_name: str,
iteration: int | None,
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
Expand All @@ -248,6 +264,7 @@ def spawn_worker(
Callable: The function to run the worker.
"""
compute_context = create_compute_context()

if not compute_context.distribute_workers:
return start_worker_fn(
run_name=run_name,
Expand Down
4 changes: 2 additions & 2 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def train(run_name):
@click.option("-w", "--num_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
@click.option("-ow", "--overwrite", is_flag=True)
def validate(run_name, iteration):
dacapo.validate(run_name, iteration)
def validate(run_name, iteration, num_workers, output_dtype, overwrite):
dacapo.validate_run(run_name, iteration, num_workers, output_dtype, overwrite)


@cli.command()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, array_config):
`like` method to create a new OnesArray with the same metadata as
another array.
"""
logger.warning("OnesArray is deprecated. Use ConstantArray instead.")
self._source_array = array_config.source_array_config.array_type(
array_config.source_array_config
)
Expand Down Expand Up @@ -406,5 +407,4 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
specified by the region of interest. This method returns a subarray
of the array with all values set to 1.
"""
logger.warning("OnesArray is deprecated. Use ConstantArray instead.")
return np.ones_like(self.source_array.__getitem__(roi), dtype=bool)
16 changes: 11 additions & 5 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,17 @@ def data(self) -> Any:
"""
file_name = str(self.file_name)
# Zarr library does not detect the store for N5 datasets
if file_name.endswith(".n5"):
zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode)
else:
zarr_container = zarr.open(str(file_name), mode=self.mode)
return zarr_container[self.dataset]
try:
if file_name.endswith(".n5"):
zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode)
else:
zarr_container = zarr.open(str(file_name), mode=self.mode)
return zarr_container[self.dataset]
except Exception as e:
logger.error(
f"Could not open dataset {self.dataset} in file {file_name} in mode {self.mode}"
)
raise e

def __getitem__(self, roi: Roi) -> np.ndarray:
"""
Expand Down
20 changes: 13 additions & 7 deletions dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,11 @@ def class_name(self):
Notes:
This function is used to get the class name.
"""
if self._class_name is None:
if self.targets is None:
logger.warning("Both targets and class name are None.")
return None
self._class_name = self.targets
return self._class_name

# Goal is to force class_name to be set only once, so we have the same classes for all datasets
Expand Down Expand Up @@ -730,10 +735,14 @@ def __generate_semantic_seg_datasplit(self):
gt_config,
mask_config,
) = self.__generate_semantic_seg_dataset_crop(dataset)
if type(self.class_name) == list:
classes = self.classes_separator_caracter.join(self.class_name)
else:
classes = self.class_name
if dataset.dataset_type == DatasetType.train:
train_dataset_configs.append(
RawGTDatasetConfig(
name=f"{dataset}_{self.class_name}_{self.output_resolution[0]}nm",
name=f"{dataset}_{gt_config.name}_{classes}_{self.output_resolution[0]}nm",
raw_config=raw_config,
gt_config=gt_config,
mask_config=mask_config,
Expand All @@ -742,16 +751,13 @@ def __generate_semantic_seg_datasplit(self):
else:
validation_dataset_configs.append(
RawGTDatasetConfig(
name=f"{dataset}_{self.class_name}_{self.output_resolution[0]}nm",
name=f"{dataset}_{gt_config.name}_{classes}_{self.output_resolution[0]}nm",
raw_config=raw_config,
gt_config=gt_config,
mask_config=mask_config,
)
)
if type(self.class_name) == list:
classes = self.classes_separator_caracter.join(self.class_name)
else:
classes = self.class_name

return TrainValidateDataSplitConfig(
name=f"{self.name}_{self.segmentation_type}_{classes}_{self.output_resolution[0]}nm",
train_configs=train_dataset_configs,
Expand Down Expand Up @@ -815,7 +821,7 @@ def __generate_semantic_seg_dataset_crop(self, dataset: DatasetSpec):
organelle_arrays = {}
# classes_datasets, classes = self.check_class_name(gt_dataset)
classes_datasets, classes = format_class_name(
gt_dataset, self.classes_separator_caracter
gt_dataset, self.classes_separator_caracter, self.targets
)
for current_class_dataset, current_class_name in zip(classes_datasets, classes):
if not (gt_path / current_class_dataset).exists():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def process(
self,
parameters: "ThresholdPostProcessorParameters", # type: ignore[override]
output_array_identifier: "LocalArrayIdentifier",
num_workers: int = 16,
num_workers: int = 12,
block_size: Coordinate = Coordinate((256, 256, 256)),
) -> ZarrArray:
"""
Expand Down Expand Up @@ -122,7 +122,7 @@ def process(

read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :])
# run blockwise post-processing
run_blockwise(
sucess = run_blockwise(
worker_file=str(
Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py")
),
Expand All @@ -138,4 +138,7 @@ def process(
threshold=parameters.threshold,
)

if not sucess:
raise RuntimeError("Blockwise post-processing failed.")

return output_array
Loading
Loading