From 78159d61b966ec32ac330c28bea32e9e12fe6eb4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 10 Nov 2020 14:05:32 +0000 Subject: [PATCH] Extend the supported types of decodePNG (#2984) * Add support of different color types in readpng. * Adding test images and unit-tests. * Use closest possible type. * Fix formatting. --- test/assets/fakedata/logos/gray_pytorch.png | Bin 0 -> 433 bytes .../fakedata/logos/grayalpha_pytorch.png | Bin 0 -> 590 bytes .../assets/fakedata/logos/pallete_pytorch.png | Bin 0 -> 1151 bytes test/assets/fakedata/logos/rgb_pytorch.png | Bin 0 -> 575 bytes .../fakedata/logos/rgbalpha_pytorch.png | Bin 0 -> 1151 bytes test/test_image.py | 10 ++++-- torchvision/csrc/cpu/image/readpng_cpu.cpp | 31 ++++++++++++++---- 7 files changed, 31 insertions(+), 10 deletions(-) create mode 100644 test/assets/fakedata/logos/gray_pytorch.png create mode 100644 test/assets/fakedata/logos/grayalpha_pytorch.png create mode 100644 test/assets/fakedata/logos/pallete_pytorch.png create mode 100644 test/assets/fakedata/logos/rgb_pytorch.png create mode 100644 test/assets/fakedata/logos/rgbalpha_pytorch.png diff --git a/test/assets/fakedata/logos/gray_pytorch.png b/test/assets/fakedata/logos/gray_pytorch.png new file mode 100644 index 0000000000000000000000000000000000000000..412b931299ebcc6afa4d677c514e1c2b21681545 GIT binary patch literal 433 zcmV;i0Z#sjP)0ro^9)=28F~;SeD*4gD*!~G5If)rs8Wl*B;YfgVGXzLK zgF^`8kQW~_Mw!93@Bz6Rn3H1|7@{-q$B!Te1Amp;*~J-90J02~7=RQ*Oa`I|>FU{6 z1TwFSGUkBL3Uwm*pdk@tT9uTmDJZgH7u*>I5X%U)a;1}x0<;FN$J%3!b97QS_;UJFHV0qq*N-E bO8=ubnYTlDCaD3600000NkvXXu0mjft7Npb literal 0 HcmV?d00001 diff --git a/test/assets/fakedata/logos/grayalpha_pytorch.png b/test/assets/fakedata/logos/grayalpha_pytorch.png new file mode 100644 index 0000000000000000000000000000000000000000..3e77d72b904b14aebc99ff98d0b8acfbcedf7603 GIT binary patch literal 590 zcmV-U0cL5JhDg06L+Clt2g2g?F%27fDK@gj7H}2#df8(Z}}8ypeT|-(4j+)=1Z*KZHOC zA%qY@2q7usr9~|d-W9aba2IK%;VjZl!&RiEhNDPZ4L6b28crhZHC#m2&~Ol0M*}WW z)cC^IQ3fhf)`%~qC()T3aG@j`f1e_(epp5{8psg^5?QX{1w>?h4Twk&8UPWSMmK|9 zqj76IPQy6;_R*=X7&6k6hC#$b!?NBhMSL`>kr#m5Ye6*hs6+~t0yT6>3K3t8N)59> zy4G4V)qlhy`P;HDYk35DPRtp>tj_i*LXq@M;Wim*Bxz3EFR@!867kDNDVi)N+sw#4 z*nC(c3Pjppy-8|h*Xj+m2!4dvqB0}BL6P2|NN-SNJOf(+h^o)Xqq?L^)Gc-`pG^s> zM)W1N2}{H+b}jo&$27JZfo*0ags{8+ipxnBX;*x#ig+gSxH3gV6?d(|J?*D7#(V_% z%R!8hj}9W}{;FFT8Sn9+Yryyx_UUh$&Z^7EYxc`tH|)ntu2G1-0Odd}H(UB9F;T9l zXPjCFmNvs$KE*I~);rT&WAr8K;Qxi=#&bL;&p(ClsREW2$+(yycU!iF)wqttAslTH zemPqA2av;(H7|%w4(!yYl|h|5VcYmPgp;_?33q&kpTn1C27LqIbfGY%)BX~J5JCtc cgb>mtKXGfO%i=ibiU0rr07*qoM6N<$g2L7Ww*UYD literal 0 HcmV?d00001 diff --git a/test/assets/fakedata/logos/pallete_pytorch.png b/test/assets/fakedata/logos/pallete_pytorch.png new file mode 100644 index 0000000000000000000000000000000000000000..2108d1b315a73725115f22033954469a50718cb0 GIT binary patch literal 1151 zcmeAS@N?(olHy`uVBq!ia0vp^DImB|mLR^8|cRo5KAwfYwAt51Q zVPO#w5m8Z5F)=Z5ad8O=2}wywDJdywX=xc58Ch9bIXO9bd3gl|1w}gt-Bn%dghy1Kgh`uc{3hQ`LmrlzLm=H`}`me$tRwzjtR_V$jBj?T``uCA``?(UwR zp5ETxzP`Tx{{9IQCQO_-anhtolP6D}GG)rtsZ*y-n>Ky=^cgc|%$zxM)~s2xXV0E9 zXU^QYbLY*QH-G;81q&7|T)1%2qD6}rFJ7`_$TOXvSrKGty{Nk+qQlC_8mKR?A*C?*REZ= zckkY_XV2cfd-v_zw}1cs0|yQqJb3WXp+kocA3k#A$kC%mj~zR9{P^(`Cr+F^dGge$ zQ>Ra#K6B>G*|TTQojZ5_{P_zPE?m5L@zSMBmoHzwa^=d^t5>gGyLSEh^&2;C+`M`7 z)~#E&Z{NOi=g!@`ckkW1cmMwV2M-=ReE9IuqeqV)KYsG$$k{E1dg-FJ7e^mhs|3&^5@)cIC+sjHim%3}Mok{T)78djzFFitsZ+IOrlb?cdT3Trz;9>2euwBSgD zrq`*SDYKFnX_^}GB|FF%x*f6J>-MGWnELa1PQ~X0K%V`yzLK%|qN}c}IBPgC6d62S L{an^LB{Ts53B@u^ literal 0 HcmV?d00001 diff --git a/test/assets/fakedata/logos/rgb_pytorch.png b/test/assets/fakedata/logos/rgb_pytorch.png new file mode 100644 index 0000000000000000000000000000000000000000..c9d08e6c7da91991a780ded69d966fbd0c18eb5a GIT binary patch literal 575 zcmV-F0>J%=P)mlEGZGj7-Nhv#u#IaF~&Ub z`}JAfg#UT3ZorwgrlOmy&ZeT3tmdYokF5TtqKT|6OhpG-yO@e{SsR&(T3LIUib7f2 znTje|JDQ3TSx-%UeE*C;Ugi2tTyNdf{F~uOlZ7=kb3iOQS&ODRAd+PmaMjNnZUjJ9 zlUR+5Lc--_C1A-ayjS3rUX``c4B&cG-3=31RsxEw^1%(M0Zvu|%SvEb2`nptWhJnz z1eTScX5Ns^HuHuz#(W(4^W5BYWUP<8T=~wzvHVw=eSveE- z@vmU*u$W9x_LNA6ouqO*$_do4iT5CcyX{v$JfuD{;{F;oStf}>w4t0GG0jzx?!IRz zQ-)W1Qwl#ZaW_~0ufxGg%Bmjxm&rPqx3c N002ovPDHLkV1i6h2Xz1d literal 0 HcmV?d00001 diff --git a/test/assets/fakedata/logos/rgbalpha_pytorch.png b/test/assets/fakedata/logos/rgbalpha_pytorch.png new file mode 100644 index 0000000000000000000000000000000000000000..2108d1b315a73725115f22033954469a50718cb0 GIT binary patch literal 1151 zcmeAS@N?(olHy`uVBq!ia0vp^DImB|mLR^8|cRo5KAwfYwAt51Q zVPO#w5m8Z5F)=Z5ad8O=2}wywDJdywX=xc58Ch9bIXO9bd3gl|1w}gt-Bn%dghy1Kgh`uc{3hQ`LmrlzLm=H`}`me$tRwzjtR_V$jBj?T``uCA``?(UwR zp5ETxzP`Tx{{9IQCQO_-anhtolP6D}GG)rtsZ*y-n>Ky=^cgc|%$zxM)~s2xXV0E9 zXU^QYbLY*QH-G;81q&7|T)1%2qD6}rFJ7`_$TOXvSrKGty{Nk+qQlC_8mKR?A*C?*REZ= zckkY_XV2cfd-v_zw}1cs0|yQqJb3WXp+kocA3k#A$kC%mj~zR9{P^(`Cr+F^dGge$ zQ>Ra#K6B>G*|TTQojZ5_{P_zPE?m5L@zSMBmoHzwa^=d^t5>gGyLSEh^&2;C+`M`7 z)~#E&Z{NOi=g!@`ckkW1cmMwV2M-=ReE9IuqeqV)KYsG$$k{E1dg-FJ7e^mhs|3&^5@)cIC+sjHim%3}Mok{T)78djzFFitsZ+IOrlb?cdT3Trz;9>2euwBSgD zrq`*SDYKFnX_^}GB|FF%x*f6J>-MGWnELa1PQ~X0K%V`yzLK%|qN}c}IBPgC6d62S L{an^LB{Ts53B@u^ literal 0 HcmV?d00001 diff --git a/test/test_image.py b/test/test_image.py index af1641c3355..45a4258816e 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -16,7 +16,8 @@ IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") -IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder") +FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") +IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') @@ -133,9 +134,12 @@ def test_write_jpeg(self): self.assertEqual(torch_bytes, pil_bytes) def test_decode_png(self): - for img_path in get_images(IMAGE_DIR, ".png"): + for img_path in get_images(FAKEDATA_DIR, ".png"): img_pil = torch.from_numpy(np.array(Image.open(img_path))) - img_pil = img_pil.permute(2, 0, 1) + if len(img_pil.shape) == 3: + img_pil = img_pil.permute(2, 0, 1) + else: + img_pil = img_pil.unsqueeze(0) data = read_file(img_path) img_lpng = decode_png(data) self.assertTrue(img_lpng.equal(img_pil)) diff --git a/torchvision/csrc/cpu/image/readpng_cpu.cpp b/torchvision/csrc/cpu/image/readpng_cpu.cpp index 3c2141aa2da..6fbe04ac033 100644 --- a/torchvision/csrc/cpu/image/readpng_cpu.cpp +++ b/torchvision/csrc/cpu/image/readpng_cpu.cpp @@ -71,17 +71,34 @@ torch::Tensor decodePNG(const torch::Tensor& data) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); TORCH_CHECK(retval == 1, "Could read image metadata from content.") } - if (color_type != PNG_COLOR_TYPE_RGB) { - png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK( - color_type == PNG_COLOR_TYPE_RGB, "Non RGB images are not supported.") + + int channels; + switch (color_type) { + case PNG_COLOR_TYPE_RGB: + channels = 3; + break; + case PNG_COLOR_TYPE_RGB_ALPHA: + channels = 4; + break; + case PNG_COLOR_TYPE_GRAY: + channels = 1; + break; + case PNG_COLOR_TYPE_GRAY_ALPHA: + channels = 2; + break; + case PNG_COLOR_TYPE_PALETTE: + channels = 1; + break; + default: + png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); + TORCH_CHECK(false, "Image color type is not supported."); } - auto tensor = - torch::empty({int64_t(height), int64_t(width), int64_t(3)}, torch::kU8); + auto tensor = torch::empty( + {int64_t(height), int64_t(width), int64_t(channels)}, torch::kU8); auto ptr = tensor.accessor().data(); auto bytes = png_get_rowbytes(png_ptr, info_ptr); - for (decltype(height) i = 0; i < height; ++i) { + for (png_uint_32 i = 0; i < height; ++i) { png_read_row(png_ptr, ptr, nullptr); ptr += bytes; }