diff --git a/test/assets/fakedata/logos/gray_pytorch.png b/test/assets/fakedata/logos/gray_pytorch.png new file mode 100644 index 00000000000..412b931299e Binary files /dev/null and b/test/assets/fakedata/logos/gray_pytorch.png differ diff --git a/test/assets/fakedata/logos/grayalpha_pytorch.png b/test/assets/fakedata/logos/grayalpha_pytorch.png new file mode 100644 index 00000000000..3e77d72b904 Binary files /dev/null and b/test/assets/fakedata/logos/grayalpha_pytorch.png differ diff --git a/test/assets/fakedata/logos/pallete_pytorch.png b/test/assets/fakedata/logos/pallete_pytorch.png new file mode 100644 index 00000000000..2108d1b315a Binary files /dev/null and b/test/assets/fakedata/logos/pallete_pytorch.png differ diff --git a/test/assets/fakedata/logos/rgb_pytorch.png b/test/assets/fakedata/logos/rgb_pytorch.png new file mode 100644 index 00000000000..c9d08e6c7da Binary files /dev/null and b/test/assets/fakedata/logos/rgb_pytorch.png differ diff --git a/test/assets/fakedata/logos/rgbalpha_pytorch.png b/test/assets/fakedata/logos/rgbalpha_pytorch.png new file mode 100644 index 00000000000..2108d1b315a Binary files /dev/null and b/test/assets/fakedata/logos/rgbalpha_pytorch.png differ diff --git a/test/test_image.py b/test/test_image.py index af1641c3355..45a4258816e 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -16,7 +16,8 @@ IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") -IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder") +FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") +IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') @@ -133,9 +134,12 @@ def test_write_jpeg(self): self.assertEqual(torch_bytes, pil_bytes) def test_decode_png(self): - for img_path in get_images(IMAGE_DIR, ".png"): + for img_path in get_images(FAKEDATA_DIR, ".png"): img_pil = torch.from_numpy(np.array(Image.open(img_path))) - img_pil = img_pil.permute(2, 0, 1) + if len(img_pil.shape) == 3: + img_pil = img_pil.permute(2, 0, 1) + else: + img_pil = img_pil.unsqueeze(0) data = read_file(img_path) img_lpng = decode_png(data) self.assertTrue(img_lpng.equal(img_pil)) diff --git a/torchvision/csrc/cpu/image/readpng_cpu.cpp b/torchvision/csrc/cpu/image/readpng_cpu.cpp index 3c2141aa2da..6fbe04ac033 100644 --- a/torchvision/csrc/cpu/image/readpng_cpu.cpp +++ b/torchvision/csrc/cpu/image/readpng_cpu.cpp @@ -71,17 +71,34 @@ torch::Tensor decodePNG(const torch::Tensor& data) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); TORCH_CHECK(retval == 1, "Could read image metadata from content.") } - if (color_type != PNG_COLOR_TYPE_RGB) { - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK( - color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.") + + int channels; + switch (color_type) { + case PNG_COLOR_TYPE_RGB: + channels = 3; + break; + case PNG_COLOR_TYPE_RGB_ALPHA: + channels = 4; + break; + case PNG_COLOR_TYPE_GRAY: + channels = 1; + break; + case PNG_COLOR_TYPE_GRAY_ALPHA: + channels = 2; + break; + case PNG_COLOR_TYPE_PALETTE: + channels = 1; + break; + default: + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(false, "Image color type is not supported."); } - auto tensor = - torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8); + auto tensor = torch::empty( + {int64_t(height), int64_t(width), int64_t(channels)}, torch::kU8); auto ptr = tensor.accessor().data(); auto bytes = png_get_rowbytes(png_ptr, info_ptr); - for (decltype(height) i = 0; i < height; ++i) { + for (png_uint_32 i = 0; i < height; ++i) { png_read_row(png_ptr, ptr, nullptr); ptr += bytes; }