From d0172a9a8bead5350b5dbb1b7cd8fc23e1fd86a7 Mon Sep 17 00:00:00 2001 From: Andreas Schuh Date: Mon, 23 Oct 2023 15:32:00 +0000 Subject: [PATCH] [core] Fix flow field Jacobian determinant calculation --- src/deepali/core/flow.py | 57 ++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/src/deepali/core/flow.py b/src/deepali/core/flow.py index 85e44cb..9af12f4 100644 --- a/src/deepali/core/flow.py +++ b/src/deepali/core/flow.py @@ -146,7 +146,8 @@ def divergence( kwargs = dict(mode=mode, sigma=sigma, spacing=spacing, stride=stride) which = FlowDerivativeKeys.divergence(spatial_dims=D) deriv = flow_derivatives(flow, which=which, **kwargs) - div = torch.zeros((N, 1) + flow.shape[2:], dtype=flow.dtype, device=flow.device) + ref = deriv["du/dx"] + div = torch.zeros((N, 1) + ref.shape[2:], dtype=ref.dtype, device=ref.device) for value in deriv.values(): div = div.add_(value) return div @@ -378,22 +379,44 @@ def jacobian_det( kwargs = dict(mode=mode, sigma=sigma, spacing=spacing, stride=stride) which = FlowDerivativeKeys.jacobian(spatial_dims=D) deriv = flow_derivatives(flow, which=which, **kwargs) - jac: Optional[Tensor] = None - for perm in permutations(range(D)): - term: Optional[Tensor] = None - for i, j in zip(range(D), perm): - dij = deriv[FlowDerivativeKeys.symbol(i, j)] - if i == j: - dij = dij.add_(1) # T(x) = x + u(x) - term = dij if term is None else term.mul_(dij) - assert term is not None - if jac is None: - jac = term - elif is_even_permutation(perm): - jac = jac.add_(term) - else: - jac = jac.sub_(term) - assert jac is not None + # Add 1 to diagonal elements of Jacobian matrix, because T(x) = x + u(x) + for i in range(D): + deriv[FlowDerivativeKeys.symbol(i, i)].add_(1) + if D == 2: + a = deriv["du/dx"] + b = deriv["du/dy"] + c = deriv["dv/dx"] + d = deriv["dv/dy"] + jac = a.mul(d).sub_(b.mul(c)) + elif D == 3: + a = deriv["du/dx"] + b = deriv["du/dy"] + c = deriv["du/dz"] + d = deriv["dv/dx"] + e = deriv["dv/dy"] + f = deriv["dv/dz"] + g = deriv["dw/dx"] + h = deriv["dw/dy"] + i = deriv["dw/dz"] + term_1 = a.mul(e.mul(i).sub_(f.mul(h))) + term_2 = b.mul(d.mul(i).sub_(g.mul(f))) + term_3 = c.mul(d.mul(h).sub_(e.mul(g))) + jac = term_1.sub(term_2).add(term_3) + else: + jac: Optional[Tensor] = None + for perm in permutations(range(D)): + term: Optional[Tensor] = None + for i, j in zip(range(D), perm): + dij = deriv[FlowDerivativeKeys.symbol(i, j)] + term = dij if term is None else term.mul_(dij) + assert term is not None + if jac is None: + jac = term + elif is_even_permutation(perm): + jac = jac.add_(term) + else: + jac = jac.sub_(term) + assert jac is not None return jac