Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/main #198

Merged
merged 41 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2eef2d2
fix formatting
davidackerman Mar 19, 2024
d6705fb
Merge branch 'dev/examples' of github.com:janelia-cellmap/dacapo into…
rhoadesScholar Mar 19, 2024
1c756a4
fix: 🐛 Fix default runs_base_dir
rhoadesScholar Mar 19, 2024
8cae470
:art: Format Python code with psf/black
rhoadesScholar Mar 19, 2024
34ba7a7
chore: 🩹 Make prediction/validation single worker to fix patches
rhoadesScholar Mar 19, 2024
f9a0bbb
chore: 🎨 Black format.
rhoadesScholar Mar 19, 2024
66ba704
Format Python code with psf/black push (#195)
rhoadesScholar Mar 19, 2024
bc1fdba
fix registry
mzouink Mar 19, 2024
9104dcc
Merge branch 'main' into dev/main
rhoadesScholar Mar 19, 2024
74bdc18
Merge branch 'dev/main' into cosem_starter
mzouink Mar 19, 2024
d9f32c4
Merge branch 'dev/main' of github.com:janelia-cellmap/dacapo into dev…
rhoadesScholar Mar 19, 2024
6bb5d47
head matching
mzouink Mar 19, 2024
c669d43
Merge branch 'cosem_starter' of https://github.com/janelia-cellmap/da…
mzouink Mar 19, 2024
aa65907
Merge branch 'main' into dev/main
rhoadesScholar Mar 19, 2024
d8076d5
fix minor errors
mzouink Mar 19, 2024
74461db
Merge branch 'main' into dev/main
rhoadesScholar Mar 19, 2024
1c1cf4f
Merge branch 'main' into dev/main
rhoadesScholar Mar 19, 2024
c128020
Merge branch 'dev/main' into cosem_starter
mzouink Mar 20, 2024
8bb5d50
Update start.py
rhoadesScholar Mar 20, 2024
34d11be
Update start.py
rhoadesScholar Mar 20, 2024
4c594fe
Cosem starter - Head matching (#196)
rhoadesScholar Mar 20, 2024
838891e
Merge branch 'dev/main' of github.com:janelia-cellmap/dacapo into dev…
rhoadesScholar Mar 20, 2024
c39f72b
perf: ⚡️ Restrict local prediction to one worker.
rhoadesScholar Mar 20, 2024
5e6b0f4
perf: ⚡️ Change default validation worker number.
rhoadesScholar Mar 20, 2024
04c05d1
feat: 🚀 Improve model loading/prediction.
rhoadesScholar Mar 20, 2024
8185d14
feat: 🚀 Improve model loading/prediction. (#199)
rhoadesScholar Mar 20, 2024
88d088a
chore: 🙈 Remove ipynotebook checkpoints.
rhoadesScholar Mar 20, 2024
d887f8f
Merge branch 'dev/main' into dev/examples
rhoadesScholar Mar 20, 2024
ff006df
Merge branch 'dev/examples' of github.com:janelia-cellmap/dacapo into…
rhoadesScholar Mar 20, 2024
d931648
Dev/examples (#194)
rhoadesScholar Mar 20, 2024
72ba70f
Merge branch 'dev/main' of github.com:janelia-cellmap/dacapo into dev…
rhoadesScholar Mar 20, 2024
d470a58
fix: 🐛 Predict fix.
rhoadesScholar Mar 20, 2024
be87111
feat: ✨ Generalize get_viewer util
rhoadesScholar Mar 20, 2024
5228051
Update validate.py
rhoadesScholar Mar 20, 2024
34f4b7f
Merge branch 'dev/main' into dev/get_viewer
rhoadesScholar Mar 20, 2024
0d24e2a
Merge branch 'dev/get_viewer' of github.com:janelia-cellmap/dacapo in…
rhoadesScholar Mar 20, 2024
7f44619
feat: ✨ Update synthetic example notebook.
rhoadesScholar Mar 20, 2024
6ae0dc8
feat: ✨ Generalize get_viewer util (#200)
rhoadesScholar Mar 20, 2024
2400b1e
Merge branch 'dev/main' of github.com:janelia-cellmap/dacapo into dev…
rhoadesScholar Mar 20, 2024
14c42db
fix: 🐛 Fix predict_worker
rhoadesScholar Mar 20, 2024
af8b671
style: 🎨 Black format.
rhoadesScholar Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.hdf
*.h5
# *.ipynb
.ipynb_checkpoints/
*.pyc
*.egg-info
*.dat
Expand Down
24 changes: 13 additions & 11 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from pathlib import Path
from typing import Optional

import torch
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
Expand Down Expand Up @@ -45,9 +46,9 @@ def cli(log_level):
@click.option(
"-i",
"--iteration",
required=True,
type=int,
type=Optional[int],
help="The training iteration of the model to use for prediction.",
default=None,
)
@click.option(
"-ic",
Expand All @@ -62,7 +63,7 @@ def cli(log_level):
@click.option("-od", "--output_dataset", required=True, type=str)
def start_worker(
run_name: str,
iteration: int,
iteration: int | None,
input_container: Path | str,
input_dataset: str,
output_container: Path | str,
Expand All @@ -76,11 +77,12 @@ def start_worker(
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

# create weights store
weights_store = create_weights_store()
if iteration is not None:
# create weights store
weights_store = create_weights_store()

# load weights
weights_store.retrieve_weights(run_name, iteration)
# load weights
weights_store.retrieve_weights(run_name, iteration)

# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
Expand Down Expand Up @@ -178,15 +180,15 @@ def start_worker(

def spawn_worker(
run_name: str,
iteration: int,
iteration: int | None,
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
):
"""Spawn a worker to predict on a given dataset.

Args:
run_name (str): The name of the run to apply.
iteration (int): The training iteration of the model to use for prediction.
iteration (int or None): The training iteration of the model to use for prediction.
input_array_identifier (LocalArrayIdentifier): The raw data to predict on.
output_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
"""
Expand All @@ -200,8 +202,6 @@ def spawn_worker(
"start-worker",
"--run-name",
run_name,
"--iteration",
iteration,
"--input_container",
input_array_identifier.container,
"--input_dataset",
Expand All @@ -211,6 +211,8 @@ def spawn_worker(
"--output_dataset",
output_array_identifier.dataset,
]
if iteration is not None:
command.extend(["--iteration", str(iteration)])

print("Defining worker with command: ", compute_context.wrap_command(command))

Expand Down
2 changes: 1 addition & 1 deletion dacapo/examples/distance_task/cosem_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
1 change: 0 additions & 1 deletion dacapo/examples/distance_task/cosem_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# %% [markdown]
# # Dacapo
#
Expand Down
17 changes: 8 additions & 9 deletions dacapo/examples/distance_task/cosem_example_fill_in_the_blank.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
# Create the datasplit, produce the neuroglancer link and store the datasplit
datasplit = ...
viewer = ...
config_store...

config_store
# %% [markdown]
# ## Task
# What do you want to learn? An instance segmentation? If so, how? Affinities,
Expand All @@ -40,9 +39,8 @@

# Create a distance task config where the clip_distance=tol_distance=10x the output resolution,
# and scale_factor = 20x the output resolution
task_config =
config_store....

task_config = ...
config_store
# %% [markdown]
# ## Architecture
#
Expand Down Expand Up @@ -97,14 +95,14 @@
# Create a gamma augment config with range .5 to 2
...,
# Create an intensity scale shift agument config to rescale data from the range 0->1 to -1->1
...,
...,
],
snapshot_interval=10000,
min_masked=0.05,
clip_raw=True,
)
# Store the trainer
config_store....
config_store

# %% [markdown]
# ## Run
Expand All @@ -128,7 +126,7 @@
run_config = ...

print(run_config.name)
config_store...
config_store

# %% [markdown]
# ## Train
Expand All @@ -138,6 +136,7 @@
# %%
from dacapo.train import train_run
from dacapo.experiments.run import Run

# load the run and train it
run = Run(config_store...)
run = Run(config_store)
train_run(run)
20 changes: 10 additions & 10 deletions dacapo/examples/distance_task/synthetic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@

datasplit = datasplit_config.datasplit_type(datasplit_config)
viewer = datasplit._neuroglancer()
config_store.store_datasplit_config(datasplit_config)
# config_store.store_datasplit_config(datasplit_config)

# %% [markdown]
# The above datasplit_generator automates a lot of the heavy lifting for configuring data to set up a run. The following shows everything that it is doing, and an equivalent way to set up the datasplit.
Expand Down Expand Up @@ -232,7 +232,7 @@
tol_distance=80.0,
scale_factor=160.0,
)
config_store.store_task_config(task_config)
# config_store.store_task_config(task_config)

# %% [markdown]
# ## Architecture
Expand All @@ -252,11 +252,11 @@
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
eval_shape_increase=(72, 72, 72),
)
try:
config_store.store_architecture_config(architecture_config)
except:
config_store.delete_architecture_config(architecture_config.name)
config_store.store_architecture_config(architecture_config)
# try:
# config_store.store_architecture_config(architecture_config)
# except:
# config_store.delete_architecture_config(architecture_config.name)
# config_store.store_architecture_config(architecture_config)

# %% [markdown]
# ## Trainer
Expand Down Expand Up @@ -293,7 +293,7 @@
min_masked=0.05,
clip_raw=True,
)
config_store.store_trainer_config(trainer_config)
# config_store.store_trainer_config(trainer_config)

# %% [markdown]
# ## Run
Expand All @@ -311,7 +311,7 @@
# "best",
# )

iterations = 2000
iterations = 200
validation_interval = iterations // 2
repetitions = 1
for i in range(repetitions):
Expand Down Expand Up @@ -376,7 +376,7 @@
# %%
from dacapo.validate import validate

validate(run_config.name, iterations, num_workers=16, overwrite=True)
validate(run_config.name, iterations, num_workers=1, overwrite=True)

# %% [markdown]
# ## Predict
Expand Down
10 changes: 8 additions & 2 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@ def __init__(self, run_config):
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)
if self.start is None:
return
else:
if hasattr(run_config.task_config,"channels"):
new_head = run_config.task_config.channels
else:
new_head = None
self.start.initialize_weights(self.model,new_head=new_head)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
37 changes: 32 additions & 5 deletions dacapo/experiments/starts/cosem_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,39 @@
import logging
from cellmap_models import cosem
from pathlib import Path
from .start import Start
from .start import Start, _set_weights

logger = logging.getLogger(__file__)


def get_model_setup(run):
try:
model = cosem.load_model(run)
if hasattr(model, "classes_channels"):
classes_channels = model.classes_channels
else:
classes_channels = None
if hasattr(model, "voxel_size_input"):
voxel_size_input = model.voxel_size_input
else:
voxel_size_input = None
if hasattr(model, "voxel_size_output"):
voxel_size_output = model.voxel_size_output
else:
voxel_size_output = None
return classes_channels, voxel_size_input, voxel_size_output
except Exception as e:
logger.error(f"could not load model setup: {e} - Not a big deal, model will train wiithout head matching")
return None, None, None

class CosemStart(Start):
def __init__(self, start_config):
super().__init__(start_config)
self.run = start_config.run
self.criterion = start_config.criterion
self.name = f"{self.run}/{self.criterion}"
channels, voxel_size_input, voxel_size_output = get_model_setup(self.run)
if voxel_size_input is not None:
logger.warning(f"Starter model resolution: input {voxel_size_input} output {voxel_size_output}, Make sure to set the correct resolution for the input data.")
self.channels = channels

def check(self):
from dacapo.store.create_store import create_weights_store
Expand All @@ -25,7 +49,8 @@ def check(self):
else:
logger.info(f"Checkpoint for {self.name} exists.")

def initialize_weights(self, model):
def initialize_weights(self, model, new_head=None):
self.check()
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
Expand All @@ -36,4 +61,6 @@ def initialize_weights(self, model):
path = weights_dir / self.criterion
cosem.download_checkpoint(self.name, path)
weights = weights_store._retrieve_weights(self.run, self.criterion)
super._set_weights(model, weights)
_set_weights(model, weights, self.run, self.criterion, self.channels, new_head)


Loading
Loading