Skip to content

Commit

Permalink
style: 🎨 Black format
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Feb 7, 2024
1 parent 5c10f19 commit c88016c
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

logger = logging.getLogger(__file__)


class ConcatArray(Array):
"""This is a wrapper around other `source_arrays` that concatenates
them along the channel dimension."""
Expand Down Expand Up @@ -118,5 +119,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
logger.info(f"Concatenated array has only one channel: {self.name} {concatenated.shape}")
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
21 changes: 14 additions & 7 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

logger = logging.getLogger(__file__)


class Run:
name: str
train_until: int
Expand Down Expand Up @@ -58,28 +59,34 @@ def __init__(self, run_config):
return
try:
from ..store import create_config_store

start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(run_config.start_config.run)
starter_config = start_config_store.retrieve_run_config(
run_config.start_config.run
)
except Exception as e:
logger.error(f"could not load start config: {e} Should be added to the database config store RUN")
logger.error(
f"could not load start config: {e} Should be added to the database config store RUN"
)
raise e

# preloaded weights from previous run
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config,"channels"):
if hasattr(run_config.task_config, "channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head)
self.start = Start(
run_config.start_config, old_head=old_head, new_head=new_head
)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config,remove_head=True)
self.start = Start(run_config.start_config, remove_head=True)
self.start.initialize_weights(self.model)


@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
"""
Expand Down
38 changes: 25 additions & 13 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@

logger = logging.getLogger(__file__)

# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]
head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"]
# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]
head_keys = [
"prediction_head.weight",
"prediction_head.bias",
"chain.1.weight",
"chain.1.bias",
]

# Hack
# if label is mito_peroxisome or peroxisome then change it to mito
mitos = ["mito_proxisome","peroxisome"]
mitos = ["mito_proxisome", "peroxisome"]

def match_heads(model, head_weights, old_head, new_head ):

def match_heads(model, head_weights, old_head, new_head):
# match the heads
for label in new_head:
old_label = label
Expand All @@ -30,8 +36,9 @@ def match_heads(model, head_weights, old_head, new_head ):
model.state_dict()[key][new_index] = n_val
logger.warning(f"matched head for {label} with {old_label}")


class Start(ABC):
def __init__(self, start_config,remove_head = False, old_head= None, new_head = None):
def __init__(self, start_config, remove_head=False, old_head=None, new_head=None):
self.run = start_config.run
self.criterion = start_config.criterion
self.remove_head = remove_head
Expand All @@ -44,7 +51,9 @@ def initialize_weights(self, model):
weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)

logger.warning(f"loading weights from run {self.run}, criterion: {self.criterion}")
logger.warning(
f"loading weights from run {self.run}, criterion: {self.criterion}"
)

try:
if self.old_head and self.new_head:
Expand All @@ -61,15 +70,21 @@ def initialize_weights(self, model):
logger.warning(f"ERROR starter: {e}")

def load_model_using_head_removal(self, model, weights):
logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}")
logger.warning(
f"removing head from run {self.run}, criterion: {self.criterion}"
)
for key in head_keys:
weights.model.pop(key, None)
logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}")
model.load_state_dict(weights.model, strict=False)
logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}")
logger.warning(
f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}"
)

def load_model_using_head_matching(self, model, weights):
logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}")
logger.warning(
f"matching heads from run {self.run}, criterion: {self.criterion}"
)
logger.warning(f"old head: {self.old_head}")
logger.warning(f"new head: {self.new_head}")
head_weights = {}
Expand All @@ -79,6 +94,3 @@ def load_model_using_head_matching(self, model, weights):
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
model = match_heads(model, head_weights, self.old_head, self.new_head)



10 changes: 8 additions & 2 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ class DistancePredictor(Predictor):
in the channels argument.
"""

def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool,extra_conv :bool):
def __init__(
self,
channels: List[str],
scale_factor: float,
mask_distances: bool,
extra_conv: bool,
):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor
Expand All @@ -37,7 +43,7 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo
self.epsilon = 5e-2
self.threshold = 0.8
self.extra_conv = extra_conv
self.extra_conv_dims =len(self.channels) *2
self.extra_conv_dims = len(self.channels) * 2

@property
def embedding_dims(self):
Expand Down
14 changes: 10 additions & 4 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def __init__(self, trainer_config):
self.clip_raw = trainer_config.clip_raw

# Testing out if calculating multiple times and multiplying is necessary
self.add_predictor_nodes_to_dataset = trainer_config.add_predictor_nodes_to_dataset
self.add_predictor_nodes_to_dataset = (
trainer_config.add_predictor_nodes_to_dataset
)
self.finetune_head_only = trainer_config.finetune_head_only

self.scheduler = None
Expand Down Expand Up @@ -177,7 +179,9 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
task.predictor,
gt_key=gt_key,
target_key=target_key,
weights_key=datasets_weight_key if self.add_predictor_nodes_to_dataset else weight_key,
weights_key=datasets_weight_key
if self.add_predictor_nodes_to_dataset
else weight_key,
mask_key=mask_key,
)

Expand Down Expand Up @@ -230,7 +234,9 @@ def iterate(self, num_iterations, model, optimizer, device):
f"Trainer fetch batch took {time.time() - t_start_fetch} seconds"
)

for param in model.parameters(): # TODO: get parameters from optimizer instead
for (
param
) in model.parameters(): # TODO: get parameters from optimizer instead
param.grad = None

t_start_prediction = time.time()
Expand Down Expand Up @@ -352,4 +358,4 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pass

def can_train(self, datasets) -> bool:
return all([dataset.gt is not None for dataset in datasets])
return all([dataset.gt is not None for dataset in datasets])
8 changes: 5 additions & 3 deletions dacapo/experiments/trainers/gunpowder_trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ class GunpowderTrainerConfig(TrainerConfig):

add_predictor_nodes_to_dataset: Optional[bool] = attr.ib(
default=True,
metadata={"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"}
metadata={
"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"
},
)

finetune_head_only: Optional[bool] = attr.ib(
default=False,
metadata={"help_text": "Whether to fine-tune head only or all layers"}
)
metadata={"help_text": "Whether to fine-tune head only or all layers"},
)
2 changes: 1 addition & 1 deletion dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def train_run(
stats_store.store_training_stats(run.name, run.training_stats)

# make sure to move optimizer back to the correct device
run.move_optimizer(compute_context.device)
run.move_optimizer(compute_context.device)
run.model.train()

weights_store.store_weights(run, run.training_stats.trained_until())
Expand Down
2 changes: 1 addition & 1 deletion dacapo/utils/balance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def balance_weights(
# scale_slab the masked-in scale_slab with the class weights
scale_slab *= np.take(w, labels_slab)

return error_scale, moving_counts
return error_scale, moving_counts
23 changes: 12 additions & 11 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,44 @@
#
import os
import sys
sys.path.insert(0, os.path.abspath('../..'))

sys.path.insert(0, os.path.abspath("../.."))


# -- Project information -----------------------------------------------------

project = 'DaCapo'
copyright = '2022, William Patton, David Ackerman, Jan Funke'
author = 'William Patton, David Ackerman, Jan Funke'
project = "DaCapo"
copyright = "2022, William Patton, David Ackerman, Jan Funke"
author = "William Patton, David Ackerman, Jan Funke"


# -- General configuration ---------------------------------------------------

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx_autodoc_typehints']
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_autodoc_typehints"]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]


# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_material'
html_theme = "sphinx_material"

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
html_css_files = [
'css/custom.css',
]
"css/custom.css",
]

0 comments on commit c88016c

Please sign in to comment.