diff --git a/embedders/prottrans.py b/embedders/prottrans.py index 65e4209..c7e4ed1 100644 --- a/embedders/prottrans.py +++ b/embedders/prottrans.py @@ -41,13 +41,16 @@ def main_prottrans(df: pd.DataFrame, args: argparse.Namespace, iterator: List[sl for batch_id_filename, batchslice in tqdm(enumerate(iterator), total=len(iterator)): seqlist = seqlist_all[batchslice] lenlist = lenlist_all[batchslice] + # add empty character between all residues + # his is mandatory for pt5 embedders + seqlist = [' '.join(list(seq)) for seq in seqlist] batch_index = list(range(batchslice.start, batchslice.stop)) ids = tokenizer.batch_encode_plus(seqlist, add_special_tokens=True, padding="longest") input_ids = torch.tensor(ids['input_ids']).to(device, non_blocking=True) attention_mask = torch.tensor(ids['attention_mask']).to(device, non_blocking=True) with torch.no_grad(): embeddings = model(input_ids=input_ids, attention_mask=attention_mask) - embeddings = embeddings[0].float().cpu() + embeddings = embeddings.last_hidden_state.float().cpu() # remove sequence padding num_batch_embeddings = len(embeddings) assert num_batch_embeddings == len(seqlist) @@ -55,8 +58,8 @@ def main_prottrans(df: pd.DataFrame, args: argparse.Namespace, iterator: List[sl for i in range(num_batch_embeddings): seq_len = lenlist[i] emb = embeddings[i] - if emb.shape[1] < seq_len: - raise KeyError(f'sequence is longer then embedding {emb.shape[1]} and {seq_len} ') + if emb.shape[0] < seq_len: + raise KeyError(f'sequence is longer then embedding {emb.shape} and {seq_len} ') embeddings_filt.append(emb[:seq_len]) # store each batch depending on save mode if args.asdir: diff --git a/scripts/run_plm_blast.py b/scripts/run_plm_blast.py index f999924..9cb1803 100644 --- a/scripts/run_plm_blast.py +++ b/scripts/run_plm_blast.py @@ -7,10 +7,11 @@ import concurrent import itertools import datetime +from typing import List import pandas as pd -import numpy as np import torch +from torch.nn.functional import avg_pool1d from tqdm import tqdm from Bio.Align import substitution_matrices blosum62 = substitution_matrices.load("BLOSUM62") @@ -52,9 +53,9 @@ def get_parser(): type=str) group = parser.add_mutually_exclusive_group() group.add_argument('-cosine_cutoff', help='Cosine similarity cut-off (0..1)', - type=range01, default=None, dest='COS_SIM_CUT') + type=range01, default=0.2, dest='COS_SIM_CUT') group.add_argument('-cosine_percentile_cutoff', help='Cosine similarity percentile cut-off (0-100)', - type=range0100, default=None, dest='COS_PER_CUT') + type=range0100, default=99, dest='COS_PER_CUT') parser.add_argument('-alignment_cutoff', help='Alignment score cut-off (default: %(default)s)', type=float, default=0.4, dest='ALN_CUT') parser.add_argument('-sigma_factor', help='The Sigma factor defines the greediness of the local alignment search procedure. Values <1 may result in longer alignments (default: %(default)s)', @@ -85,7 +86,6 @@ def get_parser(): assert args.MAX_TARGETS > 0 assert args.MAX_WORKERS > 0, 'At least one CPU core is needed!' assert args.COS_SIM_CUT != None or args.COS_PER_CUT != None, 'Please define COS_PER_CUT _or_ COS_SIM_CUT!' - return args def check_cohesion(frame, filedict, embeddings, truncate=600): @@ -154,50 +154,49 @@ def calc_ident(s1, s2): # read query query_index = args.query + '.csv' -query_emb = args.query + '.pt_emb.p' - +query_embs = args.query + '.pt_emb.p' query_df = pd.read_csv(query_index) -query_emb = torch.load(query_emb)[0] -# select embeddings pool factor -if args.EMB_POOL == 1: - emb_type = '' -elif args.EMB_POOL == 2: - emb_type = '.512' -elif args.EMB_POOL == 4: - emb_type = '.256' -else: - raise ValueError(f'invalid EMB_POOL value: {args.EMB_POOL}') - -query_seq = str(query_df.iloc[0].sequence) +query_embs = torch.load(query_embs) +if query_df.shape[0] != len(query_embs): + raise ValueError(f'length of embedding file and sequences df is different {query_df.shape[0]} != {len(query_emb)}') +if args.verbose: + print(f'query sequences: {query_df.shape[0]}') +query_seqs = query_df['sequence'].tolist() +query_seqs: List[str]= [str(seq) for seq in query_seqs] +########################################################################## +# filtering # +########################################################################## +query_filedict = dict() if args.use_chunkcs: if args.verbose: print('chunk cosine similarity screening ...') - query_emb_chunkcs = torch.nn.functional.avg_pool1d(query_emb.unsqueeze(0), 16).squeeze() - ########################################################################## - # fixed # - ########################################################################## + query_emb_chunkcs = [ + avg_pool1d(emb.unsqueeze(0), 16).squeeze() for emb in query_embs] dbfile = os.path.join(args.db, 'emb.64') filelist = [os.path.join(args.db, f'{f}.emb') for f in range(0, db_df.shape[0])] embedding_list = torch.load(dbfile) - filedict = ds.local.chunk_cosine_similarity(query = query_emb_chunkcs, - targets = embedding_list, - quantile = args.COS_PER_CUT/100, - dataset_files = filelist, - stride = 10) + for i, emb in enumerate(query_emb_chunkcs): + filedict = ds.local.chunk_cosine_similarity(query=emb, + targets=embedding_list, + quantile=args.COS_PER_CUT/100, + dataset_files=filelist, + stride=10) + query_filedict[i] = filedict else: if args.verbose: print('cosine similarity screening ...') - filedict = ds.load_and_score_database(query_emb, - dbpath = args.db, - quantile = args.COS_PER_CUT/100, - num_workers = args.MAX_WORKERS) - filedict = { k : v.replace('.emb.sum', f'.emb{emb_type}') for k, v in filedict.items()} + for i, emb in enumerate(query_embs): + filedict = ds.load_and_score_database(emb, + dbpath = args.db, + quantile = args.COS_PER_CUT/100, + num_workers = args.MAX_WORKERS) + filedict = { k : v.replace('.emb.sum', f'.emb{emb_type}') for k, v in filedict.items()} + query_filedict[i] = filedict if args.verbose: - print(f'{len(filedict)} hits after pre-filtering') - print(f'loading per residue embeddings with pool: {emb_type} size: {len(filedict)}') + print(f'loading per residue embeddings') filelist = list(filedict.values()) -embedding_list = ds.load_full_embeddings(filelist=filelist) + #check_cohesion(db_df, filedict, embedding_list) @@ -205,43 +204,48 @@ def calc_ident(s1, s2): print('No hits after pre-filtering. Consider lowering `cosine_cutoff`') sys.exit(0) -query_emb_pool = torch.nn.functional.avg_pool1d(query_emb.T.unsqueeze(0), args.EMB_POOL).T.squeeze() -query_emb_pool = query_emb_pool.numpy() +query_embs_pool = [ + avg_pool1d(emb.T.unsqueeze(0), args.EMB_POOL).T.squeeze() for emb in query_embs] +query_embs_pool = [emb.numpy() for emb in query_embs_pool] iter_id = 0 records_stack = [] -num_indices = len(filedict) +num_indices_per_query = [len(vals) for vals in query_filedict.values()] batch_size = 20*args.MAX_WORKERS batch_size = min(300, batch_size) -num_batch = max(math.floor(num_indices/batch_size), 1) +num_batches_per_query = [max(math.floor(nind/batch_size), 1) for nind in num_indices_per_query] +num_batches = sum(num_batches_per_query) # Multi-CPU search print('running plm blast') -with tqdm(total=num_batch) as progress_bar: - for batch_start in range(0, num_batch): - bstart = batch_start*batch_size - bend = bstart + batch_size - # batch indices should not exeed num_indices - bend = min(bend, num_indices) - batchslice = slice(bstart, bend, 1) - filedictslice = itertools.islice(filedict.items(), bstart, bend) - # submit a batch of jobs - # concurrent poolexecutor may spawn to many processes which will lead - # to OS error batching should fix this issue - job_stack = {} - with concurrent.futures.ProcessPoolExecutor(max_workers = args.MAX_WORKERS) as executor: - for (idx, file), emb in zip(filedictslice, embedding_list[batchslice]): - job = executor.submit(module.full_compare, query_emb_pool, emb, idx, file) - job_stack[job] = iter_id - iter_id += 1 - time.sleep(0.1) - for job in concurrent.futures.as_completed(job_stack): - try: - res = job.result() - if len(res) > 0: - records_stack.append(res) - except Exception as e: - raise AssertionError('job not done', e) - progress_bar.update(1) - gc.collect() +with tqdm(total=num_batches) as progress_bar: + for filedict, query_emb, batches in zip(query_filedict.values(), query_embs_pool, num_batches_per_query): + embedding_list = embedding_list = ds.load_full_embeddings(filelist=filelist) + num_indices = len(embedding_list) + for batch_start in range(0, batches): + bstart = batch_start*batch_size + bend = bstart + batch_size + # batch indices should not exeed num_indices + bend = min(bend, num_indices) + batchslice = slice(bstart, bend, 1) + filedictslice = itertools.islice(filedict.items(), bstart, bend) + # submit a batch of jobs + # concurrent poolexecutor may spawn to many processes which will lead + # to OS error batching should fix this issue + job_stack = {} + with concurrent.futures.ProcessPoolExecutor(max_workers = args.MAX_WORKERS) as executor: + for (idx, file), emb in zip(filedictslice, embedding_list[batchslice]): + job = executor.submit(module.full_compare, query_emb, emb, idx, file) + job_stack[job] = iter_id + iter_id += 1 + time.sleep(0.1) + for job in concurrent.futures.as_completed(job_stack): + try: + res = job.result() + if len(res) > 0: + records_stack.append(res) + except Exception as e: + raise AssertionError('job not done', e) + progress_bar.update(1) + gc.collect() resdf = pd.concat(records_stack) if resdf.score.max() > 1: diff --git a/tests/_test_embeddings.py b/tests/_test_embeddings.py index 312bf36..a9336d4 100644 --- a/tests/_test_embeddings.py +++ b/tests/_test_embeddings.py @@ -2,6 +2,9 @@ import subprocess import pytest +import pandas as pd +import torch as th + DIR = os.path.dirname(__file__) EMBEDDING_SCRIPT = "embeddings.py" EMBEDDING_DATA = os.path.join(DIR, "test_data/seq.p") @@ -10,20 +13,36 @@ @pytest.mark.parametrize("embedder", ["pt", "esm"]) @pytest.mark.parametrize("truncate", ["200", "500"]) -def test_embedding_generation(embedder, truncate): +@pytest.mark.parametrize("batchsize", ['16', '0']) +def test_embedding_generation(embedder, truncate, batchsize): if not os.path.isdir("tests/output"): os.mkdir("tests/output") + if not os.path.isfile(EMBEDDING_SCRIPT): + raise FileNotFoundError(f'no embedder script in: {EMBEDDING_SCRIPT}') + embdata = pd.read_pickle(EMBEDDING_DATA) + seqlist = embdata['seq'].tolist() proc = subprocess.run(["python", "embeddings.py", EMBEDDING_DATA, EMBEDDING_OUTPUT, "-embedder", embedder, - "--truncate", truncate], + "--truncate", truncate, + "-bs", batchsize], stderr=subprocess.PIPE, stdout=subprocess.PIPE) - # check output + # chech process error code + assert proc.returncode == 0, proc.stderr + assert proc.stderr, proc.stderr + # check process output file/dir assert os.path.isfile(EMBEDDING_OUTPUT), f'missing embedding output file, {EMBEDDING_OUTPUT} {proc.stderr}' + # check output consistency + embout = th.load(EMBEDDING_OUTPUT) + assert len(embout) == embdata.shape[0], proc.stderr + # check embedding size of each sequence + for i in range(embdata.shape[0]): + emblen = embout[i].shape + seqlen = len(seqlist[i]) + assert emblen[0] == seqlen, f'{emblen[0]} != {seqlen}, emb full shape: {emblen}' + # remove output os.remove(EMBEDDING_OUTPUT) - assert proc.returncode == 0, proc.stderr -