From 42fa0ad63b18dba61b00defa83f9c74b54ebea79 Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Wed, 20 Sep 2023 12:14:18 +0100 Subject: [PATCH 01/14] Added prod function and associated unit tests to Jax Frontend --- .../jax/numpy/mathematical_functions.py | 35 +++++++++++++++ .../test_numpy/test_mathematical_functions.py | 43 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 3e5b19ce92e04..86ce2a6ed5a50 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -580,6 +580,41 @@ def power(x1, x2, /): return ivy.pow(x1, x2) +@with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") +@to_ivy_arrays_and_back +def prod( + a, + *, + axis=None, + dtype=None, + keepdims=False, + initial=None, + where=None, + promote_integers=True, + out=None, +): + if ivy.is_array(where): + a = ivy.where(where, a, ivy.default(out, ivy.ones_like(a)), out=out) + + dtype or a.dtype + + if dtype is None and promote_integers: + if ivy.is_uint_dtype(a.dtype): + dtype = "uint64" + elif ivy.is_int_dtype(a.dtype): + dtype = "int64" + + if initial is not None: + if axis is not None: + s = ivy.to_list(ivy.shape(a, as_array=True)) + s[axis] = 1 + header = ivy.full(ivy.Shape(tuple(s)), initial) + a = ivy.concat([header, a], axis=axis) + else: + a[0] *= initial + return ivy.prod(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out) + + @to_ivy_arrays_and_back def product( a, diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 65be2aa2b07c3..0b78a14565ba2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -2662,6 +2662,49 @@ def test_jax_power( ) +@handle_frontend_test( + fn_tree="jax.numpy.prod", + dtype_x_axis_dtype_where=_get_castable_dtypes_values(use_where=True), + keepdims=st.booleans(), + initial=st.one_of(st.floats(min_value=-100, max_value=100)), + promote_integers=st.booleans(), +) +def test_jax_prod( + dtype_x_axis_dtype_where, + keepdims, + initial, + promote_integers, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype_where + if ivy.current_backend_str() == "torch": + assume(not test_flags.as_variable[0]) + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + axis=axis, + dtype=dtype, + keepdims=keepdims, + initial=initial, + where=where, + promote_integers=promote_integers, + ) + + # rad2deg @handle_frontend_test( fn_tree="jax.numpy.rad2deg", From cae1b5c901350b6d7f8dfa5f57bfc84dc7a4b715 Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Wed, 20 Sep 2023 13:37:17 +0100 Subject: [PATCH 02/14] Added prod function and associated unit tests to Jax Frontend --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 86ce2a6ed5a50..2735e00734bdc 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -596,8 +596,6 @@ def prod( if ivy.is_array(where): a = ivy.where(where, a, ivy.default(out, ivy.ones_like(a)), out=out) - dtype or a.dtype - if dtype is None and promote_integers: if ivy.is_uint_dtype(a.dtype): dtype = "uint64" From 15c3f17fd8ad4f087214ad9ec22ee7a3c2e3613e Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Thu, 21 Sep 2023 05:35:23 +0100 Subject: [PATCH 03/14] Handled unsupported dtypes for paddle by adding supported dtypes decorator. --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 2735e00734bdc..4e08e6369a2db 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -4,6 +4,7 @@ to_ivy_arrays_and_back, ) from ivy.func_wrapper import with_unsupported_dtypes +from ivy.func_wrapper import with_supported_dtypes from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs from ivy.functional.frontends.numpy.manipulation_routines import trim_zeros @@ -580,7 +581,9 @@ def power(x1, x2, /): return ivy.pow(x1, x2) -@with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) @to_ivy_arrays_and_back def prod( a, From f162c4c63b6cb23d40967f4b2a3d9f83f5bd08c2 Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Tue, 26 Sep 2023 10:12:57 +0100 Subject: [PATCH 04/14] Removed condition that was checking whether the backend is torch from the test file. --- .../test_jax/test_numpy/test_mathematical_functions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 0b78a14565ba2..fb7df4995fcfb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -2681,8 +2681,6 @@ def test_jax_prod( on_device, ): input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype_where - if ivy.current_backend_str() == "torch": - assume(not test_flags.as_variable[0]) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, From 5205cbdbd969b7c80ebf8f41a4614edfc9beaae4 Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Fri, 29 Sep 2023 14:33:05 +0100 Subject: [PATCH 05/14] Changed the with supported dtypes decorator to support dtypes for jax. --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 2 +- .../test_jax/test_numpy/test_mathematical_functions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 4e08e6369a2db..6e9354938c57a 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -582,7 +582,7 @@ def power(x1, x2, /): @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" + {"0.4.14 and below": ("float32", "float64", "int32", "int64")}, "jax" ) @to_ivy_arrays_and_back def prod( diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index fb7df4995fcfb..010da08155246 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -113,7 +113,7 @@ def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): def _get_dtype_input_and_vector(draw): size1 = draw(helpers.ints(min_value=1, max_value=5)) size2 = draw(helpers.ints(min_value=1, max_value=5)) - dtype = draw(helpers.get_dtypes("integer")) + dtype = draw(helpers.get_dtypes("numeric")) vec1 = draw(helpers.array_values(dtype=dtype[0], shape=(size1, size2))) return dtype, vec1 From 5eedc1ef9010a05d43b3af75b2bef56b2c5ba52b Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Fri, 29 Sep 2023 15:15:50 +0100 Subject: [PATCH 06/14] Rolling back changes to _get_dtype_input_and_vector --- .../test_jax/test_numpy/test_mathematical_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 010da08155246..fb7df4995fcfb 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -113,7 +113,7 @@ def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): def _get_dtype_input_and_vector(draw): size1 = draw(helpers.ints(min_value=1, max_value=5)) size2 = draw(helpers.ints(min_value=1, max_value=5)) - dtype = draw(helpers.get_dtypes("numeric")) + dtype = draw(helpers.get_dtypes("integer")) vec1 = draw(helpers.array_values(dtype=dtype[0], shape=(size1, size2))) return dtype, vec1 From 324247630f87c0522c4be58ec9aeb5b3e30ec1cf Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Mon, 16 Oct 2023 19:40:14 +0100 Subject: [PATCH 07/14] Removed supported type handler from jax frontend and handled unsupported dtypes in paddle backend functions --- ivy/functional/backends/paddle/searching.py | 6 +++++- ivy/functional/backends/paddle/statistical.py | 5 +++++ .../frontends/jax/numpy/mathematical_functions.py | 4 ---- .../test_jax/test_numpy/test_mathematical_functions.py | 2 ++ 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/ivy/functional/backends/paddle/searching.py b/ivy/functional/backends/paddle/searching.py index 9519e03c23916..34e869893059d 100644 --- a/ivy/functional/backends/paddle/searching.py +++ b/ivy/functional/backends/paddle/searching.py @@ -4,10 +4,11 @@ import paddle import ivy.functional.backends.paddle as paddle_backend import ivy -from ivy.func_wrapper import with_unsupported_device_and_dtypes +from ivy.func_wrapper import with_unsupported_device_and_dtypes, with_supported_dtypes from . import backend_version from .elementwise import _elementwise_helper + # Array API Standard # # ------------------ # @@ -140,6 +141,9 @@ def nonzero( return res.T +@with_supported_dtypes( + {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, backend_version +) def where( condition: paddle.Tensor, x1: Union[float, int, paddle.Tensor], diff --git a/ivy/functional/backends/paddle/statistical.py b/ivy/functional/backends/paddle/statistical.py index 0f294cde0e285..1d2f30c1c38f7 100644 --- a/ivy/functional/backends/paddle/statistical.py +++ b/ivy/functional/backends/paddle/statistical.py @@ -10,12 +10,14 @@ from ivy.func_wrapper import ( with_unsupported_device_and_dtypes, with_supported_device_and_dtypes, + with_supported_dtypes, ) import ivy.functional.backends.paddle as paddle_backend # local from . import backend_version + # Array API Standard # # -------------------# @@ -136,6 +138,9 @@ def mean( return ret.astype(ret_dtype) +@with_supported_dtypes( + {"2.5.1 and below": ("int32", "int64", "float32", "float64")}, backend_version +) def prod( x: paddle.Tensor, /, diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 6e9354938c57a..274c86a0fa1c0 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -4,7 +4,6 @@ to_ivy_arrays_and_back, ) from ivy.func_wrapper import with_unsupported_dtypes -from ivy.func_wrapper import with_supported_dtypes from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs from ivy.functional.frontends.numpy.manipulation_routines import trim_zeros @@ -581,9 +580,6 @@ def power(x1, x2, /): return ivy.pow(x1, x2) -@with_supported_dtypes( - {"0.4.14 and below": ("float32", "float64", "int32", "int64")}, "jax" -) @to_ivy_arrays_and_back def prod( a, diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index fb7df4995fcfb..75941cadfb9c0 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -2700,6 +2700,8 @@ def test_jax_prod( initial=initial, where=where, promote_integers=promote_integers, + atol=1e-01, + rtol=1e-01, ) From 016bf45c6eb11f75bfba99f41c5d9c340de38d5a Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Wed, 20 Sep 2023 12:14:18 +0100 Subject: [PATCH 08/14] Added prod function and associated unit tests to Jax Frontend --- .../jax/numpy/mathematical_functions.py | 35 +++++++++++++++ .../test_numpy/test_mathematical_functions.py | 43 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index b146e5743b934..881723b93f92e 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -585,6 +585,41 @@ def power(x1, x2, /): return ivy.pow(x1, x2) +@with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") +@to_ivy_arrays_and_back +def prod( + a, + *, + axis=None, + dtype=None, + keepdims=False, + initial=None, + where=None, + promote_integers=True, + out=None, +): + if ivy.is_array(where): + a = ivy.where(where, a, ivy.default(out, ivy.ones_like(a)), out=out) + + dtype or a.dtype + + if dtype is None and promote_integers: + if ivy.is_uint_dtype(a.dtype): + dtype = "uint64" + elif ivy.is_int_dtype(a.dtype): + dtype = "int64" + + if initial is not None: + if axis is not None: + s = ivy.to_list(ivy.shape(a, as_array=True)) + s[axis] = 1 + header = ivy.full(ivy.Shape(tuple(s)), initial) + a = ivy.concat([header, a], axis=axis) + else: + a[0] *= initial + return ivy.prod(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out) + + @to_ivy_arrays_and_back def product( a, diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 9bc5d435dce60..33d164817ed30 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -2710,6 +2710,49 @@ def test_jax_power( ) +@handle_frontend_test( + fn_tree="jax.numpy.prod", + dtype_x_axis_dtype_where=_get_castable_dtypes_values(use_where=True), + keepdims=st.booleans(), + initial=st.one_of(st.floats(min_value=-100, max_value=100)), + promote_integers=st.booleans(), +) +def test_jax_prod( + dtype_x_axis_dtype_where, + keepdims, + initial, + promote_integers, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype_where + if ivy.current_backend_str() == "torch": + assume(not test_flags.as_variable[0]) + where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( + where=where, + input_dtype=input_dtypes, + test_flags=test_flags, + ) + helpers.test_frontend_function( + input_dtypes=input_dtypes, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + axis=axis, + dtype=dtype, + keepdims=keepdims, + initial=initial, + where=where, + promote_integers=promote_integers, + ) + + # rad2deg @handle_frontend_test( fn_tree="jax.numpy.rad2deg", From 2486076343aad98bef572f8a41dde2d18041d07d Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Wed, 20 Sep 2023 13:37:17 +0100 Subject: [PATCH 09/14] Added prod function and associated unit tests to Jax Frontend --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 881723b93f92e..4d124d7140c8b 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -601,8 +601,6 @@ def prod( if ivy.is_array(where): a = ivy.where(where, a, ivy.default(out, ivy.ones_like(a)), out=out) - dtype or a.dtype - if dtype is None and promote_integers: if ivy.is_uint_dtype(a.dtype): dtype = "uint64" From a9caef2965dae7a63512b7f02d34ea5882120036 Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Thu, 21 Sep 2023 05:35:23 +0100 Subject: [PATCH 10/14] Handled unsupported dtypes for paddle by adding supported dtypes decorator. --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 4d124d7140c8b..5225fe340610d 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -4,6 +4,7 @@ to_ivy_arrays_and_back, ) from ivy.func_wrapper import with_unsupported_dtypes +from ivy.func_wrapper import with_supported_dtypes from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs from ivy.functional.frontends.numpy.manipulation_routines import trim_zeros @@ -585,7 +586,9 @@ def power(x1, x2, /): return ivy.pow(x1, x2) -@with_unsupported_dtypes({"2.5.1 and below": "bfloat16"}, "paddle") +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" +) @to_ivy_arrays_and_back def prod( a, From 18dfc19358db213382f53ad93c9266a6302b237f Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Tue, 26 Sep 2023 10:12:57 +0100 Subject: [PATCH 11/14] Removed condition that was checking whether the backend is torch from the test file. --- .../test_jax/test_numpy/test_mathematical_functions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 33d164817ed30..860a0f2160853 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -2729,8 +2729,6 @@ def test_jax_prod( on_device, ): input_dtypes, x, axis, dtype, where = dtype_x_axis_dtype_where - if ivy.current_backend_str() == "torch": - assume(not test_flags.as_variable[0]) where, input_dtypes, test_flags = np_frontend_helpers.handle_where_and_array_bools( where=where, input_dtype=input_dtypes, From e5dc61ee48d3e207f681339c291451c5a1ed787a Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Fri, 29 Sep 2023 14:33:05 +0100 Subject: [PATCH 12/14] Changed the with supported dtypes decorator to support dtypes for jax. --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 2 +- .../test_jax/test_numpy/test_mathematical_functions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 5225fe340610d..7d7491f3e035d 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -587,7 +587,7 @@ def power(x1, x2, /): @with_supported_dtypes( - {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" + {"0.4.14 and below": ("float32", "float64", "int32", "int64")}, "jax" ) @to_ivy_arrays_and_back def prod( diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 860a0f2160853..ca01ea7e5f370 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -113,7 +113,7 @@ def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): def _get_dtype_input_and_vector(draw): size1 = draw(helpers.ints(min_value=1, max_value=5)) size2 = draw(helpers.ints(min_value=1, max_value=5)) - dtype = draw(helpers.get_dtypes("integer")) + dtype = draw(helpers.get_dtypes("numeric")) vec1 = draw(helpers.array_values(dtype=dtype[0], shape=(size1, size2))) return dtype, vec1 From 043feb3f57cd18fbf456f094e7490140b20f235d Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Fri, 29 Sep 2023 15:15:50 +0100 Subject: [PATCH 13/14] Rolling back changes to _get_dtype_input_and_vector --- .../test_jax/test_numpy/test_mathematical_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index ca01ea7e5f370..860a0f2160853 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -113,7 +113,7 @@ def _get_castable_dtypes_values(draw, *, allow_nan=False, use_where=False): def _get_dtype_input_and_vector(draw): size1 = draw(helpers.ints(min_value=1, max_value=5)) size2 = draw(helpers.ints(min_value=1, max_value=5)) - dtype = draw(helpers.get_dtypes("numeric")) + dtype = draw(helpers.get_dtypes("integer")) vec1 = draw(helpers.array_values(dtype=dtype[0], shape=(size1, size2))) return dtype, vec1 From a271e8d64dea4719818660952e32fa8f3301b127 Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Mon, 16 Oct 2023 19:40:14 +0100 Subject: [PATCH 14/14] Removed supported type handler from jax frontend and handled unsupported dtypes in paddle backend functions --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 4 ---- .../test_jax/test_numpy/test_mathematical_functions.py | 2 ++ 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 7d7491f3e035d..b65a5e591bfb1 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -4,7 +4,6 @@ to_ivy_arrays_and_back, ) from ivy.func_wrapper import with_unsupported_dtypes -from ivy.func_wrapper import with_supported_dtypes from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs from ivy.functional.frontends.numpy.manipulation_routines import trim_zeros @@ -586,9 +585,6 @@ def power(x1, x2, /): return ivy.pow(x1, x2) -@with_supported_dtypes( - {"0.4.14 and below": ("float32", "float64", "int32", "int64")}, "jax" -) @to_ivy_arrays_and_back def prod( a, diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 860a0f2160853..055d7140e893c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -2748,6 +2748,8 @@ def test_jax_prod( initial=initial, where=where, promote_integers=promote_integers, + atol=1e-01, + rtol=1e-01, )