Skip to content

Commit

Permalink
[fbsync] deinterlacing PNG images with read_image (#4268)
Browse files Browse the repository at this point in the history
Summary:
* interlaced png images

Reviewed By: NicolasHug

Differential Revision: D30417198

fbshipit-source-id: fe5ba53b6e668aa55e2e6d4702d1be559f848b57

Co-authored-by: Vincent Moens <vmoens@fb.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Aug 20, 2021
1 parent 5a24192 commit 12fc3d4
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/assets/interlaced_png/wizard_low.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png")
IS_WINDOWS = sys.platform in ('win32', 'cygwin')
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.'))

Expand Down Expand Up @@ -304,6 +305,15 @@ def test_read_1_bit_png_consistency(shape, mode):
assert_equal(img1, img2)


def test_read_interlaced_png():
imgs = list(get_images(INTERLACED_PNG, ".png"))
with Image.open(imgs[0]) as im1, Image.open(imgs[1]) as im2:
assert not (im1.info.get("interlace") is im2.info.get("interlace"))
img1 = read_image(imgs[0])
img2 = read_image(imgs[1])
assert_equal(img1, img2)


@needs_cuda
@pytest.mark.parametrize('img_path', [
pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path))
Expand Down
19 changes: 15 additions & 4 deletions torchvision/csrc/io/image/cpu/decode_png.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {

png_uint_32 width, height;
int bit_depth, color_type;
int interlace_type;
auto retval = png_get_IHDR(
png_ptr,
info_ptr,
&width,
&height,
&bit_depth,
&color_type,
nullptr,
&interlace_type,
nullptr,
nullptr);

Expand All @@ -81,6 +82,13 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
if (color_type == PNG_COLOR_TYPE_GRAY && bit_depth < 8)
png_set_expand_gray_1_2_4_to_8(png_ptr);

int number_of_passes;
if (interlace_type == PNG_INTERLACE_ADAM7) {
number_of_passes = png_set_interlace_handling(png_ptr);
} else {
number_of_passes = 1;
}

if (mode != IMAGE_READ_MODE_UNCHANGED) {
// TODO: consider supporting PNG_INFO_tRNS
bool is_palette = (color_type & PNG_COLOR_MASK_PALETTE) != 0;
Expand Down Expand Up @@ -163,9 +171,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
auto tensor =
torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8);
auto ptr = tensor.accessor<uint8_t, 3>().data();
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += width * channels;
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, ptr, nullptr);
ptr += width * channels;
}
ptr = tensor.accessor<uint8_t, 3>().data();
}
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
return tensor.permute({2, 0, 1});
Expand Down

0 comments on commit 12fc3d4

Please sign in to comment.