Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added std instance method to pytorch frontend #15003

Merged
merged 9 commits into from
May 20, 2023
2 changes: 2 additions & 0 deletions ivy/functional/frontends/torch/reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def median(input, dim=None, keepdim=False, *, out=None):


@to_ivy_arrays_and_back
@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, "torch")

def std(input, dim=None, unbiased=True, keepdim=False, *, out=None):
return ivy.std(input, axis=dim, correction=int(unbiased), keepdims=keepdim, out=out)

Expand Down
5 changes: 5 additions & 0 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,11 @@ def addcdiv(self, tensor1, tensor2, *, value=1):
def sign(self):
return torch_frontend.sign(self._ivy_array)

def std(self, dim=None, unbiased=True, keepdim=False, *, out=None):
return torch_frontend.std(
self, dim=dim, unbiased=unbiased, keepdim=keepdim, out=out
)

@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, "torch")
def fmod(self, other, *, out=None):
return torch_frontend.fmod(self, other, out=out)
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/ivy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ def _iter_product(*args, repeat=1):
for prod in result:
yield tuple(prod)


@handle_exceptions
@inputs_to_ivy_arrays
def ndenumerate(
Expand Down
32 changes: 32 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from ivy_tests.test_ivy.test_functional.test_core.test_statistical import (
_get_castable_dtype,
statistical_dtype_values,
)


Expand Down Expand Up @@ -8017,6 +8018,37 @@ def test_torch_instance_sign(
)


# std
@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="torch.tensor",
method_name="std",
dtype_and_x=statistical_dtype_values(function="std"),
)
def test_torch_instance_std(
dtype_and_x,
frontend,
frontend_method_data,
init_flags,
method_flags,
on_device,
):
input_dtype, x, _, _ = dtype_and_x
helpers.test_frontend_method(
init_input_dtypes=input_dtype,
init_all_as_kwargs_np={
"data": x[0],
},
method_input_dtypes=input_dtype,
method_all_as_kwargs_np={},
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
frontend=frontend,
on_device=on_device,
)


# fmod
@handle_frontend_method(
class_tree=CLASS_TREE,
Expand Down