Skip to content

Commit

Permalink
interpolate ivy-llc#26133 issue
Browse files Browse the repository at this point in the history
  • Loading branch information
rfatihors committed Oct 1, 2023
1 parent 280b05d commit 2308ff1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 17 deletions.
19 changes: 2 additions & 17 deletions ivy/functional/backends/paddle/experimental/layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# global
from turtle import st
from typing import Optional, Union, Tuple, List, Literal, Sequence, Callable
import paddle
from hypothesis import given

from ivy.functional.ivy.layers import (
_handle_padding,
Expand Down Expand Up @@ -665,6 +665,7 @@ def sliding_window(
)


@st.composite
def interpolate_linear(
x: paddle.Tensor,
size: Union[Sequence[int], int],
Expand All @@ -679,19 +680,3 @@ def interpolate_linear(
return paddle.nn.functional.interpolate(
x, size, scale_factor, mode, align_corners, align_mode, data_format, name
)


@given()
def test_interpolate_linear(
x: paddle.Tensor,
size: Union[Sequence[int], int],
mode: Optional[Literal["linear", "bilinear", "trilinear"]] = "linear",
scale_factor: Optional[Union[Sequence[int], int]] = None,
align_corners: Optional[bool] = False,
align_mode: int = 0,
data_format: str = "NCHW",
name: Optional[str] = None,
):
assert paddle.nn.functional.interpolate(
x, size, scale_factor, mode, align_corners, align_mode, data_format, name
) == interpolate_linear(x)
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# global
from ctypes import Union
from typing import Optional, Literal

import numpy as np
import torch
from hypothesis import strategies as st, assume
from numpy.random.mtrand import Sequence

# local
import ivy
import ivy_tests.test_ivy.helpers as helpers
from ivy.functional.backends import paddle
from ivy.functional.backends.paddle import interpolate_linear
from ivy_tests.test_ivy.helpers import handle_test


Expand Down Expand Up @@ -1085,6 +1091,21 @@ def test_interpolate(
)


def test_interpolate_linear(
x: paddle.Tensor,
size: Union[Sequence[int], int],
mode: Optional[Literal["linear", "bilinear", "trilinear"]] = "linear",
scale_factor: Optional[Union[Sequence[int], int]] = None,
align_corners: Optional[bool] = False,
align_mode: int = 0,
data_format: str = "NCHW",
name: Optional[str] = None,
):
assert paddle.nn.functional.interpolate(
x, size, scale_factor, mode, align_corners, align_mode, data_format, name
) == interpolate_linear(x)


@handle_test(
fn_tree="functional.ivy.experimental.max_pool1d",
x_k_s_p=helpers.arrays_for_pooling(
Expand Down

0 comments on commit 2308ff1

Please sign in to comment.