Skip to content

Commit

Permalink
Merge branch 'rhoadesj/hot_distance' into actions/black
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Feb 11, 2024
2 parents 232047c + 53b57b6 commit 70169e2
Show file tree
Hide file tree
Showing 28 changed files with 799 additions and 183 deletions.
10 changes: 6 additions & 4 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def axes(self):
try:
return self._attributes["axes"]
except KeyError:
logger.debug(
logger.info(
"DaCapo expects Zarr datasets to have an 'axes' attribute!\n"
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}",
Expand All @@ -58,7 +58,7 @@ def axes(self):

@property
def dims(self) -> int:
return self.voxel_size.dims
return len(self.data.shape)

@lazy_property.LazyProperty
def _daisy_array(self) -> funlib.persistence.Array:
Expand All @@ -81,7 +81,7 @@ def writable(self) -> bool:

@property
def dtype(self) -> Any:
return self.data.dtype
return self.data.dtype # TODO: why not use self._daisy_array.dtype?

@property
def num_channels(self) -> Optional[int]:
Expand All @@ -92,7 +92,7 @@ def spatial_axes(self) -> List[str]:
return [ax for ax in self.axes if ax not in set(["c", "b"])]

@property
def data(self) -> Any:
def data(self) -> Any: # TODO: why not use self._daisy_array.data?
zarr_container = zarr.open(str(self.file_name))
return zarr_container[self.dataset]

Expand All @@ -116,6 +116,7 @@ def create_from_array_identifier(
dtype,
write_size=None,
name=None,
overwrite=False,
):
"""
Create a new ZarrArray given an array identifier. It is assumed that
Expand Down Expand Up @@ -145,6 +146,7 @@ def create_from_array_identifier(
dtype,
num_channels=num_channels,
write_size=write_size,
delete=overwrite,
)
zarr_dataset = zarr_container[array_identifier.dataset]
zarr_dataset.attrs["offset"] = (
Expand Down
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa
from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa
from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa
from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa
46 changes: 46 additions & 0 deletions dacapo/experiments/tasks/evaluators/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,52 @@ def is_best(
else:
return getattr(score, criterion) < previous_best_score

def get_overall_best(self, dataset: "Dataset", criterion: str):
overall_best = None
if self.best_scores:
for _, parameter, _ in self.best_scores.keys():
score = self.best_scores[(dataset, parameter, criterion)]
if score is None:
overall_best = None
else:
_, current_parameter_score = score
if overall_best is None:
overall_best = current_parameter_score
else:
if current_parameter_score:
if self.higher_is_better(criterion):
if current_parameter_score > overall_best:
overall_best = current_parameter_score
else:
if current_parameter_score < overall_best:
overall_best = current_parameter_score
return overall_best

def get_overall_best_parameters(self, dataset: "Dataset", criterion: str):
overall_best = None
overall_best_parameters = None
if self.best_scores:
for _, parameter, _ in self.best_scores.keys():
score = self.best_scores[(dataset, parameter, criterion)]
if score is None:
overall_best = None
else:
_, current_parameter_score = score
if overall_best is None:
overall_best = current_parameter_score
overall_best_parameters = parameter
else:
if current_parameter_score:
if self.higher_is_better(criterion):
if current_parameter_score > overall_best:
overall_best = current_parameter_score
overall_best_parameters = parameter
else:
if current_parameter_score < overall_best:
overall_best = current_parameter_score
overall_best_parameters = parameter
return overall_best_parameters

def set_best(self, validation_scores: "ValidationScores") -> None:
"""
Find the best iteration for each dataset/post_processing_parameter/criterion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

@attr.s
class InstanceEvaluationScores(EvaluationScores):
criteria = ["voi_split", "voi_merge", "voi"]
criteria = ["voi_split", "voi_merge", "voi", "avg_iou"]

voi_split: float = attr.ib(default=float("nan"))
voi_merge: float = attr.ib(default=float("nan"))
avg_iou: float = attr.ib(default=float("nan"))

@property
def voi(self):
Expand All @@ -21,6 +22,7 @@ def higher_is_better(criterion: str) -> bool:
"voi_split": False,
"voi_merge": False,
"voi": False,
"avg_iou": True,
}
return mapping[criterion]

Expand All @@ -30,6 +32,7 @@ def bounds(criterion: str) -> Tuple[float, float]:
"voi_split": (0, 1),
"voi_merge": (0, 1),
"voi": (0, 1),
"avg_iou": (0, None),
}
return mapping[criterion]

Expand Down
14 changes: 12 additions & 2 deletions dacapo/experiments/tasks/evaluators/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .evaluator import Evaluator
from .instance_evaluation_scores import InstanceEvaluationScores

from funlib.evaluate import rand_voi
from funlib.evaluate import rand_voi, detection_scores

import numpy as np

Expand All @@ -16,9 +16,19 @@ def evaluate(self, output_array_identifier, evaluation_array):
evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64)
output_data = output_array[output_array.roi].astype(np.uint64)
results = rand_voi(evaluation_data, output_data)
results.update(
detection_scores(
evaluation_data,
output_data,
matching_score="iou",
voxel_size=output_array.voxel_size,
)
)

return InstanceEvaluationScores(
voi_merge=results["voi_merge"], voi_split=results["voi_split"]
voi_merge=results["voi_merge"],
voi_split=results["voi_split"],
avg_iou=results["avg_iou"],
)

