Skip to content

Commit

Permalink
added implementation to calculate embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
haeussma committed Oct 14, 2024
1 parent a1857d8 commit e6b469d
Show file tree
Hide file tree
Showing 5 changed files with 1,289 additions and 14 deletions.
54 changes: 54 additions & 0 deletions pyeed/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import gc

import torch
from transformers import EsmModel, EsmTokenizer


def get_batch_embeddings(sequences: list[str], batch_size: int = 16):
# Load the ESM2 model and tokenizer
model_name = "facebook/esm2_t33_650M_UR50D"
model = EsmModel.from_pretrained(model_name)
tokenizer = EsmTokenizer.from_pretrained(model_name)

# Check if MPS (Metal Performance Shaders) is available and use it
device = (
torch.device("mps") if torch.backends.mps.is_built() else torch.device("cpu")
)
model = model.to(device)

embedding_list = []
model.eval()

with torch.no_grad():
# Process sequences in batches
for i in range(0, len(sequences), batch_size):
batch = sequences[i : i + batch_size]

# Tokenize the input sequences (must be a list of strings)
inputs = tokenizer(
batch, padding=True, truncation=True, return_tensors="pt"
).to(device)

# Get model outputs
outputs = model(**inputs)
embeddings = outputs.last_hidden_state

# Process each sequence in the batch
for j in range(len(batch)):
valid_token_mask = inputs["attention_mask"][j].bool()
seq_embeddings = embeddings[j][valid_token_mask].mean(dim=0).cpu()
embedding_list.append(seq_embeddings)

return embedding_list


def free_memory():
gc.collect() # Python garbage collection
if torch.backends.mps.is_built():
torch.mps.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()


if __name__ == "__main__":
free_memory()
6 changes: 6 additions & 0 deletions pyeed/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def _class_properties(cls):

def save(self, *args, **kwargs):
"""Validates the properties and then saves the node."""
allowed_properties = self.__class__._class_properties()

# Only validate properties defined in the model schema
for field, prop in self.__dict__.items():
if field not in allowed_properties:
continue # Skip non-class properties (like internal Neo4j fields)

if prop is None or callable(prop):
continue

Expand Down
24 changes: 24 additions & 0 deletions pyeed/pyeed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio

import nest_asyncio
from loguru import logger

from pyeed.adapter.primary_db_adapter import PrimaryDBAdapter
from pyeed.adapter.uniprot_mapper import UniprotToPyeed
from pyeed.dbconnect import DatabaseConnector
from pyeed.embedding import free_memory, get_batch_embeddings
from pyeed.model import Protein


class Pyeed:
Expand Down Expand Up @@ -45,6 +48,27 @@ def fetch_from_primary_db(self, ids: list[str]):

asyncio.run(adapter.make_request())

def calculate_sequence_embeddings(self):
"""
Calculates embeddings for all sequences in the database that do not have embeddings.
"""

proteins = Protein.nodes.filter(embedding__isnull=True)
logger.debug(f"Found {len(proteins)} proteins without embeddings.")
accessions = [protein.accession_id for protein in proteins]
sequences = [protein.sequence for protein in proteins]

logger.debug(f"Calculating embeddings for {len(sequences)} sequences.")
embeddings = get_batch_embeddings(sequences)

for i, protein in enumerate(proteins):
if not protein.accession_id == accessions[i]:
raise ValueError("Protein accessions do not match.")
protein.embedding = embeddings[i].tolist()
protein.save()

free_memory()


if __name__ == "__main__":
eedb = Pyeed("bolt://127.0.0.1:7687")
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pyeed"
version = "0.4.0"
version = "0.4.1"
description = "Toolkit to create, annotate, and analyze sequence data"
authors = ["haeussma <83341109+haeussma@users.noreply.github.com>"]
license = "MIT"
Expand Down Expand Up @@ -28,6 +28,10 @@ bio = "^1.7.1"
loguru = "^0.7.2"
neomodel = "^5.3.3"
shapely = "^2.0.6"
torch = "^2.4.1"
transformers = "^4.45.2"
scikit-learn = "^1.5.2"
numpy = "^2.1.2"

[tool.poetry.group.dev.dependencies]
mkdocs-material = "^9.5.9"
Expand Down
1,213 changes: 1,200 additions & 13 deletions test.ipynb

Large diffs are not rendered by default.

0 comments on commit e6b469d

Please sign in to comment.