Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #8

Merged
merged 2 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions embedders/prottrans.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,25 @@ def main_prottrans(df: pd.DataFrame, args: argparse.Namespace, iterator: List[sl
for batch_id_filename, batchslice in tqdm(enumerate(iterator), total=len(iterator)):
seqlist = seqlist_all[batchslice]
lenlist = lenlist_all[batchslice]
# add empty character between all residues
# his is mandatory for pt5 embedders
seqlist = [' '.join(list(seq)) for seq in seqlist]
batch_index = list(range(batchslice.start, batchslice.stop))
ids = tokenizer.batch_encode_plus(seqlist, add_special_tokens=True, padding="longest")
input_ids = torch.tensor(ids['input_ids']).to(device, non_blocking=True)
attention_mask = torch.tensor(ids['attention_mask']).to(device, non_blocking=True)
with torch.no_grad():
embeddings = model(input_ids=input_ids, attention_mask=attention_mask)
embeddings = embeddings[0].float().cpu()
embeddings = embeddings.last_hidden_state.float().cpu()
# remove sequence padding
num_batch_embeddings = len(embeddings)
assert num_batch_embeddings == len(seqlist)
embeddings_filt = []
for i in range(num_batch_embeddings):
seq_len = lenlist[i]
emb = embeddings[i]
if emb.shape[1] < seq_len:
raise KeyError(f'sequence is longer then embedding {emb.shape[1]} and {seq_len} ')
if emb.shape[0] < seq_len:
raise KeyError(f'sequence is longer then embedding {emb.shape} and {seq_len} ')
embeddings_filt.append(emb[:seq_len])
# store each batch depending on save mode
if args.asdir:
Expand Down
136 changes: 70 additions & 66 deletions scripts/run_plm_blast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import concurrent
import itertools
import datetime
from typing import List

import pandas as pd
import numpy as np
import torch
from torch.nn.functional import avg_pool1d
from tqdm import tqdm
from Bio.Align import substitution_matrices
blosum62 = substitution_matrices.load("BLOSUM62")
Expand Down Expand Up @@ -52,9 +53,9 @@ def get_parser():
type=str)
group = parser.add_mutually_exclusive_group()
group.add_argument('-cosine_cutoff', help='Cosine similarity cut-off (0..1)',
type=range01, default=None, dest='COS_SIM_CUT')
type=range01, default=0.2, dest='COS_SIM_CUT')
group.add_argument('-cosine_percentile_cutoff', help='Cosine similarity percentile cut-off (0-100)',
type=range0100, default=None, dest='COS_PER_CUT')
type=range0100, default=99, dest='COS_PER_CUT')
parser.add_argument('-alignment_cutoff', help='Alignment score cut-off (default: %(default)s)',
type=float, default=0.4, dest='ALN_CUT')
parser.add_argument('-sigma_factor', help='The Sigma factor defines the greediness of the local alignment search procedure. Values <1 may result in longer alignments (default: %(default)s)',
Expand Down Expand Up @@ -85,7 +86,6 @@ def get_parser():
assert args.MAX_TARGETS > 0
assert args.MAX_WORKERS > 0, 'At least one CPU core is needed!'
assert args.COS_SIM_CUT != None or args.COS_PER_CUT != None, 'Please define COS_PER_CUT _or_ COS_SIM_CUT!'

return args

def check_cohesion(frame, filedict, embeddings, truncate=600):
Expand Down Expand Up @@ -154,94 +154,98 @@ def calc_ident(s1, s2):

# read query
query_index = args.query + '.csv'
query_emb = args.query + '.pt_emb.p'

query_embs = args.query + '.pt_emb.p'
query_df = pd.read_csv(query_index)
query_emb = torch.load(query_emb)[0]
# select embeddings pool factor
if args.EMB_POOL == 1:
emb_type = ''
elif args.EMB_POOL == 2:
emb_type = '.512'
elif args.EMB_POOL == 4:
emb_type = '.256'
else:
raise ValueError(f'invalid EMB_POOL value: {args.EMB_POOL}')

query_seq = str(query_df.iloc[0].sequence)
query_embs = torch.load(query_embs)
if query_df.shape[0] != len(query_embs):
raise ValueError(f'length of embedding file and sequences df is different {query_df.shape[0]} != {len(query_emb)}')
if args.verbose:
print(f'query sequences: {query_df.shape[0]}')
query_seqs = query_df['sequence'].tolist()
query_seqs: List[str]= [str(seq) for seq in query_seqs]

##########################################################################
# filtering #
##########################################################################
query_filedict = dict()
if args.use_chunkcs:
if args.verbose:
print('chunk cosine similarity screening ...')
query_emb_chunkcs = torch.nn.functional.avg_pool1d(query_emb.unsqueeze(0), 16).squeeze()
##########################################################################
# fixed #
##########################################################################
query_emb_chunkcs = [
avg_pool1d(emb.unsqueeze(0), 16).squeeze() for emb in query_embs]
dbfile = os.path.join(args.db, 'emb.64')
filelist = [os.path.join(args.db, f'{f}.emb') for f in range(0, db_df.shape[0])]
embedding_list = torch.load(dbfile)
filedict = ds.local.chunk_cosine_similarity(query = query_emb_chunkcs,
targets = embedding_list,
quantile = args.COS_PER_CUT/100,
dataset_files = filelist,
stride = 10)
for i, emb in enumerate(query_emb_chunkcs):
filedict = ds.local.chunk_cosine_similarity(query=emb,
targets=embedding_list,
quantile=args.COS_PER_CUT/100,
dataset_files=filelist,
stride=10)
query_filedict[i] = filedict
else:
if args.verbose:
print('cosine similarity screening ...')
filedict = ds.load_and_score_database(query_emb,
dbpath = args.db,
quantile = args.COS_PER_CUT/100,
num_workers = args.MAX_WORKERS)
filedict = { k : v.replace('.emb.sum', f'.emb{emb_type}') for k, v in filedict.items()}
for i, emb in enumerate(query_embs):
filedict = ds.load_and_score_database(emb,
dbpath = args.db,
quantile = args.COS_PER_CUT/100,
num_workers = args.MAX_WORKERS)
filedict = { k : v.replace('.emb.sum', f'.emb{emb_type}') for k, v in filedict.items()}
query_filedict[i] = filedict
if args.verbose:
print(f'{len(filedict)} hits after pre-filtering')
print(f'loading per residue embeddings with pool: {emb_type} size: {len(filedict)}')
print(f'loading per residue embeddings')
filelist = list(filedict.values())
embedding_list = ds.load_full_embeddings(filelist=filelist)

#check_cohesion(db_df, filedict, embedding_list)


if len(filedict) == 0:
print('No hits after pre-filtering. Consider lowering `cosine_cutoff`')
sys.exit(0)

query_emb_pool = torch.nn.functional.avg_pool1d(query_emb.T.unsqueeze(0), args.EMB_POOL).T.squeeze()
query_emb_pool = query_emb_pool.numpy()
query_embs_pool = [
avg_pool1d(emb.T.unsqueeze(0), args.EMB_POOL).T.squeeze() for emb in query_embs]
query_embs_pool = [emb.numpy() for emb in query_embs_pool]
iter_id = 0
records_stack = []
num_indices = len(filedict)
num_indices_per_query = [len(vals) for vals in query_filedict.values()]
batch_size = 20*args.MAX_WORKERS
batch_size = min(300, batch_size)
num_batch = max(math.floor(num_indices/batch_size), 1)
num_batches_per_query = [max(math.floor(nind/batch_size), 1) for nind in num_indices_per_query]
num_batches = sum(num_batches_per_query)
# Multi-CPU search
print('running plm blast')
with tqdm(total=num_batch) as progress_bar:
for batch_start in range(0, num_batch):
bstart = batch_start*batch_size
bend = bstart + batch_size
# batch indices should not exeed num_indices
bend = min(bend, num_indices)
batchslice = slice(bstart, bend, 1)
filedictslice = itertools.islice(filedict.items(), bstart, bend)
# submit a batch of jobs
# concurrent poolexecutor may spawn to many processes which will lead
# to OS error batching should fix this issue
job_stack = {}
with concurrent.futures.ProcessPoolExecutor(max_workers = args.MAX_WORKERS) as executor:
for (idx, file), emb in zip(filedictslice, embedding_list[batchslice]):
job = executor.submit(module.full_compare, query_emb_pool, emb, idx, file)
job_stack[job] = iter_id
iter_id += 1
time.sleep(0.1)
for job in concurrent.futures.as_completed(job_stack):
try:
res = job.result()
if len(res) > 0:
records_stack.append(res)
except Exception as e:
raise AssertionError('job not done', e)
progress_bar.update(1)
gc.collect()
with tqdm(total=num_batches) as progress_bar:
for filedict, query_emb, batches in zip(query_filedict.values(), query_embs_pool, num_batches_per_query):
embedding_list = embedding_list = ds.load_full_embeddings(filelist=filelist)
num_indices = len(embedding_list)
for batch_start in range(0, batches):
bstart = batch_start*batch_size
bend = bstart + batch_size
# batch indices should not exeed num_indices
bend = min(bend, num_indices)
batchslice = slice(bstart, bend, 1)
filedictslice = itertools.islice(filedict.items(), bstart, bend)
# submit a batch of jobs
# concurrent poolexecutor may spawn to many processes which will lead
# to OS error batching should fix this issue
job_stack = {}
with concurrent.futures.ProcessPoolExecutor(max_workers = args.MAX_WORKERS) as executor:
for (idx, file), emb in zip(filedictslice, embedding_list[batchslice]):
job = executor.submit(module.full_compare, query_emb, emb, idx, file)
job_stack[job] = iter_id
iter_id += 1
time.sleep(0.1)
for job in concurrent.futures.as_completed(job_stack):
try:
res = job.result()
if len(res) > 0:
records_stack.append(res)
except Exception as e:
raise AssertionError('job not done', e)
progress_bar.update(1)
gc.collect()

resdf = pd.concat(records_stack)
if resdf.score.max() > 1:
Expand Down
29 changes: 24 additions & 5 deletions tests/_test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import subprocess
import pytest

import pandas as pd
import torch as th

DIR = os.path.dirname(__file__)
EMBEDDING_SCRIPT = "embeddings.py"
EMBEDDING_DATA = os.path.join(DIR, "test_data/seq.p")
Expand All @@ -10,20 +13,36 @@

@pytest.mark.parametrize("embedder", ["pt", "esm"])
@pytest.mark.parametrize("truncate", ["200", "500"])
def test_embedding_generation(embedder, truncate):
@pytest.mark.parametrize("batchsize", ['16', '0'])
def test_embedding_generation(embedder, truncate, batchsize):
if not os.path.isdir("tests/output"):
os.mkdir("tests/output")
if not os.path.isfile(EMBEDDING_SCRIPT):
raise FileNotFoundError(f'no embedder script in: {EMBEDDING_SCRIPT}')
embdata = pd.read_pickle(EMBEDDING_DATA)
seqlist = embdata['seq'].tolist()
proc = subprocess.run(["python", "embeddings.py",
EMBEDDING_DATA, EMBEDDING_OUTPUT,
"-embedder", embedder,
"--truncate", truncate],
"--truncate", truncate,
"-bs", batchsize],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE)
# check output
# chech process error code
assert proc.returncode == 0, proc.stderr
assert proc.stderr, proc.stderr
# check process output file/dir
assert os.path.isfile(EMBEDDING_OUTPUT), f'missing embedding output file, {EMBEDDING_OUTPUT} {proc.stderr}'
# check output consistency
embout = th.load(EMBEDDING_OUTPUT)
assert len(embout) == embdata.shape[0], proc.stderr
# check embedding size of each sequence
for i in range(embdata.shape[0]):
emblen = embout[i].shape
seqlen = len(seqlist[i])
assert emblen[0] == seqlen, f'{emblen[0]} != {seqlen}, emb full shape: {emblen}'
# remove output
os.remove(EMBEDDING_OUTPUT)
assert proc.returncode == 0, proc.stderr




Expand Down