-
Notifications
You must be signed in to change notification settings - Fork 2
/
youtube_runner.py
243 lines (202 loc) · 8.3 KB
/
youtube_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import argparse
import json
import os
import pprint
import time
from os import PathLike
from pathlib import Path
from typing import Any
import sqlalchemy as db
from loguru import logger
from sqlalchemy.orm import sessionmaker
from youtube_transcript_api._errors import TranscriptsDisabled
from generators.youtube_generator import generate
from models import DummyTest
from src.database import YouTubeBase
from src.dataclasses import YouTubeVideo
from src.test_runner import TestRegistry, TestRunner
from src.utils import insert_youtube_result
@TestRegistry.register(DummyTest)
class YouTubeTestRunner(TestRunner):
"""
Test runner for youtube videos, uses a testplan generated by youtube_generator
"""
def __init__(
self,
testplan_path: PathLike,
audio_dir: PathLike,
output_dir: PathLike,
iterations: int = 1,
save_transcripts: bool = False,
save_to_database: bool = False,
keep_audio: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self._audio_dir = Path(audio_dir)
self._output_dir = Path(output_dir)
self._testplan_path = Path(testplan_path)
self._iterations = iterations
self._save_transcripts = save_transcripts
self._save_to_database = save_to_database
self._keep_audio = keep_audio
self._session = None
if self._save_to_database:
engine = db.create_engine(f"sqlite:///youtube.sqlite")
YouTubeBase.metadata.create_all(engine)
self._session = sessionmaker(bind=engine)()
def run(self) -> None:
if self.tester.transcriber is None:
raise ValueError("Transcriber is None")
if self.tester.normalizer is None:
logger.warning("Normalizer is None, running without normalizer")
# check if we can generate more testplans if we need to
if self._iterations > 1 and "GoogleAPI" not in os.environ:
raise ValueError(
"GoogleAPI not in the environment, can not generate more test plans. "
"Add GoogleAPI or set iterations to 1."
)
with open(self._testplan_path, encoding="utf8") as f:
testplan = json.load(f)
# run the testplan
for i in range(self._iterations):
logger.info(f"Starting {i + 1}/{self._iterations} testplan")
logger.info(f"Testplan args:\n{pprint.pformat(testplan['args'])}")
# iterate over the videos in the testplan
for idx, video_details in enumerate(testplan["items"]):
logger.info(
f"Testplan status: {idx + 1}/{len(testplan['items'])} video, {i + 1}/{self._iterations} testplan"
)
video = YouTubeVideo.from_dict(video_details)
# download the target transcript
try:
target_transcript = video.youtube_transcript(self.tester.language)
except ValueError as e:
logger.warning(
f"Skipping the video {video.videoId}, ValueError (youtube transcript): {e}"
)
video_details["error"] = f"ValueError (youtube transcript): {e}"
continue
except TranscriptsDisabled as e:
logger.warning(
f"Skipping the video {video.videoId}, TranscriptsDisabled (youtube transcript): {e}"
)
video_details[
"error"
] = f"TranscriptsDisabled (youtube transcript): {e}"
continue
# download the audio
try:
audio = video.download_mp3(self._audio_dir)
except ValueError as e:
logger.warning(
f"Skipping the video {video.videoId}, ValueError (download): {e}"
)
video_details["error"] = f"ValueError (download): {e}"
continue
# transcribe the audio
try:
model_transcript = self.tester.transcribe(audio)
except TimeoutError as e:
logger.warning(
f"Skipping the video {video.videoId}, TimeoutError (model transcript): {e}"
)
video_details["error"] = f"TimeoutError (model transcript): {e}"
continue
if not self._keep_audio:
audio.unlink()
# compare the transcripts
try:
results = self.tester.compare(model_transcript, target_transcript)
results.update(self.tester.additional_info())
video_details["results"] = results
except ValueError as e:
logger.warning(
f"Skipping the video {video.videoId}, ValueError (compare): {e}"
)
video_details["error"] = f"ValueError (compare): {e}"
continue
# add the transcripts to the video details if we want to save them
if self._save_transcripts:
video_details["modelTranscript"] = model_transcript
video_details["targetTranscript"] = target_transcript
self.save_results(testplan)
# generate a new testplan if we need to
if i + 1 < self._iterations:
args = testplan["args"]
args["pageToken"] = testplan["nextPageToken"]
testplan = generate(args)
logger.info("Testplan finished")
def save_results(self, results: dict[str, Any]) -> None:
"""
Save the results to a json file
Args:
results: results to save
"""
time_str = time.strftime("%Y%m%d-%H%M%S")
filename = f"{results['args']['q']}_{self.__class__.__name__}_{time_str}.json"
path = Path(__file__).parent.joinpath(self._output_dir, filename)
path.parent.mkdir(exist_ok=True)
logger.info(f"Saving results - {path}")
with open(path, "x", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False)
if self._save_to_database and not insert_youtube_result(
self._session, filename, results
):
logger.warning(f"Failed to save results to database")
def __repr__(self):
try:
category = f'_{self._testplan_path.name.split("_")[0]}'
except IndexError:
category = ""
return f"{self.__class__.__name__}{category}"
@staticmethod
def runner_args(parser: argparse.ArgumentParser) -> None:
"""
Add runner specific arguments to the parser
Args:
parser: parser to add arguments to
"""
parser.add_argument(type=str, dest="testplan_path", help="Testplan path")
parser.add_argument(
"-st",
"--save-transcript",
required=False,
action="store_true",
dest="save_transcripts",
default=False,
)
parser.add_argument(
"-db",
"--save-to-database",
required=False,
action="store_true",
dest="save_to_database",
default=False,
)
parser.add_argument(
"--audio-path",
required=False,
type=str,
default="./cache/audio",
dest="audio_dir",
)
parser.add_argument(
"-k",
"--keep-audio",
required=False,
action="store_true",
dest="keep_audio",
default=False,
)
parser.add_argument("-it", "--iterations", required=False, type=int, default=1)
parser.add_argument(
"-o",
"--output",
required=False,
type=str,
default="./output",
dest="output_dir",
)
if __name__ == "__main__":
YouTubeTestRunner.from_command_line().run()