Skip to content

Commit

Permalink
Merge e88c2c8 into 1b90080
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored May 30, 2024
2 parents 1b90080 + e88c2c8 commit 330365b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
51 changes: 39 additions & 12 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,36 @@ def cli(log_level):
)
@click.option("-od", "--output_dataset", required=True, type=str)
def start_worker(
run_name: str,
run_name: str| Run,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
output_container: Path | str,
output_dataset: str,
return_io_loop: Optional[bool] = False,
):
return start_worker_fn(
run_name=run_name,
iteration=iteration,
input_container=input_container,
input_dataset=input_dataset,
output_container=output_container,
output_dataset=output_dataset,
return_io_loop=return_io_loop,
)



def start_worker_fn(
run_name: str| Run,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
output_container: Path | str,
output_dataset: str,
return_io_loop: bool,
):

"""
Start a worker to apply a trained model to a dataset.
Expand All @@ -93,9 +115,14 @@ def start_worker(
device = compute_context.device

# retrieving run
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)
logger.error(f"run_name: {run_name} {type(run_name)}" )
if isinstance(run_name, Run):
run = run_name
run_name = run.name
else:
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

if iteration is not None:
# create weights store
Expand Down Expand Up @@ -207,7 +234,7 @@ def io_loop():


def spawn_worker(
run_name: str,
run_name: str| Run,
iteration: int | None,
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
Expand All @@ -225,13 +252,13 @@ def spawn_worker(
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
return start_worker(
run_name,
iteration,
input_array_identifier.container,
input_array_identifier.dataset,
output_array_identifier.container,
output_array_identifier.dataset,
return start_worker_fn(
run_name= run_name,
iteration=iteration,
input_container= input_array_identifier.container,
input_dataset= input_array_identifier.dataset,
output_container=output_array_identifier.container,
output_dataset=output_array_identifier.dataset,
return_io_loop=True,
)

Expand Down
4 changes: 2 additions & 2 deletions dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def predict(
input_size = input_voxel_size * input_shape
output_size = output_voxel_size * model.compute_output_shape(input_shape)[1]
num_out_channels = model.num_out_channels
del model
# del model

# calculate input and output rois

Expand Down Expand Up @@ -149,7 +149,7 @@ def predict(
max_retries=2, # TODO: make this an option
timeout=None, # TODO: make this an option
######
run_name=run_name,
run_name=run,
iteration=iteration,
input_array_identifier=input_array_identifier,
output_array_identifier=output_array_identifier,
Expand Down

0 comments on commit 330365b

Please sign in to comment.