Skip to content

Commit

Permalink
fix: torch backend mean reverted to work with keepdim, out args, …
Browse files Browse the repository at this point in the history
…as well as `dtype`
  • Loading branch information
mattbarrett98 committed Apr 22, 2024
1 parent 642e471 commit 9a1ac4c
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions ivy/functional/backends/torch/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,14 @@ def mean(
if dtype is not None:
dtype = ivy.as_native_dtype(dtype)
if axis is None:
ret = torch.mean(input=x, dtype=dtype)
else:
ret = torch.mean(input=x, dim=axis, keepdims=keepdims, dtype=dtype, out=out)
return ret
num_dims = len(x.shape)
axis = list(range(num_dims))
if axis in [(), []]:
if ivy.exists(out):
return ivy.inplace_update(out, x)
else:
return x
return torch.mean(x, dim=axis, dtype=dtype, keepdim=keepdims, out=out)


mean.support_native_out = True
Expand Down

0 comments on commit 9a1ac4c

Please sign in to comment.