diff --git a/playground.py b/playground.py index 1182518..ae013c9 100644 --- a/playground.py +++ b/playground.py @@ -1,3 +1,37 @@ +# # %% +# import numpy as np +# import torch + +# d = torch.load("/cluster/home/t122995uhn/projects/data/v131/DavisKibaDataset/davis/nomsa_aflow_original_binary/full/data_pro.pt") +# np.array(list(d['ABL1(F317I)p'].pro_seq))[d['ABL1(F317I)p'].pocket_mask].shape + + + +# %% +# building pocket datasets: +from src.utils.pocket_alignment import pocket_dataset_full +import shutil +import os + +data_dir = '/cluster/home/t122995uhn/projects/data/' +db_type = ['kiba', 'davis'] +db_feat = ['nomsa_binary_original_binary', 'nomsa_aflow_original_binary', + 'nomsa_binary_gvp_binary', 'nomsa_aflow_gvp_binary'] + +for t in db_type: + for f in db_feat: + print(f'\n---{t}-{f}---\n') + dataset_dir= f"{data_dir}/DavisKibaDataset/{t}/{f}/full" + save_dir = f"{data_dir}/v131/DavisKibaDataset/{t}/{f}/full" + + pocket_dataset_full( + dataset_dir= dataset_dir, + pocket_dir = f"{data_dir}/{t}/", + save_dir = save_dir, + skip_download=True + ) + + #%% import pandas as pd @@ -37,45 +71,65 @@ def get_test_oncokbs(train_df=pd.read_csv('/cluster/home/t122995uhn/projects/dat #%% -######################################################################## -########################## BUILD DATASETS ############################## -######################################################################## +############################################################################## +########################## BUILD/SPLIT DATASETS ############################## +############################################################################## import os from src.data_prep.init_dataset import create_datasets from src import cfg import logging cfg.logger.setLevel(logging.DEBUG) -splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/' -create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba], +dbs = [cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba] +splits = ['davis', 'kiba'] +splits = ['/cluster/home/t122995uhn/projects/MutDTA/splits/' + s for s in splits] +print(splits) + +#%% +for split, db in zip(splits, dbs): + print('\n',split, db) + create_datasets(db, feat_opt=cfg.PRO_FEAT_OPT.nomsa, edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow], ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False, k_folds=5, - test_prots_csv=f'{splits}/test.csv', - val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],) - # data_root=os.path.abspath('../data/test/')) + test_prots_csv=f'{split}/test.csv', + val_prots_csv=[f'{split}/val{i}.csv' for i in range(5)]) -# %% Copy splits to commit them: -#from to: -import shutil -from_dir_p = '/cluster/home/t122995uhn/projects/data/v131/' -to_dir_p = '/cluster/home/t122995uhn/projects/MutDTA/splits/' -from_db = ['PDBbindDataset', 'DavisKibaDataset/kiba', 'DavisKibaDataset/davis'] -to_db = ['pdbbind', 'kiba', 'davis'] - -from_db = [f'{from_dir_p}/{f}/nomsa_binary_original_binary/' for f in from_db] -to_db = [f'{to_dir_p}/{f}' for f in to_db] - -for src, dst in zip(from_db, to_db): - for x in ['train', 'val']: - for i in range(5): - print(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv") - shutil.copy(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv") - - print(f"{src}/test/XY.csv", f"{dst}/test.csv") - shutil.copy(f"{src}/test/XY.csv", f"{dst}/test.csv") - - +#%% TEST INFERENCE +from src import cfg +from src.utils.loader import Loader + +# db2 = Loader.load_dataset(cfg.DATA_OPT.davis, +# cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, +# path='/cluster/home/t122995uhn/projects/data/', +# subset="full") + +db2 = Loader.load_DataLoaders(cfg.DATA_OPT.davis, + cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, + path='/cluster/home/t122995uhn/projects/data/v131', + training_fold=0, + batch_train=2) +for b2 in db2['test']: break + + +# %% +m = Loader.init_model(cfg.MODEL_OPT.DG, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, + dropout=0.3480, output_dim=256, + ) + +#%% +# m(b['protein'], b['ligand']) +m(b2['protein'], b2['ligand']) +#%% +model = m +loaders = db2 +device = 'cpu' +NUM_EPOCHS = 1 +LEARNING_RATE = 0.001 +from src.train_test.training import train + +logs = train(model, loaders['train'], loaders['val'], device, + epochs=NUM_EPOCHS, lr_0=LEARNING_RATE) # %% diff --git a/src/utils/pocket_alignment.py b/src/utils/pocket_alignment.py index 8ff30eb..c0bbec9 100644 --- a/src/utils/pocket_alignment.py +++ b/src/utils/pocket_alignment.py @@ -9,6 +9,7 @@ from Bio import Align from Bio.Align import substitution_matrices +import numpy as np import pandas as pd import torch @@ -78,26 +79,34 @@ def mask_graph(data, mask: list[bool]): additional attributes: -pocket_mask : list[bool] The mask specified by the mask parameter of dimension [full_seuqence_length] - -pocket_mask_x : torch.Tensor + -x : torch.Tensor The nodes of only the pocket of the protein sequence of dimension [pocket_sequence_length, num_features] - -pocket_mask_edge_index : torch.Tensor + -edge_index : torch.Tensor The edge connections in COO format only relating to the pocket nodes of the protein sequence of dimension [2, num_pocket_edges] """ + # node map for updating edge indicies after mask + node_map = np.cumsum(mask) - 1 + nodes = data.x[mask] - edges = data.edge_index + edges = [] edge_mask = [] - for i in range(edges.shape[1]): - # Throw out edges that are connected to at least one node not in the - # binding pocket - node_1, node_2 = edges[:,i][0], edges[:,i][1] - edge_mask.append(True) if mask[node_1] and mask[node_2] else edge_mask.append(False) - edges = torch.transpose(torch.transpose(edges, 0, 1)[edge_mask], 0, 1) + for i in range(data.edge_index.shape[1]): + # Throw out edges that are not part of connecting two nodes in the pocket... + node_1, node_2 = data.edge_index[:,i][0], data.edge_index[:,i][1] + if mask[node_1] and mask[node_2]: + # append mapped index: + edges.append([node_map[node_1], node_map[node_2]]) + edge_mask.append(True) + else: + edge_mask.append(False) + data.x = nodes data.pocket_mask = mask - data.pocket_mask_x = nodes - data.pocket_mask_edge_index = edges + data.edge_index = torch.tensor(edges).T # reshape to (2, E) + if 'edge_weight' in data: + data.edge_weight = data.edge_weight[edge_mask] return data @@ -122,7 +131,8 @@ def _parse_json(json_path: str) -> str: def get_dataset_binding_pockets( dataset_path: str = 'data/DavisKibaDataset/kiba/nomsa_binary_original_binary/full', - pockets_path: str = 'data/DavisKibaDataset/kiba_pocket' + pockets_path: str = 'data/DavisKibaDataset/kiba_pocket', + skip_download: bool = False, ) -> tuple[dict[str, str], set[str]]: """ Get all binding pocket sequences for a dataset @@ -149,14 +159,14 @@ def get_dataset_binding_pockets( # Strip out mutations and '-(alpha, beta, gamma)' tags if they are present, # the binding pocket sequence will be the same for mutated and non-mutated genes prot_ids = [id.split('(')[0].split('-')[0] for id in prot_ids] - dl = Downloader() seq_save_dir = os.path.join(pockets_path, 'pockets') - os.makedirs(seq_save_dir, exist_ok=True) - download_check = dl.download_pocket_seq(prot_ids, seq_save_dir) + + if not skip_download: # to use cached downloads only! (useful when on compute node) + dl = Downloader() + os.makedirs(seq_save_dir, exist_ok=True) + dl.download_pocket_seq(prot_ids, seq_save_dir) + download_errors = set() - for key, val in download_check.items(): - if val == 400: - download_errors.add(key) sequences = {} for file in os.listdir(seq_save_dir): pocket_seq = _parse_json(os.path.join(seq_save_dir, file)) @@ -164,6 +174,12 @@ def get_dataset_binding_pockets( download_errors.add(file.split('.')[0]) else: sequences[file.split('.')[0]] = pocket_seq + + # adding any remainder prots not downloaded. + for p in prot_ids: + if p not in sequences: + download_errors.add(p) + return (sequences, download_errors) @@ -197,7 +213,7 @@ def create_binding_pocket_dataset( new_data = mask_graph(data, mask) new_dataset[id] = new_data os.makedirs(os.path.dirname(new_dataset_path), exist_ok=True) - torch.save(dataset, new_dataset_path) + torch.save(new_dataset, new_dataset_path) def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_save_path: str): @@ -215,8 +231,8 @@ def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_ csv_save_path : str The path to save the new CSV file to. """ - df = pd.read_csv(dataset_csv_path) - df = df[~df['prot_id'].isin(download_errors)] + df = pd.read_csv(dataset_csv_path, index_col=0) + df = df[~df.prot_id.str.split('(').str[0].str.split('-').str[0].isin(download_errors)] os.makedirs(os.path.dirname(csv_save_path), exist_ok=True) df.to_csv(csv_save_path) @@ -224,7 +240,8 @@ def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_ def pocket_dataset_full( dataset_dir: str, pocket_dir: str, - save_dir: str + save_dir: str, + skip_download: bool = False ) -> None: """ Create all elements of a dataset that includes binding pockets. This @@ -240,7 +257,7 @@ def pocket_dataset_full( save_dir : str The path to where the new dataset is to be saved """ - pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir) + pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir, skip_download) print(f'Binding pocket sequences were not found for the following {len(download_errors)} protein IDs:') print(','.join(list(download_errors))) create_binding_pocket_dataset( @@ -254,7 +271,9 @@ def pocket_dataset_full( download_errors, os.path.join(save_dir, 'cleaned_XY.csv') ) - shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt')) + if dataset_dir != save_dir: + shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt')) + shutil.copy2(os.path.join(dataset_dir, 'XY.csv'), os.path.join(save_dir, 'XY.csv')) if __name__ == '__main__': diff --git a/train_test.py b/train_test.py index 8bc9bba..d98b730 100644 --- a/train_test.py +++ b/train_test.py @@ -2,7 +2,22 @@ from src.utils.arg_parse import parse_train_test_args args, unknown_args = parse_train_test_args(verbose=True, - jyp_args='-m DG -d PDBbind -f nomsa -e binary -bs 64') + jyp_args='--model_opt DG \ + --data_opt davis \ + \ + --feature_opt nomsa \ + --edge_opt binary \ + --ligand_feature_opt original \ + --ligand_edge_opt binary \ + \ + --learning_rate 0.00012 \ + --batch_size 128 \ + --dropout 0.24 \ + --output_dim 128 \ + \ + --train \ + --fold_selection 0 \ + --num_epochs 2000') FORCE_TRAINING = args.train DEBUG = args.debug