Skip to content

Commit

Permalink
Merge branch 'master' into rj/storage-info
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky authored Feb 29, 2024
2 parents 365e945 + d1061cc commit cde9e95
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### Features
- Added `get_workspace_status()` method to management API ([#1662](https://github.com/neptune-ai/neptune-client/pull/1662))
- Added auto-scaling pixel values for image logging ([#1664](https://github.com/neptune-ai/neptune-client/pull/1664))

### Fixes
- Restored support for SSL verification exception ([#1661](https://github.com/neptune-ai/neptune-client/pull/1661))
Expand Down
66 changes: 40 additions & 26 deletions src/neptune/internal/utils/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations

__all__ = [
"get_image_content",
"get_html_content",
Expand All @@ -37,6 +39,7 @@
)
from typing import Optional

import numpy as np
from packaging import version
from pandas import DataFrame

Expand All @@ -45,6 +48,7 @@

logger = get_logger()
SEABORN_GRID_CLASSES = {"FacetGrid", "PairGrid", "JointGrid"}
ALLOWED_IMG_PIXEL_RANGES = ("[0, 255]", "[0.0, 1.0]")

try:
from numpy import array as numpy_array
Expand All @@ -65,8 +69,8 @@ def pilimage_fromarray():
pass


def get_image_content(image) -> Optional[bytes]:
content = _image_to_bytes(image)
def get_image_content(image, autoscale=True) -> Optional[bytes]:
content = _image_to_bytes(image, autoscale)

return content

Expand All @@ -83,12 +87,12 @@ def get_pickle_content(obj) -> Optional[bytes]:
return content


def _image_to_bytes(image) -> bytes:
def _image_to_bytes(image, autoscale) -> bytes:
if image is None:
raise ValueError("image is None")

elif is_numpy_array(image):
return _get_numpy_as_image(image)
return _get_numpy_as_image(image, autoscale)

elif is_pil_image(image):
return _get_pil_image_data(image)
Expand All @@ -97,10 +101,10 @@ def _image_to_bytes(image) -> bytes:
return _get_figure_image_data(image)

elif _is_torch_tensor(image):
return _get_numpy_as_image(image.detach().numpy())
return _get_numpy_as_image(image.detach().numpy(), autoscale)

elif _is_tensorflow_tensor(image):
return _get_numpy_as_image(image.numpy())
return _get_numpy_as_image(image.numpy(), autoscale)

elif is_seaborn_figure(image):
return _get_figure_image_data(image.figure)
Expand Down Expand Up @@ -196,38 +200,48 @@ def _image_content_to_html(content: bytes) -> str:
return "<img src='data:image/png;base64," + str_equivalent_image + "'/>"


def _get_numpy_as_image(array):
def _get_numpy_as_image(array: np.ndarray, autoscale: bool) -> bytes:
array = array.copy() # prevent original array from modifying
if autoscale:
array = _scale_array(array)

data_range_warnings = []
array_min = array.min()
array_max = array.max()
if array_min < 0:
data_range_warnings.append(f"the smallest value in the array is {array_min}")
if array_max > 1:
data_range_warnings.append(f"the largest value in the array is {array_max}")
if data_range_warnings:
data_range_warning_message = (" and ".join(data_range_warnings) + ".").capitalize()
logger.warning(
"%s To be interpreted as colors correctly values in the array need to be in the [0, 1] range.",
data_range_warning_message,
)
array *= 255
shape = array.shape
if len(shape) == 2:
if len(array.shape) == 2:
return _get_pil_image_data(pilimage_fromarray(array.astype(numpy_uint8)))
if len(shape) == 3:
if shape[2] == 1:
if len(array.shape) == 3:
if array.shape[2] == 1:
array2d = numpy_array([[col[0] for col in row] for row in array])
return _get_pil_image_data(pilimage_fromarray(array2d.astype(numpy_uint8)))
if shape[2] in (3, 4):
if array.shape[2] in (3, 4):
return _get_pil_image_data(pilimage_fromarray(array.astype(numpy_uint8)))
raise ValueError(
"Incorrect size of numpy.ndarray. Should be 2-dimensional or"
"3-dimensional with 3rd dimension of size 1, 3 or 4."
)


def _scale_array(array: np.ndarray) -> np.ndarray:
array_min = array.min()
array_max = array.max()

if array_min >= 0 and 1 < array_max <= 255:
return array

if array_min >= 0 and array_max <= 1:
return array * 255

_warn_about_incorrect_image_data_range(array_min, array_max)
return array


def _warn_about_incorrect_image_data_range(array_min: int | float, array_max: int | float) -> None:
msg = f"Image data is in range [{array_min}, {array_max}]."
logger.warning(
"%s To be interpreted as colors correctly values in the array need to be in the %s or %s range.",
msg,
*ALLOWED_IMG_PIXEL_RANGES,
)


def _get_pil_image_data(image: PILImage) -> bytes:
with io.BytesIO() as image_buffer:
image.save(image_buffer, format="PNG")
Expand Down
4 changes: 2 additions & 2 deletions src/neptune/types/atoms/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def from_stream(stream: IOBase, *, seek: Optional[int] = 0, extension: Optional[
return File(file_composite=file_composite)

@staticmethod
def as_image(image) -> "File":
def as_image(image, autoscale: bool = True) -> "File":
"""Static method for converting image objects or image-like objects to an image File value object.
This way you can upload `Matplotlib` figures, `Seaborn` figures, `PIL` images, `NumPy` arrays, as static images.
Expand Down Expand Up @@ -207,7 +207,7 @@ def as_image(image) -> "File":
.. _as_image docs page:
https://docs.neptune.ai/api/field_types#as_image
"""
content_bytes = get_image_content(image)
content_bytes = get_image_content(image, autoscale=autoscale)
return File.from_content(content_bytes if content_bytes is not None else b"", extension="png")

@staticmethod
Expand Down
58 changes: 46 additions & 12 deletions tests/unit/neptune/new/internal/utils/test_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
IS_WINDOWS,
)
from neptune.internal.utils.images import (
_scale_array,
get_html_content,
get_image_content,
)
Expand Down Expand Up @@ -75,23 +76,13 @@ def test_get_image_content_from_2d_grayscale_array(self):
def test_get_image_content_from_3d_grayscale_array(self):
# given
image_array = numpy.array([[[1], [0]], [[-3], [4]], [[5], [6]]])
expected_array = numpy.array([[1, 0], [-3, 4], [5, 6]]) * 255
expected_array = numpy.array([[1, 0], [-3, 4], [5, 6]])
expected_image = Image.fromarray(expected_array.astype(numpy.uint8))

# when
_log = partial(format_log, "WARNING")

# expect
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
self.assertEqual(get_image_content(image_array), self._encode_pil_image(expected_image))
self.assertEqual(
stdout.getvalue(),
_log(
"The smallest value in the array is -3 and the largest value in the array is 6."
" To be interpreted as colors correctly values in the array need to be in the [0, 1] range.\n",
),
)
self.assertEqual(get_image_content(image_array), self._encode_pil_image(expected_image))

def test_get_image_content_from_rgb_array(self):
# given
Expand Down Expand Up @@ -292,3 +283,46 @@ def _random_image_array(w=20, h=30, d: Optional[int] = 3):
return numpy.random.rand(w, h, d)
else:
return numpy.random.rand(w, h)


def test_scale_array_when_array_already_scaled():
# given
arr = numpy.array([[123, 32], [255, 0]])

# when
result = _scale_array(arr)

# then
assert numpy.all(arr == result)


def test_scale_array_when_array_not_scaled():
# given
arr = numpy.array([[0.3, 0], [0.5, 1]])

# when
result = _scale_array(arr)
expected = numpy.array([[76.5, 0.0], [127.5, 255.0]])

# then
assert numpy.all(expected == result)


def test_scale_array_incorrect_range():
# given
arr = numpy.array([[-12, 7], [300, 0]])

# when
_log = partial(format_log, "WARNING")

stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
result = _scale_array(arr)

# then
assert numpy.all(arr == result) # returned original array

assert stdout.getvalue() == _log(
"Image data is in range [-12, 300]. To be interpreted as colors "
"correctly values in the array need to be in the [0, 255] or [0.0, 1.0] range.\n",
)

0 comments on commit cde9e95

Please sign in to comment.