Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/tests #134

Merged
merged 28 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b13902f
overhaul options
pattonw Feb 16, 2024
70c0e43
Merge branch 'dev/main' into options-overhaul
rhoadesScholar Feb 16, 2024
ab58caa
fix: ⚡️ Pass compute context as part of options, and pass options fil…
rhoadesScholar Feb 16, 2024
9f1f84d
fix: :poop: making options/config serializable
neptunes5thmoon Feb 16, 2024
d0a3fc6
Merge branch '57-parallel-blockwise-processing' of github.com:janelia…
rhoadesScholar Feb 16, 2024
7a39d9d
Remove unused code in options.py
rhoadesScholar Feb 16, 2024
f535161
:art: Format Python code with psf/black
rhoadesScholar Feb 16, 2024
51d2040
Dev/main (#132)
rhoadesScholar Feb 16, 2024
96b2cb5
Merge branch '57-parallel-blockwise-processing' into actions/black
rhoadesScholar Feb 16, 2024
1d6f0f9
Format Python code with psf/black push (#131)
rhoadesScholar Feb 16, 2024
1f5990b
Merge branch '57-parallel-blockwise-processing' of github.com:janelia…
rhoadesScholar Feb 16, 2024
89382b0
fix: 🧪 Fix test_apply
rhoadesScholar Feb 16, 2024
37fb933
fix: 🧪 Working on tests. All "components" are passing.
rhoadesScholar Feb 19, 2024
3775d6d
fix: ✅ Testing train passes.
rhoadesScholar Feb 19, 2024
e365495
fix: 🐛 Fixed test_options for no_config
rhoadesScholar Feb 19, 2024
645be34
fix: 🚧 Work to fix test_predict.
rhoadesScholar Feb 19, 2024
816c19f
fix: 🚧 Some cleanup, but still failing test_predict.
rhoadesScholar Feb 20, 2024
bd17166
fix: 🚧 Debug predict_worker.py
rhoadesScholar Feb 20, 2024
e380907
Remove commented out code for daisy case in predict_worker.py
rhoadesScholar Feb 20, 2024
9b12bbd
fix: 🐛 Passed test_predict.py
rhoadesScholar Feb 21, 2024
0ef1fd2
fix: 🧪 Fix imports. test_apply.py passes
rhoadesScholar Feb 21, 2024
6d0a444
fix: 🐛 Fix abstract ConfigStore to allow MongoConfigStore definition.
rhoadesScholar Feb 21, 2024
c1b7f8a
fix: 🚧 Extend prediction output declaration, refactor, debug tests, a…
rhoadesScholar Feb 22, 2024
d5f0e60
chore: 🧑‍💻 Clean imports.
rhoadesScholar Feb 22, 2024
0aae770
Merge branch 'dev/main' into dev/tests
rhoadesScholar Feb 22, 2024
c7ec5ad
Merge branch 'dev/main' into dev/tests
rhoadesScholar Feb 22, 2024
6c9be49
Merge branch 'dev/tests' of github.com:janelia-cellmap/dacapo into de…
rhoadesScholar Feb 22, 2024
9def264
chore: 🔥 Cleaning
rhoadesScholar Feb 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dist
build
dacapo.yaml
__pycache__
scratch/

# vscode stuff
.vscode
Expand Down
23 changes: 9 additions & 14 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import dacapo.experiments.tasks.post_processors as post_processors
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.predict import predict
from dacapo.compute_context import LocalTorch, ComputeContext
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store.create_store import (
create_config_store,
Expand All @@ -34,9 +33,8 @@ def apply(
iteration: Optional[int] = None,
parameters: Optional[PostProcessorParameters | str] = None,
roi: Optional[Roi | str] = None,
num_workers: int = 30,
num_workers: int = 12,
output_dtype: Optional[np.dtype | str] = np.uint8, # type: ignore
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
file_format: str = "zarr",
):
Expand Down Expand Up @@ -144,7 +142,7 @@ def apply(
)
output_container = Path(
output_path,
"".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}",
Path(input_container).stem + f".{file_format}",
)
prediction_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
Expand All @@ -160,7 +158,7 @@ def apply(
Path(input_container, input_dataset),
)
return apply_run(
run.name,
run,
iteration,
parameters,
input_array_identifier,
Expand All @@ -169,46 +167,43 @@ def apply(
roi,
num_workers,
output_dtype,
compute_context,
overwrite,
)


def apply_run(
run_name: str,
run: Run,
iteration: int,
parameters: PostProcessorParameters,
input_array_identifier: "LocalArrayIdentifier",
prediction_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
roi: Optional[Roi] = None,
num_workers: int = 30,
output_dtype: Optional[np.dtype] = np.uint8, # type: ignore
compute_context: ComputeContext = LocalTorch(),
num_workers: int = 12,
output_dtype: np.dtype | str = np.uint8, # type: ignore
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""

# render prediction dataset
logger.info("Predicting on dataset %s", prediction_array_identifier)
predict(
run_name,
run.name,
iteration,
input_container=input_array_identifier.container,
input_dataset=input_array_identifier.dataset,
output_path=prediction_array_identifier.container,
output_path=prediction_array_identifier,
output_roi=roi,
num_workers=num_workers,
output_dtype=output_dtype,
compute_context=compute_context,
overwrite=overwrite,
)

# post-process the output
logger.info("Post-processing output to dataset %s", output_array_identifier)
post_processor = run.task.post_processor
post_processor.set_prediction(prediction_array_identifier)
post_processor.process(parameters, output_array_identifier)
post_processor.process(parameters, output_array_identifier, num_workers=num_workers)

logger.info("Done")
return
9 changes: 5 additions & 4 deletions dacapo/blockwise/argmax_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.compute_context import ComputeContext, LocalTorch
from dacapo.compute_context import create_compute_context

import daisy

Expand All @@ -14,6 +14,7 @@

read_write_conflict: bool = False
fit: str = "valid"
path = __file__


@click.group()
Expand Down Expand Up @@ -74,20 +75,20 @@ def start_worker(
def spawn_worker(
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
compute_context: ComputeContext = LocalTorch(),
):
"""Spawn a worker to predict on a given dataset.

Args:
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch().
"""
compute_context = create_compute_context()

# Make the command for the worker to run
command = [
"python",
__file__,
path,
"start-worker",
"--input_container",
input_array_identifier.container,
Expand Down
10 changes: 1 addition & 9 deletions dacapo/blockwise/blockwise_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
from importlib.machinery import SourceFileLoader
from pathlib import Path
from daisy import Task, Roi
from dacapo.compute_context import ComputeContext
import dacapo.compute_context


class DaCapoBlockwiseTask(Task):
def __init__(
self,
worker_file: str | Path,
compute_context: ComputeContext | str,
total_roi: Roi,
read_roi: Roi,
write_roi: Roi,
Expand All @@ -21,9 +18,6 @@ def __init__(
*args,
**kwargs,
):
if isinstance(compute_context, str):
compute_context = getattr(dacapo.compute_context, compute_context)()

# Load worker functions
worker_name = Path(worker_file).stem
worker = SourceFileLoader(worker_name, str(worker_file)).load_module()
Expand All @@ -32,9 +26,7 @@ def __init__(
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
task_id = worker_name + timestamp

process_function = worker.spawn_worker(
*args, **kwargs, compute_context=compute_context
)
process_function = worker.spawn_worker(*args, **kwargs)
if hasattr(worker, "check_function"):
check_function = worker.check_function
else:
Expand Down
70 changes: 37 additions & 33 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from pathlib import Path

import torch
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray
from dacapo.gp.dacapo_array_source import DaCapoArraySource
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.gp import DaCapoArraySource
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.store.create_store import create_config_store, create_weights_store
from dacapo.experiments import Run
from dacapo.compute_context import ComputeContext, LocalTorch
from dacapo.compute_context import create_compute_context
import gunpowder as gp
import gunpowder.torch as gp_torch

import daisy
from daisy import Coordinate
from funlib.geometry import Coordinate

import numpy as np
import click
Expand All @@ -22,6 +21,7 @@

read_write_conflict: bool = False
fit: str = "valid"
path = __file__


@click.group()
Expand Down Expand Up @@ -66,7 +66,7 @@ def start_worker(
input_dataset: str,
output_container: Path | str,
output_dataset: str,
device: str = "cuda",
device: str | torch.device = "cuda",
):
# retrieving run
config_store = create_config_store()
Expand All @@ -92,7 +92,7 @@ def start_worker(
torch.backends.cudnn.benchmark = True

# get the model's input and output size
model = run.model.eval()
model = run.model.eval().to(device)
input_voxel_size = Coordinate(raw_array.voxel_size)
output_voxel_size = model.scale(input_voxel_size)
input_shape = Coordinate(model.eval_input_shape)
Expand All @@ -102,6 +102,7 @@ def start_worker(
logger.info(
"Predicting with input size %s, output size %s", input_size, output_size
)

# create gunpowder keys

raw = gp.ArrayKey("RAW")
Expand All @@ -112,11 +113,13 @@ def start_worker(
# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims))
pipeline += gp.Pad(raw, None)
# raw: (c, d, h, w)
pipeline += gp.Unsqueeze([raw])
# raw: (1, c, d, h, w)

pipeline += gp.Normalize(raw)

# predict
pipeline += gp_torch.Predict(
model=model,
Expand Down Expand Up @@ -146,51 +149,52 @@ def start_worker(
) # assumes float32 is [0,1]
pipeline += gp.AsType(prediction, output_array.dtype)

# wait for blocks to run pipeline
client = daisy.Client()

while True:
print("getting block")
with client.acquire_block() as block:
if block is None:
break

ref_request = gp.BatchRequest()
ref_request[raw] = gp.ArraySpec(
roi=block.read_roi, voxel_size=input_voxel_size, dtype=raw_array.dtype
)
ref_request[prediction] = gp.ArraySpec(
roi=block.write_roi,
voxel_size=output_voxel_size,
dtype=output_array.dtype,
)
# write to output array
pipeline += gp.ZarrWrite(
{
prediction: output_array_identifier.dataset,
},
store=str(output_array_identifier.container),
)

with gp.build(pipeline):
batch = pipeline.request_batch(ref_request)
# make reference batch request
request = gp.BatchRequest()
request.add(raw, input_size, voxel_size=input_voxel_size)
request.add(
prediction,
output_size,
voxel_size=output_voxel_size,
)
# use daisy requests to run pipeline
pipeline += gp.DaisyRequestBlocks(
reference=request,
roi_map={raw: "read_roi", prediction: "write_roi"},
num_workers=1,
)

# write to output array
output_array[block.write_roi] = batch.arrays[prediction].data
with gp.build(pipeline):
batch = pipeline.request_batch(gp.BatchRequest())


def spawn_worker(
run_name: str,
iteration: int,
raw_array_identifier: "LocalArrayIdentifier",
prediction_array_identifier: "LocalArrayIdentifier",
compute_context: ComputeContext = LocalTorch(),
):
"""Spawn a worker to predict on a given dataset.

Args:
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
compute_context (ComputeContext, optional): The compute context to use. Defaults to LocalTorch().
"""
compute_context = create_compute_context()

# Make the command for the worker to run
command = [
"python",
__file__,
path,
"start-worker",
"--run-name",
run_name,
Expand Down
9 changes: 5 additions & 4 deletions dacapo/blockwise/relabel_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from glob import glob
import os
import daisy
from dacapo.compute_context import ComputeContext, LocalTorch
from dacapo.compute_context import create_compute_context
from dacapo.store.array_store import LocalArrayIdentifier
from scipy.cluster.hierarchy import DisjointSet
from funlib.persistence import open_ds
Expand All @@ -27,6 +27,7 @@ def cli(log_level):

fit = "shrink"
read_write_conflict = False
path = __file__


@cli.command()
Expand Down Expand Up @@ -88,7 +89,6 @@ def read_cross_block_merges(tmpdir):
def spawn_worker(
output_array_identifier: LocalArrayIdentifier,
tmpdir: str,
compute_context: ComputeContext = LocalTorch(),
*args,
**kwargs,
):
Expand All @@ -97,12 +97,13 @@ def spawn_worker(
Args:
output_array_identifier (LocalArrayIdentifier): The output array identifier
tmpdir (str): The temporary directory
compute_context (ComputeContext, optional): The compute context. Defaults to LocalTorch().
"""
compute_context = create_compute_context()

# Make the command for the worker to run
command = [
"python",
__file__,
path,
"start-worker",
"--output_container",
output_array_identifier.container,
Expand Down
Loading
Loading