Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mzouink #18

Closed
wants to merge 14 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import numpy as np

from typing import Dict, Any
import logging

logger = logging.getLogger(__file__)

class ConcatArray(Array):
"""This is a wrapper around other `source_arrays` that concatenates
Expand Down Expand Up @@ -93,6 +95,7 @@ def num_channels(self):
return len(self.channels)

def __getitem__(self, roi: Roi) -> np.ndarray:
logger.info(f"Concat Array: Get Item {self.name} {roi}")
rhoadesScholar marked this conversation as resolved.
Show resolved Hide resolved
default = (
np.zeros_like(self.source_array[roi])
if self.default_array is None
Expand All @@ -116,5 +119,6 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
raise Exception(f"{concatenated.shape}, shapes")
logger.info(f"Concatenated array has only one channel: {self.name} {concatenated.shape}")
# raise Exception(f"{concatenated.shape}, shapes")
rhoadesScholar marked this conversation as resolved.
Show resolved Hide resolved
return concatenated
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def attrs(self):

@property
def axes(self):
return ["t", "z", "y", "x"][-self.dims :]
return ["c", "z", "y", "x"][-self.dims :]

@property
def dims(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def from_gp_array(cls, array: gp.Array):
((["b", "c"] if len(array.data.shape) == instance.dims + 2 else []))
+ (["c"] if len(array.data.shape) == instance.dims + 1 else [])
+ [
"t",
"c",
"z",
"y",
"x",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def axes(self):
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}",
)
return ["t", "z", "y", "x"][-self.dims : :]
return ["c", "z", "y", "x"][-self.dims : :]

@property
def dims(self) -> int:
Expand Down
34 changes: 26 additions & 8 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from .validation_scores import ValidationScores
from .starts import Start
from .model import Model

import logging
import torch

logger = logging.getLogger(__file__)

class Run:
name: str
Expand Down Expand Up @@ -53,14 +54,31 @@ def __init__(self, run_config):
self.task.parameters, self.datasplit.validate, self.task.evaluation_scores
)

if run_config.start_config is None:
return
try:
from ..store import create_config_store
start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(run_config.start_config.run)
except Exception as e:
logger.error(f"could not load start config: {e} Should be added to the database config store RUN")
raise e

# preloaded weights from previous run
self.start = (
Start(run_config.start_config)
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config,"channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(run_config.start_config,old_head=old_head,new_head=new_head)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config,remove_head=True)
self.start.initialize_weights(self.model)


@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
71 changes: 66 additions & 5 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,82 @@

logger = logging.getLogger(__file__)

# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]
head_keys = ["prediction_head.weight","prediction_head.bias","chain.1.weight","chain.1.bias"]

# Hack
# if label is mito_peroxisome or peroxisome then change it to mito
mitos = ["mito_proxisome","peroxisome"]

def match_heads(model, head_weights, old_head, new_head ):
# match the heads
for label in new_head:
old_label = label
if label in mitos:
old_label = "mito"
if old_label in old_head:
logger.warning(f"matching head for {label}")
# find the index of the label in the old_head
old_index = old_head.index(old_label)
# find the index of the label in the new_head
new_index = new_head.index(label)
# get the weight and bias of the old head
for key in head_keys:
if key in model.state_dict().keys():
n_val = head_weights[key][old_index]
model.state_dict()[key][new_index] = n_val
logger.warning(f"matched head for {label} with {old_label}")

class Start(ABC):
def __init__(self, start_config):
def __init__(self, start_config,remove_head = False, old_head= None, new_head = None):
self.run = start_config.run
self.criterion = start_config.criterion
self.remove_head = remove_head
self.old_head = old_head
self.new_head = new_head

def initialize_weights(self, model):
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)
logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}")

# load the model weights (taken from torch load_state_dict source)
logger.warning(f"loading weights from run {self.run}, criterion: {self.criterion}")

try:
model.load_state_dict(weights.model)
if self.old_head and self.new_head:
try:
self.load_model_using_head_matching(model, weights)
except RuntimeError as e:
logger.error(f"ERROR starter matching head: {e}")
self.load_model_using_head_removal(model, weights)
elif self.remove_head:
self.load_model_using_head_removal(model, weights)
else:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
logger.warning(f"ERROR starter: {e}")

def load_model_using_head_removal(self, model, weights):
logger.warning(f"removing head from run {self.run}, criterion: {self.criterion}")
for key in head_keys:
weights.model.pop(key, None)
logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}")
model.load_state_dict(weights.model, strict=False)
logger.warning(f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}")

def load_model_using_head_matching(self, model, weights):
logger.warning(f"matching heads from run {self.run}, criterion: {self.criterion}")
logger.warning(f"old head: {self.old_head}")
logger.warning(f"new head: {self.new_head}")
head_weights = {}
for key in head_keys:
head_weights[key] = weights.model[key]
for key in head_keys:
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
model = match_heads(model, head_weights, self.old_head, self.new_head)



1 change: 1 addition & 0 deletions dacapo/experiments/tasks/distance_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(self, task_config):
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
extra_conv=task_config.extra_conv,
)
self.loss = MSELoss()
self.post_processor = ThresholdPostProcessor()
Expand Down
7 changes: 7 additions & 0 deletions dacapo/experiments/tasks/distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,10 @@ class DistanceTaskConfig(TaskConfig):
"is less than the distance to object boundary."
},
)

extra_conv: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether or not to add an extra conv layer before the head"
},
)
50 changes: 41 additions & 9 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class DistancePredictor(Predictor):
in the channels argument.
"""

