Skip to content

Commit

Permalink
Merge pull request #8 from labstructbioinf/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Argusmocny authored Jun 12, 2023
2 parents 248cd9b + 52e17cb commit a1cd4c0
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 74 deletions.
9 changes: 6 additions & 3 deletions embedders/prottrans.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,25 @@ def main_prottrans(df: pd.DataFrame, args: argparse.Namespace, iterator: List[sl
for batch_id_filename, batchslice in tqdm(enumerate(iterator), total=len(iterator)):
seqlist = seqlist_all[batchslice]
lenlist = lenlist_all[batchslice]
# add empty character between all residues
# his is mandatory for pt5 embedders
seqlist = [' '.join(list(seq)) for seq in seqlist]
batch_index = list(range(batchslice.start, batchslice.stop))
ids = tokenizer.batch_encode_plus(seqlist, add_special_tokens=True, padding="longest")
input_ids = torch.tensor(ids['input_ids']).to(device, non_blocking=True)
attention_mask = torch.tensor(ids['attention_mask']).to(device, non_blocking=True)
with torch.no_grad():
embeddings = model(input_ids=input_ids, attention_mask=attention_mask)
embeddings = embeddings[0].float().cpu()
embeddings = embeddings.last_hidden_state.float().cpu()
# remove sequence padding
num_batch_embeddings = len(embeddings)
assert num_batch_embeddings == len(seqlist)
embeddings_filt = []
for i in range(num_batch_embeddings):
seq_len = lenlist[i]
emb = embeddings[i]
if emb.shape[1] < seq_len:
raise KeyError(f'sequence is longer then embedding {emb.shape[1]} and {seq_len} ')
if emb.shape[0] < seq_len:
raise KeyError(f'sequence is longer then embedding {emb.shape} and {seq_len} ')
embeddings_filt.append(emb[:seq_len])
# store each batch depending on save mode
if args.asdir:
Expand Down
136 changes: 70 additions & 66 deletions scripts/run_plm_blast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import concurrent
import itertools
import datetime
from typing import List

import pandas as pd
import numpy as np
import torch
from torch.nn.functional import avg_pool1d
from tqdm import tqdm
from Bio.Align import substitution_matrices
blosum62 = substitution_matrices.load("BLOSUM62")
Expand Down Expand Up @@ -52,9 +53,9 @@ def get_parser():
type=str)
group = parser.add_mutually_exclusive_group()
group.add_argument('-cosine_cutoff', help='Cosine similarity cut-off (0..1)',
type=range01, default=None, dest='COS_SIM_CUT')
type=range01, default=0.2, dest='COS_SIM_CUT')
group.add_argument('-cosine_percentile_cutoff', help='Cosine similarity percentile cut-off (0-100)',
type=range0100, default=None, dest='COS_PER_CUT')
type=range0100, default=99, dest='COS_PER_CUT')
parser.add_argument('-alignment_cutoff', help='Alignment score cut-off (default: %(default)s)',
type=float, default=0.4, dest='ALN_CUT')
parser.add_argument('-sigma_factor', help='The Sigma factor defines the greediness of the local alignment search procedure. Values <1 may result in longer alignments (default: %(default)s)',
Expand Down Expand Up @@ -85,7 +86,6 @@ def get_parser():
assert args.MAX_TARGETS > 0
assert args.MAX_WORKERS > 0, 'At least one CPU core is needed!'
assert args.COS_SIM_CUT != None or args.COS_PER_CUT != None, 'Please define COS_PER_CUT _or_ COS_SIM_CUT!'

return args

def check_cohesion(frame, filedict, embeddings, truncate=600):
Expand Down Expand Up @@ -154,94 +154,98 @@ def calc_ident(s1, s2):

# read query
query_index = args.query + '.csv'
query_emb = args.query + '.pt_emb.p'

query_embs = args.query + '.pt_emb.p'
query_df = pd.read_csv(query_index)
query_emb = torch.load(query_emb)[0]
# select embeddings pool factor
if args.EMB_POOL == 1:
emb_type = ''
elif args.EMB_POOL == 2:
emb_type = '.512'
elif args.EMB_POOL == 4:
emb_type = '.256'
else:
raise ValueError(f'invalid EMB_POOL value: {args.EMB_POOL}')

query_seq = str(query_df.iloc[0].sequence)
query_embs = torch.load(query_embs)
if query_df.shape[0] != len(query_embs):
raise ValueError(f'length of embedding file and sequences df is different {query_df.shape[0]} != {len(query_emb)}')
if args.verbose:
print(f'query sequences: {query_df.shape[0]}')
query_seqs = query_df['sequence'].tolist()
query_seqs: List[str]= [str(seq) for seq in query_seqs]

##########################################################################
# filtering #
##########################################################################
query_filedict = dict()
if args.use_chunkcs:
if args.verbose:
print('chunk cosine similarity screening ...')
query_emb_chunkcs = torch.nn.functional.avg_pool1d(query_emb.unsqueeze(0), 16).squeeze()
##########################################################################
# fixed #
##########################################################################
query_emb_chunkcs = [
avg_pool1d(emb.unsqueeze(0), 16).squeeze() for emb in query_embs]
dbfile = os.path.join(args.db, 'emb.64')
filelist = [os.path.join(args.db, f'{f}.emb') for f in range(0, db_df.shape[0])]
embedding_list = torch.load(dbfile)
filedict = ds.local.chunk_cosine_similarity(query = query_emb_chunkcs,
targets = embedding_list,
quantile = args.COS_PER_CUT/100,
dataset_files = filelist,
stride = 10)
for i, emb in enumerate(query_emb_chunkcs):
filedict = ds.local.chunk_cosine_similarity(query=emb,
targets=embedding_list,
quantile=args.COS_PER_CUT/100,
dataset_files=filelist,
stride=10)
query_filedict[i] = filedict
else:
if args.verbose:
print('cosine similarity screening ...')
filedict = ds.load_and_score_database(query_emb,
dbpath = args.db,
quantile = args.COS_PER_CUT/100,
num_workers = args.MAX_WORKERS)
filedict = { k : v.replace('.emb.sum', f'.emb{emb_type}') for k, v in filedict.items()}
for i, emb in enumerate(query_embs):
filedict = ds.load_and_score_database(emb,
dbpath = args.db,
quantile = args.COS_PER_CUT/100,
num_workers = args.MAX_WORKERS)
filedict = { k : v.replace('.emb.sum', f'.emb{emb_type}') for k, v in filedict.items()}
query_filedict[i] = filedict
if args.verbose:
print(f'{len(filedict)} hits after pre-filtering')
print(f'loading per residue embeddings with pool: {emb_type} size: {len(filedict)}')
print(f'loading per residue embeddings')
filelist = list(filedict.values())
embedding_list = ds.load_full_embeddings(filelist=filelist)

#check_cohesion(db_df, filedict, embedding_list)


if len(filedict) == 0:
print('No hits after pre-filtering. Consider lowering `cosine_cutoff`')
sys.exit(0)

query_emb_pool = torch.nn.functional.avg_pool1d(query_emb.T.unsqueeze(0), args.EMB_POOL).T.squeeze()
query_emb_pool = query_emb_pool.numpy()
query_embs_pool = [
avg_pool1d(emb.T.unsqueeze(0), args.EMB_POOL).T.squeeze() for emb in query_embs]
query_embs_pool = [emb.numpy() for emb in query_embs_pool]
iter_id = 0
records_stack = []
num_indices = len(filedict)
num_indices_per_query = [len(vals) for vals in query_filedict.values()]
batch_size = 20*args.MAX_WORKERS
batch_size = min(300, batch_size)
num_batch = max(math.floor(num_indices/batch_size), 1)
num_batches_per_query = [max(math.floor(nind/batch_size), 1) for nind in num_indices_per_query]
num_batches = sum(num_batches_per_query)
# Multi-CPU search
print('running plm blast')
with tqdm(total=num_batch) as progress_bar:
for batch_start in range(0, num_batch):
bstart = batch_start*batch_size
bend = bstart + batch_size
# batch indices should not exeed num_indices
bend = min(bend, num_indices)
batchslice = slice(bstart, bend, 1)
filedictslice = itertools.islice(filedict.items(), bstart, bend)
# submit a batch of jobs
# concurrent poolexecutor may spawn to many processes which will lead
# to OS error batching should fix this issue
job_stack = {}
with concurrent.futures.ProcessPoolExecutor(max_workers = args.MAX_WORKERS) as executor:
for (idx, file), emb in zip(filedictslice, embedding_list[batchslice]):
job = executor.submit(module.full_compare, query_emb_pool, emb, idx, file)
job_stack[job] = iter_id
iter_id += 1
time.sleep(0.1)
for job in concurrent.futures.as_completed(job_stack):
try:
res = job.result()
if len(res) > 0:
records_stack.append(res)
except Exception as e:
raise AssertionError('job not done', e)
progress_bar.update(1)
gc.collect()
with tqdm(total=num_batches) as progress_bar:
for filedict, query_emb, batches in zip(query_filedict.values(), query_embs_pool, num_batches_per_query):
embedding_list = embedding_list = ds.load_full_embeddings(filelist=filelist)
num_indices = len(embedding_list)
for batch_start in range(0, batches):
bstart = batch_start*batch_size
bend = bstart + batch_size
# batch indices should not exeed num_indices
bend = min(bend, num_indices)
batchslice = slice(bstart, bend, 1)
filedictslice = itertools.islice(filedict.items(), bstart, bend)
# submit a batch of jobs
# concurrent poolexecutor may spawn to many processes which will lead
# to OS error batching should fix this issue
job_stack = {}
with concurrent.futures.ProcessPoolExecutor(max_workers = args.MAX_WORKERS) as executor:
for (idx, file), emb in zip(filedictslice, embedding_list[batchslice]):
job = executor.submit(module.full_compare, query_emb, emb, idx, file)
job_stack[job] = iter_id
iter_id += 1
time.sleep(0.1)
for job in concurrent.futures.as_completed(job_stack):
try:
res = job.result()
if len(res) > 0:
records_stack.append(res)
except Exception as e:
raise AssertionError('job not done', e)
progress_bar.update(1)
gc.collect()

resdf = pd.concat(records_stack)
if resdf.score.max() > 1:
Expand Down
29 changes: 24 additions & 5 deletions tests/_test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import subprocess
import pytest

import pandas as pd
import torch as th

DIR = os.path.dirname(__file__)
EMBEDDING_SCRIPT = "embeddings.py"
EMBEDDING_DATA = os.path.join(DIR, "test_data/seq.p")
Expand All @@ -10,20 +13,36 @@

@pytest.mark.parametrize("embedder", ["pt", "esm"])
@pytest.mark.parametrize("truncate", ["200", "500"])
def test_embedding_generation(embedder, truncate):
@pytest.mark.parametrize("batchsize", ['16', '0'])
def test_embedding_generation(embedder, truncate, batchsize):
if not os.path.isdir("tests/output"):
os.mkdir("tests/output")
if not os.path.isfile(EMBEDDING_SCRIPT):
raise FileNotFoundError(f'no embedder script in: {EMBEDDING_SCRIPT}')
embdata = pd.read_pickle(EMBEDDING_DATA)
seqlist = embdata['seq'].tolist()
proc = subprocess.run(["python", "embeddings.py",
EMBEDDING_DATA, EMBEDDING_OUTPUT,
"-embedder", embedder,
"--truncate", truncate],
"--truncate", truncate,
"-bs", batchsize],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE)
# check output
# chech process error code
assert proc.returncode == 0, proc.stderr
assert proc.stderr, proc.stderr
# check process output file/dir
assert os.path.isfile(EMBEDDING_OUTPUT), f'missing embedding output file, {EMBEDDING_OUTPUT} {proc.stderr}'
# check output consistency
embout = th.load(EMBEDDING_OUTPUT)
assert len(embout) == embdata.shape[0], proc.stderr
# check embedding size of each sequence
for i in range(embdata.shape[0]):
emblen = embout[i].shape
seqlen = len(seqlist[i])
assert emblen[0] == seqlen, f'{emblen[0]} != {seqlen}, emb full shape: {emblen}'
# remove output
os.remove(EMBEDDING_OUTPUT)
assert proc.returncode == 0, proc.stderr




Expand Down

0 comments on commit a1cd4c0

Please sign in to comment.