Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

function grid_sample.py #27915

Closed
wants to merge 5 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions grid_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
if mode == 'nearest':
# Round the grid values to get the closest integer indices
x_rounded = ivy.round(grid[..., 0])
y_rounded = ivy.round(grid[..., 1])

if padding_mode == 'zeros':
# Create masks for out-of-bound x and y positions
mask_x = ivy.logical_or(x_rounded < 0, x_rounded >= input.shape[-1])
mask_y = ivy.logical_or(y_rounded < 0, y_rounded >= input.shape[-2])

# Combine the masks
mask = ivy.logical_or(mask_x, mask_y)

# Using the indices, gather the values from the input tensor
sampled_output = ivy.where(mask, ivy.zeros_like(input), input[..., y_rounded, x_rounded])

elif padding_mode == 'border':
# Clamp the indices to lie within the borders
x_clamped = ivy.clip(x_rounded, 0, input.shape[-1] - 1)
y_clamped = ivy.clip(y_rounded, 0, input.shape[-2] - 1)

# Using the clamped indices, gather the values from the input tensor
sampled_output = input[..., y_clamped, x_clamped]

else:
raise ValueError("Unsupported padding_mode. Expected 'zeros' or 'border'.")

elif mode == 'bilinear':
# Bilinear interpolation
raise NotImplementedError("Bilinear interpolation has not been implemented yet.")

else:
raise ValueError("Unsupported mode. Expected 'bilinear' or 'nearest'.")

return sampled_output
Loading