Skip to content

Commit

Permalink
add max_pool3d to pytorch frontend (#22038)
Browse files Browse the repository at this point in the history
  • Loading branch information
progs2002 authored Sep 14, 2023
1 parent 0eb0cfb commit 09473c3
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
25 changes: 25 additions & 0 deletions ivy/functional/frontends/torch/nn/functional/pooling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,28 @@ def max_pool2d(
dilation=dilation,
ceil_mode=ceil_mode,
)


@with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, "torch")
@to_ivy_arrays_and_back
def max_pool3d(
input,
kernel_size,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False,
):
if stride is None:
stride = kernel_size

return ivy.max_pool3d(
input,
kernel_size,
stride,
padding,
data_format="NCDHW",
dilation=dilation,
ceil_mode=ceil_mode,
)
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,46 @@ def test_torch_max_pool2d(
dilation=dilation,
ceil_mode=ceil_mode,
)


# max_pool3d
@handle_frontend_test(
fn_tree="torch.nn.functional.max_pool3d",
x_k_s_p=helpers.arrays_for_pooling(
min_dims=5,
max_dims=5,
min_side=1,
max_side=5,
only_explicit_padding=True,
return_dilation=True,
data_format="channel_first",
),
test_with_out=st.just(False),
ceil_mode=st.booleans(),
)
def test_torch_max_pool3d(
x_k_s_p,
ceil_mode,
*,
test_flags,
frontend,
backend_fw,
fn_tree,
on_device,
):
dtype, x, kernel, stride, pad, dilation = x_k_s_p
padding = (pad[0][0], pad[1][0], pad[2][0])
helpers.test_frontend_function(
input_dtypes=dtype,
backend_to_test=backend_fw,
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
input=x[0],
kernel_size=kernel,
stride=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
)

0 comments on commit 09473c3

Please sign in to comment.