Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port tests for F.gaussian_blur GaussianBlur #7935

Merged
merged 4 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,17 +269,6 @@ def __init__(
],
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
),
ConsistencyConfig(
v2_transforms.GaussianBlur,
legacy_transforms.GaussianBlur,
[
ArgsKwargs(kernel_size=3),
ArgsKwargs(kernel_size=(1, 5)),
ArgsKwargs(kernel_size=3, sigma=0.7),
ArgsKwargs(kernel_size=5, sigma=(0.3, 1.4)),
],
closeness_kwargs={"rtol": 1e-5, "atol": 1e-5},
),
ConsistencyConfig(
v2_transforms.RandomPerspective,
legacy_transforms.RandomPerspective,
Expand Down Expand Up @@ -512,7 +501,6 @@ def test_call_consistency(config, args_kwargs):
)
for transform_cls, get_params_args_kwargs in [
(v2_transforms.ColorJitter, ArgsKwargs(brightness=None, contrast=None, saturation=None, hue=None)),
(v2_transforms.GaussianBlur, ArgsKwargs(0.3, 1.4)),
(v2_transforms.RandomPerspective, ArgsKwargs(23, 17, 0.5)),
(v2_transforms.AutoAugment, ArgsKwargs(5)),
]
Expand Down
58 changes: 0 additions & 58 deletions test/test_transforms_v2_functional.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
import os
import re

import numpy as np
Expand Down Expand Up @@ -740,63 +739,6 @@ def _compute_expected_mask(mask, output_size):
torch.testing.assert_close(expected, actual)


# Copied from test/test_functional_tensor.py
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("canvas_size", ("small", "large"))
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)])
@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)])
def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma):
fn = F.gaussian_blur_image

# true_cv2_results = {
# # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8)
# "3_3_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5)
# "3_3_0.5": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8)
# "3_5_0.8": ...
# # cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5)
# "3_5_0.5": ...
# # np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
# # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7)
# "23_23_1.7": ...
# }
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
true_cv2_results = torch.load(p)

if canvas_size == "small":
tensor = (
torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device)
)
else:
tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device)

if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return

if dt is not None:
tensor = tensor.to(dtype=dt)

_ksize = (ksize, ksize) if isinstance(ksize, int) else ksize
_sigma = sigma[0] if sigma is not None else None
shape = tensor.shape
gt_key = f"{shape[-2]}_{shape[-1]}_{shape[-3]}__{_ksize[0]}_{_ksize[1]}_{_sigma}"
if gt_key not in true_cv2_results:
return

true_out = (
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
)

image = tv_tensors.Image(tensor)

