Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: streamable mp4 transform #550

Merged
merged 5 commits into from
Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions mapillary_tools/geotag/camm_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import dataclasses
import io
import typing as T

from .. import geo
Expand Down Expand Up @@ -228,17 +228,6 @@ def create_camm_trak(
}


@dataclasses.dataclass
class CAMMPointReader(builder.Reader):
__slots__ = ("point",)

def __init__(self, point: geo.Point):
self.point = point

def read(self):
return build_camm_sample(self.point)


def extract_points(fp: T.BinaryIO) -> T.Tuple[str, T.List[geo.Point]]:
start_offset = fp.tell()
points = camm_parser.extract_points(fp)
Expand All @@ -262,8 +251,8 @@ def camm_sample_generator2(points: T.Sequence[geo.Point]):
def _f(
fp: T.BinaryIO,
moov_children: T.List[BoxDict],
) -> T.Generator[builder.Reader, None, None]:
movie_timescale = builder._find_movie_timescale(moov_children)
) -> T.Generator[io.IOBase, None, None]:
movie_timescale = builder.find_movie_timescale(moov_children)
# make sure the precision of timedeltas not lower than 0.001 (1ms)
media_timescale = max(1000, movie_timescale)
camm_samples = list(convert_points_to_raw_samples(points, media_timescale))
Expand All @@ -279,15 +268,15 @@ def _f(
moov_children.append(camm_trak)

# if yield, the moov_children will not be modified
return (CAMMPointReader(point) for point in points)
return (io.BytesIO(build_camm_sample(point)) for point in points)

return _f


def camm_sample_generator(
fp: T.BinaryIO,
moov_children: T.List[BoxDict],
) -> T.Iterator[builder.Reader]:
) -> T.Iterator[io.IOBase]:
fp.seek(0)
_, points = extract_points(fp)
if not points:
Expand Down
178 changes: 178 additions & 0 deletions mapillary_tools/geotag/io_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import io
import typing as T


class ChainedIO(io.IOBase):
# is the chained stream seekable?
_streams: T.Sequence[io.IOBase]
# the beginning offset of the current stream
_begin_offset: int
# offset after SEEK_END
_offset_after_seek_end: int
# index of the current stream
_idx: int

def __init__(self, streams: T.Sequence[io.IOBase]):
for s in streams:
assert s.readable(), f"stream {s} must be readable"
assert s.seekable(), f"stream {s} must be seekable"
# required, otherwise inconsistent results when seeking back and forth
s.seek(0, io.SEEK_SET)
self._streams = streams
self._begin_offset = 0
self._offset_after_seek_end = 0
self._idx = 0

def _seek_next_stream(self) -> None:
"""
seek to the end of the current stream, and seek to the beginning of the next stream
"""
if self._idx < len(self._streams):
s = self._streams[self._idx]
ssize = s.seek(0, io.SEEK_END)

# update index
self._idx += 1

# seek to the beginning of the next stream
if self._idx < len(self._streams):
self._streams[self._idx].seek(0, io.SEEK_SET)

# update offset
self._begin_offset += ssize

def read(self, n: int = -1) -> bytes:
acc = []

while self._idx < len(self._streams):
data = self._streams[self._idx].read(n)
acc.append(data)
if n == -1:
self._seek_next_stream()
elif len(data) < n:
n = n - len(data)
self._seek_next_stream()
else:
break

return b"".join(acc)

def seekable(self) -> bool:
return True

def writable(self) -> bool:
return False

def readable(self) -> bool:
return True

def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence == io.SEEK_CUR:
if offset < 0:
raise ValueError("negative offset not supported yet")

while self._idx < len(self._streams):
s = self._streams[self._idx]
co = s.tell()
eo = s.seek(0, io.SEEK_END)
assert co <= eo
if offset <= eo - co:
s.seek(co + offset, io.SEEK_SET)
offset = 0
break
self._seek_next_stream()
offset = offset - (eo - co)

if 0 < offset:
self._offset_after_seek_end += offset

elif whence == io.SEEK_SET:
self._idx = 0
self._begin_offset = 0
self._offset_after_seek_end = 0
self._streams[self._idx].seek(0, io.SEEK_SET)
if offset:
self.seek(offset, io.SEEK_CUR)

elif whence == io.SEEK_END:
self._idx = 0
self._begin_offset = 0
self._offset_after_seek_end = 0
while self._idx < len(self._streams):
self._seek_next_stream()
if offset:
self.seek(offset, io.SEEK_CUR)

else:
raise IOError("invalid whence")

return self.tell()

def tell(self) -> int:
if self._idx < len(self._streams):
rel_offset = self._streams[self._idx].tell()
else:
rel_offset = self._offset_after_seek_end

return self._begin_offset + rel_offset

def close(self) -> None:
for b in self._streams:
b.close()
return None


class SlicedIO(io.IOBase):
__slots__ = ("_source", "_begin_offset", "_rel_offset", "_size")

_source: T.BinaryIO
_begin_offset: int
_rel_offset: int
_size: int

def __init__(self, source: T.BinaryIO, offset: int, size: int) -> None:
assert source.readable(), "source stream must be readable"
assert source.seekable(), "source stream must be seekable"
self._source = source
if offset < 0:
raise ValueError(f"negative offset {offset}")
self._begin_offset = offset
self._rel_offset = 0
self._size = size

def read(self, n: int = -1) -> bytes:
if self._rel_offset < self._size:
self._source.seek(self._begin_offset + self._rel_offset, io.SEEK_SET)
remaining = self._size - self._rel_offset
max_read = remaining if n == -1 else min(n, remaining)
data = self._source.read(max_read)
self._rel_offset += len(data)
return data
else:
return b""

