This repository contains the implementation code for our preprint Aggregating Residue-Level Protein Language Model Embeddings with Optimal Transport, which showcases the benefits of sliced-Wasserstein embedding to summarize token-level representations produced by pre-trained ESM-2 protein language models (PLMs).
Protein language models (PLMs) have emerged as powerful approaches for mapping protein sequences into informative embeddings suitable for a range of applications. PLMs, as well as many other protein representation schemes, generate per-token (i.e., per-residue) representations, leading to variable-sized outputs based on protein length. This variability presents a challenge for protein-level prediction tasks, which require uniform-sized embeddings for consistent analysis across different proteins. Prior work has typically resorted to average pooling to summarize token-level PLM outputs. It is, however, unclear if such an aggregation operation effectively prioritizes the relevant information across token-level representations. Addressing this, we introduce a novel method utilizing sliced-Wasserstein embeddings to convert variable-length PLM outputs into fixed-length protein-level representations. Inspired by the success of optimal transport techniques in representation learning, we first conceptualize per-token PLM outputs as samples from a probabilistic distribution. We then employ sliced-Wasserstein distances to map these samples against a learnable reference set, creating a Euclidean embedding in the output space. The resulting embedding is agnostic to the length of the input and represents the entire protein. Across a range of state-of-the-art pre-trained ESM-2 PLMs, with varying model sizes, we show the superiority of our method over average pooling for protein-drug and protein-protein interaction. Our aggregation scheme is especially effective when model size is constrained, enabling smaller-scale PLMs to match or exceed the performance of average-pooled larger-scale PLMs. Since using smaller models reduces computational resource requirements, our approach not only promises more accurate inference but can also help democratize access to foundation models.
Create a new virtual Conda environment, called plm_swe
, with the required libraries using the following commands:
conda create -n plm_swe python=3.9
conda activate plm_swe
pip install -r requirements.txt
The download_data.py
script can be used to download the datasets for the experiments into a new folder called datasets
by running
python download_data.py --to datasets --benchmarks davis bindingdb ppi_gold
The following commands can be used to run the four numerical experiments in the paper. The number of points in the reference set, number of slices, and the size of the ESM-2 pre-trained PLM can be adjusted via command-line parameters, as well as the configuration files under config
.
python run_dti.py --run-id dti_davis_swepooling_128refpoints_128slices_esm2_8m --config config/dti_davis.yaml --num-ref-points 128 --num-slices 128 --target-model-type esm2_t6_8M_UR50D
python run_dti.py --run-id dti_bindingdb_swepooling_128refpoints_128slices_esm2_8m --config config/dti_bindingdb.yaml --num-ref-points 128 --num-slices 128 --target-model-type esm2_t6_8M_UR50D
python run_dti.py --run-id dti_tdc_dg_swepooling_128refpoints_128slices_esm2_8m --config config/dti_tdc_dg.yaml --num-ref-points 128 --num-slices 128 --target-model-type esm2_t6_8M_UR50D
python run_ppi.py --run-id ppi_gold_swepooling_128refpoints_128slices_esm2_8m --config config/ppi_gold.yaml --num-ref-points 128 --num-slices 128 --target-model-type esm2_t6_8M_UR50D
This repository is built upon the following GitHub repositories:
If you make use of this repository, please cite our preprint using the following BibTeX format:
@article{naderializadeh2024_plm_swe,
title={Aggregating Residue-Level Protein Language Model Embeddings with Optimal Transport},
author={NaderiAlizadeh, Navid and Singh, Rohit},
journal={bioRxiv},
year={2024},
publisher={Cold Spring Harbor Laboratory}
}