Skip to content

Commit

Permalink
handle case in adaptive_avg_pool2d where jax array has no device attr
Browse files Browse the repository at this point in the history
  • Loading branch information
Nightcrab authored Aug 15, 2024
1 parent 8b5f4e7 commit 7497a4e
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions ivy/functional/ivy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2464,11 +2464,16 @@ def adaptive_avg_pool2d(
return ivy.squeeze(pooled_output, axis=0)
return pooled_output

if not hasattr(input, "device"):
device = list(input.devices())[0]
else:
device = input.device

idxh, length_h, range_max_h, adaptive_h = _compute_idx(
input.shape[-2], output_size[-2], input.device
input.shape[-2], output_size[-2], device
)
idxw, length_w, range_max_w, adaptive_w = _compute_idx(
input.shape[-1], output_size[-1], input.device
input.shape[-1], output_size[-1], device
)

# to numpy and back in order to bypass a slicing error in tensorflow
Expand Down

0 comments on commit 7497a4e

Please sign in to comment.