diff --git a/ivy/functional/backends/paddle/searching.py b/ivy/functional/backends/paddle/searching.py index 64b68a8a63ba1..34e869893059d 100644 --- a/ivy/functional/backends/paddle/searching.py +++ b/ivy/functional/backends/paddle/searching.py @@ -4,19 +4,24 @@ import paddle import ivy.functional.backends.paddle as paddle_backend import ivy -from ivy.func_wrapper import ( - with_supported_dtypes, - with_unsupported_dtypes, -) +from ivy.func_wrapper import with_unsupported_device_and_dtypes, with_supported_dtypes from . import backend_version from .elementwise import _elementwise_helper + # Array API Standard # # ------------------ # -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int16", "int32", "int64", "uint8")}, +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "complex64", + "complex128", + ) + } + }, backend_version, ) def argmax( @@ -30,6 +35,8 @@ def argmax( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: dtype = dtype if dtype is not None else paddle.int64 + if x.dtype in [paddle.int8, paddle.float16, paddle.bool]: + x = x.cast("float32") if select_last_index: x = paddle_backend.flip(x, axis=axis) ret = paddle.argmax(x, axis=axis, keepdim=keepdims) @@ -47,8 +54,15 @@ def argmax( return ret.astype(dtype) -@with_unsupported_dtypes( - {"2.5.1 and below": ("bfloat16", "bool", "complex", "float16", "int8")}, +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "complex64", + "complex128", + ) + } + }, backend_version, ) def argmin( @@ -62,6 +76,8 @@ def argmin( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: dtype = dtype if dtype is not None else paddle.int64 + if x.dtype in [paddle.int8, paddle.float16, paddle.bool]: + x = x.cast("float32") if select_last_index: x = paddle_backend.flip(x, axis=axis) ret = paddle.argmin(x, axis=axis, keepdim=keepdims) @@ -79,9 +95,6 @@ def argmin( return ret.astype(dtype) -@with_unsupported_dtypes( - {"2.5.1 and below": ("float16", "int8", "uint8")}, backend_version -) def nonzero( x: paddle.Tensor, /, @@ -90,11 +103,20 @@ def nonzero( size: Optional[int] = None, fill_value: Number = 0, ) -> Union[paddle.Tensor, Tuple[paddle.Tensor]]: - if paddle.is_complex(x): - real_idx = paddle.nonzero(x.real()) - imag_idx = paddle.nonzero(x.imag()) - idx = paddle.concat([real_idx, imag_idx], axis=0) - res = paddle.unique(idx, axis=0) + if x.dtype in [ + paddle.int8, + paddle.uint8, + paddle.float16, + paddle.complex64, + paddle.complex128, + ]: + if paddle.is_complex(x): + real_idx = paddle.nonzero(x.real()) + imag_idx = paddle.nonzero(x.imag()) + idx = paddle.concat([real_idx, imag_idx], axis=0) + res = paddle.unique(idx, axis=0) + else: + res = paddle.nonzero(x.cast("float32")) else: res = paddle.nonzero(x) @@ -119,6 +141,9 @@ def nonzero( return res.T +@with_supported_dtypes( + {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, backend_version +) def where( condition: paddle.Tensor, x1: Union[float, int, paddle.Tensor], @@ -160,17 +185,22 @@ def where( # ----- # -@with_unsupported_dtypes( - {"2.5.1 and below": ("float16", "int8", "uint8")}, backend_version -) def argwhere( x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: if x.ndim == 0: return paddle.zeros(shape=[int(bool(x.item())), 0], dtype="int64") - if paddle.is_complex(x): - real_idx = paddle.nonzero(x.real()) - imag_idx = paddle.nonzero(x.imag()) - idx = paddle.concat([real_idx, imag_idx], axis=0) - return paddle.unique(idx, axis=0) + if x.dtype in [ + paddle.int8, + paddle.uint8, + paddle.float16, + paddle.complex64, + paddle.complex128, + ]: + if paddle.is_complex(x): + real_idx = paddle.nonzero(x.real()) + imag_idx = paddle.nonzero(x.imag()) + idx = paddle.concat([real_idx, imag_idx], axis=0) + return paddle.unique(idx, axis=0) + return paddle.nonzero(x.cast("float32")) return paddle.nonzero(x) diff --git a/ivy/functional/backends/paddle/statistical.py b/ivy/functional/backends/paddle/statistical.py index 791eaefee92cc..1d2f30c1c38f7 100644 --- a/ivy/functional/backends/paddle/statistical.py +++ b/ivy/functional/backends/paddle/statistical.py @@ -6,26 +6,22 @@ import paddle import ivy +from ivy.utils.exceptions import IvyNotImplementedException from ivy.func_wrapper import ( - with_supported_dtypes, - with_unsupported_dtypes, + with_unsupported_device_and_dtypes, with_supported_device_and_dtypes, + with_supported_dtypes, ) import ivy.functional.backends.paddle as paddle_backend -from ivy.utils.einsum_parser import legalise_einsum_expr -from ivy.functional.ivy.statistical import _get_promoted_type_of_operands # local from . import backend_version + # Array API Standard # # -------------------# -@with_supported_dtypes( - {"2.5.1 and below": ("complex", "float32", "float64", "int32", "int64")}, - backend_version, -) def min( x: paddle.Tensor, /, @@ -35,10 +31,22 @@ def min( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: ret_dtype = x.dtype - if paddle.is_complex(x): - real = paddle.amin(x.real(), axis=axis, keepdim=keepdims) - imag = paddle.amin(x.imag(), axis=axis, keepdim=keepdims) - ret = paddle.complex(real, imag) + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.float16, + paddle.bfloat16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x): + real = paddle.amin(x.real(), axis=axis, keepdim=keepdims) + imag = paddle.amin(x.imag(), axis=axis, keepdim=keepdims) + ret = paddle.complex(real, imag) + else: + ret = paddle.amin(x.cast("float32"), axis=axis, keepdim=keepdims) else: ret = paddle.amin(x, axis=axis, keepdim=keepdims) # The following code is to simulate other frameworks @@ -51,10 +59,6 @@ def min( return ret.astype(ret_dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("complex", "float32", "float64", "int32", "int64")}, - backend_version, -) def max( x: paddle.Tensor, /, @@ -64,18 +68,30 @@ def max( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: ret_dtype = x.dtype - if paddle.is_complex(x): - const = paddle.to_tensor(1j, dtype=x.dtype) - real_max = paddle.max(x.real(), axis=axis, keepdim=keepdims) - imag = paddle.where( - x.real() == real_max, x.imag(), paddle.full_like(x.imag(), -1e10) - ) - # we consider the number with the biggest real and imag part - img_max = paddle.max(imag, axis=axis, keepdim=keepdims) - img_max = paddle.cast(img_max, x.dtype) - return paddle.add( - paddle.cast(real_max, x.dtype), paddle.multiply(img_max, const) - ) + if x.dtype in [ + paddle.int8, + paddle.int16, + paddle.uint8, + paddle.bfloat16, + paddle.float16, + paddle.complex64, + paddle.complex128, + paddle.bool, + ]: + if paddle.is_complex(x): + const = paddle.to_tensor(1j, dtype=x.dtype) + real_max = paddle.max(x.real(), axis=axis, keepdim=keepdims) + imag = paddle.where( + x.real() == real_max, x.imag(), paddle.full_like(x.imag(), -1e10) + ) + # we consider the number with the biggest real and imag part + img_max = paddle.max(imag, axis=axis, keepdim=keepdims) + img_max = paddle.cast(img_max, x.dtype) + return paddle.add( + paddle.cast(real_max, x.dtype), paddle.multiply(img_max, const) + ) + else: + ret = paddle.amax(x.cast("float32"), axis=axis, keepdim=keepdims) else: ret = paddle.amax(x, axis=axis, keepdim=keepdims) @@ -89,9 +105,6 @@ def max( return ret.astype(ret_dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("bool", "complex", "float32", "float64")}, backend_version -) def mean( x: paddle.Tensor, /, @@ -101,11 +114,17 @@ def mean( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: ret_dtype = x.dtype - if paddle.is_complex(x): - ret = paddle.complex( - paddle.mean(x.real(), axis=axis, keepdim=keepdims), - paddle.mean(x.imag(), axis=axis, keepdim=keepdims), - ) + if x.dtype not in [ + paddle.float32, + paddle.float64, + ]: + if paddle.is_complex(x): + ret = paddle.complex( + paddle.mean(x.real(), axis=axis, keepdim=keepdims), + paddle.mean(x.imag(), axis=axis, keepdim=keepdims), + ) + else: + ret = paddle.mean(x.cast("float32"), axis=axis, keepdim=keepdims) else: ret = paddle.mean(x, axis=axis, keepdim=keepdims) @@ -120,7 +139,7 @@ def mean( @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, backend_version + {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, backend_version ) def prod( x: paddle.Tensor, @@ -131,9 +150,16 @@ def prod( keepdims: bool = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + x_dtype = x.dtype + supported_dtypes = ["int32", "int64", "float32", "float64"] + if str(x_dtype) not in supported_dtypes: + x = x.cast("float32") + dtype_ = dtype + if str(dtype) not in supported_dtypes: + dtype = None ret = paddle.prod(x, axis=axis, keepdim=keepdims, dtype=dtype) - if ret.dtype != dtype: - ret = ret.cast(dtype) + if ret.dtype != dtype_: + ret = ret.cast(dtype_) return ret @@ -168,7 +194,6 @@ def std( return _std(x, axis, correction, keepdims).cast(x.dtype) -@with_unsupported_dtypes({"2.5.1 and below": ("int8", "uint8")}, backend_version) def sum( x: paddle.Tensor, /, @@ -180,7 +205,10 @@ def sum( ) -> paddle.Tensor: dtype = x.dtype if dtype is None else dtype dtype = ivy.as_ivy_dtype(dtype) - ret = paddle.sum(x, axis=axis, dtype=dtype, keepdim=keepdims) + if x.dtype in [paddle.int8, paddle.uint8]: + ret = paddle.sum(x.cast("float32"), axis=axis, dtype=dtype, keepdim=keepdims) + else: + ret = paddle.sum(x.cast(dtype), axis=axis, dtype=dtype, keepdim=keepdims) # The following code is to simulate other frameworks # output shapes behaviour since min output dim is 1 in paddle if isinstance(axis, Sequence): @@ -206,8 +234,12 @@ def var( # Extra # # ----- # -@with_supported_dtypes( - {"2.5.1 and below": ("complex", "float32", "float64", "int32", "int64")}, +@with_supported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ("int32", "int64", "float64", "complex128", "float32", "complex64") + } + }, backend_version, ) def cumprod( @@ -222,6 +254,13 @@ def cumprod( ) -> paddle.Tensor: dtype = dtype if dtype is not None else x.dtype x = paddle.cast(x, dtype) + if ivy.as_native_dtype(dtype) in [ + paddle.uint8, + paddle.int8, + paddle.int16, + paddle.float16, + ]: + x = paddle.cast(x, "float32") if not (exclusive or reverse): return paddle.cumprod(x, dim=axis).cast(dtype) elif exclusive and reverse: @@ -256,8 +295,9 @@ def cumprod( return paddle_backend.flip(x, axis=axis).cast(dtype) -@with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, backend_version +@with_unsupported_device_and_dtypes( + {"2.5.1 and below": {"cpu": ("complex64", "complex128")}}, + backend_version, ) def cumsum( x: paddle.Tensor, @@ -270,6 +310,13 @@ def cumsum( ) -> paddle.Tensor: dtype = dtype if dtype is not None else x.dtype x = paddle.cast(x, dtype) + if ivy.as_native_dtype(dtype) in [ + paddle.uint8, + paddle.int8, + paddle.float16, + paddle.bool, + ]: + x = paddle.cast(x, "float32") if not (exclusive or reverse): return paddle.cumsum(x, axis=axis).cast(dtype) elif exclusive and reverse: @@ -304,31 +351,9 @@ def cumsum( return paddle_backend.flip(x, axis=axis).cast(dtype) -@with_supported_device_and_dtypes( - { - "2.5.1 and below": { - "cpu": ("float32", "float64", "complex64", "complex128"), - "gpu": ( - "bfloat16", - "float16", - "float32", - "float64", - "complex64", - "complex128", - ), - }, - "2.4.2 and below": { - "cpu": ("float32", "float64", "complex64", "complex128"), - "gpu": ("float16", "float32", "float64", "complex64", "complex128"), - }, - }, - backend_version, -) def einsum( equation: str, *operands: paddle.Tensor, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - dtype = _get_promoted_type_of_operands(operands) - equation = legalise_einsum_expr(*[equation, *operands]) - return paddle.einsum(equation, *operands).astype(dtype) + raise IvyNotImplementedException() diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index b146e5743b934..b65a5e591bfb1 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -585,6 +585,38 @@ def power(x1, x2, /): return ivy.pow(x1, x2) +@to_ivy_arrays_and_back +def prod( + a, + *, + axis=None, + dtype=None, + keepdims=False, + initial=None, + where=None, + promote_integers=True, + out=None, +): + if ivy.is_array(where): + a = ivy.where(where, a, ivy.default(out, ivy.ones_like(a)), out=out) + + if dtype is None and promote_integers: + if ivy.is_uint_dtype(a.dtype): + dtype = "uint64" + elif ivy.is_int_dtype(a.dtype): + dtype = "int64" + + if initial is not None: + if axis is not None: + s = ivy.to_list(ivy.shape(a, as_array=True)) + s[axis] = 1 + header = ivy.full(ivy.Shape(tuple(s)), initial) + a = ivy.concat([header, a], axis=axis) + else: + a[0] *= initial + return ivy.prod(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out) + + @to_ivy_arrays_and_back def product( a, diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 9bc5d435dce60..055d7140e893c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -2710,6 +2710,49 @@ def test_jax_power( ) +@handle_frontend_test( + fn_tree="jax.numpy.prod", + dtype_x_axis_dtype_where=_get_castable_dtypes_values(use_where=True), + keepdims=st.booleans(), + initial=st.one_of(st.floats(min_value=-100, max_value=100)), + promote_integers=st.booleans(), +) +def test_jax_prod( + dtype_x_axis_dtype_where, + keepdims, + initial, + promote_integers, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype_where + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + axis=axis, + dtype=dtype, + keepdims=keepdims, + initial=initial, + where=where, + promote_integers=promote_integers, + atol=1e-01, + rtol=1e-01, + ) + + # rad2deg @handle_frontend_test( fn_tree="jax.numpy.rad2deg",