diff --git a/ivy/functional/backends/torch/statistical.py b/ivy/functional/backends/torch/statistical.py index 75aae3649db61..b5ff291d030d0 100644 --- a/ivy/functional/backends/torch/statistical.py +++ b/ivy/functional/backends/torch/statistical.py @@ -97,10 +97,14 @@ def mean( if dtype is not None: dtype = ivy.as_native_dtype(dtype) if axis is None: - ret = torch.mean(input=x, dtype=dtype) - else: - ret = torch.mean(input=x, dim=axis, keepdims=keepdims, dtype=dtype, out=out) - return ret + num_dims = len(x.shape) + axis = list(range(num_dims)) + if axis in [(), []]: + if ivy.exists(out): + return ivy.inplace_update(out, x) + else: + return x + return torch.mean(x, dim=axis, dtype=dtype, keepdim=keepdims, out=out) mean.support_native_out = True