Skip to content

Commit

Permalink
Add HEIC decoder (#8597)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 30, 2024
1 parent a59c939 commit c10f938
Show file tree
Hide file tree
Showing 10 changed files with 274 additions and 15 deletions.
17 changes: 17 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
USE_HEIC = os.getenv("TORCHVISION_USE_HEIC", "0") == "1" # TODO enable by default!
USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default!
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
Expand Down Expand Up @@ -50,6 +51,7 @@
print(f"{USE_PNG = }")
print(f"{USE_JPEG = }")
print(f"{USE_WEBP = }")
print(f"{USE_HEIC = }")
print(f"{USE_AVIF = }")
print(f"{USE_NVJPEG = }")
print(f"{NVCC_FLAGS = }")
Expand Down Expand Up @@ -334,6 +336,21 @@ def make_image_extension():
else:
warnings.warn("Building torchvision without WEBP support")

if USE_HEIC:
heic_found, heic_include_dir, heic_library_dir = find_library(header="libheif/heif.h")
if heic_found:
print("Building torchvision with HEIC support")
print(f"{heic_include_dir = }")
print(f"{heic_library_dir = }")
if heic_include_dir is not None and heic_library_dir is not None:
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
include_dirs.append(heic_include_dir)
library_dirs.append(heic_library_dir)
libraries.append("heif")
define_macros += [("HEIC_FOUND", 1)]
else:
warnings.warn("Building torchvision without HEIC support")

if USE_AVIF:
avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h")
if avif_found:
Expand Down
Binary file not shown.
62 changes: 48 additions & 14 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_decode_avif,
_decode_heic,
decode_gif,
decode_image,
decode_jpeg,
Expand Down Expand Up @@ -928,11 +929,10 @@ def test_decode_avif(decode_fun, scripted):
assert img[None].is_contiguous(memory_format=torch.channels_last)


@pytest.mark.xfail(reason="AVIF support not enabled yet.")
@pytest.mark.xfail(reason="AVIF and HEIC support not enabled yet.")
# Note: decode_image fails because some of these files have a (valid) signature
# we don't recognize. We should probably use libmagic....
# @pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("decode_fun", (_decode_avif,))
@pytest.mark.parametrize("decode_fun", (_decode_avif, _decode_heic))
@pytest.mark.parametrize("scripted", (False, True))
@pytest.mark.parametrize(
"mode, pil_mode",
Expand All @@ -942,7 +942,9 @@ def test_decode_avif(decode_fun, scripted):
(ImageReadMode.UNCHANGED, None),
),
)
@pytest.mark.parametrize("filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"))
@pytest.mark.parametrize(
"filename", Path("/home/nicolashug/dev/libavif/tests/data/").glob("*.avif"), ids=lambda p: p.name
)
def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename):
if "reversed_dimg_order" in str(filename):
# Pillow properly decodes this one, but we don't (order of parts of the
Expand All @@ -960,7 +962,14 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
except RuntimeError as e:
if any(
s in str(e)
for s in ("BMFF parsing failed", "avifDecoderParse failed: ", "file contains more than one image")
for s in (
"BMFF parsing failed",
"avifDecoderParse failed: ",
"file contains more than one image",
"no 'ispe' property",
"'iref' has double references",
"Invalid image grid",
)
):
pytest.skip(reason="Expected failure, that's OK")
else:
Expand All @@ -970,22 +979,47 @@ def test_decode_avif_against_pil(decode_fun, scripted, mode, pil_mode, filename)
assert img.shape[0] == 3
if mode == ImageReadMode.RGB_ALPHA:
assert img.shape[0] == 4

if img.dtype == torch.uint16:
img = F.to_dtype(img, dtype=torch.uint8, scale=True)
try:
from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
except RuntimeError as e:
if "Invalid image grid" in str(e):
pytest.skip(reason="PIL failure")
else:
raise e

from_pil = F.pil_to_tensor(Image.open(filename).convert(pil_mode))
if False:
if True:
from torchvision.utils import make_grid

g = make_grid([img, from_pil])
F.to_pil_image(g).save((f"/home/nicolashug/out_images/{filename.name}.{pil_mode}.png"))
if mode != ImageReadMode.RGB:
# We don't compare against PIL for RGB because results look pretty
# different on RGBA images (other images are fine). The result on
# torchvision basically just plainly ignores the alpha channel, resuting
# in transparent pixels looking dark. PIL seems to be using a sort of
# k-nn thing, looking at the output. Take a look at the resuting images.
torch.testing.assert_close(img, from_pil, rtol=0, atol=3)

