Skip to content

Commit

Permalink
Merge pull request #34 from noshita/feat-thin-plate-spline
Browse files Browse the repository at this point in the history
feat: ✨ TPS
  • Loading branch information
noshita authored Mar 10, 2024
2 parents fa363ce + ae87df5 commit 18ab364
Show file tree
Hide file tree
Showing 6 changed files with 486 additions and 14 deletions.
74 changes: 63 additions & 11 deletions ktch/landmark/_Procrustes_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,12 @@

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from abc import ABCMeta

import numpy as np
import scipy as sp
import pandas as pd

import numpy.typing as npt


from sklearn.base import (
BaseEstimator,
TransformerMixin,
OneToOneFeatureMixin,
)
import scipy as sp
from sklearn.base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin


class GeneralizedProcrustesAnalysis(
Expand Down Expand Up @@ -248,3 +240,63 @@ def centroid_size(x):
x_c = x - np.mean(x, axis=0)
centroid_size = np.sqrt(np.trace(np.dot(x_c, x_c.T)))
return centroid_size


def _thin_plate_spline_2d(x_reference, x_target):
"""Thin-plate spline in 2D.
Parameters
----------
x_reference : array-like, shape (n_landmarks, n_dim)
Reference configuration.
x_target : array-like, shape (n_landmarks, n_dim)
Target configuration.
Returns
-------
W : ndarray, shape (n_landmarks, n_dim)
c : ndarray, shape (n_dim)
A : ndarray, shape (n_dim, n_dim)
"""

n_dim = 2

x_r = np.array(x_reference).reshape(-1, n_dim)
x_t = np.array(x_target).reshape(-1, n_dim)

n_landmarks = x_r.shape[0]

if not x_r.shape == x_t.shape:
raise ValueError("x_reference and x_target must have the same shape.")

r = sp.spatial.distance.cdist(x_r, x_r, "euclidean")

SMat = r**2 * np.log(r, out=np.zeros_like(r), where=(r != 0))
QMat = np.concatenate([np.ones(n_landmarks).reshape(-1, 1), x_r], 1)
zero_mat = np.zeros([n_dim + 1, n_dim + 1])

GammaMat = np.concatenate(
[
np.concatenate([SMat, QMat], 1),
np.concatenate([QMat.T, zero_mat], 1),
]
)
GammaInvMat = np.linalg.inv(GammaMat)
sol = np.dot(GammaInvMat, np.concatenate([x_t, np.zeros([n_dim + 1, n_dim])], 0))

W = sol[0:n_landmarks, :]
c = sol[n_landmarks, :]
A = sol[n_landmarks + 1 :, :]

return W, c, A


def _tps_2d(x, y, T, W, c, A):

t = np.array([x, y])

r = np.apply_along_axis(lambda v: np.sqrt(np.dot(v, v)), 1, t - T)

pred = c + np.dot(A, t) + np.dot(W.T, np.where(r == 0, 0, r**2 * np.log(r)))

return pred
6 changes: 3 additions & 3 deletions ktch/landmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ._Procrustes_analysis import GeneralizedProcrustesAnalysis
from ._Procrustes_analysis import centroid_size
from ._plot._tps import tps_grid_2d_plot
from ._Procrustes_analysis import GeneralizedProcrustesAnalysis, centroid_size

__all__ = ["GeneralizedProcrustesAnalysis", "centroid_size"]
__all__ = ["GeneralizedProcrustesAnalysis", "centroid_size", "tps_grid_2d_plot"]
Empty file added ktch/landmark/_plot/__init__.py
Empty file.
92 changes: 92 additions & 0 deletions ktch/landmark/_plot/_tps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Plot functions for thin-plate spline warping."""

# Copyright 2024 Koji Noshita
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from .._Procrustes_analysis import _thin_plate_spline_2d, _tps_2d


def tps_grid_2d_plot(
x_reference, x_target, grid_size="auto", outer=0.1, n_grid_inner=10, ax=None
):
"""Plot the thin-plate spline 2D warped grid.
Parameters
----------
x_reference : array-like, shape (n_landmarks, n_dim)
Reference configuration.
x_target : array-like, shape (n_landmarks, n_dim)
Target configuration.
grid_size : str/float, optional
Grid size, by default "auto"
outer : float, optional
Outer range of x_reference covered by the grid, by default 0.1
n_grid_inner : int, optional
Number of inner points on each grid, by default 10
ax : :class:`matplotlib.axes.Axes`, optional
Pre-existing matplotlib axes for the plot. Otherwise, call :func:`matplotlib.pyplot.gca` internally.
Returns
-------
ax : :class:`matplotlib.axes.Axes`
Matplotlib axes.
"""

import matplotlib.pyplot as plt

W, c, A = _thin_plate_spline_2d(x_reference, x_target)

if ax is None:
ax = plt.gca()

x_min, y_min = (1 + outer) * np.min(x_reference, axis=0)
x_max, y_max = (1 + outer) * np.max(x_reference, axis=0)

w = x_max - x_min
h = y_max - y_min

grid_size_ = grid_size
if grid_size == "auto":
grid_size_ = np.min([w, h]) / 10

if w > h:
w = w - w % grid_size_ + grid_size_
else:
h = h - w % grid_size + grid_size_
n_grid_x = np.rint(w / grid_size_)
n_grid_y = np.rint(h / grid_size_)

n_grid_x_ = int(n_grid_x * n_grid_inner + 1)
n_grid_y_ = int(n_grid_y * n_grid_inner + 1)

warped = np.array(
[
_tps_2d(x, y, x_reference, W, c, A)
for x in np.linspace(x_min, x_max, n_grid_x_)
for y in np.linspace(y_min, y_max, n_grid_y_)
]
)

w_1 = warped.reshape(n_grid_x_, n_grid_y_, 2)
w_2 = w_1.transpose(1, 0, 2)

ax.plot(w_1[:, ::n_grid_inner, 0], w_1[:, ::n_grid_inner, 1], "gray")
ax.plot(w_2[:, ::n_grid_inner, 0], w_2[:, ::n_grid_inner, 1], "gray")
ax.axis("equal")

ax.scatter(x=x_reference[:, 0], y=x_reference[:, 1], zorder=2)
ax.scatter(x=x_target[:, 0], y=x_target[:, 1], zorder=2)
return ax
1 change: 1 addition & 0 deletions notebooks/landmark/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ Examples concerning the :mod:`ktch.landmark` module.
./generalized_Procrustes_analysis
./gpa_from_tps
./thin_plate_spline
```
Loading

0 comments on commit 18ab364

Please sign in to comment.