From c10f938f5dd4c1849292daae3950c67e6915e0ec Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 30 Aug 2024 13:12:00 +0100 Subject: [PATCH] Add HEIC decoder (#8597) --- setup.py | 17 ++ ...rch_incorrectly_encoded_but_who_cares.heic | Bin 0 -> 4722 bytes test/test_image.py | 62 +++++-- torchvision/csrc/io/image/cpu/decode_heic.cpp | 152 ++++++++++++++++++ torchvision/csrc/io/image/cpu/decode_heic.h | 14 ++ .../csrc/io/image/cpu/decode_image.cpp | 12 ++ torchvision/csrc/io/image/image.cpp | 2 + torchvision/csrc/io/image/image.h | 1 + torchvision/io/__init__.py | 1 + torchvision/io/image.py | 28 +++- 10 files changed, 274 insertions(+), 15 deletions(-) create mode 100644 test/assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic create mode 100644 torchvision/csrc/io/image/cpu/decode_heic.cpp create mode 100644 torchvision/csrc/io/image/cpu/decode_heic.h diff --git a/setup.py b/setup.py index 7f383b82ec4..dbe8ce58aa2 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1" USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" +USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "0") == "1" # TODO enable by default! USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default! USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) @@ -50,6 +51,7 @@ print(f"{USE_PNG = }") print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") +print(f"{USE_HEIC = }") print(f"{USE_AVIF = }") print(f"{USE_NVJPEG = }") print(f"{NVCC_FLAGS = }") @@ -334,6 +336,21 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") + if USE_HEIC: + heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif.h") + if heic_found: + print("Building torchvision with HEIC support") + print(f"{heic_include_dir = }") + print(f"{heic_library_dir = }") + if heic_include_dir is not None and heic_library_dir is not None: + # if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add. + include_dirs.append(heic_include_dir) + library_dirs.append(heic_library_dir) + libraries.append("heif") + define_macros += [("HEIC_FOUND", 1)] + else: + warnings.warn("Building torchvision without HEIC support") + if USE_AVIF: avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h") if avif_found: diff --git a/test/assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic b/test/assets/fakedata/logos/rgb_pytorch_incorrectly_encoded_but_who_cares.heic new file mode 100644 index 0000000000000000000000000000000000000000..4c29ac3c71cf432a7e1bf1be4f049c417705c9e9 GIT binary patch literal 4722 zcmeHKZ9J4&8^3R!VT`-^9EI_H0`59bB| zz$h9{5XEvKHh?rf6m36igVlU4%tT{IGmYi&#Iq&u{hIm3H z9o-#J1<(kfwPJPz0E|DNqC^5{1bmFx90WDnkcbKu0C4??%Y-07v@(p&8Xyi=xe&p` z#1c=imR(VCH$^6`JV52rzs=-3CY_dM_4!$y=V4LkZ0Bx`b_n-RSEg87?KsOK+50JB8O3i zkkg>%U+SQw|NAzdWKJ5-%QY|8oD`T7c)q&k<(iWMa{|v-*SuVFQeaNt|D)>q@~-6y zkjJbPd9@;+E#51@GQC|1n(w-8PsnFJv<+{XihRFKqs_U{zUq9=afR{Gz2{PXvuG|` z(&#e4$bFc5^hFcVer9Io>Ty?1+wu^@uFAL5&4vuED(-g5(?W;u(ucba(avRk{ZS7? zv^aalMi`Xe)mGUJUI>sgNAKttudZLNb*bL&GH-giQG1*-^HM9l0r07PZL!LJNyGHf z`2P0AYX&FEjQm0l)!3-{78iTg^iZ|4hVhImQKq0SzA!M{@P&1&;O*^6lV@kAoM?YI zIq|C$8CpxLHERTqBqwinENZ{w-R_%!9AzaIImt_1b7`^e{6k$Ssf{5SFCv6-UWR}VJ&PY;6C4&aCc z)PS3Ep45nQhAN+$q{mod)*4bHcl766&r9%1zEIqMD^LA-L04JV`qr@Uf^og(^MkoX z6>XEA?Du1Es(6_p=sI&(W{(kq*KGz_jZOnKS04dWhCTudL1hCRyp zrS@8f?*+dLxRw2V<%KT$Tg6#I`FeBSP>c($NrO}qrz2-vuD`h7)ce($37@jxt$Q3} zU&QJsk}?KP?;PAZkzx{}kWSIYjP^CAUQbtiFDn_D8NVxkb|g05ij&&h68uw3rk0LH z;=QlD%`!(W_C^*M6-=0vc$ujEoS7FTsn~C5?BWgP4YB0S8PZu9lrhZD5bHCI z`n?)c4OyPIo}UXlhIdkU;L98hdxdv57v0Do{3ZxPl9}m)E-DTYjv?M2AMwAQf(6!{ z(b@FLGsaX2vxZ#)jw}l~D;-jy+qy0DS)S6eK}~~|ziG{zOpB-9sRE9hz%V{YXb>I^ z{peQPOU(O-{Tj2gOQ7&jl==sFmaNwC$sHK^MI(g~v0)$;?`65Pr}dogzOeRX--kVt zt<&rKS+i=_;bz#i+$!;6;#kfCQGSxiX>s9&f{{`Xe^` zc#{?u9WMlqF?L0LFs{H}Ify%Ssr6aFV#QjIe$DpKWS=5@OZW3q+Lm|QJZtl!D|rUS zyo>jX^|NmbY+$8rTvJwSzBDDuizZyyZ5@iKh-2M@ud3MNA}xy2V*CZAeH?cNR2OsN zM17I%ci{!N>h!ot+=SdL>gZURzV(>T`pIGX=9@_^KVE9;;og;WgfphEu^QOVwcgyR z`?ul+w$I>}OGff82C0)m3Ha2Q(nfE8>e1jXog(GSk6<9y5P|G|^>{Pne-Lix3qyyM9 z3MMCAPXP##rvLx| literal 0 HcmV?d00001 diff --git a/test/test_image.py b/test/test_image.py index f1fe70135fe..d489b10af7c 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -15,6 +15,7 @@ from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence from torchvision.io.image import ( _decode_avif, + _decode_heic, decode_gif, decode_image, decode_jpeg, @@ -928,11 +929,10 @@ def test_decode_avif(decode_fun, scripted): assert img[None].is_contiguous(memory_format=torch.channels_last) -@pytest.mark.xfail(reason="AVIF support not enabled yet.") +@pytest.mark.xfail(reason="AVIF and HEIC support not enabled yet.") # Note: decode_image fails because some of these files have a (valid) signature # we don't recognize. We should probably use libmagic.... -# @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image)) -@pytest.mark.parametrize("decode_fun", (_decode_avif,)) +@pytest.mark.parametrize("decode_fun", (_decode_avif, _decode_heic)) @pytest.mark.parametrize("scripted", (False, True)) @pytest.mark.parametrize( "mode, pil_mode", @@ -942,7 +942,9 @@ def test_decode_avif(decode_fun, scripted): (ImageReadMode.UNCHANGED, None), ), ) -@pytest.mark.parametrize("filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif")) +@pytest.mark.parametrize( + "filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name +) def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename): if "reversed_dimg_order" in str(filename): # Pillow properly decodes this one, but we don't (order of parts of the @@ -960,7 +962,14 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename) except RuntimeError as e: if any( s in str(e) - for s in ("BMFF parsing failed", "avifDecoderParse failed: ", "file contains more than one image") + for s in ( + "BMFF parsing failed", + "avifDecoderParse failed: ", + "file contains more than one image", + "no 'ispe' property", + "'iref' has double references", + "Invalid image grid", + ) ): pytest.skip(reason="Expected failure, that's OK") else: @@ -970,22 +979,47 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename) assert img.shape[0] == 3 if mode == ImageReadMode.RGB_ALPHA: assert img.shape[0] == 4 + if img.dtype == torch.uint16: img = F.to_dtype(img, dtype=torch.uint8, scale=True) + try: + from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode)) + except RuntimeError as e: + if "Invalid image grid" in str(e): + pytest.skip(reason="PIL failure") + else: + raise e - from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode)) - if False: + if True: from torchvision.utils import make_grid g = make_grid([img, from_pil]) F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png")) - if mode != ImageReadMode.RGB: - # We don't compare against PIL for RGB because results look pretty - # different on RGBA images (other images are fine). The result on - # torchvision basically just plainly ignores the alpha channel, resuting - # in transparent pixels looking dark. PIL seems to be using a sort of - # k-nn thing, looking at the output. Take a look at the resuting images. - torch.testing.assert_close(img, from_pil, rtol=0, atol=3) + + is__decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic" + if mode == ImageReadMode.RGB and not is__decode_heic: + # We don't compare torchvision's AVIF against PIL for RGB because + # results look pretty different on RGBA images (other images are fine). + # The result on torchvision basically just plainly ignores the alpha + # channel, resuting in transparent pixels looking dark. PIL seems to be + # using a sort of k-nn thing (Take a look at the resuting images) + return + if filename.name == "sofa_grid1x5_420.avif" and is__decode_heic: + return + + torch.testing.assert_close(img, from_pil, rtol=0, atol=3) + + +@pytest.mark.xfail(reason="HEIC support not enabled yet.") +@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image)) +@pytest.mark.parametrize("scripted", (False, True)) +def test_decode_heic(decode_fun, scripted): + encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic"))) + if scripted: + decode_fun = torch.jit.script(decode_fun) + img = decode_fun(encoded_bytes) + assert img.shape == (3, 100, 100) + assert img[None].is_contiguous(memory_format=torch.channels_last) if __name__ == "__main__": diff --git a/torchvision/csrc/io/image/cpu/decode_heic.cpp b/torchvision/csrc/io/image/cpu/decode_heic.cpp new file mode 100644 index 00000000000..148d6043f10 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_heic.cpp @@ -0,0 +1,152 @@ +#include "decode_heic.h" + +#if HEIC_FOUND +#include "libheif/heif_cxx.h" +#endif // HEIC_FOUND + +namespace vision { +namespace image { + +#if !HEIC_FOUND +torch::Tensor decode_heic( + const torch::Tensor& encoded_data, + ImageReadMode mode) { + TORCH_CHECK( + false, "decode_heic: torchvision not compiled with libheif support"); +} +#else + +torch::Tensor decode_heic( + const torch::Tensor& encoded_data, + ImageReadMode mode) { + TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous."); + TORCH_CHECK( + encoded_data.dtype() == torch::kU8, + "Input tensor must have uint8 data type, got ", + encoded_data.dtype()); + TORCH_CHECK( + encoded_data.dim() == 1, + "Input tensor must be 1-dimensional, got ", + encoded_data.dim(), + " dims."); + + if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB && + mode != IMAGE_READ_MODE_RGB_ALPHA) { + // Other modes aren't supported, but we don't error or even warn because we + // have generic entry points like decode_image which may support all modes, + // it just depends on the underlying decoder. + mode = IMAGE_READ_MODE_UNCHANGED; + } + + // If return_rgb is false it means we return rgba - nothing else. + auto return_rgb = true; + + int height = 0; + int width = 0; + int num_channels = 0; + int stride = 0; + uint8_t* decoded_data = nullptr; + heif::Image img; + int bit_depth = 0; + + try { + heif::Context ctx; + ctx.read_from_memory_without_copy( + encoded_data.data_ptr(), encoded_data.numel()); + + // TODO properly error on (or support) image sequences. Right now, I think + // this function will always return the first image in a sequence, which is + // inconsistent with decode_gif (which returns a batch) and with decode_avif + // (which errors loudly). + // Why? I'm struggling to make sense of + // ctx.get_number_of_top_level_images(). It disagrees with libavif's + // imageCount. For example on some of the libavif test images: + // + // - colors-animated-12bpc-keyframes-0-2-3.avif + // avif num images = 5 + // heif num images = 1 // Why is this 1 when clearly this is supposed to + // be a sequence? + // - sofa_grid1x5_420.avif + // avif num images = 1 + // heif num images = 6 // If we were to error here we won't be able to + // decode this image which is otherwise properly + // decoded by libavif. + // I can't find a libheif function that does what we need here, or at least + // that agrees with libavif. + + // TORCH_CHECK( + // ctx.get_number_of_top_level_images() == 1, + // "heic file contains more than one image"); + + heif::ImageHandle handle = ctx.get_primary_image_handle(); + bit_depth = handle.get_luma_bits_per_pixel(); + + return_rgb = + (mode == IMAGE_READ_MODE_RGB || + (mode == IMAGE_READ_MODE_UNCHANGED && !handle.has_alpha_channel())); + + height = handle.get_height(); + width = handle.get_width(); + + num_channels = return_rgb ? 3 : 4; + heif_chroma chroma; + if (bit_depth == 8) { + chroma = return_rgb ? heif_chroma_interleaved_RGB + : heif_chroma_interleaved_RGBA; + } else { + // TODO: This, along with our 10bits -> 16bits range mapping down below, + // may not work on BE platforms + chroma = return_rgb ? heif_chroma_interleaved_RRGGBB_LE + : heif_chroma_interleaved_RRGGBBAA_LE; + } + + img = handle.decode_image(heif_colorspace_RGB, chroma); + + decoded_data = img.get_plane(heif_channel_interleaved, &stride); + } catch (const heif::Error& err) { + // We need this try/catch block and call TORCH_CHECK, because libheif may + // otherwise throw heif::Error that would just be reported as "An unknown + // exception occurred" when we move back to Python. + TORCH_CHECK(false, "decode_heif failed: ", err.get_message()); + } + TORCH_CHECK(decoded_data != nullptr, "Something went wrong during decoding."); + + auto dtype = (bit_depth == 8) ? torch::kUInt8 : at::kUInt16; + auto out = torch::empty({height, width, num_channels}, dtype); + uint8_t* out_ptr = (uint8_t*)out.data_ptr(); + + // decoded_data is *almost* the raw decoded data, but not quite: for some + // images, there may be some padding at the end of each row, i.e. when stride + // != row_size_in_bytes. So we can't copy decoded_data into the tensor's + // memory directly, we have to copy row by row. Oh, and if you think you can + // take a shortcut when stride == row_size_in_bytes and just do: + // out = torch::from_blob(decoded_data, ...) + // you can't, because decoded_data is owned by the heif::Image object and it + // gets freed when it gets out of scope! + auto row_size_in_bytes = width * num_channels * ((bit_depth == 8) ? 1 : 2); + for (auto h = 0; h < height; h++) { + memcpy( + out_ptr + h * row_size_in_bytes, + decoded_data + h * stride, + row_size_in_bytes); + } + if (bit_depth > 8) { + // Say bit depth is 10. decodec_data and out_ptr contain 10bits values + // over 2 bytes, stored into uint16_t. In torchvision a uint16 value is + // expected to be in [0, 2**16), so we have to map the 10bits value to that + // range. Note that other libraries like libavif do that mapping + // automatically. + // TODO: It's possible to avoid the memcpy call above in this case, and do + // the copy at the same time as the conversation. Whether it's worth it + // should be benchmarked. + auto out_ptr_16 = (uint16_t*)out_ptr; + for (auto p = 0; p < height * width * num_channels; p++) { + out_ptr_16[p] <<= (16 - bit_depth); + } + } + return out.permute({2, 0, 1}); +} +#endif // HEIC_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_heic.h b/torchvision/csrc/io/image/cpu/decode_heic.h new file mode 100644 index 00000000000..4a23e4c1431 --- /dev/null +++ b/torchvision/csrc/io/image/cpu/decode_heic.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_heic( + const torch::Tensor& data, + ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index e5a421b7287..9c1a7ff3ef4 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -2,6 +2,7 @@ #include "decode_avif.h" #include "decode_gif.h" +#include "decode_heic.h" #include "decode_jpeg.h" #include "decode_png.h" #include "decode_webp.h" @@ -61,6 +62,17 @@ torch::Tensor decode_image( return decode_avif(data, mode); } + // Similarly for heic we assume the signature is "ftypeheic" but some files + // may come as "ftypmif1" where the "heic" part is defined later in the file. + // We can't be re-inventing libmagic here. We might need to start relying on + // it though... + const uint8_t heic_signature[8] = { + 0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63}; // == "ftypheic" + TORCH_CHECK(data.numel() >= 12, err_msg); + if ((memcmp(heic_signature, datap + 4, 8) == 0)) { + return decode_heic(data, mode); + } + const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF" const uint8_t webp_signature_end[7] = { 0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8" diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index a777d19d3bd..f0ce91144a6 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -23,6 +23,8 @@ static auto registry = &decode_jpeg) .op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor", &decode_webp) + .op("image::decode_heic(Tensor encoded_data, int mode) -> Tensor", + &decode_heic) .op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor", &decode_avif) .op("image::encode_jpeg", &encode_jpeg) diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index 91a5144fa1c..23493f3c030 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -2,6 +2,7 @@ #include "cpu/decode_avif.h" #include "cpu/decode_gif.h" +#include "cpu/decode_heic.h" #include "cpu/decode_image.h" #include "cpu/decode_jpeg.h" #include "cpu/decode_png.h" diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 08a0d6d62b7..a604ea1fdb6 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -61,6 +61,7 @@ "decode_image", "decode_jpeg", "decode_png", + "decode_heic", "decode_webp", "decode_gif", "encode_jpeg", diff --git a/torchvision/io/image.py b/torchvision/io/image.py index e169c0a4f7a..f1df0d52672 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -417,5 +417,31 @@ def _decode_avif( Decoded image (Tensor[image_channels, image_height, image_width]) """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(decode_webp) + _log_api_usage_once(_decode_avif) return torch.ops.image.decode_avif(input, mode.value) + + +def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: + """ + Decode an HEIC image into a 3 dimensional RGB[A] Tensor. + + The values of the output tensor are in uint8 in [0, 255] for most images. If + the image has a bit-depth of more than 8, then the output tensor is uint16 + in [0, 65535]. Since uint16 support is limited in pytorch, we recommend + calling :func:`torchvision.transforms.v2.functional.to_dtype()` with + ``scale=True`` after this function to convert the decoded image into a uint8 + or float tensor. + + Args: + input (Tensor[1]): a one dimensional contiguous uint8 tensor containing + the raw bytes of the HEIC image. + mode (ImageReadMode): The read mode used for optionally + converting the image color space. Default: ``ImageReadMode.UNCHANGED``. + Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``. + + Returns: + Decoded image (Tensor[image_channels, image_height, image_width]) + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(_decode_heic) + return torch.ops.image.decode_heic(input, mode.value)