Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format Python code with psf/black push #52

Merged
merged 2 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .one_hot_task_config import OneHotTaskConfig, OneHotTask # noqa
from .pretrained_task_config import PretrainedTaskConfig, PretrainedTask # noqa
from .affinities_task_config import AffinitiesTaskConfig, AffinitiesTask # noqa
from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa
from .hot_distance_task_config import HotDistanceTaskConfig, HotDistanceTask # noqa
1 change: 1 addition & 0 deletions dacapo/experiments/tasks/hot_distance_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import List


@attr.s
class HotDistanceTaskConfig(TaskConfig):
"""This is a Hot Distance task config used for generating and
Expand Down
16 changes: 9 additions & 7 deletions dacapo/experiments/tasks/losses/hot_distance_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ def compute(self, prediction, target, weight):
return self.hot_loss(
prediction_hot, target_hot, weight_hot
) + self.distance_loss(prediction_distance, target_distance, weight_distance)

def hot_loss(self, prediction, target, weight):
loss = torch.nn.BCEWithLogitsLoss(reduction='none')
return torch.mean(loss(prediction , target) * weight)
loss = torch.nn.BCEWithLogitsLoss(reduction="none")
return torch.mean(loss(prediction, target) * weight)

def distance_loss(self, prediction, target, weight):
loss = torch.nn.MSELoss()
return loss(prediction * weight, target * weight)

def split(self, x):
# Shape[0] is the batch size and Shape[1] is the number of channels.
assert x.shape[1] % 2 == 0, f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
assert (
x.shape[1] % 2 == 0
), f"First dimension (Channels) of target {x.shape} must be even to be splitted in hot and distance."
mid = x.shape[1] // 2
return torch.split(x,mid,dim=1)
return torch.split(x, mid, dim=1)
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/predictors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from .one_hot_predictor import OneHotPredictor # noqa
from .predictor import Predictor # noqa
from .affinities_predictor import AffinitiesPredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
from .hot_distance_predictor import HotDistancePredictor # noqa
8 changes: 6 additions & 2 deletions dacapo/experiments/tasks/predictors/hot_distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
2,
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[mask[target.roi]],
moving_counts=None if moving_class_counts is None else moving_class_counts[: self.classes],
moving_counts=None
if moving_class_counts is None
else moving_class_counts[: self.classes],
)

if self.mask_distances:
Expand All @@ -95,7 +97,9 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
2,
slab=tuple(1 if c == "c" else -1 for c in gt.axes),
masks=[mask[target.roi], distance_mask],
moving_counts=None if moving_class_counts is None else moving_class_counts[-self.classes :],
moving_counts=None
if moving_class_counts is None
else moving_class_counts[-self.classes :],
)

weights = np.concatenate((one_hot_weights, distance_weights))
Expand Down
2 changes: 1 addition & 1 deletion dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def train_run(

weights_store.retrieve_weights(run, iteration=trained_until)

elif latest_weights_iteration > trained_until:
elif latest_weights_iteration > trained_until:
logger.warn(
f"Found weights for iteration {latest_weights_iteration}, but "
f"run {run.name} was only trained until {trained_until}. "
Expand Down
Loading