From 6dc524766c40431fb8c6a049a80f68892d9eed21 Mon Sep 17 00:00:00 2001 From: Boghdady9 Date: Sat, 23 Sep 2023 15:48:54 +0300 Subject: [PATCH] feat(paddle): Add backward method to Paddle Frontend --- .../frontends/paddle/tensor/tensor.py | 4 +++ .../test_paddle/test_tensor/test_tensor.py | 36 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index 619c02107b635..b4a110f47d363 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -795,3 +795,7 @@ def nonzero(self): @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") def inner(self, y, name=None): return paddle_frontend.inner(self._ivy_array, y, name) + + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def mean(self, axis=None, keepdim=False, name=None): + return paddle_frontend.mean(self._ivy_array, axis=axis, keepdim=keepdim) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index 0bdcb4f424790..9e0b554dcd6c5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -424,6 +424,42 @@ def test_paddle_is_floating_point( ) +# mean +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="mean", + dtype_and_x=_statistical_dtype_values(function="mean"), + keepdim=st.booleans(), +) +def test_paddle_tensor_mean( + dtype_and_x, + keepdim, + frontend, + backend_fw, + frontend_method_data, + init_flags, + method_flags, + on_device, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "axis": axis, + "keepdim": keepdim, + }, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + backend_to_test=backend_fw, + method_flags=method_flags, + on_device=on_device, + ) + + # __add__ @handle_frontend_method( class_tree=CLASS_TREE,