From 9a1ac4c3dbab68bb564d495f60ac441cb9c766be Mon Sep 17 00:00:00 2001 From: Matt Barrett <83289589+mattbarrett98@users.noreply.github.com> Date: Mon, 22 Apr 2024 15:44:14 +0000 Subject: [PATCH] fix: torch backend mean reverted to work with `keepdim`, `out` args, as well as `dtype` --- ivy/functional/backends/torch/statistical.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ivy/functional/backends/torch/statistical.py b/ivy/functional/backends/torch/statistical.py index 75aae3649db61..b5ff291d030d0 100644 --- a/ivy/functional/backends/torch/statistical.py +++ b/ivy/functional/backends/torch/statistical.py @@ -97,10 +97,14 @@ def mean( if dtype is not None: dtype = ivy.as_native_dtype(dtype) if axis is None: - ret = torch.mean(input=x, dtype=dtype) - else: - ret = torch.mean(input=x, dim=axis, keepdims=keepdims, dtype=dtype, out=out) - return ret + num_dims = len(x.shape) + axis = list(range(num_dims)) + if axis in [(), []]: + if ivy.exists(out): + return ivy.inplace_update(out, x) + else: + return x + return torch.mean(x, dim=axis, dtype=dtype, keepdim=keepdims, out=out) mean.support_native_out = True