def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool):
def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool,extra_conv :bool):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor
Expand All @@ -36,20 +36,52 @@ def __init__(self, channels: List[str], scale_factor: float, mask_distances: boo
self.max_distance = 1 * scale_factor
self.epsilon = 5e-2
self.threshold = 0.8
self.extra_conv = extra_conv
self.extra_conv_dims =len(self.channels) *2

@property
def embedding_dims(self):
return len(self.channels)

def create_model(self, architecture):
if architecture.dims == 2:
head = torch.nn.Conv2d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
elif architecture.dims == 3:
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
if self.extra_conv:
if architecture.dims == 2:
head = torch.nn.Sequential(
torch.nn.Conv2d(
architecture.num_out_channels,
self.extra_conv_dims,
kernel_size=3,
padding=1,
),
torch.nn.Conv2d(
self.extra_conv_dims,
self.embedding_dims,
kernel_size=1,
),
)
elif architecture.dims == 3:
head = torch.nn.Sequential(
torch.nn.Conv3d(
architecture.num_out_channels,
self.extra_conv_dims,
kernel_size=3,
padding=1,
),
torch.nn.Conv3d(
self.extra_conv_dims,
self.embedding_dims,
kernel_size=1,
),
)
else:
if architecture.dims == 2:
head = torch.nn.Conv2d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)
elif architecture.dims == 3:
head = torch.nn.Conv3d(
architecture.num_out_channels, self.embedding_dims, kernel_size=1
)

return Model(architecture, head)

Expand Down
43 changes: 29 additions & 14 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,24 @@ def __init__(self, trainer_config):
self.mask_integral_downsample_factor = 4
self.clip_raw = trainer_config.clip_raw

# Testing out if calculating multiple times and multiplying is necessary
self.add_predictor_nodes_to_dataset = trainer_config.add_predictor_nodes_to_dataset
self.finetune_head_only = trainer_config.finetune_head_only

self.scheduler = None

def create_optimizer(self, model):
optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters())
if self.finetune_head_only:
logger.warning("Finetuning head only")
parameters = []
for name, param in model.named_parameters():
if "prediction_head" in name:
parameters.append(param)
else:
param.requires_grad = False
else:
parameters = model.parameters()
optimizer = torch.optim.RAdam(lr=self.learning_rate, params=parameters)
self.scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.01,
Expand Down Expand Up @@ -146,13 +160,14 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
for augment in self.augments:
dataset_source += augment.node(raw_key, gt_key, mask_key)

# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)
if self.add_predictor_nodes_to_dataset:
# Add predictor nodes to dataset_source
dataset_source += DaCapoTargetFilter(
task.predictor,
gt_key=gt_key,
weights_key=dataset_weight_key,
mask_key=mask_key,
)

dataset_sources.append(dataset_source)
pipeline = tuple(dataset_sources) + gp.RandomProvider(weights)
Expand All @@ -162,11 +177,12 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
task.predictor,
gt_key=gt_key,
target_key=target_key,
weights_key=datasets_weight_key,
weights_key=datasets_weight_key if self.add_predictor_nodes_to_dataset else weight_key,
mask_key=mask_key,
)

pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)
if self.add_predictor_nodes_to_dataset:
pipeline += Product(dataset_weight_key, datasets_weight_key, weight_key)

# Trainer attributes:
if self.num_data_fetchers > 1:
Expand Down Expand Up @@ -208,15 +224,13 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
def iterate(self, num_iterations, model, optimizer, device):
t_start_fetch = time.time()

logger.info("Starting iteration!")

for iteration in range(self.iteration, self.iteration + num_iterations):
raw, gt, target, weight, mask = self.next()
logger.debug(
f"Trainer fetch batch took {time.time() - t_start_fetch} seconds"
)

for param in model.parameters():
for param in model.parameters(): # TODO: get parameters from optimizer instead
param.grad = None

t_start_prediction = time.time()
Expand All @@ -227,6 +241,7 @@ def iterate(self, num_iterations, model, optimizer, device):
torch.as_tensor(target[target.roi]).to(device).float(),
torch.as_tensor(weight[weight.roi]).to(device).float(),
)

loss.backward()
optimizer.step()

Expand Down Expand Up @@ -337,4 +352,4 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pass

def can_train(self, datasets) -> bool:
return all([dataset.gt is not None for dataset in datasets])
return all([dataset.gt is not None for dataset in datasets])
10 changes: 10 additions & 0 deletions dacapo/experiments/trainers/gunpowder_trainer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,13 @@ class GunpowderTrainerConfig(TrainerConfig):
)
min_masked: Optional[float] = attr.ib(default=0.15)
clip_raw: bool = attr.ib(default=True)

add_predictor_nodes_to_dataset: Optional[bool] = attr.ib(
default=True,
metadata={"help_text": "Whether to add a predictor node to dataset_source and apply product of weights"}
)

finetune_head_only: Optional[bool] = attr.ib(
default=False,
metadata={"help_text": "Whether to fine-tune head only or all layers"}
)
Loading
Loading