Skip to content

Commit

Permalink
Merge branch 'hot_distance' into rhoadesj/hot_distance
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Feb 9, 2024
2 parents 448f766 + cd4077d commit 53b57b6
Show file tree
Hide file tree
Showing 28 changed files with 510 additions and 254 deletions.
2 changes: 1 addition & 1 deletion dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
"-r", "--run_name", required=True, type=str, help="The name of the run to use."
)
@click.option(
"-ic",
Expand Down
10 changes: 6 additions & 4 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def axes(self):
try:
return self._attributes["axes"]
except KeyError:
logger.debug(
logger.info(
"DaCapo expects Zarr datasets to have an 'axes' attribute!\n"
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}",
Expand All @@ -58,7 +58,7 @@ def axes(self):

@property
def dims(self) -> int:
return self.voxel_size.dims
return len(self.data.shape)

@lazy_property.LazyProperty
def _daisy_array(self) -> funlib.persistence.Array:
Expand All @@ -81,7 +81,7 @@ def writable(self) -> bool:

@property
def dtype(self) -> Any:
return self.data.dtype
return self.data.dtype # TODO: why not use self._daisy_array.dtype?

@property
def num_channels(self) -> Optional[int]:
Expand All @@ -92,7 +92,7 @@ def spatial_axes(self) -> List[str]:
return [ax for ax in self.axes if ax not in set(["c", "b"])]

@property
def data(self) -> Any:
def data(self) -> Any: # TODO: why not use self._daisy_array.data?
zarr_container = zarr.open(str(self.file_name))
return zarr_container[self.dataset]

Expand All @@ -116,6 +116,7 @@ def create_from_array_identifier(
dtype,
write_size=None,
name=None,
overwrite=False,
):
"""
Create a new ZarrArray given an array identifier. It is assumed that
Expand Down Expand Up @@ -145,6 +146,7 @@ def create_from_array_identifier(
dtype,
num_channels=num_channels,
write_size=write_size,
delete=overwrite,
)
zarr_dataset = zarr_container[array_identifier.dataset]
zarr_dataset.attrs["offset"] = (
Expand Down
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa
from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa
from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa
from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa
46 changes: 46 additions & 0 deletions dacapo/experiments/tasks/evaluators/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,52 @@ def is_best(
else:
return getattr(score, criterion) < previous_best_score

def get_overall_best(self, dataset: "Dataset", criterion: str):
overall_best = None
if self.best_scores:
for _, parameter, _ in self.best_scores.keys():
score = self.best_scores[(dataset, parameter, criterion)]
if score is None:
overall_best = None
else:
_, current_parameter_score = score
if overall_best is None:
overall_best = current_parameter_score
else:
if current_parameter_score:
if self.higher_is_better(criterion):
if current_parameter_score > overall_best:
overall_best = current_parameter_score
else:
if current_parameter_score < overall_best:
overall_best = current_parameter_score
return overall_best

def get_overall_best_parameters(self, dataset: "Dataset", criterion: str):
overall_best = None
overall_best_parameters = None
if self.best_scores:
for _, parameter, _ in self.best_scores.keys():
score = self.best_scores[(dataset, parameter, criterion)]
if score is None:
overall_best = None
else:
_, current_parameter_score = score
if overall_best is None:
overall_best = current_parameter_score
overall_best_parameters = parameter
else:
if current_parameter_score:
if self.higher_is_better(criterion):
if current_parameter_score > overall_best:
overall_best = current_parameter_score
overall_best_parameters = parameter
else:
if current_parameter_score < overall_best:
overall_best = current_parameter_score
overall_best_parameters = parameter
return overall_best_parameters

def set_best(self, validation_scores: "ValidationScores") -> None:
"""
Find the best iteration for each dataset/post_processing_parameter/criterion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

@attr.s
class InstanceEvaluationScores(EvaluationScores):
criteria = ["voi_split", "voi_merge", "voi"]
criteria = ["voi_split", "voi_merge", "voi", "avg_iou"]

voi_split: float = attr.ib(default=float("nan"))
voi_merge: float = attr.ib(default=float("nan"))
avg_iou: float = attr.ib(default=float("nan"))

@property
def voi(self):
Expand All @@ -21,6 +22,7 @@ def higher_is_better(criterion: str) -> bool:
"voi_split": False,
"voi_merge": False,
"voi": False,
"avg_iou": True,
}
return mapping[criterion]

Expand All @@ -30,6 +32,7 @@ def bounds(criterion: str) -> Tuple[float, float]:
"voi_split": (0, 1),
"voi_merge": (0, 1),
"voi": (0, 1),
"avg_iou": (0, None),
}
return mapping[criterion]

