Skip to content

Commit

Permalink
fix(plot): remove empty black sprites on nonsquare length (#431)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao authored Jul 8, 2022
1 parent df5ad71 commit 42d6005
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions docarray/array/mixins/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,18 +432,18 @@ def plot_image_sprites(
import matplotlib.pyplot as plt

img_per_row = ceil(sqrt(len(self)))
img_per_col = ceil(len(self) / img_per_row)
img_size = int(canvas_size / img_per_row)

if img_size < min_size:
# image is too small, recompute the size
img_size = min_size
img_per_row = int(canvas_size / img_size)

max_num_img = img_per_row**2
max_num_img = img_per_row * img_per_col
sprite_img = np.zeros(
[img_size * img_per_row, img_size * img_per_row, 3], dtype='uint8'
[img_size * img_per_col, img_size * img_per_row, 3], dtype='uint8'
)
img_id = 0
img_size_w, img_size_h = img_size, img_size
set_aspect_ratio = False

Expand Down Expand Up @@ -479,15 +479,15 @@ def plot_image_sprites(
h, w, _ = _d.tensor.shape
img_size_h = int(h * img_size / w)
sprite_img = np.zeros(
[img_size_h * img_per_row, img_size_w * img_per_row, 3],
[img_size_h * img_per_col, img_size_w * img_per_row, 3],
dtype='uint8',
)
set_aspect_ratio = True

_d.set_image_tensor_shape(shape=(img_size_h, img_size_w))

row_id = floor(img_id / img_per_row)
col_id = img_id % img_per_row
row_id = floor(_idx / img_per_row)
col_id = _idx % img_per_row

if show_index:
_img = Image.fromarray(_d.tensor)
Expand All @@ -500,9 +500,6 @@ def plot_image_sprites(
(col_id * img_size_w) : ((col_id + 1) * img_size_w),
] = _d.tensor

img_id += 1
if img_id >= max_num_img:
break
except Exception as ex:
raise ValueError(
'Bad image tensor. Try different `image_source` or `channel_axis`'
Expand Down

0 comments on commit 42d6005

Please sign in to comment.