Skip to content

Commit

Permalink
Merge branch 'main' of github.com:janelia-cellmap/dacapo
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Aug 12, 2024
2 parents 2ed077e + 2e77a18 commit 5bcd5e0
Show file tree
Hide file tree
Showing 14 changed files with 372 additions and 149 deletions.
62 changes: 62 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@

cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Patton"
given-names: "William"
orcid: "https://orcid.org/0000-0002-9652-3222"
- family-names: "Rhoades"
given-names: "Jeff L."
orcid: "https://orcid.org/0000-0001-5077-2533"
- family-names: "Zouinkhi"
given-names: "Marwan"
orcid: "https://orcid.org/0000-0002-9441-2908"
- family-names: "Funke"
given-names: "Jan"
orcid: "http://orcid.org/0000-0003-4388-7783"
title: "DaCapo"
version: 0.3.0
doi: 10.48550/arXiv.2408.02834
date-released: 2024-08-05
url: "https://github.com/janelia-cellmap/dacapo"
preferred-citation:
type: article
authors:
- family-names: "Patton"
given-names: "William"
orcid: "https://orcid.org/0000-0002-9652-3222"
- family-names: "Rhoades"
given-names: "Jeff L."
orcid: "https://orcid.org/0000-0001-5077-2533"
- family-names: "Zouinkhi"
given-names: "Marwan"
orcid: "https://orcid.org/0000-0002-9441-2908"
- family-names: "Ackerman"
given-names: "David G."
orcid: "http://orcid.org/0000-0003-0172-6594"
- family-names: "Malin-Mayor"
given-names: "Caroline"
orcid: "https://orcid.org/0000-0002-9627-6030"
- family-names: "Adjavon"
given-names: "Diane"
- family-names: "Heinrich"
given-names: "Larissa"
orcid: "http://orcid.org/0000-0003-2852-6664"
- family-names: "Bennett"
given-names: "Davis"
orcid: "http://orcid.org/0000-0001-7579-2848"
- family-names: "Zubov"
given-names: "Yurii"
orcid: "https://orcid.org/0000-0003-1988-8081"
- family-names: "Project Team"
given-names: "CellMap"
- family-names: "Weigel"
given-names: "Aubrey V."
orcid: "http://orcid.org/0000-0003-1694-4420"
- family-names: "Funke"
given-names: "Jan"
orcid: "http://orcid.org/0000-0003-4388-7783"
doi: 10.48550/arXiv.2408.02834
journal: "arXiv-cs.CV"
title: "DaCapo: a modular deep learning framework for scalable 3D image segmentation"
year: 2024
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,15 @@ Tasks we support and approaches for those tasks:
- Example of [groundtruth data](https://tinyurl.com/pu8mespz)
- Visualization
- [Neuroglancer GitHub Repo](https://github.com/google/neuroglancer)


# Citing this repo
If you use our code, please cite us and spread the news!
```
@article{Patton_DaCapo_a_modular_2024,
author = {Patton, William and Rhoades, Jeff L. and Zouinkhi, Marwan and Ackerman, David G. and Malin-Mayor, Caroline and Adjavon, Diane and Heinrich, Larissa and Bennett, Davis and Zubov, Yurii and Project Team, CellMap and Weigel, Aubrey V. and Funke, Jan},
doi = {10.48550/arXiv.2408.02834},
journal = {arXiv-cs.CV},
title = {{DaCapo: a modular deep learning framework for scalable 3D image segmentation}},
year = {2024}
}
```
2 changes: 1 addition & 1 deletion dacapo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
from . import experiments, utils # noqa
from .apply import apply # noqa
from .train import train # noqa
from .validate import validate # noqa
from .validate import validate, validate_run # noqa
from .predict import predict # noqa
from .blockwise import run_blockwise, segment_blockwise # noqa
1 change: 1 addition & 0 deletions dacapo/blockwise/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .blockwise_task import DaCapoBlockwiseTask
from .scheduler import run_blockwise, segment_blockwise
from . import global_vars
1 change: 1 addition & 0 deletions dacapo/blockwise/global_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
current_run = None
188 changes: 104 additions & 84 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import numpy as np
import click
from dacapo.blockwise import global_vars

import logging

Expand All @@ -27,6 +28,20 @@
path = __file__


def is_global_run_set(run_name) -> bool:
if global_vars.current_run is not None:
if global_vars.current_run.name == run_name:
return True
else:
logger.error(
f"Found global run {global_vars.current_run.name} but looking for {run_name}"
)
return False
else:
logger.error("No global run is set.")
return False


@click.group()
@click.option(
"--log-level",
Expand Down Expand Up @@ -70,7 +85,7 @@ def cli(log_level):
)
@click.option("-od", "--output_dataset", required=True, type=str)
def start_worker(
run_name: str | Run,
run_name: str,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
Expand All @@ -90,7 +105,7 @@ def start_worker(


def start_worker_fn(
run_name: str | Run,
run_name: str,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
Expand All @@ -109,93 +124,97 @@ def start_worker_fn(
output_container (Path | str): The output container.
output_dataset (str): The output dataset.
"""
compute_context = create_compute_context()
device = compute_context.device

# retrieving run
if isinstance(run_name, Run):
run = run_name
run_name = run.name
else:
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)
def io_loop():
daisy_client = daisy.Client()

if iteration is not None:
# create weights store
weights_store = create_weights_store()
compute_context = create_compute_context()
device = compute_context.device

if is_global_run_set(run_name):
logger.warning("Using global run variable")
run = global_vars.current_run
else:
logger.warning("initiating local run in predict_worker")
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

if iteration is not None and compute_context.distribute_workers:
# create weights store
weights_store = create_weights_store()

# load weights
run.model.load_state_dict(
weights_store.retrieve_weights(run_name, iteration).model
)

# load weights
run.model.load_state_dict(
weights_store.retrieve_weights(run_name, iteration).model
# get arrays
input_array_identifier = LocalArrayIdentifier(
Path(input_container), input_dataset
)
raw_array = ZarrArray.open_from_array_identifier(input_array_identifier)

# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
raw_array = ZarrArray.open_from_array_identifier(input_array_identifier)

output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)

# set benchmark flag to True for performance
torch.backends.cudnn.benchmark = True

# get the model's input and output size
model = run.model.eval().to(device)
input_voxel_size = Coordinate(raw_array.voxel_size)
output_voxel_size = model.scale(input_voxel_size)
input_shape = Coordinate(model.eval_input_shape)
input_size = input_voxel_size * input_shape
output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

print(f"Predicting with input size {input_size}, output size {output_size}")

# create gunpowder keys

raw = gp.ArrayKey("RAW")
prediction = gp.ArrayKey("PREDICTION")

# assemble prediction pipeline

# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, None)
# raw: (c, d, h, w)
pipeline += gp.Unsqueeze([raw])
# raw: (1, c, d, h, w)

pipeline += gp.Normalize(raw)

# predict
# model.eval()
pipeline += gp_torch.Predict(
model=model,
inputs={"x": raw},
outputs={0: prediction},
array_specs={
prediction: gp.ArraySpec(
voxel_size=output_voxel_size,
dtype=np.float32, # assumes network output is float32
)
},
spawn_subprocess=False,
device=str(device),
)

# make reference batch request
request = gp.BatchRequest()
request.add(raw, input_size, voxel_size=input_voxel_size)
request.add(
prediction,
output_size,
voxel_size=output_voxel_size,
)
output_array_identifier = LocalArrayIdentifier(
Path(output_container), output_dataset
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)

# set benchmark flag to True for performance
torch.backends.cudnn.benchmark = True

# get the model's input and output size
model = run.model.eval()
# .to(device)
input_voxel_size = Coordinate(raw_array.voxel_size)
output_voxel_size = model.scale(input_voxel_size)
input_shape = Coordinate(model.eval_input_shape)
input_size = input_voxel_size * input_shape
output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]

print(f"Predicting with input size {input_size}, output size {output_size}")

# create gunpowder keys

raw = gp.ArrayKey("RAW")
prediction = gp.ArrayKey("PREDICTION")

# assemble prediction pipeline

# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, None)
# raw: (c, d, h, w)
pipeline += gp.Unsqueeze([raw])
# raw: (1, c, d, h, w)

pipeline += gp.Normalize(raw)

# predict
# model.eval()
pipeline += gp_torch.Predict(
model=model,
inputs={"x": raw},
outputs={0: prediction},
array_specs={
prediction: gp.ArraySpec(
voxel_size=output_voxel_size,
dtype=np.float32, # assumes network output is float32
)
},
spawn_subprocess=False,
device=str(device),
)

def io_loop():
daisy_client = daisy.Client()
# make reference batch request
request = gp.BatchRequest()
request.add(raw, input_size, voxel_size=input_voxel_size)
request.add(
prediction,
output_size,
voxel_size=output_voxel_size,
)

while True:
with daisy_client.acquire_block() as block:
Expand Down Expand Up @@ -231,7 +250,7 @@ def io_loop():


def spawn_worker(
run_name: str | Run,
run_name: str,
iteration: int | None,
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
Expand All @@ -248,6 +267,7 @@ def spawn_worker(
Callable: The function to run the worker.
"""
compute_context = create_compute_context()

if not compute_context.distribute_workers:
return start_worker_fn(
run_name=run_name,
Expand Down
4 changes: 2 additions & 2 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def train(run_name):
@click.option("-w", "--num_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
@click.option("-ow", "--overwrite", is_flag=True)
def validate(run_name, iteration):
dacapo.validate(run_name, iteration)
def validate(run_name, iteration, num_workers, output_dtype, overwrite):
dacapo.validate_run(run_name, iteration, num_workers, output_dtype, overwrite)


@cli.command()
Expand Down
16 changes: 11 additions & 5 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,17 @@ def data(self) -> Any:
"""
file_name = str(self.file_name)
# Zarr library does not detect the store for N5 datasets
if file_name.endswith(".n5"):
zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode)
else:
zarr_container = zarr.open(str(file_name), mode=self.mode)
return zarr_container[self.dataset]
try:
if file_name.endswith(".n5"):
zarr_container = zarr.open(N5FSStore(str(file_name)), mode=self.mode)
else:
zarr_container = zarr.open(str(file_name), mode=self.mode)
return zarr_container[self.dataset]
except Exception as e:
logger.error(
f"Could not open dataset {self.dataset} in file {file_name} in mode {self.mode}"
)
raise e

def __getitem__(self, roi: Roi) -> np.ndarray:
"""
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/datasplits/datasplit_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def __generate_semantic_seg_datasplit(self):
mask_config=mask_config,
)
)

return TrainValidateDataSplitConfig(
name=f"{self.name}_{self.segmentation_type}_{classes}_{self.output_resolution[0]}nm",
train_configs=train_dataset_configs,
Expand Down
Loading

0 comments on commit 5bcd5e0

Please sign in to comment.