From 45a643c7a6b3418cd8a10dce401dcd93481d2345 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 3 May 2021 12:31:26 +0000 Subject: [PATCH 1/6] WIP Added antialias option to transforms.functional.resize --- test/test_functional_tensor.py | 44 ++ .../csrc/ops/cpu/interpolate_aa_kernels.cpp | 495 ++++++++++++++++++ torchvision/csrc/ops/interpolate_aa.cpp | 43 ++ torchvision/csrc/ops/interpolate_aa.h | 24 + torchvision/transforms/functional.py | 11 +- torchvision/transforms/functional_tensor.py | 20 +- 6 files changed, 633 insertions(+), 4 deletions(-) create mode 100644 torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp create mode 100644 torchvision/csrc/ops/interpolate_aa.cpp create mode 100644 torchvision/csrc/ops/interpolate_aa.h diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 1964e3134ec..0b4995c0b98 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -472,6 +472,50 @@ def test_resize(self): with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"): F.resize(img, size=32, max_size=32) + def test_resize_antialias(self): + script_fn = torch.jit.script(F.resize) + tensor, pil_img = self._create_data(320, 290, device=self.device) + + for dt in [None, torch.float32, torch.float64, torch.float16]: + + if dt == torch.float16 and torch.device(self.device).type == "cpu": + # skip float16 on CPU case + continue + + if dt is not None: + # This is a trivial cast to float of uint8 data to test all cases + tensor = tensor.to(dt) + + for size in [[96, 72], ]: + for interpolation in [BILINEAR, ]: + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True) + resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) + + self.assertEqual( + resized_tensor.size()[1:], resized_pil_img.size[::-1], + msg=f"{size}, {interpolation}, {dt}" + ) + + resized_tensor_f = resized_tensor + # we need to cast to uint8 to compare with PIL image + if resized_tensor_f.dtype == torch.uint8: + resized_tensor_f = resized_tensor_f.to(torch.float) + + self.approxEqualTensorToPIL( + resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}" + ) + self.approxEqualTensorToPIL( + resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max", msg=f"{size}, {interpolation}, {dt}" + ) + + if isinstance(size, int): + script_size = [size, ] + else: + script_size = size + + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True) + self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) + def test_resized_crop(self): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity diff --git a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp new file mode 100644 index 00000000000..ededbea1f7b --- /dev/null +++ b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp @@ -0,0 +1,495 @@ +#include +#include +#include +#include +#include +#include + +#include + +// Code temporary is in torchvision before merging it to PyTorch +namespace at { +namespace native { +namespace internal_upsample { + +using scale_t = std::vector>; + +template +static inline scalar_t interpolate_aa_single_dim_zero_strides( + char* src, + char** data, + int64_t i, + const index_t ids_stride) { + const index_t ids_min = *(index_t*)&data[0][0]; + const index_t ids_size = *(index_t*)&data[1][0]; + + char* src_min = src + ids_min; + + scalar_t t = *(scalar_t*)&src_min[0]; + index_t wts_idx = *(index_t*)&data[4][0]; + char* wts_ptr = &data[3][wts_idx]; + scalar_t wts = *(scalar_t*)&wts_ptr[0]; + + scalar_t output = t * wts; + int j = 1; + + // Using partial loop unroll gives a small speed-up + for (; j < 2; j++) { + wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; + t = *(scalar_t*)&src_min[j * ids_stride]; + output += t * wts; + } + for (; j < ids_size; j++) { + wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; + t = *(scalar_t*)&src_min[j * ids_stride]; + output += t * wts; + } + return output; +} + +template +static inline scalar_t interpolate_aa_single_dim( + char* src, + char** data, + const int64_t* strides, + int64_t i, + const index_t ids_stride) { + index_t ids_min = *(index_t*)&data[0][i * strides[0]]; + index_t ids_size = *(index_t*)&data[1][i * strides[1]]; + + char* src_min = src + ids_min; + + scalar_t t = *(scalar_t*)&src_min[0]; + index_t wts_idx = *(index_t*)&data[4][i * strides[4]]; + char* wts_ptr = &data[3][wts_idx]; + scalar_t wts = *(scalar_t*)&wts_ptr[0]; + + scalar_t output = t * wts; + int j = 1; + // Using partial loop unroll gives a small speed-up + for (; j < 2; j++) { + wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; + t = *(scalar_t*)&src_min[j * ids_stride]; + output += t * wts; + } + for (; j < ids_size; j++) { + wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; + t = *(scalar_t*)&src_min[j * ids_stride]; + output += t * wts; + } + return output; +} + +template +static inline void basic_loop_aa_single_dim_zero_strides( + char** data, + const int64_t* strides, + int64_t n) { + char* dst = data[0]; + char* src = data[1]; + // index stride is constant for the given dimension + const index_t ids_stride = *(index_t*)&data[2 + 2][0]; + + for (int64_t i = 0; i < n; i++) { + *(scalar_t*)&dst[i * strides[0]] = + interpolate_aa_single_dim_zero_strides( + src + i * strides[1], &data[2], i, ids_stride); + } +} + +template +static inline void basic_loop_aa_single_dim_nonzero_strides( + char** data, + const int64_t* strides, + int64_t n) { + char* dst = data[0]; + char* src = data[1]; + // index stride is constant for the given dimension + const index_t ids_stride = *(index_t*)&data[2 + 2][0]; + + if (strides[1] == 0) { + for (int64_t i = 0; i < n; i++) { + *(scalar_t*)&dst[i * strides[0]] = + interpolate_aa_single_dim( + src, &data[2], &strides[2], i, ids_stride); + } + } else { + for (int64_t i = 0; i < n; i++) { + *(scalar_t*)&dst[i * strides[0]] = + interpolate_aa_single_dim( + src + i * strides[1], &data[2], &strides[2], i, ids_stride); + } + } +} + +template +static inline bool is_zero_stride(const int64_t* strides) { + bool output = strides[0] == 0; + for (int i = 1; i < m; i++) { + output &= (strides[i] == 0); + } + return output; +} + +template +void ti_cpu_upsample_generic_aa( + at::TensorIterator& iter, + int interp_size = -1) { + TORCH_INTERNAL_ASSERT(interp_size > 0); + + auto loop = [&](char** data, const int64_t* strides, int64_t n) { + if ((strides[0] == sizeof(scalar_t)) && (strides[1] == sizeof(scalar_t)) && + is_zero_stride<3 + 2>(&strides[2])) { + basic_loop_aa_single_dim_zero_strides( + data, strides, n); + } else { + basic_loop_aa_single_dim_nonzero_strides( + data, strides, n); + } + }; + + iter.for_each(loop); +} + +// Helper structs to use with ti_upsample_generic_Nd_kernel_impl +template +struct HelperInterpBase { + static inline void init_indices_weights( + std::vector& output, + int64_t output_size, + int64_t ndims, + int64_t reshape_dim, + int interp_size) { + auto new_shape = std::vector(ndims, 1); + new_shape[reshape_dim] = output_size; + + for (int j = 0; j < interp_size; j++) { + output.emplace_back( + empty(new_shape, CPU(c10::CppTypeToScalarType()))); + output.emplace_back( + empty(new_shape, CPU(c10::CppTypeToScalarType()))); + } + } +}; + +template +struct HelperInterpLinear : public HelperInterpBase { + static const int interp_size = 2; + + static inline std::vector compute_indices_weights( + int64_t input_size, + int64_t output_size, + int64_t stride, + int64_t ndims, + int64_t reshape_dim, + bool align_corners, + const c10::optional opt_scale, + bool antialias, + int& out_interp_size) { + scalar_t scale = area_pixel_compute_scale( + input_size, output_size, align_corners, opt_scale); + TORCH_INTERNAL_ASSERT(antialias && scale > 1.0); + + return _compute_indices_weights_aa( + input_size, + output_size, + stride, + ndims, + reshape_dim, + align_corners, + scale, + out_interp_size); + } + + // taken from + // https://github.com/python-pillow/Pillow/blob/6812205f18ca4ef54372e87e1a13ce4a859434df/ + // src/libImaging/Resample.c#L20-L29 + static inline scalar_t _filter(scalar_t x) { + if (x < 0.0) { + x = -x; + } + if (x < 1.0) { + return 1.0 - x; + } + return 0.0; + } + + static inline std::vector _compute_indices_weights_aa( + int64_t input_size, + int64_t output_size, + int64_t stride, + int64_t ndims, + int64_t reshape_dim, + bool align_corners, + scalar_t scale, + int& out_interp_size) { + int interp_size = HelperInterpLinear::interp_size; + scalar_t support = (interp_size / 2) * scale; + interp_size = (int)ceilf(support) * 2 + 1; + + // return interp_size + out_interp_size = interp_size; + + std::vector output; + auto new_shape = std::vector(ndims, 1); + new_shape[reshape_dim] = output_size; + + // ---- Bounds approach as in PIL ----- + // bounds: xmin/xmax + output.emplace_back( + empty(new_shape, CPU(c10::CppTypeToScalarType()))); + output.emplace_back( + empty(new_shape, CPU(c10::CppTypeToScalarType()))); + output.emplace_back( + empty(new_shape, CPU(c10::CppTypeToScalarType()))); + + { + // Weights + new_shape[reshape_dim] = output_size * interp_size; + auto wts = empty(new_shape, CPU(c10::CppTypeToScalarType())); + auto strides = wts.strides().vec(); + strides[reshape_dim] = 0; + new_shape[reshape_dim] = output_size; + wts = wts.as_strided(new_shape, strides); + output.emplace_back(wts); + // Weights indices + output.emplace_back( + empty(new_shape, CPU(c10::CppTypeToScalarType()))); + } + + scalar_t center, total_w, invscale = 1.0 / scale; + index_t zero = static_cast(0); + int64_t* idx_ptr_xmin = output[0].data_ptr(); + int64_t* idx_ptr_size = output[1].data_ptr(); + int64_t* idx_ptr_stride = output[2].data_ptr(); + scalar_t* wt_ptr = output[3].data_ptr(); + int64_t* wt_idx_ptr = output[4].data_ptr(); + + int64_t xmin, xmax, j; + + for (int64_t i = 0; i < output_size; i++) { + center = scale * (i + 0.5); + xmin = std::max(static_cast(center - support + 0.5), zero); + xmax = + std::min(static_cast(center + support + 0.5), input_size) - + xmin; + idx_ptr_xmin[i] = xmin * stride; + idx_ptr_size[i] = xmax; + idx_ptr_stride[i] = stride; + + wt_idx_ptr[i] = i * interp_size * sizeof(scalar_t); + + total_w = 0.0; + for (j = 0; j < xmax; j++) { + scalar_t w = _filter((j + xmin - center + 0.5) * invscale); + wt_ptr[i * interp_size + j] = w; + total_w += w; + } + for (j = 0; j < xmax; j++) { + if (total_w != 0.0) { + wt_ptr[i * interp_size + j] /= total_w; + } + } + + for (; j < interp_size; j++) { + wt_ptr[i * interp_size + j] = static_cast(0.0); + } + } + return output; + } +}; + +template < + typename index_t, + int out_ndims, + typename scale_type, + template + class F> +void _ti_separable_upsample_generic_Nd_kernel_impl_single_dim( + Tensor& output, + const Tensor& input, + int interp_dim, + bool align_corners, + const scale_type& scales, + bool antialias) { + // input can be NCHW, NCL or NCKHW + auto shape = input.sizes().vec(); + auto strides = input.strides().vec(); + auto oshape = output.sizes(); + + TORCH_INTERNAL_ASSERT( + shape.size() == oshape.size() && shape.size() == 2 + out_ndims); + TORCH_INTERNAL_ASSERT(strides.size() == 2 + out_ndims); + TORCH_INTERNAL_ASSERT(antialias); + + for (int i = 0; i < out_ndims; i++) { + shape[i + 2] = oshape[i + 2]; + } + strides[interp_dim] = 0; + auto restrided_input = input.as_strided(shape, strides); + + std::vector> indices_weights; + + int interp_size = F::interp_size; + auto input_scalar_type = input.scalar_type(); + + if (interp_size == 1 && input_scalar_type == at::ScalarType::Byte) { + // nearest also supports uint8 tensor, but we have to use float + // with compute_indices_weights + input_scalar_type = at::ScalarType::Float; + } + + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::Byte, + input_scalar_type, + "compute_indices_weights_generic", + [&] { + indices_weights.emplace_back( + F::compute_indices_weights( + input.size(interp_dim), + oshape[interp_dim], + input.stride(interp_dim) * input.element_size(), + input.dim(), + interp_dim, + align_corners, + scales[interp_dim - 2], + antialias, + interp_size)); + }); + + TensorIteratorConfig config; + config.check_all_same_dtype(false) + .declare_static_dtype_and_device(input.scalar_type(), input.device()) + .add_output(output) + .add_input(restrided_input); + + for (auto& idx_weight : indices_weights) { + for (auto& tensor : idx_weight) { + config.add_input(tensor); + } + } + + auto iter = config.build(); + + if (interp_size > 1) { + // Nearest also supports uint8 tensor, so need to handle it separately + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "upsample_generic_Nd", [&] { + ti_cpu_upsample_generic_aa( + iter, interp_size); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd", [&] { + ti_cpu_upsample_generic_aa( + iter, interp_size); + }); + } +} + +template < + typename index_t, + int out_ndims, + typename scale_type, + template + class F> +void ti_separable_upsample_generic_Nd_kernel_impl( + Tensor& output, + const Tensor& input, + bool align_corners, + const scale_type& scales, + bool antialias) { + auto temp_oshape = input.sizes().vec(); + at::Tensor temp_output, temp_input = input; + for (int i = 0; i < out_ndims - 1; i++) { + int interp_dim = 2 + out_ndims - 1 - i; + temp_oshape[interp_dim] = output.sizes()[interp_dim]; + temp_output = at::empty(temp_oshape, input.options()); + _ti_separable_upsample_generic_Nd_kernel_impl_single_dim< + index_t, + out_ndims, + scale_t, + HelperInterpLinear>( + temp_output, temp_input, interp_dim, align_corners, scales, antialias); + temp_input = temp_output; + } + _ti_separable_upsample_generic_Nd_kernel_impl_single_dim< + index_t, + out_ndims, + scale_t, + HelperInterpLinear>( + output, temp_input, 2, align_corners, scales, antialias); +} + +void _ti_upsample_bilinear2d_kernel_impl( + Tensor& output, + const Tensor& input, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + bool antialias) { + ti_separable_upsample_generic_Nd_kernel_impl< + int64_t, + 2, + scale_t, + HelperInterpLinear>( + output, input, align_corners, {scales_h, scales_w}, antialias); +} + +} // namespace internal_upsample +} // namespace native +} // namespace at + +namespace vision { +namespace ops { + +namespace { + +at::Tensor interpolate_linear_aa_forward_kernel( + const at::Tensor& input, + at::IntArrayRef output_size, + bool align_corners) { + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + + c10::optional> scale_factors = {}; + + // Copied from UpSampleBilinear2d.cpp + auto output = at::empty({0}, input.options()); + auto osize = at::native::upsample::compute_output_size( + input.sizes(), output_size, scale_factors); + auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0); + auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1); + auto full_output_size = + at::native::upsample_2d_common_check(input.sizes(), osize); + + // Allow for empty batch size but not other dimensions + TORCH_CHECK( + input.numel() != 0 || + c10::multiply_integers( + input.sizes().begin() + 1, input.sizes().end()), + "Non-empty 4D data tensor expected but got a tensor with sizes ", + input.sizes()); + + output.resize_(full_output_size, input.suggest_memory_format()); + at::native::internal_upsample::_ti_upsample_bilinear2d_kernel_impl( + output, input, align_corners, scale_h, scale_w, /*antialias=*/true); + return output; +} + +// at::Tensor interpolate_linear_aa_backward_kernel( +// const at::Tensor& grad) { +// return grad_input; +// } + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, CPU, m) { + m.impl( + TORCH_SELECTIVE_NAME("torchvision::interpolate_linear_aa"), + TORCH_FN(interpolate_linear_aa_forward_kernel)); + // m.impl( + // TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa_backward"), + // TORCH_FN(interpolate_linear_aa_backward_kernel)); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/interpolate_aa.cpp b/torchvision/csrc/ops/interpolate_aa.cpp new file mode 100644 index 00000000000..58bec2aa33d --- /dev/null +++ b/torchvision/csrc/ops/interpolate_aa.cpp @@ -0,0 +1,43 @@ +#include "interpolate_aa.h" + +#include + +namespace vision { +namespace ops { + +at::Tensor interpolate_linear_aa( + const at::Tensor& input, // Input image + at::IntArrayRef output_size, // Output image size + bool align_corners) // The flag to align corners +{ + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::interpolate_linear_aa", "") + .typed(); + return op.call(input, output_size, align_corners); +} + +namespace detail { + +// at::Tensor _interpolate_linear_aa_backward( +// const at::Tensor& grad, +// at::IntArrayRef output_size, +// bool align_corners) +// { +// return at::Tensor(); +// } + +} // namespace detail + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); + // m.def(TORCH_SELECTIVE_SCHEMA( + // "torchvision::_interpolate_linear_aa_backward(Tensor grad, Tensor rois, + // float spatial_scale, int pooled_height, int pooled_width, int + // batch_size, int channels, int height, int width, int sampling_ratio, + // bool aligned) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/interpolate_aa.h b/torchvision/csrc/ops/interpolate_aa.h new file mode 100644 index 00000000000..95066e64ebe --- /dev/null +++ b/torchvision/csrc/ops/interpolate_aa.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor interpolate_linear_aa( + const at::Tensor& input, + at::IntArrayRef output_size, + bool align_corners = false); + +namespace detail { + +// at::Tensor _interpolate_linear_aa_backward( +// const at::Tensor& grad, +// at::IntArrayRef output_size, +// bool align_corners=false); + +} // namespace detail + +} // namespace ops +} // namespace vision diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 855ce19bde4..b735196d44f 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -341,7 +341,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - max_size: Optional[int] = None) -> Tensor: + max_size: Optional[int] = None, antialias: Optional[bool] = None) -> Tensor: r"""Resize the input image to the given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -375,6 +375,9 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte smaller edge may be shorter than ``size``. This is only supported if ``size`` is an int (or a sequence of length 1 in torchscript mode). + antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias + is always used. If ``img`` is Tensor, the flag is False by default and can be set True for + ``InterpolationMode.BILINEAR`` only mode. Returns: PIL Image or Tensor: Resized image. @@ -391,10 +394,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte raise TypeError("Argument interpolation should be a InterpolationMode") if not isinstance(img, torch.Tensor): + if antialias is not None: + warnings.warn( + "Anti-alias option is always applied for PIL Image input. Argument antialias is ignored." + ) pil_interpolation = pil_modes_mapping[interpolation] return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) - return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size) + return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) def scale(*args, **kwargs): diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 156d49150bc..97c3a4a6992 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -470,7 +470,13 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con return img -def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None) -> Tensor: +def resize( + img: Tensor, + size: List[int], + interpolation: str = "bilinear", + max_size: Optional[int] = None, + antialias: Optional[bool] = None +) -> Tensor: _assert_image_tensor(img) if not isinstance(size, (int, tuple, list)): @@ -494,6 +500,12 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si "i.e. size should be an int or a sequence of length 1 in torchscript mode." ) + if antialias is None: + antialias = False + + if antialias and interpolation not in ["bilinear", ]: + raise ValueError("Antialias option is supported for bilinear interpolation mode only") + w, h = _get_image_size(img) if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge @@ -524,7 +536,11 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_si # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None - img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners) + if antialias and (new_w < w and new_h < h): + # Apply antialias for donwsampling on both dims + img = torch.ops.torchvision.interpolate_linear_aa(img, [new_h, new_w], align_corners=False) + else: + img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners) if interpolation == "bicubic" and out_dtype == torch.uint8: img = img.clamp(min=0, max=255) From 5cb18b5a2b2888c1cab2cdb3f35ba610d999351b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 3 May 2021 14:10:04 +0000 Subject: [PATCH 2/6] Updates according to the review --- test/test_functional_tensor.py | 9 +++++++-- test/test_transforms.py | 4 ++++ .../csrc/ops/cpu/interpolate_aa_kernels.cpp | 4 +++- torchvision/csrc/ops/interpolate_aa.cpp | 4 +++- torchvision/csrc/ops/interpolate_aa.h | 3 ++- torchvision/transforms/functional.py | 5 ++++- torchvision/transforms/transforms.py | 15 +++++++++++---- 7 files changed, 34 insertions(+), 10 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 0b4995c0b98..4dcc307d723 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -473,6 +473,10 @@ def test_resize(self): F.resize(img, size=32, max_size=32) def test_resize_antialias(self): + + if self.device == "cuda": + self.skipTest("Not implemented for CUDA device") + script_fn = torch.jit.script(F.resize) tensor, pil_img = self._create_data(320, 290, device=self.device) @@ -486,7 +490,7 @@ def test_resize_antialias(self): # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) - for size in [[96, 72], ]: + for size in [[96, 72], [96, 420]]: for interpolation in [BILINEAR, ]: resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) @@ -505,7 +509,8 @@ def test_resize_antialias(self): resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}" ) self.approxEqualTensorToPIL( - resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max", msg=f"{size}, {interpolation}, {dt}" + resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max", + msg=f"{size}, {interpolation}, {dt}" ) if isinstance(size, int): diff --git a/test/test_transforms.py b/test/test_transforms.py index 0a01247aa87..9402a37bc35 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -348,6 +348,10 @@ def test_resize(self): self.assertEqual((owidth, oheight), result.size) + with self.assertWarnsRegex(UserWarning, r"Anti-alias option is always applied for PIL Image input"): + t = transforms.Resize(osize, antialias=False) + t(img) + def test_random_crop(self): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 diff --git a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp index ededbea1f7b..9b134a2c539 100644 --- a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp +++ b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp @@ -475,6 +475,7 @@ at::Tensor interpolate_linear_aa_forward_kernel( return output; } +// TODO: Implement backward function // at::Tensor interpolate_linear_aa_backward_kernel( // const at::Tensor& grad) { // return grad_input; @@ -484,8 +485,9 @@ at::Tensor interpolate_linear_aa_forward_kernel( TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl( - TORCH_SELECTIVE_NAME("torchvision::interpolate_linear_aa"), + TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"), TORCH_FN(interpolate_linear_aa_forward_kernel)); + // TODO: Implement backward function // m.impl( // TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa_backward"), // TORCH_FN(interpolate_linear_aa_backward_kernel)); diff --git a/torchvision/csrc/ops/interpolate_aa.cpp b/torchvision/csrc/ops/interpolate_aa.cpp index 58bec2aa33d..209cc73cf76 100644 --- a/torchvision/csrc/ops/interpolate_aa.cpp +++ b/torchvision/csrc/ops/interpolate_aa.cpp @@ -19,6 +19,7 @@ at::Tensor interpolate_linear_aa( namespace detail { +// TODO: Implement backward function // at::Tensor _interpolate_linear_aa_backward( // const at::Tensor& grad, // at::IntArrayRef output_size, @@ -31,7 +32,8 @@ namespace detail { TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); + "torchvision::_interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); + // TODO: Implement backward function // m.def(TORCH_SELECTIVE_SCHEMA( // "torchvision::_interpolate_linear_aa_backward(Tensor grad, Tensor rois, // float spatial_scale, int pooled_height, int pooled_width, int diff --git a/torchvision/csrc/ops/interpolate_aa.h b/torchvision/csrc/ops/interpolate_aa.h index 95066e64ebe..acadce49392 100644 --- a/torchvision/csrc/ops/interpolate_aa.h +++ b/torchvision/csrc/ops/interpolate_aa.h @@ -6,13 +6,14 @@ namespace vision { namespace ops { -VISION_API at::Tensor interpolate_linear_aa( +VISION_API at::Tensor _interpolate_linear_aa( const at::Tensor& input, at::IntArrayRef output_size, bool align_corners = false); namespace detail { +// TODO: Implement backward function // at::Tensor _interpolate_linear_aa_backward( // const at::Tensor& grad, // at::IntArrayRef output_size, diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index b735196d44f..17dd649159e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -379,6 +379,9 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte is always used. If ``img`` is Tensor, the flag is False by default and can be set True for ``InterpolationMode.BILINEAR`` only mode. + .. warning:: + There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor. + Returns: PIL Image or Tensor: Resized image. """ @@ -394,7 +397,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte raise TypeError("Argument interpolation should be a InterpolationMode") if not isinstance(img, torch.Tensor): - if antialias is not None: + if antialias is not None and not antialias: warnings.warn( "Anti-alias option is always applied for PIL Image input. Argument antialias is ignored." ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index dd87cc2b82c..4e013227be1 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -257,10 +257,16 @@ class Resize(torch.nn.Module): smaller edge may be shorter than ``size``. This is only supported if ``size`` is an int (or a sequence of length 1 in torchscript mode). + antialias (bool, optional): antialias flag. If ``img`` is PIL Image, the flag is ignored and anti-alias + is always used. If ``img`` is Tensor, the flag is False by default and can be set True for + ``InterpolationMode.BILINEAR`` only mode. + + .. warning:: + There is no autodiff support for ``antialias=True`` option with input ``img`` as Tensor. """ - def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None): + def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None, antialias=None): super().__init__() if not isinstance(size, (int, Sequence)): raise TypeError("Size should be int or sequence. Got {}".format(type(size))) @@ -278,6 +284,7 @@ def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None interpolation = _interpolation_modes_from_int(interpolation) self.interpolation = interpolation + self.antialias = antialias def forward(self, img): """ @@ -287,12 +294,12 @@ def forward(self, img): Returns: PIL Image or Tensor: Rescaled image. """ - return F.resize(img, self.size, self.interpolation, self.max_size) + return F.resize(img, self.size, self.interpolation, self.max_size, self.antialias) def __repr__(self): interpolate_str = self.interpolation.value - return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2})'.format( - self.size, interpolate_str, self.max_size) + return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( + self.size, interpolate_str, self.max_size, self.antialias) class Scale(Resize): From 0e10604f8a60df32e8b37685aaca8779bca04903 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 3 May 2021 14:49:36 +0000 Subject: [PATCH 3/6] Excluded these C++ files for iOS build --- ios/CMakeLists.txt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ios/CMakeLists.txt b/ios/CMakeLists.txt index 6b9fd3925b2..8cff8d6ec79 100644 --- a/ios/CMakeLists.txt +++ b/ios/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.4.1) +cmake_minimum_required(VERSION 3.6) set(TARGET torchvision_ops) project(${TARGET} CXX) set(CMAKE_CXX_STANDARD 14) @@ -11,6 +11,12 @@ file(GLOB VISION_SRCS ../torchvision/csrc/ops/*.h ../torchvision/csrc/ops/*.cpp) +# Remove interpolate_aa sources as they are temporary code +# see https://github.com/pytorch/vision/pull/3761 +# and using TensorIterator unavailable with iOS +# FILTER was added in CMake>=3.6 => 3.4.1 -> 3.6 +list(FILTER VISION_SRCS EXCLUDE REGEX ".+(interpolate_aa).+") + add_library(${TARGET} STATIC ${VISION_SRCS} ) From c2247de434672b3d001c7348738fe64e472aca3f Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 3 May 2021 17:21:22 +0000 Subject: [PATCH 4/6] Added support for mixed downsampling/upsampling --- test/test_functional_tensor.py | 4 ++-- torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp | 7 ++++--- torchvision/transforms/functional_tensor.py | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 4dcc307d723..d1b5430012a 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -490,7 +490,7 @@ def test_resize_antialias(self): # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) - for size in [[96, 72], [96, 420]]: + for size in [[96, 72], [96, 420], [420, 72]]: for interpolation in [BILINEAR, ]: resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) @@ -519,7 +519,7 @@ def test_resize_antialias(self): script_size = size resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True) - self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) + self.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}") def test_resized_crop(self): # test values of F.resized_crop in several cases: diff --git a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp index 9b134a2c539..cee831a590c 100644 --- a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp +++ b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp @@ -188,7 +188,7 @@ struct HelperInterpLinear : public HelperInterpBase { int& out_interp_size) { scalar_t scale = area_pixel_compute_scale( input_size, output_size, align_corners, opt_scale); - TORCH_INTERNAL_ASSERT(antialias && scale > 1.0); + TORCH_INTERNAL_ASSERT(antialias); return _compute_indices_weights_aa( input_size, @@ -224,7 +224,8 @@ struct HelperInterpLinear : public HelperInterpBase { scalar_t scale, int& out_interp_size) { int interp_size = HelperInterpLinear::interp_size; - scalar_t support = (interp_size / 2) * scale; + scalar_t support = + (scale > 1.0) ? (interp_size / 2) * scale : interp_size / 2 * 1.0; interp_size = (int)ceilf(support) * 2 + 1; // return interp_size @@ -257,7 +258,7 @@ struct HelperInterpLinear : public HelperInterpBase { empty(new_shape, CPU(c10::CppTypeToScalarType()))); } - scalar_t center, total_w, invscale = 1.0 / scale; + scalar_t center, total_w, invscale = (scale > 1.0) ? 1.0 / scale : 1.0; index_t zero = static_cast(0); int64_t* idx_ptr_xmin = output[0].data_ptr(); int64_t* idx_ptr_size = output[1].data_ptr(); diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 97c3a4a6992..e29065e6ec0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -536,9 +536,9 @@ def resize( # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None - if antialias and (new_w < w and new_h < h): + if antialias: # Apply antialias for donwsampling on both dims - img = torch.ops.torchvision.interpolate_linear_aa(img, [new_h, new_w], align_corners=False) + img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False) else: img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners) From 379d91f1b56d9a32d45afed763b0f5fd3facf499 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 4 May 2021 14:05:42 +0000 Subject: [PATCH 5/6] Fixed heap overflow caused by explicit loop unrolling --- .../csrc/ops/cpu/interpolate_aa_kernels.cpp | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp index cee831a590c..29d390165d0 100644 --- a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp +++ b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp @@ -32,13 +32,6 @@ static inline scalar_t interpolate_aa_single_dim_zero_strides( scalar_t output = t * wts; int j = 1; - - // Using partial loop unroll gives a small speed-up - for (; j < 2; j++) { - wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; - t = *(scalar_t*)&src_min[j * ids_stride]; - output += t * wts; - } for (; j < ids_size; j++) { wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; t = *(scalar_t*)&src_min[j * ids_stride]; @@ -66,12 +59,6 @@ static inline scalar_t interpolate_aa_single_dim( scalar_t output = t * wts; int j = 1; - // Using partial loop unroll gives a small speed-up - for (; j < 2; j++) { - wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; - t = *(scalar_t*)&src_min[j * ids_stride]; - output += t * wts; - } for (; j < ids_size; j++) { wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; t = *(scalar_t*)&src_min[j * ids_stride]; @@ -225,7 +212,7 @@ struct HelperInterpLinear : public HelperInterpBase { int& out_interp_size) { int interp_size = HelperInterpLinear::interp_size; scalar_t support = - (scale > 1.0) ? (interp_size / 2) * scale : interp_size / 2 * 1.0; + (scale >= 1.0) ? (interp_size / 2) * scale : interp_size / 2 * 1.0; interp_size = (int)ceilf(support) * 2 + 1; // return interp_size @@ -258,7 +245,7 @@ struct HelperInterpLinear : public HelperInterpBase { empty(new_shape, CPU(c10::CppTypeToScalarType()))); } - scalar_t center, total_w, invscale = (scale > 1.0) ? 1.0 / scale : 1.0; + scalar_t center, total_w, invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; index_t zero = static_cast(0); int64_t* idx_ptr_xmin = output[0].data_ptr(); int64_t* idx_ptr_size = output[1].data_ptr(); From 4b641edd6b8e9c0e9102332f22df2cc9b1d8842a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 10 May 2021 10:57:32 +0000 Subject: [PATCH 6/6] Applied PR review suggestions - used pytest parametrize instead unittest - cast to scalar_t ptr - removed interpolate aa files for ios/android keeping original cmake version --- android/ops/CMakeLists.txt | 7 ++ ios/CMakeLists.txt | 7 +- test/test_functional_tensor.py | 96 +++++++++---------- .../csrc/ops/cpu/interpolate_aa_kernels.cpp | 12 +-- 4 files changed, 64 insertions(+), 58 deletions(-) diff --git a/android/ops/CMakeLists.txt b/android/ops/CMakeLists.txt index ad42adbfa71..3210925a85c 100644 --- a/android/ops/CMakeLists.txt +++ b/android/ops/CMakeLists.txt @@ -14,6 +14,13 @@ file(GLOB VISION_SRCS ../../torchvision/csrc/ops/*.h ../../torchvision/csrc/ops/*.cpp) +# Remove interpolate_aa sources as they are temporary code +# see https://github.com/pytorch/vision/pull/3761 +# and IndexingUtils.h is unavailable on Android build +list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp") +list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.cpp") +list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../../torchvision/csrc/ops/interpolate_aa.h") + add_library(${TARGET} SHARED ${VISION_SRCS} ) diff --git a/ios/CMakeLists.txt b/ios/CMakeLists.txt index 8cff8d6ec79..2ac46c15018 100644 --- a/ios/CMakeLists.txt +++ b/ios/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.6) +cmake_minimum_required(VERSION 3.4.1) set(TARGET torchvision_ops) project(${TARGET} CXX) set(CMAKE_CXX_STANDARD 14) @@ -14,8 +14,9 @@ file(GLOB VISION_SRCS # Remove interpolate_aa sources as they are temporary code # see https://github.com/pytorch/vision/pull/3761 # and using TensorIterator unavailable with iOS -# FILTER was added in CMake>=3.6 => 3.4.1 -> 3.6 -list(FILTER VISION_SRCS EXCLUDE REGEX ".+(interpolate_aa).+") +list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp") +list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.cpp") +list(REMOVE_ITEM VISION_SRCS "${CMAKE_CURRENT_LIST_DIR}/../torchvision/csrc/ops/interpolate_aa.h") add_library(${TARGET} STATIC ${VISION_SRCS} diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index d1b5430012a..f28b6e633d9 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -472,55 +472,6 @@ def test_resize(self): with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"): F.resize(img, size=32, max_size=32) - def test_resize_antialias(self): - - if self.device == "cuda": - self.skipTest("Not implemented for CUDA device") - - script_fn = torch.jit.script(F.resize) - tensor, pil_img = self._create_data(320, 290, device=self.device) - - for dt in [None, torch.float32, torch.float64, torch.float16]: - - if dt == torch.float16 and torch.device(self.device).type == "cpu": - # skip float16 on CPU case - continue - - if dt is not None: - # This is a trivial cast to float of uint8 data to test all cases - tensor = tensor.to(dt) - - for size in [[96, 72], [96, 420], [420, 72]]: - for interpolation in [BILINEAR, ]: - resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True) - resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) - - self.assertEqual( - resized_tensor.size()[1:], resized_pil_img.size[::-1], - msg=f"{size}, {interpolation}, {dt}" - ) - - resized_tensor_f = resized_tensor - # we need to cast to uint8 to compare with PIL image - if resized_tensor_f.dtype == torch.uint8: - resized_tensor_f = resized_tensor_f.to(torch.float) - - self.approxEqualTensorToPIL( - resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}" - ) - self.approxEqualTensorToPIL( - resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max", - msg=f"{size}, {interpolation}, {dt}" - ) - - if isinstance(size, int): - script_size = [size, ] - else: - script_size = size - - resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True) - self.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}") - def test_resized_crop(self): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity @@ -1067,5 +1018,52 @@ def test_perspective_interpolation_warning(tester): tester.assertTrue(res1.equal(res2)) +@pytest.mark.parametrize('device', ["cpu", ]) +@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]]) +@pytest.mark.parametrize('interpolation', [BILINEAR, ]) +def test_resize_antialias(device, dt, size, interpolation, tester): + + if dt == torch.float16 and device == "cpu": + # skip float16 on CPU case + return + + script_fn = torch.jit.script(F.resize) + tensor, pil_img = tester._create_data(320, 290, device=device) + + if dt is not None: + # This is a trivial cast to float of uint8 data to test all cases + tensor = tensor.to(dt) + + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, antialias=True) + resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) + + tester.assertEqual( + resized_tensor.size()[1:], resized_pil_img.size[::-1], + msg=f"{size}, {interpolation}, {dt}" + ) + + resized_tensor_f = resized_tensor + # we need to cast to uint8 to compare with PIL image + if resized_tensor_f.dtype == torch.uint8: + resized_tensor_f = resized_tensor_f.to(torch.float) + + tester.approxEqualTensorToPIL( + resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}" + ) + tester.approxEqualTensorToPIL( + resized_tensor_f, resized_pil_img, tol=1.0 + 1e-5, agg_method="max", + msg=f"{size}, {interpolation}, {dt}" + ) + + if isinstance(size, int): + script_size = [size, ] + else: + script_size = size + + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, antialias=True) + tester.assertTrue(resized_tensor.equal(resize_result), msg=f"{size}, {interpolation}, {dt}") + + if __name__ == '__main__': unittest.main() diff --git a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp index 29d390165d0..62fec046850 100644 --- a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp +++ b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp @@ -27,13 +27,13 @@ static inline scalar_t interpolate_aa_single_dim_zero_strides( scalar_t t = *(scalar_t*)&src_min[0]; index_t wts_idx = *(index_t*)&data[4][0]; - char* wts_ptr = &data[3][wts_idx]; - scalar_t wts = *(scalar_t*)&wts_ptr[0]; + scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx]; + scalar_t wts = wts_ptr[0]; scalar_t output = t * wts; int j = 1; for (; j < ids_size; j++) { - wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; + wts = wts_ptr[j]; t = *(scalar_t*)&src_min[j * ids_stride]; output += t * wts; } @@ -54,13 +54,13 @@ static inline scalar_t interpolate_aa_single_dim( scalar_t t = *(scalar_t*)&src_min[0]; index_t wts_idx = *(index_t*)&data[4][i * strides[4]]; - char* wts_ptr = &data[3][wts_idx]; - scalar_t wts = *(scalar_t*)&wts_ptr[0]; + scalar_t* wts_ptr = (scalar_t*)&data[3][wts_idx]; + scalar_t wts = wts_ptr[0]; scalar_t output = t * wts; int j = 1; for (; j < ids_size; j++) { - wts = *(scalar_t*)&wts_ptr[j * sizeof(scalar_t)]; + wts = wts_ptr[j]; t = *(scalar_t*)&src_min[j * ids_stride]; output += t * wts; }