Skip to content

Commit

Permalink
feat: endpoint /convert_alignment supports srt and vtt
Browse files Browse the repository at this point in the history
  • Loading branch information
joanise committed Sep 26, 2022
1 parent 6fcc404 commit e035cb2
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 25 deletions.
99 changes: 74 additions & 25 deletions readalongs/web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from lxml import etree
from pydantic import BaseModel, Field

from readalongs.align import create_tei_from_text, save_label_files
from readalongs.align import create_tei_from_text, save_label_files, save_subtitles
from readalongs.text.add_ids_to_xml import add_ids
from readalongs.text.convert_xml import convert_xml
from readalongs.text.make_dict import make_dict_object
Expand Down Expand Up @@ -222,8 +222,8 @@ class ConvertRequest(BaseModel):

output_format: str = Field(
example="TextGrid",
regex="^(?i)(eaf|TextGrid)$",
title="Format to convert to, one of TextGrid (Praat), eaf (ELAN).",
regex="^(?i)(eaf|TextGrid|srt|vtt)$",
title="Format to convert to, one of TextGrid (Praat), eaf (ELAN), srt (SRT subtitles), or vtt (VTT subtitles).",
)

xml: str = Field(
Expand Down Expand Up @@ -268,10 +268,28 @@ class ConvertRequest(BaseModel):
class ConvertResponse(BaseModel):
"""Convert response has the requesed converted file's contents"""

file_name: str = Field(
title="A suggested name for this output file: aligned + the standard extension"
)

file_contents: str = Field(
title="Full contents of the converted file in the format requested."
)

other_file_name: Union[None, str] = Field(
title="Suggested name for the second file, if the conversion generates two files"
)

other_file_contents: Union[None, str] = Field(
title="Full contents of the second file, if any"
)


def slurp_file(filename):
"""Slurp a file into one string"""
with open(filename, "r", encoding="utf-8") as f:
return f.read()


@v1.post("/convert_alignment", response_model=ConvertResponse)
async def convert_alignment(input: ConvertRequest) -> ConvertResponse:
Expand All @@ -280,16 +298,26 @@ async def convert_alignment(input: ConvertRequest) -> ConvertResponse:
Args (as dict items in the request body):
- audio_length: duration in seconds of the audio file used to create the alignment
- encoding: use utf-8, other encodings are not supported (yet)
- output_format: one of TextGrid (Praat), eaf (ELAN), ...
- output_format: one of TextGrid, eaf, srt, vtt
- xml: the XML file produced by /assemble
- smil: the SMIL file produced by SoundSwallower(.js)
Formats supported:
- TextGrid: Praat TextGrid file format
- eaf: ELAN eaf file format
- srt: SRT subtitle format, as two files, 1) for sentences, 2) for words
- vtt: WebVTT subtitle format, as two files, 1) for sentences, 2) for words
Data privacy consideration: due to limitations of the libraries used to perform
some of these conversions, the output file may be temporarily stored on disk,
but it gets deleted immediately, before it is even returned by this endpoint.
some of these conversions, the output files will be temporarily stored on disk,
but they get deleted immediately, before they are even returned by this endpoint.
Returns:
- file_name: a suggested name for the returned file
- file_contents: the contents of the file converted in the requested format
for srt and vtt:
- other_file_name: a suggested name for the second file
- other_file_contents: the contents of the second file
"""
try:
parsed_xml = etree.fromstring(bytes(input.xml, encoding=input.encoding))
Expand All @@ -307,30 +335,51 @@ async def convert_alignment(input: ConvertRequest) -> ConvertResponse:
except ValueError as e:
raise HTTPException(status_code=422, detail="SMIL provided is not valid") from e

output_format = input.output_format.lower()
if output_format == "textgrid":
with TemporaryDirectory() as temp_dir_name:
prefix = os.path.join(temp_dir_name, "f")
save_label_files(words, parsed_xml, input.audio_length, prefix, "textgrid")
with open(prefix + ".TextGrid", mode="r", encoding="utf-8") as f:
textgrid_text = f.read()
# Data privacy consideration: creating the temporary directory with a context
# manager like this guarantees, as we promise in the API documentation, that the
# temporary directory and all temporary files we create in it will be deleted
# when this function exits, whether it is with an error or with success.
with TemporaryDirectory() as temp_dir_name:
prefix = os.path.join(temp_dir_name, "f")

return ConvertResponse(file_contents=textgrid_text)
output_format = input.output_format.lower()
if output_format == "textgrid":
save_label_files(words, parsed_xml, input.audio_length, prefix, "textgrid")
return ConvertResponse(
file_name="aligned.TextGrid",
file_contents=slurp_file(prefix + ".TextGrid"),
)

elif output_format == "eaf":
with TemporaryDirectory() as temp_dir_name:
prefix = os.path.join(temp_dir_name, "f")
elif output_format == "eaf":
save_label_files(words, parsed_xml, input.audio_length, prefix, "eaf")
with open(prefix + ".eaf", mode="r", encoding="utf-8") as f:
eaf_text = f.read()
return ConvertResponse(
file_name="aligned.eaf",
file_contents=slurp_file(prefix + ".eaf"),
)

return ConvertResponse(file_contents=eaf_text)
elif output_format == "srt":
save_subtitles(words, parsed_xml, prefix, "srt")
return ConvertResponse(
file_name="aligned_sentences.srt",
file_contents=slurp_file(prefix + "_sentences.srt"),
other_file_name="aligned_words.srt",
other_file_contents=slurp_file(prefix + "_words.srt"),
)

else:
raise HTTPException(
status_code=500,
detail="Invalid output_format should have been caught by fastAPI already...",
)
elif output_format == "vtt":
save_subtitles(words, parsed_xml, prefix, "vtt")
return ConvertResponse(
file_name="aligned_sentences.vtt",
file_contents=slurp_file(prefix + "_sentences.vtt"),
other_file_name="aligned_words.vtt",
other_file_contents=slurp_file(prefix + "_words.vtt"),
)

else:
raise HTTPException(
status_code=500,
detail="Invalid output_format (should have been caught by regex validation!)",
)


# Mount the v1 version of the API to the root of the app
Expand Down
79 changes: 79 additions & 0 deletions test/test_web_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def test_convert_to_TextGrid(self):
response = API_CLIENT.post("/api/v1/convert_alignment", json=request)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["file_contents"], self.hej_verden_textgrid)
self.assertTrue(response.json()["file_name"].endswith(".TextGrid"))

def test_convert_to_TextGrid_errors(self):
request = {
Expand Down Expand Up @@ -273,6 +274,84 @@ def test_convert_to_eaf(self):
response = API_CLIENT.post("/api/v1/convert_alignment", json=request)
self.assertEqual(response.status_code, 200)
self.assertIn("<ANNOTATION_DOCUMENT", response.json()["file_contents"])
self.assertTrue(response.json()["file_name"].endswith(".eaf"))

def test_convert_to_srt(self):
request = {
"encoding": "utf-8",
"audio_length": 83.1,
"output_format": "srt",
"xml": self.hej_verden_xml,
"smil": self.hej_verden_smil,
}
response = API_CLIENT.post("/api/v1/convert_alignment", json=request)
self.assertEqual(response.status_code, 200)
self.assertTrue(response.json()["file_name"].endswith("_sentences.srt"))
self.assertTrue(response.json()["other_file_name"].endswith("_words.srt"))
self.assertEqual(
response.json()["file_contents"],
dedent(
"""\
1
00:00:17,745 --> 00:01:22,190
hej é verden à
"""
),
)
self.assertEqual(
response.json()["other_file_contents"],
dedent(
"""\
1
00:00:17,745 --> 00:00:58,600
hej é
2
00:00:58,600 --> 00:01:22,190
verden à
"""
),
)

def test_convert_to_vtt(self):
request = {
"encoding": "utf-8",
"audio_length": 83.1,
"output_format": "vtt",
"xml": self.hej_verden_xml,
"smil": self.hej_verden_smil,
}
response = API_CLIENT.post("/api/v1/convert_alignment", json=request)
self.assertEqual(response.status_code, 200)
self.assertTrue(response.json()["file_name"].endswith("_sentences.vtt"))
self.assertTrue(response.json()["other_file_name"].endswith("_words.vtt"))
self.assertEqual(
response.json()["file_contents"],
dedent(
"""\
WEBVTT
00:00:17.745 --> 00:01:22.190
hej é verden à
"""
),
)
self.assertEqual(
response.json()["other_file_contents"],
dedent(
"""\
WEBVTT
00:00:17.745 --> 00:00:58.600
hej é
00:00:58.600 --> 00:01:22.190
verden à
"""
),
)

def test_convert_to_bad_format(self):
request = {
Expand Down

0 comments on commit e035cb2

Please sign in to comment.