Skip to content

vysri/knowledge-boosting

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Knowledge Boosting (Model Collaboration During Low-Latency Inference)

Gradio demo Gradio demo

This repository provides code for the Knowledge Boosting architecture proposed in the paper, Knowledge boosting during low latency inference, presented at Interspeech 2024. Knowledge Boosting is a technique to enhance small model performance on-device during inference with assistance from a large model on a remote device. Knowledge boosting allows the large model to provide time-delayed hints to the small model on-device during inference time.

kb-video-preview.mp4

Architecture

Our system architecture. The green arrow is present only during large model pre-training. The red arrows are present during knowledge boosting. The black arrows are present both during pre-training and knowledge boosting. The TF-GridNet model is used to demonstrate results and is the model documented in this repository.

kb-animation

Setup

# Commands in all sections are run from the repo's top level directory
conda create --name kb python=3.9
conda activate kb
pip install -r requirements.txt

Training and Evaluation

Dataset

We use Zenodo to host our datasets. You can access the different datasets below (download both part 1 and 2 for a specific dataset). Each dataset contains a train, validation, and test partition.

Create the data directory and untar the data from Zenodo. This example is for target speaker extraction. Replace 'tse' with 'ss' below for source separation:

# Create data directory
mkdir data

Download all tarballs from the datasets specified above.

# Assemble the train dataset tarball
cat kb-tse-dataset-train-part*.tar > /scr/kb-tse-dataset-train.tar

Untar the datasets from Zenodo into data directory.

cd data
tar -xvf kb-tse-dataset-train.tar -C .
tar -xvf kb-tse-dataset-val.tar -C .
tar -xvf kb-tse-dataset-test.tar -C .
cd ..

Training

You can run either the baseline models (train large and small models separately before joint training) or run joing configurations. These configurations are under configs/baselines and configs/TSE_joint or configs/SS_joint depending on the task.

Note that in the joint configurations specifically, you will need to specify the big_model_init_ckpt argument which is a PyTorch (.pt) model checkpoint. You may generate your own through training the baseline configurations provided or refer to our model checkpoints (TSE, SS).

# Usage: trainer.py [-h] --config CONFIG --run_dir RUN_DIR [--resume] [--ckpt CKPT] [--test]
python -m src.trainer --run_dir <NAME OF DIR TO LOG RUNS> --config <configs/PATH TO CONFIG.json>

Citation

@misc{srinivas2024knowledgeboosting,
  title={Knowledge boosting during low-latency inference}, 
  author={Vidya Srinivas and Malek Itani and Tuochao Chen and Emre Sefik Eskimez and Takuya Yoshioka and Shyamnath Gollakota},
  year={2024},
  eprint={2407.11055},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}