is__decode_heic = getattr(decode_fun, "__name__", getattr(decode_fun, "name", None)) == "_decode_heic"
if mode == ImageReadMode.RGB and not is__decode_heic:
# We don't compare torchvision's AVIF against PIL for RGB because
# results look pretty different on RGBA images (other images are fine).
# The result on torchvision basically just plainly ignores the alpha
# channel, resuting in transparent pixels looking dark. PIL seems to be
# using a sort of k-nn thing (Take a look at the resuting images)
return
if filename.name == "sofa_grid1x5_420.avif" and is__decode_heic:
return

torch.testing.assert_close(img, from_pil, rtol=0, atol=3)


@pytest.mark.xfail(reason="HEIC support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (_decode_heic, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_heic(decode_fun, scripted):
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".heic")))
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)


if __name__ == "__main__":
Expand Down
152 changes: 152 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_heic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#include "decode_heic.h"

#if HEIC_FOUND
#include "libheif/heif_cxx.h"
#endif // HEIC_FOUND

namespace vision {
namespace image {

#if !HEIC_FOUND
torch::Tensor decode_heic(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(
false, "decode_heic: torchvision not compiled with libheif support");
}
#else

torch::Tensor decode_heic(
const torch::Tensor& encoded_data,
ImageReadMode mode) {
TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");

if (mode != IMAGE_READ_MODE_UNCHANGED && mode != IMAGE_READ_MODE_RGB &&
mode != IMAGE_READ_MODE_RGB_ALPHA) {
// Other modes aren't supported, but we don't error or even warn because we
// have generic entry points like decode_image which may support all modes,
// it just depends on the underlying decoder.
mode = IMAGE_READ_MODE_UNCHANGED;
}

// If return_rgb is false it means we return rgba - nothing else.
auto return_rgb = true;

int height = 0;
int width = 0;
int num_channels = 0;
int stride = 0;
uint8_t* decoded_data = nullptr;
heif::Image img;
int bit_depth = 0;

try {
heif::Context ctx;
ctx.read_from_memory_without_copy(
encoded_data.data_ptr<uint8_t>(), encoded_data.numel());

// TODO properly error on (or support) image sequences. Right now, I think
// this function will always return the first image in a sequence, which is
// inconsistent with decode_gif (which returns a batch) and with decode_avif
// (which errors loudly).
// Why? I'm struggling to make sense of
// ctx.get_number_of_top_level_images(). It disagrees with libavif's
// imageCount. For example on some of the libavif test images:
//
// - colors-animated-12bpc-keyframes-0-2-3.avif
// avif num images = 5
// heif num images = 1 // Why is this 1 when clearly this is supposed to
// be a sequence?
// - sofa_grid1x5_420.avif
// avif num images = 1
// heif num images = 6 // If we were to error here we won't be able to
// decode this image which is otherwise properly
// decoded by libavif.
// I can't find a libheif function that does what we need here, or at least
// that agrees with libavif.

// TORCH_CHECK(
// ctx.get_number_of_top_level_images() == 1,
// "heic file contains more than one image");

heif::ImageHandle handle = ctx.get_primary_image_handle();
bit_depth = handle.get_luma_bits_per_pixel();

return_rgb =
(mode == IMAGE_READ_MODE_RGB ||
(mode == IMAGE_READ_MODE_UNCHANGED && !handle.has_alpha_channel()));

height = handle.get_height();
width = handle.get_width();

num_channels = return_rgb ? 3 : 4;
heif_chroma chroma;
if (bit_depth == 8) {
chroma = return_rgb ? heif_chroma_interleaved_RGB
: heif_chroma_interleaved_RGBA;
} else {
// TODO: This, along with our 10bits -> 16bits range mapping down below,
// may not work on BE platforms
chroma = return_rgb ? heif_chroma_interleaved_RRGGBB_LE
: heif_chroma_interleaved_RRGGBBAA_LE;
}

img = handle.decode_image(heif_colorspace_RGB, chroma);

decoded_data = img.get_plane(heif_channel_interleaved, &stride);
} catch (const heif::Error& err) {
// We need this try/catch block and call TORCH_CHECK, because libheif may
// otherwise throw heif::Error that would just be reported as "An unknown
// exception occurred" when we move back to Python.
TORCH_CHECK(false, "decode_heif failed: ", err.get_message());
}
TORCH_CHECK(decoded_data != nullptr, "Something went wrong during decoding.");

auto dtype = (bit_depth == 8) ? torch::kUInt8 : at::kUInt16;
auto out = torch::empty({height, width, num_channels}, dtype);
uint8_t* out_ptr = (uint8_t*)out.data_ptr();

// decoded_data is *almost* the raw decoded data, but not quite: for some
// images, there may be some padding at the end of each row, i.e. when stride
// != row_size_in_bytes. So we can't copy decoded_data into the tensor's
// memory directly, we have to copy row by row. Oh, and if you think you can
// take a shortcut when stride == row_size_in_bytes and just do:
// out = torch::from_blob(decoded_data, ...)
// you can't, because decoded_data is owned by the heif::Image object and it
// gets freed when it gets out of scope!
auto row_size_in_bytes = width * num_channels * ((bit_depth == 8) ? 1 : 2);
for (auto h = 0; h < height; h++) {
memcpy(
out_ptr + h * row_size_in_bytes,
decoded_data + h * stride,
row_size_in_bytes);
}
if (bit_depth > 8) {
// Say bit depth is 10. decodec_data and out_ptr contain 10bits values
// over 2 bytes, stored into uint16_t. In torchvision a uint16 value is
// expected to be in [0, 2**16), so we have to map the 10bits value to that
// range. Note that other libraries like libavif do that mapping
// automatically.
// TODO: It's possible to avoid the memcpy call above in this case, and do
// the copy at the same time as the conversation. Whether it's worth it
// should be benchmarked.
auto out_ptr_16 = (uint16_t*)out_ptr;
for (auto p = 0; p < height * width * num_channels; p++) {
out_ptr_16[p] <<= (16 - bit_depth);
}
}
return out.permute({2, 0, 1});
}
#endif // HEIC_FOUND

} // namespace image
} // namespace vision
14 changes: 14 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_heic.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_heic(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);

} // namespace image
} // namespace vision
12 changes: 12 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "decode_avif.h"
#include "decode_gif.h"
#include "decode_heic.h"
#include "decode_jpeg.h"
#include "decode_png.h"
#include "decode_webp.h"
Expand Down Expand Up @@ -61,6 +62,17 @@ torch::Tensor decode_image(
return decode_avif(data, mode);
}

