Skip to content

Commit

Permalink
Add moveaxis to torch frontend 2 (ivy-llc#12927)
Browse files Browse the repository at this point in the history
Co-authored-by: WilliamHirst bod_holthe@outlook.com
  • Loading branch information
maxkom125 authored and Giac3 committed Mar 24, 2023
1 parent 1106614 commit 8622e00
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def movedim(input, source, destination):
return ivy.moveaxis(input, source, destination)


@to_ivy_arrays_and_back
def moveaxis(input, source, destination):
return ivy.moveaxis(input, source, destination)


@to_ivy_arrays_and_back
def hstack(tensors, *, out=None):
return ivy.hstack(tensors, out=out)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,78 @@ def test_torch_movedim(
)


# moveaxis
@handle_frontend_test(
fn_tree="torch.moveaxis",
dtype_and_input=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=-100,
max_value=100,
shape=st.shared(
helpers.get_shape(
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
key="a_s_d",
),
),
source=helpers.get_axis(
allow_none=False,
unique=True,
shape=st.shared(
helpers.get_shape(
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
key="a_s_d",
),
min_size=1,
force_int=True,
),
destination=helpers.get_axis(
allow_none=False,
unique=True,
shape=st.shared(
helpers.get_shape(
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
key="a_s_d",
),
min_size=1,
force_int=True,
),
test_with_out=st.just(False),
)
def test_torch_moveaxis(
*,
dtype_and_input,
source,
destination,
on_device,
fn_tree,
frontend,
test_flags,
):
input_dtype, value = dtype_and_input
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
input=value[0],
source=source,
destination=destination,
)


# hstack
@handle_frontend_test(
fn_tree="torch.hstack",
Expand Down

0 comments on commit 8622e00

Please sign in to comment.