Skip to content

Commit

Permalink
For double sided maxwell (ivy-llc#21264)
Browse files Browse the repository at this point in the history
  • Loading branch information
stalemate1 authored and druvdub committed Oct 14, 2023
1 parent 4c72922 commit e3c35a2
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 3 deletions.
26 changes: 24 additions & 2 deletions ivy/functional/frontends/jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,28 @@ def maxwell(key, shape, dtype="float64"):
return ivy.vector_norm(random_normal, axis=-1)


@handle_jax_dtype
@to_ivy_arrays_and_back
@with_unsupported_dtypes(
{
"0.4.14 and below": (
"uint32"
)
},
"jax",
)
def double_sided_maxwell(key, loc, scale, shape=(), dtype="float64"):
params_shapes = ivy.broadcast_shapes(ivy.shape(loc), ivy.shape(scale))
if not shape:
shape = params_shapes

shape = shape + params_shapes
maxwell_rvs = maxwell(key, shape=shape, dtype=dtype)
random_sign = rademacher(key, shape=shape, dtype=dtype)

return random_sign * maxwell_rvs * scale+loc


@handle_jax_dtype
@to_ivy_arrays_and_back
@with_unsupported_dtypes(
Expand Down Expand Up @@ -387,11 +409,11 @@ def poisson(key, lam, shape=None, dtype=None):
)
def rademacher(key, shape, dtype="int64"):
seed = _get_seed(key)
b = ivy.bernoulli(ivy.array([0.5]), shape=shape, dtype="float32", seed=seed)
prob = ivy.full(shape, 0.5, dtype="float32")
b = ivy.bernoulli(prob, shape=shape, dtype="float32", seed=seed)
b = ivy.astype(b, dtype)
return 2 * b - 1


@handle_jax_dtype
@to_ivy_arrays_and_back
@with_unsupported_dtypes(
Expand Down
64 changes: 63 additions & 1 deletion ivy_tests/test_ivy/test_frontends/test_jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,68 @@ def call():
assert u.shape == v.shape


@pytest.mark.xfail
@handle_frontend_test(
fn_tree="jax.random.double_sided_maxwell",
dtype_key=helpers.dtype_and_values(
available_dtypes=["uint32"],
min_value=1,
max_value=2000,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
max_dim_size=2,
),
shape=helpers.get_shape(),
dtype=helpers.get_dtypes("float", full=False),
loc=st.integers(min_value=10, max_value=100),
scale=st.floats(min_value=0, max_value=100, exclude_min=True),
test_with_out=st.just(False),
)
def test_jax_double_sided_maxwell(
*,
dtype_key,
loc,
scale,
shape,
dtype,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
):
input_dtype, key = dtype_key

def call():
return helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
test_values=False,
backend_to_test=backend_fw,
key=key[0],
loc=loc,
scale=scale,
shape=shape,
dtype=dtype[0],
)

ret = call()

if not ivy.exists(ret):
return

ret_np, ret_from_np = ret
ret_np = helpers.flatten_and_to_np(backend=backend_fw, ret=ret_np)
ret_from_np = helpers.flatten_and_to_np(backend=backend_fw, ret=ret_from_np)
for u, v in zip(ret_np, ret_from_np):
assert u.dtype == v.dtype
assert u.shape == v.shape


@pytest.mark.xfail
@handle_frontend_test(
fn_tree="jax.random.multivariate_normal",
Expand Down Expand Up @@ -1522,7 +1584,7 @@ def call():

@pytest.mark.xfail
@handle_frontend_test(
fn_tree="jax.random.uniform",
fn_tree="jax.random.ball",
dtype_key=helpers.dtype_and_values(
available_dtypes=["uint32"],
min_value=0,
Expand Down

0 comments on commit e3c35a2

Please sign in to comment.