Skip to content

Commit

Permalink
Merge pull request mmistakes#166 from mritv/master
Browse files Browse the repository at this point in the history
 Fixes mmistakes#150. Removed cv2 dependency.
  • Loading branch information
jhaux authored Sep 26, 2019
2 parents 9e687b9 + b4bf988 commit a25d3ec
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- `LambdaCheckpointHook` uses global step and doesn't save on first step.
- Switched opencv2 functions with manual ones to get rid of the dependency.
37 changes: 30 additions & 7 deletions edflow/data/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import matplotlib.gridspec as gridspec # noqa

import numpy as np # noqa
import cv2 # noqa

from edflow.util import walk # noqa

Expand All @@ -31,18 +30,27 @@ def flow2hsv(flow):
hsv[:, :, 2] = 255

# magnitude and angle
mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
mag, ang = cart2polar(flow[..., 0], flow[..., 1])

# make it colorful
hsv[..., 0] = ang * 180 / np.pi / 2
hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)

hsv[..., 0] = ang * 180 / np.pi
normalizer = mpl.colors.Normalize(mag.min(), mag.max())
hsv[..., 1] = np.int32(normalizer(mag) * 255)
return hsv


def cart2polar(x, y):
"""
Takes two array as x and y coordinates and returns the magnitude and angle.
"""
r = np.sqrt(x ** 2 + y ** 2)
phi = np.arctan2(x, y)
return r, phi


def hsv2rgb(hsv):
"""color space conversion hsv -> rgb. simple wrapper for nice name."""
rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
rgb = mpl.colors.hsv_to_rgb(hsv)
return rgb


Expand Down Expand Up @@ -196,6 +204,17 @@ def heatmap_fn(key, im, ax):
im_fn(key, im, ax)


def keypoints_fn(key, keypoints, ax):
"""
Plots a list of keypoints as a dot plot.
"""
add_im_info(keypoints, ax)
x = keypoints[:, 0]
y = keypoints[:, 1]
ax.plot(x, y, "go", markersize=1)
ax.set_ylabel(key)


def flow_fn(key, im, ax):
"""Plot an flow. Used by :func:`plot_datum`."""
im = flow2rgb(im)
Expand All @@ -213,6 +232,7 @@ def other_fn(key, obj, ax):
PLOT_FUNCTIONS = {
"image": im_fn,
"heat": heatmap_fn,
"keypoints": keypoints_fn,
"flow": flow_fn,
"other": other_fn,
}
Expand All @@ -228,7 +248,10 @@ def default_heuristic(key, obj):
if obj.shape[-1] in [3, 4]:
return "image"
elif obj.shape[-1] == 2:
return "flow"
if len(obj.shape) <= 2:
return "keypoints"
else:
return "flow"
else:
return "heat"
return "other"
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package_data={"": ["*.yaml"]},
install_requires=[
"pyyaml",
"opencv-python",
"tqdm",
"Pillow",
"chainer",
Expand Down
28 changes: 28 additions & 0 deletions tests/test_data/test_data_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os

import numpy as np

from edflow.data.util import *


def test_plot_datum():
test_image = np.ones((128, 128, 3), dtype=int)
test_heatmap = np.zeros((128, 128, 25), dtype=int)
test_keypoints = np.random.randint(0, 128, (25, 2))

test_example = {
"image": test_image,
"heatmap": test_heatmap,
"keypoints": test_keypoints,
}
plot_datum(test_example, "test_plot.png")
assert os.path.exists("test_plot.png")
os.remove("test_plot.png")


def test_cart2polar():
x = np.array([1, 0])
y = np.array([0, 0])
r, phi = cart2polar(x, y)
assert r[0] == 1
assert round(phi[0], 2) == round(np.pi / 2, 2)

0 comments on commit a25d3ec

Please sign in to comment.