Skip to content

Commit

Permalink
Add torch frontend grid_sample and test (#22539)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com>
  • Loading branch information
Deepdive543443 and Sam-Armstrong authored Sep 6, 2023
1 parent cef0521 commit 8cf4dc7
Show file tree
Hide file tree
Showing 2 changed files with 372 additions and 1 deletion.
282 changes: 281 additions & 1 deletion ivy/functional/frontends/torch/nn/functional/vision_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# local
import ivy
from ivy import with_unsupported_dtypes
from ivy import with_unsupported_dtypes, with_supported_dtypes
from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back
from ivy.utils.exceptions import IvyNotImplementedException

Expand Down Expand Up @@ -355,3 +355,283 @@ def upsample_bilinear(input, size=None, scale_factor=None):
@to_ivy_arrays_and_back
def upsample_nearest(input, size=None, scale_factor=None):
return interpolate(input, size=size, scale_factor=scale_factor, mode="nearest")


def reflect(x, low2, high2):
min = low2 / 2
span = (high2 - low2) / 2
x = ivy.abs(x - min)
frac_in = ivy.abs(x / span)
extra = (frac_in - ivy.floor(frac_in)) * ivy.abs(span)
flips = ivy.floor(x / span)
x[flips % 2 == 0] = (extra + min)[flips % 2 == 0]
x[flips % 2 != 0] = (span - extra + min)[flips % 2 != 0]
return x


def grid_sample_padding(grid, padding_mode, align_corners, borders=None):
if padding_mode == "reflection":
if align_corners:
for idx, border in enumerate(borders):
grid[..., idx] = reflect(grid[..., idx], 0, 2 * (border - 1))
grid[..., idx] = ivy.clip(grid[..., idx], 0, border - 1)

else:
for idx, border in enumerate(borders):
grid[..., idx] = reflect(grid[..., idx], -1, 2 * border - 1)
grid[..., idx] = ivy.clip(grid[..., idx], 0, border - 1)

elif padding_mode == "border":
for idx, border in enumerate(borders):
grid[..., idx] = ivy.clip(grid[..., idx], 0, border - 1)

masks = []
for idx, border in enumerate(borders):
masks.append(ivy.bitwise_or(grid[..., idx] < -4, grid[..., idx] > border + 2))
borders[idx] += 1

zeros_mask = masks[0]
for i in range(1, len(borders)):
zeros_mask = ivy.bitwise_or(zeros_mask, masks[i])

if grid[zeros_mask].shape[0] > 0:
grid[zeros_mask] = ivy.array(borders)
return grid


def bicubic_interp(x, t, alpha=-0.75):
n, h, w = t.shape
coeffs = []
coeffs.append(ivy.reshape(cubic_conv2(alpha, t + 1), (n, 1, h, w)))
coeffs.append(ivy.reshape(cubic_conv1(alpha, t), (n, 1, h, w)))
coeffs.append(ivy.reshape(cubic_conv1(alpha, 1 - t), (n, 1, h, w)))
coeffs.append(ivy.reshape(cubic_conv2(alpha, 2 - t), (n, 1, h, w)))
return x[0] * coeffs[0] + x[1] * coeffs[1] + x[2] * coeffs[2] + x[3] * coeffs[3]


cubic_conv1 = lambda A, x: ((A + 2) * x - (A + 3)) * x * x + 1
cubic_conv2 = lambda A, x: (((A * x) - (5 * A)) * x + (8 * A)) * x - (4 * A)


@with_supported_dtypes({"2.0.1 and below": ("float32", "float64")}, "torch")
@to_ivy_arrays_and_back
def grid_sample(
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
):
input_clone = ivy.copy_array(input)
grid_clone = ivy.copy_array(grid)

if ivy.get_num_dims(input_clone) == 4: # sample from 2D images
n, c, h, w = input_clone.shape
n, to_h, to_w, gc = grid_clone.shape

# Un-normalize 2D grid
if align_corners: # to range[0, size - 1]
grid_clone[..., 0] = ((grid_clone[..., 0] + 1) / 2) * (w - 1)
grid_clone[..., 1] = ((grid_clone[..., 1] + 1) / 2) * (h - 1)

elif not align_corners: # to range[0.5, size - 0.5]
grid_clone[..., 0] = ((grid_clone[..., 0] + 1) * w - 1) / 2
grid_clone[..., 1] = ((grid_clone[..., 1] + 1) * h - 1) / 2

batch_coor = ivy.reshape(ivy.arange(n), (-1, 1))
batch_coor = ivy.repeat(batch_coor, to_h * to_w, axis=1)
batch_coor = ivy.reshape(batch_coor, (n, to_h, to_w))
padding = [(0, 0) for _ in range(2)] + [(4, 4) for _ in range(2)]
input_clone = ivy.pad(input_clone, padding, mode="constant", constant_values=0)

if mode == "bicubic":
grid_floor = ivy.floor(grid_clone)
distance = grid_clone - grid_floor

tx, ty = distance[..., 0], distance[..., 1]

grid_floor -= 1
grid_floor = [
grid_sample_padding(
grid_floor + i, padding_mode, align_corners, borders=[w, h]
)
for i in range(4)
]

w_cubic = [
ivy.astype(grid_floor[i][..., 0] + 4, ivy.int64) for i in range(4)
]
h_cubic = [
ivy.astype(grid_floor[i][..., 1] + 4, ivy.int64) for i in range(4)
]

coeffs = [
bicubic_interp(
[
ivy.permute_dims(
input_clone[batch_coor, :, h_cubic[i], w_cubic[0]],
(0, 3, 1, 2),
),
ivy.permute_dims(
input_clone[batch_coor, :, h_cubic[i], w_cubic[1]],
(0, 3, 1, 2),
),
ivy.permute_dims(
input_clone[batch_coor, :, h_cubic[i], w_cubic[2]],
(0, 3, 1, 2),
),
ivy.permute_dims(
input_clone[batch_coor, :, h_cubic[i], w_cubic[3]],
(0, 3, 1, 2),
),
],
tx,
)
for i in range(4)
]
return bicubic_interp(coeffs, ty)

else:
grid_clone = grid_sample_padding(
grid_clone, padding_mode, align_corners, borders=[w, h]
)

if mode == "bilinear":
grid_clone += 4
w_coor = ivy.reshape(grid_clone[..., 0], (n, to_h, to_w))
h_coor = ivy.reshape(grid_clone[..., 1], (n, to_h, to_w))

w0 = ivy.astype(ivy.floor(w_coor), ivy.int64)
h0 = ivy.astype(ivy.floor(h_coor), ivy.int64)
w1 = w0 + 1
h1 = h0 + 1

v00 = ivy.permute_dims(input_clone[batch_coor, :, h0, w0], (0, 3, 1, 2))
v01 = ivy.permute_dims(input_clone[batch_coor, :, h0, w1], (0, 3, 1, 2))
v10 = ivy.permute_dims(input_clone[batch_coor, :, h1, w0], (0, 3, 1, 2))
v11 = ivy.permute_dims(input_clone[batch_coor, :, h1, w1], (0, 3, 1, 2))

alpha = ivy.reshape(w_coor - w0, (n, 1, to_h, to_w))
beta = ivy.reshape(h_coor - h0, (n, 1, to_h, to_w))

alpha = ivy.astype(alpha, ivy.float32)
beta = ivy.astype(beta, ivy.float32)

v0 = v00 * (1 - alpha) + v01 * alpha
v1 = v10 * (1 - alpha) + v11 * alpha

return v0 * (1 - beta) + v1 * beta

elif mode == "nearest":
w_coor = ivy.reshape(grid_clone[..., 0], (n, to_h, to_w))
h_coor = ivy.reshape(grid_clone[..., 1], (n, to_h, to_w))

w_coor = ivy.astype(ivy.round(w_coor), ivy.int64) + 4
h_coor = ivy.astype(ivy.round(h_coor), ivy.int64) + 4
return ivy.permute_dims(
input_clone[batch_coor, :, h_coor, w_coor], (0, 3, 1, 2)
)

else:
raise ivy.exceptions.IvyError(f"Not supported mode {mode}")

elif ivy.get_num_dims(input_clone) == 5: # sample from 3D images
n, c, d, h, w = input_clone.shape
n, to_d, to_h, to_w, gc = grid_clone.shape

# Un-normalize 3D grid
if align_corners: # to range[0, size - 1]
grid_clone[..., 0] = ((grid_clone[..., 0] + 1) / 2) * (w - 1)
grid_clone[..., 1] = ((grid_clone[..., 1] + 1) / 2) * (h - 1)
grid_clone[..., 2] = ((grid_clone[..., 2] + 1) / 2) * (d - 1)
elif not align_corners: # to range[0.5, size - 0.5]
grid_clone[..., 0] = ((grid_clone[..., 0] + 1) * w - 1) / 2
grid_clone[..., 1] = ((grid_clone[..., 1] + 1) * h - 1) / 2
grid_clone[..., 2] = ((grid_clone[..., 2] + 1) * d - 1) / 2

batch_coor = ivy.reshape(ivy.arange(n), (-1, 1))
batch_coor = ivy.repeat(batch_coor, to_d * to_h * to_w, axis=1)
batch_coor = ivy.reshape(batch_coor, (n, to_d, to_h, to_w))
padding = [(0, 0) for _ in range(2)] + [(3, 3) for _ in range(3)]
input_clone = ivy.pad(input_clone, padding, mode="constant", constant_values=0)

grid_clone = grid_sample_padding(
grid_clone, padding_mode, align_corners, borders=[w, h, d]
)

if mode == "bilinear":
grid_clone += 3
w_coor = ivy.reshape(grid_clone[..., 0], (n, to_d, to_h, to_w))
h_coor = ivy.reshape(grid_clone[..., 1], (n, to_d, to_h, to_w))
d_coor = ivy.reshape(grid_clone[..., 2], (n, to_d, to_h, to_w))

w0 = ivy.astype(ivy.floor(w_coor), ivy.int64)
h0 = ivy.astype(ivy.floor(h_coor), ivy.int64)
d0 = ivy.astype(ivy.floor(d_coor), ivy.int64)
w1 = w0 + 1
h1 = h0 + 1
d1 = d0 + 1

v000 = ivy.permute_dims(
input_clone[batch_coor, :, d0, h0, w0], (0, 4, 1, 2, 3)
) # tnw
v001 = ivy.permute_dims(
input_clone[batch_coor, :, d0, h0, w1], (0, 4, 1, 2, 3)
) # tne
v010 = ivy.permute_dims(
input_clone[batch_coor, :, d0, h1, w0], (0, 4, 1, 2, 3)
) # tsw
v011 = ivy.permute_dims(
input_clone[batch_coor, :, d0, h1, w1], (0, 4, 1, 2, 3)
) # tse
v100 = ivy.permute_dims(
input_clone[batch_coor, :, d1, h0, w0], (0, 4, 1, 2, 3)
) # bnw
v101 = ivy.permute_dims(
input_clone[batch_coor, :, d1, h0, w1], (0, 4, 1, 2, 3)
) # bne
v110 = ivy.permute_dims(
input_clone[batch_coor, :, d1, h1, w0], (0, 4, 1, 2, 3)
) # bsw
v111 = ivy.permute_dims(
input_clone[batch_coor, :, d1, h1, w1], (0, 4, 1, 2, 3)
) # bse

alpha = ivy.reshape(w_coor - w0, (n, 1, to_d, to_h, to_w))
beta = ivy.reshape(h_coor - h0, (n, 1, to_d, to_h, to_w))
gamma = ivy.reshape(d_coor - d0, (n, 1, to_d, to_h, to_w))

alpha = ivy.astype(alpha, ivy.float32)
beta = ivy.astype(beta, ivy.float32)
gamma = ivy.astype(gamma, ivy.float32)

v = (alpha * beta * gamma) * v111
v += ((1 - alpha) * beta * gamma) * v110
v += (alpha * (1 - beta) * gamma) * v101
v += ((1 - alpha) * (1 - beta) * gamma) * v100

v += (alpha * beta * (1 - gamma)) * v011
v += ((1 - alpha) * beta * (1 - gamma)) * v010
v += (alpha * (1 - beta) * (1 - gamma)) * v001
v += ((1 - alpha) * (1 - beta) * (1 - gamma)) * v000
return v

elif mode == "nearest":
ceil_mask = grid_clone % 1 == 0.5
grid_clone[ceil_mask] = ivy.astype(
ivy.ceil(grid_clone[ceil_mask]), ivy.int64
)

w_coor = ivy.reshape(grid_clone[..., 0], (n, to_d, to_h, to_w))
h_coor = ivy.reshape(grid_clone[..., 1], (n, to_d, to_h, to_w))
d_coor = ivy.reshape(grid_clone[..., 2], (n, to_d, to_h, to_w))

w_coor = ivy.astype(ivy.round(w_coor), ivy.int64) + 3
h_coor = ivy.astype(ivy.round(h_coor), ivy.int64) + 3
d_coor = ivy.astype(ivy.round(d_coor), ivy.int64) + 3
return ivy.permute_dims(
input_clone[batch_coor, :, d_coor, h_coor, w_coor], (0, 4, 1, 2, 3)
)

elif mode == "bicubic":
raise ivy.exceptions.IvyError(f"Bicubic is not support in 3D grid sampling")

else:
raise ivy.exceptions.IvyError(f"Not supported input shape {input_clone.shape}")

Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,63 @@ def _pad_helper(draw):
return dtype, input[0], padding, value, mode


@st.composite
def grid_sample_helper(draw, dtype, mode, mode_3d, padding_mode):
dtype = draw(dtype)
align_corners = draw(st.booleans())
dims = draw(st.integers(4, 5))
height = draw(helpers.ints(min_value=5, max_value=10))
width = draw(helpers.ints(min_value=5, max_value=10))
channels = draw(helpers.ints(min_value=1, max_value=3))

grid_h = draw(helpers.ints(min_value=2, max_value=4))
grid_w = draw(helpers.ints(min_value=2, max_value=4))
batch = draw(helpers.ints(min_value=1, max_value=5))

padding_mode = draw(st.sampled_from(padding_mode))
if dims == 4:
mode = draw(st.sampled_from(mode))
x = draw(
helpers.array_values(
dtype=dtype[0],
shape=[batch, channels, height, width],
min_value=-1,
max_value=1,
)
)

grid = draw(
helpers.array_values(
dtype=dtype[0],
shape=[batch, grid_h, grid_w, 2],
min_value=-1,
max_value=1,
)
)
elif dims == 5:
mode = draw(st.sampled_from(mode_3d))
depth = draw(helpers.ints(min_value=10, max_value=15))
grid_d = draw(helpers.ints(min_value=5, max_value=10))
x = draw(
helpers.array_values(
dtype=dtype[0],
shape=[batch, channels, depth, height, width],
min_value=-1,
max_value=1,
)
)

grid = draw(
helpers.array_values(
dtype=dtype[0],
shape=[batch, grid_d, grid_h, grid_w, 3],
min_value=-1,
max_value=1,
)
)
return dtype, x, grid, mode, padding_mode, align_corners


# --- Main --- #
# ------------ #

Expand Down Expand Up @@ -144,6 +201,40 @@ def test_torch_affine_grid(
)


@handle_frontend_test(
fn_tree="torch.nn.functional.grid_sample",
dtype_x_grid_modes=grid_sample_helper(
dtype=helpers.get_dtypes("valid", full=False),
mode=["nearest", "bilinear", "bicubic"],
mode_3d=["nearest", "bilinear"],
padding_mode=["border", "zeros", "reflection"],
),
)
def test_torch_grid_sample(
*,
dtype_x_grid_modes,
on_device,
backend_fw,
fn_tree,
frontend,
test_flags,
):
dtype, x, grid, mode, padding_mode, align_corners = dtype_x_grid_modes
helpers.test_frontend_function(
input_dtypes=dtype,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
input=x,
grid=grid,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)


@handle_frontend_test(
fn_tree="torch.nn.functional.interpolate",
dtype_and_input_and_other=_interp_args(
Expand Down

0 comments on commit 8cf4dc7

Please sign in to comment.