Skip to content

Commit

Permalink
Dev/main (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Feb 22, 2024
2 parents 874a4a0 + c9d0c8a commit 655b9e0
Show file tree
Hide file tree
Showing 106 changed files with 847 additions and 2,245 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dist
build
dacapo.yaml
__pycache__
scratch/

# vscode stuff
.vscode
Expand Down
3 changes: 0 additions & 3 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ To run tests with coverage locally:
`pytest tests --color=yes --cov --cov-report=term-missing`
This will also be run automatically when a PR is made to master and a codecov report will be generated telling you if your PR increased or decreased coverage.

## Doc Generation
Docstrings are generated using github action. but you can generate them using
`sphinx-build -M html docs/source/ docs/Cbuild/`

## Branching and PRs
- Users that have been added to the CellMap organization and the DaCapo project should be able to develop directly into the CellMap fork of DaCapo. Other users will need to create a fork.
Expand Down
18 changes: 0 additions & 18 deletions dacapo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,6 @@
"""
dacapo module
==============
This module contains several useful methods for performing common tasks in dacapo library.
Modules:
-----------
Options - Deals with configuring aspects of the program's operations.
experiments - This module is responsible for conducting experiments.
apply - Applies the results of the training process to the given dataset.
train - Trains the model using given data set.
validate - This module is for validating the model.
predict - This module is used to generate predictions based on the model.
"""

from .options import Options # noqa
from . import experiments # noqa
from .apply import apply # noqa
from .train import train # noqa
from .validate import validate # noqa
from .predict import predict # noqa

23 changes: 9 additions & 14 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import dacapo.experiments.tasks.post_processors as post_processors
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.predict import predict
from dacapo.compute_context import LocalTorch, ComputeContext
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store.create_store import (
create_config_store,
Expand All @@ -34,9 +33,8 @@ def apply(
iteration: Optional[int] = None,
parameters: Optional[PostProcessorParameters | str] = None,
roi: Optional[Roi | str] = None,
num_workers: int = 30,
num_workers: int = 12,
output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
file_format: str = "zarr",
):
Expand Down Expand Up @@ -144,7 +142,7 @@ def apply(
)
output_container = Path(
output_path,
"".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}",
Path(input_container).stem + f".{file_format}",
)
prediction_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
Expand All @@ -160,7 +158,7 @@ def apply(
Path(input_container, input_dataset),
)
return apply_run(
run.name,
run,
iteration,
parameters,
input_array_identifier,
Expand All @@ -169,46 +167,43 @@ def apply(
roi,
num_workers,
output_dtype,
compute_context,
overwrite,
)


def apply_run(
run_name: str,
run: Run,
iteration: int,
parameters: PostProcessorParameters,
input_array_identifier: "LocalArrayIdentifier",
prediction_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
roi: Optional[Roi] = None,
num_workers: int = 30,
output_dtype: Optional[np.dtype] = np.uint8, # type: ignore
compute_context: ComputeContext = LocalTorch(),
num_workers: int = 12,
output_dtype: np.dtype | str = np.uint8, # type: ignore
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""

# render prediction dataset
logger.info("Predicting on dataset %s", prediction_array_identifier)
predict(
run_name,
run.name,
iteration,
input_container=input_array_identifier.container,
input_dataset=input_array_identifier.dataset,
output_path=prediction_array_identifier.container,
output_path=prediction_array_identifier,
output_roi=roi,
num_workers=num_workers,
output_dtype=output_dtype,
compute_context=compute_context,
overwrite=overwrite,
)

# post-process the output
logger.info("Post-processing output to dataset %s", output_array_identifier)
post_processor = run.task.post_processor
post_processor.set_prediction(prediction_array_identifier)
post_processor.process(parameters, output_array_identifier)
post_processor.process(parameters, output_array_identifier, num_workers=num_workers)

logger.info("Done")
return
9 changes: 5 additions & 4 deletions dacapo/blockwise/argmax_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.compute_context import ComputeContext, LocalTorch
from dacapo.compute_context import create_compute_context

import daisy

Expand All @@ -14,6 +14,7 @@

read_write_conflict: bool = False
fit: str = "valid"
path = __file__


@click.group()
Expand Down Expand Up @@ -74,20 +75,20 @@ def start_worker(
def spawn_worker(
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
compute_context: ComputeContext = LocalTorch(),
):
"""Spawn a worker to predict on a given dataset.
Args:
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch().
"""
compute_context = create_compute_context()

# Make the command for the worker to run
command = [
"python",
__file__,
path,
"start-worker",
"--input_container",
input_array_identifier.container,
Expand Down
10 changes: 1 addition & 9 deletions dacapo/blockwise/blockwise_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
from importlib.machinery import SourceFileLoader
from pathlib import Path
from daisy import Task, Roi
from dacapo.compute_context import ComputeContext
import dacapo.compute_context


class DaCapoBlockwiseTask(Task):
def __init__(
self,
worker_file: str | Path,
compute_context: ComputeContext | str,
total_roi: Roi,
read_roi: Roi,
write_roi: Roi,
Expand All @@ -21,9 +18,6 @@ def __init__(
*args,
**kwargs,
):
if isinstance(compute_context, str):
compute_context = getattr(dacapo.compute_context, compute_context)()

# Load worker functions
worker_name = Path(worker_file).stem
worker = SourceFileLoader(worker_name, str(worker_file)).load_module()
Expand All @@ -32,9 +26,7 @@ def __init__(
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
task_id = worker_name + timestamp

process_function = worker.spawn_worker(
*args, **kwargs, compute_context=compute_context
)
process_function = worker.spawn_worker(*args, **kwargs)
if hasattr(worker, "check_function"):
check_function = worker.check_function
else:
Expand Down
70 changes: 37 additions & 33 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from pathlib import Path

import torch
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray
from dacapo.gp.dacapo_array_source import DaCapoArraySource
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.gp import DaCapoArraySource
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.store.create_store import create_config_store, create_weights_store
from dacapo.experiments import Run
from dacapo.compute_context import ComputeContext, LocalTorch
from dacapo.compute_context import create_compute_context
import gunpowder as gp
import gunpowder.torch as gp_torch

import daisy
from daisy import Coordinate
from funlib.geometry import Coordinate

import numpy as np
import click
Expand All @@ -22,6 +21,7 @@

read_write_conflict: bool = False
fit: str = "valid"
path = __file__


@click.group()
Expand Down Expand Up @@ -66,7 +66,7 @@ def start_worker(
input_dataset: str,
output_container: Path | str,
output_dataset: str,
device: str = "cuda",
device: str | torch.device = "cuda",
):
# retrieving run
config_store = create_config_store()
Expand All @@ -92,7 +92,7 @@ def start_worker(
torch.backends.cudnn.benchmark = True

# get the model's input and output size
model = run.model.eval()
model = run.model.eval().to(device)
input_voxel_size = Coordinate(raw_array.voxel_size)
output_voxel_size = model.scale(input_voxel_size)
input_shape = Coordinate(model.eval_input_shape)
Expand All @@ -102,6 +102,7 @@ def start_worker(
logger.info(
"Predicting with input size %s, output size %s", input_size, output_size
)

# create gunpowder keys

raw = gp.ArrayKey("RAW")
Expand All @@ -112,11 +113,13 @@ def start_worker(
# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims))
pipeline += gp.Pad(raw, None)
# raw: (c, d, h, w)
pipeline += gp.Unsqueeze([raw])
# raw: (1, c, d, h, w)

pipeline += gp.Normalize(raw)

# predict
pipeline += gp_torch.Predict(
model=model,
Expand Down Expand Up @@ -146,51 +149,52 @@ def start_worker(
) # assumes float32 is [0,1]
pipeline += gp.AsType(prediction, output_array.dtype)

# wait for blocks to run pipeline
client = daisy.Client()

while True:
print("getting block")
with client.acquire_block() as block:
if block is None:
break

ref_request = gp.BatchRequest()
ref_request[raw] = gp.ArraySpec(
roi=block.read_roi, voxel_size=input_voxel_size, dtype=raw_array.dtype
)
ref_request[prediction] = gp.ArraySpec(
roi=block.write_roi,
voxel_size=output_voxel_size,
dtype=output_array.dtype,
)
# write to output array
pipeline += gp.ZarrWrite(
{
prediction: output_array_identifier.dataset,
},
store=str(output_array_identifier.container),
)

with gp.build(pipeline):
batch = pipeline.request_batch(ref_request)
# make reference batch request
request = gp.BatchRequest()
request.add(raw, input_size, voxel_size=input_voxel_size)
request.add(
prediction,
output_size,
voxel_size=output_voxel_size,
)
# use daisy requests to run pipeline
pipeline += gp.DaisyRequestBlocks(
reference=request,
roi_map={raw: "read_roi", prediction: "write_roi"},
num_workers=1,
)

# write to output array
output_array[block.write_roi] = batch.arrays[prediction].data
with gp.build(pipeline):
batch = pipeline.request_batch(gp.BatchRequest())


def spawn_worker(
run_name: str,
iteration: int,
raw_array_identifier: "LocalArrayIdentifier",
prediction_array_identifier: "LocalArrayIdentifier",
compute_context: ComputeContext = LocalTorch(),
):
"""Spawn a worker to predict on a given dataset.
Args:
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch().
"""
compute_context = create_compute_context()

# Make the command for the worker to run
command = [
"python",
__file__,
path,
"start-worker",
"--run-name",
run_name,
Expand Down
Loading

0 comments on commit 655b9e0

Please sign in to comment.