diff --git a/ivy/functional/frontends/jax/random.py b/ivy/functional/frontends/jax/random.py index 449b6f35a5cc8..c93e26eb0d7ac 100644 --- a/ivy/functional/frontends/jax/random.py +++ b/ivy/functional/frontends/jax/random.py @@ -267,6 +267,28 @@ def maxwell(key, shape=None, dtype="float64"): return x +@handle_jax_dtype +@to_ivy_arrays_and_back +@with_unsupported_dtypes( + { + "0.4.14 and below": ( + "uint32" + ) + }, + "jax", +) +def double_sided_maxwell(key, loc, scale, shape=(), dtype="float64"): + params_shapes = ivy.broadcast_shapes(ivy.shape(loc), ivy.shape(scale)) + if not shape: + shape = params_shapes + + shape = shape + params_shapes + maxwell_rvs = maxwell(key, shape=shape, dtype=dtype) + random_sign = rademacher(key, shape=shape, dtype=dtype) + + return random_sign * maxwell_rvs * scale+loc + + @handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( @@ -377,11 +399,11 @@ def poisson(key, lam, shape=None, dtype=None): ) def rademacher(key, shape, dtype="int64"): seed = _get_seed(key) - b = ivy.bernoulli(ivy.array([0.5]), shape=shape, dtype="float32", seed=seed) + prob = ivy.full(shape, 0.5, dtype="float32") + b = ivy.bernoulli(prob, shape=shape, dtype="float32", seed=seed) b = ivy.astype(b, dtype) return 2 * b - 1 - @handle_jax_dtype @to_ivy_arrays_and_back @with_unsupported_dtypes( diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py index 10ddbd13320db..cee4fcce80431 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_random.py @@ -875,6 +875,68 @@ def call(): assert u.shape == v.shape +@pytest.mark.xfail +@handle_frontend_test( + fn_tree="jax.random.double_sided_maxwell", + dtype_key=helpers.dtype_and_values( + available_dtypes=["uint32"], + min_value=1, + max_value=2000, + min_num_dims=1, + max_num_dims=1, + min_dim_size=2, + max_dim_size=2, + ), + shape=helpers.get_shape(), + dtype=helpers.get_dtypes("float", full=False), + loc=st.integers(min_value=10, max_value=100), + scale=st.floats(min_value=0, max_value=100, exclude_min=True), + test_with_out=st.just(False), +) +def test_jax_double_sided_maxwell( + *, + dtype_key, + loc, + scale, + shape, + dtype, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, key = dtype_key + + def call(): + return helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=False, + backend_to_test=backend_fw, + key=key[0], + loc=loc, + scale=scale, + shape=shape, + dtype=dtype[0], + ) + + ret = call() + + if not ivy.exists(ret): + return + + ret_np, ret_from_np = ret + ret_np = helpers.flatten_and_to_np(backend=backend_fw, ret=ret_np) + ret_from_np = helpers.flatten_and_to_np(backend=backend_fw, ret=ret_from_np) + for u, v in zip(ret_np, ret_from_np): + assert u.dtype == v.dtype + assert u.shape == v.shape + + @pytest.mark.xfail @handle_frontend_test( fn_tree="jax.random.multivariate_normal", @@ -1466,7 +1528,7 @@ def call(): @pytest.mark.xfail @handle_frontend_test( - fn_tree="jax.random.uniform", + fn_tree="jax.random.ball", dtype_key=helpers.dtype_and_values( available_dtypes=["uint32"], min_value=0,