Skip to content

Commit

Permalink
Add the function of Procrutes Alignment (open-mmlab#157)
Browse files Browse the repository at this point in the history
* Add the function of Procrutes Alignment used in 3D joints error evaluation and corresponding test codes.

* Fix the input shape of function compute_similarity_transform.
  • Loading branch information
zengwang430521 authored Sep 28, 2020
1 parent ebb92f7 commit 3a35ed2
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mmpose/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .bottom_up_eval import (aggregate_results, get_group_preds,
get_multi_stage_outputs)
from .eval_hooks import DistEvalHook, EvalHook
from .mesh_eval import compute_similarity_transform
from .top_down_eval import (keypoint_auc, keypoint_epe, keypoint_pck_accuracy,
keypoints_from_heatmaps, pose_pck_accuracy)

__all__ = [
'EvalHook', 'DistEvalHook', 'pose_pck_accuracy', 'keypoints_from_heatmaps',
'keypoint_pck_accuracy', 'keypoint_auc', 'keypoint_epe', 'get_group_preds',
'get_multi_stage_outputs', 'aggregate_results'
'get_multi_stage_outputs', 'aggregate_results',
'compute_similarity_transform'
]
66 changes: 66 additions & 0 deletions mmpose/core/evaluation/mesh_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# ------------------------------------------------------------------------------
# Adapted from https://github.com/akanazawa/hmr
# Original licence: Copyright (c) 2018 akanazawa, under the MIT License.
# ------------------------------------------------------------------------------

import numpy as np


def compute_similarity_transform(source_points, target_points):
"""Computes a similarity transform (sR, t) that takes a set of 3D points
source_points (N x 3) closest to a set of 3D points target_points, where R
is an 3x3 rotation matrix, t 3x1 translation, s scale. And return the
transformed 3D points source_points_hat (N x 3). i.e. solves the orthogonal
Procrutes problem.
Notes:
Points number: N
Args:
source_points (np.ndarray([N, 3])): Source point set.
target_points (np.ndarray([N, 3])): Target point set.
Returns:
source_points_hat (np.ndarray([N, 3])): Transformed source point set.
"""

assert (target_points.shape[0] == source_points.shape[0])
assert (target_points.shape[1] == 3 and source_points.shape[1] == 3)

source_points = source_points.T
target_points = target_points.T

# 1. Remove mean.
mu1 = source_points.mean(axis=1, keepdims=True)
mu2 = target_points.mean(axis=1, keepdims=True)
X1 = source_points - mu1
X2 = target_points - mu2

# 2. Compute variance of X1 used for scale.
var1 = np.sum(X1**2)

# 3. The outer product of X1 and X2.
K = X1.dot(X2.T)

# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, _, Vh = np.linalg.svd(K)
V = Vh.T
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = np.eye(U.shape[0])
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
# Construct R.
R = V.dot(Z.dot(U.T))

# 5. Recover scale.
scale = np.trace(R.dot(K)) / var1

# 6. Recover translation.
t = mu2 - scale * (R.dot(mu1))

# 7. Transform the source points:
source_points_hat = scale * R.dot(source_points) + t

source_points_hat = source_points_hat.T

return source_points_hat
13 changes: 13 additions & 0 deletions tests/test_evaluation/test_mesh_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numpy as np
from numpy.testing import assert_array_almost_equal

from mmpose.core import compute_similarity_transform


def test_compute_similarity_transform():
source = np.random.rand(14, 3)
tran = np.random.rand(1, 3)
scale = 0.5
target = source * scale + tran
source_transformed = compute_similarity_transform(source, target)
assert_array_almost_equal(source_transformed, target)

0 comments on commit 3a35ed2

Please sign in to comment.