Skip to content

Commit

Permalink
fix(plot): convert uint8 before plotting (#472)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Aug 3, 2022
1 parent 04c50a7 commit 933ddf0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def plot_image_sprites(
col_id = _idx % img_per_row

if show_index:
_img = Image.fromarray(_d.tensor)
_img = Image.fromarray(np.asarray(_d.tensor, dtype='uint8'))
draw = ImageDraw.Draw(_img)
draw.text((0, 0), str(_idx), (255, 255, 255))
_d.tensor = np.asarray(_img)
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/array/mixins/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ def test_plot_embeddings(da_and_dam):
_test_plot_embeddings(da)


def test_plot_sprites(tmpdir):
da = DocumentArray.empty(5)
da.tensors = np.random.random([5, 3, 226, 226])
da.plot_image_sprites(tmpdir / 'a.png', channel_axis=0, show_index=True)
assert os.path.exists(tmpdir / 'a.png')


def _test_plot_embeddings(da):
p = da.plot_embeddings(start_server=False)
assert os.path.exists(p)
Expand Down

0 comments on commit 933ddf0

Please sign in to comment.