Skip to content

Commit

Permalink
[losses] Add 'spacing' option to flow field loss modules
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Dec 14, 2023
1 parent 41bcc55 commit c8c1626
Showing 1 changed file with 22 additions and 27 deletions.
49 changes: 22 additions & 27 deletions src/deepali/losses/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

from typing import Optional, Union

import torch
from torch import Tensor

from deepali.core.typing import ScalarOrTuple, Shape
from deepali.core.typing import Array, Scalar, ScalarOrTuple

from . import functional as L
from .base import DisplacementLoss
Expand All @@ -20,6 +19,7 @@ def __init__(
self,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple] = None,
reduction: str = "mean",
):
Expand All @@ -28,31 +28,27 @@ def __init__(
Args:
mode: Method used to approximate :func:`flow_derivatives()`.
sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
spacing: Spacing between grid elements. Should be given in the units of the flow vectors.
By default, flow vectors with respect to normalized grid coordinates are assumed.
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
reduction: Operation to use for reducing spatially distributed loss values.
"""
super().__init__()
self.mode = mode
self.sigma = sigma
self.spacing = spacing
self.stride = stride
self.reduction = reduction

def _spacing(self, u_shape: Shape) -> Optional[Tensor]:
ndim = len(u_shape)
if ndim < 3:
raise ValueError(f"{type(self).__name__}.forward() 'u' must be at least 3-dimensional")
if ndim == 3:
return None
size = torch.tensor(u_shape[-1:1:-1], dtype=torch.float, device=torch.device("cpu"))
return 2 / (size - 1)

def extra_repr(self) -> str:
args = []
if self.mode:
args.append(f"mode={self.mode!r}")
if self.sigma:
args.append(f"sigma={self.sigma!r}")
if self.spacing:
args.append(f"spacing={self.spacing!r}")
if self.stride:
args.append(f"stride={self.stride!r}")
args.append(f"reduction={self.reduction!r}")
Expand All @@ -68,6 +64,7 @@ def __init__(
q: Optional[Union[int, float]] = 1,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple] = None,
reduction: str = "mean",
):
Expand All @@ -76,24 +73,27 @@ def __init__(
Args:
mode: Method used to approximate :func:`flow_derivatives()`.
sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
spacing: Spacing between grid elements. Should be given in the units of the flow vectors.
By default, flow vectors with respect to normalized grid coordinates are assumed.
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
reduction: Operation to use for reducing spatially distributed loss values.
"""
super().__init__(mode=mode, sigma=sigma, stride=stride, reduction=reduction)
super().__init__(
mode=mode, sigma=sigma, spacing=spacing, stride=stride, reduction=reduction
)
self.p = p
self.q = 1 / p if q is None else q

def forward(self, u: Tensor) -> Tensor:
r"""Evaluate regularization loss for given transformation."""
spacing = self._spacing(u.shape)
return L.grad_loss(
u,
p=self.p,
q=self.q,
mode=self.mode,
sigma=self.sigma,
spacing=spacing,
spacing=self.spacing,
stride=self.stride,
reduction=self.reduction,
)
Expand All @@ -107,12 +107,11 @@ class Bending(_SpatialDerivativesLoss):

def forward(self, u: Tensor) -> Tensor:
r"""Evaluate regularization loss for given transformation."""
spacing = self._spacing(u.shape)
return L.bending_loss(
u,
mode=self.mode,
sigma=self.sigma,
spacing=spacing,
spacing=self.spacing,
stride=self.stride,
reduction=self.reduction,
)
Expand All @@ -127,12 +126,11 @@ class Curvature(_SpatialDerivativesLoss):

def forward(self, u: Tensor) -> Tensor:
r"""Evaluate regularization loss for given transformation."""
spacing = self._spacing(u.shape)
return L.curvature_loss(
u,
mode=self.mode,
sigma=self.sigma,
spacing=spacing,
spacing=self.spacing,
stride=self.stride,
reduction=self.reduction,
)
Expand All @@ -143,12 +141,11 @@ class Diffusion(_SpatialDerivativesLoss):

def forward(self, u: Tensor) -> Tensor:
r"""Evaluate regularization loss for given transformation."""
spacing = self._spacing(u.shape)
return L.diffusion_loss(
u,
mode=self.mode,
sigma=self.sigma,
spacing=spacing,
spacing=self.spacing,
stride=self.stride,
reduction=self.reduction,
)
Expand All @@ -159,12 +156,11 @@ class Divergence(_SpatialDerivativesLoss):

def forward(self, u: Tensor) -> Tensor:
r"""Evaluate regularization loss for given transformation."""
spacing = self._spacing(u.shape)
return L.divergence_loss(
u,
mode=self.mode,
sigma=self.sigma,
spacing=spacing,
spacing=self.spacing,
stride=self.stride,
reduction=self.reduction,
)
Expand All @@ -183,10 +179,11 @@ def __init__(
shear_modulus: Optional[float] = None,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple] = None,
reduction: str = "mean",
):
super().__init__(mode=mode, sigma=sigma, reduction=reduction)
super().__init__(mode=mode, sigma=sigma, spacing=spacing, reduction=reduction)
self.material_name = material_name
self.first_parameter = first_parameter
self.second_parameter = second_parameter
Expand All @@ -196,7 +193,6 @@ def __init__(

def forward(self, u: Tensor) -> Tensor:
r"""Evaluate regularization loss for given transformation."""
spacing = self._spacing(u.shape)
return L.elasticity_loss(
u,
material_name=self.material_name,
Expand All @@ -207,7 +203,7 @@ def forward(self, u: Tensor) -> Tensor:
shear_modulus=self.shear_modulus,
mode=self.mode,
sigma=self.sigma,
spacing=spacing,
spacing=self.spacing,
stride=self.stride,
reduction=self.reduction,
)
Expand All @@ -234,12 +230,11 @@ class TotalVariation(_SpatialDerivativesLoss):

def forward(self, u: Tensor) -> Tensor:
r"""Evaluate regularization loss for given transformation."""
spacing = self._spacing(u.shape)
return L.total_variation_loss(
u,
mode=self.mode,
sigma=self.sigma,
spacing=spacing,
spacing=self.spacing,
stride=self.stride,
reduction=self.reduction,
)
Expand Down

0 comments on commit c8c1626

Please sign in to comment.