Skip to content

Commit

Permalink
support no cli call (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Jun 4, 2024
2 parents 1b90080 + 5a2f22f commit 83347b9
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 27 deletions.
18 changes: 17 additions & 1 deletion dacapo/blockwise/argmax_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,22 @@ def start_worker(
output_container: Path | str,
output_dataset: str,
return_io_loop: bool = False,
):
return start_worker_fn(
input_container=input_container,
input_dataset=input_dataset,
output_container=output_container,
output_dataset=output_dataset,
return_io_loop=return_io_loop,
)


def start_worker_fn(
input_container: Path | str,
input_dataset: str,
output_container: Path | str,
output_dataset: str,
return_io_loop: bool = False,
):
"""
Start the threshold worker.
Expand Down Expand Up @@ -111,7 +127,7 @@ def spawn_worker(
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
return start_worker(
return start_worker_fn(
input_array_identifier.container,
input_array_identifier.dataset,
output_array_identifier.container,
Expand Down
49 changes: 37 additions & 12 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,27 @@ def cli(log_level):
)
@click.option("-od", "--output_dataset", required=True, type=str)
def start_worker(
run_name: str,
run_name: str | Run,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
output_container: Path | str,
output_dataset: str,
return_io_loop: Optional[bool] = False,
):
return start_worker_fn(
run_name=run_name,
iteration=iteration,
input_container=input_container,
input_dataset=input_dataset,
output_container=output_container,
output_dataset=output_dataset,
return_io_loop=return_io_loop,
)


def start_worker_fn(
run_name: str | Run,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
Expand All @@ -93,9 +113,14 @@ def start_worker(
device = compute_context.device

# retrieving run
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)
logger.error(f"run_name: {run_name} {type(run_name)}")
if isinstance(run_name, Run):
run = run_name
run_name = run.name
else:
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

if iteration is not None:
# create weights store
Expand Down Expand Up @@ -207,7 +232,7 @@ def io_loop():


def spawn_worker(
run_name: str,
run_name: str | Run,
iteration: int | None,
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
Expand All @@ -225,13 +250,13 @@ def spawn_worker(
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
return start_worker(
run_name,
iteration,
input_array_identifier.container,
input_array_identifier.dataset,
output_array_identifier.container,
output_array_identifier.dataset,
return start_worker_fn(
run_name=run_name,
iteration=iteration,
input_container=input_array_identifier.container,
input_dataset=input_array_identifier.dataset,
output_container=output_array_identifier.container,
output_dataset=output_array_identifier.dataset,
return_io_loop=True,
)

Expand Down
20 changes: 15 additions & 5 deletions dacapo/blockwise/relabel_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,20 @@ def start_worker(
output_dataset,
tmpdir,
return_io_loop=False,
*args,
**kwargs,
):
return start_worker_fn(
output_container=output_container,
output_dataset=output_dataset,
tmpdir=tmpdir,
return_io_loop=return_io_loop,
)


def start_worker_fn(
output_container,
output_dataset,
tmpdir,
return_io_loop=False,
):
"""
Start the relabel worker.
Expand Down Expand Up @@ -145,8 +157,6 @@ def read_cross_block_merges(tmpdir):
def spawn_worker(
output_array_identifier: LocalArrayIdentifier,
tmpdir: str,
*args,
**kwargs,
):
"""
Spawn a worker to predict on a given dataset.
Expand All @@ -160,7 +170,7 @@ def spawn_worker(
compute_context = create_compute_context()

if not compute_context.distribute_workers:
return start_worker(
return start_worker_fn(
output_array_identifier.container,
output_array_identifier.dataset,
tmpdir,
Expand Down
30 changes: 25 additions & 5 deletions dacapo/blockwise/segment_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,32 @@ def cli(log_level):
@click.option("--tmpdir", type=str, help="Temporary directory")
@click.option("--function_path", type=str, help="Path to the segment function")
def start_worker(
input_container: str,
input_container: str | Path,
input_dataset: str,
output_container: str,
output_container: str | Path,
output_dataset: str,
tmpdir: str,
function_path: str,
tmpdir: str | Path,
function_path: str | Path,
return_io_loop: bool = False,
):
return start_worker_fn(
input_container=input_container,
input_dataset=input_dataset,
output_container=output_container,
output_dataset=output_dataset,
tmpdir=tmpdir,
function_path=function_path,
return_io_loop=return_io_loop,
)


def start_worker_fn(
input_container: str | Path,
input_dataset: str,
output_container: str | Path,
output_dataset: str,
tmpdir: str | Path,
function_path: str | Path,
return_io_loop: bool = False,
):
"""
Expand Down Expand Up @@ -211,7 +231,7 @@ def spawn_worker(
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
return start_worker(
return start_worker_fn(
input_array_identifier.container,
input_array_identifier.dataset,
output_array_identifier.container,
Expand Down
20 changes: 19 additions & 1 deletion dacapo/blockwise/threshold_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ def start_worker(
output_dataset: str,
threshold: float = 0.0,
return_io_loop: bool = False,
):
return start_worker_fn(
input_container=input_container,
input_dataset=input_dataset,
output_container=output_container,
output_dataset=output_dataset,
threshold=threshold,
return_io_loop=return_io_loop,
)


def start_worker_fn(
input_container: Path | str,
input_dataset: str,
output_container: Path | str,
output_dataset: str,
threshold: float = 0.0,
return_io_loop: bool = False,
):
"""
Start the threshold worker.
Expand Down Expand Up @@ -109,7 +127,7 @@ def spawn_worker(
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
return start_worker(
return start_worker_fn(
input_array_identifier.container,
input_array_identifier.dataset,
output_array_identifier.container,
Expand Down
6 changes: 6 additions & 0 deletions dacapo/compute_context/compute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@


class ComputeContext(ABC):
distribute_workers: Optional[bool] = attr.ib(
default=False,
metadata={
"help_text": "Whether to distribute the workers across multiple nodes or processes."
},
)
"""
The ComputeContext class is an abstract base class for defining the context in which computations are to be done.
It is inherited from the built-in class `ABC` (Abstract Base Classes). Other classes can inherit this class to define
Expand Down
4 changes: 2 additions & 2 deletions dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def predict(
input_size = input_voxel_size * input_shape
output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]
num_out_channels = model.num_out_channels
del model
# del model

# calculate input and output rois

Expand Down Expand Up @@ -149,7 +149,7 @@ def predict(
max_retries=2, # TODO: make this an option
timeout=None, # TODO: make this an option
######
run_name=run_name,
run_name=run,
iteration=iteration,
input_array_identifier=input_array_identifier,
output_array_identifier=output_array_identifier,
Expand Down
2 changes: 1 addition & 1 deletion tests/operations/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"run_config",
[
lf("distance_run"),
# lf("onehot_run"),
lf("onehot_run"),
],
)
def test_validate(
Expand Down

0 comments on commit 83347b9

Please sign in to comment.