Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosem starter - Head matching #196

Merged
merged 8 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,14 @@ def __init__(self, run_config):
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)
if self.start is None:
return
else:
if hasattr(run_config.task_config,"channels"):
new_head = run_config.task_config.channels
else:
new_head = None
self.start.initialize_weights(self.model,new_head=new_head)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
37 changes: 32 additions & 5 deletions dacapo/experiments/starts/cosem_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,39 @@
import logging
from cellmap_models import cosem
from pathlib import Path
from .start import Start
from .start import Start, _set_weights

logger = logging.getLogger(__file__)


def get_model_setup(run):
try:
model = cosem.load_model(run)
if hasattr(model, "classes_channels"):
classes_channels = model.classes_channels
else:
classes_channels = None
if hasattr(model, "voxel_size_input"):
voxel_size_input = model.voxel_size_input
else:
voxel_size_input = None
if hasattr(model, "voxel_size_output"):
voxel_size_output = model.voxel_size_output
else:
voxel_size_output = None
return classes_channels, voxel_size_input, voxel_size_output
except Exception as e:
logger.error(f"could not load model setup: {e} - Not a big deal, model will train wiithout head matching")
return None, None, None

class CosemStart(Start):
def __init__(self, start_config):
super().__init__(start_config)
self.run = start_config.run
self.criterion = start_config.criterion
self.name = f"{self.run}/{self.criterion}"
channels, voxel_size_input, voxel_size_output = get_model_setup(self.run)
if voxel_size_input is not None:
logger.warning(f"Starter model resolution: input {voxel_size_input} output {voxel_size_output}, Make sure to set the correct resolution for the input data.")
self.channels = channels

def check(self):
from dacapo.store.create_store import create_weights_store
Expand All @@ -25,7 +49,8 @@ def check(self):
else:
logger.info(f"Checkpoint for {self.name} exists.")

def initialize_weights(self, model):
def initialize_weights(self, model, new_head=None):
self.check()
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
Expand All @@ -36,4 +61,6 @@ def initialize_weights(self, model):
path = weights_dir / self.criterion
cosem.download_checkpoint(self.name, path)
weights = weights_store._retrieve_weights(self.run, self.criterion)
super._set_weights(model, weights)
_set_weights(model, weights, self.run, self.criterion, self.channels, new_head)


83 changes: 61 additions & 22 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,60 @@

logger = logging.getLogger(__file__)

head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"]

def match_heads(model, head_weights, old_head, new_head ):
for label in new_head:
if label in old_head:
logger.warning(f"matching head for {label}.")
old_index = old_head.index(label)
new_index = new_head.index(label)
for key in head_keys:
if key in model.state_dict().keys():
n_val = head_weights[key][old_index]
model.state_dict()[key][new_index] = n_val
logger.warning(f"matched head for {label}.")

def _set_weights(model, weights, run, criterion, old_head=None, new_head=None):
logger.warning(f"loading weights from run {run}, criterion: {criterion}, old_head {old_head}, new_head: {new_head}")
try:
if old_head and new_head:
try:
logger.warning(f"matching heads from run {run}, criterion: {criterion}")
logger.warning(f"old head: {old_head}")
logger.warning(f"new head: {new_head}")
head_weights = {}
for key in head_keys:
head_weights[key] = weights.model[key]
for key in head_keys:
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
model = match_heads(model, head_weights, old_head, new_head)
except RuntimeError as e:
logger.error(f"ERROR starter matching head: {e}")
logger.warning(f"removing head from run {run}, criterion: {criterion}")
for key in head_keys:
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
logger.warning(f"loaded weights in non strict mode from run {run}, criterion: {criterion}")
else:
try:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
model_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
)
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
except RuntimeError as e:
logger.warning(f"ERROR starter: {e}")

class Start(ABC):
"""
Expand Down Expand Up @@ -32,28 +86,12 @@ def __init__(self, start_config):
self.run = start_config.run
self.criterion = start_config.criterion

def _set_weights(self, model, weights):
print(f"loading weights from run {self.run}, criterion: {self.criterion}")
# load the model weights (taken from torch load_state_dict source)
try:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
# if the model is not the same, we can try to load the weights
# of the common layers
model_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
) # update only the existing and matching layers
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
if hasattr(start_config.task_config,"channels"):
self.channels = start_config.task_config.channels
else:
self.channels = None

def initialize_weights(self, model):
def initialize_weights(self, model,new_head=None):
"""
Retrieves the weights from the dacapo store and load them into
the model.
Expand All @@ -72,4 +110,5 @@ def initialize_weights(self, model):

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)
self._set_weights(model, weights)
_set_weights(model, weights, self.run, self.criterion, self.channels, new_head)

1 change: 1 addition & 0 deletions dacapo/store/conversion_hooks.py
rhoadesScholar marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def register_hierarchy_hooks(converter):
"""Central place to register type hierarchies for conversion."""

converter.register_hierarchy(TaskConfig, cls_fun)
converter.register_hierarchy(StartConfig, cls_fun)
converter.register_hierarchy(ArchitectureConfig, cls_fun)
converter.register_hierarchy(TrainerConfig, cls_fun)
converter.register_hierarchy(AugmentConfig, cls_fun)
Expand Down
Loading