@property
Expand Down
25 changes: 25 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .evaluators import BinarySegmentationEvaluator
from .losses import HotDistanceLoss
from .post_processors import ThresholdPostProcessor
from .predictors import HotDistancePredictor
from .task import Task


class HotDistanceTask(Task):
"""This is just a Hot Distance Task that combine Binary and distance prediction."""

def __init__(self, task_config):
"""Create a `HotDistanceTask` from a `HotDistanceTaskConfig`."""

self.predictor = HotDistancePredictor(
channels=task_config.channels,
scale_factor=task_config.scale_factor,
mask_distances=task_config.mask_distances,
)
self.loss = HotDistanceLoss()
self.post_processor = ThresholdPostProcessor()
self.evaluator = BinarySegmentationEvaluator(
clip_distance=task_config.clip_distance,
tol_distance=task_config.tol_distance,
channels=task_config.channels,
)
47 changes: 47 additions & 0 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import attr

from .hot_distance_task import HotDistanceTask
from .task_config import TaskConfig

from typing import List

@attr.s
class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
evaluating signed distance transforms as a way of generating
segmentations.
The advantage of generating distance transforms over regular
affinities is you can get a denser signal, i.e. 1 misclassified
pixel in an affinity prediction could merge 2 otherwise very
distinct objects, this cannot happen with distances.
"""

task_type = HotDistanceTask

channels: List[str] = attr.ib(metadata={"help_text": "A list of channel names."})
clip_distance: float = attr.ib(
metadata={
"help_text": "Maximum distance to consider for false positive/negatives."
},
)
tol_distance: float = attr.ib(
metadata={
"help_text": "Tolerance distance for counting false positives/negatives"
},
)
scale_factor: float = attr.ib(
default=1,
metadata={
"help_text": "The amount by which to scale distances before applying "
"a tanh normalization."
},
)
mask_distances: bool = attr.ib(
default=False,
metadata={
"help_text": "Whether or not to mask out regions where the true distance to "
"object boundary cannot be known. This is anywhere that the distance to crop boundary "
"is less than the distance to object boundary."
},
)
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .mse_loss import MSELoss # noqa
from .loss import Loss # noqa
from .affinities_loss import AffinitiesLoss # noqa
from .hot_distance_loss import HotDistanceLoss # noqa
30 changes: 30 additions & 0 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from .loss import Loss
import torch


# HotDistance is used for predicting hot and distance maps at the same time.
# The first half of the channels are the hot maps, the second half are the distance maps.
# The loss is the sum of the BCELoss for the hot maps and the MSELoss for the distance maps.
# Model should predict twice the number of channels as the target.
class HotDistanceLoss(Loss):
def compute(self, prediction, target, weight):
target_hot, target_distance = self.split(target)
prediction_hot, prediction_distance = self.split(prediction)
weight_hot, weight_distance = self.split(weight)
return self.hot_loss(
prediction_hot, target_hot, weight_hot
) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
loss = torch.nn.BCEWithLogitsLoss(reduction='none')
return torch.mean(loss(prediction , target) * weight)

def distance_loss(self, prediction, target, weight):
loss = torch.nn.MSELoss()
return loss(prediction * weight, target * weight)

def split(self, x):
# Shape[0] is the batch size and Shape[1] is the number of channels.
assert x.shape[1] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[1] // 2
return torch.split(x,mid,dim=1)
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ def set_prediction(self, prediction_array_identifier):
prediction_array_identifier
)

def process(self, parameters, output_array_identifier):
def process(self, parameters, output_array_identifier, overwrite: bool = False):
output_array = ZarrArray.create_from_array_identifier(
output_array_identifier,
[dim for dim in self.prediction_array.axes if dim != "c"],
self.prediction_array.roi,
None,
self.prediction_array.voxel_size,
np.uint8,
overwrite=overwrite,
)

output_array[self.prediction_array.roi] = np.argmax(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]:
def set_prediction(self, prediction_array):
pass

def process(self, parameters, output_array_identifier):
def process(self, parameters, output_array_identifier, overwrite: bool = False):
# store some dummy data
f = zarr.open(str(output_array_identifier.container), "a")
f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size
2 changes: 2 additions & 0 deletions dacapo/experiments/tasks/post_processors/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def process(
self,
parameters: "PostProcessorParameters",
output_array_identifier: "LocalArrayIdentifier",
overwrite: "bool",
blockwise: "bool",
) -> "Array":
"""Convert predictions into the final output."""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def process(
self,
parameters: "PostProcessorParameters",
output_array_identifier: "LocalArrayIdentifier",
overwrite: bool = False,
) -> ZarrArray:
# TODO: Investigate Liskov substitution princple and whether it is a problem here
# OOP theory states the super class should always be replaceable with its subclasses
Expand All @@ -47,6 +48,7 @@ def process(
self.prediction_array.num_channels,
self.prediction_array.voxel_size,
np.uint8,
overwrite=overwrite,
)

output_array[self.prediction_array.roi] = (
Expand Down
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .one_hot_predictor import OneHotPredictor # noqa
from .predictor import Predictor # noqa
from .affinities_predictor import AffinitiesPredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
Loading

0 comments on commit 70169e2

Please sign in to comment.