Skip to content

Commit

Permalink
refactor: remove copy of dims_and_offset helper and make the call fro…
Browse files Browse the repository at this point in the history
…m helpers in test_jax_diagonal
  • Loading branch information
Ishticode committed Feb 17, 2024
1 parent 6004903 commit 8c8fd6d
Showing 1 changed file with 1 addition and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,6 @@ def _get_dtype_square_x(draw):
return dtype_x


# diagonal
@st.composite
def dims_and_offset(draw, shape):
shape_actual = draw(shape)
dim1 = draw(helpers.get_axis(shape=shape, force_int=True))
dim2 = draw(helpers.get_axis(shape=shape, force_int=True))
offset = draw(
st.integers(min_value=-shape_actual[dim1], max_value=shape_actual[dim1])
)
return dim1, dim2, offset


# unravel_index
@st.composite
def max_value_as_shape_prod(draw):
Expand Down Expand Up @@ -217,7 +205,7 @@ def test_jax_diag_indices_from(
available_dtypes=helpers.get_dtypes("float"),
shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape"),
),
dims_and_offset=dims_and_offset(
dims_and_offset=helpers.dims_and_offset(
shape=st.shared(helpers.get_shape(min_num_dims=2), key="shape")
),
)
Expand Down

0 comments on commit 8c8fd6d

Please sign in to comment.