Skip to content

Commit

Permalink
Re-factor SLI scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
fauxneticien committed Apr 14, 2022
1 parent ce7d7b2 commit 5a40dd2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 39 deletions.
10 changes: 0 additions & 10 deletions scripts/_run-pipeline.sh

This file was deleted.

4 changes: 2 additions & 2 deletions scripts/helpers/sli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def get_sli_df(sli_train_dir):

return sli_df

def get_sb_encoder():
def get_sb_encoder(save_dir="tmp"):
sb_encoder = EncoderClassifier.from_hparams(
source="speechbrain/lang-id-voxlingua107-ecapa",
savedir="tmp/",
savedir=save_dir,
run_opts={"device": "cuda:1" if torch.cuda.is_available() else "cpu" }
)

Expand Down
35 changes: 8 additions & 27 deletions scripts/train_sli-by-sblr.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import pickle

from argparse import ArgumentParser
from sklearn.linear_model import LogisticRegression
from sklearn.utils import shuffle
from speechbrain.pretrained import EncoderClassifier
from tqdm import tqdm

import glob
import os
import pandas as pd
import pickle
import torch
import torchaudio
from helpers.sli import get_sli_df, get_sb_encoder, add_sbemb_cols, colsplit_feats_labels

parser = ArgumentParser(
prog='train_sli-by-sblr',
Expand All @@ -23,31 +18,17 @@

args = parser.parse_args()

language_id = EncoderClassifier.from_hparams(source="speechbrain/lang-id-voxlingua107-ecapa", savedir="tmp/")

def get_sb_emb(wav_path):
waveform, sample_rate = torchaudio.load(wav_path)

if sample_rate != 16_000:
print("Resampling audio to 16 kHz ...")
samp_to_16k = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16_000)
waveform = samp_to_16k(waveform)

emb = language_id.encode_batch(waveform)

return emb.reshape((1, 256))

wav_files = glob.glob(os.path.join(args.clips_dir, "*", "*.wav"))
langs = [ os.path.basename(os.path.dirname(f)) for f in wav_files ]
sli_df = get_sli_df(args.clips_dir)

print("Extracting features...")

embds = pd.concat([ pd.DataFrame(get_sb_emb(f)) for f in tqdm(wav_files) ])
sli_df = add_sbemb_cols(sli_df, sb_encoder=get_sb_encoder())

langs, embds = shuffle(langs, embds, random_state=0)
feats, labels = colsplit_feats_labels(sli_df)
feats, labels = shuffle(feats, labels, random_state=0)

print("Fitting classifier...")
clf = LogisticRegression(random_state=0, max_iter=args.logreg_maxiter).fit(embds, langs)
clf = LogisticRegression(random_state=0, max_iter=args.logreg_maxiter).fit(feats, labels)

pickle.dump(clf, open(args.logreg_pkl, 'wb'))
print(f"Saved classifier to {args.logreg_pkl}")

0 comments on commit 5a40dd2

Please sign in to comment.