Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct timestamps when merging fragmented WebVTT #67

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions devine/core/manifests/dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,15 @@ def download_track(
segment_base = adaptation_set.find("SegmentBase")

segments: list[tuple[str, Optional[str]]] = []
segment_timescale: float = 0
segment_durations: list[int] = []
track_kid: Optional[UUID] = None

if segment_template is not None:
segment_template = copy(segment_template)
start_number = int(segment_template.get("startNumber") or 1)
segment_timeline = segment_template.find("SegmentTimeline")
segment_timescale = float(segment_template.get("timescale") or 1)

for item in ("initialization", "media"):
value = segment_template.get(item)
Expand Down Expand Up @@ -318,17 +321,16 @@ def download_track(
track_kid = track.get_key_id(init_data)

if segment_timeline is not None:
seg_time_list = []
current_time = 0
for s in segment_timeline.findall("S"):
if s.get("t"):
current_time = int(s.get("t"))
for _ in range(1 + (int(s.get("r") or 0))):
seg_time_list.append(current_time)
segment_durations.append(current_time)
current_time += int(s.get("d"))
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
seg_num_list = list(range(start_number, len(segment_durations) + start_number))

for t, n in zip(seg_time_list, seg_num_list):
for t, n in zip(segment_durations, seg_num_list):
segments.append((
DASH.replace_fields(
segment_template.get("media"),
Expand All @@ -342,8 +344,7 @@ def download_track(
if not period_duration:
raise ValueError("Duration of the Period was unable to be determined.")
period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration"))
segment_timescale = float(segment_template.get("timescale") or 1)
segment_duration = float(segment_template.get("duration")) or 1
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))

for s in range(start_number, start_number + total_segments):
Expand All @@ -356,7 +357,11 @@ def download_track(
Time=s
), None
))
# TODO: Should we floor/ceil/round, or is int() ok?
segment_durations.append(int(segment_duration))
elif segment_list is not None:
segment_timescale = float(segment_list.get("timescale") or 1)

init_data = None
initialization = segment_list.find("Initialization")
if initialization is not None:
Expand Down Expand Up @@ -388,6 +393,7 @@ def download_track(
media_url,
segment_url.get("mediaRange")
))
segment_durations.append(int(segment_url.get("duration") or 1))
elif segment_base is not None:
media_range = None
init_data = None
Expand Down Expand Up @@ -420,6 +426,10 @@ def download_track(
log.debug(track.url)
sys.exit(1)

# TODO: Should we floor/ceil/round, or is int() ok?
track.data["dash"]["timescale"] = int(segment_timescale)
track.data["dash"]["segment_durations"] = segment_durations

if not track.drm and isinstance(track, (Video, Audio)):
try:
track.drm = [Widevine.from_init_data(init_data)]
Expand Down
6 changes: 6 additions & 0 deletions devine/core/manifests/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,15 @@ def download_track(
downloader = track.downloader

urls: list[dict[str, Any]] = []
segment_durations: list[int] = []

range_offset = 0
for segment in master.segments:
if segment in unwanted_segments:
continue

segment_durations.append(int(segment.duration))

if segment.byterange:
if downloader.__name__ == "aria2c":
# aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader
Expand All @@ -277,6 +281,8 @@ def download_track(
} if byte_range else {}
})

track.data["hls"]["segment_durations"] = segment_durations

segment_save_dir = save_dir / "segments"

for status_update in downloader(
Expand Down
51 changes: 42 additions & 9 deletions devine/core/tracks/subtitle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, Iterable, Optional
from typing import Any, Callable, Iterable, Optional, Union

import pycaption
import requests
Expand All @@ -20,6 +20,7 @@
from devine.core import binaries
from devine.core.tracks.track import Track
from devine.core.utilities import try_ensure_utf8
from devine.core.utils.webvtt import merge_segmented_webvtt


class Subtitle(Track):
Expand Down Expand Up @@ -202,6 +203,24 @@ def download(
self.convert(Subtitle.Codec.TimedTextMarkupLang)
elif self.codec == Subtitle.Codec.fVTT:
self.convert(Subtitle.Codec.WebVTT)
elif self.codec == Subtitle.Codec.WebVTT:
text = self.path.read_text("utf8")
if self.descriptor == Track.Descriptor.DASH:
text = merge_segmented_webvtt(
text,
segment_durations=self.data["dash"]["segment_durations"],
timescale=self.data["dash"]["timescale"]
)
elif self.descriptor == Track.Descriptor.HLS:
text = merge_segmented_webvtt(
text,
segment_durations=self.data["hls"]["segment_durations"],
timescale=1 # ?
)
caption_set = pycaption.WebVTTReader().read(text)
Subtitle.merge_same_cues(caption_set)
subtitle_text = pycaption.WebVTTWriter().write(caption_set)
self.path.write_text(subtitle_text, encoding="utf8")

def convert(self, codec: Subtitle.Codec) -> Path:
"""
Expand Down Expand Up @@ -308,14 +327,7 @@ def parse(data: bytes, codec: Subtitle.Codec) -> pycaption.CaptionSet:
caption_lists[language] = caption_list
caption_set: pycaption.CaptionSet = pycaption.CaptionSet(caption_lists)
elif codec == Subtitle.Codec.WebVTT:
text = try_ensure_utf8(data).decode("utf8")
# Segmented VTT when merged may have the WEBVTT headers part of the next caption
# if they are not separated far enough from the previous caption, hence the \n\n
text = text. \
replace("WEBVTT", "\n\nWEBVTT"). \
replace("\r", ""). \
replace("\n\n\n", "\n \n\n"). \
replace("\n\n<", "\n<")
text = Subtitle.space_webvtt_headers(data)
caption_set = pycaption.WebVTTReader().read(text)
else:
raise ValueError(f"Unknown Subtitle format \"{codec}\"...")
Expand All @@ -332,6 +344,27 @@ def parse(data: bytes, codec: Subtitle.Codec) -> pycaption.CaptionSet:

return caption_set

@staticmethod
def space_webvtt_headers(data: Union[str, bytes]):
"""
Space out the WEBVTT Headers from Captions.

Segmented VTT when merged may have the WEBVTT headers part of the next caption
as they were not separated far enough from the previous caption and ended up
being considered as caption text rather than the header for the next segment.
"""
if isinstance(data, bytes):
data = try_ensure_utf8(data).decode("utf8")
elif not isinstance(data, str):
raise ValueError(f"Expecting data to be a str, not {data!r}")

text = data.replace("WEBVTT", "\n\nWEBVTT").\
replace("\r", "").\
replace("\n\n\n", "\n \n\n").\
replace("\n\n<", "\n<")

return text

@staticmethod
def merge_same_cues(caption_set: pycaption.CaptionSet):
"""Merge captions with the same timecodes and text as one in-place."""
Expand Down
191 changes: 191 additions & 0 deletions devine/core/utils/webvtt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import re
import sys
import typing
from typing import Optional

from pycaption import Caption, CaptionList, CaptionNode, CaptionReadError, WebVTTReader, WebVTTWriter


class CaptionListExt(CaptionList):
@typing.no_type_check
def __init__(self, iterable=None, layout_info=None):
self.first_segment_mpegts = 0
super().__init__(iterable, layout_info)


class CaptionExt(Caption):
@typing.no_type_check
def __init__(self, start, end, nodes, style=None, layout_info=None, segment_index=0, mpegts=0, cue_time=0.0):
style = style or {}
self.segment_index: int = segment_index
self.mpegts: float = mpegts
self.cue_time: float = cue_time
super().__init__(start, end, nodes, style, layout_info)


class WebVTTReaderExt(WebVTTReader):
# HLS extension support <https://datatracker.ietf.org/doc/html/rfc8216#section-3.5>
RE_TIMESTAMP_MAP = re.compile(r"X-TIMESTAMP-MAP.*")
RE_MPEGTS = re.compile(r"MPEGTS:(\d+)")
RE_LOCAL = re.compile(r"LOCAL:((?:(\d{1,}):)?(\d{2}):(\d{2})\.(\d{3}))")

def _parse(self, lines: list[str]) -> CaptionList:
captions = CaptionListExt()
start = None
end = None
nodes: list[CaptionNode] = []
layout_info = None
found_timing = False
segment_index = -1
mpegts = 0
cue_time = 0.0

# The first segment MPEGTS is needed to calculate the rest. It is possible that
# the first segment contains no cue and is ignored by pycaption, this acts as a fallback.
captions.first_segment_mpegts = 0

for i, line in enumerate(lines):
if "-->" in line:
found_timing = True
timing_line = i
last_start_time = captions[-1].start if captions else 0
try:
start, end, layout_info = self._parse_timing_line(line, last_start_time)
except CaptionReadError as e:
new_msg = f"{e.args[0]} (line {timing_line})"
tb = sys.exc_info()[2]
raise type(e)(new_msg).with_traceback(tb) from None

elif "" == line:
if found_timing and nodes:
found_timing = False
caption = CaptionExt(
start,
end,
nodes,
layout_info=layout_info,
segment_index=segment_index,
mpegts=mpegts,
cue_time=cue_time,
)
captions.append(caption)
nodes = []

elif "WEBVTT" in line:
# Merged segmented VTT doesn't have index information, track manually.
segment_index += 1
mpegts = 0
cue_time = 0.0
elif m := self.RE_TIMESTAMP_MAP.match(line):
if r := self.RE_MPEGTS.search(m.group()):
mpegts = int(r.group(1))

cue_time = self._parse_local(m.group())

# Early assignment in case the first segment contains no cue.
if segment_index == 0:
captions.first_segment_mpegts = mpegts

else:
if found_timing:
if nodes:
nodes.append(CaptionNode.create_break())
nodes.append(CaptionNode.create_text(self._decode(line)))
else:
# it's a comment or some metadata; ignore it
pass

# Add a last caption if there are remaining nodes
if nodes:
caption = CaptionExt(start, end, nodes, layout_info=layout_info, segment_index=segment_index, mpegts=mpegts)
captions.append(caption)

return captions

@staticmethod
def _parse_local(string: str) -> float:
"""
Parse WebVTT LOCAL time and convert it to seconds.
"""
m = WebVTTReaderExt.RE_LOCAL.search(string)
if not m:
return 0

parsed = m.groups()
if not parsed:
return 0
hours = int(parsed[1])
minutes = int(parsed[2])
seconds = int(parsed[3])
milliseconds = int(parsed[4])
return (milliseconds / 1000) + seconds + (minutes * 60) + (hours * 3600)


def merge_segmented_webvtt(vtt_raw: str, segment_durations: Optional[list[int]] = None, timescale: int = 1) -> str:
"""
Merge Segmented WebVTT data.

Parameters:
vtt_raw: The concatenated WebVTT files to merge. All WebVTT headers must be
appropriately spaced apart, or it may produce unwanted effects like
considering headers as captions, timestamp lines, etc.
segment_durations: A list of each segment's duration. If not provided it will try
to get it from the X-TIMESTAMP-MAP headers, specifically the MPEGTS number.
timescale: The number of time units per second.

This parses the X-TIMESTAMP-MAP data to compute new absolute timestamps, replacing
the old start and end timestamp values. All X-TIMESTAMP-MAP header information will
be removed from the output as they are no longer of concern. Consider this function
the opposite of a WebVTT Segmenter, a WebVTT Joiner of sorts.

Algorithm borrowed from N_m3u8DL-RE and shaka-player.
"""
MPEG_TIMESCALE = 90_000

vtt = WebVTTReaderExt().read(vtt_raw)
for lang in vtt.get_languages():
prev_caption = None
duplicate_index: list[int] = []
captions = vtt.get_captions(lang)

if captions[0].segment_index == 0:
first_segment_mpegts = captions[0].mpegts
else:
first_segment_mpegts = segment_durations[0] if segment_durations else captions.first_segment_mpegts

caption: CaptionExt
for i, caption in enumerate(captions):
# DASH WebVTT doesn't have MPEGTS timestamp like HLS. Instead,
# calculate the timestamp from SegmentTemplate/SegmentList duration.
likely_dash = first_segment_mpegts == 0 and caption.mpegts == 0
if likely_dash and segment_durations:
duration = segment_durations[caption.segment_index]
caption.mpegts = MPEG_TIMESCALE * (duration / timescale)

if caption.mpegts == 0:
continue

seconds = (caption.mpegts - first_segment_mpegts) / MPEG_TIMESCALE - caption.cue_time
offset = seconds * 1_000_000 # pycaption use microseconds

if caption.start < offset:
caption.start += offset
caption.end += offset

# If the difference between current and previous captions is <=1ms
# and the payload is equal then splice.
if (
prev_caption
and not caption.is_empty()
and (caption.start - prev_caption.end) <= 1000 # 1ms in microseconds
and caption.get_text() == prev_caption.get_text()
):
prev_caption.end = caption.end
duplicate_index.append(i)

prev_caption = caption

# Remove duplicate
captions[:] = [c for c_index, c in enumerate(captions) if c_index not in set(duplicate_index)]

return WebVTTWriter().write(vtt)
Loading