diff --git a/mapillary_tools/geotag/camm_builder.py b/mapillary_tools/geotag/camm_builder.py index f9729d77..521bea51 100644 --- a/mapillary_tools/geotag/camm_builder.py +++ b/mapillary_tools/geotag/camm_builder.py @@ -1,4 +1,4 @@ -import dataclasses +import io import typing as T from .. import geo @@ -228,17 +228,6 @@ def create_camm_trak( } -@dataclasses.dataclass -class CAMMPointReader(builder.Reader): - __slots__ = ("point",) - - def __init__(self, point: geo.Point): - self.point = point - - def read(self): - return build_camm_sample(self.point) - - def extract_points(fp: T.BinaryIO) -> T.Tuple[str, T.List[geo.Point]]: start_offset = fp.tell() points = camm_parser.extract_points(fp) @@ -262,8 +251,8 @@ def camm_sample_generator2(points: T.Sequence[geo.Point]): def _f( fp: T.BinaryIO, moov_children: T.List[BoxDict], - ) -> T.Generator[builder.Reader, None, None]: - movie_timescale = builder._find_movie_timescale(moov_children) + ) -> T.Generator[io.IOBase, None, None]: + movie_timescale = builder.find_movie_timescale(moov_children) # make sure the precision of timedeltas not lower than 0.001 (1ms) media_timescale = max(1000, movie_timescale) camm_samples = list(convert_points_to_raw_samples(points, media_timescale)) @@ -279,7 +268,7 @@ def _f( moov_children.append(camm_trak) # if yield, the moov_children will not be modified - return (CAMMPointReader(point) for point in points) + return (io.BytesIO(build_camm_sample(point)) for point in points) return _f @@ -287,7 +276,7 @@ def _f( def camm_sample_generator( fp: T.BinaryIO, moov_children: T.List[BoxDict], -) -> T.Iterator[builder.Reader]: +) -> T.Iterator[io.IOBase]: fp.seek(0) _, points = extract_points(fp) if not points: diff --git a/mapillary_tools/geotag/io_utils.py b/mapillary_tools/geotag/io_utils.py new file mode 100644 index 00000000..23edc50d --- /dev/null +++ b/mapillary_tools/geotag/io_utils.py @@ -0,0 +1,178 @@ +import io +import typing as T + + +class ChainedIO(io.IOBase): + # is the chained stream seekable? + _streams: T.Sequence[io.IOBase] + # the beginning offset of the current stream + _begin_offset: int + # offset after SEEK_END + _offset_after_seek_end: int + # index of the current stream + _idx: int + + def __init__(self, streams: T.Sequence[io.IOBase]): + for s in streams: + assert s.readable(), f"stream {s} must be readable" + assert s.seekable(), f"stream {s} must be seekable" + # required, otherwise inconsistent results when seeking back and forth + s.seek(0, io.SEEK_SET) + self._streams = streams + self._begin_offset = 0 + self._offset_after_seek_end = 0 + self._idx = 0 + + def _seek_next_stream(self) -> None: + """ + seek to the end of the current stream, and seek to the beginning of the next stream + """ + if self._idx < len(self._streams): + s = self._streams[self._idx] + ssize = s.seek(0, io.SEEK_END) + + # update index + self._idx += 1 + + # seek to the beginning of the next stream + if self._idx < len(self._streams): + self._streams[self._idx].seek(0, io.SEEK_SET) + + # update offset + self._begin_offset += ssize + + def read(self, n: int = -1) -> bytes: + acc = [] + + while self._idx < len(self._streams): + data = self._streams[self._idx].read(n) + acc.append(data) + if n == -1: + self._seek_next_stream() + elif len(data) < n: + n = n - len(data) + self._seek_next_stream() + else: + break + + return b"".join(acc) + + def seekable(self) -> bool: + return True + + def writable(self) -> bool: + return False + + def readable(self) -> bool: + return True + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + if whence == io.SEEK_CUR: + if offset < 0: + raise ValueError("negative offset not supported yet") + + while self._idx < len(self._streams): + s = self._streams[self._idx] + co = s.tell() + eo = s.seek(0, io.SEEK_END) + assert co <= eo + if offset <= eo - co: + s.seek(co + offset, io.SEEK_SET) + offset = 0 + break + self._seek_next_stream() + offset = offset - (eo - co) + + if 0 < offset: + self._offset_after_seek_end += offset + + elif whence == io.SEEK_SET: + self._idx = 0 + self._begin_offset = 0 + self._offset_after_seek_end = 0 + self._streams[self._idx].seek(0, io.SEEK_SET) + if offset: + self.seek(offset, io.SEEK_CUR) + + elif whence == io.SEEK_END: + self._idx = 0 + self._begin_offset = 0 + self._offset_after_seek_end = 0 + while self._idx < len(self._streams): + self._seek_next_stream() + if offset: + self.seek(offset, io.SEEK_CUR) + + else: + raise IOError("invalid whence") + + return self.tell() + + def tell(self) -> int: + if self._idx < len(self._streams): + rel_offset = self._streams[self._idx].tell() + else: + rel_offset = self._offset_after_seek_end + + return self._begin_offset + rel_offset + + def close(self) -> None: + for b in self._streams: + b.close() + return None + + +class SlicedIO(io.IOBase): + __slots__ = ("_source", "_begin_offset", "_rel_offset", "_size") + + _source: T.BinaryIO + _begin_offset: int + _rel_offset: int + _size: int + + def __init__(self, source: T.BinaryIO, offset: int, size: int) -> None: + assert source.readable(), "source stream must be readable" + assert source.seekable(), "source stream must be seekable" + self._source = source + if offset < 0: + raise ValueError(f"negative offset {offset}") + self._begin_offset = offset + self._rel_offset = 0 + self._size = size + + def read(self, n: int = -1) -> bytes: + if self._rel_offset < self._size: + self._source.seek(self._begin_offset + self._rel_offset, io.SEEK_SET) + remaining = self._size - self._rel_offset + max_read = remaining if n == -1 else min(n, remaining) + data = self._source.read(max_read) + self._rel_offset += len(data) + return data + else: + return b"" + + def seekable(self) -> bool: + return True + + def writable(self) -> bool: + return False + + def readable(self) -> bool: + return True + + def tell(self) -> int: + return self._rel_offset + + def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: + if whence == io.SEEK_SET: + new_offset = offset + if new_offset < 0: + raise ValueError(f"negative seek value {new_offset}") + elif whence == io.SEEK_CUR: + new_offset = max(0, self._rel_offset + offset) + elif whence == io.SEEK_END: + new_offset = max(0, self._size + offset) + else: + raise IOError("invalid whence") + self._rel_offset = new_offset + return self._rel_offset diff --git a/mapillary_tools/geotag/simple_mp4_builder.py b/mapillary_tools/geotag/simple_mp4_builder.py index acb20a5b..ccf1c8a2 100644 --- a/mapillary_tools/geotag/simple_mp4_builder.py +++ b/mapillary_tools/geotag/simple_mp4_builder.py @@ -11,7 +11,7 @@ import construct as C -from . import simple_mp4_parser as parser +from . import io_utils, simple_mp4_parser as parser from .simple_mp4_parser import ( ChunkLargeOffsetBox, ChunkOffsetBox, @@ -453,7 +453,7 @@ def _filter_moov_children_boxes( yield box -def _find_movie_timescale(moov_children: T.Sequence[BoxDict]) -> int: +def find_movie_timescale(moov_children: T.Sequence[BoxDict]) -> int: mvhd = _find_box_at_pathx(moov_children, [b"mvhd"]) return T.cast(T.Dict, mvhd["data"])["timescale"] @@ -467,29 +467,10 @@ def _build_moov_bytes(moov_children: T.Sequence[BoxDict]) -> bytes: ) -class Reader: - def read(self): - raise NotImplementedError - - -@dataclasses.dataclass -class SampleReader(Reader): - __slots__ = ("fp", "offset", "size") - - fp: T.BinaryIO - offset: int - size: int - - def read(self): - self.fp.seek(self.offset) - return self.fp.read(self.size) - - def transform_mp4( src_fp: T.BinaryIO, - target_fp: T.BinaryIO, - sample_generator: T.Callable[[T.BinaryIO, T.List[BoxDict]], T.Iterator[Reader]], -): + sample_generator: T.Callable[[T.BinaryIO, T.List[BoxDict]], T.Iterator[io.IOBase]], +) -> io_utils.ChainedIO: # extract ftyp src_fp.seek(0) source_ftyp_box_data = parser.parse_data_firstx(src_fp, [b"ftyp"]) @@ -507,24 +488,28 @@ def transform_mp4( # extract video samples source_samples = list(iterate_samples(moov_children)) - movie_sample_readers = ( - SampleReader(src_fp, sample.offset, sample.size) for sample in source_samples - ) - - sample_readers = itertools.chain( - movie_sample_readers, sample_generator(src_fp, moov_children) - ) + movie_sample_readers = [ + io_utils.SlicedIO(src_fp, sample.offset, sample.size) + for sample in source_samples + ] + sample_readers = list(sample_generator(src_fp, moov_children)) _update_all_trak_tkhd(moov_children) # moov_boxes should be immutable since here - target_fp.write(source_ftyp_data) - target_fp.write(rewrite_moov(target_fp.tell(), moov_children)) mdat_body_size = sum(sample.size for sample in iterate_samples(moov_children)) - write_mdat(target_fp, mdat_body_size, sample_readers) + return io_utils.ChainedIO( + [ + io.BytesIO(source_ftyp_data), + io.BytesIO(_rewrite_moov(len(source_ftyp_data), moov_children)), + io.BytesIO(_build_mdat_header_bytes(mdat_body_size)), + *movie_sample_readers, + *sample_readers, + ] + ) -def rewrite_moov(moov_offset: int, moov_boxes: T.Sequence[BoxDict]) -> bytes: +def _rewrite_moov(moov_offset: int, moov_boxes: T.Sequence[BoxDict]) -> bytes: # build moov for calculating moov size sample_offset = 0 for box in _filter_trak_boxes(moov_boxes): @@ -544,10 +529,3 @@ def rewrite_moov(moov_offset: int, moov_boxes: T.Sequence[BoxDict]) -> bytes: assert len(moov_data) == moov_data_size, f"{len(moov_data)} != {moov_data_size}" return moov_data - - -def write_mdat(fp: T.BinaryIO, mdat_body_size: int, sample_readers: T.Iterable[Reader]): - mdat_header = _build_mdat_header_bytes(mdat_body_size) - fp.write(mdat_header) - for reader in sample_readers: - fp.write(reader.read()) diff --git a/mapillary_tools/process_import_meta_properties.py b/mapillary_tools/process_import_meta_properties.py index 0e4dfb1a..4af2716b 100644 --- a/mapillary_tools/process_import_meta_properties.py +++ b/mapillary_tools/process_import_meta_properties.py @@ -1,5 +1,5 @@ -import time import os +import time import typing as T from . import exceptions, types, VERSION diff --git a/tests/integration/test_gopro.py b/tests/integration/test_gopro.py index 4ea50c3e..ebf33a19 100644 --- a/tests/integration/test_gopro.py +++ b/tests/integration/test_gopro.py @@ -1,5 +1,5 @@ -import os import json +import os import subprocess import typing as T from pathlib import Path diff --git a/tests/unit/test_camm_parser.py b/tests/unit/test_camm_parser.py index 0b96bb9a..cf8282ef 100644 --- a/tests/unit/test_camm_parser.py +++ b/tests/unit/test_camm_parser.py @@ -52,13 +52,11 @@ def build_mp4(points: T.List[geo.Point]) -> T.Optional[T.List[geo.Point]]: {"type": b"moov", "data": [mvhd]}, ] src = simple_mp4_builder.QuickBoxStruct32.BoxList.build(empty_mp4) - target_fp = io.BytesIO() - simple_mp4_builder.transform_mp4( - io.BytesIO(src), target_fp, camm_builder.camm_sample_generator2(points) + target_fp = simple_mp4_builder.transform_mp4( + io.BytesIO(src), camm_builder.camm_sample_generator2(points) ) - target_fp.seek(0) - return camm_parser.extract_points(target_fp) + return camm_parser.extract_points(T.cast(T.BinaryIO, target_fp)) def approximate(expected, actual): diff --git a/tests/unit/test_io_utils.py b/tests/unit/test_io_utils.py new file mode 100644 index 00000000..1f93c28a --- /dev/null +++ b/tests/unit/test_io_utils.py @@ -0,0 +1,123 @@ +import io +import random + +from mapillary_tools.geotag.io_utils import ChainedIO, SlicedIO + + +def test_chained(): + data = b"helloworldworldfoobarworld" + c = io.BytesIO(data) + s = ChainedIO( + [ + io.BytesIO(b"hello"), + ChainedIO([io.BytesIO(b"world")]), + ChainedIO( + [ + ChainedIO([io.BytesIO(b""), io.BytesIO(b"")]), + io.BytesIO(b"world"), + io.BytesIO(b"foo"), + ChainedIO([io.BytesIO(b"")]), + ] + ), + ChainedIO([io.BytesIO(b"")]), + ChainedIO([io.BytesIO(b"bar")]), + ChainedIO( + [ + SlicedIO(io.BytesIO(data), 5, 5), + ChainedIO([io.BytesIO(b"")]), + ] + ), + ] + ) + + assert s.seek(0) == 0 + assert c.seek(0) == 0 + assert s.read() == c.read() + + assert s.seek(2, io.SEEK_CUR) == len(data) + 2 + assert c.seek(2, io.SEEK_CUR) == len(data) + 2 + assert s.read() == c.read() + + assert s.seek(6) == 6 + assert c.seek(6) == 6 + assert s.read() == c.read() + + assert s.seek(2, io.SEEK_END) == len(data) + 2 + assert c.seek(2, io.SEEK_END) == len(data) + 2 + assert s.read() == c.read() + + assert s.seek(0) == 0 + assert c.seek(0) == 0 + assert s.read(1) == b"h" + assert s.read(1000) == data[1:] + assert s.read() == b"" + assert s.read(1) == b"" + + assert s.seek(0, io.SEEK_END) == len(data) + assert c.seek(0, io.SEEK_END) == len(data) + + c.seek(0) + s.seek(0) + for _ in range(10000): + whence = random.choice([io.SEEK_SET, io.SEEK_CUR, io.SEEK_END]) + offset = random.randint(0, 30) + assert s.tell() == c.tell() + thrown_x = None + try: + x = s.seek(offset, whence) + except ValueError as ex: + thrown_x = ex + thrown_y = None + try: + y = c.seek(offset, whence) + except ValueError as ex: + thrown_y = ex + assert (thrown_x is not None and thrown_y is not None) or ( + thrown_x is None and thrown_y is None + ), (thrown_x, thrown_y, whence, offset) + if not thrown_x: + assert ( + x == y + ), f"whence={whence} offset={offset} x={x} y={y} {s.tell()} {c.tell()}" + + n = random.randint(-1, 20) + assert s.read(n) == c.read(n), f"n={n}" + assert s.tell() == c.tell() + + +def test_sliced(): + s = io.BytesIO(b"helloworldfoo") + sliced = SlicedIO(s, 5, 5) + c = io.BytesIO(b"world") + + for _ in range(10000): + whence = random.choice([io.SEEK_SET, io.SEEK_CUR, io.SEEK_END]) + offset = random.randint(-10, 10) + thrown_x = None + try: + x = sliced.seek(offset, whence) + except ValueError as ex: + thrown_x = ex + thrown_y = None + try: + y = c.seek(offset, whence) + except ValueError as ex: + thrown_y = ex + assert (thrown_x is not None and thrown_y is not None) or ( + thrown_x is None and thrown_y is None + ), (thrown_x, thrown_y, whence, offset) + if not thrown_x: + assert x == y + + n = random.randint(-1, 20) + assert sliced.read(n) == c.read(n) + assert sliced.tell() == c.tell() + + +def test_truncate(): + c = io.BytesIO(b"helloworld") + c.truncate(3) + assert c.read() == b"hel" + s = SlicedIO(c, 1, 5) + assert s.read() == b"el" + assert s.read() == b""