Expand Down
14 changes: 12 additions & 2 deletions dacapo/experiments/tasks/evaluators/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .evaluator import Evaluator
from .instance_evaluation_scores import InstanceEvaluationScores

from funlib.evaluate import rand_voi
from funlib.evaluate import rand_voi, detection_scores

import numpy as np

Expand All @@ -16,9 +16,19 @@ def evaluate(self, output_array_identifier, evaluation_array):
evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64)
output_data = output_array[output_array.roi].astype(np.uint64)
results = rand_voi(evaluation_data, output_data)
results.update(
detection_scores(
evaluation_data,
output_data,
matching_score="iou",
voxel_size=output_array.voxel_size,
)
)

return InstanceEvaluationScores(
voi_merge=results["voi_merge"], voi_split=results["voi_split"]
voi_merge=results["voi_merge"],
voi_split=results["voi_split"],
avg_iou=results["avg_iou"],
)

@property
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import List


@attr.s
class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
evaluating signed distance transforms as a way of generating
Expand Down
21 changes: 11 additions & 10 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@ def compute(self, prediction, target, weight):
return self.hot_loss(
prediction_hot, target_hot, weight_hot
) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
return torch.nn.BCELoss().forward(prediction * weight, target * weight)

loss = torch.nn.BCEWithLogitsLoss(reduction='none')
return torch.mean(loss(prediction , target) * weight)

def distance_loss(self, prediction, target, weight):
return torch.nn.MSELoss().forward(prediction * weight, target * weight)

loss = torch.nn.MSELoss()
return loss(prediction * weight, target * weight)

def split(self, x):
assert (
x.shape[0] % 2 == 0
), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[0] // 2
return x[:mid], x[-mid:]
# Shape[0] is the batch size and Shape[1] is the number of channels.
assert x.shape[1] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[1] // 2
return torch.split(x,mid,dim=1)
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ def set_prediction(self, prediction_array_identifier):
prediction_array_identifier
)

def process(self, parameters, output_array_identifier):
def process(self, parameters, output_array_identifier, overwrite: bool = False):
output_array = ZarrArray.create_from_array_identifier(
output_array_identifier,
[dim for dim in self.prediction_array.axes if dim != "c"],
self.prediction_array.roi,
None,
self.prediction_array.voxel_size,
np.uint8,
overwrite=overwrite,
)

output_array[self.prediction_array.roi] = np.argmax(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]:
def set_prediction(self, prediction_array):
pass

def process(self, parameters, output_array_identifier):
def process(self, parameters, output_array_identifier, overwrite: bool = False):
# store some dummy data
f = zarr.open(str(output_array_identifier.container), "a")
f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size
2 changes: 2 additions & 0 deletions dacapo/experiments/tasks/post_processors/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def process(
self,
parameters: "PostProcessorParameters",
output_array_identifier: "LocalArrayIdentifier",
overwrite: "bool",
blockwise: "bool",
) -> "Array":
"""Convert predictions into the final output."""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def process(
self,
parameters: "PostProcessorParameters",
output_array_identifier: "LocalArrayIdentifier",
overwrite: bool = False,
) -> ZarrArray:
# TODO: Investigate Liskov substitution princple and whether it is a problem here
# OOP theory states the super class should always be replaceable with its subclasses
Expand All @@ -47,6 +48,7 @@ def process(
self.prediction_array.num_channels,
self.prediction_array.voxel_size,
np.uint8,
overwrite=overwrite,
)

