diff --git a/CHANGELOG.md b/CHANGELOG.md index e5e037486423..6e2f29cd9093 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,3 +18,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - `LambdaCheckpointHook` uses global step and doesn't save on first step. +- Switched opencv2 functions with manual ones to get rid of the dependency. diff --git a/edflow/data/util/__init__.py b/edflow/data/util/__init__.py index 05850603595f..b4038e7b87d4 100644 --- a/edflow/data/util/__init__.py +++ b/edflow/data/util/__init__.py @@ -6,7 +6,6 @@ import matplotlib.gridspec as gridspec # noqa import numpy as np # noqa -import cv2 # noqa from edflow.util import walk # noqa @@ -31,18 +30,27 @@ def flow2hsv(flow): hsv[:, :, 2] = 255 # magnitude and angle - mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) + mag, ang = cart2polar(flow[..., 0], flow[..., 1]) # make it colorful - hsv[..., 0] = ang * 180 / np.pi / 2 - hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) - + hsv[..., 0] = ang * 180 / np.pi + normalizer = mpl.colors.Normalize(mag.min(), mag.max()) + hsv[..., 1] = np.int32(normalizer(mag) * 255) return hsv +def cart2polar(x, y): + """ + Takes two array as x and y coordinates and returns the magnitude and angle. + """ + r = np.sqrt(x ** 2 + y ** 2) + phi = np.arctan2(x, y) + return r, phi + + def hsv2rgb(hsv): """color space conversion hsv -> rgb. simple wrapper for nice name.""" - rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) + rgb = mpl.colors.hsv_to_rgb(hsv) return rgb @@ -196,6 +204,17 @@ def heatmap_fn(key, im, ax): im_fn(key, im, ax) +def keypoints_fn(key, keypoints, ax): + """ + Plots a list of keypoints as a dot plot. + """ + add_im_info(keypoints, ax) + x = keypoints[:, 0] + y = keypoints[:, 1] + ax.plot(x, y, "go", markersize=1) + ax.set_ylabel(key) + + def flow_fn(key, im, ax): """Plot an flow. Used by :func:`plot_datum`.""" im = flow2rgb(im) @@ -213,6 +232,7 @@ def other_fn(key, obj, ax): PLOT_FUNCTIONS = { "image": im_fn, "heat": heatmap_fn, + "keypoints": keypoints_fn, "flow": flow_fn, "other": other_fn, } @@ -228,7 +248,10 @@ def default_heuristic(key, obj): if obj.shape[-1] in [3, 4]: return "image" elif obj.shape[-1] == 2: - return "flow" + if len(obj.shape) <= 2: + return "keypoints" + else: + return "flow" else: return "heat" return "other" diff --git a/setup.py b/setup.py index 39a51a5f7154..91dd7b8cb14f 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,6 @@ package_data={"": ["*.yaml"]}, install_requires=[ "pyyaml", - "opencv-python", "tqdm", "Pillow", "chainer", diff --git a/tests/test_data/test_data_util.py b/tests/test_data/test_data_util.py new file mode 100644 index 000000000000..16eef340300f --- /dev/null +++ b/tests/test_data/test_data_util.py @@ -0,0 +1,28 @@ +import os + +import numpy as np + +from edflow.data.util import * + + +def test_plot_datum(): + test_image = np.ones((128, 128, 3), dtype=int) + test_heatmap = np.zeros((128, 128, 25), dtype=int) + test_keypoints = np.random.randint(0, 128, (25, 2)) + + test_example = { + "image": test_image, + "heatmap": test_heatmap, + "keypoints": test_keypoints, + } + plot_datum(test_example, "test_plot.png") + assert os.path.exists("test_plot.png") + os.remove("test_plot.png") + + +def test_cart2polar(): + x = np.array([1, 0]) + y = np.array([0, 0]) + r, phi = cart2polar(x, y) + assert r[0] == 1 + assert round(phi[0], 2) == round(np.pi / 2, 2)