// Similarly for heic we assume the signature is "ftypeheic" but some files
// may come as "ftypmif1" where the "heic" part is defined later in the file.
// We can't be re-inventing libmagic here. We might need to start relying on
// it though...
const uint8_t heic_signature[8] = {
0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63}; // == "ftypheic"
TORCH_CHECK(data.numel() >= 12, err_msg);
if ((memcmp(heic_signature, datap + 4, 8) == 0)) {
return decode_heic(data, mode);
}

const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF"
const uint8_t webp_signature_end[7] = {
0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8"
Expand Down
2 changes: 2 additions & 0 deletions torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ static auto registry =
&decode_jpeg)
.op("image::decode_webp(Tensor encoded_data, int mode) -> Tensor",
&decode_webp)
.op("image::decode_heic(Tensor encoded_data, int mode) -> Tensor",
&decode_heic)
.op("image::decode_avif(Tensor encoded_data, int mode) -> Tensor",
&decode_avif)
.op("image::encode_jpeg", &encode_jpeg)
Expand Down
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "cpu/decode_avif.h"
#include "cpu/decode_gif.h"
#include "cpu/decode_heic.h"
#include "cpu/decode_image.h"
#include "cpu/decode_jpeg.h"
#include "cpu/decode_png.h"
Expand Down
1 change: 1 addition & 0 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
"decode_image",
"decode_jpeg",
"decode_png",
"decode_heic",
"decode_webp",
"decode_gif",
"encode_jpeg",
Expand Down
28 changes: 27 additions & 1 deletion torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,5 +417,31 @@ def _decode_avif(
Decoded image (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
_log_api_usage_once(_decode_avif)
return torch.ops.image.decode_avif(input, mode.value)


def _decode_heic(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
"""
Decode an HEIC image into a 3 dimensional RGB[A] Tensor.
The values of the output tensor are in uint8 in [0, 255] for most images. If
the image has a bit-depth of more than 8, then the output tensor is uint16
in [0, 65535]. Since uint16 support is limited in pytorch, we recommend
calling :func:`torchvision.transforms.v2.functional.to_dtype()` with
``scale=True`` after this function to convert the decoded image into a uint8
or float tensor.
Args:
input (Tensor[1]): a one dimensional contiguous uint8 tensor containing
the raw bytes of the HEIC image.
mode (ImageReadMode): The read mode used for optionally
converting the image color space. Default: ``ImageReadMode.UNCHANGED``.
Other supported values are ``ImageReadMode.RGB`` and ``ImageReadMode.RGB_ALPHA``.
Returns:
Decoded image (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(_decode_heic)
return torch.ops.image.decode_heic(input, mode.value)

0 comments on commit c10f938

Please sign in to comment.