output_array[self.prediction_array.roi] = (
Expand Down
103 changes: 65 additions & 38 deletions dacapo/experiments/tasks/post_processors/watershed_post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,48 +24,75 @@ def enumerate_parameters(self):
"""Enumerate all possible parameters of this post-processor. Should
return instances of ``PostProcessorParameters``."""

for i, bias in enumerate([0.1, 0.5, 0.9]):
for i, bias in enumerate(
[0.1, 0.3, 0.5, 0.7, 0.9]
): # TODO: add this to the config
yield WatershedPostProcessorParameters(id=i, bias=bias)

def set_prediction(self, prediction_array_identifier):
self.prediction_array = ZarrArray.open_from_array_identifier(
prediction_array_identifier
)

def process(self, parameters, output_array_identifier):
output_array = ZarrArray.create_from_array_identifier(
output_array_identifier,
[axis for axis in self.prediction_array.axes if axis != "c"],
self.prediction_array.roi,
None,
self.prediction_array.voxel_size,
np.uint64,
)
# if a previous segmentation is provided, it must have a "grid graph"
# in its metadata.
pred_data = self.prediction_array[self.prediction_array.roi]
affs = pred_data[: len(self.offsets)]
segmentation = mws.agglom(
affs - 0.5,
self.offsets,
)
# filter fragments
average_affs = np.mean(affs, axis=0)

filtered_fragments = []

fragment_ids = np.unique(segmentation)

for fragment, mean in zip(
fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids)
):
if mean < 0.5:
filtered_fragments.append(fragment)

filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype)
replace = np.zeros_like(filtered_fragments)
segmentation = npi.remap(segmentation, filtered_fragments, replace)

output_array[self.prediction_array.roi] = segmentation

return output_array
def process(
self,
parameters,
output_array_identifier,
overwrite: bool = False,
blockwise: bool = False,
): # TODO: will probably break with large arrays...
if not blockwise:
output_array = ZarrArray.create_from_array_identifier(
output_array_identifier,
[axis for axis in self.prediction_array.axes if axis != "c"],
self.prediction_array.roi,
None,
self.prediction_array.voxel_size,
np.uint64,
overwrite=overwrite,
)
# if a previous segmentation is provided, it must have a "grid graph"
# in its metadata.
# pred_data = self.prediction_array[self.prediction_array.roi]
# affs = pred_data[: len(self.offsets)].astype(
# np.float64
# ) # TODO: shouldn't need to be float64
affs = self.prediction_array[self.prediction_array.roi][: len(self.offsets)]
if affs.dtype == np.uint8:
affs = affs.astype(np.float64) / 255.0
else:
affs = affs.astype(np.float64)
segmentation = mws.agglom(
affs - parameters.bias,
self.offsets,
)
# filter fragments
average_affs = np.mean(affs, axis=0)

filtered_fragments = []

fragment_ids = np.unique(segmentation)

for fragment, mean in zip(
fragment_ids,
measurements.mean(average_affs, segmentation, fragment_ids),
):
if mean < parameters.bias:
filtered_fragments.append(fragment)

filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype)
replace = np.zeros_like(filtered_fragments)

# DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input
if filtered_fragments.size > 0:
segmentation = npi.remap(
segmentation.flatten(), filtered_fragments, replace
).reshape(segmentation.shape)

output_array[self.prediction_array.roi] = segmentation

return output_array
else:
raise NotImplementedError(
"Blockwise processing not yet implemented."
) # TODO: add rusty mws
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .one_hot_predictor import OneHotPredictor # noqa
from .predictor import Predictor # noqa
from .affinities_predictor import AffinitiesPredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
Loading

0 comments on commit 53b57b6

Please sign in to comment.