Skip to content

Commit

Permalink
Extend the supported types of decodePNG (pytorch#2984)
Browse files Browse the repository at this point in the history
* Add support of different color types in readpng.

* Adding test images and unit-tests.

* Use closest possible type.

* Fix formatting.
  • Loading branch information
datumbox authored and vfdev-5 committed Dec 4, 2020
1 parent d3fdb71 commit e9190b2
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 10 deletions.
Binary file added test/assets/fakedata/logos/gray_pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/fakedata/logos/grayalpha_pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/fakedata/logos/pallete_pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/fakedata/logos/rgb_pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/fakedata/logos/rgbalpha_pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 7 additions & 3 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@


IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')


Expand Down Expand Up @@ -133,9 +134,12 @@ def test_write_jpeg(self):
self.assertEqual(torch_bytes, pil_bytes)

def test_decode_png(self):
for img_path in get_images(IMAGE_DIR, ".png"):
for img_path in get_images(FAKEDATA_DIR, ".png"):
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
img_pil = img_pil.permute(2, 0, 1)
if len(img_pil.shape) == 3:
img_pil = img_pil.permute(2, 0, 1)
else:
img_pil = img_pil.unsqueeze(0)
data = read_file(img_path)
img_lpng = decode_png(data)
self.assertTrue(img_lpng.equal(img_pil))
Expand Down
31 changes: 24 additions & 7 deletions torchvision/csrc/cpu/image/readpng_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,34 @@ torch::Tensor decodePNG(const torch::Tensor& data) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}
if (color_type != PNG_COLOR_TYPE_RGB) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(
color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.")

int channels;
switch (color_type) {
case PNG_COLOR_TYPE_RGB:
channels = 3;
break;
case PNG_COLOR_TYPE_RGB_ALPHA:
channels = 4;
break;
case PNG_COLOR_TYPE_GRAY:
channels = 1;
break;
case PNG_COLOR_TYPE_GRAY_ALPHA:
channels = 2;
break;
case PNG_COLOR_TYPE_PALETTE:
channels = 1;
break;
default:
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "Image color type is not supported.");
}

auto tensor =
torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8);
auto tensor = torch::empty(
{int64_t(height), int64_t(width), int64_t(channels)}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
auto bytes = png_get_rowbytes(png_ptr, info_ptr);
for (decltype(height) i = 0; i < height; ++i) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += bytes;
}
Expand Down

0 comments on commit e9190b2

Please sign in to comment.