From c686786b497e927651b97b499a93c228c916b09c Mon Sep 17 00:00:00 2001 From: Sam-Armstrong Date: Tue, 10 Sep 2024 03:30:33 +0100 Subject: [PATCH] fix: array squeeze in torch frontend median --- ivy/functional/frontends/torch/reduction_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/torch/reduction_ops.py b/ivy/functional/frontends/torch/reduction_ops.py index fe5dc5d06eb9..8a3bdfa2f281 100644 --- a/ivy/functional/frontends/torch/reduction_ops.py +++ b/ivy/functional/frontends/torch/reduction_ops.py @@ -172,9 +172,9 @@ def median(input, dim=None, keepdim=False, *, out=None): median_indices = ivy.gather( sorted_indices, (sorted_indices.shape[dim] - 1) // 2, axis=dim ) - median_values = ivy.take_along_axis( + median_values = ivy.squeeze(ivy.take_along_axis( input, ivy.expand_dims(median_indices, axis=dim), dim - ).squeeze(axis=dim) + ), axis=dim) if keepdim: median_values = ivy.expand_dims(median_values, axis=dim)