diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index f28b6e633d9..6d9b20870e5 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1021,7 +1021,7 @@ def test_perspective_interpolation_warning(tester): @pytest.mark.parametrize('device', ["cpu", ]) @pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]]) -@pytest.mark.parametrize('interpolation', [BILINEAR, ]) +@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) def test_resize_antialias(device, dt, size, interpolation, tester): if dt == torch.float16 and device == "cpu": @@ -1051,8 +1051,17 @@ def test_resize_antialias(device, dt, size, interpolation, tester): tester.approxEqualTensorToPIL( resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}" ) + + accepted_tol = 1.0 + 1e-5 + if interpolation == BICUBIC: + # this overall mean value to make the tests pass + # High value is mostly required for test cases with + # downsampling and upsampling where we can not exactly + # match PIL implementation. + accepted_tol = 15.0 + tester.approxEqualTensorToPIL( - resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max", + resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", msg=f"{size}, {interpolation}, {dt}" ) diff --git a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp index 62fec046850..97b025aafb4 100644 --- a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp +++ b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp @@ -141,66 +141,7 @@ void ti_cpu_upsample_generic_aa( // Helper structs to use with ti_upsample_generic_Nd_kernel_impl template struct HelperInterpBase { - static inline void init_indices_weights( - std::vector& output, - int64_t output_size, - int64_t ndims, - int64_t reshape_dim, - int interp_size) { - auto new_shape = std::vector(ndims, 1); - new_shape[reshape_dim] = output_size; - - for (int j = 0; j < interp_size; j++) { - output.emplace_back( - empty(new_shape, CPU(c10::CppTypeToScalarType()))); - output.emplace_back( - empty(new_shape, CPU(c10::CppTypeToScalarType()))); - } - } -}; - -template -struct HelperInterpLinear : public HelperInterpBase { - static const int interp_size = 2; - - static inline std::vector compute_indices_weights( - int64_t input_size, - int64_t output_size, - int64_t stride, - int64_t ndims, - int64_t reshape_dim, - bool align_corners, - const c10::optional opt_scale, - bool antialias, - int& out_interp_size) { - scalar_t scale = area_pixel_compute_scale( - input_size, output_size, align_corners, opt_scale); - TORCH_INTERNAL_ASSERT(antialias); - - return _compute_indices_weights_aa( - input_size, - output_size, - stride, - ndims, - reshape_dim, - align_corners, - scale, - out_interp_size); - } - - // taken from - // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ - // src/libImaging/Resample.c#L20-L29 - static inline scalar_t _filter(scalar_t x) { - if (x < 0.0) { - x = -x; - } - if (x < 1.0) { - return 1.0 - x; - } - return 0.0; - } - + template static inline std::vector _compute_indices_weights_aa( int64_t input_size, int64_t output_size, @@ -209,14 +150,15 @@ struct HelperInterpLinear : public HelperInterpBase { int64_t reshape_dim, bool align_corners, scalar_t scale, - int& out_interp_size) { - int interp_size = HelperInterpLinear::interp_size; + int& in_out_interp_size, + filter_fn_t filter_fn) { + int interp_size = in_out_interp_size; scalar_t support = - (scale >= 1.0) ? (interp_size / 2) * scale : interp_size / 2 * 1.0; + (scale >= 1.0) ? (interp_size * 0.5) * scale : interp_size * 0.5; interp_size = (int)ceilf(support) * 2 + 1; // return interp_size - out_interp_size = interp_size; + in_out_interp_size = interp_size; std::vector output; auto new_shape = std::vector(ndims, 1); @@ -269,7 +211,7 @@ struct HelperInterpLinear : public HelperInterpBase { total_w = 0.0; for (j = 0; j < xmax; j++) { - scalar_t w = _filter((j + xmin - center + 0.5) * invscale); + scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale); wt_ptr[i * interp_size + j] = w; total_w += w; } @@ -287,6 +229,102 @@ struct HelperInterpLinear : public HelperInterpBase { } }; +template +struct HelperInterpLinear : public HelperInterpBase { + static const int interp_size = 2; + + // taken from + // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ + // src/libImaging/Resample.c#L20-L29 + static inline scalar_t _filter(scalar_t x) { + if (x < 0.0) { + x = -x; + } + if (x < 1.0) { + return 1.0 - x; + } + return 0.0; + } + + static inline std::vector compute_indices_weights( + int64_t input_size, + int64_t output_size, + int64_t stride, + int64_t ndims, + int64_t reshape_dim, + bool align_corners, + const c10::optional opt_scale, + bool antialias, + int& out_interp_size) { + TORCH_INTERNAL_ASSERT(antialias); + scalar_t scale = area_pixel_compute_scale( + input_size, output_size, align_corners, opt_scale); + + out_interp_size = HelperInterpLinear::interp_size; + return HelperInterpLinear::_compute_indices_weights_aa( + input_size, + output_size, + stride, + ndims, + reshape_dim, + align_corners, + scale, + out_interp_size, + _filter); + } +}; + +template +struct HelperInterpCubic : public HelperInterpBase { + static const int interp_size = 4; + + static inline std::vector compute_indices_weights( + int64_t input_size, + int64_t output_size, + int64_t stride, + int64_t ndims, + int64_t reshape_dim, + bool align_corners, + const c10::optional opt_scale, + bool antialias, + int& out_interp_size) { + TORCH_INTERNAL_ASSERT(antialias); + scalar_t scale = area_pixel_compute_scale( + input_size, output_size, align_corners, opt_scale); + + out_interp_size = HelperInterpCubic::interp_size; + return HelperInterpCubic::_compute_indices_weights_aa( + input_size, + output_size, + stride, + ndims, + reshape_dim, + align_corners, + scale, + out_interp_size, + _filter); + } + + // taken from + // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ + // src/libImaging/Resample.c#L46-L62 + static inline scalar_t _filter(scalar_t x) { + // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +#define a -0.5 + if (x < 0.0) { + x = -x; + } + if (x < 1.0) { + return ((a + 2.0) * x - (a + 3.0)) * x * x + 1; + } + if (x < 2.0) { + return (((x - 5) * x + 8) * x - 4) * a; + } + return 0.0; +#undef a + } +}; + template < typename index_t, int out_ndims, @@ -396,7 +434,7 @@ void ti_separable_upsample_generic_Nd_kernel_impl( index_t, out_ndims, scale_t, - HelperInterpLinear>( + F>( temp_output, temp_input, interp_dim, align_corners, scales, antialias); temp_input = temp_output; } @@ -404,8 +442,7 @@ void ti_separable_upsample_generic_Nd_kernel_impl( index_t, out_ndims, scale_t, - HelperInterpLinear>( - output, temp_input, 2, align_corners, scales, antialias); + F>(output, temp_input, 2, align_corners, scales, antialias); } void _ti_upsample_bilinear2d_kernel_impl( @@ -423,6 +460,21 @@ void _ti_upsample_bilinear2d_kernel_impl( output, input, align_corners, {scales_h, scales_w}, antialias); } +void _ti_upsample_bicubic2d_kernel_impl( + Tensor& output, + const Tensor& input, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + bool antialias) { + ti_separable_upsample_generic_Nd_kernel_impl< + int64_t, + 2, + scale_t, + HelperInterpCubic>( + output, input, align_corners, {scales_h, scales_w}, antialias); +} + } // namespace internal_upsample } // namespace native } // namespace at @@ -463,6 +515,37 @@ at::Tensor interpolate_linear_aa_forward_kernel( return output; } +at::Tensor interpolate_bicubic_aa_forward_kernel( + const at::Tensor& input, + at::IntArrayRef output_size, + bool align_corners) { + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + + c10::optional> scale_factors = {}; + + // Copied from UpSampleBilinear2d.cpp + auto output = at::empty({0}, input.options()); + auto osize = at::native::upsample::compute_output_size( + input.sizes(), output_size, scale_factors); + auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0); + auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1); + auto full_output_size = + at::native::upsample_2d_common_check(input.sizes(), osize); + + // Allow for empty batch size but not other dimensions + TORCH_CHECK( + input.numel() != 0 || + c10::multiply_integers( + input.sizes().begin() + 1, input.sizes().end()), + "Non-empty 4D data tensor expected but got a tensor with sizes ", + input.sizes()); + + output.resize_(full_output_size, input.suggest_memory_format()); + at::native::internal_upsample::_ti_upsample_bicubic2d_kernel_impl( + output, input, align_corners, scale_h, scale_w, /*antialias=*/true); + return output; +} + // TODO: Implement backward function // at::Tensor interpolate_linear_aa_backward_kernel( // const at::Tensor& grad) { @@ -475,6 +558,10 @@ TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"), TORCH_FN(interpolate_linear_aa_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic_aa"), + TORCH_FN(interpolate_bicubic_aa_forward_kernel)); + // TODO: Implement backward function // m.impl( // TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa_backward"), diff --git a/torchvision/csrc/ops/interpolate_aa.cpp b/torchvision/csrc/ops/interpolate_aa.cpp index 209cc73cf76..90bc26a1fb5 100644 --- a/torchvision/csrc/ops/interpolate_aa.cpp +++ b/torchvision/csrc/ops/interpolate_aa.cpp @@ -12,11 +12,23 @@ at::Tensor interpolate_linear_aa( { static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::interpolate_linear_aa", "") + .findSchemaOrThrow("torchvision::_interpolate_linear_aa", "") .typed(); return op.call(input, output_size, align_corners); } +at::Tensor interpolate_bicubic_aa( + const at::Tensor& input, // Input image + at::IntArrayRef output_size, // Output image size + bool align_corners) // The flag to align corners +{ + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_interpolate_bicubic_aa", "") + .typed(); + return op.call(input, output_size, align_corners); +} + namespace detail { // TODO: Implement backward function @@ -33,6 +45,8 @@ namespace detail { TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.def(TORCH_SELECTIVE_SCHEMA( "torchvision::_interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_interpolate_bicubic_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); // TODO: Implement backward function // m.def(TORCH_SELECTIVE_SCHEMA( // "torchvision::_interpolate_linear_aa_backward(Tensor grad, Tensor rois, diff --git a/torchvision/csrc/ops/interpolate_aa.h b/torchvision/csrc/ops/interpolate_aa.h index acadce49392..0a9ffb4b168 100644 --- a/torchvision/csrc/ops/interpolate_aa.h +++ b/torchvision/csrc/ops/interpolate_aa.h @@ -11,6 +11,11 @@ VISION_API at::Tensor _interpolate_linear_aa( at::IntArrayRef output_size, bool align_corners = false); +VISION_API at::Tensor _interpolate_bicubic_aa( + const at::Tensor& input, + at::IntArrayRef output_size, + bool align_corners = false); + namespace detail { // TODO: Implement backward function diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index e29065e6ec0..7aa63f539cb 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -503,8 +503,8 @@ def resize( if antialias is None: antialias = False - if antialias and interpolation not in ["bilinear", ]: - raise ValueError("Antialias option is supported for bilinear interpolation mode only") + if antialias and interpolation not in ["bilinear", "bicubic"]: + raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") w, h = _get_image_size(img) @@ -537,8 +537,10 @@ def resize( align_corners = False if interpolation in ["bilinear", "bicubic"] else None if antialias: - # Apply antialias for donwsampling on both dims - img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False) + if interpolation == "bilinear": + img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False) + elif interpolation == "bicubic": + img = torch.ops.torchvision._interpolate_bicubic_aa(img, [new_h, new_w], align_corners=False) else: img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)