out = fn(image, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


@pytest.mark.parametrize(
"inpt",
[
Expand Down
103 changes: 102 additions & 1 deletion test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -2863,12 +2863,64 @@ def test_transform_passthrough(self, make_input):


class TestGaussianBlur:
@pytest.mark.parametrize("kernel_size", [1, 3, (3, 1), [3, 5]])
@pytest.mark.parametrize("sigma", [None, 1.0, 1, (0.5,), [0.3], (0.3, 0.7), [0.9, 0.2]])
def test_kernel_image(self, kernel_size, sigma):
check_kernel(
F.gaussian_blur_image,
make_image(),
kernel_size=kernel_size,
sigma=sigma,
check_scripted_vs_eager=not (isinstance(kernel_size, int) or isinstance(sigma, (float, int))),
)

def test_kernel_image_errors(self):
image = make_image_tensor()

with pytest.raises(ValueError, match="kernel_size is a sequence its length should be 2"):
F.gaussian_blur_image(image, kernel_size=[1, 2, 3])

for kernel_size in [2, -1]:
with pytest.raises(ValueError, match="kernel_size should have odd and positive integers"):
F.gaussian_blur_image(image, kernel_size=kernel_size)

with pytest.raises(ValueError, match="sigma is a sequence, its length should be 2"):
F.gaussian_blur_image(image, kernel_size=1, sigma=[1, 2, 3])

with pytest.raises(TypeError, match="sigma should be either float or sequence of floats"):
F.gaussian_blur_image(image, kernel_size=1, sigma=object())

with pytest.raises(ValueError, match="sigma should have positive values"):
F.gaussian_blur_image(image, kernel_size=1, sigma=-1)

def test_kernel_video(self):
check_kernel(F.gaussian_blur_video, make_video(), kernel_size=(3, 3))

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_video],
)
def test_functional(self, make_input):
check_functional(F.gaussian_blur, make_input(), kernel_size=(3, 3))

@pytest.mark.parametrize(
("kernel", "input_type"),
[
(F.gaussian_blur_image, torch.Tensor),
(F._gaussian_blur_image_pil, PIL.Image.Image),
(F.gaussian_blur_image, tv_tensors.Image),
(F.gaussian_blur_video, tv_tensors.Video),
],
)
def test_functional_signature(self, kernel, input_type):
check_functional_kernel_signature_match(F.gaussian_blur, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"make_input",
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("sigma", [5, (0.5, 2)])
@pytest.mark.parametrize("sigma", [5, 2.0, (0.5, 2), [1.3, 2.7]])
def test_transform(self, make_input, device, sigma):
check_transform(transforms.GaussianBlur(kernel_size=3, sigma=sigma), make_input(device=device))

Expand Down Expand Up @@ -2904,6 +2956,55 @@ def test__get_params(self, sigma):
assert sigma[0] <= params["sigma"][0] <= sigma[1]
assert sigma[0] <= params["sigma"][1] <= sigma[1]

# np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))
# np_img2 = np.arange(26 * 28, dtype="uint8").reshape((26, 28))
# {
# "10_12_3__3_3_0.8": cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.8),
# "10_12_3__3_3_0.5": cv2.GaussianBlur(np_img, ksize=(3, 3), sigmaX=0.5),
# "10_12_3__3_5_0.8": cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.8),
# "10_12_3__3_5_0.5": cv2.GaussianBlur(np_img, ksize=(3, 5), sigmaX=0.5),
# "26_28_1__23_23_1.7": cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7),
# }
REFERENCE_GAUSSIAN_BLUR_IMAGE_RESULTS = torch.load(
Path(__file__).parent / "assets" / "gaussian_blur_opencv_results.pt"
)

@pytest.mark.parametrize(
("canvas_size", "kernel_size", "sigma"),
[
("small", (3, 3), 0.8),
("small", (3, 3), 0.5),
("small", (3, 5), 0.8),
("small", (3, 5), 0.5),
("large", (23, 23), 1.7),
],
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_functional_image_correctness(self, kernel_size, sigma, canvas_size, dtype, device):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored this from

def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):

which is hard to understand and also does some unnecessary work just to skip later.

if dtype == torch.float16 and device == "cpu":
# skip float16 on CPU case
return

if canvas_size == "small":
data = torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1)
pmeier marked this conversation as resolved.
Show resolved Hide resolved
else:
data = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28)))
data = data.to(dtype=dtype, device=device)

num_channels, height, width = data.shape
reference_results_key = f"{height}_{width}_{num_channels}__{kernel_size[0]}_{kernel_size[1]}_{sigma}"
expected = (
torch.tensor(self.REFERENCE_GAUSSIAN_BLUR_IMAGE_RESULTS[reference_results_key])
.reshape(height, width, num_channels)
.permute(2, 0, 1)
.to(data)
)

actual = F.gaussian_blur_image(tv_tensors.Image(data), kernel_size=kernel_size, sigma=sigma)

torch.testing.assert_close(actual, expected, rtol=0, atol=1)


class TestAutoAugmentTransforms:
# These transforms have a lot of branches in their `forward()` passes which are conditioned on random sampling.
Expand Down
12 changes: 0 additions & 12 deletions test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
xfail_jit_python_scalar_arg("output_size"),
],
),
DispatcherInfo(
F.gaussian_blur,
kernels={
tv_tensors.Image: F.gaussian_blur_image,
tv_tensors.Video: F.gaussian_blur_video,
},
pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil),
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
],
),
DispatcherInfo(
F.equalize,
kernels={
Expand Down
37 changes: 0 additions & 37 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,43 +686,6 @@ def sample_inputs_center_crop_video():
)


def sample_inputs_gaussian_blur_image_tensor():
make_gaussian_blur_image_loaders = functools.partial(make_image_loaders, sizes=[(7, 33)], color_spaces=["RGB"])

for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
yield ArgsKwargs(image_loader, kernel_size=kernel_size)

for image_loader, sigma in itertools.product(
make_gaussian_blur_image_loaders(), [None, (3.0, 3.0), [2.0, 2.0], 4.0, [1.5], (3.14,)]
):
yield ArgsKwargs(image_loader, kernel_size=5, sigma=sigma)


def sample_inputs_gaussian_blur_video():
for video_loader in make_video_loaders(sizes=[(7, 33)], num_frames=[5]):
yield ArgsKwargs(video_loader, kernel_size=[3, 3])


KERNEL_INFOS.extend(
[
KernelInfo(
F.gaussian_blur_image,
sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
],
),
KernelInfo(
F.gaussian_blur_video,
sample_inputs_fn=sample_inputs_gaussian_blur_video,
closeness_kwargs=cuda_vs_cpu_pixel_difference(),
),
]
)


def sample_inputs_equalize_image_tensor():
for image_loader in make_image_loaders(sizes=[DEFAULT_PORTRAIT_SPATIAL_SIZE], color_spaces=("GRAY", "RGB")):
yield ArgsKwargs(image_loader)
Expand Down