forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the function of Procrutes Alignment (open-mmlab#157)
* Add the function of Procrutes Alignment used in 3D joints error evaluation and corresponding test codes. * Fix the input shape of function compute_similarity_transform.
- Loading branch information
1 parent
ebb92f7
commit 3a35ed2
Showing
3 changed files
with
82 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
from .bottom_up_eval import (aggregate_results, get_group_preds, | ||
get_multi_stage_outputs) | ||
from .eval_hooks import DistEvalHook, EvalHook | ||
from .mesh_eval import compute_similarity_transform | ||
from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_pck_accuracy, | ||
keypoints_from_heatmaps, pose_pck_accuracy) | ||
|
||
__all__ = [ | ||
'EvalHook', 'DistEvalHook', 'pose_pck_accuracy', 'keypoints_from_heatmaps', | ||
'keypoint_pck_accuracy', 'keypoint_auc', 'keypoint_epe', 'get_group_preds', | ||
'get_multi_stage_outputs', 'aggregate_results' | ||
'get_multi_stage_outputs', 'aggregate_results', | ||
'compute_similarity_transform' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# ------------------------------------------------------------------------------ | ||
# Adapted from https://github.com/akanazawa/hmr | ||
# Original licence: Copyright (c) 2018 akanazawa, under the MIT License. | ||
# ------------------------------------------------------------------------------ | ||
|
||
import numpy as np | ||
|
||
|
||
def compute_similarity_transform(source_points, target_points): | ||
"""Computes a similarity transform (sR, t) that takes a set of 3D points | ||
source_points (N x 3) closest to a set of 3D points target_points, where R | ||
is an 3x3 rotation matrix, t 3x1 translation, s scale. And return the | ||
transformed 3D points source_points_hat (N x 3). i.e. solves the orthogonal | ||
Procrutes problem. | ||
Notes: | ||
Points number: N | ||
Args: | ||
source_points (np.ndarray([N, 3])): Source point set. | ||
target_points (np.ndarray([N, 3])): Target point set. | ||
Returns: | ||
source_points_hat (np.ndarray([N, 3])): Transformed source point set. | ||
""" | ||
|
||
assert (target_points.shape[0] == source_points.shape[0]) | ||
assert (target_points.shape[1] == 3 and source_points.shape[1] == 3) | ||
|
||
source_points = source_points.T | ||
target_points = target_points.T | ||
|
||
# 1. Remove mean. | ||
mu1 = source_points.mean(axis=1, keepdims=True) | ||
mu2 = target_points.mean(axis=1, keepdims=True) | ||
X1 = source_points - mu1 | ||
X2 = target_points - mu2 | ||
|
||
# 2. Compute variance of X1 used for scale. | ||
var1 = np.sum(X1**2) | ||
|
||
# 3. The outer product of X1 and X2. | ||
K = X1.dot(X2.T) | ||
|
||
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are | ||
# singular vectors of K. | ||
U, _, Vh = np.linalg.svd(K) | ||
V = Vh.T | ||
# Construct Z that fixes the orientation of R to get det(R)=1. | ||
Z = np.eye(U.shape[0]) | ||
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) | ||
# Construct R. | ||
R = V.dot(Z.dot(U.T)) | ||
|
||
# 5. Recover scale. | ||
scale = np.trace(R.dot(K)) / var1 | ||
|
||
# 6. Recover translation. | ||
t = mu2 - scale * (R.dot(mu1)) | ||
|
||
# 7. Transform the source points: | ||
source_points_hat = scale * R.dot(source_points) + t | ||
|
||
source_points_hat = source_points_hat.T | ||
|
||
return source_points_hat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import numpy as np | ||
from numpy.testing import assert_array_almost_equal | ||
|
||
from mmpose.core import compute_similarity_transform | ||
|
||
|
||
def test_compute_similarity_transform(): | ||
source = np.random.rand(14, 3) | ||
tran = np.random.rand(1, 3) | ||
scale = 0.5 | ||
target = source * scale + tran | ||
source_transformed = compute_similarity_transform(source, target) | ||
assert_array_almost_equal(source_transformed, target) |