From 0ba45decc6fe78f66307c1ce3112ed7b45d1562f Mon Sep 17 00:00:00 2001 From: Shivelight Date: Fri, 17 Nov 2023 11:27:34 +0800 Subject: [PATCH] fix(Subtitle): Correct timestamps when merging fragmented WebVTT This applies the X-TIMESTAMP-MAP data to timestamps as it reads through a concatenated (merged) WebVTT file to correct timestamps on segmented WebVTT streams. It then removes the X-TIMESTAMP-MAP header. The timescale and segment duration information is saved in the Subtitle's data dictionary under the hls/dash key: timescale (dash-only) and segment_durations. Note that this information will only be available post-download. This is done regardless if you are converting to another subtitle or not, since the downloader automatically and forcefully concatenated the segmented subtitle data. We do not support the use of segmented Subtitles for downloading or otherwise, nor do we plan to. --- devine/core/manifests/dash.py | 22 ++-- devine/core/manifests/hls.py | 6 ++ devine/core/tracks/subtitle.py | 51 +++++++-- devine/core/utils/webvtt.py | 191 +++++++++++++++++++++++++++++++++ 4 files changed, 255 insertions(+), 15 deletions(-) create mode 100644 devine/core/utils/webvtt.py diff --git a/devine/core/manifests/dash.py b/devine/core/manifests/dash.py index c7eab18..b3a556a 100644 --- a/devine/core/manifests/dash.py +++ b/devine/core/manifests/dash.py @@ -285,12 +285,15 @@ def download_track( segment_base = adaptation_set.find("SegmentBase") segments: list[tuple[str, Optional[str]]] = [] + segment_timescale: float = 0 + segment_durations: list[int] = [] track_kid: Optional[UUID] = None if segment_template is not None: segment_template = copy(segment_template) start_number = int(segment_template.get("startNumber") or 1) segment_timeline = segment_template.find("SegmentTimeline") + segment_timescale = float(segment_template.get("timescale") or 1) for item in ("initialization", "media"): value = segment_template.get(item) @@ -318,17 +321,16 @@ def download_track( track_kid = track.get_key_id(init_data) if segment_timeline is not None: - seg_time_list = [] current_time = 0 for s in segment_timeline.findall("S"): if s.get("t"): current_time = int(s.get("t")) for _ in range(1 + (int(s.get("r") or 0))): - seg_time_list.append(current_time) + segment_durations.append(current_time) current_time += int(s.get("d")) - seg_num_list = list(range(start_number, len(seg_time_list) + start_number)) + seg_num_list = list(range(start_number, len(segment_durations) + start_number)) - for t, n in zip(seg_time_list, seg_num_list): + for t, n in zip(segment_durations, seg_num_list): segments.append(( DASH.replace_fields( segment_template.get("media"), @@ -342,8 +344,7 @@ def download_track( if not period_duration: raise ValueError("Duration of the Period was unable to be determined.") period_duration = DASH.pt_to_sec(period_duration) - segment_duration = float(segment_template.get("duration")) - segment_timescale = float(segment_template.get("timescale") or 1) + segment_duration = float(segment_template.get("duration")) or 1 total_segments = math.ceil(period_duration / (segment_duration / segment_timescale)) for s in range(start_number, start_number + total_segments): @@ -356,7 +357,11 @@ def download_track( Time=s ), None )) + # TODO: Should we floor/ceil/round, or is int() ok? + segment_durations.append(int(segment_duration)) elif segment_list is not None: + segment_timescale = float(segment_list.get("timescale") or 1) + init_data = None initialization = segment_list.find("Initialization") if initialization is not None: @@ -388,6 +393,7 @@ def download_track( media_url, segment_url.get("mediaRange") )) + segment_durations.append(int(segment_url.get("duration") or 1)) elif segment_base is not None: media_range = None init_data = None @@ -420,6 +426,10 @@ def download_track( log.debug(track.url) sys.exit(1) + # TODO: Should we floor/ceil/round, or is int() ok? + track.data["dash"]["timescale"] = int(segment_timescale) + track.data["dash"]["segment_durations"] = segment_durations + if not track.drm and isinstance(track, (Video, Audio)): try: track.drm = [Widevine.from_init_data(init_data)] diff --git a/devine/core/manifests/hls.py b/devine/core/manifests/hls.py index c477dcb..de006e9 100644 --- a/devine/core/manifests/hls.py +++ b/devine/core/manifests/hls.py @@ -256,11 +256,15 @@ def download_track( downloader = track.downloader urls: list[dict[str, Any]] = [] + segment_durations: list[int] = [] + range_offset = 0 for segment in master.segments: if segment in unwanted_segments: continue + segment_durations.append(int(segment.duration)) + if segment.byterange: if downloader.__name__ == "aria2c": # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader @@ -277,6 +281,8 @@ def download_track( } if byte_range else {} }) + track.data["hls"]["segment_durations"] = segment_durations + segment_save_dir = save_dir / "segments" for status_update in downloader( diff --git a/devine/core/tracks/subtitle.py b/devine/core/tracks/subtitle.py index 866900a..56eae80 100644 --- a/devine/core/tracks/subtitle.py +++ b/devine/core/tracks/subtitle.py @@ -7,7 +7,7 @@ from functools import partial from io import BytesIO from pathlib import Path -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Iterable, Optional, Union import pycaption import requests @@ -20,6 +20,7 @@ from devine.core import binaries from devine.core.tracks.track import Track from devine.core.utilities import try_ensure_utf8 +from devine.core.utils.webvtt import merge_segmented_webvtt class Subtitle(Track): @@ -202,6 +203,24 @@ def download( self.convert(Subtitle.Codec.TimedTextMarkupLang) elif self.codec == Subtitle.Codec.fVTT: self.convert(Subtitle.Codec.WebVTT) + elif self.codec == Subtitle.Codec.WebVTT: + text = self.path.read_text("utf8") + if self.descriptor == Track.Descriptor.DASH: + text = merge_segmented_webvtt( + text, + segment_durations=self.data["dash"]["segment_durations"], + timescale=self.data["dash"]["timescale"] + ) + elif self.descriptor == Track.Descriptor.HLS: + text = merge_segmented_webvtt( + text, + segment_durations=self.data["hls"]["segment_durations"], + timescale=1 # ? + ) + caption_set = pycaption.WebVTTReader().read(text) + Subtitle.merge_same_cues(caption_set) + subtitle_text = pycaption.WebVTTWriter().write(caption_set) + self.path.write_text(subtitle_text, encoding="utf8") def convert(self, codec: Subtitle.Codec) -> Path: """ @@ -308,14 +327,7 @@ def parse(data: bytes, codec: Subtitle.Codec) -> pycaption.CaptionSet: caption_lists[language] = caption_list caption_set: pycaption.CaptionSet = pycaption.CaptionSet(caption_lists) elif codec == Subtitle.Codec.WebVTT: - text = try_ensure_utf8(data).decode("utf8") - # Segmented VTT when merged may have the WEBVTT headers part of the next caption - # if they are not separated far enough from the previous caption, hence the \n\n - text = text. \ - replace("WEBVTT", "\n\nWEBVTT"). \ - replace("\r", ""). \ - replace("\n\n\n", "\n \n\n"). \ - replace("\n\n<", "\n<") + text = Subtitle.space_webvtt_headers(data) caption_set = pycaption.WebVTTReader().read(text) else: raise ValueError(f"Unknown Subtitle format \"{codec}\"...") @@ -332,6 +344,27 @@ def parse(data: bytes, codec: Subtitle.Codec) -> pycaption.CaptionSet: return caption_set + @staticmethod + def space_webvtt_headers(data: Union[str, bytes]): + """ + Space out the WEBVTT Headers from Captions. + + Segmented VTT when merged may have the WEBVTT headers part of the next caption + as they were not separated far enough from the previous caption and ended up + being considered as caption text rather than the header for the next segment. + """ + if isinstance(data, bytes): + data = try_ensure_utf8(data).decode("utf8") + elif not isinstance(data, str): + raise ValueError(f"Expecting data to be a str, not {data!r}") + + text = data.replace("WEBVTT", "\n\nWEBVTT").\ + replace("\r", "").\ + replace("\n\n\n", "\n \n\n").\ + replace("\n\n<", "\n<") + + return text + @staticmethod def merge_same_cues(caption_set: pycaption.CaptionSet): """Merge captions with the same timecodes and text as one in-place.""" diff --git a/devine/core/utils/webvtt.py b/devine/core/utils/webvtt.py new file mode 100644 index 0000000..096c98f --- /dev/null +++ b/devine/core/utils/webvtt.py @@ -0,0 +1,191 @@ +import re +import sys +import typing +from typing import Optional + +from pycaption import Caption, CaptionList, CaptionNode, CaptionReadError, WebVTTReader, WebVTTWriter + + +class CaptionListExt(CaptionList): + @typing.no_type_check + def __init__(self, iterable=None, layout_info=None): + self.first_segment_mpegts = 0 + super().__init__(iterable, layout_info) + + +class CaptionExt(Caption): + @typing.no_type_check + def __init__(self, start, end, nodes, style=None, layout_info=None, segment_index=0, mpegts=0, cue_time=0.0): + style = style or {} + self.segment_index: int = segment_index + self.mpegts: float = mpegts + self.cue_time: float = cue_time + super().__init__(start, end, nodes, style, layout_info) + + +class WebVTTReaderExt(WebVTTReader): + # HLS extension support + RE_TIMESTAMP_MAP = re.compile(r"X-TIMESTAMP-MAP.*") + RE_MPEGTS = re.compile(r"MPEGTS:(\d+)") + RE_LOCAL = re.compile(r"LOCAL:((?:(\d{1,}):)?(\d{2}):(\d{2})\.(\d{3}))") + + def _parse(self, lines: list[str]) -> CaptionList: + captions = CaptionListExt() + start = None + end = None + nodes: list[CaptionNode] = [] + layout_info = None + found_timing = False + segment_index = -1 + mpegts = 0 + cue_time = 0.0 + + # The first segment MPEGTS is needed to calculate the rest. It is possible that + # the first segment contains no cue and is ignored by pycaption, this acts as a fallback. + captions.first_segment_mpegts = 0 + + for i, line in enumerate(lines): + if "-->" in line: + found_timing = True + timing_line = i + last_start_time = captions[-1].start if captions else 0 + try: + start, end, layout_info = self._parse_timing_line(line, last_start_time) + except CaptionReadError as e: + new_msg = f"{e.args[0]} (line {timing_line})" + tb = sys.exc_info()[2] + raise type(e)(new_msg).with_traceback(tb) from None + + elif "" == line: + if found_timing and nodes: + found_timing = False + caption = CaptionExt( + start, + end, + nodes, + layout_info=layout_info, + segment_index=segment_index, + mpegts=mpegts, + cue_time=cue_time, + ) + captions.append(caption) + nodes = [] + + elif "WEBVTT" in line: + # Merged segmented VTT doesn't have index information, track manually. + segment_index += 1 + mpegts = 0 + cue_time = 0.0 + elif m := self.RE_TIMESTAMP_MAP.match(line): + if r := self.RE_MPEGTS.search(m.group()): + mpegts = int(r.group(1)) + + cue_time = self._parse_local(m.group()) + + # Early assignment in case the first segment contains no cue. + if segment_index == 0: + captions.first_segment_mpegts = mpegts + + else: + if found_timing: + if nodes: + nodes.append(CaptionNode.create_break()) + nodes.append(CaptionNode.create_text(self._decode(line))) + else: + # it's a comment or some metadata; ignore it + pass + + # Add a last caption if there are remaining nodes + if nodes: + caption = CaptionExt(start, end, nodes, layout_info=layout_info, segment_index=segment_index, mpegts=mpegts) + captions.append(caption) + + return captions + + @staticmethod + def _parse_local(string: str) -> float: + """ + Parse WebVTT LOCAL time and convert it to seconds. + """ + m = WebVTTReaderExt.RE_LOCAL.search(string) + if not m: + return 0 + + parsed = m.groups() + if not parsed: + return 0 + hours = int(parsed[1]) + minutes = int(parsed[2]) + seconds = int(parsed[3]) + milliseconds = int(parsed[4]) + return (milliseconds / 1000) + seconds + (minutes * 60) + (hours * 3600) + + +def merge_segmented_webvtt(vtt_raw: str, segment_durations: Optional[list[int]] = None, timescale: int = 1) -> str: + """ + Merge Segmented WebVTT data. + + Parameters: + vtt_raw: The concatenated WebVTT files to merge. All WebVTT headers must be + appropriately spaced apart, or it may produce unwanted effects like + considering headers as captions, timestamp lines, etc. + segment_durations: A list of each segment's duration. If not provided it will try + to get it from the X-TIMESTAMP-MAP headers, specifically the MPEGTS number. + timescale: The number of time units per second. + + This parses the X-TIMESTAMP-MAP data to compute new absolute timestamps, replacing + the old start and end timestamp values. All X-TIMESTAMP-MAP header information will + be removed from the output as they are no longer of concern. Consider this function + the opposite of a WebVTT Segmenter, a WebVTT Joiner of sorts. + + Algorithm borrowed from N_m3u8DL-RE and shaka-player. + """ + MPEG_TIMESCALE = 90_000 + + vtt = WebVTTReaderExt().read(vtt_raw) + for lang in vtt.get_languages(): + prev_caption = None + duplicate_index: list[int] = [] + captions = vtt.get_captions(lang) + + if captions[0].segment_index == 0: + first_segment_mpegts = captions[0].mpegts + else: + first_segment_mpegts = segment_durations[0] if segment_durations else captions.first_segment_mpegts + + caption: CaptionExt + for i, caption in enumerate(captions): + # DASH WebVTT doesn't have MPEGTS timestamp like HLS. Instead, + # calculate the timestamp from SegmentTemplate/SegmentList duration. + likely_dash = first_segment_mpegts == 0 and caption.mpegts == 0 + if likely_dash and segment_durations: + duration = segment_durations[caption.segment_index] + caption.mpegts = MPEG_TIMESCALE * (duration / timescale) + + if caption.mpegts == 0: + continue + + seconds = (caption.mpegts - first_segment_mpegts) / MPEG_TIMESCALE - caption.cue_time + offset = seconds * 1_000_000 # pycaption use microseconds + + if caption.start < offset: + caption.start += offset + caption.end += offset + + # If the difference between current and previous captions is <=1ms + # and the payload is equal then splice. + if ( + prev_caption + and not caption.is_empty() + and (caption.start - prev_caption.end) <= 1000 # 1ms in microseconds + and caption.get_text() == prev_caption.get_text() + ): + prev_caption.end = caption.end + duplicate_index.append(i) + + prev_caption = caption + + # Remove duplicate + captions[:] = [c for c_index, c in enumerate(captions) if c_index not in set(duplicate_index)] + + return WebVTTWriter().write(vtt)