Skip to content

Commit

Permalink
support float32 and double64 of ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
ShenYuhan authored Dec 23, 2020
1 parent e5150df commit cf66d95
Showing 1 changed file with 26 additions and 11 deletions.
37 changes: 26 additions & 11 deletions visualdl/component/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def imgarray2bytes(np_array):
"""Convert image ndarray to bytes.
Args:
np_array (numpy.ndarray): Array to converte.
np_array (np.ndarray): Array to converte.
Returns:
Binary bytes of np_array.
Expand Down Expand Up @@ -106,7 +106,7 @@ def convert_to_HWC(tensor, input_format):
"""Convert `NCHW`, `HWC`, `HW` to `HWC`
Args:
tensor (numpy.ndarray): Value of image
tensor (np.ndarray): Value of image
input_format (string): Format of image
Return:
Expand Down Expand Up @@ -138,19 +138,34 @@ def convert_to_HWC(tensor, input_format):
return tensor


def denormalization(image_array):
"""Renormalise ndarray matrix.
Args:
image_array(np.ndarray): Value of image
Return:
Matrix after renormalising.
"""
if image_array.max() <= 1 and image_array.min() >= 0:
image_array *= 255
return image_array.astype(np.uint8)


def image(tag, image_array, step, walltime=None, dataformats="HWC"):
"""Package data to one image.
Args:
tag (string): Data identifier
image_array (numpy.ndarray): Value of iamge
image_array (np.ndarray): Value of image
step (int): Step of image
walltime (int): Wall time of image
dataformats (string): Format of image
Return:
Package with format of record_pb2.Record
"""
image_array = denormalization(image_array)
image_array = convert_to_HWC(image_array, dataformats)
image_bytes = imgarray2bytes(image_array)
image = Record.Image(encoded_image_string=image_bytes)
Expand All @@ -165,7 +180,7 @@ def embedding(tag, labels, hot_vectors, step, labels_meta=None, walltime=None):
Args:
tag (string): Data identifier
labels (list): A list of labels.
hot_vectors (numpy.array or list): A matrix which each row is
hot_vectors (np.array or list): A matrix which each row is
feature of labels.
step (int): Step of embeddings.
walltime (int): Wall time of embeddings.
Expand Down Expand Up @@ -199,7 +214,7 @@ def audio(tag, audio_array, sample_rate, step, walltime):
Args:
tag (string): Data identifier
audio_array (numpy.ndarray or list): audio represented by a numpy.array
audio_array (np.ndarray or list): audio represented by a np.array
sample_rate (int): Sample rate of audio
step (int): Step of audio
walltime (int): Wall time of audio
Expand Down Expand Up @@ -246,8 +261,8 @@ def histogram(tag, hist, bin_edges, step, walltime):
Args:
tag (string): Data identifier
hist (numpy.ndarray or list): The values of the histogram
bin_edges (numpy.ndarray or list): The bin edges
hist (np.ndarray or list): The values of the histogram
bin_edges (np.ndarray or list): The bin edges
step (int): Step of histogram
walltime (int): Wall time of histogram
Expand All @@ -265,8 +280,8 @@ def compute_curve(labels, predictions, num_thresholds=None, weights=None):
""" Compute precision-recall curve data by labels and predictions.
Args:
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
labels (np.ndarray or list): Binary labels for each element.
predictions (np.ndarray or list): The probability that an element be
classified as true.
num_thresholds (int): Number of thresholds used to draw the curve.
weights (float): Multiple of data to display on the curve.
Expand Down Expand Up @@ -318,8 +333,8 @@ def pr_curve(tag, labels, predictions, step, walltime, num_thresholds=127,
Args:
tag (string): Data identifier
labels (numpy.ndarray or list): Binary labels for each element.
predictions (numpy.ndarray or list): The probability that an element be
labels (np.ndarray or list): Binary labels for each element.
predictions (np.ndarray or list): The probability that an element be
classified as true.
step (int): Step of pr_curve
walltime (int): Wall time of pr_curve
Expand Down

0 comments on commit cf66d95

Please sign in to comment.