Skip to content

Commit

Permalink
feat: 🚧 Incorporate simple change from rhoadesj/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Feb 8, 2024
1 parent 33bbc8a commit fe23b5d
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 15 deletions.
2 changes: 1 addition & 1 deletion dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run_name", required=True, type=str, help="The name of the run to use."
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
)
@click.option(
"-ic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def axes(self):
logger.debug(
"DaCapo expects Zarr datasets to have an 'axes' attribute!\n"
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}",
f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}",
)
return ["c", "z", "y", "x"][-self.dims : :]

Expand Down
3 changes: 1 addition & 2 deletions dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def predict(
num_cpu_workers: int = 4,
compute_context: ComputeContext = LocalTorch(),
output_roi: Optional[Roi] = None,
output_dtype: Optional[np.dtype] = np.uint8,
output_dtype: Optional[np.dtype] = np.float32, # add necessary type conversions
overwrite: bool = False,
):
# get the model's input and output size
Expand Down Expand Up @@ -71,7 +71,6 @@ def predict(

# prepare data source
pipeline = DaCapoArraySource(raw_array, raw)
pipeline += gp.Normalize(raw)
# raw: (c, d, h, w)
pipeline += gp.Pad(raw, Coordinate((None,) * input_voxel_size.dims))
# raw: (c, d, h, w)
Expand Down
11 changes: 0 additions & 11 deletions dacapo/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import deepcopy
from dacapo.store.create_store import create_array_store
from .experiments import Run
from .compute_context import LocalTorch, ComputeContext
Expand All @@ -11,7 +10,6 @@
import logging

logger = logging.getLogger(__name__)
logger.setLevel("INFO")


def train(run_name: str, compute_context: ComputeContext = LocalTorch()):
Expand Down Expand Up @@ -103,16 +101,7 @@ def train_run(
logger.error(
f"Found weights for iteration {latest_weights_iteration}, but "
f"run {run.name} was only trained until {trained_until}. "
"Filling stats with last observed values."
)
last_iteration_stats = run.training_stats.iteration_stats[-1]
for i in range(
last_iteration_stats.iteration, latest_weights_iteration - 1
):
new_iteration_stats = deepcopy(last_iteration_stats)
new_iteration_stats.iteration = i + 1
run.training_stats.add_iteration_stats(new_iteration_stats)
trained_until = run.training_stats.trained_until()

# start/resume training

Expand Down

0 comments on commit fe23b5d

Please sign in to comment.