From f1c4a3ae69dd4489733efb094e38bd8807a18457 Mon Sep 17 00:00:00 2001
From: ShenYuhan
+ + +
+ + ## Audio--音频播放组件 ### 介绍 diff --git a/visualdl/utils/img_util.py b/visualdl/utils/img_util.py new file mode 100644 index 000000000..0bb64c06f --- /dev/null +++ b/visualdl/utils/img_util.py @@ -0,0 +1,89 @@ +# Copyright (c) 2020 VisualDL Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ======================================================================= +import math +from functools import reduce + +import numpy as np + +from visualdl.component.base_component import convert_to_HWC + + +def padding_image(img, height, width): + height_old, width_old, _ = img.shape + height_before = math.floor((height - height_old) / 2) + height_after = height - height_old - height_before + + width_before = math.floor((width - width_old) / 2) + width_after = width - width_old - width_before + + return np.pad(img, ((height_before, height_after), (width_before, width_after), (0, 0))) + + +def merge_images(imgs, dataformats, scale=1.0, rows=-1): + assert rows <= len(imgs), "rows should not greater than numbers of pictures" + channel = imgs[0].shape[2] + # convert format of each image to `hwc` + for i, img in enumerate(imgs): + imgs[i] = convert_to_HWC(img, dataformats) + + height = -1 + width = -1 + + for img in imgs: + height = height if height > img.shape[0] else img.shape[0] + width = width if width > img.shape[1] else img.shape[1] + + # padding every sub-image with height and width + for i, img in enumerate(imgs): + imgs[i] = padding_image(img, height, width) + + # get row and col + len_imgs = len(imgs) + if -1 == rows: + rows = cols = math.floor(math.sqrt(len_imgs)) + while rows*cols < len_imgs: + if rows <= cols: + rows += 1 + else: + cols += 1 + else: + cols = math.ceil(len_imgs/rows) + + # add white sub-image + for i in range(rows*cols-len_imgs): + imgs = np.concatenate((imgs, np.zeros((height, width, channel), dtype=np.uint8)[None, :])) + + imgs = reduce(lambda x, y: np.concatenate((x, y)), [ + reduce(lambda x, y: np.concatenate((x, y), 1), + imgs[i * cols: (i + 1) * cols]) for i in range(rows)]) + + # choose bigger number of rows and cols + + scale = 1.0/scale * rows if rows > cols else 1.0/scale * cols + + dsize = tuple(map(lambda x: math.floor(x/scale), imgs.shape))[-2::-1] + + try: + import cv2 + + imgs = cv2.resize(src=imgs, dsize=dsize) + except ImportError: + from PIL import Image + + imgs = Image.fromarray(imgs) + imgs.resize(dsize) + imgs = np.array(imgs) + + return imgs diff --git a/visualdl/writer/writer.py b/visualdl/writer/writer.py index 77f09bddb..1714d8e15 100644 --- a/visualdl/writer/writer.py +++ b/visualdl/writer/writer.py @@ -17,6 +17,7 @@ import time import numpy as np from visualdl.writer.record_writer import RecordFileWriter +from visualdl.utils.img_util import merge_images from visualdl.component.base_component import scalar, image, embedding, audio, histogram, pr_curve, roc_curve, meta_data @@ -189,6 +190,36 @@ def add_image(self, tag, img, step, walltime=None, dataformats="HWC"): image(tag=tag, image_array=img, step=step, walltime=walltime, dataformats=dataformats)) + def add_image_matrix(self, tag, imgs, step, rows=-1, scale=1.0, walltime=None, dataformats="HWC"): + """Add an image to vdl record file. + + Args: + tag (string): Data identifier + imgs (np.ndarray): Image represented by a numpy.array + step (int): Step of image + rows (int): Number of rows, -1 means as close as possible to the square + scale (float): Image zoom scale + walltime (int): Wall time of image + dataformats (string): Format of image + + Example: + from PIL import Image + import numpy as np + + I = Image.open("./test.png") + I_array = np.array([I, I, I]) + writer.add_image_matrix(tag="lll", imgs=I_array, step=0) + """ + if '%' in tag: + raise RuntimeError("% can't appear in tag!") + walltime = round(time.time() * 1000) if walltime is None else walltime + img = merge_images(imgs=imgs, dataformats=dataformats, scale=scale, rows=rows) + self.add_image(tag=tag, + img=img, + step=step, + walltime=walltime, + dataformats=dataformats) + def add_embeddings(self, tag, labels, hot_vectors, labels_meta=None, walltime=None): """Add embeddings to vdl record file.