Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ImageToTensor module and resize for non-batched images #978

Merged
merged 3 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion kornia/geometry/transform/affwarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
warp_affine3d, get_projective_transform
)
from kornia.utils import _extract_device_dtype
from kornia.utils.image import _to_bchw

__all__ = [
"affine",
Expand Down Expand Up @@ -543,7 +544,16 @@ def resize(input: torch.Tensor, size: Union[int, Tuple[int, int]],
if size == input_size:
return input

return torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
# TODO: find a proper way to handle this cases in the future
input_tmp = _to_bchw(input)

output = torch.nn.functional.interpolate(
input_tmp, size=size, mode=interpolation, align_corners=align_corners)

if len(input.shape) != len(output.shape):
output = output.squeeze()

return output


def rescale(
Expand Down
3 changes: 2 additions & 1 deletion kornia/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .one_hot import one_hot
from .grid import create_meshgrid, create_meshgrid3d
from .image import tensor_to_image, image_to_tensor
from .image import tensor_to_image, image_to_tensor, ImageToTensor
from .pointcloud_io import save_pointcloud_ply, load_pointcloud_ply
from .draw import draw_rectangle
from .helpers import _extract_device_dtype
Expand All @@ -17,4 +17,5 @@
"load_pointcloud_ply",
"draw_rectangle",
"_extract_device_dtype",
"ImageToTensor",
]
16 changes: 16 additions & 0 deletions kornia/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
import torch.nn as nn


def image_to_tensor(image: np.ndarray, keepdim: bool = True) -> torch.Tensor:
Expand Down Expand Up @@ -141,3 +142,18 @@ def tensor_to_image(tensor: torch.Tensor) -> np.array:
"Cannot process tensor with shape {}".format(input_shape))

return image


class ImageToTensor(nn.Module):
"""Converts a numpy image to a PyTorch 4d tensor image.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By numpy image, are we assuming it is uint8 or floating points within 0~1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both - it just permutes, we don;t modify dtype

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the problem. If I am a user, I would expect this function returns a value as a Kornia image, which is shaped as BCHW and ranged from 0-1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with it, but tensor_to_image doesn't do that. And this is just the module of that function

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably we could provide an extra operator TensorToDtype or similar.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shijianjian another option is the following:

def image_to_tensor(image: np.ndarray, keepdim: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
    ...
    tensor = ...
    if dtype is None:
        dtype = tensor.dtype
    return tensor.to(device, dtype)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant is the conversion from 8bit image, 12bit image, or 16bit image to 0-1 scale.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for that we have normalize. With the definition below you could do the job

nn.Sequential(
  ImageToTensor(dtype=torch.float32),
  Normalise(mean=0., std=255.),
)

I wouldn't change the default behavior of image_to_tensor since it might a large set of users


Args:
keepdim (bool): If ``False`` unsqueeze the input image to match the shape
:math:`(B, H, W, C)`. Default: ``True``
"""
def __init__(self, keepdim: bool = False):
super().__init__()
self.keepdim = keepdim

def forward(self, x: np.ndarray) -> torch.Tensor:
return image_to_tensor(x, keepdim=self.keepdim)
5 changes: 5 additions & 0 deletions test/geometry/transform/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def test_smoke(self, device, dtype):
out = kornia.resize(inp, (3, 4))
assert_allclose(inp, out, atol=1e-4, rtol=1e-4)

def test_no_batch(self, device, dtype):
inp = torch.rand(3, 3, 4, device=device, dtype=dtype)
out = kornia.resize(inp, (2, 5))
assert out.shape == (3, 2, 5)

ducha-aiki marked this conversation as resolved.
Show resolved Hide resolved
def test_upsize(self, device, dtype):
inp = torch.rand(1, 3, 3, 4, device=device, dtype=dtype)
out = kornia.resize(inp, (6, 8))
Expand Down
5 changes: 5 additions & 0 deletions test/utils/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np

import torch
from torch.testing import assert_allclose

import kornia as kornia
import kornia.testing as utils # test utils

Expand Down Expand Up @@ -43,6 +45,9 @@ def test_image_to_tensor(input_shape, expected):
assert tensor.shape == expected
assert isinstance(tensor, torch.Tensor)

to_tensor = kornia.utils.ImageToTensor(keepdim=False)
assert_allclose(tensor, to_tensor(image))


@pytest.mark.parametrize("input_shape, expected",
[((4, 4), (1, 4, 4)),
Expand Down