This repository provides code for the Knowledge Boosting architecture proposed in the paper, Knowledge boosting during low latency inference, presented at Interspeech 2024. Knowledge Boosting is a technique to enhance small model performance on-device during inference with assistance from a large model on a remote device. Knowledge boosting allows the large model to provide time-delayed hints to the small model on-device during inference time.
kb-video-preview.mp4
Our system architecture. The green arrow is present only during large model pre-training. The red arrows are present during knowledge boosting. The black arrows are present both during pre-training and knowledge boosting. The TF-GridNet model is used to demonstrate results and is the model documented in this repository.
# Commands in all sections are run from the repo's top level directory
conda create --name kb python=3.9
conda activate kb
pip install -r requirements.txt
We use Zenodo to host our datasets. You can access the different datasets below (download both part 1 and 2 for a specific dataset). Each dataset contains a train, validation, and test partition.
- Target Speech Extraction (TSE) Part 1
- Target Speech Extraction (TSE) Part 2
- Source Separation (SS) Part 1
- Source Separation (SS) Part 2
Create the data
directory and untar the data from Zenodo. This example is for target speaker extraction. Replace 'tse' with 'ss' below for source separation:
# Create data directory
mkdir data
Download all tarballs from the datasets specified above.
# Assemble the train dataset tarball
cat kb-tse-dataset-train-part*.tar > /scr/kb-tse-dataset-train.tar
Untar the datasets from Zenodo into data
directory.
cd data
tar -xvf kb-tse-dataset-train.tar -C .
tar -xvf kb-tse-dataset-val.tar -C .
tar -xvf kb-tse-dataset-test.tar -C .
cd ..
You can run either the baseline models (train large and small models separately before joint training) or run joing configurations. These configurations are under configs/baselines
and configs/TSE_joint
or configs/SS_joint
depending on the task.
Note that in the joint configurations specifically, you will need to specify the big_model_init_ckpt
argument which is a PyTorch (.pt) model checkpoint. You may generate your own through training the baseline configurations provided or refer to our model checkpoints (TSE, SS).
# Usage: trainer.py [-h] --config CONFIG --run_dir RUN_DIR [--resume] [--ckpt CKPT] [--test]
python -m src.trainer --run_dir <NAME OF DIR TO LOG RUNS> --config <configs/PATH TO CONFIG.json>
@misc{srinivas2024knowledgeboosting,
title={Knowledge boosting during low-latency inference},
author={Vidya Srinivas and Malek Itani and Tuochao Chen and Emre Sefik Eskimez and Takuya Yoshioka and Shyamnath Gollakota},
year={2024},
eprint={2407.11055},
archivePrefix={arXiv},
primaryClass={cs.LG}
}