diff --git a/test/assets/interlaced_png/wizard_low-interlaced.png b/test/assets/interlaced_png/wizard_low-interlaced.png new file mode 100644 index 00000000000..3badd9264dc Binary files /dev/null and b/test/assets/interlaced_png/wizard_low-interlaced.png differ diff --git a/test/assets/interlaced_png/wizard_low.png b/test/assets/interlaced_png/wizard_low.png new file mode 100644 index 00000000000..7b1c264f030 Binary files /dev/null and b/test/assets/interlaced_png/wizard_low.png differ diff --git a/test/test_image.py b/test/test_image.py index 47023a45be2..7c6764dce64 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -20,6 +20,7 @@ IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") +INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png") IS_WINDOWS = sys.platform in ('win32', 'cygwin') PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) @@ -304,6 +305,15 @@ def test_read_1_bit_png_consistency(shape, mode): assert_equal(img1, img2) +def test_read_interlaced_png(): + imgs = list(get_images(INTERLACED_PNG, ".png")) + with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2: + assert not (im1.info.get("interlace") is im2.info.get("interlace")) + img1 = read_image(imgs[0]) + img2 = read_image(imgs[1]) + assert_equal(img1, img2) + + @needs_cuda @pytest.mark.parametrize('img_path', [ pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index b40fd951d5b..ea38272c978 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -55,6 +55,7 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { png_uint_32 width, height; int bit_depth, color_type; + int interlace_type; auto retval = png_get_IHDR( png_ptr, info_ptr, @@ -62,7 +63,7 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { &height, &bit_depth, &color_type, - nullptr, + &interlace_type, nullptr, nullptr); @@ -81,6 +82,13 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8) png_set_expand_gray_1_2_4_to_8(png_ptr); + int number_of_passes; + if (interlace_type == PNG_INTERLACE_ADAM7) { + number_of_passes = png_set_interlace_handling(png_ptr); + } else { + number_of_passes = 1; + } + if (mode != IMAGE_READ_MODE_UNCHANGED) { // TODO: consider supporting PNG_INFO_tRNS bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0; @@ -163,9 +171,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { auto tensor = torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); auto ptr = tensor.accessor().data(); - for (png_uint_32 i = 0; i < height; ++i) { - png_read_row(png_ptr, ptr, nullptr); - ptr += width * channels; + for (int pass = 0; pass < number_of_passes; pass++) { + for (png_uint_32 i = 0; i < height; ++i) { + png_read_row(png_ptr, ptr, nullptr); + ptr += width * channels; + } + ptr = tensor.accessor().data(); } png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); return tensor.permute({2, 0, 1});