Skip to content

Commit

Permalink
Validate patch (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Aug 5, 2024
2 parents 2ed077e + c529ed2 commit d3a55c9
Show file tree
Hide file tree
Showing 12 changed files with 298 additions and 148 deletions.
2 changes: 1 addition & 1 deletion dacapo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from . import experiments, utils # noqa
from .apply import apply # noqa
from .train import train # noqa
from .validate import validate # noqa
from .validate import validate, validate_run # noqa
from .predict import predict # noqa
from .blockwise import run_blockwise, segment_blockwise # noqa
1 change: 1 addition & 0 deletions dacapo/blockwise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .blockwise_task import DaCapoBlockwiseTask
from .scheduler import run_blockwise, segment_blockwise
from . import global_vars
1 change: 1 addition & 0 deletions dacapo/blockwise/global_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
current_run = None
188 changes: 104 additions & 84 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import click
from dacapo.blockwise import global_vars

import logging

Expand All @@ -27,6 +28,20 @@
path = __file__


def is_global_run_set(run_name) -> bool:
if global_vars.current_run is not None:
if global_vars.current_run.name == run_name:
return True
else:
logger.error(
f"Found global run {global_vars.current_run.name} but looking for {run_name}"
)
return False
else:
logger.error("No global run is set.")
return False


@click.group()
@click.option(
"--log-level",
Expand Down Expand Up @@ -70,7 +85,7 @@ def cli(log_level):
)
@click.option("-od", "--output_dataset", required=True, type=str)
def start_worker(
run_name: str | Run,
run_name: str,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
Expand All @@ -90,7 +105,7 @@ def start_worker(


def start_worker_fn(
run_name: str | Run,
run_name: str,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
Expand All @@ -109,93 +124,97 @@ def start_worker_fn(
output_container (Path | str): The output container.
output_dataset (str): The output dataset.
"""
compute_context = create_compute_context()
device = compute_context.device

# retrieving run
if isinstance(run_name, Run):
run = run_name
run_name = run.name
else:
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)
def io_loop():
daisy_client = daisy.Client()

if iteration is not None:
# create weights store
weights_store = create_weights_store()
compute_context = create_compute_context()
device = compute_context.device

if is_global_run_set(run_name):
logger.warning("Using global run variable")
run = global_vars.current_run
else:
logger.warning("initiating local run in predict_worker")
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

if iteration is not None and compute_context.distribute_workers:
# create weights store
weights_store = create_weights_store()

# load weights
run.model.load_state_dict(
weights_store.retrieve_weights(run_name, iteration).model
)

# load weights
run.model.load_state_dict(
weights_store.retrieve_weights(run_name, iteration).model
# get arrays
input_array_identifier = LocalArrayIdentifier(
Path(input_container), input_dataset
)
raw_array = ZarrArray.open_from_array_identifier(input_array_identifier)

# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
raw_array = ZarrArray.open_from_array_identifier(input_array_identifier)

output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)

# set benchmark flag to True for performance
torch.backends.cudnn.benchmark = True

# get the model's input and output size
model = run.model.eval().to(device)
input_voxel_size = Coordinate(raw_array.voxel_size)
output_voxel_size = model.scale(input_voxel_size)
input_shape = Coordinate(model.eval_input_shape)
input_size = input_voxel_size * input_shape
output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

print(f"Predicting with input size {input_size}, output size {output_size}")

# create gunpowder keys

raw = gp.ArrayKey("RAW")
prediction = gp.ArrayKey("PREDICTION")

# assemble prediction pipeline

# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, None)
# raw: (c, d, h, w)
pipeline += gp.Unsqueeze([raw])
# raw: (1, c, d, h, w)

pipeline += gp.Normalize(raw)

# predict
# model.eval()
pipeline += gp_torch.Predict(
model=model,
inputs={"x": raw},
outputs={0: prediction},
array_specs={
prediction: gp.ArraySpec(
voxel_size=output_voxel_size,
dtype=np.float32, # assumes network output is float32
)
},
spawn_subprocess=False,
device=str(device),
)

# make reference batch request
request = gp.BatchRequest()
request.add(raw, input_size, voxel_size=input_voxel_size)
request.add(
prediction,
output_size,
voxel_size=output_voxel_size,
)
output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)

# set benchmark flag to True for performance
torch.backends.cudnn.benchmark = True

# get the model's input and output size
model = run.model.eval()
# .to(device)
input_voxel_size = Coordinate(raw_array.voxel_size)
output_voxel_size = model.scale(input_voxel_size)
input_shape = Coordinate(model.eval_input_shape)
input_size = input_voxel_size * input_shape
output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

print(f"Predicting with input size {input_size}, output size {output_size}")

# create gunpowder keys

raw = gp.ArrayKey("RAW")
prediction = gp.ArrayKey("PREDICTION")

# assemble prediction pipeline

# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, None)
# raw: (c, d, h, w)
pipeline += gp.Unsqueeze([raw])
# raw: (1, c, d, h, w)

pipeline += gp.Normalize(raw)

# predict
# model.eval()
pipeline += gp_torch.Predict(
model=model,
inputs={"x": raw},
outputs={0: prediction},
array_specs={
prediction: gp.ArraySpec(
voxel_size=output_voxel_size,
dtype=np.float32, # assumes network output is float32
)
},
spawn_subprocess=False,
device=str(device),
)

def io_loop():
daisy_client = daisy.Client()
# make reference batch request
request = gp.BatchRequest()
request.add(raw, input_size, voxel_size=input_voxel_size)
request.add(
prediction,
output_size,
voxel_size=output_voxel_size,
)

while True:
with daisy_client.acquire_block() as block:
Expand Down Expand Up @@ -231,7 +250,7 @@ def io_loop():


def spawn_worker(
run_name: str | Run,
run_name: str,
iteration: int | None,
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
Expand All @@ -248,6 +267,7 @@ def spawn_worker(
Callable: The function to run the worker.
"""
compute_context = create_compute_context()

if not compute_context.distribute_workers:
return start_worker_fn(
run_name=run_name,
Expand Down
4 changes: 2 additions & 2 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def train(run_name):
@click.option("-w", "--num_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
@click.option("-ow", "--overwrite", is_flag=True)
def validate(run_name, iteration):
dacapo.validate(run_name, iteration)
def validate(run_name, iteration, num_workers, output_dtype, overwrite):
dacapo.validate_run(run_name, iteration, num_workers, output_dtype, overwrite)


@cli.command()
Expand Down
16 changes: 11 additions & 5 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,17 @@ def data(self) -> Any:
"""
file_name = str(self.file_name)
# Zarr library does not detect the store for N5 datasets
if file_name.endswith(".n5"):
zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode)
else:
zarr_container = zarr.open(str(file_name), mode=self.mode)
return zarr_container[self.dataset]
try:
if file_name.endswith(".n5"):
zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode)
else:
zarr_container = zarr.open(str(file_name), mode=self.mode)
return zarr_container[self.dataset]
except Exception as e:
logger.error(
f"Could not open dataset {self.dataset} in file {file_name} in mode {self.mode}"
)
raise e

def __getitem__(self, roi: Roi) -> np.ndarray:
"""
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def __generate_semantic_seg_datasplit(self):
mask_config=mask_config,
)
)

return TrainValidateDataSplitConfig(
name=f"{self.name}_{self.segmentation_type}_{classes}_{self.output_resolution[0]}nm",
train_configs=train_dataset_configs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def process(
self,
parameters: "ThresholdPostProcessorParameters", # type: ignore[override]
output_array_identifier: "LocalArrayIdentifier",
num_workers: int = 16,
num_workers: int = 12,
block_size: Coordinate = Coordinate((256, 256, 256)),
) -> ZarrArray:
"""
Expand Down Expand Up @@ -122,7 +122,7 @@ def process(

read_roi = Roi((0, 0, 0), write_size[-self.prediction_array.dims :])
# run blockwise post-processing
run_blockwise(
sucess = run_blockwise(
worker_file=str(
Path(Path(dacapo.blockwise.__file__).parent, "threshold_worker.py")
),
Expand All @@ -138,4 +138,7 @@ def process(
threshold=parameters.threshold,
)

if not sucess:
raise RuntimeError("Blockwise post-processing failed.")

return output_array
Loading

0 comments on commit d3a55c9

Please sign in to comment.