def seekable(self) -> bool:
return True

def writable(self) -> bool:
return False

def readable(self) -> bool:
return True

def tell(self) -> int:
return self._rel_offset

def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if whence == io.SEEK_SET:
new_offset = offset
if new_offset < 0:
raise ValueError(f"negative seek value {new_offset}")
elif whence == io.SEEK_CUR:
new_offset = max(0, self._rel_offset + offset)
elif whence == io.SEEK_END:
new_offset = max(0, self._size + offset)
else:
raise IOError("invalid whence")
self._rel_offset = new_offset
return self._rel_offset
60 changes: 19 additions & 41 deletions mapillary_tools/geotag/simple_mp4_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import construct as C

from . import simple_mp4_parser as parser
from . import io_utils, simple_mp4_parser as parser
from .simple_mp4_parser import (
ChunkLargeOffsetBox,
ChunkOffsetBox,
Expand Down Expand Up @@ -453,7 +453,7 @@ def _filter_moov_children_boxes(
yield box


def _find_movie_timescale(moov_children: T.Sequence[BoxDict]) -> int:
def find_movie_timescale(moov_children: T.Sequence[BoxDict]) -> int:
mvhd = _find_box_at_pathx(moov_children, [b"mvhd"])
return T.cast(T.Dict, mvhd["data"])["timescale"]

Expand All @@ -467,29 +467,10 @@ def _build_moov_bytes(moov_children: T.Sequence[BoxDict]) -> bytes:
)


class Reader:
def read(self):
raise NotImplementedError


@dataclasses.dataclass
class SampleReader(Reader):
__slots__ = ("fp", "offset", "size")

fp: T.BinaryIO
offset: int
size: int

def read(self):
self.fp.seek(self.offset)
return self.fp.read(self.size)


def transform_mp4(
src_fp: T.BinaryIO,
target_fp: T.BinaryIO,
sample_generator: T.Callable[[T.BinaryIO, T.List[BoxDict]], T.Iterator[Reader]],
):
sample_generator: T.Callable[[T.BinaryIO, T.List[BoxDict]], T.Iterator[io.IOBase]],
) -> io_utils.ChainedIO:
# extract ftyp
src_fp.seek(0)
source_ftyp_box_data = parser.parse_data_firstx(src_fp, [b"ftyp"])
Expand All @@ -507,24 +488,28 @@ def transform_mp4(

# extract video samples
source_samples = list(iterate_samples(moov_children))
movie_sample_readers = (
SampleReader(src_fp, sample.offset, sample.size) for sample in source_samples
)

sample_readers = itertools.chain(
movie_sample_readers, sample_generator(src_fp, moov_children)
)
movie_sample_readers = [
io_utils.SlicedIO(src_fp, sample.offset, sample.size)
for sample in source_samples
]
sample_readers = list(sample_generator(src_fp, moov_children))

_update_all_trak_tkhd(moov_children)

# moov_boxes should be immutable since here
target_fp.write(source_ftyp_data)
target_fp.write(rewrite_moov(target_fp.tell(), moov_children))
mdat_body_size = sum(sample.size for sample in iterate_samples(moov_children))
write_mdat(target_fp, mdat_body_size, sample_readers)
return io_utils.ChainedIO(
[
io.BytesIO(source_ftyp_data),
io.BytesIO(_rewrite_moov(len(source_ftyp_data), moov_children)),
io.BytesIO(_build_mdat_header_bytes(mdat_body_size)),
*movie_sample_readers,
*sample_readers,
]
)


def rewrite_moov(moov_offset: int, moov_boxes: T.Sequence[BoxDict]) -> bytes:
def _rewrite_moov(moov_offset: int, moov_boxes: T.Sequence[BoxDict]) -> bytes:
# build moov for calculating moov size
sample_offset = 0
for box in _filter_trak_boxes(moov_boxes):
Expand All @@ -544,10 +529,3 @@ def rewrite_moov(moov_offset: int, moov_boxes: T.Sequence[BoxDict]) -> bytes:
assert len(moov_data) == moov_data_size, f"{len(moov_data)} != {moov_data_size}"

return moov_data


def write_mdat(fp: T.BinaryIO, mdat_body_size: int, sample_readers: T.Iterable[Reader]):
mdat_header = _build_mdat_header_bytes(mdat_body_size)
fp.write(mdat_header)
for reader in sample_readers:
fp.write(reader.read())
2 changes: 1 addition & 1 deletion mapillary_tools/process_import_meta_properties.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
import os
import time
import typing as T

from . import exceptions, types, VERSION
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_gopro.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import json
import os
import subprocess
import typing as T
from pathlib import Path
Expand Down
8 changes: 3 additions & 5 deletions tests/unit/test_camm_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,11 @@ def build_mp4(points: T.List[geo.Point]) -> T.Optional[T.List[geo.Point]]:
{"type": b"moov", "data": [mvhd]},
]
src = simple_mp4_builder.QuickBoxStruct32.BoxList.build(empty_mp4)
target_fp = io.BytesIO()
simple_mp4_builder.transform_mp4(
io.BytesIO(src), target_fp, camm_builder.camm_sample_generator2(points)
target_fp = simple_mp4_builder.transform_mp4(
io.BytesIO(src), camm_builder.camm_sample_generator2(points)
)

target_fp.seek(0)
return camm_parser.extract_points(target_fp)
return camm_parser.extract_points(T.cast(T.BinaryIO, target_fp))


def approximate(expected, actual):
Expand Down
Loading