forked from apple/ml-stuttering-events-dataset
-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract_clips.py
84 lines (66 loc) · 2.19 KB
/
extract_clips.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2021 Apple Inc. All Rights Reserved.
#
"""
For each podcast episode:
* Get all clip information for that episode
* Save each clip as a new wav file.
"""
import os
import pathlib
import subprocess
import numpy as np
import pandas as pd
from scipy.io import wavfile
import argparse
parser = argparse.ArgumentParser(description='Extract clips from SEP-28k or FluencyBank.')
parser.add_argument('--labels', type=str, required=True,
help='Path to the labels csv files (e.g., SEP-28k_labels.csv)')
parser.add_argument('--wavs', type=str, default="wavs",
help='Path where audio files from download_audio.py are saved')
parser.add_argument('--clips', type=str, default="clips",
help='Path where clips should be extracted')
parser.add_argument("--progress", action="store_true",
help="Show progress")
args = parser.parse_args()
label_file = args.labels
data_dir = args.wavs
output_dir = args.clips
# Load label/clip file
data = pd.read_csv(label_file, dtype={"EpId":str})
# Get label columns from data file
shows = data.Show
episodes = data.EpId
clip_idxs = data.ClipId
starts = data.Start
stops = data.Stop
labels = data.iloc[:,5:].values
n_items = len(shows)
loaded_wav = ""
cur_iter = range(n_items)
if args.progress:
from tqdm import tqdm
cur_iter = tqdm(cur_iter)
for i in cur_iter:
clip_idx = clip_idxs[i]
show_abrev = shows[i]
episode = episodes[i].strip()
# Setup paths
wav_path = f"{data_dir}/{shows[i]}/{episode}.wav"
clip_dir = pathlib.Path(f"{output_dir}/{show_abrev}/{episode}/")
clip_path = f"{clip_dir}/{shows[i]}_{episode}_{clip_idx}.wav"
if not os.path.exists(wav_path):
print("Missing", wav_path)
continue
# Verify clip directory exists
os.makedirs(clip_dir, exist_ok=True)
# Load audio. For efficiency reasons don't reload if we've already open the file.
if wav_path != loaded_wav:
sample_rate, audio = wavfile.read(wav_path)
assert sample_rate == 16000, "Sample rate must be 16 khz"
# Keep track of the open file
loaded_wav = wav_path
# Save clip to file
clip = audio[starts[i]:stops[i]]
wavfile.write(clip_path, sample_rate, clip)